Skip to content

Commit

Permalink
Saving all attributes of LabelModel (#1463)
Browse files Browse the repository at this point in the history
  • Loading branch information
paroma authored Sep 20, 2019
1 parent 9af1c77 commit 7c400b1
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 24 deletions.
36 changes: 15 additions & 21 deletions snorkel/labeling/model/label_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import pickle
import random
from collections import Counter
from itertools import chain, permutations
Expand Down Expand Up @@ -364,7 +365,7 @@ def predict_proba(self, L: np.ndarray) -> np.ndarray:
Parameters
----------
L
An [n,m] matrix with values in {-1,0,1,...,k-1}
An [n,m] matrix with values in {-1,0,1,...,k-1}f
Returns
-------
Expand Down Expand Up @@ -929,44 +930,37 @@ def fit(
if self.config.verbose: # pragma: no cover
logging.info("Finished Training")

def save(self, destination: str, **kwargs: Any) -> None:
def save(self, destination: str) -> None:
"""Save label model.
Parameters
----------
destination
File location for saving model
**kwargs
Arguments for torch.save
Filename for saving model
Example
-------
>>> label_model.save('./saved_label_model') # doctest: +SKIP
>>> label_model.save('./saved_label_model.pkl') # doctest: +SKIP
"""
with open(destination, "wb") as f:
torch.save(self, f, **kwargs)
f = open(destination, "wb")
pickle.dump(self.__dict__, f)
f.close()

@staticmethod
def load(source: str, **kwargs: Any) -> Any:
def load(self, source: str) -> None:
"""Load existing label model.
Parameters
----------
source
File location from where to load model
**kwargs
Arguments for torch.load
Returns
-------
LabelModel
LabelModel with appropriate loaded parameters
Filename to load model from
Example
-------
Load parameters saved in ``saved_label_model``
>>> label_model.load('./saved_label_model') # doctest: +SKIP
>>> label_model.load('./saved_label_model.pkl') # doctest: +SKIP
"""
with open(source, "rb") as f:
return torch.load(f, **kwargs)
f = open(source, "rb")
tmp_dict = pickle.load(f)
f.close()
self.__dict__.update(tmp_dict)
13 changes: 10 additions & 3 deletions test/labeling/model/test_label_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,15 +336,22 @@ def test_lr_scheduler(self):
label_model.fit(L, n_epochs=1, lr_scheduler="bad_scheduler")

def test_save_and_load(self):
L = np.array([[0, -1, 0], [0, 1, 0]])
L = np.array([[0, -1, 0], [0, 1, 1]])
label_model = LabelModel(cardinality=2, verbose=False)
label_model.fit(L, n_epochs=1)
original_preds = label_model.predict(L)

dir_path = tempfile.mkdtemp()
save_path = dir_path + "label_model"
save_path = dir_path + "label_model.pkl"
label_model.save(save_path)
label_model.load(save_path)

label_model_new = LabelModel(cardinality=2, verbose=False)
label_model_new.load(save_path)
loaded_preds = label_model_new.predict(L)
shutil.rmtree(dir_path)

np.testing.assert_array_equal(loaded_preds, original_preds)

def test_optimizer_init(self):
L = np.array([[0, -1, 0], [0, 1, 0]])
label_model = LabelModel()
Expand Down

0 comments on commit 7c400b1

Please sign in to comment.