Source code for textflint.input_layer.component.sample.dp_sample

r"""
DP Sample Class
============================================

"""

from .sample import Sample
from ..field import ListField, TextField
from ....common.utils.list_op import get_align_seq, normalize_scope

__all__ = ["DPSample"]


[docs]class DPSample(Sample): r""" DP Sample class to hold the data info and provide atomic operations. """ def __repr__(self): return 'DPSample' def check_data(self, data): assert 'word' in data and isinstance(data['word'], list), \ "Word should be in data, and the type of word should be list" assert 'postag' in data and isinstance(data['postag'], list), \ "Postag should be in data, and the type of postag should be list" assert 'head' in data and isinstance(data['head'], list), \ "Head should be in data, and the type of head should be list" assert 'deprel' in data and isinstance(data['deprel'], list), \ "Deprel should be in data, and the type of deprel should be list"
[docs] def load(self, data): r""" Convert data dict to DPSample and get matched brackets. :param dict data: contains 'word', 'postag', 'head', 'deprel' keys. """ assert data['deprel'][-1] == 'punct', \ "The sentence should end with a punctuation." words = data['word'] self.x = TextField(words) self.postag = ListField(data['postag']) self.head = ListField(data['head']) self.deprel = ListField(data['deprel']) try: left_bras = [] match_bracket = [] for i, word in enumerate(words): if word == '-LRB-': left_bras.append(i + 1) if word == '-RRB-': match_bracket.append((left_bras[-1], i + 1)) left_bras.pop(-1) except IndexError: raise TypeError('Missing matched brackets.') else: if left_bras: raise TypeError('Missing matched brackets.') else: self.brackets = match_bracket if not self.is_legal(): raise ValueError("The postag data is not aligned with words.")
def dump(self): if not self.is_legal(): raise ValueError("The postag data is not aligned with words.") return {'word': self.x.words, 'postag': self.postag.field_value, 'head': self.head.field_value, 'deprel': self.deprel.field_value, 'sample_id': self.sample_id }
[docs] def insert_field_after_indices(self, field, indices, items): r""" Insert items of multi given scopes before indices of field value at the same time. :param str field: Only value 'x' supported. :param list indices: shape:indices_num :param list items: shape: indices_num, correspond to indices :return ~DPSample: The sentence with words added. """ assert isinstance(indices, list) assert isinstance(items, list) assert len(indices) == len(items) sample = self for index, new_item in enumerate(items): sample = sample.insert_field_after_index( field, indices[index], new_item) return sample
[docs] def insert_field_after_index(self, field, ins_index, new_item): r""" Insert given data after the given index. :param str field: Only value 'x' supported. :param int ins_index: The index where the word will be inserted after. :param str new_item: The word to be inserted. :return ~DPSample: The sentence with one word added. """ assert field == 'x' sample = self.clone(self) sample = super( DPSample, sample).insert_field_after_indices( field, [ins_index], [new_item]) fields_to_values = {'postag': 'UNK', 'head': '0', 'deprel': 'unk'} for _field, value in fields_to_values.items(): field_obj = getattr(sample, _field) insert_values = get_align_seq([new_item], value)[0] rep_obj = field_obj.insert_after_index(ins_index, insert_values) setattr(sample, _field, rep_obj) head_obj = sample.get_value('head') rep_obj = sample.head bias_len = len(new_item) if isinstance(new_item, list) else 1 for i, head in enumerate(head_obj): head_id = int(head) if head_id > ins_index + 1: rep_obj = rep_obj.replace_at_index(i, str(head_id + bias_len)) setattr(sample, 'head', rep_obj) return sample
[docs] def insert_field_before_indices(self, field, indices, items): r""" Insert items of multi given scopes before indices of field value at the same time. :param str field: Only value 'x' supported. :param list indices: shape:indices_num :param list items: shape: indices_num, correspond to indices :return ~DPSample: The sentence with words added. """ assert isinstance(indices, list) assert isinstance(items, list) assert len(indices) == len(items) sample = self.clone(self) for index, new_item in enumerate(items): sample = sample.insert_field_before_index( field, indices[index], new_item) return sample
[docs] def insert_field_before_index(self, field, ins_index, new_item): r""" Insert given data before the given position. :param str field: Only value 'x' supported. :param int ins_index: The index where the word will be inserted after. :param str new_item: The word to be inserted. :return ~DPSample: The sentence with one word added. """ assert field == 'x' sample = self.clone(self) sample = super( DPSample, sample).insert_field_before_indices( field, [ins_index], [new_item]) fields_to_values = {'postag': 'UNK', 'head': '0', 'deprel': 'unk'} for _field, value in fields_to_values.items(): field_obj = getattr(sample, _field) insert_values = get_align_seq([new_item], value)[0] rep_obj = field_obj.insert_before_index(ins_index, insert_values) setattr(sample, _field, rep_obj) head_obj = sample.get_value('head') rep_obj = sample.head bias_len = len(new_item) if isinstance(new_item, list) else 1 for i, head in enumerate(head_obj): head_id = int(head) if head_id > ins_index: rep_obj = rep_obj.replace_at_index(i, str(head_id + bias_len)) setattr(sample, 'head', rep_obj) return sample
[docs] def delete_field_at_indices(self, field, indices): r""" Delete items of given scopes of field value. :param str field: Only value 'x' supported. :param list indices: shape:indices_num each index can be int indicate replace single item or their list like [1, 2, 3], can be list like (0,3) indicate replace items from 0 to 3(not included), can be slice which would be convert to list. :return ~DPSample: The sentence with words deleted. """ assert isinstance(indices, list) sample = self.clone(self) for index in indices: sample = sample.delete_field_at_index(field, index) return sample
[docs] def delete_field_at_index(self, field, del_index): r""" Delete data at the given position. :param str field: Only value 'x' supported. :param int|list|slice del_index: can be int indicate replace single item or their list like [1, 2, 3], can be list like (0,3) indicate replace items from 0 to 3(not included), can be slice which would be convert to list. :return ~DPSample: The sentence with one word deleted. """ assert field == 'x' del_scope = normalize_scope(del_index) del_len = del_scope[1] - del_scope[0] sample = self.clone(self) sample = super(DPSample, sample)\ .delete_field_at_indices(field, [del_scope]) sample = super(DPSample, sample)\ .delete_field_at_indices('postag', [del_scope]) sample = super(DPSample, sample)\ .delete_field_at_indices('head', [del_scope]) sample = super(DPSample, sample)\ .delete_field_at_indices('deprel', [del_scope]) head_obj = sample.get_value('head') rep_obj = sample.head for i, head in enumerate(head_obj): head_id = int(head) # keep words dependency pointer which in del_scope may cause error # assume won't delete item with dependency pointers if head_id >= del_scope[1]: rep_obj = rep_obj.replace_at_index(i, str(head_id - del_len)) setattr(sample, 'head', rep_obj) return sample