FPGA開発日記

カテゴリ別記事インデックス https://msyksphinz.github.io/github_pages , English Version https://fpgadevdiary.hatenadiary.com/

疎行列ベクトル積 (SpMV) に関する調査

疎行列ベクトル積について調べなくてはならなくなったので、自分の備忘録として残しておく。

行列ベクトル積というのは普通に y=Axのことを指すが、行列Aが疎である場合、つまり多くの要素が0である場合には、メモリを節約するために0出ない要素のみを記憶し、その位置をインデックスで保持する。 これを疎行列ベクトル積と呼び、SpMV (Sparse Matrix-Vector Multiplication)と呼ばれている。

例えば、以下の 4\times 4の行列と4列のベクトルの行列積を取る場合、オレンジの部分のみに値が入っており、それ以外の部分は0であるとする。 各要素に入っている値はグローバルなインデックスとする。

この場合、各行列の要素をすべて端に寄せ、その位置を覚える。

  • ptr : 各列の先頭要素が何番目のインデックスから始まるかを示す。
  • idx : 各列において、それぞれの要素が何番目の行列の場所に入っているのかを示す。

これをC言語で書くとこんな感じになる。

void spmv(int r, const double* val, const int* idx, const double* x,
          const int* ptr, double* y)
{
  for (int i = 0; i < r; i++)
  {
    int k;
    for (k = ptr[i]; k < ptr[i+1]; k++) {
      y[i] += val[k]*x[idx[k]];
    }
  }
}

これのテストをしたいのだが、RISC-Vのテストベンチマークセットに便利なScalaのコードを見つけたので、これを真似する。 というか、これを実行しようとしたらなぜかScalaがランタイムエラーを出したので、仕方がないのでRubyで書き換えた。

github.com

#!/usr/bin/env ruby

m = ARGV[0].to_i
n = ARGV[1].to_i
approx_nnz = ARGV[2].to_i

pnnz = approx_nnz.to_f/(m*n)
idx = Array[]
p = [0]

m.times {|i|
  n.times {|j|
    if rand() < pnnz then
      idx.push (j)
    end
  }
  p.push (idx.size)
}

nnz = idx.size
v = Array.new(n) { rand(1000) }
d = Array.new(nnz) { rand(1000) }

def printVec(t, name, data)
  printf("const %s %s[%d] = {", t, name, data.length)
  data.each_with_index {|d, index|
    print "  " + d.to_s
    puts "," if index != data.size-1
  }
  print("};\n\n")
end

def spmv(p, d, idx, v)
  y = Array.new
  for i in 0..(p.length-1) do
    yi = 0
    limit = 0
    if i == p.length-1 then
      limit = idx.size
    else
      limit = p[i+1]
    end
    for k in p[i]..(limit-1) do
      yi = yi + d[k]*v[idx[k]]
    end
    y[i] = yi
  end
  return y
end

printf("#define R %d\n", m)
printf("#define C %d\n", n)
printf("#define NNZ %d\n", nnz)
printVec("double", "val", d)
printVec("uint64_t", "idx", idx)
printVec("double", "x", v)
printVec("uint64_t", "ptr", p)
printVec("double", "verify_data", spmv(p, d, idx, v))
$ ./spmv_gendata.rb 4 4 8
#define R 4
#define C 4
#define NNZ 9
const double val[9] = {  236,
  776,
  140,
  252,
  760,
  5,
  829,
  723,
  383};

const int idx[9] = {  0,
  3,
  0,
  1,
  3,
  0,
  1,
  2,
  3};

const double x[4] = {  568,
  605,
  16,
  209};

const int ptr[5] = {  0,
  1,
  1,
  4,
  8};

const double verify_data[5] = {  134048,
  0,
  394164,
  674793,
  80047};