Skip to content

Commit 48f1299

Browse files
authored
update quantize+upload scripts to use safetensors (#3139)
* update release scripts to use safetensors * awq * int8-int4 * add transformers version check
1 parent 01f0a97 commit 48f1299

File tree

1 file changed

+56
-12
lines changed

1 file changed

+56
-12
lines changed

.github/scripts/torchao_model_releases/quantize_and_upload.py

Lines changed: 56 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,14 @@
88
from typing import List
99

1010
import torch
11+
import transformers
1112
from huggingface_hub import ModelCard, get_token, whoami
1213
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
1314

15+
_transformers_version = str(transformers.__version__)
16+
if _transformers_version >= "5":
17+
from transformers.quantizers.auto import get_hf_quantizer
18+
1419
from torchao._models._eval import TransformerEvalWrapper
1520
from torchao.prototype.awq import (
1621
AWQConfig,
@@ -28,6 +33,9 @@
2833
PerRow,
2934
quantize_,
3035
)
36+
from torchao.quantization.quant_api import _is_linear
37+
38+
safe_serialization = _transformers_version >= "5"
3139

3240

3341
def _get_username():
@@ -43,7 +51,10 @@ def _untie_weights_and_save_locally(model_id):
4351

4452
tokenizer = AutoTokenizer.from_pretrained(model_id)
4553

46-
from transformers.modeling_utils import find_tied_parameters
54+
if _transformers_version >= "5":
55+
from accelerate.utils.modeling import find_tied_parameters
56+
else:
57+
from transformers.modeling_utils import find_tied_parameters
4758

4859
if getattr(
4960
untied_model.config.get_text_config(decoder=True), "tie_word_embeddings"
@@ -117,7 +128,7 @@ def _untie_weights_and_save_locally(model_id):
117128
USER_ID = "YOUR_USER_ID"
118129
MODEL_NAME = model_id.split("/")[-1]
119130
save_to = f"{{USER_ID}}/{{MODEL_NAME}}-{quant}"
120-
quantized_model.push_to_hub(save_to, safe_serialization=False)
131+
quantized_model.push_to_hub(save_to, safe_serialization={safe_serialization})
121132
tokenizer.push_to_hub(save_to)
122133
123134
# Manual Testing
@@ -719,11 +730,18 @@ def quantize_and_upload(
719730
int4_packing_format="tile_packed_to_4d",
720731
int4_choose_qparams_algorithm="hqq",
721732
)
722-
quant_config = AWQConfig(base_config, step="prepare")
723-
quantize_(
724-
model,
725-
quant_config,
726-
)
733+
734+
def filter_fn_skip_lmhead(module, fqn):
735+
if fqn == "lm_head":
736+
return False
737+
return _is_linear(module, fqn)
738+
739+
awq_config = AWQConfig(base_config, step="prepare")
740+
if safe_serialization:
741+
quantize_(model, awq_config, filter_fn=filter_fn_skip_lmhead)
742+
else:
743+
quantize_(model, awq_config)
744+
727745
TransformerEvalWrapper(
728746
model=model,
729747
tokenizer=tokenizer,
@@ -732,12 +750,33 @@ def quantize_and_upload(
732750
tasks=tasks,
733751
limit=calibration_limit,
734752
)
735-
quant_config = AWQConfig(base_config, step="convert")
736-
quantize_(model, quant_config)
753+
awq_config = AWQConfig(base_config, step="convert")
754+
if safe_serialization:
755+
quantize_(model, awq_config, filter_fn=filter_fn_skip_lmhead)
756+
else:
757+
quantize_(model, awq_config)
737758

738759
quantized_model = model
739760
quant_config = AWQConfig(base_config, step="prepare_for_loading")
740-
quantized_model.config.quantization_config = TorchAoConfig(quant_config)
761+
if safe_serialization:
762+
quantization_config = TorchAoConfig(quant_config).to_dict()
763+
quantized_model.config.quantization_config = quantization_config
764+
765+
hf_quantizer, _, _, _ = get_hf_quantizer(
766+
config=quantized_model.config,
767+
quantization_config=None,
768+
dtype=torch.bfloat16,
769+
device_map="cuda:0",
770+
weights_only=True,
771+
user_agent={
772+
"file_type": "model",
773+
"framework": "pytorch",
774+
"from_auto_class": False,
775+
},
776+
)
777+
quantized_model.hf_quantizer = hf_quantizer
778+
else:
779+
quantized_model.config.quantization_config = TorchAoConfig(quant_config)
741780
elif quant == "SmoothQuant-INT8-INT8":
742781
model = AutoModelForCausalLM.from_pretrained(
743782
model_to_quantize,
@@ -804,6 +843,7 @@ def quantize_and_upload(
804843
model_type=quantized_model.config.model_type,
805844
quant=quant,
806845
quant_code=quant_to_quant_code[quant],
846+
safe_serialization=safe_serialization,
807847
# server specific recipes
808848
server_inference_recipe=""
809849
if is_mobile
@@ -836,12 +876,16 @@ def quantize_and_upload(
836876

837877
# Push to hub
838878
if push_to_hub:
839-
quantized_model.push_to_hub(quantized_model_id, safe_serialization=False)
879+
quantized_model.push_to_hub(
880+
quantized_model_id, safe_serialization=safe_serialization
881+
)
840882
tokenizer.push_to_hub(quantized_model_id)
841883
if populate_model_card_template:
842884
card.push_to_hub(quantized_model_id)
843885
else:
844-
quantized_model.save_pretrained(quantized_model_id, safe_serialization=False)
886+
quantized_model.save_pretrained(
887+
quantized_model_id, safe_serialization=safe_serialization
888+
)
845889
tokenizer.save_pretrained(quantized_model_id)
846890

847891
# Manual Testing

0 commit comments

Comments
 (0)