Source code for textflint.generation_layer.transformation.UT.keyboard

r"""
KeyboardTransformation  Class
============================================

"""

__all__ = ["Keyboard"]

import os
import re
import random

from ....common.settings import KEYBOARD_RULES
from ...transformation import WordSubstitute
from ....common.utils.install import download_if_needed
from ....common.utils.load import json_loader


class KeyboardRules:
    def __init__(
        self,
        special_char=True,
        numeric=True,
        upper_case=True,
        lang="en",
        rules_path=None
    ):
        self.special_char = special_char
        self.numeric = numeric
        self.upper_case = upper_case
        self.lang = lang
        self.rules = self.get_rules(
            model_path=rules_path,
            special_char=special_char,
            numeric=numeric,
            upper_case=upper_case,
            lang=lang)

    def predict(self, data):
        return self.rules[data]

    # TODO: Extending to 2 keyboard distance
    @classmethod
    def get_rules(
            cls,
            model_path,
            special_char=True,
            numeric=True,
            upper_case=True,
            lang="en"):
        model_path = download_if_needed(model_path)
        if not os.path.exists(model_path):
            raise ValueError(
                'The model_path does not exist. Please check "{}"'.format(
                    model_path))

        mapping = json_loader(model_path)

        result = {}

        for key, values in mapping.items():
            # Skip records if key is numeric while include_numeric is false
            if not numeric and re.match("^[0-9]*$", key):
                continue
            # skip record if key is special character while include_spec is
            # false
            if not special_char and not re.match("^[a-z0-9]*$", key):
                continue

            result[key] = []
            result[key.upper()] = []

            for value in values:
                # Skip record if value is numeric while include_numeric is
                # false
                if not numeric and re.match("^[0-9]*$", value):
                    continue

                # skip record if value is special character while include_spec
                # is false
                if not special_char and not re.match("^[a-z0-9]*$", value):
                    continue

                result[key].append(value)

                if upper_case:
                    result[key].append(value.upper())
                    result[key.upper()].append(value)
                    result[key.upper()].append(value.upper())

        clean_result = {}
        for key, values in result.items():
            # clear empty mapping
            if len(values) == 0:
                continue

            # de-duplicate
            values = [v for v in values if v != key]
            values = sorted(list(set(values)))

            clean_result[key] = values

        return clean_result


[docs]class Keyboard(WordSubstitute): r""" Transformation that simulate typo error by random values. https://arxiv.org/pdf/1711.02173.pdf For example, people may type i as o incorrectly.\ One keyboard distance is leveraged to replace character by possible keyboard error. """
[docs] def __init__( self, min_char=4, trans_min=1, trans_max=10, trans_p=0.2, stop_words=None, rules_path=None, include_special_char=True, include_numeric=True, include_upper_case=True, lang="en", **kwargs ): r""" :param int min_char: If word less than this value, do not draw word for augmentation :param int trans_min: Minimum number of character will be augmented. :param int trans_max: Maximum number of character will be augmented. If None is passed, number of augmentation is calculated via aup_char_p.If calculated result from aug_p is smaller than aug_max, will use calculated result from aup_char_p. Otherwise, using aug_max. :param float trans_p: Percentage of character (per token) will be augmented. :param list stop_words: List of words which will be skipped from augment operation. :param str rules_path: Loading customize model from file system :param bool include_special_char: Include special character :param bool include_numeric: If True, numeric character may be included in augmented data. :param bool include_upper_case: If True, upper case character may be included in augmented data. :param str lang: Indicate built-in language model. Default value is 'en'. Possible values are 'en' and 'th'. If custom model is used (passing model_path), this value will be ignored. """ super().__init__( min_char=min_char, trans_min=trans_min, trans_max=trans_max, trans_p=trans_p, stop_words=stop_words) self.include_special_char = include_special_char self.include_numeric = include_numeric self.include_upper_case = include_upper_case self.include_lower_case = True self.lang = lang if rules_path is None: if lang != 'en': raise ValueError( 'Only support en now. You may provide the keyboard mapping ' 'such that we can support "{}"'.format(lang)) self.rules_path = KEYBOARD_RULES else: self.rules_path = rules_path self.rules = self.get_rules( include_special_char, include_numeric, include_upper_case, lang, self.rules_path)
def __repr__(self): return 'Keyboard' def skip_aug(self, tokens, mask, pos=None): return self.pre_skip_aug(tokens, mask) def _get_candidates(self, word, n=5, **kwargs): r""" Get a list of transformed tokens. Default one word replace one char :param str word: token word to transform. :param int n: number of transformed tokens to generate. :param kwargs: :return list replaced_tokens: replaced tokens list """ replaced_tokens = [] chars = self.token2chars(word) valid_chars_idxes = [ idx for idx in range( len(chars)) if chars[idx] in self.rules.rules and len( self.rules.predict( chars[idx])) > 0] if not valid_chars_idxes: return [] # putback sampling replace_char_idxes = [ random.sample( valid_chars_idxes, 1)[0] for i in range(n)] replace_idx_dic = {} for idx in set(replace_char_idxes): replace_idx_dic[idx] = replace_char_idxes.count(idx) for replace_idx in replace_idx_dic: sample_num = replace_idx_dic[replace_idx] cand_chars = self.sample_num( self.rules.predict( chars[replace_idx]), sample_num) for cand_char in cand_chars: replaced_tokens.append( self.chars2token( chars[:replace_idx] + [cand_char] + chars[ replace_idx + 1:])) return replaced_tokens @classmethod def get_rules( cls, special_char=True, numeric=True, upper_case=True, lang="en", rules_path=None ): return KeyboardRules( special_char=special_char, numeric=numeric, upper_case=upper_case, lang=lang, rules_path=rules_path )