Quantcast
Channel: プログラミング
Viewing all articles
Browse latest Browse all 8051

大規模言語モデルで将棋AIを作る その4(ResNetの特徴マップ) - TadaoYamaokaの開発日記

$
0
0

前回までは、ネットワーク全体をTransformerで構成したところ、ResNetと比較して精度が上がらないという結果になった。

今回は、ResNetとTransformerを組み合わせて、初めにResNetで特徴マップを作成した後、その特徴マップを座標ごとに分割しトークンとして、Transformerに入力することを試す。

ネットワーク構成

ResNetは12ブロック256フィルタとし、Transformerはヘッド数8、feedforward256、8層とする。
Transformerの埋め込みの次元はResNetのフィルタ数となるため256、トークン数は座標の数になるため81になる。

ResNet20ブロックの後半8ブロックをTransformer8層に置き換えた形になる。

実装

ResNetの特徴マップを入力とする場合、次元は(batchsize, channels, tokens)となるため、PyTorchの標準のTransformerを使う場合、channelsとtokensの次元を交換する必要がある。
Transformerを自作する場合は、次元を交換しないで、Q,K,VのLinearを1x1の畳み込み層で実装して効率化できそうだが、今回はPyTorchの標準のTransformerで実装した。

入力層と出力層は、ResNetから変更しない。

classPolicyValueNetwork(nn.Module):
    def__init__(self, blocks, channels, activation=nn.ReLU(), fcl=256):
        ...
        # Resnet blocks
        self.blocks = nn.Sequential(*[ResNetBlock(channels, activation) for _ inrange(blocks)])
        
        # Transformer
        self.pos_encoder = PositionalEncoding(channels, 81)
        transformer_layer = nn.TransformerEncoderLayer(channels, nhead=8, dim_feedforward=1024, dropout=0.1, activation="gelu", batch_first=True)
        self.transformer = nn.TransformerEncoder(transformer_layer, num_layers=8)
        ...

    defforward(self, x1, x2):
        ...
        # resnet blocks
        h = self.blocks(u1)
        
        # Transformer
        h = h.flatten(2)
        h = h.transpose(1, 2)
        h = self.pos_encoder(h)
        h = self.transformer(h)
        h = h.transpose(1, 2)
        h = h.view(h.size(0), -1, 9, 9)
        ...

結果

ResNet20ブロック256フィルタのモデルと比較した。
また、ヘッド数、feedforwardのユニット数、活性化関数は条件を変えて比較した。

データは前回と同じで、4エポック学習した。

nheadfeed forward活性化関数val losspolicy acc.value acc.
ResNet2.55470.41360.6714
Transformer8256gelu2.56270.41240.6769
Transformer4512gelu2.55600.41300.6805
Transformer8256gelu2.59330.41320.6497
Transformer8256relu2.55290.41290.6799
Transformer81024relu2.55380.41560.6803
Transformer81024gelu2.55240.41300.6798
Transformer161024gelu2.55840.41250.6752

ResNet20ブロック256フィルタより若干悪いくらいの精度になった。
feedforwardのユニット数を増やすと、ResNet20ブロック256フィルタより少し良くなった。

ヘッド数は少なくても多すぎても良くない。
活性化関数は、条件によって変わるためreluとgeluがどちらが良いとも言えない。

学習時間

4エポックの学習時間は以下の通り。
ただし、Windowsで実行しているため、Flash Attentionは有効になっていない。

nheadfeed forward活性化関数学習時間
ResNet1:37
Transformer8256gelu1:58
Transformer8512gelu2:04
Transformer4256gelu2:00
Transformer8256relu2:01
Transformer81024relu2:14
Transformer81024gelu2:14
Transformer161024gelu2:27

ResNetよりもTransformerの学習時間は長くなっている。

まとめ

ResNetの特徴マップをTransformerの入力にすることを試した。
ResNetの後半8ブロックを、Transformer8層に置き換えたところ、同じくらいの精度になった。
feed forwardのユニット数を増やすことでResNetよりも少し精度がよくなった。
しかし、学習時間は増えるため、計算量に見合う精度にはなっていない。

次は、位置エンコーダを工夫することで精度が上げられないか試したい。


Viewing all articles
Browse latest Browse all 8051