Chinese Relation Extraction by BiGRU with Character and Sentence Attentions之代码理解
代码链接为 https://github.com/crownpku/Information-Extraction-Chinese/tree/master/RE_BGRU_2ATT 。
1. initial.py
我们首先来分析initial文件。该文件代码量为300多行,难度不大。
核心的变量如下所示:
- vec,数据类型为list of list,其中list元素为单个词的词嵌入表示,并按照按照词典的顺序进行排列。后续转换为np.array类型。vec对应的wordembedding。
- word2id,数据类型为字典,其中每个词对应的id,其中key为word,id为value。
- relation2id,数据类型为字典,key为关系(中文),value为id。
- train_sen,数据类型为字典,key为实体对(entity1, entity2),value是list of list of list(本质上为4层嵌套,最后一层是每个句子中对应词),其中数据格式为{entity pair:[[[label1-sentence 1],[label1-sentence 2]…],[[label2-sentence 1],[label2-sentence 2]…]},其中label1-sentence 1指的是[[wordid-1, rel_e1-1, rel_e2-1], [wordid-2, rel_e1-2, rel_e2-2], …]
- train_ans,数据类型为字典,key也为实体对,value是list of list,其中list元素为label的one-hot表示。
- train_y,是list of list,每个list是12个label构成的one-hot向量。
- train_word,是array嵌套的两层list。其中两层list为每个label对应下的句子的词嵌入id的序列。len(train_word)表示的是有多少个实体对。对应每个实体对每个label-sentence中的每个词。
train_word.shape = 967
train_word[0]
[[101, 897, 1044, 1072, 839, 47, 1841, 414, 252, 213, 1072, 1120, 16115, 265, 16115, 1072, 1120, 19, 1072, 169, 1, 1392, 1392, 16115, 252, 213, 1, 108, 99, 1230, 92, 2, 16116, 16116, 16116, 16116, 16116, 16116, 16116, 16116, 16116, 16116, 16116, 16116, 16116, 16116, 16116, 16116, 16116, 16116, 16116, 16116, 16116, 16116, 16116, 16116, 16116, 16116, 16116, 16116, 16116, 16116, 16116, 16116, 16116, 16116, 16116, 16116, 16116, 16116]] - feed_dict[m.input_word] = total_word
train_step(temp_word, temp_pos1, temp_pos2, temp_y, settings.big_num),所以temp_word等价于word_batch,
temp_input = temp_order[i * settings.big_num:(i + 1) * settings.big_num] for k in temp_input: temp_word.append(train_word[k]) #随机选big_num个实体对的train_word即为temp_word for i in range(len(word_batch)): # 每个batch里面有多少个实体对 total_shape.append(total_num) # 【0, 3, 5, 10, 20, 30】 total_num += len(word_batch[i]) for word in word_batch[i]: total_word.append(word) #label-sentence的word index,total_word是二维的,第二维维度为70 feed_dict[m.input_word] = total_word - big_num,指的是每个batch里面用多少个实体对做训练。
- total_shape,指的是在该batch中,每个元素为前一个数加上每个实体对对应有几个label-sentence。它的索引代表的是第几个实体。
- total_num,指的是在该batch中,所有实体对的label-sentence的总个数。它是total_shape的最后一项。
1.1 init函数
下图为vec.txt的部分数据,其中第一行第一个数为单词的个数,第二个是词嵌入的维度。后面的行最左边的是词,后面的是词对应的词向量或者词嵌入。

首先只读取第一行,得到词嵌入的维度。后续通过while循环,得到词嵌入表示。其中strip默认是对空白符进行分隔。
f = open('./origin_data/vec.txt', encoding='utf-8') content = f.readline() content = content.strip().split() dim = int(content[1]) #embedding dimension of each word,在这里指的是100 while True: content = f.readline() if content == '': break content = content.strip().split() word2id[content[0]] = len(word2id) #插入字典,这是由于python语言默认是从0开始的,代表区间为[0, len(word2id) - 1],所以len(word2id)不在id的顺序列中。 content = content[1:] #取向量 content = [(float)(i) for i in content] #类型转换 vec.append(content) #每个词的词向量表示 f.close() 添加UNK和BLANK两个词向量,对应的物理含义是未知和填充。
word2id['UNK'] = len(word2id) #添加UNK,unknown word2id['BLANK'] = len(word2id) #BLANK指的是padding vec.append(np.random.normal(size=dim, loc=0, scale=0.05)) #正态分布初始化UNK vec.append(np.random.normal(size=dim, loc=0, scale=0.05)) #正态分布初始化BLANK vec = np.array(vec, dtype=np.float32) #数据类型转换 读取relation2id:

relation2id = {} f = open('./origin_data/relation2id.txt', 'r', encoding='utf-8') while True: content = f.readline() if content == '': break content = content.strip().split() relation2id[content[0]] = int(content[1]) f.close() 设置句子的最大长度和词距离实体位置的向量
# length of sentence is 70 fixlen = 70 # max length of position embedding is 60 (-60~+60) maxlen = 60 train_sen = {} # {entity pair:[[[label1-sentence 1],[label1-sentence 2]...],[[label2-sentence 1],[label2-sentence 2]...]} train_ans = {} # {entity pair:[[label1],[label2],...]} the label is one-hot vector 读取训练数据,即得到train_sen和train_ans。

# find the index of x in y, if x not in y, return -1 def find_index(x, y): flag = -1 for i in range(len(y)): if x != y[i]: continue else: return i return flag print('reading train data...') f = open('./origin_data/train.txt', 'r', encoding='utf-8') while True: content = f.readline() if content == '': break content = content.strip().split() # get entity name en1 = content[0] en2 = content[1] relation = 0 if content[2] not in relation2id: relation = relation2id['NA'] #这里不如使用assert语句 else: relation = relation2id[content[2]] #得到数字化后的relation # put the same entity pair sentences into a dict tup = (en1, en2) label_tag = 0 if tup not in train_sen: #如果实体对不在train_sen的key中 train_sen[tup] = [] train_sen[tup].append([]) #构建空的list of list即为{(entity1, entity2):[[]]} y_id = relation label_tag = 0 label = [0 for i in range(len(relation2id))] label[y_id] = 1 #这两步本质上是one-hot表示。 train_ans[tup] = [] train_ans[tup].append(label) #在train_ans中添加one-hot label else:#如果实体对在train_sen的key中 y_id = relation label_tag = 0 label = [0 for i in range(len(relation2id))] label[y_id] = 1 temp = find_index(label, train_ans[tup]) #寻找label在train_ans[tup]中位置,由于是复合元素,所以不能直接使用index函数。 if temp == -1: #如果没有寻找到 train_ans[tup].append(label) #添加label label_tag = len(train_ans[tup]) - 1 #label_tag表示添加的label所处list中的index train_sen[tup].append([]) #train_sen再添加一个[] else: label_tag = temp sentence = content[3] en1pos = 0 en2pos = 0 #For Chinese en1pos = sentence.find(en1) if en1pos == -1: en1pos = 0 en2pos = sentence.find(en2) if en2pos == -1: en2post = 0 output = [] #output也为list of list,其中list元素为三元组,(word, rel_e1, rel_e2),其中word为词嵌入表示,后两者为词距离实体位置的嵌入。 #Embeding the position for i in range(fixlen): #句子的最大长度,从而实现了短的句子填充,长的句子截断。 word = word2id['BLANK'] #类似于初始化一个列表,让所有元素为0 rel_e1 = pos_embed(i - en1pos) rel_e2 = pos_embed(i - en2pos) output.append([word, rel_e1, rel_e2]) for i in range(min(fixlen, len(sentence))): word = 0 if sentence[i] not in word2id: word = word2id['UNK'] else: word = word2id[sentence[i]] output[i][0] = word train_sen[tup][label_tag].append(output) #[[]].append(句子的词嵌入的vector),如果label已经存在,则在已有的列表中进行append操作。 读取test.txt,操作和之前类似,就不再赘述。注意train_x、train_y和test_x、test_y的区别。

test_sen = {} # {entity pair:[[sentence 1],[sentence 2]...]} test_ans = {} # {entity pair:[labels,...]} the labels is N-hot vector (N is the number of multi-label) f = open('./origin_data/test.txt', 'r', encoding='utf-8') while True: content = f.readline() if content == '': break content = content.strip().split() en1 = content[0] en2 = content[1] relation = 0 if content[2] not in relation2id: relation = relation2id['NA'] else: relation = relation2id[content[2]] tup = (en1, en2) if tup not in test_sen: test_sen[tup] = [] y_id = relation label_tag = 0 label = [0 for i in range(len(relation2id))] label[y_id] = 1 test_ans[tup] = label else: y_id = relation test_ans[tup][y_id] = 1 sentence = content[3] en1pos = 0 en2pos = 0 #For Chinese en1pos = sentence.find(en1) if en1pos == -1: en1pos = 0 en2pos = sentence.find(en2) if en2pos == -1: en2post = 0 output = [] for i in range(fixlen): word = word2id['BLANK'] rel_e1 = pos_embed(i - en1pos) rel_e2 = pos_embed(i - en2pos) output.append([word, rel_e1, rel_e2]) for i in range(min(fixlen, len(sentence))): word = 0 if sentence[i] not in word2id: word = word2id['UNK'] else: word = word2id[sentence[i]] output[i][0] = word test_sen[tup].append(output) train_x = [] #把字典进行列表表示 train_y = [] test_x = [] test_y = [] print('organizing train data') f = open('./data/train_q&a.txt', 'w', encoding='utf-8') temp = 0 for i in train_sen: if len(train_ans[i]) != len(train_sen[i]): #label对应的长度都一致 print('ERROR') lenth = len(train_ans[i]) #实体对中共有多少种关系 for j in range(lenth): train_x.append(train_sen[i][j]) #append的是list of list,类似于[[label1-sentence 1],[label1-sentence 2]...] train_y.append(train_ans[i][j]) f.write(str(temp) + '\t' + i[0] + '\t' + i[1] + '\t' + str(np.argmax(train_ans[i][j])) + '\n') # str(np.argmax(train_ans[i][j]))表示的第几个关系 temp += 1 f.close() print('organizing test data') f = open('./data/test_q&a.txt', 'w', encoding='utf-8') temp = 0 for i in test_sen: test_x.append(test_sen[i]) test_y.append(test_ans[i]) tempstr = '' for j in range(len(test_ans[i])): if test_ans[i][j] != 0: tempstr = tempstr + str(j) + '\t' f.write(str(temp) + '\t' + i[0] + '\t' + i[1] + '\t' + tempstr + '\n') temp += 1 f.close() train_x = np.array(train_x) train_y = np.array(train_y) test_x = np.array(test_x) test_y = np.array(test_y) np.save('./data/vec.npy', vec) np.save('./data/train_x.npy', train_x) np.save('./data/train_y.npy', train_y) np.save('./data/testall_x.npy', test_x) np.save('./data/testall_y.npy', test_y) def seperate(): print('reading training data') x_train = np.load('./data/train_x.npy') train_word = [] train_pos1 = [] train_pos2 = [] print('seprating train data') for i in range(len(x_train)):#每个实体对 word = [] pos1 = [] pos2 = [] for j in x_train[i]: #每个实体对下面的每个句子 temp_word = [] temp_pos1 = [] temp_pos2 = [] for k in j: #每个实体对下面的每个单词,其中共有70个单词 temp_word.append(k[0]) temp_pos1.append(k[1]) temp_pos2.append(k[2]) word.append(temp_word) pos1.append(temp_pos1) pos2.append(temp_pos2) train_word.append(word) train_pos1.append(pos1) train_pos2.append(pos2) train_word = np.array(train_word) train_pos1 = np.array(train_pos1) train_pos2 = np.array(train_pos2) np.save('./data/train_word.npy', train_word) np.save('./data/train_pos1.npy', train_pos1) np.save('./data/train_pos2.npy', train_pos2) print('seperating test all data') x_test = np.load('./data/testall_x.npy') test_word = [] #为什么元素是object类型呢?数据长度不一致 test_pos1 = [] test_pos2 = [] for i in range(len(x_test)): word = [] pos1 = [] pos2 = [] for j in x_test[i]: temp_word = [] temp_pos1 = [] temp_pos2 = [] for k in j: temp_word.append(k[0]) temp_pos1.append(k[1]) temp_pos2.append(k[2]) word.append(temp_word) pos1.append(temp_pos1) pos2.append(temp_pos2) test_word.append(word) test_pos1.append(pos1) test_pos2.append(pos2) test_word = np.array(test_word) test_pos1 = np.array(test_pos1) test_pos2 = np.array(test_pos2) np.save('./data/testall_word.npy', test_word) np.save('./data/testall_pos1.npy', test_pos1) np.save('./data/testall_pos2.npy', test_pos2) # get answer metric for PR curve evaluation def getans(): test_y = np.load('./data/testall_y.npy') eval_y = [] for i in test_y: #去除了unknown eval_y.append(i[1:]) # allans = np.reshape(eval_y, (-1)) np.save('./data/allans.npy', allans) def get_metadata(): #逐行写入word fwrite = open('./data/metadata.tsv', 'w', encoding='utf-8') f = open('./origin_data/vec.txt', encoding='utf-8') f.readline() while True: content = f.readline().strip() if content == '': break name = content.split()[0] fwrite.write(name + '\n') f.close() fwrite.close()