FPGA開発日記

カテゴリ別記事インデックス https://msyksphinz.github.io/github_pages , English Version https://fpgadevdiary.hatenadiary.com/

ゼロから作るDeep Learning ③ のPython実装をRubyで作り直してみる(ステップ15/ステップ16)

ゼロから作るDeep Learning ❸ ―フレームワーク編

ゼロから作るDeep Learning ❸ ―フレームワーク編

  • 作者:斎藤 康毅
  • 発売日: 2020/04/20
  • メディア: 単行本(ソフトカバー)

ゼロから作るDeep Learning ③を買った。DezeroのPython実装をRubyに移植する形で独自に勉強している。次はステップ15とステップ16。

  • ステップ15とステップ16は同じ内容で理論編と実践編となっている。複雑な計算グラフを取り扱うために、ネットワークのツリーをDFSで追うのではなくBFSで追いかけるように変更する。

generationメンバ変数を追加して、generationの値の大きい順にbackwardを適用することでBFSでネットワークを逆方向に辿れるようにする。

class Variable
...
  def backward()
    if @grad == nil then
      @grad = @data.clone.fill(1.0)
    end
    funcs = []
    seen_set = Set.new

    def add_func(f, funcs, seen_set)
      if not seen_set.include?(f) then
        funcs.push(f)
        seen_set.add(f)
        funcs.sort!{|a| a.generation}
      end
    end

    add_func(@creator, funcs, seen_set)
    while funcs != [] do
      f = funcs.pop
      gys = f.outputs.map{|x| x.grad}
      gxs = f.backward(*gys)
      if not gxs.is_a?(Array) then
        gxs = [gxs]
      end
      f.inputs.zip(gxs).each{|x, gx|
        if x.grad === nil then
          x.grad = gx
        else
          tmp = (x.grad + gx)
          x.grad = [tmp.is_a?(Array) ? tmp.sum : tmp]
        end
        if x.creator != nil then
          add_func(x.creator, funcs, seen_set)
        end
      }
    end
  end

Rubyset型を使用して関数の登場順番を管理している。seen_set型を使ってすでに同じノードを処理したかどうかを見ている。generationの順番によって適用順序を切り替えているのは、funcs.sort!{|a| a.generation}によって制御している。

x = Variable.new([2.0])
a = square(x)
y = add(square(a), square(a))
y.backward()

puts(y.data)
puts(x.grad)
32.0
64.0

問題無く動作したようだ。