9.2. Computing the gradients for Back Propagation Through Time

9.2.1. Loss function

We will use the mean squared error (MSE) as our loss function. The MSE is defined as follows:

$$ L = \frac{1}{2} (y^{(T)} - Y^{(T)})^{2} \tag{9.3} $$

9.2.2. Forward Computational Graph

Fig.9-4 illustrates the forward computational graph of the GRU.

Fig.9-4: Forward Computational Graph of Many-to-One GRU
Computational Graph

For an explanation of computational graphs, see Appendix.

9.2.3. Backward Computational Graph

To build the backward computational graph, we will calculate the derivatives between nodes of the above graph.

Since the derivatives of the dense layer have been calculated in Section 4.1, we focus on calculating the derivatives of the GRU.

$$ \begin{align} \frac{\partial h^{(t)}}{\partial z^{(t)}} &= \frac{\partial ( (1 - z^{(t)}) \odot h^{(t-1)} + z^{(t)} \odot \hat{h}^{(t)} )}{\partial z^{(t)}} = - h^{(t-1)} + \hat{h}^{(t)} = \hat{h}^{(t)} - h^{(t-1)} \\ \frac{\partial z^{(t)}}{\partial W_{z}} &= \frac{\partial \sigma(W_{z} x^{(t)} + U_{z} h^{(t-1)} + b_{z}) }{\partial W_{z}} = \sigma'(W_{z} x^{(t)} + U_{z} h^{(t-1)} + b_{z}) \ {}^t x^{(t)} \\ \frac{\partial z^{(t)}}{\partial U_{z}} &= \frac{\partial \sigma(W_{z} x^{(t)} + U_{z} h^{(t-1)} + b_{z}) }{\partial U_{z}} = \sigma'(W_{z} x^{(t)} + U_{z} h^{(t-1)} + b_{z}) \ {}^t h^{(t-1)} \\ \frac{\partial z^{(t)}}{\partial b_{z}} &= \frac{\partial \sigma(W_{z} x^{(t)} + U_{z} h^{(t-1)} + b_{z}) }{\partial b_{z}} = \sigma'(W_{z} x^{(t)} + U_{z} h^{(t-1)} + b_{z}) \\ \frac{\partial z^{(t)}}{\partial h^{(t-1)}} &= \frac{\partial \sigma(W_{z} x^{(t)} + U_{z} h^{(t-1)} + b_{z}) }{\partial h^{(t-1)}} = \sigma'(W_{z} x^{(t)} + U_{z} h^{(t-1)} + b_{z}) \ {}^t U_{z} \\ \frac{\partial h^{(t)}}{\partial \hat{h}^{(t)}} &= \frac{\partial ((1 - z^{(t)}) \odot h^{(t-1)} + z^{(t)} \odot \hat{h}^{(t)})}{\partial \hat{h}^{(t)}} = z^{(t)} \\ \frac{\partial \hat{h}^{(t)}}{\partial W} &= \frac{\partial \tanh(W x^{(t)} + ( r^{(t)} \odot h^{(t-1)}) U + b) }{\partial W} = \tanh'(W x^{(t)} + ( r^{(t)} \odot h^{(t-1)}) U + b) \ {}^t x^{(t)} \\ \frac{\partial \hat{h}^{(t)}}{\partial U} &= \frac{\partial \tanh(W x^{(t)} + ( r^{(t)} \odot h^{(t-1)}) U + b) }{\partial U} = \tanh'(W x^{(t)} + ( r^{(t)} \odot h^{(t-1)}) U + b) ( r^{(t)} \odot h^{(t-1)}) \\ \frac{\partial \hat{h}^{(t)}}{\partial b} &= \frac{\partial \tanh(W x^{(t)} + ( r^{(t)} \odot h^{(t-1)}) U + b) }{\partial b} = \tanh'(W x^{(t)} + ( r^{(t)} h^{(t-1)}) U + b) \\ \frac{\partial \hat{h}^{(t)}}{\partial h^{(t-1)}} &= \frac{\partial \tanh(W x^{(t)} + ( r^{(t)} \odot h^{(t-1)}) U + b) }{\partial h^{(t-1)}} = \tanh'(W x^{(t)} + ( r^{(t)} \odot h^{(t-1)}) U + b) (r^{(t)} \ {}^t U) \\ \frac{\partial \hat{h}^{(t)}}{\partial r^{(t)}} &= \frac{\partial \tanh(W x^{(t)} + ( r^{(t)} \odot h^{(t-1)}) U + b) }{\partial r^{(t)}} = \tanh'(W x^{(t)} + ( r^{(t)} \odot h^{(t-1)}) U + b) (h^{(t-1)} \ {}^t U) \\ \frac{\partial r^{(t)}}{\partial W_{r}} &= \frac{\partial \sigma(W_{r} x^{(t)} + U_{r} h^{(t-1)} + b_{r}) }{\partial W_{r}} = \sigma'(W_{r} x^{(t)} + U_{r} h^{(t-1)} + b_{r}) \ {}^t x^{(t)} \\ \frac{\partial r^{(t)}}{\partial U_{r}} &= \frac{\partial \sigma(W_{r} x^{(t)} + U_{r} h^{(t-1)} + b_{r}) }{\partial U_{r}} = \sigma'(W_{r} x^{(t)} + U_{r} h^{(t-1)} + b_{r}) \ {}^t h^{(t-1)} \\ \frac{\partial r^{(t)}}{\partial b_{r}} &= \frac{\partial \sigma(W_{r} x^{(t)} + U_{r} h^{(t-1)} + b_{r}) }{\partial b_{r}} = \sigma'(W_{r} x^{(t)} + U_{r} h^{(t-1)} + b_{r}) \\ \frac{\partial r^{(t)}}{\partial h^{(t-1)}} &= \frac{\partial \sigma(W_{r} x^{(t)} + U_{r} h^{(t-1)} + b_{r}) }{\partial h^{(t-1)}} = \sigma'(W_{r} x^{(t)} + U_{r} h^{(t-1)} + b_{r}) \ {}^t U_{r} \\ \frac{\partial h^{(t)}}{\partial h^{(t-1)}} &= \frac{\partial ( (1 - z^{(t)}) \odot h^{(t-1)} + z^{(t)} \odot \hat{h}^{(t)} )}{\partial h^{(t-1)}} = (1 - z^{(t)}) \end{align} $$
Note

To avoid confusion, we express the transpose of a vector or matrix $ A $ as $ {}^tA$, instead of $A^{T}$, in this section.

Using these derivatives we calculated, we can build the backward computational graph shown in Fig.9-5.

Fig.9-5: Backward Computational Graph of Many-to-One GRU

To simplify the following discussion, we define the following expressions:

$$ \begin{align} \text{grad}_{dense}^{(T)} \stackrel{\mathrm{def}}{=} \frac{\partial L}{\partial h^{(T)}} \tag{9.4} \\ dh^{(t)} \stackrel{\mathrm{def}}{=} \frac{\partial L}{\partial h^{(t)}} \tag{9.5} \end{align} $$
  • $\text{grad}_{dense}^{(T)}$ is the gradient propagated from the dense layer.
  • $dh^{(t)}$ is the gradient propagated from the hidden state $h^{(t+1)}$.
Notation

In this document, “$d$” denotes the gradient (e.g., $dL, dh$), not the total derivative.

The following relation is satisfied by definition:

$$ dh^{(T)} = \text{grad}_{dense}^{(T)} $$

Next, we will calculate $dh^{(T-1)}$. Fig.9-6 illustrates the relationship between $h^{(T)}$ and $h^{(T-1)}$, which is extracted from Fig.9-5.

Fig.9-6: Relationship Between $h^{(T)}$ and $h^{(T-1)}$

As shown in Fig.9-6, $ dh^{(T-1)} $ can be calculated from $ dh^{(T)} $. There are four paths from $ h^{(T)} $ to $h^{(T-1)} $, we therefore add them together.

$$ \begin{align} dh^{(T-1)} &= \frac{\partial L}{\partial h^{(T)}} \left[ \frac{\partial h^{(T)}}{\partial z^{(T)}} \frac{\partial z^{(T)}}{\partial h^{(T-1)}} + \frac{\partial h^{(T)}}{\partial \hat{h}^{(T)}} \frac{\partial \hat{h}^{(T)}}{\partial h^{(T-1)}} + \frac{\partial h^{(T)}}{\partial \hat{h}^{(T)}} \frac{\partial \hat{h}^{(T)}}{\partial r^{(T)}} \frac{\partial r^{(T)}}{\partial h^{(T-1)}} + \frac{\partial h^{(T)}}{\partial h^{(T-1)}} \right] \\ &= dh^{(T)} (\hat{h}^{(T)} - h^{(T-1)}) \sigma'(W_{z} x^{(T)} + U_{z} h^{(T-1)} + b_{z}) \ {}^t U_{z} \\ & \quad + \ dh^{(T)} z^{(T)} \tanh'(W x^{(T)} + ( r^{(T)} \odot h^{(T-1)}) U + b) (r^{(T)} \ {}^t U) \\ & \quad + \ dh^{(T)} z^{(T)} \tanh'(W x^{(T)} + ( r^{(T)} \odot h^{(T-1)}) U + b) (h^{(T-1)} \ {}^t U) \sigma'(W_{r} x^{(T)} + U_{r} h^{(T-1)} + b_{r}) \ {}^t U_{r} \\ & \quad + \ dh^{(T)} (1 - z^{(T)}) \tag{9.6} \end{align} $$

Similarly, $ dh^{(t)} $ can be also calculated recursively.

$$ dh^{(t)} = \begin{cases} \text{grad}_{dense}^{(T)} & t = T \\ \\ \begin{align} & dh^{(t+1)} (\hat{h}^{(t+1)} - h^{(t)}) \sigma'(W_{z} x^{(t+1)} + U_{z} h^{(t)} + b_{z}) \ {}^t U_{z} \\ & \quad + \ dh^{(t+1)} z^{(t+1)} \tanh'(W x^{(t+1)} + ( r^{(t+1)} \odot h^{(t)}) U + b) (r^{(t+1)} \ {}^t U) \\ & \quad + \ dh^{(t+1)} z^{(t+1)} \tanh'(W x^{(t+1)} + ( r^{(t+1)} \odot h^{(t)}) U + b) (h^{(t)} \ {}^t U) \sigma'(W_{r} x^{(t+1)} + U_{r} h^{(t)} + b_{r}) \ {}^t U_{r} \\ & \quad + \ dh^{(t+1)} (1 - z^{(t+1)}) \end{align} & 0 \le t \lt T \end{cases} \tag{9.7} $$

9.2.4. Gradients

Using the results above, we finally obtain the gradients shown below:

$$ \begin{align} dU_{z} &\stackrel{\mathrm{def}}{=} \frac{\partial L}{\partial U_{z}} = \sum_{t=1}^{T} \frac{\partial L}{\partial h^{(t)}} \frac{\partial h^{(t)}} {\partial z^{(t)}} \frac{\partial z^{(t)}}{\partial U_{z}} \\ &= \sum_{t=1}^{T} dh^{(t)} (\hat{h}^{(t)} - h^{(t-1)}) \sigma'(W_{z} x^{(t)} + U_{z} h^{(t-1)} + b_{z}) \ {}^t h^{(t-1)} \tag{9.8} \\ dW_{z} &\stackrel{\mathrm{def}}{=} \frac{\partial L}{\partial W_{z}} = \sum_{t=0}^{T} \frac{\partial L}{\partial h^{(t)}} \frac{\partial h^{(t)}} {\partial z^{(t)}} \frac{\partial z^{(t)}}{\partial W_{z}} \\ &= dh^{(0)} \hat{h}^{(0)} \sigma'(W_{z} x^{(0)} + + b_{z}) x^{(0)} + \sum_{t=1}^{T} dh^{(t)} (\hat{h}^{(t)} - h^{(t-1)}) \sigma'(W_{z} x^{(t)} + U_{z} h^{(t-1)} + b_{z}) \ {}^t x^{(t)} \tag{9.9} \\ db_{z} &\stackrel{\mathrm{def}}{=} \frac{\partial L}{\partial b_{z}} = \sum_{t=0}^{T} \frac{\partial L}{\partial h^{(t)}} \frac{\partial h^{(t)}} {\partial z^{(t)}} \frac{\partial z^{(t)}}{\partial b_{z}} \\ &= dh^{(0)} \hat{h}^{(0)} \sigma'(W_{z} x^{(0)} + + b_{z}) + \sum_{t=1}^{T} dh^{(t)} (\hat{h}^{(t)} - h^{(t-1)}) \sigma'(W_{z} x^{(t)} + U_{z} h^{(t-1)} + b_{z}) \tag{9.10} \\ dU_{r} &\stackrel{\mathrm{def}}{=} \frac{\partial L}{\partial U_{r}} = \sum_{t=0}^{T} \frac{\partial L}{\partial h^{(t)}} \frac{\partial h^{(t)}} {\partial \hat{h}^{(t)}} \frac{\partial \hat{h}^{(t)}} {\partial r^{(t)}} \frac{\partial r^{(t)}} {\partial U_{r}} \\ &= \sum_{t=1}^{T} dh^{(t)} z^{(t)} \tanh'(W x^{(t)} + ( r^{(t)} \odot h^{(t-1)}) U + b) (h^{(t-1)} \ {}^t U) \sigma'(W_{r} x^{(t)} + U_{r} h^{(t-1)} + b_{r}) \ {}^t h^{(t-1)} \tag{9.11} \\ dW_{r} &\stackrel{\mathrm{def}}{=} \frac{\partial L}{\partial W_{r}} = \sum_{t=0}^{T} \frac{\partial L}{\partial h^{(t)}} \frac{\partial h^{(t)}} {\partial \hat{h}^{(t)}} \frac{\partial \hat{h}^{(t)}} {\partial r^{(t)}} \frac{\partial r^{(t)}} {\partial W_{r}} \\ &= \sum_{t=1}^{T} dh^{(t)} z^{(t)} \tanh'(W x^{(t)} + ( r^{(t)} \odot h^{(t-1)}) U + b) (h^{(t-1)} \ {}^t U) \sigma'(W_{r} x^{(t)} + U_{r} h^{(t-1)} + b_{r}) \ {}^t x^{(t)} \tag{9.12} \\ db_{r} &\stackrel{\mathrm{def}}{=} \frac{\partial L}{\partial b_{r}} = \sum_{t=0}^{T} \frac{\partial L}{\partial h^{(t)}} \frac{\partial h^{(t)}} {\partial \hat{h}^{(t)}} \frac{\partial \hat{h}^{(t)}} {\partial r^{(t)}} \frac{\partial r^{(t)}} {\partial b_{r}}\\ &= \sum_{t=1}^{T} dh^{(t)} z^{(t)} \tanh'(W x^{(t)} + ( r^{(t)} \odot h^{(t-1)}) U + b) (h^{(t-1)} \ {}^t U) \sigma'(W_{r} x^{(t)} + U_{r} h^{(t-1)} + b_{r}) \tag{9.13} \\ dU &\stackrel{\mathrm{def}}{=} \frac{\partial L}{\partial U} = \sum_{t=0}^{T} \frac{\partial L}{\partial h^{(t)}} \frac{\partial h^{(t)}} {\partial \hat{h}^{(t)}} \frac{\partial \hat{h}^{(t)}} {\partial U} \\ &= \sum_{t=1}^{T} dh^{(t)} z^{(t)} \tanh'(W x^{(t)} + ( r^{(t)} \odot h^{(t-1)}) U + b) ( r^{(t)} \odot h^{(t-1)}) \tag{9.14} \\ dW &\stackrel{\mathrm{def}}{=} \frac{\partial L}{\partial W} = \sum_{t=0}^{T} \frac{\partial L}{\partial h^{(t)}} \frac{\partial h^{(t)}} {\partial \hat{h}^{(t)}} \frac{\partial \hat{h}^{(t)}} {\partial W} \\ &= dh^{(0)} z^{(0)} \tanh'(W x^{(0)} + b) x^{(0)} + \sum_{t=1}^{T} dh^{(t)} z^{(t)} \tanh'(W x^{(t)} + ( r^{(t)} \odot h^{(t-1)}) U + b) \ {}^t x^{(t)} \tag{9.15} \\ db &\stackrel{\mathrm{def}}{=} \frac{\partial L}{\partial b} = \sum_{t=0}^{T} \frac{\partial L}{\partial h^{(t)}} \frac{\partial h^{(t)}} {\partial \hat{h}^{(t)}} \frac{\partial \hat{h}^{(t)}} {\partial b} \\ &= dh^{(0)} z^{(0)} \tanh'(W x^{(0)} + b) + \sum_{t=1}^{T} dh^{(t)} z^{(t)} \tanh'(W x^{(t)} + ( r^{(t)} \odot h^{(t-1)}) U + b) \tag{9.16} \end{align} $$