iTAC_Technical_Documents

アイタックソリューションズ株式会社

ブログ名

第4回 TensorRT化について

前回までのモデルはPythonを使用していました。しかし、PythonはCやC++に比べると遅いので、書き換えることで高速化することが出来ます。

ここで便利なのが以前紹介したTensorRTです。このC++ライブラリは高速推論可能でJetsonに適した形に最適化してくれるので、高速化、小型化の両方の効果を見込めます。

ただTesnorRTに関しては文献が多くないので事前に調査を行いました。TensorRTはJetson、TESLA、DRIVEを対象としていますが、GeForceでも動作するようです。以下のテストではRTX2070を使用しています。

TensorRT化のテスト

TenosrRTは学習済みモデルを高速化するライブラリで、学習をする機能はついていません。そのため、他のライブラリで学習したモデルをTensorRT化するのが基本です。

f:id:iTD_GRP:20191202224740j:plain

TensorRT化ですが幾つか方法があり、二分すると、TenosrRT上で自分で定義する方法と、他のライブラリ(Caffe、TensorFlow、PyTorchなど)からファイルを読み込む方法です。

今回はKeras(TensorFlow)のファイルを読み込む方法でTensorRT化を行います。公式が出しているサンプルコードを使いまわした雑なテストになってしまいましたが、独自のモデルを読み込んで、推論時間の比較まで行いました。

KerasからTensorRT化

テスト環境: Windows10 RTX 2070 TensorRT 5.1 CUDA Toolkit 10.0 cuDNN 7.5 Visual Studio 2017

作業の流れとしては以下のようになります。

f:id:iTD_GRP:20191202224758j:plain

  1. サンプルコードを動かす

    サンプルコードが正常に動くことを確認します。

    1. TensorRTのサンプルを入手

      https://developer.nvidia.com/tensorrtからTensorRT 5.1 GA For Windowsのzipファイルをダウンロードして解凍してください。

    2. sampleUffMNISTを開く

      Visual Studio 2017で開いてください。2019では上手く行きませんでした。

    3. 実行

      ローカルWindowsデバッカーを押して実行すると、プロンプトに結果が出力されます。

      動かない時

      windowsではgetenvを使うとエラーが出るようです。command.hを f:id:iTD_GRP:20191202224824p:plain

      getopt.cを

      f:id:iTD_GRP:20191202224843p:plain

のように変えると治りました。

  1. Kerasnのモデルをpbファイルに変換

    TensorFlowのgraph_utilモジュールを使ってpbファイルに変換します。モデルはchannel firstになっている必要があります。以下では学習済みモデルをロードして、pbファイルに変換しています。

   python:creat_pb.py

   import keras
   import pickle
   import tensorflow as tf
   from keras import backend as K
   
   # GPUのメモリの設定
   config = tf.ConfigProto()
   config.gpu_options.allow_growth = True
   sess = tf.Session(config=config)
   K.set_session(sess)
   
   # channels_first
   K.set_image_data_format('channels_first')
   
   # modelの作成
   model = ****
   model.load_weights("*****.hdf5")
   
   # layer名の確認
   # outputの名前の取得
   for op in sess.graph.get_operations():
           if op.name.split('_')[0] not in ['Assign', 'Placeholder', 'IsVariableInitialized', 'init']:
               #print(op.name)
               last_name = op.name
   output_names = [last_name]
   
   # pbファイルの作成
   frozen_graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), output_names)
   frozen_graph = tf.graph_util.remove_training_nodes(frozen_graph)
   with open('*****.pb', "wb") as ofile:
       ofile.write(frozen_graph.SerializeToString())
  1. pbファイルをuffファイルに変換

    uff converter を使ってpbファイルをuffファイルに変換します。この時、TensorRTでuffファイルを読み込むときに指定する必要があるinputとoutputのnameが出力されるので確認しておきます。カスタムレイヤーの置き換えの設定がある場合、この手順の際に行います。

  2. TensorRTで読み込み、推論

    例えば以下のように、uffファイルを読み込み推論を実行してみます。

   cpp:test.cpp

   #include <cassert>
   #include <chrono>
   #include <cublas_v2.h>
   #include <cudnn.h>
   #include <iostream>
   #include <sstream>
   #include <string.h>
   #include <time.h>
   #include <unordered_map>
   #include <vector>
   
   #include<iostream>
   #include<fstream>
   
   
   #include "BatchStreamPPM.h"
   #include "NvUffParser.h"
   #include "logger.h"
   #include "common.h"
   #include "argsParser.h"
   #include "NvInferPlugin.h"
   #include "EntropyCalibrator.h"
   
   using namespace nvuffparser;
   using namespace nvinfer1;
   
   
   int main() {
   
    // Importing A TensorFlow Model Using The C++ UFF Parser API
    // Create the builder and network
    IBuilder* builder = createInferBuilder(gLogger.getTRTLogger());
    INetworkDefinition* network = builder->createNetwork();
   
    // Create the UFF parser
    std::cout << "Create the UFF parser" << std::endl;
    IUffParser* parser = createUffParser();
   
    // Declare the network inputs and outputs to the UFF parser
    parser->registerInput("input_1", DimsCHW(3, 300, 300), UffInputOrder::kNCHW);;
    parser->registerOutput("fc7_3/Sigmoid");
    //parser->registerOutput("concat_predictions");
   
    // Parse the imported model to populate the network
    std::cout << "importing uff" << std::endl;
    auto uffFile = "*****.uff";
    parser->parse(uffFile, *network, nvinfer1::DataType::kFLOAT);
   
    // Building An Engine In C++
    // Build the engine using the builder objec
    std::cout << "building an engine" << std::endl;
    int maxBatchSize = 100;
    builder->setMaxBatchSize(maxBatchSize);
    builder->setMaxWorkspaceSize(1 << 20);
    ICudaEngine* engine = builder->buildCudaEngine(*network);
   
    // Dispense with the network, builder, and parser if using one
    parser->destroy();
    network->destroy();
    builder->destroy();
   
    // Serializing A Model In C++
    // Run the builder as a prior offline step and then serialize
    std::cout << "serializing a model" << std::endl;
    IHostMemory *serializedModel{ nullptr };
    serializedModel = engine->serialize();
    // store model to disk
    // <…>
    //serializedModel->destroy();
   
    // Create a runtime object to deserialize
    IRuntime* runtime = createInferRuntime(gLogger);
    ICudaEngine* engine_runtime = runtime->deserializeCudaEngine(serializedModel->data(), serializedModel->size(), nullptr);
   
    runtime->destroy();
    serializedModel->destroy();
   
    // Performing Inference In C++
    // Create some space to store intermediate activation values. Since the engine holds the network definition and trained parameters, additional space is necessary. These are held in an execution context
    IExecutionContext *context = engine->createExecutionContext();
   
    // Use the input and output blob names to get the corresponding input and output index
    const char* INPUT_BLOB_NAME = "input_1";
    const char* OUTPUT_BLOB_NAME = "fc7_3/Sigmoid";
    int inputIndex = engine_runtime->getBindingIndex(INPUT_BLOB_NAME);
    int outputIndex = engine_runtime->getBindingIndex(OUTPUT_BLOB_NAME);
   
    // Using these indices, set up a buffer array pointing to the input and output buffers on the GPU
    std::cout << "reading PGMfile" << std::endl;
   
   
    int64_t eltCountIn = 300 * 300 * 3;
    // Batch size
    const int N = 2;
    size_t memSizeIn = N * eltCountIn * sizeof(float);
    std::vector<std::string> imageList = {
           "*****.ppm",
           "*****.ppm",
    };
    std::vector<samplesCommon::PPM<INPUT_C, INPUT_H, INPUT_W>> ppms(N);
   
    // 画像の読み込み
    assert(ppms.size() <= imageList.size());
    for (int i = 0; i < N; ++i)
    {
        readPPMFile(imageList[i], ppms[i]);
    }
   
    //vector<float> data(N * INPUT_C * INPUT_H * INPUT_W);
    float* data = new float[N * eltCountIn];
   
    // 正規化
    for (int i = 0; i < N; ++i)
    {
        for (int j = 0; j < eltCountIn; ++j)
        {
            data[i*eltCountIn + j] = 1.0f - float(ppms[i].buffer[j]) / 255.0f;
            //data[(i + 1)*j] = float(ppms[i].buffer[j]);
        }
    }
   
    void* deviceMemIn;
    cudaMalloc(&deviceMemIn, memSizeIn);
    //cudaMemcpy(deviceMemIn, &data[0], memSizeIn, cudaMemcpyHostToDevice);
    cudaMemcpy(deviceMemIn, data, memSizeIn, cudaMemcpyHostToDevice);
   
   
    int64_t eltCountOut = 3 * 5 * 5;
    size_t memSizeOut = N * eltCountOut * sizeof(float);
    // ここは配列にしていますが、多次元配列の方が良いかもしれません。
    float* outputs = new float[N*eltCountOut];
    void* deviceMemOut;
    cudaMalloc(&deviceMemOut, memSizeOut);
   
    void* buffers[2];
    void* inputbuffer = deviceMemIn;
    void* outputBuffer = deviceMemOut;
    buffers[inputIndex] = inputbuffer;
    buffers[outputIndex] = outputBuffer;
   
    std::cout << "doing inference" << std::endl;
    cudaStream_t stream;
    cudaStreamCreate(&stream);
    // 時間に関する処理
    auto t_start = std::chrono::high_resolution_clock::now();
    // 推論の実行
    context->enqueue(N, buffers, stream, nullptr);
    // 時間に関する処理
    auto t_end = std::chrono::high_resolution_clock::now();
    float total = std::chrono::duration<float, std::milli>(t_end - t_start).count();
    std::cout << " Num batches  " << N << std::endl;
    std::cout << "Time taken for inference is " << total << " ms." << std::endl;
   
    // outputの取得
    std::cout << "getting output" << std::endl;
    cudaMemcpy(outputs, buffers[outputIndex], memSizeOut, cudaMemcpyDeviceToHost);
        // outputをテキストファイルに出力
    ofstream outputfile2("output.txt");
    for (int i = 0; i < N*eltCountOut; i++)
    {
        outputfile2 << outputs[i] << "," << std::endl;
    }
    outputfile2.close();
    CHECK(cudaFree(buffers[inputIndex]));
    CHECK(cudaFree(buffers[outputIndex]));
   
    return 0;
   }

結果

Keras、TensorRTそれぞれで同じ画像を読み込んで、推論して出力が一致するか確認しました。TensorRTの出力は分かりやすいようにコピーしてPythonでnumpyのアレイにしています。

TensorRT(←)Keras(→) f:id:iTD_GRP:20191202224909p:plain

ほぼ同じ出力が得られていることが確認できました。

推論の速さを比較してみると、20枚推論するのに、TenosrRTは0.4msほど、Kerasは110msほどで、高速化されていることが確認できました。

注意点

  • TensorRTはPython APIC++ APIがあります。Linuxではどちらも対応していますが、WindowsではC++ APIしか対応していません。
  • Windows版のTensorRTをダウンロードした場合、uff converterは付属していません。そのためLinux版のTensorRTをダウンロードして、その中のwhlファイルからインストールする必要があります。
  • TensorRTはすべてのモデルのレイヤーに対応している訳ではありません。もし対応していないレイヤーが存在した場合、uff converter を使った際に警告が出ます。対処法としては、対応しているレイヤーだけを使う方法と対応していないレイヤーをカスタムレイヤーに置き換える方法があります。カスタムレイヤーにはTensorRTがあらかじめ定義しているレイヤーと自分で定義するレイヤーがあります。自分で定義することであらゆるレイヤーに対応できますが、時間と知識が必要になると思います。

課題

今回のテストでTensorRTの基本的な使い方と、その性能を確認することが出来ました。しかし、実用的なモデルをTensorRT化するには課題が残りました。特にカスタムレイヤーの定義は大きな課題です。この問題を解決できれば対応できるモデルは格段に増えるはずです。

使用したモデル

import keras.backend as K
from keras.layers import Activation
from keras.layers import Conv2D
from keras.layers import ZeroPadding2D
from keras.layers import Dropout
from keras.layers import Dense
from keras.layers import Flatten
from keras.layers import GlobalAveragePooling2D
from keras.layers import Input
from keras.layers import AveragePooling2D
from keras.layers import MaxPooling2D
from keras.layers.merge import concatenate
from keras.layers import Reshape
from keras.models import Model



def OED():
    input0 = Input(shape=(300,300,3))
    # Inception
    # Conv 1
    conv1_1 = Conv2D(16, (1,1), activation='relu', padding='same', name='conv1_1')(input0)
    conv1_2 = Conv2D(32, (3,3), activation='relu', padding='same', name='conv1_2')(input0)
    conv1_3 = Conv2D(16, (5,5), activation='relu', padding='same', name='conv1_3')(input0)
    conv1_4 = Conv2D(8, (7,7), activation='relu', padding='same', name='conv1_4')(input0)
    conc1 = concatenate([conv1_1, conv1_2, conv1_3, conv1_4], axis=-1, name='conc1')
    pool1 = MaxPooling2D((2,2), strides=(2,2), padding='same', name='pool1')(conc1)

    # Conv 2
    conv2_1 = Conv2D(32, (1,1), activation='relu', padding='same', name='conv2_1')(pool1)
    conv2_2 = Conv2D(64, (3,3), activation='relu', padding='same', name='conv2_2')(pool1)
    conv2_3 = Conv2D(32, (5,5), activation='relu', padding='same', name='conv2_3')(pool1)
    conv2_4 = Conv2D( 16, (7,7), activation='relu', padding='same', name='conv2_4')(pool1)
    resi2 = Conv2D(16, (2,2), strides=(2,2), activation='relu', padding='same', name='resi2')(input0)
    conc2 = concatenate([conv2_1, conv2_2, conv2_3, conv2_4, resi2], axis=-1, name='conc2')
    pool2 = MaxPooling2D((2,2), strides=(2,2), padding='same', name='pool2')(conc2)
    
    # Conv 3
    conv3_1 = Conv2D( 64, (1,1), activation='relu', padding='same', name='conv3_1')(pool2)
    conv3_2 = Conv2D(128, (3,3), activation='relu', padding='same', name='conv3_2')(pool2)
    conv3_3 = Conv2D( 64, (5,5), activation='relu', padding='same', name='conv3_3')(pool2)
    conv3_4 = Conv2D( 32, (7,7), activation='relu', padding='same', name='conv3_4')(pool2)
    resi3_1 = Conv2D(16, (4,4), strides=(4,4), activation='relu', padding='same', name='resi3_1')(input0)
    resi3_2 = Conv2D(32, (2,2), strides=(2,2), activation='relu', padding='same', name='resi3_2')(pool1)
    conc3 = concatenate([conv3_1, conv3_2, conv3_3, conv3_4, resi3_1, resi3_2], axis=-1, name='conc3')
    pool3 = MaxPooling2D((2,2), strides=(2,2), padding='same', name='pool3')(conc3)
    
    # Conv 4
    conv4_1 = Conv2D(128, (1,1), activation='relu', padding='same', name='conv4_1')(pool3)
    conv4_2 = Conv2D(256, (3,3), activation='relu', padding='same', name='conv4_2')(pool3)
    conv4_3 = Conv2D( 96, (5,5), activation='relu', padding='same', name='conv4_3')(pool3)
    conv4_4 = Conv2D( 64, (7,7), activation='relu', padding='same', name='conv4_4')(pool3)
    resi4_1 = Conv2D(16, (8,8), strides=(8,8), activation='relu', padding='same', name='resi4_1')(input0)
    resi4_2 = Conv2D(24, (4,4), strides=(4,4), activation='relu', padding='same', name='resi4_2')(pool1)
    resi4_3 = Conv2D(32, (2,2), strides=(2,2), activation='relu', padding='same', name='resi4_3')(pool2)
    conc4 = concatenate([conv4_1, conv4_2, conv4_3, conv4_4, resi4_1, resi4_2, resi4_3], axis=-1, name='conc4')
    pool4 = MaxPooling2D((2,2), strides=(2,2), padding='same', name='pool4')(conc4)
    
    # Conv 5
    conv5_1 = Conv2D(192, (1,1), activation='relu', padding='same', name='conv5_1')(pool4)
    conv5_2 = Conv2D(384, (3,3), activation='relu', padding='same', name='conv5_2')(pool4)
    conv5_3 = Conv2D(128, (5,5), activation='relu', padding='same', name='conv5_3')(pool4)
    resi5_1 = Conv2D( 8, (16,16), strides=(16,16), activation='relu', padding='same', name='resi5_1')(input0)
    resi5_2 = Conv2D(16, ( 8, 8), strides=(8,8), activation='relu', padding='same', name='resi5_2')(pool1)
    resi5_3 = Conv2D(24, ( 4, 4), strides=(4,4), activation='relu', padding='same', name='resi5_3')(pool2)
    resi5_4 = Conv2D(32, ( 2, 2), strides=(2,2), activation='relu', padding='same', name='resi5_4')(pool3)
    conc5 = concatenate([conv5_1, conv5_2, conv5_3, resi5_1, resi5_2, resi5_3, resi5_4], axis=-1, name='conc5')
    pool5 = MaxPooling2D((2,2), strides=(2,2), padding='same', name='pool5')(conc5)
    
    # Conv 6
    conv6_1 = Conv2D(512, (1,1), activation='relu', padding='same', name='conv6_1')(pool5)
    conv6_2 = Conv2D(512, (3,3), activation='relu', padding='same', name='conv6_2')(pool5)
    conc6 = concatenate([conv6_1, conv6_2], axis=-1, name='conc6')
    pool6 = MaxPooling2D((2,2), strides=(2,2), padding='same', name='pool6')(conc6)
    
    # FC 7
    fc7_1 = Conv2D(1024, (1,1), activation='relu', name='fc7_1')(pool6)
    drop7_1 = Dropout(0.2, name='drop7_1')(fc7_1)
    fc7_2 = Conv2D(512, (1,1), activation='relu', name='fc7_2')(drop7_1)
    #fc7_2 = Conv2D(512, (1,1), activation='relu', name='fc7_2')(fc7_1)
    fc7_3 = Conv2D(3, (1,1), activation='sigmoid', name='fc7_3')(fc7_2)
    
    model = Model(input0, fc7_3)
    return model

if __name__ == '__main__':
    model = OED()
    model.summary()

参考文献

[1] NVIDIA TenosrRT

https://developer.nvidia.com/tensorrt

[2] Interface 2019年8月号 AI研究モダン計測制御


次の記事へ

前の記事へ 目次に戻る