Skip to content

Commit

Permalink
Explicit seed setting (#1454)
Browse files Browse the repository at this point in the history
  • Loading branch information
henryre authored Sep 6, 2019
1 parent af69c18 commit a9c28a2
Show file tree
Hide file tree
Showing 10 changed files with 31 additions and 25 deletions.
1 change: 0 additions & 1 deletion docs/packages/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,4 @@ General machine learning utilities shared across Snorkel.
filter_labels
preds_to_probs
probs_to_preds
set_seed
to_int_label_array
7 changes: 5 additions & 2 deletions snorkel/labeling/model/label_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import random
from collections import Counter
from itertools import chain, permutations
from typing import Any, Dict, List, NamedTuple, Optional, Set, Tuple, Union
Expand All @@ -13,7 +14,7 @@
from snorkel.labeling.model.graph_utils import get_clique_tree
from snorkel.labeling.model.logger import Logger
from snorkel.types import Config
from snorkel.utils import probs_to_preds, set_seed
from snorkel.utils import probs_to_preds
from snorkel.utils.config_utils import merge_config
from snorkel.utils.lr_schedulers import LRSchedulerConfig
from snorkel.utils.optimizers import OptimizerConfig
Expand Down Expand Up @@ -841,7 +842,9 @@ def fit(
TrainConfig(), kwargs # type:ignore
)
# Update base config so that it includes all parameters
set_seed(self.train_config.seed)
random.seed(self.train_config.seed)
np.random.seed(self.train_config.seed)
torch.manual_seed(self.train_config.seed)

L_shift = L_train + 1 # convert to {0, 1, ..., k}
if L_shift.max() > self.cardinality:
Expand Down
1 change: 0 additions & 1 deletion snorkel/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,5 @@
filter_labels,
preds_to_probs,
probs_to_preds,
set_seed,
to_int_label_array,
)
9 changes: 0 additions & 9 deletions snorkel/utils/core.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,7 @@
import hashlib
import random
from typing import Dict, List

import numpy as np
import torch


def set_seed(seed: int) -> None:
"""Set the Python, NumPy, and PyTorch random seeds."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)


def _hash(i: int) -> int:
Expand Down
6 changes: 4 additions & 2 deletions test/classification/test_classifier_convergence.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import random
import unittest
from typing import List

Expand All @@ -16,7 +17,6 @@
Task,
Trainer,
)
from snorkel.utils import set_seed

N_TRAIN = 1000
N_VALID = 300
Expand All @@ -26,7 +26,9 @@ class ClassifierConvergenceTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Ensure deterministic runs
set_seed(123)
random.seed(123)
np.random.seed(123)
torch.manual_seed(123)

@pytest.mark.complex
def test_convergence(self):
Expand Down
6 changes: 4 additions & 2 deletions test/classification/test_multitask_classifier.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import random
import tempfile
import unittest

Expand All @@ -14,7 +15,6 @@
Operation,
Task,
)
from snorkel.utils import set_seed

NUM_EXAMPLES = 10
BATCH_SIZE = 2
Expand All @@ -28,7 +28,9 @@ def setUpClass(cls):
cls.dataloader = create_dataloader("task1")

def setUp(self):
set_seed(123)
random.seed(123)
np.random.seed(123)
torch.manual_seed(123)

def test_onetask_model(self):
model = MultitaskClassifier(tasks=[self.task1])
Expand Down
7 changes: 5 additions & 2 deletions test/classification/training/schedulers/test_schedulers.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import random
import unittest

import numpy as np
import torch

from snorkel.classification import DictDataLoader, DictDataset
from snorkel.classification.training.schedulers import (
SequentialScheduler,
ShuffledScheduler,
)
from snorkel.utils import set_seed

dataset1 = DictDataset(
"d1",
Expand Down Expand Up @@ -37,7 +38,9 @@ def test_sequential(self):
self.assertEqual(data, sorted(data))

def test_shuffled(self):
set_seed(123)
random.seed(123)
np.random.seed(123)
torch.manual_seed(123)
scheduler = ShuffledScheduler()
data = []
for (batch, dl) in scheduler.get_batches(dataloaders):
Expand Down
7 changes: 5 additions & 2 deletions test/labeling/test_convergence.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import random
import unittest

import numpy as np
import pandas as pd
import pytest
import torch

from snorkel.labeling import (
LabelingFunction,
Expand All @@ -12,7 +14,6 @@
)
from snorkel.preprocess import preprocessor
from snorkel.types import DataPoint
from snorkel.utils import set_seed


def create_data(n: int) -> pd.DataFrame:
Expand Down Expand Up @@ -61,7 +62,9 @@ class LabelingConvergenceTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Ensure deterministic runs
set_seed(123)
random.seed(123)
np.random.seed(123)
torch.manual_seed(123)

# Create raw data
cls.N_TRAIN = 1500
Expand Down
6 changes: 4 additions & 2 deletions test/slicing/test_convergence.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import random
import unittest
from typing import List

Expand All @@ -23,7 +24,6 @@
slicing_function,
)
from snorkel.types import DataPoint
from snorkel.utils import set_seed


# Define SFs specifying points inside a circle
Expand Down Expand Up @@ -55,7 +55,9 @@ class SlicingConvergenceTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Ensure deterministic runs
set_seed(123)
random.seed(123)
np.random.seed(123)
torch.manual_seed(123)

# Create raw data
cls.N_TRAIN = 1500
Expand Down
6 changes: 4 additions & 2 deletions test/slicing/test_slice_combiner.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import random
import unittest

import numpy as np
import torch

from snorkel.slicing import SliceCombinerModule
from snorkel.utils import set_seed


class SliceCombinerTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
set_seed(123)
random.seed(123)
np.random.seed(123)
torch.manual_seed(123)

def test_forward_shape(self):
"""Test that the reweight representation shape matches expected feature size."""
Expand Down

0 comments on commit a9c28a2

Please sign in to comment.