6

I'm studying "Deep Reinforcement Learning" and build my own example after pytorch's REINFORCEMENT LEARNING (DQN) TUTORIAL.

I'm implement actor's strategy as follows: 1. model.eval() 2. get best action from a model 3. self.net.train()

The question is: Does going back and forth between eval() and train() modes cause any damage to optimization process?

The model includes only Linear and BatchNorm1d layers. As far as I know when using BatchNorm1d one must perform model.eval() to use a model, because there is different results in eval() and train() modes.

When training Classification Neural Network the model.eval() performed only after training is finished, but in case of "Deep Reinforcement Learning" it is usual to use strategy and then continue the optimization process.

I'm wondering if going back and forth between modes is "harmless" to optimization process?

def strategy(self, state): # Explore or Exploit if self.epsilon > random(): action = choice(self.actions) else: self.net.eval() action = self.net(state.unsqueeze(0)).max(1)[1].detach() self.net.train() 
1
  • eval mode just changes the behavior of things like dropout and batch norm. For example dropout becomes a passthrough layer and batch norm uses the running statistics to normalize instead of current batch statistics. Batch norm also doesn't update running statistics in eval mode. It shouldn't have any negative effect on training. Commented Oct 18, 2019 at 10:24

2 Answers 2

3

eval() puts the model in the evaluation mode.

  1. In the evaluation mode, the Dropout layer just acts as a "passthrough" layer.

  2. During training, a BatchNorm layer keeps a running estimate of its computed mean and variance. The running sum is kept with a default momentum of 0.1. During the evaluation, this running mean/variance is used for normalization.

So, going back and forth between eval() and train() modes do not cause any damage to the optimization process.

Sign up to request clarification or add additional context in comments.

1 Comment

So it is a good practice to have something like: eval(); with torch.no_grad(): code (e.g. inference); train()? I am using something like that when I want to correctly calculate the training loss and training loss and then continue training (e.g. missing to set Dropout layer to eval() mode can lead to an underestimation of the correct training loss).
0

It will not harm the optimization process, but if for some reason one wants to switch from train to eval back and forth at inference time (e.g. you want to compare stochastic forward passes with MC Dropout to deterministic passes without dropout at test time in the same notebook), then one must switch off the track_running_stats property of the BatchNorm layer, otherwise at each stochastic forward pass it will update its running estimates. Here is a small basic snippet illustrating the problem.

import torch.nn as nn model = nn.Sequential(nn.Conv2d(1,3,3), nn.BatchNorm2d(3), nn.MaxPool2d(8), nn.Flatten(), nn.Linear(3,1)) model.eval() a = torch.randn(1,1,10,10) print("INITIAL running mean: \n", model[1].running_mean) #-> tensor([0., 0., 0.]) print("eval: ",model(a)) # -> tensor([[0.7643]], grad_fn=<AddmmBackward0>) model.train() print("AFTER EVAL running mean: \n", model[1].running_mean) # -> tensor([0., 0., 0.]) print("training mode: \n",model(a)) # -> tensor([[1.4419]], grad_fn=<AddmmBackward0>) model.eval() print("AFTER TRAIN running mean: \n", model[1].running_mean,) # -> tensor([0.0146, 0.0131, 0.0101]) print("eval: ",model(a)) # -> tensor([[0.7858]], grad_fn=<AddmmBackward0>) model.train() print("\nMake track running stats constant----------\n") model[1].track_running_stats = False print("AFTER EVAL running mean: \n", model[1].running_mean) # -> tensor([0.0146, 0.0131, 0.0101]) print("training mode: \n",model(a)) # -> tensor([[1.4419]], grad_fn=<AddmmBackward0>) model.eval() print("AFTER TRAIN running mean: \n", model[1].running_mean) # -> tensor([0.0146, 0.0131, 0.0101]) print("eval: ",model(a)) # -> tensor([[0.7858]], grad_fn=<AddmmBackward0>) 

1 Comment

Make sure all code in the post is formatted as code and not as regular text.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.