본문 바로가기
카테고리 없음

[pytorch] squeeze사용법: 텐서의 차원 제거하기

by 최신 마트 정보 2024. 9. 23.

squeeze함수를 이용하면 텐서의 차원중에서 차원이 1인 차원의 축을 없애준다. squeeze함수의 변수에 아무것도 입력해주지 않으면 차원이 1인 차원을 모두 제거해준다. 이점이 주의할점이다. 이번 글에서는 squeeze함수 실제 사용코드를 살펴보도록 하자. 가자.!

squeeze 사용 python 코드

아래와 같이 tensor를 정의하고 squeeze를 이용하여 텐서의 차원중 1인 모든 차원을 제거한다.

import torch
x = torch.rand(16, 1, 1,22, 256)
x = x.squeeze() #[16,1,1,222,256] -> [16,22,256]

 

 

주의할사항이 있다. squeeze의 변수에 아무것도 입력하지 않으면 모든 1인 차원을 없애준다. 만약에 batch size가 1이라면 batch size에 해당하는 차원도 없애준다. 따라서 squeeze의 변수에 dim을 잘 입력해줘야 한다.

import torch

#배치사이즈가 1인 텐서에 대하여
x = torch.rand(1, 1, 22, 256)
x = x.squeeze() # [1, 1, 22, 256] -> [22, 256] #모든 1인 차원 없앰

#배치사이즈가 1인 텐서에 대하여
x = torch.rand(1, 1, 25, 256)
x = x2.squeeze(dim=1) # [1, 1, 22, 256] -> [1, 22, 256]
#squeeze에 dim 변수를 넣어서 특정 차원만 없애준다.

 

댓글