Nick Doiron / Jul 03 2019
Remix of Python by Nextjournal

Hindu-Arabic MNIST and XAI

Heavily based on Seldon.IO's Alibi

and Arabic Handwritten Digits Database https://www.kaggle.com/mloey1/ahdd1/

from the American University in Cairo http://datacenter.aucegypt.edu/shazeem/

import sys; sys.version.split()[0]
'3.6.8'
pip install keras alibi matplotlib numpy tensorflow
import keras
from keras import backend as K
from keras.layers import Conv2D, Dense, Dropout, Flatten, MaxPooling2D, Input, UpSampling2D
from keras.models import Model, load_model
from keras.utils import to_categorical
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from alibi.explainers import CEM

import os, csv
csvTrainLabel60k.csv
csvTrainImages60k.csv
labels = []
with open(
csvTrainLabel60k.csv
) as label_file: csv_reader = csv.reader(label_file, delimiter=',') for row in csv_reader: if len(labels) < 20000: labels.append(int(row[0])) else: break
imagers = []
with open(
csvTrainImages60k.csv
) as images_list: csv_r = csv.reader(images_list, delimiter=',') for row in csv_r: pixels = [[]] current_row = 0 current_col = 0 for pixel in row: if current_col == 0: pixels.append([int(pixel) / 255]) else: pixels[current_row].append(int(pixel) / 255) current_row += 1 if (current_row == 28 or (current_col == 0 and current_row == 27)): current_col += 1 current_row = 0 if len(imagers) < 20000: imagers.append(pixels) else: break
imagers2 = np.asarray(imagers)
#print(imagers2[4])
plt.clf()
plt.gray()
plt.imshow(imagers2[3])
plt.gcf()
xmin, xmax = -.5, .5
x_train = np.reshape(imagers2, imagers2.shape + (1,))
x_train = ((x_train - x_train.min()) / (x_train.max() - x_train.min())) * (xmax - xmin) + xmin
y_train = to_categorical(np.asarray(labels))
x_in = Input(shape=(28, 28, 1))
x = Conv2D(filters=64, kernel_size=2, padding='same', activation='relu')(x_in)
x = MaxPooling2D(pool_size=2)(x)
x = Dropout(0.3)(x)
x = Conv2D(filters=32, kernel_size=2, padding='same', activation='relu')(x)
x = MaxPooling2D(pool_size=2)(x)
x = Dropout(0.3)(x)
x = Flatten()(x)
x = Dense(256, activation='relu')(x)
x = Dropout(0.5)(x)
x_out = Dense(10, activation='softmax')(x)
    
cnn = Model(inputs=x_in, outputs=x_out)
cnn.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
cnn.summary()
cnn.fit(x_train, y_train, batch_size=64, epochs=5, verbose=0)
cnn.save('mnist_cnn.h5')
csvTestLabel10k.csv
csvTestImages10k.csv
testlabels = []
with open(
csvTestLabel10k.csv
) as test_label_file: csv_reader = csv.reader(test_label_file, delimiter=',') for row in csv_reader: testlabels.append(int(row[0]))
testimgs = []
with open(
csvTestImages10k.csv
) as test_images_file: csv_r = csv.reader(test_images_file, delimiter=',') for row in csv_r: pixels = [[]] current_row = 0 current_col = 0 for pixel in row: if current_col == 0: pixels.append([int(pixel) / 255]) else: pixels[current_row].append(int(pixel) / 255) current_row += 1 if (current_row == 28 or (current_col == 0 and current_row == 27)): current_col += 1 current_row = 0 testimgs.append(pixels)
testimgs2 = np.asarray(testimgs)
x_test = np.reshape(testimgs2, testimgs2.shape + (1,))
x_test = ((x_test - x_test.min()) / (x_test.max() - x_test.min())) * (xmax - xmin) + xmin
testlabels2 = np.asarray(testlabels)
y_test = to_categorical(np.asarray(testlabels2))
cnn.evaluate(x_test, y_test, verbose=0)
[0.05234822061315062, 0.9844]

TrustScore

from alibi.confidence import TrustScore

ts = TrustScore()
x_train_cnn = cnn.predict(x_train)
ts.fit(x_train_cnn, y_train, classes=10)

Making a simpler model

def sc_model():
    x_in = Input(shape=(28, 28, 1))
    x = Flatten()(x_in)
    x_out = Dense(10, activation='softmax')(x)
    sc = Model(inputs=x_in, outputs=x_out)
    sc.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
    return sc
sc = sc_model()
sc.summary()
sc.fit(x_train, y_train, batch_size=128, epochs=5, verbose=0)
<keras.callba...x7f939ee20b00>
sc.evaluate(x_test, y_test, verbose=0)
[0.5244448517799377, 0.9001]
x_test_cnn = cnn.predict(x_test)
y_pred = sc.predict(x_test)
score, closest_class = ts.score(x_test_cnn, y_pred, k=5)
n = 5
idx_min, idx_max = np.argsort(score)[:n], np.argsort(score)[-n:]
score_min, score_max = score[idx_min], score[idx_max]
closest_min, closest_max = closest_class[idx_min], closest_class[idx_max]
pred_min, pred_max = np.argmax(y_pred[idx_min], axis=1), np.argmax(y_pred[idx_max], axis=1)
imgs_min, imgs_max = x_test[idx_min], x_test[idx_max]
label_min, label_max = np.argmax(y_test[idx_min], axis=1), np.argmax(y_test[idx_max], axis=1)

Worst Trust Scores (all labels)

plt.clf()
plt.figure(figsize=(20, 6))
for i in range(n):
    ax = plt.subplot(1, n, i+1)
    plt.imshow(imgs_min[i].reshape(28, 28))
    plt.title('Model prediction: {} \n Label: {} \n Trust score: {:.3f}' \
              '\n Closest other class: {}'.format(pred_min[i], label_min[i], score_min[i], closest_min[i]))
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.gcf()

Best Trust Scores (all labels)

plt.figure(figsize=(20, 4))
for i in range(n):
    ax = plt.subplot(1, n, i+1)
    plt.imshow(imgs_max[i].reshape(28, 28))
    plt.title('Model prediction: {} \n Label: {} \n Trust score: {:.3f}'.format(pred_max[i], label_max[i], score_max[i]))
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.gcf()

Worst Trust Scores (one label)

Let's focus on only one label

y_test[1][1] == 1
True
myDigit = 5
myDigitImages = []
myDigitLabels = []
for index in range(0, len(y_test)):
  label = 0
  for matchLabel in y_test[index]:
    if matchLabel == 1:
      break
    label += 1
  if label == myDigit:
    myDigitLabels.append(myDigit)
    myDigitImages.append(x_test[index])
  index += 1
x_test_dig = np.asarray(myDigitImages)
y_test_dig = to_categorical(np.asarray(myDigitLabels))
x_test_dig_cnn = cnn.predict(x_test_dig)
y_pred_dig = sc.predict(x_test_dig)
score, closest_class = ts.score(x_test_dig_cnn, y_pred_dig, k=5)
n = 5
idx_min, idx_max = np.argsort(score)[:n], np.argsort(score)[-n:]
score_min, score_max = score[idx_min], score[idx_max]
closest_min, closest_max = closest_class[idx_min], closest_class[idx_max]
pred_min, pred_max = np.argmax(y_pred_dig[idx_min], axis=1), np.argmax(y_pred_dig[idx_max], axis=1)
imgs_min, imgs_max = x_test_dig[idx_min], x_test_dig[idx_max]
label_min, label_max = np.argmax(y_test_dig[idx_min], axis=1), np.argmax(y_test_dig[idx_max], axis=1)
plt.clf()
plt.figure(figsize=(20, 6))
for i in range(n):
    ax = plt.subplot(1, n, i+1)
    plt.imshow(imgs_min[i].reshape(28, 28))
    plt.title('Model prediction: {} \n Label: {} \n Trust score: {:.3f}' \
              '\n Closest other class: {}'.format(pred_min[i], label_min[i], score_min[i], closest_min[i]))
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.gcf()

Autoencoder (for pertinent positive and negative)

def ae_model():
    x_in = Input(shape=(28, 28, 1))
    x = Conv2D(16, (3, 3), activation='relu', padding='same')(x_in)
    x = Conv2D(16, (3, 3), activation='relu', padding='same')(x)
    x = MaxPooling2D((2, 2), padding='same')(x)
    encoded = Conv2D(1, (3, 3), activation=None, padding='same')(x)
    
    x = Conv2D(16, (3, 3), activation='relu', padding='same')(encoded)
    x = UpSampling2D((2, 2))(x)
    x = Conv2D(16, (3, 3), activation='relu', padding='same')(x)
    decoded = Conv2D(1, (3, 3), activation=None, padding='same')(x)

    autoencoder = Model(x_in, decoded)
    autoencoder.compile(optimizer='adam', loss='mse')
    
    return autoencoder

ae = ae_model()
ae.summary()
ae.fit(x_train, x_train, batch_size=128, epochs=10, validation_data=(x_test, x_test), verbose=0)
ae.save('mnist_ae.h5')

Sample Autoencoder Outputs

plt.clf()
decoded_imgs = ae.predict(x_test)
n = 5
plt.figure(figsize=(20, 4))
for i in range(1, n+1):
    # display original
    ax = plt.subplot(2, n, i)
    plt.imshow(x_test[i].reshape(28, 28))
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    # display reconstruction
    ax = plt.subplot(2, n, i + n)
    plt.imshow(decoded_imgs[i].reshape(28, 28))
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()
plt.gcf()

The CEM explainer

idx = 8
X = x_test[idx].reshape((1,) + x_test[idx].shape)
mode = 'PN'
shape = (1,) + x_train.shape[1:]
kappa = 0.
beta = .1
gamma = 100
c_init = 1.
c_steps = 10
max_iterations = 1000
feature_range = (x_train.min(),x_train.max())
clip = (-1000.,1000.)
lr = 1e-2
no_info_val = -1.
sess = tf.Session()
K.set_session(sess)
sess.run(tf.global_variables_initializer())

cnn = load_model('mnist_cnn.h5')
ae = load_model('mnist_ae.h5')
cem = CEM(sess, cnn, mode, shape, kappa=kappa, beta=beta, feature_range=feature_range, 
          gamma=gamma, ae_model=ae, max_iterations=max_iterations, 
          c_init=c_init, c_steps=c_steps, learning_rate_init=lr, clip=clip, no_info_val=no_info_val)
explanation = cem.explain(X, verbose=False)

sess.close()
K.clear_session()
plt.clf()
print('Original instance prediction: {}'.format(explanation['X_pred']))
plt.imshow(explanation['X'].reshape(28, 28))
plt.gcf()

Pertinent Negative

plt.clf()
print('Pertinent negative prediction: {}'.format(explanation[mode + '_pred']))
plt.imshow(explanation[mode].reshape(28, 28))
plt.gcf()

Pertinent Positive

mode = 'PP'
# initialize TensorFlow session before model definition
sess = tf.Session()
K.set_session(sess)
sess.run(tf.global_variables_initializer())

# define models
cnn = load_model('mnist_cnn.h5')
ae = load_model('mnist_ae.h5')

# initialize CEM explainer and explain instance
cem = CEM(sess, cnn, mode, shape, kappa=kappa, beta=beta, feature_range=feature_range, 
          gamma=gamma, ae_model=ae, max_iterations=max_iterations, 
          c_init=c_init, c_steps=c_steps, learning_rate_init=lr, clip=clip, no_info_val=no_info_val)
explanation = cem.explain(X, verbose=False)

sess.close()
K.clear_session()
plt.clf()
print('Pertinent positive prediction: {}'.format(explanation[mode + '_pred']))
plt.imshow(explanation[mode].reshape(28, 28))
plt.gcf()

Counterfactual

def ae_model():
    # encoder
    x_in = Input(shape=(28, 28, 1))
    x = Conv2D(16, (3, 3), activation='relu', padding='same')(x_in)
    x = Conv2D(16, (3, 3), activation='relu', padding='same')(x)
    x = MaxPooling2D((2, 2), padding='same')(x)
    encoded = Conv2D(1, (3, 3), activation=None, padding='same')(x)
    encoder = Model(x_in, encoded)

    # decoder
    dec_in = Input(shape=(14, 14, 1))
    x = Conv2D(16, (3, 3), activation='relu', padding='same')(dec_in)
    x = UpSampling2D((2, 2))(x)
    x = Conv2D(16, (3, 3), activation='relu', padding='same')(x)
    decoded = Conv2D(1, (3, 3), activation=None, padding='same')(x)
    decoder = Model(dec_in, decoded)

    # autoencoder = encoder + decoder
    x_out = decoder(encoder(x_in))
    autoencoder = Model(x_in, x_out)
    autoencoder.compile(optimizer='adam', loss='mse')

    return autoencoder, encoder, decoder
ae, enc, dec = ae_model()
ae.summary()
ae.fit(x_train, x_train, batch_size=128, epochs=4, validation_data=(x_test, x_test), verbose=1)
ae.save('mnist_ae.h5')
enc.save('mnist_enc.h5')
plt.clf()
X = x_test[52].reshape((1,) + x_test[52].shape)
plt.imshow(X.reshape(28, 28))
plt.gcf()
from alibi.explainers import CounterFactualProto

shape = (1,) + x_train.shape[1:]
gamma = 100.
theta = 100.
c_init = 1.
c_steps = 2
max_iterations = 500
feature_range = (x_train.min(),x_train.max())

# set random seed
np.random.seed(1)
tf.set_random_seed(1)

# define models
cnn = load_model('mnist_cnn.h5')
ae = load_model('mnist_ae.h5')
enc = load_model('mnist_enc.h5')

sess = K.get_session()

# initialize explainer, fit and generate counterfactual
cf = CounterFactualProto(sess, cnn, shape, gamma=gamma, theta=theta,
                         ae_model=ae, enc_model=enc, max_iterations=max_iterations,
                         feature_range=feature_range, c_init=c_init, c_steps=c_steps)
cf.fit(x_train)  # find class prototypes
explanation = cf.explain(X)

sess.close()
K.clear_session()
plt.clf()
print('Counterfactual prediction: {}'.format(explanation['cf']['class']))
print('Closest prototype class: {}'.format(cf.id_proto))
plt.imshow(explanation['cf']['X'].reshape(28, 28))
plt.gcf()