torchvisionのtransformsが分からなかったので調べた。
PyTorchで画像処理を始めたので、torchvisions.transformsを使った前処理について調べました。
torchvisions.transformsとは
Composeを使うことでチェーンさせた前処理が簡潔にかけるようになります。また、Functionalモジュールを使うことで、関数的な使い方をすることもできます。
Transforms are common image transformations. They can be chained together using Compose. Additionally, there is the torchvision.transforms.functional module. Functional transforms give fine-grained control over the transformations. This is useful if you have to build a more complex transformation pipeline (e.g. in the case of segmentation tasks).
※公式Doc(https://pytorch.org/docs/stable/torchvision/transforms.html)の説明
Composeモジュール
個人的に、Composeがちょっと分かりにくかったので、モヤモヤした点を説明します。
前処理クラスの例
例えば、以下のチュートリアルがとても簡潔に前処理クラスを定義してくれているので見ていきます。
pytorch.org
#refered: https://pytorch.org/tutorials/advanced/neural_style_tutorial.html # desired size of the output image imsize = 512 if torch.cuda.is_available() else 128 # use small size if no gpu loader = transforms.Compose([ transforms.Resize(imsize), # scale imported image transforms.ToTensor()]) # transform it into a torch tensor def image_loader(image_name): image = Image.open(image_name) # fake batch dimension required to fit network's input dimensions image = loader(image).unsqueeze(0) return image.to(device, torch.float) style_img = image_loader("./data/images/neural-style/picasso.jpg") content_img = image_loader("./data/images/neural-style/dancing.jpg") assert style_img.size() == content_img.size(), \ "we need to import style and content images of the same size"
transforms.Composeに配列で前処理を渡しておき、インスタンスにPillowで読み取ったイメージを渡すと前処理できるようです。
ただ、__call__で呼び出すときに具体的にどういうインターフェースになっているのか、いまいち分からず、モヤモヤしてしまいました。
公式Doc(torchvision.transforms — PyTorch master documentation)にも,
記事を書いた時点では__call__の仕様が明記されていなかったので、少し困りました。
(もちろん、普通に考えて画像を渡す以外にやることがないので、モヤモヤする余地がないかもしれませんが...。)
ソースコードを読む
docからComposeのソースを読むことができるので読んでみます。
#refered: https://pytorch.org/docs/stable/_modules/torchvision/transforms/transforms.html#Compose class Compose(object): """Composes several transforms together. Args: transforms (list of ``Transform`` objects): list of transforms to compose. Example: >>> transforms.Compose([ >>> transforms.CenterCrop(10), >>> transforms.ToTensor(), >>> ]) """ def __init__(self, transforms): self.transforms = transforms def __call__(self, img): for t in self.transforms: img = t(img) return img def __repr__(self): format_string = self.__class__.__name__ + '(' for t in self.transforms: format_string += '\n' format_string += ' {0}'.format(t) format_string += '\n)' return format_string
ここで、__call__(self, img)になっていることが分かるので、画像を渡せば良いことが分かりとてもスッキリしました。
Functionalモジュール
個人的にはFunctionalの方が、より生に近い感じでかけるので好みでした。
#refered: https://pytorch.org/docs/stable/torchvision/transforms.html#functional-transforms import torchvision.transforms.functional as TF import random def my_segmentation_transforms(image, segmentation): if random.random() > 0.5: angle = random.randint(-30, 30) image = TF.rotate(image, angle) segmentation = TF.rotate(segmentation, angle) # more transforms ... return image, segmentation
当然ながら、コードが長くなってしまうので、この辺の好みはもうちょっとPyTorchを書き慣れてくると変わるんだろうなぁという感じです。