Computer >> 컴퓨터 >  >> 프로그램 작성 >> Python

PyTorch에서 두 텐서를 어떻게 비교합니까?

<시간/>

PyTorch에서 두 텐서를 요소별로 비교하기 위해 torch.eq()를 사용합니다. 방법. 해당 요소를 비교하고 "True" 를 반환합니다. 두 요소가 같으면 "False"를 반환합니다. . 차원이 같거나 다른 두 텐서를 비교할 수 있지만 두 텐서의 크기는 단일 차원이 아닌 경우 일치해야 합니다.

단계

  • 필요한 라이브러리를 가져옵니다. 다음 모든 Python 예제에서 필수 Python 라이브러리는 torch입니다. . 이미 설치했는지 확인하십시오.

  • PyTorch 텐서를 만들고 인쇄합니다.

  • torch.eq(input1, input2) 계산 . "True"의 텐서를 반환합니다. 및/또는 "거짓" . 텐서를 요소별로 비교하고 해당 요소가 같으면 True를 반환하고 그렇지 않으면 False를 반환합니다.

  • 반환된 텐서를 인쇄합니다.

예시 1

다음 Python 프로그램은 두 개의 1차원 텐서 요소를 비교하는 방법을 보여줍니다.

# import necessary library
import torch

# Create two tensors
T1 = torch.Tensor([2.4,5.4,-3.44,-5.43,43.5])
T2 = torch.Tensor([2.4,5.5,-3.44,-5.43, 43])

# print above created tensors
print("T1:", T1)
print("T2:", T2)

# Compare tensors T1 and T2 element-wise
print(torch.eq(T1, T2))

출력

T1: tensor([ 2.4000, 5.4000, -3.4400, -5.4300, 43.5000])
T2: tensor([ 2.4000, 5.5000, -3.4400, -5.4300, 43.0000])
tensor([ True, False, True, True, False])

예시 2

다음 Python 프로그램은 두 개의 2차원 텐서 요소를 비교하는 방법을 보여줍니다.

# import necessary library
import torch

# create two 4x3 2D tensors
T1 = torch.Tensor([[2,3,-32],
                  [43,4,-53],
                  [4,37,-4],
                  [3,75,34]])
T2 = torch.Tensor([[2,3,-32],
                  [4,4,-53],
                  [4,37,4],
                  [3,-75,34]])

# print above created tensors
print("T1:", T1)
print("T2:", T2)

# Conpare tensors T1 and T2 element-wise
print(torch.eq(T1, T2))

출력

T1: tensor([[ 2., 3., -32.],
            [ 43., 4., -53.],
            [ 4., 37., -4.],
            [ 3., 75., 34.]])
T2: tensor([[ 2., 3., -32.],
            [ 4., 4., -53.],
            [ 4., 37., 4.],
            [ 3., -75., 34.]])
tensor([[ True, True, True],
         [False, True, True],
         [ True, True, False],
         [ True, False, True]])

예시 3

다음 Python 프로그램은 1차원 텐서를 2차원 텐서와 요소별로 비교하는 방법을 보여줍니다.

# import necessary library
import torch

# Create two tensors
T1 = torch.Tensor([2.4,5.4,-3.44,-5.43,43.5])
T2 = torch.Tensor([[2.4,5.5,-3.44,-5.43, 7],
                  [1.0,5.4,3.88,4.0,5.78]])

# Print above created tensors
print("T1:", T1)
print("T2:", T2)

# Compare the tensors T1 and T2 element-wise
print(torch.eq(T1, T2))

출력

T1: tensor([ 2.4000, 5.4000, -3.4400, -5.4300, 43.5000])
T2: tensor([[ 2.4000, 5.5000, -3.4400, -5.4300, 7.0000],
            [ 1.0000, 5.4000, 3.8800, 4.0000, 5.7800]])
tensor([[ True, False, True, True, False],
         [False, True, False, False, False]])