BERT
2018年10月谷歌AI团队发布BERT模型,在11种NLP任务测试中刷新了最佳成绩,一时风头无两。自然语言处理领域近两年最受关注,并且进展迅速的当属机器阅读理解,其中斯坦福大学于2016年提出的SQuAD数据集对于推动Machine Comprehension的发展起到了巨大的作用。SQuAD 1.0发布时,Google一直没有出手,微软曾长期占据榜首位置,阿里巴巴也曾短暂登顶。2018年1月3日微软亚洲研究院提交的R-NET模型在EM值(Exact Match表示预测答案和真实答案完全匹配)上以82.650的最高分领先,并率先超越人类分数82.304。而当谷歌一出手,便知有没有,目前SQuAD排行榜上已经被BERT霸屏,排行前列的模型几乎全部基于BERT。关于通用语言模型的介绍,可以参考另一篇翻译的博客,以及张俊林老师的介绍,参考链接附在本文末尾。
源码分析
谷歌已开放源码:
其中create_pretraining_data.py用于创建训练数据,run_pretraining.py用于进行预训练。此外,谷歌还提供了二阶段fine tunning的训练代码,run_classifier.py用于句子分类任务,run_squad.py用于机器阅读理解任务,可直接使用。而基于BERT的语言模型可直接对预训练模型进行改造后获得,参考链接:
作者主要对get_masked_lm_output函数进行了改造,具体地,计算masked lm loss时不使用masked_lm_weights,参考代码如下:
|
|
Python中可直接构造输入,然后利用Tensorflow高级API来获得结果:
estimator.predict的预测结果在model_fn_builder中指定:
BERT作为语言模型时,一个不便之处是需要逐个计算每个token的prob,然后计算句子的ppl。
ppl: 自然语言处理领域(NLP)中,衡量语言模型好坏的指标。根据每个词来估计一句话出现的概率,并用句子长度作normalize,ppl值越小,表示该句子越合理。
结果解析,ppl计算代码:
模型训练、导出和部署
由于预训练模型中masked lm loss节点并未命名,所以添加name后需要启动很短暂的预训练,同时将模型导出。get_masked_lm_output函数参考bert-as-language-model中的代码进行相应改造。Tensorflow版本升级后,使用estimator接受输入,原来我们最爱的placeholder找不到了,而在部署模型时,仍需要使用placeholder接受输入,可在run_pretraining.py导出模型时添加如下代码:
本文部署模型使用golang语言,基于tfgo实现模型的加载和tensorflow对应节点的计算。
参考run_pretraining.py导出模型时的代码,golang程序中需要构造7个输入,而masked_lm_weights和next_sentence_labels对于语言模型没有影响,可按自己喜爱构造。以下面的例子说明一下输入的构造标准:
Golang程序计算句子ppl: