Image Classification with PyTorch

A Reproducibility Study

In this article we'll explore to what extend re-implementing some experiment in a different framework yields the same results. We'll therefore remix the Image Classification with Keras article and port the code to PyTorch.


In order to classify the Cifar10 dataset using PyTorch we of course first have to install PyTorch. For this purpose we'll use the original article's environment and extend it by installing PyTorch and Torchvision. Let's run pip freeze to see which packages are already there.

pip freeze

Alright, let's now install PyTorch and Torchvision and export the environment for further use.

pip install torch torchvision

Great! Luckily the CIFAR-10 dataset is already there from the original article.

import torch
import torchvision
transform = torchvision.transforms.Compose(
  [torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(
    (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

train_data = torchvision.datasets.CIFAR10(
  root=cifar-10-python.tar.gzfetch cifar data,
  train=True, download=False, transform=transform)
train_loader =, batch_size=4,
                                           shuffle=True, num_workers=2)

The CIFAR-10 dataset is a collection of 60,000 color, 32x32-pixel images in ten classes, 10,000 of which are in a test batch. Keras can automatically download the dataset, but we'll save time by downloading it once to /results, and just copying that file to the right place when it's needed.

wget --progress=dot:giga -O /results/cifar-10-python.tar.gz \

A lot of initilization code will be needed multiple times, so we'll put it in its own runtime as functions, and then import the desired code cells where we need it.

We'll run 128-image batches and set up two training runs: a long, 500-epoch run to do the main work, and a short, 5-epoch run as an example.

batch_size = 128
num_classes = 10
epochs_shortrun = 5
epochs_longrun = 500

save_dir = "/work"
res_dir = "/results"
model_name = "convnet_cifar10"

# setup paths
import os

ckpt_dir = os.path.join(save_dir,"checkpoints")
if not os.path.isdir(ckpt_dir):

model_path = os.path.join(res_dir, model_name + ".kerasave")
hist_path = os.path.join(res_dir, model_name + ".kerashist")

Load the data and get it into a reasonable shape. Also set up a function to find the best checkpoint file, another to give us a look at the images we're analyzing, and finally set up to do real-time input-data augmentation.

import numpy as np
import dill as pickle
from math import *

def setup_tf():
  # set random seeds for reproducibility

def setup_load_cifar(verbose=False):
  import os,shutil
  from keras.datasets import cifar10
  from keras.utils import to_categorical
  datadir = os.path.expanduser("~") + "/.keras/datasets/"
  datafile = datadir+"cifar-10-batches-py.tar.gz" # the name keras looks for
  if not os.path.isfile(datafile):
    shutil.copyfile(cifar-10-python.tar.gzfetch cifar data, datafile)
  # The data, shuffled and split between train and test sets:
  (x_train, y_train), (x_test, y_test) = cifar10.load_data()
  if verbose:
    print("x_train shape: {}, {} train samples, {} test samples.\n".format(
      x_train.shape, x_train.shape[0], x_test.shape[0]))
  # Convert class vectors to binary class matrices.
  y_train = to_categorical(y_train, num_classes)
  y_test = to_categorical(y_test, num_classes)
  x_train = x_train.astype("float32")
  x_test = x_test.astype("float32")
  x_train /= 255.0
  x_test /= 255.0
  # Load label names to use in prediction results
  label_list_path = "datasets/cifar-10-batches-py/batches.meta"
  keras_dir = os.path.expanduser(os.path.join("~", ".keras"))
  datadir_base = os.path.expanduser(keras_dir)
  if not os.access(datadir_base, os.W_OK):
    datadir_base = os.path.join("/tmp", ".keras")
  label_list_path = os.path.join(datadir_base, label_list_path)
  with open(label_list_path, mode="rb") as f:
    labels = pickle.load(f)
  return x_train, y_train, x_test, y_test, labels

def setup_data_aug():
  print("Using real-time data augmentation.\n")
  # This will do preprocessing and realtime data augmentation:
  from keras.preprocessing.image import ImageDataGenerator
  datagen = ImageDataGenerator(
  	featurewise_center=False, # set input mean to 0 over the dataset
	  samplewise_center=False,  # set each sample mean to 0
  	featurewise_std_normalization=False, # divide inputs by std of the dataset
	  samplewise_std_normalization=False,  # divide each input by its std
  	zca_whitening=False, # apply ZCA whitening
	  rotation_range=0, # randomly rotate images in the range 
                      # (degrees, 0 to 180)
  	width_shift_range=0.1, # randomly shift images horizontally 
                           # (fraction of total width)
  	height_shift_range=0.1, # randomly shift images vertically 
                            # (fraction of total height)
  	horizontal_flip=True, # randomly flip images
	  vertical_flip=False   # randomly flip images
  return datagen

# Function to find latest checkpoint file
def last_ckpt(dir):
  fl = os.listdir(dir)
  fl = [x for x in fl if x.endswith(".hdf5")]
  cf = ""
  if len(fl) > 0:
    accs = [float(x.split("-")[3][0:-5]) for x in fl]
    m = max(accs)
    iaccs = [i for i, j in enumerate(accs) if j == m]
    fl = [fl[x] for x in iaccs]
    epochs = [int(x.split("-")[2]) for x in fl]
    cf = fl[epochs.index(max(epochs))]
    cf = os.path.join(dir,cf)
  return cf

#Visualizing CIFAR 10, takes indicides and shows in a grid
def cifar_grid(X,Y,inds,n_col, predictions=None):
  import matplotlib.pyplot as plt
  if predictions is not None:
    if Y.shape != predictions.shape:
      print("Predictions must equal Y in length!\n")
  N = len(inds)
  n_row = int(ceil(1.0*N/n_col))
  fig, axes = plt.subplots(n_row,n_col,figsize=(10,10))
  clabels = labels["label_names"]
  for j in range(n_row):
    for k in range(n_col):
      i_inds = j*n_col+k
      i_data = inds[i_inds]
      if i_inds < N:
        axes[j][k].imshow(X[i_data,...], interpolation="nearest")
        label = clabels[np.argmax(Y[i_data,...])]
        if predictions is not None:
          pred = clabels[np.argmax(predictions[i_data,...])]
          if label != pred:
            label += " n"
            axes[j][k].set_title(pred, color="red")
  return fig

Let's take a gander at a random selection of training images.

x_train, y_train, x_test, y_test, labels = setup_load_cifar(verbose=True)

indices = [np.random.choice(range(len(x_train))) for i in range(36)]


We'll use a simple convolutional network model (still under development), with the addition of the data augmentation defined above, and a checkpoint-writing callback that's keyed to significant accuracy improvements.

from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Flatten
from keras.layers.convolutional import Conv2D
from keras.optimizers import Adam
from keras.layers.pooling import MaxPooling2D
from keras.callbacks import ModelCheckpoint,EarlyStopping

model = Sequential()

model.add(Conv2D(32, kernel_size=(3, 3), activation="relu",
model.add(Conv2D(64, kernel_size=(3, 3), activation="relu"))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(128, kernel_size=(3, 3), activation="relu"))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(128, kernel_size=(3, 3), activation="relu"))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Dense(1024, activation="relu"))
model.add(Dense(10, activation="softmax"))

# initiate Adam optimizer
opt = Adam(lr=0.0001, decay=1e-6)

# Let's train the model using RMSprop
              optimizer=opt, metrics=["accuracy"])

# checkpoint callback
filepath = os.path.join(ckpt_dir,
checkpoint = ModelCheckpoint(
  filepath, monitor="val_acc", verbose=1, save_best_only=True, mode="max")
print("Saving improvement checkpoints to \n\t{0}".format(filepath))

# early stop callback, given a bit more leeway
stahp = EarlyStopping(min_delta=0.00001, patience=25)

Finally, let's take a look at our model, with both a text summary and a flow chart.

from keras.utils import plot_model
plot_model(model, to_file="/results/model.svg", 
           show_layer_names=True, show_shapes=True, rankdir="TB")


Now we're ready to train using the GPU. We'll put this in a separate runtime configured to use a dedicated GPU compute node. This will have an initialization cell, and then two training cells: one to do some serious long-term training (takes hours), and one which just runs a few additional epochs as an example.

from __future__ import print_function

#os.environ["CUDA_VISIBLE_DEVICES"] = "" # for testing
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

x_train, y_train, x_test, y_test, labels = setup_load_cifar(verbose=True)

datagen = setup_data_aug()
# Compute quantities required for feature-wise normalization
# (std, mean, and principal components if ZCA whitening is applied).

The training cells will save their models and weights to /results, and then for analysis and visualization we'll just need to load that data. We'll also pickle the training history for the long run to /results so we can take a look at that.

Long Training

epochs = epochs_longrun

cpf = last_ckpt(ckpt_dir)
if cpf != "":
  print("Loading starting weights from \n\t{0}".format(cpf))

# Fit the model on the batches generated by datagen.flow().
hist = model.fit_generator(datagen.flow(x_train, y_train,
    steps_per_epoch=x_train.shape[0] // batch_size,
    validation_data=(x_test, y_test),
    workers=4, callbacks=[checkpoint,stahp])

# Save model and weights
#print('Saved trained model at %s ' % model_path)

with open(hist_path, 'wb') as f:
  pickle.dump(hist.history, f)

Short Example

epochs = epochs_shortrun

# load results of long training run
model.load_weights(convnet_cifar10.kerasavelong training)

# Fit the model on the batches generated by datagen.flow().
hist = model.fit_generator(datagen.flow(x_train, y_train,
    steps_per_epoch=x_train.shape[0] // batch_size,
    validation_data=(x_test, y_test),
    workers=4, callbacks=[checkpoint])

# Save model and weights
print('Saved trained model at %s ' % model_path)


Alrighty, now we can take a look at the trained model. The load_model() function will give us back our full, trained model for evaluation and prediction.

from __future__ import print_function

from keras.models import load_model

x_train, y_train, x_test, y_test, labels = setup_load_cifar()
datagen = setup_data_aug()

model = load_model(convnet_cifar10.kerasaveshort training)

# Evaluate model with test data set
evaluation = model.evaluate_generator(datagen.flow(x_test, y_test,
    batch_size=batch_size, shuffle=False),
    steps=x_test.shape[0] // batch_size, workers=4)

# Print out final values of all metrics
key2name = {'acc':'Accuracy', 'loss':'Loss', 
    'val_acc':'Validation Accuracy', 'val_loss':'Validation Loss'}
results = []
for i,key in enumerate(model.metrics_names):
    results.append('%s = %.2f' % (key2name[key], evaluation[i]))
print(", ".join(results))

We can sample the prediction results with images.

num_predictions = 36

model = load_model(convnet_cifar10.kerasaveshort training)

predict_gen = model.predict_generator(datagen.flow(x_test, y_test,
    batch_size=batch_size, shuffle=False),
    steps=(x_test.shape[0] // batch_size)+1, workers=4)

indices = [np.random.choice(range(len(x_test))) 
           for i in range(num_predictions)]

cifar_grid(x_test,y_test,indices,6, predictions=predict_gen)

And hey, let's take a look at the training history (we'll look at the long training so it's an interesting history).

import matplotlib.pyplot as plt

with open(convnet_cifar10.kerashistlong training, 'rb') as f:
  hist = pickle.load(f)

key2name = {'acc':'Accuracy', 'loss':'Loss', 
    'val_acc':'Validation Accuracy', 'val_loss':'Validation Loss'}

fig = plt.figure()

things = ['acc','loss','val_acc','val_loss']
for i,thing in enumerate(things):
  trace = hist[thing]


Finally, we're back to where we started! Now we can test our trained network against new images. Going back to our cat photo at the top...

sess = keras.backend.get_session()

model = keras.models.load_model(convnet_cifar10.kerasavelong training)
_,_,_,_,labels = setup_load_cifar()

img = tensorflow.read_file(Missing Reference)
img = tensorflow.image.decode_jpeg(img, channels=3)
img.set_shape([None, None, 3])
img = tensorflow.image.resize_images(img, (32, 32))
img = img.eval(session=sess) # convert to numpy array
img = np.expand_dims(img, 0) # make 'batch' of 1

pred = model.predict(img)
pred = labels["label_names"][np.argmax(pred)]

...again, we are assured that we have acquired a nil photo.