Fix validation skipped when IterableDataset exhausts before reported length#21551
Closed
avocardio wants to merge 2 commits into
Closed
Fix validation skipped when IterableDataset exhausts before reported length#21551avocardio wants to merge 2 commits into
avocardio wants to merge 2 commits into
Conversation
…length When an IterableDataset reports a length via __len__ but produces fewer batches (due to shard boundaries, rounding, or drop_last=True with multiple workers), StopIteration is raised in _DataFetcher.__next__ before fetched >= length. This StopIteration propagates to the training epoch loop's run() method, where `except StopIteration: break` exits the loop — skipping on_advance_end() and the validation check it contains. The fix adds a post-loop validation check: when the data fetcher is done (StopIteration was caught) and validation should run at the epoch boundary, we set is_last_batch=True and run the validation check that was skipped. Fixes Lightning-AI#19624
for more information, see https://pre-commit.ci
Author
|
Closing to resubmit with a cleaner description. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Bug
Fixes #19624
When an
IterableDatasetreports a length via__len__but produces fewer batches than expected (e.g. due todrop_last=Truewith multiple workers, shard boundary rounding in webdataset/DALI, or worker-based splitting),StopIterationis raised before the expected batch count. This causeson_advance_endto be skipped for the final iteration, which means the end-of-epoch validation check never runs — not just for that epoch, but for all subsequent epochs as well.Root cause
In
_TrainingEpochLoop.run():When
_DataFetcher.__next__()encounters aStopIterationfrom the underlying iterator, it setsself.done = Trueand then re-raises the exception. Theexcept StopIteration: breakin the training loop catches this, which skipson_advance_end()— where the validation check (_should_check_val_fx) lives.Fix
After the while loop exits, check whether the data fetcher was exhausted (
data_fetcher.done) and if so, run the end-of-epoch validation check. This only triggers validation — not the per-batch operations (LR scheduler updates,_batches_that_steppedincrement, logger saves) that correctly should not run when no batch was processed.How to reproduce
This is common with webdataset/DALI streaming datasets where
__len__is an estimate based on total samples / batch_size but the actual number of batches produced can vary due to shard boundaries, worker splitting, ordrop_last.Test plan
test_iterable_dataset_validation_on_exhaustionregression test that creates aShortIterableDatasetreporting__len__() = 10but yielding only 8 samples (4 batches vs expected 5), and verifies validation runs for both epochs📚 Documentation preview 📚: https://pytorch-lightning--21551.org.readthedocs.build/en/21551/