Source code for textflint.generation_layer.transformation.POS.multi_pos_swap

r"""
SwapMultiPOS transformation for POS tagging
==========================================================
"""
__all__ = ["SwapMultiPOS"]

from ....input_layer.component.sample import POSSample
from ...transformation import WordSubstitute


[docs]class SwapMultiPOS(WordSubstitute): r""" Word Swap by swaping words that have multiple POS tags in WordNet. """
[docs] def __init__(self, treebank_tag="JJ", trans_max=2, trans_p=1, **kwargs ): r""" :param treebank_tag: words with this pos tag will be replaced :param kwargs: """ super().__init__(trans_max, trans_p, **kwargs) self.treebank_tag = treebank_tag self.check_pos() self.wordnet_candidates = self.get_candidates_from_wordnet() self.get_pos = True
def __repr__(self): return 'SwapMultiPOS' + '-' + self.treebank_tag def check_pos(self): support_pos = ["JJ", "NN", "VB", "RB"] assert self.treebank_tag in support_pos, \ "Only support replacing JJ, NN, VB and RB!"
[docs] def get_candidates_from_wordnet(self): r""" get all possible multi-pos words with pos tags same as treebank_tag. :return: a list """ noun = set( [i for i in self.processor.get_all_lemmas(pos='n') if "_" not in i]) verb = set( [i for i in self.processor.get_all_lemmas(pos='v') if "_" not in i]) adj = set( [i for i in self.processor.get_all_lemmas(pos='a') if "_" not in i]) adv = set( [i for i in self.processor.get_all_lemmas(pos='r') if "_" not in i]) candidates = [] if self.treebank_tag == "NN": candidates = list(noun & (verb | adj | adv)) elif self.treebank_tag == "VB": candidates = list(verb & (noun | adj | adv)) elif self.treebank_tag == "JJ": candidates = list(adj & (verb | noun | adv)) elif self.treebank_tag == "RB": candidates = list(adv & (verb | adj | noun)) return candidates
def _get_candidates(self, word, pos=None, n=5): r""" Returns a list containing all possible words. :param word: str, the word to replace :param pos: str, the pos of the word to replace :param n: the number of returned words :return: a candidates list """ return self.sample_num(self.wordnet_candidates, n)
[docs] def skip_aug(self, tokens, mask, pos=None): r""" Returns the index of the replaced tokens. :param tokens: list, tokenized words or word with pos tag pairs :param mask: list, the mask symbol of the tokens :param pos: list, the pos tags of the tokens :return: list, the words at these indices that can be replaced """ assert pos is not None, "POS tag must be given!" results = [] indices = self.pre_skip_aug(tokens, mask) for index in indices: if pos[index] == self.treebank_tag: results.append(index) return results