Suppose I have two tensors S and T defined as:
S = torch.rand((3,2,1))
T = torch.ones((3,2,1))
We can think of these as containing batches of tensors with shapes (2, 1). In this case, the batch size is 3.
I want to concatenate all possible pairings between batches. A single concatenation of batches produces a tensor of shape (4, 1). And there are 3*3 combinations so ultimately, the resulting tensor C must have a shape of (3, 3, 4, 1).
One solution is to do the following:
for i in range(S.shape[0]):
for j in range(T.shape[0]):
C[i,j,:,:] = torch.cat((S[i,:,:],T[j,:,:]))
But the for loop doesn't scale well to large batch sizes. Is there a PyTorch command to do this?