2020
2121parser = argparse .ArgumentParser ()
2222parser .add_argument ("--epoch" , type = int , default = 0 , help = "epoch to start training from" )
23- parser .add_argument ("--num_epochs" , type = int , default = 200 , help = "number of epochs of training" )
23+ parser .add_argument ("--num_epochs" , type = int , default = 300 , help = "number of epochs of training" )
2424parser .add_argument ("--dataset_name" , type = str , default = "ClothCoParse" , help = "name of the dataset" )
25- parser .add_argument ("--batch_size" , type = int , default = 2 , help = "size of the batches" )
26- parser .add_argument ("--lr" , type = float , default = 0.0002 , help = "adam: learning rate" )
27- parser .add_argument ("--b1" , type = float , default = 0.5 , help = "adam: decay of first order momentum of gradient" )
28- parser .add_argument ("--b2" , type = float , default = 0.999 , help = "adam: decay of first order momentum of gradient" )
29- parser .add_argument ("--decay_epoch" , type = int , default = 100 , help = "epoch from which to start lr decay" )
25+ parser .add_argument ("--batch_size" , type = int , default = 8 , help = "size of the batches" )
26+ parser .add_argument ("--lr" , type = float , default = 0.005 , help = "adam: learning rate" )
3027parser .add_argument ("--n_cpu" , type = int , default = 8 , help = "number of cpu threads to use during batch generation" )
3128parser .add_argument ("--img_height" , type = int , default = 512 , help = "size of image height" )
3229parser .add_argument ("--img_width" , type = int , default = 512 , help = "size of image width" )
33-
34- parser .add_argument ("--sample_interval" , type = int , default = 100 , help = "interval between sampling of images from generators" )
30+ parser .add_argument ("--evaluate_interval" , type = int , default = 50 , help = "interval between sampling of images from generators" )
3531parser .add_argument ("--checkpoint_interval" , type = int , default = 50 , help = "interval between model checkpoints" )
3632parser .add_argument ("--HPC_run" , type = int , default = 0 , help = "if 1, sets to true if running on HPC: default is 0 which reads to False" )
3733parser .add_argument ("--remove_background" , type = int , default = 0 , help = "if 1, sets to true if: default is 1 which reads to False" )
3834parser .add_argument ("--redirect_std_to_file" , type = int , default = 0 , help = "set all console output to file: default is 0 which reads to False" )
3935parser .add_argument ("--train_percentage" , type = float , default = 0.8 , help = "percentage of samples used in training, the rest used for testing" )
4036parser .add_argument ("--experiment_name" , type = str , default = None , help = "name of the folder inside saved_models" )
41-
42-
37+ parser .add_argument ("--print_freq" , type = int , default = 100 , help = "progress print out freq" )
4338
4439opt = parser .parse_args ()
45-
40+ opt . train_shuffle = True
4641if platform .system ()== 'Windows' :
4742 opt .n_cpu = 0
4843
49- opt .train_shuffle = False
44+
45+ # this used for debuging
5046opt .batch_size = 2
51- opt .num_epochs = 11
52- opt .print_freq = 10
53- opt .checkpoint_interval = 10
54- opt .train_percentage = 0.80 #0.02 # to be used for debugging with low number of samples
55- opt .epoch = 0
56- opt .experiment_name = None # 'ClothCoParse-mask_rcnn-Mar-26-at-21-2'
57- opt .sample_interval = 5
47+ # opt.num_epochs = 11
48+ # opt.print_freq = 10
49+ # opt.checkpoint_interval=10
50+ # opt.train_percentage=0.80 #0.02 # to be used for debugging with low number of samples
51+ # opt.epoch=0
52+ # opt.experiment_name = None # 'ClothCoParse-mask_rcnn-Mar-26-at-21-2'
53+ # opt.sample_interval=5
5854
5955def sample_images (data_loader_test , model , device ):
6056 images ,targets = next (iter (data_loader_test )) # grab the images
6157 images = list (image .to (device ) for image in images )
6258 model .eval () # setting model to evaluation mode
63- predictions = model (images ) # Returns predictions
59+ with torch .no_grad ():
60+ predictions = model (images ) # Returns predictions
61+ masks = predictions [0 ]['masks' ].cpu ().squeeze (1 )
62+ labels = predictions [0 ]['labels' ].cpu ()
6463 model .train () # putting back the model into train status/mode
6564
6665
6766 ''' TODO: Do something with the predictions, ie display / save '''
6867
6968
7069
71- if platform .system ()== 'Windows' :
72- opt .n_cpu = 0
73-
74- print (opt )
7570
7671# sanity check
7772if opt .epoch != 0 and opt .experiment_name is None :
@@ -85,6 +80,13 @@ def sample_images(data_loader_test, model, device):
8580 os .makedirs ("images/%s" % opt .experiment_name , exist_ok = True )
8681 os .makedirs ("saved_models/%s" % opt .experiment_name , exist_ok = True )
8782
83+ if opt .redirect_std_to_file :
84+ out_file_name = "saved_models/%s" % opt .experiment_name
85+ print ('Output sent to ' , out_file_name )
86+ sys .stdout = open (out_file_name + '.txt' , 'w' )
87+
88+ print (opt )
89+
8890
8991# train on the GPU or on the CPU, if a GPU is not available
9092device = torch .device ('cuda' ) if torch .cuda .is_available () else torch .device ('cpu' )
@@ -103,7 +105,7 @@ def sample_images(data_loader_test, model, device):
103105
104106# construct an optimizer
105107params = [p for p in model .parameters () if p .requires_grad ]
106- optimizer = torch .optim .SGD (params , lr = 0.005 , momentum = 0.9 , weight_decay = 0.0005 )
108+ optimizer = torch .optim .SGD (params , lr = opt . lr , momentum = 0.9 , weight_decay = 0.0005 )
107109# and a learning rate scheduler
108110lr_scheduler = torch .optim .lr_scheduler .StepLR (optimizer ,
109111 step_size = 3 ,
@@ -118,14 +120,13 @@ def sample_images(data_loader_test, model, device):
118120 lr_scheduler .step ()
119121 # evaluate on the test dataset
120122
121- if epoch % opt .sample_interval == 0 :
122- # evaluate(model, data_loader, device=device) # used for debugging
123+ if epoch % opt .evaluate_interval == 0 :
123124 evaluate (model , data_loader_test , device = device )
124125
125126
126127 if opt .checkpoint_interval != - 1 and epoch % opt .checkpoint_interval == 0 :
127128 # Save model checkpoints
128- print ('Saving model' )
129+ print ('Saving model ... ' )
129130 torch .save (model .state_dict (), "saved_models/%s/maskrcnn_%d.pth" % (opt .experiment_name , epoch ))
130131
131132
0 commit comments