Skip to content

Commit

Permalink
Update torch.py
Browse files Browse the repository at this point in the history
NextSentencePred实际是一个线性层 为了代码的简洁性建议书写与其他网络结构一起放在nn.sequential中
  • Loading branch information
XihWang authored Jan 4, 2025
1 parent e6b18cc commit 9e9e64f
Showing 1 changed file with 6 additions and 16 deletions.
22 changes: 6 additions & 16 deletions d2l/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2276,17 +2276,6 @@ def forward(self, X, pred_positions):
mlm_Y_hat = self.mlp(masked_X)
return mlm_Y_hat

class NextSentencePred(nn.Module):
"""BERT的下一句预测任务
Defined in :numref:`subsec_mlm`"""
def __init__(self, num_inputs, **kwargs):
super(NextSentencePred, self).__init__(**kwargs)
self.output = nn.Linear(num_inputs, 2)

def forward(self, X):
# X的形状:(batchsize,num_hiddens)
return self.output(X)

class BERTModel(nn.Module):
"""BERT模型
Expand All @@ -2295,17 +2284,18 @@ class BERTModel(nn.Module):
def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input,
ffn_num_hiddens, num_heads, num_layers, dropout,
max_len=1000, key_size=768, query_size=768, value_size=768,
hid_in_features=768, mlm_in_features=768,
nsp_in_features=768):
hid_in_features=768, mlm_in_features=768
):
super(BERTModel, self).__init__()
self.encoder = BERTEncoder(vocab_size, num_hiddens, norm_shape,
ffn_num_input, ffn_num_hiddens, num_heads, num_layers,
dropout, max_len=max_len, key_size=key_size,
query_size=query_size, value_size=value_size)
self.hidden = nn.Sequential(nn.Linear(hid_in_features, num_hiddens),
nn.Tanh())
nn.Tanh(),
nn.Linear(num_hiddens, 2))
self.mlm = MaskLM(vocab_size, num_hiddens, mlm_in_features)
self.nsp = NextSentencePred(nsp_in_features)


def forward(self, tokens, segments, valid_lens=None,
pred_positions=None):
Expand All @@ -2315,7 +2305,7 @@ def forward(self, tokens, segments, valid_lens=None,
else:
mlm_Y_hat = None
# 用于下一句预测的多层感知机分类器的隐藏层,0是“<cls>”标记的索引
nsp_Y_hat = self.nsp(self.hidden(encoded_X[:, 0, :]))
nsp_Y_hat = self.hidden(encoded_X[:, 0, :])
return encoded_X, mlm_Y_hat, nsp_Y_hat

d2l.DATA_HUB['wikitext-2'] = (
Expand Down

0 comments on commit 9e9e64f

Please sign in to comment.