Cassava Leaf Disease Classification 2/3
During these weeks, we've explored and implemented tfrecord, data augmentation, and several pre-trained models to compare the accuracy and decide which one to use.
We've used the following data augmentation methods:
tf.image.stateless_random_brightness
tf.image.stateless_random_contrast
tf.image.stateless_random_flip_left_right
tf.image.stateless_random_flip_up_down
tf.image.stateless_random_hue
tf.image.stateless_random_saturation
These functions can be found in this tf.image documentation. And they are used in sequence to change the brightness, contrast, horizontal and vertical flip, hue, and saturation of our data. These are some possible variations of a plant taken by different devices, under different sunlight, or from a different angle, so they should not affect the prediction of diseases. Therefore, these functions are used to expand the training dataset.
We've tried to compare the accuracy of the validation set of models with InceptionV3, VGG16, ResNet50, DenseNet121, EfficientNetB2, and MobileNet in 5 epochs. ResNet gives the best accuracy so far, and EfficientNet converges the fastest, where it reached 60% accuracy in the first epoch.
InceptionV3:
Epoch 1/5
104/104 [==============================] - 59s 441ms/step - loss: 4.6581 - sparse_categorical_accuracy: 0.4789 - val_loss: 4.4103 - val_sparse_categorical_accuracy: 0.5305
Epoch 2/5
104/104 [==============================] - 38s 368ms/step - loss: 2.9648 - sparse_categorical_accuracy: 0.5875 - val_loss: 3.0086 - val_sparse_categorical_accuracy: 0.6079
Epoch 3/5
104/104 [==============================] - 38s 368ms/step - loss: 2.0176 - sparse_categorical_accuracy: 0.5958 - val_loss: 1.6765 - val_sparse_categorical_accuracy: 0.5989
Epoch 4/5
104/104 [==============================] - 38s 365ms/step - loss: 1.5957 - sparse_categorical_accuracy: 0.6040 - val_loss: 2.2697 - val_sparse_categorical_accuracy: 0.6079
Epoch 5/5
104/104 [==============================] - 38s 364ms/step - loss: 1.3869 - sparse_categorical_accuracy: 0.6098 - val_loss: 1.5683 - val_sparse_categorical_accuracy: 0.6095
VGG16:
Epoch 1/5
104/104 [==============================] - 66s 504ms/step - loss: 4.8665 - sparse_categorical_accuracy: 0.4279 - val_loss: 4.0974 - val_sparse_categorical_accuracy: 0.1968
Epoch 2/5
104/104 [==============================] - 38s 364ms/step - loss: 3.0005 - sparse_categorical_accuracy: 0.5737 - val_loss: 4.3169 - val_sparse_categorical_accuracy: 0.6082
Epoch 3/5
104/104 [==============================] - 38s 367ms/step - loss: 2.0286 - sparse_categorical_accuracy: 0.6045 - val_loss: 1.9484 - val_sparse_categorical_accuracy: 0.6094
Epoch 4/5
104/104 [==============================] - 38s 367ms/step - loss: 1.5965 - sparse_categorical_accuracy: 0.6039 - val_loss: 2.2603 - val_sparse_categorical_accuracy: 0.1660
Epoch 5/5
104/104 [==============================] - 38s 365ms/step - loss: 1.3764 - sparse_categorical_accuracy: 0.6169 - val_loss: 1.2986 - val_sparse_categorical_accuracy: 0.6091
ResNet:
Epoch 1/5
104/104 [==============================] - 82s 588ms/step - loss: 4.6610 - sparse_categorical_accuracy: 0.4901 - val_loss: 7.3894 - val_sparse_categorical_accuracy: 0.5772
Epoch 2/5
104/104 [==============================] - 50s 483ms/step - loss: 3.0079 - sparse_categorical_accuracy: 0.5809 - val_loss: 2.2722 - val_sparse_categorical_accuracy: 0.6222
Epoch 3/5
104/104 [==============================] - 51s 488ms/step - loss: 2.0805 - sparse_categorical_accuracy: 0.6077 - val_loss: 1.6716 - val_sparse_categorical_accuracy: 0.6275
Epoch 4/5
104/104 [==============================] - 51s 486ms/step - loss: 1.5703 - sparse_categorical_accuracy: 0.6273 - val_loss: 1.4332 - val_sparse_categorical_accuracy: 0.6254
Epoch 5/5
104/104 [==============================] - 50s 484ms/step - loss: 1.2718 - sparse_categorical_accuracy: 0.6487 - val_loss: 1.2736 - val_sparse_categorical_accuracy: 0.6356
DenseNet121:
Epoch 1/5
104/104 [==============================] - 59s 445ms/step - loss: 4.6823 - sparse_categorical_accuracy: 0.4783 - val_loss: 18.6198 - val_sparse_categorical_accuracy: 0.1026
Epoch 2/5
104/104 [==============================] - 38s 367ms/step - loss: 3.0006 - sparse_categorical_accuracy: 0.5755 - val_loss: 2.7586 - val_sparse_categorical_accuracy: 0.6076
Epoch 3/5
104/104 [==============================] - 39s 375ms/step - loss: 2.0429 - sparse_categorical_accuracy: 0.5958 - val_loss: 3.7157 - val_sparse_categorical_accuracy: 0.6079
Epoch 4/5
104/104 [==============================] - 38s 368ms/step - loss: 1.6029 - sparse_categorical_accuracy: 0.6073 - val_loss: 3.2974 - val_sparse_categorical_accuracy: 0.6079
Epoch 5/5
104/104 [==============================] - 38s 370ms/step - loss: 1.3865 - sparse_categorical_accuracy: 0.6135 - val_loss: 1.5344 - val_sparse_categorical_accuracy: 0.6084
Our best simple model has an accuracy of 0.63 on the validation set with ResNet. Here is its structure:
with strategy.scope():
img_adjust_layer = tf.keras.layers.Lambda(tf.keras.applications.resnet50.preprocess_input, input_shape=[512, 512, 3])
base_model = tf.keras.applications.ResNet50(weights='imagenet', include_top=False)
base_model.trainable = False
model_res = tf.keras.Sequential([
tf.keras.layers.BatchNormalization(renorm=True),
img_adjust_layer,
base_model,
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(512, activation = "relu"),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dropout(rate=0.25),
tf.keras.layers.Dense(256, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(0.01)),
tf.keras.layers.BatchNormalization(),
# Add a dropout rate of 0.2
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(128, activation = "relu"),
tf.keras.layers.Dropout(rate=0.25),
tf.keras.layers.Dense(5, activation = 'softmax')
])
model_res.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=0.001, epsilon=0.001),
loss='sparse_categorical_crossentropy',
metrics=['sparse_categorical_accuracy'])
We tried to do a grid search for hyperparameter tuning. However, the session crushed with full RAM when we are doing it. So we kept our simple model for now and the loss and accuracy plot with ResNet is shown below.
Clearly, 5 epochs are not enough for this training, and training it in 8 epochs already brings the validation accuracy to 0.6483. Our next step will be to finalize our model to find the best set of hyperparameters, including the number of epochs.
We've also checked to see what are predictions of this model because as we discussed in part one, the majority of the training dataset consists of more than 60% of category 3. We want to make sure our model is not predicting with only class 3 that will also end up at 60% accuracy.
Here is the result:
So there are a couple hundreds of other classes, not all of them are predicted class 3.
Next Steps:
Because that we've only trained for 5 epochs of our models without any hyperparameter tuning, we will continue building up our ResNet model to improve the accuracy and minimize loss, though RAM crush has been a problem while grid searching.
Additionally, as Dan pointed out, the images are not consistent. The data includes pictures that are taken closely to one leaf, pictures of the whole plant, and images of a small seedling on the ground. Maybe they need to be classified by different models based on their image gradients.
Another option is to possibly improve performance is that we can try to use ImageDataGenerator to zoom, sheer, or rotate our images, since these aspects of images should not matter either, and see if it will reduce the generalization accuracy of our model.
Kaggle Submission:
Links:
References:
TFRecords: Learn to Use TensorFlow # 1 Helpful File Format
MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications
ImageNet: VGGNet, ResNet, Inception, and Xception with Keras
Authors:
Yue Wang, Tianqi Tang.
DATA 2040