-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathanalyze_phase_a.py
More file actions
145 lines (119 loc) · 4.67 KB
/
Copy pathanalyze_phase_a.py
File metadata and controls
145 lines (119 loc) · 4.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""analyze_phase_a.py — qualitative + quantitative comparison of Phase A noise
sweep against cond B and cond C.
For each `condition_a_prime_<suffix>.json` in results/, compute:
- mean prediction length (chars)
- fraction of degenerate outputs (loop / single-char collapse / empty)
- exact match rate
- 5 random sample predictions for visual inspection
Compare against cond B (TTT memory, gibberish baseline) and cond C (vanilla, fluent baseline).
"""
from __future__ import annotations
import argparse
import json
import random
import re
import statistics
from pathlib import Path
def is_loop(text: str, min_chunk: int = 4, min_consecutive: int = 3) -> bool:
"""Loop = a chunk (>= min_chunk chars) appears CONSECUTIVELY min_consecutive
or more times anywhere in the text. Tighter than 'appears at least N times'
so legitimate phrasing repetition (e.g. 'X is 50000 元 ... X is 50000 元 ...')
isn't flagged."""
n = len(text)
if n < min_chunk * min_consecutive:
return False
for size in range(min_chunk, min(20, n // min_consecutive) + 1):
for start in range(n - size * min_consecutive + 1):
chunk = text[start:start + size]
consec = 1
for k in range(1, n // size):
next_start = start + size * k
if text[next_start:next_start + size] == chunk:
consec += 1
if consec >= min_consecutive:
return True
else:
break
return False
def is_single_char_collapse(text: str, min_run: int = 6) -> bool:
if not text:
return False
return bool(re.search(r"(.)\1{" + str(min_run - 1) + r",}", text))
def is_degenerate(text: str) -> bool:
t = (text or "").strip()
if not t:
return True
if is_single_char_collapse(t):
return True
if is_loop(t):
return True
return False
def load_predictions(path: Path) -> list[dict]:
data = json.loads(path.read_text())
out = []
for s in data:
for p in s["predictions"]:
out.append(p)
return out
def stats(preds: list[dict]) -> dict:
n = len(preds)
if n == 0:
return {"n": 0}
em = sum(1 for p in preds if (p.get("predicted") or "").strip() == p["gold"]) / n
lengths = [len(p.get("predicted") or "") for p in preds]
deg = sum(1 for p in preds if is_degenerate(p.get("predicted") or "")) / n
return {
"n": n,
"em": em,
"degenerate_frac": deg,
"mean_pred_len": statistics.mean(lengths),
"median_pred_len": statistics.median(lengths),
}
def sample_dump(preds: list[dict], k: int = 5, seed: int = 0) -> list[str]:
random.seed(seed)
sel = random.sample(preds, min(k, len(preds)))
out = []
for p in sel:
text = (p.get("predicted") or "")[:120]
out.append(f" Q: {p['question']!r}\n gold: {p['gold']!r}\n pred: {text!r}")
return out
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--results-dir", type=Path, default=Path("results"))
args = ap.parse_args()
rd = args.results_dir
targets: list[tuple[str, Path]] = []
# Baselines
if (rd / "condition_b.json").exists():
targets.append(("B (TTT memory)", rd / "condition_b.json"))
if (rd / "condition_c.json").exists():
targets.append(("C (no memory, vanilla)", rd / "condition_c.json"))
# Phase A sweep
for f in sorted(rd.glob("condition_a_prime_*.json")):
suffix = f.stem.replace("condition_a_prime_", "")
scale = float(suffix.replace("p", "."))
targets.append((f"A' (noise {scale})", f))
# Inference-time TTT scaling (scaled cond B)
for f in sorted(rd.glob("condition_b_scaled_*.json")):
suffix = f.stem.replace("condition_b_scaled_", "")
scale = float(suffix.replace("p", "."))
targets.append((f"B-scaled (α={scale})", f))
rows: list[dict] = []
for label, path in targets:
preds = load_predictions(path)
s = stats(preds)
s["label"] = label
rows.append(s)
print(f"\n=== {label} ===")
print(f" n={s['n']} EM={s['em']:.4f} degenerate={s['degenerate_frac']:.3f} "
f"mean_len={s['mean_pred_len']:.1f} median_len={s['median_pred_len']}")
print(" --- 5 random samples ---")
for line in sample_dump(preds, k=5, seed=42):
print(line)
print("\n\n=== SUMMARY ===")
print(f"{'cond':<28} {'n':>5} {'EM':>7} {'deg%':>7} {'meanLen':>8}")
for r in rows:
print(f"{r['label']:<28} {r['n']:>5} {r['em']:>7.4f} "
f"{r['degenerate_frac']*100:>6.1f}% {r['mean_pred_len']:>8.1f}")
if __name__ == "__main__":
main()