Skip to content

Fix validation skipped when IterableDataset exhausts before reported length#21551

Closed
avocardio wants to merge 2 commits into
Lightning-AI:masterfrom
avocardio:fix/iterable-dataset-validation-on-exhaustion
Closed

Fix validation skipped when IterableDataset exhausts before reported length#21551
avocardio wants to merge 2 commits into
Lightning-AI:masterfrom
avocardio:fix/iterable-dataset-validation-on-exhaustion

Conversation

@avocardio

@avocardio avocardio commented Feb 25, 2026

Copy link
Copy Markdown

Bug

Fixes #19624

When an IterableDataset reports a length via __len__ but produces fewer batches than expected (e.g. due to drop_last=True with multiple workers, shard boundary rounding in webdataset/DALI, or worker-based splitting), StopIteration is raised before the expected batch count. This causes on_advance_end to be skipped for the final iteration, which means the end-of-epoch validation check never runs — not just for that epoch, but for all subsequent epochs as well.

Root cause

In _TrainingEpochLoop.run():

while not self.done:
    try:
        self.advance(data_fetcher)
        self.on_advance_end(data_fetcher)
    except StopIteration:
        break

When _DataFetcher.__next__() encounters a StopIteration from the underlying iterator, it sets self.done = True and then re-raises the exception. The except StopIteration: break in the training loop catches this, which skips on_advance_end() — where the validation check (_should_check_val_fx) lives.

Fix

After the while loop exits, check whether the data fetcher was exhausted (data_fetcher.done) and if so, run the end-of-epoch validation check. This only triggers validation — not the per-batch operations (LR scheduler updates, _batches_that_stepped increment, logger saves) that correctly should not run when no batch was processed.

How to reproduce

from torch.utils.data import DataLoader, IterableDataset

class ShortIterableDataset(IterableDataset):
    """Reports length=10 but only yields 8 samples."""
    def __iter__(self):
        for i in range(8):
            yield torch.randn(32)
    def __len__(self):
        return 10

trainer = Trainer(max_epochs=2, num_sanity_val_steps=0)
trainer.fit(model, DataLoader(ShortIterableDataset(), batch_size=2))
# Validation never runs (0 times instead of 2)

This is common with webdataset/DALI streaming datasets where __len__ is an estimate based on total samples / batch_size but the actual number of batches produced can vary due to shard boundaries, worker splitting, or drop_last.

Test plan

  • Added test_iterable_dataset_validation_on_exhaustion regression test that creates a ShortIterableDataset reporting __len__() = 10 but yielding only 8 samples (4 batches vs expected 5), and verifies validation runs for both epochs
  • Verified fix in production: two 40-epoch multi-GPU training runs on GH200 480GB servers with webdataset IterableDatasets, confirming validation fires reliably every epoch after applying this fix (previously validation was permanently skipped after epoch 0)

📚 Documentation preview 📚: https://pytorch-lightning--21551.org.readthedocs.build/en/21551/

…length

When an IterableDataset reports a length via __len__ but produces fewer
batches (due to shard boundaries, rounding, or drop_last=True with
multiple workers), StopIteration is raised in _DataFetcher.__next__
before fetched >= length. This StopIteration propagates to the training
epoch loop's run() method, where `except StopIteration: break` exits
the loop — skipping on_advance_end() and the validation check it
contains.

The fix adds a post-loop validation check: when the data fetcher is
done (StopIteration was caught) and validation should run at the epoch
boundary, we set is_last_batch=True and run the validation check that
was skipped.

Fixes Lightning-AI#19624
@github-actions github-actions Bot added the pl Generic label for PyTorch Lightning package label Feb 25, 2026
@avocardio

Copy link
Copy Markdown
Author

Closing to resubmit with a cleaner description.

@avocardio avocardio closed this Feb 25, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pl Generic label for PyTorch Lightning package

Projects

None yet

Development

Successfully merging this pull request may close these issues.

IterableDataset with CORRECT length causes validation loop to be skipped

1 participant