@@ -778,16 +778,20 @@ def shard_model(
778778 layer .self_attn .k_proj .weight .shape [0 ] // head_dim
779779 )
780780
781+ mlp_unit = getattr (layer .mlp .gate_proj , "group_size" , 1 )
781782 layer .mlp .gate_proj = self .all_to_sharded_linear (
782783 layer .mlp .gate_proj ,
784+ unit = mlp_unit ,
783785 weights = self ._greedy_weights_for ("gate" , intermediate ),
784786 )
785787 layer .mlp .down_proj = self .sharded_to_all_linear (
786788 layer .mlp .down_proj ,
789+ unit = mlp_unit ,
787790 weights = self ._greedy_weights_for ("down" , intermediate ),
788791 )
789792 layer .mlp .up_proj = self .all_to_sharded_linear (
790793 layer .mlp .up_proj ,
794+ unit = mlp_unit ,
791795 weights = self ._greedy_weights_for ("up" , intermediate ),
792796 )
793797 mx .eval (layer )
@@ -890,16 +894,20 @@ def shard_heads(w: mx.array, sh: int = sh, eh: int = eh) -> mx.array:
890894 # Shard the MLP
891895 if isinstance (layer .mlp , (DeepseekV3MLP , DeepseekV32MLP )):
892896 intermediate = layer .mlp .gate_proj .weight .shape [0 ]
897+ mlp_unit = getattr (layer .mlp .gate_proj , "group_size" , 1 )
893898 layer .mlp .gate_proj = self .all_to_sharded_linear (
894899 layer .mlp .gate_proj ,
900+ unit = mlp_unit ,
895901 weights = self ._greedy_weights_for ("gate" , intermediate ),
896902 )
897903 layer .mlp .down_proj = self .sharded_to_all_linear (
898904 layer .mlp .down_proj ,
905+ unit = mlp_unit ,
899906 weights = self ._greedy_weights_for ("down" , intermediate ),
900907 )
901908 layer .mlp .up_proj = self .all_to_sharded_linear (
902909 layer .mlp .up_proj ,
910+ unit = mlp_unit ,
903911 weights = self ._greedy_weights_for ("up" , intermediate ),
904912 )
905913
@@ -1037,16 +1045,20 @@ def shard_heads(w: mx.array, sh: int = sh, eh: int = eh) -> mx.array:
10371045
10381046 if isinstance (layer .mlp , Glm4MoeLiteMLP ):
10391047 intermediate = layer .mlp .gate_proj .weight .shape [0 ]
1048+ mlp_unit = getattr (layer .mlp .gate_proj , "group_size" , 1 )
10401049 layer .mlp .gate_proj = self .all_to_sharded_linear (
10411050 layer .mlp .gate_proj ,
1051+ unit = mlp_unit ,
10421052 weights = self ._greedy_weights_for ("gate" , intermediate ),
10431053 )
10441054 layer .mlp .down_proj = self .sharded_to_all_linear (
10451055 layer .mlp .down_proj ,
1056+ unit = mlp_unit ,
10461057 weights = self ._greedy_weights_for ("down" , intermediate ),
10471058 )
10481059 layer .mlp .up_proj = self .all_to_sharded_linear (
10491060 layer .mlp .up_proj ,
1061+ unit = mlp_unit ,
10501062 weights = self ._greedy_weights_for ("up" , intermediate ),
10511063 )
10521064
@@ -1516,16 +1528,20 @@ def shard_model(
15161528 # Shard the MLP
15171529 else :
15181530 intermediate = layer .mlp .gate_proj .weight .shape [0 ]
1531+ mlp_unit = getattr (layer .mlp .gate_proj , "group_size" , 1 )
15191532 layer .mlp .gate_proj = self .all_to_sharded_linear (
15201533 layer .mlp .gate_proj ,
1534+ unit = mlp_unit ,
15211535 weights = self ._greedy_weights_for ("gate" , intermediate ),
15221536 )
15231537 layer .mlp .down_proj = self .sharded_to_all_linear (
15241538 layer .mlp .down_proj ,
1539+ unit = mlp_unit ,
15251540 weights = self ._greedy_weights_for ("down" , intermediate ),
15261541 )
15271542 layer .mlp .up_proj = self .all_to_sharded_linear (
15281543 layer .mlp .up_proj ,
1544+ unit = mlp_unit ,
15291545 weights = self ._greedy_weights_for ("up" , intermediate ),
15301546 )
15311547
@@ -1622,16 +1638,20 @@ def shard_model(
16221638
16231639 else :
16241640 intermediate = layer .mlp .gate_proj .weight .shape [0 ]
1641+ mlp_unit = getattr (layer .mlp .gate_proj , "group_size" , 1 )
16251642 layer .mlp .gate_proj = self .all_to_sharded_linear (
16261643 layer .mlp .gate_proj ,
1644+ unit = mlp_unit ,
16271645 weights = self ._greedy_weights_for ("gate" , intermediate ),
16281646 )
16291647 layer .mlp .down_proj = self .sharded_to_all_linear (
16301648 layer .mlp .down_proj ,
1649+ unit = mlp_unit ,
16311650 weights = self ._greedy_weights_for ("down" , intermediate ),
16321651 )
16331652 layer .mlp .up_proj = self .all_to_sharded_linear (
16341653 layer .mlp .up_proj ,
1654+ unit = mlp_unit ,
16351655 weights = self ._greedy_weights_for ("up" , intermediate ),
16361656 )
16371657
@@ -1792,16 +1812,20 @@ def shard_model(
17921812
17931813 if isinstance (layer .mlp , Step35MLP ):
17941814 intermediate = layer .mlp .gate_proj .weight .shape [0 ]
1815+ mlp_unit = getattr (layer .mlp .gate_proj , "group_size" , 1 )
17951816 layer .mlp .gate_proj = self .all_to_sharded_linear (
17961817 layer .mlp .gate_proj ,
1818+ unit = mlp_unit ,
17971819 weights = self ._greedy_weights_for ("gate" , intermediate ),
17981820 )
17991821 layer .mlp .up_proj = self .all_to_sharded_linear (
18001822 layer .mlp .up_proj ,
1823+ unit = mlp_unit ,
18011824 weights = self ._greedy_weights_for ("up" , intermediate ),
18021825 )
18031826 layer .mlp .down_proj = self .sharded_to_all_linear (
18041827 layer .mlp .down_proj ,
1828+ unit = mlp_unit ,
18051829 weights = self ._greedy_weights_for ("down" , intermediate ),
18061830 )
18071831 else :
0 commit comments