From 41b72f38dfb6c63e496a6c0c69ed84309348160a Mon Sep 17 00:00:00 2001 From: Ofir Ben Shoham Date: Thu, 4 Dec 2025 11:46:15 +0200 Subject: [PATCH 1/3] Add --is-preformatted flag to prepare_hidden_states.py Added support for preformatted input data in prepare_hidden_states.py, matching the existing flag in train_eagle3.py. This allows users to skip chat template application when their data already has the template applied. Changes: - Added --is-preformatted argument to data group - Updated cache key to include is_preformatted for proper caching - Pass is_preformatted to build_eagle3_dataset() --- scripts/prepare_hidden_states.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/scripts/prepare_hidden_states.py b/scripts/prepare_hidden_states.py index 975b9753..3ad0b8e4 100644 --- a/scripts/prepare_hidden_states.py +++ b/scripts/prepare_hidden_states.py @@ -73,6 +73,11 @@ def parse_args(): data_group.add_argument("--data-path", type=str, required=True) data_group.add_argument("--max-length", type=int, default=2048) data_group.add_argument("--chat-template", type=str, default="llama3") + data_group.add_argument( + "--is-preformatted", + action="store_true", + help="Whether the input data is preformatted text with the chat template already applied to the conversation messages.", + ) data_group.add_argument("--num-samples", type=int, default=None) data_group.add_argument("--build-dataset-num-proc", type=int, default=8) @@ -558,7 +563,7 @@ def main(): tokenizer = AutoTokenizer.from_pretrained( args.target_model_path, trust_remote_code=True ) - cache_params_string = f"{args.data_path}-{args.max_length}-{args.chat_template}-{args.target_model_path}-{args.num_samples}" + cache_params_string = f"{args.data_path}-{args.max_length}-{args.chat_template}-{args.target_model_path}-{args.num_samples}-{args.is_preformatted}" cache_key = hashlib.md5(cache_params_string.encode()).hexdigest() # Preprocess on complete, un-sharded dataset @@ -572,6 +577,7 @@ def main(): cache_dir=os.path.join(args.cache_dir, "processed_dataset"), cache_key=cache_key, is_vlm=args.is_vlm, + is_preformatted=args.is_preformatted, processor=processor, num_proc=args.build_dataset_num_proc, ) From 4aea48df48c000a989a95002dc70b21353a69d4a Mon Sep 17 00:00:00 2001 From: Ofir Ben Shoham Date: Thu, 4 Dec 2025 12:02:23 +0200 Subject: [PATCH 2/3] Update documentation for --is-preformatted flag in prepare_hidden_states.py - Updated script docstring with usage example for --is-preformatted - Updated data_preparation.md to document --is-preformatted for offline training --- docs/basic_usage/data_preparation.md | 15 +++++++++++++++ scripts/prepare_hidden_states.py | 14 ++++++++++++-- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/docs/basic_usage/data_preparation.md b/docs/basic_usage/data_preparation.md index b545f5fe..e19938ad 100644 --- a/docs/basic_usage/data_preparation.md +++ b/docs/basic_usage/data_preparation.md @@ -94,6 +94,7 @@ This format is useful when you have pre-formatted prompts that were used during To use pre-formatted datasets, add the `--is-preformatted` flag to your training command. Note that the `--chat-template` parameter is still needed and should match the template used in your pre-formatted text, as it is used to identify user/assistant tokens to determine the assistant spans and generate the corresponding loss mask. ```bash +# Online training with pre-formatted data torchrun --standalone --nproc_per_node 8 \ scripts/train_eagle3.py \ --is-preformatted \ @@ -101,6 +102,20 @@ torchrun --standalone --nproc_per_node 8 \ # ... other arguments ``` +For offline training, you can also use `--is-preformatted` when generating hidden states: + +```bash +# Generate hidden states from pre-formatted data +torchrun --nproc_per_node=8 \ + scripts/prepare_hidden_states.py \ + --target-model-path meta-llama/Llama-3.1-8B-Instruct \ + --data-path ./your_preformatted_dataset.jsonl \ + --output-path ./cache/hidden_states \ + --chat-template llama3 \ + --is-preformatted \ + --max-length 2048 +``` + Once you have the `jsonl` file ready, you can proceed with online training or generate hidden states for offline training. See the Training guide for more details. diff --git a/scripts/prepare_hidden_states.py b/scripts/prepare_hidden_states.py index 3ad0b8e4..b1e9739b 100644 --- a/scripts/prepare_hidden_states.py +++ b/scripts/prepare_hidden_states.py @@ -17,8 +17,18 @@ --max-length 2048 \ --tp-size 1 \ --batch-size 32 \ - --num-samples 1000 \ - --output-path ./cache/hidden_states + --num-samples 1000 + +For pre-formatted data (with chat template already applied), add --is-preformatted: +torchrun --nproc_per_node=8 \ + scripts/prepare_hidden_states.py \ + --target-model-path meta-llama/Llama-3.1-8B-Instruct \ + --enable-aux-hidden-states \ + --data-path ./cache/dataset/preformatted_data.jsonl \ + --output-path ./cache/hidden_states \ + --chat-template llama3 \ + --is-preformatted \ + --max-length 2048 """ import argparse From cd8470f78a0b8258639207cc490960ccb147758c Mon Sep 17 00:00:00 2001 From: Ofir Ben Shoham Date: Thu, 4 Dec 2025 12:09:03 +0200 Subject: [PATCH 3/3] Address code review: add --output-path to docstring example Added back the --output-path argument to the first usage example in the docstring for clarity and consistency with the pre-formatted data example. --- scripts/prepare_hidden_states.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/prepare_hidden_states.py b/scripts/prepare_hidden_states.py index b1e9739b..a3de211e 100644 --- a/scripts/prepare_hidden_states.py +++ b/scripts/prepare_hidden_states.py @@ -17,7 +17,8 @@ --max-length 2048 \ --tp-size 1 \ --batch-size 32 \ - --num-samples 1000 + --num-samples 1000 \ + --output-path ./cache/hidden_states For pre-formatted data (with chat template already applied), add --is-preformatted: torchrun --nproc_per_node=8 \