您当前的位置: 首页 >  Python

静静喜欢大白

暂无认证

  • 2浏览

    0关注

    521博文

    0收益

  • 0浏览

    0点赞

    0打赏

    0留言

私信
关注
热门博文

Python-torch.max()

静静喜欢大白 发布时间:2021-03-16 17:59:24 ,浏览量:2

转载

目录

1. torch.max(input, dim) 函数

2.准确率的计算

在分类问题中,通常需要使用max()函数对softmax函数的输出值进行操作,求出预测值索引。下面讲解一下torch.max()函数的输入及输出值都是什么。

1. torch.max(input, dim) 函数

output = torch.max(input, dim)

输入

  • input是softmax函数输出的一个tensor
  • dim是max函数索引的维度0/10是每列的最大值,1是每行的最大值(对于2维矩阵,-1是最后一个维度,在这里相当于1,按行)

输出

  • 函数会返回两个tensor,第一个tensor是每行的最大值;第二个tensor是每行最大值的索引(计算acc时只需要第二个tensoe,所以取【1】)

在多分类任务中我们并不需要知道各类别的预测概率,所以第一个tensor对分类任务没有帮助,而第二个tensor包含了最大概率的索引,所以在实际使用中我们仅获取第二个tensor即可。

我们通过一个实例可以更容易理解这个函数的用法。

import torch
a = torch.tensor([[1,5,62,54], [2,6,2,6], [2,65,2,6]])
print(a)

输出:

tensor([[ 1,  5, 62, 54],
        [ 2,  6,  2,  6],
        [ 2, 65,  2,  6]])

索引每行的最大值:

torch.max(a, 1)

输出:

torch.return_types.max(
values=tensor([62,  6, 65]),
indices=tensor([2, 3, 1]))

在计算准确率时第一个tensor values是不需要的,所以我们只需提取第二个tensor,并将tensor格式的数据转换成array格式。

torch.max(a, 1)[1].numpy()

输出:

array([2, 3, 1], dtype=int64)

*注:在有的地方我们会看到torch.max(a, 1).data.numpy()的写法,这是因为在早期的pytorch的版本中,variable变量和tenosr是不一样的数据格式,variable可以进行反向传播,tensor不可以,需要将variable转变成tensor再转变成numpy。现在的版本已经将variable和tenosr合并,所以只用torch.max(a,1).numpy()就可以了。

2.准确率的计算
pred_y = torch.max(predict, 1)[1].numpy()
y_label = torch.max(label, 1)[1].data.numpy()
accuracy = (pred_y == y_label).sum() / len(y_label)

关注
打赏
1510642601
查看更多评论
立即登录/注册

微信扫码登录

0.0371s