Tips

PytorchのDataLoaderとtorchvision

はじめに

Pytorchとは

Pytorchとは、ディープラーニング用の動的フレームワークです。
Pytorchは比較的新しいフレームワークですが、動的でデバッグがしやすい上に、そこまでパフォーマンスが悪くないので、結構注目されており、Redditなどを見ていても実装が結構あがっています。
動的フレームワークでは、Chainerが日本で人気ですが、海外の人気をみるとPytorchのほうが高く、コミュニティもしっかりしている印象です。

Pytorchの導入に関しては、以前DLHacksで発表した資料を参考にしていただければ幸いです。(ニーズがあれば記事化するかも)

DataLoader

基本的に論文の実装の再現をする際は、下記のようなステップで実装するのが一般的かと思います。
– DataLoader
– モデル
– 損失関数
– 訓練モジュール
– (ハイパーパラメータチューニング)

DataLoaderは、訓練・テストデータのロード・前処理をするためのモジュールです。結構実装が面倒だったり、適当にやりすぎると、メモリ周りできついことになります。
Pytorchではデフォルトでdataloaderを用意しているのですが、それに加えて、ライブラリを作っていくれています。

  • torchvision:画像周りのデータローダ、前処理、有名モデル(densenet, alex, resnet, vgg等)
  • torchtext(WIP):テキスト系のデータローダ、埋め込み周りやpaddingあたりもやってくれます
  • torchaudio:音声データ周りのデータローダ

今回は、pytorchのデフォルトのdataloaderとtorchvisionのモジュールについて簡単に解説します。
torchtextについてはこちらで解説しています。

pytorchのデータローダ

実装手順

pytorchのデフォルトのものを使うことで下記3ステップで実装できます。

  1. DataSetの作成
    
DataSetのサブクラスでラップする
  2. Dataの前処理

    Transformで前処理を定義する
  3. DataLoader
    
DataLoaderでDatasetをバッチで取り出せるようにする

1.DataSet

torch.utils.data.Dataset を継承して、使います。 
このクラスは下記のソースコードの通り、ただの抽象クラス
でほぼほぼ何もしていないです。
getlenをoverrideして使います。

(引用:pytorchソースコード)

1.Dataset利用例:FaceLandmarks

Datasetの利用例をpytorchチュートリアルより抜粋します。
上述の通り、 init で必要なデータを持つようにし、あとは lengetをoverrideするだけです。
次の節で少し説明しますが、 get 内で transform を使ってデータを前処理するようにするところだけ注意してください。

2. Transform

前処理のために Transform オブジェクトを利用します。
FaceLandmarks のデータセットのコードでは、 inittransform を引数にしています。
この transformTransform オブジェクトを指しており、各オブジェクトは前処理に関して記述します。
Transform オブジェクトは、最低でもcallを定義する必要があり、大抵 initでいろいろ設定します。例として、ToTensorを見てみます。

(引用:pytorchチュートリアル)

2. Transforms

復数の前処理を使うために必要なのが、 torchvision.transforms.Compose です。
使い方は単純で下記のようになります。

(引用:pytorchチュートリアル)

単純なのもそのはずで、ソースコードをみると、Composeはただのリストで、 たった7行です。
なぜこれが pytorch 本体ではなく、 torchvision においているかは謎です。

3. DataLoader

torch.utils.data.Datalodaerを使えば、Datasetを渡すことでミニバッチを返すIterableなオブジェクトにしてくれます。
ソースコードを少し見ましたが、マルチスレッド周りとかいろいろやってくれています。

使い方は下記です。

これらは、torchのモデルに依存せずnumpy等で返してくれるので、基本的にPytorch以外でも使えそうです。
ただし、上述の抽象クラスのtorch.util.data.Datasetだけは使うので、pytorchは入れるかDatasetクラスを自分で書いておく必要があるかもしれません。

torchvision

pytorchが用意してくれている画像周りのDataLoaderです。
基本的にpytorchのDataset・Dataloaderを使っていて、拡張しています。
下記のようなモジュールがあります。

  • datasets: pytorchのDatasetで有名なデータセット簡単に使えるようにしています。
    — MNIST and FashionMNIST
    — COCO (Captioning and Detection)
    — LSUN Classification
    — ImageFolder
    — Imagenet-12
    — CIFAR10 and CIFAR100
    — STL10
    — SVHN
    — PhotoTour
  • model:有名なモデルが実装されています。学習済みを引数にpretraind=Trueを渡せば使えます。
    — alexnet
    — densenet
    — inception
    — resnet
    — squeezenet
    — vgg
  • Transforms:画像でよく使われる前処理は用意してくれています。
    — Compose:復数のTransformのリスト化
    — Scale:大きさ変更
    — CenterCrop:真ん中でクロッピング
    — RandomCrop:ランダムにクロッピング
    — RandomHorizontalFlip:ある確率でFlip
    — Normalize:正規化
    — ToTensor:PIL.ImageやnumpyをTensor化

感想

基本的にシンプルにまとまっており、自分が使いたいデータセットに対して、簡単にDataLoaderを作れるような感じでした。
Numpyオブジェクトで返してくれるので、別フレームワークでも簡単に使えるかもしれないです。
上述のpytorchのdataloaderと全くの別物であるtorchtextについてはこちらで解説しています。

発表資料


曽根岡 侑也

曽根岡 侑也

東京大学 松尾豊研究室 リサーチエンジニア
東京大学大学院工学系研究科, 松尾豊研究室卒. クロードテック株式会社CEO. IPA未踏クリエータ. Paypal Battle Hack Tokyo優勝.