PyTorch実践入門 | マイナビブックス (mynavi.jp)を見ながら ResNet101 を触ってみる。WSL 上に PyTorch と Pillow は別途インストールしておいた。今は要らないが、WSL 版 CUDA もインストールしておく。
from torchvision import models
resnet = models.reset101(pretrained = True)
from torchvision import transforms
preprocess = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
from PIL import Image
img = Image.open("sample.jpg")
img_t = preprocess(img)
import torch
batch_t = torch.unsqueeze(img_t, 0)
resnet.eval()
out = resnet(batch_t)
labels = list()
with open("imagenet_classes.txt") as f:
labels = [s.strip() for s in f.readlines()]
_, index = torch.max(out, 1)
certain = torch.nn.functional.softmax(out, dim=1)[0]
print(labels[index[0]], centain[index[0]].items())
教科書ではゴールデンレトリバーの写真を喰わしていたが、ラベルにはない、人物の写真を使ってみた。結果は「軍服」が最もマッチした結果(certain = 0.96)となった。人物の場合、着ている服にマッチしようとするようだ。
ちなみにネットからダウンロードしたゴールデンレトリバーの写真を喰わせたところ、きちんと犬種を認識した(certain = 0.88)。教科書より尤度が落ちるようだが、一応期待する結果となった。
プログラム上、transforms.Normalize() の引数や torch.unsqueeze() という謎の関数があるが、読み進めていく内に明らかになるのかな?
コメントを残す