Hierarchical Attention Networks for Document Classification(HAN)
HAN历史意义:
1、基于Attention的文本分类模型得到了很多关注
2、通过层次处理长文档的方式逐渐流行
3、推动了注意力机制在非Seqseq模型上的应用
前人主要忽视的问题:
1、文档中不同句子对于分类的重要性不同
2、句子中不同单词对于分类的重要性也有所不同
本文主要结构
一、Abstract
(通常框架为:任务的重要性 -> 前人缺点 -> 本文模型 -> 实验)
提出了一种针对文档分类任务的层次注意力网络,既包含了一种层次结构,又在词级别和句子级别使用两种注意力机制来选择重要的信息
二、Introduction
主要概括: 之前深度学习模型取得比较好的效果,但是没有注意到文档中不同部分对任务的重要度不同,基于此提出了层次注意力网络
具体背景:
1、文本分类是自然语言的基础任务之一,研究者也开始使用基于深度学习的文本分类模型
2、虽然深度学习的文本分类模型取得非常好的效果,但是没有注意文档的结构,并且没有注意到文档中不同部分对于分类的影响程度不一样
3、为了解决这一个问题,提出了一种层次注意力网络来学习文档的层次结构,并且使用两种注意力机制学习基于上下文结构的重要性
4、与前人的区别是使用上下文来区分句子或单词的重要性,而不仅仅使用单个句子或单个的词
三、Hierarchical Attention Networks
首先介绍了GRU网络
GRU网络图如下所示:
Attention机制指的是从大量的信息中抽取对任务重要的信息,所以能够抽取文档中重要的句子以及句子中重要的单词,结构如下所示:
Hierarchical Attantion Networks(HAN)模型主要包含四部分Word Encoder、Word Attention、Sentence Encoder、Sentence Attention
Word Encoder:
主要是输入词获取词向量矩阵,然后将词向量输入双向GRU网络中得到GRU网络的输出,部分代码片段如下:
""" 定义结构 """ if is_pretrain:self.embedding = nn.Embedding.from_pretrained(weights, freeze=False) else:self.embedding = nn.Embedding(vocab_size, embedding_size)self.word_gru = nn.GRU(input_size=embedding_size,hidden_size=gru_size,num_layers=1,bidirectional=True,batch_first=True)""" 具体实现片段 """x_embedding = self.embedding(x) word_outputs, word_hidden = self.word_gru(x_embedding)Word Attention
主要实现一个词级别的attention机制,在这个里面u对应的是注意力机制的query,不同的是这里的query是个变量也根据模型进行迭代优化,key对应的是gru网络的输出,attention value就是query和key对应计算value值,然后在迭代加和,部分代码片段如下所示:
""" 定义结构 """self.word_context = nn.Parameter(torch.Tensor(2*gru_size, 1),requires_grad=True) # 论文中提到的u也就是query,因为需要更新迭代所以这里面写的是nn.Parameterself.word_dense = nn.Linear(2*gru_size,2*gru_size) # 定义一个全连接网络""" 具体实现 """""" 对应论文中的公式 """ attention_word_outputs = torch.tanh(self.word_dense(word_outputs)) weights = torch.matmul(attention_word_outputs,self.word_context) weights = F.softmax(weights,dim=1)""" 有一些部分为0,所以权重矩阵对应位置有参数也没有意义,所以做mask""" x = x.unsqueeze(2) if gpu:weights = torch.where(x!=0,weights,torch.full_like(x,0,dtype=torch.float).cuda()) else:weights = torch.where(x != 0, weights, torch.full_like(x, 0, dtype=torch.float))weights = weights/(torch.sum(weights,dim=1).unsqueeze(1)+1e-4)Sentence Encoder
首先获取句子向量表示,主要由两部分构成一部分是词级别通过gru网络的输出,另一个部分是attention机制计算出的对应权重,把这两个部分进行加权求和得到句子向量表示,然后将句子向量还是输入到双向GRU网络中得到输出结果,代码片段如下所示:
""" 定义结构 """self.sentence_gru = nn.GRU(input_size=2*gru_size,hidden_size=gru_size,num_layers=1,bidirectional=True,batch_first=True)""" 具体片段 """sentence_vector = torch.sum(word_outputs*weights,dim=1).view([-1,sentence_num,word_outputs.shape[-1]]) sentence_outputs, sentence_hidden = self.sentence_gru(sentence_vector)Sentence Attention
通过定义句子级别的attention,然后获取每个句子的权重,最后得到文档表示,具体代码片段如下所示:
""" 定义网络结构 """self.sentence_context = nn.Parameter(torch.Tensor(2*gru_size, 1),requires_grad=True) self.sentence_dense = nn.Linear(2*gru_size,2*gru_size)self.fc = nn.Linear(2*gru_size,class_num) # 最后文档表示做全连接分类""" 具体片段 """attention_sentence_outputs = torch.tanh(self.sentence_dense(sentence_outputs)) weights = torch.matmul(attention_sentence_outputs,self.sentence_context) weights = F.softmax(weights,dim=1) x = x.view(-1, sentence_num, x.shape[1]) x = torch.sum(x, dim=2).unsqueeze(2) if gpu:weights = torch.where(x!=0,weights,torch.full_like(x,0,dtype=torch.float).cuda()) else:weights = torch.where(x != 0, weights, torch.full_like(x, 0, dtype=torch.float))"""权重归一化""" weights = weights / (torch.sum(weights,dim=1).unsqueeze(1)+1e-4)""" 获得文档向量表示 """ document_vector = torch.sum(sentence_outputs*weights,dim=1)""" 最后根据文档表示对文档进行分类""" output = self.fc(document_vector)
四、Experiment
实验结果部分主要是在相同的数据集上和其它模型做对比表现出该模型的效果比较好,并且根据分布解释了“good”或“bad”在分布是不一样的
五、Related Work
相关工作主要是解释了其它论文的作者是采用什么方法,怎么实现的相当于做一个对比也是一个铺垫
六、Conclusion
关键点:
1、之前基于深度学习的文本分类模型没有关注到文档中不同部分的信息重要性不同
2、通过注意力机制可以学习到文档中各个部分对于分类的重要性
3、提出HAN Attention模型
创新点:
1、提出一种新的文本分类模型-HAN Attention模型
2、通过两种级别的注意力机制同时学习文档中重要的句子和单词
3、在几个文本分类数据集上取得比较好(State of the art)的效果
启发点:
1、模型背后的直觉是文档不同部分对于文档分类的重要性不同,而且这些部分的重要性还取决于内部的单词,而不仅仅是对这部分单独确定重要性
2、单词和句子的重要性是上下文相关的,同样的词或者句子在不同的上下文情景下重要性也不同
七、代码实现
IMDB公开数据集下载地址:http://ir.hit.edu.cn/~dytang/paper/emnlp2015/emnlp-2015-data.7z
""" 数据预处理部分 """from torch.utils import data import os import nltk import numpy as np import pickle from collections import Counter""" 数据集加载 """# 数据集加载 datas = open("./data/imdb/imdb-test.txt.ss",encoding="utf-8").read().splitlines() datas = [data.split(" ")[-1].split()+[data.split(" ")[2]] for data in datas] print(datas[0:1])[['i','knew','that','the','old-time','movie','makers','often','``','borrowed',"''",'or','outright','plagiarized','from','each','other',',','but','this','is','ridiculous','!','<sssss>','not','only','did','george','albert','smith','make','this','film','in','1899',',','but','and','company','made','a','nearly','identical','film','that','same','year','with','the','same','title','!!!','<sssss>','the','worst','part','about','it','is','that','neither','film','was','all','that','great','.','<sssss>','and',',','of','the','two',',','the','smith','one','is','slightly','less','well','made','.','<sssss>','like','all','movies','of','the','1890s',',','this','one','is','incredibly','brief','and','almost','completely','uninteresting','to','audiences','in','the','21st','century','.','<sssss>','only','film','historians','and','crazy','people','like','me','would','watch','this','brief','film','-lrb-','i',"'m",'a','history','teacher','and','film','lover','--','that',"'s",'my','excuse','for','watching','them','both','-rrb-','.','5']]# 根据长度排序,保证训练时每个batch的长度一致datas = sorted(datas,key = lambda x:len(x),reverse=True) labels = [int(data[-1])-1 for data in datas] datas = [data[0:-1] for data in datas]print(labels[0:5]) print (datas[-5:])[7, 9, 9, 8, 9] [['one', 'of', 'the', 'best', 'movie', 'musicals', 'ever', 'made', '.', '<sssss>', 'the', 'singing', 'and', 'dancing', 'are', 'excellent', '.'], ['john', 'goodman', 'is', 'excellent', 'in', 'this', 'entertaining', 'portrayal', 'of', 'babe', 'ruth', "'s", 'life', '.'], ['how', 'to', 'this', 'movie', ':', 'disjointed', 'silly', 'unfulfilling', 'story', 'waste', 'of', 'time'], ['simply', 'a', 'classic', '.', '<sssss>', 'scenario', 'and', 'acting', 'are', 'excellent', '.'], ['there', 'were', 'tng', 'tv', 'episodes', 'with', 'a', 'better', 'story', '.']]# 构建 word2idmin_count = 5 word_freq = {} for data in datas:for word in data:word_freq[word] = word_freq.get(word,0)+1word2id = {"<pad>":0,"<unk>":1} for word in word_freq:if word_freq[word]<min_count:continueelse:word2id[word] = len(word2id)print(word2id){'<pad>': 0,'<unk>': 1,'i': 2,'only': 3,'just': 4,'got': 5,'around': 6,'to': 7,'watching': 8,'the': 9,'movie': 10,'today': 11,'.': 12,'<sssss>': 13,'when': 14,'it': 15,'came': 16,'out': 17,'in': 18,'movies': 19,',': 20,'heard': 21,'so': 22,'many': 23,'bad': 24,'things': 25,'about': 26,'...': 27,'how': 28,'fake': 29,'looked': 30,'long': 31,'winded': 32,'and': 33,'boring': 34,'was': 35,'stupid': 36,"n't": 37,'all': 38,'that': 39,'great': 40,'etc.': 41,'list': 42,'goes': 43,.........# 分句 for i,data in enumerate(datas):datas[i] = " ".join(data).split("<sssss>")for j,sentence in enumerate(datas[i]):datas[i][j] = sentence.split()# 将数据转化为id max_sentence_length = 100 # 句子必须一样的长度 batch_size = 64 # 每个batch size,每个文档的句子一样多 for i,document in enumerate(datas):for j,sentence in enumerate(document):for k,word in enumerate(sentence):datas[i][j][k] = word2id.get(word,word2id["<unk>"])datas[i][j] = datas[i][j][0:max_sentence_length] + \[word2id["<pad>"]]*(max_sentence_length-len(datas[i][j])) for i in range(0,len(datas),batch_size):max_data_length = max([len(x) for x in datas[i:i+batch_size]])for j in range(i,min(i+batch_size,len(datas))):datas[j] = datas[j] + [[word2id["<pad>"]]*max_sentence_length]*(max_data_length-len(datas[j]))"""得到最终输入模型的数据-datas"""
""" 模型构建部分 """# -*- coding: utf-8 -*-# @Time : 2020/11/9 下午9:46 # @Author : TaoWang # @Description :from torch.nn import functional as F import torch.nn as nn import numpy as np import torchclass HAN_Model(nn.Module):def __init__(self,vocab_size,embedding_size,gru_size,class_num,is_pretrain=False,weights=None):""":param vocab_size::param embedding_size::param gru_size::param class_num::param is_pretrain::param weights:"""super(HAN_Model, self).__init__()if is_pretrain:self.embedding = nn.Embedding.from_pretrained(weights, freeze=False)else:self.embedding = nn.Embedding(vocab_size, embedding_size)self.word_gru = nn.GRU(input_size=embedding_size,hidden_size=gru_size,num_layers=1,bidirectional=True,batch_first=True)self.word_context = nn.Parameter(torch.Tensor(2*gru_size, 1), requires_grad=True)self.word_dense = nn.Linear(2*gru_size, 2*gru_size)self.sentence_gru = nn.GRU(input_size=2*gru_size,hidden_size=gru_size,num_layers=1,bidirectional=True,batch_first=True)self.sentence_context = nn.Parameter(torch.Tensor(2*gru_size, 1), requires_grad=True)self.sentence_dense = nn.Linear(2*gru_size, 2*gru_size)self.fc = nn.Linear(2*gru_size, class_num)def forward(self, x, gpu=False):""":param x::param gpu::return:"""sentence_num = x.shape[1]sentence_length = x.shape[2]x = x.view([-1, sentence_length])x_embedding = self.embedding(x)word_outputs, word_hidden = self.word_gru(x_embedding)attention_word_outputs = torch.tanh(self.word_dense(word_outputs))weights = torch.matmul(attention_word_outputs, self.word_context)weights = F.softmax(weights, dim=1)x = x.unsqueeze(2)if gpu:weights = torch.where(x != 0, weights, torch.full_like(x, 0, dtype=torch.float).cuda())else:weights = torch.where(x != 0, weights, torch.full_like(x, 0, dtype=torch.float))weights = weights/(torch.sum(weights, dim=1).unsqueeze(1) + 1e-4)sentence_vector = torch.sum(word_outputs * weights, dim=1).view([-1, sentence_num, word_outputs.shape[-1]])sentence_outputs, sentence_hidden = self.sentence_gru(sentence_vector)attention_sentence_outputs = torch.tanh(self.sentence_dense(sentence_outputs))weights = torch.matmul(attention_sentence_outputs, self.sentence_context)weights = F.softmax(weights, dim=1)x = x.view(-1, sentence_num, x.shape[1])x = torch.sum(x, dim=2).unsqueeze(2)if gpu:weights = torch.where(x != 0, weights, torch.full_like(x, 0, dtype=torch.float))else:weights = torch.where(x != 0, weights, torch.full_like(x, 0, dtype=torch.float))weights = weights/(torch.sum(weights, dim=1).unsqueeze(1) + 1e-4)document_vector = torch.sum(sentence_outputs * weights, dim=1)output = self.fc(document_vector)return outputif __name__ == "__main__":han_model = HAN_Model(vocab_size=30000, embedding_size=200, gru_size=50, class_num=4)x = torch.Tensor(np.zeros([64, 50, 100])).long()x[0][0][0:10] = 1output = han_model(x)print(output)
""" 模型训练部分 """# -*- coding: utf-8 -*-# @Time : 2020/11/9 下午9:43 # @Author : TaoWang # @Description : 模型训练过程import torch import torch.autograd as autograd import torch.nn as nn import torch.optim as optim from model import HAN_Model from data import IMDB_Data import numpy as np from tqdm import tqdm import config as argumentparserconfig = argumentparser.ArgumentParser() torch.manual_seed(config.seed)if config.cuda and torch.cuda.is_available(): # 是否使用gputorch.cuda.set_device(config.gpu)# 导入训练集 training_set = IMDB_Data("imdb-train.txt.ss",min_count=config.min_count,max_sentence_length = config.max_sentence_length,batch_size=config.batch_size,is_pretrain=False) training_iter = torch.utils.data.DataLoader(dataset=training_set,batch_size=config.batch_size,shuffle=False,num_workers=0)# 导入测试集 test_set = IMDB_Data("imdb-test.txt.ss",min_count=config.min_count,word2id=training_set.word2id,max_sentence_length = config.max_sentence_length,batch_size=config.batch_size) test_iter = torch.utils.data.DataLoader(dataset=test_set,batch_size=config.batch_size,shuffle=False,num_workers=0)model = HAN_Model(vocab_size=len(training_set.word2id),embedding_size=config.embedding_size,gru_size = config.gru_size,class_num=config.class_num,weights=training_set.weight,is_pretrain=False)if config.cuda and torch.cuda.is_available(): # 如果使用gpu,将模型送进gpumodel.cuda()criterion = nn.CrossEntropyLoss() # 这里会做softmax optimizer = optim.Adam(model.parameters(), lr=config.learning_rate) loss = -1def get_test_result(data_iter,data_set):# 生成测试结果model.eval()true_sample_num = 0for data, label in data_iter:if config.cuda and torch.cuda.is_available():data = data.cuda()label = label.cuda()else:data = torch.autograd.Variable(data).long()if config.cuda and torch.cuda.is_available():out = model(data, gpu=True)else:out = model(data)true_sample_num += np.sum((torch.argmax(out, 1) == label).cpu().numpy())acc = true_sample_num / data_set.__len__()return accfor epoch in range(config.epoch):model.train()process_bar = tqdm(training_iter)for data, label in process_bar:if config.cuda and torch.cuda.is_available():data = data.cuda()label = label.cuda()else:data = torch.autograd.Variable(data).long()label = torch.autograd.Variable(label).squeeze()if config.cuda and torch.cuda.is_available():out = model(data,gpu=True)else:out = model(data)loss_now = criterion(out, autograd.Variable(label.long()))if loss == -1:loss = loss_now.data.item()else:loss = 0.95*loss+0.05*loss_now.data.item()process_bar.set_postfix(loss=loss_now.data.item())process_bar.update()optimizer.zero_grad()loss_now.backward()optimizer.step()test_acc = get_test_result(test_iter, test_set)print("The test acc is: %.5f" % test_acc) """ 配置文件-相关配置参数"""# -*- coding: utf-8 -*-# @Time : 2020/11/9 下午9:43 # @Author : TaoWang # @Description :import argparsedef ArgumentParser():parser = argparse.ArgumentParser()parser.add_argument('--embed_size', type=int, default=10, help="embedding size of word embedding")parser.add_argument("--epoch", type=int, default=200, help="epoch of training")parser.add_argument("--cuda", type=bool, default=True, help="whether use gpu")parser.add_argument("--gpu", type=int, default=2, help="gpu num")parser.add_argument("--learning_rate", type=float, default=0.001, help="learning rate during training")parser.add_argument("--batch_size", type=int, default=64, help="batch size during training")parser.add_argument("--seed", type=int, default=0, help="seed of random")parser.add_argument("--min_count", type=int, default=5, help="min count of words")parser.add_argument("--max_sentence_length", type=int, default=100, help="max sentence length")parser.add_argument("--embedding_size", type=int, default=200, help="word embedding size")parser.add_argument("--gru_size", type=int, default=50, help="gru size")parser.add_argument("--class_num", type=int, default=10, help="class num")return parser.parse_args()
总结
以上是生活随笔为你收集整理的Hierarchical Attention Networks for Document Classification(HAN)的全部内容,希望文章能够帮你解决所遇到的问题。
- 上一篇: Bag of Tricks for Ef
- 下一篇: SGM:Sequence Generat