Skip to content

Commit ca068f1

Browse files
Add logic to resume ImageNet example with new learning rate
1 parent 0252bda commit ca068f1

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

imagenet/main.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,9 @@ def main_worker(gpu, ngpus_per_node, args):
215215
best_acc1 = best_acc1.to(args.gpu)
216216
model.load_state_dict(checkpoint['state_dict'])
217217
optimizer.load_state_dict(checkpoint['optimizer'])
218+
if args.lr:
219+
# resume with newly specified learning rate
220+
optimizer.param_groups[0]['lr'] = args.lr
218221
scheduler.load_state_dict(checkpoint['scheduler'])
219222
print("=> loaded checkpoint '{}' (epoch {})"
220223
.format(args.resume, checkpoint['epoch']))
@@ -293,8 +296,8 @@ def main_worker(gpu, ngpus_per_node, args):
293296
'arch': args.arch,
294297
'state_dict': model.state_dict(),
295298
'best_acc1': best_acc1,
296-
'optimizer' : optimizer.state_dict(),
297-
'scheduler' : scheduler.state_dict()
299+
'optimizer': optimizer.state_dict(),
300+
'scheduler': scheduler.state_dict()
298301
}, is_best)
299302

300303

0 commit comments

Comments
 (0)