-
Notifications
You must be signed in to change notification settings - Fork 3
/
models.py
50 lines (47 loc) · 1.89 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# -*- coding: utf-8 -*-
# @Time : 2019/8/18 0:37
# @Author : kean
# @Email : ?
# @File : model.py
# @Software: PyCharm
import tensorflow as tf
from bert_base.modeling import BertModel
from layers import BLSTM_CRF
from tensorflow.contrib.layers.python.layers import initializers
def create_model(bert_config, is_training, input_ids, input_mask,
segment_ids, labels, num_labels, use_one_hot_embeddings,
dropout_rate=1.0, lstm_size=1, cell='lstm', num_layers=1):
"""
创建X模型
:param bert_config: bert 配置
:param is_training:
:param input_ids: 数据的idx 表示
:param input_mask:
:param segment_ids:
:param labels: 标签的idx 表示
:param num_labels: 类别数量
:param use_one_hot_embeddings:
:return:
"""
# 使用数据加载BertModel,获取对应的字embedding
model = BertModel(
config=bert_config,
is_training=is_training,
input_ids=input_ids,
input_mask=input_mask,
token_type_ids=segment_ids,
use_one_hot_embeddings=use_one_hot_embeddings
)
# 获取对应的embedding 输入数据[batch_size, seq_length, embedding_size]
embedding = model.get_sequence_output()
max_seq_length = embedding.shape[1].value
# 算序列真实长度
used = tf.sign(tf.abs(input_ids))
lengths = tf.reduce_sum(used, reduction_indices=1) # [batch_size] 大小的向量,包含了当前batch中的序列长度
# 添加CRF output layer
blstm_crf = BLSTM_CRF(embedded_chars=embedding, hidden_unit=lstm_size, cell_type=cell, num_layers=num_layers,
dropout_rate=dropout_rate, initializers=initializers, num_labels=num_labels,
seq_length=max_seq_length, labels=labels, lengths=lengths, is_training=is_training)
# 默认仅仅crf层
rst = blstm_crf.add_blstm_crf_layer(crf_only=True)
return rst