An Achromatic Approach to Compressing CNN Filters Using Pattern-specific Receptive Fields

Team Member: Guanzhong Chen, Shiyu Liu, Guansu(Frances) Niu, Cangcheng Tang, Zhi Wang

GitHub Repo: https://github.com/tangcc35/Canned_Pineapple

Screencast: https://youtu.be/rAFYVpmmK1Q

Final Blog Post: bit.ly/CNN_Compression

Introduction

In studies of image recognition, there are many gray-scale pictures, such as chest radiographs. Currently, the idea of training those images is to apply models that are essentially designed for training color pictures, such as DenseNet. This can cause many redundant parameters during the process. Therefore, this project aimed to discover a methodology to modify the models trained on colored images and to apply them to gray-scale images.

import tensorflow as tf
import keras
from keras.datasets import cifar10
from keras.models import Model, Sequential
from keras.layers import Dense, Dropout, Flatten, Input, AveragePooling2D, merge, Activation, SpatialDropout2D
from keras.layers import Conv2D, MaxPooling2D, BatchNormalization, SeparableConv2D
from keras.layers import Concatenate
from keras.optimizers import Adam, RMSprop, SGD
from keras import regularizers
from keras import backend as K
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import copy
from sklearn.preprocessing import MinMaxScaler
from skimage.metrics import structural_similarity as ssim
from sklearn.cluster import DBSCAN, SpectralClustering, KMeans
from sklearn.metrics import mean_squared_error
from scipy import spatial
import re
import math
mmscaler = MinMaxScaler()
Shift+Enter to run

In order to change the architecture of the DenseNet model to fit gray-scale images, this project planned to visualize the filters in each conv layer. Therefore, the original model and dataset were loaded here at the beginning.

1.Load Original DenseNet Model

# Hyperparameters
batch_size = 128
num_classes = 10
epochs = 100
l = 12
num_filter = 36 #added 24 more filters
compression = 0.5 
dropout_rate = 0.2
img_height, img_width, channel = 32, 32, 3
Shift+Enter to run
# Dense Block
# removed the dropout
def add_denseblock(input, num_filter = 12, dropout_rate = 0.2):
    global compression
    temp = input
    for _ in range(l):
        BatchNorm = BatchNormalization()(temp)
        relu = Activation('relu')(BatchNorm)
        Conv2D_3_3 = Conv2D(int(num_filter*compression), (3,3), use_bias=False ,padding='same')(relu)
        #if dropout_rate>0:
         # Conv2D_3_3 = Dropout2D(dropout_rate)(Conv2D_3_3)
        concat = Concatenate(axis=-1)([temp,Conv2D_3_3])
        
        temp = concat
        
    return temp
Shift+Enter to run
def add_transition(input, num_filter = 12, dropout_rate = 0.2):
    global compression
    BatchNorm = BatchNormalization()(input)
    relu = Activation('relu')(BatchNorm)
    Conv2D_BottleNeck = Conv2D(int(num_filter*compression), (1,1), use_bias=False, kernel_regularizer = regularizers.l1() ,padding='same')(relu)
    #if dropout_rate>0:
      #Conv2D_BottleNeck = Dropout2D(dropout_rate)(Conv2D_BottleNeck)
    avg = AveragePooling2D(pool_size=(2,2))(Conv2D_BottleNeck)
    
    return avg
Shift+Enter to run
# converted the last Dense Layer to a Fully Convolution N/w as use of Dense Layer was prohibited
def output_layer(input):
    global compression
    BatchNorm = BatchNormalization()(input)
    relu = Activation('relu')(BatchNorm)
    AvgPooling = AveragePooling2D(pool_size=(2,2))(relu)
    temp = Conv2D(num_classes, kernel_size = (2,2))(AvgPooling)
    output = Activation('softmax')(temp)
    flat = Flatten()(output)
    
    return flat
Shift+Enter to run
num_filter = 36
dropout_rate = 0.2
l= 12
input = Input(shape=(img_height, img_width, channel,))
First_Conv2D = Conv2D(num_filter, (3,3), use_bias=False ,padding='same')(input)
First_Block = add_denseblock(First_Conv2D, num_filter, dropout_rate)
First_Transition = add_transition(First_Block, num_filter, dropout_rate)
Second_Block = add_denseblock(First_Transition, num_filter, dropout_rate)
Second_Transition = add_transition(Second_Block, num_filter, dropout_rate)
Third_Block = add_denseblock(Second_Transition, num_filter, dropout_rate)
Third_Transition = add_transition(Third_Block, num_filter, dropout_rate)
Last_Block = add_denseblock(Third_Transition,  num_filter, dropout_rate)
output = output_layer(Last_Block)
Shift+Enter to run
model = Model(inputs=[input], outputs=[output])
model.summary()
Shift+Enter to run
Model: "model_1" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_1 (InputLayer) (None, 32, 32, 3) 0 __________________________________________________________________________________________________ conv2d_1 (Conv2D) (None, 32, 32, 36) 972 input_1[0][0] __________________________________________________________________________________________________ batch_normalization_1 (BatchNor (None, 32, 32, 36) 144 conv2d_1[0][0] __________________________________________________________________________________________________ activation_1 (Activation) (None, 32, 32, 36) 0 batch_normalization_1[0][0] __________________________________________________________________________________________________ conv2d_2 (Conv2D) (None, 32, 32, 18) 5832 activation_1[0][0] __________________________________________________________________________________________________ concatenate_1 (Concatenate) (None, 32, 32, 54) 0 conv2d_1[0][0] conv2d_2[0][0] __________________________________________________________________________________________________ batch_normalization_2 (BatchNor (None, 32, 32, 54) 216 concatenate_1[0][0] __________________________________________________________________________________________________ activation_2 (Activation) (None, 32, 32, 54) 0 batch_normalization_2[0][0] __________________________________________________________________________________________________ conv2d_3 (Conv2D) (None, 32, 32, 18) 8748 activation_2[0][0] __________________________________________________________________________________________________ concatenate_2 (Concatenate) (None, 32, 32, 72) 0 concatenate_1[0][0] conv2d_3[0][0] __________________________________________________________________________________________________ batch_normalization_3 (BatchNor (None, 32, 32, 72) 288 concatenate_2[0][0] __________________________________________________________________________________________________ activation_3 (Activation) (None, 32, 32, 72) 0 batch_normalization_3[0][0] __________________________________________________________________________________________________ conv2d_4 (Conv2D) (None, 32, 32, 18) 11664 activation_3[0][0] __________________________________________________________________________________________________ concatenate_3 (Concatenate) (None, 32, 32, 90) 0 concatenate_2[0][0] conv2d_4[0][0] __________________________________________________________________________________________________ batch_normalization_4 (BatchNor (None, 32, 32, 90) 360 concatenate_3[0][0] __________________________________________________________________________________________________ activation_4 (Activation) (None, 32, 32, 90) 0 batch_normalization_4[0][0] __________________________________________________________________________________________________ conv2d_5 (Conv2D) (None, 32, 32, 18) 14580 activation_4[0][0] __________________________________________________________________________________________________ concatenate_4 (Concatenate) (None, 32, 32, 108) 0 concatenate_3[0][0] conv2d_5[0][0] __________________________________________________________________________________________________ batch_normalization_5 (BatchNor (None, 32, 32, 108) 432 concatenate_4[0][0] __________________________________________________________________________________________________ activation_5 (Activation) (None, 32, 32, 108) 0 batch_normalization_5[0][0] __________________________________________________________________________________________________ conv2d_6 (Conv2D) (None, 32, 32, 18) 17496 activation_5[0][0] __________________________________________________________________________________________________ concatenate_5 (Concatenate) (None, 32, 32, 126) 0 concatenate_4[0][0] conv2d_6[0][0] __________________________________________________________________________________________________ batch_normalization_6 (BatchNor (None, 32, 32, 126) 504 concatenate_5[0][0] __________________________________________________________________________________________________ activation_6 (Activation) (None, 32, 32, 126) 0 batch_normalization_6[0][0] __________________________________________________________________________________________________ conv2d_7 (Conv2D) (None, 32, 32, 18) 20412 activation_6[0][0] __________________________________________________________________________________________________ concatenate_6 (Concatenate) (None, 32, 32, 144) 0 concatenate_5[0][0] conv2d_7[0][0] __________________________________________________________________________________________________ batch_normalization_7 (BatchNor (None, 32, 32, 144) 576 concatenate_6[0][0] __________________________________________________________________________________________________ activation_7 (Activation) (None, 32, 32, 144) 0 batch_normalization_7[0][0] __________________________________________________________________________________________________ conv2d_8 (Conv2D) (None, 32, 32, 18) 23328 activation_7[0][0] __________________________________________________________________________________________________ concatenate_7 (Concatenate) (None, 32, 32, 162) 0 concatenate_6[0][0] conv2d_8[0][0] __________________________________________________________________________________________________ batch_normalization_8 (BatchNor (None, 32, 32, 162) 648 concatenate_7[0][0] __________________________________________________________________________________________________ activation_8 (Activation) (None, 32, 32, 162) 0 batch_normalization_8[0][0] __________________________________________________________________________________________________ conv2d_9 (Conv2D) (None, 32, 32, 18) 26244 activation_8[0][0] __________________________________________________________________________________________________ concatenate_8 (Concatenate) (None, 32, 32, 180) 0 concatenate_7[0][0] conv2d_9[0][0] __________________________________________________________________________________________________ batch_normalization_9 (BatchNor (None, 32, 32, 180) 720 concatenate_8[0][0] __________________________________________________________________________________________________ activation_9 (Activation) (None, 32, 32, 180) 0 batch_normalization_9[0][0] __________________________________________________________________________________________________ conv2d_10 (Conv2D) (None, 32, 32, 18) 29160 activation_9[0][0] __________________________________________________________________________________________________ concatenate_9 (Concatenate) (None, 32, 32, 198) 0 concatenate_8[0][0] conv2d_10[0][0] __________________________________________________________________________________________________ batch_normalization_10 (BatchNo (None, 32, 32, 198) 792 concatenate_9[0][0] __________________________________________________________________________________________________ activation_10 (Activation) (None, 32, 32, 198) 0 batch_normalization_10[0][0] __________________________________________________________________________________________________ conv2d_11 (Conv2D) (None, 32, 32, 18) 32076 activation_10[0][0] __________________________________________________________________________________________________ concatenate_10 (Concatenate) (None, 32, 32, 216) 0 concatenate_9[0][0] conv2d_11[0][0] __________________________________________________________________________________________________ batch_normalization_11 (BatchNo (None, 32, 32, 216) 864 concatenate_10[0][0] __________________________________________________________________________________________________ activation_11 (Activation) (None, 32, 32, 216) 0 batch_normalization_11[0][0] __________________________________________________________________________________________________ conv2d_12 (Conv2D) (None, 32, 32, 18) 34992 activation_11[0][0] __________________________________________________________________________________________________ concatenate_11 (Concatenate) (None, 32, 32, 234) 0 concatenate_10[0][0] conv2d_12[0][0] __________________________________________________________________________________________________ batch_normalization_12 (BatchNo (None, 32, 32, 234) 936 concatenate_11[0][0] __________________________________________________________________________________________________ activation_12 (Activation) (None, 32, 32, 234) 0 batch_normalization_12[0][0] __________________________________________________________________________________________________ conv2d_13 (Conv2D) (None, 32, 32, 18) 37908 activation_12[0][0] __________________________________________________________________________________________________ concatenate_12 (Concatenate) (None, 32, 32, 252) 0 concatenate_11[0][0] conv2d_13[0][0] __________________________________________________________________________________________________ batch_normalization_13 (BatchNo (None, 32, 32, 252) 1008 concatenate_12[0][0] __________________________________________________________________________________________________ activation_13 (Activation) (None, 32, 32, 252) 0 batch_normalization_13[0][0] __________________________________________________________________________________________________ conv2d_14 (Conv2D) (None, 32, 32, 18) 4536 activation_13[0][0] __________________________________________________________________________________________________ average_pooling2d_1 (AveragePoo (None, 16, 16, 18) 0 conv2d_14[0][0] __________________________________________________________________________________________________ batch_normalization_14 (BatchNo (None, 16, 16, 18) 72 average_pooling2d_1[0][0] __________________________________________________________________________________________________ activation_14 (Activation) (None, 16, 16, 18) 0 batch_normalization_14[0][0] __________________________________________________________________________________________________ conv2d_15 (Conv2D) (None, 16, 16, 18) 2916 activation_14[0][0] __________________________________________________________________________________________________ concatenate_13 (Concatenate) (None, 16, 16, 36) 0 average_pooling2d_1[0][0] conv2d_15[0][0] __________________________________________________________________________________________________ batch_normalization_15 (BatchNo (None, 16, 16, 36) 144 concatenate_13[0][0] __________________________________________________________________________________________________ activation_15 (Activation) (None, 16, 16, 36) 0 batch_normalization_15[0][0] __________________________________________________________________________________________________ conv2d_16 (Conv2D) (None, 16, 16, 18) 5832 activation_15[0][0] __________________________________________________________________________________________________ concatenate_14 (Concatenate) (None, 16, 16, 54) 0 concatenate_13[0][0] conv2d_16[0][0] __________________________________________________________________________________________________ batch_normalization_16 (BatchNo (None, 16, 16, 54) 216 concatenate_14[0][0] __________________________________________________________________________________________________ activation_16 (Activation) (None, 16, 16, 54) 0 batch_normalization_16[0][0] __________________________________________________________________________________________________ conv2d_17 (Conv2D) (None, 16, 16, 18) 8748 activation_16[0][0] __________________________________________________________________________________________________ concatenate_15 (Concatenate) (None, 16, 16, 72) 0 concatenate_14[0][0] conv2d_17[0][0] __________________________________________________________________________________________________ batch_normalization_17 (BatchNo (None, 16, 16, 72) 288 concatenate_15[0][0] __________________________________________________________________________________________________ activation_17 (Activation) (None, 16, 16, 72) 0 batch_normalization_17[0][0] __________________________________________________________________________________________________ conv2d_18 (Conv2D) (None, 16, 16, 18) 11664 activation_17[0][0] __________________________________________________________________________________________________ concatenate_16 (Concatenate) (None, 16, 16, 90) 0 concatenate_15[0][0] conv2d_18[0][0] __________________________________________________________________________________________________ batch_normalization_18 (BatchNo (None, 16, 16, 90) 360 concatenate_16[0][0] __________________________________________________________________________________________________ activation_18 (Activation) (None, 16, 16, 90) 0 batch_normalization_18[0][0] __________________________________________________________________________________________________ conv2d_19 (Conv2D) (None, 16, 16, 18) 14580 activation_18[0][0] __________________________________________________________________________________________________ concatenate_17 (Concatenate) (None, 16, 16, 108) 0 concatenate_16[0][0] conv2d_19[0][0] __________________________________________________________________________________________________ batch_normalization_19 (BatchNo (None, 16, 16, 108) 432 concatenate_17[0][0] __________________________________________________________________________________________________ activation_19 (Activation) (None, 16, 16, 108) 0 batch_normalization_19[0][0] __________________________________________________________________________________________________ conv2d_20 (Conv2D) (None, 16, 16, 18) 17496 activation_19[0][0] __________________________________________________________________________________________________ concatenate_18 (Concatenate) (None, 16, 16, 126) 0 concatenate_17[0][0] conv2d_20[0][0] __________________________________________________________________________________________________ batch_normalization_20 (BatchNo (None, 16, 16, 126) 504 concatenate_18[0][0] __________________________________________________________________________________________________ activation_20 (Activation) (None, 16, 16, 126) 0 batch_normalization_20[0][0] __________________________________________________________________________________________________ conv2d_21 (Conv2D) (None, 16, 16, 18) 20412 activation_20[0][0] __________________________________________________________________________________________________ concatenate_19 (Concatenate) (None, 16, 16, 144) 0 concatenate_18[0][0] conv2d_21[0][0] __________________________________________________________________________________________________ batch_normalization_21 (BatchNo (None, 16, 16, 144) 576 concatenate_19[0][0] __________________________________________________________________________________________________ activation_21 (Activation) (None, 16, 16, 144) 0 batch_normalization_21[0][0] __________________________________________________________________________________________________ conv2d_22 (Conv2D) (None, 16, 16, 18) 23328 activation_21[0][0] __________________________________________________________________________________________________ concatenate_20 (Concatenate) (None, 16, 16, 162) 0 concatenate_19[0][0] conv2d_22[0][0] __________________________________________________________________________________________________ batch_normalization_22 (BatchNo (None, 16, 16, 162) 648 concatenate_20[0][0] __________________________________________________________________________________________________ activation_22 (Activation) (None, 16, 16, 162) 0 batch_normalization_22[0][0] __________________________________________________________________________________________________ conv2d_23 (Conv2D) (None, 16, 16, 18) 26244 activation_22[0][0] __________________________________________________________________________________________________ concatenate_21 (Concatenate) (None, 16, 16, 180) 0 concatenate_20[0][0] conv2d_23[0][0] __________________________________________________________________________________________________ batch_normalization_23 (BatchNo (None, 16, 16, 180) 720 concatenate_21[0][0] __________________________________________________________________________________________________ activation_23 (Activation) (None, 16, 16, 180) 0 batch_normalization_23[0][0] __________________________________________________________________________________________________ conv2d_24 (Conv2D) (None, 16, 16, 18) 29160 activation_23[0][0] __________________________________________________________________________________________________ concatenate_22 (Concatenate) (None, 16, 16, 198) 0 concatenate_21[0][0] conv2d_24[0][0] __________________________________________________________________________________________________ batch_normalization_24 (BatchNo (None, 16, 16, 198) 792 concatenate_22[0][0] __________________________________________________________________________________________________ activation_24 (Activation) (None, 16, 16, 198) 0 batch_normalization_24[0][0] __________________________________________________________________________________________________ conv2d_25 (Conv2D) (None, 16, 16, 18) 32076 activation_24[0][0] __________________________________________________________________________________________________ concatenate_23 (Concatenate) (None, 16, 16, 216) 0 concatenate_22[0][0] conv2d_25[0][0] __________________________________________________________________________________________________ batch_normalization_25 (BatchNo (None, 16, 16, 216) 864 concatenate_23[0][0] __________________________________________________________________________________________________ activation_25 (Activation) (None, 16, 16, 216) 0 batch_normalization_25[0][0] __________________________________________________________________________________________________ conv2d_26 (Conv2D) (None, 16, 16, 18) 34992 activation_25[0][0] __________________________________________________________________________________________________ concatenate_24 (Concatenate) (None, 16, 16, 234) 0 concatenate_23[0][0] conv2d_26[0][0] __________________________________________________________________________________________________ batch_normalization_26 (BatchNo (None, 16, 16, 234) 936 concatenate_24[0][0] __________________________________________________________________________________________________ activation_26 (Activation) (None, 16, 16, 234) 0 batch_normalization_26[0][0] __________________________________________________________________________________________________ conv2d_27 (Conv2D) (None, 16, 16, 18) 4212 activation_26[0][0] __________________________________________________________________________________________________ average_pooling2d_2 (AveragePoo (None, 8, 8, 18) 0 conv2d_27[0][0] __________________________________________________________________________________________________ batch_normalization_27 (BatchNo (None, 8, 8, 18) 72 average_pooling2d_2[0][0] __________________________________________________________________________________________________ activation_27 (Activation) (None, 8, 8, 18) 0 batch_normalization_27[0][0] __________________________________________________________________________________________________ conv2d_28 (Conv2D) (None, 8, 8, 18) 2916 activation_27[0][0] __________________________________________________________________________________________________ concatenate_25 (Concatenate) (None, 8, 8, 36) 0 average_pooling2d_2[0][0] conv2d_28[0][0] __________________________________________________________________________________________________ batch_normalization_28 (BatchNo (None, 8, 8, 36) 144 concatenate_25[0][0] __________________________________________________________________________________________________ activation_28 (Activation) (None, 8, 8, 36) 0 batch_normalization_28[0][0] __________________________________________________________________________________________________ conv2d_29 (Conv2D) (None, 8, 8, 18) 5832 activation_28[0][0] __________________________________________________________________________________________________ concatenate_26 (Concatenate) (None, 8, 8, 54) 0 concatenate_25[0][0] conv2d_29[0][0] __________________________________________________________________________________________________ batch_normalization_29 (BatchNo (None, 8, 8, 54) 216 concatenate_26[0][0] __________________________________________________________________________________________________ activation_29 (Activation) (None, 8, 8, 54) 0 batch_normalization_29[0][0] __________________________________________________________________________________________________ conv2d_30 (Conv2D) (None, 8, 8, 18) 8748 activation_29[0][0] __________________________________________________________________________________________________ concatenate_27 (Concatenate) (None, 8, 8, 72) 0 concatenate_26[0][0] conv2d_30[0][0] __________________________________________________________________________________________________ batch_normalization_30 (BatchNo (None, 8, 8, 72) 288 concatenate_27[0][0] __________________________________________________________________________________________________ activation_30 (Activation) (None, 8, 8, 72) 0 batch_normalization_30[0][0] __________________________________________________________________________________________________ conv2d_31 (Conv2D) (None, 8, 8, 18) 11664 activation_30[0][0] __________________________________________________________________________________________________ concatenate_28 (Concatenate) (None, 8, 8, 90) 0 concatenate_27[0][0] conv2d_31[0][0] __________________________________________________________________________________________________ batch_normalization_31 (BatchNo (None, 8, 8, 90) 360 concatenate_28[0][0] __________________________________________________________________________________________________ activation_31 (Activation) (None, 8, 8, 90) 0 batch_normalization_31[0][0] __________________________________________________________________________________________________ conv2d_32 (Conv2D) (None, 8, 8, 18) 14580 activation_31[0][0] __________________________________________________________________________________________________ concatenate_29 (Concatenate) (None, 8, 8, 108) 0 concatenate_28[0][0] conv2d_32[0][0] __________________________________________________________________________________________________ batch_normalization_32 (BatchNo (None, 8, 8, 108) 432 concatenate_29[0][0] __________________________________________________________________________________________________ activation_32 (Activation) (None, 8, 8, 108) 0 batch_normalization_32[0][0] __________________________________________________________________________________________________ conv2d_33 (Conv2D) (None, 8, 8, 18) 17496 activation_32[0][0] __________________________________________________________________________________________________ concatenate_30 (Concatenate) (None, 8, 8, 126) 0 concatenate_29[0][0] conv2d_33[0][0] __________________________________________________________________________________________________ batch_normalization_33 (BatchNo (None, 8, 8, 126) 504 concatenate_30[0][0] __________________________________________________________________________________________________ activation_33 (Activation) (None, 8, 8, 126) 0 batch_normalization_33[0][0] __________________________________________________________________________________________________ conv2d_34 (Conv2D) (None, 8, 8, 18) 20412 activation_33[0][0] __________________________________________________________________________________________________ concatenate_31 (Concatenate) (None, 8, 8, 144) 0 concatenate_30[0][0] conv2d_34[0][0] __________________________________________________________________________________________________ batch_normalization_34 (BatchNo (None, 8, 8, 144) 576 concatenate_31[0][0] __________________________________________________________________________________________________ activation_34 (Activation) (None, 8, 8, 144) 0 batch_normalization_34[0][0] __________________________________________________________________________________________________ conv2d_35 (Conv2D) (None, 8, 8, 18) 23328 activation_34[0][0] __________________________________________________________________________________________________ concatenate_32 (Concatenate) (None, 8, 8, 162) 0 concatenate_31[0][0] conv2d_35[0][0] __________________________________________________________________________________________________ batch_normalization_35 (BatchNo (None, 8, 8, 162) 648 concatenate_32[0][0] __________________________________________________________________________________________________ activation_35 (Activation) (None, 8, 8, 162) 0 batch_normalization_35[0][0] __________________________________________________________________________________________________ conv2d_36 (Conv2D) (None, 8, 8, 18) 26244 activation_35[0][0] __________________________________________________________________________________________________ concatenate_33 (Concatenate) (None, 8, 8, 180) 0 concatenate_32[0][0] conv2d_36[0][0] __________________________________________________________________________________________________ batch_normalization_36 (BatchNo (None, 8, 8, 180) 720 concatenate_33[0][0] __________________________________________________________________________________________________ activation_36 (Activation) (None, 8, 8, 180) 0 batch_normalization_36[0][0] __________________________________________________________________________________________________ conv2d_37 (Conv2D) (None, 8, 8, 18) 29160 activation_36[0][0] __________________________________________________________________________________________________ concatenate_34 (Concatenate) (None, 8, 8, 198) 0 concatenate_33[0][0] conv2d_37[0][0] __________________________________________________________________________________________________ batch_normalization_37 (BatchNo (None, 8, 8, 198) 792 concatenate_34[0][0] __________________________________________________________________________________________________ activation_37 (Activation) (None, 8, 8, 198) 0 batch_normalization_37[0][0] __________________________________________________________________________________________________ conv2d_38 (Conv2D) (None, 8, 8, 18) 32076 activation_37[0][0] __________________________________________________________________________________________________ concatenate_35 (Concatenate) (None, 8, 8, 216) 0 concatenate_34[0][0] conv2d_38[0][0] __________________________________________________________________________________________________ batch_normalization_38 (BatchNo (None, 8, 8, 216) 864 concatenate_35[0][0] __________________________________________________________________________________________________ activation_38 (Activation) (None, 8, 8, 216) 0 batch_normalization_38[0][0] __________________________________________________________________________________________________ conv2d_39 (Conv2D) (None, 8, 8, 18) 34992 activation_38[0][0] __________________________________________________________________________________________________ concatenate_36 (Concatenate) (None, 8, 8, 234) 0 concatenate_35[0][0] conv2d_39[0][0] __________________________________________________________________________________________________ batch_normalization_39 (BatchNo (None, 8, 8, 234) 936 concatenate_36[0][0] __________________________________________________________________________________________________ activation_39 (Activation) (None, 8, 8, 234) 0 batch_normalization_39[0][0] __________________________________________________________________________________________________ conv2d_40 (Conv2D) (None, 8, 8, 18) 4212 activation_39[0][0] __________________________________________________________________________________________________ average_pooling2d_3 (AveragePoo (None, 4, 4, 18) 0 conv2d_40[0][0] __________________________________________________________________________________________________ batch_normalization_40 (BatchNo (None, 4, 4, 18) 72 average_pooling2d_3[0][0] __________________________________________________________________________________________________ activation_40 (Activation) (None, 4, 4, 18) 0 batch_normalization_40[0][0] __________________________________________________________________________________________________ conv2d_41 (Conv2D) (None, 4, 4, 18) 2916 activation_40[0][0] __________________________________________________________________________________________________ concatenate_37 (Concatenate) (None, 4, 4, 36) 0 average_pooling2d_3[0][0] conv2d_41[0][0] __________________________________________________________________________________________________ batch_normalization_41 (BatchNo (None, 4, 4, 36) 144 concatenate_37[0][0] __________________________________________________________________________________________________ activation_41 (Activation) (None, 4, 4, 36) 0 batch_normalization_41[0][0] __________________________________________________________________________________________________ conv2d_42 (Conv2D) (None, 4, 4, 18) 5832 activation_41[0][0] __________________________________________________________________________________________________ concatenate_38 (Concatenate) (None, 4, 4, 54) 0 concatenate_37[0][0] conv2d_42[0][0] __________________________________________________________________________________________________ batch_normalization_42 (BatchNo (None, 4, 4, 54) 216 concatenate_38[0][0] __________________________________________________________________________________________________ activation_42 (Activation) (None, 4, 4, 54) 0 batch_normalization_42[0][0] __________________________________________________________________________________________________ conv2d_43 (Conv2D) (None, 4, 4, 18) 8748 activation_42[0][0] __________________________________________________________________________________________________ concatenate_39 (Concatenate) (None, 4, 4, 72) 0 concatenate_38[0][0] conv2d_43[0][0] __________________________________________________________________________________________________ batch_normalization_43 (BatchNo (None, 4, 4, 72) 288 concatenate_39[0][0] __________________________________________________________________________________________________ activation_43 (Activation) (None, 4, 4, 72) 0 batch_normalization_43[0][0] __________________________________________________________________________________________________ conv2d_44 (Conv2D) (None, 4, 4, 18) 11664 activation_43[0][0] __________________________________________________________________________________________________ concatenate_40 (Concatenate) (None, 4, 4, 90) 0 concatenate_39[0][0] conv2d_44[0][0] __________________________________________________________________________________________________ batch_normalization_44 (BatchNo (None, 4, 4, 90) 360 concatenate_40[0][0] __________________________________________________________________________________________________ activation_44 (Activation) (None, 4, 4, 90) 0 batch_normalization_44[0][0] __________________________________________________________________________________________________ conv2d_45 (Conv2D) (None, 4, 4, 18) 14580 activation_44[0][0] __________________________________________________________________________________________________ concatenate_41 (Concatenate) (None, 4, 4, 108) 0 concatenate_40[0][0] conv2d_45[0][0] __________________________________________________________________________________________________ batch_normalization_45 (BatchNo (None, 4, 4, 108) 432 concatenate_41[0][0] __________________________________________________________________________________________________ activation_45 (Activation) (None, 4, 4, 108) 0 batch_normalization_45[0][0] __________________________________________________________________________________________________ conv2d_46 (Conv2D) (None, 4, 4, 18) 17496 activation_45[0][0] __________________________________________________________________________________________________ concatenate_42 (Concatenate) (None, 4, 4, 126) 0 concatenate_41[0][0] conv2d_46[0][0] __________________________________________________________________________________________________ batch_normalization_46 (BatchNo (None, 4, 4, 126) 504 concatenate_42[0][0] __________________________________________________________________________________________________ activation_46 (Activation) (None, 4, 4, 126) 0 batch_normalization_46[0][0] __________________________________________________________________________________________________ conv2d_47 (Conv2D) (None, 4, 4, 18) 20412 activation_46[0][0] __________________________________________________________________________________________________ concatenate_43 (Concatenate) (None, 4, 4, 144) 0 concatenate_42[0][0] conv2d_47[0][0] __________________________________________________________________________________________________ batch_normalization_47 (BatchNo (None, 4, 4, 144) 576 concatenate_43[0][0] __________________________________________________________________________________________________ activation_47 (Activation) (None, 4, 4, 144) 0 batch_normalization_47[0][0] __________________________________________________________________________________________________ conv2d_48 (Conv2D) (None, 4, 4, 18) 23328 activation_47[0][0] __________________________________________________________________________________________________ concatenate_44 (Concatenate) (None, 4, 4, 162) 0 concatenate_43[0][0] conv2d_48[0][0] __________________________________________________________________________________________________ batch_normalization_48 (BatchNo (None, 4, 4, 162) 648 concatenate_44[0][0] __________________________________________________________________________________________________ activation_48 (Activation) (None, 4, 4, 162) 0 batch_normalization_48[0][0] __________________________________________________________________________________________________ conv2d_49 (Conv2D) (None, 4, 4, 18) 26244 activation_48[0][0] __________________________________________________________________________________________________ concatenate_45 (Concatenate) (None, 4, 4, 180) 0 concatenate_44[0][0] conv2d_49[0][0] __________________________________________________________________________________________________ batch_normalization_49 (BatchNo (None, 4, 4, 180) 720 concatenate_45[0][0] __________________________________________________________________________________________________ activation_49 (Activation) (None, 4, 4, 180) 0 batch_normalization_49[0][0] __________________________________________________________________________________________________ conv2d_50 (Conv2D) (None, 4, 4, 18) 29160 activation_49[0][0] __________________________________________________________________________________________________ concatenate_46 (Concatenate) (None, 4, 4, 198) 0 concatenate_45[0][0] conv2d_50[0][0] __________________________________________________________________________________________________ batch_normalization_50 (BatchNo (None, 4, 4, 198) 792 concatenate_46[0][0] __________________________________________________________________________________________________ activation_50 (Activation) (None, 4, 4, 198) 0 batch_normalization_50[0][0] __________________________________________________________________________________________________ conv2d_51 (Conv2D) (None, 4, 4, 18) 32076 activation_50[0][0] __________________________________________________________________________________________________ concatenate_47 (Concatenate) (None, 4, 4, 216) 0 concatenate_46[0][0] conv2d_51[0][0] __________________________________________________________________________________________________ batch_normalization_51 (BatchNo (None, 4, 4, 216) 864 concatenate_47[0][0] __________________________________________________________________________________________________ activation_51 (Activation) (None, 4, 4, 216) 0 batch_normalization_51[0][0] __________________________________________________________________________________________________ conv2d_52 (Conv2D) (None, 4, 4, 18) 34992 activation_51[0][0] __________________________________________________________________________________________________ concatenate_48 (Concatenate) (None, 4, 4, 234) 0 concatenate_47[0][0] conv2d_52[0][0] __________________________________________________________________________________________________ batch_normalization_52 (BatchNo (None, 4, 4, 234) 936 concatenate_48[0][0] __________________________________________________________________________________________________ activation_52 (Activation) (None, 4, 4, 234) 0 batch_normalization_52[0][0] __________________________________________________________________________________________________ average_pooling2d_4 (AveragePoo (None, 2, 2, 234) 0 activation_52[0][0] __________________________________________________________________________________________________ conv2d_53 (Conv2D) (None, 1, 1, 10) 9370 average_pooling2d_4[0][0] __________________________________________________________________________________________________ activation_53 (Activation) (None, 1, 1, 10) 0 conv2d_53[0][0] __________________________________________________________________________________________________ flatten_1 (Flatten) (None, 10) 0 activation_53[0][0] ================================================================================================== Total params: 995,230 Trainable params: 981,658 Non-trainable params: 13,572 __________________________________________________________________________________________________
190epochs.h5
model.load_weights(path + '/190epochs.h5')
Shift+Enter to run
# determine Loss function and Optimizer
model.compile(loss='categorical_crossentropy',
              optimizer=SGD(0.01, momentum = 0.7),
              metrics=['accuracy'])
Shift+Enter to run

2.Load Image

# Load CIFAR10 Data
import numpy as np
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train = tf.image.rgb_to_grayscale(x_train, name=None)
x_train = tf.broadcast_to(x_train, [50000, 32, 32, 3])
x_test = tf.image.rgb_to_grayscale(x_test, name=None)
x_test = tf.broadcast_to(x_test, [10000, 32, 32, 3])
x_train = x_train.numpy()
x_test = x_test.numpy()
img_height, img_width, channel = x_train.shape[1],x_train.shape[2],x_train.shape[3]
# convert to one hot encoing 
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
Shift+Enter to run
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz 170500096/170498071 [==============================] - 11s 0us/step
img_tensor = x_train[0]
img_tensor = np.expand_dims(img_tensor, axis=0)
plt.imshow(img_tensor[0])
Shift+Enter to run
<matplotlib.image.AxesImage at 0x7fa09f473400>

3.Reduce Number of Filters

This section shows the method of reducing the number of filters in the conv layer of DenseNet model.

3.1 Visualize filters

The first step is to visualize the filters in each conv layer. The filters were represented by their visualization using gradient ascent in input space. Such visualization shows the image that can light up each filter the most.

This notebook only shows an example of the process in the first conv layer of the model. In the actual project, the team reduced the filters in the first two conv layers, because many of their filters show similarity in grayscale. The deeper the model is, the fewer the number of filters look similar.

The first conv layer was modified to only take a single gray-scale channel as input.

def deprocess_image_grayscale_for_plot(x):
    # normalize tensor: center on 0., ensure std is 0.1
    x -= x.mean()
    x /= (x.std() + 1e-5)
    x *= 0.1
    # clip to [0, 1]
    x += 0.5
    x = np.clip(x, 0, 1)
    # convert to RGB array
    x *= 255
    x = np.clip(x, 0, 255).astype('uint8')
    # convert to grayscale by averaging across 3 channels
    # x = np.mean(x, axis=2)
    x = np.dot(x[...,:3], [0.2989, 0.5870, 0.1140])
    return x
def deprocess_image_grayscale(x):
    # normalize tensor: center on 0., ensure std is 0.1
    x -= x.mean()
    x /= (x.std() + 1e-5)
    x *= 0.1
    # clip to [0, 1]
    x += 0.5
    x = np.clip(x, 0, 1)
    # convert to grayscale by averaging across 3 channels
    # x = np.mean(x, axis=2)
    x = np.dot(x[...,:3], [0.2989, 0.5870, 0.1140])
    return x
def deprocess_image_color(x):
    # normalize tensor: center on 0., ensure std is 0.1
    x -= x.mean()
    x /= (x.std() + 1e-5)
    x *= 0.1
    # clip to [0, 1]
    x += 0.5
    x = np.clip(x, 0, 1)
    # convert to RGB array
    x *= 255
    x = np.clip(x, 0, 255).astype('uint8')
    return x
Shift+Enter to run
def generate_pattern(model, layer_name, filter_index, size=32, grayscale=True, for_plot=False):
    # Build a loss function that maximizes the activation
    # of the nth filter of the layer considered.
    layer_output = model.get_layer(layer_name).output
    loss = K.mean(layer_output[:, :, :, filter_index])
    # Compute the gradient of the input picture wrt this loss
    grads = K.gradients(loss, model.input)[0]
    # Normalization trick: we normalize the gradient
    grads /= (K.sqrt(K.mean(K.square(grads))) + 1e-5)
    # This function returns the loss and grads given the input picture
    iterate = K.function([model.input], [loss, grads])
    
    # We start from a gray image with some noise
    input_img_data = np.random.random((1, size, size, 3)) * 20 + 128.
    # Run gradient ascent for 40 steps
    step = 0.5
    for i in range(100):
        loss_value, grads_value = iterate([input_img_data])
        input_img_data += grads_value * step
        
    img = input_img_data[0]
    if grayscale:
        if not for_plot:
            return deprocess_image_grayscale(img)
        else:
            return deprocess_image_grayscale_for_plot(img)
    if not grayscale:
        return deprocess_image_color(img)
Shift+Enter to run
def plot_filters(model_name, layer_list, img_size, mar, 
                 row = 3, col = 6, 
                 greyscale_flag = True, plot_flag = True, col_customized=None):
    """Plot all filters in a layer"""
    for layer_name in layer_list:
        size = img_size
        margin = mar
        for_trans = []
        if not greyscale_flag:
            results = np.zeros((row * size + 7 * margin, col * size + 7 * margin, 3))
        else:
            # This a empty (black) image where we will store our results.
            results = np.zeros((row * size + 7 * margin, col * size + 7 * margin))
        for i in range(row):  # iterate over the rows of our results grid
            for j in range(col):  # iterate over the columns of our results grid
                # Generate the pattern for filter `i + (j * 6)` in `layer_name`
                
                filter_img = generate_pattern(model_name, layer_name, j + (i * row), size=size, grayscale=greyscale_flag, for_plot=plot_flag)
                  
                # Put the result in the square `(i, j)` of the results grid
                horizontal_start = i * size + i * margin
                horizontal_end = horizontal_start + size
                vertical_start = j * size + j * margin
                vertical_end = vertical_start + size
                if not greyscale_flag:
                    results[horizontal_start: horizontal_end, vertical_start: vertical_end, :] = filter_img/255
                else:
                    results[horizontal_start: horizontal_end, vertical_start: vertical_end] = filter_img / 255
                #print(np.amin(filter_img[:]), np.amax(filter_img[:]))
            print('Finished filter no.', j + (i * row), "in", layer_name, end= '.  ')
        # Display the results grid
        print(layer_name)
        plt.figure(figsize=(15, 15))
        if col_customized:
            plt.imshow(results, cmap=col_customized)
        else:
            plt.imshow(results)
        plt.show()
Shift+Enter to run

3.1.1 All Filters in Color

The figure below shows the filters in the first conv layer. Each square represents a specific filter.

plot_filters(model, ["conv2d_1"], 128, 5, row = 6, col = 6, greyscale_flag = False, plot_flag = True)
Shift+Enter to run
Finished filter no. 5 in conv2d_1. Finished filter no. 11 in conv2d_1. Finished filter no. 17 in conv2d_1. Finished filter no. 23 in conv2d_1. Finished filter no. 29 in conv2d_1. Finished filter no. 35 in conv2d_1. conv2d_1

3.1.2 All Filters in Grayscale

Since the input images are gray-scale images, these representations were converted to grayscale as well by adding linear weights to the RGB channel values [1]. This is the same method as changing color images to gray-scale images used in image preprocessing. The processed images can then represent what each filter is looking for in a gray-scale feature space. Figure below shows the gray-scale filters after the transformation of the figure above. Again, each square represents a specific filter. To increase the contrast for better visualization, the gray-scale filters were represented by blue and yellow.

plot_filters(model, ["conv2d_1"], 128, 5, row = 6, col = 6, greyscale_flag = True, plot_flag = True)
Shift+Enter to run
Finished filter no. 5 in conv2d_1. Finished filter no. 11 in conv2d_1. Finished filter no. 17 in conv2d_1. Finished filter no. 23 in conv2d_1. Finished filter no. 29 in conv2d_1. Finished filter no. 35 in conv2d_1. conv2d_1

3.2 Cluster Filters

To delete filters in charge of color information and reduce the number of filters focusing on similar patterns, we need to cluster them into different groups. First of all, to represent our filters, we continue to use the input that maximizes the response for each of them. In such a way the filters can be represented as 32 by 32 matrices, and further regarded as one observation after flattening it to one dimension. Then we need to calculate the distance between each pair of these images as a measurement of their similarity.

For the first convolutional layer completely focusing on brightness, we used the MSE method, also known as Euclidean distance, to calculate the distance. This is because we are estimating the relative overall brightness difference. For the following layers that focused more on image patterns, we tried Structural Similarity Index and Image Euclidean Distance. These two distances consider the relative pixel position when calculating the value difference between different pixels. Considering pixel positions can alleviate the problem caused by two images having similar patterns but such patterns shifting within a range of pixels. 

DBSCAN [2] was used as the clustering method. Previous studies used k-means clustering to cluster the CNN kernels [3]. The team found that k-means always results in evenly distributed kernels across different clusters, being not suitable for this case with only 18 filters for the convolutional layers. Two ways were applied to measure the distances between every two filters in matrix forms for clustering. MSE [4] was applied in the first layer, and IMED [5] was used in the second layer. A threshold was set to select filters that can be grouped together. We were meant to keep slightly more than 50% of the filters for each layer.

# Get filter information to an array
def get_filter_array(model_name, layer_name, filter_num):
    filter_collection = []
    for i in range( filter_num):
        filter_img_temp = generate_pattern(model_name, layer_name, i, size=32, grayscale=True, for_plot=False)
        filter_collection.append(filter_img_temp)
    filter_collection = np.array(filter_collection)
    return filter_collection
Shift+Enter to run
def get_G(img_h, r):
    '''calculate Pixel Distance for IMED'''
    G_matrix = np.zeros((img_h**2, img_h**2))
    for i1 in range(img_h):
        for i2 in range(img_h):
            for j1 in range(img_h):
                for j2 in range(img_h):
                    pixel_dist = 1/(2*math.pi*(r**2)) * math.exp(-((i1-i2)**2 + (j1-j2)**2)/(2*r**2))
                    G_matrix[img_h*i1+j1, img_h*i2+j2] = pixel_dist
    return G_matrix
G = get_G(32, 1)
Shift+Enter to run
def get_imed(img1, img2):
    '''Calculate Image distance'''
    img1 = img1.flatten()
    img2 = img2.flatten()
    distance = math.sqrt(np.dot(np.dot(np.transpose(img1-img2), G), (img1-img2)))
    return distance
Shift+Enter to run

MSE, SSIM, cosine similarity, brightness and IMED are all methodologies to measure the distances between every two matrices. After trying them, the team found that MSE worked the best in the first conv layer which can be seen as a color palette. IMED gave the best result in the second layer which still contained most color information.

def calculate_dist(filter_collection, method="mse"):
    filter_num = filter_collection.shape[0]
    print("Calculated Distance Matrix Using", method, "Method.")
    print("Filter number:", filter_num)
    if method == "ssim": # calculate distance matrix using ssim
        filter_distance_matrix_ssim = np.zeros((filter_num, filter_num))
        for i in range(filter_num):
            for j in range(filter_num):
                filter_distance_matrix_ssim[i, j] = ssim(filter_collection[i], filter_collection[j])
        filter_distance_matrix_ssim = 1 - filter_distance_matrix_ssim
        return filter_distance_matrix_ssim
    if method == "mse": # calculate distance matrix using mse
        filter_distance_matrix_mse = np.zeros((filter_num, filter_num))
        for i in range(filter_num):
            for j in range(filter_num):
                filter_distance_matrix_mse[i, j] = mean_squared_error(filter_collection[i].flatten(), filter_collection[j].flatten())
        return filter_distance_matrix_mse
    if method == "consine": # calculate distance matrix using cosine
        filter_distance_matrix_cosine = np.zeros((filter_num, filter_num))
        for i in range(filter_num):
            for j in range(filter_num):
                filter_distance_matrix_cosine[i, j] = \
                1 - spatial.distance.cosine(filter_collection[i].flatten(), filter_collection[j].flatten())
        filter_distance_matrix_cosine = 1 - filter_distance_matrix_cosine
        return filter_distance_matrix_cosine
    
    if method == "brightness": # calculate distance matrix using brightness
        filter_distance_matrix_brightness = np.zeros((filter_num, filter_num))
        for i in range(filter_num):
            for j in range(filter_num):
                filter_distance_matrix_brightness[i, j] = \
                np.sum(filter_collection[i]) - np.sum(filter_collection[j])
        filter_distance_matrix_brightness = mmscaler.fit_transform(filter_distance_matrix_brightness.flatten().reshape(-1, 1)).reshape((filter_num, filter_num))
        return filter_distance_matrix_brightness
    if method == "imed":
        filter_distance_matrix_imed = np.zeros((filter_num, filter_num))
        for i in range(filter_num):
            for j in range(filter_num):
                filter_distance_matrix_imed[i, j] = get_imed(filter_collection[i].flatten(), filter_collection[j].flatten())
        return filter_distance_matrix_imed
Shift+Enter to run
def get_cluster_(distance_mat, weight, row, col, min_samp=1):
    h = len(distance_mat)
    distance_mat = mmscaler.fit_transform(distance_mat.flatten().reshape(-1,1)).reshape(h, h)
    std = np.std([distance_mat[i, j] for i in range(h) for j in range(h) if i != j])
    clustering = DBSCAN(eps=weight, 
                    min_samples=min_samp, metric='precomputed').fit(distance_mat)
    filter_clusters = clustering.labels_
    return filter_clusters
Shift+Enter to run
conv2d_1 = get_filter_array(model, 'conv2d_1', 36)
conv2d_1_mse = calculate_dist(conv2d_1, method="mse")
Shift+Enter to run
Calculated Distance Matrix Using mse Method. Filter number: 36

In the first conv layer, there are 36 filters in total, and they were clustered into 20 groups.

filter_clusters = get_cluster_(conv2d_1_mse, weight=0.027, row=6, col=6, min_samp=1)
print('Number of clusters:', max(filter_clusters) + 1)
filter_clusters.reshape(6, 6)
Shift+Enter to run
Number of clusters: 20
array([[ 0, 1, 2, 3, 4, 5], [ 6, 6, 7, 6, 8, 6], [ 6, 7, 9, 10, 7, 6], [ 7, 6, 11, 6, 12, 13], [ 6, 6, 10, 14, 6, 15], [16, 17, 18, 19, 7, 13]])

3.3 Merge Filters

Every filter is a three-dimension tensor. As for the convolutional layers that need to be shrunk, the filters were simply averaged within each cluster. However, doing this also reduced the number of output channels and caused shape mismatch. Therefore, for the following convolutional layer, the weights in each filter were added across the third dimension according to how we clustered the previous layer. 

def merge_filters(original_weights, cluster_res, 
                  cluster_res_prev=np.array([0, 0, 0]), layer_type='conv'):
    """
    Merge original weights, 
    based on clustering results and number of in_channels
    """
    # clean clustering result for outliers, record them as a new cluster
    max_cluster = max(cluster_res)
    for i, cls in enumerate(cluster_res):
        if cls == -1:
            max_cluster += 1
            cluster_res[i] = max_cluster
        
    # clean prev clustering result for outliers, record them as a new cluster
    max_cluster_prev = max(cluster_res_prev)
    for i, cls in enumerate(cluster_res_prev):
        if cls == -1:
            max_cluster_prev += 1
            cluster_res_prev[i] = max_cluster_prev
    if layer_type == 'conv':
        # average over the 4th dimension
        # zero tensor to record weights
        clustered_weights = np.zeros(list(original_weights.shape[:3]) + [max_cluster + 1])
        # load new filters with averaged weights
        for filter_idx in range(max_cluster + 1):
            # get index for the filters belong to that cluster
            idx = tf.constant([i for i in range(len(cluster_res)) if cluster_res[i] == filter_idx])
            # reduce average over the filters
            clustered_filter = tf.reduce_mean(tf.gather(original_weights, idx, axis=3), axis=3)
            clustered_weights[:, :, :, filter_idx] = clustered_filter
        # sum over the 3rd dimension according to previous cluster results
        # zero tensor to record weights
        reduced_clustered_weights = np.zeros(list(original_weights.shape[:2]) + [max_cluster_prev + 1] + [max_cluster + 1])
        # load new filters with sumed weights
        for filter_idx in range(max_cluster_prev + 1):
            # get index for the filters belong to that cluster
            idx = tf.constant([i for i in range(len(cluster_res_prev)) if cluster_res_prev[i] == filter_idx])
            # reduce average over the filters
            sumed_filter = tf.reduce_sum(tf.gather(clustered_weights, idx, axis=2), axis=2)
            reduced_clustered_weights[:, :, filter_idx, :] = sumed_filter
        return reduced_clustered_weights
    if layer_type == 'bn':
        # average the weights across channels
        # list to record weights
        clustered_weights = [np.zeros([max_cluster + 1]), 
                             np.zeros([max_cluster + 1]), 
                             np.zeros([max_cluster + 1]), 
                             np.zeros([max_cluster + 1])]
        # load new filters with averaged weights
        for bn_idx, bn_weights in enumerate(clustered_weights):
            for filter_idx in range(max_cluster + 1):
                # get index for the filters belong to that cluster
                idx = tf.constant([i for i in range(len(cluster_res)) if cluster_res[i] == filter_idx])
                # reduce average over the filters
                clustered_filter = tf.reduce_mean(tf.gather(original_weights[bn_idx], idx, axis=0), axis=0)
                bn_weights[filter_idx] = clustered_filter
        return clustered_weights
Shift+Enter to run

The figure below is an example showing the pipeline of merging the filters of the first convolutional layer in a color model.

Here we are merging the first convolutional layer according to our clustering result filter_clusters. The filters were first averaged within each cluster, then reduce summed in their third dimension to match the single-channel input. This gives us new weights test_conv_1 for the first convolutional layer.

test_conv_1 = merge_filters(model.get_layer('conv2d_1').weights[0], filter_clusters, np.array([0, 0, 0]), layer_type='conv')
print('Shape of clustered 1st conv layer:', test_conv_1.shape)
Shift+Enter to run
Shape of clustered 1st conv layer: (3, 3, 1, 20)

Then it's the batch-normal layer, we reduce-average its parameters according to the clustering results of layer 1: test_conv_1, so that the dimensions all match.

test_bn_1 = merge_filters(model.get_layer('batch_normalization_1').weights, filter_clusters, layer_type='bn')
print('Shape of clustered 1st bn layer:', test_bn_1[0].shape)
Shift+Enter to run
Shape of clustered 1st bn layer: (20,)

Now we reduce -sum the second conv layer's height according to how we clusterd the conv first layer.

test_conv_2 = merge_filters(model.get_layer('conv2d_2').weights[0], np.arange(18), filter_clusters, layer_type='conv')
print('Shape of 2nd conv layer:', test_conv_2.shape)
Shift+Enter to run
Shape of 2nd conv layer: (3, 3, 20, 18)

3.4 Load Weights

This section loads the preseved weights from the original model and the adjusted weights from the clustering step. These weights are used together for building the new model and comparing results with the old model.

num_filter = 36
dropout_rate = 0.2
l= 12
input = Input(shape=(img_height, img_width, 1,))
First_Conv2D = Conv2D(test_conv_1.shape[-1], (3,3), use_bias=False ,padding='same')(input)
First_Block = add_denseblock(First_Conv2D, num_filter, dropout_rate)
First_Transition = add_transition(First_Block, num_filter, dropout_rate)
Second_Block = add_denseblock(First_Transition, num_filter, dropout_rate)
Second_Transition = add_transition(Second_Block, num_filter, dropout_rate)
Third_Block = add_denseblock(Second_Transition, num_filter, dropout_rate)
Third_Transition = add_transition(Third_Block, num_filter, dropout_rate)
Last_Block = add_denseblock(Third_Transition,  num_filter, dropout_rate)
output = output_layer(Last_Block)
model_clustered = Model(inputs=[input], outputs=[output])
Shift+Enter to run
model_clustered.summary()
Shift+Enter to run
Model: "model_4" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_3 (InputLayer) (None, 32, 32, 1) 0 __________________________________________________________________________________________________ conv2d_55 (Conv2D) (None, 32, 32, 20) 180 input_3[0][0] __________________________________________________________________________________________________ batch_normalization_53 (BatchNo (None, 32, 32, 20) 80 conv2d_55[0][0] __________________________________________________________________________________________________ activation_54 (Activation) (None, 32, 32, 20) 0 batch_normalization_53[0][0] __________________________________________________________________________________________________ conv2d_56 (Conv2D) (None, 32, 32, 18) 3240 activation_54[0][0] __________________________________________________________________________________________________ concatenate_49 (Concatenate) (None, 32, 32, 38) 0 conv2d_55[0][0] conv2d_56[0][0] __________________________________________________________________________________________________ batch_normalization_54 (BatchNo (None, 32, 32, 38) 152 concatenate_49[0][0] __________________________________________________________________________________________________ activation_55 (Activation) (None, 32, 32, 38) 0 batch_normalization_54[0][0] __________________________________________________________________________________________________ conv2d_57 (Conv2D) (None, 32, 32, 18) 6156 activation_55[0][0] __________________________________________________________________________________________________ concatenate_50 (Concatenate) (None, 32, 32, 56) 0 concatenate_49[0][0] conv2d_57[0][0] __________________________________________________________________________________________________ batch_normalization_55 (BatchNo (None, 32, 32, 56) 224 concatenate_50[0][0] __________________________________________________________________________________________________ activation_56 (Activation) (None, 32, 32, 56) 0 batch_normalization_55[0][0] __________________________________________________________________________________________________ conv2d_58 (Conv2D) (None, 32, 32, 18) 9072 activation_56[0][0] __________________________________________________________________________________________________ concatenate_51 (Concatenate) (None, 32, 32, 74) 0 concatenate_50[0][0] conv2d_58[0][0] __________________________________________________________________________________________________ batch_normalization_56 (BatchNo (None, 32, 32, 74) 296 concatenate_51[0][0] __________________________________________________________________________________________________ activation_57 (Activation) (None, 32, 32, 74) 0 batch_normalization_56[0][0] __________________________________________________________________________________________________ conv2d_59 (Conv2D) (None, 32, 32, 18) 11988 activation_57[0][0] __________________________________________________________________________________________________ concatenate_52 (Concatenate) (None, 32, 32, 92) 0 concatenate_51[0][0] conv2d_59[0][0] __________________________________________________________________________________________________ batch_normalization_57 (BatchNo (None, 32, 32, 92) 368 concatenate_52[0][0] __________________________________________________________________________________________________ activation_58 (Activation) (None, 32, 32, 92) 0 batch_normalization_57[0][0] __________________________________________________________________________________________________ conv2d_60 (Conv2D) (None, 32, 32, 18) 14904 activation_58[0][0] __________________________________________________________________________________________________ concatenate_53 (Concatenate) (None, 32, 32, 110) 0 concatenate_52[0][0] conv2d_60[0][0] __________________________________________________________________________________________________ batch_normalization_58 (BatchNo (None, 32, 32, 110) 440 concatenate_53[0][0] __________________________________________________________________________________________________ activation_59 (Activation) (None, 32, 32, 110) 0 batch_normalization_58[0][0] __________________________________________________________________________________________________ conv2d_61 (Conv2D) (None, 32, 32, 18) 17820 activation_59[0][0] __________________________________________________________________________________________________ concatenate_54 (Concatenate) (None, 32, 32, 128) 0 concatenate_53[0][0] conv2d_61[0][0] __________________________________________________________________________________________________ batch_normalization_59 (BatchNo (None, 32, 32, 128) 512 concatenate_54[0][0] __________________________________________________________________________________________________ activation_60 (Activation) (None, 32, 32, 128) 0 batch_normalization_59[0][0] __________________________________________________________________________________________________ conv2d_62 (Conv2D) (None, 32, 32, 18) 20736 activation_60[0][0] __________________________________________________________________________________________________ concatenate_55 (Concatenate) (None, 32, 32, 146) 0 concatenate_54[0][0] conv2d_62[0][0] __________________________________________________________________________________________________ batch_normalization_60 (BatchNo (None, 32, 32, 146) 584 concatenate_55[0][0] __________________________________________________________________________________________________ activation_61 (Activation) (None, 32, 32, 146) 0 batch_normalization_60[0][0] __________________________________________________________________________________________________ conv2d_63 (Conv2D) (None, 32, 32, 18) 23652 activation_61[0][0] __________________________________________________________________________________________________ concatenate_56 (Concatenate) (None, 32, 32, 164) 0 concatenate_55[0][0] conv2d_63[0][0] __________________________________________________________________________________________________ batch_normalization_61 (BatchNo (None, 32, 32, 164) 656 concatenate_56[0][0] __________________________________________________________________________________________________ activation_62 (Activation) (None, 32, 32, 164) 0 batch_normalization_61[0][0] __________________________________________________________________________________________________ conv2d_64 (Conv2D) (None, 32, 32, 18) 26568 activation_62[0][0] __________________________________________________________________________________________________ concatenate_57 (Concatenate) (None, 32, 32, 182) 0 concatenate_56[0][0] conv2d_64[0][0] __________________________________________________________________________________________________ batch_normalization_62 (BatchNo (None, 32, 32, 182) 728 concatenate_57[0][0] __________________________________________________________________________________________________ activation_63 (Activation) (None, 32, 32, 182) 0 batch_normalization_62[0][0] __________________________________________________________________________________________________ conv2d_65 (Conv2D) (None, 32, 32, 18) 29484 activation_63[0][0] __________________________________________________________________________________________________ concatenate_58 (Concatenate) (None, 32, 32, 200) 0 concatenate_57[0][0] conv2d_65[0][0] __________________________________________________________________________________________________ batch_normalization_63 (BatchNo (None, 32, 32, 200) 800 concatenate_58[0][0] __________________________________________________________________________________________________ activation_64 (Activation) (None, 32, 32, 200) 0 batch_normalization_63[0][0] __________________________________________________________________________________________________ conv2d_66 (Conv2D) (None, 32, 32, 18) 32400 activation_64[0][0] __________________________________________________________________________________________________ concatenate_59 (Concatenate) (None, 32, 32, 218) 0 concatenate_58[0][0] conv2d_66[0][0] __________________________________________________________________________________________________ batch_normalization_64 (BatchNo (None, 32, 32, 218) 872 concatenate_59[0][0] __________________________________________________________________________________________________ activation_65 (Activation) (None, 32, 32, 218) 0 batch_normalization_64[0][0] __________________________________________________________________________________________________ conv2d_67 (Conv2D) (None, 32, 32, 18) 35316 activation_65[0][0] __________________________________________________________________________________________________ concatenate_60 (Concatenate) (None, 32, 32, 236) 0 concatenate_59[0][0] conv2d_67[0][0] __________________________________________________________________________________________________ batch_normalization_65 (BatchNo (None, 32, 32, 236) 944 concatenate_60[0][0] __________________________________________________________________________________________________ activation_66 (Activation) (None, 32, 32, 236) 0 batch_normalization_65[0][0] __________________________________________________________________________________________________ conv2d_68 (Conv2D) (None, 32, 32, 18) 4248 activation_66[0][0] __________________________________________________________________________________________________ average_pooling2d_5 (AveragePoo (None, 16, 16, 18) 0 conv2d_68[0][0] __________________________________________________________________________________________________ batch_normalization_66 (BatchNo (None, 16, 16, 18) 72 average_pooling2d_5[0][0] __________________________________________________________________________________________________ activation_67 (Activation) (None, 16, 16, 18) 0 batch_normalization_66[0][0] __________________________________________________________________________________________________ conv2d_69 (Conv2D) (None, 16, 16, 18) 2916 activation_67[0][0] __________________________________________________________________________________________________ concatenate_61 (Concatenate) (None, 16, 16, 36) 0 average_pooling2d_5[0][0] conv2d_69[0][0] __________________________________________________________________________________________________ batch_normalization_67 (BatchNo (None, 16, 16, 36) 144 concatenate_61[0][0] __________________________________________________________________________________________________ activation_68 (Activation) (None, 16, 16, 36) 0 batch_normalization_67[0][0] __________________________________________________________________________________________________ conv2d_70 (Conv2D) (None, 16, 16, 18) 5832 activation_68[0][0] __________________________________________________________________________________________________ concatenate_62 (Concatenate) (None, 16, 16, 54) 0 concatenate_61[0][0] conv2d_70[0][0] __________________________________________________________________________________________________ batch_normalization_68 (BatchNo (None, 16, 16, 54) 216 concatenate_62[0][0] __________________________________________________________________________________________________ activation_69 (Activation) (None, 16, 16, 54) 0 batch_normalization_68[0][0] __________________________________________________________________________________________________ conv2d_71 (Conv2D) (None, 16, 16, 18) 8748 activation_69[0][0] __________________________________________________________________________________________________ concatenate_63 (Concatenate) (None, 16, 16, 72) 0 concatenate_62[0][0] conv2d_71[0][0] __________________________________________________________________________________________________ batch_normalization_69 (BatchNo (None, 16, 16, 72) 288 concatenate_63[0][0] __________________________________________________________________________________________________ activation_70 (Activation) (None, 16, 16, 72) 0 batch_normalization_69[0][0] __________________________________________________________________________________________________ conv2d_72 (Conv2D) (None, 16, 16, 18) 11664 activation_70[0][0] __________________________________________________________________________________________________ concatenate_64 (Concatenate) (None, 16, 16, 90) 0 concatenate_63[0][0] conv2d_72[0][0] __________________________________________________________________________________________________ batch_normalization_70 (BatchNo (None, 16, 16, 90) 360 concatenate_64[0][0] __________________________________________________________________________________________________ activation_71 (Activation) (None, 16, 16, 90) 0 batch_normalization_70[0][0] __________________________________________________________________________________________________ conv2d_73 (Conv2D) (None, 16, 16, 18) 14580 activation_71[0][0] __________________________________________________________________________________________________ concatenate_65 (Concatenate) (None, 16, 16, 108) 0 concatenate_64[0][0] conv2d_73[0][0] __________________________________________________________________________________________________ batch_normalization_71 (BatchNo (None, 16, 16, 108) 432 concatenate_65[0][0] __________________________________________________________________________________________________ activation_72 (Activation) (None, 16, 16, 108) 0 batch_normalization_71[0][0] __________________________________________________________________________________________________ conv2d_74 (Conv2D) (None, 16, 16, 18) 17496 activation_72[0][0] __________________________________________________________________________________________________ concatenate_66 (Concatenate) (None, 16, 16, 126) 0 concatenate_65[0][0] conv2d_74[0][0] __________________________________________________________________________________________________ batch_normalization_72 (BatchNo (None, 16, 16, 126) 504 concatenate_66[0][0] __________________________________________________________________________________________________ activation_73 (Activation) (None, 16, 16, 126) 0 batch_normalization_72[0][0] __________________________________________________________________________________________________ conv2d_75 (Conv2D) (None, 16, 16, 18) 20412 activation_73[0][0] __________________________________________________________________________________________________ concatenate_67 (Concatenate) (None, 16, 16, 144) 0 concatenate_66[0][0] conv2d_75[0][0] __________________________________________________________________________________________________ batch_normalization_73 (BatchNo (None, 16, 16, 144) 576 concatenate_67[0][0] __________________________________________________________________________________________________ activation_74 (Activation) (None, 16, 16, 144) 0 batch_normalization_73[0][0] __________________________________________________________________________________________________ conv2d_76 (Conv2D) (None, 16, 16, 18) 23328 activation_74[0][0] __________________________________________________________________________________________________ concatenate_68 (Concatenate) (None, 16, 16, 162) 0 concatenate_67[0][0] conv2d_76[0][0] __________________________________________________________________________________________________ batch_normalization_74 (BatchNo (None, 16, 16, 162) 648 concatenate_68[0][0] __________________________________________________________________________________________________ activation_75 (Activation) (None, 16, 16, 162) 0 batch_normalization_74[0][0] __________________________________________________________________________________________________ conv2d_77 (Conv2D) (None, 16, 16, 18) 26244 activation_75[0][0] __________________________________________________________________________________________________ concatenate_69 (Concatenate) (None, 16, 16, 180) 0 concatenate_68[0][0] conv2d_77[0][0] __________________________________________________________________________________________________ batch_normalization_75 (BatchNo (None, 16, 16, 180) 720 concatenate_69[0][0] __________________________________________________________________________________________________ activation_76 (Activation) (None, 16, 16, 180) 0 batch_normalization_75[0][0] __________________________________________________________________________________________________ conv2d_78 (Conv2D) (None, 16, 16, 18) 29160 activation_76[0][0] __________________________________________________________________________________________________ concatenate_70 (Concatenate) (None, 16, 16, 198) 0 concatenate_69[0][0] conv2d_78[0][0] __________________________________________________________________________________________________ batch_normalization_76 (BatchNo (None, 16, 16, 198) 792 concatenate_70[0][0] __________________________________________________________________________________________________ activation_77 (Activation) (None, 16, 16, 198) 0 batch_normalization_76[0][0] __________________________________________________________________________________________________ conv2d_79 (Conv2D) (None, 16, 16, 18) 32076 activation_77[0][0] __________________________________________________________________________________________________ concatenate_71 (Concatenate) (None, 16, 16, 216) 0 concatenate_70[0][0] conv2d_79[0][0] __________________________________________________________________________________________________ batch_normalization_77 (BatchNo (None, 16, 16, 216) 864 concatenate_71[0][0] __________________________________________________________________________________________________ activation_78 (Activation) (None, 16, 16, 216) 0 batch_normalization_77[0][0] __________________________________________________________________________________________________ conv2d_80 (Conv2D) (None, 16, 16, 18) 34992 activation_78[0][0] __________________________________________________________________________________________________ concatenate_72 (Concatenate) (None, 16, 16, 234) 0 concatenate_71[0][0] conv2d_80[0][0] __________________________________________________________________________________________________ batch_normalization_78 (BatchNo (None, 16, 16, 234) 936 concatenate_72[0][0] __________________________________________________________________________________________________ activation_79 (Activation) (None, 16, 16, 234) 0 batch_normalization_78[0][0] __________________________________________________________________________________________________ conv2d_81 (Conv2D) (None, 16, 16, 18) 4212 activation_79[0][0] __________________________________________________________________________________________________ average_pooling2d_6 (AveragePoo (None, 8, 8, 18) 0 conv2d_81[0][0] __________________________________________________________________________________________________ batch_normalization_79 (BatchNo (None, 8, 8, 18) 72 average_pooling2d_6[0][0] __________________________________________________________________________________________________ activation_80 (Activation) (None, 8, 8, 18) 0 batch_normalization_79[0][0] __________________________________________________________________________________________________ conv2d_82 (Conv2D) (None, 8, 8, 18) 2916 activation_80[0][0] __________________________________________________________________________________________________ concatenate_73 (Concatenate) (None, 8, 8, 36) 0 average_pooling2d_6[0][0] conv2d_82[0][0] __________________________________________________________________________________________________ batch_normalization_80 (BatchNo (None, 8, 8, 36) 144 concatenate_73[0][0] __________________________________________________________________________________________________ activation_81 (Activation) (None, 8, 8, 36) 0 batch_normalization_80[0][0] __________________________________________________________________________________________________ conv2d_83 (Conv2D) (None, 8, 8, 18) 5832 activation_81[0][0] __________________________________________________________________________________________________ concatenate_74 (Concatenate) (None, 8, 8, 54) 0 concatenate_73[0][0] conv2d_83[0][0] __________________________________________________________________________________________________ batch_normalization_81 (BatchNo (None, 8, 8, 54) 216 concatenate_74[0][0] __________________________________________________________________________________________________ activation_82 (Activation) (None, 8, 8, 54) 0 batch_normalization_81[0][0] __________________________________________________________________________________________________ conv2d_84 (Conv2D) (None, 8, 8, 18) 8748 activation_82[0][0] __________________________________________________________________________________________________ concatenate_75 (Concatenate) (None, 8, 8, 72) 0 concatenate_74[0][0] conv2d_84[0][0] __________________________________________________________________________________________________ batch_normalization_82 (BatchNo (None, 8, 8, 72) 288 concatenate_75[0][0] __________________________________________________________________________________________________ activation_83 (Activation) (None, 8, 8, 72) 0 batch_normalization_82[0][0] __________________________________________________________________________________________________ conv2d_85 (Conv2D) (None, 8, 8, 18) 11664 activation_83[0][0] __________________________________________________________________________________________________ concatenate_76 (Concatenate) (None, 8, 8, 90) 0 concatenate_75[0][0] conv2d_85[0][0] __________________________________________________________________________________________________ batch_normalization_83 (BatchNo (None, 8, 8, 90) 360 concatenate_76[0][0] __________________________________________________________________________________________________ activation_84 (Activation) (None, 8, 8, 90) 0 batch_normalization_83[0][0] __________________________________________________________________________________________________ conv2d_86 (Conv2D) (None, 8, 8, 18) 14580 activation_84[0][0] __________________________________________________________________________________________________ concatenate_77 (Concatenate) (None, 8, 8, 108) 0 concatenate_76[0][0] conv2d_86[0][0] __________________________________________________________________________________________________ batch_normalization_84 (BatchNo (None, 8, 8, 108) 432 concatenate_77[0][0] __________________________________________________________________________________________________ activation_85 (Activation) (None, 8, 8, 108) 0 batch_normalization_84[0][0] __________________________________________________________________________________________________ conv2d_87 (Conv2D) (None, 8, 8, 18) 17496 activation_85[0][0] __________________________________________________________________________________________________ concatenate_78 (Concatenate) (None, 8, 8, 126) 0 concatenate_77[0][0] conv2d_87[0][0] __________________________________________________________________________________________________ batch_normalization_85 (BatchNo (None, 8, 8, 126) 504 concatenate_78[0][0] __________________________________________________________________________________________________ activation_86 (Activation) (None, 8, 8, 126) 0 batch_normalization_85[0][0] __________________________________________________________________________________________________ conv2d_88 (Conv2D) (None, 8, 8, 18) 20412 activation_86[0][0] __________________________________________________________________________________________________ concatenate_79 (Concatenate) (None, 8, 8, 144) 0 concatenate_78[0][0] conv2d_88[0][0] __________________________________________________________________________________________________ batch_normalization_86 (BatchNo (None, 8, 8, 144) 576 concatenate_79[0][0] __________________________________________________________________________________________________ activation_87 (Activation) (None, 8, 8, 144) 0 batch_normalization_86[0][0] __________________________________________________________________________________________________ conv2d_89 (Conv2D) (None, 8, 8, 18) 23328 activation_87[0][0] __________________________________________________________________________________________________ concatenate_80 (Concatenate) (None, 8, 8, 162) 0 concatenate_79[0][0] conv2d_89[0][0] __________________________________________________________________________________________________ batch_normalization_87 (BatchNo (None, 8, 8, 162) 648 concatenate_80[0][0] __________________________________________________________________________________________________ activation_88 (Activation) (None, 8, 8, 162) 0 batch_normalization_87[0][0] __________________________________________________________________________________________________ conv2d_90 (Conv2D) (None, 8, 8, 18) 26244 activation_88[0][0] __________________________________________________________________________________________________ concatenate_81 (Concatenate) (None, 8, 8, 180) 0 concatenate_80[0][0] conv2d_90[0][0] __________________________________________________________________________________________________ batch_normalization_88 (BatchNo (None, 8, 8, 180) 720 concatenate_81[0][0] __________________________________________________________________________________________________ activation_89 (Activation) (None, 8, 8, 180) 0 batch_normalization_88[0][0] __________________________________________________________________________________________________ conv2d_91 (Conv2D) (None, 8, 8, 18) 29160 activation_89[0][0] __________________________________________________________________________________________________ concatenate_82 (Concatenate) (None, 8, 8, 198) 0 concatenate_81[0][0] conv2d_91[0][0] __________________________________________________________________________________________________ batch_normalization_89 (BatchNo (None, 8, 8, 198) 792 concatenate_82[0][0] __________________________________________________________________________________________________ activation_90 (Activation) (None, 8, 8, 198) 0 batch_normalization_89[0][0] __________________________________________________________________________________________________ conv2d_92 (Conv2D) (None, 8, 8, 18) 32076 activation_90[0][0] __________________________________________________________________________________________________ concatenate_83 (Concatenate) (None, 8, 8, 216) 0 concatenate_82[0][0] conv2d_92[0][0] __________________________________________________________________________________________________ batch_normalization_90 (BatchNo (None, 8, 8, 216) 864 concatenate_83[0][0] __________________________________________________________________________________________________ activation_91 (Activation) (None, 8, 8, 216) 0 batch_normalization_90[0][0] __________________________________________________________________________________________________ conv2d_93 (Conv2D) (None, 8, 8, 18) 34992 activation_91[0][0] __________________________________________________________________________________________________ concatenate_84 (Concatenate) (None, 8, 8, 234) 0 concatenate_83[0][0] conv2d_93[0][0] __________________________________________________________________________________________________ batch_normalization_91 (BatchNo (None, 8, 8, 234) 936 concatenate_84[0][0] __________________________________________________________________________________________________ activation_92 (Activation) (None, 8, 8, 234) 0 batch_normalization_91[0][0] __________________________________________________________________________________________________ conv2d_94 (Conv2D) (None, 8, 8, 18) 4212 activation_92[0][0] __________________________________________________________________________________________________ average_pooling2d_7 (AveragePoo (None, 4, 4, 18) 0 conv2d_94[0][0] __________________________________________________________________________________________________ batch_normalization_92 (BatchNo (None, 4, 4, 18) 72 average_pooling2d_7[0][0] __________________________________________________________________________________________________ activation_93 (Activation) (None, 4, 4, 18) 0 batch_normalization_92[0][0] __________________________________________________________________________________________________ conv2d_95 (Conv2D) (None, 4, 4, 18) 2916 activation_93[0][0] __________________________________________________________________________________________________ concatenate_85 (Concatenate) (None, 4, 4, 36) 0 average_pooling2d_7[0][0] conv2d_95[0][0] __________________________________________________________________________________________________ batch_normalization_93 (BatchNo (None, 4, 4, 36) 144 concatenate_85[0][0] __________________________________________________________________________________________________ activation_94 (Activation) (None, 4, 4, 36) 0 batch_normalization_93[0][0] __________________________________________________________________________________________________ conv2d_96 (Conv2D) (None, 4, 4, 18) 5832 activation_94[0][0] __________________________________________________________________________________________________ concatenate_86 (Concatenate) (None, 4, 4, 54) 0 concatenate_85[0][0] conv2d_96[0][0] __________________________________________________________________________________________________ batch_normalization_94 (BatchNo (None, 4, 4, 54) 216 concatenate_86[0][0] __________________________________________________________________________________________________ activation_95 (Activation) (None, 4, 4, 54) 0 batch_normalization_94[0][0] __________________________________________________________________________________________________ conv2d_97 (Conv2D) (None, 4, 4, 18) 8748 activation_95[0][0] __________________________________________________________________________________________________ concatenate_87 (Concatenate) (None, 4, 4, 72) 0 concatenate_86[0][0] conv2d_97[0][0] __________________________________________________________________________________________________ batch_normalization_95 (BatchNo (None, 4, 4, 72) 288 concatenate_87[0][0] __________________________________________________________________________________________________ activation_96 (Activation) (None, 4, 4, 72) 0 batch_normalization_95[0][0] __________________________________________________________________________________________________ conv2d_98 (Conv2D) (None, 4, 4, 18) 11664 activation_96[0][0] __________________________________________________________________________________________________ concatenate_88 (Concatenate) (None, 4, 4, 90) 0 concatenate_87[0][0] conv2d_98[0][0] __________________________________________________________________________________________________ batch_normalization_96 (BatchNo (None, 4, 4, 90) 360 concatenate_88[0][0] __________________________________________________________________________________________________ activation_97 (Activation) (None, 4, 4, 90) 0 batch_normalization_96[0][0] __________________________________________________________________________________________________ conv2d_99 (Conv2D) (None, 4, 4, 18) 14580 activation_97[0][0] __________________________________________________________________________________________________ concatenate_89 (Concatenate) (None, 4, 4, 108) 0 concatenate_88[0][0] conv2d_99[0][0] __________________________________________________________________________________________________ batch_normalization_97 (BatchNo (None, 4, 4, 108) 432 concatenate_89[0][0] __________________________________________________________________________________________________ activation_98 (Activation) (None, 4, 4, 108) 0 batch_normalization_97[0][0] __________________________________________________________________________________________________ conv2d_100 (Conv2D) (None, 4, 4, 18) 17496 activation_98[0][0] __________________________________________________________________________________________________ concatenate_90 (Concatenate) (None, 4, 4, 126) 0 concatenate_89[0][0] conv2d_100[0][0] __________________________________________________________________________________________________ batch_normalization_98 (BatchNo (None, 4, 4, 126) 504 concatenate_90[0][0] __________________________________________________________________________________________________ activation_99 (Activation) (None, 4, 4, 126) 0 batch_normalization_98[0][0] __________________________________________________________________________________________________ conv2d_101 (Conv2D) (None, 4, 4, 18) 20412 activation_99[0][0] __________________________________________________________________________________________________ concatenate_91 (Concatenate) (None, 4, 4, 144) 0 concatenate_90[0][0] conv2d_101[0][0] __________________________________________________________________________________________________ batch_normalization_99 (BatchNo (None, 4, 4, 144) 576 concatenate_91[0][0] __________________________________________________________________________________________________ activation_100 (Activation) (None, 4, 4, 144) 0 batch_normalization_99[0][0] __________________________________________________________________________________________________ conv2d_102 (Conv2D) (None, 4, 4, 18) 23328 activation_100[0][0] __________________________________________________________________________________________________ concatenate_92 (Concatenate) (None, 4, 4, 162) 0 concatenate_91[0][0] conv2d_102[0][0] __________________________________________________________________________________________________ batch_normalization_100 (BatchN (None, 4, 4, 162) 648 concatenate_92[0][0] __________________________________________________________________________________________________ activation_101 (Activation) (None, 4, 4, 162) 0 batch_normalization_100[0][0] __________________________________________________________________________________________________ conv2d_103 (Conv2D) (None, 4, 4, 18) 26244 activation_101[0][0] __________________________________________________________________________________________________ concatenate_93 (Concatenate) (None, 4, 4, 180) 0 concatenate_92[0][0] conv2d_103[0][0] __________________________________________________________________________________________________ batch_normalization_101 (BatchN (None, 4, 4, 180) 720 concatenate_93[0][0] __________________________________________________________________________________________________ activation_102 (Activation) (None, 4, 4, 180) 0 batch_normalization_101[0][0] __________________________________________________________________________________________________ conv2d_104 (Conv2D) (None, 4, 4, 18) 29160 activation_102[0][0] __________________________________________________________________________________________________ concatenate_94 (Concatenate) (None, 4, 4, 198) 0 concatenate_93[0][0] conv2d_104[0][0] __________________________________________________________________________________________________ batch_normalization_102 (BatchN (None, 4, 4, 198) 792 concatenate_94[0][0] __________________________________________________________________________________________________ activation_103 (Activation) (None, 4, 4, 198) 0 batch_normalization_102[0][0] __________________________________________________________________________________________________ conv2d_105 (Conv2D) (None, 4, 4, 18) 32076 activation_103[0][0] __________________________________________________________________________________________________ concatenate_95 (Concatenate) (None, 4, 4, 216) 0 concatenate_94[0][0] conv2d_105[0][0] __________________________________________________________________________________________________ batch_normalization_103 (BatchN (None, 4, 4, 216) 864 concatenate_95[0][0] __________________________________________________________________________________________________ activation_104 (Activation) (None, 4, 4, 216) 0 batch_normalization_103[0][0] __________________________________________________________________________________________________ conv2d_106 (Conv2D) (None, 4, 4, 18) 34992 activation_104[0][0] __________________________________________________________________________________________________ concatenate_96 (Concatenate) (None, 4, 4, 234) 0 concatenate_95[0][0] conv2d_106[0][0] __________________________________________________________________________________________________ batch_normalization_104 (BatchN (None, 4, 4, 234) 936 concatenate_96[0][0] __________________________________________________________________________________________________ activation_105 (Activation) (None, 4, 4, 234) 0 batch_normalization_104[0][0] __________________________________________________________________________________________________ average_pooling2d_8 (AveragePoo (None, 2, 2, 234) 0 activation_105[0][0] __________________________________________________________________________________________________ conv2d_107 (Conv2D) (None, 1, 1, 10) 9370 average_pooling2d_8[0][0] __________________________________________________________________________________________________ activation_106 (Activation) (None, 1, 1, 10) 0 conv2d_107[0][0] __________________________________________________________________________________________________ flatten_2 (Flatten) (None, 10) 0 activation_106[0][0] ================================================================================================== Total params: 962,214 Trainable params: 949,058 Non-trainable params: 13,156 __________________________________________________________________________________________________
# load previous weights
layer_weight_dic = {}
for layer_idx in range(len(model.layers)):
    layer_name_clustered = model_clustered.layers[layer_idx].name
    layer_weights_prev = model.layers[layer_idx].get_weights()
    layer_weight_dic[layer_name_clustered] = layer_weights_prev
Shift+Enter to run

Because the conv layers are densely connected in each dense block, we use a loop to change all the downstream layers that got affected.

# update to new weights
layer_weight_dic['conv2d_55'] = [test_conv_1]
layer_weight_dic['conv2d_56'] = [test_conv_2]
layer_weight_dic['batch_normalization_53'] = test_bn_1
bn_count = 54
conv_count = 57
num_clusters = max(filter_clusters) + 1
for i in range(12):
    # get new height for conv and bn
    ttl_dim = model_clustered.get_layer(f'conv2d_{conv_count}').weights[0].shape[2]
    # update conv dimension
    tmp_wt_conv = merge_filters(model.get_layer(f'conv2d_{conv_count - 54}').weights[0], 
                            np.arange(18),
                            np.concatenate((filter_clusters, np.arange(num_clusters, ttl_dim))), 
                            layer_type='conv')
    layer_weight_dic[f'conv2d_{conv_count}'] = [tmp_wt_conv]
    # update bn dimension
    tmp_wt_bn = merge_filters(model.get_layer(f'batch_normalization_{bn_count - 52}').weights, 
                          np.concatenate((filter_clusters, np.arange(num_clusters, ttl_dim))), 
                          layer_type='bn')
    layer_weight_dic[f'batch_normalization_{bn_count}'] = tmp_wt_bn
    bn_count += 1
    conv_count += 1
Shift+Enter to run

Now we can load our updated weights with the correct dimensions.

# load new weights to model
for layer in model_clustered.layers:
    layer.set_weights(layer_weight_dic[layer.name])
Shift+Enter to run

And compile the clustered model.

# determine Loss function and Optimizer for the new model
model_clustered.compile(loss='categorical_crossentropy',
              optimizer=SGD(0.01, momentum = 0.7),
              metrics=['accuracy'])
Shift+Enter to run

4.Testing

4.1 Raw Results

The original DenseNet model gives an accuracy of 83%, and the initial testing result is 73%. This shows a significant loss of information during the process. The pipeline above preserved most of the image pattern information, but there is another important layer in DenseNet - batch normalization. These layers can be viewed as brightness and contrast adjustment. However, these features are not represented in the gradient ascend results, and they are hard to cluster. Therefore, these parameters were re-estimated by only training the batch normalization layer for one epoch, just to let it know the proper contrast and brightness. 

4.1.1 Original Model

_, test_accuracy = model.evaluate(x_test, y_test, verbose=0)
print('Original acc:', test_accuracy)
Shift+Enter to run
Original acc: 0.8320000171661377

4.1.2 Clustered Model

_, test_accuracy_clustered = model_clustered.evaluate(
  np.expand_dims(x_test[:, :, :, 0], axis=3), 
  y_test, 
  verbose=0)
print('Clustered acc:', test_accuracy_clustered)
Shift+Enter to run
Clustered acc: 0.7335000038146973

4.2 Only Train BN Layers

for layer in model_clustered.layers:
    if 'batch_normalization' not in layer.name:
        layer.trainable = False
# determine Loss function and Optimizer for the new model
model_clustered.compile(loss='categorical_crossentropy',
              optimizer=SGD(0.01, momentum = 0.7),
              metrics=['accuracy'])
Shift+Enter to run
model_clustered.fit(np.expand_dims(x_train[:, :, :, 0], axis=3), 
                    y_train, epochs = 1, verbose=0, 
                    validation_data = (np.expand_dims(x_test[:, :, :, 0], axis=3), y_test))
Shift+Enter to run
_, test_accuracy_clustered = model_clustered.evaluate(
  np.expand_dims(x_test[:, :, :, 0], axis=3), 
  y_test, 
  verbose=0)
print('Clustered acc after bn adjustment:', test_accuracy_clustered)
Shift+Enter to run
Clustered acc after bn adjustment: 0.8543000221252441

0.85 is the final accuracy of the new gray-scale model, which is slightly higher than the result given by baseline model.

Shortcomings

  1. Depending on the method to merge filters, the information loss during the process can be significant.

  2. Gradient ascent gives slightly different results every time, so uncertainty exists in the process.

Potential Improvements

  1. In order to reduce information loss, the method can be improved in terms of tuning the threshold value and changing the way of distance measurement and filter clustering.

  2. In order to minimize uncertainty due to gradient ascent, average output can be taken over multiple ascent results.

Work Cited

[1] Luma (video). (2019, July 3). Retrieved from https://en.wikipedia.org/wiki/Luma_(video)

[2] 2.3. Clustering. (n.d.). Retrieved from https://scikit-learn.org/stable/modules/clustering.html#dbscan

[3] Son, S., Nah, S., & Lee, K. M. (2018). Clustering Convolutional Kernels to Compress Deep Neural Networks. Computer Vision – ECCV 2018 Lecture Notes in Computer Science, 225–240. doi: 10.1007/978-3-030-01237-3_14

[4] Llvll. (2016, January 19). llvll/imgcluster. Retrieved from https://github.com/llvll/imgcluster

[5] Wang, L., Zhang, Y., & Feng, J. (2005). On the Euclidean distance of images. IEEE Transactions on Pattern Analysis and Machine Intelligence, 27(8), 1334–1339. doi: 10.1109/tpami.2005.165

Runtimes (1)