Source code for textflint.generation_layer.transformation.UT.ocr

r"""
OcrAugTransformation  that apply ocr error simulation to textual input.
============================================

"""

__all__ = ['Ocr']

import random

from ...transformation import WordSubstitute


class OcrRules:
    def __init__(self):
        self.rules = self.get_rules()

    def predict(self, data):
        return self.rules[data]

    # TODO: Read from file
    @classmethod
    def get_rules(cls):
        mapping = {
            '0': ['8', '9', 'o', 'O', 'D'],
            '1': ['4', '7', 'l', 'I'],
            '2': ['z', 'Z'],
            '5': ['8'],
            '6': ['b'],
            '8': ['s', 'S', '@', '&'],
            '9': ['g'],
            'o': ['u'],
            'r': ['k'],
            'C': ['G'],
            'O': ['D', 'U'],
            'E': ['B']
        }

        result = {}

        for k in mapping:
            result[k] = mapping[k]

        for k in mapping:
            for v in mapping[k]:
                if v not in result:
                    result[v] = []

                if k not in result[v]:
                    result[v].append(k)

        return result


[docs]class Ocr(WordSubstitute): r""" Transformation that simulate ocr error by random values. """
[docs] def __init__( self, min_char=1, trans_min=1, trans_max=10, trans_p=0.2, stop_words=None, **kwargs ): r""" :param int min_char: If word less than this value, do not draw word for augmentation :param int trans_min: Minimum number of character will be augmented. :param int trans_max: Maximum number of character will be augmented. If None is passed, number of augmentation is calculated via aup_char_p. If calculated result from aug_p is smaller than aug_max, will use calculated result from aup_char_p. Otherwise, using aug_max. :param float trans_p: Percentage of character (per token) will be augmented. :param list stop_words: List of words which will be skipped from augment operation. """ super().__init__( min_char=min_char, trans_min=trans_min, trans_max=trans_max, trans_p=trans_p, stop_words=stop_words) self.rules = self.get_rules()
def __repr__(self): return 'Ocr' def skip_aug(self, tokens, mask, **kwargs): remain_idxes = self.pre_skip_aug(tokens, mask) token_idxes = [] for idx in remain_idxes: for char in tokens[idx]: if char in self.rules.rules and len( self.rules.predict(char)) > 0: token_idxes.append(idx) break return token_idxes def _get_candidates(self, word, n=3, **kwargs): r""" Get a list of transformed tokens. :param str word: token word to transform. :param int n: number of transformed tokens to generate. :param kwargs: :return list replaced_tokens: replaced tokens list """ replaced_tokens = [] chars = self.token2chars(word) valid_chars_idxes = [ idx for idx in range( len(chars)) if chars[idx] in self.rules.rules and len( self.rules.predict( chars[idx])) > 0] if not valid_chars_idxes: return [] # putback sampling replace_char_idxes = [ random.sample( valid_chars_idxes, 1)[0] for i in range(n)] replace_idx_dic = {} for idx in set(replace_char_idxes): replace_idx_dic[idx] = replace_char_idxes.count(idx) for replace_idx in replace_idx_dic: sample_num = replace_idx_dic[replace_idx] cand_chars = self.sample_num( self.rules.predict( chars[replace_idx]), sample_num) for cand_char in cand_chars: replaced_tokens.append( self.chars2token( chars[:replace_idx] + [cand_char] + chars[ replace_idx + 1:])) return replaced_tokens @classmethod def get_rules(cls): return OcrRules()