Skip to content

Latest commit

 

History

History
81 lines (61 loc) · 9.42 KB

README.md

File metadata and controls

81 lines (61 loc) · 9.42 KB

SupportNet

SupportNet: a novel incremental learning framework through deep learning and support data

This repository shows the implementation of SupportNet, solving the catastrophic forgetting problem efficiently and effectively. A plain well-trained deep learning model often does not have the ability to learn new knowledge without forgetting the previously learned knowledge, which is known as catastrophic forgetting. Here we propose a novel method, SupportNet, to efficiently and effectively solve the catastrophic forgetting problem in the class incremental learning scenario. SupportNet combines the strength of deep learning and support vector machine (SVM), where SVM is used to identify the support data from the old data, which are fed to the deep learning model together with the new data for further training so that the model can review the essential information of the old data when learning the new information. Two powerful consolidation regularizers are applied to stabilize the learned representation and ensure the robustness of the learned model. We validate our method both theoretically and also empirically with comprehensive experiments on various tasks, which shows that SupportNet drastically outperforms the state-of-the-art incremental learning methods and even reaches similar performance as the deep learning model trained from scratch on both old and new data.

Paper

https://arxiv.org/abs/1806.02942

Datasets

  1. MNIST
  2. CIFAR-10 and CIFAR-100
  3. The EC dataset: http://www.cbrc.kaust.edu.sa/DEEPre/dataset.html (This file contains the orignal sequence data and the labels. The pickle files, which are preprocessed from this original sequence data and are the feature files ready for usage in the script can be provided based on request. We are sorry that we cannot completely release them currently, since this paper has not been officiaully published. Those feature files would be released after this paper been published.)
  4. The HeLa dataset: http://murphylab.web.cmu.edu/data/2DHeLa
  5. The BreakHis dataset: https://web.inf.ufpr.br/vri/breast-cancer-database/
  6. Tiny ImageNet dataset: https://tiny-imagenet.herokuapp.com/

Prerequisites

  1. Tensorflow (https://www.tensorflow.org/)
  2. TFLearn (tflearn.org/)
  3. CUDA (https://developer.nvidia.com/cuda-downloads)
  4. cuDNN (https://developer.nvidia.com/cudnn)
  5. sklearn (scikit-learn.org/)
  6. numpy (www.numpy.org/)
  7. Jupyter notebook (jupyter.org/)

Source Code and Experimental Records

For EC number dataset

The code is in folder src_ec. The whole program can be run by execute main.sh. That file could take advantage of supportnet.py, which is the complete implementation of SupportNet. icarl_level_1.py shows our implementation of iCaRL on this specific dataset. The other files are some temp files for testing or libraries.

The experimental results were recorded in level_1_result.md.

For CIFAR-10, CIFAR-100, HeLa and BreakHis Datasets

The code the result are in the folder src_image_data. It's written using Jupyter Notebook. Every code and result were thus recorded.

Incremental Learning

Incremental Learning

Illustration of class incremental learning. After we train a base model using all the available data at a certain time point (e.g., classes $1-N_1$), new data belonging to new classes may continuously appear (e.g., classes $N_2-N_3$, classes $N_4-N_5$, etc) and we need to equip the model with the ability to handle the new classes.

Catastrophic Forgetting

Catastrophic Forgetting

The confusion matrix of incrementally training a deep learning model following the class incremental learning scenario using different methods. (A) Random guess, (B) fine-tune (only fine tune the model with the newest data), (C) iCarl, (D) SupportNet. (B) illustrates the problem of catastrophic forgetting. If we only use the newest data to further train the model, the model does not have the ability to handle the old classes anymore.

Main framework

Overview of our framework. The basic idea is to incrementally train a deep learning model efficiently using the new class data and the support data of the old classes. We divide the deep learning model into two parts, the mapping function (all the layers before the last layer) and the softmax layer (the last layer). Using the learned representation produced by the mapping function, we train an SVM, with which we can find the support vector index and thus the support data of old classes. To stabilize the learned representation of old data, we apply two novel consolidation regularizers to the network.

Main result

Main results. (A)-(E): Performance comparison between SupportNet and five competing methods on the five datasets in terms of accuracy. For the SupportNet and iCaRL methods, we set the support data (examplar) size as 2000 for MNIST, CIFAR-10 and enzyme data, 80 for the HeLa dataset, and 1600 for the breast tumor dataset. (F): The accuracy deviation of SupportNet from the ''All Data'' method with respect to the size of the support data. The x-axis shows the support data size. The y-axis is the test accuracy deviation of SupportNet from the ''All Data'' method after incrementally learning all the classes of the HeLa subcellular structure dataset.

Notice that for the MNIST data, we can reach almost the same performance as using all data during each incremental learning iteration. That is, with our framework, we only need 2000 data points to reach the same performance level as using 50,000 data points on that specific data.

More results (Please refer the manuscript for even more results)

In this section, we investigate the performance composition of SupportNet on MNIST shown in the main result figure (A). That figure only shows the overall performance of different methods on all the testing data, averaging the performances on the old test data and the new test data, which can lose the insight of different methods' performance on old data. To avoid that, we further check the performance of different methods on the old data and the new data separately, whose results can be referred to the above figure. As shown in (B), iCaRL can maintain its performance on the oldest class batch very well, however, it is unable to maintain its performance on the intermediate class batches. GEM ((A)) can outperform iCaRL on the middle class batches, however, it cannot maintain the performance of the oldest class batch. VCL ((C)) further outperforms GEM in terms of middle class batches, however it suffers from the same problem as GEM, being unable to preserve the performance on the oldest class batch. On the other hand, both VCL with K-center Coreset and SupportNet can maintain their performance on the old data classes almost perfectly, no matter for the intermediate class batches or the oldest class batch. However, because of the difference between the two algorithms, their trade-offs are different. Although VCL with K-center Coreset can maintain the performance of old classes almost exactly, there is a trade-off of the methods on the newest classes, with the newest model being unable to achieve the optimal performance on the newest class. As for SupportNet, it allows slight performance degradation on the old classes while can achieve optimal performance on the newest class batch.

Performance of SupportNet with less support data. The experiment setting is the same as main result figure, except for that we use less support data. 'SupportNet_500' means that we use only 500 data points as support data. As shown in the figure, even SupportNet with 500 support data points can outperform iCaRL with 2000 examplars, which further demonstrates the effectiveness of our support data selecting strategy.

To further evaluate SupportNet's performance on class incremental learning setting with more classes, we tested it on tiny ImageNet dataset: https://tiny-imagenet.herokuapp.com/, comparing it with iCaRL. The setting of tiny ImageNet dataset is similar to that of ImageNet. However, its data size is much smaller than ImageNet. Tiny ImageNet has 200 classes while each class only has 500 training images and 50 testing images, which means that it is even harder than ImageNet. The performance of SupportNet and iCaRL on this dataset is shown in figure. As illustrated in the figure, SupportNet can outperform iCaRL significantly on this dataset. Furthermore, as suggested by the red line, which shows the performance difference between SupportNet and iCaRL, SupportNet's performance superiority is increasingly significant as the class incremental learning setting goes further. This phenomenon demonstrates the effectiveness of SupportNet in combating catastrophic forgetting.