TensorFlow Serving、思ったより騒がれてませんね(笑)まあ、API群というか、機械学習本体ではなく、その外部の話なので、ちょっと話題性は低いのかもしれない。
とはいえ、TensorFlow Servingを使ったチュートリアルもしっかり公開されている。今回は、そのチュートリアルを翻訳しながら、内容を進めていこう。
MNIST for ML Beginners を、TensorFlow Servingを使って実装する
MNISTは手書き文字認識んアプリケーションのことで、TensorFlowのチュートリアルにも掲載されている。
今回は、このMNISTをTensorFlowで学習したモデルをエクスポートし、C++に載せ変えてアプリケーションを動作させる方法について見てみる。
このチュートリアルを進めるにあたり、一応チュートリアルの文章を翻訳してみた。翻訳完全版は以下。
とは言えさっき1時間くらいでざっくり翻訳した文章なので、クオリティはあまり期待しないで欲しい。
TensorFlowを使ってMNISTの学習を行う
まずはエクスポートするモデルを作成する。これはTensorFlow上で行う。MNISTのモデルは既にチュートリアルで作成しているので、このモデルを実行して学習させるだけである。
学習が完了すると、今度はこのモデルをエクスポートし、外部プログラムから呼び出せる形に変換する。
saver = tf.train.Saver(sharded=True) # TensorFlowグラフの保存用オブジェクト model_exporter = exporter.Exporter(saver) # TensorFlowグラフのエクスポート用オブジェクト signature = exporter.classification_signature(input_tensor=x, scores_tensor=y) # エクスポート用オブジェクトの種類を指定 model_exporter.init(sess.graph.as_graph_def(), # グラフとシグニチャを引数としてエクスポータを初期化 default_graph_signature=signature) model_exporter.export(export_path, tf.constant(FLAGS.export_version), sess) # 実際にエクスポート
このときに、エクスポート先のファイルには、モデルの情報とテンソルの情報が含まれている。
実際に実行してみよう。
$ rm -rf /tmp/mnist_model $ bazel build //tensorflow_serving/example:mnist_export
ここからが長い。おそらく学習からエクスポートまで全部やっている (bazelでラップされているため、何がされているのか分かりにくい...あと、一度TensorFlowがフリーズしてしまった。原因不明。)
$ bazel-bin/tensorflow_serving/example/mnist_export /tmp/mnist_model Training model... ('Successfully downloaded', 'train-images-idx3-ubyte.gz', 9912422, 'bytes.') ('Extracting', '/tmp/train-images-idx3-ubyte.gz') ('Successfully downloaded', 'train-labels-idx1-ubyte.gz', 28881, 'bytes.') ('Extracting', '/tmp/train-labels-idx1-ubyte.gz') ('Successfully downloaded', 't10k-images-idx3-ubyte.gz', 1648877, 'bytes.') ('Extracting', '/tmp/t10k-images-idx3-ubyte.gz') ('Successfully downloaded', 't10k-labels-idx1-ubyte.gz', 4542, 'bytes.') ('Extracting', '/tmp/t10k-labels-idx1-ubyte.gz') training accuracy 0.9092 Done training! Exporting trained model to /tmp/mnist_model Done exporting!
C++側でTensorFlowのモデルをロードする
ここからが本題だ。C++で先程エクスポートしたTensorFlowのモデルをロードする。 ソースコードは以下のようになっている。
int main(int argc, char** argv) { ... tensorflow::SessionOptions session_options; std::unique_ptr<SessionBundle> bundle(new SessionBundle); const tensorflow::Status status = tensorflow::serving::LoadSessionBundleFromPath(session_options, bundle_path, bundle.get()); ... RunServer(FLAGS_port, std::move(bundle)); return 0; }
たぶんこれだけ見れば何となく何をやっているか想像が付くと思う。TensorFlowのセッションを格納するためのオブジェクトを用意し、インスタンス化している(スマートポインタを使っているあたりが素晴らしい)。 このモデルを実行するために、SessionBundle内には
- モデルおよびセッションの情報 : session
- バインドの情報 : meta_graph_def
が入っている。
モデル付きサーバの立ち上げおよび実行
最後に、サーバを立ち上げる。
$>bazel build //tensorflow_serving/example:mnist_inference (長々とビルド...) $>bazel-bin/tensorflow_serving/example/mnist_inference --port=9000 /tmp/mnist_model/00000001 $ bazel-bin/tensorflow_serving/example/mnist_inference --port=9000 /tmp/mnist_model/00000001 I tensorflow_serving/session_bundle/session_bundle.cc:109] Attempting to load a SessionBundle from: /tmp/mnist_model/00000001 I tensorflow_serving/session_bundle/session_bundle.cc:86] Running restore op for SessionBundle I tensorflow_serving/session_bundle/session_bundle.cc:157] Done loading SessionBundle I tensorflow_serving/example/mnist_inference.cc:163] Running...
およびクライアントプログラムからのリクエストを送信する。このプログラムはPythonで書かれているようだ。
$ bazel build //tensorflow_serving/example:mnist_client INFO: Found 1 target... Target //tensorflow_serving/example:mnist_client up-to-date: bazel-bin/tensorflow_serving/example/mnist_client INFO: Elapsed time: 0.482s, Critical Path: 0.04s $>bazel-bin/tensorflow_serving/example/mnist_client --num_tests=1000 --server=localhost:9000
、、っとここまで書いておいて、mnist_clientが、python-grpcがインストールされていないとお怒りだ。どうやら、15.04のUbuntuではgrpc-pythonがインストールできないらしい。 こりゃ、Ubuntu 14.10に逆戻りだな。。。
おまけ mnist_clientは何をしているのか?
上記のmnist_clientだが、中身は以下のようなPythonコードである。
for _ in range(num_tests): request = mnist_inference_pb2.MnistRequest() image, label = test_data_set.next_batch(1) for pixel in image[0]: request.image_data.append(pixel.item()) with cv: while result['active'] == concurrency: cv.wait() result['active'] += 1 result_future = stub.Classify.future(request, 5.0) # 5 seconds result_future.add_done_callback( lambda result_future, l=label[0]: done(result_future, l)) # pylint: disable=cell-var-from-loop
Pythonはあまり詳しくないのだが、とにかくリクエストをゴリゴリ投げていっているコードだなあ。
とかブログを書いている間に、Servingのビルドの2回目が終了しないので、今日はここまで。