diff --git a/integration-tests/tests/test_random_forest_classifier.py b/integration-tests/tests/test_random_forest_classifier.py index 0a8360ec..03966675 100644 --- a/integration-tests/tests/test_random_forest_classifier.py +++ b/integration-tests/tests/test_random_forest_classifier.py @@ -62,7 +62,7 @@ def test_random_forest_classifier(tc): predict_frame = model.predict(f) assert(set(predict_frame.column_names) == set(['Class', 'Dim_1', 'Dim_2','predicted_class'])) assert(len(predict_frame.column_names) == 4) - metrics = model.test(f) + metrics = model.test(f, 'Class') assert(metrics.accuracy == 1.0) assert(metrics.f_measure == 1.0) assert(metrics.precision == 1.0) diff --git a/integration-tests/tests/test_svm.py b/integration-tests/tests/test_svm.py index 73aa2bb6..50baf31d 100644 --- a/integration-tests/tests/test_svm.py +++ b/integration-tests/tests/test_svm.py @@ -56,7 +56,7 @@ def test_svm(tc): assert(set(predicted_frame.column_names) == set(['data', 'label', 'predicted_label'])) assert(len(predicted_frame.column_names) == 3) assert(len(f.column_names) == 2) - metrics = model.test(predicted_frame) + metrics = model.test(predicted_frame, 'label') assert(metrics.accuracy == 1.0) assert(metrics.f_measure == 1.0) assert(metrics.precision == 1.0) diff --git a/python/sparktk/models/_selection/cross_validate.py b/python/sparktk/models/_selection/cross_validate.py index c80e4b8b..b18650b5 100644 --- a/python/sparktk/models/_selection/cross_validate.py +++ b/python/sparktk/models/_selection/cross_validate.py @@ -17,22 +17,22 @@ from sparktk import TkContext -from collections import namedtuple from sparktk.frame.frame import Frame from sparktk import arguments -from grid_search import grid_values, expand_kwarg_grids, grid_search, GridPoint, GridSearchResults +from grid_search import grid_search def cross_validate(frame, train_descriptors, num_folds=3, verbose=False, tc=TkContext.implicit): """ Computes k-fold cross validation on model with the given frame and parameter values :param frame: The frame to perform cross-validation on - :param model_type: The model reference - :param descriptor: Dictionary of model parameters and their value/values in list of type grid_values + :param train_descriptors: Tuple of model and Dictionary of model parameters and their value/values as singleton + values or a list of type grid_values :param num_folds: Number of folds to run the cross-validator on :param verbose: Flag indicating if the results of each fold are to be viewed. Default is set to False - :param tc: spark-tk context - :return: Summary of model's performance + :param tc: spark-tk context (provided implicitly) + :return: Summary of model's performance consisting of metrics of each combination of train_descriptor values per fold + and averages across all folds Example ------- @@ -193,8 +193,8 @@ def cross_validate(frame, train_descriptors, num_folds=3, verbose=False, tc=TkCo all_grid_search_results = [] grid_search_results_accumulator = None - for validate_frame, train_frame in split_data(frame, num_folds , tc): - scores = grid_search(train_frame, validate_frame, train_descriptors, tc) + for train_frame, test_frame in split_data(frame, num_folds , tc): + scores = grid_search(train_frame, test_frame, train_descriptors, tc) if grid_search_results_accumulator is None: grid_search_results_accumulator = scores else: @@ -213,8 +213,8 @@ def split_data(frame, num_folds, tc=TkContext.implicit): Randomly split data based on num_folds specified. Implementation logic borrowed from pyspark. :param frame: The frame to be split into train and validation frames :param num_folds: Number of folds to be split into - :param tc: spark-tk context - :return: validation frame and train frame for each fold + :param tc: spark-tk context passed implicitly + :return: train frame and test frame for each fold """ from pyspark.sql.functions import rand df = frame.dataframe @@ -222,26 +222,46 @@ def split_data(frame, num_folds, tc=TkContext.implicit): rand_col = "rand_1" df_indexed = df.select("*", rand(0).alias(rand_col)) for i in xrange(num_folds): - validation_lower_bound = i*h - validation_upper_bound = (i+1)*h - condition = (df_indexed[rand_col] >= validation_lower_bound) & (df_indexed[rand_col] < validation_upper_bound) - validation_df = df_indexed.filter(condition) + test_lower_bound = i*h + test_upper_bound = (i+1)*h + condition = (df_indexed[rand_col] >= test_lower_bound) & (df_indexed[rand_col] < test_upper_bound) + test_df = df_indexed.filter(condition) train_df = df_indexed.filter(~condition) train_frame = tc.frame.create(train_df) - validation_frame = tc.frame.create(validation_df) - yield validation_frame, train_frame + test_frame = tc.frame.create(test_df) + yield train_frame, test_frame class CrossValidateClassificationResults(object): + """ + Class storing the results of cross validation for classification + """ def __init__(self, all_grid_search_results, averages, verbose=False): + """ + Initializes the CrossValidateClassificationResults object with all the results, averages across folds and + verbosity desired + :param all_grid_search_results: Metrics for all models and their configurations on each fold + :param averages: Average metrics for each model and configurations across all folds + :param verbose: The verbosity desired. + If false, only the averages are displayed. + If true all the results and averages are displayed + """ self.all_results = all_grid_search_results self.averages = averages self.verbose = verbose def _get_all_str(self): + """ + Method to print all the metrics + :return: All the metrics + """ return "\n".join(["\n".join([str(point) for point in cm.grid_points]) for cm in self.all_results]) def show_all(self): + """ + Method to show the results for all models and configurations across each fold + :return: The classification metrics for all models and configurations across each fold + """ return self._get_all_str() def __repr__(self): diff --git a/python/sparktk/models/_selection/grid_search.py b/python/sparktk/models/_selection/grid_search.py index b1e3ea58..2c7e49f7 100644 --- a/python/sparktk/models/_selection/grid_search.py +++ b/python/sparktk/models/_selection/grid_search.py @@ -16,23 +16,31 @@ # -import inspect -from collections import namedtuple from sparktk import TkContext from sparktk.frame.ops.classification_metrics_value import ClassificationMetricsValue -from sparktk.arguments import extract_call, validate_call, affirm_type, require_type, value_error +from collections import namedtuple +from sparktk.arguments import extract_call, validate_call from sparktk.frame.frame import Frame from sparktk import arguments +def grid_values(*args): + """ + Method that returns the args as a named tuple of GridValues and list of args + :param args: Value/s passed for the model's parameter + :return: named tuple of GridValues and list of args + """ + return GridValues(args) + + def grid_search(train_frame, test_frame, train_descriptors, tc= TkContext.implicit): """ Implements grid search by training the specified model on all combinations of descriptor and testing on test frame :param train_frame: The frame to train the model on :param test_frame: The frame to test the model on - :param model_type: The model reference - :param descriptor: Dictionary of model parameters and their value/values in list of type grid_values - :param tc: spark-tk context + :param train_descriptors: Tuple of model and Dictionary of model parameters and their value/values as singleton + values or a list of type grid_values + :param tc: spark-tk context passed implicitly :return: Summary of metrics for different combinations of the grid and the best performing parameter combination Example @@ -148,18 +156,11 @@ def grid_search(train_frame, test_frame, train_descriptors, tc= TkContext.implic """ - # validate input + #validate input TkContext.validate(tc) - descriptors = affirm_type.list_of_anything(train_descriptors, "train_descriptors") - for i in xrange(len(descriptors)): - item = descriptors[i] - if not isinstance(item, TrainDescriptor): - require_type(tuple, item, "item", "grid_search needs a list of items which are either of type TrainDescriptor or tuples of (model, train_kwargs)") - if len(item) != 2: - raise value_error("list requires tuples of len 2", item, "item in train_descriptors") - if not hasattr(item[0], 'train'): - raise value_error("first item in tuple needs to be a object with a 'train' function", item, "item in train_descriptors") - descriptors[i] = TrainDescriptor(item[0], item[1]) + if not isinstance(train_descriptors, list): + train_descriptors = [train_descriptors] + descriptors = [TrainDescriptor(x[0], x[1]) for x in train_descriptors if not isinstance(x, TrainDescriptor)] arguments.require_type(Frame, train_frame, "frame") arguments.require_type(Frame, test_frame, "frame") @@ -167,28 +168,40 @@ def grid_search(train_frame, test_frame, train_descriptors, tc= TkContext.implic grid_points = [] for descriptor in descriptors: train_method = getattr(descriptor.model_type, "train") - list_of_kwargs = expand_kwarg_grids([descriptor.kwargs]) + list_of_kwargs = expand_kwarg_grids([descriptor.kwargs]) for kwargs in list_of_kwargs: train_kwargs = dict(kwargs) train_kwargs['frame'] = train_frame validate_call(train_method, train_kwargs, ignore_self=True) model = descriptor.model_type.train(**train_kwargs) - global count test_kwargs = dict(kwargs) test_kwargs['frame'] = test_frame test_kwargs = extract_call(model.test, test_kwargs, ignore_self=True) metrics = model.test(**test_kwargs) grid_points.append(GridPoint(descriptor=TrainDescriptor(descriptor.model_type, train_kwargs), metrics=metrics)) - count += 1 # sanity count return GridSearchResults(grid_points) class MetricsCompare(object): + """ + Class to compare the classification metrics and pick the best performing model configuration based on 'accuracy' + """ def __init__(self, emphasis="accuracy", compare=None): + """ + Initializes the object of class MetricsCompare + :param emphasis: The metric to be compared on. We are initializing this to 'accuracy' + :param compare: The objects to be compared + """ self.compare = compare or self._get_compare(emphasis) def is_a_better_than_b(self, a, b): + """ + Method to compare two metrics objects + :param a: First object + :param b: Second object + :return: Determines if a is better than b + """ result = self.compare(a, b) return result > 0 @@ -205,12 +218,20 @@ def default_compare(a, b): return default_compare -class Metrics(ClassificationMetricsValue): +class GridSearchClassificationMetrics(ClassificationMetricsValue): + """ + Class containing the results of grid_search for classification + """ def __init__(self): - super(Metrics, self).__init__(None, None) + super(GridSearchClassificationMetrics, self).__init__(None, None) def _divide(self, denominator): + """ + Divides the classification metrics with a value for averaging + :param denominator: The number used to compute the average + :return: Average values for the classification metric + """ self.precision = (self.precision / float(denominator)) self.accuracy = (self.accuracy / float(denominator)) self.recall = (self.recall / float(denominator)) @@ -218,7 +239,13 @@ def _divide(self, denominator): @staticmethod def _create_metric_sum(a, b): - metric_sum = Metrics() + """ + Computes the sum of the classification metrics + :param a: First element + :param b: Second element + :return: Sum of the ClassificationMetrics of a and b + """ + metric_sum = GridSearchClassificationMetrics() metric_sum.accuracy = a.accuracy + b.accuracy metric_sum.precision = a.precision + b.precision metric_sum.f_measure = a.f_measure + b.f_measure @@ -231,14 +258,15 @@ def _create_metric_sum(a, b): class TrainDescriptor(object): - """Describes a train operation: a model type and the arguments for its train method""" + """ + Class that separates the model type and args from the input and handles the representation. + """ def __init__(self, model_type, kwargs): """ - Creates a TrainDescriptor - - :param model_type: type object representing the model in question - :param kwargs: dict of key-value-pairs holding values for the train method's parameters + Initializes the model_type and model's arguments + :param model_type: The name of the model + :param kwargs: The list of model parameters """ self.model_type = model_type self.kwargs = kwargs @@ -253,23 +281,17 @@ def __repr__(self): return "%s: %s" % (mt, kw) -def grid_values(*args): - return GridValues(args) - -count = 0 - - def expand_kwarg_grids(dictionaries): """ Method to expand the dictionary of arguments - :param dictionaries: Parameters for the model + :param dictionaries: Parameters for the model of type (list of dict) :return: Expanded list of parameters for the model """ - if not isinstance(dictionaries, list): - raise ValueError("descriptors was not a list but: %s" % dictionaries) + arguments.require_type(list, dictionaries, "dictionaries") new_dictionaries = [] for dictionary in dictionaries: for k, v in dictionary.items(): + arguments.require_type(dict, dictionary, "item in dictionaries") if isinstance(v, GridValues): for a in v.args: d = dictionary.copy() @@ -284,13 +306,25 @@ def expand_kwarg_grids(dictionaries): class GridSearchResults(object): + """ + Class that stores the results of grid_search. The classification metrics of all model configurations are stored. + """ def __init__(self, grid_points, metrics_compare=None): + """ + Initializes the GridSearchResults class object with the grid_points and the metric to compare the results + :param grid_points: The results of the grid_search computation + :param metrics_compare: The user specified metric to compare the results on + """ self.grid_points = grid_points # add require_type for metrics_compare self.metrics_compare = metrics_compare or MetricsCompare() def copy(self): + """ + Copies the GridSearchResults into the desired format + :return: GridSearchResults object + """ return GridSearchResults([GridPoint(gp.descriptor, gp.metrics) for gp in self.grid_points], self.metrics_compare) def __repr__(self): @@ -313,6 +347,13 @@ def find_best(self, metrics_compare=None): @staticmethod def _validate_descriptors_are_equal(a, b, ignore_args=None): + """ + Method to compare if the two descriptors being compared are equal + :param a: First descriptor + :param b: Second descriptor + :param ignore_args: List of descriptors that need to be ignored + :return: Checks if the two descriptors being compared are equal and raises errors if they are not + """ if ignore_args is None: ignore_args = [] if a.model_type != b.model_type: @@ -326,14 +367,24 @@ def _validate_descriptors_are_equal(a, b, ignore_args=None): raise ValueError("Descriptors a != b because of different values for '%s': %s != %s" % (k, v, b.kwargs[k])) def _accumulate_matching_points(self, points): + """ + Method to compute the sum of the metrics for a given model and parameter configuration. + :param points: Model and parameter configurations and the computed metrics + :return: sum of the metrics for a given model and parameter configuration + """ if len(self.grid_points) != len(points): raise ValueError("Expected list of points of len %s, got %s" % (len(self.grid_points), len(points))) for index in xrange(len(self.grid_points)): self._validate_descriptors_are_equal(self.grid_points[index].descriptor, points[index].descriptor, ["frame"]) - m = Metrics._create_metric_sum(self.grid_points[index].metrics, points[index].metrics) + m = GridSearchClassificationMetrics._create_metric_sum(self.grid_points[index].metrics, points[index].metrics) self.grid_points[index] = GridPoint(self.grid_points[index].descriptor, m) def _divide_metrics(self, denominator): + """ + Method to compute the average metrics + :param denominator: The denominator for avergae computation + :return: Averaged values of the metrics + """ for point in self.grid_points: point.metrics._divide(denominator)