たそらぼ

日頃思ったこととかメモとか。

PyTorchでDataLoaderを使って、CIFAR-10をミニバッチで取り出してみる

torch.utils.dataで、画像データのDataLoaderを作成する方法をまとめました。
CIFAR-10データセットを例として使用しています。

CIFAR-10データセット

約8000万枚の画像がある80 Million Tiny Imagesから約6万枚の画像を抽出してラベル付けした有名なデータセットです。
データとして、32x32カラー画像が納められています。
https://www.cs.toronto.edu/~kriz/cifar.html

画像データのDataLoader作成までに使うモジュール

画像データをミニバッチ学習で利用するのに普通必要なモジュールです。

モジュール 概要
torchvisions.transforms 画像の前処理を定義する。Datasetから利用する。
torch.utils.data.Dataset データの取り出し方を定義する。DataLoaderから利用する。
torch.utils.data.DataLoader ミニバッチ学習でのデータの取り出し方を定義する。

自分の場合、モジュールがいくつかに別れているのが理由で、理解するのに苦労しました。
特にgithubやコンペなどで公開されているコードをみると、全体の分量もある一方で、データの読み込み部分はデータによってある程度作り込まないといけないので、意外と理解するのが難しいと思っています。
分からない場合は、一度必要なモジュールを整理してみて、自分で組み合わせてデータの表示までをやってみるのが良いです。

データの利用までの方法

DataLoaderを作成して、ミニバッチで取り出したデータを描写するところまでやって行きたいと思います。
※各クラスのコードは主に『PyTorchによる発展ディープラーニング』のサポートリポジトリ(https://github.com/YutaroOgawa/pytorch_advanced)を参考にさせていただいています。

必要なライブラリのインポート

説明に必要なライブラリをインポートしておきます。

# パッケージのimport、いらないものがあったら省いてください。
import numpy as np
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.utils.data as data
import torchvision
from torchvision import transforms

データを解凍する

まずwgetなどでデータセットをダウンロードして、解凍しておきます。

ls cifar-10-batches-py/

>batches.meta  data_batch_2  data_batch_4  readme.html
>data_batch_1  data_batch_3  data_batch_5  test_batch

5つに分けてpickleされているので、データを取り出します。
簡単のため、data_batch_1 だけ取り出してみます。

 def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

dict = unpickle("cifar-10-batches-py/data_batch_1")

dict.keys()
#dict_keys([b'batch_label', b'labels', b'data', b'filenames'])

前処理の定義

とりあえずTorchテンソルへの変換と、色を標準化をしておきます。
データセットが32×32になっているので、リサイズや切り出しはしないこととしました。

#refered: https://github.com/YutaroOgawa/pytorch_advanced/blob/master/1_image_classification/1-3_transfer_learning.ipynb

# 入力画像の前処理のクラス
class Transform():
    """
    画像をTorchテンソルに変換し、色を標準化する。

    ----------。
    mean : (R, G, B)
        色チャネルのmean。
    std : (R, G, B)
        色チャネルの標準偏差。
    """

    def __init__(self, mean, std):
        self.base_img_transform = transforms.Compose([
            transforms.ToTensor(), 
            transforms.Normalize(mean, std)  
        ])

    def __call__(self, img):
        return self.img_transform(img)

データセットを作る

データが取り出せるようになったので、データセットを作って行きます。データセットはtorch.utils.data.Dataset(torch.utils.data — PyTorch master documentation)を実装します。
Datasetは__getitem__をオーバーライドする必要があります。(torch.utils.data.dataset — PyTorch master documentation
オーバーライドしない場合、NotImplementedErrorが投げられるようになっています。

__init__で、__getitem__で返すためのデータをセットします。
今回は、綺麗じゃないかもしれませんが、img_list・label_list・filename_listにあらかじめ分解しておきます。
__getitem__で、index番目のデータに必要な前処理をしたのち、呼び出し側にデータを返します。

class Dataset(data.Dataset):
    """
    CIFAR-10 のDatasetクラス。

    ----------
    data_dict : 辞書
        CIFAR-10のデータの辞書。
    transform : object
        前処理クラスのインスタンス
    """

    def __init__(self, data_dict, transform=None):
        #self.data_dict = data_dict # そのままdictを渡す場合

        self.img_list = data_dict[b'data']  # 画像のリスト
        self.label_list = data_dict[b'labels']  # ラベルのリスト
        self.filename_list = data_dict[b'filenames']  # ファイル名のリスト

        self.transform = transform  # 前処理クラスのインスタンス

    def __len__(self):
        '''画像の枚数を返す'''
        return len(self.label_list)

    def __getitem__(self, index):
        '''
        Tensor形式の前処理をした画像データ・ラベル・ファイル名を取得する
        '''

        # index番目の画像をロード
        img = self.img_list[index].reshape([3, 32, 32]) # [色RGB][高さ][幅]
        img_new = img.transpose((1, 2, 0)) # [高さ][幅][色RGB]
        img_transformed = self.transform(img_new)#, self.phase)  # torch.Size([3, 224, 224])

        # index番目のlabelをロード
        label = self.label_list[index] #.decode('utf-8')
        #label = int(label)

        # index番目のファイル名をロード
        filename = self.filename_list[index].decode('utf-8')
        

        return img_transformed, label, filename

動作確認をしておきます。

mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
dataset = Dataset(data_dict=dict, transform=Transform(mean=mean, std=std))

index = 0
print(dataset.__getitem__(index)[0].size())
print(dataset.__getitem__(index)[1])

#torch.Size([3, 32, 32])
#6
#leptodactylus_pentadactylus_s_000004.png

DataLoader

最後にDataLoaderを作成すれば、データをミニバッチ学習などに利用することができます。
DataLoader(torch.utils.data — PyTorch master documentation)は色々オプションを指定可能ですが、今回は普通によく使うオプションのみ設定しておきます。

# ミニバッチのサイズ
batch_size = 32

# DataLoader
dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=batch_size, shuffle=True)

Dataを取り出してみる

早速、DataLoaderを使ってデータを取り出してみます。

batch_iterator = iter(dataloader)  # イテレータに変換する
inputs, labels, filename = next(batch_iterator)  #最初のミニバッチを取り出す

#pltで描写するために転置する
imgs = inputs
imgs_new = [img.numpy().transpose((1, 2, 0)) for img in imgs]

#描写
#refered: http://umashika5555.hatenablog.com/entry/2017/09/24/235813
plt.figure()
h_length = 8
w_length = 4
for h in range(h_length):
    for w in range(w_length):
        i = h*w_length+w+1
        fig_loc = str(h_length) + str(w_length) + str(i)
        plt.subplot(h_length,w_length,i)
        #グラフプロット
        plt.imshow(imgs_new[i-1])
plt.show()

プロット結果は次のようになって、ちゃんと32枚、画像が取り出されていることが分かります。

f:id:tasotasoso:20200112233834p:plain
1つ目のミニバッチで取り出したデータ

参考にさせていただいた資料

つくりながら学ぶ! PyTorchによる発展ディープラーニング

つくりながら学ぶ! PyTorchによる発展ディープラーニング

pytorch.org

umashika5555.hatenablog.com