Skip to content

Commit 7bf58b3

Browse files
committed
Don't reject uneven placements in placement
1 parent a80c32f commit 7bf58b3

5 files changed

Lines changed: 105 additions & 29 deletions

File tree

nix/mlx.nix

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ let
4949
owner = "rltakashige";
5050
repo = "mlx-jaccl-fix-small-recv";
5151
rev = uvLockMlxRev;
52-
hash = "sha256-GosFIWxIB48Egb1MqJrR3xhsUsQeWdRk5rV93USY6wQ=";
52+
hash = "sha256-e4PrRvaPdUX0rCJ9az8BuGDtsu33oTP4V3kZGLueRDA=";
5353
};
5454

5555
patches = [

src/exo/master/placement.py

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -128,26 +128,12 @@ def place_instance(
128128
if len(cycles_with_sufficient_memory) == 0:
129129
raise ValueError("No cycles found with sufficient memory")
130130

131-
if command.sharding == Sharding.Tensor:
132-
if not command.model_card.supports_tensor:
133-
raise ValueError(
134-
f"Requested Tensor sharding but this model does not support tensor parallelism: {command.model_card.model_id}"
135-
)
136-
# TODO: the condition here for tensor parallel is not correct, but it works good enough for now.
137-
kv_heads = command.model_card.num_key_value_heads
138-
cycles_with_sufficient_memory = [
139-
cycle
140-
for cycle in cycles_with_sufficient_memory
141-
if command.model_card.hidden_size % len(cycle) == 0
142-
and (kv_heads is None or kv_heads % len(cycle) == 0)
143-
]
144-
if not cycles_with_sufficient_memory:
145-
raise ValueError(
146-
f"No tensor sharding found for model with "
147-
f"hidden_size={command.model_card.hidden_size}"
148-
f"{f', num_key_value_heads={kv_heads}' if kv_heads is not None else ''}"
149-
f" across candidate cycles"
150-
)
131+
if command.sharding == Sharding.Tensor and not command.model_card.supports_tensor:
132+
raise ValueError(
133+
f"Requested Tensor sharding but this model does not support tensor parallelism: {command.model_card.model_id}"
134+
)
135+
136+
# Uneven tensor sharding handles arbitrary world sizes — no divisibility check needed
151137
if command.sharding == Sharding.Pipeline and command.model_card.model_id == ModelId(
152138
"mlx-community/DeepSeek-V3.1-8bit"
153139
):

src/exo/shared/types/worker/shards.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from enum import Enum
22
from typing import TypeAlias, final
33

4-
from pydantic import Field
4+
from pydantic import Field, field_validator
55

66
from exo.shared.models.model_cards import ModelCard
77
from exo.utils.pydantic_ext import TaggedModel
@@ -91,6 +91,15 @@ class TensorShardMetadata(BaseShardMetadata):
9191
shard_weights: list[float] | None = None
9292
shard_mode: TensorShardMode = TensorShardMode.Constant
9393

94+
@field_validator("shard_mode", mode="before")
95+
@classmethod
96+
def _coerce_shard_mode(cls, v: object) -> TensorShardMode:
97+
if isinstance(v, str):
98+
return TensorShardMode(v)
99+
if isinstance(v, TensorShardMode):
100+
return v
101+
raise ValueError(f"expected TensorShardMode or str, got {type(v).__name__}")
102+
94103

95104
ShardMetadata: TypeAlias = (
96105
PipelineShardMetadata | CfgShardMetadata | TensorShardMetadata

src/exo/worker/engines/mlx/auto_parallel.py

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -536,13 +536,70 @@ def _sharded_to_all(path: str, weight: mx.array):
536536
return None
537537
return -1, segments
538538

539-
sharded_to_all_linear_in_place = partial(
539+
_base_sharded_to_all_in_place = partial(
540540
shard_inplace,
541541
sharding=_sharded_to_all, # type: ignore
542542
group=group,
543543
weights=shard_weights,
544544
)
545545

546+
_base_all_to_sharded_in_place = all_to_sharded_linear_in_place
547+
548+
def _quantized_moe_shard_inplace(
549+
module: nn.Module,
550+
sharding: Literal["all-to-sharded", "sharded-to-all"],
551+
weights: list[float] | None = None,
552+
) -> None:
553+
N = group.size()
554+
r = group.rank()
555+
gs = module.group_size # pyright: ignore[reportAttributeAccessIssue]
556+
bits = module.bits # pyright: ignore[reportAttributeAccessIssue]
557+
params = module.parameters()
558+
scales = params["scales"]
559+
560+
if sharding == "all-to-sharded":
561+
dim = params["weight"].shape[max(params["weight"].ndim - 2, 0)]
562+
sizes = compute_shard_sizes(dim, N, gs, weights)
563+
result: dict[str, Any] = {}
564+
for key, param in params.items():
565+
if not isinstance(param, mx.array):
566+
result[key] = param
567+
continue
568+
axis = max(param.ndim - 2, 0)
569+
indices = [sum(sizes[:i]) for i in range(1, len(sizes))]
570+
result[key] = mx.contiguous(mx.split(param, indices, axis=axis)[r])
571+
else:
572+
num_groups = scales.shape[-1]
573+
group_counts = compute_shard_sizes(num_groups, N, 1, weights)
574+
weight_ppg = gs * bits // 32
575+
result = {}
576+
for key, param in params.items():
577+
if not isinstance(param, mx.array):
578+
result[key] = param
579+
continue
580+
if key == "weight":
581+
s = [gc * weight_ppg for gc in group_counts]
582+
elif key in ("scales", "biases"):
583+
s = list(group_counts)
584+
else:
585+
result[key] = param
586+
continue
587+
indices = [sum(s[:i]) for i in range(1, len(s))]
588+
result[key] = mx.contiguous(mx.split(param, indices, axis=-1)[r])
589+
module.update(result)
590+
591+
def all_to_sharded_linear_in_place(module: nn.Module, **kwargs: Any) -> None:
592+
if getattr(module, "group_size", 0) > 0 and getattr(module, "bits", 0) > 0 and "scales" in module.parameters():
593+
_quantized_moe_shard_inplace(module, "all-to-sharded", weights=kwargs.get("weights"))
594+
else:
595+
_base_all_to_sharded_in_place(module, **kwargs)
596+
597+
def sharded_to_all_linear_in_place(module: nn.Module, **kwargs: Any) -> None:
598+
if getattr(module, "group_size", 0) > 0 and getattr(module, "bits", 0) > 0 and "scales" in module.parameters():
599+
_quantized_moe_shard_inplace(module, "sharded-to-all", weights=kwargs.get("weights"))
600+
else:
601+
_base_sharded_to_all_in_place(module, **kwargs)
602+
546603
if isinstance(model, (LlamaModel, Ministral3Model)):
547604
tensor_parallel_sharding_strategy = LlamaShardingStrategy(
548605
group,
@@ -778,16 +835,20 @@ def shard_model(
778835
layer.self_attn.k_proj.weight.shape[0] // head_dim
779836
)
780837

838+
mlp_unit = getattr(layer.mlp.gate_proj, "group_size", 1)
781839
layer.mlp.gate_proj = self.all_to_sharded_linear(
782840
layer.mlp.gate_proj,
841+
unit=mlp_unit,
783842
weights=self._greedy_weights_for("gate", intermediate),
784843
)
785844
layer.mlp.down_proj = self.sharded_to_all_linear(
786845
layer.mlp.down_proj,
846+
unit=mlp_unit,
787847
weights=self._greedy_weights_for("down", intermediate),
788848
)
789849
layer.mlp.up_proj = self.all_to_sharded_linear(
790850
layer.mlp.up_proj,
851+
unit=mlp_unit,
791852
weights=self._greedy_weights_for("up", intermediate),
792853
)
793854
mx.eval(layer)
@@ -890,16 +951,20 @@ def shard_heads(w: mx.array, sh: int = sh, eh: int = eh) -> mx.array:
890951
# Shard the MLP
891952
if isinstance(layer.mlp, (DeepseekV3MLP, DeepseekV32MLP)):
892953
intermediate = layer.mlp.gate_proj.weight.shape[0]
954+
mlp_unit = getattr(layer.mlp.gate_proj, "group_size", 1)
893955
layer.mlp.gate_proj = self.all_to_sharded_linear(
894956
layer.mlp.gate_proj,
957+
unit=mlp_unit,
895958
weights=self._greedy_weights_for("gate", intermediate),
896959
)
897960
layer.mlp.down_proj = self.sharded_to_all_linear(
898961
layer.mlp.down_proj,
962+
unit=mlp_unit,
899963
weights=self._greedy_weights_for("down", intermediate),
900964
)
901965
layer.mlp.up_proj = self.all_to_sharded_linear(
902966
layer.mlp.up_proj,
967+
unit=mlp_unit,
903968
weights=self._greedy_weights_for("up", intermediate),
904969
)
905970

@@ -1037,16 +1102,20 @@ def shard_heads(w: mx.array, sh: int = sh, eh: int = eh) -> mx.array:
10371102

10381103
if isinstance(layer.mlp, Glm4MoeLiteMLP):
10391104
intermediate = layer.mlp.gate_proj.weight.shape[0]
1105+
mlp_unit = getattr(layer.mlp.gate_proj, "group_size", 1)
10401106
layer.mlp.gate_proj = self.all_to_sharded_linear(
10411107
layer.mlp.gate_proj,
1108+
unit=mlp_unit,
10421109
weights=self._greedy_weights_for("gate", intermediate),
10431110
)
10441111
layer.mlp.down_proj = self.sharded_to_all_linear(
10451112
layer.mlp.down_proj,
1113+
unit=mlp_unit,
10461114
weights=self._greedy_weights_for("down", intermediate),
10471115
)
10481116
layer.mlp.up_proj = self.all_to_sharded_linear(
10491117
layer.mlp.up_proj,
1118+
unit=mlp_unit,
10501119
weights=self._greedy_weights_for("up", intermediate),
10511120
)
10521121

@@ -1516,16 +1585,20 @@ def shard_model(
15161585
# Shard the MLP
15171586
else:
15181587
intermediate = layer.mlp.gate_proj.weight.shape[0]
1588+
mlp_unit = getattr(layer.mlp.gate_proj, "group_size", 1)
15191589
layer.mlp.gate_proj = self.all_to_sharded_linear(
15201590
layer.mlp.gate_proj,
1591+
unit=mlp_unit,
15211592
weights=self._greedy_weights_for("gate", intermediate),
15221593
)
15231594
layer.mlp.down_proj = self.sharded_to_all_linear(
15241595
layer.mlp.down_proj,
1596+
unit=mlp_unit,
15251597
weights=self._greedy_weights_for("down", intermediate),
15261598
)
15271599
layer.mlp.up_proj = self.all_to_sharded_linear(
15281600
layer.mlp.up_proj,
1601+
unit=mlp_unit,
15291602
weights=self._greedy_weights_for("up", intermediate),
15301603
)
15311604

@@ -1622,16 +1695,20 @@ def shard_model(
16221695

16231696
else:
16241697
intermediate = layer.mlp.gate_proj.weight.shape[0]
1698+
mlp_unit = getattr(layer.mlp.gate_proj, "group_size", 1)
16251699
layer.mlp.gate_proj = self.all_to_sharded_linear(
16261700
layer.mlp.gate_proj,
1701+
unit=mlp_unit,
16271702
weights=self._greedy_weights_for("gate", intermediate),
16281703
)
16291704
layer.mlp.down_proj = self.sharded_to_all_linear(
16301705
layer.mlp.down_proj,
1706+
unit=mlp_unit,
16311707
weights=self._greedy_weights_for("down", intermediate),
16321708
)
16331709
layer.mlp.up_proj = self.all_to_sharded_linear(
16341710
layer.mlp.up_proj,
1711+
unit=mlp_unit,
16351712
weights=self._greedy_weights_for("up", intermediate),
16361713
)
16371714

@@ -1792,16 +1869,20 @@ def shard_model(
17921869

17931870
if isinstance(layer.mlp, Step35MLP):
17941871
intermediate = layer.mlp.gate_proj.weight.shape[0]
1872+
mlp_unit = getattr(layer.mlp.gate_proj, "group_size", 1)
17951873
layer.mlp.gate_proj = self.all_to_sharded_linear(
17961874
layer.mlp.gate_proj,
1875+
unit=mlp_unit,
17971876
weights=self._greedy_weights_for("gate", intermediate),
17981877
)
17991878
layer.mlp.up_proj = self.all_to_sharded_linear(
18001879
layer.mlp.up_proj,
1880+
unit=mlp_unit,
18011881
weights=self._greedy_weights_for("up", intermediate),
18021882
)
18031883
layer.mlp.down_proj = self.sharded_to_all_linear(
18041884
layer.mlp.down_proj,
1885+
unit=mlp_unit,
18051886
weights=self._greedy_weights_for("down", intermediate),
18061887
)
18071888
else:

uv.lock

Lines changed: 6 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)