Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: accept DatasetListItem where DatasetReference is accepted #597

Merged
merged 20 commits into from
Apr 12, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Converted list_model tests to pytest and included check for dataset p…
…olymorphism
  • Loading branch information
Jim Fulton committed Apr 8, 2021
commit 712878982c61b31080797de52c74f71bb6bed855
82 changes: 0 additions & 82 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2926,88 +2926,6 @@ def test_update_table_delete_property(self):
self.assertEqual(req[1]["data"], sent)
self.assertIsNone(table3.description)

def test_list_models_empty_w_timeout(self):
path = "/projects/{}/datasets/{}/models".format(self.PROJECT, self.DS_ID)
creds = _make_credentials()
client = self._make_one(project=self.PROJECT, credentials=creds)
conn = client._connection = make_connection({})

dataset_id = "{}.{}".format(self.PROJECT, self.DS_ID)
iterator = client.list_models(dataset_id, timeout=7.5)
with mock.patch(
"google.cloud.bigquery.opentelemetry_tracing._get_final_span_attributes"
) as final_attributes:
page = next(iterator.pages)

final_attributes.assert_called_once_with({"path": path}, client, None)
models = list(page)
token = iterator.next_page_token

self.assertEqual(models, [])
self.assertIsNone(token)
conn.api_request.assert_called_once_with(
method="GET", path=path, query_params={}, timeout=7.5
)

def test_list_models_defaults(self):
from google.cloud.bigquery.model import Model

MODEL_1 = "model_one"
MODEL_2 = "model_two"
PATH = "projects/%s/datasets/%s/models" % (self.PROJECT, self.DS_ID)
TOKEN = "TOKEN"
DATA = {
"nextPageToken": TOKEN,
"models": [
{
"modelReference": {
"modelId": MODEL_1,
"datasetId": self.DS_ID,
"projectId": self.PROJECT,
}
},
{
"modelReference": {
"modelId": MODEL_2,
"datasetId": self.DS_ID,
"projectId": self.PROJECT,
}
},
],
}

creds = _make_credentials()
client = self._make_one(project=self.PROJECT, credentials=creds)
conn = client._connection = make_connection(DATA)
dataset = DatasetReference(self.PROJECT, self.DS_ID)

iterator = client.list_models(dataset)
self.assertIs(iterator.dataset, dataset)
with mock.patch(
"google.cloud.bigquery.opentelemetry_tracing._get_final_span_attributes"
) as final_attributes:
page = next(iterator.pages)

final_attributes.assert_called_once_with({"path": "/%s" % PATH}, client, None)
models = list(page)
token = iterator.next_page_token

self.assertEqual(len(models), len(DATA["models"]))
for found, expected in zip(models, DATA["models"]):
self.assertIsInstance(found, Model)
self.assertEqual(found.model_id, expected["modelReference"]["modelId"])
self.assertEqual(token, TOKEN)

conn.api_request.assert_called_once_with(
method="GET", path="/%s" % PATH, query_params={}, timeout=None
)

def test_list_models_wrong_type(self):
creds = _make_credentials()
client = self._make_one(project=self.PROJECT, credentials=creds)
with self.assertRaises(TypeError):
client.list_models(DatasetReference(self.PROJECT, self.DS_ID).model("foo"))

def test_list_routines_empty_w_timeout(self):
creds = _make_credentials()
client = self._make_one(project=self.PROJECT, credentials=creds)
Expand Down
72 changes: 72 additions & 0 deletions tests/unit/test_list_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from .helpers import make_connection, dataset_polymorphic
import google.cloud.bigquery.dataset
import mock
import pytest


def test_list_models_empty_w_timeout(client, PROJECT, DS_ID):
path = "/projects/{}/datasets/{}/models".format(PROJECT, DS_ID)
conn = client._connection = make_connection({})

dataset_id = "{}.{}".format(PROJECT, DS_ID)
iterator = client.list_models(dataset_id, timeout=7.5)
page = next(iterator.pages)
models = list(page)
token = iterator.next_page_token

assert models == []
assert token is None
conn.api_request.assert_called_once_with(
method="GET", path=path, query_params={}, timeout=7.5
)

@dataset_polymorphic
def test_list_models_defaults(make_dataset, get_reference, client, PROJECT, DS_ID):
from google.cloud.bigquery.model import Model

MODEL_1 = "model_one"
MODEL_2 = "model_two"
PATH = "projects/%s/datasets/%s/models" % (PROJECT, DS_ID)
TOKEN = "TOKEN"
DATA = {
"nextPageToken": TOKEN,
"models": [
{
"modelReference": {
"modelId": MODEL_1,
"datasetId": DS_ID,
"projectId": PROJECT,
}
},
{
"modelReference": {
"modelId": MODEL_2,
"datasetId": DS_ID,
"projectId": PROJECT,
}
},
],
}

conn = client._connection = make_connection(DATA)
dataset = make_dataset(PROJECT, DS_ID)

iterator = client.list_models(dataset)
assert iterator.dataset == get_reference(dataset)
page = next(iterator.pages)
models = list(page)
token = iterator.next_page_token

assert len(models) == len(DATA["models"])
for found, expected in zip(models, DATA["models"]):
assert isinstance(found, Model)
assert found.model_id == expected["modelReference"]["modelId"]
assert token == TOKEN

conn.api_request.assert_called_once_with(
method="GET", path="/%s" % PATH, query_params={}, timeout=None
)

def test_list_models_wrong_type(client):
with pytest.raises(TypeError):
client.list_models(42)