@@ -48,7 +48,7 @@ def _prepare_decoder_ids_for_generation(
48
48
return torch .ones ((batch_size , 1 ), dtype = torch .long , device = device ) * pad_idx
49
49
50
50
def greedy_search (
51
- self , input_ids : torch .Tensor , max_length : int , eos_idx : int , pad_idx : Optional [ int ] = None , ** model_kwargs
51
+ self , input_ids : torch .Tensor , max_length : int , eos_idx : int , pad_idx : int , ** model_kwargs
52
52
) -> torch .Tensor :
53
53
"""Greedy search decoding for text generation. Takes the most likely next token every time.
54
54
@@ -62,10 +62,11 @@ def greedy_search(
62
62
Returns:
63
63
Batch of sequences decoded by greedy search.
64
64
"""
65
- unfinished_sequences = torch .ones ((input_ids .shape [0 ], 1 ), device = input_ids .device , dtype = torch .long )
65
+ unfinished_sequences = torch .ones ((input_ids .shape [0 ]), device = input_ids .device , dtype = torch .long )
66
66
67
67
while True :
68
68
model_inputs = self .model .prepare_inputs_for_generation (input_ids , ** model_kwargs )
69
+
69
70
if self .is_huggingface_model :
70
71
model_inputs ["return_dict" ] = True
71
72
model_inputs ["output_hidden_states" ] = True
@@ -77,18 +78,16 @@ def greedy_search(
77
78
78
79
# Calculate probabilities and take the most likely next token
79
80
probs = F .log_softmax (decoder_output [:, - 1 ], dim = - 1 )
80
- _ , next_tokens = torch .topk (probs , 1 )
81
+ next_tokens = torch .argmax (probs , dim = - 1 )
81
82
82
83
# For any finished sequences, padding idx should be the last token
83
- if eos_idx is not None :
84
- if pad_idx is not None :
85
- next_tokens = next_tokens * unfinished_sequences + pad_idx * (1 - unfinished_sequences )
84
+ next_tokens = next_tokens * unfinished_sequences + pad_idx * (1 - unfinished_sequences )
86
85
87
86
# Append the next tokens to the previous tokens
88
- input_ids = torch .cat ([input_ids , next_tokens ], dim = - 1 )
87
+ input_ids = torch .cat ([input_ids , next_tokens [:, None ] ], dim = - 1 )
89
88
90
- if eos_idx is not None :
91
- unfinished_sequences = unfinished_sequences .mul ((next_tokens != eos_idx ).long () )
89
+ # Update unfinished sequences count
90
+ unfinished_sequences = unfinished_sequences .mul ((next_tokens != eos_idx )) .long ()
92
91
93
92
# Stop iterating once all sequences are finished or exceed the max_length
94
93
if unfinished_sequences .max () == 0 or len (input_ids [0 ]) >= max_length :
@@ -128,8 +127,10 @@ def generate(
128
127
129
128
if self .is_encoder_decoder :
130
129
encoder = self .model .get_encoder ()
131
- model_kwargs ["encoder_outputs" ] = encoder (inputs )
130
+ encoder_model_kwargs = {"src_key_padding_mask" : inputs .eq (pad_idx )}
131
+ model_kwargs ["encoder_outputs" ] = encoder (inputs , ** encoder_model_kwargs )
132
132
inputs = self ._prepare_decoder_ids_for_generation (len (inputs ), device = inputs .device , ** model_kwargs )
133
+ model_kwargs ["encoder_padding_mask" ] = encoder_model_kwargs .pop ("src_key_padding_mask" )
133
134
134
135
if max_length is None :
135
136
# Too hard to try to figure out the exact max_seq_length for each model
0 commit comments