ゼロから作るDeep Learning ❸ ―フレームワーク編
- 作者:斎藤 康毅
- 発売日: 2020/04/20
- メディア: 単行本(ソフトカバー)
ゼロから作るDeep Learning ③を買った。DezeroのPython実装をRubyに移植する形で独自に勉強している。次はステップ9とステップ10。
- ステップ9:関数をより便利に使う
これまでSquare
クラスやExp
クラスは明示的に宣言して使用する必要があった。これは面倒なのでWapperを作って簡単に関数を接続できるようにする。
test09.rb
def square(x) return Square.new().call(x) end def exp(x) return Exp.new().call(x) end
これによりニューラルネットワークを数珠つなぎのように表現できるようになった。
x = Variable.new([0.5]) y = square(exp(square(x))) # 数珠つなぎのように関数の連結を表現する。 y.grad = [1.0] y.backward() puts(x.grad)
さらに、backword()
の前にy.grad
の初期値1.0を設定しなくても良いようにする。
class Variable ... def backward() if @grad == nil then @grad = @data.clone.fill(1.0) end
それにしてもNumPyには便利な関数がたくさんあるようで、np.ones_like()
などという超絶便利な関数はRubyで存在しなかったので、@data
と同じ形の変数をclone
し、fill
で全部1.0を埋め込んでしまうという無理やりな作戦を使った。とりあえずこれで代用できた。
これに加えて、いくつか型の制限を設けて誤ったコードの実行を防ぐようにする。
class Variable def initialize(data) if data != nil then if not data.is_a?(Array) then raise TypeError, data.class.to_s + " is not supported." end end ...
Array
型以外のデータを変数として使用しようとするとエラーを出すようにした。
これにより、誤ってArray以外のデータを入力するとエラーが出るようにする。
x = Variable.new([0.5]) y = square(exp(square(x))) y.backward() puts(x.grad)
Traceback (most recent call last): 2: from ./step09.rb:116:in `<main>' 1: from ./step09.rb:116:in `new' ./step09.rb:7:in `initialize': Float is not supported. (TypeError)
- ステップ10:テストを作る
Rubyのテスト環境の構築方法はあまり詳しくないのだが、昔の資料を取り出してきて同じようにテスト環境を構築した。assert_equal
によって想定する値と一致するかどうかをチェックする。
test10.rb
#!/usr/bin/env ruby require 'test/unit' ... class TestDezero < Test::Unit::TestCase def test_backward x = Variable.new([3.0]) y = square(x) y.backward() assert_equal [6.0], x.grad end def test_gradient_check x = Variable.new([rand]) y = square(x) y.backward() num_grad = numerical_diff(method(:square), x) flg = x.grad.zip(num_grad).map{|x,y| (x - y).abs / x.abs <= 1e-08 }.all? assert_equal flg, true end end
やはりNumPyには便利な機能があって、「ほとんど近しい」ということを示すためのnp.allclose()
という関数が存在しているのだがこんなものRubyには存在しないので、無理やり差分の絶対値の比率が1e-08よりも小さくなることを確認するコードに変換した。これで自動チェックが行われるようになった。
./step10.rb Loaded suite ./step10 Started .. Finished in 0.0026976 seconds. --------------------------------------------------------------------------------------- 2 tests, 2 assertions, 0 failures, 0 errors, 0 pendings, 0 omissions, 0 notifications 100% passed --------------------------------------------------------------------------------------- 741.40 tests/s, 741.40 assertions/s