ゼロから作るDeep Learning ❸ ―フレームワーク編
- 作者:斎藤 康毅
- 発売日: 2020/04/20
- メディア: 単行本(ソフトカバー)
ゼロから作るDeep Learning ③のDezero実装、勉強のためRubyでの再実装に挑戦している。計算グラフが表示できるようになって、いよいよ難しそうなところに突入していく。今回はステップ35とステップ36。
- ステップ35:高階微分の計算グラフ、ということで
tanh
の計算グラフを作ってみる。今回はこの頭のおかしな計算グラフを作ることが目的なのでひたすら計算グラフを作る。
class Tanh < Function def forward(x) np = Numpy y = np.tanh(x) return y end def backward(gy) y = @outputs[0].__getobj__ gx = gy * (y * y * (-1) + 1) return gx end end def tanh(x) return Tanh.new().call(x) end
これを1階微分から8階微分まで計算グラフを求めるのが今回の課題。
x = Variable.new(np.array(1.0)) y = tanh(x) x.name = 'x' y.name = 'y' y.backward() iters = 8 for i in 0..(iters-1) do gx = x.grad x.cleargrad() gx.backward() end gx = x.grad puts gx.class gx.name = 'gx' + iters.to_s plot_dot_graph(gx, false, 'tanh' + iters.to_s + '.png')
1階微分の時が以下。
8階微分の結果。これは本書のグラフに近しいものができた。すばらしい。
x = Variable.new(np.array(2.0)) y = x ** 2 y.backward() gx = x.grad x.cleargrad() z = gx ** 3 + y z.backward() puts x.grad
y
が を示しており、gx
がy
の微分を示している。このy
の微分に対してgx ** 3 + y
を計算してからその微分を求めることでz
を微分したときの値を求めることができる。
variable(100.0)
手計算と同様の結果が出力された。