ニューラルネットの構造決めは難しい…

PyTorch 教科書の続きです。

8章では畳み込みニューラルネットが説明され、そこでは最終的には 8 × 8 × 8→32 および 32→2 の全結合ネットワークに入力されます。512→32 ニューラルネットのパラメータ数が全パラメータ数のほとんどを占めています(17000個を超える)。

しかし、全結合層がこんなに大きい必要があるのか?と思い、もう1つ畳み込みを増やし(8→2)、再度プーリングして小さい全結合層(4 × 4 × 2→2)を用意したらどうなるか、という実験をしてみました。この場合、パラメータ数は 1800 程度となります。活性化関数も途中に非線形のものが複数あるので要らないだろうという判断で付けませんでした。

class Net1(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 8, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(8, 2, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(4 * 4 * 2, 2)

    def forward(self, x):
        out = F.max_pool2d(torch.tanh(self.conv1(x)), 2)
        out = F.max_pool2d(torch.tanh(self.conv2(out)), 2)
        out = F.max_pool2d(torch.tanh(self.conv3(out)), 2)
        out = out.view(-1, 32)
        out = self.fc1(out)
        return out

結果はこちら。まず、教科書オリジナルのネットワーク。

Number of parameters: 18090
1: loss:0.5651948442504664
11: loss:0.3216442522729278
21: loss:0.27913830149325597
31: loss:0.2590295629231793
41: loss:0.24230616791233137
51: loss:0.22572632390222733
61: loss:0.21183957520184243
71: loss:0.19910000701239156
81: loss:0.18463985992085402
91: loss:0.17045862604952922
101: loss:0.1581920424511858
111: loss:0.14482170625761814
121: loss:0.13526690044220846
131: loss:0.12121768934047146
141: loss:0.10912206292057493
Accuracy train: 0.9694
Accuracy val: 0.894

過学習気味になっています。ついで新しいネットワークの検証結果。

1: loss:0.6875121172066707
11: loss:0.41660272999174275
21: loss:0.3356129206289911
31: loss:0.30873911443409646
41: loss:0.2911262233166178
51: loss:0.27725260225450915
61: loss:0.26394867768902686
71: loss:0.25447895415839117
81: loss:0.24749484060296587
91: loss:0.23944083999866134
101: loss:0.23235782914480585
111: loss:0.22487093451296447
121: loss:0.22108068706313516
131: loss:0.21631880598083422
141: loss:0.2107874967490032
Accuracy train: 0.9092
Accuracy val: 0.89

若干検出率が落ちていますが、悪くない数値となりました。

投稿者について
みのしす

小さいときは科学者になろうとしたのに、その時にたまたま身に着けたプログラミングで未だに飯を食っているしがないおじさんです。(年齢的にはもうすぐおじいさん)

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です