ゼロから作るDeep Learning ❸ ―フレームワーク編
- 作者:斎藤 康毅
- 発売日: 2020/04/20
- メディア: 単行本(ソフトカバー)
ゼロから作るDeep Learning ③のDezero実装、勉強のためRubyでの再実装に挑戦している。計算グラフが表示できるようになって、いよいよ難しそうなところに突入していく。今回はステップ27とステップ28。
class Sin < Function def forward(x) np = Numpy y = np.sin(x) return y end def backward(gy) np = Numpy x = @inputs[0].data gx = gy * np.cos(x) return gx end end def sin(x) return Sin.new().call(x) end
テイラー展開を適用した場合のsin関数は以下のような実装となる。
def factorial(number) (1..number).inject(1,:*) end def my_sin(x, threshold=0.0001) y = Variable.new(0.0) for i in (0..100000) do c = ((-1) ** i) / factorial(2 * i + 1).to_f t = (x ** (2 * i + 1)) * c y = y + t if t.data.abs < threshold then break end end return y end
テストを通してみると、どちらも同じ答えを返しており問題ないことが分かる。
begin puts("=== Test of Sin Function ===") x = Variable.new(np.array(np.pi/4)) y = sin(x) y.backward() puts(y.data) puts(x.grad) puts("=== End Test of Sin Function ===") end begin puts("=== Test of MySin Function ===") x = Variable.new(np.array(np.pi/4)) y = my_sin(x, 1e-150) y.backward() puts(y.data) puts(x.grad) plot_dot_graph(y, verbose=false, to_file='sin.png') puts("=== End Test of My Sin Function ===") end
=== Test of Sin Function === 0.7071067811865475 0.7071067811865476 === End Test of Sin Function === === Test of MySin Function === 0.7071067811865475 0.7071067811865476 === End Test of My Sin Function ===
グラフも出力できた。threshold=0.0001
の時とthreshold=1e-150
の時でグラフを比較する。
threshold=0.0001
の場合
threshold=1e-150
の場合
- ステップ28:関数の最適化の一手法として、勾配降下法を実装する。
まずはローゼンブロック関数から。以下のようにRubyで表現できる。
def rosenbrock(x0, x1) y = ((x1 - x0 ** 2) ** 2) * 100 + (x0 - 1) ** 2 return y end
これを勾配降下法を使って最小値を求めてみる。勾配降下法の実装は以下で表現できる。
begin x0 = Variable.new(np.array(0.0)) x1 = Variable.new(np.array(2.0)) lr = 0.001 iters = 50000 for i in 0..iters do puts [x0.to_s, x1.to_s].to_s y = rosenbrock(x0, x1) x0.cleargrad() x1.cleargrad() y.backward() x0.data -= lr * x0.grad x1.data -= lr * x1.grad end end
演算結果をトレースしてみると、最終的に限りなく[x0, x1] = [1.0, 1.0]
に近づくことができた。次回はこれをニュートン法で実現しようという訳だ。
... ["variable(0.9999999993696386)", "variable(0.9999999987367547)"] ["variable(0.9999999993698904)", "variable(0.9999999987372592)"] ["variable(0.999999999370142)", "variable(0.9999999987377635)"] ["variable(0.9999999993703935)", "variable(0.9999999987382675)"] ["variable(0.9999999993706449)", "variable(0.9999999987387714)"] ["variable(0.9999999993708962)", "variable(0.9999999987392751)"] ["variable(0.9999999993711475)", "variable(0.9999999987397786)"] ["variable(0.9999999993713986)", "variable(0.9999999987402819)"] ["variable(0.9999999993716496)", "variable(0.9999999987407849)"] ["variable(0.9999999993719006)", "variable(0.9999999987412878)"] ["variable(0.9999999993721514)", "variable(0.9999999987417905)"] ["variable(0.9999999993724021)", "variable(0.999999998742293)"]