Source code for textflint.generation_layer.generator.absa_generator
r"""
Generator for ABSA Task
============================================
"""
__all__ = ['ABSAGenerator']
from tqdm import tqdm
from .generator import Generator
from ...common.utils.logger import logger
from ...input_layer.component.sample import ABSASample
from ..transformation.transformation import Transformation
from ...common.utils.load import absa_dict_loader
from ...common.utils.install import download_if_needed
from ...common.settings import TASK_TRANSFORMATION_PATH, \
ALLOWED_TRANSFORMATIONS, TASK_SUBPOPULATION_PATH, \
ALLOWED_SUBPOPULATIONS, ABSA_TRAIN_RESTAURANT_PATH, \
ABSA_TRAIN_LAPTOP_PATH
Flint = {
"transformation": {
'task_path': TASK_TRANSFORMATION_PATH,
'allowed_methods': ALLOWED_TRANSFORMATIONS
},
"subpopulation": {
'task_path': TASK_SUBPOPULATION_PATH,
'allowed_methods': ALLOWED_SUBPOPULATIONS
}
}
[docs]class ABSAGenerator(Generator):
r"""
Generate extra text for AbsaAddDiff,
and dataset type is assigned in configure.
"""
def __init__(
self,
task='ABSA',
fields='sentence',
max_trans=1,
trans_methods=None,
trans_config=None,
return_unk=True,
sub_methods=None,
sub_config=None,
attack_methods=None,
validate_methods=None,
dataset_config='restaurant',
**kwargs
):
self.dataset_config = dataset_config
self.nlp = Transformation.processor.nlp
self.extra_text = []
super().__init__(
task=task,
max_trans=max_trans,
fields=fields,
trans_methods=trans_methods,
trans_config=trans_config,
return_unk=return_unk,
sub_methods=sub_methods,
sub_config=sub_config,
attack_methods=attack_methods,
validate_methods=validate_methods,
**kwargs
)
self.transform_methods = trans_methods
if self.dataset_config is None:
logger.info(
'******No config of dataset is available for AddDiff!******'
)
if 'AddDiff' in ALLOWED_TRANSFORMATIONS['ABSA']:
ALLOWED_TRANSFORMATIONS['ABSA'].remove('AddDiff')
else:
if 'AddDiff' in trans_methods and \
self.dataset_config == 'restaurant':
self.examples = absa_dict_loader(
download_if_needed(ABSA_TRAIN_RESTAURANT_PATH))
self.extra_text = self.get_extra_text()
else:
for transform_method in trans_methods:
if 'AddDiff' in transform_method and \
self.dataset_config == 'laptop':
self.examples = absa_dict_loader(
download_if_needed(ABSA_TRAIN_LAPTOP_PATH))
self.extra_text = self.get_extra_text()
break
[docs] @staticmethod
def get_extra_sentence(term_list, term_id, phrases):
r"""
Get the extra sentence from phrases text.
:param dict term_list: term list
:param str term_id: term id
:param list phrases: phrase list
:return list: extra sentences
"""
phrases_list = []
other_terms = []
extra_sentence = []
term = term_list[term_id]['term']
opinions = term_list[term_id]['opinion_words']
for other_id in term_list:
if other_id != term_id:
other_terms.append(term_list[other_id]['term'])
for phrase in phrases:
opinion_exist = True
phrase_ = \
''.join([token.text_with_ws for token in list(phrase)]).strip()
for opinion_word in opinions:
if opinion_word not in phrase_:
opinion_exist = False
if term in phrase_ and opinion_exist is True:
phrases_list.append(phrase_)
for phrase in phrases_list:
overlap = False
for other_term in other_terms:
if other_term in phrase:
overlap = True
break
if not overlap:
extra_sentence.append(phrase)
extra_sentence = sorted(extra_sentence, key=len)
return extra_sentence
[docs] def get_extra_text(self):
r"""
Get extra text from training dataset.
:return: dict of extra text
"""
positive_text = []
negative_text = []
neutral_text = []
logger.info('******Prepare extra {0} corpus for AddDiff!******'
.format(self.dataset_config))
for text_id in tqdm(self.examples):
text = self.examples[text_id]
text_sample = ABSASample(text)
term_list = text_sample.term_list
text_doc = self.nlp(text_sample.sentence.text)
phrase_list = []
for token in text_doc:
if len(list(token.subtree)) > 1:
phrase_list.append(list(token.subtree))
for term_id in term_list:
term = term_list[term_id]['term']
term_polarity = term_list[term_id]['polarity']
extra_sentence = \
self.get_extra_sentence(term_list, term_id, phrase_list)
if len(extra_sentence) == 0:
continue
extra_sentence = extra_sentence[0]
if term_polarity == 'positive':
positive_text.append((term.lower(),
[extra_sentence.lower()]))
elif term_polarity == 'negative':
negative_text.append((term.lower(),
[extra_sentence.lower()]))
elif term_polarity == 'neutral':
neutral_text.append((term.lower(),
[extra_sentence.lower()]))
return {'positive': positive_text,
'negative': negative_text,
'neutral': neutral_text}
[docs] def generate_by_transformations(self, dataset, **kwargs):
r"""
Generate samples by a list of transformation methods.
:param dataset: the input dataset
:return: (original samples, new samples, generated function string)
"""
self.prepare(dataset)
for trans_obj in self._get_flint_objs(
self.transform_methods,
TASK_TRANSFORMATION_PATH,
ALLOWED_TRANSFORMATIONS
):
# initialize current index of dataset
dataset.init_iter()
logger.info('******Start {0}!******'.format(trans_obj))
generated_samples = dataset.new_dataset()
original_samples = dataset.new_dataset()
for sample in tqdm(dataset):
# default return list of samples
trans_rst = trans_obj.transform(
sample,
n=self.max_trans,
field=self.fields,
extra_text=self.extra_text)
if trans_rst:
generated_samples.extend(trans_rst)
original_samples.append(sample)
yield original_samples, generated_samples, trans_obj.__repr__()
logger.info('******Finish {0}!******'.format(trans_obj))