16.2. Training
The training process for the Transformer model shares similarities with RNN-based encoder-decoder models. It involves the following steps:
- The Transformer receives both a source sequence and a target sequence as input.
- The model then predicts the entire target sequence ($\text{predictions}$) at once, unlike RNN-based encoder-decoder models that predict one token at a time.
- Similar to other models, the Transformer calculates a loss value based on the difference between its predicted output ($\text{predictions}$) and the actual target sequence ($\text{expected_target_sentences}$).
- The optimizer then adjusts the model’s internal parameters (weights and bias) to minimize this loss value.
train_step_signature = [
tf.TensorSpec(shape=(None, None), dtype=tf.int64),
tf.TensorSpec(shape=(None, None), dtype=tf.int64),
]
@tf.function(input_signature=train_step_signature)
def train(source_sentences, target_sentences):
expected_target_sentences = target_sentences[:, 1:]
target_sentences = target_sentences[:, :-1]
enc_padding_mask, combined_mask, dec_padding_mask = create_masks(
source_sentences, target_sentences
)
with tf.GradientTape() as tape:
predictions, _, _ = transformer(source_sentences, target_sentences, True, enc_padding_mask, combined_mask, dec_padding_mask)
loss = loss_function(expected_target_sentences, predictions)
gradients = tape.gradient(loss, transformer.trainable_variables)
optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))
train_loss(loss)
train_accuracy(expected_target_sentences, predictions)
#
# Set n_epochs at least 10 when you do training.
#
# If n_epochs = 0, this model uses the trained parameters saved in the last checkpoint,
# allowing you to perform machine translation without retraining.
if len(sys.argv) == 2:
n_epochs = int(sys.argv[1])
else:
n_epochs = 10
for epoch in range(1, n_epochs + 1):
train_loss.reset_states()
train_accuracy.reset_states()
for (batch, (source_sentences, target_sentences)) in enumerate(dataset):
train(source_sentences, target_sentences)