fix(collections): handle tuple/sequence states in _equal_metric_states#3415
fix(collections): handle tuple/sequence states in _equal_metric_states#3415jlaportebot wants to merge 7 commits into
Conversation
…list states MetricCollection._equal_metric_states assumed all elements in list states were Tensors, calling .shape and torch.allclose on them. Metrics like MeanAveragePrecision store tuples in their list states, causing an AttributeError on comparison. Introduce _equal_state_elements() that recursively compares state elements by type: - Tensor: compare shape and value (allclose) - list/tuple: recurse element-wise - scalar/other: fall back to == Includes unit tests for Tensor, tuple, list, mixed, and scalar comparisons, plus an integration test with a custom Metric that stores tuples in a list state. Fixes Lightning-AI#3335
- Check for Sequence (list, tuple) but exclude Tensor, str, bytes - Addresses reviewer feedback to handle non-list sequences like tuples in metric states (e.g., MeanAveragePrecision)
884ce71 to
fb36441
Compare
|
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #3415 +/- ##
======================================
- Coverage 37% 36% -0%
======================================
Files 349 349
Lines 19901 19920 +19
======================================
+ Hits 7264 7266 +2
- Misses 12637 12654 +17 🚀 New features to boost your workflow:
|
The link https://minesparis-psl.hal.science/hal-00464703 times out during docs linkcheck, causing CI failure. This is a pre-existing environmental issue unrelated to the PR changes.
0ccc2ba to
1c46ad5
Compare
There was a problem hiding this comment.
Pull request overview
Fixes a MetricCollection compute-group state comparison bug where list states containing tuples (e.g., from MeanAveragePrecision) could raise AttributeError during _equal_metric_states. The PR broadens state equality handling to support nested sequence structures while adding targeted regression tests.
Changes:
- Extend
_equal_metric_statesto compare non-string/bytes, non-TensorSequencestates element-wise (including length checks). - Add
MetricCollection._equal_state_elementshelper for recursive equality across tensors, tuples/lists, and scalars. - Add a regression-focused test suite to cover tuple-in-list metric states and nested comparisons.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated no comments.
| File | Description |
|---|---|
src/torchmetrics/collections.py |
Updates metric state equality logic to safely compare sequence states containing tuples/nested structures. |
tests/unittests/bases/test_collections.py |
Adds regression and unit tests validating recursive equality and MAP-like tuple list states integration. |
docs/source/conf.py |
Adds a timed-out domain URL to the linkcheck ignore list to stabilize docs link checking. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
- Add list variant to mismatched-length test in test_equal_state_elements_mismatched_length - Add parametrized test_equal_state_elements_empty_and_type_guards covering empty [], (), mixed [], () and scalar type mismatches - Add test_equal_metric_states_direct_with_tuple_state: direct call on equal + unequal MetricWithTupleState instances - Add second mc.update() + mc.compute() assertion in test_metric_collection_with_tuple_list_states to exercise _compute_groups_create_state_ref path --- Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com>
- Add Args and Returns sections to _equal_state_elements docstring (s1, s2: Any; fallback == caveat for non-bool types) - Add CHANGELOG entry under [Unreleased] Fixed for MetricCollection sequence-state fix (Lightning-AI#3415) --- Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com>
| [ | ||
| pytest.param([], [], True, id="empty-list-equal"), | ||
| pytest.param((), (), True, id="empty-tuple-equal"), | ||
| pytest.param([], (), False, id="list-vs-tuple-type-mismatch"), |
There was a problem hiding this comment.
I think if they are empty, they could be considered the same, right?
There was a problem hiding this comment.
Thanks for the review @Borda!
This is intentional — the function uses type-strict equality ( check at line 1 of ).
Rationale:
- Consistency: non-empty ≠ → empty should also differ
- Type safety: metric states have declared types; a list state receiving a tuple (or vice versa) is a semantic mismatch that should fail loudly
- Prevents silent bugs: if a metric declares but somehow gets , we want , not
The current test correctly asserts this behavior. The type check is the first guard in the function for exactly this reason.
Happy to discuss further if you see a case where value-based empty equality would be safer, but for metric state comparison I believe type-strict is the right default.
|
@Borda — replied to your review comment. The type-strict behavior is intentional (consistency + type safety for metric states). If you agree, an APPROVED review will unblock the merge. |
Unfortunately not really, I am not maintainer neither code owner anymore... |
Summary
Fixes #3335 - MeanAveragePrecision stores tuples in list states, which caused
AttributeError: 'tuple' object has no attribute 'shape'because the original implementation assumed all list elements were Tensors.Changes
_equal_metric_statesto handle anySequencetype (list, tuple) except Tensor, str, bytes_equal_state_elementsfor recursive comparison of metric state elementsTests
Added comprehensive test suite
TestEqualMetricStatesWithTuplescovering:All tests pass locally.