Skip to content

Commit 3f10cb1

Browse files
authored
[Feature] Packing (#3060)
1 parent 499e707 commit 3f10cb1

File tree

6 files changed

+736
-111
lines changed

6 files changed

+736
-111
lines changed

benchmarks/test_llm.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from __future__ import annotations
7+
8+
import importlib.util
9+
10+
import pytest
11+
import torch
12+
from tensordict import set_list_to_stack, TensorDict
13+
from torchrl.data.llm import History
14+
from torchrl.modules.llm.policies.common import ChatHistory
15+
from torchrl.modules.llm.policies.transformers_wrapper import TransformersWrapper
16+
17+
_has_transformers = importlib.import_module("transformers") is not None
18+
19+
20+
@pytest.fixture(scope="module")
21+
def transformers_wrapper():
22+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23+
with torch.device(device):
24+
model = TransformersWrapper(
25+
model="Qwen/Qwen2.5-0.5B",
26+
tokenizer="Qwen/Qwen2.5-0.5B",
27+
pad_model_input=False,
28+
generate=False,
29+
)
30+
return model
31+
32+
33+
@pytest.mark.skipif(not _has_transformers, reason="transformers not installed")
34+
class TestWrappers:
35+
@pytest.mark.parametrize("packing", [True, False])
36+
@set_list_to_stack(True)
37+
def test_packing(self, benchmark, transformers_wrapper, packing: bool):
38+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39+
with torch.device(device):
40+
transformers_wrapper = TransformersWrapper(
41+
model=transformers_wrapper.model,
42+
tokenizer=transformers_wrapper.tokenizer,
43+
pad_model_input=not packing,
44+
generate=False,
45+
pad_output=False,
46+
)
47+
data = TensorDict(
48+
{
49+
"history": ChatHistory(
50+
full=History(
51+
role=[
52+
["user", "assistant"],
53+
["user", "assistant"],
54+
["user", "assistant"],
55+
["user", "assistant"],
56+
],
57+
content=[
58+
[
59+
"Lorem ipsum dolor sit amet",
60+
"consectetur adipiscing elit",
61+
],
62+
[
63+
"sed do eiusmod tempor incididunt ut labore et dolore magna aliqua",
64+
"Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat",
65+
],
66+
[
67+
"Lorem ipsum dolor sit amet",
68+
"consectetur adipiscing elit",
69+
],
70+
[
71+
"sed do eiusmod tempor incididunt ut labore et dolore magna aliqua",
72+
"Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat",
73+
],
74+
],
75+
batch_size=(4, 2),
76+
device=device,
77+
),
78+
batch_size=(4,),
79+
device=device,
80+
)
81+
},
82+
batch_size=(4,),
83+
device=device,
84+
).to_lazystack()
85+
86+
def setup():
87+
if torch.cuda.is_available():
88+
torch.cuda.empty_cache()
89+
90+
benchmark.pedantic(
91+
transformers_wrapper,
92+
(data,),
93+
rounds=10,
94+
warmup_rounds=3,
95+
setup=setup,
96+
)

test/llm/test_wrapper.py

Lines changed: 111 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import pytest
1414
import torch
15-
from tensordict import lazy_stack, set_list_to_stack, TensorDict
15+
from tensordict import assert_close, lazy_stack, set_list_to_stack, TensorDict
1616

1717
from tensordict.utils import _zip_strict
1818
from torchrl.data.llm import History
@@ -163,6 +163,22 @@ def sample_tokens(vllm_instance):
163163
return tokenized["input_ids"], tokenized["attention_mask"]
164164

165165

166+
@pytest.fixture
167+
def sample_tokens_unpadded(vllm_instance):
168+
"""Create sample tokens for testing."""
169+
model, tokenizer = vllm_instance
170+
text = [
171+
"Are you happy? Say yes or no.",
172+
"Explain the difference between a cat and a dog. Be very detailed.",
173+
]
174+
tokenized = tokenizer(text, padding=False)
175+
return torch.nested.nested_tensor(
176+
[torch.tensor(t) for t in tokenized["input_ids"]], layout=torch.jagged
177+
), torch.nested.nested_tensor(
178+
[torch.tensor(t) for t in tokenized["attention_mask"]], layout=torch.jagged
179+
)
180+
181+
166182
def check_output_shapes(out, pad_output, requested_log_probs=False):
167183
if pad_output or not out.ndim:
168184
# We can get all tensors or they are none
@@ -1656,8 +1672,6 @@ def test_log_probs_consistency(
16561672
vllm_lp_result = vllm_lp_wrapper(new_data.copy())
16571673
tf_lp_result = tf_lp_wrapper(new_data.copy())
16581674

1659-
from tensordict import assert_close
1660-
16611675
assert_close(
16621676
vllm_lp_result, tf_lp_result, atol=1e-1, rtol=1e-1, intersection=True
16631677
)
@@ -1825,6 +1839,100 @@ def test_transformers_custom_masking(
18251839
assert hasattr(dist, "log_prob")
18261840

18271841

1842+
@pytest.mark.skipif(not _has_transformers, reason="transformers not available")
1843+
@pytest.mark.parametrize("pad_output", [False, True])
1844+
class TestPacking:
1845+
def test_packing_history(
1846+
self, transformers_instance, sample_history_assistant, pad_output
1847+
):
1848+
model, tokenizer = transformers_instance
1849+
1850+
wrapper_packed = TransformersWrapper(
1851+
model,
1852+
tokenizer=tokenizer,
1853+
input_mode="history",
1854+
generate=False,
1855+
return_log_probs=True,
1856+
pad_output=pad_output,
1857+
pad_model_input=False,
1858+
)
1859+
wrapped_padded = TransformersWrapper(
1860+
model,
1861+
tokenizer=tokenizer,
1862+
input_mode="history",
1863+
generate=False,
1864+
return_log_probs=True,
1865+
pad_output=pad_output,
1866+
pad_model_input=True,
1867+
)
1868+
1869+
td = TensorDict(
1870+
{"history": ChatHistory(full=sample_history_assistant)}, batch_size=(2,)
1871+
).to_lazystack(0)
1872+
1873+
result_padded = wrapped_padded(td)
1874+
result_packed = wrapper_packed(td)
1875+
assert_close(result_packed["log_probs"], result_padded["log_probs"])
1876+
1877+
def test_packing_text(self, transformers_instance, sample_text, pad_output):
1878+
model, tokenizer = transformers_instance
1879+
wrapper_packed = TransformersWrapper(
1880+
model,
1881+
tokenizer=tokenizer,
1882+
input_mode="text",
1883+
generate=False,
1884+
return_log_probs=True,
1885+
pad_output=pad_output,
1886+
pad_model_input=False,
1887+
)
1888+
wrapped_padded = TransformersWrapper(
1889+
model,
1890+
tokenizer=tokenizer,
1891+
input_mode="text",
1892+
generate=False,
1893+
return_log_probs=True,
1894+
pad_output=pad_output,
1895+
pad_model_input=True,
1896+
)
1897+
td = TensorDict({"text": Text(full=sample_text)}, batch_size=(2,))
1898+
result_packed = wrapper_packed(td)
1899+
result_padded = wrapped_padded(td)
1900+
assert_close(result_packed["log_probs"], result_padded["log_probs"])
1901+
1902+
def test_packing_tokens(
1903+
self, transformers_instance, sample_tokens_unpadded, pad_output
1904+
):
1905+
model, tokenizer = transformers_instance
1906+
wrapper_packed = TransformersWrapper(
1907+
model,
1908+
tokenizer=tokenizer,
1909+
input_mode="tokens",
1910+
generate=False,
1911+
return_log_probs=True,
1912+
pad_output=pad_output,
1913+
pad_model_input=False,
1914+
)
1915+
wrapped_padded = TransformersWrapper(
1916+
model,
1917+
tokenizer=tokenizer,
1918+
input_mode="tokens",
1919+
generate=False,
1920+
return_log_probs=True,
1921+
pad_output=pad_output,
1922+
pad_model_input=True,
1923+
)
1924+
td = TensorDict(
1925+
{
1926+
"tokens": Tokens(full=sample_tokens_unpadded[0]),
1927+
"masks": Masks(all_attention_mask=sample_tokens_unpadded[1]),
1928+
},
1929+
batch_size=(2,),
1930+
).to_lazystack(0)
1931+
result_padded = wrapped_padded(td)
1932+
result_packed = wrapper_packed(td)
1933+
assert_close(result_packed["log_probs"], result_padded["log_probs"])
1934+
1935+
18281936
if __name__ == "__main__":
18291937
args, unknown = argparse.ArgumentParser().parse_known_args()
18301938
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/data/llm/history.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -519,8 +519,10 @@ class History(TensorClass["nocast"]):
519519
:class:`~torchrl.modules.llm.policies.Tokens`: Container for token data.
520520
"""
521521

522-
role: str
523-
content: str | ContentBase
522+
role: str | list[str] | list[list[str]]
523+
content: str | ContentBase | list[str] | list[ContentBase] | list[list[str]] | list[
524+
list[ContentBase]
525+
]
524526
is_complete: bool = True
525527
tool_calls: list[dict] | None = None
526528
tool_responses: list[str] | None = None

torchrl/modules/llm/policies/common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,8 @@ class LLMWrapperBase(TensorDictModuleBase):
362362
generate_kwargs: Additional arguments to pass to the model's generate method.
363363
tokenizer_kwargs: Additional arguments to pass to the tokenizer.
364364
pad_output: Whether to pad the output sequences to a uniform length.
365+
pad_model_input: Whether to pad the model input sequences to a uniform length.
366+
May not be supported by all models.
365367
inplace: Determines how the module should handle in-place operations.
366368
device: The device to use for computation.
367369
layout: The layout to use for the output tensors when pad_output=False.

0 commit comments

Comments
 (0)