88from typing import List
99
1010import torch
11+ import transformers
1112from huggingface_hub import ModelCard , get_token , whoami
1213from transformers import AutoModelForCausalLM , AutoTokenizer , TorchAoConfig
1314
15+ _transformers_version = str (transformers .__version__ )
16+ if _transformers_version >= "5" :
17+ from transformers .quantizers .auto import get_hf_quantizer
18+
1419from torchao ._models ._eval import TransformerEvalWrapper
1520from torchao .prototype .awq import (
1621 AWQConfig ,
2833 PerRow ,
2934 quantize_ ,
3035)
36+ from torchao .quantization .quant_api import _is_linear
37+
38+ safe_serialization = _transformers_version >= "5"
3139
3240
3341def _get_username ():
@@ -43,7 +51,10 @@ def _untie_weights_and_save_locally(model_id):
4351
4452 tokenizer = AutoTokenizer .from_pretrained (model_id )
4553
46- from transformers .modeling_utils import find_tied_parameters
54+ if _transformers_version >= "5" :
55+ from accelerate .utils .modeling import find_tied_parameters
56+ else :
57+ from transformers .modeling_utils import find_tied_parameters
4758
4859 if getattr (
4960 untied_model .config .get_text_config (decoder = True ), "tie_word_embeddings"
@@ -117,7 +128,7 @@ def _untie_weights_and_save_locally(model_id):
117128USER_ID = "YOUR_USER_ID"
118129MODEL_NAME = model_id.split("/")[-1]
119130save_to = f"{{USER_ID}}/{{MODEL_NAME}}-{quant}"
120- quantized_model.push_to_hub(save_to, safe_serialization=False )
131+ quantized_model.push_to_hub(save_to, safe_serialization={safe_serialization} )
121132tokenizer.push_to_hub(save_to)
122133
123134# Manual Testing
@@ -719,11 +730,18 @@ def quantize_and_upload(
719730 int4_packing_format = "tile_packed_to_4d" ,
720731 int4_choose_qparams_algorithm = "hqq" ,
721732 )
722- quant_config = AWQConfig (base_config , step = "prepare" )
723- quantize_ (
724- model ,
725- quant_config ,
726- )
733+
734+ def filter_fn_skip_lmhead (module , fqn ):
735+ if fqn == "lm_head" :
736+ return False
737+ return _is_linear (module , fqn )
738+
739+ awq_config = AWQConfig (base_config , step = "prepare" )
740+ if safe_serialization :
741+ quantize_ (model , awq_config , filter_fn = filter_fn_skip_lmhead )
742+ else :
743+ quantize_ (model , awq_config )
744+
727745 TransformerEvalWrapper (
728746 model = model ,
729747 tokenizer = tokenizer ,
@@ -732,12 +750,33 @@ def quantize_and_upload(
732750 tasks = tasks ,
733751 limit = calibration_limit ,
734752 )
735- quant_config = AWQConfig (base_config , step = "convert" )
736- quantize_ (model , quant_config )
753+ awq_config = AWQConfig (base_config , step = "convert" )
754+ if safe_serialization :
755+ quantize_ (model , awq_config , filter_fn = filter_fn_skip_lmhead )
756+ else :
757+ quantize_ (model , awq_config )
737758
738759 quantized_model = model
739760 quant_config = AWQConfig (base_config , step = "prepare_for_loading" )
740- quantized_model .config .quantization_config = TorchAoConfig (quant_config )
761+ if safe_serialization :
762+ quantization_config = TorchAoConfig (quant_config ).to_dict ()
763+ quantized_model .config .quantization_config = quantization_config
764+
765+ hf_quantizer , _ , _ , _ = get_hf_quantizer (
766+ config = quantized_model .config ,
767+ quantization_config = None ,
768+ dtype = torch .bfloat16 ,
769+ device_map = "cuda:0" ,
770+ weights_only = True ,
771+ user_agent = {
772+ "file_type" : "model" ,
773+ "framework" : "pytorch" ,
774+ "from_auto_class" : False ,
775+ },
776+ )
777+ quantized_model .hf_quantizer = hf_quantizer
778+ else :
779+ quantized_model .config .quantization_config = TorchAoConfig (quant_config )
741780 elif quant == "SmoothQuant-INT8-INT8" :
742781 model = AutoModelForCausalLM .from_pretrained (
743782 model_to_quantize ,
@@ -804,6 +843,7 @@ def quantize_and_upload(
804843 model_type = quantized_model .config .model_type ,
805844 quant = quant ,
806845 quant_code = quant_to_quant_code [quant ],
846+ safe_serialization = safe_serialization ,
807847 # server specific recipes
808848 server_inference_recipe = ""
809849 if is_mobile
@@ -836,12 +876,16 @@ def quantize_and_upload(
836876
837877 # Push to hub
838878 if push_to_hub :
839- quantized_model .push_to_hub (quantized_model_id , safe_serialization = False )
879+ quantized_model .push_to_hub (
880+ quantized_model_id , safe_serialization = safe_serialization
881+ )
840882 tokenizer .push_to_hub (quantized_model_id )
841883 if populate_model_card_template :
842884 card .push_to_hub (quantized_model_id )
843885 else :
844- quantized_model .save_pretrained (quantized_model_id , safe_serialization = False )
886+ quantized_model .save_pretrained (
887+ quantized_model_id , safe_serialization = safe_serialization
888+ )
845889 tokenizer .save_pretrained (quantized_model_id )
846890
847891 # Manual Testing
0 commit comments