I do not understand how to evaluate this expression:
x.view(*(x.shape[:-2]),-1).mean(-1)`,
if x.shape == (N, C, H, W).
What does the asterisk * stand for? And what is mean(-1) ?
I do not understand how to evaluate this expression:
x.view(*(x.shape[:-2]),-1).mean(-1)`,
if x.shape == (N, C, H, W).
What does the asterisk * stand for? And what is mean(-1) ?
What is
*?
For .view() pytorch expects the new shape to be provided by individual int arguments (represented in the doc as *shape). The asterisk (*) can be used in python to unpack a list into its individual elements, thus passing to view the correct form of input arguments it expects.
So, in your case, x.shape is (N, C, H, W), if you were to pass x.shape[:-2] without the asterisk, you would get x.view((N, C), -1) - which is not what view() expects. Unpacking (N, C) using the asterisk results with view receiving view(N, C, -1) arguments as it expects. The resulting shape is (N, C, H*W) (a 3D tensor instead of 4).
What is
mean(-1)?
Simply look at the documentation of .mean(): the first argument is a dim argument. That is x.mean(-1) applies mean along the last dimension. In your case, since keepdim=False by default, your output will be a (N, C) sized tensor where each element correspond to the mean value along both spatial dimensions.
This is equivalent to
x.mean(-1).mean(-1)