Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions docs/basic_usage/data_preparation.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,28 @@ This format is useful when you have pre-formatted prompts that were used during
To use pre-formatted datasets, add the `--is-preformatted` flag to your training command. Note that the `--chat-template` parameter is still needed and should match the template used in your pre-formatted text, as it is used to identify user/assistant tokens to determine the assistant spans and generate the corresponding loss mask.

```bash
# Online training with pre-formatted data
torchrun --standalone --nproc_per_node 8 \
scripts/train_eagle3.py \
--is-preformatted \
--train-data-path ./your_preformatted_dataset.jsonl \
# ... other arguments
```

For offline training, you can also use `--is-preformatted` when generating hidden states:

```bash
# Generate hidden states from pre-formatted data
torchrun --nproc_per_node=8 \
scripts/prepare_hidden_states.py \
--target-model-path meta-llama/Llama-3.1-8B-Instruct \
--data-path ./your_preformatted_dataset.jsonl \
--output-path ./cache/hidden_states \
--chat-template llama3 \
--is-preformatted \
--max-length 2048
```

Once you have the `jsonl` file ready, you can proceed with online training or generate hidden states for offline training. See the Training guide for more details.


Expand Down
19 changes: 18 additions & 1 deletion scripts/prepare_hidden_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,17 @@
--batch-size 32 \
--num-samples 1000 \
--output-path ./cache/hidden_states

For pre-formatted data (with chat template already applied), add --is-preformatted:
torchrun --nproc_per_node=8 \
scripts/prepare_hidden_states.py \
--target-model-path meta-llama/Llama-3.1-8B-Instruct \
--enable-aux-hidden-states \
--data-path ./cache/dataset/preformatted_data.jsonl \
--output-path ./cache/hidden_states \
--chat-template llama3 \
--is-preformatted \
--max-length 2048
"""

import argparse
Expand Down Expand Up @@ -73,6 +84,11 @@ def parse_args():
data_group.add_argument("--data-path", type=str, required=True)
data_group.add_argument("--max-length", type=int, default=2048)
data_group.add_argument("--chat-template", type=str, default="llama3")
data_group.add_argument(
"--is-preformatted",
action="store_true",
help="Whether the input data is preformatted text with the chat template already applied to the conversation messages.",
)
data_group.add_argument("--num-samples", type=int, default=None)
data_group.add_argument("--build-dataset-num-proc", type=int, default=8)

Expand Down Expand Up @@ -558,7 +574,7 @@ def main():
tokenizer = AutoTokenizer.from_pretrained(
args.target_model_path, trust_remote_code=True
)
cache_params_string = f"{args.data_path}-{args.max_length}-{args.chat_template}-{args.target_model_path}-{args.num_samples}"
cache_params_string = f"{args.data_path}-{args.max_length}-{args.chat_template}-{args.target_model_path}-{args.num_samples}-{args.is_preformatted}"
cache_key = hashlib.md5(cache_params_string.encode()).hexdigest()

# Preprocess on complete, un-sharded dataset
Expand All @@ -572,6 +588,7 @@ def main():
cache_dir=os.path.join(args.cache_dir, "processed_dataset"),
cache_key=cache_key,
is_vlm=args.is_vlm,
is_preformatted=args.is_preformatted,
processor=processor,
num_proc=args.build_dataset_num_proc,
)
Expand Down