Skip to content

Bug: metric_fn in HIST and IGMTF always raises ValueError for default metric #2163

Description

@warren618

Description

metric_fn() in pytorch_hist.py (line 173) and pytorch_igmtf.py (line 166) uses == to compare self.metric against the tuple ("", "loss"):

if self.metric == ("", "loss"):
    return -self.loss_fn(pred[mask], label[mask])

Since self.metric is always a string (default is ""), this comparison is never true. The function falls through to raise ValueError("unknown metric").

All other models correctly use in:

# pytorch_alstm.py, pytorch_gru.py, pytorch_lstm.py, etc.
if self.metric in ("", "loss"):
    return -self.loss_fn(pred[mask], label[mask])

Impact

Any user training HIST or IGMTF with the default metric (or metric="loss") will get a crash:

ValueError: unknown metric ``

Fix

- if self.metric == ("", "loss"):
+ if self.metric in ("", "loss"):

I will submit a fix shortly.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions