Skip to content

Conversation

@qgallouedec
Copy link
Member

@qgallouedec qgallouedec commented Dec 5, 2025

This PR updates the bfd packing strategy so that tokens beyond seq_length are not discarded.
Instead, truncated fragments are re-queued and packed like any other sequence, preventing unnecessary token loss.

Closes #4554

Untitled-2025-07-22-1600

Before/After

from datasets import Dataset
from trl import pack_dataset

examples = {
    "input_ids": [[1, 2, 3, 4, 5], [6, 7], [8, 9, 10], [11]],
}
dataset = Dataset.from_dict(examples)
packed_dataset = pack_dataset(dataset, seq_length=4, strategy="bfd")
print(packed_dataset[:])
- {'input_ids': [[1, 2, 3, 4], [8, 9, 10, 11], [6, 7]], 'seq_lengths': [[4], [3, 1], [2]]}
+ {'input_ids': [[1, 2, 3, 4], [8, 9, 10, 5], [6, 7, 11]], 'seq_lengths': [[4], [3, 1], [2, 1]]}

Benchmark

TLDR: a bit slower, but still very fast

import random
import time
from datasets import Dataset
from trl.data_utils import pack_dataset

total_tokens = 10_000_000
seq_length = 2048  # packing target
min_seq_len, max_seq_len = 1024, 3072  # arbitrary input lengths

input_ids = []
tokens_left = total_tokens
while tokens_left > 0:
    n = min(tokens_left, random.randint(min_seq_len, max_seq_len))
    tokens_left -= n
    input_ids.append(list(range(n)))

dataset = Dataset.from_dict({"input_ids": input_ids})

start = time.perf_counter()
packed = pack_dataset(dataset, seq_length=seq_length)
elapsed = time.perf_counter() - start

print(f"Packed {total_tokens} tokens into {len(packed)} examples in {elapsed:.3f}s")
# Before: Packed 10000000 tokens into 4848 examples in 0.189s
# After:  Packed 10000000 tokens into 4952 examples in 0.255s

Important

This PR was mostly written using Codex. Based on my tests, it works. I think I understand the most of it, but I'm not behind most of the code changes.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qgallouedec qgallouedec linked an issue Dec 5, 2025 that may be closed by this pull request
if isinstance(column, pa.ChunkedArray):
column = column.combine_chunks()
if pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type):
column = pc.list_slice(column, 0, seq_length)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was truncating all list columns. We don't want to truncated anymore

Comment on lines +650 to +651
if isinstance(column, pa.ChunkedArray):
column = column.combine_chunks()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In PyArrow, a ChunkedArray is one logical column made of multiple smaller arrays ("chunks") under the hood.
Combining chunks gives one continuous array (leaving it chunked would mean offsets restart in each piece). It allows the code to operate on a single contiguous chunk, which keeps offsets consistent and avoids chunk-boundary surprises.

@jiosephlee
Copy link

Looking forward to this PR!

Copy link
Member

@albertvillanova albertvillanova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for improving this function so no tokens are discarded.

However, I think the new implementation might introduce a correctness issue for the usual use case with multiple list columns.

If I understand correctly, only the first list column is actually split into <= seq_length fragments. All other list columns are simply duplicated at the row level and later re-wrapped using the packed offsets. This would cause misalignment between columns.

A minimal repro: With

examples = {
  'input_ids_1': [
    [1, 2, 3, 4, 5], 
    [6, 7], 
    [8, 9, 10, 11], 
    [12]
  ],
  'input_ids_2': [
    [10, 20, 30, 40, 50], 
    [60, 70], 
    [80, 90, 100, 110], 
    [120]
  ]
}

the packed output with seq_length = 4 becomes:

{
  'input_ids_1': [
    [1, 2, 3, 4], 
    [8, 9, 10, 11], 
    [6, 7, 5, 12]
  ],
  'input_ids_2': [
    [10, 20, 30, 40], 
    [50, 80, 90, 100], 
    [110, 60, 70, 10]
  ],
  'seq_lengths': [
    [4], 
    [4], 
    [2, 1, 1]
  ]
}

frag_slices.append((row_start + split_start, frag_len))
expanded_indices.append(row_idx)

# Rebuild list column with fragments and duplicate non-list columns accordingly.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicate non-list columns accordingly

Does this function support non-list columns?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, but in practice, we only get list columns:

trl/trl/trainer/sft_trainer.py

Lines 1088 to 1102 in d401a42

columns = ["input_ids"]
if "completion_mask" in get_dataset_column_names(dataset):
columns.append("completion_mask")
if "assistant_masks" in get_dataset_column_names(dataset):
columns.append("assistant_masks")
dataset = dataset.select_columns(columns)
# Shuffle the dataset before packing. When using wrapped packing, it's important to shuffle before
# packing as well to avoid correlations between sequences packed together.
if args.shuffle_dataset:
dataset = dataset.shuffle(seed=args.seed)
# Packing adds new column "seq_lengths" needed for document aware FlashAttention
dataset = pack_dataset(dataset, args.max_length, args.packing_strategy, map_kwargs)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a check in the function to ensure all columns are list

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trl/trl/data_utils.py

Lines 650 to 654 in ad82b13

for idx, column in enumerate(examples.columns):
if isinstance(column, pa.ChunkedArray):
column = column.combine_chunks()
if not (pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type)):
raise TypeError("pack_dataset(bfd) requires all columns to be list-like.")

seq_length = 4
expected_output = {
"input_ids": [[4, 5, 6, 7], [1, 2, 3, 8]],
"attention_mask": [[0, 0, 1, 1], [0, 1, 1, 1]],
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In practice, we never all packing with the attention mask + it doesn't add anything to test it here. So I suggest we remove it.

Comment on lines +1040 to +1050
def test_with_overlong_two_coluns(self):
examples = {
"col1": [[1, -2, 3, -4, 5, -6], [7, -8, 9], [-10, 11, -12], [13, -14, 15, -16]],
"col2": [[-1, 2, -3, 4, -5, -6], [-7, 8, -9], [10, -11, 12], [-13, 14, -15, 16]],
}
dataset = Dataset.from_dict(examples)
seq_length = 4
expected_output = {
"col1": [[1, -2, 3, -4], [13, -14, 15, -16], [7, -8, 9], [-10, 11, -12], [5, -6]],
"col2": [[-1, 2, -3, 4], [-13, 14, -15, 16], [-7, 8, -9], [10, -11, 12], [-5, 6]],
"seq_lengths": [[4], [4], [3], [3], [2]],
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following the finding here, this new test ensure we have a consistent packing across columns

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Better packing of data with best-fit decrease strategy

5 participants