Source code for textflint.generation_layer.subpopulation.subpopulation

r"""
SubPopulation Abstract Class
============================================

"""

import math
from tqdm import tqdm
from abc import ABC, abstractmethod

from ...common.preprocess import EnProcessor


[docs]class SubPopulation(ABC): r""" An abstract class for extracting subset of examples. """ text_processor = EnProcessor() def __init__( self, intervals=None, **kwargs ): self.intervals = intervals def __repr__(self): return "SubPopulation"
[docs] def score(self, sample, field, **kwargs): r""" Score the sample :param sample: data sample :param str|list field: field str :param kwargs: :return int: score for sample """ if not isinstance(field, list): fields = [field] else: fields = field return self._score(sample, fields, **kwargs)
@abstractmethod def _score(self, samle, fields, **kwargs): r""" Score the sample :param sample: data sample :param list fields: list of field str :param kwargs: :return int: score for sample """ raise NotImplementedError
[docs] def get_slice(self, scores, dataset): r""" Pick up samples based on scores :param list scores: list of int :param dataset: Dataset :return: subset samples """ indexes = sorted(range(len(scores)), key=lambda x: scores[x]) sort_samples = [dataset[idx] for idx in indexes] left_bound = self.normalize_bound(self.intervals[0], len(dataset)) right_bound = self.normalize_bound(self.intervals[1], len(dataset)) sub_samples = sort_samples[left_bound:right_bound] return sub_samples
[docs] def slice_population(self, dataset, fields, **kwargs): r""" Extract a subset of samples. :param dataset: Dataset :param list fields: field str list :param kwargs: :return: Subset Dataset """ scores = [] for sample in tqdm(dataset): scores.append(self.score(sample, fields)) sub_samples = self.get_slice(scores, dataset) new_dataset = dataset.new_dataset() new_dataset.extend(sub_samples) return new_dataset
[docs] @staticmethod def normalize_bound(limit, size): r""" Normalize the bound of slice :param str|float|int limit: left_bound or right_bound for intervals can be percentile like 10%, 20% can be float between 0 and 1 like 0.3 can be int index like 50 :param size: the size of samples :return int : bound """ if isinstance(limit, str) and limit.endswith("%"): limit = float(limit.replace("%", "")) / 100 return math.floor(limit * size) elif isinstance(limit, float): return math.floor(limit * size) elif isinstance(limit, int): return limit else: raise NotImplementedError