Jules Merckx / Oct 22 2018
Remix of Julia by Nextjournal

Pytorch seq2seq machine translation in Flux

fra.txt
using Flux, CuArrays
mutable struct Lang; name; word2index; word2count; index2word; n_words end

Lang(name) = Lang(
    name,
    Dict{String, Int}(),
    Dict{String, Int}(),
    Dict{Int, String}(1=>"SOS", 2=>"EOS", 3=>"UNK"),
    3)

function (l::Lang)(sentence::String)
    for word in split(sentence, " ")
        if word  keys(l.word2index)
            l.word2index[word] = l.n_words + 1
            l.word2count[word] = 1
            l.index2word[l.n_words + 1] = word
            l.n_words += 1
        else
            l.word2count[word] += 1
        end
    end
end
function normalizeString(s)
    s = strip(lowercase(s))
    s= replace(s, r"([.!?])"=>s" \1")
    s= replace(s, r"[^a-zA-Z.!?]+"=>" ")
    return s
end
function readLangs(lang1, lang2; rev=false)
    println("Reading lines...")
    lines = readlines(FILE)
    pairs = [normalizeString.(pair) for pair in split.(lines, "\t")]
    
    if rev
        pairs = reverse.(pairs)
        input_lang = Lang(lang2)
        output_lang = Lang(lang1)
    else
        input_lang = Lang(lang1)
        output_lang = Lang(lang2)
    end
    return(input_lang, output_lang, pairs)
end
eng_prefixes = [
    "i am ", "i m ",
    "he is ", "he s ",
    "she is ", "she s ",
    "you are ", "you re ",
    "we are ", "we re ",
    "they are ", "they re "]

function filterPair(p)
    return(false  (length.(split.(p, " ")) .<= MAX_LENGTH) && true  (startswith.(p[1], eng_prefixes)))
end
function prepareData(lang1, lang2; rev=false)
    input_lang, output_lang, pairs = readLangs(lang1, lang2; rev=rev)
    println("Read $(length(pairs)) sentence pairs.")
    
    pairs = [pair for pair in pairs if filterPair(pair)]
    println("Trimmed to $(length(pairs)) sentence pairs.\n")
                
    println("Counting words...")
    for pair in pairs
        input_lang(pair[2])
        output_lang(pair[1])
    end
    println("Counted words:")
    println("• ", input_lang.name, ": ", input_lang.n_words)
    println("• ", output_lang.name, ": ", output_lang.n_words)
    
    return(input_lang, output_lang, pairs)
end
FILE = "D:/Downloads/fra-eng/fra.txt"
MAX_LENGTH = 10

fr, eng, pairs = prepareData("fr", "eng");
initWeight(dims...) = param(rand(dims...) .- 0.5)*sqrt(24.0/(sum(dims)))

struct Embed; w; end
Embed(vocab::Int, embed::Int) = Embed(initWeight(embed, vocab))
(e::Embed)(x::Int) = e.w[:, x]
(e::Embed)(x::Array{Int}) = hcat(e.(x)...)

Flux.@treelike Embed
struct EncoderRNN; hidden_size; embedding; rnn end
EncoderRNN(input_size, hidden_size) = EncoderRNN(
    hidden_size,
    Embed(input_size, hidden_size),
    GRU(hidden_size, hidden_size))
function (e::EncoderRNN)(x)
    x = e.embedding(x)
    x = e.rnn(x)
    return(x)
end

Flux.@treelike EncoderRNN
struct DecoderRNN; hidden_size; embedding; rnn; linear end
DecoderRNN(hidden_size, output_size) = DecoderRNN(
    hidden_size,
    Embed(output_size, hidden_size),
    Flux.GRUCell(hidden_size, hidden_size),
    Dense(hidden_size, output_size))

function (d::DecoderRNN)(hidden, x)
    x = d.embedding(x)
    x = relu.(x)
    x, hidden = d.rnn(hidden, x)
    x = softmax(d.linear(x))
    return(hidden, x)
end

Flux.@treelike DecoderRNN
struct AttnDecoderRNN; hidden_size; output_size; dropout_p; max_length; embedding; attn; attn_combine; rnn; out end
AttnDecoderRNN(hidden_size, output_size, dropout_p=0.1; max_length=MAX_LENGTH) = AttnDecoderRNN(
    hidden_size,
    output_size,
    dropout_p,
    max_length,
    Embed(output_size, hidden_size),
    Dense(hidden_size*2, max_length),
    Dense(hidden_size*2, hidden_size),
    Flux.GRUCell(hidden_size, hidden_size),
    Dense(hidden_size, output_size))

function (d::AttnDecoderRNN)(x, hidden, encoder_outputs)
    embedded = d.embedding(x)
    embedded = Dropout(d.dropout_p)(embedded)
    attn_weights = softmax(d.attn([embedded; hidden]))
    attn_applied = reshape(reduce(+, encoder_outputs.*attn_weights, dims=1), :, d.hidden_size)'
    output = [attn_applied; embedded]
    output = relu(d.attn_combine(output))
    hidden, output = d.rnn(hidden, output)
    output = softmax(d.out(output))
    return(output, hidden, attn_weights)
end
Flux.@treelike AttnDecoderRNN