19
19
20
20
import numpy as np
21
21
import torch
22
- from torch import _softmax_backward_data , nn
23
- from torch .nn import CrossEntropyLoss , LayerNorm
22
+ from torch import (
23
+ _softmax_backward_data ,
24
+ nn ,
25
+ )
26
+ from torch .nn import (
27
+ CrossEntropyLoss ,
28
+ LayerNorm ,
29
+ )
24
30
25
31
from ...activations import ACT2FN
26
- from ...file_utils import add_code_sample_docstrings , add_start_docstrings , add_start_docstrings_to_model_forward
32
+ from ...file_utils import (
33
+ add_code_sample_docstrings ,
34
+ add_start_docstrings ,
35
+ add_start_docstrings_to_model_forward ,
36
+ )
27
37
from ...modeling_outputs import (
28
38
BaseModelOutput ,
29
39
MaskedLMOutput ,
34
44
from ...modeling_utils import PreTrainedModel
35
45
from ...utils import logging
36
46
from .configuration_deberta_v2 import DebertaV2Config
47
+ from .jit_tracing import traceable
37
48
38
49
39
50
logger = logging .get_logger (__name__ )
@@ -55,7 +66,10 @@ class ContextPooler(nn.Module):
55
66
def __init__ (self , config ):
56
67
super ().__init__ ()
57
68
self .dense = nn .Linear (config .pooler_hidden_size , config .pooler_hidden_size )
58
- self .dropout = StableDropout (config .pooler_dropout )
69
+ if config .ort :
70
+ self .dropout = TorchNNDropout (config .pooler_dropout )
71
+ else :
72
+ self .dropout = StableDropout (config .pooler_dropout )
59
73
self .config = config
60
74
61
75
def forward (self , hidden_states ):
@@ -73,6 +87,7 @@ def output_dim(self):
73
87
return self .config .hidden_size
74
88
75
89
90
+ @traceable
76
91
# Copied from transformers.models.deberta.modeling_deberta.XSoftmax with deberta->deberta_v2
77
92
class XSoftmax (torch .autograd .Function ):
78
93
"""
@@ -144,6 +159,7 @@ def get_mask(input, local_context):
144
159
return mask , dropout
145
160
146
161
162
+ @traceable
147
163
# Copied from transformers.models.deberta.modeling_deberta.XDropout
148
164
class XDropout (torch .autograd .Function ):
149
165
"""Optimized dropout function to save computation and memory by using mask operation instead of multiplication."""
@@ -167,6 +183,11 @@ def backward(ctx, grad_output):
167
183
return grad_output , None
168
184
169
185
186
+ class TorchNNDropout (torch .nn .Dropout ):
187
+ def __init__ (self , drop_prob ):
188
+ super ().__init__ (drop_prob )
189
+
190
+
170
191
# Copied from transformers.models.deberta.modeling_deberta.StableDropout
171
192
class StableDropout (torch .nn .Module ):
172
193
"""
@@ -223,7 +244,10 @@ def __init__(self, config):
223
244
super ().__init__ ()
224
245
self .dense = nn .Linear (config .hidden_size , config .hidden_size )
225
246
self .LayerNorm = LayerNorm (config .hidden_size , config .layer_norm_eps )
226
- self .dropout = StableDropout (config .hidden_dropout_prob )
247
+ if config .ort :
248
+ self .dropout = TorchNNDropout (config .hidden_dropout_prob )
249
+ else :
250
+ self .dropout = StableDropout (config .hidden_dropout_prob )
227
251
228
252
def forward (self , hidden_states , input_tensor ):
229
253
hidden_states = self .dense (hidden_states )
@@ -291,7 +315,10 @@ def __init__(self, config):
291
315
super ().__init__ ()
292
316
self .dense = nn .Linear (config .intermediate_size , config .hidden_size )
293
317
self .LayerNorm = LayerNorm (config .hidden_size , config .layer_norm_eps )
294
- self .dropout = StableDropout (config .hidden_dropout_prob )
318
+ if config .ort :
319
+ self .dropout = TorchNNDropout (config .hidden_dropout_prob )
320
+ else :
321
+ self .dropout = StableDropout (config .hidden_dropout_prob )
295
322
self .config = config
296
323
297
324
def forward (self , hidden_states , input_tensor ):
@@ -346,7 +373,10 @@ def __init__(self, config):
346
373
config .hidden_size , config .hidden_size , kernel_size , padding = (kernel_size - 1 ) // 2 , groups = groups
347
374
)
348
375
self .LayerNorm = LayerNorm (config .hidden_size , config .layer_norm_eps )
349
- self .dropout = StableDropout (config .hidden_dropout_prob )
376
+ if config .ort :
377
+ self .dropout = TorchNNDropout (config .hidden_dropout_prob )
378
+ else :
379
+ self .dropout = StableDropout (config .hidden_dropout_prob )
350
380
self .config = config
351
381
352
382
def forward (self , hidden_states , residual_states , input_mask ):
@@ -584,16 +614,21 @@ def __init__(self, config):
584
614
self .pos_ebd_size = self .max_relative_positions
585
615
if self .position_buckets > 0 :
586
616
self .pos_ebd_size = self .position_buckets
587
-
588
- self .pos_dropout = StableDropout (config .hidden_dropout_prob )
617
+ if config .ort :
618
+ self .pos_dropout = TorchNNDropout (config .hidden_dropout_prob )
619
+ else :
620
+ self .pos_dropout = StableDropout (config .hidden_dropout_prob )
589
621
590
622
if not self .share_att_key :
591
623
if "c2p" in self .pos_att_type or "p2p" in self .pos_att_type :
592
624
self .pos_key_proj = nn .Linear (config .hidden_size , self .all_head_size , bias = True )
593
625
if "p2c" in self .pos_att_type or "p2p" in self .pos_att_type :
594
626
self .pos_query_proj = nn .Linear (config .hidden_size , self .all_head_size )
595
627
596
- self .dropout = StableDropout (config .attention_probs_dropout_prob )
628
+ if config .ort :
629
+ self .dropout = TorchNNDropout (config .attention_probs_dropout_prob )
630
+ else :
631
+ self .dropout = StableDropout (config .attention_probs_dropout_prob )
597
632
598
633
def transpose_for_scores (self , x , attention_heads ):
599
634
new_x_shape = x .size ()[:- 1 ] + (attention_heads , - 1 )
@@ -816,7 +851,10 @@ def __init__(self, config):
816
851
if self .embedding_size != config .hidden_size :
817
852
self .embed_proj = nn .Linear (self .embedding_size , config .hidden_size , bias = False )
818
853
self .LayerNorm = LayerNorm (config .hidden_size , config .layer_norm_eps )
819
- self .dropout = StableDropout (config .hidden_dropout_prob )
854
+ if config .ort :
855
+ self .dropout = TorchNNDropout (config .hidden_dropout_prob )
856
+ else :
857
+ self .dropout = StableDropout (config .hidden_dropout_prob )
820
858
self .config = config
821
859
822
860
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
@@ -1247,7 +1285,10 @@ def __init__(self, config):
1247
1285
self .classifier = torch .nn .Linear (output_dim , num_labels )
1248
1286
drop_out = getattr (config , "cls_dropout" , None )
1249
1287
drop_out = self .config .hidden_dropout_prob if drop_out is None else drop_out
1250
- self .dropout = StableDropout (drop_out )
1288
+ if config .ort :
1289
+ self .dropout = TorchNNDropout (drop_out )
1290
+ else :
1291
+ self .dropout = StableDropout (drop_out )
1251
1292
1252
1293
self .init_weights ()
1253
1294
0 commit comments