Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sparse LR support and reorg disc model dir #126

Merged
merged 4 commits into from
Aug 31, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

Added
^^^^^
* `@senwu`_: Add sparse logistic regression support.
* `@j-rausch`_: Added unit tests for changed lingual parsing pipeline.
* `@senwu`_: Support Python 3.7.
* `@lukehsiao`_: Allow user to change featurization settings by providing
Expand All @@ -16,6 +17,8 @@ Added

Changed
^^^^^^^
* `@senwu`_: Reorganize the disc model structure.
(`#126 <https://github.com/HazyResearch/fonduer/pull/126>`_)
* `@j-rausch`_: Speed-up of ``spacy_parser``. We split the lingual parsing
pipeline into two stages. First, we parse structure and gather all
sentences for a document. Then, we merge and feed all sentences per
Expand Down
13 changes: 11 additions & 2 deletions src/fonduer/learning/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
from fonduer.learning.disc_models.logistic_regression import LogisticRegression
from fonduer.learning.disc_models.rnn.lstm import LSTM
from fonduer.learning.disc_models.lstm import LSTM
from fonduer.learning.disc_models.sparse_logistic_regression import (
SparseLogisticRegression
)
from fonduer.learning.gen_learning import GenerativeModel, GenerativeModelAnalyzer

__all__ = ["GenerativeModel", "GenerativeModelAnalyzer", "LogisticRegression", "LSTM"]
__all__ = [
"GenerativeModel",
"GenerativeModelAnalyzer",
"LogisticRegression",
"LSTM",
"SparseLogisticRegression",
]
4 changes: 2 additions & 2 deletions src/fonduer/learning/disc_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,14 @@ def train(
diffs = Y_train.max(axis=1) - Y_train.min(axis=1)
train_idxs = np.where(diffs > 1e-6)[0]

self.model_kwargs = self._update_kwargs(X_train, **kwargs)

_X_train, _Y_train = self._preprocess_data(
X_train, Y_train, idxs=train_idxs, train=True
)
if X_dev is not None:
_X_dev, _Y_dev = self._preprocess_data(X_dev, Y_dev)

self.model_kwargs = self._update_kwargs(_X_train, **kwargs)

if "host_device" not in self.model_kwargs:
self.model_kwargs["host_device"] = "CPU"
self.logger.info("Using CPU...")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
A recurrent neural network model.
A recurrent neural network module.
"""

import torch
Expand Down Expand Up @@ -66,8 +66,6 @@ def forward(self, x, x_mask, state_word):
x : batch_size * length
x_mask : batch_size * length
"""
# print("lstm forward...")
# print(x.size(), x_mask.size(), state_word.size())
x_emb = self.drop(self.lookup(x))
output_word, state_word = self.word_lstm(x_emb, state_word)
output_word = self.drop(output_word)
Expand Down
46 changes: 46 additions & 0 deletions src/fonduer/learning/disc_models/layers/sparse_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""
A sparse linear module.
"""

import math

import torch
import torch.nn as nn


class SparseLinear(nn.Module):
def __init__(self, num_features, num_classes, bias=False, padding_idx=0):

super(SparseLinear, self).__init__()

self.num_features = num_features
self.num_classes = num_classes
self.padding_idx = padding_idx

self.weight = nn.Embedding(
self.num_features, self.num_classes, padding_idx=self.padding_idx
)
if bias:
self.bias = nn.Parameter(torch.Tensor(self.num_classes))
else:
self.bias = None

self.reset_parameters()

def reset_parameters(self):
stdv = 1. / math.sqrt(self.num_features)
self.weight.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
if self.padding_idx is not None:
self.weight.weight.data[self.padding_idx].fill_(0)

def forward(self, x, w):
"""
x : batch_size * length, the feature indices
w : batch_size * length, the weight for each feature
"""
if self.bias is None:
return (w.unsqueeze(2) * self.weight(x)).sum(dim=1)
else:
return (w.unsqueeze(2) * self.weight(x)).sum(dim=1) + self.bias
2 changes: 1 addition & 1 deletion src/fonduer/learning/disc_models/logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def _update_kwargs(self, X, **model_kwargs):
:param X: The input data of the model
:param model_kwargs: The arguments of the model
"""
model_kwargs["input_dim"] = X[0][1].shape[1]
model_kwargs["input_dim"] = X[1].shape[1]
return model_kwargs

def _build_model(self, model_kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from scipy.sparse import issparse

from fonduer.learning.disc_learning import NoiseAwareModel
from fonduer.learning.disc_models.rnn.layers import RNN
from fonduer.learning.disc_models.rnn.utils import (
from fonduer.learning.disc_models.layers.rnn import RNN
from fonduer.learning.disc_models.utils import (
SymbolTable,
mark_sentence,
mention_to_tokens,
Expand Down Expand Up @@ -131,7 +131,7 @@ def _update_kwargs(self, X, **model_kwargs):
model_kwargs[key] = settings[key]

model_kwargs["relation_arity"] = len(X[0][0])
model_kwargs["input_dim"] = X[0][1].shape[1] + len(X[0][0]) * model_kwargs[
model_kwargs["input_dim"] = X[1].shape[1] + len(X[0][0]) * model_kwargs[
"hidden_dim"
] * (2 if model_kwargs["bidirectional"] else 1)

Expand Down
Empty file.
44 changes: 0 additions & 44 deletions src/fonduer/learning/disc_models/rnn/config.py

This file was deleted.

152 changes: 152 additions & 0 deletions src/fonduer/learning/disc_models/sparse_logistic_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import numpy as np
import torch

from fonduer.learning.disc_learning import NoiseAwareModel
from fonduer.learning.disc_models.layers.sparse_linear import SparseLinear
from fonduer.learning.disc_models.utils import pad_batch


class SparseLogisticRegression(NoiseAwareModel):
def forward(self, x, w):
"""
Run forward pass.

:param x: The input (batch) of the model
"""
return self.sparse_linear(x, w)

def _check_input(self, X):
"""
Check input format.

:param X: The input data of the model
"""
return isinstance(X, tuple)

def _preprocess_data(self, X, Y=None, idxs=None, train=False):
"""
Preprocess the data:
1. Convert sparse matrix to dense matrix.
2. Update the order of candidates based on feature index.
3. Select subset of the input if idxs exists.

:param X: The input data of the model
:param X: The labels of input data
"""
C, F = X
print(F.shape)
print(max(F.indices))
id2id = dict()
for i in range(F.shape[0]):
id2id[F.row_index[i]] = i

C_ = [None] * len(C)
for c in C:
C_[id2id[c.id]] = c

if idxs is None:
if Y is not None:
return (
[
(
C_[i],
F.indices[F.indptr[i] : F.indptr[i + 1]],
F.data[F.indptr[i] : F.indptr[i + 1]],
)
for i in range(len(C_))
],
Y,
)
else:
return [
(
C_[i],
F.indices[F.indptr[i] : F.indptr[i + 1]],
F.data[F.indptr[i] : F.indptr[i + 1]],
)
for i in range(len(C_))
]
if Y is not None:
return (
[
(
C_[i],
F.indices[F.indptr[i] : F.indptr[i + 1]],
F.data[F.indptr[i] : F.indptr[i + 1]],
)
for i in idxs
],
Y[idxs],
)
else:
return [
(
C_[i],
F.indices[F.indptr[i] : F.indptr[i + 1]],
F.data[F.indptr[i] : F.indptr[i + 1]],
)
for i in idxs
]

def _update_kwargs(self, X, **model_kwargs):
"""
Update the model argument.

:param X: The input data of the model
:param model_kwargs: The arguments of the model
"""
# Add one feature for padding vector (all 0s)
model_kwargs["input_dim"] = X[1].shape[1] + 1
return model_kwargs

def _build_model(self, model_kwargs):
"""
Build the model.

:param model_kwargs: The arguments of the model
"""
if "input_dim" not in model_kwargs:
raise ValueError("Kwarg input_dim cannot be None.")

cardinality = self.cardinality if self.cardinality > 2 else 1
bias = False if "bias" not in model_kwargs else model_kwargs["bias"]

self.sparse_linear = SparseLinear(model_kwargs["input_dim"], cardinality, bias)

def _calc_logits(self, X, batch_size=None):
"""
Calculate the logits.

:param X: The input data of the model
:param batch_size: The batch size
"""
# Generate sparse multi-modal feature input
F = np.array(list(zip(*X))[1]) + 1 # Correct the index since 0 is the padding
V = np.array(list(zip(*X))[2])

outputs = (
torch.Tensor([]).cuda()
if self.model_kwargs["host_device"] in self.gpu
else torch.Tensor([])
)

n = len(F)
if batch_size is None:
batch_size = n
for batch_st in range(0, n, batch_size):
batch_ed = batch_st + batch_size if batch_st + batch_size <= n else n

features, _ = pad_batch(F[batch_st:batch_ed], 0)
values, _ = pad_batch(V[batch_st:batch_ed], 0, type="float")

if self.model_kwargs["host_device"] in self.gpu:
features = features.cuda()
values = values.cuda()

output = self.forward(features, values)
if self.cardinality == 2:
outputs = torch.cat((outputs, output.view(-1)), 0)
else:
outputs = torch.cat((outputs, output), 0)

return outputs
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,17 @@ def mark_sentence(s, args):
return x


def pad_batch(batch, max_len):
def pad_batch(batch, max_len=0, type="int"):
"""Pad the batch into matrix"""
batch_size = len(batch)
max_sent_len = min(int(np.max([len(x) for x in batch])), max_len)
idx_matrix = np.zeros((batch_size, max_sent_len), dtype=np.int)
max_sent_len = int(np.max([len(x) for x in batch]))
if max_len > 0 and max_len < max_sent_len:
max_sent_len = max_len
if type == "float":
idx_matrix = np.zeros((batch_size, max_sent_len), dtype=np.float32)
else:
idx_matrix = np.zeros((batch_size, max_sent_len), dtype=np.int)

for idx1, i in enumerate(batch):
for idx2, j in enumerate(i):
if idx2 >= max_sent_len:
Expand Down
Loading