Skip to content

Commit

Permalink
Merge pull request #191 from MacOS/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
armand33 authored Mar 1, 2021
2 parents a339b5c + a051913 commit a80743d
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions torchkge/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,17 @@


def get_data_home(data_home=None):
"""Returns the path to the data directory. The path is created if
it does not exist.
If data_home is none, the data is downloaded into the home directory of
of the user.
Parameters
----------
data_home: string
The path to the data set.
"""
if data_home is None:
data_home = environ.get('TORCHKGE_DATA',
join('~', 'torchkge_data'))
Expand All @@ -21,11 +32,31 @@ def get_data_home(data_home=None):


def clear_data_home(data_home=None):
"""Deletes the directory data_home
Parameters
----------
data_home: string
The path to the directory that should be removed.
"""
data_home = get_data_home(data_home)
shutil.rmtree(data_home)


def get_n_batches(n, b_size):
"""Returns the number of bachtes. Let n be the number of samples in the data set,
let batch_size be the number of samples per batch, then the number of batches is given by
n
n_batches = ---------
batch_size
Parameters
----------
n: int
Size of the data set.
b_size: int
Number of samples per batch.
"""
n_batch = n // b_size
if n % b_size > 0:
n_batch += 1
Expand Down

0 comments on commit a80743d

Please sign in to comment.