FPGA開発日記

カテゴリ別記事インデックス https://msyksphinz.github.io/github_pages , English Version https://fpgadevdiary.hatenadiary.com/

ゼロから作るDeep Learning ③ のDezeroをRubyで作り直してみる(ステップ43/ステップ44)

ゼロから作るDeep Learning ❸ ―フレームワーク編

ゼロから作るDeep Learning ❸ ―フレームワーク編

  • 作者:斎藤 康毅
  • 発売日: 2020/04/20
  • メディア: 単行本(ソフトカバー)

ゼロから作るDeep Learning ③のDezero実装、勉強のためRubyでの再実装に挑戦している。今回はステップ43とステップ44。

いよいよニューラルネットワークを作成する。ここではsin関数を近似しているようなデータセットを作成し、ニューラルネットワークを作成して学習を行い、sin関数を近似できるようにする。ここで新たに作成するのはLinear関数とSigmoid関数だ。どちらもノードとして定義しても良いのだが最初はシンプルな形式で定義する。これで実装があっていることを確認したい。

def linear_simple(x, w, b=nil)
  x = as_variable(x)
  w = as_variable(w)
  t = matmul(x, w)
  if b == nil then
    return t
  end
  y = t + b
  t.data = nil  # Release t.data (ndarray) for memory efficiency
  return y
end


def sigmoid_simple(x)
  np = Numpy
  x = as_variable(x)
  y = as_variable(np.array(1.0)) / (as_variable(np.array(1.0)) + exp(-x))
  return y
end

ニューラルネットワークを作成する。

np.random.seed(0)
x = np.random.rand(100, 1)
y = np.sin(2 * np.pi * x) + np.random.rand(100, 1)
x = Variable.new(x)
y = Variable.new(y)

i = 1
h = 10
o = 1
$w1 = Variable.new(0.01 * np.random.randn(i, h))
$b1 = Variable.new(np.zeros(h))
$w2 = Variable.new(0.01 * np.random.randn(h, o))
$b2 = Variable.new(np.zeros(o))

def predict(x)
  y = linear_simple(x, $w1, $b1)
  y = sigmoid_simple(y)
  y = linear_simple(y, $w2, $b2)
  return y
end

def mean_squared_error(x0, x1)
  np = Numpy
  diff = x0 - x1
  return sum(diff ** 2) / as_variable(np.array(diff.size))
end

lr = 0.2
iters = 10000

for i in 0..(iters-1) do
  y_pred = predict(x)
  loss = mean_squared_error(y, y_pred)

  $w1.cleargrad()
  $b1.cleargrad()
  $w2.cleargrad()
  $b2.cleargrad()
  loss.backward()

  $w1.data -= lr * $w1.grad.data
  $b1.data -= lr * $b1.grad.data
  $w2.data -= lr * $w2.grad.data
  $b2.data -= lr * $b2.grad.data

  if i % 1000 == 0 then
    puts loss
  end
end

10000回回して、1000回に一回誤差を出力するようにする。結果は以下のようになった。

variable(0.8473695850105871)
variable(0.2514286285183607)
variable(0.24759485466749873)
variable(0.23786120447054818)
variable(0.21222231333102928)
variable(0.16742181117834176)
variable(0.0968193261999268)
variable(0.07849528290602333)
variable(0.07749729552991155)
variable(0.07722132399559321)

グラフを出せてはいないが、Python版Dezeroと全く同じ結果になった。これで良しとする。

問題となるのはパフォーマンスだが、Python版と比べるとかなり遅い。原因はちょっとわからないのだが、そもそもRubyの実行速度が遅いのだろうか。。。PyCallなどを多用しているので仕方がないところではある。

Python版では、レイヤクラスにParameterクラスのインスタンスをメンバとして追加すると勝手に__set_attr__が呼び出されてパラメータ一覧を更新してくれる。しかしRubyにはそんな機能が無いので、やむを得ずinstance_variable_set()メンバをオーバライドする形を取った。しかしこれでも想定通りに動いてくれないので、結局レイヤクラスにメンバを追加するときはinstance_variable_setを明示的に呼び出すというかなりダサい方式になってしまっている。

class Layer
  def initialize()
    @_params = Set.new()
  end

  def instance_variable_set(name, value)
    if value.is_a?(Parameter)
      @_params.add(name)
    end
    super
  end

  def call(*inputs)
    outputs = self.forward(*inputs)
    if not outputs.is_a?(Array) then
      outputs = [outputs]
    end
    @inputs  = inputs.map{|input| WeakRef.new(input)}
    @outputs = outputs.map{|output| WeakRef.new(output)}
    return outputs.size > 1 ? outputs : outputs[0]
  end

  def forward(x)
    raise NotImplementedError
  end

  def params()
    return @_params.map {|name| instance_variable_get(name)}
  end

  def cleargrads()
    self.params().each{|param|
      param.cleargrad()
    }
  end

  attr_accessor :_params
  end

使い方。p1 - p4をレイヤクラスに追加する場合にはinstance_variable_set()を明示的に呼び出している。あまりカッコよくはない。

begin
  layer = Layer.new()

  layer.instance_variable_set(:@p1, Parameter.new(np.array(1)))
  layer.instance_variable_set(:@p2, Parameter.new(np.array(2)))
  layer.instance_variable_set(:@p3, Variable.new(np.array(3)))
  layer.instance_variable_set(:@p4, 'test')

  puts layer._params
  puts '--------------'
  layer._params.each{|name|
    puts name, layer.instance_variable_get(name)
  }
end

実行結果。パラメータインスタンスp1p2だけなので、その2つのみが抽出されて表示される。

#<Set: {:@p1, :@p2}>
--------------
@p1
variable(1)
@p2
variable(2)

ではこれをベースにしてLinearクラスを作る。前回作ったLinearクラスと区別がつかないので、明示的に区別するためにLinearLayerクラスとした。LinearLayerクラスはLinear関数と同様の機能を持つので、メンバ変数として@w@bを追懐している。追加する方法はもちろんinstance_variable_set()を使用している。

class LinearLayer < Layer
  np = Numpy
  def initialize(out_size, nobias=false, in_size=nil)
    np = Numpy
    super()
    @in_size = in_size
    @out_size = out_size
    @dtype = np.float32

    instance_variable_set(:@w, Parameter.new(nil, name:'W'))

    if @in_size != nil then
      self._init_W()
    end

    if nobias then
      instance_variable_set(:@b, nil)
    else
      instance_variable_set(:@b, Parameter.new(np.zeros(out_size, np.float32), name:'b'))
    end

  end

  def _init_W()
    np = Numpy
    i = @in_size
    o = @out_size
    w_data = np.random.randn(i, o).astype(@dtype) * np.sqrt(1.0 / i)
    @w.data = w_data
  end

  def forward(x)
    if @w.data.is_a?(NilClass) then
      @in_size = x.shape[1]
      self._init_W()
    end

    y = linear_simple(x, @w, @b)
    return y
  end
end

forward()メソッドでは、前回作成したlinearクラスのデバッグが上手く行っていないので、代わりにlinear_simpleを使用する。実装は以下のようになっている。predict()の実装ではLinearLayersigmoid_simpleを使ってニューラルネットワークを構築している。そのうえで10000回のループを回して学習を行わせている。

  def mean_squared_error(x0, x1)
    np = Numpy
    diff = x0 - x1
    return sum(diff ** 2) / as_variable(np.array(diff.size))
  end

  np.random.seed(0)
  x = np.random.rand(100, 1)
  y = np.sin(2 * np.pi * x) + np.random.rand(100, 1)
  x = Variable.new(x)
  y = Variable.new(y)

  $l1 = LinearLayer.new(10)
  $l2 = LinearLayer.new(1)

  def predict(x)
    y = $l1.call(x)
    y = sigmoid_simple(y)
    y = $l2.call(y)
    return y
  end

  lr = 0.2
  iters = 10000

  for i in 0..(iters-1) do
    y_pred = predict(x)
    loss = mean_squared_error(y, y_pred)

    $l1.cleargrads()
    $l2.cleargrads()
    loss.backward()

    [$l1, $l2].each{|l|
      l.params().each{|p|
        p.data -= lr * p.grad.data
      }
    }

    if i % 1000 == 0 then
      puts loss
    end
  end

実行結果。誤差が削減されていることが分かる。Python版Dezeroの実行結果と一致した。ニューラルネットワークの完成だ。

variable(0.8165178492839196)
variable(0.2499028013724818)
variable(0.24609873705372717)
variable(0.23721585190665512)
variable(0.2079321578201881)
variable(0.1231191944394262)
variable(0.07888168068357643)
variable(0.07666129175490086)
variable(0.0763503210507124)
variable(0.07616987350656905)