Skip to content

Commit be097ba

Browse files
committed
Don't reject uneven placements in placement
1 parent a80c32f commit be097ba

5 files changed

Lines changed: 47 additions & 28 deletions

File tree

nix/mlx.nix

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ let
4949
owner = "rltakashige";
5050
repo = "mlx-jaccl-fix-small-recv";
5151
rev = uvLockMlxRev;
52-
hash = "sha256-GosFIWxIB48Egb1MqJrR3xhsUsQeWdRk5rV93USY6wQ=";
52+
hash = "sha256-jMkUrQP7G6ceQ2ARlo/yTg+hM67dPNoDQl8MdbIKkgk=";
5353
};
5454

5555
patches = [

src/exo/master/placement.py

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -128,26 +128,12 @@ def place_instance(
128128
if len(cycles_with_sufficient_memory) == 0:
129129
raise ValueError("No cycles found with sufficient memory")
130130

131-
if command.sharding == Sharding.Tensor:
132-
if not command.model_card.supports_tensor:
133-
raise ValueError(
134-
f"Requested Tensor sharding but this model does not support tensor parallelism: {command.model_card.model_id}"
135-
)
136-
# TODO: the condition here for tensor parallel is not correct, but it works good enough for now.
137-
kv_heads = command.model_card.num_key_value_heads
138-
cycles_with_sufficient_memory = [
139-
cycle
140-
for cycle in cycles_with_sufficient_memory
141-
if command.model_card.hidden_size % len(cycle) == 0
142-
and (kv_heads is None or kv_heads % len(cycle) == 0)
143-
]
144-
if not cycles_with_sufficient_memory:
145-
raise ValueError(
146-
f"No tensor sharding found for model with "
147-
f"hidden_size={command.model_card.hidden_size}"
148-
f"{f', num_key_value_heads={kv_heads}' if kv_heads is not None else ''}"
149-
f" across candidate cycles"
150-
)
131+
if command.sharding == Sharding.Tensor and not command.model_card.supports_tensor:
132+
raise ValueError(
133+
f"Requested Tensor sharding but this model does not support tensor parallelism: {command.model_card.model_id}"
134+
)
135+
136+
# Uneven tensor sharding handles arbitrary world sizes — no divisibility check needed
151137
if command.sharding == Sharding.Pipeline and command.model_card.model_id == ModelId(
152138
"mlx-community/DeepSeek-V3.1-8bit"
153139
):

src/exo/shared/types/worker/shards.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from enum import Enum
22
from typing import TypeAlias, final
33

4-
from pydantic import Field
4+
from pydantic import Field, field_validator
55

66
from exo.shared.models.model_cards import ModelCard
77
from exo.utils.pydantic_ext import TaggedModel
@@ -91,6 +91,15 @@ class TensorShardMetadata(BaseShardMetadata):
9191
shard_weights: list[float] | None = None
9292
shard_mode: TensorShardMode = TensorShardMode.Constant
9393

94+
@field_validator("shard_mode", mode="before")
95+
@classmethod
96+
def _coerce_shard_mode(cls, v: object) -> TensorShardMode:
97+
if isinstance(v, str):
98+
return TensorShardMode(v)
99+
if isinstance(v, TensorShardMode):
100+
return v
101+
raise ValueError(f"expected TensorShardMode or str, got {type(v).__name__}")
102+
94103

95104
ShardMetadata: TypeAlias = (
96105
PipelineShardMetadata | CfgShardMetadata | TensorShardMetadata

src/exo/worker/engines/mlx/auto_parallel.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -778,16 +778,20 @@ def shard_model(
778778
layer.self_attn.k_proj.weight.shape[0] // head_dim
779779
)
780780

781+
mlp_unit = getattr(layer.mlp.gate_proj, "group_size", 1)
781782
layer.mlp.gate_proj = self.all_to_sharded_linear(
782783
layer.mlp.gate_proj,
784+
unit=mlp_unit,
783785
weights=self._greedy_weights_for("gate", intermediate),
784786
)
785787
layer.mlp.down_proj = self.sharded_to_all_linear(
786788
layer.mlp.down_proj,
789+
unit=mlp_unit,
787790
weights=self._greedy_weights_for("down", intermediate),
788791
)
789792
layer.mlp.up_proj = self.all_to_sharded_linear(
790793
layer.mlp.up_proj,
794+
unit=mlp_unit,
791795
weights=self._greedy_weights_for("up", intermediate),
792796
)
793797
mx.eval(layer)
@@ -890,16 +894,20 @@ def shard_heads(w: mx.array, sh: int = sh, eh: int = eh) -> mx.array:
890894
# Shard the MLP
891895
if isinstance(layer.mlp, (DeepseekV3MLP, DeepseekV32MLP)):
892896
intermediate = layer.mlp.gate_proj.weight.shape[0]
897+
mlp_unit = getattr(layer.mlp.gate_proj, "group_size", 1)
893898
layer.mlp.gate_proj = self.all_to_sharded_linear(
894899
layer.mlp.gate_proj,
900+
unit=mlp_unit,
895901
weights=self._greedy_weights_for("gate", intermediate),
896902
)
897903
layer.mlp.down_proj = self.sharded_to_all_linear(
898904
layer.mlp.down_proj,
905+
unit=mlp_unit,
899906
weights=self._greedy_weights_for("down", intermediate),
900907
)
901908
layer.mlp.up_proj = self.all_to_sharded_linear(
902909
layer.mlp.up_proj,
910+
unit=mlp_unit,
903911
weights=self._greedy_weights_for("up", intermediate),
904912
)
905913

@@ -1037,16 +1045,20 @@ def shard_heads(w: mx.array, sh: int = sh, eh: int = eh) -> mx.array:
10371045

10381046
if isinstance(layer.mlp, Glm4MoeLiteMLP):
10391047
intermediate = layer.mlp.gate_proj.weight.shape[0]
1048+
mlp_unit = getattr(layer.mlp.gate_proj, "group_size", 1)
10401049
layer.mlp.gate_proj = self.all_to_sharded_linear(
10411050
layer.mlp.gate_proj,
1051+
unit=mlp_unit,
10421052
weights=self._greedy_weights_for("gate", intermediate),
10431053
)
10441054
layer.mlp.down_proj = self.sharded_to_all_linear(
10451055
layer.mlp.down_proj,
1056+
unit=mlp_unit,
10461057
weights=self._greedy_weights_for("down", intermediate),
10471058
)
10481059
layer.mlp.up_proj = self.all_to_sharded_linear(
10491060
layer.mlp.up_proj,
1061+
unit=mlp_unit,
10501062
weights=self._greedy_weights_for("up", intermediate),
10511063
)
10521064

@@ -1516,16 +1528,20 @@ def shard_model(
15161528
# Shard the MLP
15171529
else:
15181530
intermediate = layer.mlp.gate_proj.weight.shape[0]
1531+
mlp_unit = getattr(layer.mlp.gate_proj, "group_size", 1)
15191532
layer.mlp.gate_proj = self.all_to_sharded_linear(
15201533
layer.mlp.gate_proj,
1534+
unit=mlp_unit,
15211535
weights=self._greedy_weights_for("gate", intermediate),
15221536
)
15231537
layer.mlp.down_proj = self.sharded_to_all_linear(
15241538
layer.mlp.down_proj,
1539+
unit=mlp_unit,
15251540
weights=self._greedy_weights_for("down", intermediate),
15261541
)
15271542
layer.mlp.up_proj = self.all_to_sharded_linear(
15281543
layer.mlp.up_proj,
1544+
unit=mlp_unit,
15291545
weights=self._greedy_weights_for("up", intermediate),
15301546
)
15311547

@@ -1622,16 +1638,20 @@ def shard_model(
16221638

16231639
else:
16241640
intermediate = layer.mlp.gate_proj.weight.shape[0]
1641+
mlp_unit = getattr(layer.mlp.gate_proj, "group_size", 1)
16251642
layer.mlp.gate_proj = self.all_to_sharded_linear(
16261643
layer.mlp.gate_proj,
1644+
unit=mlp_unit,
16271645
weights=self._greedy_weights_for("gate", intermediate),
16281646
)
16291647
layer.mlp.down_proj = self.sharded_to_all_linear(
16301648
layer.mlp.down_proj,
1649+
unit=mlp_unit,
16311650
weights=self._greedy_weights_for("down", intermediate),
16321651
)
16331652
layer.mlp.up_proj = self.all_to_sharded_linear(
16341653
layer.mlp.up_proj,
1654+
unit=mlp_unit,
16351655
weights=self._greedy_weights_for("up", intermediate),
16361656
)
16371657

@@ -1792,16 +1812,20 @@ def shard_model(
17921812

17931813
if isinstance(layer.mlp, Step35MLP):
17941814
intermediate = layer.mlp.gate_proj.weight.shape[0]
1815+
mlp_unit = getattr(layer.mlp.gate_proj, "group_size", 1)
17951816
layer.mlp.gate_proj = self.all_to_sharded_linear(
17961817
layer.mlp.gate_proj,
1818+
unit=mlp_unit,
17971819
weights=self._greedy_weights_for("gate", intermediate),
17981820
)
17991821
layer.mlp.up_proj = self.all_to_sharded_linear(
18001822
layer.mlp.up_proj,
1823+
unit=mlp_unit,
18011824
weights=self._greedy_weights_for("up", intermediate),
18021825
)
18031826
layer.mlp.down_proj = self.sharded_to_all_linear(
18041827
layer.mlp.down_proj,
1828+
unit=mlp_unit,
18051829
weights=self._greedy_weights_for("down", intermediate),
18061830
)
18071831
else:

uv.lock

Lines changed: 6 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)