Source code for textflint.generation_layer.subpopulation.UT.lm

r"""
Extract samples with high perplexity or low perplexity
============================================

"""
__all__ = ['LMSubPopulation']
import torch
import math
from ..subpopulation import SubPopulation


[docs]class LMSubPopulation(SubPopulation): r""" Filter samples based on text perplexity Example:: sample 1: "I love textflint", score: 6.7 sample 2: "I love TextFlinet", score: 6.34 """ def __init__( self, intervals=["0%", "20%"], device='cpu', max_sent_size=512 ): if intervals is None: raise ValueError( 'Intervals should be initialized for LMSubPopulation') super().__init__(intervals=intervals) self.tokenizer = None self.model = None self.device = device self.max_sent_size = max_sent_size def __repr__(self): return "LMSubPopulation-" + \ str(self.intervals[0]) + "-" + str(self.intervals[1]) def load_model(self): from transformers import GPT2Tokenizer, GPT2LMHeadModel self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2') self.model = GPT2LMHeadModel.from_pretrained( 'gpt2') self.model.to(self.device) def _score(self, sample, fields, **kwargs): r""" Calculate the score based on text perplexity :param sample: data sample :param list fields: list of field str :param kwargs: :return int: score for sample """ if not self.model: self.load_model() perplexity = 0 for field in fields: tokens = sample.get_words(field) indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokens)[ :self.max_sent_size] tokens_tensor = torch.tensor( [indexed_tokens], dtype=torch.long, device=self.device) with torch.no_grad(): output = self.model(tokens_tensor, labels=tokens_tensor) loss = output.loss perplexity += math.exp(loss.item()) return perplexity