FPGA開発日記

カテゴリ別記事インデックス https://msyksphinz.github.io/github_pages , English Version https://fpgadevdiary.hatenadiary.com/

ゼロから作るDeep Learning ③ のPython実装をRubyで作り直してみる(ステップ9/ステップ10)

ゼロから作るDeep Learning ❸ ―フレームワーク編

ゼロから作るDeep Learning ❸ ―フレームワーク編

  • 作者:斎藤 康毅
  • 発売日: 2020/04/20
  • メディア: 単行本(ソフトカバー)

ゼロから作るDeep Learning ③を買った。DezeroのPython実装をRubyに移植する形で独自に勉強している。次はステップ9とステップ10。

  • ステップ9:関数をより便利に使う

これまでSquareクラスやExpクラスは明示的に宣言して使用する必要があった。これは面倒なのでWapperを作って簡単に関数を接続できるようにする。

  • test09.rb
def square(x)
  return Square.new().call(x)
end

def exp(x)
  return Exp.new().call(x)
end

これによりニューラルネットワークを数珠つなぎのように表現できるようになった。

  x = Variable.new([0.5])
  y = square(exp(square(x)))    # 数珠つなぎのように関数の連結を表現する。

  y.grad = [1.0]
  y.backward()
  puts(x.grad)

さらに、backword()の前にy.gradの初期値1.0を設定しなくても良いようにする。

class Variable
    ...
  def backward()
    if @grad == nil then
      @grad = @data.clone.fill(1.0)
    end

それにしてもNumPyには便利な関数がたくさんあるようで、np.ones_like()などという超絶便利な関数はRubyで存在しなかったので、@dataと同じ形の変数をcloneし、fillで全部1.0を埋め込んでしまうという無理やりな作戦を使った。とりあえずこれで代用できた。

これに加えて、いくつか型の制限を設けて誤ったコードの実行を防ぐようにする。

class Variable
  def initialize(data)
    if data != nil then
      if not data.is_a?(Array) then
        raise TypeError, data.class.to_s + " is not supported."
      end
    end
...

Array型以外のデータを変数として使用しようとするとエラーを出すようにした。

これにより、誤ってArray以外のデータを入力するとエラーが出るようにする。

  x = Variable.new([0.5])
  y = square(exp(square(x)))
  y.backward()
  puts(x.grad)
Traceback (most recent call last):
    2: from ./step09.rb:116:in `<main>'
    1: from ./step09.rb:116:in `new'
./step09.rb:7:in `initialize': Float is not supported. (TypeError)
  • ステップ10:テストを作る

Rubyのテスト環境の構築方法はあまり詳しくないのだが、昔の資料を取り出してきて同じようにテスト環境を構築した。assert_equalによって想定する値と一致するかどうかをチェックする。

  • test10.rb
#!/usr/bin/env ruby

require 'test/unit'

...
    
class TestDezero < Test::Unit::TestCase
  def test_backward
    x = Variable.new([3.0])
    y = square(x)
    y.backward()
    assert_equal [6.0], x.grad
  end

  def test_gradient_check
    x = Variable.new([rand])
    y = square(x)
    y.backward()
    num_grad = numerical_diff(method(:square), x)
    flg = x.grad.zip(num_grad).map{|x,y|
      (x - y).abs / x.abs <= 1e-08
    }.all?
    assert_equal flg, true
  end
end

やはりNumPyには便利な機能があって、「ほとんど近しい」ということを示すためのnp.allclose()という関数が存在しているのだがこんなものRubyには存在しないので、無理やり差分の絶対値の比率が1e-08よりも小さくなることを確認するコードに変換した。これで自動チェックが行われるようになった。

./step10.rb 
Loaded suite ./step10
Started
..
Finished in 0.0026976 seconds.
---------------------------------------------------------------------------------------
2 tests, 2 assertions, 0 failures, 0 errors, 0 pendings, 0 omissions, 0 notifications
100% passed
---------------------------------------------------------------------------------------
741.40 tests/s, 741.40 assertions/s