|
| 1 | +"""Cubby blocks: HybridBlock (MoE + attention + memory + GLU) |
| 2 | +and MinGRUBlock (pure MinGRU + GLU).""" |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +from typing import TYPE_CHECKING |
| 6 | + |
| 7 | +import torch |
| 8 | +import torch.nn as nn |
| 9 | +import torch.nn.functional as F |
| 10 | + |
| 11 | +from nn_primitives import RMSNorm |
| 12 | +from layers import MinGRULayer, GLUChannelMix |
| 13 | +from moe import MoEMinGRULayer |
| 14 | +from attention import LocalCausalAttention |
| 15 | +from episodic_memory import EpisodicMemory |
| 16 | + |
| 17 | +if TYPE_CHECKING: |
| 18 | + from train_torch import TrainConfig |
| 19 | + |
| 20 | +class HybridBlock(nn.Module): |
| 21 | + """MinGRU (or MoE-MinGRU) + optional local attention + optional |
| 22 | + hippocampal memory injection + GLU. |
| 23 | +
|
| 24 | + Architecture per block:: |
| 25 | +
|
| 26 | + x = x + Mixer(RMSNorm(x)) # MinGRU or MoE-MinGRU |
| 27 | + x = x + Attn(RMSNorm(x)) # local attention (if enabled) |
| 28 | + x = x + MemGate(hippo.read(x)) # memory injection (if enabled) |
| 29 | + x = x + GLU(RMSNorm(x)) # channel mixing |
| 30 | +
|
| 31 | + Memory injection: the block's mean hidden state queries the episodic |
| 32 | + memory; retrieved memories are projected through a learned gate and |
| 33 | + broadcast to all positions. This doesn't affect gradients on the |
| 34 | + memory itself — only the gate is trainable. |
| 35 | + """ |
| 36 | + |
| 37 | + def __init__(self, cfg: "TrainConfig", layer_idx: int, |
| 38 | + memory: "EpisodicMemory | None" = None): |
| 39 | + super().__init__() |
| 40 | + d = cfg.d_model |
| 41 | + self.layer_idx = layer_idx |
| 42 | + |
| 43 | + # Sequence mixer |
| 44 | + self.rms_mix = RMSNorm(d) |
| 45 | + if cfg.enable_moe: |
| 46 | + self.mix = MoEMinGRULayer( |
| 47 | + d, cfg.moe_n_experts, cfg.moe_top_k, |
| 48 | + enable_hypergrad=cfg.enable_hypergrad, |
| 49 | + hypergrad_scale_init=cfg.hypergrad_scale_init, |
| 50 | + gate_noise_std=getattr(cfg, "moe_gate_noise_std", 0.0), |
| 51 | + gate_init_std=getattr(cfg, "moe_gate_init_std", 0.0), |
| 52 | + decay_bias_stagger=getattr(cfg, "moe_decay_bias_stagger", False), |
| 53 | + decay_bias_lo=getattr(cfg, "moe_decay_bias_lo", -1.0), |
| 54 | + decay_bias_hi=getattr(cfg, "moe_decay_bias_hi", 2.0), |
| 55 | + enable_hebbian_growth=getattr(cfg, "enable_hebbian_growth", False), |
| 56 | + hebbian_n_components=getattr(cfg, "hebbian_n_components", 8), |
| 57 | + hebbian_max_components=getattr(cfg, "hebbian_max_components", 64), |
| 58 | + hebbian_grow_threshold=getattr(cfg, "hebbian_grow_threshold", 0.35), |
| 59 | + hebbian_lr=getattr(cfg, "hebbian_lr", 5e-4), |
| 60 | + hebbian_lateral_beta=getattr(cfg, "hebbian_lateral_beta", 0.05), |
| 61 | + hebbian_grow_cooldown=getattr(cfg, "hebbian_grow_cooldown", 100), |
| 62 | + hebbian_ema=getattr(cfg, "hebbian_ema", 0.01), |
| 63 | + max_experts=getattr(cfg, "moe_max_experts", 32), |
| 64 | + n_shared=getattr(cfg, "moe_shared_experts", 0), |
| 65 | + ) |
| 66 | + else: |
| 67 | + self.mix = MinGRULayer( |
| 68 | + d, |
| 69 | + enable_hypergrad=cfg.enable_hypergrad, |
| 70 | + hypergrad_scale_init=cfg.hypergrad_scale_init, |
| 71 | + ) |
| 72 | + |
| 73 | + # v4 Group-Routing: wrap the per-token mixer with a GroupedMoEBlock. |
| 74 | + # Tokens are pooled into G < S groups, the inner MoE runs per-group |
| 75 | + # (the inner MoE doesn't know "S" became "G" — it just sees a shorter |
| 76 | + # sequence dim), then outputs are scattered back to per-token. Trade-off: |
| 77 | + # syntactic coherence + cheaper routing for loss of per-token routing |
| 78 | + # freedom. Opt-in via cfg.enable_group_routing; default off so existing |
| 79 | + # configs are byte-identical until explicitly switched on. |
| 80 | + if getattr(cfg, "enable_group_routing", False) and cfg.enable_moe: |
| 81 | + from group_routing import ( |
| 82 | + FixedSizeGrouping, LearnedGrouping, HebbianGrouping, |
| 83 | + SupervisedSVCGrouping, GroupedMoEBlock, GroupedMoEBlockBias, |
| 84 | + ) |
| 85 | + strategy = getattr(cfg, "grouping_strategy", "fixed_size") |
| 86 | + if strategy == "fixed_size": |
| 87 | + grouping_fn = FixedSizeGrouping( |
| 88 | + group_size=getattr(cfg, "group_size", 4)) |
| 89 | + elif strategy == "learned": |
| 90 | + grouping_fn = LearnedGrouping( |
| 91 | + d_model=d, |
| 92 | + n_groups=getattr(cfg, "n_groups", 16), |
| 93 | + temperature=getattr(cfg, "group_temperature", 1.0)) |
| 94 | + elif strategy == "hebbian": |
| 95 | + # Caller threads the basis per forward via grouping_fn.set_basis(W). |
| 96 | + grouping_fn = HebbianGrouping( |
| 97 | + sig_dim=getattr(cfg, "hebbian_n_components", 8)) |
| 98 | + elif strategy == "svc": |
| 99 | + grouping_fn = SupervisedSVCGrouping( |
| 100 | + fallback_group_size=getattr(cfg, "group_size", 4)) |
| 101 | + else: |
| 102 | + raise ValueError( |
| 103 | + f"unknown cfg.grouping_strategy={strategy!r}; expected one of " |
| 104 | + f"'fixed_size', 'learned', 'hebbian', 'svc'") |
| 105 | + # Dispatch on the scatter-back variant. "bias" adds a per-token |
| 106 | + # D→D linear projection to the per-group MoE output so each |
| 107 | + # token in a group gets per-position variation in the residual |
| 108 | + # stream — fixes the autoregressive echo-collapse observed |
| 109 | + # with pure replication. "replicated" is legacy/ablation only. |
| 110 | + variant = getattr(cfg, "group_routing_variant", "bias") |
| 111 | + if variant == "bias": |
| 112 | + BlockCls = GroupedMoEBlockBias |
| 113 | + elif variant == "replicated": |
| 114 | + BlockCls = GroupedMoEBlock |
| 115 | + else: |
| 116 | + raise ValueError( |
| 117 | + f"unknown cfg.group_routing_variant={variant!r}; expected " |
| 118 | + f"'bias' (recommended) or 'replicated' (legacy)") |
| 119 | + self.mix = BlockCls( |
| 120 | + d_model=d, inner_moe=self.mix, grouping_fn=grouping_fn) |
| 121 | + self.uses_group_routing = True |
| 122 | + self.group_routing_variant = variant |
| 123 | + else: |
| 124 | + self.uses_group_routing = False |
| 125 | + self.group_routing_variant = None |
| 126 | + |
| 127 | + # Local attention (on selected layers). |
| 128 | + # Always register placeholders (nn.Identity) so the module's |
| 129 | + # _modules dict has a stable shape across all block instances. |
| 130 | + # Dynamo guards on attribute location — mixing __dict__-stored |
| 131 | + # `None` and _modules-stored Module triggers per-layer recompiles |
| 132 | + # and blows the cache_size_limit. |
| 133 | + self.has_attn = bool(cfg.enable_attention |
| 134 | + and layer_idx % cfg.attn_every_n == 0) |
| 135 | + if self.has_attn: |
| 136 | + self.rms_attn = RMSNorm(d) |
| 137 | + self.attn = LocalCausalAttention( |
| 138 | + d, cfg.attn_n_heads, cfg.attn_window, |
| 139 | + grad_checkpoint=getattr(cfg, "attn_grad_checkpoint", False), |
| 140 | + ) |
| 141 | + else: |
| 142 | + self.rms_attn = nn.Identity() |
| 143 | + self.attn = nn.Identity() |
| 144 | + |
| 145 | + # Hippocampal memory injection (on selected layers). Same dict- |
| 146 | + # location stability concern as attention above. |
| 147 | + self.has_memory = bool(memory is not None and cfg.enable_memory |
| 148 | + and layer_idx % cfg.mem_every_n == 0) |
| 149 | + if self.has_memory: |
| 150 | + self.memory = memory |
| 151 | + self.mem_gate = nn.Linear(d, d, bias=False) |
| 152 | + nn.init.zeros_(self.mem_gate.weight) # start as no-op |
| 153 | + else: |
| 154 | + self.memory = None # external state, not a Module — kept in __dict__ |
| 155 | + self.mem_gate = nn.Identity() |
| 156 | + |
| 157 | + # FFN |
| 158 | + self.rms_ffn = RMSNorm(d) |
| 159 | + self.ffn = GLUChannelMix(d, cfg.d_ffn) |
| 160 | + |
| 161 | + # Learned residual scaling (ZAYA1-8B / OpenMythos 2026). One scalar |
| 162 | + # per residual addition, initialised to 1.0 — at step 0 the block |
| 163 | + # is byte-identical to the un-scaled baseline. Per-stream gates |
| 164 | + # let the optimiser dampen norm growth through depth (matters most |
| 165 | + # at L >= 18) without touching layer weights. Total overhead per |
| 166 | + # HybridBlock: 4 params, ~0 FLOPs. |
| 167 | + self.enable_residual_scale = bool( |
| 168 | + getattr(cfg, "enable_residual_scale", False)) |
| 169 | + if self.enable_residual_scale: |
| 170 | + self.alpha_mix = nn.Parameter(torch.ones(1)) |
| 171 | + self.alpha_attn = nn.Parameter(torch.ones(1)) |
| 172 | + self.alpha_mem = nn.Parameter(torch.ones(1)) |
| 173 | + self.alpha_ffn = nn.Parameter(torch.ones(1)) |
| 174 | + |
| 175 | + def forward(self, x: torch.Tensor, surprise_gain: float = 0.0) -> torch.Tensor: |
| 176 | + if self.enable_residual_scale: |
| 177 | + x = x + self.alpha_mix * self.mix(self.rms_mix(x), |
| 178 | + surprise_gain=surprise_gain) |
| 179 | + if self.has_attn: |
| 180 | + x = x + self.alpha_attn * self.attn(self.rms_attn(x)) |
| 181 | + if self.has_memory and self.memory is not None and self.memory.size > 0: |
| 182 | + x_mean = x.mean(dim=1) |
| 183 | + retrieved = self.memory.read(x_mean[0]) |
| 184 | + mem_inject = self.mem_gate(retrieved) |
| 185 | + x = x + self.alpha_mem * mem_inject.unsqueeze(0).unsqueeze(0) |
| 186 | + x = x + self.alpha_ffn * self.ffn(self.rms_ffn(x)) |
| 187 | + else: |
| 188 | + x = x + self.mix(self.rms_mix(x), surprise_gain=surprise_gain) |
| 189 | + if self.has_attn: |
| 190 | + x = x + self.attn(self.rms_attn(x)) |
| 191 | + if self.has_memory and self.memory is not None and self.memory.size > 0: |
| 192 | + x_mean = x.mean(dim=1) |
| 193 | + retrieved = self.memory.read(x_mean[0]) |
| 194 | + mem_inject = self.mem_gate(retrieved) |
| 195 | + x = x + mem_inject.unsqueeze(0).unsqueeze(0) |
| 196 | + x = x + self.ffn(self.rms_ffn(x)) |
| 197 | + return x |
| 198 | + |
| 199 | + |
| 200 | +class MinGRUBlock(nn.Module): |
| 201 | + """Original pure-MinGRU block (no MoE, no attention).""" |
| 202 | + def __init__(self, cfg: TrainConfig): |
| 203 | + super().__init__() |
| 204 | + self.rms_mix = RMSNorm(cfg.d_model) |
| 205 | + self.rms_ffn = RMSNorm(cfg.d_model) |
| 206 | + self.mix = MinGRULayer( |
| 207 | + cfg.d_model, |
| 208 | + enable_hypergrad=cfg.enable_hypergrad, |
| 209 | + hypergrad_scale_init=cfg.hypergrad_scale_init, |
| 210 | + ) |
| 211 | + self.ffn = GLUChannelMix(cfg.d_model, cfg.d_ffn) |
| 212 | + |
| 213 | + def forward(self, x: torch.Tensor, surprise_gain: float = 0.0) -> torch.Tensor: |
| 214 | + x = x + self.mix(self.rms_mix(x), surprise_gain=surprise_gain) |
| 215 | + x = x + self.ffn(self.rms_ffn(x)) |
| 216 | + return x |
| 217 | + |
| 218 | + |
0 commit comments