0

I am following this tutorial for Pytorch and there is a line of code that makes no sense to me in the derived class MnistModule method training_step of the nn.Module class.

The line is out = self(images)

Please can someone explain to me what is happening here? Is this correct or not and if this is convention to follow.

Thanks

Here's the snippet

 class MnistModel(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(input_size, num_classes) def forward(self, xb): xb = xb.reshape(-1, 784) out = self.linear(xb) return out def training_step(self, batch): images, labels = batch out = self(images) # Generate predictions loss = F.cross_entropy(out, labels) # Calculate loss print(type(out)) return loss 
2

1 Answer 1

3

It refers to an instance of MnistModel, the same as in any other method defined by the class. The only thing odd is that self is called, but that's explained by the fact that nn.Module defines __call__, so all instances of MnistModel are themselves callable.

out = self(images) is equivalent to out = self.__call__(images).

Sign up to request clarification or add additional context in comments.

Comments

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.