|
| 1 | +"""MVTec AD data access + anomaly-detection metrics for the QC perception gate. |
| 2 | +
|
| 3 | +Reusable across phases: the Phase 0 baseline harness and the Phase 1/2 patch |
| 4 | +method both load data and compute metrics through here. No asserted targets — |
| 5 | +this module only measures. |
| 6 | +
|
| 7 | +Dataset: `TheoM55/mvtec_anomaly_detection` (ungated HF mirror of MVTec AD). |
| 8 | +Layout (from `metadata.csv`): |
| 9 | + columns = path, label, split, object, defect, mask_path |
| 10 | + path = "train/<object>/good/NNN.png" | "test/<object>/<defect>/NNN.png" |
| 11 | + label = 0 (good) | 1..8 (defect-type id, per object) |
| 12 | + split = "train" | "test" |
| 13 | + object = one of the 15 MVTec categories |
| 14 | + defect = "good" | "<defect_name>" |
| 15 | + mask_path = "" for good, else path to the ground-truth defect mask |
| 16 | +""" |
| 17 | + |
| 18 | +from __future__ import annotations |
| 19 | + |
| 20 | +import csv |
| 21 | +import os |
| 22 | +from dataclasses import dataclass, field |
| 23 | + |
| 24 | +import numpy as np |
| 25 | + |
| 26 | +MVTEC_REPO = "TheoM55/mvtec_anomaly_detection" |
| 27 | + |
| 28 | +MVTEC_CATEGORIES = [ |
| 29 | + "bottle", "cable", "capsule", "carpet", "grid", "hazelnut", "leather", |
| 30 | + "metal_nut", "pill", "screw", "tile", "toothbrush", "transistor", "wood", |
| 31 | + "zipper", |
| 32 | +] |
| 33 | + |
| 34 | + |
| 35 | +@dataclass |
| 36 | +class Sample: |
| 37 | + """One MVTec image record (paths only; pixels loaded on demand).""" |
| 38 | + |
| 39 | + image_path: str |
| 40 | + label: int # 0 = good, 1 = anomalous (binary) |
| 41 | + defect: str # "good" or defect-type name |
| 42 | + mask_path: str | None # ground-truth mask path, or None for good |
| 43 | + |
| 44 | + |
| 45 | +@dataclass |
| 46 | +class CategoryData: |
| 47 | + """Train (good-only) + test (good+defect) split for one MVTec category.""" |
| 48 | + |
| 49 | + category: str |
| 50 | + train: list[Sample] = field(default_factory=list) # all good |
| 51 | + test: list[Sample] = field(default_factory=list) # good + defects |
| 52 | + |
| 53 | + @property |
| 54 | + def n_test_good(self) -> int: |
| 55 | + return sum(1 for s in self.test if s.label == 0) |
| 56 | + |
| 57 | + @property |
| 58 | + def n_test_defect(self) -> int: |
| 59 | + return sum(1 for s in self.test if s.label == 1) |
| 60 | + |
| 61 | + |
| 62 | +def resolve_dataset_root() -> str: |
| 63 | + """Return the local snapshot dir of the MVTec mirror (downloads metadata if |
| 64 | + needed; assumes the image/mask files were already fetched via |
| 65 | + snapshot_download). |
| 66 | + """ |
| 67 | + from huggingface_hub import hf_hub_download |
| 68 | + |
| 69 | + meta = hf_hub_download(MVTEC_REPO, "metadata.csv", repo_type="dataset") |
| 70 | + return os.path.dirname(meta) |
| 71 | + |
| 72 | + |
| 73 | +def load_category(category: str, root: str | None = None) -> CategoryData: |
| 74 | + """Load train/test sample lists for one MVTec category from metadata.csv.""" |
| 75 | + root = root or resolve_dataset_root() |
| 76 | + meta_csv = os.path.join(root, "metadata.csv") |
| 77 | + data = CategoryData(category=category) |
| 78 | + with open(meta_csv, newline="") as f: |
| 79 | + for row in csv.DictReader(f): |
| 80 | + if row["object"] != category: |
| 81 | + continue |
| 82 | + rel = row["path"].replace("\\", "/") |
| 83 | + img_path = os.path.join(root, "images", rel) |
| 84 | + label = int(row["label"]) |
| 85 | + binary = 0 if label == 0 else 1 |
| 86 | + mask_rel = (row.get("mask_path") or "").replace("\\", "/") |
| 87 | + mask_path = os.path.join(root, "masks", mask_rel) if mask_rel else None |
| 88 | + sample = Sample( |
| 89 | + image_path=img_path, |
| 90 | + label=binary, |
| 91 | + defect=row.get("defect", "good"), |
| 92 | + mask_path=mask_path, |
| 93 | + ) |
| 94 | + if row["split"] == "train": |
| 95 | + data.train.append(sample) |
| 96 | + else: |
| 97 | + data.test.append(sample) |
| 98 | + return data |
| 99 | + |
| 100 | + |
| 101 | +def load_image(path: str) -> np.ndarray: |
| 102 | + """Load an image as an (H, W, 3) uint8 BGR array (cv2-style, brain-native).""" |
| 103 | + from PIL import Image |
| 104 | + |
| 105 | + img = Image.open(path).convert("RGB") |
| 106 | + rgb = np.asarray(img, dtype=np.uint8) # (H, W, 3) RGB |
| 107 | + return rgb[:, :, ::-1].copy() # -> BGR |
| 108 | + |
| 109 | + |
| 110 | +def load_mask(path: str | None, shape: tuple[int, int]) -> np.ndarray: |
| 111 | + """Load a binary defect mask as (H, W) uint8 in {0,1}; zeros if no mask.""" |
| 112 | + if not path or not os.path.exists(path): |
| 113 | + return np.zeros(shape, dtype=np.uint8) |
| 114 | + from PIL import Image |
| 115 | + |
| 116 | + m = Image.open(path).convert("L") |
| 117 | + arr = np.asarray(m, dtype=np.uint8) |
| 118 | + return (arr > 0).astype(np.uint8) |
| 119 | + |
| 120 | + |
| 121 | +# ── Metrics (numpy-only, no sklearn dependency) ────────────────────────────── |
| 122 | + |
| 123 | + |
| 124 | +def auroc(scores: np.ndarray, labels: np.ndarray) -> float: |
| 125 | + """Area under ROC via the Mann-Whitney U statistic. |
| 126 | +
|
| 127 | + scores: higher = more anomalous. labels: 1 = anomalous, 0 = normal. |
| 128 | + Ties handled with average ranks. Returns 0.5 if a class is absent. |
| 129 | + """ |
| 130 | + scores = np.asarray(scores, dtype=np.float64) |
| 131 | + labels = np.asarray(labels, dtype=np.int64) |
| 132 | + n_pos = int((labels == 1).sum()) |
| 133 | + n_neg = int((labels == 0).sum()) |
| 134 | + if n_pos == 0 or n_neg == 0: |
| 135 | + return 0.5 |
| 136 | + order = np.argsort(scores, kind="mergesort") |
| 137 | + ranks = np.empty(len(scores), dtype=np.float64) |
| 138 | + sorted_scores = scores[order] |
| 139 | + i = 0 |
| 140 | + while i < len(scores): |
| 141 | + j = i |
| 142 | + while j + 1 < len(scores) and sorted_scores[j + 1] == sorted_scores[i]: |
| 143 | + j += 1 |
| 144 | + avg_rank = (i + j) / 2.0 + 1.0 # 1-based average rank for the tie block |
| 145 | + ranks[order[i:j + 1]] = avg_rank |
| 146 | + i = j + 1 |
| 147 | + sum_ranks_pos = ranks[labels == 1].sum() |
| 148 | + u_pos = sum_ranks_pos - n_pos * (n_pos + 1) / 2.0 |
| 149 | + return float(u_pos / (n_pos * n_neg)) |
| 150 | + |
| 151 | + |
| 152 | +def best_f1(scores: np.ndarray, labels: np.ndarray) -> dict: |
| 153 | + """Threshold-swept best F1 (ORACLE threshold on this set — optimistic). |
| 154 | +
|
| 155 | + Returns the threshold maximizing F1 plus precision/recall/F1 there. |
| 156 | + Phase 3 replaces this with a proper validation-split calibration. |
| 157 | + """ |
| 158 | + scores = np.asarray(scores, dtype=np.float64) |
| 159 | + labels = np.asarray(labels, dtype=np.int64) |
| 160 | + if len(scores) == 0: |
| 161 | + return {"threshold": 0.0, "precision": 0.0, "recall": 0.0, "f1": 0.0} |
| 162 | + cands = np.unique(scores) |
| 163 | + best = {"threshold": float(cands[0]), "precision": 0.0, "recall": 0.0, "f1": -1.0} |
| 164 | + for thr in cands: |
| 165 | + pred = (scores >= thr).astype(np.int64) |
| 166 | + tp = int(((pred == 1) & (labels == 1)).sum()) |
| 167 | + fp = int(((pred == 1) & (labels == 0)).sum()) |
| 168 | + fn = int(((pred == 0) & (labels == 1)).sum()) |
| 169 | + prec = tp / (tp + fp) if (tp + fp) else 0.0 |
| 170 | + rec = tp / (tp + fn) if (tp + fn) else 0.0 |
| 171 | + f1 = 2 * prec * rec / (prec + rec) if (prec + rec) else 0.0 |
| 172 | + if f1 > best["f1"]: |
| 173 | + best = {"threshold": float(thr), "precision": prec, "recall": rec, "f1": f1} |
| 174 | + return best |
| 175 | + |
| 176 | + |
| 177 | +def pixel_auroc(score_maps: list[np.ndarray], masks: list[np.ndarray], |
| 178 | + max_pixels: int = 2_000_000) -> float: |
| 179 | + """Pixel-level AUROC over all defect images (subsampled for tractability).""" |
| 180 | + flat_scores: list[np.ndarray] = [] |
| 181 | + flat_labels: list[np.ndarray] = [] |
| 182 | + for sm, mk in zip(score_maps, masks): |
| 183 | + if sm is None or mk is None: |
| 184 | + continue |
| 185 | + flat_scores.append(sm.ravel().astype(np.float64)) |
| 186 | + flat_labels.append(mk.ravel().astype(np.int64)) |
| 187 | + if not flat_scores: |
| 188 | + return float("nan") |
| 189 | + s = np.concatenate(flat_scores) |
| 190 | + y = np.concatenate(flat_labels) |
| 191 | + if y.sum() == 0 or y.sum() == len(y): |
| 192 | + return float("nan") |
| 193 | + if len(s) > max_pixels: |
| 194 | + rng = np.random.default_rng(0) |
| 195 | + idx = rng.choice(len(s), size=max_pixels, replace=False) |
| 196 | + s, y = s[idx], y[idx] |
| 197 | + return auroc(s, y) |
0 commit comments