diff --git a/.gitignore b/.gitignore index 34de4e774..8b13b59e5 100644 --- a/.gitignore +++ b/.gitignore @@ -139,7 +139,7 @@ data/**/*.txt data/**/*.gz data/**/*.np* data/**/*.npy -checkpoints/ +./checkpoints/ .vscode/ *.pt *.ckpt diff --git a/README.md b/README.md index c96692c60..737bbec66 100644 --- a/README.md +++ b/README.md @@ -165,7 +165,7 @@ Or use the 20B tokenizer (for which only a single Vocab file is needed): (alternatively, you can provide any tokenizer file that can be loaded by Hugging Face's tokenizers library with the `Tokenizer.from_pretrained()` command) -You can now pretokenize your data using `tools/preprocess_data.py`, the arguments for which are detailed below: +You can now pretokenize your data using `tools/datasets/preprocess_data.py`, the arguments for which are detailed below: ``` usage: preprocess_data.py [-h] --input INPUT [--jsonl-keys JSONL_KEYS [JSONL_KEYS ...]] [--num-docs NUM_DOCS] --tokenizer-type {HFGPT2Tokenizer,HFTokenizer,GPT2BPETokenizer,CharLevelTokenizer} [--vocab-file VOCAB_FILE] [--merge-file MERGE_FILE] [--append-eod] [--ftfy] --output-prefix OUTPUT_PREFIX @@ -206,7 +206,7 @@ runtime: For example: ```bash -python tools/preprocess_data.py \ +python tools/datasets/preprocess_data.py \ --input ./data/mydataset.jsonl.zst \ --output-prefix ./data/mydataset \ --vocab ./data/gpt2-vocab.json \ @@ -322,7 +322,7 @@ python ./tools/convert_sequential_to_hf.py --input_dir /path/to/model/global_st Then to upload a model to [the Hugging Face Hub](https://huggingface.co/), run: ```bash huggingface-cli login -python ./tools/upload.py +python ./tools/checkpoints/upload.py ``` and input the requested information, including HF hub user token. diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index c50e7ff01..9c6c9d087 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -111,7 +111,7 @@ Logging Arguments - **git_hash**: str - Default = d3e481c + Default = 70284f1 current git hash of repository @@ -1906,6 +1906,14 @@ Args for deepspeed config +- **load_universal**: bool + + Default = False + + Flag for whether the checkpoint to be loaded is a universal checkpoint. + + + ## NeoXArgsDeepspeedRunner Args for deepspeed runner (deepspeed.launcher.runner). diff --git a/megatron/model/utils.py b/megatron/model/utils.py index 141a6e4e9..96915584e 100644 --- a/megatron/model/utils.py +++ b/megatron/model/utils.py @@ -40,10 +40,28 @@ def get_params_for_weight_decay_optimization(module, neox_args): ) or ( neox_args.weight_decay == 0.0 ): # also include all parameters here if no weight decay is being done - no_weight_decay_params["params"].extend( - [p for p in list(module_._parameters.values()) if p is not None] - ) + # no_weight_decay_params["params"].extend( + # [p for p in list(module_._parameters.values()) if p is not None] + # ) + params = [] + for n, p in module_._parameters.items(): + if p is not None: + p.module_name = f"{module_._get_name()}.{n}" + params.append(p) + no_weight_decay_params["params"].extend(params) else: + wd_params = [] + nwd_params = [] + for n, p in module_._parameters.items(): + if p is not None: + p.module_name = f"{module_._get_name()}.{n}" + if n != "bias": + wd_params.append(p) + else: + nwd_params.append(p) + weight_decay_params["params"].extend(wd_params) + no_weight_decay_params["params"].extend(nwd_params) + """ weight_decay_params["params"].extend( [ p @@ -58,6 +76,8 @@ def get_params_for_weight_decay_optimization(module, neox_args): if p is not None and n == "bias" ] ) + """ + if neox_args.weight_decay == 0.0: # only return a single param group # with onebitadam, we want to minimize the calls to compressed_allreduce. Every param group calls it once. diff --git a/megatron/neox_arguments/deepspeed_args.py b/megatron/neox_arguments/deepspeed_args.py index 15b35e411..4871648c1 100644 --- a/megatron/neox_arguments/deepspeed_args.py +++ b/megatron/neox_arguments/deepspeed_args.py @@ -277,6 +277,9 @@ class NeoXArgsDeepspeedConfig(NeoXArgsTemplate): autotuning: dict = None """Dictionary as described in DeepSpeed autotuning documentation: https://github.com/microsoft/DeepSpeed/tree/master/deepspeed/autotuning""" + load_universal: bool = False + """Flag for whether the checkpoint to be loaded is a universal checkpoint.""" + @dataclass class NeoXArgsDeepspeedRunner(NeoXArgsTemplate): diff --git a/megatron/training.py b/megatron/training.py index 96a94a1d0..1321197d8 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -610,6 +610,53 @@ def get_learning_rate_scheduler(optimizer, neox_args): return lr_scheduler +from collections import OrderedDict +import json + + +def log_bit16_groups(optimizer, param_names, zero_stage): + + """Returns a dict of name to shape mapping, only for the flattened fp32 weights saved by the + optimizer. the names are exactly as in state_dict. The order is absolutely important, since + the saved data is just flattened data with no identifiers and requires reconstruction in the + same order it was saved. + We can't rely on self.module.named_parameters() to get the saved tensors, as some params + will be missing and others unsaved and then it'd be impossible to reconstruct state_dict + from the flattened weights. + optimizer.bit16_groups seems to be the easiest to use as it's in all zeroX versions. + """ + param_group_shapes = [] + cnt = 0 + numel = 0 + + # zero2 started using a round_robin_bit16_groups which is a shuffled version of bit16_groups - + # if we don't use it, we get parameters ordered incorrectly + if hasattr(optimizer, "round_robin_bit16_groups"): + bit16_groups = optimizer.round_robin_bit16_groups + else: + bit16_groups = ( + optimizer.bit16_groups if zero_stage == 2 else optimizer.fp16_groups + ) + + for bit16_group in bit16_groups: + param_shapes = OrderedDict() + for param in bit16_group: + cnt += 1 + numel += param.ds_numel if hasattr(param, "ds_numel") else param.numel() + shape = param.ds_shape if hasattr(param, "ds_shape") else param.shape + if param not in param_names: + raise ValueError(f"failed to find optimizer param in named params") + name = param_names[param] + param_shapes[name] = shape + + # uncomment to debug zero_to_fp32.py problems + # if self.global_rank == 0: print(f"saving param {name} {shape} (numel={shape.numel()})") + param_group_shapes.append(param_shapes) + # if self.global_rank == 0: print(f"Total saved {numel} numels in {cnt} params") + + return param_group_shapes + + def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): """Setup model and optimizer.""" model = get_model(neox_args=neox_args, use_cache=use_cache) @@ -637,6 +684,11 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): # config_params=neox_args.deepspeed_config, mpu=mpu if not neox_args.is_pipe_parallel else None, ) + zero_stage = neox_args.zero_optimization["stage"] + # bit16_groups = log_bit16_groups(optimizer, model.param_names, zero_stage) + bit16_groups = model._get_zero_param_shapes() + with open(f"zero{zero_stage}.json", mode="w") as jfile: + json.dump(bit16_groups, jfile) model.total_params = get_total_params(model.module) print_rank_0(f' > total params: {"{:,}".format(model.total_params)}') diff --git a/tools/README.md b/tools/README.md new file mode 100644 index 000000000..aabea6c5b --- /dev/null +++ b/tools/README.md @@ -0,0 +1,15 @@ +# GPT-NeoX Auxillery Tools + +This directory contains a number of auxillery tools that are useful for working with GPT-NeoX but not part of the main training code. + +## Bash + +This directory contains some simple, frequently used bash commands to make working on multiple machines easier. + +## Checkpoints + +This directory contains tools for manipulating and converting checkpoints including changing the parallelism settings of a pretrained model, converting between GPT-NeoX and the transformers library, and updating checkpoints trained with Version 1.x of this library to be compatible with Version 2.x. + +## Datasets + +This directory contains tools for downloading and preprocessing datasets to the format expected by the GPT-NeoX library. diff --git a/tools/kill.sh b/tools/bash/kill.sh old mode 100755 new mode 100644 similarity index 100% rename from tools/kill.sh rename to tools/bash/kill.sh diff --git a/tools/killall.sh b/tools/bash/killall.sh old mode 100755 new mode 100644 similarity index 100% rename from tools/killall.sh rename to tools/bash/killall.sh diff --git a/tools/sync.sh b/tools/bash/sync.sh old mode 100755 new mode 100644 similarity index 100% rename from tools/sync.sh rename to tools/bash/sync.sh diff --git a/tools/sync_cmd.sh b/tools/bash/sync_cmd.sh similarity index 100% rename from tools/sync_cmd.sh rename to tools/bash/sync_cmd.sh diff --git a/tools/syncdir.sh b/tools/bash/syncdir.sh old mode 100755 new mode 100644 similarity index 100% rename from tools/syncdir.sh rename to tools/bash/syncdir.sh diff --git a/tools/checkpoints/README.md b/tools/checkpoints/README.md new file mode 100644 index 000000000..093fa2a35 --- /dev/null +++ b/tools/checkpoints/README.md @@ -0,0 +1,50 @@ +# GPT-NeoX Checkpoint Manipulation Tools + +## Checkpoint Conversion + +The default format Deepspeed checkpoints are saved in is dependent on the model and pipeline parallelism settings of the training run. Running a model on a cluster with a different number or type of GPUs is difficult. We have adapted a set of scripts developed by [BigScience](https://github.com/bigscience-workshop/Megatron-DeepSpeed/tree/main/tools/convert_checkpoint) to make this easier. + +### DeeperSpeed to universal + +To convert your checkpoint to the universal checkpoint format run the `ds_to_universal.py` script with a command along these lines. + +```bash +CURR_CKPT="/path/to/your/old/checkpoint" +NEW_CKPT="/path/where/you/want/the/new/checkpoint" +CFG="/path/to/model/config/file" + +python3 tools/ds_to_universal.py \ + --input_folder $CURR_CKPT \ + --output_folder $NEW_CKPT \ + --config $CFG +``` + +To then run the model from your new checkpoint, add these lines to a new config and run your model like you normally would. + +```json +{ + "load": "/path/where/you/want/the/new/checkpoint", + "load_universal": true +} +``` + +### DeeperSpeed to DeeperSpeed Reshaping + +To reshape a DeeperSpeed checkpoint to _reduce_ the parallelism settings, you can use the `deepspeed_to_deepspeed.py` script. It does not work if you would like to re-shard a model to increase the amount of tensor or pipeline parallelism. But if you would like to decrease the amount of parallelism you can run the script with a command like the one below. + +```bash +CURR_CKPT="/path/to/your/old/checkpoint" +NEW_CKPT="/path/where/you/want/the/new/checkpoint" +CFG="/path/to/model/config/file" +TP=1 # Tensor (model) parallelism setting for the new checkpoint, must be less than or equal to the model's original tensor parallelism +DP=1 # Data parallelism setting for the new checkpoint +PP=1 # Model parallelism setting for the new checkpoint, must be less than or equal to the model's original pipeline parallelism + +python3 tools/deepspeed_to_deepspeed.py \ + --input_folder $CURR_CKPT \ + --output_folder $NEW_CKPT \ + --config $CFG \ + --target_tp $TP \ + --target_dp $DP \ + --target_pp $PP +``` diff --git a/tools/convert_sequential_to_hf.py b/tools/checkpoints/convert_sequential_to_hf.py similarity index 100% rename from tools/convert_sequential_to_hf.py rename to tools/checkpoints/convert_sequential_to_hf.py diff --git a/tools/checkpoints/convert_v1.0_to_hf.py b/tools/checkpoints/convert_v1.0_to_hf.py new file mode 100644 index 000000000..905bdfa16 --- /dev/null +++ b/tools/checkpoints/convert_v1.0_to_hf.py @@ -0,0 +1,334 @@ +# Copyright (c) 2023, EleutherAI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys + +import yaml +import argparse +from tqdm import tqdm +from typing import List + +import torch +from transformers import GPTNeoXConfig, GPTNeoXForCausalLM + + +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)) +) +from megatron.tokenizer import build_tokenizer + + +""" +A script for converting saved NeoX Checkpoints to Huggingface (HF) compatible GPT-NeoX type models. + +Note that this script does not support all NeoX features. +Please investigate carefully whether your model is compatible with all architectures supported by the GPTNeoXForCausalLM class in HF. + +(e.g. position embeddings such as AliBi may not be supported by Huggingface's GPT-NeoX architecture. +""" + + +def load_partitions( + input_checkpoint_path, mp_partitions, layer_idx +) -> List[torch.Tensor]: + """Returns a list containing all weights in a given layer from a model (across MP partitions)""" + + loaded_tp_ranks = [ + torch.load( + os.path.join( + input_checkpoint_path, + f"layer_{layer_idx:02}-model_{i:02}-model_states.pt", + ), + map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"), + ) + for i in range(mp_partitions) + ] + + return loaded_tp_ranks + + +def get_key(loaded_config, key, default=None): + """ + Search for a given key in a NeoX yaml. normalizes underscores -> hyphens + """ + key = key.replace("_", "-") + try: + return loaded_config[key] + except KeyError: + key = key.replace("-", "_") + try: + return loaded_config[key] + except KeyError: + return default + + +def create_config(neox_config): + """take in a loaded yaml from NeoX and assign relevant values to HF config. + Returns: GPTNeoXConfig() object + """ + + class TokenizerArgs: + # kinda hacky. + # this is to get something with the same interface as is used in build_tokenizer() + # without diving into loading a neox_args object or using argparse etc. + def __init__(self, neox_config): + self.make_vocab_size_divisible_by = get_key( + neox_config, "make-vocab-size-divisible-by", default=128 + ) + self.model_parallel_size = get_key(neox_config, "model-parallel-size") + self.vocab_file = get_key(neox_config, "vocab-file") + self.merge_file = get_key(neox_config, "merge-file") + self.tokenizer_type = get_key(neox_config, "tokenizer-type") + + self.rank = 0 + + args = TokenizerArgs(neox_config) + tokenizer = build_tokenizer(args) + try: # GPT2TokenizerFast raises NotImplementedError + pad_token = tokenizer.pad + except: + pad_token = ( + 1 # pad defaulting to 1. follows convention from GPT-NeoX-20b tokenizer + ) + + # TODO: change the default value here based on discussion regarding `gpt_j_tied` config parameter's default + use_tied_lns = get_key(neox_config, "gpt-j-tied", False) + + if use_tied_lns: + raise NotImplementedError( + """ERROR: Huggingface Transformers does not yet support a single shared layernorm + per transformer block for GPT-NeoX models trained w/ GPT-J parallel residuals. + See https://github.com/EleutherAI/gpt-neox/pull/481 for further details.""" + ) + + # set all config values. + hf_config = GPTNeoXConfig( + vocab_size=args.padded_vocab_size, + hidden_size=get_key(neox_config, "hidden-size"), + num_hidden_layers=get_key(neox_config, "num-layers"), + num_attention_heads=get_key(neox_config, "num-attention-heads"), + intermediate_size=(get_key(neox_config, "hidden-size") * 4), + hidden_act=get_key(neox_config, "activation", default="gelu"), + rotary_pct=get_key(neox_config, "rotary-pct", default=1.0), + rotary_emb_base=get_key(neox_config, "rotary-emb-base", default=10000), + max_position_embeddings=get_key(neox_config, "max-position-embeddings"), + initializer_range=get_key(neox_config, "init-method-std", 0.02), + layer_norm_eps=get_key(neox_config, "layernorm-epsilon", 1e-5), + use_cache=True, + bos_token_id=tokenizer.eod, + eos_token_id=tokenizer.eod, + tie_word_embeddings=(not get_key(neox_config, "no-weight-tying", False)), + use_parallel_residual=get_key(neox_config, "gpt-j-residual", False), + ) + return hf_config + + +def convert(input_checkpoint_path, loaded_config, output_checkpoint_path): + """convert a NeoX checkpoint to a HF model format. + should perform model-parallel merging correctly + but only supports features allowed by HF GPT-NeoX implementation (e.g. rotary embeddings) + """ + + hf_config = GPTNeoXConfig() + + hf_config = create_config(loaded_config) + + hf_model = GPTNeoXForCausalLM(hf_config) + + # save model in fp16/bf16 if Deepspeed fp16 or bf16 mixed precision was used in config, else 32 bit weights + fp16 = get_key(loaded_config, "fp16") + if fp16: + try: + # this conditional is quite messy because there were a number of ways to specify bf16 or fp16 training + # in DeeperSpeed v1.0 . + if (fp16.get("fp16", None) or fp16["enabled"]) and not (fp16.get("type", None) == "bfloat16"): + hf_model.half() + print("Saving weights in fp16 precision...") + elif fp16.get("type", None) == "bfloat16": + hf_model.to(dtype=torch.bfloat16) + print("Saving weights in bf16 precision...") + except: + print("Model not trained in fp16 / bf16 mixed precision, saving weights in fp32...") + + mp_partitions = get_key(loaded_config, "model-parallel-size") + + ### Embedding layer ### + loaded_tp_ranks = load_partitions(input_checkpoint_path, mp_partitions, 0) + hf_model.gpt_neox.embed_in.load_state_dict( + { + "weight": torch.cat( + [t["word_embeddings.weight"] for t in loaded_tp_ranks], dim=0 + ) + } + ) + + assert ( + hf_config.vocab_size == hf_model.gpt_neox.embed_in.weight.shape[0] + ), f"ERROR: calculated vocab size {hf_config.vocab_size} != embed param size {hf_model.gpt_neox.embed_in.shape[0]}" + ### End Embedding Layer ### + + for layer_i in tqdm(range(get_key(loaded_config, "num-layers"))): + + # get layer from hf model + hf_layer = hf_model.gpt_neox.layers[layer_i] + + # + 2 bc of embed layer and a dummy _pre_transformer_block + loaded_tp_ranks = load_partitions( + input_checkpoint_path, mp_partitions, layer_i + 2 + ) + + state_dict = {} + for key in [ + "attention.dense.weight", + "mlp.dense_4h_to_h.weight", + ]: + state_dict[key] = torch.cat([t[key] for t in loaded_tp_ranks], dim=1) + + # average layernorm stats over mp ranks + for key in [ + "input_layernorm.weight", + "input_layernorm.bias", + "post_attention_layernorm.weight", + "post_attention_layernorm.bias", + ]: + state_dict[key] = (sum([t[key] for t in loaded_tp_ranks])) / len( + loaded_tp_ranks + ) + + # LinearWithTPMerge + for key in [ + "mlp.dense_h_to_4h.weight", + "mlp.dense_h_to_4h.bias", + "attention.query_key_value.weight", + "attention.query_key_value.bias", + ]: + state_dict[key] = torch.cat([t[key] for t in loaded_tp_ranks], dim=0) + + # LinearWithTPSplitBias + for key in [ + "mlp.dense_4h_to_h.bias", + "attention.dense.bias", + ]: + state_dict[key] = sum([t[key] for t in loaded_tp_ranks]) + + # Just take one + state_dict["attention.rotary_emb.inv_freq"] = loaded_tp_ranks[0][ + "attention.rotary_emb.inv_freq" + ] + state_dict["attention.bias"] = hf_layer.state_dict()["attention.bias"] + state_dict["attention.masked_bias"] = hf_layer.state_dict()[ + "attention.masked_bias" + ] + + # load state_dict into layer + hf_layer.load_state_dict(state_dict) + + # Load final layer norm + loaded_tp_ranks = load_partitions( + input_checkpoint_path, mp_partitions, get_key(loaded_config, "num-layers") + 3 + ) + + hf_model.gpt_neox.final_layer_norm.load_state_dict( + { + "weight": (sum([t["norm.weight"] for t in loaded_tp_ranks])) + / len(loaded_tp_ranks), + "bias": (sum([t["norm.bias"] for t in loaded_tp_ranks])) + / len(loaded_tp_ranks), + } + ) + del loaded_tp_ranks + + # Load output embedding + loaded_tp_ranks = load_partitions( + input_checkpoint_path, mp_partitions, get_key(loaded_config, "num-layers") + 4 + ) + + hf_model.embed_out.load_state_dict( + { + "weight": torch.cat( + [t["final_linear.weight"] for t in loaded_tp_ranks], dim=0 + ), + } + ) + + del loaded_tp_ranks + + return hf_model + + +if __name__ == "__main__": + + # before running script: + # `pip install --upgrade transformers` + # `huggingface-cli login` + # + from huggingface_hub import create_repo, HfApi + + parser = argparse.ArgumentParser( + description="Merge MP partitions and convert to HF Model." + ) + parser.add_argument( + "--input_dir", + type=str, + help="Path to NeoX checkpoint, e.g. /path/to/model/global_step143000", + ) + parser.add_argument( + "--config_file", + type=str, + help="Path to config file for the input NeoX checkpoint.", + ) + parser.add_argument( + "--output_dir", + type=str, + help="Output dir, where to save the HF Model, tokenizer, and configs", + ) + parser.add_argument( + "--upload", + action="store_true", + help="Set to true in order to upload to the HF Hub directly.", + ) + args = parser.parse_args() + + with open(args.config_file) as f: + loaded_config = yaml.full_load(f) + + hf_model = convert(args.input_dir, loaded_config, args.output_dir) + + hf_model.save_pretrained(args.output_dir) + + # save tokenizer to directory as well, for easy loading of model as a HF model + tokenizer_type = get_key(loaded_config, "tokenizer-type") + + if tokenizer_type == "HFTokenizer": + print(f"saving tokenizer from file {get_key(loaded_config, 'vocab-file')}") + from transformers import PreTrainedTokenizerFast + + tokenizer = PreTrainedTokenizerFast( + tokenizer_file=get_key(loaded_config, "vocab-file") + ) + print("loaded tokenizer: ", tokenizer) + tokenizer.save_pretrained(args.output_dir) + print("tokenizer saved!") + + if args.upload: + repo_name = input("Provide a repository name for the HF Hub: ") + create_repo(repo_name, repo_type="model", private=False, use_auth_token=True) + + api = HfApi() + api.upload_folder( + folder_path=args.output_dir, + repo_id=repo_name, + repo_type="model", + ) diff --git a/tools/checkpoints/deepspeed_to_deepspeed.py b/tools/checkpoints/deepspeed_to_deepspeed.py new file mode 100644 index 000000000..2eb0adc7a --- /dev/null +++ b/tools/checkpoints/deepspeed_to_deepspeed.py @@ -0,0 +1,230 @@ +#!/usr/bin/env python +import argparse +import os +from pathlib import Path +import sys +import torch +import yaml + + +# insert megatron's root dir into sys.path +root_repo_path = str(Path(__file__).resolve().parents[2]) +if root_repo_path not in sys.path: + sys.path.insert(0, root_repo_path) + +from megatron.neox_arguments import NeoXArgs +from deepspeed.checkpoint.deepspeed_checkpoint import ( + ARGS_KEY, + CHECKPOINT_INFO_KEY, +) + +from deepspeed.checkpoint import ( + NeoxCheckpoint, + get_model_ckpt_name_for_rank, + get_zero_ckpt_name_for_rank, + get_layer_ckpt_name_for_rank, +) + +CHECKPOINT_FILE_SUFFIX = "_model_states.pt" +MP_WORLD_SIZE = "mp_world_size" +WORD_EMBEDDINGS_KEY = "word_embeddings.weight" +FINAL_LINEAR_KEY = "final_linear.weight" +ORIGINAL_VOCAB_SIZE = "original_vocab_size" +PADDED_VOCAB_SIZE = "padded_vocab_size" + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--input_folder", + default=None, + type=str, + help="Input DeepSpeed Checkpoint folder", + ) + parser.add_argument( + "--output_folder", + default=None, + type=str, + help="Output Megatron checkpoint folder", + ) + parser.add_argument("--config", type=str, help="Path to yml config") + parser.add_argument("--target_tp", default=None, type=int, help="Target TP degree") + parser.add_argument("--target_pp", default=None, type=int, help="Target PP degree") + parser.add_argument("--target_dp", default=None, type=int, help="Target DP degree") + parser.add_argument( + "--iteration", + default=None, + type=int, + help="Which checkpoint to load, defaults to what is in latest if None", + ) + + args = parser.parse_args() + print(f"args = {args}") + return args + + +def _vocab_size_with_padding(orig_vocab_size, divisible_by, tp_size): + """Pad vocab size so it is divisible by model parallel size and + still having GPU friendly size.""" + + after = orig_vocab_size + multiple = divisible_by * tp_size + while (after % multiple) != 0: + after += 1 + + print( + " > padded vocab (size: {}) with {} dummy tokens " + "(new size: {})".format(orig_vocab_size, after - orig_vocab_size, after), + flush=True, + ) + return after + + +def _save_checkpoint(file_path, chkpt_sd): + dir, _ = os.path.split(file_path) + os.makedirs(dir, exist_ok=True) + torch.save(chkpt_sd, file_path) + + +def _create_transformer_layer_checkpoint( + ds_checkpoint, base_folder, tp_index, pp_index +): + sd_list = ds_checkpoint.get_transformer_state(tp_index, pp_index) + layer_id_list = ds_checkpoint.get_pp_transformer_map(pp_index) + assert len(sd_list) == len(layer_id_list) + for sd, layer_id in zip(sd_list, layer_id_list): + ckpt_path = get_layer_ckpt_name_for_rank( + base_folder=base_folder, layer_id=layer_id, tp_rank=tp_index + ) + _save_checkpoint(ckpt_path, sd) + + +def _strip_vocab_padding(ds_checkpoint, padded_vocab_tensor, neox_args): + target_args = ds_checkpoint.get_args() + target_args["model_parallel_size"] = ds_checkpoint.tp_degree + + padded_vocab_size = _vocab_size_with_padding( + neox_args.tokenizer.vocab_size, + target_args["make_vocab_size_divisible_by"], + ds_checkpoint.tp_degree, + ) + padded_layer_size = padded_vocab_size // ds_checkpoint.tp_degree + assert padded_vocab_size <= padded_vocab_tensor.numel() + target_args[PADDED_VOCAB_SIZE] = padded_vocab_size + unpadded_vocab_tensor = torch.narrow(padded_vocab_tensor, 0, 0, padded_layer_size) + return unpadded_vocab_tensor.clone() + + +def _create_embedding_layer_checkpoint(ds_checkpoint, base_folder, tp_index, args): + sd = ds_checkpoint.get_embedding_state(tp_index) + if ds_checkpoint.is_change_tp_degree(): + print(f"TP index: {tp_index}, embeddings shape {sd[WORD_EMBEDDINGS_KEY].shape}") + sd[WORD_EMBEDDINGS_KEY] = _strip_vocab_padding( + ds_checkpoint, sd[WORD_EMBEDDINGS_KEY], args + ) + layer_id = ds_checkpoint.get_embedding_layer_id() + ckpt_path = get_layer_ckpt_name_for_rank( + base_folder=base_folder, tp_rank=tp_index, layer_id=layer_id + ) + _save_checkpoint(ckpt_path, sd) + + +def _create_final_norm_layer_checkpoint(ds_checkpoint, base_folder, tp_index, args): + sd = ds_checkpoint.get_final_norm_state(tp_index) + layer_id = ds_checkpoint.get_final_norm_layer_id() + if ds_checkpoint.is_change_tp_degree(): + sd[FINAL_LINEAR_KEY] = _strip_vocab_padding( + ds_checkpoint, sd[FINAL_LINEAR_KEY], args + ) + ckpt_path = get_layer_ckpt_name_for_rank( + base_folder=base_folder, tp_rank=tp_index, layer_id=layer_id + ) + _save_checkpoint(ckpt_path, sd) + + +def _create_2d_parallel_checkpoint(ds_checkpoint, base_folder, tp_index, pp_index): + sd = ds_checkpoint.get_2d_parallel_state(tp_index=tp_index, pp_index=pp_index) + ckpt_info = ds_checkpoint.get_checkpoint_info() + sd[MP_WORLD_SIZE] = ds_checkpoint.tp_degree + file_id = pp_index * ds_checkpoint.tp_degree + tp_index + ckpt_path = get_model_ckpt_name_for_rank(base_folder, f"{file_id:02d}") + + # Adjust specific fields + sd[ARGS_KEY] = ds_checkpoint.get_args() + # sd[ARGS_KEY][PADDED_VOCAB_SIZE] = ckpt_info[PADDED_VOCAB_SIZE] + sd[ARGS_KEY]["model_parallel_size"] = ds_checkpoint.tp_degree + sd[ARGS_KEY]["pipe_parallel_size"] = ds_checkpoint.pp_degree + if CHECKPOINT_INFO_KEY not in sd: + sd[CHECKPOINT_INFO_KEY] = {} + sd[CHECKPOINT_INFO_KEY][PADDED_VOCAB_SIZE] = sd[ARGS_KEY][PADDED_VOCAB_SIZE] + _save_checkpoint(ckpt_path, sd) + + +def _create_zero_checkpoint(ds_checkpoint, base_folder, dp_index, pp_index, tp_index): + _2d_rank = (pp_index * ds_checkpoint.tp_degree) + tp_index + sd = ds_checkpoint.get_zero_checkpoint_state( + pp_index=pp_index, tp_index=tp_index, dp_index=dp_index + ) + ckpt_path = get_zero_ckpt_name_for_rank( + base_folder=base_folder, dp_rank=dp_index, mp_rank=_2d_rank + ) + _save_checkpoint(ckpt_path, sd) + + +def _create_latest_file(base_folder, file_name, latest_tag): + file_path = os.path.join(base_folder, file_name) + os.makedirs(base_folder, exist_ok=True) + with open(file_path, "w") as f: + f.write(str(latest_tag)) + + +def get_folder(args): + folder = Path(args.input_folder) + if args.iteration is None: + with open(folder / "latest") as latest_file: + tag = latest_file.read().strip() + else: + tag = f"global_step{args.iteration}" + return folder / tag + + +def main(): + print(f"Convert DeepSpeed Checkpoint to DeepSpeed Checkpoint") + + args = parse_arguments() + print( + f"Converting DeepSpeed checkpoint in {args.input_folder} to DeepSpeed checkpoint in {args.output_folder}" + ) + + neox_args = NeoXArgs.from_ymls([args.config]) + neox_args.build_tokenizer() + + ckpt_folder = get_folder(args) + + ds_checkpoint = NeoxCheckpoint( + ckpt_folder, args.target_tp, args.target_pp, args.target_dp + ) + iteration = ds_checkpoint.get_iteration() + latest_tag = f"global_step{iteration}" + _create_latest_file( + args.output_folder, "latest_checkpointed_iteration.txt", iteration + ) + _create_latest_file(args.output_folder, "latest", latest_tag) + base_folder = os.path.join(args.output_folder, latest_tag) + + for i in range(ds_checkpoint.tp_degree): + _create_embedding_layer_checkpoint(ds_checkpoint, base_folder, i, neox_args) + _create_final_norm_layer_checkpoint(ds_checkpoint, base_folder, i, neox_args) + + for j in range(ds_checkpoint.pp_degree): + _create_transformer_layer_checkpoint(ds_checkpoint, base_folder, i, j) + _create_2d_parallel_checkpoint(ds_checkpoint, base_folder, i, j) + + for i in range(ds_checkpoint.dp_degree): + for j in range(ds_checkpoint.pp_degree): + for k in range(ds_checkpoint.tp_degree): + _create_zero_checkpoint(ds_checkpoint, base_folder, i, j, k) + + +if __name__ == "__main__": + main() diff --git a/tools/checkpoints/ds_to_universal.py b/tools/checkpoints/ds_to_universal.py new file mode 100644 index 000000000..de1af0c33 --- /dev/null +++ b/tools/checkpoints/ds_to_universal.py @@ -0,0 +1,392 @@ +#!/usr/bin/env python + +from collections import OrderedDict +from copy import deepcopy +from email.policy import default +from functools import partial +from pathlib import Path +from pprint import pprint +import argparse +import glob +import itertools +import logging +import multiprocessing +import os +import re +import shutil +import sys +import torch +import tqdm + +# insert megatron's root dir into sys.path +root_repo_path = str(Path(__file__).resolve().parents[2]) +if root_repo_path not in sys.path: + sys.path.insert(0, root_repo_path) + +from megatron.neox_arguments import NeoXArgs +from deepspeed.checkpoint import NeoxCheckpoint + +MODEL_KEY = "model" +ARGS_KEY = "args" +LANGUGAGE_MODEL_KEY = "language_model" +EMBEDDING_KEY = "embedding" +ENCODER_KEY = "encoder" +WORD_EMBEDDINGS_FOR_HEAD_KEY = "word_embeddings_for_head" +WORD_EMBEDDINGS_KEY = "word_embeddings" +FINAL_LAYER_NORM_KEY = "final_layernorm" +CHECKPOINT_VERSION_KEY = "checkpoint_version" +CHECKPOINT_VERSION_VALUE = 3.0 +ITERATION_KEY = "iteration" + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--input_folder", type=str, help="Input DeepSpeed Checkpoint folder" + ) + parser.add_argument( + "--output_folder", type=str, help="Output Megatron checkpoint folder" + ) + parser.add_argument("--config", type=str) + parser.add_argument("--target_tp", default=None, type=int, help="Target TP degree") + parser.add_argument("--target_pp", default=None, type=int, help="Target PP degree") + parser.add_argument("--target_dp", default=None, type=int, help="Target PP degree") + parser.add_argument( + "--iteration", default=None, type=int, help="Checkpoint iteration" + ) + parser.add_argument( + "--num_extract_workers", + default=4, + type=int, + help="How many parallel processes to extract zero shards", + ) + parser.add_argument( + "--num_merge_workers", + default=2, + type=int, + help="How many parallel processes to merge tp slices (more memory intensive, use much fewer than --num_extract_workers))", + ) + parser.add_argument( + "--for_release", + action="store_true", + help="Convert for release purpose, reset some (progress) counters.", + ) + args = parser.parse_args() + print(f"args = {args}") + return args + + +def get_folder(args): + folder = Path(args.input_folder) + if args.iteration is None: + with open(folder / "latest") as latest_file: + tag = latest_file.read().strip() + else: + tag = f"global_step{args.iteration}" + return folder / tag + + +def _convert_ds_transformer_state(sd_list): + new_sd = OrderedDict() + for i, sd in enumerate(sd_list): + for key, value in sd.items(): + new_key = f"layers.{i}.{key}" + new_sd[new_key] = value + + return new_sd + + +def _create_checkpoint_paths(base_folder, iteration, tp_degree, pp_degree): + path_list = [] + iter_folder = f"iter_{iteration:07d}" + for i in range(0, tp_degree): + path_list.append([]) + for j in range(0, pp_degree): + rank_folder = ( + f"mp_rank_{i:02d}" if pp_degree == 1 else f"mp_rank_{i:02d}_{j:03d}" + ) + ckpt_path = os.path.join(rank_folder, "model_optim_rng.pt") + path_list[i].append(os.path.join(base_folder, iter_folder, ckpt_path)) + + return path_list + + +def _create_megatron_dict(): + language_model_dict = {EMBEDDING_KEY: {}, ENCODER_KEY: {}} + megatron_dict = { + MODEL_KEY: {LANGUGAGE_MODEL_KEY: language_model_dict}, + CHECKPOINT_VERSION_KEY: CHECKPOINT_VERSION_VALUE, + } + return megatron_dict + + +def _save_checkpoint(file_path, chkpt_sd): + dir, _ = os.path.split(file_path) + os.makedirs(dir, exist_ok=True) + torch.save(chkpt_sd, file_path) + + +def extract_zero_shards(dir, slice_shapes, ds_checkpoint, indices_3D): + pp_index, tp_index, dp_index = indices_3D + sd = ds_checkpoint.get_zero_checkpoint_state( + pp_index=pp_index, tp_index=tp_index, dp_index=dp_index + ) + + pprint(f"Processing {dp_index=} {pp_index=}, {tp_index=}") + + optim_sd = sd["optimizer_state_dict"] + param_slice_mappings = optim_sd["param_slice_mappings"] + + # dict + state_groups = optim_sd["base_optimizer_state"]["state"] + # list + fp32_groups = optim_sd["single_partition_of_fp32_groups"] + param_groups_cnt = len(state_groups) + + for param_group_id in range(param_groups_cnt): + + flat_state = dict( + exp_avg=state_groups[param_group_id]["exp_avg"], + exp_avg_sq=state_groups[param_group_id]["exp_avg_sq"], + fp32=fp32_groups[param_group_id], + ) + + for name, fragment_mapping in param_slice_mappings[param_group_id].items(): + if "word_embeddings.weight" in name and pp_index > 0: + # Skip tied weights that are replicated in first and last pp stages + continue + + # print( + # f"{param_group_id} {name} => {fragment_mapping.start}:{fragment_mapping.numel}" + # ) + for state_key in flat_state.keys(): + dump_param_fragment( + dir, + tp_index, + dp_index, + state_key, + flat_state[state_key], + name, + fragment_mapping.start, + fragment_mapping.numel, + ) + + +cnt = 0 + + +def dump_param_fragment( + dir, tp_index, dp_index, state_name, state_flat_tensor, param_name, offset, numel +): + + global cnt # temp hack + + param_base_path = os.path.join(dir, param_name, str(tp_index)) + os.makedirs(param_base_path, exist_ok=True) + + cnt += 1 + counter = f"{dp_index:0>2d}" + + path = os.path.join(param_base_path, f"{state_name}.{counter}") + + # print(f"{param_name}: {offset}: {numel} => {path}") + + t = state_flat_tensor.narrow(0, offset, numel) + _save_checkpoint(path, t) + + +def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape): + slices = [] + print("###############", f"\n{state}", "\n################") + for tp_index in range(tp_degree): + prefix_path = os.path.join(param_base_path, str(tp_index), f"{state}") + paths = sorted(list(glob.glob(f"{prefix_path}.0*"))) + print(f"{paths=}") + shards = [torch.load(p) for p in paths] + print([p.shape for p in shards]) + print(f"{slice_shape=}") + slice = torch.cat(shards, dim=0).reshape(slice_shape) + slices.append(slice) + + return slices + + +ORIGINAL_VOCAB_SIZE = "original_vocab_size" + + +WEIGHTS_TO_AVERAGE_PATTERNS = [ + r"tied_modules.embed.word_embeddings.norm.weight", + r"tied_modules.embed.word_embeddings.norm.bias", + r"\d+.input_layernorm.weight", + r"\d+.input_layernorm.bias", + r"\d+.post_attention_layernorm.weight", + r"\d+.post_attention_layernorm.bias", + r"\d+.self_attention.dense.bias", + r"\d+.mlp.dense_4h_to_h.bias", + r"\d+.weight", + r"\d+.bias", +] + +WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN = [ + "dense_4h_to_h.weight", + "self_attention.dense.weight", +] + + +def _get_vocab_divisibility_padding_tensor(padded_vocab_tensor, neox_args): + # checkpoint_info = ds_checkpoint.get_checkpoint_info() + if padded_vocab_tensor.shape[0] > neox_args.tokenizer.vocab_size: + return padded_vocab_tensor[-1] + else: + return torch.zeros(padded_vocab_tensor.shape[1]) + + +def merge_tp_slices( + ds_checkpoint, neox_args, dir, slice_dir, tp_degree, name_and_shape +): + name, shape = name_and_shape + slice_base_path = os.path.join(slice_dir, name) + param_base_path = os.path.join(dir, name) + + for state in ("fp32", "exp_avg", "exp_avg_sq"): + slices = _merge_zero_shards(slice_base_path, state, tp_degree, shape) + final_path = os.path.join(param_base_path, f"{state}.pt") + + # print(f"Expected shape: {shape}") + # print(f"Fragment sizes:", list(frag.shape for frag in slices)) + ckpt_dict = {} + if any(re.match(pattern, name) for pattern in WEIGHTS_TO_AVERAGE_PATTERNS): + param = sum(slices) / len(slices) + else: + cat_dim = ( + 1 + if any(text in name for text in WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN) + else 0 + ) + # print(f"CAT DIM: {cat_dim}") + param = torch.cat(slices, dim=cat_dim) + ckpt_dict["cat_dim"] = cat_dim + + if "word_embeddings.weight" in name: + # print(f"Before {param.shape=}") + # strip padding + # param = _strip_vocab_padding(ds_checkpoint, param) + ckpt_dict[ + "vocab_divisibility_padding_tensor" + ] = _get_vocab_divisibility_padding_tensor(param, neox_args) + # print(f"After {param.shape=}") + + # print(f"Final shape: {param.shape}") + ckpt_dict["param"] = param + _save_checkpoint(final_path, ckpt_dict) + + +def _get_chunks(l, n): + for i in range(0, len(l), n): + yield l[i : i + n] + + +def _do_parallel_work(do_work, work_chunks, num_workers): + pool = multiprocessing.Pool(num_workers) + for batch in tqdm.tqdm(work_chunks): + pool.map(do_work, batch) + pool.close() + pool.join() + + +def _extract_zero_shard_files(args, ds_checkpoint, slice_shapes, temp_dir): + _3d_range_list = list( + itertools.product( + range(ds_checkpoint.pp_degree), + range(ds_checkpoint.tp_degree), + range(ds_checkpoint.dp_degree), + ) + ) + # pprint(_3d_range_list) + work_chunks = list(_get_chunks(_3d_range_list, args.num_extract_workers)) + # pprint(work_chunks) + + do_work = partial(extract_zero_shards, temp_dir, slice_shapes, ds_checkpoint) + _do_parallel_work(do_work, work_chunks, args.num_extract_workers) + + +def _merge_tp_slice_files(args, ds_checkpoint, neox_args, slice_shapes, temp_dir): + work_chunks = list(_get_chunks(list(slice_shapes.items()), args.num_merge_workers)) + # pprint(work_chunks) + zero_output_folder = os.path.join(args.output_folder, "zero") + do_work = partial( + merge_tp_slices, + ds_checkpoint, + neox_args, + zero_output_folder, + temp_dir, + ds_checkpoint.tp_degree, + ) + _do_parallel_work(do_work, work_chunks, args.num_merge_workers) + + +def main(): + print(f"Convert DeepSpeed Checkpoint to Universal Checkpoint") + + args = parse_arguments() + print( + f"Converting DeepSpeed checkpoint in {args.input_folder} to Universal checkpoint in {args.output_folder}" + ) + + input_folder = get_folder(args) + + ds_checkpoint = NeoxCheckpoint( + input_folder, args.target_tp, args.target_pp, args.target_dp + ) # , 1, 2) # args.target_tp, args.target_pp) + neox_args = NeoXArgs.from_ymls([args.config]) + neox_args.build_tokenizer() + + iteration = ds_checkpoint.get_iteration() + # _create_latest_file(args.output_folder, iteration) + print( + f"DP degree: {ds_checkpoint.original_dp_degree} ---> {ds_checkpoint.dp_degree}" + ) + print( + f"TP degree: {ds_checkpoint.original_tp_degree} ---> {ds_checkpoint.tp_degree}" + ) + print( + f"PP degree: {ds_checkpoint.original_pp_degree} ---> {ds_checkpoint.pp_degree}" + ) + checkpoint_paths = _create_checkpoint_paths( + args.output_folder, iteration, ds_checkpoint.tp_degree, ds_checkpoint.pp_degree + ) + + slice_shapes = [] + for mp_rank_file in ds_checkpoint.mp_rank_files: + mp_sd = torch.load(mp_rank_file, map_location=torch.device("cpu")) + slice_shapes += mp_sd["param_shapes"] + + # fix back to normal flat dict, merge duplicates for tp>1 + slice_shapes = dict((k, v) for d in slice_shapes for k, v in d.items()) + temp_dir = os.path.join(args.output_folder, "tmp") + + # print(slice_shapes) + + print("*** 1. Extracting ZeRO fragments") + _extract_zero_shard_files(args, ds_checkpoint, slice_shapes, temp_dir) + + print("*** 2. Merging slices") + _merge_tp_slice_files(args, ds_checkpoint, neox_args, slice_shapes, temp_dir) + + shutil.rmtree(temp_dir, ignore_errors=True) + + # Copy mp* files into output folder + for f in glob.glob(os.path.join(args.input_folder, "mp*")): + shutil.copy2(f, args.output_folder) + + # Update latest to output folder + checkpoint_root_folder, step_folder = os.path.split(args.output_folder) + latest_file = os.path.join(checkpoint_root_folder, "latest_universal") + with open(latest_file, "w") as f: + f.write(step_folder) + + print("*** Done!") + + +if __name__ == "__main__": + main() diff --git a/tools/inspect_checkpoints.py b/tools/checkpoints/inspect_checkpoints.py similarity index 100% rename from tools/inspect_checkpoints.py rename to tools/checkpoints/inspect_checkpoints.py diff --git a/tools/checkpoints/inspect_ds_checkpoint.py b/tools/checkpoints/inspect_ds_checkpoint.py new file mode 100644 index 000000000..f7383df83 --- /dev/null +++ b/tools/checkpoints/inspect_ds_checkpoint.py @@ -0,0 +1,128 @@ +import sys +from pathlib import Path + +# insert megatron's root dir into sys.path +root_repo_path = str(Path(__file__).resolve().parents[2]) +if root_repo_path not in sys.path: + sys.path.insert(0, root_repo_path) + +import argparse + +from deepspeed.checkpoint import NeoxCheckpoint + + +def list_files(file_list, tag): + print(f"Listing files: {tag}") + for i, file in enumerate(file_list): + print(f"{i+1}: {file}") + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--folder", default=None, type=str, help="DeepSpeed Checkpoint folder" + ) + parser.add_argument("--target_tp", default=None, type=int, help="Target TP degree") + parser.add_argument("--target_pp", default=None, type=int, help="Target PP degree") + parser.add_argument("--target_dp", default=None, type=int, help="Target DP degree") + parser.add_argument( + "--iteration", default=None, type=int, help="Which iteration to load" + ) + args = parser.parse_args() + print(f"args = {args}") + return args + + +def show_input_files(ds_checkpoint): + list_files(ds_checkpoint.file_list, "all") + list_files(ds_checkpoint.zero_files, "zero") + list_files(ds_checkpoint.layer_files, "layer") + list_files(ds_checkpoint.mp_rank_files, "mp rank") + + +def show_simple_state(ds_checkpoint): + print(f"layer keys = {ds_checkpoint.layer_keys}") + print(f"layer count = {ds_checkpoint.layer_count}") + + print( + f"tp_degree_count = {ds_checkpoint.original_tp_degree} ------> {ds_checkpoint.tp_degree}" + ) + print( + f"pp_degree_count = {ds_checkpoint.original_pp_degree} ------> {ds_checkpoint.pp_degree}" + ) + print( + f"dp_degree_count = {ds_checkpoint.original_dp_degree} ------> {ds_checkpoint.dp_degree}" + ) + ds_checkpoint.old_2d_map.print_data("old 2d map ==>") + ds_checkpoint.new_2d_map.print_data("new 2d map ==>") + + +def show_mappings(ds_checkpoint): + ds_checkpoint.show_pp_transformer_map() + ds_checkpoint.show_transformer_file_map() + ds_checkpoint.show_tp_embedding_map() + ds_checkpoint.show_tp_final_norm_map() + ds_checkpoint.show_2d_mapping() + + +def show_state_summary(tag, sd): + summary = {k: v.shape for k, v in sd.items()} + print(f"{tag} = {summary}") + + +def show_embedding_states(ds_checkpoint): + for i in range(0, ds_checkpoint.tp_degree): + sd = ds_checkpoint.get_embedding_state(i) + show_state_summary(f"embedding[{i}]", sd) + + +def show_final_norm_states(ds_checkpoint): + for i in range(0, ds_checkpoint.tp_degree): + sd = ds_checkpoint.get_final_norm_state(i) + show_state_summary(f"final_norm[{i}]", sd) + + +def show_transformer_states(ds_checkpoint): + for i in range(0, ds_checkpoint.tp_degree): + for j in range(0, ds_checkpoint.pp_degree): + state_list = ds_checkpoint.get_transformer_state(tp_index=i, pp_index=j) + print(f"tp_pp_rank[{i},{j}] = ") + for k, sd in enumerate(state_list): + show_state_summary(f" block[{k}]", sd) + print("") + + +def get_folder(args): + folder = Path(args.folder) + if args.iteration is None: + with open(folder / "latest") as latest_file: + tag = latest_file.read().strip() + else: + tag = f"global_step{args.iteration}" + return folder / tag + + +def main(): + print(f"Inspecting DeepSpeed Checkpoint") + args = parse_arguments() + + ckpt_folder = get_folder(args) + + ds_checkpoint = NeoxCheckpoint( + ckpt_folder, args.target_tp, args.target_pp, args.target_dp + ) + ds_checkpoint.validate_files() + + show_simple_state(ds_checkpoint) + show_input_files(ds_checkpoint) + show_simple_state(ds_checkpoint) + show_mappings(ds_checkpoint) + show_embedding_states(ds_checkpoint) + show_final_norm_states(ds_checkpoint) + show_transformer_states(ds_checkpoint) + checkpoint_args = ds_checkpoint.get_args() + print(f"checkpoint args = {checkpoint_args}") + + +if __name__ == "__main__": + main() diff --git a/tools/merge20b.py b/tools/checkpoints/merge20b.py similarity index 100% rename from tools/merge20b.py rename to tools/checkpoints/merge20b.py diff --git a/tools/merge_mp_partitions.py b/tools/checkpoints/merge_mp_partitions.py similarity index 100% rename from tools/merge_mp_partitions.py rename to tools/checkpoints/merge_mp_partitions.py diff --git a/tools/upload.py b/tools/checkpoints/upload.py similarity index 100% rename from tools/upload.py rename to tools/checkpoints/upload.py diff --git a/tools/corpora.py b/tools/datasets/corpora.py similarity index 100% rename from tools/corpora.py rename to tools/datasets/corpora.py diff --git a/tools/merge_datasets.py b/tools/datasets/merge_datasets.py similarity index 100% rename from tools/merge_datasets.py rename to tools/datasets/merge_datasets.py diff --git a/tools/preprocess_data.py b/tools/datasets/preprocess_data.py similarity index 100% rename from tools/preprocess_data.py rename to tools/datasets/preprocess_data.py