FPGA開発日記

カテゴリ別記事インデックス https://msyksphinz.github.io/github_pages , English Version https://fpgadevdiary.hatenadiary.com/

ゼロから作るDeep Learning ③ のDezeroをRubyで作り直してみる(PyCallでのNumpyインポートの問題)

ゼロから作るDeep Learning ❸ ―フレームワーク編

ゼロから作るDeep Learning ❸ ―フレームワーク編

  • 作者:斎藤 康毅
  • 発売日: 2020/04/20
  • メディア: 単行本(ソフトカバー)

ゼロから作るDeep Learning ③のDezero実装、勉強のためRubyでの再実装に挑戦している。ステップ47では新しいFunctionを実装しなければならないのだが、ここでRubyでPyCallを使ってNumpyを呼び出す方式で躓いてしまった。

Function.rbの中でnp.add.atを使い関数を構築しなければならないのだが、これをPyCallを使って呼び出す方法が分からない。

class GetItemGrad < Function
  def initialize(slices, in_shape)
    @slices = slices
    @in_shape = in_shape
  end

  def forward(gy)
    np = Numpy
    gx = np.zeros(@in_shape, dtype=gy.dtype)
    np.add.at(gx, @slices, gy)
  end

  def backward(ggx)
    return get_item(ggx, @slices)
  end
end

np.add.atが呼び出されない。具体的には以下のようなエラーが出力された。

/home/msyksphinz/.gem/ruby/2.5.0/gems/pycall-1.3.1/lib/pycall/pyobject_wrapper.rb:43:in `add': <class 'ValueError'>: invalid number of arguments (PyCall::PyError)

いったん切り出して状況を確認しようとしたが、これだけ絞ってもまだ問題が発生する。

#!/usr/bin/env ruby

require 'pycall/import'
include PyCall::Import
pyimport :numpy

a = numpy.array([1, 2, 3, 4])
numpy.add.at(a, [0, 1], 1)
puts a
/home/msyksphinz/.gem/ruby/2.5.0/gems/pycall-1.3.1/lib/pycall/pyobject_wrapper.rb:43:in `add': <class 'ValueError'>: invalid number of arguments (PyCall::PyError)

うーん、何故なんだろう?

a = numpy.array([1, 2, 3, 4])
PyCall::eval("numpy.add.at(#{a}, [0, 1], 1)")
puts a

これでもダメ。冷静に考えてみるとaは上書きされるんだから当たり前か。PyCallをもう少し調べてみるが、もしこれ以上は難しければRuby版はそろそろ限界かな。。。