欢迎访问 生活随笔!

生活随笔

当前位置: 首页 > 编程资源 > 编程问答 >内容正文

编程问答

KBQA-Bert学习记录-构建BERT-CRF模型

发布时间:2023/12/31 编程问答 41 豆豆
生活随笔 收集整理的这篇文章主要介绍了 KBQA-Bert学习记录-构建BERT-CRF模型 小编觉得挺不错的,现在分享给大家,帮大家做个参考.

目录

1.__init__方法

2.forward方法


将bert和crf模型结合起来,简单来说就是,设置好Bert模型,以及参数,得到的输出结果给CRF模型即可。

1.__init__方法

这里面主要是bert的参数的定义及导入,还有bert模型的导入。

MODEL_NAME = 'bert-base-chinese-model.bin' CONFIG_NAME = 'bert-base-chinese-config.json' VOB_NAME = 'bert-base-chinese-vocab.txt'class BertCrf(nn.Module):def __init__(self, config_name: str, model_name:str = None, num_tags: int = 2, batch_first: bool = True) -> None:self.batch_first = batch_first# 模型配置文件、模型预训练参数文件判断if not os.path.exists(config_name):raise ValueError("未找到模型配置文件 '{}'".format(config_name))else:self.config_name = config_nameif model_name is not None:if not os.path.exists(model_name):raise ValueError("未找到模型预训练参数文件 '{}'".format(model_name))else:self.model_name = model_nameelse:self.model_name = Noneif num_tags <= 0:raise ValueError(f'invalid number of tags: {num_tags}')super().__init__()# 配置bert的config文件self.bert_config = BertConfig.from_pretrained(self.config_name)self.bert_config.num_labels = num_tagsself.model_kwargs = {'config': self.bert_config}# 如果模型不存在if self.model_name is not None:self.bertModel = BertForTokenClassification.from_pretrained(self.model_name, **self.model_kwargs)else:self.bertModel = BertForTokenClassification(self.bert_config)self.crf_model = CRF(num_tags=num_tags, batch_first=batch_first)

2.forward方法

输出的结果,经过处理后,输入CRF函数,返回loss即可。

def forward(self, input_ids: torch.Tensor,tags: torch.Tensor = None,attention_mask: Optional[torch.ByteTensor] = None,token_type_ids=torch.Tensor,decode:bool = True,reduction: str = 'mean')->List:emissions = self.bertModel(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)[0]# 去掉开头的[CLS]以及结尾,结尾可能有两种情况:1、<pad> 2、[SEP]new_emissions = emissions[:, 1:-1]new_mask = attention_mask[:, 2:].bool()# tags为None, 是预测过程,不能求lossif tags is None:loss = Nonepasselse:new_tags = tags[:, 1:-1]loss = self.crf_model(emissions=new_emissions, tags=new_tags, mask=new_mask, reduction=reduction)if decode:tag_list = self.crf_model.decode(emissions=new_emissions, mask=new_mask)return [loss, tag_list]return [loss]

总结

以上是生活随笔为你收集整理的KBQA-Bert学习记录-构建BERT-CRF模型的全部内容,希望文章能够帮你解决所遇到的问题。

如果觉得生活随笔网站内容还不错,欢迎将生活随笔推荐给好友。