您当前的位置: 首页 >  pytorch

柳鲲鹏

暂无认证

  • 0浏览

    0关注

    4642博文

    0收益

  • 0浏览

    0点赞

    0打赏

    0留言

私信
关注
热门博文

Pytorch转TensorRT范例代码

柳鲲鹏 发布时间:2018-10-25 15:56:56 ,浏览量:0

  TensorRT官方文档说,/usr/src/tensorrt/samples/python/network_api_pytorch_mnist下有示例代码。实际上根本就没有。这里提供一个示例代码,供参考。

  这个范例的具体位置是:/usr/local/lib/python3.5/site-packages/tensorrt/examples/pytorch_to_trt

#!/usr/bin/python
import os
from random import randint
import numpy as np

try:
    import pycuda.driver as cuda
    import pycuda.gpuarray as gpuarray
    import pycuda.autoinit
except ImportError as err:
    raise ImportError("""ERROR: Failed to import module({})
Please make sure you have pycuda and the example dependencies installed.
sudo apt-get install python(3)-pycuda
pip install tensorrt[examples]""".format(err))

try:
    from PIL import Image
except ImportError as err:
    raise ImportError("""ERROR: Failed to import module ({})
Please make sure you have Pillow installed.
For installation instructions, see:
http://pillow.readthedocs.io/en/stable/installation.html""".format(err))

import mnist

try:
    import torch
except ImportError as err:
    raise ImportError("""ERROR: Failed to import module ({})
Please make sure you have PyTorch installed.
For installation instructions, see:
http://pytorch.org/""".format(err))

# TensorRT must be imported after any frameworks in the case where
# the framework has incorrect dependencies setup and is not updated
# to use the versions of libraries that TensorRT imports.
try:
    import tensorrt as trt
except ImportError as err:
    raise ImportError("""ERROR: Failed to import module ({})
Please make sure you have the TensorRT Library installed
and accessible in your LD_LIBRARY_PATH""".format(err))

G_LOGGER = trt.infer.ConsoleLogger(trt.infer.LogSeverity.INFO)

ITERATIONS = 10
INPUT_LAYERS = ["data"]
OUTPUT_LAYERS = ['prob']
INPUT_H = 28
INPUT_W = 28
OUTPUT_SIZE = 10

def create_pytorch_engine(max_batch_size, builder, dt, model):
    network = builder.create_network()

    data = network.add_input(INPUT_LAYERS[0], dt, (1, INPUT_H, INPUT_W))
    assert(data)

    #-------------
    conv1_w = model['conv1.weight'].cpu().numpy().reshape(-1)
    conv1_b = model['conv1.bias'].cpu().numpy().reshape(-1)
    conv1 = network.add_convolution(data, 20, (5,5),  conv1_w, conv1_b)
    assert(conv1)
    conv1.set_stride((1,1))

    #-------------
    pool1 = network.add_pooling(conv1.get_output(0), trt.infer.PoolingType.MAX, (2,2))
    assert(pool1)
    pool1.set_stride((2,2))

    #-------------
    conv2_w = model['conv2.weight'].cpu().numpy().reshape(-1)
    conv2_b = model['conv2.bias'].cpu().numpy().reshape(-1)
    conv2 = network.add_convolution(pool1.get_output(0), 50, (5,5), conv2_w, conv2_b)
    assert(conv2)
    conv2.set_stride((1,1))

    #-------------
    pool2 = network.add_pooling(conv2.get_output(0), trt.infer.PoolingType.MAX, (2,2))
    assert(pool2)
    pool2.set_stride((2,2))

    #-------------
    fc1_w = model['fc1.weight'].cpu().numpy().reshape(-1)
    fc1_b = model['fc1.bias'].cpu().numpy().reshape(-1)
    fc1 = network.add_fully_connected(pool2.get_output(0), 500, fc1_w, fc1_b)
    assert(fc1)

    #-------------
    relu1 = network.add_activation(fc1.get_output(0), trt.infer.ActivationType.RELU)
    assert(relu1)

    #-------------
    fc2_w = model['fc2.weight'].cpu().numpy().reshape(-1)
    fc2_b = model['fc2.bias'].cpu().numpy().reshape(-1)
    fc2 = network.add_fully_connected(relu1.get_output(0), OUTPUT_SIZE, fc2_w, fc2_b)
    assert(fc2)

    #-------------
    # Using log_softmax in training, cutting out log softmax here since no log softmax in TRT
    fc2.get_output(0).set_name(OUTPUT_LAYERS[0])
    network.mark_output(fc2.get_output(0))


    builder.set_max_batch_size(max_batch_size)
    builder.set_max_workspace_size(1             
关注
打赏
1665724893
查看更多评论
0.0518s