Source code for textflint.generation_layer.transformation.RE.swap_ent

r"""
EntitySwap class for entity swap
"""
__all__ = ["SwapEnt"]
import random
from ....common.settings import LOWFREQ, MULTI_TYPE, TYPES
from ....common.utils.install import download_if_needed
from ....common.utils.load import json_loader
from ...transformation import Transformation
from ....input_layer.component.sample import RESample


[docs]class SwapEnt(Transformation): r""" Replace entity mention with entity with same entity types """ def __init__( self, type='lowfreq', **kwargs ): super().__init__() self.type = type if type == 'lowfreq': self.type_dict = json_loader(download_if_needed(LOWFREQ)) elif type == 'multitype': self.type_dict = json_loader(download_if_needed(MULTI_TYPE)) elif type == 'sametype': self.type_dict = json_loader(download_if_needed(TYPES)) else: raise ValueError('illegal type name') def __repr__(self): return self.type
[docs] def replace_en(self, types, index, token): r""" replace entity with random token span :param str types: entity type :param list index: entity index [start, end] :param list token: tokenized sentence :return Tuple(list, int): new sentence and \ number of new entity words greater than old entity words """ assert(isinstance(types, str)), \ f"the type of 'type' should be string, got " \ f"{type(types)} instead" assert(isinstance(index, list)), f"the type of 'index' " \ f"should be list, got " \ f"{type(index)} instead" assert(isinstance(token, list)), f"the type of 'token' " \ f"should be list, got " \ f"{type(token)} instead" assert(len(index)==2), f"the length of index " \ f"should be two, got length {len(index)} instead" assert(index[0]>=0 and index[1]<len(token)), \ f"elements of index should not be negative or longer than " \ f"the length of token, got input " \ f"{index[0]}<0 or {index[1]}>= {len(token)}" length = 0 if types in self.type_dict.keys(): new_subj = random.choice(self.type_dict[types]) token_before, token_after = token[:index[0]], token[index[1] + 1:] token = token_before + new_subj.split(" ") + token_after length = len(new_subj.split(" ")) - (index[1] - index[0] + 1) return token, length
[docs] def subj_and_obj_transform(self, sample, n, entity): r""" transform both subject and object entities :param RESample sample: re_sample input :param int n: number of generated samples :return list: transformed sample list """ assert(isinstance(n, int)), \ f"the type of 'n' should be int, got {type(n)} instead" assert(isinstance(entity, list)), \ f"the type of 'entity' should be list, got {type(entity)} instead" assert(isinstance(sample, RESample)), \ f"the type of 'sample' should be RESample, got " \ f"{type(sample)} instead" assert(len(entity) == 6), \ f"the length of entity should be 6, got input length of " \ f"{len(entity)} instead" assert(entity[0]<=entity[1] and entity[2]<=entity[3]), \ f"start index of entity should not be greater than end index, got \ {entity[0]}>{entity[1]} or {entity[2]}>{entity[3]} instead" assert(isinstance(entity[4], str) and isinstance(entity[5], str)), \ f"last two elements of entity should be string, " \ f"got type {entity[4]} and {entity[5]} instead" assert(entity[4] in self.type_dict.keys() and entity[5] in self.type_dict.keys()), \ f"both entity types should be in the type dict, type " \ f"{entity[4]} and {entity[5]} do not satisfy this requirement" trans_samples = [] for i in range(n): sh, st, oh, ot, subj_type, obj_type = entity token, relation = sample.get_sent() trans_sample = {} if sh < oh: token, l1 = self.replace_en(subj_type, [sh, st], token) st, oh, ot = st + l1, oh + l1, ot + l1 token, l2 = self.replace_en(obj_type, [oh, ot], token) ot = ot + l2 elif oh < sh: token, l1 = self.replace_en(obj_type, [oh, ot], token) ot, sh, st = ot + l1, sh + l1, st + l1 token, l2 = self.replace_en(subj_type, [sh, st], token) st = st + l2 else: token, l = self.replace_en(subj_type, [sh, st], token) st, ot = st + l, ot + l trans_sample['x'] = token trans_sample['subj'], trans_sample['obj'], trans_sample['y'] = \ [sh, st], [oh, ot], relation new_samples = sample.replace_sample_fields(trans_sample) trans_samples.append(new_samples) return trans_samples
[docs] def single_transform(self, sample, n, entity): r""" transform subject or object entity :param RESample sample: re_sample input :param int n: number of generated samples :return list: transformed sample list """ assert(isinstance(n, int)), \ f"the type of 'n' should be int, got {type(n)} instead" assert(isinstance(entity, list)), \ f"the type of 'entity' should be list, got {type(entity)} instead" assert(isinstance(sample, RESample)), \ f"the type of 'sample' should be RESample, got " \ f"{type(sample)} instead" assert(len(entity) == 6), \ f"the length of entity should be 6, got input length of " \ f"{len(entity)} instead" assert(entity[0]<=entity[1] and entity[2]<=entity[3]), \ f"start index of entity should not be greater than end index, got \ {entity[0]}>{entity[1]} or {entity[2]}>{entity[3]} instead" assert(isinstance(entity[4], str) and isinstance(entity[5], str)), \ f"last two elements of entity should be string, got type " \ f"{entity[4]} and {entity[5]} instead" assert(entity[4] not in self.type_dict.keys() or entity[5] not in self.type_dict.keys()), f"only one entity type should be " \ f"in the type dict, " \ f"{entity[4]} and {entity[5]} " \ f"does not satisfy this requirement" trans_samples = [] for i in range(n): sh, st, oh, ot, subj_type, obj_type = entity token, relation = sample.get_sent() trans_sample = {} if subj_type in self.type_dict.keys(): token, l1 = self.replace_en(subj_type, [sh, st], token) st = st + l1 if sh < oh: oh, ot = oh + l1, ot + l1 elif sh == oh: ot = ot + l1 else: token, l1 = self.replace_en(obj_type, [oh, ot], token) ot = ot + l1 if sh > oh: sh, st = sh + l1, st + l1 elif sh == oh: st = st + l1 trans_sample['x'] = token trans_sample['subj'], trans_sample['obj'], trans_sample['y'] = \ [sh, st], [oh, ot], relation new_samples = sample.replace_sample_fields(trans_sample) trans_samples.append(new_samples) return trans_samples
def _transform(self, sample, n=1, **kwargs): r""" Transform text string according to its entities. :param RESample sample: re_sample input :param int n: number of generated samples :return list: transformed sample list """ assert(isinstance(sample, RESample)), \ f"the type of 'sample' should be RESample, " \ f"got {type(sample)} instead" assert(isinstance(n, int)), \ f"the type of 'n' should be int, got {type(n)} instead" sh, st, oh, ot = sample.get_en() subj_type, obj_type, _ = sample.get_type() if subj_type in self.type_dict.keys() and obj_type in \ self.type_dict.keys(): return self.subj_and_obj_transform( sample, n, [sh, st, oh, ot ,subj_type, obj_type]) elif subj_type not in self.type_dict.keys() or \ obj_type not in self.type_dict.keys(): return self.single_transform( sample, n, [sh, st, oh, ot ,subj_type, obj_type]) return [sample] * n