Source code for textflint.generation_layer.generator.re_generator

r"""
REGenerator class for sample generating 
"""
__all__ = ["REGenerator"]
from .generator import Generator
from tqdm import tqdm
from ...common import logger
from ...input_layer.dataset import Dataset
from ...common.settings import TASK_TRANSFORMATION_PATH, \
    ALLOWED_TRANSFORMATIONS, TASK_SUBPOPULATION_PATH, ALLOWED_SUBPOPULATIONS


[docs]class REGenerator(Generator): r""" generate RE sample """ def __init__( self, task='RE', max_trans=1, fields='x', trans_methods=None, trans_config=None, return_unk=True, sub_methods=None, sub_config=None, attack_methods=None, validate_methods=None, **kwargs ): super().__init__( task=task, max_trans=max_trans, fields=fields, trans_methods=trans_methods, trans_config=trans_config, return_unk=return_unk, sub_methods=sub_methods, sub_config=sub_config, attack_methods=attack_methods, validate_methods=validate_methods, **kwargs )
[docs] def generate_by_transformations(self, dataset, **kwargs): r""" Returns a list of all possible transformed samples for "dataset". :param ~TextRobustness.dataset.Dataset dataset: the original dataset ready for transformation or subpopulation :return: yield transformed samples + transformation name string. """ self.prepare(dataset) dataset.init_iter() transform_objs = self._get_flint_objs( self.transform_methods, TASK_TRANSFORMATION_PATH, ALLOWED_TRANSFORMATIONS) for obj_id, trans_obj in enumerate(transform_objs): logger.info('******Start {0}!******'.format(trans_obj)) generated_samples = dataset.new_dataset() original_samples = dataset.new_dataset() # initialize current index of dataset dataset.init_iter() for index in tqdm(range(len(dataset))): concat_samples = dataset[index + 1: index + 3] sample = dataset[index] # default return list of samples trans_rst = trans_obj.transform(sample, n=self.max_trans, field=self.fields, concat_samples=concat_samples) # default return list of samples if trans_rst: generated_samples.extend(trans_rst) original_samples.append(sample) yield original_samples, generated_samples, trans_obj.__repr__() logger.info('******Finish {0}!******'.format(trans_obj))