KBQA-Bert学习记录-构建BERT-CRF模型
生活随笔
收集整理的这篇文章主要介绍了
KBQA-Bert学习记录-构建BERT-CRF模型
小编觉得挺不错的,现在分享给大家,帮大家做个参考.
目录
1.__init__方法
2.forward方法
将bert和crf模型结合起来,简单来说就是,设置好Bert模型,以及参数,得到的输出结果给CRF模型即可。
1.__init__方法
这里面主要是bert的参数的定义及导入,还有bert模型的导入。
MODEL_NAME = 'bert-base-chinese-model.bin' CONFIG_NAME = 'bert-base-chinese-config.json' VOB_NAME = 'bert-base-chinese-vocab.txt'class BertCrf(nn.Module):def __init__(self, config_name: str, model_name:str = None, num_tags: int = 2, batch_first: bool = True) -> None:self.batch_first = batch_first# 模型配置文件、模型预训练参数文件判断if not os.path.exists(config_name):raise ValueError("未找到模型配置文件 '{}'".format(config_name))else:self.config_name = config_nameif model_name is not None:if not os.path.exists(model_name):raise ValueError("未找到模型预训练参数文件 '{}'".format(model_name))else:self.model_name = model_nameelse:self.model_name = Noneif num_tags <= 0:raise ValueError(f'invalid number of tags: {num_tags}')super().__init__()# 配置bert的config文件self.bert_config = BertConfig.from_pretrained(self.config_name)self.bert_config.num_labels = num_tagsself.model_kwargs = {'config': self.bert_config}# 如果模型不存在if self.model_name is not None:self.bertModel = BertForTokenClassification.from_pretrained(self.model_name, **self.model_kwargs)else:self.bertModel = BertForTokenClassification(self.bert_config)self.crf_model = CRF(num_tags=num_tags, batch_first=batch_first)2.forward方法
输出的结果,经过处理后,输入CRF函数,返回loss即可。
def forward(self, input_ids: torch.Tensor,tags: torch.Tensor = None,attention_mask: Optional[torch.ByteTensor] = None,token_type_ids=torch.Tensor,decode:bool = True,reduction: str = 'mean')->List:emissions = self.bertModel(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)[0]# 去掉开头的[CLS]以及结尾,结尾可能有两种情况:1、<pad> 2、[SEP]new_emissions = emissions[:, 1:-1]new_mask = attention_mask[:, 2:].bool()# tags为None, 是预测过程,不能求lossif tags is None:loss = Nonepasselse:new_tags = tags[:, 1:-1]loss = self.crf_model(emissions=new_emissions, tags=new_tags, mask=new_mask, reduction=reduction)if decode:tag_list = self.crf_model.decode(emissions=new_emissions, mask=new_mask)return [loss, tag_list]return [loss]总结
以上是生活随笔为你收集整理的KBQA-Bert学习记录-构建BERT-CRF模型的全部内容,希望文章能够帮你解决所遇到的问题。
- 上一篇: python 拼音输入法_用Python
- 下一篇: 高数篇:高等数学全目录