-
Notifications
You must be signed in to change notification settings - Fork 16
Hf checkpoint conversion for distributed checkpoints #424
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
e6b7cab
9fa51ec
a73de85
d7d0956
8957f19
527a0d2
95cead4
b8cf4ea
652e77a
fca72dc
ace93c7
53eb907
3a4b46c
642466d
f54abc6
3fbe498
ee4e244
1b4cfe0
3a67ed9
ddbb8cc
5a36d48
bce2ae1
5da0e7f
f902152
d520095
42a7e42
03e07f5
8a9ff2f
9ae218d
36e2e25
cfbe7df
1db0a6c
993d4ff
2e3076d
bc1ca36
1b2aca5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,111 @@ | ||||||||||||||||||||||||||||||||||||
| import os | ||||||||||||||||||||||||||||||||||||
| from pathlib import Path | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||
| from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner | ||||||||||||||||||||||||||||||||||||
| from torch.distributed.checkpoint.filesystem import FileSystemReader | ||||||||||||||||||||||||||||||||||||
| from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE | ||||||||||||||||||||||||||||||||||||
| from torch.distributed.checkpoint.state_dict_loader import _load_state_dict | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| from modalities.config.config import ConfigDictType, load_app_config_dict, save_yaml_config_dict | ||||||||||||||||||||||||||||||||||||
| from modalities.utils.env import EnvOverride | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def convert_dcp_to_torch(dcp_checkpoint_dir: str, output_dir: str, model_key: str = "model_raw") -> str: | ||||||||||||||||||||||||||||||||||||
| """Converts a DCP (Distributed Checkpoint) checkpoint—including | ||||||||||||||||||||||||||||||||||||
| FSDP2, PP, or TP checkpoints—to a standard PyTorch checkpoint. | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||
| dcp_checkpoint_dir (str): Directory containing the DCP checkpoint files (may include FSDP2, PP, or TP). | ||||||||||||||||||||||||||||||||||||
| output_dir (str): Directory to save the converted PyTorch checkpoint. | ||||||||||||||||||||||||||||||||||||
| model_key (str): Key of the model configuration in the modalities config. | ||||||||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||||||||
| str: Path to the converted config file. | ||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||
| os.makedirs(output_dir, exist_ok=True) | ||||||||||||||||||||||||||||||||||||
| torch_checkpoint_file = os.path.join(output_dir, "pytorch_model.bin") | ||||||||||||||||||||||||||||||||||||
| torch_config_file = convert_config_file(dcp_checkpoint_dir, output_dir, model_key, torch_checkpoint_file) | ||||||||||||||||||||||||||||||||||||
| # TODO This is the (adapted) code from torch's dcp_to_torch_save(dcp_checkpoint_dir, torch_checkpoint_file) | ||||||||||||||||||||||||||||||||||||
| # since we only want to convert the model state dict here. In future torch versions this function might | ||||||||||||||||||||||||||||||||||||
| # support converting only parts of the checkpoint. | ||||||||||||||||||||||||||||||||||||
| # (from torch.distributed.checkpoint.format_utils import dcp_to_torch_save) | ||||||||||||||||||||||||||||||||||||
| sd: STATE_DICT_TYPE = {} | ||||||||||||||||||||||||||||||||||||
| planner = _EmptyStateDictLoadPlanner(keys=["app.model"], allow_partial_load=True) | ||||||||||||||||||||||||||||||||||||
| _load_state_dict(sd, storage_reader=FileSystemReader(dcp_checkpoint_dir), planner=planner, no_dist=True) | ||||||||||||||||||||||||||||||||||||
| torch.save(sd["app"]["model"], torch_checkpoint_file) | ||||||||||||||||||||||||||||||||||||
| return torch_config_file | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def convert_config_file(dcp_checkpoint_dir: str, output_dir: str, model_key: str, torch_checkpoint_file: str) -> str: | ||||||||||||||||||||||||||||||||||||
| """Converts the modalities config file for DCP to a config file for standard PyTorch checkpoint loading. | ||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||
| dcp_checkpoint_dir (str): Directory containing the DCP checkpoint files. | ||||||||||||||||||||||||||||||||||||
| output_dir (str): Directory to save the converted config file. | ||||||||||||||||||||||||||||||||||||
| model_key (str): Key of the model configuration in the modalities config. | ||||||||||||||||||||||||||||||||||||
| torch_checkpoint_file (str): Path to the converted PyTorch checkpoint file. | ||||||||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||||||||
| str: Path to the converted config file. | ||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||
| config_src, dcp_config = load_dcp_config(dcp_checkpoint_dir) | ||||||||||||||||||||||||||||||||||||
| config_dst: str = os.path.join(output_dir, os.path.basename(config_src)) | ||||||||||||||||||||||||||||||||||||
| if os.path.exists(config_dst): | ||||||||||||||||||||||||||||||||||||
| raise FileExistsError(f"Config file '{config_dst}' already exists.") | ||||||||||||||||||||||||||||||||||||
| torch_config: ConfigDictType = { | ||||||||||||||||||||||||||||||||||||
| "checkpointed_model": { | ||||||||||||||||||||||||||||||||||||
| "component_key": "model", | ||||||||||||||||||||||||||||||||||||
| "variant_key": "fsdp1_checkpointed", | ||||||||||||||||||||||||||||||||||||
| "config": { | ||||||||||||||||||||||||||||||||||||
| "checkpoint_loading": { | ||||||||||||||||||||||||||||||||||||
| "component_key": "checkpoint_loading", | ||||||||||||||||||||||||||||||||||||
| "variant_key": "torch", | ||||||||||||||||||||||||||||||||||||
| "config": { | ||||||||||||||||||||||||||||||||||||
| "device": "cpu", | ||||||||||||||||||||||||||||||||||||
| "precision": "FP32", | ||||||||||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||||||||||
| "model": { | ||||||||||||||||||||||||||||||||||||
| "instance_key": "model", | ||||||||||||||||||||||||||||||||||||
| "pass_type": "BY_REFERENCE", | ||||||||||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||||||||||
| "checkpoint_path": torch_checkpoint_file, | ||||||||||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||
| if model_key not in dcp_config: | ||||||||||||||||||||||||||||||||||||
| raise KeyError( | ||||||||||||||||||||||||||||||||||||
| f"Model key '{model_key}' not found in config file '{config_src}'." | ||||||||||||||||||||||||||||||||||||
| f" Available keys: {list(dcp_config.keys())}" | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| torch_config["model"] = dcp_config[model_key] | ||||||||||||||||||||||||||||||||||||
BlueCrescent marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||
| torch_config["model"]["config"]["use_meta_device"] = False | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
| torch_config["model"]["config"]["use_meta_device"] = False | |
| model_section = torch_config.get("model") | |
| if not isinstance(model_section, dict): | |
| raise TypeError( | |
| f"Expected 'model' section in config file '{config_src}' to be a mapping, " | |
| f"but got {type(model_section).__name__!r}." | |
| ) | |
| model_config = model_section.get("config") | |
| if not isinstance(model_config, dict): | |
| raise TypeError( | |
| f"Expected 'model.config' section in config file '{config_src}' to be a mapping, " | |
| f"but got {type(model_config).__name__!r}." | |
| ) | |
| model_config["use_meta_device"] = False |
Uh oh!
There was an error while loading. Please reload this page.