From a38ef2dcc035912ebf383e639fec82ed00ae60f5 Mon Sep 17 00:00:00 2001 From: sirtorry Date: Wed, 13 Nov 2019 16:39:09 -0800 Subject: [PATCH 1/8] pass params --- .../cloud/automl_v1beta1/tables/tables_client.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/automl/google/cloud/automl_v1beta1/tables/tables_client.py b/automl/google/cloud/automl_v1beta1/tables/tables_client.py index 7b960f0b7b12..70cee94be3e9 100644 --- a/automl/google/cloud/automl_v1beta1/tables/tables_client.py +++ b/automl/google/cloud/automl_v1beta1/tables/tables_client.py @@ -2596,6 +2596,7 @@ def predict( model=None, model_name=None, model_display_name=None, + params=None, project=None, region=None, **kwargs @@ -2642,6 +2643,14 @@ def predict( The `model` instance you want to predict with . This must be supplied if `model_display_name` or `model_name` are not supplied. + params (Dict[str, str]): + Additional domain-specific parameters, any string must be up to + 25000 characters long. + ``feature_importance`` - (boolean) Whether + [feature\_importance][[google.cloud.automl.v1beta1.TablesModelColumnInfo.feature\_importance] + should be populated in the returned + [TablesAnnotation(-s)][[google.cloud.automl.v1beta1.TablesAnnotation]. + The default is false. Returns: A :class:`~google.cloud.automl_v1beta1.types.PredictResponse` @@ -2683,7 +2692,7 @@ def predict( request = {"row": {"values": values}} - return self.prediction_client.predict(model.name, request, **kwargs) + return self.prediction_client.predict(model.name, request, params, **kwargs) def batch_predict( self, From 7eedae1a05cdfc176513d780a2c5af2bffdbb1dd Mon Sep 17 00:00:00 2001 From: Torry Yang Date: Wed, 13 Nov 2019 16:54:57 -0800 Subject: [PATCH 2/8] Update automl/google/cloud/automl_v1beta1/tables/tables_client.py Co-Authored-By: Bu Sun Kim <8822365+busunkim96@users.noreply.github.com> --- automl/google/cloud/automl_v1beta1/tables/tables_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/automl/google/cloud/automl_v1beta1/tables/tables_client.py b/automl/google/cloud/automl_v1beta1/tables/tables_client.py index 70cee94be3e9..fe4233e492f7 100644 --- a/automl/google/cloud/automl_v1beta1/tables/tables_client.py +++ b/automl/google/cloud/automl_v1beta1/tables/tables_client.py @@ -2643,7 +2643,7 @@ def predict( The `model` instance you want to predict with . This must be supplied if `model_display_name` or `model_name` are not supplied. - params (Dict[str, str]): + params (dict[str, str]): Additional domain-specific parameters, any string must be up to 25000 characters long. ``feature_importance`` - (boolean) Whether From 1083d89d0012a1240a44804ed1ade2bde7207fb5 Mon Sep 17 00:00:00 2001 From: sirtorry Date: Wed, 13 Nov 2019 18:25:07 -0800 Subject: [PATCH 3/8] clean up spec --- .../google/cloud/automl_v1beta1/tables/tables_client.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/automl/google/cloud/automl_v1beta1/tables/tables_client.py b/automl/google/cloud/automl_v1beta1/tables/tables_client.py index 70cee94be3e9..a082d7382071 100644 --- a/automl/google/cloud/automl_v1beta1/tables/tables_client.py +++ b/automl/google/cloud/automl_v1beta1/tables/tables_client.py @@ -2644,13 +2644,8 @@ def predict( supplied if `model_display_name` or `model_name` are not supplied. params (Dict[str, str]): - Additional domain-specific parameters, any string must be up to - 25000 characters long. - ``feature_importance`` - (boolean) Whether - [feature\_importance][[google.cloud.automl.v1beta1.TablesModelColumnInfo.feature\_importance] - should be populated in the returned - [TablesAnnotation(-s)][[google.cloud.automl.v1beta1.TablesAnnotation]. - The default is false. + `feature_importance` can be set as True to enable local + explainability. The default is false. Returns: A :class:`~google.cloud.automl_v1beta1.types.PredictResponse` From 94089c37130c180313d3fe13adf25ed7acc2285d Mon Sep 17 00:00:00 2001 From: sirtorry Date: Wed, 13 Nov 2019 18:41:57 -0800 Subject: [PATCH 4/8] lint --- .../cloud/automl_v1beta1/tables/gcs_client.py | 20 +++- .../automl_v1beta1/tables/tables_client.py | 110 +++++++++++++----- 2 files changed, 99 insertions(+), 31 deletions(-) diff --git a/automl/google/cloud/automl_v1beta1/tables/gcs_client.py b/automl/google/cloud/automl_v1beta1/tables/gcs_client.py index 99d40da2867c..e890a8dd645a 100644 --- a/automl/google/cloud/automl_v1beta1/tables/gcs_client.py +++ b/automl/google/cloud/automl_v1beta1/tables/gcs_client.py @@ -41,7 +41,9 @@ class GcsClient(object): """Uploads Pandas DataFrame to a bucket in Google Cloud Storage.""" - def __init__(self, bucket_name=None, client=None, credentials=None, project=None): + def __init__( + self, bucket_name=None, client=None, credentials=None, project=None + ): """Constructor. Args: @@ -65,7 +67,9 @@ def __init__(self, bucket_name=None, client=None, credentials=None, project=None if client is not None: self.client = client elif credentials is not None: - self.client = storage.Client(credentials=credentials, project=project) + self.client = storage.Client( + credentials=credentials, project=project + ) else: self.client = storage.Client() @@ -97,7 +101,9 @@ def ensure_bucket_exists(self, project, region): except (exceptions.Forbidden, exceptions.NotFound) as e: if isinstance(e, exceptions.Forbidden): used_bucket_name = self.bucket_name - self.bucket_name = used_bucket_name + "-{}".format(int(time.time())) + self.bucket_name = used_bucket_name + "-{}".format( + int(time.time()) + ) _LOGGER.warning( "Created a bucket named {} because a bucket named {} already exists in a different project.".format( self.bucket_name, used_bucket_name @@ -123,10 +129,14 @@ def upload_pandas_dataframe(self, dataframe, uploaded_csv_name=None): raise ImportError(_PANDAS_REQUIRED) if not isinstance(dataframe, pandas.DataFrame): - raise ValueError("'dataframe' must be a pandas.DataFrame instance.") + raise ValueError( + "'dataframe' must be a pandas.DataFrame instance." + ) if self.bucket_name is None: - raise ValueError("Must ensure a bucket exists before uploading data.") + raise ValueError( + "Must ensure a bucket exists before uploading data." + ) if uploaded_csv_name is None: uploaded_csv_name = "automl-tables-dataframe-{}.csv".format( diff --git a/automl/google/cloud/automl_v1beta1/tables/tables_client.py b/automl/google/cloud/automl_v1beta1/tables/tables_client.py index 251bb94c0ae4..0c554f8ff5ed 100644 --- a/automl/google/cloud/automl_v1beta1/tables/tables_client.py +++ b/automl/google/cloud/automl_v1beta1/tables/tables_client.py @@ -25,7 +25,9 @@ from google.cloud.automl_v1beta1.proto import data_types_pb2 from google.cloud.automl_v1beta1.tables import gcs_client -_GAPIC_LIBRARY_VERSION = pkg_resources.get_distribution("google-cloud-automl").version +_GAPIC_LIBRARY_VERSION = pkg_resources.get_distribution( + "google-cloud-automl" +).version _LOGGER = logging.getLogger(__name__) @@ -187,7 +189,11 @@ def __dataset_from_args( region=None, **kwargs ): - if dataset is None and dataset_display_name is None and dataset_name is None: + if ( + dataset is None + and dataset_display_name is None + and dataset_name is None + ): raise ValueError( "One of 'dataset', 'dataset_name' or " "'dataset_display_name' must be set." @@ -216,7 +222,8 @@ def __model_from_args( ): if model is None and model_display_name is None and model_name is None: raise ValueError( - "One of 'model', 'model_name' or " "'model_display_name' must be set." + "One of 'model', 'model_name' or " + "'model_display_name' must be set." ) # we prefer to make a live call here in the case that the # model object is out-of-date @@ -240,7 +247,11 @@ def __dataset_name_from_args( region=None, **kwargs ): - if dataset is None and dataset_display_name is None and dataset_name is None: + if ( + dataset is None + and dataset_display_name is None + and dataset_name is None + ): raise ValueError( "One of 'dataset', 'dataset_name' or " "'dataset_display_name' must be set." @@ -259,7 +270,10 @@ def __dataset_name_from_args( else: # we do this to force a NotFound error when needed self.get_dataset( - dataset_name=dataset_name, project=project, region=region, **kwargs + dataset_name=dataset_name, + project=project, + region=region, + **kwargs ) return dataset_name @@ -283,7 +297,8 @@ def __table_spec_name_from_args( ) table_specs = [ - t for t in self.list_table_specs(dataset_name=dataset_name, **kwargs) + t + for t in self.list_table_specs(dataset_name=dataset_name, **kwargs) ] table_spec_full_id = table_specs[table_spec_index].name @@ -300,7 +315,8 @@ def __model_name_from_args( ): if model is None and model_display_name is None and model_name is None: raise ValueError( - "One of 'model', 'model_name' or " "'model_display_name' must be set." + "One of 'model', 'model_name' or " + "'model_display_name' must be set." ) if model_name is None: @@ -527,7 +543,8 @@ def get_dataset( """ if dataset_name is None and dataset_display_name is None: raise ValueError( - "One of 'dataset_name' or " "'dataset_display_name' must be set." + "One of 'dataset_name' or " + "'dataset_display_name' must be set." ) if dataset_name is not None: @@ -540,7 +557,12 @@ def get_dataset( ) def create_dataset( - self, dataset_display_name, metadata={}, project=None, region=None, **kwargs + self, + dataset_display_name, + metadata={}, + project=None, + region=None, + **kwargs ): """Create a dataset. Keep in mind, importing data is a separate step. @@ -580,7 +602,10 @@ def create_dataset( """ return self.auto_ml_client.create_dataset( self.__location_path(project, region), - {"display_name": dataset_display_name, "tables_dataset_metadata": metadata}, + { + "display_name": dataset_display_name, + "tables_dataset_metadata": metadata, + }, **kwargs ) @@ -767,7 +792,9 @@ def import_data( credentials = credentials or self.credentials self.__ensure_gcs_client_is_initialized(credentials, project) self.gcs_client.ensure_bucket_exists(project, region) - gcs_input_uri = self.gcs_client.upload_pandas_dataframe(pandas_dataframe) + gcs_input_uri = self.gcs_client.upload_pandas_dataframe( + pandas_dataframe + ) request = {"gcs_source": {"input_uris": [gcs_input_uri]}} elif gcs_input_uris is not None: if type(gcs_input_uris) != list: @@ -868,9 +895,13 @@ def export_data( request = {} if gcs_output_uri_prefix is not None: - request = {"gcs_destination": {"output_uri_prefix": gcs_output_uri_prefix}} + request = { + "gcs_destination": {"output_uri_prefix": gcs_output_uri_prefix} + } elif bigquery_output_uri is not None: - request = {"bigquery_destination": {"output_uri": bigquery_output_uri}} + request = { + "bigquery_destination": {"output_uri": bigquery_output_uri} + } else: raise ValueError( "One of 'gcs_output_uri_prefix', or 'bigquery_output_uri' must be set." @@ -880,7 +911,9 @@ def export_data( self.__log_operation_info("Export data", op) return op - def get_table_spec(self, table_spec_name, project=None, region=None, **kwargs): + def get_table_spec( + self, table_spec_name, project=None, region=None, **kwargs + ): """Gets a single table spec in a particular project and region. Example: @@ -992,7 +1025,9 @@ def list_table_specs( return self.auto_ml_client.list_table_specs(dataset_name, **kwargs) - def get_column_spec(self, column_spec_name, project=None, region=None, **kwargs): + def get_column_spec( + self, column_spec_name, project=None, region=None, **kwargs + ): """Gets a single column spec in a particular project and region. Example: @@ -1572,7 +1607,10 @@ def clear_time_column( dataset_name=dataset_name, **kwargs ) - my_table_spec = {"name": table_spec_full_id, "time_column_spec_id": None} + my_table_spec = { + "name": table_spec_full_id, + "time_column_spec_id": None, + } return self.auto_ml_client.update_table_spec(my_table_spec, **kwargs) @@ -1766,7 +1804,9 @@ def clear_weight_column( **kwargs ) metadata = dataset.tables_dataset_metadata - metadata = self.__update_metadata(metadata, "weight_column_spec_id", None) + metadata = self.__update_metadata( + metadata, "weight_column_spec_id", None + ) request = {"name": dataset.name, "tables_dataset_metadata": metadata} @@ -1964,7 +2004,9 @@ def clear_test_train_column( **kwargs ) metadata = dataset.tables_dataset_metadata - metadata = self.__update_metadata(metadata, "ml_use_column_spec_id", None) + metadata = self.__update_metadata( + metadata, "ml_use_column_spec_id", None + ) request = {"name": dataset.name, "tables_dataset_metadata": metadata} @@ -2217,7 +2259,9 @@ def create_model( **kwargs ) - model_metadata["train_budget_milli_node_hours"] = train_budget_milli_node_hours + model_metadata[ + "train_budget_milli_node_hours" + ] = train_budget_milli_node_hours if optimization_objective is not None: model_metadata["optimization_objective"] = optimization_objective if disable_early_stopping: @@ -2255,7 +2299,9 @@ def create_model( } op = self.auto_ml_client.create_model( - self.__location_path(project=project, region=region), request, **kwargs + self.__location_path(project=project, region=region), + request, + **kwargs ) self.__log_operation_info("Model creation", op) return op @@ -2377,7 +2423,9 @@ def get_model_evaluation( to a retryable error and retry attempts failed. ValueError: If required parameters are missing. """ - return self.auto_ml_client.get_model_evaluation(model_evaluation_name, **kwargs) + return self.auto_ml_client.get_model_evaluation( + model_evaluation_name, **kwargs + ) def get_model( self, @@ -2440,7 +2488,9 @@ def get_model( return self.auto_ml_client.get_model(model_name, **kwargs) return self.__lookup_by_display_name( - "model", self.list_models(project, region, **kwargs), model_display_name + "model", + self.list_models(project, region, **kwargs), + model_display_name, ) # TODO(jonathanskim): allow deployment from just model ID @@ -2682,12 +2732,16 @@ def predict( values = [] for i, c in zip(inputs, column_specs): - value_type = self.__type_code_to_value_type(c.data_type.type_code, i) + value_type = self.__type_code_to_value_type( + c.data_type.type_code, i + ) values.append(value_type) request = {"row": {"values": values}} - return self.prediction_client.predict(model.name, request, params, **kwargs) + return self.prediction_client.predict( + model.name, request, params, **kwargs + ) def batch_predict( self, @@ -2799,14 +2853,18 @@ def batch_predict( credentials = credentials or self.credentials self.__ensure_gcs_client_is_initialized(credentials, project) self.gcs_client.ensure_bucket_exists(project, region) - gcs_input_uri = self.gcs_client.upload_pandas_dataframe(pandas_dataframe) + gcs_input_uri = self.gcs_client.upload_pandas_dataframe( + pandas_dataframe + ) input_request = {"gcs_source": {"input_uris": [gcs_input_uri]}} elif gcs_input_uris is not None: if type(gcs_input_uris) != list: gcs_input_uris = [gcs_input_uris] input_request = {"gcs_source": {"input_uris": gcs_input_uris}} elif bigquery_input_uri is not None: - input_request = {"bigquery_source": {"input_uri": bigquery_input_uri}} + input_request = { + "bigquery_source": {"input_uri": bigquery_input_uri} + } else: raise ValueError( "One of 'gcs_input_uris'/'bigquery_input_uris' must" "be set" From e684bea063a306cbe5c0e680ef78c8151b2c450e Mon Sep 17 00:00:00 2001 From: Bu Sun Kim Date: Thu, 14 Nov 2019 10:42:00 -0800 Subject: [PATCH 5/8] chore: blacken --- .../cloud/automl_v1beta1/tables/gcs_client.py | 20 +--- .../automl_v1beta1/tables/tables_client.py | 110 +++++------------- 2 files changed, 31 insertions(+), 99 deletions(-) diff --git a/automl/google/cloud/automl_v1beta1/tables/gcs_client.py b/automl/google/cloud/automl_v1beta1/tables/gcs_client.py index e890a8dd645a..99d40da2867c 100644 --- a/automl/google/cloud/automl_v1beta1/tables/gcs_client.py +++ b/automl/google/cloud/automl_v1beta1/tables/gcs_client.py @@ -41,9 +41,7 @@ class GcsClient(object): """Uploads Pandas DataFrame to a bucket in Google Cloud Storage.""" - def __init__( - self, bucket_name=None, client=None, credentials=None, project=None - ): + def __init__(self, bucket_name=None, client=None, credentials=None, project=None): """Constructor. Args: @@ -67,9 +65,7 @@ def __init__( if client is not None: self.client = client elif credentials is not None: - self.client = storage.Client( - credentials=credentials, project=project - ) + self.client = storage.Client(credentials=credentials, project=project) else: self.client = storage.Client() @@ -101,9 +97,7 @@ def ensure_bucket_exists(self, project, region): except (exceptions.Forbidden, exceptions.NotFound) as e: if isinstance(e, exceptions.Forbidden): used_bucket_name = self.bucket_name - self.bucket_name = used_bucket_name + "-{}".format( - int(time.time()) - ) + self.bucket_name = used_bucket_name + "-{}".format(int(time.time())) _LOGGER.warning( "Created a bucket named {} because a bucket named {} already exists in a different project.".format( self.bucket_name, used_bucket_name @@ -129,14 +123,10 @@ def upload_pandas_dataframe(self, dataframe, uploaded_csv_name=None): raise ImportError(_PANDAS_REQUIRED) if not isinstance(dataframe, pandas.DataFrame): - raise ValueError( - "'dataframe' must be a pandas.DataFrame instance." - ) + raise ValueError("'dataframe' must be a pandas.DataFrame instance.") if self.bucket_name is None: - raise ValueError( - "Must ensure a bucket exists before uploading data." - ) + raise ValueError("Must ensure a bucket exists before uploading data.") if uploaded_csv_name is None: uploaded_csv_name = "automl-tables-dataframe-{}.csv".format( diff --git a/automl/google/cloud/automl_v1beta1/tables/tables_client.py b/automl/google/cloud/automl_v1beta1/tables/tables_client.py index 0c554f8ff5ed..251bb94c0ae4 100644 --- a/automl/google/cloud/automl_v1beta1/tables/tables_client.py +++ b/automl/google/cloud/automl_v1beta1/tables/tables_client.py @@ -25,9 +25,7 @@ from google.cloud.automl_v1beta1.proto import data_types_pb2 from google.cloud.automl_v1beta1.tables import gcs_client -_GAPIC_LIBRARY_VERSION = pkg_resources.get_distribution( - "google-cloud-automl" -).version +_GAPIC_LIBRARY_VERSION = pkg_resources.get_distribution("google-cloud-automl").version _LOGGER = logging.getLogger(__name__) @@ -189,11 +187,7 @@ def __dataset_from_args( region=None, **kwargs ): - if ( - dataset is None - and dataset_display_name is None - and dataset_name is None - ): + if dataset is None and dataset_display_name is None and dataset_name is None: raise ValueError( "One of 'dataset', 'dataset_name' or " "'dataset_display_name' must be set." @@ -222,8 +216,7 @@ def __model_from_args( ): if model is None and model_display_name is None and model_name is None: raise ValueError( - "One of 'model', 'model_name' or " - "'model_display_name' must be set." + "One of 'model', 'model_name' or " "'model_display_name' must be set." ) # we prefer to make a live call here in the case that the # model object is out-of-date @@ -247,11 +240,7 @@ def __dataset_name_from_args( region=None, **kwargs ): - if ( - dataset is None - and dataset_display_name is None - and dataset_name is None - ): + if dataset is None and dataset_display_name is None and dataset_name is None: raise ValueError( "One of 'dataset', 'dataset_name' or " "'dataset_display_name' must be set." @@ -270,10 +259,7 @@ def __dataset_name_from_args( else: # we do this to force a NotFound error when needed self.get_dataset( - dataset_name=dataset_name, - project=project, - region=region, - **kwargs + dataset_name=dataset_name, project=project, region=region, **kwargs ) return dataset_name @@ -297,8 +283,7 @@ def __table_spec_name_from_args( ) table_specs = [ - t - for t in self.list_table_specs(dataset_name=dataset_name, **kwargs) + t for t in self.list_table_specs(dataset_name=dataset_name, **kwargs) ] table_spec_full_id = table_specs[table_spec_index].name @@ -315,8 +300,7 @@ def __model_name_from_args( ): if model is None and model_display_name is None and model_name is None: raise ValueError( - "One of 'model', 'model_name' or " - "'model_display_name' must be set." + "One of 'model', 'model_name' or " "'model_display_name' must be set." ) if model_name is None: @@ -543,8 +527,7 @@ def get_dataset( """ if dataset_name is None and dataset_display_name is None: raise ValueError( - "One of 'dataset_name' or " - "'dataset_display_name' must be set." + "One of 'dataset_name' or " "'dataset_display_name' must be set." ) if dataset_name is not None: @@ -557,12 +540,7 @@ def get_dataset( ) def create_dataset( - self, - dataset_display_name, - metadata={}, - project=None, - region=None, - **kwargs + self, dataset_display_name, metadata={}, project=None, region=None, **kwargs ): """Create a dataset. Keep in mind, importing data is a separate step. @@ -602,10 +580,7 @@ def create_dataset( """ return self.auto_ml_client.create_dataset( self.__location_path(project, region), - { - "display_name": dataset_display_name, - "tables_dataset_metadata": metadata, - }, + {"display_name": dataset_display_name, "tables_dataset_metadata": metadata}, **kwargs ) @@ -792,9 +767,7 @@ def import_data( credentials = credentials or self.credentials self.__ensure_gcs_client_is_initialized(credentials, project) self.gcs_client.ensure_bucket_exists(project, region) - gcs_input_uri = self.gcs_client.upload_pandas_dataframe( - pandas_dataframe - ) + gcs_input_uri = self.gcs_client.upload_pandas_dataframe(pandas_dataframe) request = {"gcs_source": {"input_uris": [gcs_input_uri]}} elif gcs_input_uris is not None: if type(gcs_input_uris) != list: @@ -895,13 +868,9 @@ def export_data( request = {} if gcs_output_uri_prefix is not None: - request = { - "gcs_destination": {"output_uri_prefix": gcs_output_uri_prefix} - } + request = {"gcs_destination": {"output_uri_prefix": gcs_output_uri_prefix}} elif bigquery_output_uri is not None: - request = { - "bigquery_destination": {"output_uri": bigquery_output_uri} - } + request = {"bigquery_destination": {"output_uri": bigquery_output_uri}} else: raise ValueError( "One of 'gcs_output_uri_prefix', or 'bigquery_output_uri' must be set." @@ -911,9 +880,7 @@ def export_data( self.__log_operation_info("Export data", op) return op - def get_table_spec( - self, table_spec_name, project=None, region=None, **kwargs - ): + def get_table_spec(self, table_spec_name, project=None, region=None, **kwargs): """Gets a single table spec in a particular project and region. Example: @@ -1025,9 +992,7 @@ def list_table_specs( return self.auto_ml_client.list_table_specs(dataset_name, **kwargs) - def get_column_spec( - self, column_spec_name, project=None, region=None, **kwargs - ): + def get_column_spec(self, column_spec_name, project=None, region=None, **kwargs): """Gets a single column spec in a particular project and region. Example: @@ -1607,10 +1572,7 @@ def clear_time_column( dataset_name=dataset_name, **kwargs ) - my_table_spec = { - "name": table_spec_full_id, - "time_column_spec_id": None, - } + my_table_spec = {"name": table_spec_full_id, "time_column_spec_id": None} return self.auto_ml_client.update_table_spec(my_table_spec, **kwargs) @@ -1804,9 +1766,7 @@ def clear_weight_column( **kwargs ) metadata = dataset.tables_dataset_metadata - metadata = self.__update_metadata( - metadata, "weight_column_spec_id", None - ) + metadata = self.__update_metadata(metadata, "weight_column_spec_id", None) request = {"name": dataset.name, "tables_dataset_metadata": metadata} @@ -2004,9 +1964,7 @@ def clear_test_train_column( **kwargs ) metadata = dataset.tables_dataset_metadata - metadata = self.__update_metadata( - metadata, "ml_use_column_spec_id", None - ) + metadata = self.__update_metadata(metadata, "ml_use_column_spec_id", None) request = {"name": dataset.name, "tables_dataset_metadata": metadata} @@ -2259,9 +2217,7 @@ def create_model( **kwargs ) - model_metadata[ - "train_budget_milli_node_hours" - ] = train_budget_milli_node_hours + model_metadata["train_budget_milli_node_hours"] = train_budget_milli_node_hours if optimization_objective is not None: model_metadata["optimization_objective"] = optimization_objective if disable_early_stopping: @@ -2299,9 +2255,7 @@ def create_model( } op = self.auto_ml_client.create_model( - self.__location_path(project=project, region=region), - request, - **kwargs + self.__location_path(project=project, region=region), request, **kwargs ) self.__log_operation_info("Model creation", op) return op @@ -2423,9 +2377,7 @@ def get_model_evaluation( to a retryable error and retry attempts failed. ValueError: If required parameters are missing. """ - return self.auto_ml_client.get_model_evaluation( - model_evaluation_name, **kwargs - ) + return self.auto_ml_client.get_model_evaluation(model_evaluation_name, **kwargs) def get_model( self, @@ -2488,9 +2440,7 @@ def get_model( return self.auto_ml_client.get_model(model_name, **kwargs) return self.__lookup_by_display_name( - "model", - self.list_models(project, region, **kwargs), - model_display_name, + "model", self.list_models(project, region, **kwargs), model_display_name ) # TODO(jonathanskim): allow deployment from just model ID @@ -2732,16 +2682,12 @@ def predict( values = [] for i, c in zip(inputs, column_specs): - value_type = self.__type_code_to_value_type( - c.data_type.type_code, i - ) + value_type = self.__type_code_to_value_type(c.data_type.type_code, i) values.append(value_type) request = {"row": {"values": values}} - return self.prediction_client.predict( - model.name, request, params, **kwargs - ) + return self.prediction_client.predict(model.name, request, params, **kwargs) def batch_predict( self, @@ -2853,18 +2799,14 @@ def batch_predict( credentials = credentials or self.credentials self.__ensure_gcs_client_is_initialized(credentials, project) self.gcs_client.ensure_bucket_exists(project, region) - gcs_input_uri = self.gcs_client.upload_pandas_dataframe( - pandas_dataframe - ) + gcs_input_uri = self.gcs_client.upload_pandas_dataframe(pandas_dataframe) input_request = {"gcs_source": {"input_uris": [gcs_input_uri]}} elif gcs_input_uris is not None: if type(gcs_input_uris) != list: gcs_input_uris = [gcs_input_uris] input_request = {"gcs_source": {"input_uris": gcs_input_uris}} elif bigquery_input_uri is not None: - input_request = { - "bigquery_source": {"input_uri": bigquery_input_uri} - } + input_request = {"bigquery_source": {"input_uri": bigquery_input_uri}} else: raise ValueError( "One of 'gcs_input_uris'/'bigquery_input_uris' must" "be set" From 7d9e11755357233edaf30499bca71bb505525af2 Mon Sep 17 00:00:00 2001 From: sirtorry Date: Fri, 15 Nov 2019 21:48:42 +0000 Subject: [PATCH 6/8] update tests --- .../unit/gapic/v1beta1/test_tables_client_v1beta1.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/automl/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py b/automl/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py index 516a4b76080d..e1794c4c4fb5 100644 --- a/automl/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py +++ b/automl/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py @@ -1117,7 +1117,7 @@ def test_predict_from_array(self): client = self.tables_client({"get_model.return_value": model}, {}) client.predict(["1"], model_name="my_model") client.prediction_client.predict.assert_called_with( - "my_model", {"row": {"values": [{"string_value": "1"}]}} + "my_model", {"row": {"values": [{"string_value": "1"}]}}, None ) def test_predict_from_dict(self): @@ -1134,6 +1134,7 @@ def test_predict_from_dict(self): client.prediction_client.predict.assert_called_with( "my_model", {"row": {"values": [{"string_value": "1"}, {"string_value": "2"}]}}, + None ) def test_predict_from_dict_missing(self): @@ -1148,7 +1149,9 @@ def test_predict_from_dict_missing(self): client = self.tables_client({"get_model.return_value": model}, {}) client.predict({"a": "1"}, model_name="my_model") client.prediction_client.predict.assert_called_with( - "my_model", {"row": {"values": [{"string_value": "1"}, {"null_value": 0}]}} + "my_model", + {"row": {"values": [{"string_value": "1"}, {"null_value": 0}]}}, + None ) def test_predict_all_types(self): @@ -1210,6 +1213,7 @@ def test_predict_all_types(self): ] } }, + None ) def test_predict_from_array_missing(self): From 6ef675e0ce8135455d8f0b3ba6c78dbb18632e58 Mon Sep 17 00:00:00 2001 From: Bu Sun Kim Date: Fri, 15 Nov 2019 14:17:17 -0800 Subject: [PATCH 7/8] blacken --- .../unit/gapic/v1beta1/test_tables_client_v1beta1.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/automl/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py b/automl/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py index e1794c4c4fb5..f164a8875787 100644 --- a/automl/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py +++ b/automl/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py @@ -1134,7 +1134,7 @@ def test_predict_from_dict(self): client.prediction_client.predict.assert_called_with( "my_model", {"row": {"values": [{"string_value": "1"}, {"string_value": "2"}]}}, - None + None, ) def test_predict_from_dict_missing(self): @@ -1149,9 +1149,9 @@ def test_predict_from_dict_missing(self): client = self.tables_client({"get_model.return_value": model}, {}) client.predict({"a": "1"}, model_name="my_model") client.prediction_client.predict.assert_called_with( - "my_model", - {"row": {"values": [{"string_value": "1"}, {"null_value": 0}]}}, - None + "my_model", + {"row": {"values": [{"string_value": "1"}, {"null_value": 0}]}}, + None, ) def test_predict_all_types(self): @@ -1213,7 +1213,7 @@ def test_predict_all_types(self): ] } }, - None + None, ) def test_predict_from_array_missing(self): From 3e7c94de507e2e79773717dfdaa9927038218024 Mon Sep 17 00:00:00 2001 From: Bu Sun Kim Date: Fri, 15 Nov 2019 14:44:18 -0800 Subject: [PATCH 8/8] chore: remove trailing space --- automl/google/cloud/automl_v1beta1/tables/tables_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/automl/google/cloud/automl_v1beta1/tables/tables_client.py b/automl/google/cloud/automl_v1beta1/tables/tables_client.py index 251bb94c0ae4..3a65ecc8f235 100644 --- a/automl/google/cloud/automl_v1beta1/tables/tables_client.py +++ b/automl/google/cloud/automl_v1beta1/tables/tables_client.py @@ -2643,7 +2643,7 @@ def predict( The `model` instance you want to predict with . This must be supplied if `model_display_name` or `model_name` are not supplied. - params (dict[str, str]): + params (dict[str, str]): `feature_importance` can be set as True to enable local explainability. The default is false.