@@ -536,13 +536,70 @@ def _sharded_to_all(path: str, weight: mx.array):
536536 return None
537537 return - 1 , segments
538538
539- sharded_to_all_linear_in_place = partial (
539+ _base_sharded_to_all_in_place = partial (
540540 shard_inplace ,
541541 sharding = _sharded_to_all , # type: ignore
542542 group = group ,
543543 weights = shard_weights ,
544544 )
545545
546+ _base_all_to_sharded_in_place = all_to_sharded_linear_in_place
547+
548+ def _quantized_moe_shard_inplace (
549+ module : nn .Module ,
550+ sharding : Literal ["all-to-sharded" , "sharded-to-all" ],
551+ weights : list [float ] | None = None ,
552+ ) -> None :
553+ N = group .size ()
554+ r = group .rank ()
555+ gs = module .group_size # pyright: ignore[reportAttributeAccessIssue]
556+ bits = module .bits # pyright: ignore[reportAttributeAccessIssue]
557+ params = module .parameters ()
558+ scales = params ["scales" ]
559+
560+ if sharding == "all-to-sharded" :
561+ dim = params ["weight" ].shape [max (params ["weight" ].ndim - 2 , 0 )]
562+ sizes = compute_shard_sizes (dim , N , gs , weights )
563+ result : dict [str , Any ] = {}
564+ for key , param in params .items ():
565+ if not isinstance (param , mx .array ):
566+ result [key ] = param
567+ continue
568+ axis = max (param .ndim - 2 , 0 )
569+ indices = [sum (sizes [:i ]) for i in range (1 , len (sizes ))]
570+ result [key ] = mx .contiguous (mx .split (param , indices , axis = axis )[r ])
571+ else :
572+ num_groups = scales .shape [- 1 ]
573+ group_counts = compute_shard_sizes (num_groups , N , 1 , weights )
574+ weight_ppg = gs * bits // 32
575+ result = {}
576+ for key , param in params .items ():
577+ if not isinstance (param , mx .array ):
578+ result [key ] = param
579+ continue
580+ if key == "weight" :
581+ s = [gc * weight_ppg for gc in group_counts ]
582+ elif key in ("scales" , "biases" ):
583+ s = list (group_counts )
584+ else :
585+ result [key ] = param
586+ continue
587+ indices = [sum (s [:i ]) for i in range (1 , len (s ))]
588+ result [key ] = mx .contiguous (mx .split (param , indices , axis = - 1 )[r ])
589+ module .update (result )
590+
591+ def all_to_sharded_linear_in_place (module : nn .Module , ** kwargs : Any ) -> None :
592+ if getattr (module , "group_size" , 0 ) > 0 and getattr (module , "bits" , 0 ) > 0 and "scales" in module .parameters ():
593+ _quantized_moe_shard_inplace (module , "all-to-sharded" , weights = kwargs .get ("weights" ))
594+ else :
595+ _base_all_to_sharded_in_place (module , ** kwargs )
596+
597+ def sharded_to_all_linear_in_place (module : nn .Module , ** kwargs : Any ) -> None :
598+ if getattr (module , "group_size" , 0 ) > 0 and getattr (module , "bits" , 0 ) > 0 and "scales" in module .parameters ():
599+ _quantized_moe_shard_inplace (module , "sharded-to-all" , weights = kwargs .get ("weights" ))
600+ else :
601+ _base_sharded_to_all_in_place (module , ** kwargs )
602+
546603 if isinstance (model , (LlamaModel , Ministral3Model )):
547604 tensor_parallel_sharding_strategy = LlamaShardingStrategy (
548605 group ,
@@ -778,16 +835,20 @@ def shard_model(
778835 layer .self_attn .k_proj .weight .shape [0 ] // head_dim
779836 )
780837
838+ mlp_unit = getattr (layer .mlp .gate_proj , "group_size" , 1 )
781839 layer .mlp .gate_proj = self .all_to_sharded_linear (
782840 layer .mlp .gate_proj ,
841+ unit = mlp_unit ,
783842 weights = self ._greedy_weights_for ("gate" , intermediate ),
784843 )
785844 layer .mlp .down_proj = self .sharded_to_all_linear (
786845 layer .mlp .down_proj ,
846+ unit = mlp_unit ,
787847 weights = self ._greedy_weights_for ("down" , intermediate ),
788848 )
789849 layer .mlp .up_proj = self .all_to_sharded_linear (
790850 layer .mlp .up_proj ,
851+ unit = mlp_unit ,
791852 weights = self ._greedy_weights_for ("up" , intermediate ),
792853 )
793854 mx .eval (layer )
@@ -890,16 +951,20 @@ def shard_heads(w: mx.array, sh: int = sh, eh: int = eh) -> mx.array:
890951 # Shard the MLP
891952 if isinstance (layer .mlp , (DeepseekV3MLP , DeepseekV32MLP )):
892953 intermediate = layer .mlp .gate_proj .weight .shape [0 ]
954+ mlp_unit = getattr (layer .mlp .gate_proj , "group_size" , 1 )
893955 layer .mlp .gate_proj = self .all_to_sharded_linear (
894956 layer .mlp .gate_proj ,
957+ unit = mlp_unit ,
895958 weights = self ._greedy_weights_for ("gate" , intermediate ),
896959 )
897960 layer .mlp .down_proj = self .sharded_to_all_linear (
898961 layer .mlp .down_proj ,
962+ unit = mlp_unit ,
899963 weights = self ._greedy_weights_for ("down" , intermediate ),
900964 )
901965 layer .mlp .up_proj = self .all_to_sharded_linear (
902966 layer .mlp .up_proj ,
967+ unit = mlp_unit ,
903968 weights = self ._greedy_weights_for ("up" , intermediate ),
904969 )
905970
@@ -1037,16 +1102,20 @@ def shard_heads(w: mx.array, sh: int = sh, eh: int = eh) -> mx.array:
10371102
10381103 if isinstance (layer .mlp , Glm4MoeLiteMLP ):
10391104 intermediate = layer .mlp .gate_proj .weight .shape [0 ]
1105+ mlp_unit = getattr (layer .mlp .gate_proj , "group_size" , 1 )
10401106 layer .mlp .gate_proj = self .all_to_sharded_linear (
10411107 layer .mlp .gate_proj ,
1108+ unit = mlp_unit ,
10421109 weights = self ._greedy_weights_for ("gate" , intermediate ),
10431110 )
10441111 layer .mlp .down_proj = self .sharded_to_all_linear (
10451112 layer .mlp .down_proj ,
1113+ unit = mlp_unit ,
10461114 weights = self ._greedy_weights_for ("down" , intermediate ),
10471115 )
10481116 layer .mlp .up_proj = self .all_to_sharded_linear (
10491117 layer .mlp .up_proj ,
1118+ unit = mlp_unit ,
10501119 weights = self ._greedy_weights_for ("up" , intermediate ),
10511120 )
10521121
@@ -1516,16 +1585,20 @@ def shard_model(
15161585 # Shard the MLP
15171586 else :
15181587 intermediate = layer .mlp .gate_proj .weight .shape [0 ]
1588+ mlp_unit = getattr (layer .mlp .gate_proj , "group_size" , 1 )
15191589 layer .mlp .gate_proj = self .all_to_sharded_linear (
15201590 layer .mlp .gate_proj ,
1591+ unit = mlp_unit ,
15211592 weights = self ._greedy_weights_for ("gate" , intermediate ),
15221593 )
15231594 layer .mlp .down_proj = self .sharded_to_all_linear (
15241595 layer .mlp .down_proj ,
1596+ unit = mlp_unit ,
15251597 weights = self ._greedy_weights_for ("down" , intermediate ),
15261598 )
15271599 layer .mlp .up_proj = self .all_to_sharded_linear (
15281600 layer .mlp .up_proj ,
1601+ unit = mlp_unit ,
15291602 weights = self ._greedy_weights_for ("up" , intermediate ),
15301603 )
15311604
@@ -1622,16 +1695,20 @@ def shard_model(
16221695
16231696 else :
16241697 intermediate = layer .mlp .gate_proj .weight .shape [0 ]
1698+ mlp_unit = getattr (layer .mlp .gate_proj , "group_size" , 1 )
16251699 layer .mlp .gate_proj = self .all_to_sharded_linear (
16261700 layer .mlp .gate_proj ,
1701+ unit = mlp_unit ,
16271702 weights = self ._greedy_weights_for ("gate" , intermediate ),
16281703 )
16291704 layer .mlp .down_proj = self .sharded_to_all_linear (
16301705 layer .mlp .down_proj ,
1706+ unit = mlp_unit ,
16311707 weights = self ._greedy_weights_for ("down" , intermediate ),
16321708 )
16331709 layer .mlp .up_proj = self .all_to_sharded_linear (
16341710 layer .mlp .up_proj ,
1711+ unit = mlp_unit ,
16351712 weights = self ._greedy_weights_for ("up" , intermediate ),
16361713 )
16371714
@@ -1792,16 +1869,20 @@ def shard_model(
17921869
17931870 if isinstance (layer .mlp , Step35MLP ):
17941871 intermediate = layer .mlp .gate_proj .weight .shape [0 ]
1872+ mlp_unit = getattr (layer .mlp .gate_proj , "group_size" , 1 )
17951873 layer .mlp .gate_proj = self .all_to_sharded_linear (
17961874 layer .mlp .gate_proj ,
1875+ unit = mlp_unit ,
17971876 weights = self ._greedy_weights_for ("gate" , intermediate ),
17981877 )
17991878 layer .mlp .up_proj = self .all_to_sharded_linear (
18001879 layer .mlp .up_proj ,
1880+ unit = mlp_unit ,
18011881 weights = self ._greedy_weights_for ("up" , intermediate ),
18021882 )
18031883 layer .mlp .down_proj = self .sharded_to_all_linear (
18041884 layer .mlp .down_proj ,
1885+ unit = mlp_unit ,
18051886 weights = self ._greedy_weights_for ("down" , intermediate ),
18061887 )
18071888 else :
0 commit comments