ResNet をいじってみる

PyTorch実践入門 | マイナビブックス (mynavi.jp)を見ながら ResNet101 を触ってみる。WSL 上に PyTorch と Pillow は別途インストールしておいた。今は要らないが、WSL 版 CUDA もインストールしておく。

from torchvision import models
resnet = models.reset101(pretrained = True)
from torchvision import transforms
preprocess = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
from PIL import Image
img = Image.open("sample.jpg")
img_t = preprocess(img)
import torch
batch_t = torch.unsqueeze(img_t, 0)
resnet.eval()
out = resnet(batch_t)
labels = list()
with open("imagenet_classes.txt") as f:
  labels = [s.strip() for s in f.readlines()]
_, index = torch.max(out, 1)
certain = torch.nn.functional.softmax(out, dim=1)[0]
print(labels[index[0]], centain[index[0]].items())

教科書ではゴールデンレトリバーの写真を喰わしていたが、ラベルにはない、人物の写真を使ってみた。結果は「軍服」が最もマッチした結果(certain = 0.96)となった。人物の場合、着ている服にマッチしようとするようだ。

ちなみにネットからダウンロードしたゴールデンレトリバーの写真を喰わせたところ、きちんと犬種を認識した(certain = 0.88)。教科書より尤度が落ちるようだが、一応期待する結果となった。

プログラム上、transforms.Normalize() の引数や torch.unsqueeze() という謎の関数があるが、読み進めていく内に明らかになるのかな?

投稿者について
みのしす

小さいときは科学者になろうとしたのに、その時にたまたま身に着けたプログラミングで未だに飯を食っているしがないおじさんです。(年齢的にはもうすぐおじいさん)

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です