AI/Pytorch

Pytorch - Indexing(1)

hundredeuk2 2022. 9. 27. 16:46
728x90
  • Torch.index_select indices 지정하여 행렬에 차원을 기준으로 원하는 값을 뽑을 있다.
  • 이것을 응용하여 행렬의 대각선에 잇는 원소들이나, 세로, 가로 축의 원소들만 뽑아올 있다.
  • Ex) [[1,2],[3,4]] 에서 1과 3 추출하고 싶으면 아래와 같다.
import torch
	
Matrix = torch.Tensor([[1, 2],
	                  [3, 4]])
	
indices = torch.tensor([0])
A = torch.index_select(A, 1, indices)
output = A.view(1,2)
	
print(output)

Out : tensor([[2., 4.]])
  • 코드의 이해를 위해 사진을 참고하면 axis 1 기준으로 0번째 원소만을 뽑는다.

 

  • 따라서 2 4 뽑기 위해선
indices = torch.tensor([1])
A = torch.index_select(A, 1, indices)

마지막으로 [1,2]  뽑기 위해선

indices = torch.tensor([0])
A = torch.index_select(A, 0, indices)

한번 실습으로 [3,4] 를 뽑아보는 연습을 해보는 것이 좋다.