Computer >> 컴퓨터 >  >> 프로그램 작성 >> Python

PyTorch에서 텐서의 요소를 정렬하는 방법은 무엇입니까?

<시간/>

PyTorch에서 텐서의 요소를 정렬하려면 torch.sort() 메서드를 사용할 수 있습니다. 이 메서드는 두 개의 텐서를 반환합니다. 첫 번째 텐서는 요소 값을 정렬한 텐서이고 두 번째 텐서는 원래 텐서에 있는 요소 인덱스의 텐서입니다. 2D 텐서를 행과 열로 계산할 수 있습니다.

단계

  • 필요한 라이브러리를 가져옵니다. 다음 모든 Python 예제에서 필수 Python 라이브러리는 torch입니다. . 이미 설치했는지 확인하십시오.

  • PyTorch 텐서를 만들고 인쇄합니다.

  • 위에서 만든 텐서의 요소를 정렬하려면 torch.sort(input, dim)을 계산합니다. . 이 값을 새 변수 "v"에 할당 .여기에 입력 입력 텐서이며 dim 요소가 정렬되는 차원입니다. 요소를 행 방향으로 정렬하려면 dim을 1로 설정하고 요소를 열 방향으로 정렬하려면 dim 0으로 설정됩니다.

  • 정렬된 값이 있는 Tensor는 v[0]으로 액세스할 수 있습니다. 정렬된 요소의 인덱스 텐서는 v[1] .

  • 정렬된 값으로 Tensor를 출력하고 정렬된 값의 인덱스로 Tensor를 출력합니다.

예시 1

다음 Python 프로그램은 1Dtensor의 요소를 정렬하는 방법을 보여줍니다.

# Python program to sort elements of a tensor
# import necessary library
import torch

# Create a tensor
T = torch.Tensor([2.334,4.433,-4.33,-0.433,5, 4.443])
print("Original Tensor:\n", T)

# sort the tensor T
# it sorts the tensor in ascending order
v = torch.sort(T)

# print(v)
# print tensor of sorted value
print("Tensor with sorted value:\n", v[0])

# print indices of sorted value
print("Indices of sorted value:\n", v[1])

출력

Original Tensor:
   tensor([ 2.3340, 4.4330, -4.3300, -0.4330, 5.0000, 4.4430])
Tensor with sorted value:
   tensor([-4.3300, -0.4330, 2.3340, 4.4330, 4.4430, 5.0000])
Indices of sorted value:
   tensor([2, 3, 0, 1, 5, 4])

예시 2

다음 Python 프로그램은 2Dtensor의 요소를 정렬하는 방법을 보여줍니다.

# Python program to sort elements of a 2-D tensor
# import the library
import torch

# Create a 2-D tensor
T = torch.Tensor([[2,3,-32],
                  [43,4,-53],
                  [4,37,-4],
                  [3,-75,34]])
print("Original Tensor:\n", T)

# sort tensor T
# it sorts the tensor in ascending order
v = torch.sort(T)

# print(v)
# print tensor of sorted value
print("Tensor with sorted value:\n", v[0])

# print indices of sorted value
print("Indices of sorted value:\n", v[1])
print("Sort tensor Column-wise")
v = torch.sort(T, 0)

# print(v)
# print tensor of sorted value
print("Tensor with sorted value:\n", v[0])

# print indices of sorted value
print("Indices of sorted value:\n", v[1])
print("Sort tensor Row-wise")
v = torch.sort(T, 1)

# print(v)
# print tensor of sorted value
print("Tensor with sorted value:\n", v[0])

# print indices of sorted value
print("Indices of sorted value:\n", v[1])

출력

Original Tensor:
tensor([[ 2., 3., -32.],
        [ 43., 4., -53.],
        [ 4., 37., -4.],
        [ 3., -75., 34.]])
Tensor with sorted value:
tensor([[-32., 2., 3.],
         [-53., 4., 43.],
         [ -4., 4., 37.],
         [-75., 3., 34.]])
Indices of sorted value:
tensor([[2, 0, 1],
         [2, 1, 0],
         [2, 0, 1],
         [1, 0, 2]])
Sort tensor Column-wise
Tensor with sorted value:
tensor([[ 2., -75., -53.],
         [ 3., 3., -32.],
         [ 4., 4., -4.],
         [ 43., 37., 34.]])
Indices of sorted value:
tensor([[0, 3, 1],
         [3, 0, 0],
         [2, 1, 2],
         [1, 2, 3]])
Sort tensor Row-wise
Tensor with sorted value:
tensor([[-32., 2., 3.],
         [-53., 4., 43.],
         [ -4., 4., 37.],
         [-75., 3., 34.]])
Indices of sorted value:
tensor([[2, 0, 1],
         [2, 1, 0],
         [2, 0, 1],
         [1, 0, 2]])