Cassava Leaf Disease Classification 2/3
During these weeks, we've explored and implemented tfrecord, data augmentation, and several pre-trained models to compare the accuracy and decide which one to use.
We've used the following data augmentation methods:
tf.image.stateless_random_brightness tf.image.stateless_random_contrast tf.image.stateless_random_flip_left_right tf.image.stateless_random_flip_up_down tf.image.stateless_random_hue tf.image.stateless_random_saturationThese functions can be found in this tf.image documentation. And they are used in sequence to change the brightness, contrast, horizontal and vertical flip, hue, and saturation of our data. These are some possible variations of a plant taken by different devices, under different sunlight, or from a different angle, so they should not affect the prediction of diseases. Therefore, these functions are used to expand the training dataset.
We've tried to compare the accuracy of the validation set of models with InceptionV3, VGG16, ResNet50, DenseNet121, EfficientNetB2, and MobileNet in 5 epochs. ResNet gives the best accuracy so far, and EfficientNet converges the fastest, where it reached 60% accuracy in the first epoch.
InceptionV3:
Epoch 1/5104/104 [==============================] - 59s 441ms/step - loss: 4.6581 - sparse_categorical_accuracy: 0.4789 - val_loss: 4.4103 - val_sparse_categorical_accuracy: 0.5305Epoch 2/5104/104 [==============================] - 38s 368ms/step - loss: 2.9648 - sparse_categorical_accuracy: 0.5875 - val_loss: 3.0086 - val_sparse_categorical_accuracy: 0.6079Epoch 3/5104/104 [==============================] - 38s 368ms/step - loss: 2.0176 - sparse_categorical_accuracy: 0.5958 - val_loss: 1.6765 - val_sparse_categorical_accuracy: 0.5989Epoch 4/5104/104 [==============================] - 38s 365ms/step - loss: 1.5957 - sparse_categorical_accuracy: 0.6040 - val_loss: 2.2697 - val_sparse_categorical_accuracy: 0.6079Epoch 5/5104/104 [==============================] - 38s 364ms/step - loss: 1.3869 - sparse_categorical_accuracy: 0.6098 - val_loss: 1.5683 - val_sparse_categorical_accuracy: 0.6095VGG16:
Epoch 1/5104/104 [==============================] - 66s 504ms/step - loss: 4.8665 - sparse_categorical_accuracy: 0.4279 - val_loss: 4.0974 - val_sparse_categorical_accuracy: 0.1968Epoch 2/5104/104 [==============================] - 38s 364ms/step - loss: 3.0005 - sparse_categorical_accuracy: 0.5737 - val_loss: 4.3169 - val_sparse_categorical_accuracy: 0.6082Epoch 3/5104/104 [==============================] - 38s 367ms/step - loss: 2.0286 - sparse_categorical_accuracy: 0.6045 - val_loss: 1.9484 - val_sparse_categorical_accuracy: 0.6094Epoch 4/5104/104 [==============================] - 38s 367ms/step - loss: 1.5965 - sparse_categorical_accuracy: 0.6039 - val_loss: 2.2603 - val_sparse_categorical_accuracy: 0.1660Epoch 5/5104/104 [==============================] - 38s 365ms/step - loss: 1.3764 - sparse_categorical_accuracy: 0.6169 - val_loss: 1.2986 - val_sparse_categorical_accuracy: 0.6091ResNet:
Epoch 1/5104/104 [==============================] - 82s 588ms/step - loss: 4.6610 - sparse_categorical_accuracy: 0.4901 - val_loss: 7.3894 - val_sparse_categorical_accuracy: 0.5772Epoch 2/5104/104 [==============================] - 50s 483ms/step - loss: 3.0079 - sparse_categorical_accuracy: 0.5809 - val_loss: 2.2722 - val_sparse_categorical_accuracy: 0.6222Epoch 3/5104/104 [==============================] - 51s 488ms/step - loss: 2.0805 - sparse_categorical_accuracy: 0.6077 - val_loss: 1.6716 - val_sparse_categorical_accuracy: 0.6275Epoch 4/5104/104 [==============================] - 51s 486ms/step - loss: 1.5703 - sparse_categorical_accuracy: 0.6273 - val_loss: 1.4332 - val_sparse_categorical_accuracy: 0.6254Epoch 5/5104/104 [==============================] - 50s 484ms/step - loss: 1.2718 - sparse_categorical_accuracy: 0.6487 - val_loss: 1.2736 - val_sparse_categorical_accuracy: 0.6356DenseNet121:
Epoch 1/5104/104 [==============================] - 59s 445ms/step - loss: 4.6823 - sparse_categorical_accuracy: 0.4783 - val_loss: 18.6198 - val_sparse_categorical_accuracy: 0.1026Epoch 2/5104/104 [==============================] - 38s 367ms/step - loss: 3.0006 - sparse_categorical_accuracy: 0.5755 - val_loss: 2.7586 - val_sparse_categorical_accuracy: 0.6076Epoch 3/5104/104 [==============================] - 39s 375ms/step - loss: 2.0429 - sparse_categorical_accuracy: 0.5958 - val_loss: 3.7157 - val_sparse_categorical_accuracy: 0.6079Epoch 4/5104/104 [==============================] - 38s 368ms/step - loss: 1.6029 - sparse_categorical_accuracy: 0.6073 - val_loss: 3.2974 - val_sparse_categorical_accuracy: 0.6079Epoch 5/5104/104 [==============================] - 38s 370ms/step - loss: 1.3865 - sparse_categorical_accuracy: 0.6135 - val_loss: 1.5344 - val_sparse_categorical_accuracy: 0.6084Our best simple model has an accuracy of 0.63 on the validation set with ResNet. Here is its structure:
with strategy.scope(): img_adjust_layer = tf.keras.layers.Lambda(tf.keras.applications.resnet50.preprocess_input, input_shape=[512, 512, 3]) base_model = tf.keras.applications.ResNet50(weights='imagenet', include_top=False) base_model.trainable = False model_res = tf.keras.Sequential([ tf.keras.layers.BatchNormalization(renorm=True), img_adjust_layer, base_model, tf.keras.layers.Flatten(), tf.keras.layers.Dense(512, activation = "relu"), tf.keras.layers.BatchNormalization(), tf.keras.layers.Dropout(rate=0.25), tf.keras.layers.Dense(256, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(0.01)), tf.keras.layers.BatchNormalization(), # Add a dropout rate of 0.2 tf.keras.layers.Dropout(0.2), tf.keras.layers.Dense(128, activation = "relu"), tf.keras.layers.Dropout(rate=0.25), tf.keras.layers.Dense(5, activation = 'softmax') ]) model_res.compile( optimizer=tf.keras.optimizers.Adam(learning_rate=0.001, epsilon=0.001), loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])We tried to do a grid search for hyperparameter tuning. However, the session crushed with full RAM when we are doing it. So we kept our simple model for now and the loss and accuracy plot with ResNet is shown below.
Clearly, 5 epochs are not enough for this training, and training it in 8 epochs already brings the validation accuracy to 0.6483. Our next step will be to finalize our model to find the best set of hyperparameters, including the number of epochs.
We've also checked to see what are predictions of this model because as we discussed in part one, the majority of the training dataset consists of more than 60% of category 3. We want to make sure our model is not predicting with only class 3 that will also end up at 60% accuracy.
Here is the result:
So there are a couple hundreds of other classes, not all of them are predicted class 3.
Next Steps:
Because that we've only trained for 5 epochs of our models without any hyperparameter tuning, we will continue building up our ResNet model to improve the accuracy and minimize loss, though RAM crush has been a problem while grid searching.
Additionally, as Dan pointed out, the images are not consistent. The data includes pictures that are taken closely to one leaf, pictures of the whole plant, and images of a small seedling on the ground. Maybe they need to be classified by different models based on their image gradients.
Another option is to possibly improve performance is that we can try to use ImageDataGenerator to zoom, sheer, or rotate our images, since these aspects of images should not matter either, and see if it will reduce the generalization accuracy of our model.
Kaggle Submission:
Links:
References:
TFRecords: Learn to Use TensorFlow # 1 Helpful File Format
MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications
ImageNet: VGGNet, ResNet, Inception, and Xception with Keras
Authors:
Yue Wang, Tianqi Tang.
DATA 2040