Skip to content

Commit 157da8b

Browse files
committed
minor change
1 parent db7b6db commit 157da8b

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

torchtext/models/roberta/model.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def __init__(
6767
return_all_layers=False,
6868
)
6969

70-
self.project = None
70+
self.project = nn.Identity()
7171
if projection_dim is not None:
7272
self.project = ProjectionLayer(embed_dim=embedding_dim, projection_dim=projection_dim, dropout=projection_dropout)
7373

@@ -83,8 +83,7 @@ def forward(self, tokens: Tensor, mask: Optional[Tensor] = None) -> Tensor:
8383
if mask is not None:
8484
output = output[mask.to(torch.bool), :]
8585

86-
if self.project is not None:
87-
output = self.project(output)
86+
output = self.project(output)
8887

8988
return output
9089

0 commit comments

Comments
 (0)