Dealing with class imbalances
Disclaimer: Throughout this blog post I'll use hard negatives and hard examples interchangeably since the former can be generalized, and the latter is more widely known.
About two months ago, Kaggle launched the Kuzushiji Recognition challenge: Japan has millions of books and documents written in Kuzushiji, yet only 0.01% of modern Japanese people know how to read it . The challenge is asking us to build an Optical Character Recognition (OCR) model that would be able to map Kuzushiji characters into modern Japanese characters. The challenge is therefore divided into two parts:
- Detection: the model should be able to locate where all of the written characters are inside a given image with high accuracy.
- Classification: create a one-to-one correspondence model where one Kuzushiji character is mapped to its correspondent modern character.
The main problem that I want to address in this post is that of class imbalance present in the training dataset. This post will be specific to object detection, but each mentioned method here can be easily translated to object classification.
For detection, it is normal for the background class to be 100,000 times as large as the foreground class for a sliding-window models [1]; for region-based networks the background-to-foreground ratio is 70:1. And for classification, a Kuzushiji character's number of instances may range from 25,000 to only 1, as shown in the following figure.
Now, the most important thing to take away is that every machine learning problem has to deal with imbalances: tumor detection, medical diagnosis, fraud detection, spam filtering, OCR, stock movement predictions; every single one of those challenges require a considerate degree of knowledge on how to predict ─and thus deal─ with rare cases. Trying to train a neural network ─or any type of model for that matter─ will not yield very good results if we do not take into account the class imbalance nature of the dataset.
This is where hard negative mining takes place, as it focuses on reducing the false positive rate of the model (thus fixing the increased model bias of one class over the other). An instance is considered "hard" if it surpasses a loss threshold, conversely an instance is "easy" if its loss is less than the threshold.
Hard negative mining has been implemented on three areas of research, namely SVM optimization, boosted trees, and shallow and deep neural networks. After explaining bootstrapping I will focus on hard mining techniques for neural networks as they have proven to be very successful since ImageNet.
Bootstrapping
The oldest trick in the book to reduce false triggers is called bootstrapping. It was first introduced by Sung and Poggio [2] over 20 years ago in order to train a face detection model. One of the main problems while training a detector model is choosing how many instances of the negative class are representative enough to yield good results.
For example, deciding the number of the "face" (positive class) instance that should be present in the dataset is fairly easy; but in order to get a good representation of the "non-face" (negative class) instances, we must include every other object in the universe. For obvious reasons this is not efficient nor feasible, as well as the dataset would be intractable.
As a consequence, every dataset is limited to possessing a small and tractable representation of the negative class. And since our model will be trained on this dataset, it is more likely for it to infer false positives on newly presented data.
The number of false positives increases when the number of negative instances present in the dataset decreases. The question of how can we overcome this tradeoff arises. Now enter bootstrapping presented by Sung and Poggio [2]:
- Start with a small and possibly non-representative set of the negative class (non-face) examples in the training set.
- Train the model with the current training set.
- Run the model on a sequence of images that don't contain the positive class (faces).
- Collect all the instances that the model wrongly classified as part of the positive class and add them to the training set.
- Return to step 2.
This cycle can be repeated until convergence ─which is not guaranteed─ is reached. Shrivastava et al. [1] recommend running it once, but the original authors [2] ran it twice for their face detector.
As a side note, I would like to mention that since the model decreases the number of false positives, it may increase the number of false negatives. But it all depends on how representative of the positive class the dataset is.
Online Hard Example Mining
Some inconveniences about bootstrapping are that i) we have to re-train the model for as many cycles as we deem necessary, and ii) in-between two cycles we have to freeze the model in order to append more data to our working set. Although this may not seem like a problem for models such as SVMs, shallow networks and boosted trees, deep networks already take a lot of time to train. Freezing the model, then adding data to the working set, then continue training a very deep network is not a luxury we can afford due to time constraints.
Online Hard Example Mining (OHEM) was introduced by Shrivastava et al. [1] as a way to fetch hard examples while training the model, thus eliminating the "freezing phase" of bootstrapping. In their paper they proposed the following method to obtain hard examples in object detection:
- Avoid calculating correlated losses (deduplication) by applying non-maximum suppression to highly overlapped RoIs (IoU = 0.7).
- Calculate the loss of each RoI in the forward pass.
- Sort the losses and fetch the top k on which the model performed the worst.
- Back-propagate only the top k losses through the model.
Focal loss
On the modern era of object detection, the models are divided into two categories:
- Two-stage detectors. The first stage generates region candidates that should contain all of the positive classes (person, dog, car, etc.) while filtering out the majority of the negative class (background). Finally, the second stage is in charge of separating the foreground classes from the background. Examples of these models are R-CNN and Faster R-CNN.
- One-stage detectors. Without generating region candidates, these models attempt to separate the foreground from the background on a "single shot" instance. YOLO and SSD models are the most well-known one-stage detector models.
Both of these methods are on an extreme of the speed/accuracy tradeoff, being the former more accurate and the latter faster. It may seem like this is a good tradeoff depending on our needs, but in reality we would rather have a model that is both fast and accurate.
Two-stage detectors address the problem of extreme class imbalance by using cascades and other heuristic methods to generate candidate regions, thus having to process 1-2k locations. In contrast, one-stage detectors have to process ~100k locations where the great majority of them belong to the negative class. Tsu-Yi et al. [3] argue that this is the reason why one-stage detectors have 10-20% lower AP score than their counterparts.
Instead of focusing on hard mining techniques like those aforementioned, Tsu-Yi et al. designed a loss function which down-weights easily classified examples. By doing this, they avoid accumulating "unnecessary" losses that eventually end up dominating the gradient while still keeping information about hard examples; i.e. they found a way to separate easy examples from hard ones via the loss function
The Focal Loss, as they call it, is the same as a weighted crossentropy that is dynamically scaled to zero by the modulating factor
Conclusions
Due to the common presence of class imbalance in machine learning problems, researches have come up with various methods capable of reducing model bias on the dominant class over the rare one. Throughout this blog post we have covered three of the most common practices that have proved to be both efficient and capable of achieving state-of-the-art results:
- Bootstrapping: although it has been around for more than 20 years, it is still a very good go-to method to train SVMs and other shallow detectors.
- Online Hard Example Mining: presented as a solution that eliminates the "freezing phase" from bootstrapping, it allows for deep models to focus on hard examples while also reducing the use of heuristic methods.
- Focal loss: the Focal Loss is a means to avoid the gradient being overtaken by the accumulation of the losses of easy examples. Instead, it differentiates between easy and hard examples to just focus on the latter, thus forcing the model into learning new patterns.
Despite most of object detectors being used nowadays are thankfully pretrained, it is necessary for us to know this techniques as they are also being applied in a large spectrum of fields, covering from medical diagnosis to fraud detection. Moreover, a good estimation is that class imbalance is not going anywhere for the foreseeable future, so we better be prepared in order to be able to take advantage of it as much as possible.
If you want to know more about ways of dealing with class imbalances go to the Further Reading section of this post.
Further reading
References
[1] A. Shrivastava, A. Gupta, and R. Girshick. Training region-based object detectors with online hard example mining. In CVPR, 2016.
[2] K.-K. Sung and T. Poggio. Learning and Example Selection for Object and Pattern Detection. In MIT A.I. Memo No. 1521, 1994.
[3] T.-Y Lin, P. Goyal, R. Girshick, K. He, and P. Dollár. Focal Loss for Dense Object Detection. arXiv:1708.02002, 2017.