Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explicit seed setting #1454

Merged
merged 1 commit into from
Sep 6, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/packages/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,4 @@ General machine learning utilities shared across Snorkel.
filter_labels
preds_to_probs
probs_to_preds
set_seed
to_int_label_array
7 changes: 5 additions & 2 deletions snorkel/labeling/model/label_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import random
from collections import Counter
from itertools import chain, permutations
from typing import Any, Dict, List, NamedTuple, Optional, Set, Tuple, Union
Expand All @@ -13,7 +14,7 @@
from snorkel.labeling.model.graph_utils import get_clique_tree
from snorkel.labeling.model.logger import Logger
from snorkel.types import Config
from snorkel.utils import probs_to_preds, set_seed
from snorkel.utils import probs_to_preds
from snorkel.utils.config_utils import merge_config
from snorkel.utils.lr_schedulers import LRSchedulerConfig
from snorkel.utils.optimizers import OptimizerConfig
Expand Down Expand Up @@ -841,7 +842,9 @@ def fit(
TrainConfig(), kwargs # type:ignore
)
# Update base config so that it includes all parameters
set_seed(self.train_config.seed)
random.seed(self.train_config.seed)
np.random.seed(self.train_config.seed)
torch.manual_seed(self.train_config.seed)

L_shift = L_train + 1 # convert to {0, 1, ..., k}
if L_shift.max() > self.cardinality:
Expand Down
1 change: 0 additions & 1 deletion snorkel/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,5 @@
filter_labels,
preds_to_probs,
probs_to_preds,
set_seed,
to_int_label_array,
)
9 changes: 0 additions & 9 deletions snorkel/utils/core.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,7 @@
import hashlib
import random
from typing import Dict, List

import numpy as np
import torch


def set_seed(seed: int) -> None:
"""Set the Python, NumPy, and PyTorch random seeds."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)


def _hash(i: int) -> int:
Expand Down
6 changes: 4 additions & 2 deletions test/classification/test_classifier_convergence.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import random
import unittest
from typing import List

Expand All @@ -16,7 +17,6 @@
Task,
Trainer,
)
from snorkel.utils import set_seed

N_TRAIN = 1000
N_VALID = 300
Expand All @@ -26,7 +26,9 @@ class ClassifierConvergenceTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Ensure deterministic runs
set_seed(123)
random.seed(123)
np.random.seed(123)
torch.manual_seed(123)

@pytest.mark.complex
def test_convergence(self):
Expand Down
6 changes: 4 additions & 2 deletions test/classification/test_multitask_classifier.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import random
import tempfile
import unittest

Expand All @@ -14,7 +15,6 @@
Operation,
Task,
)
from snorkel.utils import set_seed

NUM_EXAMPLES = 10
BATCH_SIZE = 2
Expand All @@ -28,7 +28,9 @@ def setUpClass(cls):
cls.dataloader = create_dataloader("task1")

def setUp(self):
set_seed(123)
random.seed(123)
np.random.seed(123)
torch.manual_seed(123)

def test_onetask_model(self):
model = MultitaskClassifier(tasks=[self.task1])
Expand Down
7 changes: 5 additions & 2 deletions test/classification/training/schedulers/test_schedulers.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import random
import unittest

import numpy as np
import torch

from snorkel.classification import DictDataLoader, DictDataset
from snorkel.classification.training.schedulers import (
SequentialScheduler,
ShuffledScheduler,
)
from snorkel.utils import set_seed

dataset1 = DictDataset(
"d1",
Expand Down Expand Up @@ -37,7 +38,9 @@ def test_sequential(self):
self.assertEqual(data, sorted(data))

def test_shuffled(self):
set_seed(123)
random.seed(123)
np.random.seed(123)
torch.manual_seed(123)
scheduler = ShuffledScheduler()
data = []
for (batch, dl) in scheduler.get_batches(dataloaders):
Expand Down
7 changes: 5 additions & 2 deletions test/labeling/test_convergence.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import random
import unittest

import numpy as np
import pandas as pd
import pytest
import torch

from snorkel.labeling import (
LabelingFunction,
Expand All @@ -12,7 +14,6 @@
)
from snorkel.preprocess import preprocessor
from snorkel.types import DataPoint
from snorkel.utils import set_seed


def create_data(n: int) -> pd.DataFrame:
Expand Down Expand Up @@ -61,7 +62,9 @@ class LabelingConvergenceTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Ensure deterministic runs
set_seed(123)
random.seed(123)
np.random.seed(123)
torch.manual_seed(123)

# Create raw data
cls.N_TRAIN = 1500
Expand Down
6 changes: 4 additions & 2 deletions test/slicing/test_convergence.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import random
import unittest
from typing import List

Expand All @@ -23,7 +24,6 @@
slicing_function,
)
from snorkel.types import DataPoint
from snorkel.utils import set_seed


# Define SFs specifying points inside a circle
Expand Down Expand Up @@ -55,7 +55,9 @@ class SlicingConvergenceTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Ensure deterministic runs
set_seed(123)
random.seed(123)
np.random.seed(123)
torch.manual_seed(123)

# Create raw data
cls.N_TRAIN = 1500
Expand Down
6 changes: 4 additions & 2 deletions test/slicing/test_slice_combiner.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import random
import unittest

import numpy as np
import torch

from snorkel.slicing import SliceCombinerModule
from snorkel.utils import set_seed


class SliceCombinerTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
set_seed(123)
random.seed(123)
np.random.seed(123)
torch.manual_seed(123)

def test_forward_shape(self):
"""Test that the reweight representation shape matches expected feature size."""
Expand Down