Pytorch. Snapshots Weights Averaging.
1 min readMay 5, 2020
Okay, you have a number of checkpoints from a train-loop for 100500 epochs. Or you’ve carried some experiments with one architecture and changed only global parameters and now you have 7 saved .pth models in your folder.
But all of these models still can’t achieve 90% accuracy … 0.5–0.7% are missing.
Then how to reach the desired accuracy of 90%? The answer is to combine all models & average weights from snapshots.
Explanation.
Get a dictionary for each snapshot: parameters’ names and values.
for snapshot_path in list_of_snapshots_paths:
model = load_net(path=snapshot_path)
snapshots_weights[snapshot_path] =
dict(model.named_parameters())
Iterate on each parameter and set in new state_dict averaged value.
custom_params += snapshot_params[name].data dict_params[name].data.copy_(custom_params/N)
Load new state_dict into the model.
model_dict = model.state_dict()
model_dict.update(dict_params)model.load_state_dict(model_dict)