はじめに
Pytorchとは
Pytorchとは、ディープラーニング用の動的フレームワークです。
Pytorchは比較的新しいフレームワークですが、動的でデバッグがしやすい上に、そこまでパフォーマンスが悪くないので、結構注目されており、Redditなどを見ていても実装が結構あがっています。
動的フレームワークでは、Chainerが日本で人気ですが、海外の人気をみるとPytorchのほうが高く、コミュニティもしっかりしている印象です。
Pytorchの導入に関しては、以前DLHacksで発表した資料を参考にしていただければ幸いです。(ニーズがあれば記事化するかも)
DataLoader
基本的に論文の実装の再現をする際は、下記のようなステップで実装するのが一般的かと思います。
– DataLoader
– モデル
– 損失関数
– 訓練モジュール
– (ハイパーパラメータチューニング)
DataLoaderは、訓練・テストデータのロード・前処理をするためのモジュールです。結構実装が面倒だったり、適当にやりすぎると、メモリ周りできついことになります。
Pytorchではデフォルトでdataloaderを用意しているのですが、それに加えて、ライブラリを作っていくれています。
- torchvision:画像周りのデータローダ、前処理、有名モデル(densenet, alex, resnet, vgg等)
- torchtext(WIP):テキスト系のデータローダ、埋め込み周りやpaddingあたりもやってくれます
- torchaudio:音声データ周りのデータローダ
今回は、pytorchのデフォルトのdataloaderとtorchvisionのモジュールについて簡単に解説します。
torchtextについてはこちらで解説しています。
pytorchのデータローダ
実装手順
pytorchのデフォルトのものを使うことで下記3ステップで実装できます。
- DataSetの作成
DataSetのサブクラスでラップする - Dataの前処理
Transformで前処理を定義する - DataLoader
DataLoaderでDatasetをバッチで取り出せるようにする
1.DataSet
torch.utils.data.Dataset
を継承して、使います。
このクラスは下記のソースコードの通り、ただの抽象クラス
でほぼほぼ何もしていないです。
get
と len
をoverrideして使います。
1 2 3 4 5 6 7 8 9 10 11 12 |
class Dataset(object): """An abstract class representing a Dataset. All other datasets should subclass it. All subclasses should override ``__len__``, that provides the size of the dataset, and ``__getitem__``, supporting integer indexing in range from 0 to len(self) exclusive. """ def __getitem__(self, index): raise NotImplementedError def __len__(self): raise NotImplementedError |
(引用:pytorchソースコード)
1.Dataset利用例:FaceLandmarks
Datasetの利用例をpytorchチュートリアルより抜粋します。
上述の通り、 init
で必要なデータを持つようにし、あとは len
と get
をoverrideするだけです。
次の節で少し説明しますが、 get
内で transform
を使ってデータを前処理するようにするところだけ注意してください。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
class FaceLandmarksDataset(Dataset): """Face Landmarks dataset.""" def __init__(self, csv_file, root_dir, transform=None): """ Args: csv_file (string): Path to the csv file with annotations. root_dir (string): Directory with all the images. transform (callable, optional): Optional transform to be applied on a sample. """ self.landmarks_frame = pd.read_csv(csv_file) self.root_dir = root_dir self.transform = transform def __len__(self): return len(self.landmarks_frame) def __getitem__(self, idx): img_name = os.path.join(self.root_dir, self.landmarks_frame.ix[idx, 0]) image = io.imread(img_name) landmarks = self.landmarks_frame.ix[idx, 1:].as_matrix().astype('float') landmarks = landmarks.reshape(-1, 2) sample = {'image': image, 'landmarks': landmarks} if self.transform: sample = self.transform(sample) return sample |
2. Transform
前処理のために Transform
オブジェクトを利用します。
FaceLandmarks
のデータセットのコードでは、 init
で transform
を引数にしています。
この transform
は Transform
オブジェクトを指しており、各オブジェクトは前処理に関して記述します。
Transform
オブジェクトは、最低でもcall
を定義する必要があり、大抵 init
でいろいろ設定します。例として、ToTensor
を見てみます。
1 2 3 4 5 6 7 8 9 10 11 12 |
class ToTensor(object): """Convert ndarrays in sample to Tensors.""" def __call__(self, sample): image, landmarks = sample['image'], sample['landmarks'] # swap color axis because # numpy image: H x W x C # torch image: C X H X W image = image.transpose((2, 0, 1)) return {'image': torch.from_numpy(image), 'landmarks': torch.from_numpy(landmarks)} |
(引用:pytorchチュートリアル)
2. Transforms
復数の前処理を使うために必要なのが、 torchvision.transforms.Compose
です。
使い方は単純で下記のようになります。
1 2 3 4 5 6 7 |
transformed_dataset = FaceLandmarksDataset(csv_file='faces/face_landmarks.csv', root_dir='faces/', transform=transforms.Compose([ Rescale(256), RandomCrop(224), ToTensor() ])) |
(引用:pytorchチュートリアル)
単純なのもそのはずで、ソースコードをみると、Compose
はただのリストで、 たった7行です。
なぜこれが pytorch
本体ではなく、 torchvision
においているかは謎です。
1 2 3 4 5 6 7 8 |
class Compose(object): def __init__(self, transforms): self.transforms = transforms def __call__(self, img): for t in self.transforms: img = t(img) return img |
3. DataLoader
torch.utils.data.Datalodaer
を使えば、Dataset
を渡すことでミニバッチを返すIterableなオブジェクトにしてくれます。
ソースコードを少し見ましたが、マルチスレッド周りとかいろいろやってくれています。
使い方は下記です。
1 2 3 4 5 6 |
dataloader = DataLoader(transformed_dataset, batch_size=4, shuffle=True, num_workers=4) for i_batch, sample_batched in enumerate(dataloader): print(i_batch, sample_batched['image'].size(), sample_batched['landmarks'].size()) |
これらは、torchのモデルに依存せずnumpy等で返してくれるので、基本的にPytorch以外でも使えそうです。
ただし、上述の抽象クラスのtorch.util.data.Dataset
だけは使うので、pytorchは入れるかDatasetクラスを自分で書いておく必要があるかもしれません。
torchvision
pytorchが用意してくれている画像周りのDataLoaderです。
基本的にpytorchのDataset・Dataloaderを使っていて、拡張しています。
下記のようなモジュールがあります。
- datasets: pytorchのDatasetで有名なデータセット簡単に使えるようにしています。
— MNIST and FashionMNIST
— COCO (Captioning and Detection)
— LSUN Classification
— ImageFolder
— Imagenet-12
— CIFAR10 and CIFAR100
— STL10
— SVHN
— PhotoTour - model:有名なモデルが実装されています。学習済みを引数に
pretraind=True
を渡せば使えます。
— alexnet
— densenet
— inception
— resnet
— squeezenet
— vgg - Transforms:画像でよく使われる前処理は用意してくれています。
— Compose:復数のTransformのリスト化
— Scale:大きさ変更
— CenterCrop:真ん中でクロッピング
— RandomCrop:ランダムにクロッピング
— RandomHorizontalFlip:ある確率でFlip
— Normalize:正規化
— ToTensor:PIL.ImageやnumpyをTensor化
感想
基本的にシンプルにまとまっており、自分が使いたいデータセットに対して、簡単にDataLoaderを作れるような感じでした。
Numpyオブジェクトで返してくれるので、別フレームワークでも簡単に使えるかもしれないです。
上述のpytorchのdataloaderと全くの別物であるtorchtextについてはこちらで解説しています。