9.2. Computing the gradients for Back Propagation Through Time
9.2.1. Loss function
We will use the mean squared error (MSE) as our loss function. The MSE is defined as follows:
$$ L = \frac{1}{2} (y^{(T)} - Y^{(T)})^{2} \tag{9.3} $$9.2.2. Forward Computational Graph
Fig.9-4 illustrates the forward computational graph of the GRU.
For an explanation of computational graphs, see Appendix.
9.2.3. Backward Computational Graph
To build the backward computational graph, we will calculate the derivatives between nodes of the above graph.
Since the derivatives of the dense layer have been calculated in Section 4.1, we focus on calculating the derivatives of the GRU.
$$ \begin{align} \frac{\partial h^{(t)}}{\partial z^{(t)}} &= \frac{\partial ( (1 - z^{(t)}) \odot h^{(t-1)} + z^{(t)} \odot \hat{h}^{(t)} )}{\partial z^{(t)}} = - h^{(t-1)} + \hat{h}^{(t)} = \hat{h}^{(t)} - h^{(t-1)} \\ \frac{\partial z^{(t)}}{\partial W_{z}} &= \frac{\partial \sigma(W_{z} x^{(t)} + U_{z} h^{(t-1)} + b_{z}) }{\partial W_{z}} = \sigma'(W_{z} x^{(t)} + U_{z} h^{(t-1)} + b_{z}) \ {}^t x^{(t)} \\ \frac{\partial z^{(t)}}{\partial U_{z}} &= \frac{\partial \sigma(W_{z} x^{(t)} + U_{z} h^{(t-1)} + b_{z}) }{\partial U_{z}} = \sigma'(W_{z} x^{(t)} + U_{z} h^{(t-1)} + b_{z}) \ {}^t h^{(t-1)} \\ \frac{\partial z^{(t)}}{\partial b_{z}} &= \frac{\partial \sigma(W_{z} x^{(t)} + U_{z} h^{(t-1)} + b_{z}) }{\partial b_{z}} = \sigma'(W_{z} x^{(t)} + U_{z} h^{(t-1)} + b_{z}) \\ \frac{\partial z^{(t)}}{\partial h^{(t-1)}} &= \frac{\partial \sigma(W_{z} x^{(t)} + U_{z} h^{(t-1)} + b_{z}) }{\partial h^{(t-1)}} = \sigma'(W_{z} x^{(t)} + U_{z} h^{(t-1)} + b_{z}) \ {}^t U_{z} \\ \frac{\partial h^{(t)}}{\partial \hat{h}^{(t)}} &= \frac{\partial ((1 - z^{(t)}) \odot h^{(t-1)} + z^{(t)} \odot \hat{h}^{(t)})}{\partial \hat{h}^{(t)}} = z^{(t)} \\ \frac{\partial \hat{h}^{(t)}}{\partial W} &= \frac{\partial \tanh(W x^{(t)} + ( r^{(t)} \odot h^{(t-1)}) U + b) }{\partial W} = \tanh'(W x^{(t)} + ( r^{(t)} \odot h^{(t-1)}) U + b) \ {}^t x^{(t)} \\ \frac{\partial \hat{h}^{(t)}}{\partial U} &= \frac{\partial \tanh(W x^{(t)} + ( r^{(t)} \odot h^{(t-1)}) U + b) }{\partial U} = \tanh'(W x^{(t)} + ( r^{(t)} \odot h^{(t-1)}) U + b) ( r^{(t)} \odot h^{(t-1)}) \\ \frac{\partial \hat{h}^{(t)}}{\partial b} &= \frac{\partial \tanh(W x^{(t)} + ( r^{(t)} \odot h^{(t-1)}) U + b) }{\partial b} = \tanh'(W x^{(t)} + ( r^{(t)} h^{(t-1)}) U + b) \\ \frac{\partial \hat{h}^{(t)}}{\partial h^{(t-1)}} &= \frac{\partial \tanh(W x^{(t)} + ( r^{(t)} \odot h^{(t-1)}) U + b) }{\partial h^{(t-1)}} = \tanh'(W x^{(t)} + ( r^{(t)} \odot h^{(t-1)}) U + b) (r^{(t)} \ {}^t U) \\ \frac{\partial \hat{h}^{(t)}}{\partial r^{(t)}} &= \frac{\partial \tanh(W x^{(t)} + ( r^{(t)} \odot h^{(t-1)}) U + b) }{\partial r^{(t)}} = \tanh'(W x^{(t)} + ( r^{(t)} \odot h^{(t-1)}) U + b) (h^{(t-1)} \ {}^t U) \\ \frac{\partial r^{(t)}}{\partial W_{r}} &= \frac{\partial \sigma(W_{r} x^{(t)} + U_{r} h^{(t-1)} + b_{r}) }{\partial W_{r}} = \sigma'(W_{r} x^{(t)} + U_{r} h^{(t-1)} + b_{r}) \ {}^t x^{(t)} \\ \frac{\partial r^{(t)}}{\partial U_{r}} &= \frac{\partial \sigma(W_{r} x^{(t)} + U_{r} h^{(t-1)} + b_{r}) }{\partial U_{r}} = \sigma'(W_{r} x^{(t)} + U_{r} h^{(t-1)} + b_{r}) \ {}^t h^{(t-1)} \\ \frac{\partial r^{(t)}}{\partial b_{r}} &= \frac{\partial \sigma(W_{r} x^{(t)} + U_{r} h^{(t-1)} + b_{r}) }{\partial b_{r}} = \sigma'(W_{r} x^{(t)} + U_{r} h^{(t-1)} + b_{r}) \\ \frac{\partial r^{(t)}}{\partial h^{(t-1)}} &= \frac{\partial \sigma(W_{r} x^{(t)} + U_{r} h^{(t-1)} + b_{r}) }{\partial h^{(t-1)}} = \sigma'(W_{r} x^{(t)} + U_{r} h^{(t-1)} + b_{r}) \ {}^t U_{r} \\ \frac{\partial h^{(t)}}{\partial h^{(t-1)}} &= \frac{\partial ( (1 - z^{(t)}) \odot h^{(t-1)} + z^{(t)} \odot \hat{h}^{(t)} )}{\partial h^{(t-1)}} = (1 - z^{(t)}) \end{align} $$To avoid confusion, we express the transpose of a vector or matrix $ A $ as $ {}^tA$, instead of $A^{T}$, in this section.
Using these derivatives we calculated, we can build the backward computational graph shown in Fig.9-5.
To simplify the following discussion, we define the following expressions:
$$ \begin{align} \text{grad}_{dense}^{(T)} \stackrel{\mathrm{def}}{=} \frac{\partial L}{\partial h^{(T)}} \tag{9.4} \\ dh^{(t)} \stackrel{\mathrm{def}}{=} \frac{\partial L}{\partial h^{(t)}} \tag{9.5} \end{align} $$- $\text{grad}_{dense}^{(T)}$ is the gradient propagated from the dense layer.
- $dh^{(t)}$ is the gradient propagated from the hidden state $h^{(t+1)}$.
In this document, “$d$” denotes the gradient (e.g., $dL, dh$), not the total derivative.
The following relation is satisfied by definition:
$$ dh^{(T)} = \text{grad}_{dense}^{(T)} $$Next, we will calculate $dh^{(T-1)}$. Fig.9-6 illustrates the relationship between $h^{(T)}$ and $h^{(T-1)}$, which is extracted from Fig.9-5.
As shown in Fig.9-6, $ dh^{(T-1)} $ can be calculated from $ dh^{(T)} $. There are four paths from $ h^{(T)} $ to $h^{(T-1)} $, we therefore add them together.
$$ \begin{align} dh^{(T-1)} &= \frac{\partial L}{\partial h^{(T)}} \left[ \frac{\partial h^{(T)}}{\partial z^{(T)}} \frac{\partial z^{(T)}}{\partial h^{(T-1)}} + \frac{\partial h^{(T)}}{\partial \hat{h}^{(T)}} \frac{\partial \hat{h}^{(T)}}{\partial h^{(T-1)}} + \frac{\partial h^{(T)}}{\partial \hat{h}^{(T)}} \frac{\partial \hat{h}^{(T)}}{\partial r^{(T)}} \frac{\partial r^{(T)}}{\partial h^{(T-1)}} + \frac{\partial h^{(T)}}{\partial h^{(T-1)}} \right] \\ &= dh^{(T)} (\hat{h}^{(T)} - h^{(T-1)}) \sigma'(W_{z} x^{(T)} + U_{z} h^{(T-1)} + b_{z}) \ {}^t U_{z} \\ & \quad + \ dh^{(T)} z^{(T)} \tanh'(W x^{(T)} + ( r^{(T)} \odot h^{(T-1)}) U + b) (r^{(T)} \ {}^t U) \\ & \quad + \ dh^{(T)} z^{(T)} \tanh'(W x^{(T)} + ( r^{(T)} \odot h^{(T-1)}) U + b) (h^{(T-1)} \ {}^t U) \sigma'(W_{r} x^{(T)} + U_{r} h^{(T-1)} + b_{r}) \ {}^t U_{r} \\ & \quad + \ dh^{(T)} (1 - z^{(T)}) \tag{9.6} \end{align} $$Similarly, $ dh^{(t)} $ can be also calculated recursively.
$$ dh^{(t)} = \begin{cases} \text{grad}_{dense}^{(T)} & t = T \\ \\ \begin{align} & dh^{(t+1)} (\hat{h}^{(t+1)} - h^{(t)}) \sigma'(W_{z} x^{(t+1)} + U_{z} h^{(t)} + b_{z}) \ {}^t U_{z} \\ & \quad + \ dh^{(t+1)} z^{(t+1)} \tanh'(W x^{(t+1)} + ( r^{(t+1)} \odot h^{(t)}) U + b) (r^{(t+1)} \ {}^t U) \\ & \quad + \ dh^{(t+1)} z^{(t+1)} \tanh'(W x^{(t+1)} + ( r^{(t+1)} \odot h^{(t)}) U + b) (h^{(t)} \ {}^t U) \sigma'(W_{r} x^{(t+1)} + U_{r} h^{(t)} + b_{r}) \ {}^t U_{r} \\ & \quad + \ dh^{(t+1)} (1 - z^{(t+1)}) \end{align} & 0 \le t \lt T \end{cases} \tag{9.7} $$9.2.4. Gradients
Using the results above, we finally obtain the gradients shown below:
$$ \begin{align} dU_{z} &\stackrel{\mathrm{def}}{=} \frac{\partial L}{\partial U_{z}} = \sum_{t=1}^{T} \frac{\partial L}{\partial h^{(t)}} \frac{\partial h^{(t)}} {\partial z^{(t)}} \frac{\partial z^{(t)}}{\partial U_{z}} \\ &= \sum_{t=1}^{T} dh^{(t)} (\hat{h}^{(t)} - h^{(t-1)}) \sigma'(W_{z} x^{(t)} + U_{z} h^{(t-1)} + b_{z}) \ {}^t h^{(t-1)} \tag{9.8} \\ dW_{z} &\stackrel{\mathrm{def}}{=} \frac{\partial L}{\partial W_{z}} = \sum_{t=0}^{T} \frac{\partial L}{\partial h^{(t)}} \frac{\partial h^{(t)}} {\partial z^{(t)}} \frac{\partial z^{(t)}}{\partial W_{z}} \\ &= dh^{(0)} \hat{h}^{(0)} \sigma'(W_{z} x^{(0)} + + b_{z}) x^{(0)} + \sum_{t=1}^{T} dh^{(t)} (\hat{h}^{(t)} - h^{(t-1)}) \sigma'(W_{z} x^{(t)} + U_{z} h^{(t-1)} + b_{z}) \ {}^t x^{(t)} \tag{9.9} \\ db_{z} &\stackrel{\mathrm{def}}{=} \frac{\partial L}{\partial b_{z}} = \sum_{t=0}^{T} \frac{\partial L}{\partial h^{(t)}} \frac{\partial h^{(t)}} {\partial z^{(t)}} \frac{\partial z^{(t)}}{\partial b_{z}} \\ &= dh^{(0)} \hat{h}^{(0)} \sigma'(W_{z} x^{(0)} + + b_{z}) + \sum_{t=1}^{T} dh^{(t)} (\hat{h}^{(t)} - h^{(t-1)}) \sigma'(W_{z} x^{(t)} + U_{z} h^{(t-1)} + b_{z}) \tag{9.10} \\ dU_{r} &\stackrel{\mathrm{def}}{=} \frac{\partial L}{\partial U_{r}} = \sum_{t=0}^{T} \frac{\partial L}{\partial h^{(t)}} \frac{\partial h^{(t)}} {\partial \hat{h}^{(t)}} \frac{\partial \hat{h}^{(t)}} {\partial r^{(t)}} \frac{\partial r^{(t)}} {\partial U_{r}} \\ &= \sum_{t=1}^{T} dh^{(t)} z^{(t)} \tanh'(W x^{(t)} + ( r^{(t)} \odot h^{(t-1)}) U + b) (h^{(t-1)} \ {}^t U) \sigma'(W_{r} x^{(t)} + U_{r} h^{(t-1)} + b_{r}) \ {}^t h^{(t-1)} \tag{9.11} \\ dW_{r} &\stackrel{\mathrm{def}}{=} \frac{\partial L}{\partial W_{r}} = \sum_{t=0}^{T} \frac{\partial L}{\partial h^{(t)}} \frac{\partial h^{(t)}} {\partial \hat{h}^{(t)}} \frac{\partial \hat{h}^{(t)}} {\partial r^{(t)}} \frac{\partial r^{(t)}} {\partial W_{r}} \\ &= \sum_{t=1}^{T} dh^{(t)} z^{(t)} \tanh'(W x^{(t)} + ( r^{(t)} \odot h^{(t-1)}) U + b) (h^{(t-1)} \ {}^t U) \sigma'(W_{r} x^{(t)} + U_{r} h^{(t-1)} + b_{r}) \ {}^t x^{(t)} \tag{9.12} \\ db_{r} &\stackrel{\mathrm{def}}{=} \frac{\partial L}{\partial b_{r}} = \sum_{t=0}^{T} \frac{\partial L}{\partial h^{(t)}} \frac{\partial h^{(t)}} {\partial \hat{h}^{(t)}} \frac{\partial \hat{h}^{(t)}} {\partial r^{(t)}} \frac{\partial r^{(t)}} {\partial b_{r}}\\ &= \sum_{t=1}^{T} dh^{(t)} z^{(t)} \tanh'(W x^{(t)} + ( r^{(t)} \odot h^{(t-1)}) U + b) (h^{(t-1)} \ {}^t U) \sigma'(W_{r} x^{(t)} + U_{r} h^{(t-1)} + b_{r}) \tag{9.13} \\ dU &\stackrel{\mathrm{def}}{=} \frac{\partial L}{\partial U} = \sum_{t=0}^{T} \frac{\partial L}{\partial h^{(t)}} \frac{\partial h^{(t)}} {\partial \hat{h}^{(t)}} \frac{\partial \hat{h}^{(t)}} {\partial U} \\ &= \sum_{t=1}^{T} dh^{(t)} z^{(t)} \tanh'(W x^{(t)} + ( r^{(t)} \odot h^{(t-1)}) U + b) ( r^{(t)} \odot h^{(t-1)}) \tag{9.14} \\ dW &\stackrel{\mathrm{def}}{=} \frac{\partial L}{\partial W} = \sum_{t=0}^{T} \frac{\partial L}{\partial h^{(t)}} \frac{\partial h^{(t)}} {\partial \hat{h}^{(t)}} \frac{\partial \hat{h}^{(t)}} {\partial W} \\ &= dh^{(0)} z^{(0)} \tanh'(W x^{(0)} + b) x^{(0)} + \sum_{t=1}^{T} dh^{(t)} z^{(t)} \tanh'(W x^{(t)} + ( r^{(t)} \odot h^{(t-1)}) U + b) \ {}^t x^{(t)} \tag{9.15} \\ db &\stackrel{\mathrm{def}}{=} \frac{\partial L}{\partial b} = \sum_{t=0}^{T} \frac{\partial L}{\partial h^{(t)}} \frac{\partial h^{(t)}} {\partial \hat{h}^{(t)}} \frac{\partial \hat{h}^{(t)}} {\partial b} \\ &= dh^{(0)} z^{(0)} \tanh'(W x^{(0)} + b) + \sum_{t=1}^{T} dh^{(t)} z^{(t)} \tanh'(W x^{(t)} + ( r^{(t)} \odot h^{(t-1)}) U + b) \tag{9.16} \end{align} $$