대각행렬 추출하기 (gather와 expand를 활용하여)
개인적으로 gather를 이해할 때 조금 CNN 컨볼루젼 커널을 생각하면 코딩을 하면 편하다고 생각한다.
! 본 글은 axis=1의 관점으로만 해석했기 때문에 gather를 이해하기에 최적화된 방법은 아닙니다.
- 2D tensor 대각선 요소
앞서 봤던 axis개념으로 접근을 하고 gather를 이해하려면 세로를 차례대로 가로의 indecies의 요소들의 위치를 뽑는다.
따라서 예시에서 4와 -2를 뽑기 위해선 axis=1을 기준으로는 순차적으로 index = 0, index =1 번을 뽑아야 구현된다.
그리고 gather라는 매서드는 indecies를 tensor로 받아야하기 때문에 다음과 같은 코드로 구현할 수 있다.
import torch
A = torch.Tensor([[4, 5],[1, -2]])
output = A
output = torch.gather(output, 1, torch.tensor([[0],[1]]))
print(output)
Out으로 [[4],[-2]]가 나온다면 정상이다. 이의 형태를 바꾸기 위해선 view나 reshape 양식을 사용하면 된다.
그렇다면 3차원의 tensor일 때 어떻게 해야할까.
- 3D tensor 대각선 요소
이번에도 axis=1을 기준으로 확인한다면 3x3이라면 indecies를 [0,1,2]로 지정해줘야한다.
근데 문제는 3D일때는 채널이라는 것이 존재하는데, axis=2라는 관점이 존재하기 때문이다. 3x3의 행렬이 여러개 겹쳐 있을 수 있다는 뜻이다. gather를 컨볼루션 연산의 커널로 연상하듯 채널도 이미지 데이터로 이해하면 편하다.
이미지 픽셀 데이터가 가로x세로로 존재하며 RGB값으로 3개의 데이터 값이 합쳐져야 이미지 하나의 데이터로 표현된다.
따라서 [0,1,2]의 벡터를 채널의 갯수만큼 늘려줘야하는데, 이 때 쓰는 매서드가 expand이다.
(expand는 추후에 제대로 설명을 다시 하겠다. 간단히만 설명하면 원하는 차원 크기로 텐서를 반복하여 생성한다.)
여기선 어떤 데이터에 따라 유동적으로 변하는 Function을 만들기 위해 axis = 0,1의 shape에 따라 [0], [0,1], [0,1,2]를 유동적으로 생성한 후, 채널의 갯수에 맞춰 유동적으로 expand를 사용하여 변환시켜준다.
import torch
def get_diag_element_3D(A):
C, H, W = A.size()
if H > 2 :
H = 3
tmp = [[[x for x in range(H)]]]
indecies = torch.tensor(tmp).expand(C,-1,H)
output = A
output = torch.gather(output, 1, indecies)
output = output.squeeze()
return output
'AI > Pytorch' 카테고리의 다른 글
Pytorch - Ones와 Zeros (0) | 2022.09.29 |
---|---|
Pytorch - Indexing(1) (0) | 2022.09.27 |