前回のTensorFlowのサンプルを使って、学習サンプル数と学習結果についてもうちょっと調べてみた。
オーバフィッティングは起こらない?
最初にやりたかったのは、10次の多項式をモデル化し、10個のサンプルを与えるだけで、オーバフィッティングするのではないかという予想のもと、そのようなグラフを作ってみたかった。 (本当に初学者なので、オーバーフィッティングというものを馬鹿の二つ覚えしているに過ぎない...)
つまり、10個のサンプルだけで10次の多項式をフィッティングさせようとすると、ほぼ確実に誤差が0になる多項式を生成できるため、逆に複雑なグラフになってしまうのではないかということだ。
TensorFlowのコードでは、GradientDescentOptimizer、つまり最急降下法を使っているので、うまく再現できるかどうかは不安だが、とりあえずままずやってみる。
サンプル数が少ない状態で学習をこなすとどのようになるのか。
前回の日記と同じく、多項式をフィッティングさせる問題を解かしてみよう。 ただし、今回は拡張して10次の多項式をモデル化し、同様に三角関数をフィッティングさせてみる。 今回は、バイアスを±に振って、より大きな誤差がサンプルに乗るようにした。
uniformは、-1.0から1.0までの乱数となる。この式に従って、10個のサンプルから100個のサンプルまで用意し、同じ回数だけトレーニングをこなしたときにどのようになるか見てみよう。
import tensorflow as tf import numpy as np import random def training(rand_num): x_data = np.float32(np.random.random(rand_num)) # Random input y_data = np.sin(2*np.pi*x_data) + 0.5 * np.random.uniform(-1, 1) # for index in range(0, x_data.size): # print x_data[index], y_data[index] W9 = tf.Variable(0.0) W8 = tf.Variable(0.0) W7 = tf.Variable(0.0) W6 = tf.Variable(0.0) W5 = tf.Variable(0.0) W4 = tf.Variable(0.0) W3 = tf.Variable(0.0) W2 = tf.Variable(0.0) W1 = tf.Variable(0.0) W0 = tf.Variable(0.0) y4 = W9*x_data*x_data*x_data*x_data*x_data*x_data*x_data*x_data*x_data+ \ W8*x_data*x_data*x_data*x_data*x_data*x_data*x_data*x_data+ \ W7*x_data*x_data*x_data*x_data*x_data*x_data*x_data+ \ W6*x_data*x_data*x_data*x_data*x_data*x_data+ \ W5*x_data*x_data*x_data*x_data*x_data+ \ W4*x_data*x_data*x_data*x_data+ \ W3*x_data*x_data*x_data+ \ W2*x_data*x_data+ \ W1*x_data+ \ W0 loss = tf.reduce_mean(tf.square(y4 - y_data)) optimizer = tf.train.GradientDescentOptimizer(0.5) train = optimizer.minimize(loss) init = tf.initialize_all_variables() sess = tf.Session() sess.run(init) for step in range(0, 4000): sess.run(train) print sess.run(W9), sess.run(W8), sess.run(W7), sess.run(W6), sess.run(W5),\ sess.run(W4), sess.run(W3), sess.run(W2), sess.run(W1), sess.run(W0), sess.run(loss) for rand_num in range(10, 100, 10): training (rand_num)
10次多項式の作り方が汚い! いや、ホントPython初学者なんですよ。。。
という訳で、10~100まで、10きざみでサンプル数を増やしていった時の結果がこちら。
ちゃんとグラフを書けばいいのだが、Excelなので御愛嬌。黄色の点がサンプル数10個。かなりずれてるし、1.0に近づくと発散してしまう。 サンプル数を100まで増やす(茶色の点)と、ある程度フィッテイングしている。
だけど、パラメータと一緒にlossの値も表示してみた(つまり、誤差の大きさ)のだが、あまり変わらないのはどういう訳だろう?
print sess.run(W9), sess.run(W8), sess.run(W7), sess.run(W6), sess.run(W5),\ sess.run(W4), sess.run(W3), sess.run(W2), sess.run(W1), sess.run(W0), sess.run(loss) ... 0.978335 1.34685 1.54304 1.40795 0.732446 -0.659654 -2.56692 -3.4394 0.564211 0.859257 0.0144768 1.25557 1.78463 2.11603 2.02262 1.167 -0.838215 -3.97495 -5.83498 3.67262 -0.0174955 0.0203349 2.75274 2.94246 2.87709 2.34228 1.0264 -1.41436 -4.89714 -6.83561 4.77098 0.343619 0.0133519 -0.755926 0.788617 2.02863 2.6458 2.17918 0.0780315 -3.82903 -7.05471 4.09891 0.711796 0.0232179 -1.2181 0.538887 2.05961 2.97595 2.7154 0.546614 -3.86205 -7.74185 4.02293 0.247065 0.0255745 -0.496135 0.967327 2.16554 2.75367 2.22242 -0.0230082 -4.09764 -7.20815 4.19292 -0.171896 0.0243655 -2.07696 0.260569 2.18981 3.31652 3.06736 0.764891 -3.77455 -7.58748 3.7258 0.0196501 0.0243026 -1.60937 0.418332 2.13299 3.15403 2.90046 0.636502 -3.96227 -8.02923 4.42217 0.256392 0.0244194 -1.71368 0.334878 2.05888 3.07877 2.83201 0.642286 -3.70678 -7.40558 3.78548 -0.0822037 0.0250228
一番右の値が誤差になり、上からサンプル数を10から100まで増やしていったのだが、loss値が大きくなっているように見える。 ここらへんはまだ良く分かってないなあ。。。