ゼロから作るDeep Learning ❸ ―フレームワーク編
- 作者:斎藤 康毅
- 発売日: 2020/04/20
- メディア: 単行本(ソフトカバー)
ゼロから作るDeep Learning ③を買った。DezeroのPython実装をRubyに移植する形で独自に勉強している。次はステップ15とステップ16。
- ステップ15とステップ16は同じ内容で理論編と実践編となっている。複雑な計算グラフを取り扱うために、ネットワークのツリーをDFSで追うのではなくBFSで追いかけるように変更する。
generation
メンバ変数を追加して、generation
の値の大きい順にbackward
を適用することでBFSでネットワークを逆方向に辿れるようにする。
class Variable ... def backward() if @grad == nil then @grad = @data.clone.fill(1.0) end funcs = [] seen_set = Set.new def add_func(f, funcs, seen_set) if not seen_set.include?(f) then funcs.push(f) seen_set.add(f) funcs.sort!{|a| a.generation} end end add_func(@creator, funcs, seen_set) while funcs != [] do f = funcs.pop gys = f.outputs.map{|x| x.grad} gxs = f.backward(*gys) if not gxs.is_a?(Array) then gxs = [gxs] end f.inputs.zip(gxs).each{|x, gx| if x.grad === nil then x.grad = gx else tmp = (x.grad + gx) x.grad = [tmp.is_a?(Array) ? tmp.sum : tmp] end if x.creator != nil then add_func(x.creator, funcs, seen_set) end } end end
Rubyのset
型を使用して関数の登場順番を管理している。seen_set
型を使ってすでに同じノードを処理したかどうかを見ている。generation
の順番によって適用順序を切り替えているのは、funcs.sort!{|a| a.generation}
によって制御している。
x = Variable.new([2.0]) a = square(x) y = add(square(a), square(a)) y.backward() puts(y.data) puts(x.grad)
32.0 64.0
問題無く動作したようだ。