未だ、人類の求めるAIには届かず、AIはまちがい続ける。
Intro
rokkiです.進捗報告会が無事終了し,一安心しております.
今回は最近研究で取り組んだ課題をまとめた記事です.
「Convolutional Neural Networkが回転不変性を表現できるか」という課題に取り組みました.
取り組むキッカケとしては,自身の研究で回転不変性*1を対処する必要が生じたためです.
データセット
今回実験に用いたデータセットの紹介をします.
Computer Visionの領域では有名なデータセットであるMNISTを用いました. また,MNISTを回転させたデータセットである,rotated MNISTも比較実験のために用いました.
データセットの準備
以下ページの"Rotated MNIST digits"からダウンロードできます.
基本情報
0~9の手書き文字画像が62,000枚. 学習用とテスト用でそれぞれ別ファイルに圧縮されている.
- train_valid_data: (12000, 28*28)
- test_data: (50000, 28*28)
統計情報
各ラベルの比率は以下の通り.両方のデータセットにおいて,1が多く,5が少ない傾向がある.
実験
実験の流れ
- 特徴抽出器の学習:2種類のデータセット*2を異なるモデルに学習させる.学習方法はmetric learningを用いる.
- 分類器の学習:学習させたモデルを特徴抽出器として活用する.これらのモデルで得た特徴量を用いて,複数の分類器に10クラス分類タスクを学習させる.
- 分類器の評価:分類器の分類性能によって,vanilla CNNの回転不変の表現能力を評価する.
実験の詳細
- 目的:vanilla CNNの回転不変の表現能力を検証する.
- データセット:MNISTとrotated-MNISTの学習用データ数,およびテスト用データ数は統一*3.実装の都合上64x64にリサイズしている.
- モデル:基本的なCNNブロック(MaxPool,CNN,BN,ReLU,CNN,BN,ReLU)を4つ重ねる.出力層はAvgPool,Linear.
- data shape:input (N, 64, 64) -> output (N, 64)
- 分類器:KNN, 決定木, Random Forest, Neural Net, AdaBoost, Naive Bayes, QDA
- metric learning:triplet lossを使用.postive, negativeサンプル数はそれぞれ1.
- その他ハイパーパラメータ:optimizerにはAdamを使用,metric learningの学習ではepoch=30, margin=0.1.
torchsummary*4によるモデルの可視化
from torchsummary import summary summary(model, [anchor, positive, negative]) ---------------------------------------------------------------- Layer (type) Output Shape Param # ================================================================ MaxPool2d-1 [-1, 1, 32, 32] 0 Conv2d-2 [-1, 8, 32, 32] 80 BatchNorm2d-3 [-1, 8, 32, 32] 16 ReLU-4 [-1, 8, 32, 32] 0 Conv2d-5 [-1, 8, 32, 32] 584 BatchNorm2d-6 [-1, 8, 32, 32] 16 ReLU-7 [-1, 8, 32, 32] 0 MaxPool2d-8 [-1, 8, 16, 16] 0 Conv2d-9 [-1, 16, 16, 16] 1,168 BatchNorm2d-10 [-1, 16, 16, 16] 32 ReLU-11 [-1, 16, 16, 16] 0 Conv2d-12 [-1, 16, 16, 16] 2,320 BatchNorm2d-13 [-1, 16, 16, 16] 32 ReLU-14 [-1, 16, 16, 16] 0 MaxPool2d-15 [-1, 16, 8, 8] 0 Conv2d-16 [-1, 32, 8, 8] 4,640 BatchNorm2d-17 [-1, 32, 8, 8] 64 ReLU-18 [-1, 32, 8, 8] 0 Conv2d-19 [-1, 32, 8, 8] 9,248 BatchNorm2d-20 [-1, 32, 8, 8] 64 ReLU-21 [-1, 32, 8, 8] 0 MaxPool2d-22 [-1, 32, 4, 4] 0 Conv2d-23 [-1, 64, 4, 4] 18,496 BatchNorm2d-24 [-1, 64, 4, 4] 128 ReLU-25 [-1, 64, 4, 4] 0 Conv2d-26 [-1, 64, 4, 4] 36,928 BatchNorm2d-27 [-1, 64, 4, 4] 128 ReLU-28 [-1, 64, 4, 4] 0 MaxPool2d-29 [-1, 64, 2, 2] 0 Conv2d-30 [-1, 128, 2, 2] 73,856 BatchNorm2d-31 [-1, 128, 2, 2] 256 ReLU-32 [-1, 128, 2, 2] 0 Conv2d-33 [-1, 128, 2, 2] 147,584 BatchNorm2d-34 [-1, 128, 2, 2] 256 ReLU-35 [-1, 128, 2, 2] 0 AvgPool2d-36 [-1, 128, 1, 1] 0 ... Linear-109 [-1, 64] 8,256 Linear-110 [-1, 64] 8,256 Linear-111 [-1, 64] 8,256 ================================================================ Total params: 912,456 Trainable params: 912,456 Non-trainable params: 0 ---------------------------------------------------------------- Input size (MB): #### Forward/backward pass size (MB): 2.30 Params size (MB): 3.48 Estimated Total Size (MB): #### ----------------------------------------------------------------
Results
全ての分類器において,MNISTの分類性能に比べ,rotated MNISTの分類性能が低下した*5
最も高精度だったQDAの分類結果を以下に示す.
考察
検証の結果,「vanilla CNNが回転不変性を表現するのは困難」と言える*6.
また,rotated MNISTの分類結果において,0と1の分類性能が高いのは興味深い.
6と9は回転すると見分けがつかないが,こうした性質は分類結果からは得られなかった.
ハマった点
Outlo
今回は,直近に取り組んだ課題をまとめました. 修論を執筆する時期に自分で見返すのかなぁと考えています.
以下注意事項です.
*本記事には誤りが含まれている可能性があります.全て鵜呑みにせず,ご自身で追加実験検証を行ってみてください.