写在前面
张量是一个包含单一数据类型元素的多维矩阵,pytorch中共定义了10种张量的数据类型。
数据类型 Data typedtypeCPU tensorGPU tensor32-bit floating pointtorch.float32 or torch.floattorch.FloatTensortorch.cuda.FloatTensor64-bit floating pointtorch.float64 or torch.doubletorch.DoubleTensortorch.cuda.DoubleTensor16-bit floating point 1torch.float16 or torch.halftorch.HalfTensortorch.cuda.HalfTensor16-bit floating point 2torch.bfloat16torch.BFloat16Tensortorch.cuda.BFloat16Tensor32-bit complextorch.complex32\\64-bit complextorch.complex64\\128-bit complextorch.complex128 or torch.cdouble\\8-bit integer (unsigned)torch.uint8torch.ByteTensortorch.cuda.ByteTensor8-bit integer (signed)torch.int8torch.CharTensortorch.cuda.CharTensor16-bit integer (signed)torch.int16 or torch.shorttorch.ShortTensortorch.cuda.ShortTensor32-bit integer (signed)torch.int32 or torch.inttorch.IntTensortorch.cuda.IntTensor64-bit integer (signed)torch.int64 or torch.longtorch.LongTensortorch.cuda.LongTensorBooleantorch.booltorch.BoolTensortorch.cuda.BoolTensor 转换方式 方式1直接在张量后面加.int()
、.long()
、.float()
、.double()
>>> import torch
>>> a = torch.tensor([1.3, 1.5, 1.7])
>>> a
tensor([1.3000, 1.5000, 1.7000])
>>> a.dtype
torch.float32
>>> a.int()
tensor([1, 1, 1], dtype=torch.int32)
>>> a.long()
tensor([1, 1, 1])
>>> a.float()
tensor([1.3000, 1.5000, 1.7000])
>>> a.double()
tensor([1.3000, 1.5000, 1.7000], dtype=torch.float64)
方式2
使用张量的.type(dtype)
方法,dtype
参数填数据类型表格中的第二列,如torch.uint8
>>> import torch
>>> a = torch.tensor([-1, 1, 256])
>>> a
tensor([ -1, 1, 256])
>>> a.dtype
torch.int64
>>> a.type(torch.uint8)
tensor([255, 1, 0], dtype=torch.uint8)
方式3
使用张量的.type_as(tensor)
方法将张量1转换为和张量2一样的数据类型
>>> import torch
>>> a = torch.tensor([1, 2, 3])
>>> a.dtype
torch.int64
>>> b = torch.tensor([1.1, 2.2, 3.3])
>>> b.dtype
torch.float32
>>> b = b.type_as(a)
>>> b.dtype
torch.int64
>>> b
tensor([1, 2, 3])
引用参考
https://pytorch.org/docs/stable/tensors.html