rokkiの備忘録

地道こそ近道

未だ、人類の求めるAIには届かず、AIはまちがい続ける。

Intro

rokkiです.進捗報告会が無事終了し,一安心しております.

今回は最近研究で取り組んだ課題をまとめた記事です.

Convolutional Neural Networkが回転不変性を表現できるか」という課題に取り組みました.

取り組むキッカケとしては,自身の研究で回転不変性*1を対処する必要が生じたためです.

データセット

今回実験に用いたデータセットの紹介をします.

Computer Visionの領域では有名なデータセットであるMNISTを用いました. また,MNISTを回転させたデータセットである,rotated MNISTも比較実験のために用いました.

データセットの準備

以下ページの"Rotated MNIST digits"からダウンロードできます.

sites.google.com

基本情報

0~9の手書き文字画像が62,000枚. 学習用とテスト用でそれぞれ別ファイルに圧縮されている.

  • train_valid_data: (12000, 28*28)
  • test_data: (50000, 28*28)

f:id:roki_memo:20200711191627p:plain

統計情報

各ラベルの比率は以下の通り.両方のデータセットにおいて,1が多く,5が少ない傾向がある.

f:id:roki_memo:20200711194553p:plain

実験

実験の流れ

  • 特徴抽出器の学習:2種類のデータセット*2を異なるモデルに学習させる.学習方法はmetric learningを用いる.
  • 分類器の学習:学習させたモデルを特徴抽出器として活用する.これらのモデルで得た特徴量を用いて,複数の分類器に10クラス分類タスクを学習させる.
  • 分類器の評価:分類器の分類性能によって,vanilla CNNの回転不変の表現能力を評価する.

実験の詳細

  • 目的:vanilla CNNの回転不変の表現能力を検証する.
  • データセット:MNISTとrotated-MNISTの学習用データ数,およびテスト用データ数は統一*3.実装の都合上64x64にリサイズしている.
  • モデル:基本的なCNNブロック(MaxPool,CNN,BN,ReLU,CNN,BN,ReLU)を4つ重ねる.出力層はAvgPool,Linear.
  • data shape:input (N, 64, 64) -> output (N, 64)
  • 分類器:KNN, 決定木, Random Forest, Neural Net, AdaBoost, Naive Bayes, QDA
  • metric learning:triplet lossを使用.postive, negativeサンプル数はそれぞれ1.
  • その他ハイパーパラメータ:optimizerにはAdamを使用,metric learningの学習ではepoch=30, margin=0.1.

torchsummary*4によるモデルの可視化

from torchsummary import summary
summary(model, [anchor, positive, negative])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
         MaxPool2d-1            [-1, 1, 32, 32]               0
            Conv2d-2            [-1, 8, 32, 32]              80
       BatchNorm2d-3            [-1, 8, 32, 32]              16
              ReLU-4            [-1, 8, 32, 32]               0
            Conv2d-5            [-1, 8, 32, 32]             584
       BatchNorm2d-6            [-1, 8, 32, 32]              16
              ReLU-7            [-1, 8, 32, 32]               0
         MaxPool2d-8            [-1, 8, 16, 16]               0
            Conv2d-9           [-1, 16, 16, 16]           1,168
      BatchNorm2d-10           [-1, 16, 16, 16]              32
             ReLU-11           [-1, 16, 16, 16]               0
           Conv2d-12           [-1, 16, 16, 16]           2,320
      BatchNorm2d-13           [-1, 16, 16, 16]              32
             ReLU-14           [-1, 16, 16, 16]               0
        MaxPool2d-15             [-1, 16, 8, 8]               0
           Conv2d-16             [-1, 32, 8, 8]           4,640
      BatchNorm2d-17             [-1, 32, 8, 8]              64
             ReLU-18             [-1, 32, 8, 8]               0
           Conv2d-19             [-1, 32, 8, 8]           9,248
      BatchNorm2d-20             [-1, 32, 8, 8]              64
             ReLU-21             [-1, 32, 8, 8]               0
        MaxPool2d-22             [-1, 32, 4, 4]               0
           Conv2d-23             [-1, 64, 4, 4]          18,496
      BatchNorm2d-24             [-1, 64, 4, 4]             128
             ReLU-25             [-1, 64, 4, 4]               0
           Conv2d-26             [-1, 64, 4, 4]          36,928
      BatchNorm2d-27             [-1, 64, 4, 4]             128
             ReLU-28             [-1, 64, 4, 4]               0
        MaxPool2d-29             [-1, 64, 2, 2]               0
           Conv2d-30            [-1, 128, 2, 2]          73,856
      BatchNorm2d-31            [-1, 128, 2, 2]             256
             ReLU-32            [-1, 128, 2, 2]               0
           Conv2d-33            [-1, 128, 2, 2]         147,584
      BatchNorm2d-34            [-1, 128, 2, 2]             256
             ReLU-35            [-1, 128, 2, 2]               0
        AvgPool2d-36            [-1, 128, 1, 1]               0
                ...
          Linear-109                   [-1, 64]           8,256
          Linear-110                   [-1, 64]           8,256
          Linear-111                   [-1, 64]           8,256
================================================================
Total params: 912,456
Trainable params: 912,456
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): ####
Forward/backward pass size (MB): 2.30
Params size (MB): 3.48
Estimated Total Size (MB): ####
----------------------------------------------------------------

Results

全ての分類器において,MNISTの分類性能に比べ,rotated MNISTの分類性能が低下した*5

最も高精度だったQDAの分類結果を以下に示す.

f:id:roki_memo:20200711192003p:plain

考察

検証の結果,「vanilla CNNが回転不変性を表現するのは困難」と言える*6

また,rotated MNISTの分類結果において,0と1の分類性能が高いのは興味深い.

6と9は回転すると見分けがつかないが,こうした性質は分類結果からは得られなかった.

ハマった点

  • 何故かrotated-MNISTデータセットには転置処理が必要だった点.MNISTデータセットには不要だった*7

Outlo

今回は,直近に取り組んだ課題をまとめました. 修論を執筆する時期に自分で見返すのかなぁと考えています.

以下注意事項です.

*本記事には誤りが含まれている可能性があります.全て鵜呑みにせず,ご自身で追加実験検証を行ってみてください.

*1:回転不変性の参考文献: 参考文献1 arxiv.org

参考文献2 arxiv.org

*2:回転している画像(MNIST),回転画像(rotated-MNIST)

*3:各文字のサンプル数の統一は行っていないため,統一すれば結果が変化するかもしれない.

*4:torchsummarygithub.com

*5:Acc,Precision,Recall,f1を追記予定

*6:逆に言えば,CNNの拡張によって回転不変性を獲得できるはず

*7:モデルに渡す直前でデータは可視化しよう!!