Source code for textflint.generation_layer.transformation.transformation

"""
Transformation Abstract Class
============================================
"""
__all__ = ["Transformation"]
import random
from abc import ABC, abstractmethod

from ...common.utils import logger
from ...common.utils.list_op import trade_off_sub_words
from ...common.utils.error import FlintError
from ...common.preprocess.en_processor import EnProcessor


[docs]class Transformation(ABC): r""" An abstract class for transforming a sequence of text to produce a list of potential adversarial example. """ processor = EnProcessor() def __init__( self, **kwargs ): pass def __repr__(self): return 'Transformation'
[docs] def transform(self, sample, n=1, field='x', **kwargs): r""" Transform data sample to a list of Sample. :param ~textflint.input_layer.component.sample.Sample sample: Data sample for augmentation. :param int n: Max number of unique augmented output, default is 5. :param str|list field: Indicate which fields to apply transformations. :param dict **kwargs: other auxiliary params. :return: list of Sample """ if n < 1: return [] if not isinstance(field, list): assert isinstance(field, str), "The type of field must be a str " \ "or list not {0}".format(type(field)) fields = [field] else: fields = field assert isinstance(fields, list), \ "The type of field can choice in str or" \ " list,not {0}".format(type(field)) fields = list(set(fields)) try: # Deal with textflint Exception if len(fields) == 1: transform_results = self._transform(sample, n=n, field=fields[0], **kwargs) else: transform_results = [] for field in fields: transform_results.append([[trans.get_value(field), trans.get_mask(field)] for trans in self._transform( sample, n=n, field=field, **kwargs)]) trans_items, trans_fields = trade_off_sub_words( transform_results, fields, n=n) transform_results = [] for trans_item in trans_items: transform_results.append( sample.replace_fields( trans_fields, [k[0] for k in trans_item], field_masks=[k[1] for k in trans_item])) except FlintError as e: logger.error(str(e)) return [] except Exception as e: logger.error(str(e)) raise FlintError("You hit an internal error. " "Please open an issue in " "https://github.com/textflint/textflint" " to report it.") if transform_results: return [sample for sample in transform_results if (not sample.is_origin) and sample.is_legal()] else: return []
@abstractmethod def _transform(self, sample, n=1, field='x', **kwargs): r""" Returns a list of all possible transformations for ``component``. :param ~textflint.input_layer.component.sample.Sample sample: Data sample for augmentation. :param int n: Default is 5. MAx number of unique augmented output. :param str field: Indicate which field to apply transformations. :param dict **kwargs: other auxiliary params. :return: list of Sample """ raise NotImplementedError
[docs] @classmethod def sample_num(cls, x, num): r""" Get 'num' samples from x. :param list x: list to sample :param int num: sample number :return: max 'num' unique samples. """ if isinstance(x, list): num = min(num, len(x)) return random.sample(x, num) elif isinstance(x, int): num = min(num, x) return random.sample(range(0, x), num)