FPGA開発日記

カテゴリ別記事インデックス https://msyksphinz.github.io/github_pages , English Version https://fpgadevdiary.hatenadiary.com/

ゼロから作るDeep Learning ③ のDezeroをRubyで作り直してみる(ステップ41/ステップ42)

ゼロから作るDeep Learning ❸ ―フレームワーク編

ゼロから作るDeep Learning ❸ ―フレームワーク編

  • 作者:斎藤 康毅
  • 発売日: 2020/04/20
  • メディア: 単行本(ソフトカバー)

ゼロから作るDeep Learning ③のDezero実装、勉強のためRubyでの再実装に挑戦している。今回はステップ41とステップ42。

  • ステップ41:ベクトルの内積と行列積を実装する。どちらもNumpyレベルでは実装されているので、forward()backward()の作り方について考える。forward()の実装はそのままだが、backward()つまり逆伝搬の扱い方がやはり難しい。式はとりあえず眺めるだけで、何となく内積によって求められるということが分かる。y = x*Wという行列積に対してその微分は、xWの両方に対して出力する必要があるため、MatMulクラスのforward()backward()は以下のようになるらしい。細かい式については良く分からんかった!行列の微分なんて学生の時以来なので忘れてしまっている。
class MatMul < Function
  def forward(x, w)
    y = x.dot(w)
    return y
  end

  def backward(gy)
    x = @inputs[0]
    w = @inputs[1]
    gx = matmul(gy, w.T)
    gW = matmul(x.T, gy)
    return gx, gW
  end
end

def matmul(x, w)
  return MatMul.new().call(x, w)
end

テスト。以下のような行列を作って計算し、その形状を確認する。

require 'numpy'
np = Numpy

begin
  x = Variable.new(np.random.randn(2, 3))
  w = Variable.new(np.random.randn(3, 4))
  y = matmul(x, w)
  y.backward()

  puts x.grad.shape.to_s
  puts w.grad.shape.to_s
end
(2, 3)
(3, 4)

一応想定通りの形状になった。これで実装は良いものとする。

  • ステップ42:線形回帰を実装する。やっと機械学習っぽくなってきた。ここでは与えられたデータセットに対して線形回帰の式を導出する。Pythonの例ではグラフも出力することができるが、ここでは面倒なので省略。

線形回帰のデータセットは以下のようにして作り上げる。ここでは100個のデータを用意している、のかな?

np.random.seed(0)
x = np.random.rand(100, 1)
y = 5 + 2 * x + np.random.rand(100, 1)

で、線形回帰の作り方だが、現在の予測値に対して誤差を計算し、それを縮めていくという方式を取るわけだ。誤差の計算方法としてmean_squared_error()関数を用意して、さらに現在の値を計算するためにpredict()を用意する。

def predict(x)
  y = matmul(x, $w) + $b
  return y
end

def mean_squared_error(x0, x1)
  diff = x0 - x1
  return sum(diff ** 2) / diff.size
end

行列積をつかttえ線形回帰を実装していくわけで、100回繰り返しながら誤差を小さくしていき、最終的にWbがどのような値になっていくかを観察する。

lr = 0.1
iters = 100

for i in 0..(iters-1) do
  y_pred = predict(x)
  loss = mean_squared_error(y, y_pred)

  $w.cleargrad()
  $b.cleargrad()
  loss.backward()

  $w.data -= lr * $w.grad.data
  $b.data -= lr * $b.grad.data

  puts $w, $b, loss
end
...
variable([[2.11807369]])
variable([5.46608905])
variable([0.07908607])

できた!最終的にW = 2.11807369b = 5.46608905、誤差は0.07908607となった。グラフは表示できないけど、これは正しい計算だ。傾き2に対して切片5をベースにサンプルをプロットしているので、大体あっていると言える。