Homework 4
1. Loading the data
The data we will use is a simplified and reduced version of the MNIST set of handwritten digits, one of the most commonly-used data sets in machine learning.
This version of the data was kindly provided by Prof. Raj Rao, University of Michigan.
(i) Load the training data from training_digits.jld↩ using the JLD.jl
package, and explore it.
The data consists of a 3-dimensional array ("tensor"):
- The third component runs from 1 to 10, and specifies the digit from 0 to 9.
- The columns (the first component) are vectors of length representing grayscale images of pixels.
- The second component runs over the number of training examples.
In the same file are the corresponding labels .
Load the test data from test_digits.jld↩.
import JLD
# reference an uploaded file via the @ context menu train = JLD.load(training_digits.jld↩)["TRAIN_DIGITS"]; test = JLD.load(test_digits.jld↩);
(ii) Visualize some of the digits by reshaping them to a array and plotting the resulting image, e.g. using matshow
in PyPlot. Use e.g. @manipulate
. Note that some of the data is all zero.
using Interact, WebSockets, WebIO using Plots gr(leg=false, ticks=nothing, grid=false, border=false);
for digit=0:9, idx=1:319 heatmap(rotl90(reshape(train[:,idx,digit+1], 16, 16)), c=ColorGradient([:white, :black]), aspectratio=:equal, size=(100, 100)) end
good_idxs = Int[] raw_digits = Int[] for i in 1:3190 digit = Int(floor((i-1)/319)) idx = i - 319*digit if !all(train[:,idx,digit+1] .== 0) append!(good_idxs, i) append!(raw_digits, digit) end end digits = [d == digit+1 ? 1 : 0 for digit in raw_digits, d in 1:10];
anim = for i in good_idxs digit = Int(floor((i-1)/319)) idx = i - 319*digit heatmap(rotl90(reshape(train[:,idx,digit+1], 16, 16)), c=ColorGradient([:white, :black]), aspectratio=:equal, size=(100, 100)) end every 10 gif(anim, "/results/res.gif", show_msg = false);
2. Least squares
One of the first ideas that we could have is to look for a simple function that classifies each digit as e.g. "being a 7 or not being a 7". We could thus look for a linear map (the simplest type of map) from to that maps to a or , depending if the image does or does not correspond to a digit of type .
(iii) Express the "being a 7" problem as a matrix-vector multiplication over a single matrix containing all the data. Use least squares (\\
) to solve the problem.
begin Ned
size(W) = 10, 256
size(x) = 256, 1707
size(y) = 10, 1707
end Ned
A = zeros(1707, 256) for i in 1:length(good_idxs) gidx = good_idxs[i] digit = Int(floor((gidx-1)/319)) idx = gidx - 319*digit A[i,:] = train[:,idx,digit+1] end
x = A' y = digits' W = y / x; W' ≈ x' \ y' # using backslash notation
(iv) Stacking up all such problems vertically gives a problem where the vector is mapped onto a one-hot vector representing the digit . It turns out that some of the data vectors are all zero. Remove these and make one single matrix by horizontally stacking the data, and another, the corresponding one-hot vectors.
Using least squares, find the matrix that best solves for each . (NB: Take care about the dimensions. What are you solving for?)
Use the resulting matrix to classify the test digits. What proportion does it get right?
using Printf y_raw = W*x; y_hat = mapslices(argmax, y_raw, dims = 1) .- 1 incorrect_idxs = findall(x -> x != 0, (y_hat - raw_digits')) ("Got %.1f%% right in training set", 100*(1-length(incorrect_idxs)/length(y_hat)))
y_raw = W*test["digits"]; y_hat = mapslices(argmax, y_raw, dims = 1) .- 1 incorrect_idxs = findall(x -> x != 0, vec(y_hat - test["labels"])) ("Got %.1f%% right in testing set", 100*(1-length(incorrect_idxs)/length(y_hat)))
sorted_incorrect_idxs = sort(incorrect_idxs, by= i->test["labels"][i]) anim = for idx in sorted_incorrect_idxs heatmap(rotl90(reshape(test["digits"][:,idx], 16, 16)), c=ColorGradient([:white, :black]), aspectratio=:equal, size=(100, 100)) end gif(anim, "/results/res.gif", show_msg = false);
(v) Instead of just , do the same with , by adding an extra at the bottom of each .
x_b = [x; ones(1, size(x, 2))] W_b = y / x_b y_raw = W_b*[test["digits"]; ones(1, size(test["digits"], 2))] y_hat = mapslices(argmax, y_raw, dims = 1) .- 1 incorrect_idxs = findall(x -> x != 0, vec(y_hat - test["labels"])) ("Got %.1f%% right in testing set", 100*(1-length(incorrect_idxs)/length(y_hat)))
begin Ned
Why does this make no difference? Perhaps because the corners tended to be -1.0 already, and so were useful as constant factors?
end Ned
(vi) Use @manipulate
to scroll through misclassified images and discuss their features.
incorrect_by_label = [filter(i -> test["labels"][i] == j, incorrect_idxs) for j in 0:9] tfont = Plots.Font("Arial", 4, :hcenter, :vcenter, 0.0, RGB(0,0,0)) for i in 0:9 plots = [heatmap(rotl90(reshape(test["digits"][:,idx], 16, 16)), c=ColorGradient([:white, :black]), aspectratio=:equal, title=string(idx), titlefont=tfont) for idx in incorrect_by_label[i+1]] n = length(incorrect_by_label[i+1]) plot(plots..., size=(50*sqrt(n), 50*sqrt(n))) end
test["digits"] # output the array, so that we're able to reference it!
2.1. Create a plot with Python using the above Julia Data
digits = nil↩ # reference the above julia array import matplotlib import matplotlib.pyplot as plt import numpy as np x = np.matrix(digits) img = np.reshape(x[2, :], (16, 16)) fig, ax = plt.subplots() plt.imshow(img, cmap = 'Greys') fig