Skip to content

Commit 8769f1b

Browse files
committed
Prepared file to run on cluster
1 parent 555c9b9 commit 8769f1b

File tree

1 file changed

+29
-28
lines changed

1 file changed

+29
-28
lines changed

mask_rcnn.py

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -20,58 +20,53 @@
2020

2121
parser = argparse.ArgumentParser()
2222
parser.add_argument("--epoch", type=int, default=0, help="epoch to start training from")
23-
parser.add_argument("--num_epochs", type=int, default=200, help="number of epochs of training")
23+
parser.add_argument("--num_epochs", type=int, default=300, help="number of epochs of training")
2424
parser.add_argument("--dataset_name", type=str, default="ClothCoParse", help="name of the dataset")
25-
parser.add_argument("--batch_size", type=int, default=2, help="size of the batches")
26-
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
27-
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
28-
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
29-
parser.add_argument("--decay_epoch", type=int, default=100, help="epoch from which to start lr decay")
25+
parser.add_argument("--batch_size", type=int, default=8, help="size of the batches")
26+
parser.add_argument("--lr", type=float, default=0.005, help="adam: learning rate")
3027
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
3128
parser.add_argument("--img_height", type=int, default=512, help="size of image height")
3229
parser.add_argument("--img_width", type= int, default=512, help="size of image width")
33-
34-
parser.add_argument("--sample_interval", type=int, default=100, help="interval between sampling of images from generators")
30+
parser.add_argument("--evaluate_interval", type=int, default=50, help="interval between sampling of images from generators")
3531
parser.add_argument("--checkpoint_interval", type=int, default=50, help="interval between model checkpoints")
3632
parser.add_argument("--HPC_run", type=int, default=0, help="if 1, sets to true if running on HPC: default is 0 which reads to False")
3733
parser.add_argument("--remove_background", type=int, default=0, help="if 1, sets to true if: default is 1 which reads to False")
3834
parser.add_argument("--redirect_std_to_file", type =int, default=0, help="set all console output to file: default is 0 which reads to False")
3935
parser.add_argument("--train_percentage", type=float, default=0.8, help="percentage of samples used in training, the rest used for testing")
4036
parser.add_argument("--experiment_name", type=str, default=None, help="name of the folder inside saved_models")
41-
42-
37+
parser.add_argument("--print_freq", type=int, default=100, help="progress print out freq")
4338

4439
opt = parser.parse_args()
45-
40+
opt.train_shuffle = True
4641
if platform.system()=='Windows':
4742
opt.n_cpu= 0
4843

49-
opt.train_shuffle = False
44+
45+
# this used for debuging
5046
opt.batch_size = 2
51-
opt.num_epochs = 11
52-
opt.print_freq = 10
53-
opt.checkpoint_interval=10
54-
opt.train_percentage=0.80 #0.02 # to be used for debugging with low number of samples
55-
opt.epoch=0
56-
opt.experiment_name = None # 'ClothCoParse-mask_rcnn-Mar-26-at-21-2'
57-
opt.sample_interval=5
47+
# opt.num_epochs = 11
48+
# opt.print_freq = 10
49+
# opt.checkpoint_interval=10
50+
# opt.train_percentage=0.80 #0.02 # to be used for debugging with low number of samples
51+
# opt.epoch=0
52+
# opt.experiment_name = None # 'ClothCoParse-mask_rcnn-Mar-26-at-21-2'
53+
# opt.sample_interval=5
5854

5955
def sample_images(data_loader_test, model, device):
6056
images,targets = next(iter(data_loader_test)) # grab the images
6157
images = list(image.to(device) for image in images)
6258
model.eval() # setting model to evaluation mode
63-
predictions = model(images) # Returns predictions
59+
with torch.no_grad():
60+
predictions = model(images) # Returns predictions
61+
masks = predictions[0]['masks'].cpu().squeeze(1)
62+
labels = predictions[0]['labels'].cpu()
6463
model.train() # putting back the model into train status/mode
6564

6665

6766
''' TODO: Do something with the predictions, ie display / save '''
6867

6968

7069

71-
if platform.system()=='Windows':
72-
opt.n_cpu= 0
73-
74-
print(opt)
7570

7671
# sanity check
7772
if opt.epoch !=0 and opt.experiment_name is None:
@@ -85,6 +80,13 @@ def sample_images(data_loader_test, model, device):
8580
os.makedirs("images/%s" % opt.experiment_name, exist_ok=True)
8681
os.makedirs("saved_models/%s" % opt.experiment_name, exist_ok=True)
8782

83+
if opt.redirect_std_to_file:
84+
out_file_name = "saved_models/%s" % opt.experiment_name
85+
print('Output sent to ', out_file_name)
86+
sys.stdout = open(out_file_name+'.txt', 'w')
87+
88+
print(opt)
89+
8890

8991
# train on the GPU or on the CPU, if a GPU is not available
9092
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
@@ -103,7 +105,7 @@ def sample_images(data_loader_test, model, device):
103105

104106
# construct an optimizer
105107
params = [p for p in model.parameters() if p.requires_grad]
106-
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
108+
optimizer = torch.optim.SGD(params, lr=opt.lr, momentum=0.9, weight_decay=0.0005)
107109
# and a learning rate scheduler
108110
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
109111
step_size=3,
@@ -118,14 +120,13 @@ def sample_images(data_loader_test, model, device):
118120
lr_scheduler.step()
119121
# evaluate on the test dataset
120122

121-
if epoch % opt.sample_interval == 0:
122-
# evaluate(model, data_loader, device=device) # used for debugging
123+
if epoch % opt.evaluate_interval == 0:
123124
evaluate(model, data_loader_test, device=device)
124125

125126

126127
if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval== 0:
127128
# Save model checkpoints
128-
print('Saving model')
129+
print('Saving model ...')
129130
torch.save(model.state_dict(), "saved_models/%s/maskrcnn_%d.pth" % (opt.experiment_name, epoch))
130131

131132

0 commit comments

Comments
 (0)