Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Develop #255

Merged
merged 6 commits into from
Apr 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/history.rst
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
=======
History
=======
0.17.5 (2022-09-18)
0.17.7 (2023-04-05)
-------------------
* Fix bug in TransH implementation
* Adding additional pretrained models

0.17.6 (2023-03-31)
-------------------
Expand Down
18 changes: 18 additions & 0 deletions docs/reference/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,22 @@ TransE model

.. autofunction:: torchkge.utils.pretrained_models.load_pretrained_transe

RESCAL Model
=============
.. tabularcolumns:: p{3cm}p{3cm}p{3cm}p{3cm}

+-----------+-----------+-----------+----------+--------------------+
| Model | Dataset | Dimension | Test MRR | Filtered Test MRR |
+===========+===========+===========+==========+====================+
| RESCAL | FB15k237 | 200 | 0.180 | 0.305 |
+-----------+-----------+-----------+----------+--------------------+
| RESCAL | WN18RR | 150 | 0.273 | 0.424 |
+-----------+-----------+-----------+----------+--------------------+
| RESCAL | Yago3-10 | 200 | 0.124 | 0.308 |
+-----------+-----------+-----------+----------+--------------------+

.. autofunction:: torchkge.utils.pretrained_models.load_pretrained_rescal

ComplEx Model
=============
.. tabularcolumns:: p{3cm}p{3cm}p{3cm}p{3cm}
Expand All @@ -55,6 +71,8 @@ ComplEx Model
+-----------+-----------+-----------+----------+--------------------+
| ComplEx | WDV5 | 200 | 0.283 | 0.371 |
+-----------+-----------+-----------+----------+--------------------+
| ComplEx | Yago3-10 | 200 | 0.164 | 0.421 |
+-----------+-----------+-----------+----------+--------------------+

.. autofunction:: torchkge.utils.pretrained_models.load_pretrained_complex

Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.17.6
current_version = 0.17.7
commit = True
tag = True

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,6 @@
setup_requires=setup_requirements,
tests_require=test_requirements,
test_suite='tests',
version='0.17.6',
version='0.17.7',
zip_safe=False,
)
6 changes: 3 additions & 3 deletions torchkge/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

__author__ = """Armand Boschin"""
__email__ = '[email protected]'
__version__ = '0.17.6'
__version__ = '0.17.7'

from torchkge.exceptions import NotYetEvaluatedError
from torchkge.utils import MarginLoss, LogisticLoss
Expand All @@ -13,5 +13,5 @@
from .evaluation import LinkPredictionEvaluator
from .evaluation import TripletClassificationEvaluator
from .models import ConvKBModel
from .models import RESCALModel, DistMultModel
from .models import TransEModel, TransHModel, TransRModel, TransDModel
from .models import RESCALModel, DistMultModel, HolEModel, ComplExModel, AnalogyModel
from .models import TransEModel, TransHModel, TransRModel, TransDModel, TorusEModel
1 change: 1 addition & 0 deletions torchkge/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@
from .losses import MarginLoss, LogisticLoss, BinaryCrossEntropyLoss
from .modeling import init_embedding, get_true_targets, load_embeddings, filter_scores
from .operations import get_rank, get_mask, get_bernoulli_probs
from .pretrained_models import load_pretrained_transe, load_pretrained_rescal, load_pretrained_complex
from .training import Trainer, TrainDataLoader
88 changes: 66 additions & 22 deletions torchkge/utils/pretrained_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@
"""

from ..exceptions import NoPreTrainedVersionError
from ..models import TransEModel, ComplExModel
from ..models import TransEModel, ComplExModel, RESCALModel
from ..utils import load_embeddings


def load_pretrained_transe(dataset, emb_dim, data_home=None):
def load_pretrained_transe(dataset, emb_dim=None, data_home=None):
"""Load a pretrained version of TransE model.

Parameters
----------
dataset: str
emb_dim: int
emb_dim: int (opt, default None)
Embedding dimension
data_home: str (opt, default None)
Path to the `torchkge_data` directory (containing data folders). Useful
Expand All @@ -26,16 +26,18 @@ def load_pretrained_transe(dataset, emb_dim, data_home=None):
model: `TorchKGE.model.translation.TransEModel`
Pretrained version of TransE model.
"""
dims = {'fb15k': 100, 'wn18rr': 100, 'fb15k237': 150, 'wdv5': 150, 'yago310': 200}
try:
assert (dataset in {'fb15k', 'wn18rr'} and emb_dim == 100) \
or (dataset == 'fb15k237' and emb_dim == 150) \
or (dataset == 'wdv5' and emb_dim == 150) \
or (dataset == 'yago310' and emb_dim == 200)

except AssertionError:
raise NoPreTrainedVersionError('No pre-trained version of TransE for '
'{} in dimension {}'.format(dataset,
emb_dim))
if emb_dim is None:
emb_dim = dims[dataset]
else:
try:
assert dims[dataset] == emb_dim
except AssertionError:
raise NoPreTrainedVersionError('No pre-trained version of TransE for '
'{} in dimension {}'.format(dataset, emb_dim))
except KeyError:
raise NoPreTrainedVersionError('No pre-trained version of TransE for {}'.format(dataset))

state_dict = load_embeddings('transe', emb_dim, dataset, data_home)
model = TransEModel(emb_dim,
Expand All @@ -47,13 +49,13 @@ def load_pretrained_transe(dataset, emb_dim, data_home=None):
return model


def load_pretrained_complex(dataset, emb_dim, data_home=None):
def load_pretrained_complex(dataset, emb_dim=None, data_home=None):
"""Load a pretrained version of ComplEx model.

Parameters
----------
dataset: str
emb_dim: int
emb_dim: int (opt, default None)
Embedding dimension
data_home: str (opt, default None)
Path to the `torchkge_data` directory (containing data folders). Useful
Expand All @@ -64,15 +66,18 @@ def load_pretrained_complex(dataset, emb_dim, data_home=None):
model: `TorchKGE.model.translation.ComplExModel`
Pretrained version of ComplEx model.
"""
dims = {'wn18rr': 200, 'fb15k237': 200, 'wdv5': 200, 'yago310': 200}
try:
assert (dataset == 'wn18rr' and emb_dim == 200) \
or (dataset == 'fb15k237' and emb_dim == 200) \
or (dataset == 'wdv5' and emb_dim == 200)

except AssertionError:
raise NoPreTrainedVersionError('No pre-trained version of ComplEx for '
'{} in dimension {}'.format(dataset,
emb_dim))
if emb_dim is None:
emb_dim = dims[dataset]
else:
try:
assert dims[dataset] == emb_dim
except AssertionError:
raise NoPreTrainedVersionError('No pre-trained version of ComplEx for '
'{} in dimension {}'.format(dataset, emb_dim))
except KeyError:
raise NoPreTrainedVersionError('No pre-trained version of ComplEx for {}'.format(dataset))

state_dict = load_embeddings('complex', emb_dim, dataset, data_home)
model = ComplExModel(emb_dim,
Expand All @@ -81,3 +86,42 @@ def load_pretrained_complex(dataset, emb_dim, data_home=None):
model.load_state_dict(state_dict)

return model


def load_pretrained_rescal(dataset, emb_dim=None, data_home=None):
"""Load a pretrained version of RESCAL model.

Parameters
----------
dataset: str
emb_dim: int (opt, default None)
Embedding dimension
data_home: str (opt, default None)
Path to the `torchkge_data` directory (containing data folders). Useful
for pre-trained model loading.

Returns
-------
model: `TorchKGE.model.translation.RESCALModel`
Pretrained version of RESCAL model.
"""
dims = {'wn18rr': 200, 'fb15k237': 200, 'yago310': 200}
try:
if emb_dim is None:
emb_dim = dims[dataset]
else:
try:
assert dims[dataset] == emb_dim
except AssertionError:
raise NoPreTrainedVersionError('No pre-trained version of RESCAL for '
'{} in dimension {}'.format(dataset, emb_dim))
except KeyError:
raise NoPreTrainedVersionError('No pre-trained version of RESCAL for {}'.format(dataset))

state_dict = load_embeddings('rescal', emb_dim, dataset, data_home)
model = RESCALModel(emb_dim,
n_entities=state_dict['ent_emb.weight'].shape[0],
n_relations=state_dict['rel_mat.weight'].shape[0])
model.load_state_dict(state_dict)

return model