ge/0573.jpg" alt="torch.ge()与nn.masked"/>
torch.ge()与nn.masked
1.torch.
ge
(input, other, *, out=None) → Tensor
Computes input≥other element-wise. 返回的是bool tensor,shape与input相同
Parameters
input (Tensor) – the tensor to compare
other (Tensor or float) – the tensor or value to compare
Keyword Arguments
out (Tensor, optional) – the output tensor.
Returns
A boolean tensor that is True where
input
is greater than or equal toother
and False elsewhere
2.nn.masked_select()
配合torch.ge使用,用于筛选出满足ge条件的input中的数据
Returns a new 1-D tensor which indexes the
input
tensor according to the boolean maskmask
which is a BoolTensor.The shapes of the
mask
tensor and theinput
tensor don’t need to match, but they must be broadcastable.
Parameters
input (Tensor) – the input tensor.
mask (BoolTensor) – the tensor containing the binary mask to index with
Keyword Arguments
out (Tensor, optional) – the output tensor.
import torch
x = torch.randn([3, 4])
print(x)
# 将x中的每一个元素与0.5进行比较
# 当元素大于等于0.5返回True,否则返回False
mask = x.ge(0.5)
print(mask)
print(torch.masked_select(x, mask))
'''
tensor([[ 1.2001, 1.2968, -0.6657, -0.6907],[-2.0099, 0.6249, -0.5382, 1.4458],[ 0.0684, 0.4118, 0.1011, -0.5684]])
tensor([[ True, True, False, False],[False, True, False, True],[False, False, False, False]])
tensor([1.2001, 1.2968, 0.6249, 1.4458])
'''
更多推荐
torch.ge()与nn.masked
发布评论