Skip to content

Commit

Permalink
Merge pull request #80 from torchkge-team/develop
Browse files Browse the repository at this point in the history
Addind TorusE and dataloaders and fixing some bugs.
  • Loading branch information
armand33 authored Oct 21, 2019
2 parents 298a7c0 + 8f2d9a4 commit 3bb8ccb
Show file tree
Hide file tree
Showing 18 changed files with 673 additions and 359 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.10.6
current_version = 0.11.0
commit = True
tag = True

Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
with open('HISTORY.rst') as history_file:
history = history_file.read()

requirements = ['torch==1.2.0', 'torchvision', 'tqdm==4.32.2']
requirements = ['torch==1.2.0', 'torchvision', 'tqdm==4.32.2', 'pandas']

setup_requirements = []

Expand Down Expand Up @@ -39,6 +39,6 @@
test_suite='tests',
tests_require=test_requirements,
url='https://github.com/torchkge-team/torchkge',
version='0.10.6',
version='0.11.0',
zip_safe=False,
)
2 changes: 1 addition & 1 deletion torchkge/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

__author__ = """Armand Boschin"""
__email__ = '[email protected]'
__version__ = '0.10.6'
__version__ = '0.11.0'

from .data import KnowledgeGraph

Expand Down
168 changes: 168 additions & 0 deletions torchkge/data/DataLoader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# -*- coding: utf-8 -*-
"""
Copyright TorchKGE developers
[email protected]
This module's code is freely adapted from Scikit-Learn's sklearn.datasets.base.py code.
"""

from .KnowledgeGraph import KnowledgeGraph

from os import environ, makedirs, remove
from os.path import exists, expanduser, join
from pandas import read_csv, concat, merge, DataFrame
from urllib.request import urlretrieve

import shutil
import tarfile
import zipfile


def get_data_home(data_home=None):
if data_home is None:
data_home = environ.get('TORCHKGE_DATA',
join('~', 'torchkge_data'))
data_home = expanduser(data_home)
if not exists(data_home):
makedirs(data_home)
return data_home


def clear_data_home(data_home=None):
data_home = get_data_home(data_home)
shutil.rmtree(data_home)


def load_fb13():
data_home = get_data_home()
data_path = data_home + '/FB13'
if not exists(data_path):
makedirs(data_path, exist_ok=True)
urlretrieve("https://graphs.telecom-paristech.fr/datasets/FB13.zip",
data_home + '/FB13.zip')
with zipfile.ZipFile(data_home + '/FB13.zip', 'r') as zip_ref:
zip_ref.extractall(data_home)
remove(data_home + '/FB13.zip')
shutil.rmtree(data_home + '/__MACOSX')

df1 = read_csv(data_path + '/train2id.txt',
sep='\t', header=None, names=['from', 'rel', 'to'])
df2 = read_csv(data_path + '/valid2id.txt',
sep='\t', header=None, names=['from', 'rel', 'to'])
df3 = read_csv(data_path + '/test2id.txt',
sep='\t', header=None, names=['from', 'rel', 'to'])
df = concat([df1, df2, df3])
kg = KnowledgeGraph(df)

return kg.split_kg(sizes=(len(df1), len(df2), len(df3)))


def load_fb15k():
data_home = get_data_home()
data_path = data_home + '/FB15k'
if not exists(data_path):
makedirs(data_path, exist_ok=True)
urlretrieve("https://graphs.telecom-paristech.fr/datasets/FB15k.zip",
data_home + '/FB15k.zip')
with zipfile.ZipFile(data_home + '/FB15k.zip', 'r') as zip_ref:
zip_ref.extractall(data_home)
remove(data_home + '/FB15k.zip')
shutil.rmtree(data_home + '/__MACOSX')

df1 = read_csv(data_path + '/freebase_mtr100_mte100-train.txt',
sep='\t', header=None, names=['from', 'rel', 'to'])
df2 = read_csv(data_path + '/freebase_mtr100_mte100-valid.txt',
sep='\t', header=None, names=['from', 'rel', 'to'])
df3 = read_csv(data_path + '/freebase_mtr100_mte100-test.txt',
sep='\t', header=None, names=['from', 'rel', 'to'])
df = concat([df1, df2, df3])
kg = KnowledgeGraph(df)

return kg.split_kg(sizes=(len(df1), len(df2), len(df3)))


def load_fb15k237():
data_home = get_data_home()
data_path = data_home + '/FB15k237'
if not exists(data_path):
makedirs(data_path, exist_ok=True)
urlretrieve("https://graphs.telecom-paristech.fr/datasets/FB15k237.zip",
data_home + '/FB15k237.zip')
with zipfile.ZipFile(data_home + '/FB15k237.zip', 'r') as zip_ref:
zip_ref.extractall(data_home)
remove(data_home + '/FB15k237.zip')
shutil.rmtree(data_home + '/__MACOSX')

df1 = read_csv(data_path + '/train.txt',
sep='\t', header=None, names=['from', 'rel', 'to'])
df2 = read_csv(data_path + '/valid.txt',
sep='\t', header=None, names=['from', 'rel', 'to'])
df3 = read_csv(data_path + '/test.txt',
sep='\t', header=None, names=['from', 'rel', 'to'])
df = concat([df1, df2, df3])
kg = KnowledgeGraph(df)

return kg.split_kg(sizes=(len(df1), len(df2), len(df3)))


def load_wn18():
data_home = get_data_home()
data_path = data_home + '/WN18'
if not exists(data_path):
makedirs(data_path, exist_ok=True)
urlretrieve("https://graphs.telecom-paristech.fr/datasets/WN18.zip",
data_home + '/WN18.zip')
with zipfile.ZipFile(data_home + '/WN18.zip', 'r') as zip_ref:
zip_ref.extractall(data_home)
remove(data_home + '/WN18.zip')
shutil.rmtree(data_home + '/__MACOSX')

df1 = read_csv(data_path + '/wordnet-mlj12-train.txt',
sep='\t', header=None, names=['from', 'rel', 'to'])
df2 = read_csv(data_path + '/wordnet-mlj12-valid.txt',
sep='\t', header=None, names=['from', 'rel', 'to'])
df3 = read_csv(data_path + '/wordnet-mlj12-test.txt',
sep='\t', header=None, names=['from', 'rel', 'to'])
df = concat([df1, df2, df3])
kg = KnowledgeGraph(df)

return kg.split_kg(sizes=(len(df1), len(df2), len(df3)))


def load_wikidatasets(which, limit_=None):
assert which in ['humans', 'companies', 'animals', 'countries', 'films']

if limit_ is None:
limit_ = 0

data_home = get_data_home() + '/WikiDataSets'
data_path = data_home + '/' + which
if not exists(data_path):
makedirs(data_path, exist_ok=True)
urlretrieve("https://graphs.telecom-paristech.fr/WikiDataSets/{}.tar.gz".format(which),
data_home + '/{}.tar.gz'.format(which))

with tarfile.open(data_home + '/{}.tar.gz'.format(which), 'r') as tf:
tf.extractall(data_home)
remove(data_home + '/{}.tar.gz'.format(which))

df = read_csv(data_path + '/edges.txt'.format(which), sep='\t', header=1,
names=['from', 'to', 'rel'])

a = df.groupby('from').count()['rel']
b = df.groupby('to').count()['rel']

# Filter out nodes with too few facts
tmp = merge(right=DataFrame(a).reset_index(),
left=DataFrame(b).reset_index(),
how='outer', right_on='from', left_on='to', ).fillna(0)

tmp['rel'] = tmp['rel_x'] + tmp['rel_y']
tmp = tmp.drop(['from', 'rel_x', 'rel_y'], axis=1)

tmp = tmp.loc[tmp['rel'] >= limit_]
df_bis = df.loc[df['from'].isin(tmp['to']) | df['to'].isin(tmp['to'])]

kg = KnowledgeGraph(df_bis)
kg_train, kg_val, kg_test = kg.split_kg(share=0.8, validation=True)

return kg_train, kg_val, kg_test
32 changes: 16 additions & 16 deletions torchkge/data/KnowledgeGraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,14 @@ class KnowledgeGraph(Dataset):
Number of distinct entities in the data set.
n_rel: int
Number of distinct entities in the data set.
n_sample: int
n_facts: int
Number of samples in the data set. A sample is a fact: a triplet (h, r, l).
head_idx: torch.Tensor, dtype = long, shape = (n_sample)
List of the int key of heads for each sample (fact).
tail_idx: torch.Tensor, dtype = long, shape = (n_sample)
List of the int key of tails for each sample (facts).
relations: torch.Tensor, dtype = long, shape = (n_sample)
List of the int key of relations for each sample (facts).
head_idx: torch.Tensor, dtype = long, shape = (n_facts)
List of the int key of heads for each fact.
tail_idx: torch.Tensor, dtype = long, shape = (n_facts)
List of the int key of tails for each fact.
relations: torch.Tensor, dtype = long, shape = (n_facts)
List of the int key of relations for each fact.
"""

Expand All @@ -84,14 +84,14 @@ def __init__(self, df=None, kg=None,
if df is not None:
assert kg is None
self.df = df
self.n_sample = len(df)
self.n_facts = len(df)
self.head_idx = tensor(df['from'].map(self.ent2ix).values).long()
self.tail_idx = tensor(df['to'].map(self.ent2ix).values).long()
self.relations = tensor(df['rel'].map(self.rel2ix).values).long()
else:
assert kg is not None
self.df = kg['df']
self.n_sample = kg['heads'].shape[0]
self.n_facts = kg['heads'].shape[0]
self.head_idx = kg['heads']
self.tail_idx = kg['tails']
self.relations = kg['relations']
Expand All @@ -107,7 +107,7 @@ def __init__(self, df=None, kg=None,
self.dict_of_tails = dict_of_tails

def __len__(self):
return self.n_sample
return self.n_facts

def __getitem__(self, item):
return self.head_idx[item].item(), self.tail_idx[item].item(), self.relations[item].item()
Expand Down Expand Up @@ -139,9 +139,9 @@ def split_kg(self, share=0.8, sizes=None, validation=False):
if sizes is not None:
try:
if len(sizes) == 3:
assert (sizes[0] + sizes[1] + sizes[2] == self.n_sample)
assert (sizes[0] + sizes[1] + sizes[2] == self.n_facts)
elif len(sizes) == 2:
assert (sizes[0] + sizes[1] == self.n_sample)
assert (sizes[0] + sizes[1] == self.n_facts)
else:
raise SizeMismatchError('Tuple `sizes` should be of length 2 or 3.')
except AssertionError:
Expand All @@ -156,10 +156,10 @@ def split_kg(self, share=0.8, sizes=None, validation=False):
mask_te = ~(mask_tr | mask_val)
else:
mask_tr = cat([tensor([1 for _ in range(sizes[0])]),
tensor([0 for _ in range(sizes[1] + sizes[2])])]).byte()
tensor([0 for _ in range(sizes[1] + sizes[2])])]).bool()
mask_val = cat([tensor([0 for _ in range(sizes[0])]),
tensor([1 for _ in range(sizes[1])]),
tensor([0 for _ in range(sizes[2])])]).byte()
tensor([0 for _ in range(sizes[2])])]).bool()
mask_te = ~(mask_tr | mask_val)

return KnowledgeGraph(
Expand All @@ -182,7 +182,7 @@ def split_kg(self, share=0.8, sizes=None, validation=False):
mask_tr = (empty(self.head_idx.shape).uniform_() < share)
else:
mask_tr = cat([tensor([1 for _ in range(sizes[0])]),
tensor([0 for _ in range(sizes[1])])]).byte()
tensor([0 for _ in range(sizes[1])])]).bool()
return KnowledgeGraph(
kg={'heads': self.head_idx[mask_tr], 'tails': self.tail_idx[mask_tr],
'relations': self.relations[mask_tr], 'df': self.df},
Expand All @@ -198,7 +198,7 @@ def evaluate_dicts(self):
fact in the entire knowledge graph.
"""
for i in tqdm(range(self.n_sample)):
for i in tqdm(range(self.n_facts)):
self.dict_of_heads[(self.tail_idx[i].item(),
self.relations[i].item())].add(self.head_idx[i].item())
self.dict_of_tails[(self.head_idx[i].item(),
Expand Down
6 changes: 6 additions & 0 deletions torchkge/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,8 @@
from .KnowledgeGraph import SmallKG
from .KnowledgeGraph import KnowledgeGraph

from .DataLoader import load_fb13
from .DataLoader import load_fb15k
from .DataLoader import load_fb15k237
from .DataLoader import load_wn18
from .DataLoader import load_wikidatasets
8 changes: 4 additions & 4 deletions torchkge/evaluation/LinkPrediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ def __init__(self, model, knowledge_graph):
self.model = model
self.kg = knowledge_graph

self.rank_true_heads = empty(size=(knowledge_graph.n_sample,)).long()
self.rank_true_tails = empty(size=(knowledge_graph.n_sample,)).long()
self.filt_rank_true_heads = empty(size=(knowledge_graph.n_sample,)).long()
self.filt_rank_true_tails = empty(size=(knowledge_graph.n_sample,)).long()
self.rank_true_heads = empty(size=(knowledge_graph.n_facts,)).long()
self.rank_true_tails = empty(size=(knowledge_graph.n_facts,)).long()
self.filt_rank_true_heads = empty(size=(knowledge_graph.n_facts,)).long()
self.filt_rank_true_tails = empty(size=(knowledge_graph.n_facts,)).long()

self.evaluated = False
self.k_max = 10
Expand Down
22 changes: 12 additions & 10 deletions torchkge/evaluation/TripletClassification.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,17 @@ def get_scores(self, heads, tails, relations, batch_size):
Parameters
----------
heads: torch.Tensor, dtype = long, shape = n_sample
heads: torch.Tensor, dtype = long, shape = n_facts
List of heads indices.
tails: torch.Tensor, dtype = long, shape = n_sample
tails: torch.Tensor, dtype = long, shape = n_facts
List of tails indices.
relations: torch.Tensor, dtype = long, shape = n_sample
relations: torch.Tensor, dtype = long, shape = n_facts
List of relation indices.
batch_size: int
Returns
-------
scores: torch.Tensor, dtype = float, shape = n_sample
scores: torch.Tensor, dtype = float, shape = n_facts
List of scores of each triplet.
"""
scores = []
Expand Down Expand Up @@ -111,9 +111,11 @@ def evaluate(self, batch_size):
self.thresholds = zeros(self.kg_val.n_rel)

for i in range(self.kg_val.n_rel):
mask = (r_idx == i).byte()
assert mask.sum() > 0
self.thresholds[i] = neg_scores[mask == 1].max()
mask = (r_idx == i).bool()
if mask.sum() > 0:
self.thresholds[i] = neg_scores[mask].max()
else:
self.thresholds[i] = neg_scores.max()

self.evaluated = True
self.thresholds.detach_()
Expand Down Expand Up @@ -144,7 +146,7 @@ def accuracy(self, batch_size):
if self.use_cuda:
self.thresholds = self.thresholds.cuda()

scores = (scores > self.thresholds[r_idx]).byte()
neg_scores = (neg_scores < self.thresholds[r_idx]).byte()
scores = (scores > self.thresholds[r_idx])
neg_scores = (neg_scores < self.thresholds[r_idx])

return (scores.sum().item() + neg_scores.sum().item()) / (2 * self.kg_test.n_sample)
return (scores.sum().item() + neg_scores.sum().item()) / (2 * self.kg_test.n_facts)
Loading

0 comments on commit 3bb8ccb

Please sign in to comment.