ゼロから作るDeep Learning ❸ ―フレームワーク編
- 作者:斎藤 康毅
- 発売日: 2020/04/20
- メディア: 単行本(ソフトカバー)
ゼロから作るDeep Learning ③を買った。DezeroのPython実装をRubyに移植する形で独自に勉強している。次はステップ21とステップ22。
これまでの演算ではすべて"Variable
"と"Variable
"同士での演算としていたが、固定値を扱うのにいちいちVariable
型でラップしていては面倒なので、整数や浮動小数点の値をそのまま扱えるようにしたい。このために、Variable
型の演算子をオーバーライドして使いやすくする。
class Mul < Function def forward(x0, x1) y = x0[0] * x1[0] return [y] end def backward(gy) x0 = @inputs[0].data y0 = @inputs[1].data return [gy * x1, gy * x0] end end def mul(x0, x1) return Mul.new().call(x0, x1) end
class Variable ... def *(other) return mul(self, other) end def +(other) return add(self, other) end
こんな感じでRubyでも演算子のオーバーロードが可能なので実装した。以下のように、固定値を演算の2項目として扱うことができるようになる。
x = Variable.new([2.0]) y = x + [3.0] puts y
variable([5.0])
ここでRubyを使った場合の弱点が明らかになった。Pythonで使用している__radd__
に相当する演算子が存在せず、結局2.0 * Variable
のような実装は諦めた。
これ以外の演算子についてもどんどんオーバーロードしていった。
class Variable ... def *(other) return mul(self, other) end def /(other) return div(self, other) end def +(other) return add(self, other) end def -(other) return sub(self, other) end def -@ return neg(self) end def **(other) return pow(self, other) end
class Div < Function def forward(x0, x1) y = x0[0] / x1[0] return [y] end def backward(gy) x0 = @inputs[0].data y0 = @inputs[1].data return [gy / x1, gy * (-x0 / x1 ** 2.0)] end end def div(x0, x1) return Div.new().call(x0, x1) end class Neg < Function def forward(x) return x.map{|x| -x} end def backward(gy) return gx.map{|x| -x} end end def neg(x) return Neg.new().call(x) end class Pow < Function def initialize(c) @c = c end def forward(x) return x.map{|x| x ** @c} end def backward(gy) x = @inputs[0].data c = @c gx = c * x ** (c - 1) ** gy return gx end end def pow(x, c) return Pow.new(c).call(x) end
テストコードは以下。
x = Variable.new([2.0]) y = x + [3.0] puts y x = Variable.new([2.0]) y = -x puts y a = Variable.new([3.0]) b = Variable.new([2.0]) y = a - b puts y a = Variable.new([3.0]) b = Variable.new([2.0]) y = a / b puts y a = Variable.new([3.0]) y = a ** 3.0 puts y
variable([5.0]) variable([-2.0]) variable([1.0]) variable([1.5]) variable([27.0])
できた。想定通り。