-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Preserve truncated tokens in BFD packing #4632
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| if isinstance(column, pa.ChunkedArray): | ||
| column = column.combine_chunks() | ||
| if pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type): | ||
| column = pc.list_slice(column, 0, seq_length) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this was truncating all list columns. We don't want to truncated anymore
| if isinstance(column, pa.ChunkedArray): | ||
| column = column.combine_chunks() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In PyArrow, a ChunkedArray is one logical column made of multiple smaller arrays ("chunks") under the hood.
Combining chunks gives one continuous array (leaving it chunked would mean offsets restart in each piece). It allows the code to operate on a single contiguous chunk, which keeps offsets consistent and avoids chunk-boundary surprises.
|
Looking forward to this PR! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for improving this function so no tokens are discarded.
However, I think the new implementation might introduce a correctness issue for the usual use case with multiple list columns.
If I understand correctly, only the first list column is actually split into <= seq_length fragments. All other list columns are simply duplicated at the row level and later re-wrapped using the packed offsets. This would cause misalignment between columns.
A minimal repro: With
examples = {
'input_ids_1': [
[1, 2, 3, 4, 5],
[6, 7],
[8, 9, 10, 11],
[12]
],
'input_ids_2': [
[10, 20, 30, 40, 50],
[60, 70],
[80, 90, 100, 110],
[120]
]
}the packed output with seq_length = 4 becomes:
{
'input_ids_1': [
[1, 2, 3, 4],
[8, 9, 10, 11],
[6, 7, 5, 12]
],
'input_ids_2': [
[10, 20, 30, 40],
[50, 80, 90, 100],
[110, 60, 70, 10]
],
'seq_lengths': [
[4],
[4],
[2, 1, 1]
]
}
trl/data_utils.py
Outdated
| frag_slices.append((row_start + split_start, frag_len)) | ||
| expanded_indices.append(row_idx) | ||
|
|
||
| # Rebuild list column with fragments and duplicate non-list columns accordingly. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
duplicate non-list columns accordingly
Does this function support non-list columns?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, but in practice, we only get list columns:
trl/trl/trainer/sft_trainer.py
Lines 1088 to 1102 in d401a42
| columns = ["input_ids"] | |
| if "completion_mask" in get_dataset_column_names(dataset): | |
| columns.append("completion_mask") | |
| if "assistant_masks" in get_dataset_column_names(dataset): | |
| columns.append("assistant_masks") | |
| dataset = dataset.select_columns(columns) | |
| # Shuffle the dataset before packing. When using wrapped packing, it's important to shuffle before | |
| # packing as well to avoid correlations between sequences packed together. | |
| if args.shuffle_dataset: | |
| dataset = dataset.shuffle(seed=args.seed) | |
| # Packing adds new column "seq_lengths" needed for document aware FlashAttention | |
| dataset = pack_dataset(dataset, args.max_length, args.packing_strategy, map_kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a check in the function to ensure all columns are list
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lines 650 to 654 in ad82b13
| for idx, column in enumerate(examples.columns): | |
| if isinstance(column, pa.ChunkedArray): | |
| column = column.combine_chunks() | |
| if not (pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type)): | |
| raise TypeError("pack_dataset(bfd) requires all columns to be list-like.") |
| seq_length = 4 | ||
| expected_output = { | ||
| "input_ids": [[4, 5, 6, 7], [1, 2, 3, 8]], | ||
| "attention_mask": [[0, 0, 1, 1], [0, 1, 1, 1]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In practice, we never all packing with the attention mask + it doesn't add anything to test it here. So I suggest we remove it.
| def test_with_overlong_two_coluns(self): | ||
| examples = { | ||
| "col1": [[1, -2, 3, -4, 5, -6], [7, -8, 9], [-10, 11, -12], [13, -14, 15, -16]], | ||
| "col2": [[-1, 2, -3, 4, -5, -6], [-7, 8, -9], [10, -11, 12], [-13, 14, -15, 16]], | ||
| } | ||
| dataset = Dataset.from_dict(examples) | ||
| seq_length = 4 | ||
| expected_output = { | ||
| "col1": [[1, -2, 3, -4], [13, -14, 15, -16], [7, -8, 9], [-10, 11, -12], [5, -6]], | ||
| "col2": [[-1, 2, -3, 4], [-13, 14, -15, 16], [-7, 8, -9], [10, -11, 12], [-5, 6]], | ||
| "seq_lengths": [[4], [4], [3], [3], [2]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Following the finding here, this new test ensure we have a consistent packing across columns
This PR updates the
bfdpacking strategy so that tokens beyondseq_lengthare not discarded.Instead, truncated fragments are re-queued and packed like any other sequence, preventing unnecessary token loss.
Closes #4554
Before/After
Benchmark
TLDR: a bit slower, but still very fast
Important
This PR was mostly written using Codex. Based on my tests, it works. I think I understand the most of it, but I'm not behind most of the code changes.