Computer >> 컴퓨터 >  >> 프로그램 작성 >> Python

PyTorch에서 텐서의 k 번째 및 상위 k 요소를 찾는 방법은 무엇입니까?

<시간/>

PyTorch는 torch.kthvalue() 메서드를 제공합니다. 텐서의 k번째 요소를 찾습니다. 오름차순으로 정렬된 텐서의 k번째 요소 값과 원래 텐서의 요소 인덱스를 반환합니다.

torch.topk() 메서드는 상위 "k" 요소를 찾는 데 사용됩니다. 텐서에서 상위 "k" 또는 가장 큰 "k" 요소를 반환합니다.

단계

  • 필요한 라이브러리를 가져옵니다. 다음 모든 Python 예제에서 필수 Python 라이브러리는 torch입니다. . 이미 설치했는지 확인하십시오.

  • PyTorch 텐서를 만들고 인쇄합니다.

  • torch.kthvalue(input, k) 계산 . 두 개의 텐서를 반환합니다. 이 두 텐서를 두 개의 새 변수 "값"에 할당 및 "색인" . 여기서 입력은 텐서이고 k는 정수입니다.

  • torch.topk(input, k) 계산 . 두 개의 텐서를 반환합니다. 첫 번째 텐서는 상위 "k" 요소의 값을 갖고 두 번째 텐서는 원래 텐서에서 이러한 요소의 인덱스를 갖습니다. 이 두 텐서를 새 변수 "값"에 할당합니다. 및 "인덱스" .

  • 텐서의 k번째 요소의 값과 인덱스, 텐서의 상위 "k" 요소의 값과 인덱스를 인쇄합니다.

예시 1

이 파이썬 프로그램은 텐서의 k번째 요소를 찾는 방법을 보여줍니다.

# Python program to find k-th element of a tensor
# import necessary library
import torch

# Create a 1D tensor
T = torch.Tensor([2.334,4.433,-4.33,-0.433,5, 4.443])
print("Original Tensor:\n", T)

# Find the 3rd element in sorted tensor. First it sorts the
# tensor in ascending order then returns the kth element value
# from sorted tensor and the index of element in original tensor
value, index = torch.kthvalue(T, 3)

# print 3rd element with value and index
print("3rd element value:", value)
print("3rd element index:", index)

출력

Original Tensor:
   tensor([ 2.3340, 4.4330, -4.3300, -0.4330, 5.0000, 4.4430])
3rd element value: tensor(2.3340)
3rd element index: tensor(0)

예시 2

다음 Python 프로그램은 텐서의 상위 "k" 또는 가장 큰 "k" 요소를 찾는 방법을 보여줍니다.

# Python program to find to top k elements of a tensor
# import necessary library
import torch

# Create a 1D tensor
T = torch.Tensor([2.334,4.433,-4.33,-0.433,5, 4.443])
print("Original Tensor:\n", T)

# Find the top k=2 or 2 largest elements of the tensor
# returns the 2 largest values and their indices in original
# tensor
values, indices = torch.topk(T, 2)

# print top 2 elements with value and index
print("Top 2 element values:", values)
print("Top 2 element indices:", indices)
가 있는 요소

출력

Original Tensor:
   tensor([ 2.3340, 4.4330, -4.3300, -0.4330, 5.0000, 4.4430])
Top 2 element values: tensor([5.0000, 4.4430])
Top 2 element indices: tensor([4, 5])