Skip to content

Commit

Permalink
Merge pull request #62 from maciejkula/goodbooks
Browse files Browse the repository at this point in the history
Add goodbooks dataset
  • Loading branch information
maciejkula authored Sep 28, 2017
2 parents 7b1868b + 3c73799 commit dc68cfe
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 2 deletions.
8 changes: 8 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@
Changelog
=========

unreleased (unreleased)
-----------------------

Added
~~~~~

* Goodbooks dataset.

v0.1.2 (2017-09-10)
-------------------

Expand Down
1 change: 1 addition & 0 deletions docs/datasets/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ Datasets

Synthetic <synthetic>
Movielens <movielens>
Goodbooks <goodbooks>
7 changes: 7 additions & 0 deletions docs/datasets/goodbooks.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Goodbooks
=========


.. automodule:: spotlight.datasets.goodbooks
:members:
:undoc-members:
48 changes: 48 additions & 0 deletions spotlight/datasets/goodbooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""
Utilities for fetching the Goodbooks-10K dataset [1]_.
References
----------
.. [1] https://github.com/zygmuntz/goodbooks-10k
"""

import h5py

import numpy as np

from spotlight.datasets import _transport
from spotlight.interactions import Interactions


def _get_dataset():

path = _transport.get_data('https://github.com/zygmuntz/goodbooks-10k/'
'releases/download/v1.0/goodbooks-10k.hdf5',
'goodbooks',
'goodbooks.hdf5')

with h5py.File(path, 'r') as data:
return (data['ratings'][:, 0],
data['ratings'][:, 1],
data['ratings'][:, 2],
np.arange(len(data['ratings']), dtype=np.int32))


def get_goodbooks_dataset():
"""
Download and return the goodbooks-10K dataset [2]_.
Returns
-------
Interactions: :class:`spotlight.interactions.Interactions`
instance of the interactions class
References
----------
.. [2] https://github.com/zygmuntz/goodbooks-10k
"""

return Interactions(*_get_dataset())
11 changes: 11 additions & 0 deletions spotlight/interactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,3 +299,14 @@ def __init__(self,
self.num_items = sequences.max() + 1
else:
self.num_items = num_items

def __repr__(self):

num_sequences, sequence_length = self.sequences.shape

return ('<Sequence interactions dataset ({num_sequences} '
'sequences x {sequence_length} sequence length)>'
.format(
num_sequences=num_sequences,
sequence_length=sequence_length,
))
6 changes: 4 additions & 2 deletions spotlight/sequence/implicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,9 +269,10 @@ def _get_negative_prediction(self, shape, user_representation):
def _get_multiple_negative_predictions(self, shape, user_representation,
n=5):
batch_size, sliding_window = shape
size = (n,) + (1,) * (user_representation.dim() - 1)
negative_prediction = self._get_negative_prediction(
(n * batch_size, sliding_window),
user_representation.repeat(n, 1, 1))
user_representation.repeat(*size))

return negative_prediction.view(n, batch_size, sliding_window)

Expand Down Expand Up @@ -314,7 +315,8 @@ def predict(self, sequences, item_ids=None):
item_var = Variable(gpu(item_ids, self._use_cuda))

_, sequence_representations = self._net.user_representation(sequence_var)
out = self._net(sequence_representations.repeat(len(item_var), 1),
size = (len(item_var),) + sequence_representations.size()[1:]
out = self._net(sequence_representations.expand(*size),
item_var)

return cpu(out.data).numpy().flatten()

0 comments on commit dc68cfe

Please sign in to comment.