FPGA開発日記

FPGAというより、コンピュータアーキテクチャかもね! カテゴリ別記事インデックス https://sites.google.com/site/fpgadevelopindex/

TensorFlow Servingのチュートリアルを翻訳してMNISTにトライ

TensorFlow Serving、思ったより騒がれてませんね(笑)まあ、API群というか、機械学習本体ではなく、その外部の話なので、ちょっと話題性は低いのかもしれない。

とはいえ、TensorFlow Servingを使ったチュートリアルもしっかり公開されている。今回は、そのチュートリアルを翻訳しながら、内容を進めていこう。

MNIST for ML Beginners を、TensorFlow Servingを使って実装する

tensorflow.github.io

MNISTは手書き文字認識んアプリケーションのことで、TensorFlowのチュートリアルにも掲載されている。

MNIST For ML Beginners

今回は、このMNISTをTensorFlowで学習したモデルをエクスポートし、C++に載せ変えてアプリケーションを動作させる方法について見てみる。

このチュートリアルを進めるにあたり、一応チュートリアルの文章を翻訳してみた。翻訳完全版は以下。

github.com

とは言えさっき1時間くらいでざっくり翻訳した文章なので、クオリティはあまり期待しないで欲しい。

TensorFlowを使ってMNISTの学習を行う

まずはエクスポートするモデルを作成する。これはTensorFlow上で行う。MNISTのモデルは既にチュートリアルで作成しているので、このモデルを実行して学習させるだけである。

学習が完了すると、今度はこのモデルをエクスポートし、外部プログラムから呼び出せる形に変換する。

saver = tf.train.Saver(sharded=True)       # TensorFlowグラフの保存用オブジェクト
model_exporter = exporter.Exporter(saver)  # TensorFlowグラフのエクスポート用オブジェクト
signature = exporter.classification_signature(input_tensor=x, scores_tensor=y)   # エクスポート用オブジェクトの種類を指定
model_exporter.init(sess.graph.as_graph_def(),   # グラフとシグニチャを引数としてエクスポータを初期化 
                    default_graph_signature=signature)
model_exporter.export(export_path, tf.constant(FLAGS.export_version), sess)  # 実際にエクスポート

このときに、エクスポート先のファイルには、モデルの情報とテンソルの情報が含まれている。

実際に実行してみよう。

$ rm -rf /tmp/mnist_model
$ bazel build //tensorflow_serving/example:mnist_export

ここからが長い。おそらく学習からエクスポートまで全部やっている (bazelでラップされているため、何がされているのか分かりにくい...あと、一度TensorFlowがフリーズしてしまった。原因不明。)

$ bazel-bin/tensorflow_serving/example/mnist_export /tmp/mnist_model
Training model...
('Successfully downloaded', 'train-images-idx3-ubyte.gz', 9912422, 'bytes.')
('Extracting', '/tmp/train-images-idx3-ubyte.gz')
('Successfully downloaded', 'train-labels-idx1-ubyte.gz', 28881, 'bytes.')
('Extracting', '/tmp/train-labels-idx1-ubyte.gz')
('Successfully downloaded', 't10k-images-idx3-ubyte.gz', 1648877, 'bytes.')
('Extracting', '/tmp/t10k-images-idx3-ubyte.gz')
('Successfully downloaded', 't10k-labels-idx1-ubyte.gz', 4542, 'bytes.')
('Extracting', '/tmp/t10k-labels-idx1-ubyte.gz')
training accuracy 0.9092
Done training!
Exporting trained model to /tmp/mnist_model
Done exporting!

C++側でTensorFlowのモデルをロードする

ここからが本題だ。C++で先程エクスポートしたTensorFlowのモデルをロードする。 ソースコードは以下のようになっている。

int main(int argc, char** argv) {
  ...

  tensorflow::SessionOptions session_options;
  std::unique_ptr<SessionBundle> bundle(new SessionBundle);
  const tensorflow::Status status =
      tensorflow::serving::LoadSessionBundleFromPath(session_options,
                                                     bundle_path, bundle.get());
  ...

  RunServer(FLAGS_port, std::move(bundle));

  return 0;
}

たぶんこれだけ見れば何となく何をやっているか想像が付くと思う。TensorFlowのセッションを格納するためのオブジェクトを用意し、インスタンス化している(スマートポインタを使っているあたりが素晴らしい)。 このモデルを実行するために、SessionBundle内には

  • モデルおよびセッションの情報 : session
  • バインドの情報 : meta_graph_def

が入っている。

モデル付きサーバの立ち上げおよび実行

最後に、サーバを立ち上げる。

$>bazel build //tensorflow_serving/example:mnist_inference
(長々とビルド...)
$>bazel-bin/tensorflow_serving/example/mnist_inference --port=9000 /tmp/mnist_model/00000001
$ bazel-bin/tensorflow_serving/example/mnist_inference --port=9000 /tmp/mnist_model/00000001
I tensorflow_serving/session_bundle/session_bundle.cc:109] Attempting to load a SessionBundle from: /tmp/mnist_model/00000001
I tensorflow_serving/session_bundle/session_bundle.cc:86] Running restore op for SessionBundle
I tensorflow_serving/session_bundle/session_bundle.cc:157] Done loading SessionBundle
I tensorflow_serving/example/mnist_inference.cc:163] Running...

およびクライアントプログラムからのリクエストを送信する。このプログラムはPythonで書かれているようだ。

$ bazel build //tensorflow_serving/example:mnist_client
INFO: Found 1 target...
Target //tensorflow_serving/example:mnist_client up-to-date:
  bazel-bin/tensorflow_serving/example/mnist_client
INFO: Elapsed time: 0.482s, Critical Path: 0.04s

$>bazel-bin/tensorflow_serving/example/mnist_client --num_tests=1000 --server=localhost:9000

、、っとここまで書いておいて、mnist_clientが、python-grpcがインストールされていないとお怒りだ。どうやら、15.04のUbuntuではgrpc-pythonがインストールできないらしい。 こりゃ、Ubuntu 14.10に逆戻りだな。。。

おまけ mnist_clientは何をしているのか?

上記のmnist_clientだが、中身は以下のようなPythonコードである。

  for _ in range(num_tests):
    request = mnist_inference_pb2.MnistRequest()
    image, label = test_data_set.next_batch(1)
    for pixel in image[0]:
      request.image_data.append(pixel.item())
    with cv:
      while result['active'] == concurrency:
        cv.wait()
      result['active'] += 1
    result_future = stub.Classify.future(request, 5.0)  # 5 seconds
    result_future.add_done_callback(
        lambda result_future, l=label[0]: done(result_future, l))  # pylint: disable=cell-var-from-loop

Pythonはあまり詳しくないのだが、とにかくリクエストをゴリゴリ投げていっているコードだなあ。

とかブログを書いている間に、Servingのビルドの2回目が終了しないので、今日はここまで。