16.2. Training

The training process for the Transformer model shares similarities with RNN-based encoder-decoder models. It involves the following steps:

  1. The Transformer receives both a source sequence and a target sequence as input.
  2. The model then predicts the entire target sequence ($\text{predictions}$) at once, unlike RNN-based encoder-decoder models that predict one token at a time.
  3. Similar to other models, the Transformer calculates a loss value based on the difference between its predicted output ($\text{predictions}$) and the actual target sequence ($\text{expected_target_sentences}$).
  4. The optimizer then adjusts the model’s internal parameters (weights and bias) to minimize this loss value.
train_step_signature = [
    tf.TensorSpec(shape=(None, None), dtype=tf.int64),
    tf.TensorSpec(shape=(None, None), dtype=tf.int64),
]


@tf.function(input_signature=train_step_signature)
def train(source_sentences, target_sentences):

    expected_target_sentences = target_sentences[:, 1:]
    target_sentences = target_sentences[:, :-1]

    enc_padding_mask, combined_mask, dec_padding_mask = create_masks(
        source_sentences, target_sentences
    )

    with tf.GradientTape() as tape:
        predictions, _, _ = transformer(source_sentences, target_sentences, True, enc_padding_mask, combined_mask, dec_padding_mask)
        loss = loss_function(expected_target_sentences, predictions)

    gradients = tape.gradient(loss, transformer.trainable_variables)

    optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))

    train_loss(loss)
    train_accuracy(expected_target_sentences, predictions)


#
# Set n_epochs at least 10 when you do training.
#
# If n_epochs = 0, this model uses the trained parameters saved in the last checkpoint,
# allowing you to perform machine translation without retraining.
if len(sys.argv) == 2:
    n_epochs = int(sys.argv[1])
else:
    n_epochs = 10


for epoch in range(1, n_epochs + 1):

    train_loss.reset_states()
    train_accuracy.reset_states()

    for (batch, (source_sentences, target_sentences)) in enumerate(dataset):
        train(source_sentences, target_sentences)