# Chapter 9, Convolutional Network ![cnn.png][nextjournal#file#f9d92965-5f49-4f75-baa3-0815201701ae] # A Convolutional Network for recognising handwritten digits based on a code from the Flux model zoo which can be found at ## Load packages and data First we load the required packages: ```julia id=4668a2dd-8a86-4bfe-8bea-aea99755cf92 using Flux, Flux.Data.MNIST, Images, Statistics using Flux: onehotbatch, onecold, crossentropy, throttle using Base.Iterators: repeated, partition using Printf using Plots ``` Now we read in the data: ```julia id=bdd4a1ed-f91d-4d14-9889-92d1ef43434a labels = MNIST.labels(); images = MNIST.images(); ``` ```julia id=5555dd6c-9e1b-41bf-9bc1-d888c9bcb915 length(labels) ``` ```julia id=aba39bea-29fb-43de-9f8d-5bad07e8dbd8 typeof(images[1]) ``` ```julia id=af73815e-5e39-41b6-a278-3bb336cd72f1 display(images[1]) ``` ![result][nextjournal#output#af73815e-5e39-41b6-a278-3bb336cd72f1#result] ```julia id=5895c946-8bb0-4e37-99a6-a04cfce763c6 size(images[1]) ``` ```julia id=a757431d-daf1-4a43-a8f6-5d79d89e3ff4 labels[1] ``` ## Batch the data ```julia id=1fe3e372-deb7-4724-bc15-b978aa905e73 # Bundle images together with labels and group into minibatchess function make_minibatch(X, Y, idxs) X_batch = Array{Float32}(undef, size(X[1])..., 1, length(idxs)) for i in 1:length(idxs) X_batch[:, :, :, i] = Float32.(X[idxs[i]]) end Y_batch = onehotbatch(Y[idxs], 0:9) return (X_batch, Y_batch) end ``` ```julia id=5c7b5524-7d61-42d8-bb02-8ee842bfa357 # Prepare train set as a collection of minibatchs: train_images = MNIST.images(1:10000) train_labels = MNIST.labels(1:10000) batch_size = 500 mb_idxs = partition(1:length(train_images), batch_size) train_set = [make_minibatch(train_images, train_labels, i) for i in mb_idxs]; ``` ```julia id=df560a8a-658d-45e6-b67f-bdfb928b1b77 length(train_images) ``` ```julia id=fe651124-bf64-4394-b92a-7c7afc947508 # Prepare test set as one giant minibatch: test_images = MNIST.images(:test) test_labels = MNIST.labels(:test) test_set = make_minibatch(test_images, test_labels, 1:length(test_images)); ``` ```julia id=568f8e08-4517-44f1-b832-b7d36a4579ae length(test_images) ``` ## Setup the convolutional model ```julia id=b031705c-bffc-4134-94e7-53153b4a2e69 # Define our model. # We will use a simple convolutional architecture with # three iterations of Conv -> ReLU -> MaxPool, # followed by a final Dense layer that # feeds into a softmax probability output. @info("Constructing model...") model = Chain( # First convolution, operating upon a 28x28 image Conv((3, 3), 1=>16, pad=(1,1), relu), MaxPool((2,2)), # Second convolution, operating upon a 14x14 image Conv((3, 3), 16=>32, pad=(1,1), relu), MaxPool((2,2)), # Third convolution, operating upon a 7x7 image Conv((3, 3), 32=>32, pad=(1,1), relu), MaxPool((2,2)), # Reshape 3d tensor into a 2d one, # at this point it should be (3, 3, 32, N) # which is where we get the 288 in the `Dense` # layer below: x -> reshape(x, :, size(x, 4)), Dense(288, 10), # Finally, softmax to get nice probabilities softmax, ) ``` ```julia id=bf147ca4-7819-4817-9dad-7e79a6590311 # Make sure our model is nicely precompiled # before starting our training loop model(train_set[1][1]) ``` ## Define the loss function and select an optimiser ```julia id=35f4400c-7767-4a81-8518-c79ae23a9a9d loss(x, y) = sum(Flux.crossentropy(model(x), y)) opt = ADAM(0.001) # Momentum(0.01) accuracy(x, y) = mean(Flux.onecold(model(x), 1:10) .== Flux.onecold(y, 1:10)) n_epochs = 1 ``` ## Train the model training for 1 epoch takes 3 min with no GPU, accuracy on the test set reaches 50% after 1 epoch, 77% after 2, 87% after 3, 91% after 4 and 93% after 5. ```julia id=cc71499e-4625-453b-bdd9-0eb0bd7cda26 @Flux.epochs n_epochs Flux.train!( loss, params(model), train_set, opt, cb=() -> @show accuracy(test_set...) ) ``` ## Display the results ```julia id=df02d83b-ccb6-4f64-aa19-125c3ab509d3 pred_test_labels = Flux.onecold(model(test_set[1]), 1:10) true_test_labels = Flux.onecold(test_set[2], 1:10) acc = mean(pred_test_labels .== true_test_labels) cm = zeros(Int64, 10, 10) for i in 1:length(pred_test_labels) cm[pred_test_labels[i],true_test_labels[i]] += 1 end ``` ```julia id=4a960c80-6252-4780-b258-aca2bbd8ac46 cm ``` ```julia id=93362408-e951-4912-b75c-18f4acc2bc68 p2 = heatmap(cm, c=:dense, title="Confusion Matrix, accuracy = "*string(acc), ylabel="True label", xlabel= "Predicted label", xticks=(1:10, 0:9), yticks=(1:10, 0:9)) ``` ![result][nextjournal#output#93362408-e951-4912-b75c-18f4acc2bc68#result] ```julia id=2eea3f3f-2157-4727-b105-51b753491249 ``` [nextjournal#file#f9d92965-5f49-4f75-baa3-0815201701ae]: [nextjournal#output#af73815e-5e39-41b6-a278-3bb336cd72f1#result]: [nextjournal#output#93362408-e951-4912-b75c-18f4acc2bc68#result]:
This notebook was exported from https://nextjournal.com/a/LXMhRNsUE5X94AHZxDrK8?change-id=CfiV2G7ptFbAbrbFMeUSPe ```edn nextjournal-metadata {:article {:settings nil, :nodes {"1fe3e372-deb7-4724-bc15-b978aa905e73" {:compute-ref #uuid "17af7b06-95a5-40fa-a0c3-a0e0030f7062", :exec-duration 693, :id "1fe3e372-deb7-4724-bc15-b978aa905e73", :kind "code", :output-log-lines {}, :runtime [:runtime "34e7994b-249d-48b9-ba73-05176dc99cd6"]}, "2eea3f3f-2157-4727-b105-51b753491249" {:id "2eea3f3f-2157-4727-b105-51b753491249", :kind "code", :runtime [:runtime "34e7994b-249d-48b9-ba73-05176dc99cd6"]}, "34e7994b-249d-48b9-ba73-05176dc99cd6" {:environment [:environment {:article/nextjournal.id #uuid "02b99285-162d-4560-b63e-b52fa0afc716", :change/nextjournal.id #uuid "5d52be1e-02be-4f2a-883f-e2c8824a8216", :node/id "82010b96-8284-48c0-8587-6c625e0e6c08"}], :environment? true, :id "34e7994b-249d-48b9-ba73-05176dc99cd6", :kind "runtime", :language "julia", :resources {:machine-type "n1-standard-4"}, :type :jupyter, :docker/environment-image "docker.nextjournal.com/environment@sha256:2511d20983ff11c9409af4ba4fb80b0fef33fe246dcf79367e6159348610f631"}, "35f4400c-7767-4a81-8518-c79ae23a9a9d" {:compute-ref #uuid "2c676128-ce57-4fbe-97eb-ac13a1c8137e", :exec-duration 39, :id "35f4400c-7767-4a81-8518-c79ae23a9a9d", :kind "code", :output-log-lines {}, :runtime [:runtime "34e7994b-249d-48b9-ba73-05176dc99cd6"]}, "4668a2dd-8a86-4bfe-8bea-aea99755cf92" {:compute-ref #uuid "8068232c-c0bb-4f0f-93dd-bf44c94bda5f", :exec-duration 19707, :id "4668a2dd-8a86-4bfe-8bea-aea99755cf92", :kind "code", :output-log-lines {}, :runtime [:runtime "34e7994b-249d-48b9-ba73-05176dc99cd6"]}, "4a960c80-6252-4780-b258-aca2bbd8ac46" {:compute-ref #uuid "6ca6a64d-be64-4b2b-a345-9264bb6dc401", :exec-duration 11, :id "4a960c80-6252-4780-b258-aca2bbd8ac46", :kind "code", :output-log-lines {}, :runtime [:runtime "34e7994b-249d-48b9-ba73-05176dc99cd6"]}, "5555dd6c-9e1b-41bf-9bc1-d888c9bcb915" {:compute-ref #uuid "c3ace0de-6ba8-4e22-a72b-563383dfb586", :exec-duration 1009, :id "5555dd6c-9e1b-41bf-9bc1-d888c9bcb915", :kind "code", :output-log-lines {}, :runtime [:runtime "34e7994b-249d-48b9-ba73-05176dc99cd6"]}, "568f8e08-4517-44f1-b832-b7d36a4579ae" {:compute-ref #uuid "4de69ffc-b43b-47b1-9430-a0040161a3ec", :exec-duration 12, :id "568f8e08-4517-44f1-b832-b7d36a4579ae", :kind "code", :output-log-lines {}, :runtime [:runtime "34e7994b-249d-48b9-ba73-05176dc99cd6"]}, "5895c946-8bb0-4e37-99a6-a04cfce763c6" {:compute-ref #uuid "056ee8b2-3c95-4151-b314-8e7ba21b8282", :exec-duration 617, :id "5895c946-8bb0-4e37-99a6-a04cfce763c6", :kind "code", :output-log-lines {}, :runtime [:runtime "34e7994b-249d-48b9-ba73-05176dc99cd6"]}, "5c7b5524-7d61-42d8-bb02-8ee842bfa357" {:compute-ref #uuid "651c9ed9-ead5-4c49-bf14-a0c939caddd7", :exec-duration 808, :id "5c7b5524-7d61-42d8-bb02-8ee842bfa357", :kind "code", :output-log-lines {}, :runtime [:runtime "34e7994b-249d-48b9-ba73-05176dc99cd6"]}, "93362408-e951-4912-b75c-18f4acc2bc68" {:compute-ref #uuid "25abf9e7-07b0-4400-86b7-2d7576f9f60f", :exec-duration 732, :id "93362408-e951-4912-b75c-18f4acc2bc68", :kind "code", :output-log-lines {}, :runtime [:runtime "34e7994b-249d-48b9-ba73-05176dc99cd6"]}, "a757431d-daf1-4a43-a8f6-5d79d89e3ff4" {:compute-ref #uuid "b71cd51e-8e57-4175-9aef-d955311636a9", :exec-duration 7, :id "a757431d-daf1-4a43-a8f6-5d79d89e3ff4", :kind "code", :output-log-lines {}, :runtime [:runtime "34e7994b-249d-48b9-ba73-05176dc99cd6"]}, "aba39bea-29fb-43de-9f8d-5bad07e8dbd8" {:compute-ref #uuid "f9770a59-4804-48c8-99cd-3335ec7e1087", :exec-duration 682, :id "aba39bea-29fb-43de-9f8d-5bad07e8dbd8", :kind "code", :output-log-lines {}, :runtime [:runtime "34e7994b-249d-48b9-ba73-05176dc99cd6"]}, "af73815e-5e39-41b6-a278-3bb336cd72f1" {:compute-ref #uuid "60ab40f2-7784-42e3-8724-ade757ceb3ce", :exec-duration 21196, :id "af73815e-5e39-41b6-a278-3bb336cd72f1", :kind "code", :output-log-lines {:stdout 2}, :runtime [:runtime "34e7994b-249d-48b9-ba73-05176dc99cd6"]}, "b031705c-bffc-4134-94e7-53153b4a2e69" {:compute-ref #uuid "be12206d-c982-4c40-83ff-55e6c6e28d9b", :exec-duration 2527, :id "b031705c-bffc-4134-94e7-53153b4a2e69", :kind "code", :output-log-lines {:stdout 2}, :runtime [:runtime "34e7994b-249d-48b9-ba73-05176dc99cd6"], :stdout-collapsed? false}, "bdd4a1ed-f91d-4d14-9889-92d1ef43434a" {:compute-ref #uuid "6e9d355c-bba1-40f2-8e82-0101854f19ea", :exec-duration 11739, :id "bdd4a1ed-f91d-4d14-9889-92d1ef43434a", :kind "code", :output-log-lines {:stdout 24}, :runtime [:runtime "34e7994b-249d-48b9-ba73-05176dc99cd6"]}, "bf147ca4-7819-4817-9dad-7e79a6590311" {:compute-ref #uuid "284f5523-e217-4cff-a3ba-448cf5a39791", :exec-duration 10829, :id "bf147ca4-7819-4817-9dad-7e79a6590311", :kind "code", :output-log-lines {}, :runtime [:runtime "34e7994b-249d-48b9-ba73-05176dc99cd6"]}, "cc71499e-4625-453b-bdd9-0eb0bd7cda26" {:compute-ref #uuid "7f6580dc-2b09-4497-8d39-0bfdb919af66", :exec-duration 164067, :id "cc71499e-4625-453b-bdd9-0eb0bd7cda26", :kind "code", :output-log-lines {:stdout 22}, :runtime [:runtime "34e7994b-249d-48b9-ba73-05176dc99cd6"]}, "df02d83b-ccb6-4f64-aa19-125c3ab509d3" {:compute-ref #uuid "b6f64761-0843-4efc-90cc-fdbaace666fd", :exec-duration 7407, :id "df02d83b-ccb6-4f64-aa19-125c3ab509d3", :kind "code", :output-log-lines {}, :runtime [:runtime "34e7994b-249d-48b9-ba73-05176dc99cd6"]}, "df560a8a-658d-45e6-b67f-bdfb928b1b77" {:compute-ref #uuid "1119749f-be5e-4d1f-9419-226e5290a203", :exec-duration 8, :id "df560a8a-658d-45e6-b67f-bdfb928b1b77", :kind "code", :output-log-lines {}, :runtime [:runtime "34e7994b-249d-48b9-ba73-05176dc99cd6"]}, "f9d92965-5f49-4f75-baa3-0815201701ae" {:id "f9d92965-5f49-4f75-baa3-0815201701ae", :kind "file"}, "fe651124-bf64-4394-b92a-7c7afc947508" {:compute-ref #uuid "24c2a06c-d1ae-42a2-985d-a900215f31b3", :exec-duration 443, :id "fe651124-bf64-4394-b92a-7c7afc947508", :kind "code", :output-log-lines {}, :runtime [:runtime "34e7994b-249d-48b9-ba73-05176dc99cd6"]}}, :nextjournal/id #uuid "02b9d6f6-884d-4c1d-929f-a5101060c1cf", :article/change {:nextjournal/id #uuid "5e7c5a79-05a4-414a-a6cb-23844e50392d"}}} ```