Source code for textflint.generation_layer.transformation.RE.swap_age

import re
from ...transformation import Transformation
from ....input_layer.component.sample.re_sample import RESample

r"""
AgeSwap class for age-related transformation 
"""
__all__ = ["SwapAge"]


[docs]class SwapAge(Transformation): r""" Entity position swap with paraphrase(age related) """ def __init__( self, **kwargs ): super().__init__() def __repr__(self): return 'SwapAge' def _transform(self, sample, n=1, field='x', **kwargs): r""" Swap entity position through paraphrasing :param RESample sample: sample input :param int n: number of generated samples (no more than one) :return list: transformed sample list """ assert(isinstance(sample, RESample)), \ f"the type of 'sample' should be RESample, got " \ f"{type(sample)} instead" assert(isinstance(n, int)), f"the type of 'n' should be int, " \ f"got {type(n)} instead" regex = re.compile(r', (\d+) ,') words, relation = sample.get_sent() if 'age' not in relation: return [sample] ss, se, os, oe = sample.get_en() subj_type, obj_type, _ = sample.get_type() subj_words = words[ss:se+1] obj_words = words[os:oe+1] deprel, head = sample.get_dp() new_text = None trans_sample = {} if 'age' in relation and 'PERSON' in subj_type: match_obj = regex.match(' '.join(words[se + 1:se + 4])) if match_obj: if deprel[ss] == 'nsubj': # find verb words[ss:se+1] = ['[HEAD]']*len(subj_words) words[os:oe+1] = ['[TAIL]']*len(obj_words) verb_id = head[ss] for i in range(len(words) - 1, verb_id - 1, -1): if head[i] == verb_id and deprel[i] == 'punct': break #age = match_obj.group(1) new_text = words[:se + 1] + words[se + 4:i] new_text.append(',') new_text.extend('and he was [TAIL] year old ' 'at that time .'.split(' ')) new_text.extend(words[i:]) if new_text: ss, os = new_text.index('[HEAD]'), new_text.index('[TAIL]') se, oe = ss+new_text.count('[HEAD]')-1, \ os+new_text.count('[TAIL]')-1 new_text[ss:se+1] = subj_words new_text[os:oe+1] = obj_words trans_sample['x'] = new_text else: trans_sample['x'] = words trans_sample['subj'], trans_sample['obj'], trans_sample['y'] = \ [ss, se], [os, oe], relation trans_samples = sample.replace_sample_fields(trans_sample) return [trans_samples]