Skip to content


Code Review FB CrossVal (#406)
Browse files Browse the repository at this point in the history
* Initial commit for cross val for classifiers

* Working k-fold crsoss validation version for logistic regression

* Fixes for classifier apis

* Fix an issue with init's license

* Fix to init before merging master

* tested code. WIP doc and tests

* Updated docs

* Fixed an issue in doctests for grid search and cross val

* Fixed issues in doc tests

* Updated doctests for cross val and grid search to account for updated PropertiesObject repr fix

* Skipping non-deterministic code in doc test

* Updated docs, update code for souper and fix for

* Fixes for lazyloader tests and importing gridvalues

* Adding integration tests

* Removing non-determinstic tests for cross_val

* Code review feedback
  • Loading branch information
abhiwand authored and rodorad committed Jan 7, 2017
1 parent 9ff2da8 commit 970373f
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 56 deletions.
2 changes: 1 addition & 1 deletion integration-tests/tests/
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_random_forest_classifier(tc):
predict_frame = model.predict(f)
assert(set(predict_frame.column_names) == set(['Class', 'Dim_1', 'Dim_2','predicted_class']))
assert(len(predict_frame.column_names) == 4)
metrics = model.test(f)
metrics = model.test(f, 'Class')
assert(metrics.accuracy == 1.0)
assert(metrics.f_measure == 1.0)
assert(metrics.precision == 1.0)
Expand Down
2 changes: 1 addition & 1 deletion integration-tests/tests/
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_svm(tc):
assert(set(predicted_frame.column_names) == set(['data', 'label', 'predicted_label']))
assert(len(predicted_frame.column_names) == 3)
assert(len(f.column_names) == 2)
metrics = model.test(predicted_frame)
metrics = model.test(predicted_frame, 'label')
assert(metrics.accuracy == 1.0)
assert(metrics.f_measure == 1.0)
assert(metrics.precision == 1.0)
Expand Down
52 changes: 36 additions & 16 deletions python/sparktk/models/_selection/
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,22 @@

from sparktk import TkContext
from collections import namedtuple
from sparktk.frame.frame import Frame
from sparktk import arguments
from grid_search import grid_values, expand_kwarg_grids, grid_search, GridPoint, GridSearchResults
from grid_search import grid_search

def cross_validate(frame, train_descriptors, num_folds=3, verbose=False, tc=TkContext.implicit):
Computes k-fold cross validation on model with the given frame and parameter values
:param frame: The frame to perform cross-validation on
:param model_type: The model reference
:param descriptor: Dictionary of model parameters and their value/values in list of type grid_values
:param train_descriptors: Tuple of model and Dictionary of model parameters and their value/values as singleton
values or a list of type grid_values
:param num_folds: Number of folds to run the cross-validator on
:param verbose: Flag indicating if the results of each fold are to be viewed. Default is set to False
:param tc: spark-tk context
:return: Summary of model's performance
:param tc: spark-tk context (provided implicitly)
:return: Summary of model's performance consisting of metrics of each combination of train_descriptor values per fold
and averages across all folds
Expand Down Expand Up @@ -193,8 +193,8 @@ def cross_validate(frame, train_descriptors, num_folds=3, verbose=False, tc=TkCo

all_grid_search_results = []
grid_search_results_accumulator = None
for validate_frame, train_frame in split_data(frame, num_folds , tc):
scores = grid_search(train_frame, validate_frame, train_descriptors, tc)
for train_frame, test_frame in split_data(frame, num_folds , tc):
scores = grid_search(train_frame, test_frame, train_descriptors, tc)
if grid_search_results_accumulator is None:
grid_search_results_accumulator = scores
Expand All @@ -213,35 +213,55 @@ def split_data(frame, num_folds, tc=TkContext.implicit):
Randomly split data based on num_folds specified. Implementation logic borrowed from pyspark.
:param frame: The frame to be split into train and validation frames
:param num_folds: Number of folds to be split into
:param tc: spark-tk context
:return: validation frame and train frame for each fold
:param tc: spark-tk context passed implicitly
:return: train frame and test frame for each fold
from pyspark.sql.functions import rand
df = frame.dataframe
h = 1.0/num_folds
rand_col = "rand_1"
df_indexed ="*", rand(0).alias(rand_col))
for i in xrange(num_folds):
validation_lower_bound = i*h
validation_upper_bound = (i+1)*h
condition = (df_indexed[rand_col] >= validation_lower_bound) & (df_indexed[rand_col] < validation_upper_bound)
validation_df = df_indexed.filter(condition)
test_lower_bound = i*h
test_upper_bound = (i+1)*h
condition = (df_indexed[rand_col] >= test_lower_bound) & (df_indexed[rand_col] < test_upper_bound)
test_df = df_indexed.filter(condition)
train_df = df_indexed.filter(~condition)
train_frame = tc.frame.create(train_df)
validation_frame = tc.frame.create(validation_df)
yield validation_frame, train_frame
test_frame = tc.frame.create(test_df)
yield train_frame, test_frame

class CrossValidateClassificationResults(object):
Class storing the results of cross validation for classification
def __init__(self, all_grid_search_results, averages, verbose=False):
Initializes the CrossValidateClassificationResults object with all the results, averages across folds and
verbosity desired
:param all_grid_search_results: Metrics for all models and their configurations on each fold
:param averages: Average metrics for each model and configurations across all folds
:param verbose: The verbosity desired.
If false, only the averages are displayed.
If true all the results and averages are displayed
self.all_results = all_grid_search_results
self.averages = averages
self.verbose = verbose

def _get_all_str(self):
Method to print all the metrics
:return: All the metrics
return "\n".join(["\n".join([str(point) for point in cm.grid_points]) for cm in self.all_results])

def show_all(self):
Method to show the results for all models and configurations across each fold
:return: The classification metrics for all models and configurations across each fold
return self._get_all_str()

def __repr__(self):
Expand Down
127 changes: 89 additions & 38 deletions python/sparktk/models/_selection/
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,31 @@

import inspect
from collections import namedtuple
from sparktk import TkContext
from sparktk.frame.ops.classification_metrics_value import ClassificationMetricsValue
from sparktk.arguments import extract_call, validate_call, affirm_type, require_type, value_error
from collections import namedtuple
from sparktk.arguments import extract_call, validate_call
from sparktk.frame.frame import Frame
from sparktk import arguments

def grid_values(*args):
Method that returns the args as a named tuple of GridValues and list of args
:param args: Value/s passed for the model's parameter
:return: named tuple of GridValues and list of args
return GridValues(args)

def grid_search(train_frame, test_frame, train_descriptors, tc= TkContext.implicit):
Implements grid search by training the specified model on all combinations of descriptor and testing on test frame
:param train_frame: The frame to train the model on
:param test_frame: The frame to test the model on
:param model_type: The model reference
:param descriptor: Dictionary of model parameters and their value/values in list of type grid_values
:param tc: spark-tk context
:param train_descriptors: Tuple of model and Dictionary of model parameters and their value/values as singleton
values or a list of type grid_values
:param tc: spark-tk context passed implicitly
:return: Summary of metrics for different combinations of the grid and the best performing parameter combination
Expand Down Expand Up @@ -148,47 +156,52 @@ def grid_search(train_frame, test_frame, train_descriptors, tc= TkContext.implic

# validate input
#validate input
descriptors = affirm_type.list_of_anything(train_descriptors, "train_descriptors")
for i in xrange(len(descriptors)):
item = descriptors[i]
if not isinstance(item, TrainDescriptor):
require_type(tuple, item, "item", "grid_search needs a list of items which are either of type TrainDescriptor or tuples of (model, train_kwargs)")
if len(item) != 2:
raise value_error("list requires tuples of len 2", item, "item in train_descriptors")
if not hasattr(item[0], 'train'):
raise value_error("first item in tuple needs to be a object with a 'train' function", item, "item in train_descriptors")
descriptors[i] = TrainDescriptor(item[0], item[1])
if not isinstance(train_descriptors, list):
train_descriptors = [train_descriptors]
descriptors = [TrainDescriptor(x[0], x[1]) for x in train_descriptors if not isinstance(x, TrainDescriptor)]

arguments.require_type(Frame, train_frame, "frame")
arguments.require_type(Frame, test_frame, "frame")

grid_points = []
for descriptor in descriptors:
train_method = getattr(descriptor.model_type, "train")
list_of_kwargs = expand_kwarg_grids([descriptor.kwargs])
list_of_kwargs = expand_kwarg_grids([descriptor.kwargs])
for kwargs in list_of_kwargs:
train_kwargs = dict(kwargs)
train_kwargs['frame'] = train_frame
validate_call(train_method, train_kwargs, ignore_self=True)
model = descriptor.model_type.train(**train_kwargs)
global count
test_kwargs = dict(kwargs)
test_kwargs['frame'] = test_frame
test_kwargs = extract_call(model.test, test_kwargs, ignore_self=True)
metrics = model.test(**test_kwargs)
grid_points.append(GridPoint(descriptor=TrainDescriptor(descriptor.model_type, train_kwargs), metrics=metrics))
count += 1 # sanity count
return GridSearchResults(grid_points)

class MetricsCompare(object):
Class to compare the classification metrics and pick the best performing model configuration based on 'accuracy'

def __init__(self, emphasis="accuracy", compare=None):
Initializes the object of class MetricsCompare
:param emphasis: The metric to be compared on. We are initializing this to 'accuracy'
:param compare: The objects to be compared
""" = compare or self._get_compare(emphasis)

def is_a_better_than_b(self, a, b):
Method to compare two metrics objects
:param a: First object
:param b: Second object
:return: Determines if a is better than b
result =, b)
return result > 0

Expand All @@ -205,20 +218,34 @@ def default_compare(a, b):
return default_compare

class Metrics(ClassificationMetricsValue):
class GridSearchClassificationMetrics(ClassificationMetricsValue):
Class containing the results of grid_search for classification

def __init__(self):
super(Metrics, self).__init__(None, None)
super(GridSearchClassificationMetrics, self).__init__(None, None)

def _divide(self, denominator):
Divides the classification metrics with a value for averaging
:param denominator: The number used to compute the average
:return: Average values for the classification metric
self.precision = (self.precision / float(denominator))
self.accuracy = (self.accuracy / float(denominator))
self.recall = (self.recall / float(denominator))
self.f_measure = (self.f_measure / float(denominator))

def _create_metric_sum(a, b):
metric_sum = Metrics()
Computes the sum of the classification metrics
:param a: First element
:param b: Second element
:return: Sum of the ClassificationMetrics of a and b
metric_sum = GridSearchClassificationMetrics()
metric_sum.accuracy = a.accuracy + b.accuracy
metric_sum.precision = a.precision + b.precision
metric_sum.f_measure = a.f_measure + b.f_measure
Expand All @@ -231,14 +258,15 @@ def _create_metric_sum(a, b):

class TrainDescriptor(object):
"""Describes a train operation: a model type and the arguments for its train method"""
Class that separates the model type and args from the input and handles the representation.

def __init__(self, model_type, kwargs):
Creates a TrainDescriptor
:param model_type: type object representing the model in question
:param kwargs: dict of key-value-pairs holding values for the train method's parameters
Initializes the model_type and model's arguments
:param model_type: The name of the model
:param kwargs: The list of model parameters
self.model_type = model_type
self.kwargs = kwargs
Expand All @@ -253,23 +281,17 @@ def __repr__(self):
return "%s: %s" % (mt, kw)

def grid_values(*args):
return GridValues(args)

count = 0

def expand_kwarg_grids(dictionaries):
Method to expand the dictionary of arguments
:param dictionaries: Parameters for the model
:param dictionaries: Parameters for the model of type (list of dict)
:return: Expanded list of parameters for the model
if not isinstance(dictionaries, list):
raise ValueError("descriptors was not a list but: %s" % dictionaries)
arguments.require_type(list, dictionaries, "dictionaries")
new_dictionaries = []
for dictionary in dictionaries:
for k, v in dictionary.items():
arguments.require_type(dict, dictionary, "item in dictionaries")
if isinstance(v, GridValues):
for a in v.args:
d = dictionary.copy()
Expand All @@ -284,13 +306,25 @@ def expand_kwarg_grids(dictionaries):

class GridSearchResults(object):
Class that stores the results of grid_search. The classification metrics of all model configurations are stored.

def __init__(self, grid_points, metrics_compare=None):
Initializes the GridSearchResults class object with the grid_points and the metric to compare the results
:param grid_points: The results of the grid_search computation
:param metrics_compare: The user specified metric to compare the results on
self.grid_points = grid_points
# add require_type for metrics_compare
self.metrics_compare = metrics_compare or MetricsCompare()

def copy(self):
Copies the GridSearchResults into the desired format
:return: GridSearchResults object
return GridSearchResults([GridPoint(gp.descriptor, gp.metrics) for gp in self.grid_points], self.metrics_compare)

def __repr__(self):
Expand All @@ -313,6 +347,13 @@ def find_best(self, metrics_compare=None):

def _validate_descriptors_are_equal(a, b, ignore_args=None):
Method to compare if the two descriptors being compared are equal
:param a: First descriptor
:param b: Second descriptor
:param ignore_args: List of descriptors that need to be ignored
:return: Checks if the two descriptors being compared are equal and raises errors if they are not
if ignore_args is None:
ignore_args = []
if a.model_type != b.model_type:
Expand All @@ -326,14 +367,24 @@ def _validate_descriptors_are_equal(a, b, ignore_args=None):
raise ValueError("Descriptors a != b because of different values for '%s': %s != %s" % (k, v, b.kwargs[k]))

def _accumulate_matching_points(self, points):
Method to compute the sum of the metrics for a given model and parameter configuration.
:param points: Model and parameter configurations and the computed metrics
:return: sum of the metrics for a given model and parameter configuration
if len(self.grid_points) != len(points):
raise ValueError("Expected list of points of len %s, got %s" % (len(self.grid_points), len(points)))

for index in xrange(len(self.grid_points)):
self._validate_descriptors_are_equal(self.grid_points[index].descriptor, points[index].descriptor, ["frame"])
m = Metrics._create_metric_sum(self.grid_points[index].metrics, points[index].metrics)
m = GridSearchClassificationMetrics._create_metric_sum(self.grid_points[index].metrics, points[index].metrics)
self.grid_points[index] = GridPoint(self.grid_points[index].descriptor, m)

def _divide_metrics(self, denominator):
Method to compute the average metrics
:param denominator: The denominator for avergae computation
:return: Averaged values of the metrics
for point in self.grid_points:

0 comments on commit 970373f

Please sign in to comment.