Fix WeightAveraging swapping in the un-updated average model during validation#21732
Conversation
…ation Before its first update, the AveragedModel only holds the copy of the initial weights made in setup(). The validation hooks swapped it in unconditionally, so during a delayed-start warmup (e.g. EMAWeightAveraging with update_starting_at_step) validation evaluated the untrained snapshot instead of the current trained weights. Only swap the models when the average model has been updated at least once (n_averaged > 0). The swap stays balanced across validation start/end since n_averaged does not change during validation.
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Pull request overview
Fixes a bug in the WeightAveraging / EMAWeightAveraging callbacks where validation could swap in the averaged model even before it had received its first update (n_averaged == 0), causing validation to run on an untrained initial-weight snapshot during delayed-start warmup.
Changes:
- Guarded validation-time model swapping so it only happens after the averaged model has been updated at least once (
n_averaged > 0). - Updated SWA test expectations to reflect that validation swapping now starts only after the first averaging update occurs.
- Added a regression test ensuring validation observes current trained weights (not the frozen initial snapshot) when the averaging update threshold is never reached.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated no comments.
| File | Description |
|---|---|
src/lightning/pytorch/callbacks/weight_averaging.py |
Prevents swapping in the averaged model for validation until n_averaged > 0, avoiding evaluation on an untrained snapshot. |
tests/tests_pytorch/callbacks/test_weight_averaging.py |
Adjusts swap-count expectations and adds a regression test for “no swap before first update” behavior. |
src/lightning/pytorch/CHANGELOG.md |
Documents the fix under the unreleased “Fixed” section. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
I think it's better to use |
|
Codecov Report✅ All modified and coverable lines are covered by tests.
Additional details and impacted files@@ Coverage Diff @@
## master #21732 +/- ##
=========================================
- Coverage 87% 79% -8%
=========================================
Files 270 267 -3
Lines 24005 23949 -56
=========================================
- Hits 20776 18824 -1952
- Misses 3229 5125 +1896 |
|
Hi @Martovark, great catch. Verified it from: Fixed standalone tests; should be good to land. Thanks @ATOM00blue for the work. |
What does this PR do?
Fixes #21724
closes #21754
WeightAveraging(and itsEMAWeightAveragingsubclass) creates theAveragedModelinsetup()as a copy of the model's initial weights, withn_averaged == 0. The validation hookson_validation_epoch_start/on_validation_epoch_endswapped this average model in unconditionally whenever it existed.When using a delayed start (e.g.
EMAWeightAveraging(update_starting_at_step=1000)) and validating during the warmup period, the average model has never been updated, so validation ran against the frozen initial (untrained) weights instead of the current trained ones. This is what the issue describes as metrics being near zero beforeupdate_starting_at_step.This PR only swaps the models for validation once the average model has actually been updated at least once (
n_averaged > 0). The swap remains balanced across the start/end hooks becausen_averageddoes not change during validation.Tests
test_weight_averaging_no_swap_before_first_update, which verifies that during a never-reached delayed start the parameters seen at validation are the current trained weights, not the frozen initial snapshot. It fails before this change and passes after.SWATestCallbackswap-count expectations: with a delayed update schedule, validation now only swaps once the average model has been updated.Before submitting
PR review
Anyone in the community is welcome to review the PR.