Skip to content

Commit

Permalink
Merge branch 'master' of github.com:torchkge-team/torchkge
Browse files Browse the repository at this point in the history
  • Loading branch information
armand33 committed Apr 2, 2020
2 parents 5d7238d + a2149ac commit f2a945f
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions torchkge/evaluation/LinkPrediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,26 @@ def mrr(self):
filt_tail_mrr = (self.filt_rank_true_tails.float()**(-1)).mean()

return (head_mrr + tail_mrr).item() / 2, (filt_head_mrr + filt_tail_mrr).item() / 2

def print_results(self, k=None):
"""
Parameters
----------
k: int or list
k (or list of k) such that hit@k will be printed.
"""
if k is None:
k = 10

if k is not None and type(k) == int:
print('Hit@{} : {} \t Filt. Hit@{} : {}'.format(k, self.hit_at_k(k=k)[0],
k, self.hit_at_k(k=k)[1]))
if k is not None and type(k) == list:
for i in k:
print('Hit@{} : {} \t Filt. Hit@{} : {}'.format(i, self.hit_at_k(k=i)[0],
i, self.hit_at_k(k=i)[1]))

print('Mean Rank : {} \t Filt. Mean Rank : {}'.format(self.mean_rank()[0], self.mean_rank()[1]))
print('MRR : {} \t Filt. MRR : {}'.format(self.mrr()[0], self.mrr()[1]))

0 comments on commit f2a945f

Please sign in to comment.