Skip to main content

Gradients & backward pass on the graph

Definition

Let f :TR be a differentiable function, where T is the space of tensors.\text{Let f :} \mathcal{T} → \mathbb{R} \text{ be a differentiable function, where } \mathcal{T} \text{ is the space of tensors.} The gradient of f at AT is the tensor f(A)T defined by:\text{The gradient of } f \text{ at } A \in \mathcal{T} \text{ is the tensor } \nabla f(A) \in \mathcal{T} \text{ defined by:} dAT, t.q. ddA(1)=dA(1),Df(A)[dA]=f(A),dA\forall\, dA \in \mathcal{T}, \text{ t.q. } \mathcal{d}_{dA}(-1) = \mathcal{d}_A(-1), \qquad \mathrm{D}f(A)[dA] = \langle \nabla f(A),\, dA \rangle

This reads as the variation of ff evaluated at AA for any small variation dAdA.

For scalars, this is the expected definition:

dA=hRdA = h \in \mathbb R and Df(A)[h]=f(A),h=hf(A)=f(A+h)f(A)Df(A)[h] = \langle f'(A), h \rangle = h*f'(A) = f(A+h) -f(A) for hh small enough. (f(A)=f(A)\nabla f(A) = f'(A))

When we introduced the graph, we saw that the nodes represented an operation and pointed to other operations (or a scalar if we were at the end of the graph). This scalar (often denoted LL) represents a loss function. We will come back to this function later. It is important to understand that LL shows the performance of the model, so the lower LL is, the better the model performs. We therefore seek to optimize the parameters according to LL.

Standard operations - gradients

To do this, we will start by calculating the gradients of each tensor according to the topological order given above. We will therefore start with LL, since its gradient is trivially equal to 1.

Next, we will calculate the gradient of each node in reverse topological order using the gradients already calculated.

The goal is to calculate the gradient of A, B (or A) given the gradient of f(A, B) (or f(A)).

Addition

Let L:TRL:\mathcal T\to\mathbb R be differentiable and f:T2Tf:\mathcal T^2\to\mathcal T, f(A,B)=A+Bf(A,B)=A+B. Let Y=f(A,B)Y=f(A,B) and G=YLY=f(A,B)TG=\nabla_Y L\big|_{Y=f(A,B)} \in \mathcal T

A part of the graph (the application of ff) therefore looks like this:

First, let us note that

dY=Df(A,B)[dA,dB]=dA+dBdY = Df(A, B)[dA, dB] = dA+dB car f(A+dA,B+dB)f(A,B)=dA+dB\text{car } f(A+dA, B+dB)-f(A, B) = dA+dB

Expanding LfL\circ f and thanks to the chain rule and the definition of the differential, we have:

D(Lf)(A,B)[dA,dB]=DL(f(A,B))[Df(A,B)[dA,dB]]\mathrm D(L \circ f)(A,B)[\mathrm dA,\mathrm dB] = \mathrm DL\big(f(A,B)\big)\big[\,\mathrm Df(A,B)[\mathrm dA,\mathrm dB]\,\big] =dL(Y)[dY]=G,dY= dL(Y)[dY] = \langle G, dY \rangle =G, dA+dB=G,dA+G,dB= \langle G,\ \mathrm dA+\mathrm dB\rangle = \langle G, \mathrm dA \rangle + \langle G, \mathrm dB \rangle

Furthermore, (definition)

D(Lf)(A,B)[dA,dB]=AL,dA+BL,dB.\mathrm D(L \circ f)(A,B)[\mathrm dA,\mathrm dB] = \langle \nabla_A L,\,\mathrm dA\rangle + \langle \nabla_B L,\,\mathrm dB\rangle.

By identification (valid for any (dA,dB)(\mathrm dA,\mathrm dB)): (Riesz)

 AL=BL =Gwhere G=YLY=A+B\boxed{\ \nabla_A L = \nabla_B L\ = G } \qquad\text{where } G=\nabla_Y L\big|_{Y=A+B}

Multiplication

Let L:TRL:\mathcal T\to\mathbb R be differentiable and f:T2Tf:\mathcal T^2\to\mathcal T, f(A,B)=A@Bf(A,B)=A@B. (tensor multiplication) Let Y=f(A,B)Y=f(A,B) and G=YLY=f(A,B)TG=\nabla_Y L\big|_{Y=f(A,B)} \in \mathcal T

A part of the graph (the application of ff) therefore looks like this:

First, let us note that

dY=Df(A,B)[dA,dB]=A@dB+dA@BdY = Df(A, B)[dA, dB] = A@dB + dA@B car f(A+dA,B+dB)f(A,B)=A@dB+dA@B+dA@dB=A@dB+dA@B\text{car } f(A+dA, B+dB)-f(A, B) = A@dB + dA@B + dA@dB = A@dB + dA@B

Expanding LfL\circ f and according to the chain rule and the definition of the differential, we have:

D(Lf)(A,B)[dA,dB]=DL(f(A,B))[Df(A,B)[dA,dB]]\mathrm D(L \circ f)(A,B)[\mathrm dA,\mathrm dB] = \mathrm DL\big(f(A,B)\big)\big[\,\mathrm Df(A,B)[\mathrm dA,\mathrm dB]\,\big] =DL(Y)[dY]=G,dY= DL(Y)[dY] = \langle G, dY \rangle =G, dA+dB=G,A@dB+G,dA@B= \langle G,\ \mathrm dA+\mathrm dB\rangle = \langle G, \mathrm A@dB \rangle + \langle G, \mathrm dA@B \rangle =A@G,dB+G@B,dA (definition of scalar product)= \langle A^\top @ G, dB \rangle + \langle G @ B^\top, dA \rangle \text{ (definition of scalar product)}

Furthermore, (definition)

D(Lf)(A,B)[dA,dB]=AL,dA+BL,dB.\mathrm D(L \circ f)(A,B)[\mathrm dA,\mathrm dB] = \langle \nabla_A L,\,\mathrm dA\rangle + \langle \nabla_B L,\,\mathrm dB\rangle.

By identification (valid for any (dA,dB)(\mathrm dA,\mathrm dB)): (Riesz)

 AL=G@B,BL =A@Gwhere G=YLY=A@B\boxed{\ \nabla_A L = G @ B^\top, \nabla_B L\ = A^\top @ G} \qquad\text{where } G=\nabla_Y L\big|_{Y=A@B}

Function application

Let L:TRL:\mathcal T\to\mathbb R be differentiable and f:TTf:\mathcal T\to\mathcal T. (applied element by element) (e.g., tanh, ReLU)

We can therefore define ff' as the derivative of ff in R\mathbb{R} (applicable to a tensor), because in reality f:RRf: \mathbb{R} \to \mathbb{R}, we have just extended it to apply it to each element of the tensor.

Let Y=f(A)Y=f(A) and G=YLY=f(A)TG=\nabla_Y L\big|_{Y=f(A)} \in \mathcal T

A piece of the graph (the application of ff) therefore looks like this:

First, note that since ff acts point by point, with i=(i1,...,iN)i = (i_1, ..., i_N) an index (TiT_i is therefore scalar)

[Df(T)[dT]]i=f(Ti)dTidY=Df(T)[dT]=f(T)dT.\big[\mathrm Df(T)[\mathrm dT]\big]_{i}=f'(T_i)\,\mathrm dT_i \quad\Longrightarrow\quad dY = \mathrm Df(T)[\mathrm dT]=f'(T)\odot \mathrm dT.

where \odot means element-by-element multiplication. (Hadamard)

Expanding LfL\circ f and according to the chain rule and the definition of the differential, we have:

D(Lf)(A)[dA]=DL(f(A))[Df(A)[dA]]\mathrm D(L \circ f)(A)[\mathrm dA] = \mathrm DL\big(f(A)\big)\big[\,\mathrm Df(A)[\mathrm dA]\,\big] =DL(Y)[dY]=G,dY= DL(Y)[dY] = \langle G, dY \rangle =G, f(A)dA=f(A)G,dA(*)= \langle G,\ f'(A) \odot dA\rangle = \langle f'(A)\odot G, dA \rangle \tag{*}

Furthermore, (definition)

D(Lf)(A)[dA]=AL,dA.\mathrm D(L \circ f)(A)[\mathrm dA] = \langle \nabla_A L,\,\mathrm dA\rangle .

By identifying (valid for any (dA,dB)(\mathrm dA,\mathrm dB)):

 AL=f(A)G,where G=YLY=f(A)\boxed{\ \nabla_A L = f'(A) \odot G, } \qquad\text{where } G=\nabla_Y L\big|_{Y=f(A)}

(*) is very simple to derive: X,YZ=xbijybijzbij\langle X, Y \odot Z \rangle = x_{bij}y_{bij}z_{bij} with bb the batch index. We can clearly see that everything commutes, and therefore X,YZ=YX,Z\langle X, Y \odot Z \rangle =\langle Y \odot X , Z \rangle