2121
2222
2323def string_to_config (s ):
24- if s is None :
24+ if s == " None" :
2525 return None
2626 elif s == "float8_rowwise" :
2727 return Float8DynamicActivationFloat8WeightConfig (granularity = PerRow ())
@@ -41,7 +41,7 @@ def string_to_config(s):
4141 raise AssertionError (f"unsupported { s } " )
4242
4343
44- def quantize_model_and_save (model_id , quant_config , output_dir = "results" ):
44+ def quantize_model_and_save (model_id , quant_config , output_dir ):
4545 """Quantize the model and save it to the output directory."""
4646 print ("Quantizing model with config: " , quant_config )
4747 if quant_config is None :
@@ -60,27 +60,6 @@ def quantize_model_and_save(model_id, quant_config, output_dir="results"):
6060 return quantized_model , tokenizer
6161
6262
63- def run_lm_eval (model_dir , tasks_list = ["hellaswag" ], device = "cuda:0" , batch_size = 8 ):
64- """Run the lm_eval command using subprocess."""
65- tasks_str = "," .join (tasks_list )
66- command = [
67- "lm_eval" ,
68- "--model" ,
69- "hf" ,
70- "--model_args" ,
71- f"pretrained={ model_dir } " ,
72- "--tasks" ,
73- f"{ tasks_str } " ,
74- "--device" ,
75- f"{ device } " ,
76- "--batch_size" ,
77- f"{ batch_size } " ,
78- "--output_path" ,
79- f"{ model_dir } /lm_eval_outputs/" ,
80- ]
81- subprocess .run (command , check = True )
82-
83-
8463def get_size_of_dir (model_output_dir ):
8564 # get dir size from shell, to skip complexity of dealing with tensor
8665 # subclasses
@@ -94,43 +73,23 @@ def get_size_of_dir(model_output_dir):
9473def run (
9574 model_id : str ,
9675 quant_recipe_name : str | None ,
97- tasks ,
98- device ,
99- batch_size ,
10076 model_output_dir ,
10177):
10278 print (f"\n Running { model_id = } with { quant_recipe_name = } \n " )
103- model_name = model_id .split ("/" )[- 1 ]
104- model_output_dir = (
105- f"benchmarks/data/quantized_model/{ model_name } -{ quant_recipe_name } "
106- )
10779 quant_config = string_to_config (quant_recipe_name )
10880 quantized_model , tokenizer = quantize_model_and_save (
10981 model_id , quant_config = quant_config , output_dir = model_output_dir
11082 )
11183 print (quantized_model )
112-
84+ print ( f"saved { model_id = } , { quant_recipe_name = } to { model_output_dir = } " )
11385 model_size = get_size_of_dir (model_output_dir ) / 1e9
11486 print (f"checkpoint size: { model_size } GB" )
11587
116- run_lm_eval (
117- model_output_dir , tasks_list = tasks , device = device , batch_size = batch_size
118- )
119- print ("done\n " )
120-
12188
12289if __name__ == "__main__" :
123- try :
124- import lm_eval # noqa: F401
125- except :
126- print (
127- "lm_eval is required to run this script. Please install it using pip install lm-eval."
128- )
129- exit (0 )
130-
13190 # Set up argument parser
13291 parser = argparse .ArgumentParser (
133- description = "Quantize a model and evaluate its throughput ."
92+ description = "Load a model from HuggingFace, quantize it and save it to disk ."
13493 )
13594 parser .add_argument (
13695 "--model_id" ,
@@ -141,26 +100,12 @@ def run(
141100 parser .add_argument (
142101 "--quant_recipe_name" ,
143102 type = str ,
144- default = None ,
145- help = "The quantization recipe to use." ,
146- )
147- parser .add_argument (
148- "--tasks" ,
149- nargs = "+" ,
150- type = str ,
151- default = ["wikitext" ],
152- help = "List of lm-eluther tasks to evaluate usage: --tasks task1 task2" ,
153- )
154- parser .add_argument (
155- "--device" , type = str , default = "cuda:0" , help = "Device to run the model on."
156- )
157- parser .add_argument (
158- "--batch_size" , type = str , default = "auto" , help = "Batch size for lm_eval."
103+ help = "The quantization recipe to use, 'None' means no quantization" ,
159104 )
160105 parser .add_argument (
161106 "--output_dir" ,
162107 type = str ,
163- default = "quantized_models " ,
108+ default = "benchmarks/data/quantized_model/test " ,
164109 help = "Output directory for quantized model." ,
165110 )
166111 args = parser .parse_args ()
@@ -169,8 +114,5 @@ def run(
169114 run (
170115 model_id = args .model_id ,
171116 quant_recipe_name = args .quant_recipe_name ,
172- tasks = args .tasks ,
173- device = args .device ,
174- batch_size = args .batch_size ,
175117 model_output_dir = args .output_dir ,
176118 )
0 commit comments