14 minute read

This post explores unique concepts in few-shot learning.

Introduction

Few-shot learning describes the situation where a classifier must generalize to new classes using only a few examples of each new class. It represents scenarios where data collection (or annotation) is costly for the classes you care about, but you may have access to related data. There are some unique terms associated with few-shot learning, which we will introduce next.

In standard supervised learning, the training and testing sets contain the SAME classes. The classifier is tested for its ability to discriminate between KNOWN classes.

supervised learning

image reference

In few-shot learning, the training and testing sets contain DIFFERENT classes. A few-shot classifier is tested for its ability to discriminate UNKNOWN classes given only a few examples. These samples are further divided into the support set and query set. The support set contains labeled examples (for tuning the few-shot classifier) and the query set contains unlabeled samples for evaluation. The number of classes in the support set is denoted by N and the number of examples per class is denoted by K. A 5-way 1-shot method describes a method that has (N=5) classes with (K=1) examples per class.

few_shot_learning

image reference

Experiment

This experiment will test a model in a 5-way 5-shot task (i.e. 5 new classes, 5 examples each). We will use a standard benchmark in few-shot learning - CUB-200-2011 dataset. It is a fine-grained image classification task with over 200 different species of birds. The large number of classes make it suitable for few-shot learning as several can be withheld for testing while maintaining a difficult baseline. We next explain our basic approach to few-shot learning (prototype models) and then compare the performance of three ways to train the model.

Few-shot framework: Prototype model

We adopt a single framework for all variations of our experiment - Prototypical networks. The concept is based on the idea that there is a single prototype representation for each class. A prototype for each class is calculated from the mean of its support set in the embedding space of a neural network. Our three approaches explore three different ways to train this neural network. Classification in a prototypical model is simply a nearest neighbors using only the prototypes1. The code uses models and task loaders from Easy Few-shot learning.

prototypical network

image reference

Approach #1: Pretrained model

The first approach adapts a pretrained model using transfer learning. The transfer learning process starts with a model that is trained on a large and general dataset (e.g. Imagenet). The purpose of this model is to rely on the learned feature maps - which should be robust - and adapt the later layers to classify your new classes (see image below).

transfer model

image reference

This experiment uses a ResNet-18 model with pretrained weights from an Imagenet dataset. For each task in the evaluation phase, the support set is used to calculate the prototypes and the the query sets are classified from the prototypes.

model = resnet18(weights=ResNet18_Weights.DEFAULT)
model.fc = nn.Flatten()
few_shot_classifier = PrototypicalNetworks(model).to(DEVICE)

Approach #2: Classical model

This approach uses a subset of the training data to generate a pretrained model. NOTE: the test data contains classes that were not included in the training data. This approach is not always viable since it requires prior knowledge of the few-shot task data distribution. However, it can be powerful if you work in a specific domain and have labeled data from that domain readily available. To simulate it here, we split the 200 species in CUB into 140 for training and 60 for testing. During test time, a 5-way 5-shot task is created by randomly sampling 5 classes from the 60 test classes. This process is repeated 500 times to get a statistically significant measure of generalization performance.

model = resnet18()
model.fc = nn.Linear(512, 140)
model.load_state_dict(torch.load('/content/easy-few-shot-learning/classical_model_18_acc_744.pt')) # pretrained just for you!
few_shot_classifier = PrototypicalNetworks(model).to(DEVICE)

Approach #3: Episodic model

Episodic training (also called meta-learning) mirrors the few-shot tasks that will be used to test the final model during the training phase. Each “episode” is designed to mimic the few-shot task by subsampling classes as well as data points. The use of episodes makes the training problem more faithful to the test environment and thereby improves generalization. This strategy typically assumes prior knowledge of N and K that will be used at test time. ref1 ref2

Episodic training can be performed following a classic training regime if additional data is available or using pretrained-weights if it is not. In this case we reserve the same 140 classes as in Approach #2 for training. But instead of performing classic training, which reduces cross entropy across all classes, we adjust weights based on many training episodes (5-way 5-shot). This process is slower as there is additional overhead required for each episode.

model = resnet18()
model.fc = nn.Flatten()
model.load_state_dict(torch.load('/content/easy-few-shot-learning/episodic_model_18_acc_779.pt')) # pretrained just for you!
few_shot_classifier = PrototypicalNetworks(model).to(DEVICE)

BONUS Approach: Pretrained model (DINOv2 backbone)

The three primary variations above can be compared directly because they are all variations of Resnet-18 and use Prototypical models for evaluation. In this bonus approach, we replace the Resnet-18 backbone of the pretrained model (approach #1) with DINOv2. The rest of the setup is the same.

Results

For this purposes of this discussion we will focus on Approaches #1-#3. The accuracy improves with each successive approach, but the assumptions also increase. Namely, in moving from transfer learning to classic learning we assume knowledge of the domain of the test classes and also assume access to labeled examples from that domain. In moving from classic learning to episodic learning we further assume knowledge of the task itself (5-way 5-shot) and the distribution of the data within it (i.e. uniformly sampled). This does not always reflect the real word - a study showed that if these assumptions are incorrect they can impact performance significantly.

Backbone Training approach Accuracy
Resnet-18 transfer learning 0.684
Resnet-18 classic 0.773
Resnet-18 episodic 0.779
DINOv2 transfer learning 0.964

Conclusion

We compared three approaches to few-shot learning and discussed the conditions and assumptions necessary to use them. The code for this experiment can be found at my ML portfolio website.

Footnotes

  1. There is some similarity between Prototypical models and Image retrieval tasks. Image retrieval can be viewed as a prototypical model where each image is its own prototype. And instead of returing a single class, the top-N closest prototypes are returned.