Image Classification with Keras
A Keras/Tensorflow Convolutional Network Applied to the CIFAR-10 Dataset
In this article you'll learn how to train a neural network to classify images.

The Internet demands cat pictures. Many humans can recognize a cat fairly easily, but can we train computers to do so? Well, our Magic 8-Ball says that the above image contains a
exec(open("functions.py").read())import keras, tensorflowsess = keras.backend.get_session()model = keras.models.load_model(convnet_cifar10.kerasave)_,_,_,_,labels = setup_load_cifar()img = tensorflow.read_file(Qat.jpg)img = tensorflow.image.decode_jpeg(img, channels=3)img.set_shape([None, None, 3])img = tensorflow.image.resize_images(img, (32, 32))img = img.eval(session=sess) # convert to numpy arrayimg = np.expand_dims(img, 0) # make 'batch' of 1pred = model.predict(img)pred = labels["label_names"][np.argmax(pred)]predTo accomplish this grand innovation we've made use of the the high-level Keras API over the Tensorflow framework.
The Keras project is a high-level Python neural network API. It's designed to be both user friendly and modular, supporting multiple backends. The default Keras backend is Tensorflow, a symbolic math library which is widely used for machine learning and neural network tasks. We'll be training our Keras/Tensorflow setup to classify the CIFAR-10 image dataset, which is 10% cat pictures.
Setup
For our environment, we'll simply transclude the Nextjournal default Tensorflow reusable environment, which includes Keras and everything else we need.
The CIFAR-10 dataset is a collection of 60,000 color, 32x32-pixel images in ten classes, 10,000 of which are in a test batch. Keras can automatically download the dataset, but we'll save time by uploading it, and mounting the file to the right place in the runtime settings.
A lot of initialization code will be needed multiple times, so we'll put it in its own runtime as functions, and then import the desired code cells where we need it.
We'll run 128-image batches and set up two training runs: a long, 500-epoch run to do the main work, and a short, 5-epoch run as an example.
batch_size = 128num_classes = 10epochs_shortrun = 5epochs_longrun = 500random_seed = 343save_dir = "/work"res_dir = "/results"model_name = "convnet_cifar10"# setup pathsimport osckpt_dir = os.path.join(save_dir,"checkpoints")if not os.path.isdir(ckpt_dir): os.makedirs(ckpt_dir)model_path = os.path.join(res_dir, model_name + ".kerasave")hist_path = os.path.join(res_dir, model_name + ".kerashist")Load the data and get it into a reasonable shape. Also set up a function to find the best checkpoint file, another to give us a look at the images we're analyzing, and finally set up to do real-time input-data augmentation.
exec(open("settings.py").read())import numpy as npimport dill as picklefrom math import *def setup_tf(seed): import os,random,numpy as np,tensorflow as tf tf.reset_default_graph() # set random seeds for reproducibility os.environ['PYTHONHASHSEED']=str(seed) random.seed(seed) np.random.seed(seed) tf.set_random_seed(seed) def setup_load_cifar(verbose=False): import os,shutil from keras.datasets import cifar10 from keras.utils import to_categorical # The data, shuffled and split between train and test sets: (x_train, y_train), (x_test, y_test) = cifar10.load_data() if verbose: print("x_train shape: {}, {} train samples, {} test samples.\n".format( x_train.shape, x_train.shape[0], x_test.shape[0])) # Convert class vectors to binary class matrices. y_train = to_categorical(y_train, num_classes) y_test = to_categorical(y_test, num_classes) x_train = x_train.astype("float32") x_test = x_test.astype("float32") x_train /= 255.0 x_test /= 255.0 # Load label names to use in prediction results label_list_path = "datasets/cifar-10-batches-py/batches.meta" keras_dir = os.path.expanduser(os.path.join("~", ".keras")) datadir_base = os.path.expanduser(keras_dir) if not os.access(datadir_base, os.W_OK): datadir_base = os.path.join("/tmp", ".keras") label_list_path = os.path.join(datadir_base, label_list_path) with open(label_list_path, mode="rb") as f: labels = pickle.load(f) return x_train, y_train, x_test, y_test, labelsdef setup_data_aug(): print("Using real-time data augmentation.\n") # This will do preprocessing and realtime data augmentation: from keras.preprocessing.image import ImageDataGenerator datagen = ImageDataGenerator( featurewise_center=False, # set input mean to 0 over the dataset samplewise_center=False, # set each sample mean to 0 featurewise_std_normalization=False, # divide inputs by std of the dataset samplewise_std_normalization=False, # divide each input by its std zca_whitening=False, # apply ZCA whitening rotation_range=0, # randomly rotate images in the range # (degrees, 0 to 180) width_shift_range=0.1, # randomly shift images horizontally # (fraction of total width) height_shift_range=0.1, # randomly shift images vertically # (fraction of total height) horizontal_flip=True, # randomly flip images vertical_flip=False # randomly flip images ) return datagen# Function to find latest checkpoint filedef last_ckpt(dir): fl = os.listdir(dir) fl = [x for x in fl if x.endswith(".hdf5")] cf = "" if len(fl) > 0: accs = [float(x.split("-")[3][0:-5]) for x in fl] m = max(accs) iaccs = [i for i, j in enumerate(accs) if j == m] fl = [fl[x] for x in iaccs] epochs = [int(x.split("-")[2]) for x in fl] cf = fl[epochs.index(max(epochs))] cf = os.path.join(dir,cf) return cf#Visualizing CIFAR 10, takes indicides and shows in a griddef cifar_grid(X,Y,inds,n_col,labels,predictions=None): import matplotlib.pyplot as plt if predictions is not None: if Y.shape != predictions.shape: print("Predictions must equal Y in length!\n") return(None) N = len(inds) n_row = int(ceil(1.0*N/n_col)) fig, axes = plt.subplots(n_row,n_col,figsize=(10,10)) clabels = labels["label_names"] for j in range(n_row): for k in range(n_col): i_inds = j*n_col+k i_data = inds[i_inds] axes[j][k].set_axis_off() if i_inds < N: axes[j][k].imshow(X[i_data,...], interpolation="nearest") label = clabels[np.argmax(Y[i_data,...])] axes[j][k].set_title(label) if predictions is not None: pred = clabels[np.argmax(predictions[i_data,...])] if label != pred: label += " n" axes[j][k].set_title(pred, color="red") fig.set_tight_layout(True) return figLet's take a gander at a random selection of training images.
exec(open("functions.py").read())x_train, y_train, x_test, y_test, labels = setup_load_cifar(verbose=True)indices = [np.random.choice(range(len(x_train))) for i in range(36)]cifar_grid(x_train,y_train,indices,6,labels)We'll use a simple convolutional network model (still under development), with the addition of the data augmentation defined above, and a checkpoint-writing callback that's keyed to significant accuracy improvements.
from keras.models import Sequentialfrom keras.layers.core import Dense, Dropout, Flattenfrom keras.layers.convolutional import Conv2Dfrom keras.optimizers import Adamfrom keras.layers.pooling import MaxPooling2Dfrom keras.callbacks import ModelCheckpoint,EarlyStoppingmodel = Sequential()model.add(Conv2D(32, kernel_size=(3, 3), activation="relu", input_shape=x_train.shape[1:]))model.add(Conv2D(64, kernel_size=(3, 3), activation="relu"))model.add(MaxPooling2D(pool_size=(2, 2)))model.add(Dropout(0.25))model.add(Conv2D(128, kernel_size=(3, 3), activation="relu"))model.add(MaxPooling2D(pool_size=(2, 2)))model.add(Conv2D(128, kernel_size=(3, 3), activation="relu"))model.add(MaxPooling2D(pool_size=(2, 2)))model.add(Dropout(0.25))model.add(Flatten())model.add(Dense(1024, activation="relu"))model.add(Dropout(0.5))model.add(Dense(10, activation="softmax"))# initiate Adam optimizeropt = Adam(lr=0.0001, decay=1e-6)# Let's train the model using RMSpropmodel.compile(loss="categorical_crossentropy", optimizer=opt, metrics=["accuracy"])# checkpoint callbackfilepath = os.path.join(ckpt_dir, "weights-improvement-{epoch:02d}-{val_acc:.6f}.hdf5")checkpoint = ModelCheckpoint( filepath, monitor="val_acc", verbose=1, save_best_only=True, mode="max")print("Saving improvement checkpoints to \n\t{0}".format(filepath))# early stop callback, given a bit more leewaystahp = EarlyStopping(min_delta=0.00001, patience=25)Finally, let's take a look at our model, with both a text summary and a flow chart.
exec(open("model.py").read())from keras.utils import plot_modelplot_model(model, to_file="/results/model.svg", show_layer_names=True, show_shapes=True, rankdir="TB")print(model.summary())
Interested in a new type of notebook?Try Nextjournal. The notebook for reproducible research.
- Automatically version-controlled all the time
- Supports Python, R, Julia, Clojure and more
- Invite co-workers, collaborate in real-time
- Import your existing Jupyter notebooks
Training
Now we're ready to train using the GPU. We'll put this in a separate runtime configured to use a dedicated GPU compute node. This will have an initialization cell, and then two training cells: one to do some serious long-term training (takes hours), and one which just runs a few additional epochs as an example.
exec(open("functions.py").read())#os.environ["CUDA_VISIBLE_DEVICES"] = "" # for testingos.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"x_train, y_train, x_test, y_test, labels = setup_load_cifar(verbose=True)setup_tf(random_seed)datagen = setup_data_aug()# Compute quantities required for feature-wise normalization# (std, mean, and principal components if ZCA whitening is applied).datagen.fit(x_train)exec(open("model.py").read()) The training cells will save their models and weights to /results, and then for analysis and visualization we'll just need to load that data. We'll also pickle the training history for the long run to /results so we can take a look at that.
Long Training
Keep locked to prevent spurious re-runs (takes about 2.5 hours on K80).
epochs = epochs_longruncpf = last_ckpt(ckpt_dir)if cpf != "": print("Loading starting weights from \n\t{0}".format(cpf)) model.load_weights(cpf)# Fit the model on the batches generated by datagen.flow().hist = model.fit_generator(datagen.flow(x_train, y_train, batch_size=batch_size,shuffle=True), steps_per_epoch=x_train.shape[0] // batch_size, epochs=epochs,verbose=2, validation_data=(x_test, y_test), workers=4, callbacks=[checkpoint,stahp])# Save model and weightsmodel.save(model_path)#print('Saved trained model at %s ' % model_path)with open(hist_path, 'wb') as f: pickle.dump(hist.history, f)Short Example
epochs = epochs_shortrun# load results of long training runmodel.load_weights(convnet_cifar10.kerasave)# Fit the model on the batches generated by datagen.flow().hist = model.fit_generator(datagen.flow(x_train, y_train, batch_size=batch_size), steps_per_epoch=x_train.shape[0] // batch_size, epochs=epochs,verbose=2, validation_data=(x_test, y_test), workers=4, callbacks=[checkpoint])# Save model and weightsmodel.save(model_path)print('Saved trained model at %s ' % model_path)Results
Alrighty, now we can take a look at the trained model. The load_model() function will give us back our full, trained model for evaluation and prediction.
exec(open("functions.py").read())from keras.models import load_modelx_train, y_train, x_test, y_test, labels = setup_load_cifar()datagen = setup_data_aug()datagen.fit(x_train)model = load_model(convnet_cifar10.kerasave)# Evaluate model with test data setevaluation = model.evaluate_generator(datagen.flow(x_test, y_test, batch_size=batch_size, shuffle=False), steps=x_test.shape[0] // batch_size, workers=4)# Print out final values of all metricskey2name = {'acc':'Accuracy', 'loss':'Loss', 'val_acc':'Validation Accuracy', 'val_loss':'Validation Loss'}results = []for i,key in enumerate(model.metrics_names): results.append('%s = %.2f' % (key2name[key], evaluation[i]))print(", ".join(results))We can sample the prediction results with images.
num_predictions = 36model = load_model(convnet_cifar10.kerasave)predict_gen = model.predict_generator(datagen.flow(x_test, y_test, batch_size=batch_size, shuffle=False), steps=(x_test.shape[0] // batch_size)+1, workers=4)indices = [np.random.choice(range(len(x_test))) for i in range(num_predictions)]cifar_grid(x_test,y_test,indices,6,labels,predictions=predict_gen)And hey, let's take a look at the training history (we'll look at the long training so it's an interesting history).
import matplotlib.pyplot as pltwith open(convnet_cifar10.kerashist, 'rb') as f: hist = pickle.load(f)key2name = {'acc':'Accuracy', 'loss':'Loss', 'val_acc':'Validation Accuracy', 'val_loss':'Validation Loss'}fig = plt.figure()things = ['acc','loss','val_acc','val_loss']for i,thing in enumerate(things): trace = hist[thing] plt.subplot(2,2,i+1) plt.plot(range(len(trace)),trace) plt.title(key2name[thing])fig.set_tight_layout(True)figFinally, we're back to where we started! Now we can test our trained network against new images. Going back to our cat photo at the top...

import tensorflow,kerassess = keras.backend.get_session()model = keras.models.load_model(convnet_cifar10.kerasave)_,_,_,_,labels = setup_load_cifar()img = tensorflow.read_file(Qat.jpg)img = tensorflow.image.decode_jpeg(img, channels=3)img.set_shape([None, None, 3])img = tensorflow.image.resize_images(img, (32, 32))img = img.eval(session=sess) # convert to numpy arrayimg = np.expand_dims(img, 0) # make 'batch' of 1pred = model.predict(img)pred = labels["label_names"][np.argmax(pred)]pred...again, we are assured that we have acquired a
Interested in a new type of notebook?Try Nextjournal. The notebook for reproducible research.
- Automatically version-controlled all the time
- Supports Python, R, Julia, Clojure and more
- Invite co-workers, collaborate in real-time
- Import your existing Jupyter notebooks