Hugh Murrell / Aug 15 2019
Chapter 10, Recurrent Network.
An LSTM Recurrent Network
Based on a code by Mike Innes from the Flux model zoo and found at:
Load packages and data
First we load the required packages:
using Flux using Flux: onehot, chunk, batchseq, throttle, crossentropy using StatsBase: wsample using Base.Iterators: partition
Now we read in the data:
text = collect(String(read(input_catullus.txt))) alphabet = [unique(text)..., '_'] text = map(ch -> onehot(ch, alphabet), text) stop = onehot('_', alphabet);
N = length(alphabet) seqlen = 50 nbatch = 50 Xs = collect(partition(batchseq(chunk(text, nbatch), stop), seqlen)); Ys = collect(partition(batchseq(chunk(text[2:end], nbatch), stop), seqlen));
Define the LSTM model
# Define our model. # We will use two LSTMs # followed by a final Dense layer that # feeds into a softmax probability output. m = Chain( LSTM(N, 128), LSTM(128, 128), Dense(128, N), softmax)
Chain(Recur(LSTMCell(59, 128)), Recur(LSTMCell(128, 128)), Dense(128, 59), NNlib.softmax)
Define the loss function
# `loss()` calculates the crossentropy loss function loss(xs, ys) los = sum(crossentropy.(m.(xs), ys)) Flux.truncate!(m) return los end
loss (generic function with 1 method)
Select an optimiser
opt = ADAM(0.01) tx, ty = (Xs[5], Ys[5]) evalcb = () -> loss(tx, ty)
#5 (generic function with 1 method)
Train the model
epochs = 50 for i = 1:epochs Flux.train!(loss, params(m), zip(Xs, Ys), opt, cb = throttle(evalcb, 30)) end
Define a sampling function
# Sampling function sample(m, alphabet, len; temp = 1) Flux.reset!(m) buf = IOBuffer() c = rand(alphabet) for i = 1:len write(buf, c) c = wsample(alphabet, m(onehot(c, alphabet)).data) end return String(take!(buf)) end
sample (generic function with 1 method)
Sample from the model
sample(m, alphabet, 10000) |> println