たそらぼ

日頃思ったこととかメモとか。

torchvisionのtransformsが分からなかったので調べた。

PyTorchで画像処理を始めたので、torchvisions.transformsを使った前処理について調べました。

pytorch.org

torchvisions.transformsとは

Composeを使うことでチェーンさせた前処理が簡潔にかけるようになります。また、Functionalモジュールを使うことで、関数的な使い方をすることもできます。

Transforms are common image transformations. They can be chained together using Compose. Additionally, there is the torchvision.transforms.functional module. Functional transforms give fine-grained control over the transformations. This is useful if you have to build a more complex transformation pipeline (e.g. in the case of segmentation tasks).
※公式Doc(https://pytorch.org/docs/stable/torchvision/transforms.html)の説明

Composeモジュール

個人的に、Composeがちょっと分かりにくかったので、モヤモヤした点を説明します。

前処理クラスの例

例えば、以下のチュートリアルがとても簡潔に前処理クラスを定義してくれているので見ていきます。
pytorch.org

#refered: https://pytorch.org/tutorials/advanced/neural_style_tutorial.html

# desired size of the output image
imsize = 512 if torch.cuda.is_available() else 128  # use small size if no gpu

loader = transforms.Compose([
    transforms.Resize(imsize),  # scale imported image
    transforms.ToTensor()])  # transform it into a torch tensor


def image_loader(image_name):
    image = Image.open(image_name)
    # fake batch dimension required to fit network's input dimensions
    image = loader(image).unsqueeze(0)
    return image.to(device, torch.float)


style_img = image_loader("./data/images/neural-style/picasso.jpg")
content_img = image_loader("./data/images/neural-style/dancing.jpg")

assert style_img.size() == content_img.size(), \
    "we need to import style and content images of the same size"


transforms.Composeに配列で前処理を渡しておき、インスタンスにPillowで読み取ったイメージを渡すと前処理できるようです。
ただ、__call__で呼び出すときに具体的にどういうインターフェースになっているのか、いまいち分からず、モヤモヤしてしまいました。
公式Doc(torchvision.transforms — PyTorch master documentation)にも,
記事を書いた時点では__call__の仕様が明記されていなかったので、少し困りました。
(もちろん、普通に考えて画像を渡す以外にやることがないので、モヤモヤする余地がないかもしれませんが...。)

ソースコードを読む

docからComposeのソースを読むことができるので読んでみます。

#refered: https://pytorch.org/docs/stable/_modules/torchvision/transforms/transforms.html#Compose

class Compose(object):
    """Composes several transforms together.

    Args:
        transforms (list of ``Transform`` objects): list of transforms to compose.

    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.ToTensor(),
        >>> ])
    """

    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img

    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        for t in self.transforms:
            format_string += '\n'
            format_string += '    {0}'.format(t)
        format_string += '\n)'
        return format_string

ここで、__call__(self, img)になっていることが分かるので、画像を渡せば良いことが分かりとてもスッキリしました。

Functionalモジュール

個人的にはFunctionalの方が、より生に近い感じでかけるので好みでした。

#refered: https://pytorch.org/docs/stable/torchvision/transforms.html#functional-transforms

import torchvision.transforms.functional as TF
import random

def my_segmentation_transforms(image, segmentation):
    if random.random() > 0.5:
        angle = random.randint(-30, 30)
        image = TF.rotate(image, angle)
        segmentation = TF.rotate(segmentation, angle)
    # more transforms ...
    return image, segmentation

当然ながら、コードが長くなってしまうので、この辺の好みはもうちょっとPyTorchを書き慣れてくると変わるんだろうなぁという感じです。

参考にさせていただいた資料

pytorch.org

qiita.com

qiita.com

つくりながら学ぶ! PyTorchによる発展ディープラーニング

つくりながら学ぶ! PyTorchによる発展ディープラーニング