-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
220 lines (188 loc) · 9.28 KB
/
Copy pathmodel.py
File metadata and controls
220 lines (188 loc) · 9.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
"""
Equilibrium Associative Memory (EAM)
====================================
A small, runnable seed of the "GPU-native, founded-by-physicists" architecture.
It instantiates four of the design moves from the thesis as concrete, dense,
GPU-friendly operations:
1. Equilibrium / fixed-point computation .... one weight-tied block solved to a
fixed point (a Deep Equilibrium model). Depth is implicit, memory is O(1) in
the number of solver steps when grad_tail=1 (Jacobian-Free Backprop).
2. Associative memory == attention ......... the attention term is a modern
Hopfield retrieval. softmax(QK^T / sqrt(d)) V *is* the Hopfield update rule;
the block exposes the Hopfield energy as a diagnostic.
3. Vector-symbolic representation .......... position is encoded by HRR
circular-convolution *binding* (a role vector bound into each token) rather
than additive positional encoding. This is the VSA / holographic primitive.
4. Non-autoregressive generation ........... trained as a masked denoiser;
decoded by parallel iterative unmasking (MaskGIT-style), not token-by-token.
Everything is dense matmul + FFT + softmax, i.e. exactly what a GPU wants.
This is deliberately a seed you can extend, not a frontier system. Honest about
that: it is a DEQ-transformer with HRR binding and parallel denoising. The point
is that "transformer" falls out as the frozen-feedforward, autoregressive special
case of this more general object.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def circular_bind(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""Holographic Reduced Representation binding via circular convolution (FFT).
bind(a, b) = irfft(rfft(a) * rfft(b)). Dense, parallel, GPU-native.
"""
fa = torch.fft.rfft(a, dim=-1)
fb = torch.fft.rfft(b, dim=-1)
return torch.fft.irfft(fa * fb, n=a.shape[-1], dim=-1)
class AssocBlock(nn.Module):
"""One weight-tied update, applied repeatedly to reach a fixed point.
The attention term is content-addressable associative retrieval (modern
Hopfield). The MLP is the local per-token energy term. Input injection x is
held fixed across the solve so the block defines f(z, x) with a fixed point.
"""
def __init__(self, dim: int, n_heads: int, mlp_mult: int = 4, dropout: float = 0.0):
super().__init__()
assert dim % n_heads == 0, "dim must be divisible by n_heads"
self.dim = dim
self.n_heads = n_heads
self.head_dim = dim // n_heads
self.scale = self.head_dim ** -0.5
self.ln_attn = nn.LayerNorm(dim)
self.qkv = nn.Linear(dim, 3 * dim, bias=False)
self.proj = nn.Linear(dim, dim, bias=False)
self.ln_mlp = nn.LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_mult * dim),
nn.GELU(),
nn.Linear(mlp_mult * dim, dim),
)
# Terminal norm: bounds the block output so the iterated map is a genuine
# contraction with a real fixed point. Without it the block is a bare
# residual carry f(z, x) = z + g(z, x), whose only fixed point needs g -> 0;
# nothing drives that, so z drifts linearly and the relative residual decays
# as 1/t (false "convergence") instead of geometrically toward an attractor.
self.ln_out = nn.LayerNorm(dim)
self.drop = nn.Dropout(dropout)
def _attn(self, h: torch.Tensor) -> torch.Tensor:
B, N, D = h.shape
q, k, v = self.qkv(self.ln_attn(h)).chunk(3, dim=-1)
q = q.view(B, N, self.n_heads, self.head_dim).transpose(1, 2)
k = k.view(B, N, self.n_heads, self.head_dim).transpose(1, 2)
v = v.view(B, N, self.n_heads, self.head_dim).transpose(1, 2)
out = F.scaled_dot_product_attention(q, k, v) # modern Hopfield retrieval
out = out.transpose(1, 2).reshape(B, N, D)
return self.proj(out)
def forward(self, z: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
h = z + x # input injection (held fixed)
h = h + self.drop(self._attn(h)) # associative retrieval term
h = h + self.drop(self.mlp(self.ln_mlp(h))) # local energy term
return self.ln_out(h) # bound the state -> real fixed point
@torch.no_grad()
def hopfield_energy(self, h: torch.Tensor) -> torch.Tensor:
"""Diagnostic: modern Hopfield energy of the retrieval term (per-batch mean).
E = -lse(beta * Q K^T) + 0.5 ||Q||^2, summed over tokens. Lower = more
settled into a stored pattern. Monitored across solver steps to show the
state relaxing toward an attractor.
"""
q, k, _ = self.qkv(self.ln_attn(h)).chunk(3, dim=-1)
logits = torch.einsum("bnd,bmd->bnm", q, k) * self.scale
lse = torch.logsumexp(logits, dim=-1)
e = -lse.sum(dim=-1) + 0.5 * (q * q).sum(dim=(-1, -2))
return e.mean()
class EquilibriumAssocMemory(nn.Module):
def __init__(
self,
vocab: int,
dim: int = 256,
n_heads: int = 4,
max_len: int = 64,
solver_steps: int = 16,
grad_tail: int = 2,
damping: float = 0.8,
use_binding: bool = True,
mlp_mult: int = 4,
dropout: float = 0.0,
):
super().__init__()
self.dim = dim
self.max_len = max_len
self.solver_steps = solver_steps
self.grad_tail = grad_tail # steps with grad: 1 == pure JFB (O(1) mem)
self.damping = damping
self.use_binding = use_binding
self.embed = nn.Embedding(vocab, dim)
if use_binding:
# fixed random role vectors; position is encoded by binding, not addition
self.register_buffer("roles", torch.randn(max_len, dim) / math.sqrt(dim))
else:
self.pos = nn.Parameter(torch.randn(1, max_len, dim) * 0.02)
self.block = AssocBlock(dim, n_heads, mlp_mult, dropout)
self.ln_out = nn.LayerNorm(dim)
self.head = nn.Linear(dim, vocab, bias=False)
self.head.weight = self.embed.weight # weight tying
def inject(self, tokens: torch.Tensor) -> torch.Tensor:
e = self.embed(tokens)
N = tokens.shape[1]
if self.use_binding:
return circular_bind(e, self.roles[:N].unsqueeze(0))
return e + self.pos[:, :N]
def solve(self, x: torch.Tensor, return_diag: bool = False):
"""Picard iteration to a fixed point. Most steps run under no_grad; the
last `grad_tail` steps carry gradient (Jacobian-Free / truncated-unroll
DEQ gradient), so training memory does not grow with solver depth.
Diagnostics: `residual` is the relative fixed-point residual
||z_{t+1} - z_t|| / ||z_{t+1}||, which should shrink toward 0 as the
state settles into its attractor. `energy` is the modern Hopfield energy
of the retrieval term (a proxy that plateaus as z converges)."""
z = torch.zeros_like(x)
residuals, energies = [], []
n_nograd = max(0, self.solver_steps - self.grad_tail)
def step(z):
z_new = self.damping * self.block(z, x) + (1 - self.damping) * z
if return_diag:
r = (z_new - z).norm() / (z_new.norm() + 1e-8)
residuals.append(r.item())
energies.append(self.block.hopfield_energy(z_new).item())
return z_new
with torch.no_grad():
for _ in range(n_nograd):
z = step(z)
z = z.detach()
for _ in range(self.grad_tail):
z = step(z)
return (z, {"residual": residuals, "energy": energies}) if return_diag else z
def forward(self, tokens: torch.Tensor, return_diag: bool = False):
x = self.inject(tokens)
if return_diag:
z, diag = self.solve(x, return_diag=True)
return self.head(self.ln_out(z)), diag
z = self.solve(x)
return self.head(self.ln_out(z))
@torch.no_grad()
def generate(self, tokens, mask_id, n_steps=8, temperature=1.0):
"""Non-autoregressive parallel decoding (MaskGIT-style cosine schedule).
Start from a sequence with masked positions; over n_steps rounds, commit
the most-confident predictions in parallel and re-mask the rest.
"""
self.eval()
out = tokens.clone()
mask = out == mask_id
total = mask.sum(dim=1)
for t in range(n_steps):
logits = self.forward(out)
probs = F.softmax(logits / temperature, dim=-1)
conf, pred = probs.max(dim=-1)
conf = conf.masked_fill(~mask, -1.0)
ratio = math.cos(0.5 * math.pi * (t + 1) / n_steps) # fraction still masked
for b in range(out.shape[0]):
idx = torch.nonzero(mask[b], as_tuple=False).squeeze(1)
if idx.numel() == 0:
continue
order = torch.argsort(conf[b, idx], descending=True)
keep_masked = int(total[b].item() * ratio)
n_unmask = max(1, idx.numel() - keep_masked)
chosen = idx[order[:n_unmask]]
out[b, chosen] = pred[b, chosen]
mask[b, chosen] = False
if mask.any(): # fill any leftover in one shot
logits = self.forward(out)
out[mask] = logits.argmax(-1)[mask]
return out