Skip to content

Commit db26565

Browse files
authored
Make sure to include padding mask in generation (#2096)
1 parent feef6b8 commit db26565

File tree

3 files changed

+33
-24
lines changed

3 files changed

+33
-24
lines changed

test/integration_tests/test_generate.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,41 +14,47 @@ def setUp(self) -> None:
1414
self.model = t5_base.get_model()
1515
self.model.eval()
1616
# Examples taken from T5 Paper and Huggingface
17-
self.inputs = self.transform(
18-
[
19-
"summarize: studies have shown that owning a dog is good for you",
20-
"translate English to German: That is good.",
21-
"cola sentence: The course is jumping well.",
22-
"stsb sentence1: The rhino grazed on the grass. sentence2: A rhino is grazing in a field.",
23-
"summarize: state authorities dispatched emergency crews tuesday to survey the damage after an onslaught of severe weather in mississippi...",
24-
]
25-
)
17+
self.inputs = [
18+
"summarize: studies have shown that owning a dog is good for you",
19+
"translate English to German: That is good.",
20+
"cola sentence: The course is jumping well.",
21+
"stsb sentence1: The rhino grazed on the grass. sentence2: A rhino is grazing in a field.",
22+
"summarize: state authorities dispatched emergency crews tuesday to survey the damage after an onslaught of severe weather in mississippi...",
23+
]
24+
self.transformed_inputs = self.transform(self.inputs)
2625
torch.manual_seed(0)
2726

2827
def test_greedy_generate_with_t5(self) -> None:
2928
generation_model = GenerationUtils(self.model)
3029

31-
tokens = generation_model.generate(self.inputs, num_beams=1, max_length=30)
30+
tokens = generation_model.generate(self.transformed_inputs, num_beams=1, max_length=30)
3231
generated_text = self.transform.decode(tokens.tolist())
3332

3433
expected_generated_text = [
35-
"a dog is good for you, according to studies . owning a dog is good for you, according to studies .",
36-
"Das ist gut.",
34+
"owning a dog is good for you, according to studies . a dog is a good companion for a variety of reasons",
35+
"Das ist gut so.",
3736
"acceptable",
3837
"4.0",
3938
"mississippi authorities dispatch emergency crews to survey damage . severe weather in mississippi has caused extensive damage",
4039
]
4140

4241
self.assertEqual(generated_text, expected_generated_text)
4342

43+
inputs = self.transform([self.inputs[0]])
44+
45+
tokens_for_single_example = generation_model.generate(inputs, num_beams=1, max_length=30)
46+
generated_text_for_single_example = self.transform.decode(tokens_for_single_example.tolist())
47+
48+
self.assertEqual(generated_text[0], generated_text_for_single_example[-1])
49+
4450
def test_generate_errors_with_incorrect_beams(self) -> None:
4551
generation_model = GenerationUtils(self.model, is_encoder_decoder=True)
4652

4753
with self.assertRaises(ValueError):
48-
generation_model.generate(self.inputs, num_beams=0)
54+
generation_model.generate(self.transformed_inputs, num_beams=0)
4955

5056
@patch("logging.Logger.warning")
5157
def test_warns_when_no_max_len_provided(self, mock) -> None:
5258
generation_model = GenerationUtils(self.model)
53-
generation_model.generate(self.inputs)
59+
generation_model.generate(self.transformed_inputs)
5460
mock.assert_called_with(f"`max_length` was not specified. Defaulting to {DEFAULT_MAX_SEQ_LEN} tokens.")

torchtext/models/t5/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ def prepare_inputs_for_generation(
198198
self,
199199
input_ids: Tensor,
200200
encoder_outputs: ENCODER_OUTPUTS_TYPE,
201+
encoder_padding_mask: Optional[Tensor] = None,
201202
past: Optional[List[PAST_KEY_VALUES_TYPE]] = None,
202203
return_past_key_values: bool = True,
203204
) -> Dict[str, Union[Tensor, ENCODER_OUTPUTS_TYPE, Optional[List[PAST_KEY_VALUES_TYPE]], bool]]:
@@ -209,6 +210,7 @@ def prepare_inputs_for_generation(
209210
"decoder_tokens": input_ids,
210211
"encoder_outputs": encoder_outputs,
211212
"past_key_values": past,
213+
"encoder_padding_mask": encoder_padding_mask,
212214
"return_past_key_values": return_past_key_values,
213215
}
214216

torchtext/prototype/generate.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def _prepare_decoder_ids_for_generation(
4848
return torch.ones((batch_size, 1), dtype=torch.long, device=device) * pad_idx
4949

5050
def greedy_search(
51-
self, input_ids: torch.Tensor, max_length: int, eos_idx: int, pad_idx: Optional[int] = None, **model_kwargs
51+
self, input_ids: torch.Tensor, max_length: int, eos_idx: int, pad_idx: int, **model_kwargs
5252
) -> torch.Tensor:
5353
"""Greedy search decoding for text generation. Takes the most likely next token every time.
5454
@@ -62,10 +62,11 @@ def greedy_search(
6262
Returns:
6363
Batch of sequences decoded by greedy search.
6464
"""
65-
unfinished_sequences = torch.ones((input_ids.shape[0], 1), device=input_ids.device, dtype=torch.long)
65+
unfinished_sequences = torch.ones((input_ids.shape[0]), device=input_ids.device, dtype=torch.long)
6666

6767
while True:
6868
model_inputs = self.model.prepare_inputs_for_generation(input_ids, **model_kwargs)
69+
6970
if self.is_huggingface_model:
7071
model_inputs["return_dict"] = True
7172
model_inputs["output_hidden_states"] = True
@@ -77,18 +78,16 @@ def greedy_search(
7778

7879
# Calculate probabilities and take the most likely next token
7980
probs = F.log_softmax(decoder_output[:, -1], dim=-1)
80-
_, next_tokens = torch.topk(probs, 1)
81+
next_tokens = torch.argmax(probs, dim=-1)
8182

8283
# For any finished sequences, padding idx should be the last token
83-
if eos_idx is not None:
84-
if pad_idx is not None:
85-
next_tokens = next_tokens * unfinished_sequences + pad_idx * (1 - unfinished_sequences)
84+
next_tokens = next_tokens * unfinished_sequences + pad_idx * (1 - unfinished_sequences)
8685

8786
# Append the next tokens to the previous tokens
88-
input_ids = torch.cat([input_ids, next_tokens], dim=-1)
87+
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
8988

90-
if eos_idx is not None:
91-
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_idx).long())
89+
# Update unfinished sequences count
90+
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_idx)).long()
9291

9392
# Stop iterating once all sequences are finished or exceed the max_length
9493
if unfinished_sequences.max() == 0 or len(input_ids[0]) >= max_length:
@@ -128,8 +127,10 @@ def generate(
128127

129128
if self.is_encoder_decoder:
130129
encoder = self.model.get_encoder()
131-
model_kwargs["encoder_outputs"] = encoder(inputs)
130+
encoder_model_kwargs = {"src_key_padding_mask": inputs.eq(pad_idx)}
131+
model_kwargs["encoder_outputs"] = encoder(inputs, **encoder_model_kwargs)
132132
inputs = self._prepare_decoder_ids_for_generation(len(inputs), device=inputs.device, **model_kwargs)
133+
model_kwargs["encoder_padding_mask"] = encoder_model_kwargs.pop("src_key_padding_mask")
133134

134135
if max_length is None:
135136
# Too hard to try to figure out the exact max_seq_length for each model

0 commit comments

Comments
 (0)