FPGA開発日記

カテゴリ別記事インデックス https://msyksphinz.github.io/github_pages , English Version https://fpgadevdiary.hatenadiary.com/

ゼロから作るDeep Learning ③ のPython実装をRubyで作り直してみる(ステップ11/ステップ12)

ゼロから作るDeep Learning ❸ ―フレームワーク編

ゼロから作るDeep Learning ❸ ―フレームワーク編

  • 作者:斎藤 康毅
  • 発売日: 2020/04/20
  • メディア: 単行本(ソフトカバー)

ゼロから作るDeep Learning ③を買った。DezeroのPython実装をRubyに移植する形で独自に勉強している。次はステップ11とステップ12。

  • ステップ11:可変長引数のサポート

まずはcall()に可変長引数をサポートする。複数の引数を取ることができるようにするために、受け取った配列を受け取って配列を組みなおす。入力inputsに対してmapを適用してデータを取り出す。これに対してforward()を適用し、それをもう一度Variableをラップし直す。

 class Function
   def call(inputs)
     xs = inputs.map{|x| x.data}
     ys = forward(xs)
     outputs = ys.map{|y| Variable.new(y) }

     outputs.each {|output|
       output.set_creator(self)
     }
     @inputs = inputs
     @outputs = outputs
     return outputs
   end

この変更に基づいてAdd()関数を作ってみる。

 class Add < Function
   def forward(xs)
     y = xs[0][0] + xs[1][0]
     return [y]
   end
 end

forward()は複数長の配列を受け取り、その2つの要素を加算して返す。テストコードは以下のようになる。

 xs = [Variable.new([2.0]), Variable.new([3.0])]
 f = Add.new()
 ys = f.call(xs)
 y = ys[0]
 puts y.data
5.0

複数サイズの配列を受け取って、計算することができた。

  • ステップ11:可変長引数のサポート

Rubyでも可変長をサポートすることができる。Pythonと同じ表現形式かな?

 class Function
   def call(*inputs)
     xs = inputs.map{|x| x.data}
     ys = forward(xs)
     outputs = ys.map{|y| Variable.new(y) }

     outputs.each {|output|
       output.set_creator(self)
     }
     @inputs = inputs
     @outputs = outputs
     return outputs.size > 1 ? outputs : outputs[0]
   end
 class Add < Function
   def forward(xs)
     y = xs[0][0] + xs[1][0]
     return [y]
   end
 end

 x0 = Variable.new([2.0])
 x1 = Variable.new([3.0])
 f = Add.new()
 y = f.call(x0, x1)
 puts y.data

Add()クラスでは複数要素の可変長引数を1つの配列にまとめ、forward()に渡すことで計算できる。

5.0

続いてAdd()クラスの引数の受け取り方の改善を行う。

 class Add < Function
   def forward(x0, x1)
     y = x0[0] + x1[0]
     return y
   end
 end

これはあまり前回のコードと違いが無いのだが、とりあえず可変長で受け取ることができるだけ進歩している。call()も可変長引数、forward()も可変長で受け取ることができるので、以下のような表現が可能となる。

 def add(x0, x1)
   return Add.new().call(x0, x1)
 end

 x0 = Variable.new([2.0])
 x1 = Variable.new([3.0])
 y = add(x0, x1)
 puts y.data
5.0

できた。