diff --git a/benchmarks/test_llm.py b/benchmarks/test_llm.py new file mode 100644 index 00000000000..030b3c45f90 --- /dev/null +++ b/benchmarks/test_llm.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import importlib.util + +import pytest +import torch +from tensordict import set_list_to_stack, TensorDict +from torchrl.data.llm import History +from torchrl.modules.llm.policies.common import ChatHistory +from torchrl.modules.llm.policies.transformers_wrapper import TransformersWrapper + +_has_transformers = importlib.import_module("transformers") is not None + + +@pytest.fixture(scope="module") +def transformers_wrapper(): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + with torch.device(device): + model = TransformersWrapper( + model="Qwen/Qwen2.5-0.5B", + tokenizer="Qwen/Qwen2.5-0.5B", + pad_model_input=False, + generate=False, + ) + return model + + +@pytest.mark.skipif(not _has_transformers, reason="transformers not installed") +class TestWrappers: + @pytest.mark.parametrize("packing", [True, False]) + @set_list_to_stack(True) + def test_packing(self, benchmark, transformers_wrapper, packing: bool): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + with torch.device(device): + transformers_wrapper = TransformersWrapper( + model=transformers_wrapper.model, + tokenizer=transformers_wrapper.tokenizer, + pad_model_input=not packing, + generate=False, + pad_output=False, + ) + data = TensorDict( + { + "history": ChatHistory( + full=History( + role=[ + ["user", "assistant"], + ["user", "assistant"], + ["user", "assistant"], + ["user", "assistant"], + ], + content=[ + [ + "Lorem ipsum dolor sit amet", + "consectetur adipiscing elit", + ], + [ + "sed do eiusmod tempor incididunt ut labore et dolore magna aliqua", + "Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat", + ], + [ + "Lorem ipsum dolor sit amet", + "consectetur adipiscing elit", + ], + [ + "sed do eiusmod tempor incididunt ut labore et dolore magna aliqua", + "Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat", + ], + ], + batch_size=(4, 2), + device=device, + ), + batch_size=(4,), + device=device, + ) + }, + batch_size=(4,), + device=device, + ).to_lazystack() + + def setup(): + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + benchmark.pedantic( + transformers_wrapper, + (data,), + rounds=10, + warmup_rounds=3, + setup=setup, + ) diff --git a/test/llm/test_wrapper.py b/test/llm/test_wrapper.py index dfa81b0846c..1dcdec955a7 100644 --- a/test/llm/test_wrapper.py +++ b/test/llm/test_wrapper.py @@ -12,7 +12,7 @@ import pytest import torch -from tensordict import lazy_stack, set_list_to_stack, TensorDict +from tensordict import assert_close, lazy_stack, set_list_to_stack, TensorDict from tensordict.utils import _zip_strict from torchrl.data.llm import History @@ -163,6 +163,22 @@ def sample_tokens(vllm_instance): return tokenized["input_ids"], tokenized["attention_mask"] +@pytest.fixture +def sample_tokens_unpadded(vllm_instance): + """Create sample tokens for testing.""" + model, tokenizer = vllm_instance + text = [ + "Are you happy? Say yes or no.", + "Explain the difference between a cat and a dog. Be very detailed.", + ] + tokenized = tokenizer(text, padding=False) + return torch.nested.nested_tensor( + [torch.tensor(t) for t in tokenized["input_ids"]], layout=torch.jagged + ), torch.nested.nested_tensor( + [torch.tensor(t) for t in tokenized["attention_mask"]], layout=torch.jagged + ) + + def check_output_shapes(out, pad_output, requested_log_probs=False): if pad_output or not out.ndim: # We can get all tensors or they are none @@ -1656,8 +1672,6 @@ def test_log_probs_consistency( vllm_lp_result = vllm_lp_wrapper(new_data.copy()) tf_lp_result = tf_lp_wrapper(new_data.copy()) - from tensordict import assert_close - assert_close( vllm_lp_result, tf_lp_result, atol=1e-1, rtol=1e-1, intersection=True ) @@ -1825,6 +1839,100 @@ def test_transformers_custom_masking( assert hasattr(dist, "log_prob") +@pytest.mark.skipif(not _has_transformers, reason="transformers not available") +@pytest.mark.parametrize("pad_output", [False, True]) +class TestPacking: + def test_packing_history( + self, transformers_instance, sample_history_assistant, pad_output + ): + model, tokenizer = transformers_instance + + wrapper_packed = TransformersWrapper( + model, + tokenizer=tokenizer, + input_mode="history", + generate=False, + return_log_probs=True, + pad_output=pad_output, + pad_model_input=False, + ) + wrapped_padded = TransformersWrapper( + model, + tokenizer=tokenizer, + input_mode="history", + generate=False, + return_log_probs=True, + pad_output=pad_output, + pad_model_input=True, + ) + + td = TensorDict( + {"history": ChatHistory(full=sample_history_assistant)}, batch_size=(2,) + ).to_lazystack(0) + + result_padded = wrapped_padded(td) + result_packed = wrapper_packed(td) + assert_close(result_packed["log_probs"], result_padded["log_probs"]) + + def test_packing_text(self, transformers_instance, sample_text, pad_output): + model, tokenizer = transformers_instance + wrapper_packed = TransformersWrapper( + model, + tokenizer=tokenizer, + input_mode="text", + generate=False, + return_log_probs=True, + pad_output=pad_output, + pad_model_input=False, + ) + wrapped_padded = TransformersWrapper( + model, + tokenizer=tokenizer, + input_mode="text", + generate=False, + return_log_probs=True, + pad_output=pad_output, + pad_model_input=True, + ) + td = TensorDict({"text": Text(full=sample_text)}, batch_size=(2,)) + result_packed = wrapper_packed(td) + result_padded = wrapped_padded(td) + assert_close(result_packed["log_probs"], result_padded["log_probs"]) + + def test_packing_tokens( + self, transformers_instance, sample_tokens_unpadded, pad_output + ): + model, tokenizer = transformers_instance + wrapper_packed = TransformersWrapper( + model, + tokenizer=tokenizer, + input_mode="tokens", + generate=False, + return_log_probs=True, + pad_output=pad_output, + pad_model_input=False, + ) + wrapped_padded = TransformersWrapper( + model, + tokenizer=tokenizer, + input_mode="tokens", + generate=False, + return_log_probs=True, + pad_output=pad_output, + pad_model_input=True, + ) + td = TensorDict( + { + "tokens": Tokens(full=sample_tokens_unpadded[0]), + "masks": Masks(all_attention_mask=sample_tokens_unpadded[1]), + }, + batch_size=(2,), + ).to_lazystack(0) + result_padded = wrapped_padded(td) + result_packed = wrapper_packed(td) + assert_close(result_packed["log_probs"], result_padded["log_probs"]) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/data/llm/history.py b/torchrl/data/llm/history.py index dadbd75dea0..d91f627cd74 100644 --- a/torchrl/data/llm/history.py +++ b/torchrl/data/llm/history.py @@ -519,8 +519,10 @@ class History(TensorClass["nocast"]): :class:`~torchrl.modules.llm.policies.Tokens`: Container for token data. """ - role: str - content: str | ContentBase + role: str | list[str] | list[list[str]] + content: str | ContentBase | list[str] | list[ContentBase] | list[list[str]] | list[ + list[ContentBase] + ] is_complete: bool = True tool_calls: list[dict] | None = None tool_responses: list[str] | None = None diff --git a/torchrl/modules/llm/policies/common.py b/torchrl/modules/llm/policies/common.py index 9ab755b66dd..63171ab23f0 100644 --- a/torchrl/modules/llm/policies/common.py +++ b/torchrl/modules/llm/policies/common.py @@ -362,6 +362,8 @@ class LLMWrapperBase(TensorDictModuleBase): generate_kwargs: Additional arguments to pass to the model's generate method. tokenizer_kwargs: Additional arguments to pass to the tokenizer. pad_output: Whether to pad the output sequences to a uniform length. + pad_model_input: Whether to pad the model input sequences to a uniform length. + May not be supported by all models. inplace: Determines how the module should handle in-place operations. device: The device to use for computation. layout: The layout to use for the output tensors when pad_output=False. diff --git a/torchrl/modules/llm/policies/transformers_wrapper.py b/torchrl/modules/llm/policies/transformers_wrapper.py index ac0ebc782d6..55ef75b5f6a 100644 --- a/torchrl/modules/llm/policies/transformers_wrapper.py +++ b/torchrl/modules/llm/policies/transformers_wrapper.py @@ -8,7 +8,7 @@ from contextlib import nullcontext from copy import copy -from typing import Literal +from typing import Any, Literal import torch from tensordict import ( @@ -39,8 +39,18 @@ class TransformersWrapper(LLMWrapperBase): """A wrapper class for Hugging Face Transformers models, providing a consistent interface for text generation and log probability computation. - This class is a subclass of :class:`~torchrl.modules.llm.policies.LLMWrapperBase` and provides a unified API for handling different input modalities - (history, text, tokens) with consistent output structure using :class:`~tensordict.TensorClass` objects. + Packing vs Padding: + - Packing (`pad_model_input=False`): + * More memory efficient for variable-length sequences. + * Not all models support packed input (requires custom attention masks and position ids). + * May be less compatible with some HuggingFace models or custom architectures. + - Padding (`pad_model_input=True`): + * Universally supported by all models. + * Wastes memory for short sequences in a batch. + * Simpler, but less efficient for highly variable-length data. + - If unsure, use padding for maximum compatibility. Use packing for large batches of variable-length data and when your model supports it. + + Additional error handling is provided for empty and overlong sequences. Args: model (transformers.AutoModelForCausalLM | str): The Hugging Face Transformers model to wrap. @@ -65,8 +75,16 @@ class TransformersWrapper(LLMWrapperBase): return_log_probs (bool, optional): Whether to return log probabilities. Defaults to `False`. generate_kwargs (dict | None, optional): Additional arguments to pass to the model's generate method. Defaults to `None`. tokenizer_kwargs (dict | None, optional): Additional arguments to pass to the tokenizer. Defaults to `None`. - pad_output (bool, optional): Whether to pad the output sequences to a uniform length. Transformers require `pad_output=True`, and the output - sequences will be padded and represented as tensors. Defaults to `False`. + pad_output (bool, optional): Whether to pad the output sequences to a uniform length. This does not impact the underlying padding + during call to the model. To use padding or packing during the model `forward` call, see `pad_model_input`. + Defaults to `False`. + pad_model_input (bool, optional): Whether to pad the model input sequences to a uniform length. + If `False`, packing will be used instead. Packing is generally more memory efficient than padding, + but this feature may not work with all models. + `pad_model_input` can only be used when `generate=False`. + This does not impact the padding of the model output - one may ask for padded output though `pad_output=True` while the model + is called with `pad_model_input=False`. + Defaults to `True`. inplace (Literal[True, False, "empty"] | None, optional): Determines how the module should handle in-place operations. Defaults to `True`. device (torch.device | None, optional): The device to use for computation. Defaults to `None`. layout (torch.layout | None, optional): The layout to use for the output tensors when `pad_output=False`. Defaults to `torch.strided`. @@ -157,6 +175,7 @@ def __init__( generate_kwargs: dict | None = None, tokenizer_kwargs: dict | None = None, pad_output: bool = False, + pad_model_input: bool | None = None, inplace: Literal[True, False, "empty"] | None = None, device: torch.device | None = None, layout: torch.layout | None = None, @@ -192,6 +211,10 @@ def __init__( self.input_mode = input_mode self.attention_mask_key = attention_mask_key self.generate = generate + if pad_model_input is not None and generate: + raise ValueError("pad_model_input is not supported when generate=True.") + pad_model_input = pad_model_input if pad_model_input is not None else True + self.pad_model_input = pad_model_input # Auto-determine what to return based on input mode self.return_history = input_mode in ("history",) @@ -980,49 +1003,93 @@ def _logprobs_from_history_tokens( """Compute log-probs from history tokens.""" pad_val = self.tokenizer.pad_token_id - # unfortunately HF wants us to use padded tensors - tokens_full_padded = response_tokens.get( - "input_ids", - as_padded_tensor=True, - padding_side="left", - padding_value=pad_val, - ) - if not isinstance(tokens_full_padded, torch.Tensor): - raise ValueError( - f"Expected Tensor for tokens_full_padded, got {type(tokens_full_padded)}" - ) - attention_mask_full_padded = response_tokens.get( - "attention_mask", - as_padded_tensor=True, - padding_side="left", - padding_value=0, - ) - if not isinstance(attention_mask_full_padded, torch.Tensor): - raise ValueError( - f"Expected Tensor for attention_mask_full_padded, got {type(attention_mask_full_padded)}" - ) - if cfg is not None: kwargs = copy(self.generate_kwargs) kwargs["generation_config"] = cfg else: kwargs = self.generate_kwargs - tokens_out_struct = self.model( - tokens_full_padded, attention_mask=attention_mask_full_padded, **kwargs - ) - - ( - log_probs_full_padded, - logits_full_padded, - ) = self._compute_log_probs_from_model_output( - tokens_out_struct, - tokens_full_padded, - attention_mask_full_padded, - pad_val, - logits_only=logits_only, - ) + # non-packed forward pass + if self.pad_model_input: + # unfortunately HF wants us to use padded tensors + tokens_full_padded = response_tokens.get( + "input_ids", + as_padded_tensor=True, + padding_side="left", + padding_value=pad_val, + ) + if not isinstance(tokens_full_padded, torch.Tensor): + raise ValueError( + f"Expected Tensor for tokens_full_padded, got {type(tokens_full_padded)}" + ) + attention_mask_full_padded = response_tokens.get( + "attention_mask", + as_padded_tensor=True, + padding_side="left", + padding_value=0, + ) + if not isinstance(attention_mask_full_padded, torch.Tensor): + raise ValueError( + f"Expected Tensor for attention_mask_full_padded, got {type(attention_mask_full_padded)}" + ) + ( + log_probs_full_padded, + logits_full_padded, + ) = self._model_forward_with_padded_sequences( + tokens_full_padded, + attention_mask_full_padded, + pad_val=pad_val, + logits_only=logits_only, + **kwargs, + ) + else: + # unfortunately HF wants us to use padded tensors + tokens_full_unpadded = response_tokens.get( + "input_ids", + as_nested_tensor=True, + layout=torch.jagged, + ) + attention_mask_full_unpadded = response_tokens.get( + "attention_mask", + as_nested_tensor=True, + layout=torch.jagged, + ) + ( + log_probs_full_unpadded, + logits_full_unpadded, + ) = self._model_forward_with_packed_sequences( + # TODO: no padding if we don't need to + tokens_full_unpadded, + attention_mask_full_unpadded, + pad=False, + logits_only=logits_only, + **kwargs, + ) + tokens_full_padded = pad_sequence( + tokens_full_unpadded.unbind(0), + batch_first=True, + padding_value=pad_val, + padding_side="left", + ) + attention_mask_full_padded = pad_sequence( + attention_mask_full_unpadded.unbind(0), + batch_first=True, + padding_value=0, + padding_side="left", + ) + log_probs_full_padded = pad_sequence( + log_probs_full_unpadded.unbind(0), + batch_first=True, + padding_value=0.0, + padding_side="left", + ) + logits_full_padded = pad_sequence( + logits_full_unpadded.unbind(0), + batch_first=True, + padding_value=0.0, + padding_side="left", + ) # Build output TensorClass objects text_obj = Text._from_tensordict( TensorDict(batch_size=out.batch_size).to_lazystack(0) @@ -1168,34 +1235,74 @@ def _from_transformers_logprobs_text(self, td, cfg, out, logits_only=False): .to_lazystack(0) .update(dict(tokens_in)) ) - input_ids_full_padded = tokens_in.get( - "input_ids", - as_padded_tensor=True, - padding_side="left", - padding_value=self.padding_value, - ) - attention_mask_full_padded = tokens_in.get( - "attention_mask", - as_padded_tensor=True, - padding_side="left", - padding_value=0, - ) + pad_val = self.padding_value - tokens_out_struct = self.model( - input_ids_full_padded, attention_mask=attention_mask_full_padded, **kwargs - ) + if self.pad_model_input: + tokens_full_padded = tokens_in.get( + "input_ids", + as_padded_tensor=True, + padding_side="left", + padding_value=pad_val, + ) + attention_mask_full_padded = tokens_in.get( + "attention_mask", + as_padded_tensor=True, + padding_side="left", + padding_value=0, + ) - # Compute log-probs for the input tokens - ( - log_probs_full_padded, - logits_full_padded, - ) = self._compute_log_probs_from_model_output( - tokens_out_struct, - input_ids_full_padded, - attention_mask_full_padded, - self.tokenizer.pad_token_id, - logits_only=logits_only, - ) + ( + log_probs_full_padded, + logits_full_padded, + ) = self._model_forward_with_padded_sequences( + tokens_full_padded, + attention_mask_full_padded, + pad_val=pad_val, + logits_only=logits_only, + **kwargs, + ) + else: + # packed forward pass + tokens_full_unpadded = tokens_in.get( + "input_ids", + as_nested_tensor=True, + layout=torch.jagged, + ) + attention_mask_full_unpadded = tokens_in.get( + "attention_mask", + as_nested_tensor=True, + layout=torch.jagged, + ) + ( + log_probs_full_unpadded, + logits_full_unpadded, + ) = self._model_forward_with_packed_sequences( + tokens_full_unpadded, attention_mask_full_unpadded, pad=False, **kwargs + ) + tokens_full_padded = pad_sequence( + tokens_full_unpadded.unbind(0), + batch_first=True, + padding_value=pad_val, + padding_side="left", + ) + attention_mask_full_padded = pad_sequence( + attention_mask_full_unpadded.unbind(0), + batch_first=True, + padding_value=0, + padding_side="left", + ) + log_probs_full_padded = pad_sequence( + log_probs_full_unpadded.unbind(0), + batch_first=True, + padding_value=0.0, + padding_side="left", + ) + logits_full_padded = pad_sequence( + logits_full_unpadded.unbind(0), + batch_first=True, + padding_value=0.0, + padding_side="left", + ) # Build output TensorClass objects text_obj = Text._from_tensordict( @@ -1210,10 +1317,10 @@ def _from_transformers_logprobs_text(self, td, cfg, out, logits_only=False): TensorDict(batch_size=out.batch_size).to_lazystack(0) ) if self.pad_output: - tokens_obj.full = input_ids_full_padded + tokens_obj.full = tokens_full_padded else: input_ids_full_unpadded = _unpad_tensors( - input_ids_full_padded, attention_mask_full_padded, as_nested=False + tokens_full_padded, attention_mask_full_padded, as_nested=False ) tokens_obj.full = input_ids_full_unpadded tokens_obj.response = None @@ -1460,50 +1567,109 @@ def _from_transformers_logprobs_tokens( pad_val = self.tokenizer.pad_token_id - input_ids_full_padded = td.get( - self.input_key, - as_padded_tensor=True, - padding_side="left", - padding_value=pad_val, - ) - # Attention mask: try first the regular entry, then the key provided in the constructor, finally fallback on eager attention mask - attention_mask_full_padded = td.get( - ("masks", "all_attention_mask"), - as_padded_tensor=True, - padding_side="left", - padding_value=False, - ) - if attention_mask_full_padded is None: + if cfg is not None: + kwargs = copy(self.generate_kwargs) + kwargs["generation_config"] = cfg + else: + kwargs = self.generate_kwargs + + if self.pad_model_input: + tokens_full_padded = td.get( + self.input_key, + as_padded_tensor=True, + padding_side="left", + padding_value=pad_val, + ) + # Attention mask: try first the regular entry, then the key provided in the constructor, finally fallback on eager attention mask attention_mask_full_padded = td.get( - self.attention_mask_key, + ("masks", "all_attention_mask"), as_padded_tensor=True, padding_side="left", padding_value=False, ) if attention_mask_full_padded is None: - attention_mask_full_padded = input_ids_full_padded != pad_val + attention_mask_full_padded = td.get( + self.attention_mask_key, + as_padded_tensor=True, + padding_side="left", + padding_value=False, + ) + if attention_mask_full_padded is None: + attention_mask_full_padded = tokens_full_padded != pad_val - if cfg is not None: - kwargs = copy(self.generate_kwargs) - kwargs["generation_config"] = cfg + ( + log_probs_full_padded, + logits_full_padded, + ) = self._model_forward_with_padded_sequences( + tokens_full_padded, + attention_mask_full_padded, + pad_val=pad_val, + logits_only=logits_only, + **kwargs, + ) else: - kwargs = self.generate_kwargs - - tokens_out_struct = self.model( - input_ids_full_padded, attention_mask=attention_mask_full_padded, **kwargs - ) + # packed forward pass + # unfortunately HF wants us to use padded tensors + tokens_full_unpadded = td.get( + self.input_key, + as_nested_tensor=True, + layout=torch.jagged, + ) + if tokens_full_unpadded is None: + raise ValueError( + f"Expected '{self.input_key}' key for tokens input mode, but found keys: {list(td.keys())}" + ) + # Attention mask: try first the regular entry, then the key provided in the constructor, finally fallback on eager attention mask + attention_mask_full_unpadded = td.get( + ("masks", "all_attention_mask"), + as_nested_tensor=True, + layout=torch.jagged, + ) + if attention_mask_full_unpadded is None: + attention_mask_full_unpadded = td.get( + self.attention_mask_key, + as_nested_tensor=True, + layout=torch.jagged, + ) + if attention_mask_full_unpadded is None: + # does this even work? + attention_mask_full_unpadded = tokens_full_unpadded != pad_val - # Compute log-probs for the input tokens - ( - log_probs_full_padded, - logits_full_padded, - ) = self._compute_log_probs_from_model_output( - tokens_out_struct, - input_ids_full_padded, - attention_mask_full_padded, - self.tokenizer.pad_token_id, - logits_only=logits_only, - ) + ( + log_probs_full_unpadded, + logits_full_unpadded, + ) = self._model_forward_with_packed_sequences( + # TODO: no padding if we don't need to + tokens_full_unpadded, + attention_mask_full_unpadded, + pad=False, + logits_only=logits_only, + **kwargs, + ) + tokens_full_padded = pad_sequence( + tokens_full_unpadded.unbind(0), + batch_first=True, + padding_value=pad_val, + padding_side="left", + ) + attention_mask_full_padded = pad_sequence( + attention_mask_full_unpadded.unbind(0), + batch_first=True, + padding_value=0, + padding_side="left", + ) + log_probs_full_padded = pad_sequence( + log_probs_full_unpadded.unbind(0), + batch_first=True, + padding_value=0.0, + padding_side="left", + ) + logits_full_padded = pad_sequence( + logits_full_unpadded.unbind(0), + batch_first=True, + padding_value=0.0, + padding_side="left", + ) # Build output TensorClass objects text_obj = Text._from_tensordict( @@ -1519,11 +1685,11 @@ def _from_transformers_logprobs_tokens( ) if not self.pad_output: input_ids_full_unpadded = _unpad_tensors( - input_ids_full_padded, attention_mask_full_padded, as_nested=False + tokens_full_padded, attention_mask_full_padded, as_nested=False ) tokens_obj.full = input_ids_full_unpadded else: - tokens_obj.full = input_ids_full_padded + tokens_obj.full = tokens_full_padded tokens_obj.response = None tokens_obj.padded = MetaData(self.pad_output) out.set(self.tokens_key, tokens_obj) @@ -1846,3 +2012,249 @@ def _get_generic_dist(self, tensordict: TensorDictBase, **kwargs) -> D.Distribut finally: self._in_get_dist_call = False self.out_keys.remove("logits") + + def _pack_sequences( + self, + input_ids: torch.nested.NestedTensor, + attention_mask: torch.nested.NestedTensor, + ) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]: + """Pack sequences into a single tensor.""" + packed_input_ids = input_ids.values() + lengths = input_ids.lengths() + if lengths is None: + offsets = input_ids.offsets() + lengths = offsets.diff() + offsets = offsets[1:] + else: + offsets = lengths.cumsum(0) + # Create block-diagonal attention mask to prevent cross-sequence attention + attention_mask = self._create_block_diagonal_attention_mask(lengths) + # Create position IDs that restart for each sequence + position_ids = self._create_packed_position_ids( + lengths, total_length=packed_input_ids.numel() + ) + + packing_metadata = { + "sequence_lengths": lengths, + "cumulative_lengths": offsets, + "attention_mask": attention_mask, + "position_ids": position_ids, + } + + return ( + packed_input_ids.unsqueeze(0), + attention_mask.unsqueeze(0), + packing_metadata, + ) + + def _model_forward_with_padded_sequences( + self, + tokens_full_padded: torch.Tensor, + attention_mask_full_padded: torch.Tensor, + *, + pad_val: float | int | torch.Tensor | None = None, + logits_only: bool = False, + **kwargs, + ): + """Forward pass with padded sequences.""" + # Error handling for empty sequences + if tokens_full_padded.numel() == 0: + raise ValueError( + "Input contains empty sequences. Packing/padding requires at least one token per sequence." + ) + # Error handling for overlong sequences + max_len = getattr(self.model.config, "max_position_embeddings", None) + if max_len is not None and tokens_full_padded.shape[-1] > max_len: + raise ValueError( + f"Input sequence length ({tokens_full_padded.shape[-1]}) exceeds model's max_position_embeddings ({max_len}). Consider truncating or splitting your input." + ) + tokens_out_struct = self.model( + tokens_full_padded, attention_mask_full_padded, **kwargs + ) + ( + log_probs_full_padded, + logits_full_padded, + ) = self._compute_log_probs_from_model_output( + tokens_out_struct, + tokens_full_padded, + attention_mask_full_padded, + pad_val, + logits_only=logits_only, + ) + return log_probs_full_padded, logits_full_padded + + def _model_forward_with_packed_sequences( + self, + flat_input_ids: torch.Tensor, + block_diag_attention_mask: torch.Tensor, + *, + pad: bool = True, + logits_only: bool = False, + **kwargs, + ): + """Pack sequences into a single tensor and forward them through the model. + + Args: + flat_input_ids (NestedTensor): NestedTensor of shape (batch_size, -1) + block_diag_attention_mask (NestedTensor): NestedTensor of shape (batch_size, -1) + + Returns: + pad (bool): Whether to pad the output tensors. + logits_only (bool): Whether to return only logits. + kwargs (dict): Additional keyword arguments to pass to the model. + + """ + # Error handling for empty sequences + if flat_input_ids.numel() == 0: + raise ValueError( + "Input contains empty sequences. Packing requires at least one token per sequence." + ) + # Error handling for overlong sequences + # Note: Skipping this check for nested tensors due to symbolic representation issues + # The model will handle sequence length limits internally + max_len = getattr(self.model.config, "max_position_embeddings", None) + if max_len is not None and not hasattr(flat_input_ids, "size"): + # Only check for regular tensors, not nested tensors + actual_size = flat_input_ids.shape[-1] + if actual_size > max_len: + raise ValueError( + f"Input sequence length ({actual_size}) exceeds model's max_position_embeddings ({max_len}). Consider truncating or splitting your input." + ) + ( + flat_input_ids, + block_diag_attention_mask, + packing_metadata, + ) = self._pack_sequences(flat_input_ids, block_diag_attention_mask) + + outputs = self.model( + input_ids=flat_input_ids, + attention_mask=block_diag_attention_mask.unsqueeze(0), + position_ids=packing_metadata["position_ids"], + use_cache=False, # Disable KV cache for packing + **kwargs, + ) + log_probs, logits = self._unpack_outputs( + outputs, packing_metadata, flat_input_ids, pad=pad, logits_only=logits_only + ) + return log_probs, logits + + def _unpack_outputs( + self, + outputs, + packing_metadata: dict[str, Any], + flat_input_ids: torch.Tensor, + pad: bool = True, + logits_only: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Unpack outputs using nested tensors - zero syncs.""" + # use cross_entropy to compute log_probs + log_probs, logits = self._compute_log_probs_from_model_output( + outputs, + flat_input_ids, + torch.ones_like(flat_input_ids, dtype=torch.bool), + -100, + logits_only=logits_only, + ) + # check shapes: [1, L] for log_probs, [1, L, vocab_size] for logits + if logits_only: + log_probs = None + else: + if log_probs.shape != logits.shape[:2]: + raise ValueError( + f"Log probs shape {log_probs.shape=} does not match logits shape {logits.shape[:2]=}" + ) + if log_probs.ndim != 2: + raise ValueError(f"Log probs shape {log_probs.shape=} is not 2D") + if logits.ndim != 3: + raise ValueError(f"Logits shape {logits.shape=} is not 3D") + sequence_lengths = packing_metadata["sequence_lengths"] + if log_probs.shape[1] != sequence_lengths.sum(): + raise ValueError( + f"Log probs shape {log_probs.shape=} does not match sequence lengths {sequence_lengths.sum()=}" + ) + + log_probs = log_probs.squeeze(0) + nested_logprobs = torch.nested.nested_tensor_from_jagged( + log_probs, + lengths=sequence_lengths, + ) + + logits = logits.squeeze(0) + nested_logits = torch.nested.nested_tensor_from_jagged( + logits, # Remove batch dim: (total_length, vocab_size) + lengths=sequence_lengths, + ) + + if pad: + return nested_logprobs.to_padded_tensor( + padding=0.0 + ), nested_logits.to_padded_tensor(padding=0.0) + return nested_logprobs, nested_logits + + def _create_block_diagonal_attention_mask( + self, sequence_lengths: torch.Tensor + ) -> torch.Tensor: + """Efficient creation of a block-diagonal attention mask. + + Zero cuda syncs, no integer involved except len(tensor) - compilable. + + Args: + sequence_lengths: Tensor of shape (batch_size,) containing the lengths of the sequences + + Returns: + attention_mask: Tensor of shape (batch_size, total_length, total_length) + where each sequence can only attend to itself. + """ + seq_ids = torch.arange(len(sequence_lengths), device=sequence_lengths.device) + position_to_seq_id = seq_ids.repeat_interleave(sequence_lengths) + + attention_mask = position_to_seq_id.unsqueeze( + 1 + ) == position_to_seq_id.unsqueeze(0) + return attention_mask + + def repeat_interleave_causal(self, sequence_lengths: torch.Tensor) -> torch.Tensor: + """Same as _create_block_diagonal_attention_mask, but with causal masking.""" + total_length = sequence_lengths.sum() + + seq_ids = torch.arange(len(sequence_lengths), device=sequence_lengths.device) + position_to_seq_id = seq_ids.repeat_interleave(sequence_lengths) + + positions = torch.arange(int(total_length), device=sequence_lengths.device) + + same_sequence = position_to_seq_id.unsqueeze(1) == position_to_seq_id.unsqueeze( + 0 + ) + causal = positions.unsqueeze(0) <= positions.unsqueeze(1) + + attention_mask = same_sequence & causal + return attention_mask + + def _create_packed_position_ids( + self, sequence_lengths: torch.Tensor, total_length: int | None = None + ) -> torch.Tensor: + """Create position IDs that restart from 0 for each sequence. + + For sequences of length [3, 2], creates: [0, 1, 2, 0, 1] + + No cuda syncs. + """ + if total_length is None: + total_length = int(sequence_lengths.sum().item()) + + # Create global position IDs: [0, 1, 2, 3, 4] + global_positions = torch.arange(total_length, device=sequence_lengths.device) + + # Create sequence start offsets repeated for each position: [0, 0, 0, 3, 3] + offsets = torch.cat( + [ + torch.zeros(1, device=sequence_lengths.device), + sequence_lengths.cumsum(0)[:-1], + ] + ) + sequence_starts = offsets.repeat_interleave(sequence_lengths) + + # Subtract to get local positions: [0, 1, 2, 0, 1] + position_ids = global_positions - sequence_starts + + return position_ids.unsqueeze(0) # (1, total_length) diff --git a/torchrl/modules/llm/policies/vllm_wrapper.py b/torchrl/modules/llm/policies/vllm_wrapper.py index 2df4386afe0..3dfe687512b 100644 --- a/torchrl/modules/llm/policies/vllm_wrapper.py +++ b/torchrl/modules/llm/policies/vllm_wrapper.py @@ -76,6 +76,8 @@ class vLLMWrapper(LLMWrapperBase): generate_kwargs (dict | None, optional): Additional arguments to pass to the model's generate method. Defaults to `None`. tokenizer_kwargs (dict | None, optional): Additional arguments to pass to the tokenizer. Defaults to `None`. pad_output (bool, optional): Whether to pad the output sequences to a uniform length. Defaults to `False`. + pad_model_input (bool, optional): Whether to pad the model input sequences to a uniform length. + This is not supported by vLLM. inplace (Literal[True, False, "empty"] | None, optional): Determines how the module should handle in-place operations. Defaults to `True`. device (torch.device | None, optional): The device to use for computation. Defaults to `None`. layout (torch.layout | None, optional): The layout to use for the output tensors when `pad_output=False`. Defaults to `torch.strided`. @@ -167,6 +169,7 @@ def __init__( generate_kwargs: dict | None = None, tokenizer_kwargs: dict | None = None, pad_output: bool = False, + pad_model_input: bool | None = None, inplace: Literal[True, False, "empty"] | None = None, device: torch.device | None = None, layout: torch.layout | None = None, @@ -208,6 +211,8 @@ def __init__( self.input_mode = input_mode self.attention_mask_key = attention_mask_key self.generate = generate + if pad_model_input is not None: + raise ValueError("pad_model_input is not supported by vLLMWrapper.") # Auto-determine what to return based on input mode self.return_history = input_mode in ("history",)