Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 220 additions & 1 deletion mooncake-wheel/mooncake/structured_object_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import uuid
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
from contextlib import contextmanager
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Iterator, Literal, Mapping, Optional, Protocol, Sequence

import numpy as np
Expand Down Expand Up @@ -1303,3 +1303,222 @@ def _cleanup_keys(store: BundleStore, keys: Sequence[str], strict: bool) -> None
raise RuntimeError(
f"failed to remove {len(retry_errors)} Mooncake keys: {retry_errors[:3]}"
)


# ---------------------------------------------------------------------------
# Codec inference and recursive structure expansion
# ---------------------------------------------------------------------------

_INFER_MAX_STRUCT_KEYS = 64
_INFER_MAX_LIST_LEN = 8
_INFER_MAX_SAMPLE_ROWS = 128
_INFER_MAX_JSON_BYTES = 1 << 20
_INFER_MAX_DEPTH = 32

try:
import torch as _torch
except Exception: # pragma: no cover
_torch = None # type: ignore[assignment]


@dataclass
class _CodecDecision:
accepted: bool
codec: str
reason: str
normalized_type: str
metadata: dict[str, Any] = field(default_factory=dict)


@dataclass
class _InferredLeaf:
path: str
values: list[Any]
decision: _CodecDecision


@dataclass
class _InferredNode:
path: str
node_type: str
children: list[Any]
lengths: list[int] | None = None


def _non_null(values: list[Any]) -> list[Any]:
return [v for v in values if v is not None]


def _is_pil_image(value: Any) -> bool:
module = getattr(value.__class__, "__module__", None) or ""
return module.startswith("PIL.") and hasattr(value, "save")


def _is_bytes_like(value: Any) -> bool:
return isinstance(value, (bytes, bytearray, memoryview))


def _is_media_list(value: Any) -> bool:
return (
isinstance(value, (list, tuple))
and len(value) > 0
and all(_is_pil_image(item) or _is_bytes_like(item) for item in value)
)


def _check_all(
values: list[Any],
predicate: Any,
codec: str,
normalized_type: str,
) -> _CodecDecision:
nn = _non_null(values)
if not nn:
return _CodecDecision(False, codec, "all rows are null", normalized_type)
if not all(predicate(v) for v in nn):
return _CodecDecision(False, codec, f"not all rows are {normalized_type}", normalized_type)
return _CodecDecision(True, codec, f"all non-null rows are {normalized_type}", normalized_type)


def _can_tensor(values: list[Any]) -> _CodecDecision:
if _torch is None:
return _CodecDecision(False, "ragged_tensor", "torch is not available", "torch.Tensor")
nn = _non_null(values)
if not nn:
return _CodecDecision(False, "ragged_tensor", "all rows are null", "torch.Tensor")
if not all(isinstance(v, _torch.Tensor) for v in nn):
return _CodecDecision(False, "ragged_tensor", "not all rows are Tensor", "torch.Tensor")
dtypes = sorted({str(v.dtype) for v in nn})
if len(dtypes) != 1:
return _CodecDecision(False, "ragged_tensor", f"mixed dtypes: {dtypes}", "torch.Tensor")
return _CodecDecision(True, "ragged_tensor", "all non-null rows are Tensor", "torch.Tensor", {"dtype": dtypes[0]})
Comment thread
zxpdemonio marked this conversation as resolved.


def _can_numeric_sequence(values: list[Any]) -> _CodecDecision:
nn = _non_null(values)
if not nn:
return _CodecDecision(False, "typed_ragged", "all rows are null", "numeric sequence")
dtypes = []
for v in nn:
if isinstance(v, np.ndarray):
arr = v
elif isinstance(v, (list, tuple)):
try:
arr = np.asarray(v)
except (ValueError, TypeError):
return _CodecDecision(False, "typed_ragged", "row cannot be converted to ndarray", "numeric sequence")
else:
return _CodecDecision(False, "typed_ragged", "row is not array-like", "numeric sequence")
if arr.dtype == object or not np.issubdtype(arr.dtype, np.number):
return _CodecDecision(False, "typed_ragged", f"non-numeric dtype: {arr.dtype}", "numeric sequence")
dtypes.append(arr.dtype)
dtype = np.result_type(*dtypes)
return _CodecDecision(True, "typed_ragged", "all rows promote to common numeric dtype", "numeric sequence", {"dtype": str(dtype)})


def _can_numeric_scalar(values: list[Any]) -> _CodecDecision:
nn = _non_null(values)
if not nn:
return _CodecDecision(False, "ndarray", "all rows are null", "numeric scalar")
if not all(isinstance(v, (bool, int, float, np.number)) for v in nn):
return _CodecDecision(False, "ndarray", "not all rows are numeric scalar", "numeric scalar")
try:
dtype = np.result_type(*nn)
except (TypeError, ValueError):
return _CodecDecision(False, "ndarray", "cannot determine common dtype", "numeric scalar")
Comment thread
zxpdemonio marked this conversation as resolved.
if not np.issubdtype(dtype, np.number) and not np.issubdtype(dtype, np.bool_):
return _CodecDecision(False, "ndarray", f"non-numeric dtype: {dtype}", "numeric scalar")
return _CodecDecision(True, "ndarray", "all rows are numeric scalar", "numeric scalar", {"dtype": str(dtype)})


def _can_json(values: list[Any]) -> _CodecDecision:
nn = _non_null(values)
if not nn:
return _CodecDecision(False, "json_ragged", "all rows are null", "json")
total_bytes = 0
for v in nn[:_INFER_MAX_SAMPLE_ROWS]:
Comment thread
zxpdemonio marked this conversation as resolved.
Outdated
try:
total_bytes += len(json.dumps(v, ensure_ascii=False, separators=(",", ":")).encode("utf-8"))
except (TypeError, ValueError):
return _CodecDecision(False, "json_ragged", "serialization failed", "json")
if total_bytes > _INFER_MAX_JSON_BYTES:
return _CodecDecision(False, "json_ragged", "sampled payload too large", "json")
return _CodecDecision(True, "json_ragged", "sampled rows pass JSON serialization", "json")


_CODEC_PREDICATES: tuple[Any, ...] = (
_can_tensor,
lambda v: _check_all(v, _is_media_list, "media_list_ragged", "media list"),
_can_numeric_sequence,
_can_numeric_scalar,
lambda v: _check_all(v, _is_bytes_like, "bytes_ragged", "bytes-like"),
lambda v: _check_all(v, _is_pil_image, "media_bytes", "media"),
lambda v: _check_all(v, lambda x: isinstance(x, str), "utf8_ragged", "str"),
_can_json,
)


def _choose_leaf_codec(values: list[Any]) -> _CodecDecision:
for predicate in _CODEC_PREDICATES:
decision = predicate(values)
if decision.accepted:
return decision
return _CodecDecision(False, "pickle_ragged_fallback", "no optimized codec matched", "python object")


def _try_expand_dict(values: list[Any]) -> list[str] | None:
nn = _non_null(values)
if not nn or not all(isinstance(v, dict) for v in nn):
return None
keys = {k for v in nn for k in v.keys()}
if not all(isinstance(k, str) for k in keys) or len(keys) > _INFER_MAX_STRUCT_KEYS:
return None
return sorted(keys)


def _try_expand_list(values: list[Any]) -> tuple[int, list[int]] | None:
nn = _non_null(values)
if not nn or not all(isinstance(v, (list, tuple)) for v in nn):
return None
max_len = max(len(v) for v in nn)
if max_len > _INFER_MAX_LIST_LEN:
return None
if not all(isinstance(item, (dict, list, tuple)) for v in nn for item in v):
return None
Comment thread
zxpdemonio marked this conversation as resolved.
Outdated
lengths = [len(v) if isinstance(v, (list, tuple)) else 0 for v in values]
return max_len, lengths


def infer_structure(
path: str,
Comment thread
zxpdemonio marked this conversation as resolved.
values: list[Any],
leaves: list[_InferredLeaf],
nodes: list[_InferredNode],
*,
_depth: int = 0,
) -> None:
"""Recursively expand *values* into leaves and interior nodes.

*leaves* and *nodes* are output accumulators populated during recursion.
"""
if _depth > _INFER_MAX_DEPTH:
raise ValueError(f"infer_structure exceeded max depth {_INFER_MAX_DEPTH} at {path!r}")
dict_keys = _try_expand_dict(values)
if dict_keys is not None:
nodes.append(_InferredNode(path, "dict", dict_keys))
for key in dict_keys:
child = [v.get(key) if isinstance(v, dict) else None for v in values]
Comment thread
zxpdemonio marked this conversation as resolved.
Outdated
infer_structure(f"{path}.{key}", child, leaves, nodes, _depth=_depth + 1)
Comment thread
zxpdemonio marked this conversation as resolved.
Outdated
return
list_result = _try_expand_list(values)
if list_result is not None:
max_len, lengths = list_result
nodes.append(_InferredNode(path, "list", list(range(max_len)), lengths))
for index in range(max_len):
child = [
v[index] if isinstance(v, (list, tuple)) and index < len(v) else None
for v in values
]
infer_structure(f"{path}[{index}]", child, leaves, nodes, _depth=_depth + 1)
return
leaves.append(_InferredLeaf(path, values, _choose_leaf_codec(values)))
110 changes: 110 additions & 0 deletions mooncake-wheel/tests/test_put_get_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,5 +235,115 @@ def test_zero_tensor_with_tp(self):
self.store.remove(f"{key}_tp_{rank}")


from mooncake.structured_object_store import (
_choose_leaf_codec,
infer_structure,
)


class TestCodecInference(unittest.TestCase):

def test_tensor(self):
import torch
d = _choose_leaf_codec([torch.tensor([1, 2]), torch.tensor([3])])
self.assertTrue(d.accepted)
self.assertEqual(d.codec, "ragged_tensor")

def test_tensor_mixed_dtype_rejected(self):
import torch
d = _choose_leaf_codec([torch.tensor([1], dtype=torch.float32), torch.tensor([1], dtype=torch.int64)])
self.assertFalse(d.accepted)

def test_numeric_sequence(self):
d = _choose_leaf_codec([[1, 2, 3], [4, 5]])
self.assertTrue(d.accepted)
self.assertEqual(d.codec, "typed_ragged")

def test_bytes(self):
d = _choose_leaf_codec([b"hello", b"world"])
self.assertTrue(d.accepted)
self.assertEqual(d.codec, "bytes_ragged")

def test_text(self):
d = _choose_leaf_codec(["hello", "world"])
self.assertTrue(d.accepted)
self.assertEqual(d.codec, "utf8_ragged")

def test_json(self):
d = _choose_leaf_codec([{"a": 1}, {"b": 2}])
self.assertTrue(d.accepted)
self.assertEqual(d.codec, "json_ragged")

def test_scalar(self):
d = _choose_leaf_codec([1, 2.0, 3])
self.assertTrue(d.accepted)
self.assertEqual(d.codec, "ndarray")

def test_fallback(self):
d = _choose_leaf_codec([object(), object()])
self.assertFalse(d.accepted)
self.assertEqual(d.codec, "pickle_ragged_fallback")

def test_with_nulls(self):
d = _choose_leaf_codec(["hello", None, "world"])
self.assertTrue(d.accepted)
self.assertEqual(d.codec, "utf8_ragged")

def test_empty_values(self):
d = _choose_leaf_codec([])
self.assertFalse(d.accepted)

def test_all_none(self):
d = _choose_leaf_codec([None, None, None])
self.assertFalse(d.accepted)

def test_infer_flat(self):
leaves, nodes = [], []
infer_structure("root", ["a", "b", "c"], leaves, nodes)
self.assertEqual(len(leaves), 1)
self.assertEqual(len(nodes), 0)
self.assertEqual(leaves[0].decision.codec, "utf8_ragged")

def test_infer_dict(self):
leaves, nodes = [], []
infer_structure("root", [{"x": 1, "y": "a"}, {"x": 2, "y": "b"}], leaves, nodes)
self.assertEqual(len(nodes), 1)
self.assertEqual(nodes[0].node_type, "dict")
self.assertEqual(sorted(nodes[0].children), ["x", "y"])
self.assertEqual(sorted(l.path for l in leaves), ["root.x", "root.y"])

def test_infer_dict_with_none_rows(self):
leaves, nodes = [], []
infer_structure("r", [{"x": 1}, None, {"x": 3}], leaves, nodes)
self.assertEqual(len(nodes), 1)
self.assertEqual(len(leaves), 1)
self.assertEqual(leaves[0].path, "r.x")

def test_infer_nested(self):
leaves, nodes = [], []
infer_structure("r", [{"a": {"b": 1}}, {"a": {"b": 2}}], leaves, nodes)
self.assertEqual(len(nodes), 2)
self.assertEqual(leaves[0].path, "r.a.b")
self.assertEqual(leaves[0].decision.codec, "ndarray")

def test_infer_list(self):
leaves, nodes = [], []
infer_structure("r", [[{"k": 1}], [{"k": 2}]], leaves, nodes)
list_nodes = [n for n in nodes if n.node_type == "list"]
self.assertEqual(len(list_nodes), 1)
self.assertEqual(list_nodes[0].lengths, [1, 1])
dict_nodes = [n for n in nodes if n.node_type == "dict"]
self.assertEqual(len(dict_nodes), 1)
self.assertEqual(len(leaves), 1)
self.assertEqual(leaves[0].path, "r[0].k")

def test_depth_limit(self):
deep = 1
for _ in range(40):
deep = {"a": deep}
with self.assertRaises(ValueError):
infer_structure("r", [deep, deep], [], [])


if __name__ == '__main__':
unittest.main()
Loading