Image Classification with Knet
In this article we'll use Julia and the Knet package to classify an image:
The Knet project is a machine learning library implemented in Julia at Koç University. With Knet, models are defined using plain Julia code, including many high-level language features which other libraries prohibit. The library uses memory management and efficient GPU kernels to achieve good performance.
1. Setup
Much of what we need is provided by Nextjournal and its default Julia environment. We'll just have to install Knet and a few dependencies. We currently install a branch of MAT.jl which works on Julia 1.0 but is unoptimized—this should be updated once the relevant PR is merged.
pkg"update" pkg"add OffsetArrays FFTViews ArgParse Images Knet" pkg"add https://github.com/halleysfifthinc/MAT.jl#v0.7-update" pkg"build MAT OffsetArrays FFTViews ArgParse Images Knet" using MAT, OffsetArrays, FFTViews, ArgParse, Images, Knet
We'll also download the pretraining data to /results
.
2. Model Definition
We'll pull from the Knet example files: the ResNet-152 model implemented in resnet.jl. This residual network is similar to a standard layered network, but separate building blocks of a few layers are directly connected via shortcuts. This mixing solves the problem of degradation of training accuracy seen in some deep networks without adding any significant computational complexity.
import ArgParse,Images,MAT,ImageMagick,Knet # mode, 0=>train, 1=>test function resnet152(w,x,ms; mode=1) # layer 1 conv1 = reslayerx1(w[1:3],x,ms; padding=3, stride=2, mode=mode) pool1 = pool(conv1; window=3, stride=2) # layer 2,3,4,5 r2 = reslayerx5(w[4:33], pool1, ms; strides=[1,1,1,1], mode=mode) r3 = reslayerx5(w[34:108], r2, ms; mode=mode) r4 = reslayerx5(w[109:435], r3, ms; mode=mode) r5 = reslayerx5(w[436:465], r4, ms; mode=mode) # fully connected layer pool5 = pool(r5; stride=1, window=7, mode=2) fc1000 = w[466] * mat(pool5) .+ w[467] end
Then the layer definition functions and some utility functions, including an image processing function from vgg.jl.
# Batch Normalization Layer # works both for convolutional and fully connected layers # mode, 0=>train, 1=>test function batchnorm(w, x, ms; mode=1, epsilon=1e-5) mu, sigma = nothing, nothing if mode == 0 d = ndims(x) == 4 ? (1,2,4) : (2,) s = prod(size(x,d...)) mu = sum(x,d) / s x0 = x .- mu x1 = x0 .* x0 sigma = sqrt(epsilon + (sum(x1, d)) / s) elseif mode == 1 mu = popfirst!(ms) sigma = popfirst!(ms) end # we need getval in backpropagation push!(ms, AutoGrad.value(mu), AutoGrad.value(sigma)) xhat = (x.-mu) ./ sigma return w[1] .* xhat .+ w[2] end function reslayerx0(w,x,ms; padding=0, stride=1, mode=1) b = conv4(w[1],x; padding=padding, stride=stride) bx = batchnorm(w[2:3],b,ms; mode=mode) end function reslayerx1(w,x,ms; padding=0, stride=1, mode=1) relu.(reslayerx0(w,x,ms; padding=padding, stride=stride, mode=mode)) end function reslayerx2(w,x,ms; pads=[0,1,0], strides=[1,1,1], mode=1) ba = reslayerx1(w[1:3],x,ms; padding=pads[1], stride=strides[1], mode=mode) bb = reslayerx1(w[4:6],ba,ms; padding=pads[2], stride=strides[2], mode=mode) bc = reslayerx0(w[7:9],bb,ms; padding=pads[3], stride=strides[3], mode=mode) end function reslayerx3(w,x,ms; pads=[0,0,1,0], strides=[2,2,1,1], mode=1) # 12 a = reslayerx0(w[1:3],x,ms; stride=strides[1], padding=pads[1], mode=mode) b = reslayerx2(w[4:12],x,ms; strides=strides[2:4], pads=pads[2:4], mode=mode) relu.(a .+ b) end function reslayerx4(w,x,ms; pads=[0,1,0], strides=[1,1,1], mode=1) relu.(x .+ reslayerx2(w,x,ms; pads=pads, strides=strides, mode=mode)) end function reslayerx5(w,x,ms; strides=[2,2,1,1], mode=1) x = reslayerx3(w[1:12],x,ms; strides=strides, mode=mode) for k = 13:9:length(w) x = reslayerx4(w[k:k+8],x,ms; mode=mode) end return x end
function get_params(params, atype) len = length(params["value"]) ws, ms = [], [] for k = 1:len name = params["name"][k] value = convert(Array{Float32}, params["value"][k]) if endswith(name, "moments") push!(ms, reshape(value[:,1], (1,1,size(value,1),1))) push!(ms, reshape(value[:,2], (1,1,size(value,1),1))) elseif startswith(name, "bn") push!(ws, reshape(value, (1,1,length(value),1))) elseif startswith(name, "fc") && endswith(name, "filter") push!(ws, transpose(reshape(value,(size(value,3),size(value,4))))) elseif startswith(name, "conv") && endswith(name, "bias") push!(ws, reshape(value, (1,1,length(value),1))) else push!(ws, value) end end map(wi->convert(atype, wi), ws), map(mi->convert(atype, mi), ms) end # From vgg.jl function data(img, averageImage) if occursin("://",img) "Downloading $img" img = download(img) end a0 = load(img) new_size = ntuple(i->div(size(a0,i)*224,minimum(size(a0))),2) a1 = Images.imresize(a0, new_size) i1 = div(size(a1,1)-224,2) j1 = div(size(a1,2)-224,2) b1 = a1[i1+1:i1+224,j1+1:j1+224] c1 = permutedims(channelview(b1), (3,2,1)) d1 = convert(Array{Float32}, c1) e1 = reshape(d1[:,:,1:3], (224,224,3,1)) f1 = (255 * e1 .- averageImage) g1 = permutedims(f1, [2,1,3,4]) end
3. Classification
Let's see how this model does. As is tradition, we shall cat.
We'll wrap the main model-handling part of the code in a predict() function, so that the frontend won't be trying to display a 240 MB array—otherwise, this is fairly straightforward.
using Knet, MAT, Images o = Dict( :atype => KnetArray{Float32}, :model => imagenet-resnet-152-dag.matdata download 152↩, :image => Image↩, :top => 10 ) function predict(o) "Reading $(o[:model])" model = matread(abspath(o[:model])) avgimg = model["meta"]["normalization"]["averageImage"] avgimg = convert(Array{Float32}, avgimg) description = model["meta"]["classes"]["description"] w, ms = get_params(model["params"], o[:atype]) "Reading $(o[:image])" img = data(o[:image], avgimg) img = convert(o[:atype], img) "Classifying." y1 = resnet152(w,img,ms) return y1, description end y1, description = predict(o) z1 = vec(Array(y1)) s1 = sortperm(z1,rev=true) p1 = exp.(logp(z1))
using Printf for ind in s1[1:o[:top]] print("$(description[ind]): $(@sprintf("%.2f",p1[ind]*100))%\n") end
Well, it's pretty sure there's some type of cat! It also has an inkling that there's some sort of textile in there. Nifty.