Source code for textflint.generation_layer.transformation.word_substitute

r"""
WordSubstitute Base Class
============================================
"""
__all__ = ["WordSubstitute"]
import string
from abc import abstractmethod

from ..transformation import Transformation
from ...common.settings import STOP_WORDS, ORIGIN
from ...common.utils.list_op import trade_off_sub_words


[docs]class WordSubstitute(Transformation): r""" Word replace transformation to implement normal word replace functions. """
[docs] def __init__( self, trans_min=1, trans_max=10, trans_p=0.1, stop_words=None, **kwargs ): r""" :param int trans_min: Minimum number of word will be augmented. :param int trans_max: Maximum number of word will be augmented. If None is passed, number of augmentation is calculated via aup_char_p. If calculated result from aug_p is smaller than aug_max, will use calculated result from aup_char_p. Otherwise, using aug_max. :param float trans_p: Percentage of word will be augmented. :param list stop_words: List of words which will be skipped from augment operation. :param ~textflint.common.preprocess.EnProcessor processor: :param bool get_pos: whether pass pos tag to _get_substitute_words API. """ super().__init__() self.trans_min = trans_min self.trans_max = trans_max self.trans_p = trans_p self.stop_words = STOP_WORDS if not stop_words else stop_words # set this value to avoid meaningless pos tagging self.get_pos = False
def _transform(self, sample, field='x', n=1, **kwargs): r""" Transform text string according field. :param dict sample: input data, normally one data component. :param str fields: indicate which field to apply transformation :param int n: number of generated samples :return list: transformed sample list. """ tokens = sample.get_words(field) tokens_mask = sample.get_mask(field) # return up to (len(sub_indices) * n) candidates pos_info = sample.get_pos(field) if self.get_pos else None legal_indices = self.skip_aug(tokens, tokens_mask, pos=pos_info) if not legal_indices: return [] sub_words, sub_indices = self._get_substitute_words( tokens, legal_indices, pos=pos_info, n=n) # select property candidates trans_num = self.get_trans_cnt(len(tokens)) sub_words, sub_indices = trade_off_sub_words( sub_words, sub_indices, trans_num, n) if not sub_words: return [] trans_samples = [] for i in range(len(sub_words)): single_sub_words = sub_words[i] trans_samples.append( sample.replace_field_at_indices( field, sub_indices, single_sub_words)) return trans_samples def _get_substitute_words(self, words, legal_indices, pos=None, n=5): r""" Returns a list containing all possible words . :param list words: all words :param list legal_indices: indices which has not been skipped :param None|list pos: None or list of pos tags :param int n: max candidates for each word to be substituted :return list: list of list """ # process each legal words to get maximum transformed samples legal_words = [words[index] for index in legal_indices] legal_words_pos = [pos[index] for index in legal_indices] if self.get_pos else None candidates_list = [] candidates_indices = [] for index, word in enumerate(legal_words): _pos = legal_words_pos[index] if self.get_pos else None candidates = self._get_candidates(word, pos=_pos, n=n) # filter no word without candidates if candidates: candidates_indices.append(legal_indices[index]) candidates_list.append( self._get_candidates( word, pos=_pos, n=n)) return candidates_list, candidates_indices @abstractmethod def _get_candidates(self, word, pos=None, n=5, **kwargs): r""" Returns a list containing all possible words . :param str word: :param str pos: the pos tag :return list: candidates list """ raise NotImplementedError
[docs] @abstractmethod def skip_aug(self, tokens, mask, pos=None): r""" Returns the index of the replaced tokens. :param list tokens: tokenized words or word with pos tag pairs :return list: the index of the replaced tokens """ raise NotImplementedError
[docs] def is_stop_words(self, token): r""" Judge whether the input word belongs to the stop words vocab. :param str token: the input word to be judged :return bool: is a stop word or not """ return self.stop_words is not None and token in self.stop_words
[docs] def pre_skip_aug(self, tokens, mask): r""" Skip the tokens in stop words list or punctuation list. :param list tokens: the list of tokens :param list mask: the list of mask Indicates whether each word is allowed to be substituted. ORIGIN is allowed, while TASK_MASK and MODIFIED_MASK is not. :return list: List of possible substituted token index. """ assert len(tokens) == len(mask) results = [] for token_idx, token in enumerate(tokens): # skip punctuation if token in string.punctuation: continue # skip stopwords by list if self.is_stop_words(token): continue if mask[token_idx] != ORIGIN: continue results.append(token_idx) return results
[docs] def get_trans_cnt(self, size): r""" Get the num of words/chars transformation. :param int size: the size of target sentence :return int: number of words to apply transformation. """ cnt = int(self.trans_p * size) if cnt < self.trans_min: return self.trans_min if self.trans_max is not None and cnt > self.trans_max: return self.trans_max return cnt
@staticmethod def token2chars(word): return list(word) @staticmethod def chars2token(chars): assert isinstance(chars, list) return ''.join(chars)