Skip to content

Commit f5a9860

Browse files
authored
disable quantization for the MTP fc projection to match FP8 model configs (#4572)
1 parent 34a1ef6 commit f5a9860

3 files changed

Lines changed: 9 additions & 5 deletions

File tree

lmdeploy/pytorch/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,7 @@ def from_pretrained(
389389
dtype (str): user specified data type for model weights and
390390
activations. Refer to `PyTorchEngineConfig` for details
391391
hf_overrides (dict[str, Any]): overrides for the HF config.
392+
model_format (str): the quantization format of the model.
392393
"""
393394
from transformers import AutoConfig
394395

@@ -567,6 +568,7 @@ def from_config(
567568
target_model: str = None,
568569
dtype: str = 'auto',
569570
trust_remote_code: bool = False,
571+
model_format: str = None,
570572
hf_overrides: dict[str, Any] = None,
571573
):
572574
model = model or target_model
@@ -576,6 +578,7 @@ def from_config(
576578
is_draft_model=True,
577579
spec_method=method,
578580
block_size=target_cache_cfg.block_size,
581+
model_format=model_format,
579582
hf_overrides=hf_overrides,
580583
)
581584
cache_config = None
@@ -590,6 +593,7 @@ def from_config(
590593
cache_max_entry_count=target_cache_cfg.cache_max_entry_count,
591594
max_prefill_token_num=target_cache_cfg.max_prefill_token_num,
592595
device_type=target_cache_cfg.device_type,
596+
quant_policy=target_cache_cfg.quant_policy,
593597
migration_backend=target_cache_cfg.migration_backend)
594598
obj = cls(
595599
model=model,

lmdeploy/pytorch/engine/config_builder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def build_specdecode_config(target_model, speculative_config: SpeculativeConfig,
115115
target_cache_cfg=cache_config,
116116
dtype=engine_config.dtype,
117117
trust_remote_code=trust_remote_code,
118+
model_format=engine_config.model_format,
118119
hf_overrides=engine_config.hf_overrides,
119120
)
120121
return specdecode_config

lmdeploy/pytorch/models/qwen3_5_mtp.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,24 +94,23 @@ def __init__(
9494
for idx in range(self.num_mtp_layers)
9595
})
9696

97-
quantization_config = getattr(config, 'quantization_config', None)
98-
9997
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, device=device)
10098
self.pre_fc_norm_hidden = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, device=device)
10199
self.pre_fc_norm_embedding = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, device=device)
102100

103101
# shared with target model
104102
self.embed_tokens = None
105-
103+
# do not quant fc as in https://huggingface.co/Qwen/Qwen3.5-27B-FP8/blob/main/config.json#L403
104+
# and https://huggingface.co/Qwen/Qwen3.5-35B-A3B-FP8/blob/main/config.json#L409
106105
self.fc = build_colwise_linear(
107106
config.hidden_size * 2,
108107
config.hidden_size,
109108
bias=False,
110109
dtype=dtype,
111110
device=device,
112111
is_tp=False,
113-
quant_config=quantization_config,
114112
dp_disable_tp=True,
113+
prefix=add_prefix('fc', prefix=prefix),
115114
)
116115

117116
# build rotary embedding
@@ -200,7 +199,7 @@ def __init__(self,
200199
self.model = Qwen3_5MultiTokenPredictor(config.text_config,
201200
dtype=dtype,
202201
device=device,
203-
prefix=add_prefix('model', prefix=prefix))
202+
prefix=add_prefix('mtp', prefix=prefix))
204203

205204
self.num_experts = getattr(config.text_config, 'num_experts', None)
206205
self.enable_sci_mtp = getattr(config, 'enable_sci_mtp', False)

0 commit comments

Comments
 (0)