diff --git a/docs/history.rst b/docs/history.rst index a6b076e..dc02e64 100644 --- a/docs/history.rst +++ b/docs/history.rst @@ -1,9 +1,9 @@ ======= History ======= -0.17.5 (2022-09-18) +0.17.7 (2023-04-05) ------------------- -* Fix bug in TransH implementation +* Adding additional pretrained models 0.17.6 (2023-03-31) ------------------- diff --git a/docs/reference/utils.rst b/docs/reference/utils.rst index 07c25d7..f034a9a 100755 --- a/docs/reference/utils.rst +++ b/docs/reference/utils.rst @@ -42,6 +42,22 @@ TransE model .. autofunction:: torchkge.utils.pretrained_models.load_pretrained_transe +RESCAL Model +============= +.. tabularcolumns:: p{3cm}p{3cm}p{3cm}p{3cm} + ++-----------+-----------+-----------+----------+--------------------+ +| Model | Dataset | Dimension | Test MRR | Filtered Test MRR | ++===========+===========+===========+==========+====================+ +| RESCAL | FB15k237 | 200 | 0.180 | 0.305 | ++-----------+-----------+-----------+----------+--------------------+ +| RESCAL | WN18RR | 150 | 0.273 | 0.424 | ++-----------+-----------+-----------+----------+--------------------+ +| RESCAL | Yago3-10 | 200 | 0.124 | 0.308 | ++-----------+-----------+-----------+----------+--------------------+ + +.. autofunction:: torchkge.utils.pretrained_models.load_pretrained_rescal + ComplEx Model ============= .. tabularcolumns:: p{3cm}p{3cm}p{3cm}p{3cm} @@ -55,6 +71,8 @@ ComplEx Model +-----------+-----------+-----------+----------+--------------------+ | ComplEx | WDV5 | 200 | 0.283 | 0.371 | +-----------+-----------+-----------+----------+--------------------+ +| ComplEx | Yago3-10 | 200 | 0.164 | 0.421 | ++-----------+-----------+-----------+----------+--------------------+ .. autofunction:: torchkge.utils.pretrained_models.load_pretrained_complex diff --git a/setup.cfg b/setup.cfg index a091d4a..5fe0de5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.17.6 +current_version = 0.17.7 commit = True tag = True diff --git a/setup.py b/setup.py index ffb64a5..15ec68f 100644 --- a/setup.py +++ b/setup.py @@ -38,6 +38,6 @@ setup_requires=setup_requirements, tests_require=test_requirements, test_suite='tests', - version='0.17.6', + version='0.17.7', zip_safe=False, ) diff --git a/torchkge/__init__.py b/torchkge/__init__.py index 21ebbec..10bf8a2 100755 --- a/torchkge/__init__.py +++ b/torchkge/__init__.py @@ -4,7 +4,7 @@ __author__ = """Armand Boschin""" __email__ = 'aboschin@enst.fr' -__version__ = '0.17.6' +__version__ = '0.17.7' from torchkge.exceptions import NotYetEvaluatedError from torchkge.utils import MarginLoss, LogisticLoss @@ -13,5 +13,5 @@ from .evaluation import LinkPredictionEvaluator from .evaluation import TripletClassificationEvaluator from .models import ConvKBModel -from .models import RESCALModel, DistMultModel -from .models import TransEModel, TransHModel, TransRModel, TransDModel +from .models import RESCALModel, DistMultModel, HolEModel, ComplExModel, AnalogyModel +from .models import TransEModel, TransHModel, TransRModel, TransDModel, TorusEModel diff --git a/torchkge/utils/__init__.py b/torchkge/utils/__init__.py index 006fbb0..b082389 100755 --- a/torchkge/utils/__init__.py +++ b/torchkge/utils/__init__.py @@ -14,4 +14,5 @@ from .losses import MarginLoss, LogisticLoss, BinaryCrossEntropyLoss from .modeling import init_embedding, get_true_targets, load_embeddings, filter_scores from .operations import get_rank, get_mask, get_bernoulli_probs +from .pretrained_models import load_pretrained_transe, load_pretrained_rescal, load_pretrained_complex from .training import Trainer, TrainDataLoader diff --git a/torchkge/utils/pretrained_models.py b/torchkge/utils/pretrained_models.py index 1614729..e67e48b 100644 --- a/torchkge/utils/pretrained_models.py +++ b/torchkge/utils/pretrained_models.py @@ -5,17 +5,17 @@ """ from ..exceptions import NoPreTrainedVersionError -from ..models import TransEModel, ComplExModel +from ..models import TransEModel, ComplExModel, RESCALModel from ..utils import load_embeddings -def load_pretrained_transe(dataset, emb_dim, data_home=None): +def load_pretrained_transe(dataset, emb_dim=None, data_home=None): """Load a pretrained version of TransE model. Parameters ---------- dataset: str - emb_dim: int + emb_dim: int (opt, default None) Embedding dimension data_home: str (opt, default None) Path to the `torchkge_data` directory (containing data folders). Useful @@ -26,16 +26,18 @@ def load_pretrained_transe(dataset, emb_dim, data_home=None): model: `TorchKGE.model.translation.TransEModel` Pretrained version of TransE model. """ + dims = {'fb15k': 100, 'wn18rr': 100, 'fb15k237': 150, 'wdv5': 150, 'yago310': 200} try: - assert (dataset in {'fb15k', 'wn18rr'} and emb_dim == 100) \ - or (dataset == 'fb15k237' and emb_dim == 150) \ - or (dataset == 'wdv5' and emb_dim == 150) \ - or (dataset == 'yago310' and emb_dim == 200) - - except AssertionError: - raise NoPreTrainedVersionError('No pre-trained version of TransE for ' - '{} in dimension {}'.format(dataset, - emb_dim)) + if emb_dim is None: + emb_dim = dims[dataset] + else: + try: + assert dims[dataset] == emb_dim + except AssertionError: + raise NoPreTrainedVersionError('No pre-trained version of TransE for ' + '{} in dimension {}'.format(dataset, emb_dim)) + except KeyError: + raise NoPreTrainedVersionError('No pre-trained version of TransE for {}'.format(dataset)) state_dict = load_embeddings('transe', emb_dim, dataset, data_home) model = TransEModel(emb_dim, @@ -47,13 +49,13 @@ def load_pretrained_transe(dataset, emb_dim, data_home=None): return model -def load_pretrained_complex(dataset, emb_dim, data_home=None): +def load_pretrained_complex(dataset, emb_dim=None, data_home=None): """Load a pretrained version of ComplEx model. Parameters ---------- dataset: str - emb_dim: int + emb_dim: int (opt, default None) Embedding dimension data_home: str (opt, default None) Path to the `torchkge_data` directory (containing data folders). Useful @@ -64,15 +66,18 @@ def load_pretrained_complex(dataset, emb_dim, data_home=None): model: `TorchKGE.model.translation.ComplExModel` Pretrained version of ComplEx model. """ + dims = {'wn18rr': 200, 'fb15k237': 200, 'wdv5': 200, 'yago310': 200} try: - assert (dataset == 'wn18rr' and emb_dim == 200) \ - or (dataset == 'fb15k237' and emb_dim == 200) \ - or (dataset == 'wdv5' and emb_dim == 200) - - except AssertionError: - raise NoPreTrainedVersionError('No pre-trained version of ComplEx for ' - '{} in dimension {}'.format(dataset, - emb_dim)) + if emb_dim is None: + emb_dim = dims[dataset] + else: + try: + assert dims[dataset] == emb_dim + except AssertionError: + raise NoPreTrainedVersionError('No pre-trained version of ComplEx for ' + '{} in dimension {}'.format(dataset, emb_dim)) + except KeyError: + raise NoPreTrainedVersionError('No pre-trained version of ComplEx for {}'.format(dataset)) state_dict = load_embeddings('complex', emb_dim, dataset, data_home) model = ComplExModel(emb_dim, @@ -81,3 +86,42 @@ def load_pretrained_complex(dataset, emb_dim, data_home=None): model.load_state_dict(state_dict) return model + + +def load_pretrained_rescal(dataset, emb_dim=None, data_home=None): + """Load a pretrained version of RESCAL model. + + Parameters + ---------- + dataset: str + emb_dim: int (opt, default None) + Embedding dimension + data_home: str (opt, default None) + Path to the `torchkge_data` directory (containing data folders). Useful + for pre-trained model loading. + + Returns + ------- + model: `TorchKGE.model.translation.RESCALModel` + Pretrained version of RESCAL model. + """ + dims = {'wn18rr': 200, 'fb15k237': 200, 'yago310': 200} + try: + if emb_dim is None: + emb_dim = dims[dataset] + else: + try: + assert dims[dataset] == emb_dim + except AssertionError: + raise NoPreTrainedVersionError('No pre-trained version of RESCAL for ' + '{} in dimension {}'.format(dataset, emb_dim)) + except KeyError: + raise NoPreTrainedVersionError('No pre-trained version of RESCAL for {}'.format(dataset)) + + state_dict = load_embeddings('rescal', emb_dim, dataset, data_home) + model = RESCALModel(emb_dim, + n_entities=state_dict['ent_emb.weight'].shape[0], + n_relations=state_dict['rel_mat.weight'].shape[0]) + model.load_state_dict(state_dict) + + return model