Fastai, WandB, timmによるビジョンモデルのパラメータ探索とファインチューニングの方法

1 minute read

Published:

CNNやVision Transformerなどの画像認識モデル(ビジョンモデル)の学習では、探索すべきハイパーパラメータがいくつかあります。こういったハイパラチューニングや、その他の性能を最大まで引き出すために必要なプロセスについては、例えば The best vision models for fine-tuningConvNext論文 が非常に参考になります。この記事では、上記を踏まえて、Fastai、WandB (Weights and Biases)、timm (Python Image Models) を活用して画像認識タスク解決のためのモデル選定・パラメータチューニング・ファインチューニングの方法についてまとめます。

1. ファインチューニングとその方法

ファインチューニングについて簡単に説明します。詳細は 深層学習(岡谷貴之)転移学習(松井孝太、熊谷 亘) が参考なります。

ファインチューニングは、大規模なデータセットで事前学習したモデルを新しい(あるいは、独自の)タスクに適用するための手法です。事前学習したデータセットをソースドメイン、新しいタスクのデータセットをターゲットドメインとここで呼ぶことにします。ソースドメインの大規模な学習によってビジョンモデルは画像の特徴抽出能力を獲得するため、ターゲットドメインのデータセットを用いてさらにモデルのパラメータを更新することで高い性能を発揮する場合があります。

おおよそ以下の手順で行います。

  1. ソースドメインのデータセットで事前学習したモデルを読み込む:多くの場合は timm (PyTorch Image Models) などのライブラリを使って、事前学習済みのモデルを読み込みます。
  2. モデルの一部を凍結する:ソースドメインのデータセットで学習したモデルの一部、例えばhead部分(識別用のDense層)以外の特徴抽出用Conv層などを凍結します。凍結とは、その層のパラメータを更新しないようにすることです。
  3. ターゲットドメインのデータセットで学習する:ターゲットドメインのデータセットを用いて、モデルのパラメータを更新します。このとき、凍結した層以外のパラメータを更新します。
  4. モデルを解凍する:学習が進んできた段階で、凍結した層を解凍してパラメータを更新するようにします。モデルの凍結は、追加したhead層の初期パラメータの影響で特徴抽出層のパラメータが大きく変化してしまうことを防ぐために行います。
  5. モデルの全体を学習する:最終的には、モデル全体をターゲットドメインのデータセットで学習します。

どの層を凍結するか、凍結したあとどの層を解凍するかはバリエーションがあります。Conv層を凍結したら解凍せずそのまま最後まで学習する場合もあります。この記事では、head以外を凍結して学習、その後解凍して学習する方法を前提として説明します。

2. Fastai、WandB、timmを使ったファインチューニング

ファインチューニングの具体的な方法、つまり、どの層をいつまで凍結するか、どの層を解凍するか、勾配降下法・学習率・学習スケジューリング手法をどのように設定するか、などは非常に重要かつ難しい問題です。いろいろなモデルアーキテクチャが提案されますが、それらに対して画一的な学習方法で十分な性能を引き出すことは難しい場合が多く、論文で説明されている学習手法(学習スケジューリング、勾配降下法の選択など)の実装を中心にパラメータチューニングを行う必要があります。とはいえ、何をしたら良いのかは発見的に見つかる場合も多く、試行錯誤が必要なため特に私も含む初学者にとっては難しい問題です。その中で、The best vision models for fine-tuning の実装はビジョンモデルの学習における一つのベストプラクティスとして非常に参考になります。ぜひ記事に目を通し、公開されている実装をもとに学習をトレースしてみることをお勧めします。この記事では、トレースした結果をもとに、私のタスクドメインに合わせた学習方法とその実装についてまとめます。

2.1. timm (PyTorch Image Models)

timmは、PyTorchで使える画像認識モデルのライブラリです。ResNet、EfficientNet、ViT、ConvNextなど性能の高いモデルを簡単に利用できます。どんな性能の高いモデルがあるのかについては、たとえば ImageNet Benchmark (Image Classification) | Papers With Code を見てみてください。timmは、事前学習済みのモデルも提供しており、これを使うことでファインチューニングを簡単に行うことができます。Githubリポジトリ を見るとモデルの構成もわかりやすくコードに書かれているので、モデルの理解にも役立ちます。timmから使えるモデルは一つのモデルでも似たような名前のものがたくさんありますが、特に研究などで使う場合はそれぞれがどのようなバリエーションなのかコードで確認する必要があります。論文実装とは異なるバージョンがあるモデルも多いです。

timmで使えるモデルはfastaiでラップされています。なので細かな実装は次節で説明します。

2.2. Fastaiによるファインチューニング

Fastaiは、PyTorchなどの機能を簡単に使えるようにラップした高レベルな機械学習フレームワークです。細かな実装(自前のトランスフォームなど)をするにはやはりPyTorchの知識があるよいかと思います。Fastaiのよいところは、学習率スケジューリングやデータ拡張などの実装と実行が非常に簡単で、またそれらがきちんと論文の時代の流れに沿っているのでデフォルトの設定でもいい感じに学習ができることです。下に示すコードは、The best vision models for fine-tuning 記事中で示されている実装 の大部分を参考にしており、これと異なる点はデータのロード方法 get_dataset()関数、追加のトランスフォーム AddGaussianNoiseクラス、監視する評価メトリクス accuracy, Recall(), Precision(), RocAucBinary()くらいです。ハイパーパラメータチューニングで得た最適な値を config_defaults に設定しています。

Fastaiでファインチューニングするコードの例を以下に示します。

import wandb
import argparse
import torchvision as tv
from fastai.vision.all import *
from fastai.callback.wandb import WandbCallback
from fastai.callback.tracker import SaveModelCallback

WANDB_PROJECT = 'drop-collision-detection'

config_defaults = SimpleNamespace(
    batch_size=64,
    epochs=10,
    num_experiments=1,
    learning_rate=0.02909,
    model_name="convnext_small_in22k",
    pool="concat",
    seed=42,
    wandb_project=WANDB_PROJECT,
    split_func="default",
)

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', type=int, default=config_defaults.batch_size)
    parser.add_argument('--epochs', type=int, default=config_defaults.epochs)
    parser.add_argument('--num_experiments', type=int, default=config_defaults.num_experiments)
    parser.add_argument('--learning_rate', type=float, default=config_defaults.learning_rate)
    parser.add_argument('--model_name', type=str, default=config_defaults.model_name)
    parser.add_argument('--split_func', type=str, default=config_defaults.split_func)
    parser.add_argument('--pool', type=str, default=config_defaults.pool)
    parser.add_argument('--seed', type=int, default=config_defaults.seed)
    parser.add_argument('--wandb_project', type=str, default=WANDB_PROJECT)
    return parser.parse_args()

def get_gpu_mem(device=0):
    "Memory usage in GB"
    gpu_mem = torch.cuda.memory_stats_as_nested_dict(device=device)
    return (gpu_mem["reserved_bytes"]["small_pool"]["peak"] + gpu_mem["reserved_bytes"]["large_pool"]["peak"])*1024**-3

class AddGaussianNoise(RandTransform):
    def __init__(self, mean=0., std=100., **kwargs):
        self.mean = mean
        self.std = std
        super().__init__(**kwargs)
        
    def encodes(self, x: TensorImage):
        return x + torch.randn(x.size(), device= x.device) * self.std + self.mean

def get_dataset(batch_size, seed, *args, **kwargs):
    dataset_path = "./data/"
    train_path = Path(dataset_path)
    dls = ImageDataLoaders.from_folder(
        path=train_path,
        train="train",
        valid="test",
        item_tfms=Resize(224),
        batch_tfms=[*aug_transforms(size=224, max_zoom=1.0, max_warp=0.0, mult=1.0 ), AddGaussianNoise(mean=0., std=7., p=0.75)],
        bs=batch_size,
        val_bs=batch_size,
        seed=seed
    )
    metrics = [accuracy, Recall(), Precision(), RocAucBinary()]
    
    return dls, metrics

def train(config=config_defaults):
    with wandb.init(project=config.wandb_project, config=config) as run:
        run.name = f"{config.model_name}_bs_{config.batch_size}_lr_{config.learning_rate}_pool_{config.pool}"        
        config = wandb.config
        dls, metrics = get_dataset(config.batch_size, config.seed)    
        learn = vision_learner(
                dls, config.model_name, metrics=metrics, concat_pool=(config.pool=="concat"),
                cbs=[WandbCallback(log=None, log_preds=False), SaveModelCallback(every_epoch=True)]).to_fp16()
        ti = time.perf_counter()
        learn.fine_tune(config.epochs, config.learning_rate)
        
        wandb.summary["GPU_mem"] = get_gpu_mem(learn.dls.device)
        wandb.summary["model_family"] = config.model_name.split('_')[0]
        wandb.summary["fit_time"] = time.perf_counter() - ti

if __name__ == "__main__":
    args = parse_args()
    train(config=args)

2.3. WandB (Weights and Biases)によるハイパーパラメータチューニングと学習のトラッキング

WandBは学習の進行状況やメトリクスをクラウドでトラッキングできるサービスで、異なるモデルの比較はもちろん、ハイパーパラメータの範囲と探索手法を指定してハイパーパラメータチューニングを行うこともできます。vision_learner にWandBのコールバックを渡せば自動で情報を記録してくれるので非常に楽です。次のような settings.yaml を書いて探索範囲と手法を宣言します。この場合は,ベイズ最適化でValidation lossを最小化するために batch_size, learning_rate, model_name のパラメータを掃引 (sweep) します。

method: bayes
metric:
  goal: minimize
  name: valid_loss
parameters:
  batch_size:
    values:
      - 16
      - 32
      - 64
  learning_rate:
    distribution: uniform
    max: 0.05
    min: 1e-05
  model_name:
    values:
      - convnext_small_in22k
      - convnextv2_tiny.fcmae_ft_in22k_in1k
program: finetune.py

学習したモデルの性能比較を以下に可視化しています。

https://api.wandb.ai/links/dai-personal/b9g3fq8r

3. まとめ

Fastai、WandB、timmを使ってビジョンモデルのファインチューニングを行う方法についてまとめました。私自身は機械学習のコミュニティに属してはおらず一人で勉強しているため、間違っている点やこうしたほうがよいなどのアドバイスがあればぜひ連絡をいただきたいです。また、この記事が他の方の学習に少しでも役立てば幸いです。