Source code for textflint.input_layer.model.flint_model.torch_model

"""
PyTorch Model Wrapper
--------------------------
"""


import torch
from torch.nn import CrossEntropyLoss


from .flint_model import FlintModel
from ....common.utils import device


[docs]class TorchModel(FlintModel): r""" Loads a PyTorch model (`nn.Module`) and tokenizer. """
[docs] def __init__( self, model, tokenizer, task='SA', batch_size=32 ): """ :param torch.nn.Module model: target PyTorch model :param tokenizer: tokenizer whose output can be packed as a tensor and passed to the model. No type requirement, but most have `tokenizer` method that accepts list of strings. :param str task: task name :param int batch_size: batch size to apply evaluation """ if not isinstance(model, torch.nn.Module): raise TypeError( f"PyTorch model must be torch.nn.Module, got type {type(model)}" ) super().__init__(model, tokenizer, task=task, batch_size=batch_size) self.model = model.to(device)
def __call__(self, *inputs): raise NotImplementedError()
[docs] def get_model_grad(self, text_inputs, loss_fn=CrossEntropyLoss()): r""" Get gradient of loss with respect to input tokens. :param str|[str] text_inputs: input string or input string list :param torch.nn.Module loss_fn: loss function. Default is `torch.nn.CrossEntropyLoss` :return: Dict of ids, tokens, and gradient as numpy array. """ if not hasattr(self.model, "get_input_embeddings"): raise AttributeError( f"{type(self.model)} must have method `get_input_embeddings` " f"that returns `torch.nn.Embedding` object that represents " f"input embedding layer" ) if not isinstance(loss_fn, torch.nn.Module): raise ValueError("Loss function must be of type `torch.nn.Module`.") self.model.train() embedding_layer = self.model.get_input_embeddings() original_state = embedding_layer.weight.requires_grad embedding_layer.weight.requires_grad = True emb_grads = [] def grad_hook(module, grad_in, grad_out): emb_grads.append(grad_out[0]) emb_hook = embedding_layer.register_backward_hook(grad_hook) self.model.zero_grad() model_device = next(self.model.parameters()).device inputs_ids = self.encode(text_inputs) ids = [torch.tensor(ids).to(model_device) for ids in inputs_ids] predictions = self.model(text_inputs) output = predictions.argmax(dim=1) loss = loss_fn(predictions, output) loss.backward() # grad w.r.t to word embeddings grad = torch.transpose(emb_grads[0], 0, 1)[0].cpu().numpy() embedding_layer.weight.requires_grad = original_state emb_hook.remove() self.model.eval() output = {"ids": ids[0].tolist(), "gradient": grad} return output
[docs] def encode(self, inputs): r""" Tokenize inputs and convert it to ids. :param inputs: model original input :return: list of inputs ids """ raise NotImplementedError()