roboflowの物体検出モデル「RF-DETR」の環境を構築してみた

はじめに

株式会社ジャパン・インフラ・ウェイマークの川邉です。 当社はNTT西日本の子会社で、ドローン×画像解析AIを活用したインフラ点検を主に行っています。

本記事では2025年3月にRoboflow社が発表したRF-DETRという物体検出モデルの環境構築を行った際に、色々と調べた結果をまとめています。記事執筆時点の最新版である1.3.0について記載していますが、発表から半年ちょっとですでに300コミットを超えているリポジトリですので、すぐに古い情報となってしまうかもしれませんが、新しく始める時のとっかかりくらいのつもりでお読みください。

対象読者

本記事が想定する対象読者は以下の通りです。

  • 新しい物体検出モデルに興味がある人
  • 業務で使いやすい物体検出モデルを探している人

背景

高性能で使いやすい物体検出モデルとしてはUltralytics社の提供するUltralytics YOLO™が有名ですが、ライセンスの関係で商用システムの開発には若干使いづらいものとなっています。 そんな中、2025年3月にRoboflow社がRF-DETRという新しい物体検出モデルを発表しました。公式のページによると、COCO datasetに対してmAPが60を超える高性能となっており、なおかつ、ライセンスはApache 2.0なので非常に使い勝手の良いものとなっています。

ただ、出たばかりで情報も少ないことから、実際に自前で学習/推論環境を構築しながら、できること/できないことを色々と調査してみました。

前提条件

  • 執筆時点で最新版の1.3.0を使って調査しました
  • 環境構築はWSL2(Ubuntu-22.04)上で実施しています
  • 仮想環境にはPoetryを利用しましたが、このあたりは何でも良いと思います

コード

とりあえず結論からということで、最終的に動作確認した学習/開発環境のコードを以下に記載します。

pyproject.toml

[tool.poetry]
name = "rf_detr_test"
version = "0.1.0"
description = ""
authors = []

[tool.poetry.dependencies]
python = ">=3.10,<3.15"
torch = {version = "^2.9.0", extras = ["cuda"]}
torchvision = "^0.24.0"
pillow = "^12.0.0"

[tool.poetry.dev-dependencies]
pytest = "^5.2"

[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

train.py

import torch
import rfdetr
import os
import argparse

# 学習結果格納先のディレクトリのパスを返す
def train(
    dataset_dir: str,
    num_classes: int,
    epochs: int = 1000,
    patience: int = 50,
    batch_size: int = 4,
    lr: float = 1e-4,
    pretrained: bool = True,
    save_dir: str = './result',
    size: str = 'medium',
    resolution: int = 640
) -> str:

    # 結果保存先ディレクトリの作成
    result_root = os.path.join(save_dir, 'rf_detr_train')
    save_dir = os.path.join(result_root, 'result')
    if os.path.exists(save_dir):
        count = 1
        while os.path.exists(save_dir):
            save_dir = os.path.join(result_root, f'result_{count}')
            count += 1
    os.makedirs(save_dir, exist_ok=True)

    if size in ('nano', 'small', 'medium'):
        if resolution % 32 != 0:
            raise ValueError(f'モデルサイズが{size}の場合、resolutionは32の倍数を指定してください')
    elif size in ('large'):
        if resolution % 56 != 0:
            raise ValueError(f'モデルサイズが{size}の場合、resolutionは56の倍数を指定してください')

    # 利用可能なデバイスの確認
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # モデル初期化
    model_classes = {
        'medium': rfdetr.RFDETRMedium,
        'nano': rfdetr.RFDETRNano,
        'large': rfdetr.RFDETRLarge,
        'small': rfdetr.RFDETRSmall
    }
    model = model_classes[size](
        num_classes=num_classes,
        pretrained=True,
        device=device
    ) if size in model_classes else None

    if model is None:
        raise ValueError(f'invalid model size: {size}')

    model.train(
        dataset_dir=dataset_dir,
        epochs=epochs,
        batch_size=batch_size,
        lr=lr,
        output_dir=save_dir,
        early_stopping=True,
        patience=patience,
        resolution=resolution,
        checkpoint_interval=epochs # 余分なチェックポイントをなるべく保存させない
    )

    return save_dir

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='RF-DETRを使って学習を実施するプログラム')
    parser.add_argument('--dataset', required=True, help='学習データセットの格納されているディレクトリ')
    parser.add_argument('--size', default='medium', choices=['nano', 'small', 'medium', 'large'], help='モデルサイズ')
    parser.add_argument('--num_classes', type=int, required=True, help='クラス数')
    parser.add_argument('--epochs', type=int, default=1000, help='学習エポック数')
    parser.add_argument('--patience', type=int, default=50, help='lossが改善しなかった場合に学習を完了させるエポック数')
    parser.add_argument('--batch_size', type=int, default=4, help='バッチサイズ')
    parser.add_argument('--save', default='result', help='結果格納先のディレクトリ')
    parser.add_argument('--resolution', type=int, default=640, help='学習時の画像サイズ')
    parser.add_argument('--no-pretrained', action='store_true', help='事前学習済みのデータセットを利用しない')
    args = parser.parse_args()

    train(
        dataset_dir = args.dataset,
        num_classes = args.num_classes,
        epochs = args.epochs,
        patience = args.patience,
        batch_size = args.batch_size,
        lr = 1e-4,
        pretrained = not args.no_pretrained,
        save_dir = args.save,
        size = args.size,
        resolution = args.resolution
    )

predict.py

import torch
from PIL import Image
import rfdetr
import argparse
import os

def predict(
    checkpoint_path: str,
    image_dir: str,
    num_classes: int,
    size: str = 'medium',
    resolution: int = 640,
    threshold: float = 0.5
):
    if size in ('nano', 'small', 'medium'):
        if resolution % 32 != 0:
            raise ValueError(f'モデルサイズが{size}の場合、resolutionは32の倍数を指定してください')
    elif size in ('large'):
        if resolution % 56 != 0:
            raise ValueError(f'モデルサイズが{size}の場合、resolutionは56の倍数を指定してください')

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # 推論時もnum_classesの指定は必要
    # https://github.com/roboflow/rf-detr/issues/51
    
    model_classes = {
        'medium': rfdetr.RFDETRMedium,
        'nano': rfdetr.RFDETRNano,
        'large': rfdetr.RFDETRLarge,
        'small': rfdetr.RFDETRSmall
    }

    model = model_classes[size](
        device=device,
        num_classes=num_classes,
        pretrained=False,
        pretrain_weights=checkpoint_path
    ) if size in model_classes else None

    if model is None:
        raise ValueError(f'invalid model size: {size}')

    for image_path in [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith(('.jpg', '.png', '.jpeg'))]:
        print(f'load image: {image_path}')

        # 推論対象画像の読み込み
        image = Image.open(image_path).convert("RGB")

        # 推論実行
        detections = model.predict(
            image,
            threshold=threshold,
            resolution=resolution
        )

        # 推論結果の表示
        for xyxy, class_id, score in zip(detections.xyxy, detections.class_id, detections.confidence):
            print(xyxy, class_id, score)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='RF-DETRを学習したモデルを使って推論を実施するプログラム')
    parser.add_argument('--model', '-m', required=True, help='学習した重みファイルのパス')
    parser.add_argument('--image', '-i', required=True, help='画像ファイルの格納されているディレクトリ')
    parser.add_argument('--size', '-s', default='medium', choices=['nano', 'small', 'medium', 'large'], help='モデルサイズ')
    parser.add_argument('--resolution', '-r', type=int, default=640, help='推論時の画像サイズ')
    parser.add_argument('--threshold', '-t', type=float, default=0.5, help='推論時の閾値')
    parser.add_argument('--num_classes', '-n', type=int, required=True, help='クラス数')
    args = parser.parse_args()

    predict(
        checkpoint_path = args.model,
        image_dir = args.image,
        size = args.size,
        num_classes = args.num_classes,
        resolution = args.resolution,
        threshold = args.threshold
    )

構築方法

Poetryを使いましたので、pyproject.tomlの存在するディレクトリで

poetry install

を実行すれば、仮想環境上に必要なモジュールなどがインストールされます。 ただし、上のpyproject.tomlだけでは肝心のRF-DETRがインストールされません。私の環境ではrfdetrだけはpoetry addでインストールしようとするとエラーが発生したため、

poetry run pip install rfdetr==1.3.0

で別にインストールしています。実行環境によっては、pyproject.toml内で定義してもちゃんとインストールされるのではないかと思います。

学習データのディレクトリ構成

RF-DETRの学習に関するREADMEでは以下のように記載されています。

dataset/
├── train/
│   ├── _annotations.coco.json
│   ├── image1.jpg
│   ├── image2.jpg
│   └── ... (other image files)
├── valid/
│   ├── _annotations.coco.json
│   ├── image1.jpg
│   ├── image2.jpg
│   └── ... (other image files)
└── test/
    ├── _annotations.coco.json
    ├── image1.jpg
    ├── image2.jpg
    └── ... (other image files)

なお、RF-DETRの内部ではCOCO JSON形式のファイルの処理にpycocotools.COCO.loadResを利用しています。このモジュールでは 2.0.9 以降はCOCO JSONのトップレベルキーに info がないとエラーが発生するようになっているようです。単純に学習データの定義だけであれば不要な項目かもしれませんが、安全のためにつけておくと良いかと思います。

学習方法

上記コードで学習を実施する場合は、以下のように学習用データセットのパス(--dataset)とクラス数(--num_classes)を必ず付与してください(パスは環境に応じて適宜変更してください)。

poetry run python3 rf_detr_test/train.py --dataset ~/coco_dataset/ --num_classes 2

ちなみに、実際に同じデータセットでYOLOと学習過程を比較すると

  • 学習結果が収束するまでのエポック数はYOLOよりもかなり少ない
  • 学習に必要なGPUメモリ量はYOLOよりも多い

という感じでしたので、YOLOで学習した時にGPUメモリがギリギリだった学習データセットについては、バッチ数を下げて実施したほうが良いと思います。 学習完了までのエポック数はYOLOよりもかなり少ないので、バッチ数を下げてもトータルの学習時間は短くなるはずです。

参考までに、同一の学習データセットに対してRF-DETR と YOLOv8で学習した時の学習曲線/メトリクスを記載しておきます(学習環境としてはAWSのg6.xlargeインスタンスを利用しました)。各AIモデルが標準で出力するグラフのため縦軸が一致していませんが、参考までに。

RF-DETRの学習曲線
YOLOv8の学習曲線

推論方法

以下のように推論対象の画像のディレクトリのパス(--image)、読み込むモデルのパス(--model)、クラス数(--num_classes)を指定して実行してください(パスは環境に応じて適宜変更してください)。

poetry run python3 rf_detr_test/predict.py --image ~/test_images --model ./train/checkpoint.pth  --num_classes 2

コード解説

全体的な話

  • モデルにはRFDETRNano RFDETRSmall RFDETRMedium RFDETRLarge RFDETRBase の5つのサイズが用意されています(今回のコードではRFDETRBaseは利用していませんが、使い方は基本的に同じはずです)
  • RFDETRBaseが他の4つのモデルのどこに位置するのかは良く分かりませんでした(コードにはTrain an RF-DETR Base model (29M parameters).との注釈がありました)
  • 推論/学習共にモデル作成時にnum_classes引数でクラス名を指定したほうが良いようです(参照
  • 画像サイズはresolution引数で指定します。モデルサイズがbase、largeの場合は56の倍数、nano、small、mediumの場合は32の倍数が良いようです(参考)。長方形の画像を入力した場合、内部で正方形にリサイズされます(YOLOのように余白を追加しないので、学習/推論時のアスペクト比が実画像と異なる可能性があります)。

PyTorchのバージョンの話

公式リポジトリのpyproject.tomlでは、利用するPyTorchのバージョンが

"torch>=1.13.0",
"torchvision>=0.14.0"

となっていますが、実際にPyTorch 1.13.0で動作させようとするとrfdetr/engine.pyの以下の箇所でエラーが発生します

def train_one_epoch(
    model: torch.nn.Module,
    criterion: torch.nn.Module,
    lr_scheduler: torch.optim.lr_scheduler.LRScheduler, # <-- ここ
    data_loader: Iterable,
    optimizer: torch.optim.Optimizer,
    device: torch.device,
    epoch: int,
    batch_size: int,
    max_norm: float = 0,
    ema_m: torch.nn.Module = None,
    schedules: dict = {},
    num_training_steps_per_epoch=None,
    vit_encoder_num_layers=None,
    args=None,
    callbacks: DefaultDict[str, List[Callable]] = None,
):

これは、torch.optim.lr_scheduler.LRSchedulerがPyTorchの1.13.0には存在しないためです(参考)。初めて登場するのが2.5.0なので(参考)、少なくとも本コードについてはPyTorch2.5.0以降でないとエラーになると思います。

学習コードの話

  • model.trainという関数が用意されているので、この中で学習を実施しています
  • ユーザ自作のコールバック関数の指定はできないようです。内部的にはエポック終了時などのイベントごとに呼ばれているコールバックが存在しているようですので、そのあたりのコードを自分で変更しても良いですし、放っておいてもそのうちに対応してくれるかもしれません。

損失関数

lossの計算はrfdetr/models/lwdetr.pySetCriterionの中で実施されています。今のところクラスごとに重みを指定する機能はありませんが、Focal Lossを使っているのでクラス不均衡に対する強度はある程度確保されているようです(YOLOの方は分類損失の計算にtorch.nn.BCEWithLogitsLossを使用しているので、学習データのクラス不均衡の影響を受けます)。 また、学習対象の特性に応じて利用する損失関数の調整がある程度可能になっています。指定可能なパラメータは以下です(参考

パラメータ 初期値 内容
aux_loss True デコーダの各層から出力される補助的な予測に対しても損失を計算します。Trueにすると最終層だけではなく中間層の出力も損失に含めるので学習が早く安定します
sum_group_losses False 各グループの損失を合計して最終損失に反映します
use_varifocal_loss False Focal Lossの代わりにVarifocal Lossを利用する。IoUスコアに応じて重みづけが変わるので、Focal Lossよりも位置情報に敏感になります
use_position_supervised_loss False 位置情報に基づいた追加の教師あり損失を使うかどうかを制御します。DETR系では通常、Hungarian Matching後に座標回帰を行うため対象物が小さい場合や、隣接して存在する場合の精度が下がる傾向があると言われています。このオプションは座標に対する補助的な監督を強化することで、DETR系の弱点を補うようになるようです
ia_bce_loss True クラス分類損失としてIoU-Aware Binary Cross Entropy Lossを使用するかどうかを指定します。IoU情報を組み込んだBCE損失で、Varifocal Lossの代替として利用可能です

なお、処理の優先順位が

  1. ia_bce_loss
  2. use_position_supervised_loss
  3. use_varifocal_loss
  4. その他

となっており、なおかつ、ia_bce_lossの初期値がTrueとなっているので、ia_bce_lossの値を明示的にFalseに指定しないと、他の損失関数は利用されません。

作成されるファイル

学習を実施するとsave_dirで指定したディレクトリ配下に以下のような重みファイルが生成されます(参考)。

重みファイル名 内容
checkpoint_best_regular.pth EMA(Exponential Moving Average)を使わない純粋なモデルの中で最もスコアが良かったもの。追加学習を実施する時はこちらを利用すると良いようです。
checkpoint_best_ema.pth EMAモデルで全ての過去のモデル重みを指数移動平均で平滑化したもの
checkpoint_best_total.pth 最終的に一番良かったもの。基本的に推論時はこちらを利用すると良いようです

また、上記のファイルに加えて、以下のタイミングで学習中のチェックポイントが保存されます(rfdetr/main.pytrain関数)

if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % args.checkpoint_interval == 0:
    checkpoint_paths.append(output_dir / f'checkpoint{epoch:04}.pth')

lr_dropの初期値は(おそらく)100なので、checkpoint_intervalをどれだけ大きく設定しても100エポックごとにcheckpointXXXX.pthというファイルが保存されます(XXXX部分は4桁の数字)。 クラウドサービスでGPUインスタンスを借りて学習を実施する場合、あまり沢山チェックポイントファイルが作成されるとストレージが圧迫されるので困るのですが、コールバックの指定などができないので今のところはWatchDog的なもので監視しながら適宜消していくしかなさそうです。

推論コードの話

  • 学習時とモデルサイズ、クラス数(num_classes)、画像サイズ(resolution)は合わせる必要があります
  • 検出する閾値はthresholdで指定可能です
  • predictの返り値にはxyxy(BBOXのleft、top、right、bottomの座標)、confidence(信頼度スコア)、class_id(クラスID)が入っています。xyxyは入力画像のスケールに戻されていますが、float型なので描画時などは注意してください

未解決の問題

コードを読んだり試したりしている時に起きた問題のうち、未解決のものをいくつかまとめました。いずれについても公式リポジトリに修正のPRを投げた or 投げる予定なので、いずれ解決してくれるとは思います。

学習時のエラー

NVIDIA L4環境上で学習時にia_bce_lossFalseuse_position_supervised_lossTrue に設定すると以下のエラーが発生しました。

return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
  File "/home/ubuntu/.cache/pypoetry/virtualenvs/rf-detr-train-vhWH2Pc8-py3.10/lib/python3.10/site-packages/rfdetr/models/lwdetr.py", line 372, in loss_labels
    cls_iou_func_targets[pos_ind] = pos_ious_func
RuntimeError: Index put requires the source and destination dtypes match, got BFloat16 for the destination and Float for the source.

BFloat16対応しているGPUだと優先的にBFloat16を利用しようとするのですが、引き続きFloat型を利用している箇所があるため型不整合が起きているものと思います。 一応、rfdetr/models/lwdetr.pyの該当箇所を修正して本エラーを抑制する事は難しくないのですが、これを解消すると別のエラーが発生します。

データ拡張の少なさ

学習時のデータ拡張に関わるコードを見ると内容が以下だけなので、YOLOに比べるとバリエーションが少ない印象です。

return T.Compose([
    T.RandomHorizontalFlip(),
    T.RandomSelect(
        T.SquareResize(scales),
        T.Compose([
            T.RandomResize([400, 500, 600]),
            T.RandomSizeCrop(384, 600),
            T.SquareResize(scales),
        ]),
    ),
    normalize,
])

まとめ

本項ではRoboflowの最新物体検出モデルであるRF-DETRの学習/推論環境の構築についてまとめました。 できたばかりでまだ粗削りなところはありますが、性能面でもライセンス面でも今後に期待の物体検出モデルですのでもう少し弄ってみて、新しいことが分かったり有効な使い方が見つかった時には、また記事にまとめたいと思います。

執筆者

川邉 隆伸 (ジャパン・インフラ・ウェイマーク開発部所属)

画像認識系AIの開発や、それらを提供するSaaS環境の構築を行っています。

免責事項

本記事に記載された情報は、2025年11月時点での公開情報および筆者の検証・調査結果に基づくものです。

  • 記事に記載されている各プログラムやモジュールの機能、実行条件などは予告なく変更される場合があります
  • 本記事の内容を実践される際は、必ず各プログラムの最新の公式ドキュメントをご確認ください
  • 本記事の情報に基づいて行われた意思決定や実装により生じた損害について、筆者および所属組織は一切の責任を負いかねます

参考資料・出典

本記事の執筆にあたり参考としたページは以下の通りです

商標

  • Ultralytics YOLO は、Ultralytics の登録商標または商標です
  • PyTorch は、The Linux Foundation の登録商標または商標です
  • Ubuntu は、Canonical Ltd の登録商標または商標です
  • AWS は、Amazon Web Services, Inc. またはその関連会社の商標もしくは登録商標です

© NTT WEST, Inc.