r"""
Use Bert to generate words.
==========================================================
"""
__all__ = ["CnMLM"]
import torch
from transformers import BertTokenizer, BertForMaskedLM
from ..transformation import Transformation
from ....input_layer.component.field.cn_text_field import CnTextField
from ....common.preprocess.cn_processor import CnProcessor
from ....common.settings import ORIGIN, MODIFIED_MASK
[docs]class CnMLM(Transformation):
r"""
Use Bert to generate words.
Example::
小明喜欢看书 -> 小明喜欢看报纸
"""
[docs] def __init__(self, **kwargs):
r"""
:param **kwargs:
"""
super().__init__()
def __repr__(self):
return 'CnMLM'
def _transform(self, sample, n=1, **kwargs):
r"""
In this function, because there is only one deformation mode, only one
set of outputs is output.
:param ~textflint.CWSSample sample: the data which need be changed
:param **kwargs:
:return: trans_sample a list of sample
"""
# get sentence label and pos tag
origin_sentence = sample.get_value('x')
origin_label = sample.get_value('y')
pos_tags = sample.pos_tags
x, y, mask = self._get_transformations(
origin_sentence, origin_label, pos_tags, sample.mask)
if x == origin_sentence:
return []
x = CnTextField(x, mask)
return [sample.update(x, y)]
[docs] def create_word(self, sentence):
r"""
Crete the word we need
:param str sentence: the sentence with [MASK]
:return: the change sentence
"""
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
text = '[CLS] ' + sentence
tokenized_text = tokenizer.tokenize(text)
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
# Create the segments tensors.
segments_ids = [0] * len(tokenized_text)
# Convert inputs to PyTorch tensors
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])
# Load pre-trained model (weights)
model = BertForMaskedLM.from_pretrained('bert-base-chinese')
model.eval()
masked_index = tokenized_text.index('[MASK]')
masked_index1 = masked_index + 1
# Predict all tokens
with torch.no_grad():
predictions = model(tokens_tensor, segments_tensors)
predicted_index = torch.argmax(predictions[0][0][masked_index]).item()
predicted_index1 = torch.argmax(
predictions[0][0][masked_index1]).item()
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
predicted_token1 = tokenizer.convert_ids_to_tokens([predicted_index1])[
0]
# Determine whether the generated words meet the requirements
if len(predicted_token) != 1 or len(predicted_token1) != 1 or \
self.is_word(predicted_token + predicted_token1):
return ''
# Change the generated sentence
return predicted_token + predicted_token1
def _get_transformations(self, sentence, label, pos_tags, mask):
r"""
Generate word function.
:param str sentence: chinese sentence
:param list label: Chinese word segmentation tag
:param list pos_tags: sentence's pos tag
:return list: two list include the pos and labels which are changed
"""
assert len(sentence) == len(label)
cnt = 0
for i in range(len(pos_tags)):
tag, start, end = pos_tags[i]
start += cnt
end += cnt
# find the pos that can generate word
# Situation 1: v + single n
# we generate double n replace single n
if label[start] == 'B' and label[start + 1] == 'E' and \
i < len(pos_tags) - 1 and pos_tags[i][0] == 'v' \
and end == start + 1 and \
self.check_part_pos(sentence[start + 1]):
token = ''
for j in range(len(sentence)):
if j != start + 1:
token += sentence[j] + ' '
else:
token += '[MASK] [MASK] '
change = self.create_word(token)
if change != '':
if self.check(start, end, mask):
sentence = sentence[:start + 1] + \
change + sentence[start + 2:]
label = label[:start] + \
['S', 'B', 'E'] + label[start + 2:]
mask = mask[:start + 1] + \
[MODIFIED_MASK] * 2 + mask[start + 2:]
cnt += 1
start += 1
# Situation 1: n + n + n
# we generate double n replace single n and split one word into two
elif label[start:start + 3] == ['B', 'M', 'E'] and \
tag == 'n' and end - start == 2:
token = ''
start += 2
for i in range(len(sentence)):
if i != start:
token += sentence[i] + ' '
else:
token += '[MASK] [MASK] '
change = self.create_word(token)
if self.check(start, end, mask):
if change != '':
sentence = sentence[:start] + \
change + sentence[start + 1:]
label = label[:start - 1] + \
['E', 'B', 'E'] + label[start + 1:]
mask = mask[:start] + [MODIFIED_MASK] * \
2 + mask[start + 1:]
cnt += 1
start += 1
start += 1
return sentence, label, mask
@staticmethod
def is_word(sentence):
from ltp import LTP
r"""
Judge whether it is a word.
:param str sentence: input sentence string
sentence: input sentence string
:return bool: is a word or not
"""
if sentence[0] == sentence[1]:
return True
ltp = LTP()
seg, hidden = ltp.seg([sentence])
pos = ltp.pos(hidden)
pos = pos[0]
if len(pos) == 1 and pos[0] == 'n':
return False
return True
@staticmethod
def check(start, end, mask):
for i in range(start, end + 1):
if mask[i] != ORIGIN:
return False
return True
[docs] @staticmethod
def check_part_pos(sentence):
"""
get the pos of sentence if we need
:param str sentence: origin word
:return: bool
"""
if sentence == "":
return False
Processor = CnProcessor()
pos = Processor.get_pos_tag(sentence)
if len(pos) == 1 and pos[0][0] == 'n':
return True
return False