Computer >> 컴퓨터 >  >> 프로그램 작성 >> Python

PyTorch에서 텐서를 짜거나 빼는 방법은 무엇입니까?

<시간/>

텐서를 짜기 위해 torch.squeeze()를 사용합니다. 방법. 입력 텐서의 모든 차원을 가진 새 텐서를 반환하지만 크기 1을 제거합니다. 예를 들어 입력 텐서의 모양이 (M ☓ 1 ☓ N ☓ 1 ☓ P)이면 압축된 텐서는 모양( 남 ☓ 남 ☓ 피).

텐서를 풀기 위해 torch.unsqueeze()를 사용합니다. 방법. 특정 위치에 삽입된 크기 1의 새로운 텐서 차원을 반환합니다.

단계

  • 필요한 라이브러리를 가져옵니다. 다음 모든 Python 예제에서 필수 Python 라이브러리는 torch입니다. . 이미 설치했는지 확인하십시오.

  • 텐서를 만들고 인쇄하세요.

  • torch.squeeze(입력) 계산 . 크기 1을 압축(제거)하고 입력의 다른 모든 차원과 함께 텐서를 반환합니다. 텐서.

  • torch.unsqueeze(입력, 희미함) 계산 . 주어진 dim에 크기 1의 새로운 차원을 삽입하고 텐서를 반환합니다.

  • 압착 및/또는 압착되지 않은 텐서를 인쇄합니다.

예시 1

# Python program to squeeze and unsqueeze a tensor
# import necessary library
import torch

# Create a tensor of all one
T = torch.ones(2,1,2) # size 2x1x2
print("Original Tensor T:\n", T )
print("Size of T:", T.size())

# Squeeze the dimension of the tensor
squeezed_T = torch.squeeze(T) # now size 2x2
print("Squeezed_T\n:", squeezed_T )
print("Size of Squeezed_T:", squeezed_T.size())

출력

Original Tensor T:
tensor([[[1., 1.]],
         [[1., 1.]]])
Size of T: torch.Size([2, 1, 2])
Squeezed_T
: tensor([[1., 1.],
         [1., 1.]])
Size of Squeezed_T: torch.Size([2, 2])

예시 2

# Python program to squeeze and unsqueeze a tensor
# import necessary library
import torch

# create a tensor
T = torch.Tensor([1,2,3]) # size 3
print("Original Tensor T:\n", T )
print("Size of T:", T.size())

# Squeeze the tensor in dimension o or column dim
unsqueezed_T = torch.unsqueeze(T, dim = 0) # now size 1x3
print("Unsqueezed T\n:", unsqueezed_T )
print("Size of UnSqueezed T:", unsqueezed_T.size())

# Squeeze the tensor in dimension 1 or row dim
unsqueezed_T = torch.unsqueeze(T, dim = 1) # now size 3x1
print("Unsqueezed T\n:", unsqueezed_T )
print("Size of Unsqueezed T:", unsqueezed_T.size())

출력

Original Tensor T:
   tensor([1., 2., 3.])
Size of T: torch.Size([3])
Unsqueezed T
: tensor([[1., 2., 3.]])
Size of UnSqueezed T: torch.Size([1, 3])
Unsqueezed T
: tensor([[1.],
         [2.],
         [3.]])
Size of Unsqueezed T: torch.Size([3, 1])