r"""
Swapping words by Mask Language Model
==========================================================
"""
__all__ = ['MLMSuggestion']
import torch
from copy import copy
from collections import defaultdict
from .back_trans import BackTrans
from ....common import device as default_device
from ...transformation.word_substitute import WordSubstitute
from ....common.settings import BERT_MODEL_NAME
from ....common.utils.list_op import trade_off_sub_words
[docs]class MLMSuggestion(WordSubstitute):
r"""
Transforms an input by replacing its tokens with words of mask language
predicted.
To accelerate transformation for long text, input single sentence to
language model rather than whole text.
"""
[docs] def __init__(
self,
masked_model=None,
device=None,
accrue_threshold=1,
max_sent_size=100,
trans_min=1,
trans_max=10,
trans_p=0.2,
stop_words=None,
**kwargs
):
r"""
:param str masked_model: masked language model to predicate candidates
:param str device: indicate utilize cpu or which gpu device to run
neural network
:param int accrue_threshold: threshold of Bert results to pick
:param max_sent_size: max_sent_size
:param int trans_min: Minimum number of character will be augmented.
:param int trans_max: Maximum number of character will be augmented.
If None is passed, number of augmentation is calculated via aup_char_p.
If calculated result from aug_p is smaller than aug_max, will use
calculated result from aup_char_p. Otherwise, using aug_max.
:param float trans_p: Percentage of character (per token) will be
augmented.
:param list stop_words: List of words which will be skipped from augment
operation.
"""
super().__init__(
trans_min=trans_min,
trans_max=trans_max,
trans_p=trans_p,
stop_words=stop_words
)
self.device = BackTrans.get_device(device) if device else default_device
self.max_sent_size = max_sent_size
self.get_pos = True
self.accrue_threshold = accrue_threshold
self.masked_model = masked_model if masked_model else BERT_MODEL_NAME
self.tokenizer = None
self.model = None
self.pos_allowed_token_id = None
def __repr__(self):
return 'MLMSuggestion'
[docs] def get_model(self):
r"""
Loads masked language model to predict candidates.
"""
from transformers import BertTokenizer, BertForMaskedLM
self.tokenizer = BertTokenizer.from_pretrained(
self.masked_model, do_lower_case=False)
self.model = BertForMaskedLM.from_pretrained(self.masked_model)
self.model.to(self.device)
self.model.eval()
[docs] def pre_calculate_allowed_tokens(self):
r"""
Precalculate meaningful tokens, filter tokens which is not an alphabetic
string.
Pre filter would accelerate procedure of verifying pos tags of
candidates.
"""
pos_to_token_id_dict = defaultdict(list)
bert_tokens = self.tokenizer.convert_ids_to_tokens(
list(range(self.tokenizer.vocab_size)))
bert_token_pos = self.processor.get_pos(bert_tokens)
for index, pos in enumerate(bert_token_pos):
word = bert_tokens[index]
if not word.isalpha() or len(word) == 1:
continue
pos_to_token_id_dict[pos[1][:2]].append(index)
for pos, indices in pos_to_token_id_dict.items():
pos_to_token_id_dict[pos] \
= torch.tensor(indices, dtype=torch.long, device=self.device)
return pos_to_token_id_dict
def _transform(self, sample, field='x', n=1, **kwargs):
r"""
Transform text string according field.
:param Sample sample: input data, normally one data component.
:param str field: indicate which field to apply transformation
:param int n: number of generated samples
:param kwargs:
:return list trans_samples: transformed sample list.
"""
if not self.pos_allowed_token_id:
self.get_model()
self.pos_allowed_token_id = self.pre_calculate_allowed_tokens()
tokens = sample.get_words(field)
tokens_mask = sample.get_mask(field)
# accelerate computation for long text
sentences_tokens = sample.get_sentences(field) if len(
tokens) > self.max_sent_size else [tokens]
# return up to (len(sub_indices) * n) candidates
pos_info = sample.get_pos(field)
legal_indices = self.skip_aug(tokens, tokens_mask)
if not legal_indices:
return []
sub_words, sub_indices = self._get_substitute_words(
tokens, legal_indices, sentences_tokens, pos=pos_info, n=n)
# select property candidates
sub_words, sub_indices = trade_off_sub_words(sub_words, sub_indices,
n=n)
if not sub_words:
return []
trans_samples = []
for i in range(len(sub_words)):
single_sub_words = sub_words[i]
trans_samples.append(
sample.replace_field_at_indices(
field, sub_indices, single_sub_words))
return trans_samples
def _get_substitute_words(self, words, legal_indices, sentences_tokens,
pos=None, n=5):
r"""
Returns a list containing all possible words .
Overwrite _get_substitute_words of super class.
To accelerate transformation for long text, input single sentence to
language model rather than whole text.
:param list words: all words
:param list legal_indices: indices which has not been skipped
:param list sentences_tokens: list of tokens of each sentence
:param list|None pos: None or list of pos tags
:param int n:max candidates for each word to be substituted
:return list candidates_list: list of candidates list
:return list candidates_indices: list of candidates_indices list
"""
sub_indices, sub_sentences, sub_sent_indices = \
self._get_relate_sub_info(words, sentences_tokens, legal_indices)
assert len(sub_sentences) == len(sub_indices)
candidates_list = []
candidates_indices = []
mask_word_pos_list = []
mask_indices = []
batch_tokens_tensor = torch.tensor(
[], dtype=torch.long, device=self.device)
batch_size = 0
for i, mask_word_index in enumerate(sub_indices):
mask_original_word_pos = pos[mask_word_index][:2]
if mask_original_word_pos not in self.pos_allowed_token_id:
continue
candidates_indices.append(mask_word_index)
mask_indices.append(sub_sent_indices[i] + 1)
mask_word_pos_list.append(mask_original_word_pos)
# process each legal words to get maximum transformed samples
sent_tokens = ['[CLS]'] + \
copy(sentences_tokens[sub_sentences[i]]) + ['[SEP]']
sent_tokens[sub_sent_indices[i] + 1] = '[MASK]'
# padding sentence to do batch predict
sent_tokens = sent_tokens + ['[PAD]'] * \
(self.max_sent_size - len(sent_tokens))
indexed_tokens = self.tokenizer.convert_tokens_to_ids(
sent_tokens[:self.max_sent_size])
tokens_tensor = torch.tensor(
[indexed_tokens], dtype=torch.long, device=self.device)
batch_tokens_tensor = torch.cat(
(batch_tokens_tensor, tokens_tensor))
batch_size += 1
if batch_size > 0:
segments_tensors = torch.zeros(
batch_size,
self.max_sent_size,
dtype=torch.int64,
device=self.device)
candidates_list = self._get_candidates(
batch_tokens_tensor,
segments_tensors,
mask_indices,
mask_word_pos_list,
n=n)
return candidates_list, candidates_indices
def _get_relate_sub_info(self, words, sentences_tokens, legal_indices):
sentences_indices = []
idx = 0
for sentence_tokens in sentences_tokens:
sentences_indices.append(
list(range(idx, idx + len(sentence_tokens))))
idx += len(sentence_tokens)
trans_num = self.get_trans_cnt(len(words))
sub_indices = sorted(self.sample_num(legal_indices, trans_num))
sub_sentences = []
sub_sent_indices = []
for sub_index in sub_indices:
for idx, sentence_indices in enumerate(sentences_indices):
if sub_index in sentence_indices:
sub_sentences.append(idx)
sub_sent_indices.append(sentence_indices.index(sub_index))
return sub_indices, sub_sentences, sub_sent_indices
def _get_candidates(self, batch_tokens_tensor, segments_tensors,
mask_indices,mask_word_pos_list, n=5):
with torch.no_grad():
output = self.model(batch_tokens_tensor, segments_tensors)
candidates_list = []
for i, tup in enumerate(zip(mask_indices, mask_word_pos_list)):
mask_index, mask_original_word_pos = tup
predict_tensor = output[0][i, mask_index]
allowed_token_id = self.pos_allowed_token_id[mask_original_word_pos]
pos_allowed_predict = predict_tensor.gather(0, allowed_token_id)
prob_values, topk_index = pos_allowed_predict.topk(
min(pos_allowed_predict.shape[0], n))
original_vocab_index = allowed_token_id.gather(0, topk_index)
replace_words = self.tokenizer.convert_ids_to_tokens(
original_vocab_index)
candidates_list.append(replace_words)
return candidates_list
def skip_aug(self, tokens, mask, pos=None):
return self.pre_skip_aug(tokens, mask)