以前 ResNet で分類問題を計算したが、その時に Normalize と unsqueeze という謎の関数があった。本を読み進めるとこの関数の謎が解けてきた。
まず、PyTorch の学習・出力は「ミニバッチ」単位で行うというお約束がある。ミニバッチ数を N とすると、N × C × H × W (C はカラーチャンネル数)という形式でなければならない。画像データは当然ミニバッチではないため、そのように見せるために N=1 を第0次元目に挿入する必要があり、その操作が torch.unsqueese(img_t, 0) となる。
画像は通常 [0, 256) の8ビット整数として扱われるが、32ビット浮動小数点[0, 1) に変換した方がニューラルネットでは処理しやすいので、平均0、標準偏差1の分布になるようにする処理が Normalize である。(なお、画像を8ビット固定小数点として解釈しなおして活性化関数もそのつもりで再定義することは可能である)このとき、68% の入力が [-1, 1] に入ることが期待される。
PyTorch では Normalize 関数に平均と引数を与える必要がある(自動では計算されない)。カラー画像では3チャンネルあるため、これらの値も長さ3のベクトルとなる。
コメントを残す