Skip to content

Commit 710d6eb

Browse files
committed
amend
1 parent b63f5ef commit 710d6eb

File tree

4 files changed

+187
-45
lines changed

4 files changed

+187
-45
lines changed

benchmarks/test_llm.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from __future__ import annotations
7+
8+
import importlib.util
9+
10+
import pytest
11+
import torch
12+
from tensordict import set_list_to_stack, TensorDict
13+
from torchrl.data.llm import History
14+
from torchrl.modules.llm.policies.common import ChatHistory
15+
from torchrl.modules.llm.policies.transformers_wrapper import TransformersWrapper
16+
17+
_has_transformers = importlib.import_module("transformers") is not None
18+
19+
20+
@pytest.fixture(scope="module")
21+
def transformers_wrapper():
22+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23+
with torch.device(device):
24+
model = TransformersWrapper(
25+
model="Qwen/Qwen2.5-0.5B",
26+
tokenizer="Qwen/Qwen2.5-0.5B",
27+
pad_model_input=False,
28+
generate=False,
29+
)
30+
return model
31+
32+
33+
@pytest.mark.skipif(not _has_transformers, reason="transformers not installed")
34+
class TestWrappers:
35+
@pytest.mark.parametrize("packing", [True, False])
36+
@set_list_to_stack(True)
37+
def test_packing(self, benchmark, transformers_wrapper, packing: bool):
38+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39+
with torch.device(device):
40+
transformers_wrapper = TransformersWrapper(
41+
model=transformers_wrapper.model,
42+
tokenizer=transformers_wrapper.tokenizer,
43+
pad_model_input=not packing,
44+
generate=False,
45+
pad_output=False,
46+
)
47+
data = TensorDict(
48+
{
49+
"history": ChatHistory(
50+
full=History(
51+
role=[
52+
["user", "assistant"],
53+
["user", "assistant"],
54+
["user", "assistant"],
55+
["user", "assistant"],
56+
],
57+
content=[
58+
[
59+
"Lorem ipsum dolor sit amet",
60+
"consectetur adipiscing elit",
61+
],
62+
[
63+
"sed do eiusmod tempor incididunt ut labore et dolore magna aliqua",
64+
"Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat",
65+
],
66+
[
67+
"Lorem ipsum dolor sit amet",
68+
"consectetur adipiscing elit",
69+
],
70+
[
71+
"sed do eiusmod tempor incididunt ut labore et dolore magna aliqua",
72+
"Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat",
73+
],
74+
],
75+
batch_size=(4, 2),
76+
device=device,
77+
),
78+
batch_size=(4,),
79+
device=device,
80+
)
81+
},
82+
batch_size=(4,),
83+
device=device,
84+
).to_lazystack()
85+
86+
def setup():
87+
if torch.cuda.is_available():
88+
torch.cuda.empty_cache()
89+
90+
benchmark.pedantic(
91+
transformers_wrapper,
92+
(data,),
93+
rounds=10,
94+
warmup_rounds=3,
95+
setup=setup,
96+
)

test/llm/test_wrapper.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,6 @@ def test_history_input_mode(
355355
data = data[0]
356356

357357
# Run wrapper
358-
print(f"{data=}")
359358
result = wrapper(data)
360359
check_output_shapes(result, pad_output, requested_log_probs=not generate)
361360

@@ -1841,7 +1840,7 @@ def test_transformers_custom_masking(
18411840

18421841

18431842
@pytest.mark.skipif(not _has_transformers, reason="transformers not available")
1844-
@pytest.mark.parametrize("pad_output", [True, False])
1843+
@pytest.mark.parametrize("pad_output", [False, True])
18451844
class TestPacking:
18461845
def test_packing_history(
18471846
self, transformers_instance, sample_history_assistant, pad_output
@@ -1871,8 +1870,8 @@ def test_packing_history(
18711870
{"history": ChatHistory(full=sample_history_assistant)}, batch_size=(2,)
18721871
).to_lazystack(0)
18731872

1874-
result_packed = wrapper_packed(td)
18751873
result_padded = wrapped_padded(td)
1874+
result_packed = wrapper_packed(td)
18761875
assert_close(result_packed["log_probs"], result_padded["log_probs"])
18771876

18781877
def test_packing_text(self, transformers_instance, sample_text, pad_output):
@@ -1895,9 +1894,7 @@ def test_packing_text(self, transformers_instance, sample_text, pad_output):
18951894
pad_output=pad_output,
18961895
pad_model_input=True,
18971896
)
1898-
td = TensorDict({"text": Text(full=sample_text)}, batch_size=(2,)).to_lazystack(
1899-
0
1900-
)
1897+
td = TensorDict({"text": Text(full=sample_text)}, batch_size=(2,))
19011898
result_packed = wrapper_packed(td)
19021899
result_padded = wrapped_padded(td)
19031900
assert_close(result_packed["log_probs"], result_padded["log_probs"])
@@ -1931,8 +1928,8 @@ def test_packing_tokens(
19311928
},
19321929
batch_size=(2,),
19331930
).to_lazystack(0)
1934-
result_packed = wrapper_packed(td)
19351931
result_padded = wrapped_padded(td)
1932+
result_packed = wrapper_packed(td)
19361933
assert_close(result_packed["log_probs"], result_padded["log_probs"])
19371934

19381935

torchrl/data/llm/history.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -519,8 +519,10 @@ class History(TensorClass["nocast"]):
519519
:class:`~torchrl.modules.llm.policies.Tokens`: Container for token data.
520520
"""
521521

522-
role: str
523-
content: str | ContentBase
522+
role: str | list[str] | list[list[str]]
523+
content: str | ContentBase | list[str] | list[ContentBase] | list[list[str]] | list[
524+
list[ContentBase]
525+
]
524526
is_complete: bool = True
525527
tool_calls: list[dict] | None = None
526528
tool_responses: list[str] | None = None

torchrl/modules/llm/policies/transformers_wrapper.py

Lines changed: 83 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,18 @@
3939
class TransformersWrapper(LLMWrapperBase):
4040
"""A wrapper class for Hugging Face Transformers models, providing a consistent interface for text generation and log probability computation.
4141
42-
This class is a subclass of :class:`~torchrl.modules.llm.policies.LLMWrapperBase` and provides a unified API for handling different input modalities
43-
(history, text, tokens) with consistent output structure using :class:`~tensordict.TensorClass` objects.
42+
Packing vs Padding:
43+
- Packing (`pad_model_input=False`):
44+
* More memory efficient for variable-length sequences.
45+
* Not all models support packed input (requires custom attention masks and position ids).
46+
* May be less compatible with some HuggingFace models or custom architectures.
47+
- Padding (`pad_model_input=True`):
48+
* Universally supported by all models.
49+
* Wastes memory for short sequences in a batch.
50+
* Simpler, but less efficient for highly variable-length data.
51+
- If unsure, use padding for maximum compatibility. Use packing for large batches of variable-length data and when your model supports it.
52+
53+
Additional error handling is provided for empty and overlong sequences.
4454
4555
Args:
4656
model (transformers.AutoModelForCausalLM | str): The Hugging Face Transformers model to wrap.
@@ -2038,9 +2048,26 @@ def _pack_sequences(
20382048
)
20392049

20402050
def _model_forward_with_padded_sequences(
2041-
self, tokens_full_padded, attention_mask_full_padded, pad_val, logits_only=False, **kwargs
2051+
self,
2052+
tokens_full_padded: torch.Tensor,
2053+
attention_mask_full_padded: torch.Tensor,
2054+
*,
2055+
pad_val: float | int | torch.Tensor | None = None,
2056+
logits_only: bool = False,
2057+
**kwargs,
20422058
):
20432059
"""Forward pass with padded sequences."""
2060+
# Error handling for empty sequences
2061+
if tokens_full_padded.numel() == 0:
2062+
raise ValueError(
2063+
"Input contains empty sequences. Packing/padding requires at least one token per sequence."
2064+
)
2065+
# Error handling for overlong sequences
2066+
max_len = getattr(self.model.config, "max_position_embeddings", None)
2067+
if max_len is not None and tokens_full_padded.shape[-1] > max_len:
2068+
raise ValueError(
2069+
f"Input sequence length ({tokens_full_padded.shape[-1]}) exceeds model's max_position_embeddings ({max_len}). Consider truncating or splitting your input."
2070+
)
20442071
tokens_out_struct = self.model(
20452072
tokens_full_padded, attention_mask_full_padded, **kwargs
20462073
)
@@ -2057,35 +2084,51 @@ def _model_forward_with_padded_sequences(
20572084
return log_probs_full_padded, logits_full_padded
20582085

20592086
def _model_forward_with_packed_sequences(
2060-
self, flat_input_ids, block_diag_attention_mask, pad: bool = True, logits_only=False, **kwargs
2087+
self,
2088+
flat_input_ids: torch.Tensor,
2089+
block_diag_attention_mask: torch.Tensor,
2090+
*,
2091+
pad: bool = True,
2092+
logits_only: bool = False,
2093+
**kwargs,
20612094
):
20622095
"""Pack sequences into a single tensor and forward them through the model.
20632096
20642097
Args:
2065-
input_ids: NestedTensor of shape (batch_size, -1)
2066-
attention_mask: NestedTensor of shape (batch_size, -1)
2098+
flat_input_ids (NestedTensor): NestedTensor of shape (batch_size, -1)
2099+
block_diag_attention_mask (NestedTensor): NestedTensor of shape (batch_size, -1)
20672100
20682101
Returns:
2069-
logits: NestedTensor of shape (batch_size, -1, vocab_size)
2102+
pad (bool): Whether to pad the output tensors.
2103+
logits_only (bool): Whether to return only logits.
2104+
kwargs (dict): Additional keyword arguments to pass to the model.
20702105
20712106
"""
2107+
# Error handling for empty sequences
2108+
if flat_input_ids.numel() == 0:
2109+
raise ValueError(
2110+
"Input contains empty sequences. Packing requires at least one token per sequence."
2111+
)
2112+
# Error handling for overlong sequences
2113+
# Note: Skipping this check for nested tensors due to symbolic representation issues
2114+
# The model will handle sequence length limits internally
2115+
max_len = getattr(self.model.config, "max_position_embeddings", None)
2116+
if max_len is not None and not hasattr(flat_input_ids, "size"):
2117+
# Only check for regular tensors, not nested tensors
2118+
actual_size = flat_input_ids.shape[-1]
2119+
if actual_size > max_len:
2120+
raise ValueError(
2121+
f"Input sequence length ({actual_size}) exceeds model's max_position_embeddings ({max_len}). Consider truncating or splitting your input."
2122+
)
20722123
(
20732124
flat_input_ids,
20742125
block_diag_attention_mask,
20752126
packing_metadata,
20762127
) = self._pack_sequences(flat_input_ids, block_diag_attention_mask)
2077-
# check shapes: [B, L] for input_ids, [B, L, L] for attention_mask
2078-
if flat_input_ids.shape != block_diag_attention_mask.shape[:2]:
2079-
raise ValueError(
2080-
f"Input ids shape {flat_input_ids.shape=} does not match attention mask shape {block_diag_attention_mask.shape[:2]=}"
2081-
)
2082-
if flat_input_ids.shape[1] != block_diag_attention_mask.shape[2]:
2083-
raise ValueError(
2084-
f"Input ids shape {flat_input_ids.shape[1]=} does not match attention mask shape {block_diag_attention_mask.shape[2]=}"
2085-
)
2128+
20862129
outputs = self.model(
20872130
input_ids=flat_input_ids,
2088-
attention_mask=block_diag_attention_mask,
2131+
attention_mask=block_diag_attention_mask.unsqueeze(0),
20892132
position_ids=packing_metadata["position_ids"],
20902133
use_cache=False, # Disable KV cache for packing
20912134
**kwargs,
@@ -2113,30 +2156,34 @@ def _unpack_outputs(
21132156
logits_only=logits_only,
21142157
)
21152158
# check shapes: [1, L] for log_probs, [1, L, vocab_size] for logits
2116-
if log_probs.shape != logits.shape[:2]:
2117-
raise ValueError(
2118-
f"Log probs shape {log_probs.shape=} does not match logits shape {logits.shape[:2]=}"
2119-
)
2120-
if log_probs.ndim != 2:
2121-
raise ValueError(f"Log probs shape {log_probs.shape=} is not 2D")
2122-
if logits.ndim != 3:
2123-
raise ValueError(f"Logits shape {logits.shape=} is not 3D")
2124-
sequence_lengths = packing_metadata["sequence_lengths"]
2125-
if log_probs.shape[1] != sequence_lengths.sum():
2126-
raise ValueError(
2127-
f"Log probs shape {log_probs.shape=} does not match sequence lengths {sequence_lengths.sum()=}"
2159+
if logits_only:
2160+
log_probs = None
2161+
else:
2162+
if log_probs.shape != logits.shape[:2]:
2163+
raise ValueError(
2164+
f"Log probs shape {log_probs.shape=} does not match logits shape {logits.shape[:2]=}"
2165+
)
2166+
if log_probs.ndim != 2:
2167+
raise ValueError(f"Log probs shape {log_probs.shape=} is not 2D")
2168+
if logits.ndim != 3:
2169+
raise ValueError(f"Logits shape {logits.shape=} is not 3D")
2170+
sequence_lengths = packing_metadata["sequence_lengths"]
2171+
if log_probs.shape[1] != sequence_lengths.sum():
2172+
raise ValueError(
2173+
f"Log probs shape {log_probs.shape=} does not match sequence lengths {sequence_lengths.sum()=}"
2174+
)
2175+
2176+
log_probs = log_probs.squeeze(0)
2177+
nested_logprobs = torch.nested.nested_tensor_from_jagged(
2178+
log_probs,
2179+
lengths=sequence_lengths,
21282180
)
21292181

21302182
logits = logits.squeeze(0)
21312183
nested_logits = torch.nested.nested_tensor_from_jagged(
21322184
logits, # Remove batch dim: (total_length, vocab_size)
21332185
lengths=sequence_lengths,
21342186
)
2135-
log_probs = log_probs.squeeze(0)
2136-
nested_logprobs = torch.nested.nested_tensor_from_jagged(
2137-
log_probs,
2138-
lengths=sequence_lengths,
2139-
)
21402187

21412188
if pad:
21422189
return nested_logprobs.to_padded_tensor(
@@ -2173,7 +2220,7 @@ def repeat_interleave_causal(self, sequence_lengths: torch.Tensor) -> torch.Tens
21732220
seq_ids = torch.arange(len(sequence_lengths), device=sequence_lengths.device)
21742221
position_to_seq_id = seq_ids.repeat_interleave(sequence_lengths)
21752222

2176-
positions = torch.arange(total_length, device=sequence_lengths.device)
2223+
positions = torch.arange(int(total_length), device=sequence_lengths.device)
21772224

21782225
same_sequence = position_to_seq_id.unsqueeze(1) == position_to_seq_id.unsqueeze(
21792226
0
@@ -2193,7 +2240,7 @@ def _create_packed_position_ids(
21932240
No cuda syncs.
21942241
"""
21952242
if total_length is None:
2196-
total_length = sequence_lengths.sum()
2243+
total_length = int(sequence_lengths.sum().item())
21972244

21982245
# Create global position IDs: [0, 1, 2, 3, 4]
21992246
global_positions = torch.arange(total_length, device=sequence_lengths.device)

0 commit comments

Comments
 (0)