Source code for textflint.generation_layer.transformation.RE.insert_clause

r"""
AddClause class for adding entity description transformation
"""
__all__ = ["InsertClause"]
import urllib.request
import wikidata
from wikidata.client import Client
import random
import json

from ....common.utils.error import FlintError
from ...transformation import Transformation
from ....input_layer.component.sample.re_sample import RESample
from ....common.settings import WIKIDATA_STATEMENTS_no_zero, \
    CLAUSE_HEAD_no_zero, scbase


[docs]class InsertClause(Transformation): r""" Add extra entity-related clause to text """ def __init__( self, **kwargs ): super().__init__() def __repr__(self): return 'InsertClause'
[docs] def search_list(self, query): r"""Retrieve entity id from wikidata :param string query: name of query entity :return list: information of query entity """ search = scbase + query try: url = urllib.request.urlopen(search) except OSError: raise FlintError('Time out to access Wikidata, ' 'plz check your network!') else: data = json.loads(url.read().decode()) return data['search']
def _get_clause(self, entid): r""" Obtain entity-related clause :param string entid: index of entity :return string: entity clause """ assert (isinstance(entid, str)), \ f"the type of 'entid' should be str, got {type(entid)} instead" client = Client() entity = client.get(entid) random.shuffle(WIKIDATA_STATEMENTS_no_zero) for sta in WIKIDATA_STATEMENTS_no_zero: staent = client.get(sta) statement = entity.getlist(staent) if statement: if isinstance(statement[0], wikidata.entity.Entity): ret = CLAUSE_HEAD_no_zero[sta] + str(statement[0].label) else: ret = CLAUSE_HEAD_no_zero[sta] + str(statement[0]) return ret return CLAUSE_HEAD_no_zero['description'] + str(entity.description)
[docs] def get_clause(self, query): r""" obtain entity description :param string query: name of query entity :return string: entity description """ assert (isinstance(query, str)), \ f"the type of 'query' should be str, got {type(query)} instead" scrsl = self.search_list(query) clauseadd = '' if scrsl: entinf = scrsl[0] clauseadd = self._get_clause(entinf['id']) return clauseadd
def _transform(self, sample, n=1, field='x', **kwargs): r""" Transform text string according to its entities :param RESample sample: sample input :param int n: number of generated samples (no more than one) :return list: transformed sample list """ assert(isinstance(sample, RESample)), \ f"the type of 'sample' should be RESample, got " \ f"{type(sample)} instead" assert(isinstance(n, int)), f"the type of 'n' should be int, " \ f"got {type(n)} instead" sh, st, oh, ot = sample.get_en() text, relation = sample.get_sent() trans_sample = {} head_entity_span = [sh, st] head_entity_name = ' '.join( text[head_entity_span[0]:head_entity_span[1] + 1]) tail_entity_span = [oh, ot] tail_entity_name = ' '.join( text[tail_entity_span[0]:tail_entity_span[1] + 1]) new_text = text query = head_entity_name.replace(" ", "%20") clauseadd = self.get_clause(query).split() if clauseadd: if head_entity_span[1] < tail_entity_span[0]: # head ... tail tail_entity_span[0], tail_entity_span[1] = \ tail_entity_span[0] + len(clauseadd), tail_entity_span[ 1] + len(clauseadd) new_text = new_text[:head_entity_span[1] + 1] + \ clauseadd + new_text[head_entity_span[1] + 1:] query = tail_entity_name.replace(" ", "%20") clauseadd = self.get_clause(query).split() if clauseadd: if tail_entity_span[1] < head_entity_span[0]: # tail ... head head_entity_span[0], head_entity_span[1] = \ head_entity_span[0] + len(clauseadd), \ head_entity_span[1] + len(clauseadd) new_text = new_text[:tail_entity_span[1] + 1] + \ clauseadd + new_text[tail_entity_span[1] + 1:] assert head_entity_name == " ".join( new_text[head_entity_span[0]:head_entity_span[1] + 1]) assert tail_entity_name == " ".join( new_text[tail_entity_span[0]:tail_entity_span[1] + 1]) trans_sample['x'], trans_sample['subj'], trans_sample['obj'], \ trans_sample['y'] = new_text,\ head_entity_span, tail_entity_span, relation trans_samples = sample.replace_sample_fields(trans_sample) return [trans_samples]