Skip to content

Commit db7b6db

Browse files
committed
add projection layer to roberta encoder
1 parent aea6ad6 commit db7b6db

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

torchtext/models/roberta/model.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from .modules import (
1212
TransformerEncoder,
13+
ProjectionLayer,
1314
)
1415
import logging
1516
logger = logging.getLogger(__name__)
@@ -25,6 +26,8 @@ class RobertaEncoderConf:
2526
num_attention_heads: int = 12
2627
num_encoder_layers: int = 12
2728
dropout: float = 0.1
29+
projection_dim: Optional[int] = None
30+
projection_dropout: Optional[float] = None
2831
scaling: Optional[float] = None
2932
normalize_before: bool = False
3033

@@ -40,6 +43,8 @@ def __init__(
4043
num_attention_heads: int,
4144
num_encoder_layers: int,
4245
dropout: float = 0.1,
46+
projection_dim: Optional[int] = None,
47+
projection_dropout: Optional[float] = None,
4348
scaling: Optional[float] = None,
4449
normalize_before: bool = False,
4550
):
@@ -62,6 +67,10 @@ def __init__(
6267
return_all_layers=False,
6368
)
6469

70+
self.project = None
71+
if projection_dim is not None:
72+
self.project = ProjectionLayer(embed_dim=embedding_dim, projection_dim=projection_dim, dropout=projection_dropout)
73+
6574
@classmethod
6675
def from_config(cls, config: RobertaEncoderConf):
6776
return cls(**asdict(config))
@@ -73,6 +82,10 @@ def forward(self, tokens: Tensor, mask: Optional[Tensor] = None) -> Tensor:
7382
output = output.transpose(1, 0)
7483
if mask is not None:
7584
output = output[mask.to(torch.bool), :]
85+
86+
if self.project is not None:
87+
output = self.project(output)
88+
7689
return output
7790

7891

torchtext/models/roberta/modules.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,27 @@ def _make_positions(self, tensor, pad_index: int):
3131
return torch.cumsum(masked, dim=1) * masked + pad_index
3232

3333

34+
class ProjectionLayer(Module):
35+
def __init__(self,
36+
embed_dim: int,
37+
projection_dim: int,
38+
dropout: Optional[float] = None) -> None:
39+
super().__init__()
40+
41+
self.projection_layer = nn.Linear(embed_dim, projection_dim)
42+
self.norm_layer = nn.LayerNorm(projection_dim)
43+
if dropout is not None:
44+
self.dropout_layer = nn.Dropout(dropout)
45+
else:
46+
self.dropout_layer = nn.Identity()
47+
48+
def forward(self, x: torch.Tensor) -> torch.Tensor:
49+
x = self.projection_layer(x)
50+
x = self.norm_layer(x)
51+
x = self.dropout_layer(x)
52+
return x
53+
54+
3455
class ResidualMLP(Module):
3556
def __init__(
3657
self,

0 commit comments

Comments
 (0)