Hugh Murrell / Aug 15 2019

Chapter 10, Recurrent Network.

An LSTM Recurrent Network

Based on a code by Mike Innes from the Flux model zoo and found at:

Load packages and data

First we load the required packages:

using Flux
using Flux: onehot, chunk, batchseq, throttle, crossentropy
using StatsBase: wsample
using Base.Iterators: partition

Now we read in the data:

input_catullus.txt
text = collect(String(read(
input_catullus.txt
))) alphabet = [unique(text)..., '_'] text = map(ch -> onehot(ch, alphabet), text) stop = onehot('_', alphabet);
N = length(alphabet)
seqlen = 50
nbatch = 50

Xs = collect(partition(batchseq(chunk(text, nbatch), stop), seqlen));
Ys = collect(partition(batchseq(chunk(text[2:end], nbatch), stop), seqlen));

Define the LSTM model

# Define our model.  
# We will use two LSTMs
# followed by a final Dense layer that
# feeds into a softmax probability output.
m = Chain(
  LSTM(N, 128),
  LSTM(128, 128),
  Dense(128, N),
  softmax)
Chain(Recur(LSTMCell(59, 128)), Recur(LSTMCell(128, 128)), Dense(128, 59), NNlib.softmax)

Define the loss function

# `loss()` calculates the crossentropy loss 
function loss(xs, ys)
  los = sum(crossentropy.(m.(xs), ys))
  Flux.truncate!(m)
  return los
end
loss (generic function with 1 method)

Select an optimiser

opt = ADAM(0.01)
tx, ty = (Xs[5], Ys[5])
evalcb = () -> @show loss(tx, ty)
#5 (generic function with 1 method)

Train the model

epochs = 50
for i = 1:epochs
    Flux.train!(loss, params(m), zip(Xs, Ys), opt,
            cb = throttle(evalcb, 30))
end

Define a sampling function

# Sampling

function sample(m, alphabet, len; temp = 1)
  Flux.reset!(m)
  buf = IOBuffer()
  c = rand(alphabet)
  for i = 1:len
    write(buf, c)
    c = wsample(alphabet, m(onehot(c, alphabet)).data)
  end
  return String(take!(buf))
end
sample (generic function with 1 method)

Sample from the model

sample(m, alphabet, 10000) |> println