PyTorchでDataLoaderを使って、CIFAR-10をミニバッチで取り出してみる
torch.utils.dataで、画像データのDataLoaderを作成する方法をまとめました。
CIFAR-10データセットを例として使用しています。
CIFAR-10データセット
約8000万枚の画像がある80 Million Tiny Imagesから約6万枚の画像を抽出してラベル付けした有名なデータセットです。
データとして、32x32カラー画像が納められています。
https://www.cs.toronto.edu/~kriz/cifar.html
画像データのDataLoader作成までに使うモジュール
画像データをミニバッチ学習で利用するのに普通必要なモジュールです。
モジュール | 概要 |
torchvisions.transforms | 画像の前処理を定義する。Datasetから利用する。 |
torch.utils.data.Dataset | データの取り出し方を定義する。DataLoaderから利用する。 |
torch.utils.data.DataLoader | ミニバッチ学習でのデータの取り出し方を定義する。 |
自分の場合、モジュールがいくつかに別れているのが理由で、理解するのに苦労しました。
特にgithubやコンペなどで公開されているコードをみると、全体の分量もある一方で、データの読み込み部分はデータによってある程度作り込まないといけないので、意外と理解するのが難しいと思っています。
分からない場合は、一度必要なモジュールを整理してみて、自分で組み合わせてデータの表示までをやってみるのが良いです。
データの利用までの方法
DataLoaderを作成して、ミニバッチで取り出したデータを描写するところまでやって行きたいと思います。
※各クラスのコードは主に『PyTorchによる発展ディープラーニング』のサポートリポジトリ(https://github.com/YutaroOgawa/pytorch_advanced)を参考にさせていただいています。
必要なライブラリのインポート
説明に必要なライブラリをインポートしておきます。
# パッケージのimport、いらないものがあったら省いてください。 import numpy as np from PIL import Image from tqdm import tqdm import matplotlib.pyplot as plt %matplotlib inline import torch import torch.utils.data as data import torchvision from torchvision import transforms
データを解凍する
まずwgetなどでデータセットをダウンロードして、解凍しておきます。
ls cifar-10-batches-py/ >batches.meta data_batch_2 data_batch_4 readme.html >data_batch_1 data_batch_3 data_batch_5 test_batch
5つに分けてpickleされているので、データを取り出します。
簡単のため、data_batch_1 だけ取り出してみます。
def unpickle(file): import pickle with open(file, 'rb') as fo: dict = pickle.load(fo, encoding='bytes') return dict dict = unpickle("cifar-10-batches-py/data_batch_1") dict.keys() #dict_keys([b'batch_label', b'labels', b'data', b'filenames'])
前処理の定義
とりあえずTorchテンソルへの変換と、色を標準化をしておきます。
データセットが32×32になっているので、リサイズや切り出しはしないこととしました。
#refered: https://github.com/YutaroOgawa/pytorch_advanced/blob/master/1_image_classification/1-3_transfer_learning.ipynb # 入力画像の前処理のクラス class Transform(): """ 画像をTorchテンソルに変換し、色を標準化する。 ----------。 mean : (R, G, B) 色チャネルのmean。 std : (R, G, B) 色チャネルの標準偏差。 """ def __init__(self, mean, std): self.base_img_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean, std) ]) def __call__(self, img): return self.img_transform(img)
データセットを作る
データが取り出せるようになったので、データセットを作って行きます。データセットはtorch.utils.data.Dataset(torch.utils.data — PyTorch master documentation)を実装します。
Datasetは__getitem__をオーバーライドする必要があります。(torch.utils.data.dataset — PyTorch master documentation)
オーバーライドしない場合、NotImplementedErrorが投げられるようになっています。
__init__で、__getitem__で返すためのデータをセットします。
今回は、綺麗じゃないかもしれませんが、img_list・label_list・filename_listにあらかじめ分解しておきます。
__getitem__で、index番目のデータに必要な前処理をしたのち、呼び出し側にデータを返します。
class Dataset(data.Dataset): """ CIFAR-10 のDatasetクラス。 ---------- data_dict : 辞書 CIFAR-10のデータの辞書。 transform : object 前処理クラスのインスタンス """ def __init__(self, data_dict, transform=None): #self.data_dict = data_dict # そのままdictを渡す場合 self.img_list = data_dict[b'data'] # 画像のリスト self.label_list = data_dict[b'labels'] # ラベルのリスト self.filename_list = data_dict[b'filenames'] # ファイル名のリスト self.transform = transform # 前処理クラスのインスタンス def __len__(self): '''画像の枚数を返す''' return len(self.label_list) def __getitem__(self, index): ''' Tensor形式の前処理をした画像データ・ラベル・ファイル名を取得する ''' # index番目の画像をロード img = self.img_list[index].reshape([3, 32, 32]) # [色RGB][高さ][幅] img_new = img.transpose((1, 2, 0)) # [高さ][幅][色RGB] img_transformed = self.transform(img_new)#, self.phase) # torch.Size([3, 224, 224]) # index番目のlabelをロード label = self.label_list[index] #.decode('utf-8') #label = int(label) # index番目のファイル名をロード filename = self.filename_list[index].decode('utf-8') return img_transformed, label, filename
動作確認をしておきます。
mean = (0.485, 0.456, 0.406) std = (0.229, 0.224, 0.225) dataset = Dataset(data_dict=dict, transform=Transform(mean=mean, std=std)) index = 0 print(dataset.__getitem__(index)[0].size()) print(dataset.__getitem__(index)[1]) #torch.Size([3, 32, 32]) #6 #leptodactylus_pentadactylus_s_000004.png
DataLoader
最後にDataLoaderを作成すれば、データをミニバッチ学習などに利用することができます。
DataLoader(torch.utils.data — PyTorch master documentation)は色々オプションを指定可能ですが、今回は普通によく使うオプションのみ設定しておきます。
# ミニバッチのサイズ batch_size = 32 # DataLoader dataloader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, shuffle=True)
Dataを取り出してみる
早速、DataLoaderを使ってデータを取り出してみます。
batch_iterator = iter(dataloader) # イテレータに変換する inputs, labels, filename = next(batch_iterator) #最初のミニバッチを取り出す #pltで描写するために転置する imgs = inputs imgs_new = [img.numpy().transpose((1, 2, 0)) for img in imgs] #描写 #refered: http://umashika5555.hatenablog.com/entry/2017/09/24/235813 plt.figure() h_length = 8 w_length = 4 for h in range(h_length): for w in range(w_length): i = h*w_length+w+1 fig_loc = str(h_length) + str(w_length) + str(i) plt.subplot(h_length,w_length,i) #グラフプロット plt.imshow(imgs_new[i-1]) plt.show()
プロット結果は次のようになって、ちゃんと32枚、画像が取り出されていることが分かります。