@@ -389,6 +389,7 @@ def from_pretrained(
389389 dtype (str): user specified data type for model weights and
390390 activations. Refer to `PyTorchEngineConfig` for details
391391 hf_overrides (dict[str, Any]): overrides for the HF config.
392+ model_format (str): the quantization format of the model.
392393 """
393394 from transformers import AutoConfig
394395
@@ -567,6 +568,7 @@ def from_config(
567568 target_model : str = None ,
568569 dtype : str = 'auto' ,
569570 trust_remote_code : bool = False ,
571+ model_format : str = None ,
570572 hf_overrides : dict [str , Any ] = None ,
571573 ):
572574 model = model or target_model
@@ -576,6 +578,7 @@ def from_config(
576578 is_draft_model = True ,
577579 spec_method = method ,
578580 block_size = target_cache_cfg .block_size ,
581+ model_format = model_format ,
579582 hf_overrides = hf_overrides ,
580583 )
581584 cache_config = None
@@ -590,6 +593,7 @@ def from_config(
590593 cache_max_entry_count = target_cache_cfg .cache_max_entry_count ,
591594 max_prefill_token_num = target_cache_cfg .max_prefill_token_num ,
592595 device_type = target_cache_cfg .device_type ,
596+ quant_policy = target_cache_cfg .quant_policy ,
593597 migration_backend = target_cache_cfg .migration_backend )
594598 obj = cls (
595599 model = model ,
0 commit comments