写在前面
张量是一个包含单一数据类型元素的多维矩阵,pytorch中共定义了10种张量的数据类型。
数据类型
| Data type | dtype | CPU tensor | GPU tensor |
|---|---|---|---|
| 32-bit floating point | torch.float32 or torch.float | torch.FloatTensor | torch.cuda.FloatTensor |
| 64-bit floating point | torch.float64 or torch.double | torch.DoubleTensor | torch.cuda.DoubleTensor |
| 16-bit floating point 1 | torch.float16 or torch.half | torch.HalfTensor | torch.cuda.HalfTensor |
| 16-bit floating point 2 | torch.bfloat16 | torch.BFloat16Tensor | torch.cuda.BFloat16Tensor |
| 32-bit complex | torch.complex32 | \ | \ |
| 64-bit complex | torch.complex64 | \ | \ |
| 128-bit complex | torch.complex128 or torch.cdouble | \ | \ |
| 8-bit integer (unsigned) | torch.uint8 | torch.ByteTensor | torch.cuda.ByteTensor |
| 8-bit integer (signed) | torch.int8 | torch.CharTensor | torch.cuda.CharTensor |
| 16-bit integer (signed) | torch.int16 or torch.short | torch.ShortTensor | torch.cuda.ShortTensor |
| 32-bit integer (signed) | torch.int32 or torch.int | torch.IntTensor | torch.cuda.IntTensor |
| 64-bit integer (signed) | torch.int64 or torch.long | torch.LongTensor | torch.cuda.LongTensor |
| Boolean | torch.bool | torch.BoolTensor | torch.cuda.BoolTensor |
转换方式
方式1
直接在张量后面加.int()、.long()、.float()、.double()
>>> import torch
>>> a = torch.tensor([1.3, 1.5, 1.7])
>>> a
tensor([1.3000, 1.5000, 1.7000])
>>> a.dtype
torch.float32
>>> a.int()
tensor([1, 1, 1], dtype=torch.int32)
>>> a.long()
tensor([1, 1, 1])
>>> a.float()
tensor([1.3000, 1.5000, 1.7000])
>>> a.double()
tensor([1.3000, 1.5000, 1.7000], dtype=torch.float64)
方式2
使用张量的.type(dtype)方法,dtype参数填数据类型表格中的第二列,如torch.uint8
>>> import torch
>>> a = torch.tensor([-1, 1, 256])
>>> a
tensor([ -1, 1, 256])
>>> a.dtype
torch.int64
>>> a.type(torch.uint8)
tensor([255, 1, 0], dtype=torch.uint8)
方式3
使用张量的.type_as(tensor)方法将张量1转换为和张量2一样的数据类型
>>> import torch
>>> a = torch.tensor([1, 2, 3])
>>> a.dtype
torch.int64
>>> b = torch.tensor([1.1, 2.2, 3.3])
>>> b.dtype
torch.float32
>>> b = b.type_as(a)
>>> b.dtype
torch.int64
>>> b
tensor([1, 2, 3])
引用参考
https://pytorch.org/docs/stable/tensors.html
