10
10
11
11
from .modules import (
12
12
TransformerEncoder ,
13
+ ProjectionLayer ,
13
14
)
14
15
import logging
15
16
logger = logging .getLogger (__name__ )
@@ -25,6 +26,8 @@ class RobertaEncoderConf:
25
26
num_attention_heads : int = 12
26
27
num_encoder_layers : int = 12
27
28
dropout : float = 0.1
29
+ projection_dim : Optional [int ] = None
30
+ projection_dropout : Optional [float ] = None
28
31
scaling : Optional [float ] = None
29
32
normalize_before : bool = False
30
33
@@ -40,6 +43,8 @@ def __init__(
40
43
num_attention_heads : int ,
41
44
num_encoder_layers : int ,
42
45
dropout : float = 0.1 ,
46
+ projection_dim : Optional [int ] = None ,
47
+ projection_dropout : Optional [float ] = None ,
43
48
scaling : Optional [float ] = None ,
44
49
normalize_before : bool = False ,
45
50
):
@@ -62,6 +67,10 @@ def __init__(
62
67
return_all_layers = False ,
63
68
)
64
69
70
+ self .project = None
71
+ if projection_dim is not None :
72
+ self .project = ProjectionLayer (embed_dim = embedding_dim , projection_dim = projection_dim , dropout = projection_dropout )
73
+
65
74
@classmethod
66
75
def from_config (cls , config : RobertaEncoderConf ):
67
76
return cls (** asdict (config ))
@@ -73,6 +82,10 @@ def forward(self, tokens: Tensor, mask: Optional[Tensor] = None) -> Tensor:
73
82
output = output .transpose (1 , 0 )
74
83
if mask is not None :
75
84
output = output [mask .to (torch .bool ), :]
85
+
86
+ if self .project is not None :
87
+ output = self .project (output )
88
+
76
89
return output
77
90
78
91
0 commit comments