基于BERT+LSTM+CRF的命名实体识别

定义参数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class CommonConfig:
bert_dir = "hfl/chinese-macbert-base"
output_dir = "./checkpoint/"
data_dir = "./data"

class NerConfig:
def __init__(self):
cf = CommonConfig()
self.bert_dir = cf.bert_dir
self.output_dir = cf.output_dir
# self.output_dir = os.path.join(self.output_dir)
if not os.path.exists(self.output_dir):
os.mkdir(self.output_dir)
self.data_dir = cf.data_dir

labels_list = ['O', 'B-HCCX', 'I-HCCX', 'B-HPPX', 'I-HPPX', 'B-XH', 'I-XH', 'B-MISC', 'I-MISC']
self.num_labels = len(labels_list)
self.label2id = {l: i for i, l in enumerate(labels_list)}
self.id2label = {i: l for i, l in enumerate(labels_list)}

self.max_seq_len = 512
self.epochs = 3
self.train_batch_size = 64
self.dev_batch_size = 64
self.bert_learning_rate = 3e-5
self.crf_learning_rate = 3e-3
self.adam_epsilon = 1e-8
self.weight_decay = 0.01
self.warmup_proportion = 0.01
self.save_step = 500

加载参数

1
2
3
4
args = NerConfig()

with open(os.path.join(args.output_dir, "ner_args.json"), "w") as fp:
json.dump(vars(args), fp, ensure_ascii=False, indent=2)

定义DataLoader

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class NerDataset(Dataset):  
def __init__(self, data, args, tokenizer):
self.data = data
self.args = args
self.tokenizer = tokenizer
self.label2id = args.label2id
self.max_seq_len = args.max_seq_len

def __len__(self):
return len(self.data)

def __getitem__(self, item):
examples=self.data[item]
tokenized_exmaples = self.tokenizer(examples["tokens"], max_length=self.max_seq_len,truncation=True,padding="max_length",
is_split_into_words=True)
word_ids = tokenized_exmaples.word_ids(batch_index=0)
label_ids = []
for word_id in word_ids:
if word_id is None:
label_ids.append(0)
else:
label_ids.append(examples["label"][word_id])
tokenized_exmaples["labels"] = label_ids

for k,v in tokenized_exmaples.items():
tokenized_exmaples[k]=torch.tensor(np.array(v))

return tokenized_exmaples

定义读取数据的方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# 读取数据方法
def read_data(file,label2id):
lists = []
with open(file, 'r') as f:
lines = f.readlines()
id = 0
tokens = []
ner_tags = []
ner_labels = []
for line in tqdm(lines):
line = line.replace("\n", "")
if len(line) == 0:
lists.append({
"id": str(id),
"tokens": tokens,
"ner_tags": ner_tags,
"label": ner_labels
})
tokens = []
ner_tags = []
ner_labels = []
id += 1
continue
texts = line.split(" ")
tokens.append(texts[0])
ner_tags.append(texts[1])
ner_labels.append(label2id[texts[1]])
return lists

处理数据

1
2
3
4
5
6
7
8
tokenizer = AutoTokenizer.from_pretrained(args.bert_dir)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_dataset = NerDataset(read_data("../origin/train.txt",args.label2id),args, tokenizer=tokenizer)
dev_dataset=NerDataset(read_data("../origin/dev.txt",args.label2id),args,tokenizer=tokenizer)
# test_dataset = NerDataset(read_data("../origin/test.txt"), tokenizer=tokenizer)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.train_batch_size, num_workers=2)
dev_loader = DataLoader(dev_dataset, shuffle=False, batch_size=args.dev_batch_size, num_workers=2)

定义模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class BertNer(nn.Module):  
def __init__(self, args):
super(BertNer, self).__init__()
self.bert = BertModel.from_pretrained(args.bert_dir)
self.bert_config = BertConfig.from_pretrained(args.bert_dir)
self.config = self.bert_config
hidden_size = self.bert_config.hidden_size
self.lstm_hiden = 128
self.max_seq_len = args.max_seq_len
self.bilstm = nn.LSTM(hidden_size, self.lstm_hiden, 1, bidirectional=True, batch_first=True,
dropout=0.1)
self.linear = nn.Linear(self.lstm_hiden * 2, args.num_labels)
self.crf = CRF(args.num_labels, batch_first=True)

def forward(self, input_ids, attention_mask, labels=None):
bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
seq_out = bert_output[0] # [batchsize, max_len, 768]
batch_size = seq_out.size(0)
seq_out, _ = self.bilstm(seq_out)
seq_out = seq_out.contiguous().view(-1, self.lstm_hiden * 2)
seq_out = seq_out.contiguous().view(batch_size, self.max_seq_len, -1)
seq_out = self.linear(seq_out)
logits = self.crf.decode(seq_out, mask=attention_mask.bool())
loss = None
if labels is not None:
loss = -self.crf(seq_out, labels, mask=attention_mask.bool(), reduction='token_mean')
return (logits, labels, loss)

加载模型

1
2
model = BertNer(args)
model.to(device)

定义Trainer

这里只介绍主要的训练和测试的方法,其他代码见后文代码地址

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class Trainer:   
def train(self):
global_step = 1
epoch_loss = 0
for epoch in range(1, self.epochs + 1):
for step, batch_data in enumerate(tqdm(self.train_loader)):
self.model.train()
for key, value in batch_data.items():
batch_data[key] = value.to(self.device)
input_ids = batch_data["input_ids"]
attention_mask = batch_data["attention_mask"]
labels = batch_data["labels"]
logits, _labels, loss = self.model(input_ids, attention_mask, labels)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.schedule.step()
global_step += 1
epoch_loss = loss
if global_step % self.save_step == 0:
torch.save(self.model.state_dict(), os.path.join(self.output_dir, "pytorch_model_ner.bin"))
print(f"【train】{epoch}/{self.epochs} {global_step}/{self.total_step} loss:{epoch_loss.item()}")

torch.save(self.model.state_dict(), os.path.join(self.output_dir, "pytorch_model_ner.bin"))

def test(self):
self.model.load_state_dict(torch.load(os.path.join(self.output_dir, "pytorch_model_ner.bin")))
self.model.eval()
preds = []
trues = []
for step, batch_data in enumerate(tqdm(self.test_loader)):
for key, value in batch_data.items():
batch_data[key] = value.to(self.device)
input_ids = batch_data["input_ids"]
attention_mask = batch_data["attention_mask"]
labels = batch_data["labels"]
logits, _labels, loss = self.model(input_ids, attention_mask, labels)
attention_mask = attention_mask.detach().cpu().numpy()
labels = labels.detach().cpu().numpy()

batch_size = input_ids.size(0)
for i in range(batch_size):
length = sum(attention_mask[i])
logit = logits[i][1:length]
logit = [self.id2label[i] for i in logit]
label = labels[i][1:length]
label = [self.id2label[i] for i in label]
preds.append(logit)
trues.append(label)

report = classification_report(trues, preds)
return report

使用Trainer传入参数

1
2
3
4
5
6
7
8
9
10
11
12
train = Trainer(
output_dir=args.output_dir,
model=model,
train_loader=train_loader,
dev_loader=dev_loader,
test_loader=dev_loader,
epochs=args.epochs,
device=device,
id2label=args.id2label,
t_total_num=len(train_loader) * args.epochs,
optimizer_args=args
)

训练

1
train.train()
【train】1/3 95/282 loss:0.32739126682281494
【train】2/3 189/282 loss:0.2356387972831726
【train】3/3 283/282 loss:0.21381179988384247

评估

1
2
report = train.test()
print(report)
precision recall f1-score support
HCCX 0.83 0.84 0.83
HPPX 0.78 0.79 0.78
MISC 0.76 0.81 0.79
XH 0.71 0.79 0.75
micro avg 0.81 0.83 0.82
macro avg 0.77 0.81 0.79
weighted avg 0.81 0.83 0.82

推理预测

预测函数见全部代码Predictor类

1
2
3
4
5
6
7
from predict import Predictor

predictor=Predictor()
text="川珍浓香型香菇干货去根肉厚干香菇500g热销品质抢购"
ner_result = predictor.ner_predict(text)
print("文本>>>>>:", text)
print("实体>>>>>:", ner_result)
1
2
文本>>>>>: 川珍浓香型香菇干货去根肉厚干香菇500g热销品质抢购
实体>>>>>: {'HPPX': [('川珍', 0, 1)], 'HCCX': [('香菇', 5, 6), ('干货', 7, 8), ('干香菇', 13, 15)], 'MISC': [('500g', 16, 19)]}

代码地址

基于BERT+LSTM+CRF的命名实体识别

更多内容