FPGA開発日記

カテゴリ別記事インデックス https://msyksphinz.github.io/github_pages , English Version https://fpgadevdiary.hatenadiary.com/

ゼロから作るDeep Learning ③ のPython実装をRubyで作り直してみる(ステップ19/ステップ20)

ゼロから作るDeep Learning ❸ ―フレームワーク編

ゼロから作るDeep Learning ❸ ―フレームワーク編

  • 作者:斎藤 康毅
  • 発売日: 2020/04/20
  • メディア: 単行本(ソフトカバー)

ゼロから作るDeep Learning ③を買った。DezeroのPython実装をRubyに移植する形で独自に勉強している。次はステップ21とステップ22。

これまでの演算ではすべて"Variable"と"Variable"同士での演算としていたが、固定値を扱うのにいちいちVariable型でラップしていては面倒なので、整数や浮動小数点の値をそのまま扱えるようにしたい。このために、Variable型の演算子をオーバーライドして使いやすくする。

class Mul < Function
  def forward(x0, x1)
    y = x0[0] * x1[0]
    return [y]
  end
  def backward(gy)
    x0 = @inputs[0].data
    y0 = @inputs[1].data
    return [gy * x1, gy * x0]
  end
end

def mul(x0, x1)
  return Mul.new().call(x0, x1)
end
class Variable
    ...
  def *(other)
    return mul(self, other)
  end

  def +(other)
    return add(self, other)
  end

こんな感じでRubyでも演算子オーバーロードが可能なので実装した。以下のように、固定値を演算の2項目として扱うことができるようになる。

x = Variable.new([2.0])
y = x + [3.0]
puts y
variable([5.0])

ここでRubyを使った場合の弱点が明らかになった。Pythonで使用している__radd__に相当する演算子が存在せず、結局2.0 * Variableのような実装は諦めた。

これ以外の演算子についてもどんどんオーバーロードしていった。

class Variable
...
  def *(other)
    return mul(self, other)
  end

  def /(other)
    return div(self, other)
  end

  def +(other)
    return add(self, other)
  end

  def -(other)
    return sub(self, other)
  end

  def -@
    return neg(self)
  end

  def **(other)
    return pow(self, other)
  end
class Div < Function
  def forward(x0, x1)
    y = x0[0] / x1[0]
    return [y]
  end
  def backward(gy)
    x0 = @inputs[0].data
    y0 = @inputs[1].data
    return [gy / x1, gy * (-x0 / x1 ** 2.0)]
  end
end

def div(x0, x1)
  return Div.new().call(x0, x1)
end


class Neg < Function
  def forward(x)
    return x.map{|x| -x}
  end
  def backward(gy)
    return gx.map{|x| -x}
  end
end

def neg(x)
  return Neg.new().call(x)
end


class Pow < Function
  def initialize(c)
    @c = c
  end
  def forward(x)
    return x.map{|x| x ** @c}
  end
  def backward(gy)
    x = @inputs[0].data
    c = @c
    gx = c * x ** (c - 1) ** gy
    return gx
  end
end

def pow(x, c)
  return Pow.new(c).call(x)
end

テストコードは以下。

x = Variable.new([2.0])
y = x + [3.0]
puts y


x = Variable.new([2.0])
y = -x
puts y


a = Variable.new([3.0])
b = Variable.new([2.0])
y = a - b
puts y

a = Variable.new([3.0])
b = Variable.new([2.0])
y = a / b
puts y


a = Variable.new([3.0])
y = a ** 3.0
puts y
variable([5.0])
variable([-2.0])
variable([1.0])
variable([1.5])
variable([27.0])

できた。想定通り。