Machine Translation using Sequence-to-Sequence Learning (on GPU)

An Annotated Introduction to LSTM-based Encoder-Decoder Models

In this article we're teaching a Recurrent Neural Network (RNN) model based on two LSTM layers to translate english sentences to german inspired by a great tutorial on the official Keras blog.

In sequence-to-sequence learning we want to convert input sequences, in the general case of arbitraty length, to sequences in another domain. An obvious application of this is machine translation.

'Go on.' -> [Sequence-to-Sequence Model] -> 'Mach weiter.'

Data Preparation

Luckily there are quite a few datasets available here. They all consist of sentence pairs delimited by a tab. In the case of the german-english dataset we'll be using there are nil sentence pairs prepared by the Tatoeba Project. But before we can feed our model some english sentences we first have to process the data such that we can use it as input to Keras' LSTM layers.

Here we have the dataset (all in one .txt file) already uploaded. We can load the sentence pairs from this by creating a reference (via @...) in our Python runtime.


Let's create a list of lines by splitting the text file at every occurance of '\n'.

with open(deu-eng.txt, 'r', encoding='utf-8') as f:
  lines ='\n')

Let's look at an example.


Sweet! So we have both the input (english) and the target (german) sentences in every line, separated by '\t'.


Let's go ahead and split each line into input text and target text. Since we'll do the translation character by character we also want to compute a set of every character we encounter in the dataset, both for inputs as well as targets.

input_texts = []
target_texts = []
input_characters = set()
target_characters = set()

Now we can loop over every sample we choose for training and fill the lists and sets. We'll also add '\t' (start-of-sequence) and '\n' (end-of-sequence) characters to every target text. This will later help our model determine when to start and - more importantly - end sequences. We need this due to the fact that we don't know a-priori how long the output sequences should be. That's why we teach our model to decide on that by itself.

num_samples = 10000
for line in lines[: min(num_samples, len(lines) - 1)]:
  input_text, target_text = line.split('\t')
  target_text = '\t' + target_text + '\n'
  for char in input_text:
    if char not in input_characters:
  for char in target_text:
    if char not in target_characters:
input_characters = sorted(list(input_characters))
target_characters = sorted(list(target_characters))
num_encoder_tokens = len(input_characters)
num_decoder_tokens = len(target_characters)
max_encoder_seq_length = max([len(txt) for txt in input_texts])
max_decoder_seq_length = max([len(txt) for txt in target_texts])

print('Number of samples:', len(input_texts))
print('Number of unique input tokens:', num_encoder_tokens)
print('Number of unique output tokens:', num_decoder_tokens)
print('Max sequence length for inputs:', max_encoder_seq_length)
print('Max sequence length for outputs:', max_decoder_seq_length)

By now we have nil samples consisting of input/target texts. Along the input texts we have 70 unique characters, while the target texts contain 87 unique characters.


This is, in part, due to these weird german umlauts.

Another characteristic for the german language is that sentences tend to be a bit longer than their english counterparts (maybe you can find a way to test that hypothesis for yourself). In any case, the longest target sequence in the 10000 sample sentences we're using contains 53 characters, while the longest input sequence only contains 16.

But these input texts still don't work as input to our model. We'll need to convert the characters into numeric values. In our case one-hot encodings are fine, but when turning to more involved models, using more advanced embedding methods such as Word2Vec would make more sense.

Here we first tokenize our characters by assigning each unique character to an integer value.

input_token_index = dict(
  [(char, i) for i, char in enumerate(input_characters)])
target_token_index = dict(
  [(char, i) for i, char in enumerate(target_characters)])

With that we can start creating numeric data. We'll need input data for both the encoder and the decoder of the model, as well as the target data (used only in the decoder part).

import numpy as np

encoder_input_data = np.zeros(
  (len(input_texts), max_encoder_seq_length, num_encoder_tokens),
decoder_input_data = np.zeros(
  (len(input_texts), max_decoder_seq_length, num_decoder_tokens),
decoder_target_data = np.zeros(
  (len(input_texts), max_decoder_seq_length, num_decoder_tokens),

The encoder_input_data will consist of nil samples of the maximum sequence length (16) filles with the respective one-hot-encoded tokens (in this case a vector of length 70).

The decoder_input_data as and the decoder_target_data are both constructed in the same way as the input data for the encoder. Why we suddenly need two kinds of inputs for the encoder and decoder will become apparent later when looking at the model in more detail.

Time to fill in the data with the actual tokens.

for i, (input_text, target_text) in enumerate(zip(input_texts, target_texts)):
  for t, char in enumerate(input_text):
    encoder_input_data[i, t, input_token_index[char]] = 1.
  for t, char in enumerate(target_text):
    # decoder_target_data is ahead of decoder_input_data by one timestep
    decoder_input_data[i, t, target_token_index[char]] = 1.
    if t > 0:
      # decoder_target_data will be ahead by one timestep
      # and will not include the start character.
      decoder_target_data[i, t - 1, target_token_index[char]] = 1.

With that our example sentence nil turned into a sequence of length 16 with one-hot encodings of tokens for every character.


Building the Model

Now it's time to take a closer look at our encoder-decoder model. Our model will consist of two LSTMs. One will serve as an encoder, encoding the input sequence and producing internal state vectors which serve as conditioning for the decoder.

import keras, tensorflow
from keras.models import Model
from keras.layers import Input, LSTM, Dense
import numpy as np
batch_size = 64  # Batch size for training.
epochs = 100  # Number of epochs to train for.
latent_dim = 256  # Latent dimensionality of the encoding space.
encoder_inputs = Input(shape=(None, num_encoder_tokens))
encoder = LSTM(latent_dim, return_state=True)
encoder_outputs, state_h, state_c = encoder(encoder_inputs)
# We discard `encoder_outputs` and only keep the states.
encoder_states = [state_h, state_c]
# Set up the decoder, using `encoder_states` as initial state.
decoder_inputs = Input(shape=(None, num_decoder_tokens))
# We set up our decoder to return full output sequences,
# and to return internal states as well. We don't use the
# return states in the training model, but we will use them in inference.
decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True)
decoder_outputs, _, _ = decoder_lstm(decoder_inputs,
decoder_dense = Dense(num_decoder_tokens, activation='softmax')
decoder_outputs = decoder_dense(decoder_outputs)
# Define the model that will turn
# `encoder_input_data` & `decoder_input_data` into `decoder_target_data`
model = Model([encoder_inputs, decoder_inputs], decoder_outputs)
# Compile Model
model.compile(optimizer='rmsprop', loss='categorical_crossentropy')
model.summary()[encoder_input_data, decoder_input_data], decoder_target_data,
# Save model'/results/seq2seq_eng-ger.h5')
model = Model([encoder_inputs, decoder_inputs], decoder_outputs)
model.compile(optimizer='rmsprop', loss='categorical_crossentropy')

Inference mode (sampling).

1) encode input and retrieve initial decoder state

2) run one step of decoder with this initial state and a "start of sequence" token as target.

Output will be the next target token

3) Repeat with the current target token and current states

encoder_model = Model(encoder_inputs, encoder_states)

decoder_state_input_h = Input(shape=(latent_dim,))
decoder_state_input_c = Input(shape=(latent_dim,))
decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]

decoder_outputs, state_h, state_c = decoder_lstm(
  decoder_inputs, initial_state=decoder_states_inputs)
decoder_states = [state_h, state_c]
decoder_outputs = decoder_dense(decoder_outputs)

decoder_model = Model(
  [decoder_inputs] + decoder_states_inputs,
  [decoder_outputs] + decoder_states)
# Reverse-lookup token index to decode sequences back to
# something readable.
reverse_input_char_index = dict(
  (i, char) for char, i in input_token_index.items())
reverse_target_char_index = dict(
  (i, char) for char, i in target_token_index.items())
def decode_sequence(input_seq):
  # Encode the input as state vectors.
  states_value = encoder_model.predict(input_seq)
  # Generate empty target sequence of length 1.
  target_seq = np.zeros((1, 1, num_decoder_tokens))
  # Populate the first character of target sequence with the start character.
  target_seq[0, 0, target_token_index['\t']] = 1.
  # Sampling loop for a batch of sequences
  # (to simplify, here we assume a batch of size 1).
  stop_condition = False
  decoded_sentence = ''
  while not stop_condition:
    output_tokens, h, c = decoder_model.predict(
      [target_seq] + states_value)
    # Sample a token
    sampled_token_index = np.argmax(output_tokens[0, -1, :])
    sampled_char = reverse_target_char_index[sampled_token_index]
    decoded_sentence += sampled_char
    # Exit condition: either hit max length
    # or find stop character.
    if (sampled_char == '\n' or 
        len(decoded_sentence) > max_decoder_seq_length):
      stop_condition = True
    # Update the target sequence (of length 1).
    target_seq = np.zeros((1, 1, num_decoder_tokens))
    target_seq[0, 0, sampled_token_index] = 1.
    # Update states
    states_value = [h, c]
  return decoded_sentence
for seq_index in range(10):
  # Take one sequence (part of the training set)
  # for trying out decoding.
  input_seq = encoder_input_data[seq_index: seq_index + 1]
  decoded_sentence = decode_sequence(input_seq)
  print('Input sentence:', input_texts[seq_index])
  print('Decoded sentence:', decoded_sentence)