Source code for textflint.generation_layer.validator.validator

r"""
Constraint Class
=====================================
"""
import numpy
from abc import ABC, abstractmethod

from ...input_layer.dataset.dataset import *
__all__ = ['Validator']


[docs]class Validator(ABC): r""" An abstract class that computes the semantic similarity score between original text and adversarial texts :param ~textflint.input_layer.dataset origin_dataset: the dataset of origin sample :param ~textflint.input_layer.dataset trans_dataset: the dataset of translate sample :param str|list fields: the name of the origin field need compare. :param bool need_tokens: if we need tokenize the sentence """ def __init__( self, origin_dataset, trans_dataset, fields, need_tokens=False ): assert isinstance(origin_dataset, Dataset), f"Input must be a {Dataset}" assert isinstance(trans_dataset, Dataset), f"Input must be a {Dataset}" assert len(origin_dataset) and len(trans_dataset), \ f"origin_dataset and trans_dataset can not be empty." assert isinstance(origin_dataset[0], type(trans_dataset[0])), \ f"The type of origin sample and trans sample must be same." self.ori_dataset = origin_dataset self.trans_dataset = trans_dataset self.id2loc = None self._score = None self.fields = fields if isinstance(fields, list) else [fields] self.check_data() self.need_tokens = need_tokens
[docs] @abstractmethod def validate(self, transformed_text, reference_text): r""" Calculate the score :param str transformed_text: transformed sentence :param str reference_text: origin sentence :return float: the score of two sentence """ raise NotImplementedError()
[docs] def check_data(self): r""" Check whether the input data is legal """ self.id2loc = {} for i in range(len(self.ori_dataset)): if self.ori_dataset[i].sample_id not in self.id2loc: self.id2loc[self.ori_dataset[i].sample_id] = i else: raise ValueError('There are two origin samples have same id.') if not isinstance(self.ori_dataset[0], type(self.ori_dataset[i])): raise ValueError('There are at least two type of ' 'origin sample in the dataset.') try: for field in self.fields: self.ori_dataset[i].get_value(field) except AttributeError: raise ValueError(f'The {i} sample does not ' f'have the attribute {field}.') for i, trans_data in enumerate(self.trans_dataset): if trans_data.sample_id not in self.id2loc: raise ValueError('There is no origin sample ' 'can match trans sample') if not isinstance(self.trans_dataset[0], type(trans_data)): raise ValueError('There are at least two type of ' 'origin sample in the dataset.') try: for field in self.fields: trans_data.get_value(field) except AttributeError: raise ValueError(f'The {i} sample does not ' f'have the attribute {field}.')
@property def score(self): r""" Calculate the score of the deformed sentence :return list: a list of translate sentence score """ if not self._score: self.check_data() self._score = [] for trans_sample in self.trans_dataset: for field in self.fields: score = [] trans = trans_sample.get_words(field) \ if self.need_tokens else trans_sample.get_text(field) ori = self.ori_dataset[self.id2loc[trans_sample.sample_id]] ori = ori.get_words(field) if self.need_tokens \ else ori.get_text(field) score.append(self.validate(trans, ori)) self._score.append(numpy.mean(score)) assert len(self._score) == len(self.trans_dataset), \ "The len of the score not equal transset." return self._score