Source code for textflint.generation_layer.transformation.RE.swap_birth

r"""
BirthSwap class for birth-related transformation
"""
from ....input_layer.component.sample import RESample
from ...transformation import Transformation

__all__ = ["SwapBirth"]


[docs]class SwapBirth(Transformation): r""" Entity position swap with paraphrase(birth related) """ def __init__( self, **kwargs ): super().__init__() def __repr__(self): return 'SwapBirth'
[docs] def generate_birth_for_root(self, idx, tokens, pos_tags, heads): r"""generate new sentence if birth is root :param int idx: the idx of word "bear" in sentence :param list tokens: tokens of the sentence :param list pos_tags: pos tagging labels of the sentence :param list heads: stanford heads :return string: transformed sentence """ new_sen = tokens[:idx - 1] v = -1 for pos in range(idx + 1, len(tokens)): if pos_tags[pos].startswith('V') and heads[pos] == idx + 1: v = pos if v < 0: return [] else: born = tokens[idx - 1:v] if tokens[v - 1] == 'and': born.pop() new_sen.extend([','] + tokens[v:-1] + [','] + born + ['.']) return new_sen
[docs] def generate_birth_for_clause(self, idx, tokens, deprels): r"""generate new sentence if birth is clause :param int idx: the idx of word "bear" in sentence :param list tokens: tokens of the sentence :param list deprels: stanford dependency relations :return string: transformed sentence """ root_id = -1 for k, r in enumerate(deprels): if r == 'ROOT': root_id = k if root_id < idx: return [] end = root_id - 1 while tokens[end] != ',': end = end - 1 born = tokens[idx:end] if idx - 1 >= 0 and tokens[idx - 1] == ',': new_sen = tokens[:idx - 1] + tokens[end + 1:] elif idx - 2 >= 0 and tokens[idx - 2] == 'who': if idx - 3 >= 0 and tokens[idx - 3] == ',': new_sen = tokens[:idx - 3] + tokens[end + 1:] else: new_sen = tokens[:idx - 2] + tokens[end + 1:] else: return [] new_sen = ["Born"] + born[1:] + [","] + new_sen return new_sen
[docs] def generate_new_sen_for_birth(self, idx, tokens, pos_tags, heads, deprels): r""" generate new sentence :param int idx: the idx of word "bear" in sentence :param list tokens: tokens of the sentence :param list pos_tags: pos tagging labels of the sentence :param list heads: stanford heads :param list deprels: stanford dependency relations :return string: transformed sentence """ assert (isinstance(idx, int)), \ f"the type of 'idx' should be int, got {type(idx)} instead" assert (isinstance(tokens, list)), \ f"the type of 'tokens' should be list, got {type(tokens)} instead" assert (isinstance(pos_tags, list)), \ f"the type of 'pos_tags' should be list, got " \ f"{type(pos_tags)} instead" assert (isinstance(heads, list)), \ f"the type of 'heads' should be list, got {type(heads)} instead" assert (isinstance(deprels, list)), \ f"the type of 'deprels' should be list, got {type(deprels)} instead" assert (len(tokens) == len(pos_tags) == len(heads) == len(deprels)), \ f"the length of the list inputs should be the same, got " \ f"{len(tokens), len(pos_tags), len(heads), len(deprels)} instead" assert (idx >= 0 and idx < len(tokens)), \ f"got invalid value of idx: {idx}" assert (isinstance(heads[0], int)), \ f"the type of each token in 'heads' " \ f"should be int, got {type(heads[0])} instead" assert (len([i for i in heads if i < 0 or i >= len(tokens)]) == 0), \ f"got invalid value of 'heads': {heads}" assert ("ROOT" in deprels), f"'ROOT' should be " \ f"included in 'deprels', got {deprels}" assert (0 in heads), f"0 should be included in 'heads', got {heads}" deprel = deprels[idx] ### if "born" is root of the sent, swap the position of born- and clause if deprel == 'ROOT': new_sen = self.generate_birth_for_root(idx, tokens, pos_tags, heads) else: new_sen = self.generate_birth_for_clause(idx, tokens, deprels) return new_sen
def _transform(self, sample, n=1, field='x', **kwargs): r""" Swap entity position through paraphrasing :param RESample sample: sample input :param int n: number of generated samples (no more than one) :return list: transformed sample list """ assert (isinstance(sample, RESample)), \ f"the type of 'sample' should be RESample, " \ f"got {type(sample)} instead" assert (isinstance(n, int)), f"the type of 'n' should be int, " \ f"got {type(n)} instead" words, relation = sample.get_sent() if 'birth' not in relation: return [sample] sh, st, oh, ot = sample.get_en() pos_tag = sample.get_pos() head_pos = [sh, st] subj = ' '.join(words[sh:st + 1]) obj = ' '.join(words[oh:ot + 1]) new_sample = {} new_sample['x'] = words if 'birth' in relation: deprels, heads = sample.get_dp() new_sen = [] words[sh:st + 1] = ["<SUBJ>"] * (st - sh + 1) words[oh:ot + 1] = ["<OBJ>"] * (ot - oh + 1) for i, word in enumerate(words): if word == "born": if head_pos[1] < i: new_sen = self.generate_new_sen_for_birth( i, words, pos_tag, heads, deprels) if new_sen: new_sh = new_sen.index("<SUBJ>") new_st = new_sh + (st - sh) new_oh = new_sen.index("<OBJ>") new_ot = new_oh + (ot - oh) sh, st, oh, ot = new_sh, new_st, new_oh, new_ot new_sen[new_sh:new_st + 1] = subj.split() new_sen[new_oh:new_ot + 1] = obj.split() new_sample['x'] = new_sen new_sample['subj'], new_sample['obj'], new_sample['y'] = \ [sh, st], [oh, ot], relation trans_samples = sample.replace_sample_fields(new_sample) return [trans_samples]