Skip to content

Commit edf14e0

Browse files
split files
1 parent b117ac3 commit edf14e0

11 files changed

Lines changed: 3883 additions & 1736 deletions

File tree

model/cubby/attention.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
"""Local sliding-window causal self-attention."""
2+
from __future__ import annotations
3+
4+
import torch
5+
import torch.nn as nn
6+
import torch.nn.functional as F
7+
8+
class LocalCausalAttention(nn.Module):
9+
"""Sliding-window causal self-attention.
10+
11+
Supplements MinGRU's recurrence with precise local token-to-token
12+
lookups within a window of W positions. The recurrence handles
13+
global context decay; attention handles "who said what" precision.
14+
15+
True O(L·W) memory and compute when S > window — Q is processed in
16+
chunks of W tokens, each chunk attending to its own ``W`` queries
17+
against the preceding ``2·W`` keys (so every position sees a full
18+
window of W context regardless of where it lands inside its chunk).
19+
Each per-chunk SDPA call is ``W · 2W`` so peak attention memory is
20+
``B · H · W · 2W`` rather than ``B · H · L²`` — at L=32k, W=512 that
21+
is a 64× reduction over the legacy mask-based implementation, and
22+
crucially does not require building an L×L mask in the first place.
23+
24+
Optional gradient checkpointing (``grad_checkpoint=True``) trades a
25+
second forward pass at backprop time for ~2× lower activation
26+
memory inside the attention layer — useful at long sequences where
27+
activations dominate.
28+
29+
Uses PyTorch's ``scaled_dot_product_attention`` for each chunk, so
30+
FlashAttention-2 kernel fusion still kicks in on A100/H100/H200.
31+
"""
32+
33+
def __init__(self, d_model: int, n_heads: int = 4, window: int = 128,
34+
grad_checkpoint: bool = False):
35+
super().__init__()
36+
assert d_model % n_heads == 0
37+
self.n_heads = n_heads
38+
self.d_head = d_model // n_heads
39+
self.window = window
40+
self.grad_checkpoint = bool(grad_checkpoint)
41+
self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
42+
self.out_proj = nn.Linear(d_model, d_model, bias=False)
43+
44+
def forward(self, x: torch.Tensor) -> torch.Tensor:
45+
if self.training and self.grad_checkpoint:
46+
# use_reentrant=False is the modern path; matches the
47+
# forward graph + handles non-tensor outputs gracefully.
48+
from torch.utils.checkpoint import checkpoint
49+
return checkpoint(self._forward_impl, x, use_reentrant=False)
50+
return self._forward_impl(x)
51+
52+
def _forward_impl(self, x: torch.Tensor) -> torch.Tensor:
53+
B, S, D = x.shape
54+
qkv = self.qkv(x).reshape(B, S, 3, self.n_heads, self.d_head)
55+
q, k, v = qkv.unbind(dim=2) # each (B, S, H, Dh)
56+
q = q.transpose(1, 2).contiguous() # (B, H, S, Dh)
57+
k = k.transpose(1, 2).contiguous()
58+
v = v.transpose(1, 2).contiguous()
59+
60+
if S <= self.window:
61+
# Full causal SDPA — window covers everything; no mask build.
62+
out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
63+
else:
64+
out = self._chunked_sliding_window(q, k, v)
65+
66+
out = out.transpose(1, 2).contiguous().reshape(B, S, D)
67+
return self.out_proj(out)
68+
69+
def _chunked_sliding_window(
70+
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
71+
) -> torch.Tensor:
72+
"""O(L·W) sliding-window attention via Q chunking.
73+
74+
Process Q in chunks of ``W`` positions. For each chunk starting
75+
at ``c_start``, gather keys/values from
76+
``[max(0, c_start - W + 1) : c_end]`` — that's the union of
77+
every chunk position's window (each position ``p`` attends to
78+
``[p - W + 1 .. p]``). Each per-chunk attention is then a small
79+
``c_len × (≤ 2W - 1)`` SDPA call.
80+
81+
Mask within a chunk: query at chunk-local row ``i`` (absolute
82+
``c_start + i``) attends to key-array index ``col`` (absolute
83+
``k_start + col``) iff
84+
``c_start + i - W + 1 <= k_start + col <= c_start + i``.
85+
Substituting ``offset = c_start - k_start`` gives
86+
``offset + i - W + 1 <= col <= offset + i``.
87+
88+
Memory: peak attention block is ``B · H · W · (2W - 1)`` per
89+
chunk, independent of ``S``. The number of chunks
90+
``ceil(S / W)`` is the linear-in-L factor.
91+
"""
92+
W = self.window
93+
S = q.shape[2]
94+
out = torch.empty_like(q)
95+
96+
for c_start in range(0, S, W):
97+
c_end = min(c_start + W, S)
98+
c_len = c_end - c_start
99+
k_start = max(0, c_start - W + 1) # earliest position any chunk-row needs
100+
kv_len = c_end - k_start
101+
local_q = q[:, :, c_start:c_end] # (B, H, c_len, Dh)
102+
local_k = k[:, :, k_start:c_end] # (B, H, kv_len, Dh)
103+
local_v = v[:, :, k_start:c_end]
104+
offset = c_start - k_start # c_start when < W-1 else W-1
105+
106+
if c_start == 0:
107+
# First chunk — windowed causal collapses to plain
108+
# causal because the lower bound (col >= i - W + 1) is
109+
# always satisfied by col >= 0 when i < W.
110+
local_out = F.scaled_dot_product_attention(
111+
local_q, local_k, local_v, is_causal=True)
112+
else:
113+
row = torch.arange(c_len, device=q.device).unsqueeze(1) # (c_len, 1)
114+
col = torch.arange(kv_len, device=q.device).unsqueeze(0) # (1, kv_len)
115+
lo = offset + row - W + 1 # window lower bound
116+
hi = offset + row # window upper bound (causal)
117+
allowed = (col >= lo) & (col <= hi)
118+
attn_mask = torch.zeros(
119+
c_len, kv_len, device=q.device, dtype=q.dtype)
120+
attn_mask.masked_fill_(~allowed, float("-inf"))
121+
local_out = F.scaled_dot_product_attention(
122+
local_q, local_k, local_v, attn_mask=attn_mask)
123+
out[:, :, c_start:c_end] = local_out
124+
return out
125+
126+

model/cubby/blocks.py

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
"""Cubby blocks: HybridBlock (MoE + attention + memory + GLU)
2+
and MinGRUBlock (pure MinGRU + GLU)."""
3+
from __future__ import annotations
4+
5+
from typing import TYPE_CHECKING
6+
7+
import torch
8+
import torch.nn as nn
9+
import torch.nn.functional as F
10+
11+
from nn_primitives import RMSNorm
12+
from layers import MinGRULayer, GLUChannelMix
13+
from moe import MoEMinGRULayer
14+
from attention import LocalCausalAttention
15+
from episodic_memory import EpisodicMemory
16+
17+
if TYPE_CHECKING:
18+
from train_torch import TrainConfig
19+
20+
class HybridBlock(nn.Module):
21+
"""MinGRU (or MoE-MinGRU) + optional local attention + optional
22+
hippocampal memory injection + GLU.
23+
24+
Architecture per block::
25+
26+
x = x + Mixer(RMSNorm(x)) # MinGRU or MoE-MinGRU
27+
x = x + Attn(RMSNorm(x)) # local attention (if enabled)
28+
x = x + MemGate(hippo.read(x)) # memory injection (if enabled)
29+
x = x + GLU(RMSNorm(x)) # channel mixing
30+
31+
Memory injection: the block's mean hidden state queries the episodic
32+
memory; retrieved memories are projected through a learned gate and
33+
broadcast to all positions. This doesn't affect gradients on the
34+
memory itself — only the gate is trainable.
35+
"""
36+
37+
def __init__(self, cfg: "TrainConfig", layer_idx: int,
38+
memory: "EpisodicMemory | None" = None):
39+
super().__init__()
40+
d = cfg.d_model
41+
self.layer_idx = layer_idx
42+
43+
# Sequence mixer
44+
self.rms_mix = RMSNorm(d)
45+
if cfg.enable_moe:
46+
self.mix = MoEMinGRULayer(
47+
d, cfg.moe_n_experts, cfg.moe_top_k,
48+
enable_hypergrad=cfg.enable_hypergrad,
49+
hypergrad_scale_init=cfg.hypergrad_scale_init,
50+
gate_noise_std=getattr(cfg, "moe_gate_noise_std", 0.0),
51+
gate_init_std=getattr(cfg, "moe_gate_init_std", 0.0),
52+
decay_bias_stagger=getattr(cfg, "moe_decay_bias_stagger", False),
53+
decay_bias_lo=getattr(cfg, "moe_decay_bias_lo", -1.0),
54+
decay_bias_hi=getattr(cfg, "moe_decay_bias_hi", 2.0),
55+
enable_hebbian_growth=getattr(cfg, "enable_hebbian_growth", False),
56+
hebbian_n_components=getattr(cfg, "hebbian_n_components", 8),
57+
hebbian_max_components=getattr(cfg, "hebbian_max_components", 64),
58+
hebbian_grow_threshold=getattr(cfg, "hebbian_grow_threshold", 0.35),
59+
hebbian_lr=getattr(cfg, "hebbian_lr", 5e-4),
60+
hebbian_lateral_beta=getattr(cfg, "hebbian_lateral_beta", 0.05),
61+
hebbian_grow_cooldown=getattr(cfg, "hebbian_grow_cooldown", 100),
62+
hebbian_ema=getattr(cfg, "hebbian_ema", 0.01),
63+
max_experts=getattr(cfg, "moe_max_experts", 32),
64+
n_shared=getattr(cfg, "moe_shared_experts", 0),
65+
)
66+
else:
67+
self.mix = MinGRULayer(
68+
d,
69+
enable_hypergrad=cfg.enable_hypergrad,
70+
hypergrad_scale_init=cfg.hypergrad_scale_init,
71+
)
72+
73+
# v4 Group-Routing: wrap the per-token mixer with a GroupedMoEBlock.
74+
# Tokens are pooled into G < S groups, the inner MoE runs per-group
75+
# (the inner MoE doesn't know "S" became "G" — it just sees a shorter
76+
# sequence dim), then outputs are scattered back to per-token. Trade-off:
77+
# syntactic coherence + cheaper routing for loss of per-token routing
78+
# freedom. Opt-in via cfg.enable_group_routing; default off so existing
79+
# configs are byte-identical until explicitly switched on.
80+
if getattr(cfg, "enable_group_routing", False) and cfg.enable_moe:
81+
from group_routing import (
82+
FixedSizeGrouping, LearnedGrouping, HebbianGrouping,
83+
SupervisedSVCGrouping, GroupedMoEBlock, GroupedMoEBlockBias,
84+
)
85+
strategy = getattr(cfg, "grouping_strategy", "fixed_size")
86+
if strategy == "fixed_size":
87+
grouping_fn = FixedSizeGrouping(
88+
group_size=getattr(cfg, "group_size", 4))
89+
elif strategy == "learned":
90+
grouping_fn = LearnedGrouping(
91+
d_model=d,
92+
n_groups=getattr(cfg, "n_groups", 16),
93+
temperature=getattr(cfg, "group_temperature", 1.0))
94+
elif strategy == "hebbian":
95+
# Caller threads the basis per forward via grouping_fn.set_basis(W).
96+
grouping_fn = HebbianGrouping(
97+
sig_dim=getattr(cfg, "hebbian_n_components", 8))
98+
elif strategy == "svc":
99+
grouping_fn = SupervisedSVCGrouping(
100+
fallback_group_size=getattr(cfg, "group_size", 4))
101+
else:
102+
raise ValueError(
103+
f"unknown cfg.grouping_strategy={strategy!r}; expected one of "
104+
f"'fixed_size', 'learned', 'hebbian', 'svc'")
105+
# Dispatch on the scatter-back variant. "bias" adds a per-token
106+
# D→D linear projection to the per-group MoE output so each
107+
# token in a group gets per-position variation in the residual
108+
# stream — fixes the autoregressive echo-collapse observed
109+
# with pure replication. "replicated" is legacy/ablation only.
110+
variant = getattr(cfg, "group_routing_variant", "bias")
111+
if variant == "bias":
112+
BlockCls = GroupedMoEBlockBias
113+
elif variant == "replicated":
114+
BlockCls = GroupedMoEBlock
115+
else:
116+
raise ValueError(
117+
f"unknown cfg.group_routing_variant={variant!r}; expected "
118+
f"'bias' (recommended) or 'replicated' (legacy)")
119+
self.mix = BlockCls(
120+
d_model=d, inner_moe=self.mix, grouping_fn=grouping_fn)
121+
self.uses_group_routing = True
122+
self.group_routing_variant = variant
123+
else:
124+
self.uses_group_routing = False
125+
self.group_routing_variant = None
126+
127+
# Local attention (on selected layers).
128+
# Always register placeholders (nn.Identity) so the module's
129+
# _modules dict has a stable shape across all block instances.
130+
# Dynamo guards on attribute location — mixing __dict__-stored
131+
# `None` and _modules-stored Module triggers per-layer recompiles
132+
# and blows the cache_size_limit.
133+
self.has_attn = bool(cfg.enable_attention
134+
and layer_idx % cfg.attn_every_n == 0)
135+
if self.has_attn:
136+
self.rms_attn = RMSNorm(d)
137+
self.attn = LocalCausalAttention(
138+
d, cfg.attn_n_heads, cfg.attn_window,
139+
grad_checkpoint=getattr(cfg, "attn_grad_checkpoint", False),
140+
)
141+
else:
142+
self.rms_attn = nn.Identity()
143+
self.attn = nn.Identity()
144+
145+
# Hippocampal memory injection (on selected layers). Same dict-
146+
# location stability concern as attention above.
147+
self.has_memory = bool(memory is not None and cfg.enable_memory
148+
and layer_idx % cfg.mem_every_n == 0)
149+
if self.has_memory:
150+
self.memory = memory
151+
self.mem_gate = nn.Linear(d, d, bias=False)
152+
nn.init.zeros_(self.mem_gate.weight) # start as no-op
153+
else:
154+
self.memory = None # external state, not a Module — kept in __dict__
155+
self.mem_gate = nn.Identity()
156+
157+
# FFN
158+
self.rms_ffn = RMSNorm(d)
159+
self.ffn = GLUChannelMix(d, cfg.d_ffn)
160+
161+
# Learned residual scaling (ZAYA1-8B / OpenMythos 2026). One scalar
162+
# per residual addition, initialised to 1.0 — at step 0 the block
163+
# is byte-identical to the un-scaled baseline. Per-stream gates
164+
# let the optimiser dampen norm growth through depth (matters most
165+
# at L >= 18) without touching layer weights. Total overhead per
166+
# HybridBlock: 4 params, ~0 FLOPs.
167+
self.enable_residual_scale = bool(
168+
getattr(cfg, "enable_residual_scale", False))
169+
if self.enable_residual_scale:
170+
self.alpha_mix = nn.Parameter(torch.ones(1))
171+
self.alpha_attn = nn.Parameter(torch.ones(1))
172+
self.alpha_mem = nn.Parameter(torch.ones(1))
173+
self.alpha_ffn = nn.Parameter(torch.ones(1))
174+
175+
def forward(self, x: torch.Tensor, surprise_gain: float = 0.0) -> torch.Tensor:
176+
if self.enable_residual_scale:
177+
x = x + self.alpha_mix * self.mix(self.rms_mix(x),
178+
surprise_gain=surprise_gain)
179+
if self.has_attn:
180+
x = x + self.alpha_attn * self.attn(self.rms_attn(x))
181+
if self.has_memory and self.memory is not None and self.memory.size > 0:
182+
x_mean = x.mean(dim=1)
183+
retrieved = self.memory.read(x_mean[0])
184+
mem_inject = self.mem_gate(retrieved)
185+
x = x + self.alpha_mem * mem_inject.unsqueeze(0).unsqueeze(0)
186+
x = x + self.alpha_ffn * self.ffn(self.rms_ffn(x))
187+
else:
188+
x = x + self.mix(self.rms_mix(x), surprise_gain=surprise_gain)
189+
if self.has_attn:
190+
x = x + self.attn(self.rms_attn(x))
191+
if self.has_memory and self.memory is not None and self.memory.size > 0:
192+
x_mean = x.mean(dim=1)
193+
retrieved = self.memory.read(x_mean[0])
194+
mem_inject = self.mem_gate(retrieved)
195+
x = x + mem_inject.unsqueeze(0).unsqueeze(0)
196+
x = x + self.ffn(self.rms_ffn(x))
197+
return x
198+
199+
200+
class MinGRUBlock(nn.Module):
201+
"""Original pure-MinGRU block (no MoE, no attention)."""
202+
def __init__(self, cfg: TrainConfig):
203+
super().__init__()
204+
self.rms_mix = RMSNorm(cfg.d_model)
205+
self.rms_ffn = RMSNorm(cfg.d_model)
206+
self.mix = MinGRULayer(
207+
cfg.d_model,
208+
enable_hypergrad=cfg.enable_hypergrad,
209+
hypergrad_scale_init=cfg.hypergrad_scale_init,
210+
)
211+
self.ffn = GLUChannelMix(cfg.d_model, cfg.d_ffn)
212+
213+
def forward(self, x: torch.Tensor, surprise_gain: float = 0.0) -> torch.Tensor:
214+
x = x + self.mix(self.rms_mix(x), surprise_gain=surprise_gain)
215+
x = x + self.ffn(self.rms_ffn(x))
216+
return x
217+
218+

0 commit comments

Comments
 (0)