Skip to content

Commit

Permalink
Merge pull request #38 from HazyResearch/dev_cuda_loader
Browse files Browse the repository at this point in the history
Adding GPU functionality, enabling data loaders/tuples as input
  • Loading branch information
ajratner authored Oct 12, 2018
2 parents 9e2bb03 + 754a63f commit a2bb278
Show file tree
Hide file tree
Showing 25 changed files with 630 additions and 218 deletions.
2 changes: 1 addition & 1 deletion .isort.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ include_trailing_comma=True
force_grid_wrap=0
combine_as_imports=True
line_length=80
known_third_party=matplotlib,networkx,nltk,numpy,pandas,scipy,setuptools,sklearn,torch,torchtext
known_third_party=GPUtil,matplotlib,networkx,nltk,numpy,pandas,scipy,setuptools,sklearn,torch,torchtext,tqdm
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Snorkel MeTaL uses a new matrix approximation approach to learn the accuracies o
This makes it significantly more scalable than our previous approaches.

## References
* **Best Reference: [_Training Complex Models with Multi-Task Weak Supervision_](https://ajratner.github.io/assets/papers/mts-draft.pdf) [Technical Report]**
* **Best Reference: [_Training Complex Models with Multi-Task Weak Supervision_](https://arxiv.org/abs/1810.02840) [Technical Report]**
* [Snorkel MeTaL: Weak Supervision for Multi-Task Learning](https://ajratner.github.io/assets/papers/deem-metal-prototype.pdf) [SIGMOD DEEM 2018]
* _[Snorkel: Rapid Training Data Creation with Weak Supervision](https://arxiv.org/abs/1711.10160) [VLDB 2018]_
* _[Data Programming: Creating Large Training Sets, Quickly](https://arxiv.org/abs/1605.07723) [NIPS 2016]_
Expand Down Expand Up @@ -109,3 +109,9 @@ This will install a few additional tools that help to ensure that any commits or
* [flake8](http://flake8.pycqa.org/en/latest/): PEP8 linting

After running `make dev` to install the necessary tools, you can run `make check` to see if any changes you've made violate the repo standards and `make fix` to fix any related to isort/black. Fixes for flake8 violations will need to be made manually.

### GPU Usage
MeTaL supports GPU usage, but does not include this in automatically-run tests; to run these tests, first install the requirements in `tests/gpu/requirements.txt`, then run:
```
nosetests tests/gpu
```
3 changes: 2 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ dependencies:
- pandas
- pytorch=0.4.1
- runipy
- scipy
- scipy
- tqdm
166 changes: 139 additions & 27 deletions metal/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
import torch
import torch.nn as nn
import torch.optim as optim
from scipy.sparse import issparse
from torch.utils.data import DataLoader, Dataset, TensorDataset
from tqdm import tqdm

from metal.analysis import confusion_matrix
from metal.metrics import metric_score
from metal.utils import Checkpointer, recursive_merge_dicts
from metal.utils import Checkpointer, place_on_gpu, recursive_merge_dicts


class Classifier(nn.Module):
Expand Down Expand Up @@ -44,10 +47,15 @@ def __init__(self, k, config):
self.multitask = False
self.k = k

# Set random seed
if self.config["seed"] is None:
self.config["seed"] = np.random.randint(1e6)
self._set_seed(self.config["seed"])

# Confirm that cuda is available if config is using CUDA
if self.config["use_cuda"] and not torch.cuda.is_available():
raise ValueError("use_cuda=True but CUDA not available.")

def _set_seed(self, seed):
self.seed = seed
if torch.cuda.is_available():
Expand Down Expand Up @@ -121,21 +129,25 @@ def _create_checkpointer(self, checkpoint_config):
model_class, **checkpoint_config, verbose=self.config["verbose"]
)

def _train(self, train_loader, loss_fn, X_dev=None, Y_dev=None):
def _train(self, train_data, loss_fn, dev_data=None):
"""The internal training routine called by train() after initial setup
Args:
train_loader: a torch DataLoader of X (data) and Y (labels) for
the train split
train_data: a tuple of Tensors (X,Y), a Dataset, or a DataLoader of
X (data) and Y (labels) for the train split
loss_fn: the loss function to minimize (maps *data -> loss)
X_dev: the dev set model input
Y_dev: the dev set target labels
dev_data: a tuple of Tensors (X,Y), a Dataset, or a DataLoader of
X (data) and Y (labels) for the dev split
If either of X_dev or Y_dev is not provided, then no checkpointing or
If dev_data is not provided, then no checkpointing or
evaluation on the dev set will occur.
"""
train_config = self.config["train_config"]
evaluate_dev = X_dev is not None and Y_dev is not None
evaluate_dev = dev_data is not None

# Convert data to DataLoaders
train_loader = self._create_data_loader(train_data)
dev_loader = self._create_data_loader(dev_data)

# Set the optimizer
optimizer = self._set_optimizer(train_config)
Expand All @@ -150,13 +162,29 @@ def _train(self, train_loader, loss_fn, X_dev=None, Y_dev=None):
train_config["checkpoint_config"]
)

# Moving model to GPU
if self.config["use_cuda"]:
if self.config["verbose"]:
print("Using GPU...")
self.cuda()

# Train the model
for epoch in range(train_config["n_epochs"]):
epoch_loss = 0.0
for data in train_loader:
for batch_num, data in tqdm(
enumerate(train_loader),
total=len(train_loader),
disable=train_config["disable_prog_bar"],
):

# Moving data to GPU
if self.config["use_cuda"]:
data = place_on_gpu(data)

# Zero the parameter gradients
optimizer.zero_grad()

# import pdb; pdb.set_trace()
# Forward pass to calculate outputs
loss = loss_fn(*data)
if torch.isnan(loss):
Expand Down Expand Up @@ -187,8 +215,12 @@ def _train(self, train_loader, loss_fn, X_dev=None, Y_dev=None):
if evaluate_dev and (epoch % train_config["validation_freq"] == 0):
val_metric = train_config["validation_metric"]
dev_score = self.score(
X_dev, Y_dev, metric=val_metric, verbose=False
dev_loader,
metric=val_metric,
verbose=False,
print_confusion_matrix=False,
)

if train_config["checkpoint"]:
checkpointer.checkpoint(self, epoch, dev_score)

Expand Down Expand Up @@ -220,10 +252,42 @@ def _train(self, train_loader, loss_fn, X_dev=None, Y_dev=None):
# Print confusion matrix if applicable
if self.config["verbose"]:
print("Finished Training")
if evaluate_dev and not self.multitask:
Y_p_dev = self.predict(X_dev)
print("Confusion Matrix (Dev)")
confusion_matrix(Y_p_dev, Y_dev, pretty_print=True)
if evaluate_dev:
self.score(
dev_loader,
metric=["accuracy"],
verbose=True,
print_confusion_matrix=True,
)

def _create_dataset(self, *data):
"""Converts input data to the appropriate Dataset"""
# Make sure data is a tuple of dense tensors
data = [self._to_torch(x, dtype=torch.FloatTensor) for x in data]
return TensorDataset(*data)

def _create_data_loader(self, data, **kwargs):
"""Converts input data into a DataLoader"""
if data is None:
return None

# Set DataLoader config
# NOTE: Not applicable if data is already a DataLoader
config = {
**self.config["train_config"]["data_loader_config"],
**kwargs,
"pin_memory": self.config["use_cuda"],
}

# Return data as DataLoader
if isinstance(data, (tuple, list)):
return DataLoader(self._create_dataset(*data), **config)
elif isinstance(data, Dataset):
return DataLoader(data, **config)
elif isinstance(data, DataLoader):
return data
else:
raise ValueError("Input data type not recognized.")

def _set_optimizer(self, train_config):
optimizer_config = train_config["optimizer_config"]
Expand Down Expand Up @@ -270,32 +334,34 @@ def _set_scheduler(self, scheduler_config, optimizer):

def score(
self,
X,
Y,
data,
metric=["accuracy"],
break_ties="random",
verbose=True,
print_confusion_matrix=True,
**kwargs,
):
"""Scores the predictive performance of the Classifier on all tasks
Args:
X: The input for the predict method
Y: An [n] or [n, 1] torch.Tensor or np.ndarray of target labels in
{1,...,k}
data: a Pytorch DataLoader, Dataset, or tuple with Tensors (X,Y):
X: The input for the predict method
Y: An [n] or [n, 1] torch.Tensor or np.ndarray of target labels
in {1,...,k}
metric: A metric (string) with which to score performance or a
list of such metrics
break_ties: A tie-breaking policy (see Classifier._break_ties())
verbose: The verbosity for just this score method; it will not
update the class config
update the class config.
print_confusion_matrix: Print confusion matrix
Returns:
scores: A (float) score or a list of such scores if kwarg metric
is a list
"""
Y = self._to_numpy(Y)
Y_p = self.predict(X, break_ties=break_ties, **kwargs)
Y_p, Y = self._get_predictions(data, break_ties=break_ties, **kwargs)

# Evaluate on the specified metrics
metric_list = metric if isinstance(metric, list) else [metric]
scores = []
for metric in metric_list:
Expand All @@ -304,11 +370,52 @@ def score(
if verbose:
print(f"{metric.capitalize()}: {score:.3f}")

# Optionally print confusion matrix
if print_confusion_matrix:
confusion_matrix(Y_p, Y, pretty_print=True)

if isinstance(scores, list) and len(scores) == 1:
return scores[0]
else:
return scores

def _get_predictions(self, data, break_ties="random", **kwargs):
"""Computes predictions in batch, given a labeled dataset
Args:
data: a Pytorch DataLoader, Dataset, or tuple with Tensors (X,Y):
X: The input for the predict method
Y: An [n] or [n, 1] torch.Tensor or np.ndarray of target labels
in {1,...,k}
break_ties: How to break ties when making predictions
Returns:
Y_p: A Tensor of predictions
Y: A Tensor of labels
"""
data_loader = self._create_data_loader(data)
Y_p = []
Y = []

# Do batch evaluation by default, getting the predictions and labels
for batch_num, data in enumerate(data_loader):
Xb, Yb = data
Y.append(self._to_numpy(Yb))

# Optionally move to GPU
if self.config["use_cuda"]:
Xb = place_on_gpu(Xb)

# Append predictions and labels from DataLoader
Y_p.append(
self._to_numpy(
self.predict(Xb, break_ties=break_ties, **kwargs)
)
)
Y_p = np.hstack(Y_p)
Y = np.hstack(Y)
return Y_p, Y

def predict(self, X, break_ties="random", **kwargs):
"""Predicts hard (int) labels for an input X on all tasks
Expand All @@ -320,8 +427,7 @@ def predict(self, X, break_ties="random", **kwargs):
An n-dim np.ndarray of predictions in {1,...k}
"""
Y_p = self._to_numpy(self.predict_proba(X, **kwargs))
Y_ph = self._break_ties(Y_p, break_ties)
return Y_ph.astype(np.int)
return self._break_ties(Y_p, break_ties).astype(np.int)

def predict_proba(self, X, **kwargs):
"""Predicts soft probabilistic labels for an input X on all tasks
Expand Down Expand Up @@ -363,15 +469,18 @@ def _break_ties(self, Y_s, break_ties="random"):

@staticmethod
def _to_numpy(Z):
"""Converts a None, list, np.ndarray, or torch.Tensor to np.ndarray"""
"""Converts a None, list, np.ndarray, or torch.Tensor to np.ndarray;
also handles converting sparse input to dense."""
if Z is None:
return Z
elif issparse(Z):
return Z.toarray()
elif isinstance(Z, np.ndarray):
return Z
elif isinstance(Z, list):
return np.array(Z)
elif isinstance(Z, torch.Tensor):
return Z.numpy()
return Z.cpu().numpy()
else:
msg = (
f"Expected None, list, numpy.ndarray or torch.Tensor, "
Expand All @@ -381,9 +490,12 @@ def _to_numpy(Z):

@staticmethod
def _to_torch(Z, dtype=None):
"""Converts a None, list, np.ndarray, or torch.Tensor to torch.Tensor"""
"""Converts a None, list, np.ndarray, or torch.Tensor to torch.Tensor;
also handles converting sparse input to dense."""
if Z is None:
return None
elif issparse(Z):
Z = torch.from_numpy(Z.toarray())
elif isinstance(Z, torch.Tensor):
pass
elif isinstance(Z, list):
Expand Down
4 changes: 2 additions & 2 deletions metal/contrib/featurizers/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
torchtext==0.2.3
ntlk
scikit-learn
nltk
scikit-learn
5 changes: 3 additions & 2 deletions metal/end_model/em_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
"layer_out_dims": [10, 2],
"batchnorm": False,
"dropout": 0.0,
# GPU
"use_cuda": False,
# TRAINING
"train_config": {
# Display
"print_every": 1, # Print after this many epochs
# GPU
"use_cuda": False,
"disable_prog_bar": False, # Disable progress bar each epoch
# Dataloader
"data_loader_config": {"batch_size": 32, "num_workers": 1},
# Train Loop
Expand Down
Loading

0 comments on commit a2bb278

Please sign in to comment.