AI/Pytorch
Pytorch - Indexing(1)
hundredeuk2
2022. 9. 27. 16:46
728x90
- Torch.index_select로 indices를 지정하여 행렬에 차원을 기준으로 원하는 값을 뽑을 수 있다.
- 이것을 응용하여 행렬의 대각선에 잇는 원소들이나, 세로, 가로 축의 원소들만 뽑아올 수 있다.
- Ex) [[1,2],[3,4]] 에서 1과 3을 추출하고 싶으면 아래와 같다.
import torch
Matrix = torch.Tensor([[1, 2],
[3, 4]])
indices = torch.tensor([0])
A = torch.index_select(A, 1, indices)
output = A.view(1,2)
print(output)
Out : tensor([[2., 4.]])
- 코드의 이해를 위해 사진을 참고하면 axis 1을 기준으로 0번째 원소만을 뽑는다.

- 따라서 2와 4를 뽑기 위해선
indices = torch.tensor([1])
A = torch.index_select(A, 1, indices)
마지막으로 [1,2] 를 뽑기 위해선
indices = torch.tensor([0])
A = torch.index_select(A, 0, indices)
한번 실습으로 [3,4] 를 뽑아보는 연습을 해보는 것이 좋다.