Identifying Image Subjects with Keras

Train a Keras/Tensorflow Convolutional Neural Network using the CIFAR-10 image dataset

Micah Dombrowski

Keras is a high-level Python neural network API, designed for ease-of-use multiple backend support. The default backend is Tensorflow. We'll use Tensorflow to classify the CIFAR-10 image dataset.

1.
Setup

Since we'll run this training on a GPU, we'll need tensorflow-gpu and the CUDA neural network libraries. We need h5py for storing the training results in HD5 format. Let's also add graphical packages for visualizing our training models.

conda install -qy -c anaconda tensorflow-gpu h5py graphviz pydot
pip install keras dill
Done
setup
Bash














































































The CIFAR-10 dataset is a collection of 60,000 color, 32x32-pixel images in ten classes, 10,000 of which are in a test batch. Keras can automatically download the dataset, but let's save some time by doing that now—it will cache the download, and we'll lock this cell and the setup cell above after the first run so that we shouldn't need to redownload.

from keras.datasets import cifar10
cifar10.load_data()
Done
download data
Python Setup


We'll run 128-image batches and set up two training runs: a long, 500-epoch run to do the main work, and a short, 5-epoch run as an example.

batch_size = 128
num_classes = 10
epochs_shortrun = 5
epochs_longrun = 500
save_dir = "/work"
res_dir = "/results"
model_name = 'convnet_cifar10'
Done
settings
Python

Load the data and get it into a reasonable shape. Also set up a function to find the best checkpoint file, another to give us a look at the images we're analyzing, and finally set up to do real-time input-data augmentation.

from __future__ import print_function
import tensorflow as tf
import keras
from keras.datasets import cifar10
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Flatten
from keras.layers.convolutional import Conv2D
from keras.optimizers import Adam
from keras.layers.pooling import MaxPooling2D
from keras.callbacks import ModelCheckpoint,EarlyStopping
from keras.utils import to_categorical
from keras.models import load_model

import os
import dill as pickle
import numpy as np

# set random seeds for reproducibility
tf.reset_default_graph()
tf.set_random_seed(343)
np.random.seed(343)

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

# The data, shuffled and split between train and test sets:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# Convert class vectors to binary class matrices.
y_train = to_categorical(y_train, num_classes)
y_test = to_categorical(y_test, num_classes)

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255.0
x_test /= 255.0

# Load label names to use in prediction results
label_list_path = 'datasets/cifar-10-batches-py/batches.meta'

keras_dir = os.path.expanduser(os.path.join('~', '.keras'))
datadir_base = os.path.expanduser(keras_dir)
if not os.access(datadir_base, os.W_OK):
    datadir_base = os.path.join('/tmp', '.keras')
label_list_path = os.path.join(datadir_base, label_list_path)

with open(label_list_path, mode='rb') as f:
    labels = pickle.load(f)

ckpt_dir = os.path.join(save_dir,"checkpoints")
if not os.path.isdir(ckpt_dir):
    os.makedirs(ckpt_dir)

model_path = os.path.join(res_dir, model_name + ".kerasave")
hist_path = os.path.join(res_dir, model_name + ".kerashist")

# Function to find latest checkpoint file
def last_ckpt(dir):
  fl = os.listdir(dir)
  fl = [x for x in fl if x.endswith(".hdf5")]
  cf = ""
  if len(fl) > 0:
    accs = [float(x.split("-")[3][0:-5]) for x in fl]
    m = max(accs)
    iaccs = [i for i, j in enumerate(accs) if j == m]
    fl = [fl[x] for x in iaccs]
    epochs = [int(x.split("-")[2]) for x in fl]
    cf = fl[epochs.index(max(epochs))]
    cf = os.path.join(dir,cf)
  
  return(cf)

import matplotlib.pyplot as plt
from math import *

#Visualizing CIFAR 10, takes indicides and shows in a grid
def cifar_grid(X,Y,inds,n_col, predictions=None):
  if predictions is not None:
    if Y.shape != predictions.shape:
      print("Predictions must equal Y in length!")
      return(None)
  N = len(inds)
  n_row = int(ceil(1.0*N/n_col))
  fig, axes = plt.subplots(n_row,n_col,figsize=(10,10))
  
  clabels = labels["label_names"]
  for j in range(n_row):
    for k in range(n_col):
      i_inds = j*n_col+k
      i_data = inds[i_inds]
      
      axes[j][k].set_axis_off()
      if i_inds < N:
        axes[j][k].imshow(X[i_data,...], interpolation='nearest')
        label = clabels[np.argmax(Y[i_data,...])]
        axes[j][k].set_title(label)
        if predictions is not None:
          pred = clabels[np.argmax(predictions[i_data,...])]
          if label != pred:
            label += " n"
            axes[j][k].set_title(pred, color='red')            
  
  fig.set_tight_layout(True)
  return(fig)

print('Using real-time data augmentation.')
# This will do preprocessing and realtime data augmentation:
datagen = ImageDataGenerator(
	featurewise_center=False,  # set input mean to 0 over the dataset
	samplewise_center=False,  # set each sample mean to 0
	featurewise_std_normalization=False,  # divide inputs by std of the dataset
	samplewise_std_normalization=False,  # divide each input by its std
	zca_whitening=False,  # apply ZCA whitening
	rotation_range=0,  # randomly rotate images in the range (degrees, 0 to 180)
	width_shift_range=0.1,  # randomly shift images horizontally (fraction of total width)
	height_shift_range=0.1,  # randomly shift images vertically (fraction of total height)
	horizontal_flip=True,  # randomly flip images
	vertical_flip=False)  # randomly flip images

# Compute quantities required for feature-wise normalization
# (std, mean, and principal components if ZCA whitening is applied).
datagen.fit(x_train)
Done
load modules and data
Python






Let's take a gander at a random selection of training images.

indices = [np.random.choice(range(len(x_train))) for i in range(36)]

cifar_grid(x_train,y_train,indices,6)
Done
sample images
Python

We'll use a simple convolutional network model (still under development), with the addition of the data augmentation defined above, and a checkpoint-writing callback that's keyed to significant accuracy improvements.

model = Sequential()

model.add(Conv2D(32, kernel_size=(3, 3), activation='relu',
                 input_shape=x_train.shape[1:]))
model.add(Conv2D(64, kernel_size=(3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))

model.add(Conv2D(128, kernel_size=(3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(128, kernel_size=(3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))

model.add(Flatten())
model.add(Dense(1024, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(10, activation='softmax'))


# initiate Adam optimizer
opt = Adam(lr=0.0001, decay=1e-6)

# Let's train the model using RMSprop
model.compile(loss='categorical_crossentropy',
              optimizer=opt,
              metrics=['accuracy'])

# checkpoint callback
filepath = os.path.join(ckpt_dir,
    "weights-improvement-{epoch:02d}-{val_acc:.6f}.hdf5")
checkpoint = ModelCheckpoint(
  filepath, monitor='val_acc', verbose=1, save_best_only=True, mode='max'
)
print("Saving improvement checkpoints to \n\t{0}".format(filepath))

# early stop callback, given a bit more leeway
stahp = EarlyStopping(min_delta=0.000001, patience=25)
Done
define model
Python



Finally, let's take a look at our model, with both a text summary and a flow chart.

from keras.utils import plot_model
plot_model(model, to_file="/results/model.svg", 
           show_layer_names=True, show_shapes=True, rankdir="TB")
print(model.summary())
Done
model analysis
Python



































2.
Training

Now we're ready to train using the GPU. We'll set up two code branches: one to do some serious long-term training (takes hours), and one which just runs a few additional epochs as an example. Note that both act as endpoints to their inheritance branches, because the underlying system that allows inheritance does not currenly know how to handle GPU activity. So, the training cells will save their results to /files/, and then for analysis and visualization we'll just need to load that data. We'll also pickle the training history for the long run to a file, in case we want to take a look at that.

Both paths are able to load the most accurate of any existing checkpoints in /files/checkpoints—this would allow for additional refinement of accuracy, although such runs would probably require removal or relaxation of the EarlyStopping callback.

2.1.
Long Training

epochs = settings.epochs_longrun

cpf = last_ckpt(ckpt_dir)
if cpf != "":
  print("Loading starting weights from \n\t{0}".format(cpf))
  model.load_weights(cpf)

# Fit the model on the batches generated by datagen.flow().
hist = model.fit_generator(datagen.flow(x_train, y_train,
    batch_size=batch_size,shuffle=True),
    steps_per_epoch=x_train.shape[0] // batch_size,
    epochs=epochs,verbose=2,
    validation_data=(x_test, y_test),
    workers=4, callbacks=[checkpoint,stahp])

# Save model and weights
model.save(model_path)
#print('Saved trained model at %s ' % model_path)

with open(hist_path, 'wb') as f:
  pickle.dump(hist.history, f)
Locked
long training
Python


























































































































































































































































































































































































































































































































































































































































































































































































































































convnet_cifar10.kerashist
Open in new window
convnet_cifar10.kerasave
Open in new window

2.2.
Short Example

epochs = settings.epochs_shortrun

# load results of long training run
model.load_weights(long training.convnet_cifar10.kerasave)

# Fit the model on the batches generated by datagen.flow().
hist = model.fit_generator(datagen.flow(x_train, y_train,
    batch_size=batch_size),
    steps_per_epoch=x_train.shape[0] // batch_size,
    epochs=epochs,verbose=2,
    validation_data=(x_test, y_test),
    workers=4, callbacks=[checkpoint])

# Save model and weights
model.save(model_path)
print('Saved trained model at %s ' % model_path)
Done
short training
Python

















convnet_cifar10.kerasave
Open in new window

3.
Results

Alrighty, now we can take a look at the trained model. The load_model() function will give us back our full, trained model for evaluation and prediction.

model = load_model(short training.convnet_cifar10.kerasave)

# Evaluate model with test data set
evaluation = model.evaluate_generator(datagen.flow(x_test, y_test,
    batch_size=batch_size, shuffle=False),
    steps=x_test.shape[0] // batch_size, workers=4)

# Print out final values of all metrics
key2name = {'acc':'Accuracy', 'loss':'Loss', 
    'val_acc':'Validation Accuracy', 'val_loss':'Validation Loss'}
results = []
for i,key in enumerate(model.metrics_names):
    results.append('%s = %.2f' % (key2name[key], evaluation[i]))
print(", ".join(results))
Done
results analysis
Python


We can sample the prediction results with images.

num_predictions = 36

model = load_model(short training.convnet_cifar10.kerasave)

predict_gen = model.predict_generator(datagen.flow(x_test, y_test,
    batch_size=batch_size, shuffle=False),
    steps=(x_test.shape[0] // batch_size)+1, workers=4)

indices = [np.random.choice(range(len(x_test))) 
           for i in range(num_predictions)]

cifar_grid(x_test,y_test,indices,6, predictions=predict_gen)
Done
prediction examples
Python

And hey, let's take a look at the training history (we'll look at the long training so it's an interesting history).

import matplotlib
import matplotlib.pyplot as plt

with open(long training.convnet_cifar10.kerashist, 'rb') as f:
  hist = pickle.load(f)

key2name = {'acc':'Accuracy', 'loss':'Loss', 
    'val_acc':'Validation Accuracy', 'val_loss':'Validation Loss'}

fig = plt.figure()

things = ['acc','loss','val_acc','val_loss']
for i,thing in enumerate(things):
  trace = hist[thing]
  plt.subplot(2,2,i+1)
  plt.plot(range(len(trace)),trace)
  plt.title(key2name[thing])

fig.set_tight_layout(True)
fig
Done
training analysis
Python

Finally, let's see if our trained network can correctly identify the subject of an uploaded image. This is the Internet, so it must be a cat.

from keras import backend as K
sess = K.get_session()

model = load_model(short training.convnet_cifar10.kerasave)

img = tf.read_file(Qat.jpg)
img = tf.image.decode_jpeg(img, channels=3)
img.set_shape([None, None, 3])
img = tf.image.resize_images(img, (32, 32))
img = img.eval(session=sess) # convert to numpy array
img = np.expand_dims(img, 0) # make 'batch' of 1

pred = model.predict(img)
pred = labels["label_names"][np.argmax(pred)]
Done
custom analysis
Python

Magic 8-ball says this image contains a cat. Huzzah!