Source code for textflint.generation_layer.transformation.UT.swap_named_ent

r"""
SwapNamedEnt substitute class
==========================================================
"""

__all__ = ['SwapNamedEnt']

from ..transformation import Transformation
from ....common.settings import ENTITIES_PATH, CORENLP_ENTITY_MAP
from ....common.utils.load import json_loader
from ....common.utils.list_op import trade_off_sub_words
from ....common.utils.install import download_if_needed


[docs]class SwapNamedEnt(Transformation): r""" Swap entities with other entities of the same category. """
[docs] def __init__( self, entity_res=None, **kwargs ): r""" :param dict entity_res: dic of categories and their entities. """ super().__init__() entities_path = entity_res if entity_res else download_if_needed( ENTITIES_PATH) self.entities_dic = json_loader(entities_path)
def __repr__(self): return 'SwapNamedEnt' def _transform(self, sample, field='x', n=1, **kwargs): r""" Transform text string according transform_field. :param ~Sample sample: input data, normally one data component. :param str field: indicate which field to transform. :param int n: number of generated samples :param kwargs: :return list trans_samples: transformed sample list. """ trans_samples = [] entities_info = sample.get_ner(field) # replace sub strings by contractions indices, entities, categories = self.decompose_entities_info( entities_info) candidates = self._get_random_entities(categories, n) candidates, indices = trade_off_sub_words(candidates, indices, n=n) if not indices: return [] for i in range(len(candidates)): candidate = candidates[i] trans_samples.append( sample.unequal_replace_field_at_indices( field, indices, candidate)) return trans_samples
[docs] @staticmethod def decompose_entities_info(entities_info): r""" Decompose given entities and normalize entity tag to ['LOCATION', 'PERSON', 'ORGANIZATION'] Example:: [('Lionel Messi', 0, 2, 'PERSON'), ('Argentina', 7, 8, 'LOCATION')] >> [[0, 2], [7, 8]], ['Lionel Messi', 'Argentina'], ['PERSON', 'LOCATION'] :param dict entities_info: parsed by default ner component. :return list indices: indices :return list entities: entity values :return list categories: categories """ indices = [] entities = [] categories = [] for entity_info in entities_info: category = entity_info[3] if category in CORENLP_ENTITY_MAP: entities.append(entity_info[0]) indices.append([entity_info[1], entity_info[2]]) categories.append(CORENLP_ENTITY_MAP[category]) return indices, entities, categories
def _get_random_entities(self, categories, n): r""" Random generate entities of given categories :param list categories: :param int n: :return list rand_entities: indices and random entities respectively """ rand_entities = [] for i in range(len(categories)): if categories[i] not in self.entities_dic: rand_entities.append([]) else: rand_entities.append(self.sample_num( self.entities_dic[categories[i]], n)) return rand_entities