Pytorch 中使用 sort 进行排序的基本方法

Pytorch 是非常注明的机器学习框架,其中的 torch.Tensor 是自带排序的,直接使用 torch.sort() 这个方法即可。排序可以按照升序、降序,可以选择排序的维度,等等。下面介绍一下 Pytorch 中的排序方法。

文章来自:https://hxhen.com/sort-method-in-torch-model/

一、方法原型

torch.sort(input, dim=None, descending=False, out=None) -> (Tensor, LongTensor)

二、返回值

A tuple of (sorted_tensor, sorted_indices) is returned, 
where the sorted_indices are the indices of the elements in the original input tensor.

三、参数

  • input (Tensor) – the input tensor
    形式上与 numpy.narray 类似
  • dim (int, optional) – the dimension to sort along
    维度,对于二维数据:dim=0 按列排序,dim=1 按行排序,默认 dim=1
  • descending (bool, optional) – controls the sorting order (ascending or descending)
    降序,descending=True 从大到小排序,descending=False 从小到大排序,默认 descending=Flase

四、实例

import torch
x = torch.randn(3,4)
x  #初始值,始终不变
tensor([[-0.9950, -0.6175, -0.1253,  1.3536],
        [ 0.1208, -0.4237, -1.1313,  0.9022],
        [-1.1995, -0.0699, -0.4396,  0.8043]])
sorted, indices = torch.sort(x)  #按行从小到大排序
sorted
tensor([[-0.9950, -0.6175, -0.1253,  1.3536],
        [-1.1313, -0.4237,  0.1208,  0.9022],
        [-1.1995, -0.4396, -0.0699,  0.8043]])
indices
tensor([[0, 1, 2, 3],
        [2, 1, 0, 3],
        [0, 2, 1, 3]])
sorted, indices = torch.sort(x, descending=True)  #按行从大到小排序 (即反序)
sorted
tensor([[ 1.3536, -0.1253, -0.6175, -0.9950],
        [ 0.9022,  0.1208, -0.4237, -1.1313],
        [ 0.8043, -0.0699, -0.4396, -1.1995]])
indices
tensor([[3, 2, 1, 0],
        [3, 0, 1, 2],
        [3, 1, 2, 0]])
sorted, indices = torch.sort(x, dim=0)  #按列从小到大排序
sorted
tensor([[-1.1995, -0.6175, -1.1313,  0.8043],
        [-0.9950, -0.4237, -0.4396,  0.9022],
        [ 0.1208, -0.0699, -0.1253,  1.3536]])
indices
tensor([[2, 0, 1, 2],
        [0, 1, 2, 1],
        [1, 2, 0, 0]])
sorted, indices = torch.sort(x, dim=0, descending=True)  #按列从大到小排序
sorted
tensor([[ 0.1208, -0.0699, -0.1253,  1.3536],
        [-0.9950, -0.4237, -0.4396,  0.9022],
        [-1.1995, -0.6175, -1.1313,  0.8043]])
indices
tensor([[1, 2, 0, 0],
        [0, 1, 2, 1],
        [2, 0, 1, 2]])

官方文档:https://pytorch.org/docs/stable/torch.html


【AD】美国洛杉矶CN2 VPS/香港CN2 VPS/日本CN2 VPS推荐,延迟低、稳定性高、免费备份_搬瓦工vps

【AD】RackNerd 推出的 KVM VPS 特价优惠,在纽约、西雅图、圣何塞和阿什本每年仅需 12.88 美元!