Source code for textflint.input_layer.component.sample.ner_sample

r"""
NER Sample class to hold the necessary info and provide atomic operations.
==========================================================
"""
__all__ = ["NERSample"]
from .sample import Sample
from ..field import ListField, TextField
from ....common.utils.list_op import get_align_seq


[docs]class NERSample(Sample): r""" NER Sample class to hold the necessary info and provide atomic operations. """
[docs] def __init__( self, data, origin=None, sample_id=None, mode='BIO' ): r""" :param dict data: The dict obj that contains data info :param ~BaseSample origin: Original sample obj :param int sample_id: the id of sample :param str mode: The sequence labeling mode for NER samples. """ self.mode = mode self.text = None self.tags = None self.entities = None super().__init__(data, origin=origin, sample_id=sample_id)
def __repr__(self): return 'NERSample'
[docs] def check_data(self, data): r""" Check rare data format. :param dict data: rare data input. """ assert 'x' in data and isinstance(data['x'], (str, list)), \ r"Type of 'x' should be 'str' or 'list'" assert 'y' in data and isinstance(data['y'], list), \ r"Type of 'y' should be 'list'" assert self.mode == 'BIO' or self.mode == 'BIOES', \ 'Not support {0} type, plz ensure mode in {1}' .format( self.mode, ['BIO', 'BIOES'])
def is_legal(self): if len(self.text.words) != len(self.tags): print('here') return False return True
[docs] def load(self, data): r""" Parse data into sample field value. :param dict data: rare data input. """ self.text = TextField(data['x']) self.tags = ListField(data['y']) # set mask to prevent UT transform modify entity word. if self.mode == 'BIO': self.entities = ListField(self.find_entities_BIO( self.text.words, self.tags)) elif self.mode == 'BIOES': self.entities = ListField(self.find_entities_BIOES( self.text.words, self.tags)) if not self.is_legal(): raise ValueError('A failed transformation which leads to ' 'mismatch between input and output.')
[docs] def dump(self): r""" Convert sample info to input data json format. :return json: the dict of sentences and labels """ if not self.is_legal(): raise ValueError('A failed transformation which leads to ' 'mismatch between input and output.') return {'x': self.text.words, 'y': self.tags.field_value, 'sample_id': self.sample_id}
[docs] def delete_field_at_indices(self, field, indices): r""" Delete tokens and their NER tag. :param str field: field str :param list indices: list of int/list/slice shape:indices_num each index can be int indicate delete single item or their list like [1, 2, 3], can be list like (0,3) indicate replace items from 0 to 3(not included), can be slice which would be convert to list :return: Modified NERSample. """ assert field == 'text' sample = self.clone(self) sample = super(NERSample, sample)\ .delete_field_at_indices(field, indices) sample = super(NERSample, sample)\ .delete_field_at_indices('tags', indices) return sample
[docs] def delete_field_at_index(self, field, index): r""" Delete tokens and their NER tag. :param str field: field string, normally 'x' :param int|list|slice index: int/list/slice can be int indicate delete single item or their list like [1, 2, 3], can be list like (0,3) indicate replace items from 0 to 3(not included), can be slice which would be convert to list :return: Modified NERSample """ return self.delete_field_at_indices(field, [index])
[docs] def insert_field_before_indices(self, field, indices, items): r""" Insert tokens and ner tags.Assuming the tag of new_item is O. :param str field:field string :param list indices: list of int shape:indices_num, list like [1, 2, 3] :param list items: list of str/list shape: indices_num, correspond to indices :return: Modified NERSample """ assert field == 'text' sample = self.clone(self) sample = super(NERSample, sample)\ .insert_field_before_indices(field, indices, items) # add 'O' tag to insert token insert_tags = get_align_seq(items, 'O') sample = super(NERSample, sample)\ .insert_field_before_indices('tags', indices, insert_tags) return sample
[docs] def insert_field_before_index(self, field, ins_index, new_item): r""" Insert tokens and ner tags.Assuming the tag of new_item is O :param str field: field str :param int ins_index: indicate which index to insert items :param str/list new_item: items to insert :return: Modified NERSample """ return self.insert_field_before_indices(field, [ins_index], [new_item])
[docs] def insert_field_after_indices(self, field, indices, items): r""" Insert tokens and ner tags.Assuming the tag of new_item is O. :param str field: field string :param list indices: list of int shape:indices_num, like [1, 2, 3] :param list items: list of str/list shape: indices_num, correspond to indices :return: Modified NERSample """ assert field == 'text' sample = self.clone(self) sample = super(NERSample, sample)\ .insert_field_after_indices(field, indices, items) # add 'O' tag to insert token insert_tags = get_align_seq(items, 'O') sample = super(NERSample, sample)\ .insert_field_after_indices('tags', indices, insert_tags) return sample
[docs] def insert_field_after_index(self, field, ins_index, new_item): r""" Insert tokens and ner tags.Assuming the tag of new_item is O. :param str field: field string :param int ins_index: indicate where to apply insert :param str|list new_item: shape: indices_num, correspond to field_sub_items :return: Modified NERSample """ return self.insert_field_after_indices(field, [ins_index], [new_item])
[docs] def find_entities_BIO(self, word_seq, tag_seq): r""" find entities in a sentence with BIO labels. :param list word_seq: a list of tokens representing a sentence :param list tag_seq: a list of tags representing a tag sequence labeling the sentence :return list entity_in_seq: a list of entities found in the sequence, including the information of the start position & end position in the sentence, the category, and the entity itself. """ entity_in_seq = [] entity = {'start': 0, 'end': 0, 'entity': "", 'tag': ""} temp_entity = "" for i in range(len(word_seq)): assert tag_seq[i][0] in ['B', 'I', 'O'], \ 'entity labels should be started with \'B\' or \'I\' or \'O\'.' if tag_seq[i][0] == 'B': assert tag_seq[i][1] == '-', \ 'entity labels should be like the format \'X-XXX\'.' entity['start'] = i entity['tag'] = tag_seq[i][2:] temp_entity = word_seq[i] if i == len(word_seq) - 1: entity['end'] = i entity['entity'] = temp_entity entity_in_seq.append(entity) entity = {'start': 0, 'end': 0, 'entity': "", 'tag': ""} elif tag_seq[i][0] == 'I': assert temp_entity != '', \ '\'I\' label cannot be the start of the entity.' assert tag_seq[i][1] == '-', \ 'entity labels should be like the format \'X-XXX\'.' temp_entity += ' ' + word_seq[i] if i == len(word_seq) - 1: entity['end'] = i entity['entity'] = temp_entity entity_in_seq.append(entity) entity = {'start': 0, 'end': 0, 'entity': "", 'tag': ""} elif tag_seq[i] == 'O': if i > 0 and not tag_seq[i - 1] == 'O': entity['end'] = i - 1 entity['entity'] = temp_entity entity_in_seq.append(entity) entity = {'start': 0, 'end': 0, 'entity': "", 'tag': ""} temp_entity = '' return entity_in_seq
[docs] def find_entities_BIOES(self, word_seq, tag_seq): r""" find entities in a sentence with BIOES labels. :param list word_seq: a list of tokens representing a sentence :param list tag_seq: a list of tags representing a tag sequence labeling the sentence :return list entity_in_seq: a list of entities found in the sequence, including the information of the start position & end position in the sentence, the category, and the entity itself. """ entity_in_seq = [] entity = {'start': 0, 'end': 0, 'entity': "", 'tag': ""} temp_entity = "" for i in range(len(word_seq)): assert tag_seq[i][0] in ['B', 'I', 'O', 'E', 'S'], \ 'entity labels should be started with ' \ '\'B\' or \'I\' or \'O\' or \'E\' or \'S\'.' if not tag_seq[i] == 'O': assert tag_seq[i][1] == '-', \ 'entity labels should be like the format \'X-XXX\'.' if tag_seq[i][0] == 'B': assert temp_entity == '', \ '\'B\' label must be the start of the entity.' entity['start'] = i entity['tag'] = tag_seq[i][2:] temp_entity = word_seq[i] elif tag_seq[i][0] == 'I': assert temp_entity != '', \ '\'I\' label cannot be the start of the entity.' temp_entity += ' ' + word_seq[i] elif tag_seq[i][0] == 'E': assert temp_entity != '', \ '\'E\' label cannot be the start of the entity.' temp_entity += ' ' + word_seq[i] entity['end'] = i entity['entity'] = temp_entity entity_in_seq.append(entity) entity = {'start': 0, 'end': 0, 'entity': "", 'tag': ""} temp_entity = '' elif tag_seq[i][0] == 'S': assert temp_entity == '', \ '\'S\' label must be the start of the entity.' entity['start'] = i entity['end'] = i entity['entity'] = word_seq[i] entity_in_seq.append(entity) entity = {'start': 0, 'end': 0, 'entity': "", 'tag': ""} temp_entity = '' return entity_in_seq
[docs] def entities_replace(self, entities_info, candidates): r""" Replace multi entity in once time.Assume input entities with reversed sequential. :param list entities_info: list of entity_info :param list candidates: candidate entities :return: Modified NERSample """ assert len(entities_info) == len(candidates) assert isinstance(entities_info, list) assert isinstance(candidates, list) sample = self.clone(self) for i in range(len(entities_info)): entity = entities_info[i] sample = sample.entity_replace(entity['start'], entity['end'], candidates[i], entity['tag']) return sample
[docs] def entity_replace(self, start, end, entity, label): r""" Replace one entity and update entities info. :param int start: the start position of the entity to be replaced :param int end: the end position of the entity to be replaced :param str entity: the entity to be replaced with :param str label: the category of the entity :return: Modified NERSample """ assert start <= end, "start is before end!" sample = self.clone(self) entity = entity.split(" ") word_prefix = sample.text.words[:start] word_suffix = [] if end == len(sample.text.words) - 1 \ else sample.text.words[end + 1:] sample.text = TextField(word_prefix + entity + word_suffix) tag_prefix = sample.tags[:start] tag_suffix = [] if end == len(sample.tags) - 1 \ else sample.tags[end + 1:] if self.mode == 'BIO': sample.tags = ListField(tag_prefix + ["B-" + label] + ["I-" + label] * (len(entity) - 1) + tag_suffix) else: len_entity = len(entity) if len_entity == 1: substitude = ["S-" + label] elif len_entity == 2: substitude = ["B-" + label] + ["E-" + label] else: substitude = ["B-" + label] \ + ["I-" + label] * (len_entity - 2) \ + ["E-" + label] sample.tags = ListField(tag_prefix + substitude + tag_suffix) # set mask for changed entities # for i in range(start, end+1): # sample.text.set_mask(i, MODIFIED_MASK) sample.entities = ListField( sample.find_entities_BIO(sample.text.words, sample.tags)) return sample