Source code for textflint.input_layer.component.sample.cws_sample

"""
CWS Sample Class
============================================
"""

from .sample import Sample
from ..field.cn_text_field import CnTextField
from ..field import ListField
from ....common.settings import ORIGIN, MODIFIED_MASK
from ....common.utils.list_op import *

__all__ = ['CWSSample']


[docs]class CWSSample(Sample): r""" Our segmentation rules are based on ctb6. the input x can be a list or a sentence the input y is segmentation label include:B,M,E,S the y also can automatic generation,if you want automatic generation you must input an empty list and x must each word in x is separated by a space or split into each element of the list Note that punctuation should be separated into a single word Example:: 1. input {'x':'小明好想送Jo圣诞礼物', 'y' = ['B', 'E', 'B', 'E', 'S', 'B', 'E', 'B', 'E', 'B', 'E']} 2. input {'x':['小明','好想送Jo圣诞礼物'], 'y' = ['B', 'E', 'B', 'E', 'S', 'B', 'E', 'B', 'E', 'B', 'E']} 3. input {'x':'小明 好想 送 Jo 圣诞 礼物', 'y' = []} 4. input {'x':['小明', '好想', '送', 'Jo', '圣诞', '礼物'], 'y' = []} """
[docs] def __init__(self, data, origin=None, sample_id=None): r""" :param dict data: The dict obj that contains data info :param int sample_id: the id of sample :param bool origin: if the sample is origin """ super().__init__(data, origin=origin, sample_id=sample_id)
def __repr__(self): return 'CWSSample' @staticmethod def is_legal(): return True
[docs] def check_data(self, data): r""" Check the whether the data legitimate but we don't check that the label is correct if the data is not legal but acceptable format, change the format of data :param dict data: The dict obj that contains data info """ assert 'x' in data and 'y' in data assert isinstance(data['y'], list), \ f"The type of data[y] me be list,not {0}".format(type(data['y'])) assert isinstance(data['x'], (str, list)), \ f"The type of data[y] me be list or str, not {0}"\ .format(type(data['x'])) if not data['y']: if isinstance(data['x'], str): data['x'] = data['x'].strip().split(' ') data['y'] = [] sentence = '' for x in data['x']: x.replace(' ', '') sentence += x if len(x) == 1: data['y'] += ['S'] elif len(x) > 1: data['y'] += ['B'] + ['M'] * (len(x) - 2) + ['E'] data['x'] = sentence else: sentence = [] for i in data['x']: sentence += i.replace(' ', '') data['x'] = sentence cws_tag = ['B', 'M', 'E', 'S'] assert len(data['x']) == len(data['y']) for tag in data['y']: assert tag in cws_tag
[docs] def load(self, data): r""" Convert data dict which contains essential information to CWSSample. :param dict data: The dict obj that contains data info """ self.x = CnTextField(data['x']) assert isinstance(data['y'], list) self.y = ListField(data['y'])
def dump(self): assert len( self.x.mask) == len( self.x.field_value) == len( self.y.field_value) return { 'x': self.x.field_value, 'y': self.y.field_value, 'sample_id': self.sample_id} @property def mask(self): return self.x.mask @property def pos_tags(self): return self.x.pos_tags() @property def ner(self): return self.x.ner()
[docs] def get_words(self): r""" Get the words from the sentence. :return list: the words in sentence """ start = 0 words = [] while start < len(self.x.field_value): # find the word if self.y.field_value[start] == 'B': end = start + 1 while self.y.field_value[end] != 'E': end += 1 elif self.y.field_value[start] == 'S': end = start else: raise ValueError(f"the label is not right") words.append(self.x.field_value[start:end + 1]) start = end + 1 return words
[docs] def replace_at_ranges(self, indices, new_items, y_new_items=None): r""" Replace words at indices and set their mask to MODIFIED_MASK. :param list indices: The list of the pos need to be changed. :param list new_items: The list of the item need to be changed. :param list y_new_items: The list of the mask info need to be changed. :return: replaced CWSSample object. """ indices, items, mask, y_new_items = self.check( indices, new_items, y_new_items) if len(indices): cws = self.clone(self) new_mask = unequal_replace_at_scopes(self.mask, indices, mask) new_field = unequal_replace_at_scopes(self.x.token, indices, items) x = self.x.new_field(new_field, mask=new_mask) setattr(cws, 'x', x) if y_new_items: y = unequal_replace_at_scopes( self.y.field_value, indices, y_new_items) setattr(cws, 'y', ListField(y)) return cws else: return self
[docs] def update(self, x, y): r""" Replace words at indices and set their mask to MODIFIED_MASK. :param str x: the new sentence. :param list y: the new labels. :return: new CWSSample object. """ cws = self.clone(self) setattr(cws, 'x', x) setattr(cws, 'y', ListField(y)) return cws
[docs] def check(self, indices, new_items, y_new_items=None): r""" Check whether the position of change is legal. :param list indices: The list of the pos need to be changed. :param list new_items: The list of the item need to be changed. :param list y_new_items: The list of the mask info need to be changed. :return three list: legal position, change items, change labels. """ assert len(indices) == len(new_items) legal_indices = [] legal_items = [] mask_change = [] legal_y = [] mask = self.mask for i in range(len(indices)): flag = True if isinstance(indices[i], list): for j in range(indices[i][0], indices[i][1]): if mask[j] != ORIGIN: flag = False break else: if mask[indices[i]] != ORIGIN: flag = False if flag: legal_indices.append(indices[i]) legal_items.append(new_items[i]) if y_new_items: legal_y.append(y_new_items[i]) if isinstance(new_items[i], list): change = [] for k in new_items[i]: change += [MODIFIED_MASK] * len(k) mask_change.append(change) else: mask_change.append([MODIFIED_MASK] * len(new_items[i])) return legal_indices, legal_items, mask_change, legal_y
[docs] @staticmethod def get_labels(words): r""" Get the label of the word. :param str words: The word you want to get labels. :return list: the label of the words. """ assert isinstance(words, str), \ "The type of words must be str, not {0}".format(type(words)) if len(words) == 1: return ['S'] return ['B'] + ['M'] * (len(words) - 2) + ['E']