textflint.generation_layer.transformation.UT.mlm_suggestion

Swapping words by Mask Language Model

class textflint.generation_layer.transformation.UT.mlm_suggestion.MLMSuggestion(masked_model=None, device=None, accrue_threshold=1, max_sent_size=100, trans_min=1, trans_max=10, trans_p=0.2, stop_words=None, **kwargs)[source]

Bases: textflint.generation_layer.transformation.word_substitute.WordSubstitute

Transforms an input by replacing its tokens with words of mask language predicted. To accelerate transformation for long text, input single sentence to language model rather than whole text.

__init__(masked_model=None, device=None, accrue_threshold=1, max_sent_size=100, trans_min=1, trans_max=10, trans_p=0.2, stop_words=None, **kwargs)[source]
Parameters
  • masked_model (str) – masked language model to predicate candidates

  • device (str) – indicate utilize cpu or which gpu device to run neural network

  • accrue_threshold (int) – threshold of Bert results to pick

  • max_sent_size – max_sent_size

  • trans_min (int) – Minimum number of character will be augmented.

  • trans_max (int) – Maximum number of character will be augmented. If None is passed, number of augmentation is calculated via aup_char_p. If calculated result from aug_p is smaller than aug_max, will use calculated result from aup_char_p. Otherwise, using aug_max.

  • trans_p (float) – Percentage of character (per token) will be augmented.

  • stop_words (list) – List of words which will be skipped from augment operation.

get_model()[source]

Loads masked language model to predict candidates.

pre_calculate_allowed_tokens()[source]

Precalculate meaningful tokens, filter tokens which is not an alphabetic string.

Pre filter would accelerate procedure of verifying pos tags of candidates.

class textflint.generation_layer.transformation.UT.mlm_suggestion.BackTrans(from_model_name=None, to_model_name=None, device=None, **kwargs)[source]

Bases: textflint.generation_layer.transformation.transformation.Transformation

Back Translation with hugging-face translation models. A sentence can only be transformed into one sentence at most.

__init__(from_model_name=None, to_model_name=None, device=None, **kwargs)[source]
Parameters
  • from_model_name (str) – model to translate original language to target language

  • to_model_name (str) – model to translate target language to original language

  • device – indicate utilize cpu or which gpu device to run neural network

static get_device(device)[source]

Get gpu or cpu device.

Parameters

device (str) – device string “cpu” means use cpu device. “cuda:0” means use gpu device which index is 0.

Returns

device in torch.

class textflint.generation_layer.transformation.UT.mlm_suggestion.WordSubstitute(trans_min=1, trans_max=10, trans_p=0.1, stop_words=None, **kwargs)[source]

Bases: textflint.generation_layer.transformation.transformation.Transformation

Word replace transformation to implement normal word replace functions.

__init__(trans_min=1, trans_max=10, trans_p=0.1, stop_words=None, **kwargs)[source]
Parameters
  • trans_min (int) – Minimum number of word will be augmented.

  • trans_max (int) – Maximum number of word will be augmented. If None is passed, number of augmentation is calculated via aup_char_p. If calculated result from aug_p is smaller than aug_max, will use calculated result from aup_char_p. Otherwise, using aug_max.

  • trans_p (float) – Percentage of word will be augmented.

  • stop_words (list) – List of words which will be skipped from augment operation.

  • processor (EnProcessor) –

  • get_pos (bool) – whether pass pos tag to _get_substitute_words API.

abstract skip_aug(tokens, mask, pos=None)[source]

Returns the index of the replaced tokens.

Parameters

tokens (list) – tokenized words or word with pos tag pairs

Return list

the index of the replaced tokens

is_stop_words(token)[source]

Judge whether the input word belongs to the stop words vocab.

Parameters

token (str) – the input word to be judged

Return bool

is a stop word or not

pre_skip_aug(tokens, mask)[source]

Skip the tokens in stop words list or punctuation list.

Parameters
  • tokens (list) – the list of tokens

  • mask (list) – the list of mask Indicates whether each word is allowed to be substituted. ORIGIN is allowed, while TASK_MASK and MODIFIED_MASK is not.

Return list

List of possible substituted token index.

get_trans_cnt(size)[source]

Get the num of words/chars transformation.

Parameters

size (int) – the size of target sentence

Return int

number of words to apply transformation.

textflint.generation_layer.transformation.UT.mlm_suggestion.copy(x)[source]

Shallow copy operation on arbitrary Python objects.

See the module’s __doc__ string for more info.

class textflint.generation_layer.transformation.UT.mlm_suggestion.defaultdict

Bases: dict

defaultdict(default_factory[, …]) –> dict with default factory

The default factory is called without arguments to produce a new value when a key is not present, in __getitem__ only. A defaultdict compares equal to a dict with the same items. All remaining arguments are treated the same as if they were passed to the dict constructor, including keyword arguments.

__init__(*args, **kwargs)

Initialize self. See help(type(self)) for accurate signature.

copy()a shallow copy of D.
default_factory

Factory for default value called by __missing__().

textflint.generation_layer.transformation.UT.mlm_suggestion.trade_off_sub_words(sub_words, sub_indices, trans_num=None, n=1)[source]

Select proper candidate words to maximum number of transform result. Select words of top n substitutes words number.

Parameters
  • sub_words (list) – list of substitutes word of each legal word

  • sub_indices (list) – list of indices of each legal word

  • trans_num (int) – max number of words to apply substitution

  • n (int) –

Returns

sub_words after alignment + indices of sub_words