mpd/image-classification-with-tflearn

DraftYou are viewing the latest draft from Oct 03 2017, 23:05.

Image Classification with TFLearn

A TFLearn/Tensorflow Deep Residual Network Applied to the CIFAR-10 Dataset

Micah Dombrowski

The TFLearn project is a higher-level API on top of Tensorflow, allowing for more fast and natural construction of neural networks, as well as easy training, evaluation, and prediction. In this article, we'll use TFLearn to classify the CIFAR-10 dataset.

1.
Setup

To run TFLearn we'll need tensorflow-gpu backend, as well as the Python HDF5 package and SciPy. We'll get TFLearn and a pickle alternative via pip, and also upgrade matplotlib for some visualization.

conda install -qy -c jjh_cio_testing/label/in_defaults \
    tensorflow-gpu \
    h5py \
    scipy
pip install --upgrade tflearn dill matplotlib
Done
Inherit filesystem or process state from another cell. This cell will be flagged for execution if the parent changes.
Learn more
Bash












































































The CIFAR-10 dataset is a collection of 60,000 color, 32x32-pixel images in ten classes, 10,000 of which are in a test batch. TFLearn can automatically download the dataset, but let's save some time by doing it ourselves and uploading the file to the article's permanent storage.

cifar-10-python.tar.gz
Open in new window

Then we'll copy the file to our local filesystem and extract it so that TFLearn will pick it up.

mkdir -p cifar
cd cifar
cp cifar-10-python.tar.gz ./cifar-10-python.tar.gz
tar -zvxf cifar-10-python.tar.gz
Done
: setup
Inherit filesystem or process state from another cell. This cell will be flagged for execution if the parent changes.
Learn more
Bash










# Residual blocks
# 32 layers: n=5, 56 layers: n=9, 110 layers: n=18
n = 5
batch_size = 1000
num_classes = 10
epochs_shortrun = 5
epochs_longrun = 200
save_dir = "/files"
res_dir = "/results"
model_name = 'resnet_cifar10'
Done
: extract cifar
Inherit filesystem or process state from another cell. This cell will be flagged for execution if the parent changes.
Learn more
Python









from __future__ import division, print_function, absolute_import
import tflearn
from tflearn.datasets import cifar10
import tensorflow as tf

import os, sys, tarfile, glob, shutil
import dill as pickle
import numpy as np

# set random seeds for reproducibility
tf.reset_default_graph()
tflearn.init_graph(seed=343)
tf.set_random_seed(343)
np.random.seed(343)

# makes Tensorflow shush about SSE and such
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

# Data loading
(X, Y), (testX, testY) = cifar10.load_data(dirname="/cifar")
Y = tflearn.data_utils.to_categorical(Y, num_classes)
testY = tflearn.data_utils.to_categorical(testY, num_classes)

save_fn = model_name + ".tfsave"
save_file = os.path.join(save_dir, save_fn)

ckpt_dir = os.path.join(save_dir,"checkpoints")
if not os.path.isdir(ckpt_dir):
    os.makedirs(ckpt_dir)

tblog_dir = os.path.join(save_dir,"tflogs")
if not os.path.isdir(tblog_dir):
    os.makedirs(tblog_dir)
event_dir = os.path.join(tblog_dir,model_name)

# Function to find latest checkpoint file
def last_ckpt(dir):
  fl = os.listdir(dir)
  fl = [x for x in fl if x.endswith(".index")]
  cf = ""
  if len(fl) > 0:
    steps = [float(x.split("-")[1][0:-6]) for x in fl]
    m = max(steps)
    cf = fl[steps.index(m)]
    cf = os.path.join(dir,cf)
  
  return(cf)

import matplotlib.pyplot as plt
from math import *

with open("/cifar/cifar-10-batches-py/batches.meta", 'rb') as fo:
    labels = pickle.load(fo)

#Visualizing CIFAR 10, takes indicides and shows in a grid
def cifar_grid(X,Y,inds,n_col, predictions=None):
  if predictions is not None:
    if Y.shape != predictions.shape:
      print("Predictions must equal Y in length!")
      return(None)
  N = len(inds)
  n_row = int(ceil(1.0*N/n_col))
  fig, axes = plt.subplots(n_row,n_col,figsize=(10,10))
  
  clabels = labels["label_names"]
  for j in range(n_row):
    for k in range(n_col):
      i_inds = j*n_col+k
      i_data = inds[i_inds]
      
      axes[j][k].set_axis_off()
      if i_inds < N:
        axes[j][k].imshow(X[i_data,...], interpolation='nearest')
        label = clabels[np.argmax(Y[i_data,...])]
        axes[j][k].set_title(label)
        if predictions is not None:
          pred = clabels[np.argmax(predictions[i_data,...])]
          if label != pred:
            label += " n"
            axes[j][k].set_title(pred, color='red')            
  
  fig.set_tight_layout(True)
  return(fig)

# Real-time data preprocessing
img_prep = tflearn.ImagePreprocessing()
img_prep.add_featurewise_zero_center(per_channel=True, 
    mean=[ 0.49139968, 0.48215841, 0.44653091 ])

# Real-time data augmentation
img_aug = tflearn.ImageAugmentation()
img_aug.add_random_flip_leftright()
img_aug.add_random_crop([32, 32], padding=4)
Done
: settings
Inherit filesystem or process state from another cell. This cell will be flagged for execution if the parent changes.
Learn more
Python































































































Let's look at a random selection of the dataset images.

indices = [np.random.choice(range(len(X))) for i in range(36)]

cifar_grid(X,Y,indices,6)
Done
: load modules & data
Inherit filesystem or process state from another cell. This cell will be flagged for execution if the parent changes.
Learn more
Python



The model we're using is a deep residual network design taken from the TFLearn example scripts.

# Building Residual Network
net = tflearn.input_data(shape=[None, 32, 32, 3],
                         data_preprocessing=img_prep,
                         data_augmentation=img_aug)
net = tflearn.conv_2d(net, 16, 3, regularizer='L2', weight_decay=0.0001)
net = tflearn.residual_block(net, n, 16)
net = tflearn.residual_block(net, 1, 32, downsample=True)
net = tflearn.residual_block(net, n-1, 32)
net = tflearn.residual_block(net, 1, 64, downsample=True)
net = tflearn.residual_block(net, n-1, 64)
net = tflearn.batch_normalization(net)
net = tflearn.activation(net, 'relu')
net = tflearn.global_avg_pool(net)

# Regression
net = tflearn.fully_connected(net, num_classes, activation='softmax')
mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=32000, staircase=True)
net = tflearn.regression(net, optimizer=mom,
                         loss='categorical_crossentropy')

# define the early-stop callback
class EarlyStoppingCallback(tflearn.callbacks.Callback):
    def __init__(self, val_loss_thresh, val_loss_patience):
        """ minimum loss improvement setup """
        self.val_loss_thresh = val_loss_thresh
        self.val_loss_last = float('inf')
        self.val_loss_patience = val_loss_patience
        self.val_loss_squint = 0
    
    def on_batch_end(self, training_state, snapshot=False):
        """ loss improvement threshold w/ patience """
        # Apparently this can happen.
        if training_state.val_loss is None: return
        
        if (self.val_loss_last - training_state.val_loss) < self.val_loss_thresh:
          # unacceptable!
          if self.val_loss_squint >= self.val_loss_patience:
            raise StopIteration
          else:
            self.val_loss_squint += 1
        else:
          # we good again - reset
          self.val_loss_last = training_state.val_loss
          self.val_loss_squint = 0
Done
: load modules & data
Inherit filesystem or process state from another cell. This cell will be flagged for execution if the parent changes.
Learn more
Python





















2.
Training

Now we're ready to train using the GPU. Note that a GPU-using cell acts as an endpoint to its inheritance branch, because the underlying system that allows inheritance does not currenly know how to handle GPU activity. So, the training cells will save their results to /files/, and then for analysis and visualization we'll just need to load that data.

The /files/checkpoints—this would allow for additional refinement of accuracy, although such runs would probably require removal or relaxation of the EarlyStopping callback.

# Initialize model
ckpt_file = os.path.join(ckpt_dir,"model.ckpt")
model = tflearn.DNN(net, checkpoint_path=ckpt_file,
                    max_checkpoints=10, clip_gradients=0.,
                    tensorboard_dir=tblog_dir,tensorboard_verbose=0)

# disabled until directories can be written to /results
#cff = last_ckpt(ckpt_dir)
#if cff != "":
#  print("Loading ",cff,"...")
#  model.load(cff)

# Initializae our callback.
early_stopping_cb = EarlyStoppingCallback(
  val_loss_thresh=0.001, val_loss_patience=25)

print("Starting to train...")

# checkpoints disabled until directories can be written to /results
try:
  model.fit(X, Y, n_epoch=epochs_longrun, validation_set=(testX, testY),
     snapshot_epoch=True, snapshot_step=None,
     show_metric=True, batch_size=batch_size, shuffle=True,
     run_id='resnet_cifar10',
     callbacks=early_stopping_cb)
except StopIteration:
    print("Got bored, stopping early.")

print("Training complete.")

model.save(save_file)

# can only save single files to /results, so let's tar the saves
tar_file = os.path.join(res_dir,model_name)+".tar.bz2"
with tarfile.open(tar_file, "w:bz2") as tar:
  for name in [x for x in os.listdir(save_dir) 
               if x.startswith(save_fn)]:
    tar.add(os.path.join(save_dir, name), arcname=name)

# copy events file to /results for history plotting
evfiles = filter(os.path.isfile, glob.glob(os.path.join(event_dir, 
                                           "events.out.tfevents.*")))
evfiles.sort(key=lambda x: os.path.getmtime(x))
shutil.copyfile(os.path.join(event_dir,evfiles[-1]),
                os.path.join(res_dir,model_name+".tfevents"))
Done
: build model
Inherit filesystem or process state from another cell. This cell will be flagged for execution if the parent changes.
Learn more
Python




































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































resnet_cifar10.tar.bz2
Open in new window
resnet_cifar10.tfevents
Open in new window

3.
Results

Alrighty, now we can take a look at the trained model. We'll use the definitions from earlier, initialize the model, then load weights from a save file for evaluation and prediction.

# Initialize model
model = tflearn.DNN(net, clip_gradients=0.)

tar = tarfile.open(train.resnet_cifar10.tar.bz2, "r:bz2")
tar.extractall("./")
tar.close()

if (os.path.isfile(save_fn+".index") and
    os.path.isfile(save_fn+".meta") and
    os.path.isfile(save_fn+".data-00000-of-00001")):
  print("Loading {0}...".format(save_fn))
  model.load(save_fn)
else:
  cff = last_ckpt(ckpt_dir)
  if cff != "":
    print("Loading {0}...".format(cff))
    model.load(cff)
  else:
    raise ValueError("No results files found!")

acc = model.evaluate(testX, testY)

# While we've got it set up, run predictions 
# for the test batch and save to file
predY = model.predict(testX)
with open("/results/test_predictions.dat","wb") as f:
  pickle.dump(predY,f)

print("Average accuracy: {:.2f}%.".format(acc[0]*100))
Done
: build model
Inherit filesystem or process state from another cell. This cell will be flagged for execution if the parent changes.
Learn more
Python































test_predictions.dat
Open in new window

We can sample the prediction results with images. While we can't depend on the prior cell, it has to be run to make our data file.

with open(results analysis.test_predictions.dat,"rb") as f:
  predY = pickle.load(f)

indices = [np.random.choice(range(len(testX))) for i in range(36)]

cifar_grid(testX,testY,range(36),6, predictions=predY)
Done
: load modules & data
Inherit filesystem or process state from another cell. This cell will be flagged for execution if the parent changes.
Learn more
Python






TFLearn tells Tensorflow to write to tfevents files during training. These are geared towards the interactive analysis and visualization program tensorboard, but with a bit of work we can pull the training history out of these files and plot it ourselves.

from tensorflow.tensorboard.backend.event_processing import event_accumulator

# need to copy out to get .tfevents extension, because...raisins
shutil.copy(train.resnet_cifar10.tfevents,"/tmp/hist.tfevents")

ea = event_accumulator.EventAccumulator("/tmp/hist.tfevents",
  size_guidance={ # see below regarding this argument
  event_accumulator.SCALARS: 0
})

ea.Reload() # loads events from file

# fiddly stuff to inspect tags/scalar data entries
#print(ea.Tags())
#print([x for x in ea.Tags()['scalars'] if not x.startswith("Momentum")])

# pull out four metrics, plot
hist = {
  'Accuracy' : [x.value for x in ea.Scalars('Accuracy')],
  'Validation Accuracy' : [x.value for x in 
                           ea.Scalars('Accuracy/Validation')],
  'Loss' : [x.value for x in ea.Scalars('Loss')],
  'Validation Loss' : [x.value for x in ea.Scalars('Loss/Validation')]
}

import matplotlib
import matplotlib.pyplot as plt

fig = plt.figure()
keys = ['Accuracy', 'Loss', 'Validation Accuracy', 'Validation Loss']
for i,thing in enumerate(keys):
  trace = hist[thing]
  plt.subplot(2,2,i+1)
  plt.plot(range(len(trace)),trace)
  plt.title(thing)

fig.set_tight_layout(True)
fig
Done
: load modules & data
Inherit filesystem or process state from another cell. This cell will be flagged for execution if the parent changes.
Learn more
Python