如何有效地检索 Torch 张量中最大值的索引?

编程入门 行业动态 更新时间:2024-10-11 17:19:32
本文介绍了如何有效地检索 Torch 张量中最大值的索引?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧! 问题描述

假设有一个火炬张量,例如以下形状:

Assume to have a torch tensor, for example of the following shape:

x = torch.rand(20, 1, 120, 120)

我现在想要的是获取每个 120x120 矩阵的最大值的索引.为了简化问题,我首先将 x.squeeze() 与形状 [20, 120, 120] 一起使用.然后我想得到火炬张量,它是一个形状为 [20, 2] 的索引列表.

What I would like now, is to get the indices of the maximum values of each 120x120 matrix. To simplify the problem I would first x.squeeze() to work with shape [20, 120, 120]. I would then like to get torch tensor which is a list of indices with shape [20, 2].

我怎样才能快速做到这一点?

How can I do this fast?

推荐答案

如果我理解正确,您不需要值,而是索引.不幸的是,没有现成的解决方案.存在一个 argmax() 函数,但我不知道如何让它完全按照你的意愿去做.

If I get you correctly you don't want the values, but the indices. Unfortunately there is no out of the box solution. There exists an argmax() function, but I cannot see how to get it to do exactly what you want.

所以这里有一个小的解决方法,效率应该还可以,因为我们只是对张量进行除法:

So here is a small workaround, the efficiency should also be okay since we're just dividing tensors:

n = torch.tensor(4) d = torch.tensor(4) x = torch.rand(n, 1, d, d) m = x.view(n, -1).argmax(1) # since argmax() does only return the index of the flattened # matrix block we have to calculate the indices by ourself # by using / and % (// would also work, but as we are dealing with # type torch.long / works as well indices = torch.cat(((m / d).view(-1, 1), (m % d).view(-1, 1)), dim=1) print(x) print(indices)

n 代表你的第一个维度,d 代表最后两个维度.我在这里取较小的数字来显示结果.但当然这也适用于 n=20 和 d=120:

n represents your first dimension, and d the last two dimensions. I take smaller numbers here to show the result. But of course this will also work for n=20 and d=120:

n = torch.tensor(20) d = torch.tensor(120) x = torch.rand(n, 1, d, d) m = x.view(n, -1).argmax(1) indices = torch.cat(((m / d).view(-1, 1), (m % d).view(-1, 1)), dim=1) #print(x) print(indices)

这是 n=4 和 d=4 的输出:

Here is the output for n=4 and d=4:

tensor([[[[0.3699, 0.3584, 0.4940, 0.8618], [0.6767, 0.7439, 0.5984, 0.5499], [0.8465, 0.7276, 0.3078, 0.3882], [0.1001, 0.0705, 0.2007, 0.4051]]], [[[0.7520, 0.4528, 0.0525, 0.9253], [0.6946, 0.0318, 0.5650, 0.7385], [0.0671, 0.6493, 0.3243, 0.2383], [0.6119, 0.7762, 0.9687, 0.0896]]], [[[0.3504, 0.7431, 0.8336, 0.0336], [0.8208, 0.9051, 0.1681, 0.8722], [0.5751, 0.7903, 0.0046, 0.1471], [0.4875, 0.1592, 0.2783, 0.6338]]], [[[0.9398, 0.7589, 0.6645, 0.8017], [0.9469, 0.2822, 0.9042, 0.2516], [0.2576, 0.3852, 0.7349, 0.2806], [0.7062, 0.1214, 0.0922, 0.1385]]]]) tensor([[0, 3], [3, 2], [1, 1], [1, 0]])

我希望这是你想要的!:)

I hope this is what you wanted to get! :)

这里有一个稍微修改过的,它可能会稍微快一点(我猜不是很多:),但它更简单和更漂亮:

Here is a slightly modified which might be minimally faster (not much I guess :), but it is a bit simpler and prettier:

而不是像以前那样:

m = x.view(n, -1).argmax(1) indices = torch.cat(((m // d).view(-1, 1), (m % d).view(-1, 1)), dim=1)

已经对 argmax 值进行了必要的整形:

The necessary reshaping already done on the argmax values:

m = x.view(n, -1).argmax(1).view(-1, 1) indices = torch.cat((m // d, m % d), dim=1)

但正如评论中提到的那样.我认为不可能从中得到更多.

But as mentioned in the comments. I don't think it is possible to get much more out of it.

您可以做的一件事是,如果真的对您来说获得最后可能的性能改进很重要,则将上述功能作为低级扩展实现(例如C++) 用于 pytorch.

One thing you could do, if it is really important for you to get the last possible bit of performance improvement out of it, is implementing this above function as a low-level extension (like in C++) for pytorch.

这只会给你一个你可以调用它的函数,并且会避免运行缓慢的 Python 代码.

This would give you just one function you can call for it and would avoid slow python code.

pytorch/tutorials/advanced/cpp_extension.html

更多推荐

如何有效地检索 Torch 张量中最大值的索引?

本文发布于:2023-11-29 08:49:47,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1645920.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:张量   最大值   有效地   索引   Torch

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!