TensorRT官方文档说,/usr/src/tensorrt/samples/python/network_api_pytorch_mnist下有示例代码。实际上根本就没有。这里提供一个示例代码,供参考。
这个范例的具体位置是:/usr/local/lib/python3.5/site-packages/tensorrt/examples/pytorch_to_trt
#!/usr/bin/python
import os
from random import randint
import numpy as np
try:
import pycuda.driver as cuda
import pycuda.gpuarray as gpuarray
import pycuda.autoinit
except ImportError as err:
raise ImportError("""ERROR: Failed to import module({})
Please make sure you have pycuda and the example dependencies installed.
sudo apt-get install python(3)-pycuda
pip install tensorrt[examples]""".format(err))
try:
from PIL import Image
except ImportError as err:
raise ImportError("""ERROR: Failed to import module ({})
Please make sure you have Pillow installed.
For installation instructions, see:
http://pillow.readthedocs.io/en/stable/installation.html""".format(err))
import mnist
try:
import torch
except ImportError as err:
raise ImportError("""ERROR: Failed to import module ({})
Please make sure you have PyTorch installed.
For installation instructions, see:
http://pytorch.org/""".format(err))
# TensorRT must be imported after any frameworks in the case where
# the framework has incorrect dependencies setup and is not updated
# to use the versions of libraries that TensorRT imports.
try:
import tensorrt as trt
except ImportError as err:
raise ImportError("""ERROR: Failed to import module ({})
Please make sure you have the TensorRT Library installed
and accessible in your LD_LIBRARY_PATH""".format(err))
G_LOGGER = trt.infer.ConsoleLogger(trt.infer.LogSeverity.INFO)
ITERATIONS = 10
INPUT_LAYERS = ["data"]
OUTPUT_LAYERS = ['prob']
INPUT_H = 28
INPUT_W = 28
OUTPUT_SIZE = 10
def create_pytorch_engine(max_batch_size, builder, dt, model):
network = builder.create_network()
data = network.add_input(INPUT_LAYERS[0], dt, (1, INPUT_H, INPUT_W))
assert(data)
#-------------
conv1_w = model['conv1.weight'].cpu().numpy().reshape(-1)
conv1_b = model['conv1.bias'].cpu().numpy().reshape(-1)
conv1 = network.add_convolution(data, 20, (5,5), conv1_w, conv1_b)
assert(conv1)
conv1.set_stride((1,1))
#-------------
pool1 = network.add_pooling(conv1.get_output(0), trt.infer.PoolingType.MAX, (2,2))
assert(pool1)
pool1.set_stride((2,2))
#-------------
conv2_w = model['conv2.weight'].cpu().numpy().reshape(-1)
conv2_b = model['conv2.bias'].cpu().numpy().reshape(-1)
conv2 = network.add_convolution(pool1.get_output(0), 50, (5,5), conv2_w, conv2_b)
assert(conv2)
conv2.set_stride((1,1))
#-------------
pool2 = network.add_pooling(conv2.get_output(0), trt.infer.PoolingType.MAX, (2,2))
assert(pool2)
pool2.set_stride((2,2))
#-------------
fc1_w = model['fc1.weight'].cpu().numpy().reshape(-1)
fc1_b = model['fc1.bias'].cpu().numpy().reshape(-1)
fc1 = network.add_fully_connected(pool2.get_output(0), 500, fc1_w, fc1_b)
assert(fc1)
#-------------
relu1 = network.add_activation(fc1.get_output(0), trt.infer.ActivationType.RELU)
assert(relu1)
#-------------
fc2_w = model['fc2.weight'].cpu().numpy().reshape(-1)
fc2_b = model['fc2.bias'].cpu().numpy().reshape(-1)
fc2 = network.add_fully_connected(relu1.get_output(0), OUTPUT_SIZE, fc2_w, fc2_b)
assert(fc2)
#-------------
# Using log_softmax in training, cutting out log softmax here since no log softmax in TRT
fc2.get_output(0).set_name(OUTPUT_LAYERS[0])
network.mark_output(fc2.get_output(0))
builder.set_max_batch_size(max_batch_size)
builder.set_max_workspace_size(1
关注
打赏
最近更新
- 深拷贝和浅拷贝的区别(重点)
- 【Vue】走进Vue框架世界
- 【云服务器】项目部署—搭建网站—vue电商后台管理系统
- 【React介绍】 一文带你深入React
- 【React】React组件实例的三大属性之state,props,refs(你学废了吗)
- 【脚手架VueCLI】从零开始,创建一个VUE项目
- 【React】深入理解React组件生命周期----图文详解(含代码)
- 【React】DOM的Diffing算法是什么?以及DOM中key的作用----经典面试题
- 【React】1_使用React脚手架创建项目步骤--------详解(含项目结构说明)
- 【React】2_如何使用react脚手架写一个简单的页面?