Hi, thanks for open-sourcing Helios.
While adapting the Stage2 Pyramid Unified Predictor Corrector pipeline for robot-control video generation, I encountered a confusing issue regarding latent downsampling statistics.
I would like to ask whether the current implementation is intentionally designed for Gaussian-noise transport only, or whether VAE latent conditions are also expected to follow the same variance scaling assumptions.
Background
In Helios Stage2, the latent pyramid uses:
latents = F.interpolate(
latents,
size=(height, width),
mode="bilinear",
align_corners=False,
) * factor
This seems to assume:
std_after_downsample ≈ 1 / factor
so multiplying by factor restores the original variance.
My Confusion
This assumption appears to hold for:
but not for:
F.interpolate(..., mode="bilinear")
especially when sequential pyramid downsampling is used.
For example, my actual pyramid is:
rather than a single /6.
Reproduction
Bilinear Sequential Downsampling
import torch
import torch.nn.functional as F
torch.manual_seed(0)
x = torch.randn(1, 16, 256, 256)
def downsample(t, factor):
h = t.shape[-2] // factor
w = t.shape[-1] // factor
return F.interpolate(
t,
size=(h, w),
mode="bilinear",
align_corners=True,
)
x_3 = downsample(x, 3)
x_3_2 = downsample(x_3, 2)
print("orig std:", x.std().item())
print("/3 std:", x_3.std().item())
print("/3 -> /2 std:", x_3_2.std().item())
print("(/3 -> /2) * 6 std:", (x_3_2 * 6).std().item())
Results:
orig std: 1.000
/3 std: 0.666
/3 -> /2 std: 0.422
(/3 -> /2) * 6 std: 2.53
So:
std shrinkage is NOT 1/factor
for bilinear interpolation.
AvgPool Behavior
However, with strict average pooling:
x_3 = F.avg_pool2d(x, kernel_size=3, stride=3)
x_3_2 = F.avg_pool2d(x_3, kernel_size=2, stride=2)
I obtain:
/3 std ≈ 0.333
/3 -> /2 std ≈ 0.167
which DOES match:
and therefore:
correctly restores variance.
Visualization Script
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
torch.manual_seed(0)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
B, C, H, W = 1, 16, 256, 256
x = torch.randn(B, C, H, W, device=DEVICE)
orig_std = x.std().item()
orig_mean = x.mean().item()
print("=" * 80)
print(f"Original mean = {orig_mean:.6f}")
print(f"Original std = {orig_std:.6f}")
print("=" * 80)
# ============================================================
# Sequential bilinear downsampling:
# /3 then /2
# ============================================================
def bilinear_downsample(x, factor):
h = x.shape[-2] // factor
w = x.shape[-1] // factor
return F.interpolate(
x,
size=(h, w),
mode="bilinear",
align_corners=True,
)
x_3 = bilinear_downsample(x, 3)
x_3_2 = bilinear_downsample(x_3, 2)
# ============================================================
# One-shot /6
# ============================================================
x_6 = bilinear_downsample(x, 6)
# ============================================================
# Compensation
# ============================================================
x_3_scaled = x_3 * 3
x_3_2_scaled = x_3_2 * 6
x_6_scaled = x_6 * 6
# ============================================================
# Std matching
# ============================================================
x_3_2_std_matched = (
(x_3_2 - x_3_2.mean())
/ (x_3_2.std() + 1e-6)
) * orig_std + orig_mean
# ============================================================
# Print statistics
# ============================================================
def print_stats(name, t):
print(
f"{name:30s} | "
f"shape={tuple(t.shape)} | "
f"mean={t.mean().item():+.6f} | "
f"std={t.std().item():.6f}"
)
print_stats("Original", x)
print_stats("Bilinear /3", x_3)
print_stats("Bilinear /3 then /2", x_3_2)
print_stats("One-shot Bilinear /6", x_6)
print("-" * 80)
print_stats("(/3) * 3", x_3_scaled)
print_stats("(/3 then /2) * 6", x_3_2_scaled)
print_stats("(/6) * 6", x_6_scaled)
print_stats("(/3 then /2) std matched", x_3_2_std_matched)
# ============================================================
# Std curve
# ============================================================
names = [
"Original",
"/3",
"/3->/2",
"/6",
"(/3)*3",
"(/3->/2)*6",
"std matched",
]
stds = [
x.std().item(),
x_3.std().item(),
x_3_2.std().item(),
x_6.std().item(),
x_3_scaled.std().item(),
x_3_2_scaled.std().item(),
x_3_2_std_matched.std().item(),
]
plt.figure(figsize=(12, 6))
plt.bar(names, stds)
plt.axhline(
y=orig_std,
linestyle="--",
linewidth=2,
label="Original Std",
)
plt.ylabel("Standard Deviation")
plt.title("Std After Bilinear Downsampling")
plt.grid(axis="y")
plt.legend()
plt.tight_layout()
plt.show()
# ============================================================
# Distribution visualization
# ============================================================
plt.figure(figsize=(12, 7))
for name, t in [
("Original", x),
("/3", x_3),
("/3->/2", x_3_2),
("/6", x_6),
("std matched", x_3_2_std_matched),
]:
plt.hist(
t.flatten().float().cpu().numpy(),
bins=120,
density=True,
alpha=0.35,
label=f"{name}, std={t.std().item():.3f}",
)
plt.xlabel("Value")
plt.ylabel("Density")
plt.title("Distribution After Bilinear Downsampling")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()
Question
Is the current Helios Stage2 variance compensation:
intended specifically for:
pure Gaussian noise transport
rather than generic VAE latent feature maps?
Because when applying the same logic to VAE control latents, the variance becomes severely amplified after sequential bilinear downsampling.
Observation
For VAE latents, dynamic variance matching seems much more stable:
x = x * (target_std / current_std)
instead of:
especially under multi-stage bilinear pyramid resizing.
Additional Context
My setup uses:
- VAE latent control videos
- sequential pyramid downsampling (
/3 then /2)
- bilinear interpolation
- Stage2-style multi-scale latent transport
The issue manifests as:
- mosaic artifacts
- checkerboard textures
- unstable control conditioning
which appear correlated with latent variance explosion.
Thanks!
Hi, thanks for open-sourcing Helios.
While adapting the Stage2 Pyramid Unified Predictor Corrector pipeline for robot-control video generation, I encountered a confusing issue regarding latent downsampling statistics.
I would like to ask whether the current implementation is intentionally designed for Gaussian-noise transport only, or whether VAE latent conditions are also expected to follow the same variance scaling assumptions.
Background
In Helios Stage2, the latent pyramid uses:
This seems to assume:
so multiplying by
factorrestores the original variance.My Confusion
This assumption appears to hold for:
but not for:
especially when sequential pyramid downsampling is used.
For example, my actual pyramid is:
rather than a single
/6.Reproduction
Bilinear Sequential Downsampling
Results:
So:
for bilinear interpolation.
AvgPool Behavior
However, with strict average pooling:
I obtain:
which DOES match:
and therefore:
correctly restores variance.
Visualization Script
Question
Is the current Helios Stage2 variance compensation:
intended specifically for:
rather than generic VAE latent feature maps?
Because when applying the same logic to VAE control latents, the variance becomes severely amplified after sequential bilinear downsampling.
Observation
For VAE latents, dynamic variance matching seems much more stable:
instead of:
especially under multi-stage bilinear pyramid resizing.
Additional Context
My setup uses:
/3then/2)The issue manifests as:
which appear correlated with latent variance explosion.
Thanks!