10

This is potentially a very easy question. I just started with PyTorch lightning and can't figure out how to receive the output of my model after training.

I am interested in both predictions of y_train and y_test as an array of some sort (PyTorch tensor or NumPy array in a later step) to plot next to the labels using different scripts.

dataset = Dataset(train_tensor) val_dataset = Dataset(val_tensor) training_generator = torch.utils.data.DataLoader(dataset, **train_params) val_generator = torch.utils.data.DataLoader(val_dataset, **val_params) mynet = Net(feature_len) trainer = pl.Trainer(gpus=0,max_epochs=max_epochs, logger=logger, progress_bar_refresh_rate=20, callbacks=[early_stop_callback], num_sanity_val_steps=0) trainer.fit(mynet) 

In my lightning module I have the functions:

def __init__(self, random_inputs): def forward(self, x): def train_dataloader(self): def val_dataloader(self): def training_step(self, batch, batch_nb): def training_epoch_end(self, outputs): def validation_step(self, batch, batch_nb): def validation_epoch_end(self, outputs): def configure_optimizers(self): 

Do I need a specific predict function or is there any already implemented way I don't see?

4 Answers 4

15

I disagree with these answers: OP's question appears to be focused on how he should use a model trained in lightning to get predictions in general, rather than for a specific step in the training pipeline. In which case, a user shouldn't need to go anywhere near a Trainer object - those are not intended to be used for general prediction and the answers above are therefore encouraging an anti-pattern (carrying a trainer object around with us every time we want to do some prediction) to anyone who reads these answers in the future.

Instead of using trainer, we can get predictions straight from the Lightning module that has been defined: if I have my (trained) instance of the lightning module model = Net(...) then using that model to get predictions on inputs x is achieved simply by calling model(x) (so long as the forward method has been implemented/overriden on the Lightning module - which is required).

In contrast, Trainer.predict() is not the intended means of obtaining predictions using your trained model in general. The Trainer API provides methods to tune, fit and test your LightningModule as part of your training pipeline, and it looks to me that the predict method is provided for ad-hoc predictions on separate dataloaders as part of less 'standard' training steps.

The OP's question (Do I need a specific predict function or is there any already implemented way I don't see?) implies that they're not familiar with the way that the forward() method works in PyTorch, but asks whether there's already a method for prediction that they can't see. A full answer therefore requires a further explanation of where the forward() method fits into the prediction process:

The reason model(x) works is because Lightning Modules are subclasses of torch.nn.Module and these implement a magic method called __call__() which means that we can call the class instance as if it were a function. __call__() in turn calls forward(), which is why we need to override that method in our Lightning module.

NB. because forward is only one piece of the logic called when we use model(x), it is always recommended to use model(x) instead of model.forward(x) for prediction unless you have a specific reason to deviate.

Sign up to request clarification or add additional context in comments.

2 Comments

It's good that you pointed out how a network can be run directly since when starting with Pytorch Lightning without ever having used Pytorch directly hides the underlying mechanisms. I would argue that it's still reasonable in situations to use the Trainer class even for prediction, as it handles putting your model and data onto the GPU, it can call certain hooks, why reinvent the wheel? It's not an antipattern, rename the class to Commander and much of your argument is invalid. I still think it's good you pointed it out, but antipattern is too strong.
I think advice on how to get predictions from the model needs to include how to run it on a gpu, model.eval(), turning off gradients and all the other things that Lightning has done for the user so far. Simply calling model(x) is unlikely to do what the user wants.
7

You can try prediction in two ways:

  1. Perform batched prediction as per normal.
test_dataset = Dataset(test_tensor) test_generator = torch.utils.data.DataLoader(test_dataset, **test_params) mynet.eval() batch = next(iter(test_generator)) with torch.no_grad(): predictions_single_batch = mynet(**unpacked_batch) 
  1. Instantiate a new Trainer object. Trainer's predict API allows you to pass an arbitrary DataLoader.
test_dataset = Dataset(test_tensor) test_generator = torch.utils.data.DataLoader(test_dataset, **test_params) predictor = pl.Trainer(gpus=1) predictions_all_batches = predictor.predict(mynet, dataloaders=test_generator) 

 I've noticed that in the second case, Pytorch Lightning takes care of stuff like moving your tensors and model onto (not off of) GPU, aligned with its potential to perform distributed predictions. It also doesn't returns any gradient-attached loss values, which helps dispense of the need to write boilerplate code like with torch.no_grad().

2 Comments

An important point to this answer is that you need to create a new trainer in some circumstances upon testing/prediction. the documentation for predict explains accelerators that spawn new processes won't return predictions (so they will not sync if you want to gather them later) e.g., under DDP. So you can train under DDP but cannot do inference under DDP as it's not supported.
I haven't tested this, but my understanding for this statement True by default except when an accelerator that spawns processes is used (not supported). is that return_prediction is not supported if we set up the Trainer with ddp_spawn instead of ddp. There might be some complications or bottleneck with mp.spawn(). I do agree that setting up a predictor with Trainer is semantically quite confusing.
4

You can use the predict method as well. Here is the example from the document. https://pytorch-lightning.readthedocs.io/en/latest/starter/introduction_guide.html

class LitMNISTDreamer(LightningModule): def forward(self, z): imgs = self.decoder(z) return imgs def predict_step(self, batch, batch_idx: int , dataloader_idx: int = None): return self(batch) model = LitMNISTDreamer() trainer.predict(model, datamodule) 

5 Comments

The predict method seems to have been added in the meantime. I was just baffled it wasnt available before.
Yeah they seem crazy good at adding new stuff
What's the difference between using trainer.predict() and using model()? Does the first option automatically wrap the call inside eval mode and no_grad?
The trainer puts your model and input on the graphics card, limits the number of batches (if set, see trainer __init__ args), performs distributed computation and so on.
Is there anyway to run "predict" as an iterator? I don't really want to load all of my data into memory
0

The trainer has a test function. You might want to have a look at the original documents from pytorch-lightning for more details: https://pytorch-lightning.readthedocs.io/en/latest/trainer.html#testing.

3 Comments

Awesome. how did I not find this myself. Most likely because of all the errors I got. But sorted everything out.
Seems to have a predict function now: github.com/PyTorchLightning/pytorch-lightning/issues/1853
I don't believe .test allows you to return a tensor (It's purpose is largely to collects logs via the logging API - which don't currently accept lists or torch/np.arrays). So .predict() appears to be the way forward.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.