Source code for textflint.generation_layer.generator.ner_generator
r"""
NER Generator aims to apply NER data generation function
==========================================================
"""
__all__ = ["NERGenerator"]
from tqdm import tqdm
from ...common import logger
from .generator import Generator
from ...input_layer.dataset import Dataset
from ...common.settings import TASK_TRANSFORMATION_PATH, \
ALLOWED_TRANSFORMATIONS, TASK_SUBPOPULATION_PATH, ALLOWED_SUBPOPULATIONS
Flint = {
"transformation": {'task_path': TASK_TRANSFORMATION_PATH,
'allowed_methods': ALLOWED_TRANSFORMATIONS},
"subpopulation": {'task_path': TASK_SUBPOPULATION_PATH,
'allowed_methods': ALLOWED_SUBPOPULATIONS}
}
[docs]class NERGenerator(Generator):
r"""
NER Generator aims to apply NER data generation function.
"""
def __init__(
self,
task='NER',
max_trans=1,
fields='text',
trans_methods=None,
trans_config=None,
return_unk=True,
sub_methods=None,
sub_config=None,
attack_methods=None,
validate_methods=None,
**kwargs
):
super().__init__(
task=task,
max_trans=max_trans,
fields=fields,
trans_methods=trans_methods,
trans_config=trans_config,
return_unk=return_unk,
sub_methods=sub_methods,
sub_config=sub_config,
attack_methods=attack_methods,
validate_methods=validate_methods,
**kwargs
)
[docs] def generate_by_transformations(self, dataset, **kwargs):
r"""
Returns a list of all possible transformed samples for "dataset".
:param ~TextRobustness.dataset.Dataset dataset: the original dataset
ready for transformation or subpopulation
:return: yield transformed samples + transformation name string.
"""
self.prepare(dataset)
transform_objs = self._get_flint_objs(
self.transform_methods,
TASK_TRANSFORMATION_PATH,
ALLOWED_TRANSFORMATIONS)
for obj_id, trans_obj in enumerate(transform_objs):
logger.info('******Start {0}!******'.format(trans_obj))
generated_samples = dataset.new_dataset()
original_samples = dataset.new_dataset()
dataset.init_iter()
for index in tqdm(range(len(dataset))):
concat_samples = dataset[index + 1: index + 3]
sample = dataset[index]
# default return list of samples
trans_rst = trans_obj.transform(sample,
n=self.max_trans,
field=self.fields,
concat_samples=concat_samples)
# default return list of samples
if trans_rst:
generated_samples.extend(trans_rst)
original_samples.append(sample)
yield original_samples, generated_samples, trans_obj.__repr__()
logger.info('******Finish {0}!******'.format(trans_obj))
# free transformation object memory
transform_objs[obj_id] = None
del trans_obj