疎行列ベクトル積について調べなくてはならなくなったので、自分の備忘録として残しておく。
行列ベクトル積というのは普通にのことを指すが、行列Aが疎である場合、つまり多くの要素が0である場合には、メモリを節約するために0出ない要素のみを記憶し、その位置をインデックスで保持する。 これを疎行列ベクトル積と呼び、SpMV (Sparse Matrix-Vector Multiplication)と呼ばれている。
例えば、以下のの行列と4列のベクトルの行列積を取る場合、オレンジの部分のみに値が入っており、それ以外の部分は0であるとする。 各要素に入っている値はグローバルなインデックスとする。
この場合、各行列の要素をすべて端に寄せ、その位置を覚える。
ptr
: 各列の先頭要素が何番目のインデックスから始まるかを示す。idx
: 各列において、それぞれの要素が何番目の行列の場所に入っているのかを示す。
これをC言語で書くとこんな感じになる。
void spmv(int r, const double* val, const int* idx, const double* x, const int* ptr, double* y) { for (int i = 0; i < r; i++) { int k; for (k = ptr[i]; k < ptr[i+1]; k++) { y[i] += val[k]*x[idx[k]]; } } }
これのテストをしたいのだが、RISC-Vのテストベンチマークセットに便利なScalaのコードを見つけたので、これを真似する。 というか、これを実行しようとしたらなぜかScalaがランタイムエラーを出したので、仕方がないのでRubyで書き換えた。
#!/usr/bin/env ruby m = ARGV[0].to_i n = ARGV[1].to_i approx_nnz = ARGV[2].to_i pnnz = approx_nnz.to_f/(m*n) idx = Array[] p = [0] m.times {|i| n.times {|j| if rand() < pnnz then idx.push (j) end } p.push (idx.size) } nnz = idx.size v = Array.new(n) { rand(1000) } d = Array.new(nnz) { rand(1000) } def printVec(t, name, data) printf("const %s %s[%d] = {", t, name, data.length) data.each_with_index {|d, index| print " " + d.to_s puts "," if index != data.size-1 } print("};\n\n") end def spmv(p, d, idx, v) y = Array.new for i in 0..(p.length-1) do yi = 0 limit = 0 if i == p.length-1 then limit = idx.size else limit = p[i+1] end for k in p[i]..(limit-1) do yi = yi + d[k]*v[idx[k]] end y[i] = yi end return y end printf("#define R %d\n", m) printf("#define C %d\n", n) printf("#define NNZ %d\n", nnz) printVec("double", "val", d) printVec("uint64_t", "idx", idx) printVec("double", "x", v) printVec("uint64_t", "ptr", p) printVec("double", "verify_data", spmv(p, d, idx, v))
$ ./spmv_gendata.rb 4 4 8
#define R 4 #define C 4 #define NNZ 9 const double val[9] = { 236, 776, 140, 252, 760, 5, 829, 723, 383}; const int idx[9] = { 0, 3, 0, 1, 3, 0, 1, 2, 3}; const double x[4] = { 568, 605, 16, 209}; const int ptr[5] = { 0, 1, 1, 4, 8}; const double verify_data[5] = { 134048, 0, 394164, 674793, 80047};