-
Notifications
You must be signed in to change notification settings - Fork 79
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding GPU functionality, enabling data loaders/tuples as input #38
Changes from 17 commits
2153b73
df50a3d
1f53c78
7bc4ec0
c3a8e59
5dd17b3
a68fa17
d44ff5a
be8ea6a
a4031e4
362b523
7d61e42
9610ba5
4a608c7
f43d804
0375ad9
7b3d399
9996c91
ab9aae7
053a473
2c010dd
0f1fa36
bf35050
3327f2d
126f55d
e300967
84459db
5d1387e
e2beb92
aa281c6
67fd6f6
1437541
20fc090
f261b33
419d844
607d3bd
5646a63
754a63f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,4 +13,5 @@ dependencies: | |
- pandas | ||
- pytorch=0.4.1 | ||
- runipy | ||
- scipy | ||
- scipy | ||
- tqdm |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,8 @@ | |
import torch | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
from torch.utils.data.dataloader import DataLoader | ||
from tqdm import tqdm | ||
|
||
from metal.analysis import confusion_matrix | ||
from metal.metrics import metric_score | ||
|
@@ -110,21 +112,21 @@ def train(self, *args, **kwargs): | |
""" | ||
raise NotImplementedError | ||
|
||
def _train(self, train_loader, loss_fn, X_dev=None, Y_dev=None): | ||
def _train(self, train_loader, loss_fn, dev_loader=None): | ||
"""The internal training routine called by train() after initial setup | ||
|
||
Args: | ||
train_loader: a torch DataLoader of X (data) and Y (labels) for | ||
the train split | ||
loss_fn: the loss function to minimize (maps *data -> loss) | ||
X_dev: the dev set model input | ||
Y_dev: the dev set target labels | ||
dev_loader: a torch DataLoader of X (data) and Y (labels) for | ||
the dev split | ||
|
||
If either of X_dev or Y_dev is not provided, then no checkpointing or | ||
If dev_loader is not provided, then no checkpointing or | ||
evaluation on the dev set will occur. | ||
""" | ||
train_config = self.config["train_config"] | ||
evaluate_dev = X_dev is not None and Y_dev is not None | ||
evaluate_dev = dev_loader is not None | ||
|
||
# Set the optimizer | ||
optimizer_config = train_config["optimizer_config"] | ||
|
@@ -142,10 +144,27 @@ def _train(self, train_loader, loss_fn, X_dev=None, Y_dev=None): | |
model_class, **checkpoint_config, verbose=self.config["verbose"] | ||
) | ||
|
||
# Moving model to GPU | ||
if train_config["use_cuda"]: | ||
|
||
if self.config["verbose"]: | ||
print("Using GPU...") | ||
|
||
self.cuda() | ||
|
||
# Train the model | ||
for epoch in range(train_config["n_epochs"]): | ||
epoch_loss = 0.0 | ||
for data in train_loader: | ||
for batch, data in tqdm( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 'batch' can be a lot of things--e.g., data, or a counter. In this case, it's a counter, right? So maybe, batch_num. |
||
enumerate(train_loader), | ||
ajratner marked this conversation as resolved.
Show resolved
Hide resolved
|
||
total=len(train_loader), | ||
disable=train_config["disable_prog_bar"], | ||
): | ||
|
||
# moving data to GPU | ||
if train_config["use_cuda"]: | ||
data = [d.cuda() for d in data] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
# Zero the parameter gradients | ||
optimizer.zero_grad() | ||
|
||
|
@@ -195,8 +214,9 @@ def _train(self, train_loader, loss_fn, X_dev=None, Y_dev=None): | |
if evaluate_dev and (epoch % train_config["validation_freq"] == 0): | ||
val_metric = train_config["validation_metric"] | ||
dev_score = self.score( | ||
X_dev, Y_dev, metric=val_metric, verbose=False | ||
dev_loader, metric=val_metric, verbose=False | ||
) | ||
|
||
if train_config["checkpoint"]: | ||
checkpointer.checkpoint(self, epoch, dev_score) | ||
|
||
|
@@ -229,7 +249,8 @@ def _train(self, train_loader, loss_fn, X_dev=None, Y_dev=None): | |
print("Finished Training") | ||
|
||
if evaluate_dev: | ||
Y_p_dev = self.predict(X_dev) | ||
# Currently use default random break ties in evaluate | ||
Y_p_dev, Y_dev = self.evaluate(dev_loader) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Still not sure we need a new fn name here, will come back to... |
||
|
||
if not self.multitask: | ||
print("Confusion Matrix (Dev)") | ||
|
@@ -271,10 +292,84 @@ def _set_scheduler(self, scheduler_config, optimizer): | |
) | ||
return lr_scheduler | ||
|
||
def _batch_evaluate(self, loader, break_ties="random", **kwargs): | ||
"""Evaluates the model using minibatches | ||
|
||
Args: | ||
loader: Pytorch DataLoader supplying (X,Y): | ||
X: The input for the predict method | ||
Y: An [n] or [n, 1] torch.Tensor or np.ndarray of target labels | ||
in {1,...,k}; can be None for cases with no ground truth | ||
|
||
Returns: | ||
Y_p: an np.ndarray of predictions | ||
Y: an np.ndarray of ground truth labels | ||
""" | ||
Y = [] | ||
Y_p = [] | ||
for batch, data in enumerate(loader): | ||
X_batch, Y_batch = data | ||
|
||
if self.config["train_config"]["use_cuda"]: | ||
X_batch = X_batch.cuda() | ||
|
||
Y_batch = self._to_numpy(Y_batch) | ||
|
||
if Y_batch.ndim > 1: | ||
Y_batch = self._break_ties(Y_batch, break_ties) | ||
|
||
Y.append(Y_batch) | ||
Y_p.append( | ||
self._to_numpy( | ||
self.predict(X_batch, break_ties=break_ties, **kwargs) | ||
) | ||
) | ||
|
||
Y = np.hstack(Y) | ||
Y_p = np.hstack(Y_p) | ||
|
||
return Y_p, Y | ||
|
||
def evaluate(self, data, break_ties="random", **kwargs): | ||
"""Evaluates the model | ||
|
||
Args: | ||
data: either a Pytorch DataLoader or tuple supplying (X,Y): | ||
X: The input for the predict method | ||
Y: An [n] or [n, 1] torch.Tensor or np.ndarray of target labels | ||
in {1,...,k} | ||
|
||
Returns: | ||
Y_p: an np.ndarray of predictions | ||
Y: an np.ndarray of ground truth labels | ||
""" | ||
|
||
if type(data) is tuple: | ||
X, Y = data | ||
|
||
if self.config["train_config"]["use_cuda"]: | ||
X = X.cuda() | ||
|
||
Y = self._to_numpy(Y) | ||
|
||
if Y.ndim > 1: | ||
Y = self._break_ties(Y, break_ties) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why would we ever have ties in the eval data? I don't think we should support this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed--if it's called Y and not Y_p, then it's integer values; no ties to break. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah so apparently we're passing these in as one-hots in some of the tests- I'm going to change it to ints. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nvm just due to changes below |
||
|
||
Y_p = self.predict(X, break_ties=break_ties, **kwargs) | ||
|
||
elif type(data) is DataLoader: | ||
Y_p, Y = self._batch_evaluate(data, break_ties=break_ties) | ||
|
||
else: | ||
raise ValueError( | ||
"Unrecognized input data structure, use tuple or DataLoader!" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure we need an exclamation point. :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lol, sounds good. Will update There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am still very confused when I see There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay, so I really don't think there should be three separate methods
Happy to be proven wrong but really think we can find a way to make this all be a single, simple There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I proposed the same before--no need for evaluate. predict() handles all predicting, and score() just calls predict and calculates metrics. Batching is all handled by the DataLoader (which we either construct using their config options or which is passed in by them). |
||
) | ||
|
||
return Y_p, Y | ||
|
||
def score( | ||
self, | ||
X, | ||
Y, | ||
data, | ||
metric=["accuracy"], | ||
break_ties="random", | ||
verbose=True, | ||
|
@@ -283,9 +378,10 @@ def score( | |
"""Scores the predictive performance of the Classifier on all tasks | ||
|
||
Args: | ||
X: The input for the predict method | ||
Y: An [n] or [n, 1] torch.Tensor or np.ndarray of target labels in | ||
{1,...,k} | ||
data: either a Pytorch DataLoader or tuple supplying (X,Y): | ||
X: The input for the predict method | ||
Y: An [n] or [n, 1] torch.Tensor or np.ndarray of target labels | ||
in {1,...,k} | ||
metric: A metric (string) with which to score performance or a | ||
list of such metrics | ||
break_ties: How to break ties when making predictions | ||
|
@@ -295,9 +391,8 @@ def score( | |
Returns: | ||
scores: A (float) score | ||
""" | ||
Y = self._to_numpy(Y) | ||
Y_p = self.predict(X, break_ties=break_ties, **kwargs) | ||
|
||
Y_p, Y = self.evaluate(data, break_ties=break_ties) | ||
metric_list = metric if isinstance(metric, list) else [metric] | ||
scores = [] | ||
for metric in metric_list: | ||
|
@@ -373,7 +468,7 @@ def _to_numpy(Z): | |
elif isinstance(Z, list): | ||
return np.array(Z) | ||
elif isinstance(Z, torch.Tensor): | ||
return Z.numpy() | ||
return Z.cpu().numpy() | ||
else: | ||
msg = ( | ||
f"Expected None, list, numpy.ndarray or torch.Tensor, " | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
torchtext==0.2.3 | ||
ntlk | ||
scikit-learn | ||
nltk | ||
scikit-learn |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,7 +13,6 @@ | |
class EndModel(Classifier): | ||
"""A dynamically constructed discriminative classifier | ||
|
||
Args: | ||
layer_out_dims: a list of integers corresponding to the output sizes | ||
of the layers of your network. The first element is the | ||
dimensionality of the input layer, the last element is the | ||
|
@@ -72,7 +71,9 @@ def _build(self, input_module, middle_modules, head_module): | |
self.network = nn.Sequential(input_layer, *middle_layers, head) | ||
|
||
# Construct loss module | ||
self.criteria = SoftCrossEntropyLoss(reduction="sum") | ||
self.criteria = SoftCrossEntropyLoss( | ||
reduction="sum", use_cuda=self.config["train_config"]["use_cuda"] | ||
ajratner marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
|
||
def _build_input_layer(self, input_module): | ||
if input_module is None: | ||
|
@@ -164,19 +165,39 @@ def _make_data_loader(self, X, Y, data_loader_config): | |
return data_loader | ||
|
||
def _get_loss_fn(self): | ||
loss_fn = lambda X, Y: self.criteria(self.forward(X), Y) | ||
if hasattr(self.config, "use_cuda"): | ||
if self.config["use_cuda"]: | ||
criteria = self.criteria.cuda() | ||
ajratner marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
criteria = self.criteria | ||
loss_fn = lambda X, Y: criteria(self.forward(X), Y) | ||
|
||
return loss_fn | ||
|
||
def train(self, X_train, Y_train, X_dev=None, Y_dev=None, **kwargs): | ||
self.config = recursive_merge_dicts(self.config, kwargs) | ||
train_config = self.config["train_config"] | ||
def _convert_input_data(self, data): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Move this to |
||
if type(data) is tuple: | ||
X, Y = data | ||
Y = self._to_torch(Y, dtype=torch.FloatTensor) | ||
loader_config = self.config["train_config"]["data_loader_config"] | ||
loader = self._make_data_loader(X, Y, loader_config) | ||
elif type(data) is DataLoader: | ||
loader = data | ||
else: | ||
raise ValueError( | ||
"Unrecognized input data structure, use tuple or DataLoader." | ||
) | ||
return loader | ||
|
||
def train(self, train_data, dev_data=None, **kwargs): | ||
|
||
Y_train = self._to_torch(Y_train, dtype=torch.FloatTensor) | ||
Y_dev = self._to_torch(Y_dev) | ||
self.config = recursive_merge_dicts(self.config, kwargs) | ||
|
||
# Make data loaders | ||
loader_config = train_config["data_loader_config"] | ||
train_loader = self._make_data_loader(X_train, Y_train, loader_config) | ||
# Convert input data to data loaders | ||
train_loader = self._convert_input_data(train_data) | ||
if dev_data is not None: | ||
dev_loader = self._convert_input_data(dev_data) | ||
else: | ||
dev_loader = None | ||
|
||
# Initialize the model | ||
self.reset() | ||
|
@@ -185,7 +206,7 @@ def train(self, X_train, Y_train, X_dev=None, Y_dev=None, **kwargs): | |
loss_fn = self._get_loss_fn() | ||
|
||
# Execute training procedure | ||
self._train(train_loader, loss_fn, X_dev=X_dev, Y_dev=Y_dev) | ||
self._train(train_loader, loss_fn, dev_loader=dev_loader) | ||
|
||
def predict_proba(self, X): | ||
"""Returns a [n, k] tensor of soft (float) predictions.""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't need all these spaces