ゼロから作るDeep Learning ❸ ―フレームワーク編
- 作者:斎藤 康毅
- 発売日: 2020/04/20
- メディア: 単行本(ソフトカバー)
ゼロから作るDeep Learning ③のDezero実装、勉強のためRubyでの再実装に挑戦している。今回はステップ39とステップ40。
- ステップ39:和を求める関数を実装する。
Sum
クラスの作成。Sum
クラスは各要素の値をすべて足し込む。この時にaxis
とkeepdims
オプションを追加している。
class Sum < Function def initialize(axis, keepdims) @axis = axis @keepdims = keepdims end def forward(x) @x_shape = x.shape y = PyCall::eval("#{x}.sum(axis=#{@axis}, keepdims=#{@keepdims}") return y end def backward(gy) gy = reshape_sum_backward(gy, @x_shape, @axis, @keepdims) gx = broadcast_to(gy, @x_shape) return gx end end def sum(x, axis=nil, keepdims=false) return Sum.new(axis, keepdims).call(x) end
begin x = Variable.new(np.array([1, 2, 3, 4, 5, 6])) y = sum(x) y.backward() puts y puts x.grad end begin x = Variable.new(np.array([[1, 2, 3], [4, 5, 6]])) y = sum(x) y.backward() puts y puts x.grad end
実行結果は以下のようになった。とりあえず上手く行ったようだ。
variable(21) variable([1 1 1 1 1 1]) variable(21) variable([[1 1 1]
今回の実装に当たり、reshape_sum_backward()
を実装し直す必要があった。もう面倒なのでPyCall
を使いまくってPythonの機能をフル活用する。
def reshape_sum_backward(gy, x_shape, axis, keepdims) ndim = PyCall::len(x_shape) tupled_axis = axis if axis == nil then tupled_axis = nil elsif not PyCall::eval("hasattr(#{axis}, 'len')") then tupled_axis = [axis] end shape = [] if not (ndim == 0 or tupled_axis == nil or keepdims) then actual_axis = tupled_axis.each{|a| a >= 0 ? a : a + ndim } shape = [gy.shape] for a in PyCall::eval("sorted(#{actual_axis})") do shape.insert(a, 1) end else shape = gy.shape end gy = gy.reshape(shape) # reshape return gy end
- ステップ40:今度はブロードキャストを行う関数を定義する。
class BroadcastTo < Function def initialize(shape) @shape = shape end def forward(x) np = Numpy @x_shape = x.shape y = np.broadcast_to(x, @shape) return y end def backward(gy) gx = Kernel.send(:sum_to, gy, @x_shape) gx = sum_to(gy, @x_shape) return gx end end def broadcast_to(x, shape) if x.shape == shape then return as_variable(x) end return BroadcastTo.new(shape).call(x) end
def util_sum_to(x, shape) ndim = PyCall::eval("len(#{shape})") lead = x.ndim - ndim lead_axis = PyCall.eval("tuple(range(#{lead}))") axis = PyCall.eval("tuple([i + #{lead} for i, sx in enumerate(#{shape}) if sx == 1])") y = x.sum(lead_axis + axis, keepdims:true) if lead > 0 then y = y.squeeze(lead_axis) end return y end
begin x0 = Variable.new(np.array([1, 2, 3])) x1 = Variable.new(np.array([10])) y = x0 + x1 puts y y.backward puts x1.grad end
variable([11 12 13]) variable([3])