728x90

대각행렬 추출하기 (gather와 expand를 활용하여)

개인적으로 gather를 이해할 때 조금 CNN 컨볼루젼 커널을 생각하면 코딩을 하면 편하다고 생각한다.

! 본 글은 axis=1의 관점으로만 해석했기 때문에 gather를 이해하기에 최적화된 방법은 아닙니다.

  • 2D tensor 대각선 요소

앞서 봤던 axis개념으로 접근을 하고 gather를 이해하려면 세로를 차례대로 가로의 indecies의 요소들의 위치를 뽑는다.

따라서 예시에서 4와 -2를 뽑기 위해선 axis=1을 기준으로는 순차적으로 index = 0, index =1 번을 뽑아야 구현된다.

그리고 gather라는 매서드는 indecies를 tensor로 받아야하기 때문에 다음과 같은 코드로 구현할 수 있다.

import torch

A = torch.Tensor([[4, 5],[1, -2]])
output = A
output = torch.gather(output, 1, torch.tensor([[0],[1]]))
print(output)

Out으로 [[4],[-2]]가 나온다면 정상이다. 이의 형태를 바꾸기 위해선 view나 reshape 양식을 사용하면 된다.

그렇다면 3차원의 tensor일 때 어떻게 해야할까.

  • 3D tensor 대각선 요소

이번에도 axis=1을 기준으로 확인한다면 3x3이라면 indecies를 [0,1,2]로 지정해줘야한다.

근데 문제는 3D일때는 채널이라는 것이 존재하는데, axis=2라는 관점이 존재하기 때문이다. 3x3의 행렬이 여러개 겹쳐 있을 수 있다는 뜻이다. gather를 컨볼루션 연산의 커널로 연상하듯 채널도 이미지 데이터로 이해하면 편하다.

이미지 픽셀 데이터가 가로x세로로 존재하며 RGB값으로 3개의 데이터 값이 합쳐져야 이미지 하나의 데이터로 표현된다.

 

따라서 [0,1,2]의 벡터를 채널의 갯수만큼 늘려줘야하는데, 이 때 쓰는 매서드가 expand이다.

(expand는 추후에 제대로 설명을 다시 하겠다. 간단히만 설명하면 원하는 차원 크기로 텐서를 반복하여 생성한다.)

여기선 어떤 데이터에 따라 유동적으로 변하는 Function을 만들기 위해 axis = 0,1의 shape에 따라 [0], [0,1], [0,1,2]를 유동적으로 생성한 후, 채널의 갯수에 맞춰 유동적으로 expand를 사용하여 변환시켜준다.

import torch

def get_diag_element_3D(A):
    C, H, W = A.size()
    if H > 2 :
        H = 3
    tmp = [[[x for x in range(H)]]]
    indecies = torch.tensor(tmp).expand(C,-1,H)    
    output = A
    output = torch.gather(output, 1, indecies)
    output = output.squeeze()
    return output

'AI > Pytorch' 카테고리의 다른 글

Pytorch - Ones와 Zeros  (0) 2022.09.29
Pytorch - Indexing(1)  (0) 2022.09.27

+ Recent posts