Torch.unsqueeze and squeeze

torch.unsqueeze
와 torch.squeeze
는 가끔 헷갈리는 개념이다. 머리로는 이해하는 것 같지만 막상 쓰려고 할때 팍 이해가 안가는 느낌이 많이 든다. 그래서 오늘 조금 정리해보려고 한다.
Tensor 와 차원(Dimension)
Pytorch 에서 Tensor 는 숫자값을 담는 다차원 배열이라고 생각하면 편하다. N 차원 텐서들을 아래와 같은 형태로 정의할 수 있다.
zero_dim = torch.tensor(5)
one_dim = torch.tensor([1, 2, 3])
two_dim = torch.tensor([[1,2], [3,4]])
위와 같은 방법으로 N 차원을 정의 가능하고, Tensor 의 각차원에 몇개의 원소가 있는지 확인해보기 위해서는 shape
라는 속성을 확인해보면 된다.
zero_dim.shape => torch.Size([])
one_dim.shape => torch.Size([3])
two_dim.shape => torch.Size([2, 2])
- 우리가 만든 이차원 텐서
two_dim
기준으로 설명해보면 첫번째 차원(축 0) 에는 2개의 원소가 있고, 두번째 차원(축 1) 에도 2개의 원소가 있음을 뜻한다.
- 우리가 만든 이차원 텐서
Unsqueeze
unsqueeze
는 이름만 봐도 살짝 유추가능하듯이 텐서에 크기가 1인 새로운 차원을 추가하는 함수이다. dim
이라는 parameter 를 추가적으로 넣을 수 있는데, 어떤 인덱스에 크기가 1인 새로운 차원을 넣을지 우리가 정할 수 있는 인덱스 이다.
>>> x = torch.tensor([1, 2, 3, 4])
>>> x.shape
torch.Size([4])
>>> u_t = torch.unsqueeze(x, 0)
>>> u_t
tensor([[ 1, 2, 3, 4]])
>>> u_t.shape
>>> torch.Size([1, 4])
위의 예시를 보면 가장 바깥쪽(축 0) 에 차원을 추가하게 되면 1개의 행과 4개의 열을 가진 이차원 텐서로 변하는 것을 확인할 수 있다. 즉, 원래는 4개의 열을 가진 텐서에서 하나의 행과 4개의 열을 가진 텐서로 unsqueeze 되었다.
일차원으로만 하면 이해가 어렴풋이 잘 안갈수도 있으니 이차원으로 다시 시도해보자.
>>> y = torch.tensor([[1, 2], [3, 4]])
>>> y.shape
torch.Size([2, 2])
>>> y_unsqueezed_0 = y.unsqueeze(dim=0)
>>> y_unsqueezed_0.shape
torch.Size([1, 2, 2])
>>> y_unsqueezed_1 = y.unsqueeze(dim=1)
>>> y_unsqueezed_1.shape
torch.Size([2, 1, 2])
>>> y_unsqueezed_2 = y.unsqueeze(dim=2)
>>> y_unsqueezed_2.shape
torch.Size([2, 2, 1])
위의 예시를 보면 어느 차원에 차원을 추가하게 될지에 따라서 shape
속성이 변하는 것을 확인할 수 있다. 0 차원에 추가했을때는 [[[1, 2], [3, 4]]]
가 될테고, 1 차원에 추가했을때는 [[[1,2]], [[3,4]]]
와 같은 형태가 되고, 2차원에 추가했을때에는 [[[1], [2]], [[3],[4]]]
와 같이 된다.
when to use?
언제 사용하는지는 대부분 알수도 있지만, 보통 우리가 만든 텐서를 배치처리를 하는 모델의 입력값으로 넣어야 할때 입력값으로 [배치크기, 채널, 높이, 너비]
와 같이 보통 제일 바깥쪽에 N(배치 차원)을 추가해야 하므로, t.unsqueeze(0)
을 하는 경우가 많다.
Squeeze
그렇다면 torch.squeeze
는 무엇일까? squeeze 는 차원을 줄이는 역할을 한다. squeeze 역시도 dim
이라는 파라미터를 가지고 있는데 해당 차원을 줄이는 역할을 한다. 하지만 공식문서에 적힌 설명을 보면 크기가 1인 차원에 대해서만 차원을 줄인다고 한다. 공식문서에 적힌 예를 보면 (A×1×B)
차원의 텐서에 squeeze(input, 0)
을 하면 텐서가 변하지 않고, squeeze(input, 1)
를 하는 경우에는 크기가 1인 차원이 없어져 (A×B)
차원의 텐서로 변한다고 한다.
사실 여기는 unsqueeze 를 이해했다면 이해가 잘될거라고 생각하고 별 다르게 더 짚고 넘어가지는 않겠다.
Subscribe to my newsletter
Read articles from roach directly inside your inbox. Subscribe to the newsletter, and don't miss out.
Written by

roach
roach
https://www.linkedin.com/feed/update/urn:li:activity:7092144087058825216/