FPGA開発日記

カテゴリ別記事インデックス https://msyksphinz.github.io/github_pages , English Version https://fpgadevdiary.hatenadiary.com/

Chainerを使って関数フィッティングに挑戦する(2.多項式へのフィッティング)

Chainerで様々な関数をフィッティングさせるニューラルネットワークを作っているのだが、前回、うまいことフィッティングできず線形な関数になってしまったので、その調査をしていた。

msyksphinz.hatenablog.com

まず、現在のニューラルネットワークがどのように悪いのか、という話だが、Chainer.Functions.Linear()を使ってネットワークを作っており、これは完全な線形関数だ。

線形関数を繋げていっても、それは線形関数になってしまう。というか、xを何度も乗算して多項式を作れていない。どうすればいいんだろう?

    def forward(self, x):
        h1 = self.fc1(x)
        h2 = self.fc2(h1)
        h3 = self.fc3(h2)
        return h3

TensorFlowで多項式のフィッティングをしたとき

TensorFlowでも同様の試行をしたのだが、このときはニューラルネットワークというか、多項式を定義して、その誤差を0に近づけるように最適化していった。

msyksphinz.hatenablog.com

y4 = W3*x_data*x_data*x_data+W2*x_data*x_data + W1*x_data + W0

良く考えると、このforwardの変数は、Chainer.Variableとして定義してある。演算をオーバーロードできているのならば、普通に乗算演算子を使って定義できるのではないか?

    def forward(self, x):
        h1 = self.fc1(x)
        h2 = self.fc2(h1*x)
        h3 = self.fc3(h2*x)

あ、確かに、それっぽくなった。

f:id:msyksphinz:20160715022809p:plain

矩形関数をフィッティングしてみる

調子にのって、矩形関数をフィッティングしてみよう。以下のような関数を定義して、入力ファイルを作成した。

float function (float x)
{
  if (x < 0.0) { return -1.0; }
  else         { return  1.0; }
}
  • input_data.txt
-1.81518, -1
0.854429, 1
0.263134, 1
1.57222, 1
-1.93722, -1
0.6646, 1
-1.69719, -1
-1.02543, -1
1.25107, 1
-1.95763, -1
-0.755761, -1
1.77541, 1
-0.77935, -1
0.520007, 1
-1.23955, -1
0.12331, 1
-1.33116, -1
-1.62714, -1
-1.38443, -1

これを、以下の7次多項式にフィッティングさせる。

class Function3DimentionModel(chainer.FunctionSet):
    def __init__(self):
        super(Function3DimentionModel, self).__init__(
            fc1=chainer.functions.Linear(1, 1),
            fc2=chainer.functions.Linear(1, 1),
            fc3=chainer.functions.Linear(1, 1),
            fc4=chainer.functions.Linear(1, 1),
            fc5=chainer.functions.Linear(1, 1),
            fc6=chainer.functions.Linear(1, 1),
            fc7=chainer.functions.Linear(1, 1),
    )
    def forward(self, x):
        h1 = self.fc1(x)
        h2 = self.fc2(h1*x)
        h3 = self.fc3(h2*x)
        h4 = self.fc4(h3*x)
        h5 = self.fc5(h4*x)
        h6 = self.fc6(h5*x)
        h7 = self.fc7(h6*x)
        return h7

フィッティング結果

フーリエ級数展開みたいになる。うまくいった。

f:id:msyksphinz:20160715023535p:plain

ソースコード

github.com