39
39
class TransformersWrapper (LLMWrapperBase ):
40
40
"""A wrapper class for Hugging Face Transformers models, providing a consistent interface for text generation and log probability computation.
41
41
42
- This class is a subclass of :class:`~torchrl.modules.llm.policies.LLMWrapperBase` and provides a unified API for handling different input modalities
43
- (history, text, tokens) with consistent output structure using :class:`~tensordict.TensorClass` objects.
42
+ Packing vs Padding:
43
+ - Packing (`pad_model_input=False`):
44
+ * More memory efficient for variable-length sequences.
45
+ * Not all models support packed input (requires custom attention masks and position ids).
46
+ * May be less compatible with some HuggingFace models or custom architectures.
47
+ - Padding (`pad_model_input=True`):
48
+ * Universally supported by all models.
49
+ * Wastes memory for short sequences in a batch.
50
+ * Simpler, but less efficient for highly variable-length data.
51
+ - If unsure, use padding for maximum compatibility. Use packing for large batches of variable-length data and when your model supports it.
52
+
53
+ Additional error handling is provided for empty and overlong sequences.
44
54
45
55
Args:
46
56
model (transformers.AutoModelForCausalLM | str): The Hugging Face Transformers model to wrap.
@@ -2038,9 +2048,26 @@ def _pack_sequences(
2038
2048
)
2039
2049
2040
2050
def _model_forward_with_padded_sequences (
2041
- self , tokens_full_padded , attention_mask_full_padded , pad_val , logits_only = False , ** kwargs
2051
+ self ,
2052
+ tokens_full_padded : torch .Tensor ,
2053
+ attention_mask_full_padded : torch .Tensor ,
2054
+ * ,
2055
+ pad_val : float | int | torch .Tensor | None = None ,
2056
+ logits_only : bool = False ,
2057
+ ** kwargs ,
2042
2058
):
2043
2059
"""Forward pass with padded sequences."""
2060
+ # Error handling for empty sequences
2061
+ if tokens_full_padded .numel () == 0 :
2062
+ raise ValueError (
2063
+ "Input contains empty sequences. Packing/padding requires at least one token per sequence."
2064
+ )
2065
+ # Error handling for overlong sequences
2066
+ max_len = getattr (self .model .config , "max_position_embeddings" , None )
2067
+ if max_len is not None and tokens_full_padded .shape [- 1 ] > max_len :
2068
+ raise ValueError (
2069
+ f"Input sequence length ({ tokens_full_padded .shape [- 1 ]} ) exceeds model's max_position_embeddings ({ max_len } ). Consider truncating or splitting your input."
2070
+ )
2044
2071
tokens_out_struct = self .model (
2045
2072
tokens_full_padded , attention_mask_full_padded , ** kwargs
2046
2073
)
@@ -2057,35 +2084,51 @@ def _model_forward_with_padded_sequences(
2057
2084
return log_probs_full_padded , logits_full_padded
2058
2085
2059
2086
def _model_forward_with_packed_sequences (
2060
- self , flat_input_ids , block_diag_attention_mask , pad : bool = True , logits_only = False , ** kwargs
2087
+ self ,
2088
+ flat_input_ids : torch .Tensor ,
2089
+ block_diag_attention_mask : torch .Tensor ,
2090
+ * ,
2091
+ pad : bool = True ,
2092
+ logits_only : bool = False ,
2093
+ ** kwargs ,
2061
2094
):
2062
2095
"""Pack sequences into a single tensor and forward them through the model.
2063
2096
2064
2097
Args:
2065
- input_ids : NestedTensor of shape (batch_size, -1)
2066
- attention_mask : NestedTensor of shape (batch_size, -1)
2098
+ flat_input_ids (NestedTensor) : NestedTensor of shape (batch_size, -1)
2099
+ block_diag_attention_mask (NestedTensor) : NestedTensor of shape (batch_size, -1)
2067
2100
2068
2101
Returns:
2069
- logits: NestedTensor of shape (batch_size, -1, vocab_size)
2102
+ pad (bool): Whether to pad the output tensors.
2103
+ logits_only (bool): Whether to return only logits.
2104
+ kwargs (dict): Additional keyword arguments to pass to the model.
2070
2105
2071
2106
"""
2107
+ # Error handling for empty sequences
2108
+ if flat_input_ids .numel () == 0 :
2109
+ raise ValueError (
2110
+ "Input contains empty sequences. Packing requires at least one token per sequence."
2111
+ )
2112
+ # Error handling for overlong sequences
2113
+ # Note: Skipping this check for nested tensors due to symbolic representation issues
2114
+ # The model will handle sequence length limits internally
2115
+ max_len = getattr (self .model .config , "max_position_embeddings" , None )
2116
+ if max_len is not None and not hasattr (flat_input_ids , "size" ):
2117
+ # Only check for regular tensors, not nested tensors
2118
+ actual_size = flat_input_ids .shape [- 1 ]
2119
+ if actual_size > max_len :
2120
+ raise ValueError (
2121
+ f"Input sequence length ({ actual_size } ) exceeds model's max_position_embeddings ({ max_len } ). Consider truncating or splitting your input."
2122
+ )
2072
2123
(
2073
2124
flat_input_ids ,
2074
2125
block_diag_attention_mask ,
2075
2126
packing_metadata ,
2076
2127
) = self ._pack_sequences (flat_input_ids , block_diag_attention_mask )
2077
- # check shapes: [B, L] for input_ids, [B, L, L] for attention_mask
2078
- if flat_input_ids .shape != block_diag_attention_mask .shape [:2 ]:
2079
- raise ValueError (
2080
- f"Input ids shape { flat_input_ids .shape = } does not match attention mask shape { block_diag_attention_mask .shape [:2 ]= } "
2081
- )
2082
- if flat_input_ids .shape [1 ] != block_diag_attention_mask .shape [2 ]:
2083
- raise ValueError (
2084
- f"Input ids shape { flat_input_ids .shape [1 ]= } does not match attention mask shape { block_diag_attention_mask .shape [2 ]= } "
2085
- )
2128
+
2086
2129
outputs = self .model (
2087
2130
input_ids = flat_input_ids ,
2088
- attention_mask = block_diag_attention_mask ,
2131
+ attention_mask = block_diag_attention_mask . unsqueeze ( 0 ) ,
2089
2132
position_ids = packing_metadata ["position_ids" ],
2090
2133
use_cache = False , # Disable KV cache for packing
2091
2134
** kwargs ,
@@ -2113,30 +2156,34 @@ def _unpack_outputs(
2113
2156
logits_only = logits_only ,
2114
2157
)
2115
2158
# check shapes: [1, L] for log_probs, [1, L, vocab_size] for logits
2116
- if log_probs .shape != logits .shape [:2 ]:
2117
- raise ValueError (
2118
- f"Log probs shape { log_probs .shape = } does not match logits shape { logits .shape [:2 ]= } "
2119
- )
2120
- if log_probs .ndim != 2 :
2121
- raise ValueError (f"Log probs shape { log_probs .shape = } is not 2D" )
2122
- if logits .ndim != 3 :
2123
- raise ValueError (f"Logits shape { logits .shape = } is not 3D" )
2124
- sequence_lengths = packing_metadata ["sequence_lengths" ]
2125
- if log_probs .shape [1 ] != sequence_lengths .sum ():
2126
- raise ValueError (
2127
- f"Log probs shape { log_probs .shape = } does not match sequence lengths { sequence_lengths .sum ()= } "
2159
+ if logits_only :
2160
+ log_probs = None
2161
+ else :
2162
+ if log_probs .shape != logits .shape [:2 ]:
2163
+ raise ValueError (
2164
+ f"Log probs shape { log_probs .shape = } does not match logits shape { logits .shape [:2 ]= } "
2165
+ )
2166
+ if log_probs .ndim != 2 :
2167
+ raise ValueError (f"Log probs shape { log_probs .shape = } is not 2D" )
2168
+ if logits .ndim != 3 :
2169
+ raise ValueError (f"Logits shape { logits .shape = } is not 3D" )
2170
+ sequence_lengths = packing_metadata ["sequence_lengths" ]
2171
+ if log_probs .shape [1 ] != sequence_lengths .sum ():
2172
+ raise ValueError (
2173
+ f"Log probs shape { log_probs .shape = } does not match sequence lengths { sequence_lengths .sum ()= } "
2174
+ )
2175
+
2176
+ log_probs = log_probs .squeeze (0 )
2177
+ nested_logprobs = torch .nested .nested_tensor_from_jagged (
2178
+ log_probs ,
2179
+ lengths = sequence_lengths ,
2128
2180
)
2129
2181
2130
2182
logits = logits .squeeze (0 )
2131
2183
nested_logits = torch .nested .nested_tensor_from_jagged (
2132
2184
logits , # Remove batch dim: (total_length, vocab_size)
2133
2185
lengths = sequence_lengths ,
2134
2186
)
2135
- log_probs = log_probs .squeeze (0 )
2136
- nested_logprobs = torch .nested .nested_tensor_from_jagged (
2137
- log_probs ,
2138
- lengths = sequence_lengths ,
2139
- )
2140
2187
2141
2188
if pad :
2142
2189
return nested_logprobs .to_padded_tensor (
@@ -2173,7 +2220,7 @@ def repeat_interleave_causal(self, sequence_lengths: torch.Tensor) -> torch.Tens
2173
2220
seq_ids = torch .arange (len (sequence_lengths ), device = sequence_lengths .device )
2174
2221
position_to_seq_id = seq_ids .repeat_interleave (sequence_lengths )
2175
2222
2176
- positions = torch .arange (total_length , device = sequence_lengths .device )
2223
+ positions = torch .arange (int ( total_length ) , device = sequence_lengths .device )
2177
2224
2178
2225
same_sequence = position_to_seq_id .unsqueeze (1 ) == position_to_seq_id .unsqueeze (
2179
2226
0
@@ -2193,7 +2240,7 @@ def _create_packed_position_ids(
2193
2240
No cuda syncs.
2194
2241
"""
2195
2242
if total_length is None :
2196
- total_length = sequence_lengths .sum ()
2243
+ total_length = int ( sequence_lengths .sum (). item () )
2197
2244
2198
2245
# Create global position IDs: [0, 1, 2, 3, 4]
2199
2246
global_positions = torch .arange (total_length , device = sequence_lengths .device )
0 commit comments