I understand that Flatten removes all of the dimensions except for one. For example, I understand flatten():
> t = torch.ones(4, 3) > t tensor([[1., 1., 1.], [1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]) > flatten(t) tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]) However, I don't get Flatten, especially I don't get meaning of this snippet from the doc:
>>> input = torch.randn(32, 1, 5, 5) >>> m = nn.Sequential( >>> nn.Conv2d(1, 32, 5, 1, 1), >>> nn.Flatten() >>> ) >>> output = m(input) >>> output.size() torch.Size([32, 288]) I felt the output should have size [160], because 32*5=160.
Q1. So why it outputted size [32,288]?
Q2. I also don't get meaning of shape information given in the doc:
Q3. And also meaning of parameters:

