diff --git a/setup.cfg b/setup.cfg index 4047bb0..d8549f1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.10.6 +current_version = 0.11.0 commit = True tag = True diff --git a/setup.py b/setup.py index dc69d4f..faf3717 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ with open('HISTORY.rst') as history_file: history = history_file.read() -requirements = ['torch==1.2.0', 'torchvision', 'tqdm==4.32.2'] +requirements = ['torch==1.2.0', 'torchvision', 'tqdm==4.32.2', 'pandas'] setup_requirements = [] @@ -39,6 +39,6 @@ test_suite='tests', tests_require=test_requirements, url='https://github.com/torchkge-team/torchkge', - version='0.10.6', + version='0.11.0', zip_safe=False, ) diff --git a/torchkge/__init__.py b/torchkge/__init__.py index c1d56f5..a82756e 100644 --- a/torchkge/__init__.py +++ b/torchkge/__init__.py @@ -4,7 +4,7 @@ __author__ = """Armand Boschin""" __email__ = 'aboschin@enst.fr' -__version__ = '0.10.6' +__version__ = '0.11.0' from .data import KnowledgeGraph diff --git a/torchkge/data/DataLoader.py b/torchkge/data/DataLoader.py new file mode 100644 index 0000000..d8d3f8b --- /dev/null +++ b/torchkge/data/DataLoader.py @@ -0,0 +1,168 @@ +# -*- coding: utf-8 -*- +""" +Copyright TorchKGE developers +aboschin@enst.fr +This module's code is freely adapted from Scikit-Learn's sklearn.datasets.base.py code. +""" + +from .KnowledgeGraph import KnowledgeGraph + +from os import environ, makedirs, remove +from os.path import exists, expanduser, join +from pandas import read_csv, concat, merge, DataFrame +from urllib.request import urlretrieve + +import shutil +import tarfile +import zipfile + + +def get_data_home(data_home=None): + if data_home is None: + data_home = environ.get('TORCHKGE_DATA', + join('~', 'torchkge_data')) + data_home = expanduser(data_home) + if not exists(data_home): + makedirs(data_home) + return data_home + + +def clear_data_home(data_home=None): + data_home = get_data_home(data_home) + shutil.rmtree(data_home) + + +def load_fb13(): + data_home = get_data_home() + data_path = data_home + '/FB13' + if not exists(data_path): + makedirs(data_path, exist_ok=True) + urlretrieve("https://graphs.telecom-paristech.fr/datasets/FB13.zip", + data_home + '/FB13.zip') + with zipfile.ZipFile(data_home + '/FB13.zip', 'r') as zip_ref: + zip_ref.extractall(data_home) + remove(data_home + '/FB13.zip') + shutil.rmtree(data_home + '/__MACOSX') + + df1 = read_csv(data_path + '/train2id.txt', + sep='\t', header=None, names=['from', 'rel', 'to']) + df2 = read_csv(data_path + '/valid2id.txt', + sep='\t', header=None, names=['from', 'rel', 'to']) + df3 = read_csv(data_path + '/test2id.txt', + sep='\t', header=None, names=['from', 'rel', 'to']) + df = concat([df1, df2, df3]) + kg = KnowledgeGraph(df) + + return kg.split_kg(sizes=(len(df1), len(df2), len(df3))) + + +def load_fb15k(): + data_home = get_data_home() + data_path = data_home + '/FB15k' + if not exists(data_path): + makedirs(data_path, exist_ok=True) + urlretrieve("https://graphs.telecom-paristech.fr/datasets/FB15k.zip", + data_home + '/FB15k.zip') + with zipfile.ZipFile(data_home + '/FB15k.zip', 'r') as zip_ref: + zip_ref.extractall(data_home) + remove(data_home + '/FB15k.zip') + shutil.rmtree(data_home + '/__MACOSX') + + df1 = read_csv(data_path + '/freebase_mtr100_mte100-train.txt', + sep='\t', header=None, names=['from', 'rel', 'to']) + df2 = read_csv(data_path + '/freebase_mtr100_mte100-valid.txt', + sep='\t', header=None, names=['from', 'rel', 'to']) + df3 = read_csv(data_path + '/freebase_mtr100_mte100-test.txt', + sep='\t', header=None, names=['from', 'rel', 'to']) + df = concat([df1, df2, df3]) + kg = KnowledgeGraph(df) + + return kg.split_kg(sizes=(len(df1), len(df2), len(df3))) + + +def load_fb15k237(): + data_home = get_data_home() + data_path = data_home + '/FB15k237' + if not exists(data_path): + makedirs(data_path, exist_ok=True) + urlretrieve("https://graphs.telecom-paristech.fr/datasets/FB15k237.zip", + data_home + '/FB15k237.zip') + with zipfile.ZipFile(data_home + '/FB15k237.zip', 'r') as zip_ref: + zip_ref.extractall(data_home) + remove(data_home + '/FB15k237.zip') + shutil.rmtree(data_home + '/__MACOSX') + + df1 = read_csv(data_path + '/train.txt', + sep='\t', header=None, names=['from', 'rel', 'to']) + df2 = read_csv(data_path + '/valid.txt', + sep='\t', header=None, names=['from', 'rel', 'to']) + df3 = read_csv(data_path + '/test.txt', + sep='\t', header=None, names=['from', 'rel', 'to']) + df = concat([df1, df2, df3]) + kg = KnowledgeGraph(df) + + return kg.split_kg(sizes=(len(df1), len(df2), len(df3))) + + +def load_wn18(): + data_home = get_data_home() + data_path = data_home + '/WN18' + if not exists(data_path): + makedirs(data_path, exist_ok=True) + urlretrieve("https://graphs.telecom-paristech.fr/datasets/WN18.zip", + data_home + '/WN18.zip') + with zipfile.ZipFile(data_home + '/WN18.zip', 'r') as zip_ref: + zip_ref.extractall(data_home) + remove(data_home + '/WN18.zip') + shutil.rmtree(data_home + '/__MACOSX') + + df1 = read_csv(data_path + '/wordnet-mlj12-train.txt', + sep='\t', header=None, names=['from', 'rel', 'to']) + df2 = read_csv(data_path + '/wordnet-mlj12-valid.txt', + sep='\t', header=None, names=['from', 'rel', 'to']) + df3 = read_csv(data_path + '/wordnet-mlj12-test.txt', + sep='\t', header=None, names=['from', 'rel', 'to']) + df = concat([df1, df2, df3]) + kg = KnowledgeGraph(df) + + return kg.split_kg(sizes=(len(df1), len(df2), len(df3))) + + +def load_wikidatasets(which, limit_=None): + assert which in ['humans', 'companies', 'animals', 'countries', 'films'] + + if limit_ is None: + limit_ = 0 + + data_home = get_data_home() + '/WikiDataSets' + data_path = data_home + '/' + which + if not exists(data_path): + makedirs(data_path, exist_ok=True) + urlretrieve("https://graphs.telecom-paristech.fr/WikiDataSets/{}.tar.gz".format(which), + data_home + '/{}.tar.gz'.format(which)) + + with tarfile.open(data_home + '/{}.tar.gz'.format(which), 'r') as tf: + tf.extractall(data_home) + remove(data_home + '/{}.tar.gz'.format(which)) + + df = read_csv(data_path + '/edges.txt'.format(which), sep='\t', header=1, + names=['from', 'to', 'rel']) + + a = df.groupby('from').count()['rel'] + b = df.groupby('to').count()['rel'] + + # Filter out nodes with too few facts + tmp = merge(right=DataFrame(a).reset_index(), + left=DataFrame(b).reset_index(), + how='outer', right_on='from', left_on='to', ).fillna(0) + + tmp['rel'] = tmp['rel_x'] + tmp['rel_y'] + tmp = tmp.drop(['from', 'rel_x', 'rel_y'], axis=1) + + tmp = tmp.loc[tmp['rel'] >= limit_] + df_bis = df.loc[df['from'].isin(tmp['to']) | df['to'].isin(tmp['to'])] + + kg = KnowledgeGraph(df_bis) + kg_train, kg_val, kg_test = kg.split_kg(share=0.8, validation=True) + + return kg_train, kg_val, kg_test diff --git a/torchkge/data/KnowledgeGraph.py b/torchkge/data/KnowledgeGraph.py index 308c32f..061bd03 100644 --- a/torchkge/data/KnowledgeGraph.py +++ b/torchkge/data/KnowledgeGraph.py @@ -54,14 +54,14 @@ class KnowledgeGraph(Dataset): Number of distinct entities in the data set. n_rel: int Number of distinct entities in the data set. - n_sample: int + n_facts: int Number of samples in the data set. A sample is a fact: a triplet (h, r, l). - head_idx: torch.Tensor, dtype = long, shape = (n_sample) - List of the int key of heads for each sample (fact). - tail_idx: torch.Tensor, dtype = long, shape = (n_sample) - List of the int key of tails for each sample (facts). - relations: torch.Tensor, dtype = long, shape = (n_sample) - List of the int key of relations for each sample (facts). + head_idx: torch.Tensor, dtype = long, shape = (n_facts) + List of the int key of heads for each fact. + tail_idx: torch.Tensor, dtype = long, shape = (n_facts) + List of the int key of tails for each fact. + relations: torch.Tensor, dtype = long, shape = (n_facts) + List of the int key of relations for each fact. """ @@ -84,14 +84,14 @@ def __init__(self, df=None, kg=None, if df is not None: assert kg is None self.df = df - self.n_sample = len(df) + self.n_facts = len(df) self.head_idx = tensor(df['from'].map(self.ent2ix).values).long() self.tail_idx = tensor(df['to'].map(self.ent2ix).values).long() self.relations = tensor(df['rel'].map(self.rel2ix).values).long() else: assert kg is not None self.df = kg['df'] - self.n_sample = kg['heads'].shape[0] + self.n_facts = kg['heads'].shape[0] self.head_idx = kg['heads'] self.tail_idx = kg['tails'] self.relations = kg['relations'] @@ -107,7 +107,7 @@ def __init__(self, df=None, kg=None, self.dict_of_tails = dict_of_tails def __len__(self): - return self.n_sample + return self.n_facts def __getitem__(self, item): return self.head_idx[item].item(), self.tail_idx[item].item(), self.relations[item].item() @@ -139,9 +139,9 @@ def split_kg(self, share=0.8, sizes=None, validation=False): if sizes is not None: try: if len(sizes) == 3: - assert (sizes[0] + sizes[1] + sizes[2] == self.n_sample) + assert (sizes[0] + sizes[1] + sizes[2] == self.n_facts) elif len(sizes) == 2: - assert (sizes[0] + sizes[1] == self.n_sample) + assert (sizes[0] + sizes[1] == self.n_facts) else: raise SizeMismatchError('Tuple `sizes` should be of length 2 or 3.') except AssertionError: @@ -156,10 +156,10 @@ def split_kg(self, share=0.8, sizes=None, validation=False): mask_te = ~(mask_tr | mask_val) else: mask_tr = cat([tensor([1 for _ in range(sizes[0])]), - tensor([0 for _ in range(sizes[1] + sizes[2])])]).byte() + tensor([0 for _ in range(sizes[1] + sizes[2])])]).bool() mask_val = cat([tensor([0 for _ in range(sizes[0])]), tensor([1 for _ in range(sizes[1])]), - tensor([0 for _ in range(sizes[2])])]).byte() + tensor([0 for _ in range(sizes[2])])]).bool() mask_te = ~(mask_tr | mask_val) return KnowledgeGraph( @@ -182,7 +182,7 @@ def split_kg(self, share=0.8, sizes=None, validation=False): mask_tr = (empty(self.head_idx.shape).uniform_() < share) else: mask_tr = cat([tensor([1 for _ in range(sizes[0])]), - tensor([0 for _ in range(sizes[1])])]).byte() + tensor([0 for _ in range(sizes[1])])]).bool() return KnowledgeGraph( kg={'heads': self.head_idx[mask_tr], 'tails': self.tail_idx[mask_tr], 'relations': self.relations[mask_tr], 'df': self.df}, @@ -198,7 +198,7 @@ def evaluate_dicts(self): fact in the entire knowledge graph. """ - for i in tqdm(range(self.n_sample)): + for i in tqdm(range(self.n_facts)): self.dict_of_heads[(self.tail_idx[i].item(), self.relations[i].item())].add(self.head_idx[i].item()) self.dict_of_tails[(self.head_idx[i].item(), diff --git a/torchkge/data/__init__.py b/torchkge/data/__init__.py index c898ffe..93bf4ca 100644 --- a/torchkge/data/__init__.py +++ b/torchkge/data/__init__.py @@ -1,2 +1,8 @@ from .KnowledgeGraph import SmallKG from .KnowledgeGraph import KnowledgeGraph + +from .DataLoader import load_fb13 +from .DataLoader import load_fb15k +from .DataLoader import load_fb15k237 +from .DataLoader import load_wn18 +from .DataLoader import load_wikidatasets diff --git a/torchkge/evaluation/LinkPrediction.py b/torchkge/evaluation/LinkPrediction.py index c2e3277..c69bdb0 100644 --- a/torchkge/evaluation/LinkPrediction.py +++ b/torchkge/evaluation/LinkPrediction.py @@ -56,10 +56,10 @@ def __init__(self, model, knowledge_graph): self.model = model self.kg = knowledge_graph - self.rank_true_heads = empty(size=(knowledge_graph.n_sample,)).long() - self.rank_true_tails = empty(size=(knowledge_graph.n_sample,)).long() - self.filt_rank_true_heads = empty(size=(knowledge_graph.n_sample,)).long() - self.filt_rank_true_tails = empty(size=(knowledge_graph.n_sample,)).long() + self.rank_true_heads = empty(size=(knowledge_graph.n_facts,)).long() + self.rank_true_tails = empty(size=(knowledge_graph.n_facts,)).long() + self.filt_rank_true_heads = empty(size=(knowledge_graph.n_facts,)).long() + self.filt_rank_true_tails = empty(size=(knowledge_graph.n_facts,)).long() self.evaluated = False self.k_max = 10 diff --git a/torchkge/evaluation/TripletClassification.py b/torchkge/evaluation/TripletClassification.py index 11db390..72c7ce0 100644 --- a/torchkge/evaluation/TripletClassification.py +++ b/torchkge/evaluation/TripletClassification.py @@ -65,17 +65,17 @@ def get_scores(self, heads, tails, relations, batch_size): Parameters ---------- - heads: torch.Tensor, dtype = long, shape = n_sample + heads: torch.Tensor, dtype = long, shape = n_facts List of heads indices. - tails: torch.Tensor, dtype = long, shape = n_sample + tails: torch.Tensor, dtype = long, shape = n_facts List of tails indices. - relations: torch.Tensor, dtype = long, shape = n_sample + relations: torch.Tensor, dtype = long, shape = n_facts List of relation indices. batch_size: int Returns ------- - scores: torch.Tensor, dtype = float, shape = n_sample + scores: torch.Tensor, dtype = float, shape = n_facts List of scores of each triplet. """ scores = [] @@ -111,9 +111,11 @@ def evaluate(self, batch_size): self.thresholds = zeros(self.kg_val.n_rel) for i in range(self.kg_val.n_rel): - mask = (r_idx == i).byte() - assert mask.sum() > 0 - self.thresholds[i] = neg_scores[mask == 1].max() + mask = (r_idx == i).bool() + if mask.sum() > 0: + self.thresholds[i] = neg_scores[mask].max() + else: + self.thresholds[i] = neg_scores.max() self.evaluated = True self.thresholds.detach_() @@ -144,7 +146,7 @@ def accuracy(self, batch_size): if self.use_cuda: self.thresholds = self.thresholds.cuda() - scores = (scores > self.thresholds[r_idx]).byte() - neg_scores = (neg_scores < self.thresholds[r_idx]).byte() + scores = (scores > self.thresholds[r_idx]) + neg_scores = (neg_scores < self.thresholds[r_idx]) - return (scores.sum().item() + neg_scores.sum().item()) / (2 * self.kg_test.n_sample) + return (scores.sum().item() + neg_scores.sum().item()) / (2 * self.kg_test.n_facts) diff --git a/torchkge/models/BilinearModels.py b/torchkge/models/BilinearModels.py index 3052a89..c670389 100644 --- a/torchkge/models/BilinearModels.py +++ b/torchkge/models/BilinearModels.py @@ -4,17 +4,17 @@ aboschin@enst.fr """ -from torch import empty, matmul, tensor, diag_embed -from torch.nn import Module, Embedding, Parameter +from torch import empty, matmul, diag_embed +from torch.nn import Embedding, Parameter from torch.nn.functional import normalize from torch.nn.init import xavier_uniform_ from torchkge.models import Model -from torchkge.utils import get_rank, get_mask, get_rolling_matrix +from torchkge.utils import get_rank, get_mask, get_rolling_matrix, get_true_targets from torchkge.exceptions import WrongDimensionError -class RESCALModel(Module, Model): +class RESCALModel(Model): """Implementation of RESCAL model detailed in 2011 paper by Nickel et al..\ In the original paper, optimization is done using Alternating Least Squares (ALS). Here we use\ iterative gradient descent optimization. @@ -50,51 +50,23 @@ class RESCALModel(Module, Model): """ def __init__(self, ent_emb_dim, n_entities, n_relations): - super().__init__() - self.ent_emb_dim = ent_emb_dim - self.number_entities = n_entities - self.number_relations = n_relations + super().__init__(ent_emb_dim, n_entities, n_relations) # initialize embedding objects self.entity_embeddings = Embedding(self.number_entities, self.ent_emb_dim) self.relation_matrices = Parameter(xavier_uniform_(empty(size=(self.number_relations, self.ent_emb_dim, - self.ent_emb_dim)))) + self.ent_emb_dim))), + requires_grad=True) # fill the embedding weights with Xavier initialized values self.entity_embeddings.weight = Parameter(xavier_uniform_( - empty(size=(self.number_entities, self.ent_emb_dim)))) + empty(size=(self.number_entities, self.ent_emb_dim))), + requires_grad=True) # normalize the embeddings self.normalize_parameters() - def forward(self, heads, tails, negative_heads, negative_tails, relations): - """Forward pass on the current batch. - - Parameters - ---------- - heads: torch tensor, dtype = long, shape = (batch_size) - Integer keys of the current batch's heads - tails: torch tensor, dtype = long, shape = (batch_size) - Integer keys of the current batch's tails. - negative_heads: torch tensor, dtype = long, shape = (batch_size) - Integer keys of the current batch's negatively sampled heads. - negative_tails: torch tensor, dtype = long, shape = (batch_size) - Integer keys of the current batch's negatively sampled tails. - relations: torch tensor, dtype = long, shape = (batch_size) - Integer keys of the current batch's relations. - - Returns - ------- - golden_triplets: torch tensor, dtype = float, shape = (batch_size) - Estimation of the true value that should be 1 (by matrix factorization). - negative_triplets: torch tensor, dtype = float, shape = (batch_size) - Estimation of the true value that should be 0 (by matrix factorization). - - """ - return self.scoring_function(heads, tails, relations), \ - self.scoring_function(negative_heads, negative_tails, relations) - def scoring_function(self, heads_idx, tails_idx, rels_idx): # recover entities embeddings heads_embeddings = normalize(self.entity_embeddings(heads_idx), p=2, dim=1) @@ -147,6 +119,18 @@ def compute_product(self, heads, tails, rel_mat): return matmul(matmul(heads, rel_mat), tails).view(b_size, -1) + def get_head_tail_candidates(self, h_idx, t_idx): + b_size = h_idx.shape[0] + + candidates = self.entity_embeddings.weight.data + candidates = candidates.view(1, self.number_entities, self.ent_emb_dim) + candidates = candidates.expand(b_size, self.number_entities, self.ent_emb_dim) + + h_emb = self.entity_embeddings(h_idx) + t_emb = self.entity_embeddings(t_idx) + + return h_emb, t_emb, candidates + def evaluation_helper(self, h_idx, t_idx, r_idx): """ @@ -171,17 +155,10 @@ def evaluation_helper(self, h_idx, t_idx, r_idx): Tensor containing all entities as candidates for each sample of the batch. """ - b_size = h_idx.shape[0] - - candidates = self.entity_embeddings.weight.data - candidates = candidates.view(1, self.number_entities, self.ent_emb_dim) - candidates = candidates.expand(b_size, self.number_entities, self.ent_emb_dim) - - h_emb = self.entity_embeddings(h_idx) - t_emb = self.entity_embeddings(t_idx) + h_emb, t_emb, candidates = self.get_head_tail_candidates(h_idx, t_idx) r_mat = self.relation_matrices[r_idx] - return h_emb, t_emb, r_mat, candidates + return h_emb, t_emb, candidates, r_mat def compute_ranks(self, e_emb, candidates, r_mat, e_idx, r_idx, true_idx, dictionary, heads=1): """ @@ -228,11 +205,9 @@ def compute_ranks(self, e_emb, candidates, r_mat, e_idx, r_idx, true_idx, dictio # filter out the true negative samples by assigning negative score filt_scores = scores.clone() for i in range(current_batch_size): - true_targets = dictionary[e_idx[i].item(), r_idx[i].item()].copy() - if len(true_targets) == 1: + true_targets = get_true_targets(dictionary, e_idx, r_idx, true_idx, i) + if true_targets is None: continue - true_targets.remove(true_idx[i].item()) - true_targets = tensor(list(true_targets)).long() filt_scores[i][true_targets] = float(-1) # from dissimilarities, extract the rank of the true entity. @@ -241,24 +216,6 @@ def compute_ranks(self, e_emb, candidates, r_mat, e_idx, r_idx, true_idx, dictio return rank_true_entities, filtered_rank_true_entities - def evaluate_candidates(self, h_idx, t_idx, r_idx, kg): - h_emb, t_emb, r_mat, candidates = self.evaluation_helper(h_idx, t_idx, r_idx) - - rank_true_tails, filt_rank_true_tails = self.compute_ranks(h_emb, - candidates, - r_mat, h_idx, r_idx, - t_idx, - kg.dict_of_tails, - heads=1) - rank_true_heads, filt_rank_true_heads = self.compute_ranks(t_emb, - candidates, - r_mat, t_idx, r_idx, - h_idx, - kg.dict_of_heads, - heads=-1) - - return rank_true_tails, filt_rank_true_tails, rank_true_heads, filt_rank_true_heads - class DistMultModel(RESCALModel): """Implementation of DistMult model detailed in 2014 paper by Yang et al.. @@ -299,7 +256,8 @@ def __init__(self, ent_emb_dim, n_entities, n_relations): del self.relation_matrices self.relation_vectors = Parameter( - xavier_uniform_(empty(size=(self.number_relations, self.ent_emb_dim)))) + xavier_uniform_(empty(size=(self.number_relations, self.ent_emb_dim))), + requires_grad=True) def scoring_function(self, heads_idx, tails_idx, rels_idx): # recover entities embeddings @@ -337,17 +295,10 @@ def evaluation_helper(self, h_idx, t_idx, r_idx): Tensor containing all entities as candidates for each sample of the batch. """ - b_size = h_idx.shape[0] - - candidates = self.entity_embeddings.weight.data - candidates = candidates.view(1, self.number_entities, self.ent_emb_dim) - candidates = candidates.expand(b_size, self.number_entities, self.ent_emb_dim) - - h_emb = self.entity_embeddings(h_idx) - t_emb = self.entity_embeddings(t_idx) + h_emb, t_emb, candidates = self.get_head_tail_candidates(h_idx, t_idx) r_mat = diag_embed(self.relation_vectors[r_idx]) - return h_emb, t_emb, r_mat, candidates + return h_emb, t_emb, candidates, r_mat class HolEModel(RESCALModel): @@ -383,6 +334,7 @@ class HolEModel(RESCALModel): with Xavier uniform. """ + def __init__(self, ent_emb_dim, n_entities, n_relations): super().__init__(ent_emb_dim, n_entities, n_relations) @@ -426,17 +378,10 @@ def evaluation_helper(self, h_idx, t_idx, r_idx): Tensor containing all entities as candidates for each sample of the batch. """ - b_size = h_idx.shape[0] - - candidates = self.entity_embeddings.weight.data - candidates = candidates.view(1, self.number_entities, self.ent_emb_dim) - candidates = candidates.expand(b_size, self.number_entities, self.ent_emb_dim) - - h_emb = self.entity_embeddings(h_idx) - t_emb = self.entity_embeddings(t_idx) + h_emb, t_emb, candidates = self.get_head_tail_candidates(h_idx, t_idx) r_mat = get_rolling_matrix(self.relation_vectors[r_idx]) - return h_emb, t_emb, r_mat, candidates + return h_emb, t_emb, candidates, r_mat def normalize_parameters(self): pass @@ -476,6 +421,7 @@ class ComplExModel(DistMultModel): smaller_dim: int Number of 2x2 matrices on the diagonals of relation-specific matrices. """ + def __init__(self, ent_emb_dim, n_entities, n_relations): try: assert ent_emb_dim % 2 == 0 diff --git a/torchkge/models/Models.py b/torchkge/models/Models.py index a27361b..f34b413 100644 --- a/torchkge/models/Models.py +++ b/torchkge/models/Models.py @@ -4,10 +4,46 @@ aboschin@enst.fr """ +from torch import arange +from torch.nn import Module -class Model(object): - def __init__(self): - pass +from torchkge.utils import init_embedding, l1_dissimilarity, l2_dissimilarity +from torchkge.utils import get_rank, get_true_targets + + +class Model(Module): + def __init__(self, ent_emb_dim, n_entities, n_relations): + super().__init__() + self.ent_emb_dim = ent_emb_dim + self.number_entities = n_entities + self.number_relations = n_relations + + def forward(self, heads, tails, negative_heads, negative_tails, relations): + """Forward pass on the current batch. + + Parameters + ---------- + heads: torch tensor, dtype = long, shape = (batch_size) + Integer keys of the current batch's heads + tails: torch tensor, dtype = long, shape = (batch_size) + Integer keys of the current batch's tails. + negative_heads: torch tensor, dtype = long, shape = (batch_size) + Integer keys of the current batch's negatively sampled heads. + negative_tails: torch tensor, dtype = long, shape = (batch_size) + Integer keys of the current batch's negatively sampled tails. + relations: torch tensor, dtype = long, shape = (batch_size) + Integer keys of the current batch's relations. + + Returns + ------- + golden_triplets: torch tensor, dtype = float, shape = (batch_size) + Scoring function evaluated on true triples. + negative_triplets: torch tensor, dtype = float, shape = (batch_size) + Scoring function evaluated on negatively sampled triples. + + """ + return self.scoring_function(heads, tails, relations), \ + self.scoring_function(negative_heads, negative_tails, relations) def scoring_function(self, heads_idx, tails_idx, rels_idx): pass @@ -16,4 +52,125 @@ def normalize_parameters(self): pass def evaluation_helper(self, h_idx, t_idx, r_idx): + raise NotImplementedError + + def evaluate_candidates(self, h_idx, t_idx, r_idx, kg): + proj_h_emb, proj_t_emb, candidates, r_emb = self.evaluation_helper(h_idx, t_idx, r_idx) + + # evaluation_helper both ways (head, rel) -> tail and (rel, tail) -> head + rank_true_tails, filt_rank_true_tails = self.compute_ranks(proj_h_emb, + candidates, + r_emb, h_idx, r_idx, + t_idx, + kg.dict_of_tails, + heads=1) + rank_true_heads, filt_rank_true_heads = self.compute_ranks(proj_t_emb, + candidates, + r_emb, t_idx, r_idx, + h_idx, + kg.dict_of_heads, + heads=-1) + + return rank_true_tails, filt_rank_true_tails, rank_true_heads, filt_rank_true_heads + + +class TranslationalModel(Model): + def __init__(self, ent_emb_dim, n_entities, n_relations, dissimilarity): + super().__init__(ent_emb_dim, n_entities, n_relations) + + self.entity_embeddings = init_embedding(self.number_entities, self.ent_emb_dim) + + assert dissimilarity in ['L1', 'L2', None] + if dissimilarity == 'L1': + self.dissimilarity = l1_dissimilarity + elif dissimilarity == 'L2': + self.dissimilarity = l2_dissimilarity + else: + self.dissimilarity = None + + def recover_project_normalize(self, ent_idx, normalize_): pass + + def compute_ranks(self, proj_e_emb, proj_candidates, + r_emb, e_idx, r_idx, true_idx, dictionary, heads=1): + """ + + Parameters + ---------- + proj_e_emb: torch.Tensor, shape = (batch_size, rel_emb_dim), dtype = float + Tensor containing current projected embeddings of entities. + proj_candidates: torch.Tensor, shape = (b_size, rel_emb_dim, n_entities), dtype = float + Tensor containing projected embeddings of all entities. + r_emb: torch.Tensor, shape = (batch_size, ent_emb_dim), dtype = float + Tensor containing current embeddings of relations. + e_idx: torch.Tensor, shape = (batch_size), dtype = long + Tensor containing the indices of entities. + r_idx: torch.Tensor, shape = (batch_size), dtype = long + Tensor containing the indices of relations. + true_idx: torch.Tensor, shape = (batch_size), dtype = long + Tensor containing the true entity for each sample. + dictionary: default dict + Dictionary of keys (int, int) and values list of ints giving all possible entities for + the (entity, relation) pair. + heads: integer + 1 ou -1 (must be 1 if entities are heads and -1 if entities are tails). \ + We test dissimilarity between heads * entities + relations and heads * targets. + + + Returns + ------- + rank_true_entities: torch.Tensor, shape = (b_size), dtype = int + Tensor containing the rank of the true entities when ranking any entity based on \ + computation of d(hear+relation, tail). + filtered_rank_true_entities: torch.Tensor, shape = (b_size), dtype = int + Tensor containing the rank of the true entities when ranking only true false entities \ + based on computation of d(hear+relation, tail). + + """ + current_batch_size, embedding_dimension = proj_e_emb.shape + + # tmp_sum is either heads + r_emb or r_emb - tails (expand does not use extra memory) + tmp_sum = (heads * proj_e_emb + r_emb).view((current_batch_size, embedding_dimension, 1)) + tmp_sum = tmp_sum.expand((current_batch_size, embedding_dimension, self.number_entities)) + + # compute either dissimilarity(heads + relation, proj_candidates) or + # dissimilarity(-proj_candidates, relation - tails) + dissimilarities = self.dissimilarity(tmp_sum, heads * proj_candidates) + + # filter out the true negative samples by assigning infinite dissimilarity + filt_dissimilarities = dissimilarities.clone() + for i in range(current_batch_size): + true_targets = get_true_targets(dictionary, e_idx, r_idx, true_idx, i) + if true_targets is None: + continue + filt_dissimilarities[i][true_targets] = float('Inf') + + # from dissimilarities, extract the rank of the true entity. + rank_true_entities = get_rank(-dissimilarities, true_idx) + filtered_rank_true_entities = get_rank(-filt_dissimilarities, true_idx) + + return rank_true_entities, filtered_rank_true_entities + + def evaluation_helper(self, h_idx, t_idx, r_idx): + return None, None, None, None + + def recover_candidates(self, h_idx, b_size): + all_idx = arange(0, self.number_entities).long() + if h_idx.is_cuda: + all_idx = all_idx.cuda() + candidates = self.entity_embeddings(all_idx).transpose(0, 1) + candidates = candidates.view((1, + self.ent_emb_dim, + self.number_entities)).expand((b_size, + self.ent_emb_dim, + self.number_entities)) + return candidates + + def projection_helper(self, h_idx, t_idx, b_size, candidates, rel_emb_dim): + mask = h_idx.view(b_size, 1, 1).expand(b_size, rel_emb_dim, 1) + proj_h_emb = candidates.gather(dim=2, index=mask).view(b_size, rel_emb_dim) + + mask = t_idx.view(b_size, 1, 1).expand(b_size, rel_emb_dim, 1) + proj_t_emb = candidates.gather(dim=2, index=mask).view(b_size, rel_emb_dim) + + return proj_h_emb, proj_t_emb diff --git a/torchkge/models/TranslationModels.py b/torchkge/models/TranslationModels.py index f93d46a..985cc35 100644 --- a/torchkge/models/TranslationModels.py +++ b/torchkge/models/TranslationModels.py @@ -5,18 +5,19 @@ """ from torch import empty, matmul, eye, arange, tensor -from torch.nn import Module, Parameter +from torch.nn import Parameter from torch.nn.functional import normalize from torch.nn.init import xavier_uniform_ from torch.cuda import empty_cache -from torchkge.models import Model -from torchkge.utils import get_rank, init_embedding, l2_dissimilarity +from torchkge.models import TranslationalModel +from torchkge.utils import init_embedding +from torchkge.utils import l1_torus_dissimilarity, l2_torus_dissimilarity, el2_torus_dissimilarity from tqdm import tqdm -class TransEModel(Module, Model): +class TransEModel(TranslationalModel): """Implementation of TransE model detailed in 2013 paper by Bordes et al.. References @@ -34,8 +35,8 @@ class TransEModel(Module, Model): Number of entities in the current data set. n_relations: int Number of relations in the current data set. - dissimilarity: function - Used to compute dissimilarities (e.g. L1 or L2 dissimilarities). + dissimilarity: String + Either 'L1' or 'L2'. Attributes ---------- @@ -56,58 +57,22 @@ class TransEModel(Module, Model): """ - def __init__(self, ent_emb_dim, n_entities, n_relations, dissimilarity, rel_emb_dim=None): - super().__init__() + def __init__(self, ent_emb_dim, n_entities, n_relations, dissimilarity): + try: + assert dissimilarity in ['L1', 'L2', None] + except AssertionError: + raise AssertionError("Dissimilarity variable can either be 'L1' or 'L2'.") - self.ent_emb_dim = ent_emb_dim - self.number_entities = n_entities - self.number_relations = n_relations - self.dissimilarity = dissimilarity - - if rel_emb_dim is None: - rel_emb_dim = ent_emb_dim - else: - self.rel_emb_dim = rel_emb_dim + super().__init__(ent_emb_dim, n_entities, n_relations, dissimilarity) # initialize embeddings - self.entity_embeddings = init_embedding(self.number_entities, self.ent_emb_dim) - self.relation_embeddings = init_embedding(self.number_relations, rel_emb_dim) + self.relation_embeddings = init_embedding(self.number_relations, self.ent_emb_dim) # normalize parameters - self.entity_embeddings.weight.data = normalize(self.entity_embeddings.weight.data, - p=2, dim=1) self.relation_embeddings.weight.data = normalize(self.relation_embeddings.weight.data, p=2, dim=1) - - def forward(self, heads, tails, negative_heads, negative_tails, relations): - """Forward pass on the current batch. - - Parameters - ---------- - heads: torch.Tensor, dtype = long, shape = (batch_size) - Integer keys of the current batch's heads - tails: torch.Tensor, dtype = long, shape = (batch_size) - Integer keys of the current batch's tails. - negative_heads: torch.Tensor, dtype = long, shape = (batch_size) - Integer keys of the current batch's negatively sampled heads. - negative_tails: torch.Tensor, dtype = long, shape = (batch_size) - Integer keys of the current batch's negatively sampled tails. - relations: torch.Tensor, dtype = long, shape = (batch_size) - Integer keys of the current batch's relations. - - Returns - ------- - golden_triplets: torch.Tensor, dtype = float, shape = (batch_size) - Score function: opposite of dissimilarities between h+r and t for golden triplets. - negative_triplets: torch.Tensor, dtype = float, shape = (batch_size) - Score function: opposite of dissimilarities between h+r and t for negatively - sampled triplets. - - """ - golden_triplets = self.scoring_function(heads, tails, relations) - negative_triplets = self.scoring_function(negative_heads, negative_tails, relations) - - return golden_triplets, negative_triplets + self.entity_embeddings.weight.data = normalize(self.entity_embeddings.weight.data, + p=2, dim=1) def scoring_function(self, heads_idx, tails_idx, rels_idx): """Compute the scoring function for the triplets given as argument. @@ -210,87 +175,6 @@ def evaluation_helper(self, h_idx, t_idx, r_idx): return proj_h_emb, proj_t_emb, proj_candidates, r_emb - def compute_ranks(self, proj_e_emb, proj_candidates, - r_emb, e_idx, r_idx, true_idx, dictionary, heads=1): - """ - - Parameters - ---------- - proj_e_emb: torch.Tensor, shape = (batch_size, rel_emb_dim), dtype = float - Tensor containing current projected embeddings of entities. - proj_candidates: torch.Tensor, shape = (b_size, rel_emb_dim, n_entities), dtype = float - Tensor containing projected embeddings of all entities. - r_emb: torch.Tensor, shape = (batch_size, ent_emb_dim), dtype = float - Tensor containing current embeddings of relations. - e_idx: torch.Tensor, shape = (batch_size), dtype = long - Tensor containing the indices of entities. - r_idx: torch.Tensor, shape = (batch_size), dtype = long - Tensor containing the indices of relations. - true_idx: torch.Tensor, shape = (batch_size), dtype = long - Tensor containing the true entity for each sample. - dictionary: default dict - Dictionary of keys (int, int) and values list of ints giving all possible entities for - the (entity, relation) pair. - heads: integer - 1 ou -1 (must be 1 if entities are heads and -1 if entities are tails). \ - We test dissimilarity between heads * entities + relations and heads * targets. - - - Returns - ------- - rank_true_entities: torch.Tensor, shape = (b_size), dtype = int - Tensor containing the rank of the true entities when ranking any entity based on \ - computation of d(hear+relation, tail). - filtered_rank_true_entities: torch.Tensor, shape = (b_size), dtype = int - Tensor containing the rank of the true entities when ranking only true false entities \ - based on computation of d(hear+relation, tail). - - """ - current_batch_size, embedding_dimension = proj_e_emb.shape - - # tmp_sum is either heads + r_emb or r_emb - tails (expand does not use extra memory) - tmp_sum = (heads * proj_e_emb + r_emb).view((current_batch_size, embedding_dimension, 1)) - tmp_sum = tmp_sum.expand((current_batch_size, embedding_dimension, self.number_entities)) - - # compute either dissimilarity(heads + relation, proj_candidates) or - # dissimilarity(-proj_candidates, relation - tails) - dissimilarities = self.dissimilarity(tmp_sum, heads * proj_candidates) - - # filter out the true negative samples by assigning infinite dissimilarity - filt_dissimilarities = dissimilarities.clone() - for i in range(current_batch_size): - true_targets = dictionary[e_idx[i].item(), r_idx[i].item()].copy() - if len(true_targets) == 1: - continue - true_targets.remove(true_idx[i].item()) - true_targets = tensor(list(true_targets)).long() - filt_dissimilarities[i][true_targets] = float('Inf') - - # from dissimilarities, extract the rank of the true entity. - rank_true_entities = get_rank(-dissimilarities, true_idx) - filtered_rank_true_entities = get_rank(-filt_dissimilarities, true_idx) - - return rank_true_entities, filtered_rank_true_entities - - def evaluate_candidates(self, h_idx, t_idx, r_idx, kg): - proj_h_emb, proj_t_emb, proj_candidates, r_emb = self.evaluation_helper(h_idx, t_idx, r_idx) - - # evaluation_helper both ways (head, rel) -> tail and (rel, tail) -> head - rank_true_tails, filt_rank_true_tails = self.compute_ranks(proj_h_emb, - proj_candidates, - r_emb, h_idx, r_idx, - t_idx, - kg.dict_of_tails, - heads=1) - rank_true_heads, filt_rank_true_heads = self.compute_ranks(proj_t_emb, - proj_candidates, - r_emb, t_idx, r_idx, - h_idx, - kg.dict_of_heads, - heads=-1) - - return rank_true_tails, filt_rank_true_tails, rank_true_heads, filt_rank_true_heads - class TransHModel(TransEModel): """Implementation of TransH model detailed in 2014 paper by Wang et al.. @@ -334,11 +218,9 @@ class TransHModel(TransEModel): """ def __init__(self, ent_emb_dim, n_entities, n_relations): - - super().__init__(ent_emb_dim, n_entities, n_relations, l2_dissimilarity) - - self.normal_vectors = Parameter(xavier_uniform_(empty(size=(n_relations, ent_emb_dim)))) - self.normalize_parameters() + super().__init__(ent_emb_dim, n_entities, n_relations, dissimilarity='L2') + self.normal_vectors = Parameter(xavier_uniform_(empty(size=(n_relations, ent_emb_dim))), + requires_grad=True) def scoring_function(self, heads_idx, tails_idx, rels_idx): """Compute the scoring function for the triplets given as argument. @@ -441,15 +323,7 @@ def evaluation_helper(self, h_idx, t_idx, r_idx): b_size, _ = normal_vectors.shape # recover candidates - all_idx = arange(0, self.number_entities).long() - if h_idx.is_cuda: - all_idx = all_idx.cuda() - candidates = self.entity_embeddings(all_idx).transpose(0, 1) - candidates = candidates.view((1, - self.ent_emb_dim, - self.number_entities)).expand((b_size, - self.ent_emb_dim, - self.number_entities)) + candidates = self.recover_candidates(h_idx, b_size) # project each candidates with each normal vector normal_components = candidates * normal_vectors.view((b_size, self.ent_emb_dim, 1)) @@ -460,16 +334,13 @@ def evaluation_helper(self, h_idx, t_idx, r_idx): assert proj_candidates.shape == (b_size, self.ent_emb_dim, self.number_entities) # recover, project and normalize entity embeddings - mask = h_idx.view(b_size, 1, 1).expand(b_size, self.ent_emb_dim, 1) - proj_h_emb = proj_candidates.gather(dim=2, index=mask).view(b_size, self.ent_emb_dim) - - mask = t_idx.view(b_size, 1, 1).expand(b_size, self.ent_emb_dim, 1) - proj_t_emb = proj_candidates.gather(dim=2, index=mask).view(b_size, self.ent_emb_dim) + proj_h_emb, proj_t_emb = self.projection_helper(h_idx, t_idx, b_size, proj_candidates, + self.ent_emb_dim) return proj_h_emb, proj_t_emb, proj_candidates, r_emb -class TransRModel(TransEModel): +class TransRModel(TranslationalModel): """Implementation of TransR model detailed in 2015 paper by Lin et al.. References @@ -515,11 +386,14 @@ class TransRModel(TransEModel): def __init__(self, ent_emb_dim, rel_emb_dim, n_entities, n_relations): - super().__init__(ent_emb_dim, n_entities, n_relations, l2_dissimilarity, - rel_emb_dim=rel_emb_dim) + super().__init__(ent_emb_dim, n_entities, n_relations, dissimilarity='L2') + + self.rel_emb_dim = rel_emb_dim + self.relation_embeddings = init_embedding(self.number_relations, self.rel_emb_dim) self.projection_matrices = Parameter(xavier_uniform_(empty(size=(n_relations, rel_emb_dim, - ent_emb_dim)))) + ent_emb_dim))), + requires_grad=True) self.normalize_parameters() def scoring_function(self, heads_idx, tails_idx, rels_idx): @@ -625,29 +499,18 @@ def evaluation_helper(self, h_idx, t_idx, r_idx): b_size, _, _ = projection_matrices.shape # recover candidates - all_idx = arange(0, self.number_entities).long() - if h_idx.is_cuda: - all_idx = all_idx.cuda() - candidates = self.entity_embeddings(all_idx).transpose(0, 1) - candidates = candidates.view((1, - self.ent_emb_dim, - self.number_entities)).expand((b_size, - self.ent_emb_dim, - self.number_entities)) + candidates = self.recover_candidates(h_idx, b_size) # project each candidates with each projection matrix proj_candidates = matmul(projection_matrices, candidates) - mask = h_idx.view(b_size, 1, 1).expand(b_size, self.rel_emb_dim, 1) - proj_h_emb = proj_candidates.gather(dim=2, index=mask).view(b_size, self.rel_emb_dim) - - mask = t_idx.view(b_size, 1, 1).expand(b_size, self.rel_emb_dim, 1) - proj_t_emb = proj_candidates.gather(dim=2, index=mask).view(b_size, self.rel_emb_dim) + proj_h_emb, proj_t_emb = self.projection_helper(h_idx, t_idx, b_size, candidates, + self.rel_emb_dim) return proj_h_emb, proj_t_emb, proj_candidates, r_emb -class TransDModel(TransEModel): +class TransDModel(TranslationalModel): """Implementation of TransD model detailed in 2015 paper by Ji et al.. References @@ -697,11 +560,15 @@ class TransDModel(TransEModel): def __init__(self, ent_emb_dim, rel_emb_dim, n_entities, n_relations): - super().__init__(ent_emb_dim, n_entities, n_relations, l2_dissimilarity, - rel_emb_dim=rel_emb_dim) + super().__init__(ent_emb_dim, n_entities, n_relations, dissimilarity='L2') - self.ent_proj_vects = Parameter(xavier_uniform_(empty(size=(n_entities, ent_emb_dim)))) - self.rel_proj_vects = Parameter(xavier_uniform_(empty(size=(n_relations, rel_emb_dim)))) + self.rel_emb_dim = rel_emb_dim + self.relation_embeddings = init_embedding(self.number_relations, self.rel_emb_dim) + + self.ent_proj_vects = Parameter(xavier_uniform_(empty(size=(n_entities, ent_emb_dim))), + requires_grad=True) + self.rel_proj_vects = Parameter(xavier_uniform_(empty(size=(n_relations, rel_emb_dim))), + requires_grad=True) self.normalize_parameters() self.evaluated_projections = False @@ -872,10 +739,129 @@ def evaluation_helper(self, h_idx, t_idx, r_idx): r_emb = normalize(self.relation_embeddings(r_idx), p=2, dim=1) proj_candidates = self.projected_entities[r_idx] - mask = h_idx.view(b_size, 1, 1).expand(b_size, self.rel_emb_dim, 1) - proj_h_emb = proj_candidates.gather(dim=2, index=mask).view(b_size, self.rel_emb_dim) - - mask = t_idx.view(b_size, 1, 1).expand(b_size, self.rel_emb_dim, 1) - proj_t_emb = proj_candidates.gather(dim=2, index=mask).view(b_size, self.rel_emb_dim) + proj_h_emb, proj_t_emb = self.projection_helper(h_idx, t_idx, b_size, proj_candidates, + self.rel_emb_dim) return proj_h_emb, proj_t_emb, proj_candidates, r_emb + + +class TorusEModel(TransEModel): + """Implementation of TorusE model detailed in 2018 paper by Ebisu and Ichise. + + References + ---------- + * Takuma Ebisu and Ryutaro Ichise + TorusE: Knowledge Graph Embedding on a Lie Group. + In Proceedings of the 32nd AAAI Conference on Artificial Intelligence + (New Orleans, LA, USA, Feb. 2018), AAAI Press, pp. 1819–1826. + https://arxiv.org/abs/1711.05435 + + Parameters + ---------- + ent_emb_dim: int + Dimension of the embedding of entities. + n_entities: int + Number of entities in the current data set. + n_relations: int + Number of relations in the current data set. + dissimilarity: function + Used to compute dissimilarities (e.g. L1 or L2 dissimilarities). + + Attributes + ---------- + ent_emb_dim: int + Dimension of the embedding of entities. + number_entities: int + Number of entities in the current data set. + number_relations: int + Number of relations in the current data set. + dissimilarity: function + Used to compute dissimilarities (e.g. L1 or L2 dissimilarities). + entity_embeddings: torch Embedding, shape = (number_entities, ent_emb_dim) + Contains the embeddings of the entities. It is initialized with Xavier uniform and then\ + normalized. + relation_embeddings: torch Embedding, shape = (number_relations, ent_emb_dim) + Contains the embeddings of the relations. It is initialized with Xavier uniform and\ + then normalized. + + """ + + def __init__(self, ent_emb_dim, n_entities, n_relations, dissimilarity): + + assert dissimilarity in ['L1', 'L2', 'eL2'] + self.dissimilarity_type = dissimilarity + + super().__init__(ent_emb_dim, n_entities, n_relations, dissimilarity=None) + + self.relation_embeddings = init_embedding(self.number_relations, self.ent_emb_dim) + + if self.dissimilarity_type == 'L1': + self.dissimilarity = l1_torus_dissimilarity + if self.dissimilarity_type == 'L2': + self.dissimilarity = l2_torus_dissimilarity + if self.dissimilarity_type == 'eL2': + self.dissimilarity = el2_torus_dissimilarity + + self.normalize_parameters() + + def scoring_function(self, heads_idx, tails_idx, rels_idx): + """Compute the scoring function for the triplets given as argument. + + Parameters + ---------- + heads_idx: torch.Tensor, dtype = long, shape = (batch_size) + Integer keys of the current batch's heads + tails_idx: torch.Tensor, dtype = long, shape = (batch_size) + Integer keys of the current batch's tails. + rels_idx: torch.Tensor, dtype = long, shape = (batch_size) + Integer keys of the current batch's relations. + + Returns + ------- + score: torch.Tensor, dtype = float, shape = (batch_size) + Score function: opposite of dissimilarities between h+r and t. + + """ + + # recover relations embeddings + rels_emb = self.relation_embeddings(rels_idx) + + # recover, project and normalize entity embeddings + h_emb = self.recover_project_normalize(heads_idx, normalize_=False) + t_emb = self.recover_project_normalize(tails_idx, normalize_=False) + + if self.dissimilarity_type == 'L1': + return 2 * self.dissimilarity(h_emb + rels_emb, t_emb) + if self.dissimilarity_type == 'L2': + return 4 * self.dissimilarity(h_emb + rels_emb, t_emb)**2 + else: + assert self.dissimilarity_type == 'eL2' + return self.dissimilarity(h_emb + rels_emb, t_emb)**2 / 4 + + def recover_project_normalize(self, ent_idx, normalize_=False): + """ + + Parameters + ---------- + ent_idx: torch.Tensor, dtype = long, shape = (batch_size) + Integer keys of entities + normalize_: bool + Whether entities embeddings should be normalized or not. + + Returns + ------- + projections: torch.Tensor, dtype = float, shape = (batch_size, ent_emb_dim) + Embedded entities normalized. + + """ + # recover entity embeddings + ent_emb = self.entity_embeddings(ent_idx) + ent_emb.data.frac_() + + return ent_emb + + def normalize_parameters(self): + """Project embeddings on torus. + """ + self.entity_embeddings.weight.data.frac_() + self.relation_embeddings.weight.data.frac_() diff --git a/torchkge/models/__init__.py b/torchkge/models/__init__.py index 4bf79f5..757a628 100644 --- a/torchkge/models/__init__.py +++ b/torchkge/models/__init__.py @@ -1,9 +1,10 @@ -from .Models import Model +from .Models import Model, TranslationalModel from .TranslationModels import TransEModel from .TranslationModels import TransHModel from .TranslationModels import TransRModel from .TranslationModels import TransDModel +from .TranslationModels import TorusEModel from .BilinearModels import RESCALModel from .BilinearModels import DistMultModel diff --git a/torchkge/sampling/NegativeSampling.py b/torchkge/sampling/NegativeSampling.py index 97699e6..2a9900c 100644 --- a/torchkge/sampling/NegativeSampling.py +++ b/torchkge/sampling/NegativeSampling.py @@ -33,31 +33,31 @@ class NegativeSampler: n_ent: int Number of entities in the entire knowledge graph. This is the same in `kg`, `kg_val`\ and `kg_test`. - n_sample: int + n_facts: int Number of triplets in `kg`. - n_sample_val: in + n_facts_val: in Number of triplets in `kg_val`. - n_sample_test: int + n_facts_test: int Number of triples in `kg_test`. """ def __init__(self, kg, kg_val=None, kg_test=None): self.kg = kg self.n_ent = kg.n_ent - self.n_sample = kg.n_sample + self.n_facts = kg.n_facts self.kg_val = kg_val self.kg_test = kg_test if kg_val is None: - self.n_sample_val = 0 + self.n_facts_val = 0 else: - self.n_sample_val = kg_val.n_sample + self.n_facts_val = kg_val.n_facts if kg_test is None: - self.n_sample_test = 0 + self.n_facts_test = 0 else: - self.n_sample_test = kg_test.n_sample + self.n_facts_test = kg_test.n_facts def corrupt_batch(self, heads, tails, relations): raise NotYetImplementedError('NegativeSampler is just an interface, please consider using ' @@ -84,18 +84,18 @@ def corrupt_kg(self, batch_size, use_cuda, which='main'): Returns ------- - neg_heads: torch.Tensor, dtype = long, shape = (n_samples) + neg_heads: torch.Tensor, dtype = long, shape = (n_facts) Tensor containing the integer key of negatively sampled heads of the relations\ in the graph designated by `which`. - neg_tails: torch.Tensor, dtype = long, shape = (n_samples) + neg_tails: torch.Tensor, dtype = long, shape = (n_facts) Tensor containing the integer key of negatively sampled tails of the relations\ in the graph designated by `which`. """ assert which in ['main', 'train', 'test', 'val'] if which == 'val': - assert self.n_sample_val > 0 + assert self.n_facts_val > 0 if which == 'test': - assert self.n_sample_test > 0 + assert self.n_facts_test > 0 if which == 'val': dataloader = DataLoader(self.kg_val, batch_size=batch_size, shuffle=False, @@ -158,11 +158,11 @@ class UniformNegativeSampler(NegativeSampler): n_ent: int Number of entities in the entire knowledge graph. This is the same in `kg`, `kg_val`\ and `kg_test`. - n_sample: int + n_facts: int Number of triplets in `kg`. - n_sample_val: in + n_facts_val: in Number of triplets in `kg_val`. - n_sample_test: int + n_facts_test: int Number of triples in `kg_test`. """ @@ -245,11 +245,11 @@ class BernoulliNegativeSampler(NegativeSampler): n_ent: int Number of entities in the entire knowledge graph. This is the same in `kg`, `kg_val`\ and `kg_test`. - n_sample: int + n_facts: int Number of triplets in `kg`. - n_sample_val: in + n_facts_val: in Number of triplets in `kg_val`. - n_sample_test: int + n_facts_test: int Number of triples in `kg_test`. bern_probs: torch.Tensor, dtype = float, shape = (kg.n_rel) Bernoulli sampling probabilities. See paper for more details. @@ -351,11 +351,11 @@ class PositionalNegativeSampler(BernoulliNegativeSampler): n_ent: int Number of entities in the entire knowledge graph. This is the same in `kg`, `kg_val`\ and `kg_test`. - n_sample: int + n_facts: int Number of triplets in `kg`. - n_sample_val: in + n_facts_val: in Number of triplets in `kg_val`. - n_sample_test: int + n_facts_test: int Number of triples in `kg_test`. bern_probs: torch.Tensor, dtype = float, shape = (kg.n_rel) Bernoulli sampling probabilities. See paper for more details. @@ -375,7 +375,7 @@ def __init__(self, kg, kg_val=None, kg_test=None): def find_possibilities(self): """For each relation of the knowledge graph (and possibly the validation graph but not the\ - test graph) find all the possible heads and tails in the sens of Wang et al., e.g. all\ + test graph) find all the possible heads and tails in the sense of Wang et al., e.g. all\ entities that occupy once this position in another triplet. Returns @@ -392,13 +392,10 @@ def find_possibilities(self): """ possible_heads, possible_tails = fill_in_dicts(self.kg) - if self.n_sample_val > 0: + if self.n_facts_val > 0: possible_heads, possible_tails = fill_in_dicts(self.kg_val, possible_heads, possible_tails) - possible_heads = dict(possible_heads) - possible_tails = dict(possible_tails) - n_poss_heads = [] n_poss_tails = [] @@ -411,7 +408,7 @@ def find_possibilities(self): n_poss_heads = tensor(n_poss_heads) n_poss_tails = tensor(n_poss_tails) - return possible_heads, possible_tails, n_poss_heads, n_poss_tails + return dict(possible_heads), dict(possible_tails), n_poss_heads, n_poss_tails def corrupt_batch(self, heads, tails, relations): """For each golden triplet, produce a corrupted one not different from any other golden\ @@ -466,14 +463,26 @@ def corrupt_batch(self, heads, tails, relations): rels = relations[mask == 1] for i in range(n_heads_corrupted): r = rels[i].item() - corr.append(self.possible_heads[r][choice_heads[i].item()]) - neg_heads[mask == 1] = tensor(corr, device=device) + choices = self.possible_heads[r] + if len(choices) == 0: + # in this case the relation r has never been used with any head + # choose one entity at random + corr.append(randint(low=0, high=self.n_ent, size=(1,)).item()) + else: + corr.append(choices[choice_heads[i].item()]) + neg_heads[mask == 1] = tensor(corr, device=device).long() corr = [] rels = relations[mask == 0] for i in range(batch_size - n_heads_corrupted): r = rels[i].item() - corr.append(self.possible_tails[r][choice_tails[i].item()]) - neg_tails[mask == 0] = tensor(corr, device=device) + choices = self.possible_tails[r] + if len(choices) == 0: + # in this case the relation r has never been used with any tail + # choose one entity at random + corr.append(randint(low=0, high=self.n_ent, size=(1,)).item()) + else: + corr.append(choices[choice_tails[i].item()]) + neg_tails[mask == 0] = tensor(corr, device=device).long() return neg_heads.long(), neg_tails.long() diff --git a/torchkge/utils/__init__.py b/torchkge/utils/__init__.py index 86ba856..e4512d9 100644 --- a/torchkge/utils/__init__.py +++ b/torchkge/utils/__init__.py @@ -1,5 +1,7 @@ from .data_preprocessing import get_dictionaries, get_bern_probs from .dissimilarities import l1_dissimilarity, l2_dissimilarity +from .dissimilarities import l1_torus_dissimilarity, l2_torus_dissimilarity, el2_torus_dissimilarity from .losses import MarginLoss, LogisticLoss, MSE -from .operations import get_rank, get_mask, get_rolling_matrix, init_embedding +from .operations import get_rank, get_mask, get_rolling_matrix +from .models_utils import init_embedding, get_true_targets from .negative_sampling import fill_in_dicts diff --git a/torchkge/utils/dissimilarities.py b/torchkge/utils/dissimilarities.py index 7c5af17..d0390ca 100644 --- a/torchkge/utils/dissimilarities.py +++ b/torchkge/utils/dissimilarities.py @@ -3,6 +3,8 @@ Copyright TorchKGE developers aboschin@enst.fr """ +from torch import abs, min, sqrt, cos +from math import pi def l1_dissimilarity(a, b): @@ -10,12 +12,12 @@ def l1_dissimilarity(a, b): Parameters ---------- - a: torch.Tensor, dtype = float, shape = (n_sample, dim) - b: torch.Tensor, dtype = float, shape = (n_sample, dim) + a: torch.Tensor, dtype = float, shape = (n_facts, dim) + b: torch.Tensor, dtype = float, shape = (n_facts, dim) Returns ------- - dist: torch.Tensor, dtype = float, shape = (n_sample) + dist: torch.Tensor, dtype = float, shape = (n_facts) Tensor of the row_wise L1 distance. """ @@ -27,13 +29,29 @@ def l2_dissimilarity(a, b): Parameters ---------- - a: torch.Tensor, dtype = float, shape = (n_sample, dim) - b: torch.Tensor, dtype = float, shape = (n_sample, dim) + a: torch.Tensor, dtype = float, shape = (n_facts, dim) + b: torch.Tensor, dtype = float, shape = (n_facts, dim) Returns ------- - dist: torch.Tensor, dtype = float, shape = (n_sample) + dist: torch.Tensor, dtype = float, shape = (n_facts) Tensor of the row_wise squared L2 distance. """ return (a-b).norm(p=2, dim=1)**2 + + +def l1_torus_dissimilarity(a, b): + a, b = a.frac(), b.frac() + return min(abs(a-b), 1 - abs(a-b)).sum(dim=1) + + +def l2_torus_dissimilarity(a, b): + a, b = a.frac(), b.frac() + return sqrt(min((a - b)**2, 1 - (a - b)**2).sum(dim=1)) + + +def el2_torus_dissimilarity(a, b): + tmp = min(a - b, 1 - (a-b)) + tmp = 2 * (1 - cos(2 * pi * tmp)) + return sqrt(tmp.sum(dim=1)) diff --git a/torchkge/utils/models_utils.py b/torchkge/utils/models_utils.py new file mode 100644 index 0000000..8b9b26b --- /dev/null +++ b/torchkge/utils/models_utils.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +""" +Copyright TorchKGE developers +aboschin@enst.fr +""" + +from torch import tensor, empty +from torch.nn import Embedding, Parameter +from torch.nn.init import xavier_uniform_ + + +def init_embedding(n_vectors, dim): + """Create a torch.nn.Embedding object with `n_vectors` samples and `dim` dimensions. + """ + entity_embeddings = Embedding(n_vectors, dim) + entity_embeddings.weight = Parameter(xavier_uniform_(empty(size=(n_vectors, dim))), + requires_grad=True) + return entity_embeddings + + +def get_true_targets(dictionary, e_idx, r_idx, true_idx, i): + true_targets = dictionary[e_idx[i].item(), r_idx[i].item()].copy() + if len(true_targets) == 1: + return None + true_targets.remove(true_idx[i].item()) + return tensor(list(true_targets)).long() + + + diff --git a/torchkge/utils/negative_sampling.py b/torchkge/utils/negative_sampling.py index ad8b95b..d951b7f 100644 --- a/torchkge/utils/negative_sampling.py +++ b/torchkge/utils/negative_sampling.py @@ -13,7 +13,7 @@ def fill_in_dicts(kg, possible_heads=None, possible_tails=None): if possible_tails is None: possible_tails = defaultdict(set) - for i in tqdm(range(kg.n_sample)): + for i in tqdm(range(kg.n_facts)): possible_heads[kg.relations[i].item()].add(kg.head_idx[i].item()) possible_tails[kg.relations[i].item()].add(kg.tail_idx[i].item()) diff --git a/torchkge/utils/operations.py b/torchkge/utils/operations.py index 60d0b0f..d94391f 100644 --- a/torchkge/utils/operations.py +++ b/torchkge/utils/operations.py @@ -4,17 +4,7 @@ aboschin@enst.fr """ -from torch import empty, bincount, cat, topk, zeros -from torch.nn import Embedding, Parameter -from torch.nn.init import xavier_uniform_ - - -def init_embedding(n_vectors, dim): - """Create a torch.nn.Embedding object with `n_vectors` samples and `dim` dimensions. - """ - entity_embeddings = Embedding(n_vectors, dim) - entity_embeddings.weight = Parameter(xavier_uniform_(empty(size=(n_vectors, dim)))) - return entity_embeddings +from torch import bincount, cat, topk, zeros def get_mask(length, start, end): @@ -29,13 +19,13 @@ def get_mask(length, start, end): Returns ------- - mask: torch.Tensor, shape=(length), dtype=byte + mask: torch.Tensor, shape=(length), dtype=bool Mask of length `length` filled with 0s except between indices `start` (included)\ and `end` (excluded). """ mask = zeros(length) mask[[i for i in range(start, end)]] = 1 - return mask.byte() + return mask.bool() def get_rolling_matrix(x): @@ -60,14 +50,14 @@ def get_rank(data, true, low_values=False): Parameters ---------- - data: torch.Tensor, dtype = float, shape = (n_sample, dimensions) - true: torch.Tensor, dtype = int, shape = (n_sample) + data: torch.Tensor, dtype = float, shape = (n_facts, dimensions) + true: torch.Tensor, dtype = int, shape = (n_facts) low_values: bool if True, best rank is the lowest score else it is the highest Returns ------- - ranks: torch.Tensor, dtype = int, shape = (n_sample) + ranks: torch.Tensor, dtype = int, shape = (n_facts) data[ranks[i]] = true[i] """ true_data = data.gather(1, true.long().view(-1, 1))