Skip to main content

ChoiceNet: Quantitative Transfer Learning

·
Artificial Intelligence Coursework Meta Learning Stanford CS Transfer Learning
James Braza
James Braza
Artificial Intelligence and Software
Studied transfer learning fundamentals by predicting test-set accuracy given fine-tuning dataset paired with pre-trained model.

In autumn 2022, I took CS330: Deep Multi-Task and Meta Learning at Stanford. My partner and I delved into the foundations of transfer learning by attempting to build a network or technique that can inform which transfer learning starter to use, given a fine-tuning dataset. Our model, which we called ChoiceNet, had the following signature:

  • Input: two-tuple
    1. Pre-trained model’s weights or training dataset
    2. Fine-tuning dataset
  • Output: test-set accuracy after fine-tuning

In other words, if you know your fine-tuning dataset and can pick from 2+ pre-trained models, is there a heuristic or network to choose the frozen base model for fine-tuning?

Transfer Learning Dataset
#

To begin experimentation, fabricated a dataset mapping:

  • X: (transfer dataset or pre-trained weights, fine-tuning dataset)
  • Y: test-set accuracy

We used a very simple CNN as our core model being transfer learned ("TransferModel"), so we could quickly pre-train and export weights. Our fine-tuning dataset was of plant leaves. We created four categories of datapoints:

  1. Constant and similar to fine-tuning dataset (via TensorFlow plant_village dataset)
  2. Constant and dissimilar to fine-tuning dataset (via bird species dataset)
  3. Random 10-class subsets of CIFAR-100 or ImageNet
  4. Empty dataset (no pre-training = random initialization), as an experimental control
Diagram showing transfer learning dataset creation
Diagram showing the hierarchy of transfer learning dataset creation.

Architecture Versions
#

ChoiceNet v1 took the following actions:

  1. Embedding the pre-trained model: flattened weights from the TransferModel’s last 2-D convolution
  2. Embedding the fine-tuning dataset: applied 256-element principle component analysis (PCA) for bulk reduction, then averaged along examples
  3. Combining both embeddings: (1) LoRA-similar reduction of embedded pre-trained model, (2) concatenate both embeddings, and (3) pass into head of two fully-connected layers with dropout
Diagram showing ChoiceNet v1 architecture
ChoiceNet v1 architecture.

ChoiceNet v2 fixed a few misgivings:

  1. Fine-tuning dataset’s average along examples removed too much information
  2. Class-specifics were entirely ignored
  3. Transfer learning dataset was unused

To fix these issues, ChoiceNet v2 does the following:

  1. Embedding the pre-trained model and transfer dataset: use activations (not weights) from last 2-D conv, and average activations per-class
  2. Embedding the fine-tuning dataset: use activations from a pre-trained ResNet50 v2, average along classes, and keep the 10 classes with the highest average absolute value activation
  3. Combining both embeddings: (1) LoRA-similar reduction to both embeddings, (2) sum the embeddings (instead of concatenation), and (3) pass through same fully-connected head as ChoiceNet v1
Diagram showing ChoiceNet v2 architecture
ChoiceNet v2 architecture.

Findings
#

A disclaimer is this project’s core question was incredibly broad. Given just one quarter we only used one fine-tuning dataset and one domain (supervised image classification).

Scatter plot of ChoiceNet v2's performance
Scatter plot of ChoiceNet v2’s predictions vs actual. We can observe ChoiceNet v2 underestimates performance, and a similar transfer learning dataset performs best.
  • ChoiceNet v2 decreases test-time MSE loss by 80%, and underestimates final accuracy vastly less than ChoiceNet v1
  • We created several unsupervised mathematical techniques (not detailed in this article) using distribution distance, raw pixel values distributions, or average class correlation to predict test set performance

Source Code
#

jamesbraza/cs330-project

Stanford CS330 Deep Multi-Task and Meta Learning Class Project

Jupyter Notebook
1
0