Skip to content

Helios Stage2 Pyramid Latent Downsampling: Variance Distribution Mismatch for VAE Latents #116

Description

@Yvonne-OH

Hi, thanks for open-sourcing Helios.

While adapting the Stage2 Pyramid Unified Predictor Corrector pipeline for robot-control video generation, I encountered a confusing issue regarding latent downsampling statistics.

I would like to ask whether the current implementation is intentionally designed for Gaussian-noise transport only, or whether VAE latent conditions are also expected to follow the same variance scaling assumptions.


Background

In Helios Stage2, the latent pyramid uses:

latents = F.interpolate(
    latents,
    size=(height, width),
    mode="bilinear",
    align_corners=False,
) * factor

This seems to assume:

std_after_downsample ≈ 1 / factor

so multiplying by factor restores the original variance.


My Confusion

This assumption appears to hold for:

F.avg_pool2d(...)

but not for:

F.interpolate(..., mode="bilinear")

especially when sequential pyramid downsampling is used.

For example, my actual pyramid is:

/3 -> /2

rather than a single /6.


Reproduction

Bilinear Sequential Downsampling

import torch
import torch.nn.functional as F

torch.manual_seed(0)

x = torch.randn(1, 16, 256, 256)

def downsample(t, factor):
    h = t.shape[-2] // factor
    w = t.shape[-1] // factor

    return F.interpolate(
        t,
        size=(h, w),
        mode="bilinear",
        align_corners=True,
    )

x_3 = downsample(x, 3)
x_3_2 = downsample(x_3, 2)

print("orig std:", x.std().item())
print("/3 std:", x_3.std().item())
print("/3 -> /2 std:", x_3_2.std().item())
print("(/3 -> /2) * 6 std:", (x_3_2 * 6).std().item())

Results:

orig std:              1.000
/3 std:                0.666
/3 -> /2 std:          0.422
(/3 -> /2) * 6 std:    2.53

So:

std shrinkage is NOT 1/factor

for bilinear interpolation.


AvgPool Behavior

However, with strict average pooling:

x_3 = F.avg_pool2d(x, kernel_size=3, stride=3)
x_3_2 = F.avg_pool2d(x_3, kernel_size=2, stride=2)

I obtain:

/3 std ≈ 0.333
/3 -> /2 std ≈ 0.167

which DOES match:

std ≈ 1 / total_factor

and therefore:

x * factor

correctly restores variance.


Visualization Script

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

torch.manual_seed(0)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

B, C, H, W = 1, 16, 256, 256

x = torch.randn(B, C, H, W, device=DEVICE)

orig_std = x.std().item()
orig_mean = x.mean().item()

print("=" * 80)
print(f"Original mean = {orig_mean:.6f}")
print(f"Original std  = {orig_std:.6f}")
print("=" * 80)

# ============================================================
# Sequential bilinear downsampling:
# /3 then /2
# ============================================================

def bilinear_downsample(x, factor):
    h = x.shape[-2] // factor
    w = x.shape[-1] // factor

    return F.interpolate(
        x,
        size=(h, w),
        mode="bilinear",
        align_corners=True,
    )

x_3 = bilinear_downsample(x, 3)
x_3_2 = bilinear_downsample(x_3, 2)

# ============================================================
# One-shot /6
# ============================================================

x_6 = bilinear_downsample(x, 6)

# ============================================================
# Compensation
# ============================================================

x_3_scaled = x_3 * 3
x_3_2_scaled = x_3_2 * 6
x_6_scaled = x_6 * 6

# ============================================================
# Std matching
# ============================================================

x_3_2_std_matched = (
    (x_3_2 - x_3_2.mean())
    / (x_3_2.std() + 1e-6)
) * orig_std + orig_mean

# ============================================================
# Print statistics
# ============================================================

def print_stats(name, t):
    print(
        f"{name:30s} | "
        f"shape={tuple(t.shape)} | "
        f"mean={t.mean().item():+.6f} | "
        f"std={t.std().item():.6f}"
    )

print_stats("Original", x)
print_stats("Bilinear /3", x_3)
print_stats("Bilinear /3 then /2", x_3_2)
print_stats("One-shot Bilinear /6", x_6)

print("-" * 80)

print_stats("(/3) * 3", x_3_scaled)
print_stats("(/3 then /2) * 6", x_3_2_scaled)
print_stats("(/6) * 6", x_6_scaled)
print_stats("(/3 then /2) std matched", x_3_2_std_matched)

# ============================================================
# Std curve
# ============================================================

names = [
    "Original",
    "/3",
    "/3->/2",
    "/6",
    "(/3)*3",
    "(/3->/2)*6",
    "std matched",
]

stds = [
    x.std().item(),
    x_3.std().item(),
    x_3_2.std().item(),
    x_6.std().item(),
    x_3_scaled.std().item(),
    x_3_2_scaled.std().item(),
    x_3_2_std_matched.std().item(),
]

plt.figure(figsize=(12, 6))

plt.bar(names, stds)

plt.axhline(
    y=orig_std,
    linestyle="--",
    linewidth=2,
    label="Original Std",
)

plt.ylabel("Standard Deviation")
plt.title("Std After Bilinear Downsampling")
plt.grid(axis="y")
plt.legend()

plt.tight_layout()
plt.show()

# ============================================================
# Distribution visualization
# ============================================================

plt.figure(figsize=(12, 7))

for name, t in [
    ("Original", x),
    ("/3", x_3),
    ("/3->/2", x_3_2),
    ("/6", x_6),
    ("std matched", x_3_2_std_matched),
]:
    plt.hist(
        t.flatten().float().cpu().numpy(),
        bins=120,
        density=True,
        alpha=0.35,
        label=f"{name}, std={t.std().item():.3f}",
    )

plt.xlabel("Value")
plt.ylabel("Density")
plt.title("Distribution After Bilinear Downsampling")
plt.grid(True)
plt.legend()

plt.tight_layout()
plt.show()

Question

Is the current Helios Stage2 variance compensation:

* factor

intended specifically for:

pure Gaussian noise transport

rather than generic VAE latent feature maps?

Because when applying the same logic to VAE control latents, the variance becomes severely amplified after sequential bilinear downsampling.


Observation

For VAE latents, dynamic variance matching seems much more stable:

x = x * (target_std / current_std)

instead of:

x = x * factor

especially under multi-stage bilinear pyramid resizing.


Additional Context

My setup uses:

  • VAE latent control videos
  • sequential pyramid downsampling (/3 then /2)
  • bilinear interpolation
  • Stage2-style multi-scale latent transport

The issue manifests as:

  • mosaic artifacts
  • checkerboard textures
  • unstable control conditioning

which appear correlated with latent variance explosion.

Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions