FPGA開発日記

カテゴリ別記事インデックス https://msyksphinz.github.io/github_pages , English Version https://fpgadevdiary.hatenadiary.com/

ゼロから作るDeep Learning ③ のDezeroをRubyで作り直してみる(ステップ39/ステップ40)

ゼロから作るDeep Learning ❸ ―フレームワーク編

ゼロから作るDeep Learning ❸ ―フレームワーク編

  • 作者:斎藤 康毅
  • 発売日: 2020/04/20
  • メディア: 単行本(ソフトカバー)

ゼロから作るDeep Learning ③のDezero実装、勉強のためRubyでの再実装に挑戦している。今回はステップ39とステップ40。

  • ステップ39:和を求める関数を実装する。Sumクラスの作成。Sumクラスは各要素の値をすべて足し込む。この時にaxiskeepdimsオプションを追加している。
class Sum < Function
  def initialize(axis, keepdims)
    @axis = axis
    @keepdims = keepdims
  end

  def forward(x)
    @x_shape = x.shape
    y = PyCall::eval("#{x}.sum(axis=#{@axis}, keepdims=#{@keepdims}")
    return y
  end

  def backward(gy)
    gy = reshape_sum_backward(gy, @x_shape, @axis, @keepdims)
    gx = broadcast_to(gy, @x_shape)
    return gx
  end
end

def sum(x, axis=nil, keepdims=false)
  return Sum.new(axis, keepdims).call(x)
end
begin
  x = Variable.new(np.array([1, 2, 3, 4, 5, 6]))
  y = sum(x)
  y.backward()
  puts y
  puts x.grad
end

begin
  x = Variable.new(np.array([[1, 2, 3], [4, 5, 6]]))
  y = sum(x)
  y.backward()
  puts y
  puts x.grad
end

実行結果は以下のようになった。とりあえず上手く行ったようだ。

variable(21)
variable([1 1 1 1 1 1])

variable(21)
variable([[1 1 1]

今回の実装に当たり、reshape_sum_backward()を実装し直す必要があった。もう面倒なのでPyCallを使いまくってPythonの機能をフル活用する。

def reshape_sum_backward(gy, x_shape, axis, keepdims)
  ndim = PyCall::len(x_shape)
  tupled_axis = axis
  if axis == nil then
    tupled_axis = nil
  elsif not PyCall::eval("hasattr(#{axis}, 'len')") then
    tupled_axis = [axis]
  end

  shape = []
  if not (ndim == 0 or tupled_axis == nil or keepdims) then
    actual_axis = tupled_axis.each{|a|
      a >= 0 ? a : a + ndim
    }
    shape = [gy.shape]
    for a in PyCall::eval("sorted(#{actual_axis})") do
      shape.insert(a, 1)
    end
  else
    shape = gy.shape
  end

  gy = gy.reshape(shape)  # reshape
  return gy
end
  • ステップ40:今度はブロードキャストを行う関数を定義する。
class BroadcastTo < Function
  def initialize(shape)
    @shape = shape
  end

  def forward(x)
    np = Numpy
    @x_shape = x.shape
    y = np.broadcast_to(x, @shape)
    return y
  end

  def backward(gy)
    gx = Kernel.send(:sum_to, gy, @x_shape)
    gx = sum_to(gy, @x_shape)
    return gx
  end
end


def broadcast_to(x, shape)
  if x.shape == shape then
    return as_variable(x)
  end
  return BroadcastTo.new(shape).call(x)
end
def util_sum_to(x, shape)
  ndim = PyCall::eval("len(#{shape})")
  lead = x.ndim - ndim
  lead_axis = PyCall.eval("tuple(range(#{lead}))")

  axis = PyCall.eval("tuple([i + #{lead} for i, sx in enumerate(#{shape}) if sx == 1])")
  y = x.sum(lead_axis + axis, keepdims:true)
  if lead > 0 then
    y = y.squeeze(lead_axis)
  end
  return y
end
begin
  x0 = Variable.new(np.array([1, 2, 3]))
  x1 = Variable.new(np.array([10]))
  y = x0 + x1
  puts y

  y.backward
  puts x1.grad
end
variable([11 12 13])
variable([3])