Jules Merckx / Oct 22 2018
Remix of Julia by
Nextjournal
Pytorch seq2seq machine translation in Flux
Pytorch seq2seq machine translation in Flux
using Flux, CuArrays
mutable struct Lang; name; word2index; word2count; index2word; n_words end Lang(name) = Lang( name, Dict{String, Int}(), Dict{String, Int}(), Dict{Int, String}(1=>"SOS", 2=>"EOS", 3=>"UNK"), 3) function (l::Lang)(sentence::String) for word in split(sentence, " ") if word ∉ keys(l.word2index) l.word2index[word] = l.n_words + 1 l.word2count[word] = 1 l.index2word[l.n_words + 1] = word l.n_words += 1 else l.word2count[word] += 1 end end end
function normalizeString(s) s = strip(lowercase(s)) s= replace(s, r"([.!?])"=>s" \1") s= replace(s, r"[^a-zA-Z.!?]+"=>" ") return s end
function readLangs(lang1, lang2; rev=false) println("Reading lines...") lines = readlines(FILE) pairs = [normalizeString.(pair) for pair in split.(lines, "\t")] if rev pairs = reverse.(pairs) input_lang = Lang(lang2) output_lang = Lang(lang1) else input_lang = Lang(lang1) output_lang = Lang(lang2) end return(input_lang, output_lang, pairs) end
eng_prefixes = [ "i am ", "i m ", "he is ", "he s ", "she is ", "she s ", "you are ", "you re ", "we are ", "we re ", "they are ", "they re "] function filterPair(p) return(false ∉ (length.(split.(p, " ")) .<= MAX_LENGTH) && true ∈ (startswith.(p[1], eng_prefixes))) end
function prepareData(lang1, lang2; rev=false) input_lang, output_lang, pairs = readLangs(lang1, lang2; rev=rev) println("Read $(length(pairs)) sentence pairs.") pairs = [pair for pair in pairs if filterPair(pair)] println("Trimmed to $(length(pairs)) sentence pairs.\n") println("Counting words...") for pair in pairs input_lang(pair[2]) output_lang(pair[1]) end println("Counted words:") println("• ", input_lang.name, ": ", input_lang.n_words) println("• ", output_lang.name, ": ", output_lang.n_words) return(input_lang, output_lang, pairs) end
FILE = "D:/Downloads/fra-eng/fra.txt" MAX_LENGTH = 10 fr, eng, pairs = prepareData("fr", "eng");
initWeight(dims...) = param(rand(dims...) .- 0.5)*sqrt(24.0/(sum(dims))) struct Embed; w; end Embed(vocab::Int, embed::Int) = Embed(initWeight(embed, vocab)) (e::Embed)(x::Int) = e.w[:, x] (e::Embed)(x::Array{Int}) = hcat(e.(x)...) Flux. Embed
struct EncoderRNN; hidden_size; embedding; rnn end EncoderRNN(input_size, hidden_size) = EncoderRNN( hidden_size, Embed(input_size, hidden_size), GRU(hidden_size, hidden_size)) function (e::EncoderRNN)(x) x = e.embedding(x) x = e.rnn(x) return(x) end Flux. EncoderRNN
struct DecoderRNN; hidden_size; embedding; rnn; linear end DecoderRNN(hidden_size, output_size) = DecoderRNN( hidden_size, Embed(output_size, hidden_size), Flux.GRUCell(hidden_size, hidden_size), Dense(hidden_size, output_size)) function (d::DecoderRNN)(hidden, x) x = d.embedding(x) x = relu.(x) x, hidden = d.rnn(hidden, x) x = softmax(d.linear(x)) return(hidden, x) end Flux. DecoderRNN
struct AttnDecoderRNN; hidden_size; output_size; dropout_p; max_length; embedding; attn; attn_combine; rnn; out end AttnDecoderRNN(hidden_size, output_size, dropout_p=0.1; max_length=MAX_LENGTH) = AttnDecoderRNN( hidden_size, output_size, dropout_p, max_length, Embed(output_size, hidden_size), Dense(hidden_size*2, max_length), Dense(hidden_size*2, hidden_size), Flux.GRUCell(hidden_size, hidden_size), Dense(hidden_size, output_size)) function (d::AttnDecoderRNN)(x, hidden, encoder_outputs) embedded = d.embedding(x) embedded = Dropout(d.dropout_p)(embedded) attn_weights = softmax(d.attn([embedded; hidden])) attn_applied = reshape(reduce(+, encoder_outputs.*attn_weights, dims=1), :, d.hidden_size)' output = [attn_applied; embedded] output = relu(d.attn_combine(output)) hidden, output = d.rnn(hidden, output) output = softmax(d.out(output)) return(output, hidden, attn_weights) end Flux. AttnDecoderRNN