Skip to content

Commit

Permalink
Merge pull request #107 from torchkge-team/develop
Browse files Browse the repository at this point in the history
Minor fix and comment additions.
  • Loading branch information
armand33 authored Jan 10, 2020
2 parents 1acc509 + 141c01b commit 1e04980
Show file tree
Hide file tree
Showing 7 changed files with 145 additions and 16 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.12.0
current_version = 0.12.1
commit = True
tag = True

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,6 @@
setup_requires=setup_requirements,
tests_require=test_requirements,
test_suite='tests',
version='0.12.0',
version='0.12.1',
zip_safe=False,
)
2 changes: 1 addition & 1 deletion torchkge/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

__author__ = """Armand Boschin"""
__email__ = '[email protected]'
__version__ = '0.12.0'
__version__ = '0.12.1'

from .data import KnowledgeGraph

Expand Down
2 changes: 1 addition & 1 deletion torchkge/evaluation/LinkPrediction.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def evaluate(self, batch_size, k_max):
"""
self.k_max = k_max
use_cuda = self.model.entity_embeddings.weight.is_cuda
use_cuda = next(self.model.parameters()).is_cuda
dataloader = DataLoader(self.kg, batch_size=batch_size, pin_memory=use_cuda)

if use_cuda:
Expand Down
2 changes: 1 addition & 1 deletion torchkge/evaluation/TripletClassification.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(self, model, kg_val, kg_test):
# assert kg_val.n_rel == kg_test.n_rel
# assert set(kg_test.relations.unique().tolist()).issubset(set(kg_val.relations.unique().tolist()))
self.kg_test = kg_test
self.use_cuda = self.model.entity_embeddings.weight.is_cuda
self.use_cuda = next(self.model.parameters()).is_cuda

self.evaluated = False
self.thresholds = None
Expand Down
2 changes: 1 addition & 1 deletion torchkge/models/BilinearModels.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def evaluation_helper(self, h_idx, t_idx, r_idx):
return h_emb, t_emb, candidates, r_mat

def compute_ranks(self, e_emb, candidates, r_mat, e_idx, r_idx, true_idx, dictionary, heads=1):
"""
"""Compute the ranks and the filtered ranks of true entities when doing link prediction.
Parameters
----------
Expand Down
149 changes: 139 additions & 10 deletions torchkge/models/interfaces.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,123 @@ def forward(self, heads, tails, negative_heads, negative_tails, relations):
self.scoring_function(negative_heads, negative_tails, relations)

def scoring_function(self, heads_idx, tails_idx, rels_idx):
pass
"""Compute the scoring function for the triplets given as argument.
Parameters
----------
heads_idx: `torch.Tensor`, dtype: `torch.long`, shape: (batch_size)
Integer keys of the current batch's heads
tails_idx: `torch.Tensor`, dtype: `torch.long`, shape: (batch_size)
Integer keys of the current batch's tails.
rels_idx: `torch.Tensor`, dtype: `torch.long`, shape: (batch_size)
Integer keys of the current batch's relations.
Returns
-------
score: `torch.Tensor`, dtype: `torch.float`, shape: (batch_size)
Score function: opposite of dissimilarities between h+r and t.
"""
raise NotImplementedError

def normalize_parameters(self):
pass
"""Normalize the parameters of the model using the L2 norm.
"""
raise NotImplementedError

def evaluation_helper(self, h_idx, t_idx, r_idx):
"""Project current entities and candidates into relation-specific sub-spaces.
Parameters
----------
h_idx: `torch.Tensor`, shape: (b_size), dtype: `torch.long`
Tensor containing indices of current head entities.
t_idx: `torch.Tensor`, shape: (b_size), dtype: `torch.long`
Tensor containing indices of current tail entities.
r_idx: `torch.Tensor`, shape: (b_size), dtype: `torch.long`
Tensor containing indices of current relations.
Returns
-------
proj_h_emb: `torch.Tensor`, shape: (b_size, rel_emb_dim), dtype: `torch.float`
Tensor containing embeddings of current head entities projected in relation space.
proj_t_emb: `torch.Tensor`, shape: (b_size, rel_emb_dim), dtype: `torch.float`
Tensor containing embeddings of current tail entities projected in relation space.
proj_candidates: `torch.Tensor`, shape: (b_size, rel_emb_dim, n_entities), dtype: `torch.float`
Tensor containing all entities projected in each relation spaces (relations
corresponding to current batch's relations).
r_emb: `torch.Tensor`, shape: (b_size, rel_emb_dim), dtype: `torch.float`
Tensor containing current relations embeddings.
"""
raise NotImplementedError

def compute_ranks(self, proj_e_emb, proj_candidates, r_emb, e_idx, r_idx, true_idx, dictionary, heads=1):
"""Compute the ranks and the filtered ranks of true entities when doing link prediction.
Parameters
----------
proj_e_emb: `torch.Tensor`, shape: (batch_size, rel_emb_dim), dtype: `torch.float`
Tensor containing current projected embeddings of entities.
proj_candidates: `torch.Tensor`, shape: (b_size, rel_emb_dim, n_entities), dtype: `torch.float`
Tensor containing projected embeddings of all entities.
r_emb: `torch.Tensor`, shape: (batch_size, ent_emb_dim), dtype: `torch.float`
Tensor containing current embeddings of relations.
e_idx: `torch.Tensor`, shape: (batch_size), dtype: `torch.long`
Tensor containing the indices of entities.
r_idx: `torch.Tensor`, shape: (batch_size), dtype: `torch.long`
Tensor containing the indices of relations.
true_idx: `torch.Tensor`, shape: (batch_size), dtype: `torch.long`
Tensor containing the true entity for each sample.
dictionary: default dict
Dictionary of keys (int, int) and values list of ints giving all possible entities for
the (entity, relation) pair.
heads: integer
1 ou -1 (must be 1 if entities are heads and -1 if entities are tails). \
We test dissimilarity_type between heads * entities + relations and heads * targets.
Returns
-------
rank_true_entities: `torch.Tensor`, shape: (b_size), dtype: `torch.int`
Tensor containing the rank of the true entities when ranking any entity based on \
computation of d(hear+relation, tail).
filtered_rank_true_entities: `torch.Tensor`, shape: (b_size), dtype: `torch.int`
Tensor containing the rank of the true entities when ranking only true false entities \
based on computation of d(hear+relation, tail).
"""
raise NotImplementedError

def evaluate_candidates(self, h_idx, t_idx, r_idx, kg):
"""Compute the head and tail ranks and filtered ranks of the current batch.
Parameters
----------
h_idx: `torch.Tensor`, shape: (b_size), dtype: `torch.long`
Tensor containing indices of current head entities.
t_idx: `torch.Tensor`, shape: (b_size), dtype: `torch.long`
Tensor containing indices of current tail entities.
r_idx: `torch.Tensor`, shape: (b_size), dtype: `torch.long`
Tensor containing indices of current relations.
kg: `torchkge.data.KnowledgeGraph.KnowledgeGraph`
Knowledge graph on which the model was trained.
Returns
-------
rank_true_tails: `torch.Tensor`, shape: (b_size), dtype: `torch.int`
Tensor containing the rank of the true tails when ranking any entity based on \
computation of d(hear+relation, tail).
filt_rank_true_tails: `torch.Tensor`, shape: (b_size), dtype: `torch.int`
Tensor containing the rank of the true tails when ranking only true false entities \
based on computation of d(hear+relation, tail).
rank_true_heads: Tensor containing the rank of the true heads when ranking any entity based on \
computation of d(hear+relation, tail).
filt_rank_true_heads: `torch.Tensor`, shape: (b_size), dtype: `torch.int`
Tensor containing the rank of the true heads when ranking only true false entities \
based on computation of d(hear+relation, tail).
"""
proj_h_emb, proj_t_emb, candidates, r_emb = self.evaluation_helper(h_idx, t_idx, r_idx)

# evaluation_helper both ways (head, rel) -> tail and (rel, tail) -> head
Expand Down Expand Up @@ -131,12 +239,20 @@ def __init__(self, ent_emb_dim, n_entities, n_relations, dissimilarity_type):
else:
self.dissimilarity = None

def scoring_function(self, heads_idx, tails_idx, rels_idx):
raise NotImplementedError

def normalize_parameters(self):
raise NotImplementedError

def evaluation_helper(self, h_idx, t_idx, r_idx):
raise NotImplementedError

def recover_project_normalize(self, ent_idx, rel_idx, normalize_):
pass
raise NotImplementedError

def compute_ranks(self, proj_e_emb, proj_candidates,
r_emb, e_idx, r_idx, true_idx, dictionary, heads=1):
"""
def compute_ranks(self, proj_e_emb, proj_candidates, r_emb, e_idx, r_idx, true_idx, dictionary, heads=1):
"""Compute the ranks and the filtered ranks of true entities when doing link prediction.
Parameters
----------
Expand Down Expand Up @@ -194,10 +310,22 @@ def compute_ranks(self, proj_e_emb, proj_candidates,

return rank_true_entities, filtered_rank_true_entities

def evaluation_helper(self, h_idx, t_idx, r_idx):
return None, None, None, None

def recover_candidates(self, h_idx, b_size):
"""Prepare candidates for link prediction evaluation.
Parameters
----------
h_idx: `torch.Tensor`, shape: (b_size), dtype: `torch.long`
Tensor containing indices of current head entities.
b_size: int
Batch size.
Returns
-------
candidates: `torch.Tensor`, shape: (b_size, ent_emb_dim, number_entities), dtype: `torch.float`
Tensor containing replications of all entities embeddings as many times as the batch size.
"""
all_idx = arange(0, self.number_entities).long()
if h_idx.is_cuda:
all_idx = all_idx.cuda()
Expand All @@ -209,7 +337,8 @@ def recover_candidates(self, h_idx, b_size):
self.number_entities))
return candidates

def projection_helper(self, h_idx, t_idx, b_size, candidates, rel_emb_dim):
@staticmethod
def projection_helper(h_idx, t_idx, b_size, candidates, rel_emb_dim):
mask = h_idx.view(b_size, 1, 1).expand(b_size, rel_emb_dim, 1)
proj_h_emb = candidates.gather(dim=2, index=mask).view(b_size, rel_emb_dim)

Expand Down

0 comments on commit 1e04980

Please sign in to comment.