FPGA開発日記

カテゴリ別記事インデックス https://msyksphinz.github.io/github_pages , English Version https://fpgadevdiary.hatenadiary.com/

ゼロから作るDeep Learning ③ のPython実装をRubyで作り直してみる(ステップ7/ステップ8)

ゼロから作るDeep Learning ❸ ―フレームワーク編

ゼロから作るDeep Learning ❸ ―フレームワーク編

  • 作者:斎藤 康毅
  • 発売日: 2020/04/20
  • メディア: 単行本(ソフトカバー)

ゼロから作るDeep Learning ③を買った。DezeroのPython実装をRubyに移植する形で独自に勉強している。次はステップ7とステップ8。

バックプロパゲーションを自動化するために、VariableおよびFunctionに変更を加える。Variableには、変数自分自身を作成したFunctionを覚えるためのメンバ変数を追加し、Functionには自分の生成した変数を記録する。

  • step07.rb
class Variable
  def initialize(data)
    @data = data
    @grad = nil
    @creator = nil
  end
  def set_creator(func)
    @creator = func
  end

  def backward()
    f = @creator
    if f != nil then
      x = f.input
      x.grad = f.backward(@grad)
      x.backward()
    end
  end

  attr_accessor :data, :grad, :creator
end
class Square < Function
  def forward(x)
    return x.map{|i| i ** 2}
  end
  def backward(gy)
    x = @input.data
    gx = x.zip(gy).map{|i0, i1| i0 * i1 * 2.0}
    return gx
  end
end

ポイントとなるのはVariablebackward()において、さらにbackward()を呼び出すことで変数から変数へと逆方向にジャンプしていくことだ。これを繰り返して最終的な始点に計算を向けていく。

A = Square.new()
B = Exp.new()
C = Square.new()

x = Variable.new([0.5])
a = A.call(x)
b = B.call(a)
y = C.call(b)

y.grad = [1.0]
y.backward()
puts(x.grad)
3.297442541400256
  • ステップ8:backward()再帰をループに置き換える

ステップ7のbackward()の実装では、バックプロパゲーションを実現するのに再帰的にbackward()を呼び出した。これでも良いのだが今後の拡張性と速度向上を見据えてループの実装に置き換える。

これは簡単で、いわゆるツリーの探索方式と考えて良い。さらに今回の場合はDFSだろうがBFSだろうがノード自身が前の値を覚えているためどっちを使っても良い。従ってキューを作ってループに変更しても特に問題ない。

  • step08.rb
class Variable
  def initialize(data)
    @data = data
    @grad = nil
    @creator = nil
  end
  def set_creator(func)
    @creator = func
  end

  def backward()
    funcs = [@creator]
    while funcs != [] do
      f = funcs.pop
      x = f.input
      y = f.output
      x.grad = f.backward(y.grad)
      if x.creator != nil then
        funcs.push(x.creator)
      end
    end
  end

  attr_accessor :data, :grad, :creator
end

同様にテストを実行し、同じ結果が得られた。

A = Square.new()
B = Exp.new()
C = Square.new()

x = Variable.new([0.5])
a = A.call(x)
b = B.call(a)
y = C.call(b)

y.grad = [1.0]
y.backward()
puts(x.grad)
3.297442541400256