Skip to content

Commit 0b2532a

Browse files
authored
add ort to debertav2 model config (#12)
* add ort config for debertav2 model * remove prints * remove old commented code * fix run style error * add flake ignore comment * trial to fix blackify format error
1 parent 8e5b0db commit 0b2532a

File tree

3 files changed

+105
-12
lines changed

3 files changed

+105
-12
lines changed

examples/pytorch/text-classification/run_glue.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,7 @@ def main():
301301
cache_dir=model_args.cache_dir,
302302
revision=model_args.model_revision,
303303
use_auth_token=True if model_args.use_auth_token else None,
304+
ort=True if training_args.ort else None,
304305
)
305306
tokenizer = AutoTokenizer.from_pretrained(
306307
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# flake8: noqa
2+
# coding=utf-8
3+
# Copyright 2020, Microsoft and the HuggingFace Inc. team.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
"""
18+
Logging util @Author: [email protected]
19+
"""
20+
21+
""" Utils for torch jit tracing customer operators/functions"""
22+
import os
23+
24+
import torch
25+
26+
27+
def traceable(cls):
28+
class _Function(object):
29+
@staticmethod
30+
def apply(*args):
31+
if torch.onnx.is_in_onnx_export():
32+
return cls.forward(_Function, *args)
33+
else:
34+
return cls.apply(*args)
35+
36+
@staticmethod
37+
def save_for_backward(*args):
38+
pass
39+
40+
return _Function
41+
42+
43+
class TraceMode:
44+
"""Trace context used when tracing modules contains customer operators/Functions"""
45+
46+
def __enter__(self):
47+
os.environ["JIT_TRACE"] = "True"
48+
return self
49+
50+
def __exit__(self, exp_value, exp_type, trace):
51+
del os.environ["JIT_TRACE"]

src/transformers/models/deberta_v2/modeling_deberta_v2.py

Lines changed: 53 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,21 @@
1919

2020
import numpy as np
2121
import torch
22-
from torch import _softmax_backward_data, nn
23-
from torch.nn import CrossEntropyLoss, LayerNorm
22+
from torch import (
23+
_softmax_backward_data,
24+
nn,
25+
)
26+
from torch.nn import (
27+
CrossEntropyLoss,
28+
LayerNorm,
29+
)
2430

2531
from ...activations import ACT2FN
26-
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
32+
from ...file_utils import (
33+
add_code_sample_docstrings,
34+
add_start_docstrings,
35+
add_start_docstrings_to_model_forward,
36+
)
2737
from ...modeling_outputs import (
2838
BaseModelOutput,
2939
MaskedLMOutput,
@@ -34,6 +44,7 @@
3444
from ...modeling_utils import PreTrainedModel
3545
from ...utils import logging
3646
from .configuration_deberta_v2 import DebertaV2Config
47+
from .jit_tracing import traceable
3748

3849

3950
logger = logging.get_logger(__name__)
@@ -55,7 +66,10 @@ class ContextPooler(nn.Module):
5566
def __init__(self, config):
5667
super().__init__()
5768
self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size)
58-
self.dropout = StableDropout(config.pooler_dropout)
69+
if config.ort:
70+
self.dropout = TorchNNDropout(config.pooler_dropout)
71+
else:
72+
self.dropout = StableDropout(config.pooler_dropout)
5973
self.config = config
6074

6175
def forward(self, hidden_states):
@@ -73,6 +87,7 @@ def output_dim(self):
7387
return self.config.hidden_size
7488

7589

90+
@traceable
7691
# Copied from transformers.models.deberta.modeling_deberta.XSoftmax with deberta->deberta_v2
7792
class XSoftmax(torch.autograd.Function):
7893
"""
@@ -144,6 +159,7 @@ def get_mask(input, local_context):
144159
return mask, dropout
145160

146161

162+
@traceable
147163
# Copied from transformers.models.deberta.modeling_deberta.XDropout
148164
class XDropout(torch.autograd.Function):
149165
"""Optimized dropout function to save computation and memory by using mask operation instead of multiplication."""
@@ -167,6 +183,11 @@ def backward(ctx, grad_output):
167183
return grad_output, None
168184

169185

186+
class TorchNNDropout(torch.nn.Dropout):
187+
def __init__(self, drop_prob):
188+
super().__init__(drop_prob)
189+
190+
170191
# Copied from transformers.models.deberta.modeling_deberta.StableDropout
171192
class StableDropout(torch.nn.Module):
172193
"""
@@ -223,7 +244,10 @@ def __init__(self, config):
223244
super().__init__()
224245
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
225246
self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
226-
self.dropout = StableDropout(config.hidden_dropout_prob)
247+
if config.ort:
248+
self.dropout = TorchNNDropout(config.hidden_dropout_prob)
249+
else:
250+
self.dropout = StableDropout(config.hidden_dropout_prob)
227251

228252
def forward(self, hidden_states, input_tensor):
229253
hidden_states = self.dense(hidden_states)
@@ -291,7 +315,10 @@ def __init__(self, config):
291315
super().__init__()
292316
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
293317
self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
294-
self.dropout = StableDropout(config.hidden_dropout_prob)
318+
if config.ort:
319+
self.dropout = TorchNNDropout(config.hidden_dropout_prob)
320+
else:
321+
self.dropout = StableDropout(config.hidden_dropout_prob)
295322
self.config = config
296323

297324
def forward(self, hidden_states, input_tensor):
@@ -346,7 +373,10 @@ def __init__(self, config):
346373
config.hidden_size, config.hidden_size, kernel_size, padding=(kernel_size - 1) // 2, groups=groups
347374
)
348375
self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
349-
self.dropout = StableDropout(config.hidden_dropout_prob)
376+
if config.ort:
377+
self.dropout = TorchNNDropout(config.hidden_dropout_prob)
378+
else:
379+
self.dropout = StableDropout(config.hidden_dropout_prob)
350380
self.config = config
351381

352382
def forward(self, hidden_states, residual_states, input_mask):
@@ -584,16 +614,21 @@ def __init__(self, config):
584614
self.pos_ebd_size = self.max_relative_positions
585615
if self.position_buckets > 0:
586616
self.pos_ebd_size = self.position_buckets
587-
588-
self.pos_dropout = StableDropout(config.hidden_dropout_prob)
617+
if config.ort:
618+
self.pos_dropout = TorchNNDropout(config.hidden_dropout_prob)
619+
else:
620+
self.pos_dropout = StableDropout(config.hidden_dropout_prob)
589621

590622
if not self.share_att_key:
591623
if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type:
592624
self.pos_key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
593625
if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type:
594626
self.pos_query_proj = nn.Linear(config.hidden_size, self.all_head_size)
595627

596-
self.dropout = StableDropout(config.attention_probs_dropout_prob)
628+
if config.ort:
629+
self.dropout = TorchNNDropout(config.attention_probs_dropout_prob)
630+
else:
631+
self.dropout = StableDropout(config.attention_probs_dropout_prob)
597632

598633
def transpose_for_scores(self, x, attention_heads):
599634
new_x_shape = x.size()[:-1] + (attention_heads, -1)
@@ -816,7 +851,10 @@ def __init__(self, config):
816851
if self.embedding_size != config.hidden_size:
817852
self.embed_proj = nn.Linear(self.embedding_size, config.hidden_size, bias=False)
818853
self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
819-
self.dropout = StableDropout(config.hidden_dropout_prob)
854+
if config.ort:
855+
self.dropout = TorchNNDropout(config.hidden_dropout_prob)
856+
else:
857+
self.dropout = StableDropout(config.hidden_dropout_prob)
820858
self.config = config
821859

822860
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
@@ -1247,7 +1285,10 @@ def __init__(self, config):
12471285
self.classifier = torch.nn.Linear(output_dim, num_labels)
12481286
drop_out = getattr(config, "cls_dropout", None)
12491287
drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
1250-
self.dropout = StableDropout(drop_out)
1288+
if config.ort:
1289+
self.dropout = TorchNNDropout(drop_out)
1290+
else:
1291+
self.dropout = StableDropout(drop_out)
12511292

12521293
self.init_weights()
12531294

0 commit comments

Comments
 (0)