FPGA開発日記

カテゴリ別記事インデックス https://msyksphinz.github.io/github_pages , English Version https://fpgadevdiary.hatenadiary.com/

「ゼロから作る Deep Learning」のCNNの構造を読みとく(1. CNNの構造勉強中)

数学は苦手です。

「ゼロから作る Deep Learning」のCNNの章を一生懸命理解しようとしている。目標は、自分でC/C++などの言語でCNNを構築できるようになること。 それを組み込み機器などに移植するのが目標だ。

大まかなCNNの流れ

SimpleConvNet {
   Convolution (W1, b1)
   Relu ()
   Pooling ()
   Affine (W2, b2)
   Relu ()
   Affine (W3, b3)
   SoftmaxWithLoss ()
}

Convolution以外は何となく分かるが、Convolution内では何をやっている?っていうか各関数のパラメータが分からん。

Convolution 関数: Convolution (W, b, stride, pad)

Convolution (W, b, stride, pad) 実際にMNISTを動作させるときのパラメータで読み解いていくと、

  • W : W1は4つの要素を持っていた。具体的には (30, 1, 5, 5) 順番に、
    • FN : 30 はフィルタの数。FN これはバッチサイズになっているように見える。
    • C: 1 はチャネル数。色がついている場合などはRGBごとのために3に設定されたりする。ただしMNISTは白黒なので1。
    • FH : 5はフィルタの縦の長さ。畳み込みを行うフィルタの縦の長さ。
    • FW : 5はフィルタの横の長さ。畳み込みを行うフィルタの横の長さ。

従って、MNISTの場合は Convolution(30, 1, 5, 5) となる。

このとき、この畳み込み演算で生成される行列のサイズは、フィルタのサイズとパディングの数だけ減るため、以下のように計算される。

\text{out_h} = 1 + \text{int}\left(\left( H + 2\times\text{self.pad} - FH \right) / \text{stride} \right) = 1 + (28 - 5) = 24 となるわけか。

Python の関数群の理解。 "transpose" と "reshape" について

transposeはPythonの行列の形を変える。例。

>>> array = np.array([[[0, 1, 2], [3, 4, 5], [6, 7, 8]], [[9,10,11], [12,13,14], [15,16,17]], [[18,19,20], [21,22,23],
array([[[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8]],

       [[ 9, 10, 11],
        [12, 13, 14],
        [15, 16, 17]],

       [[18, 19, 20],
        [21, 22, 23],
        [24, 25, 26]]])
>>> array.transpose(0,2,1)
array([[[ 0,  3,  6],
        [ 1,  4,  7],
        [ 2,  5,  8]],

       [[ 9, 12, 15],
        [10, 13, 16],
        [11, 14, 17]],

       [[18, 21, 24],
        [19, 22, 25],
        [20, 23, 26]]])

reshapeは配列の形を変える。この時に便利なのが、配列の長さとして-1を指定すると、残りの長さの指定値から適切な長さを割り出して自動的に変換してくれるということ。例。

>>> array.reshape(3, -1)
array([[ 0,  1,  2,  3,  4,  5,  6,  7,  8],
       [ 9, 10, 11, 12, 13, 14, 15, 16, 17],
       [18, 19, 20, 21, 22, 23, 24, 25, 26]])

これでフィルタの形を変えて、処理しやすい形に変換しているということか。多少分かってきた。

f:id:msyksphinz:20180125004459p:plain

im2col のコードについて

im2col はデータをフィルタ処理しやすいように変換する関数のこと。通常そのままフィルタをかけてもよいのだが、PythonはFor文が苦手のため、わざわざ配列の形式を変換している。

  • im2col(input_data, filter_h, filter_w, stride, pad)

ここで、 input_data は4次元配列で、(N,C,H,W)を示している。