Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding GPU functionality, enabling data loaders/tuples as input #38

Merged
merged 38 commits into from
Oct 12, 2018
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
2153b73
fixed nltk typo, fixed featurizer init in contrib, added gpu support …
jdunnmon Aug 28, 2018
df50a3d
re-add pre commit hook
jdunnmon Aug 28, 2018
1f53c78
re-add pre commit hook
jdunnmon Aug 28, 2018
7bc4ec0
updated score method to handle cuda input
jdunnmon Aug 29, 2018
c3a8e59
added disable_prog_bar option in classifier
jdunnmon Aug 29, 2018
5dd17b3
Merge branch 'master' into dev_cuda_loader
jdunnmon Aug 31, 2018
a68fa17
changes to end model and classifier to allow for use of tuple or Data…
jdunnmon Sep 2, 2018
d44ff5a
typo in docstring
jdunnmon Sep 2, 2018
be8ea6a
removed precommit hook
jdunnmon Sep 2, 2018
a4031e4
re-added pre-commit hook
jdunnmon Sep 2, 2018
362b523
style fix
jdunnmon Sep 2, 2018
7d61e42
refactor to break ties
jdunnmon Sep 3, 2018
9610ba5
updated multitask classifier
jdunnmon Oct 10, 2018
4a608c7
doc update
jdunnmon Oct 10, 2018
f43d804
typo
jdunnmon Oct 10, 2018
0375ad9
tests passing
jdunnmon Oct 11, 2018
7b3d399
updated tutorials
jdunnmon Oct 11, 2018
9996c91
- Got rid of _evaluate methods, transform everything to a DataLoader
ajratner Oct 11, 2018
ab9aae7
- Separating get_predictions back out for code reuse
ajratner Oct 11, 2018
053a473
Moved "use_cuda" to global configs
ajratner Oct 12, 2018
2c010dd
Adding cuda availability check
ajratner Oct 12, 2018
0f1fa36
Merge branch 'master' into dev_cuda_loader
ajratner Oct 12, 2018
bf35050
Adding new citation
ajratner Oct 12, 2018
3327f2d
Merge branch 'dev_cuda_loader' of https://github.com/HazyResearch/met…
ajratner Oct 12, 2018
126f55d
Bug fix post-merge
ajratner Oct 12, 2018
e300967
fixed issues with cuda allocation, training on GPU
jdunnmon Oct 12, 2018
84459db
- Handle converting sparse matrices in Classifier.create_dataset
ajratner Oct 12, 2018
5d1387e
Fixed MajorityLabelVoter, added to tests
ajratner Oct 12, 2018
e2beb92
Cleaning up Basics tutorial
ajratner Oct 12, 2018
aa281c6
Fixing bug in MTClassifier.score
ajratner Oct 12, 2018
67fd6f6
added gpu test
jdunnmon Oct 12, 2018
1437541
Reverting to Y_p
ajratner Oct 12, 2018
20fc090
Cleaned up loss handling of CUDA
ajratner Oct 12, 2018
f261b33
Docstring fix
ajratner Oct 12, 2018
419d844
Cleaning up place_on_gpu function (pending GPU testing)
ajratner Oct 12, 2018
607d3bd
Install requirements + README note for GPU tests
ajratner Oct 12, 2018
5646a63
Minor addition of readme to tests/gpu
ajratner Oct 12, 2018
754a63f
Typo fix
ajratner Oct 12, 2018
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .isort.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ include_trailing_comma=True
force_grid_wrap=0
combine_as_imports=True
line_length=80
known_third_party=matplotlib,networkx,nltk,numpy,pandas,scipy,setuptools,sklearn,torch,torchtext
known_third_party=matplotlib,networkx,nltk,numpy,pandas,scipy,setuptools,sklearn,torch,torchtext,tqdm
3 changes: 2 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ dependencies:
- pandas
- pytorch=0.4.1
- runipy
- scipy
- scipy
- tqdm
127 changes: 111 additions & 16 deletions metal/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm

from metal.analysis import confusion_matrix
from metal.metrics import metric_score
Expand Down Expand Up @@ -110,21 +112,21 @@ def train(self, *args, **kwargs):
"""
raise NotImplementedError

def _train(self, train_loader, loss_fn, X_dev=None, Y_dev=None):
def _train(self, train_loader, loss_fn, dev_loader=None):
"""The internal training routine called by train() after initial setup

Args:
train_loader: a torch DataLoader of X (data) and Y (labels) for
the train split
loss_fn: the loss function to minimize (maps *data -> loss)
X_dev: the dev set model input
Y_dev: the dev set target labels
dev_loader: a torch DataLoader of X (data) and Y (labels) for
the dev split

If either of X_dev or Y_dev is not provided, then no checkpointing or
If dev_loader is not provided, then no checkpointing or
evaluation on the dev set will occur.
"""
train_config = self.config["train_config"]
evaluate_dev = X_dev is not None and Y_dev is not None
evaluate_dev = dev_loader is not None

# Set the optimizer
optimizer_config = train_config["optimizer_config"]
Expand All @@ -142,10 +144,27 @@ def _train(self, train_loader, loss_fn, X_dev=None, Y_dev=None):
model_class, **checkpoint_config, verbose=self.config["verbose"]
)

# Moving model to GPU
if train_config["use_cuda"]:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't need all these spaces

if self.config["verbose"]:
print("Using GPU...")

self.cuda()

# Train the model
for epoch in range(train_config["n_epochs"]):
epoch_loss = 0.0
for data in train_loader:
for batch, data in tqdm(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'batch' can be a lot of things--e.g., data, or a counter. In this case, it's a counter, right? So maybe, batch_num.

enumerate(train_loader),
ajratner marked this conversation as resolved.
Show resolved Hide resolved
total=len(train_loader),
disable=train_config["disable_prog_bar"],
):

# moving data to GPU
if train_config["use_cuda"]:
data = [d.cuda() for d in data]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

data is a Tensor here, so you can just do data.cuda()?


# Zero the parameter gradients
optimizer.zero_grad()

Expand Down Expand Up @@ -195,8 +214,9 @@ def _train(self, train_loader, loss_fn, X_dev=None, Y_dev=None):
if evaluate_dev and (epoch % train_config["validation_freq"] == 0):
val_metric = train_config["validation_metric"]
dev_score = self.score(
X_dev, Y_dev, metric=val_metric, verbose=False
dev_loader, metric=val_metric, verbose=False
)

if train_config["checkpoint"]:
checkpointer.checkpoint(self, epoch, dev_score)

Expand Down Expand Up @@ -229,7 +249,8 @@ def _train(self, train_loader, loss_fn, X_dev=None, Y_dev=None):
print("Finished Training")

if evaluate_dev:
Y_p_dev = self.predict(X_dev)
# Currently use default random break ties in evaluate
Y_p_dev, Y_dev = self.evaluate(dev_loader)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still not sure we need a new fn name here, will come back to...


if not self.multitask:
print("Confusion Matrix (Dev)")
Expand Down Expand Up @@ -271,10 +292,84 @@ def _set_scheduler(self, scheduler_config, optimizer):
)
return lr_scheduler

def _batch_evaluate(self, loader, break_ties="random", **kwargs):
"""Evaluates the model using minibatches

Args:
loader: Pytorch DataLoader supplying (X,Y):
X: The input for the predict method
Y: An [n] or [n, 1] torch.Tensor or np.ndarray of target labels
in {1,...,k}; can be None for cases with no ground truth

Returns:
Y_p: an np.ndarray of predictions
Y: an np.ndarray of ground truth labels
"""
Y = []
Y_p = []
for batch, data in enumerate(loader):
X_batch, Y_batch = data

if self.config["train_config"]["use_cuda"]:
X_batch = X_batch.cuda()

Y_batch = self._to_numpy(Y_batch)

if Y_batch.ndim > 1:
Y_batch = self._break_ties(Y_batch, break_ties)

Y.append(Y_batch)
Y_p.append(
self._to_numpy(
self.predict(X_batch, break_ties=break_ties, **kwargs)
)
)

Y = np.hstack(Y)
Y_p = np.hstack(Y_p)

return Y_p, Y

def evaluate(self, data, break_ties="random", **kwargs):
"""Evaluates the model

Args:
data: either a Pytorch DataLoader or tuple supplying (X,Y):
X: The input for the predict method
Y: An [n] or [n, 1] torch.Tensor or np.ndarray of target labels
in {1,...,k}

Returns:
Y_p: an np.ndarray of predictions
Y: an np.ndarray of ground truth labels
"""

if type(data) is tuple:
X, Y = data

if self.config["train_config"]["use_cuda"]:
X = X.cuda()

Y = self._to_numpy(Y)

if Y.ndim > 1:
Y = self._break_ties(Y, break_ties)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would we ever have ties in the eval data? I don't think we should support this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed--if it's called Y and not Y_p, then it's integer values; no ties to break.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah so apparently we're passing these in as one-hots in some of the tests- I'm going to change it to ints. Y_train might be in vector form, but in single task, Y used for eval should be ints (we should also be clear about this)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nvm just due to changes below


Y_p = self.predict(X, break_ties=break_ties, **kwargs)

elif type(data) is DataLoader:
Y_p, Y = self._batch_evaluate(data, break_ties=break_ties)

else:
raise ValueError(
"Unrecognized input data structure, use tuple or DataLoader!"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure we need an exclamation point. :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lol, sounds good. Will update

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am still very confused when I see evaluate, score, and predict... so at very least: need better docstrings and/or names. But, I think it is KEY to avoid code bloat at all costs!! So thinking if way to merge...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, so I really don't think there should be three separate methods evaluate, evaluate_batch, and predict. Can all just be predict. As far as I see:

  • First, can just check and turn any tuple into a DataLoader with batch_size=all, so everything is batch eval; that combines two methods
  • Next, why do we need to return Y? Seems like much simpler to just pull it out of the Dataset/Dataloader if needed?

Happy to be proven wrong but really think we can find a way to make this all be a single, simple predict method

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I proposed the same before--no need for evaluate. predict() handles all predicting, and score() just calls predict and calculates metrics. Batching is all handled by the DataLoader (which we either construct using their config options or which is passed in by them).

)

return Y_p, Y

def score(
self,
X,
Y,
data,
metric=["accuracy"],
break_ties="random",
verbose=True,
Expand All @@ -283,9 +378,10 @@ def score(
"""Scores the predictive performance of the Classifier on all tasks

Args:
X: The input for the predict method
Y: An [n] or [n, 1] torch.Tensor or np.ndarray of target labels in
{1,...,k}
data: either a Pytorch DataLoader or tuple supplying (X,Y):
X: The input for the predict method
Y: An [n] or [n, 1] torch.Tensor or np.ndarray of target labels
in {1,...,k}
metric: A metric (string) with which to score performance or a
list of such metrics
break_ties: How to break ties when making predictions
Expand All @@ -295,9 +391,8 @@ def score(
Returns:
scores: A (float) score
"""
Y = self._to_numpy(Y)
Y_p = self.predict(X, break_ties=break_ties, **kwargs)

Y_p, Y = self.evaluate(data, break_ties=break_ties)
metric_list = metric if isinstance(metric, list) else [metric]
scores = []
for metric in metric_list:
Expand Down Expand Up @@ -373,7 +468,7 @@ def _to_numpy(Z):
elif isinstance(Z, list):
return np.array(Z)
elif isinstance(Z, torch.Tensor):
return Z.numpy()
return Z.cpu().numpy()
else:
msg = (
f"Expected None, list, numpy.ndarray or torch.Tensor, "
Expand Down
4 changes: 2 additions & 2 deletions metal/contrib/featurizers/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
torchtext==0.2.3
ntlk
scikit-learn
nltk
scikit-learn
1 change: 1 addition & 0 deletions metal/end_model/em_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"train_config": {
# Display
"print_every": 1, # Print after this many epochs
"disable_prog_bar": False, # Disable progress bar each epoch
# GPU
"use_cuda": False,
# Dataloader
Expand Down
45 changes: 33 additions & 12 deletions metal/end_model/end_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
class EndModel(Classifier):
"""A dynamically constructed discriminative classifier

Args:
layer_out_dims: a list of integers corresponding to the output sizes
of the layers of your network. The first element is the
dimensionality of the input layer, the last element is the
Expand Down Expand Up @@ -72,7 +71,9 @@ def _build(self, input_module, middle_modules, head_module):
self.network = nn.Sequential(input_layer, *middle_layers, head)

# Construct loss module
self.criteria = SoftCrossEntropyLoss(reduction="sum")
self.criteria = SoftCrossEntropyLoss(
reduction="sum", use_cuda=self.config["train_config"]["use_cuda"]
ajratner marked this conversation as resolved.
Show resolved Hide resolved
)

def _build_input_layer(self, input_module):
if input_module is None:
Expand Down Expand Up @@ -164,19 +165,39 @@ def _make_data_loader(self, X, Y, data_loader_config):
return data_loader

def _get_loss_fn(self):
loss_fn = lambda X, Y: self.criteria(self.forward(X), Y)
if hasattr(self.config, "use_cuda"):
if self.config["use_cuda"]:
criteria = self.criteria.cuda()
ajratner marked this conversation as resolved.
Show resolved Hide resolved
else:
criteria = self.criteria
loss_fn = lambda X, Y: criteria(self.forward(X), Y)

return loss_fn

def train(self, X_train, Y_train, X_dev=None, Y_dev=None, **kwargs):
self.config = recursive_merge_dicts(self.config, kwargs)
train_config = self.config["train_config"]
def _convert_input_data(self, data):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this to Classifier maybe?

if type(data) is tuple:
X, Y = data
Y = self._to_torch(Y, dtype=torch.FloatTensor)
loader_config = self.config["train_config"]["data_loader_config"]
loader = self._make_data_loader(X, Y, loader_config)
elif type(data) is DataLoader:
loader = data
else:
raise ValueError(
"Unrecognized input data structure, use tuple or DataLoader."
)
return loader

def train(self, train_data, dev_data=None, **kwargs):

Y_train = self._to_torch(Y_train, dtype=torch.FloatTensor)
Y_dev = self._to_torch(Y_dev)
self.config = recursive_merge_dicts(self.config, kwargs)

# Make data loaders
loader_config = train_config["data_loader_config"]
train_loader = self._make_data_loader(X_train, Y_train, loader_config)
# Convert input data to data loaders
train_loader = self._convert_input_data(train_data)
if dev_data is not None:
dev_loader = self._convert_input_data(dev_data)
else:
dev_loader = None

# Initialize the model
self.reset()
Expand All @@ -185,7 +206,7 @@ def train(self, X_train, Y_train, X_dev=None, Y_dev=None, **kwargs):
loss_fn = self._get_loss_fn()

# Execute training procedure
self._train(train_loader, loss_fn, X_dev=X_dev, Y_dev=Y_dev)
self._train(train_loader, loss_fn, dev_loader=dev_loader)

def predict_proba(self, X):
"""Returns a [n, k] tensor of soft (float) predictions."""
Expand Down
9 changes: 8 additions & 1 deletion metal/end_model/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,24 @@ class SoftCrossEntropyLoss(nn.Module):
target: An [n, k] float tensor of target probabilities
"""

def __init__(self, weight=None, reduction="elementwise_mean"):
def __init__(
self, weight=None, reduction="elementwise_mean", use_cuda=False
):
super().__init__()
assert weight is None or isinstance(weight, torch.FloatTensor)
self.weight = weight
self.reduction = reduction
self.use_cuda = use_cuda
ajratner marked this conversation as resolved.
Show resolved Hide resolved

def forward(self, input, target):
n, k = input.shape
cum_losses = torch.zeros(n)
if self.use_cuda:
cum_losses = cum_losses.cuda()
for y in range(k):
cls_idx = torch.full((n,), y, dtype=torch.long)
if self.use_cuda:
cls_idx = cls_idx.cuda()
y_loss = F.cross_entropy(input, cls_idx, reduction="none")
if self.weight is not None:
y_loss = y_loss * self.weight[y]
Expand Down
3 changes: 3 additions & 0 deletions metal/label_model/lm_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,8 @@
# Train loop
"n_epochs": 100,
"print_every": 10,
"disable_prog_bar": True, # Disable progress bar each epoch
# GPU
"use_cuda": False,
},
}
Loading