[ENH] Add MultiLoss and multi-target support to ptf-v2#2312
[ENH] Add MultiLoss and multi-target support to ptf-v2#2312andersendsa wants to merge 6 commits into
Conversation
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2312 +/- ##
=======================================
Coverage ? 87.12%
=======================================
Files ? 167
Lines ? 9763
Branches ? 0
=======================================
Hits ? 8506
Misses ? 1257
Partials ? 0
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
1d7c22b to
14ec491
Compare
phoeenniixx
left a comment
There was a problem hiding this comment.
There are a lot of unrelated changes like to the notebooks and test.yml? Why? Please merge with main to solve the issues in the CI
@phoeenniixx Those other changes were already done when i checked out the repository |
I meant on your branch. you dont need to make these changes. Look at changed files - they have a lot of files which have unrelated changes to the PR. Please revert those |
Fixes the BaseModelV2 `to_prediction` and `to_quantiles` functions to return multi-target list of tensors appropriately with `MultiLoss`. Updates `log_metrics` in `BaseModelV2` to appropriately tag multi-target metrics logging iteratively to prevent dimension mismatch crashes.
a31dd12 to
f69aa82
Compare
|
Hi @phoeenniixx i have done all the changes and the pr is ready for review |
Reference Issues/PRs
#2308
What does this implement/fix? Explain your changes.
Fixes the BaseModelV2
to_predictionandto_quantilesfunctions to return multi-target list of tensors appropriately withMultiLoss. Updateslog_metricsinBaseModelV2to appropriately tag multi-target metrics logging iteratively to prevent dimension mismatch crashes.Did you add any tests for the change?
tests/test_models/test_base_model_v2.py:
test_multi_loss_predictions: Verifies that to_prediction and to_quantiles correctly process predictions utilizing MultiLoss with different metrics like MAE and QuantileLoss, checking that it correctly outputs lists of tensors with the appropriate shapes.
test_multi_loss_log_metrics: Verifies that log_metrics successfully logs the evaluations for each target without a dimension crash, and correctly maps the prefix tags sequentially.
PR checklist
pre-commit install.To run hooks independent of commit, execute
pre-commit run --all-files