diff --git a/.bumpversion.cfg b/.bumpversion.cfg index b0f5f6aab..26bd8a191 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -4,7 +4,7 @@ parse = (?P\d+) \.(?P\d+) \.(?P\d+) ((?Pa|b|rc)(?P\d+))? -serialize = +serialize = {major}.{minor}.{patch}{prerelease}{num} {major}.{minor}.{patch} commit = False @@ -13,7 +13,7 @@ tag = False [bumpversion:part:prerelease] first_value = a optional_value = final -values = +values = a b rc diff --git a/.flake8 b/.flake8 new file mode 100644 index 000000000..5e6fa8903 --- /dev/null +++ b/.flake8 @@ -0,0 +1,12 @@ +[flake8] +select = + E + W + F +ignore = + W503 # makes Flake8 work like black + W504 + E203 # makes Flake8 work like black + E741 + E501 +exclude = tests diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 0f793f232..beedab7ad 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -18,4 +18,4 @@ resolves # - [ ] I have signed the [CLA](https://docs.getdbt.com/docs/contributor-license-agreements) - [ ] I have run this code in development and it appears to resolve the stated issue - [ ] This PR includes tests, or tests are not required/relevant for this PR -- [ ] I have updated the `CHANGELOG.md` and added information about my change to the "dbt-bigquery next" section. \ No newline at end of file +- [ ] I have updated the `CHANGELOG.md` and added information about my change to the "dbt-bigquery next" section. diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 01c7dfba8..2097fded1 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -157,7 +157,7 @@ jobs: pip install tox pip --version tox --version - + - name: Install dbt-core latest run: | pip install "git+https://github.com/dbt-labs/dbt-core.git#egg=dbt-core&subdirectory=core" @@ -212,7 +212,7 @@ jobs: post-failure: runs-on: ubuntu-latest - needs: test + needs: test if: ${{ failure() }} steps: diff --git a/.github/workflows/jira-creation.yml b/.github/workflows/jira-creation.yml index c84e106a7..b4016befc 100644 --- a/.github/workflows/jira-creation.yml +++ b/.github/workflows/jira-creation.yml @@ -13,7 +13,7 @@ name: Jira Issue Creation on: issues: types: [opened, labeled] - + permissions: issues: write diff --git a/.github/workflows/jira-label.yml b/.github/workflows/jira-label.yml index fd533a170..3da2e3a38 100644 --- a/.github/workflows/jira-label.yml +++ b/.github/workflows/jira-label.yml @@ -13,7 +13,7 @@ name: Jira Label Mirroring on: issues: types: [labeled, unlabeled] - + permissions: issues: read @@ -24,4 +24,3 @@ jobs: JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }} JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }} JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }} - diff --git a/.github/workflows/jira-transition.yml b/.github/workflows/jira-transition.yml index 71273c7a9..ed9f9cd4f 100644 --- a/.github/workflows/jira-transition.yml +++ b/.github/workflows/jira-transition.yml @@ -21,4 +21,4 @@ jobs: secrets: JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }} JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }} - JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }} \ No newline at end of file + JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }} diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 23cdb4502..77e8e18e2 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -37,19 +37,10 @@ defaults: jobs: code-quality: - name: ${{ matrix.toxenv }} + name: code-quality runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - toxenv: [flake8] - - env: - TOXENV: ${{ matrix.toxenv }} - PYTEST_ADDOPTS: "-v --color=yes" - steps: - name: Check out the repository uses: actions/checkout@v2 @@ -62,12 +53,13 @@ jobs: - name: Install python dependencies run: | pip install --user --upgrade pip - pip install tox + pip install -r dev_requirements.txt pip --version - tox --version - - - name: Run tox - run: tox + pre-commit --version + mypy --version + dbt --version + - name: Run pre-comit hooks + run: pre-commit run --all-files --show-diff-on-failure unit: name: unit test / python ${{ matrix.python-version }} @@ -100,11 +92,9 @@ jobs: pip install tox pip --version tox --version - - name: Install dbt-core latest run: | pip install "git+https://github.com/dbt-labs/dbt-core.git#egg=dbt-core&subdirectory=core" - - name: Run tox run: tox @@ -140,7 +130,6 @@ jobs: pip install --user --upgrade pip pip install --upgrade setuptools wheel twine check-wheel-contents pip --version - - name: Build distributions run: ./scripts/build-dist.sh @@ -150,11 +139,9 @@ jobs: - name: Check distribution descriptions run: | twine check dist/* - - name: Check wheel contents run: | check-wheel-contents dist/*.whl --ignore W007,W008 - - uses: actions/upload-artifact@v2 with: name: dist @@ -184,7 +171,6 @@ jobs: pip install --user --upgrade pip pip install --upgrade wheel pip --version - - uses: actions/download-artifact@v2 with: name: dist @@ -196,15 +182,12 @@ jobs: - name: Install wheel distributions run: | find ./dist/*.whl -maxdepth 1 -type f | xargs pip install --force-reinstall --find-links=dist/ - - name: Check wheel distributions run: | dbt --version - - name: Install source distributions run: | find ./dist/*.gz -maxdepth 1 -type f | xargs pip install --force-reinstall --find-links=dist/ - - name: Check source distributions run: | dbt --version diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index 2848ce8f7..a1ca95861 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -14,4 +14,4 @@ jobs: # mark issues/PRs stale when they haven't seen activity in 180 days days-before-stale: 180 # ignore checking issues with the following labels - exempt-issue-labels: "epic, discussion" \ No newline at end of file + exempt-issue-labels: "epic, discussion" diff --git a/.github/workflows/version-bump.yml b/.github/workflows/version-bump.yml index 4913a6e84..b0a3174df 100644 --- a/.github/workflows/version-bump.yml +++ b/.github/workflows/version-bump.yml @@ -1,16 +1,16 @@ # **what?** # This workflow will take a version number and a dry run flag. With that -# it will run versionbump to update the version number everywhere in the +# it will run versionbump to update the version number everywhere in the # code base and then generate an update Docker requirements file. If this # is a dry run, a draft PR will open with the changes. If this isn't a dry # run, the changes will be committed to the branch this is run on. # **why?** -# This is to aid in releasing dbt and making sure we have updated +# This is to aid in releasing dbt and making sure we have updated # the versions and Docker requirements in all places. # **when?** -# This is triggered either manually OR +# This is triggered either manually OR # from the repository_dispatch event "version-bump" which is sent from # the dbt-release repo Action @@ -25,11 +25,11 @@ on: is_dry_run: description: 'Creates a draft PR to allow testing instead of committing to a branch' required: true - default: 'true' + default: 'true' repository_dispatch: types: [version-bump] -jobs: +jobs: bump: runs-on: ubuntu-latest steps: @@ -57,19 +57,19 @@ jobs: run: | python3 -m venv env source env/bin/activate - pip install --upgrade pip - + pip install --upgrade pip + - name: Create PR branch if: ${{ steps.variables.outputs.IS_DRY_RUN == 'true' }} run: | git checkout -b bumping-version/${{steps.variables.outputs.VERSION_NUMBER}}_$GITHUB_RUN_ID git push origin bumping-version/${{steps.variables.outputs.VERSION_NUMBER}}_$GITHUB_RUN_ID git branch --set-upstream-to=origin/bumping-version/${{steps.variables.outputs.VERSION_NUMBER}}_$GITHUB_RUN_ID bumping-version/${{steps.variables.outputs.VERSION_NUMBER}}_$GITHUB_RUN_ID - + - name: Bumping version run: | source env/bin/activate - pip install -r dev_requirements.txt + pip install -r dev_requirements.txt env/bin/bumpversion --allow-dirty --new-version ${{steps.variables.outputs.VERSION_NUMBER}} major git status @@ -99,4 +99,4 @@ jobs: draft: true base: ${{github.ref}} title: 'Bumping version to ${{steps.variables.outputs.VERSION_NUMBER}}' - branch: 'bumping-version/${{steps.variables.outputs.VERSION_NUMBER}}_${{GITHUB.RUN_ID}}' + branch: 'bumping-version/${{steps.variables.outputs.VERSION_NUMBER}}_${{GITHUB.RUN_ID}}' diff --git a/.gitignore b/.gitignore index 43724b61e..780d98f70 100644 --- a/.gitignore +++ b/.gitignore @@ -49,9 +49,7 @@ coverage.xml *,cover .hypothesis/ test.env - -# Mypy -.mypy_cache/ +*.pytest_cache/ # Translations *.mo @@ -66,10 +64,10 @@ docs/_build/ # PyBuilder target/ -#Ipython Notebook +# Ipython Notebook .ipynb_checkpoints -#Emacs +# Emacs *~ # Sublime Text @@ -78,6 +76,7 @@ target/ # Vim *.sw* +# Pyenv .python-version # Vim @@ -90,6 +89,7 @@ venv/ # AWS credentials .aws/ +# MacOS .DS_Store # vscode diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..9d247581b --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,53 @@ +# For more on configuring pre-commit hooks (see https://pre-commit.com/) + +# TODO: remove global exclusion of tests when testing overhaul is complete +exclude: '^tests/.*' + +default_language_version: + python: python3.8 + +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v3.2.0 + hooks: + - id: check-yaml + args: [--unsafe] + - id: check-json + - id: end-of-file-fixer + - id: trailing-whitespace + - id: check-case-conflict +- repo: https://github.com/psf/black + rev: 22.3.0 + hooks: + - id: black + args: + - "--line-length=99" + - "--target-version=py38" + - id: black + alias: black-check + stages: [manual] + args: + - "--line-length=99" + - "--target-version=py38" + - "--check" + - "--diff" +- repo: https://gitlab.com/pycqa/flake8 + rev: 4.0.1 + hooks: + - id: flake8 + - id: flake8 + alias: flake8-check + stages: [manual] +- repo: https://github.com/pre-commit/mirrors-mypy + rev: v0.782 + hooks: + - id: mypy + args: [--show-error-codes, --ignore-missing-imports] + files: ^dbt/adapters/.* + language: system + - id: mypy + alias: mypy-check + stages: [manual] + args: [--show-error-codes, --pretty, --ignore-missing-imports] + files: ^dbt/adapters + language: system diff --git a/CHANGELOG.md b/CHANGELOG.md index 7d17f1ebb..987bc79c6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,8 @@ ### Under the hood - Use dbt.tests.adapter.basic in tests (new test framework) ([#135](https://github.com/dbt-labs/dbt-bigquery/issues/135), [#142](https://github.com/dbt-labs/dbt-bigquery/pull/142)) +- Adding pre-commit and black formatter hooks ([#147](https://github.com/dbt-labs/dbt-bigquery/pull/147)) +- Adding pre-commit code changes ([#148](https://github.com/dbt-labs/dbt-bigquery/pull/148)) ## dbt-bigquery 1.1.0b1 (March 23, 2022) ### Features diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 504710c70..aec184e46 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -17,7 +17,7 @@ For those wishing to contribute we highly suggest reading the [dbt-core](https:/ Please note that all contributors to `dbt-bigquery` must sign the [Contributor License Agreement](https://docs.getdbt.com/docs/contributor-license-agreements) to have their Pull Request merged into an `dbt-bigquery` codebase. If you are unable to sign the CLA, then the `dbt-bigquery` maintainers will unfortunately be unable to merge your Pull Request. You are, however, welcome to open issues and comment on existing ones. -## Getting the code +## Getting the code You will need `git` in order to download and modify the `dbt-bigquery` source code. You can find direction [here](https://github.com/git-guides/install-git) on how to install `git`. @@ -93,7 +93,7 @@ Many changes will require and update to the `dbt-bigquery` docs here are some us ## Submitting a Pull Request -dbt Labs provides a CI environment to test changes to the `dbt-bigquery` adapter and periodic checks against the development version of `dbt-core` through Github Actions. +dbt Labs provides a CI environment to test changes to the `dbt-bigquery` adapter and periodic checks against the development version of `dbt-core` through Github Actions. A `dbt-bigquery` maintainer will review your PR. They may suggest code revision for style or clarity, or request that you add unit or integration test(s). These are good things! We believe that, with a little bit of help, anyone can contribute high-quality code. diff --git a/MANIFEST.in b/MANIFEST.in index 78412d5b8..cfbc714ed 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1 +1 @@ -recursive-include dbt/include *.sql *.yml *.md \ No newline at end of file +recursive-include dbt/include *.sql *.yml *.md diff --git a/dbt/adapters/bigquery/__init__.py b/dbt/adapters/bigquery/__init__.py index c85e79b38..b66ef278a 100644 --- a/dbt/adapters/bigquery/__init__.py +++ b/dbt/adapters/bigquery/__init__.py @@ -8,6 +8,5 @@ from dbt.include import bigquery Plugin = AdapterPlugin( - adapter=BigQueryAdapter, - credentials=BigQueryCredentials, - include_path=bigquery.PACKAGE_PATH) + adapter=BigQueryAdapter, credentials=BigQueryCredentials, include_path=bigquery.PACKAGE_PATH # type: ignore[arg-type] +) diff --git a/dbt/adapters/bigquery/__version__.py b/dbt/adapters/bigquery/__version__.py index a86cb5c59..56ec17a89 100644 --- a/dbt/adapters/bigquery/__version__.py +++ b/dbt/adapters/bigquery/__version__.py @@ -1 +1 @@ -version = '1.1.0b1' +version = "1.1.0b1" diff --git a/dbt/adapters/bigquery/column.py b/dbt/adapters/bigquery/column.py index 6c8a70df1..2422fcc0b 100644 --- a/dbt/adapters/bigquery/column.py +++ b/dbt/adapters/bigquery/column.py @@ -5,27 +5,27 @@ from google.cloud.bigquery import SchemaField -Self = TypeVar('Self', bound='BigQueryColumn') +Self = TypeVar("Self", bound="BigQueryColumn") @dataclass(init=False) class BigQueryColumn(Column): TYPE_LABELS = { - 'STRING': 'STRING', - 'TIMESTAMP': 'TIMESTAMP', - 'FLOAT': 'FLOAT64', - 'INTEGER': 'INT64', - 'RECORD': 'RECORD', + "STRING": "STRING", + "TIMESTAMP": "TIMESTAMP", + "FLOAT": "FLOAT64", + "INTEGER": "INT64", + "RECORD": "RECORD", } - fields: List[Self] - mode: str + fields: List[Self] # type: ignore + mode: str # type: ignore def __init__( self, column: str, dtype: str, fields: Optional[Iterable[SchemaField]] = None, - mode: str = 'NULLABLE', + mode: str = "NULLABLE", ) -> None: super().__init__(column, dtype) @@ -36,9 +36,7 @@ def __init__( self.mode = mode @classmethod - def wrap_subfields( - cls: Type[Self], fields: Iterable[SchemaField] - ) -> List[Self]: + def wrap_subfields(cls: Type[Self], fields: Iterable[SchemaField]) -> List[Self]: return [cls.create_from_field(field) for field in fields] @classmethod @@ -51,20 +49,18 @@ def create_from_field(cls: Type[Self], field: SchemaField) -> Self: ) @classmethod - def _flatten_recursive( - cls: Type[Self], col: Self, prefix: Optional[str] = None - ) -> List[Self]: + def _flatten_recursive(cls: Type[Self], col: Self, prefix: Optional[str] = None) -> List[Self]: if prefix is None: - prefix = [] + prefix = [] # type: ignore[assignment] if len(col.fields) == 0: - prefixed_name = ".".join(prefix + [col.column]) + prefixed_name = ".".join(prefix + [col.column]) # type: ignore[operator] new_col = cls(prefixed_name, col.dtype, col.fields, col.mode) return [new_col] new_fields = [] for field in col.fields: - new_prefix = prefix + [col.column] + new_prefix = prefix + [col.column] # type: ignore[operator] new_fields.extend(cls._flatten_recursive(field, new_prefix)) return new_fields @@ -74,54 +70,52 @@ def flatten(self): @property def quoted(self): - return '`{}`'.format(self.column) + return "`{}`".format(self.column) def literal(self, value): return "cast({} as {})".format(value, self.dtype) @property def data_type(self) -> str: - if self.dtype.upper() == 'RECORD': + if self.dtype.upper() == "RECORD": subcols = [ - "{} {}".format(col.name, col.data_type) for col in self.fields + "{} {}".format(col.name, col.data_type) for col in self.fields # type: ignore[attr-defined] ] - field_type = 'STRUCT<{}>'.format(", ".join(subcols)) + field_type = "STRUCT<{}>".format(", ".join(subcols)) else: field_type = self.dtype - if self.mode.upper() == 'REPEATED': - return 'ARRAY<{}>'.format(field_type) + if self.mode.upper() == "REPEATED": + return "ARRAY<{}>".format(field_type) else: return field_type def is_string(self) -> bool: - return self.dtype.lower() == 'string' + return self.dtype.lower() == "string" def is_integer(self) -> bool: - return self.dtype.lower() == 'int64' + return self.dtype.lower() == "int64" def is_numeric(self) -> bool: - return self.dtype.lower() == 'numeric' + return self.dtype.lower() == "numeric" def is_float(self): - return self.dtype.lower() == 'float64' + return self.dtype.lower() == "float64" - def can_expand_to(self: Self, other_column: Self) -> bool: + def can_expand_to(self: Self, other_column: Self) -> bool: # type: ignore """returns True if both columns are strings""" return self.is_string() and other_column.is_string() def __repr__(self) -> str: - return "".format(self.name, self.data_type, - self.mode) + return "".format(self.name, self.data_type, self.mode) def column_to_bq_schema(self) -> SchemaField: - """Convert a column to a bigquery schema object. - """ + """Convert a column to a bigquery schema object.""" kwargs = {} if len(self.fields) > 0: - fields = [field.column_to_bq_schema() for field in self.fields] + fields = [field.column_to_bq_schema() for field in self.fields] # type: ignore[attr-defined] kwargs = {"fields": fields} return SchemaField(self.name, self.dtype, self.mode, **kwargs) diff --git a/dbt/adapters/bigquery/connections.py b/dbt/adapters/bigquery/connections.py index ad49cfd9f..05f236a55 100644 --- a/dbt/adapters/bigquery/connections.py +++ b/dbt/adapters/bigquery/connections.py @@ -15,7 +15,7 @@ from google.auth import impersonated_credentials from google.oauth2 import ( credentials as GoogleCredentials, - service_account as GoogleServiceAccountCredentials + service_account as GoogleServiceAccountCredentials, ) from dbt.adapters.bigquery import gcloud @@ -24,7 +24,10 @@ from dbt.tracking import active_user from dbt.contracts.connection import ConnectionState, AdapterResponse from dbt.exceptions import ( - FailedToConnectException, RuntimeException, DatabaseException, DbtProfileError + FailedToConnectException, + RuntimeException, + DatabaseException, + DbtProfileError, ) from dbt.adapters.base import BaseConnectionManager, Credentials from dbt.events import AdapterLogger @@ -36,7 +39,7 @@ logger = AdapterLogger("BigQuery") -BQ_QUERY_JOB_SPLIT = '-----Query Job SQL Follows-----' +BQ_QUERY_JOB_SPLIT = "-----Query Job SQL Follows-----" WRITE_TRUNCATE = google.cloud.bigquery.job.WriteDisposition.WRITE_TRUNCATE @@ -69,15 +72,15 @@ def get_bigquery_defaults(scopes=None) -> Tuple[Any, Optional[str]]: class Priority(StrEnum): - Interactive = 'interactive' - Batch = 'batch' + Interactive = "interactive" + Batch = "batch" class BigQueryConnectionMethod(StrEnum): - OAUTH = 'oauth' - SERVICE_ACCOUNT = 'service-account' - SERVICE_ACCOUNT_JSON = 'service-account-json' - OAUTH_SECRETS = 'oauth-secrets' + OAUTH = "oauth" + SERVICE_ACCOUNT = "service-account" + SERVICE_ACCOUNT_JSON = "service-account-json" + OAUTH_SECRETS = "oauth-secrets" @dataclass @@ -90,7 +93,7 @@ class BigQueryCredentials(Credentials): method: BigQueryConnectionMethod # BigQuery allows an empty database / project, where it defers to the # environment for the project - database: Optional[str] + database: Optional[str] # type: ignore execution_project: Optional[str] = None location: Optional[str] = None priority: Optional[Priority] = None @@ -114,34 +117,44 @@ class BigQueryCredentials(Credentials): token_uri: Optional[str] = None scopes: Optional[Tuple[str, ...]] = ( - 'https://www.googleapis.com/auth/bigquery', - 'https://www.googleapis.com/auth/cloud-platform', - 'https://www.googleapis.com/auth/drive' + "https://www.googleapis.com/auth/bigquery", + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/drive", ) _ALIASES = { # 'legacy_name': 'current_name' - 'project': 'database', - 'dataset': 'schema', - 'target_project': 'target_database', - 'target_dataset': 'target_schema', - 'retries': 'job_retries', - 'timeout_seconds': 'job_execution_timeout_seconds', + "project": "database", + "dataset": "schema", + "target_project": "target_database", + "target_dataset": "target_schema", + "retries": "job_retries", + "timeout_seconds": "job_execution_timeout_seconds", } @property def type(self): - return 'bigquery' + return "bigquery" @property def unique_field(self): return self.database def _connection_keys(self): - return ('method', 'database', 'schema', 'location', 'priority', - 'timeout_seconds', 'maximum_bytes_billed', 'execution_project', - 'job_retry_deadline_seconds', 'job_retries', - 'job_creation_timeout_seconds', 'job_execution_timeout_seconds') + return ( + "method", + "database", + "schema", + "location", + "priority", + "timeout_seconds", + "maximum_bytes_billed", + "execution_project", + "job_retry_deadline_seconds", + "job_retries", + "job_creation_timeout_seconds", + "job_execution_timeout_seconds", + ) @classmethod def __pre_deserialize__(cls, d: Dict[Any, Any]) -> Dict[Any, Any]: @@ -150,17 +163,17 @@ def __pre_deserialize__(cls, d: Dict[Any, Any]) -> Dict[Any, Any]: # https://github.com/dbt-labs/dbt/pull/2908#discussion_r532927436. # `database` is an alias of `project` in BigQuery - if 'database' not in d: + if "database" not in d: _, database = get_bigquery_defaults() - d['database'] = database + d["database"] = database # `execution_project` default to dataset/project - if 'execution_project' not in d: - d['execution_project'] = d['database'] + if "execution_project" not in d: + d["execution_project"] = d["database"] return d class BigQueryConnectionManager(BaseConnectionManager): - TYPE = 'bigquery' + TYPE = "bigquery" QUERY_TIMEOUT = 300 RETRIES = 1 @@ -169,7 +182,7 @@ class BigQueryConnectionManager(BaseConnectionManager): @classmethod def handle_error(cls, error, message): - error_msg = "\n".join([item['message'] for item in error.errors]) + error_msg = "\n".join([item["message"] for item in error.errors]) raise DatabaseException(error_msg) def clear_transaction(self): @@ -189,12 +202,14 @@ def exception_handler(self, sql): self.handle_error(e, message) except google.auth.exceptions.RefreshError as e: - message = "Unable to generate access token, if you're using " \ - "impersonate_service_account, make sure your " \ - 'initial account has the "roles/' \ - 'iam.serviceAccountTokenCreator" role on the ' \ - 'account you are trying to impersonate.\n\n' \ - f'{str(e)}' + message = ( + "Unable to generate access token, if you're using " + "impersonate_service_account, make sure your " + 'initial account has the "roles/' + 'iam.serviceAccountTokenCreator" role on the ' + "account you are trying to impersonate.\n\n" + f"{str(e)}" + ) raise RuntimeException(message) except Exception as e: @@ -229,7 +244,7 @@ def commit(self): def format_bytes(self, num_bytes): if num_bytes: - for unit in ['Bytes', 'KB', 'MB', 'GB', 'TB', 'PB']: + for unit in ["Bytes", "KB", "MB", "GB", "TB", "PB"]: if abs(num_bytes) < 1024.0: return f"{num_bytes:3.1f} {unit}" num_bytes /= 1024.0 @@ -241,7 +256,7 @@ def format_bytes(self, num_bytes): return num_bytes def format_rows_number(self, rows_number): - for unit in ['', 'k', 'm', 'b', 't']: + for unit in ["", "k", "m", "b", "t"]: if abs(rows_number) < 1000.0: return f"{rows_number:3.1f}{unit}".strip() rows_number /= 1000.0 @@ -273,10 +288,10 @@ def get_bigquery_credentials(cls, profile_credentials): client_id=profile_credentials.client_id, client_secret=profile_credentials.client_secret, token_uri=profile_credentials.token_uri, - scopes=profile_credentials.scopes + scopes=profile_credentials.scopes, ) - error = ('Invalid `method` in profile: "{}"'.format(method)) + error = 'Invalid `method` in profile: "{}"'.format(method) raise FailedToConnectException(error) @classmethod @@ -292,14 +307,13 @@ def get_impersonated_bigquery_credentials(cls, profile_credentials): @classmethod def get_bigquery_client(cls, profile_credentials): if profile_credentials.impersonate_service_account: - creds =\ - cls.get_impersonated_bigquery_credentials(profile_credentials) + creds = cls.get_impersonated_bigquery_credentials(profile_credentials) else: creds = cls.get_bigquery_credentials(profile_credentials) execution_project = profile_credentials.execution_project - location = getattr(profile_credentials, 'location', None) + location = getattr(profile_credentials, "location", None) - info = client_info.ClientInfo(user_agent=f'dbt-{dbt_version}') + info = client_info.ClientInfo(user_agent=f"dbt-{dbt_version}") return google.cloud.bigquery.Client( execution_project, creds, @@ -309,8 +323,8 @@ def get_bigquery_client(cls, profile_credentials): @classmethod def open(cls, connection): - if connection.state == 'open': - logger.debug('Connection is already open, skipping open.') + if connection.state == "open": + logger.debug("Connection is already open, skipping open.") return connection try: @@ -323,16 +337,17 @@ def open(cls, connection): handle = cls.get_bigquery_client(connection.credentials) except Exception as e: - logger.debug("Got an error when attempting to create a bigquery " - "client: '{}'".format(e)) + logger.debug( + "Got an error when attempting to create a bigquery " "client: '{}'".format(e) + ) connection.handle = None - connection.state = 'fail' + connection.state = "fail" raise FailedToConnectException(str(e)) connection.handle = handle - connection.state = 'open' + connection.state = "open" return connection @classmethod @@ -373,29 +388,30 @@ def raw_execute(self, sql, fetch=False, *, use_legacy_sql=False): labels = {} if active_user: - labels['dbt_invocation_id'] = active_user.invocation_id + labels["dbt_invocation_id"] = active_user.invocation_id - job_params = {'use_legacy_sql': use_legacy_sql, 'labels': labels} + job_params = {"use_legacy_sql": use_legacy_sql, "labels": labels} priority = conn.credentials.priority if priority == Priority.Batch: - job_params['priority'] = google.cloud.bigquery.QueryPriority.BATCH + job_params["priority"] = google.cloud.bigquery.QueryPriority.BATCH else: - job_params[ - 'priority'] = google.cloud.bigquery.QueryPriority.INTERACTIVE + job_params["priority"] = google.cloud.bigquery.QueryPriority.INTERACTIVE maximum_bytes_billed = conn.credentials.maximum_bytes_billed if maximum_bytes_billed is not None and maximum_bytes_billed != 0: - job_params['maximum_bytes_billed'] = maximum_bytes_billed + job_params["maximum_bytes_billed"] = maximum_bytes_billed job_creation_timeout = self.get_job_creation_timeout_seconds(conn) job_execution_timeout = self.get_job_execution_timeout_seconds(conn) def fn(): return self._query_and_results( - client, sql, job_params, + client, + sql, + job_params, job_creation_timeout=job_creation_timeout, - job_execution_timeout=job_execution_timeout + job_execution_timeout=job_execution_timeout, ) query_job, iterator = self._retry_and_handle(msg=sql, conn=conn, fn=fn) @@ -414,84 +430,79 @@ def execute( else: table = agate_helper.empty_table() - message = 'OK' + message = "OK" code = None num_rows = None bytes_processed = None - if query_job.statement_type == 'CREATE_VIEW': - code = 'CREATE VIEW' + if query_job.statement_type == "CREATE_VIEW": + code = "CREATE VIEW" - elif query_job.statement_type == 'CREATE_TABLE_AS_SELECT': + elif query_job.statement_type == "CREATE_TABLE_AS_SELECT": conn = self.get_thread_connection() client = conn.handle query_table = client.get_table(query_job.destination) - code = 'CREATE TABLE' + code = "CREATE TABLE" num_rows = query_table.num_rows num_rows_formated = self.format_rows_number(num_rows) bytes_processed = query_job.total_bytes_processed processed_bytes = self.format_bytes(bytes_processed) - message = f'{code} ({num_rows_formated} rows, {processed_bytes} processed)' + message = f"{code} ({num_rows_formated} rows, {processed_bytes} processed)" - elif query_job.statement_type == 'SCRIPT': - code = 'SCRIPT' + elif query_job.statement_type == "SCRIPT": + code = "SCRIPT" bytes_processed = query_job.total_bytes_processed - message = f'{code} ({self.format_bytes(bytes_processed)} processed)' + message = f"{code} ({self.format_bytes(bytes_processed)} processed)" - elif query_job.statement_type in ['INSERT', 'DELETE', 'MERGE', 'UPDATE']: + elif query_job.statement_type in ["INSERT", "DELETE", "MERGE", "UPDATE"]: code = query_job.statement_type num_rows = query_job.num_dml_affected_rows num_rows_formated = self.format_rows_number(num_rows) bytes_processed = query_job.total_bytes_processed processed_bytes = self.format_bytes(bytes_processed) - message = f'{code} ({num_rows_formated} rows, {processed_bytes} processed)' + message = f"{code} ({num_rows_formated} rows, {processed_bytes} processed)" - elif query_job.statement_type == 'SELECT': + elif query_job.statement_type == "SELECT": conn = self.get_thread_connection() client = conn.handle # use anonymous table for num_rows query_table = client.get_table(query_job.destination) - code = 'SELECT' + code = "SELECT" num_rows = query_table.num_rows num_rows_formated = self.format_rows_number(num_rows) bytes_processed = query_job.total_bytes_processed processed_bytes = self.format_bytes(bytes_processed) - message = f'{code} ({num_rows_formated} rows, {processed_bytes} processed)' + message = f"{code} ({num_rows_formated} rows, {processed_bytes} processed)" - response = BigQueryAdapterResponse( - _message=message, - rows_affected=num_rows, - code=code, - bytes_processed=bytes_processed + response = BigQueryAdapterResponse( # type: ignore[call-arg] + _message=message, rows_affected=num_rows, code=code, bytes_processed=bytes_processed ) return response, table def get_partitions_metadata(self, table): def standard_to_legacy(table): - return table.project + ':' + table.dataset + '.' + table.identifier + return table.project + ":" + table.dataset + "." + table.identifier - legacy_sql = 'SELECT * FROM ['\ - + standard_to_legacy(table) + '$__PARTITIONS_SUMMARY__]' + legacy_sql = "SELECT * FROM [" + standard_to_legacy(table) + "$__PARTITIONS_SUMMARY__]" sql = self._add_query_comment(legacy_sql) # auto_begin is ignored on bigquery, and only included for consistency - _, iterator =\ - self.raw_execute(sql, fetch='fetch_result', use_legacy_sql=True) + _, iterator = self.raw_execute(sql, fetch="fetch_result", use_legacy_sql=True) return self.get_table_from_response(iterator) def copy_bq_table(self, source, destination, write_disposition): conn = self.get_thread_connection() client = conn.handle -# ------------------------------------------------------------------------------- -# BigQuery allows to use copy API using two different formats: -# 1. client.copy_table(source_table_id, destination_table_id) -# where source_table_id = "your-project.source_dataset.source_table" -# 2. client.copy_table(source_table_ids, destination_table_id) -# where source_table_ids = ["your-project.your_dataset.your_table_name", ...] -# Let's use uniform function call and always pass list there -# ------------------------------------------------------------------------------- + # ------------------------------------------------------------------------------- + # BigQuery allows to use copy API using two different formats: + # 1. client.copy_table(source_table_id, destination_table_id) + # where source_table_id = "your-project.source_dataset.source_table" + # 2. client.copy_table(source_table_ids, destination_table_id) + # where source_table_ids = ["your-project.your_dataset.your_table_name", ...] + # Let's use uniform function call and always pass list there + # ------------------------------------------------------------------------------- if type(source) is not list: source = [source] @@ -505,22 +516,24 @@ def copy_bq_table(self, source, destination, write_disposition): logger.debug( 'Copying table(s) "{}" to "{}" with disposition: "{}"', - ', '.join(source_ref.path for source_ref in source_ref_array), - destination_ref.path, write_disposition) + ", ".join(source_ref.path for source_ref in source_ref_array), + destination_ref.path, + write_disposition, + ) def copy_and_results(): - job_config = google.cloud.bigquery.CopyJobConfig( - write_disposition=write_disposition) - copy_job = client.copy_table( - source_ref_array, destination_ref, job_config=job_config) + job_config = google.cloud.bigquery.CopyJobConfig(write_disposition=write_disposition) + copy_job = client.copy_table(source_ref_array, destination_ref, job_config=job_config) iterator = copy_job.result(timeout=self.get_job_execution_timeout_seconds(conn)) return copy_job, iterator self._retry_and_handle( msg='copy table "{}" to "{}"'.format( - ', '.join(source_ref.path for source_ref in source_ref_array), - destination_ref.path), - conn=conn, fn=copy_and_results) + ", ".join(source_ref.path for source_ref in source_ref_array), destination_ref.path + ), + conn=conn, + fn=copy_and_results, + ) @staticmethod def dataset_ref(database, schema): @@ -545,8 +558,7 @@ def drop_dataset(self, database, schema): def fn(): return client.delete_dataset(dataset_ref, delete_contents=True, not_found_ok=True) - self._retry_and_handle( - msg='drop dataset', conn=conn, fn=fn) + self._retry_and_handle(msg="drop dataset", conn=conn, fn=fn) def create_dataset(self, database, schema): conn = self.get_thread_connection() @@ -555,30 +567,26 @@ def create_dataset(self, database, schema): def fn(): return client.create_dataset(dataset_ref, exists_ok=True) - self._retry_and_handle(msg='create dataset', conn=conn, fn=fn) + + self._retry_and_handle(msg="create dataset", conn=conn, fn=fn) def _query_and_results( - self, client, sql, job_params, - job_creation_timeout=None, - job_execution_timeout=None + self, client, sql, job_params, job_creation_timeout=None, job_execution_timeout=None ): """Query the client and wait for results.""" # Cannot reuse job_config if destination is set and ddl is used job_config = google.cloud.bigquery.QueryJobConfig(**job_params) - query_job = client.query( - query=sql, - job_config=job_config, - timeout=job_creation_timeout - ) + query_job = client.query(query=sql, job_config=job_config, timeout=job_creation_timeout) iterator = query_job.result(timeout=job_execution_timeout) return query_job, iterator def _retry_and_handle(self, msg, conn, fn): """retry a function call within the context of exception_handler.""" + def reopen_conn_on_error(error): if isinstance(error, REOPENABLE_ERRORS): - logger.warning('Reopening connection after {!r}'.format(error)) + logger.warning("Reopening connection after {!r}".format(error)) self.close(conn) self.open(conn) return @@ -589,19 +597,20 @@ def reopen_conn_on_error(error): predicate=_ErrorCounter(self.get_job_retries(conn)).count_error, sleep_generator=self._retry_generator(), deadline=self.get_job_retry_deadline_seconds(conn), - on_error=reopen_conn_on_error) + on_error=reopen_conn_on_error, + ) def _retry_generator(self): """Generates retry intervals that exponentially back off.""" return retry.exponential_sleep_generator( - initial=self.DEFAULT_INITIAL_DELAY, - maximum=self.DEFAULT_MAXIMUM_DELAY) + initial=self.DEFAULT_INITIAL_DELAY, maximum=self.DEFAULT_MAXIMUM_DELAY + ) def _labels_from_query_comment(self, comment: str) -> Dict: try: comment_labels = json.loads(comment) except (TypeError, ValueError): - return {'query_comment': _sanitize_label(comment)} + return {"query_comment": _sanitize_label(comment)} return { _sanitize_label(key): _sanitize_label(str(value)) for key, value in comment_labels.items() @@ -621,9 +630,10 @@ def count_error(self, error): self.error_count += 1 if _is_retryable(error) and self.error_count <= self.retries: logger.debug( - 'Retry attempt {} of {} after error: {}'.format( + "Retry attempt {} of {} after error: {}".format( self.error_count, self.retries, repr(error) - )) + ) + ) return True else: return False @@ -634,7 +644,8 @@ def _is_retryable(error): if isinstance(error, RETRYABLE_ERRORS): return True elif isinstance(error, google.api_core.exceptions.Forbidden) and any( - e['reason'] == 'rateLimitExceeded' for e in error.errors): + e["reason"] == "rateLimitExceeded" for e in error.errors + ): return True return False diff --git a/dbt/adapters/bigquery/gcloud.py b/dbt/adapters/bigquery/gcloud.py index 28e7e1a74..eb418e93b 100644 --- a/dbt/adapters/bigquery/gcloud.py +++ b/dbt/adapters/bigquery/gcloud.py @@ -14,7 +14,7 @@ def gcloud_installed(): try: - run_cmd('.', ['gcloud', '--version']) + run_cmd(".", ["gcloud", "--version"]) return True except OSError as e: logger.debug(e) @@ -23,6 +23,6 @@ def gcloud_installed(): def setup_default_credentials(): if gcloud_installed(): - run_cmd('.', ["gcloud", "auth", "application-default", "login"]) + run_cmd(".", ["gcloud", "auth", "application-default", "login"]) else: raise dbt.exceptions.RuntimeException(NOT_INSTALLED_MSG) diff --git a/dbt/adapters/bigquery/impl.py b/dbt/adapters/bigquery/impl.py index 03931aae9..0fc5fc1cc 100644 --- a/dbt/adapters/bigquery/impl.py +++ b/dbt/adapters/bigquery/impl.py @@ -6,10 +6,8 @@ import dbt.exceptions import dbt.clients.agate_helper -from dbt import ui -from dbt.adapters.base import ( - BaseAdapter, available, RelationType, SchemaSearchMap, AdapterConfig -) +from dbt import ui # type: ignore +from dbt.adapters.base import BaseAdapter, available, RelationType, SchemaSearchMap, AdapterConfig from dbt.adapters.bigquery.relation import BigQueryRelation from dbt.adapters.bigquery import BigQueryColumn from dbt.adapters.bigquery import BigQueryConnectionManager @@ -38,9 +36,7 @@ def sql_escape(string): if not isinstance(string, str): - dbt.exceptions.raise_compiler_exception( - f'cannot escape a non-string: {string}' - ) + dbt.exceptions.raise_compiler_exception(f"cannot escape a non-string: {string}") return json.dumps(string)[1:-1] @@ -48,25 +44,24 @@ def sql_escape(string): @dataclass class PartitionConfig(dbtClassMixin): field: str - data_type: str = 'date' - granularity: str = 'day' + data_type: str = "date" + granularity: str = "day" range: Optional[Dict[str, Any]] = None def render(self, alias: Optional[str] = None): column: str = self.field if alias: - column = f'{alias}.{self.field}' + column = f"{alias}.{self.field}" - if self.data_type.lower() == 'int64' or ( - self.data_type.lower() == 'date' and - self.granularity.lower() == 'day' + if self.data_type.lower() == "int64" or ( + self.data_type.lower() == "date" and self.granularity.lower() == "day" ): return column else: - return f'{self.data_type}_trunc({column}, {self.granularity})' + return f"{self.data_type}_trunc({column}, {self.granularity})" @classmethod - def parse(cls, raw_partition_by) -> Optional['PartitionConfig']: + def parse(cls, raw_partition_by) -> Optional["PartitionConfig"]: # type: ignore [return] if raw_partition_by is None: return None try: @@ -74,13 +69,11 @@ def parse(cls, raw_partition_by) -> Optional['PartitionConfig']: return cls.from_dict(raw_partition_by) except ValidationError as exc: msg = dbt.exceptions.validator_error_message(exc) - dbt.exceptions.raise_compiler_error( - f'Could not parse partition config: {msg}' - ) + dbt.exceptions.raise_compiler_error(f"Could not parse partition config: {msg}") except TypeError: dbt.exceptions.raise_compiler_error( - f'Invalid partition_by config:\n' - f' Got: {raw_partition_by}\n' + f"Invalid partition_by config:\n" + f" Got: {raw_partition_by}\n" f' Expected a dictionary with "field" and "data_type" keys' ) @@ -91,16 +84,12 @@ class GrantTarget(dbtClassMixin): project: str def render(self): - return f'{self.project}.{self.dataset}' + return f"{self.project}.{self.dataset}" def _stub_relation(*args, **kwargs): return BigQueryRelation.create( - database='', - schema='', - identifier='', - quote_policy={}, - type=BigQueryRelation.Table + database="", schema="", identifier="", quote_policy={}, type=BigQueryRelation.Table ) @@ -121,9 +110,9 @@ class BigqueryConfig(AdapterConfig): class BigQueryAdapter(BaseAdapter): RELATION_TYPES = { - 'TABLE': RelationType.Table, - 'VIEW': RelationType.View, - 'EXTERNAL': RelationType.External + "TABLE": RelationType.Table, + "VIEW": RelationType.View, + "EXTERNAL": RelationType.External, } Relation = BigQueryRelation @@ -138,14 +127,14 @@ class BigQueryAdapter(BaseAdapter): @classmethod def date_function(cls) -> str: - return 'CURRENT_TIMESTAMP()' + return "CURRENT_TIMESTAMP()" @classmethod def is_cancelable(cls) -> bool: return False def drop_relation(self, relation: BigQueryRelation) -> None: - is_cached = self._schema_is_cached(relation.database, relation.schema) + is_cached = self._schema_is_cached(relation.database, relation.schema) # type: ignore[arg-type] if is_cached: self.cache_dropped(relation) @@ -156,7 +145,7 @@ def drop_relation(self, relation: BigQueryRelation) -> None: def truncate_relation(self, relation: BigQueryRelation) -> None: raise dbt.exceptions.NotImplementedException( - '`truncate` is not implemented for this adapter!' + "`truncate` is not implemented for this adapter!" ) def rename_relation( @@ -168,11 +157,13 @@ def rename_relation( from_table_ref = self.get_table_ref_from_relation(from_relation) from_table = client.get_table(from_table_ref) - if from_table.table_type == "VIEW" or \ - from_relation.type == RelationType.View or \ - to_relation.type == RelationType.View: + if ( + from_table.table_type == "VIEW" + or from_relation.type == RelationType.View + or to_relation.type == RelationType.View + ): raise dbt.exceptions.RuntimeException( - 'Renaming of views is not currently supported in BigQuery' + "Renaming of views is not currently supported in BigQuery" ) to_table_ref = self.get_table_ref_from_relation(to_relation) @@ -185,18 +176,16 @@ def rename_relation( def list_schemas(self, database: str) -> List[str]: # the database string we get here is potentially quoted. Strip that off # for the API call. - database = database.strip('`') + database = database.strip("`") conn = self.connections.get_thread_connection() client = conn.handle def query_schemas(): # this is similar to how we have to deal with listing tables - all_datasets = client.list_datasets(project=database, - max_results=10000) + all_datasets = client.list_datasets(project=database, max_results=10000) return [ds.dataset_id for ds in all_datasets] - return self.connections._retry_and_handle( - msg='list dataset', conn=conn, fn=query_schemas) + return self.connections._retry_and_handle(msg="list dataset", conn=conn, fn=query_schemas) @available.parse(lambda *a, **k: False) def check_schema_exists(self, database: str, schema: str) -> bool: @@ -217,14 +206,10 @@ def check_schema_exists(self, database: str, schema: str) -> bool: return False return True - def get_columns_in_relation( - self, relation: BigQueryRelation - ) -> List[BigQueryColumn]: + def get_columns_in_relation(self, relation: BigQueryRelation) -> List[BigQueryColumn]: try: table = self.connections.get_bq_table( - database=relation.database, - schema=relation.schema, - identifier=relation.identifier + database=relation.database, schema=relation.schema, identifier=relation.identifier ) return self._get_dbt_columns_from_bq_table(table) @@ -232,9 +217,7 @@ def get_columns_in_relation( logger.debug("get_columns_in_relation error: {}".format(e)) return [] - def expand_column_types( - self, goal: BigQueryRelation, current: BigQueryRelation - ) -> None: + def expand_column_types(self, goal: BigQueryRelation, current: BigQueryRelation) -> None: # type: ignore[override] # This is a no-op on BigQuery pass @@ -265,7 +248,8 @@ def list_relations_without_caching( # see: https://github.com/dbt-labs/dbt/issues/726 # TODO: cache the list of relations up front, and then we # won't need to do this - max_results=100000) + max_results=100000, + ) # This will 404 if the dataset does not exist. This behavior mirrors # the implementation of list_relations for other adapters @@ -274,20 +258,14 @@ def list_relations_without_caching( except google.api_core.exceptions.NotFound: return [] except google.api_core.exceptions.Forbidden as exc: - logger.debug('list_relations_without_caching error: {}'.format(str(exc))) + logger.debug("list_relations_without_caching error: {}".format(str(exc))) return [] - def get_relation( - self, database: str, schema: str, identifier: str - ) -> BigQueryRelation: + def get_relation(self, database: str, schema: str, identifier: str) -> BigQueryRelation: if self._schema_is_cached(database, schema): # if it's in the cache, use the parent's model of going through # the relations cache and picking out the relation - return super().get_relation( - database=database, - schema=schema, - identifier=identifier - ) + return super().get_relation(database=database, schema=schema, identifier=identifier) try: table = self.connections.get_bq_table(database, schema, identifier) @@ -310,29 +288,23 @@ def drop_schema(self, relation: BigQueryRelation) -> None: @classmethod def quote(cls, identifier: str) -> str: - return '`{}`'.format(identifier) + return "`{}`".format(identifier) @classmethod def convert_text_type(cls, agate_table: agate.Table, col_idx: int) -> str: return "string" @classmethod - def convert_number_type( - cls, agate_table: agate.Table, col_idx: int - ) -> str: + def convert_number_type(cls, agate_table: agate.Table, col_idx: int) -> str: decimals = agate_table.aggregate(agate.MaxPrecision(col_idx)) return "float64" if decimals else "int64" @classmethod - def convert_boolean_type( - cls, agate_table: agate.Table, col_idx: int - ) -> str: + def convert_boolean_type(cls, agate_table: agate.Table, col_idx: int) -> str: return "bool" @classmethod - def convert_datetime_type( - cls, agate_table: agate.Table, col_idx: int - ) -> str: + def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str: return "datetime" @classmethod @@ -346,14 +318,14 @@ def convert_time_type(cls, agate_table: agate.Table, col_idx: int) -> str: ### # Implementation details ### - def _make_match_kwargs( - self, database: str, schema: str, identifier: str - ) -> Dict[str, str]: - return filter_null_values({ - 'database': database, - 'identifier': identifier, - 'schema': schema, - }) + def _make_match_kwargs(self, database: str, schema: str, identifier: str) -> Dict[str, str]: + return filter_null_values( + { + "database": database, + "identifier": identifier, + "schema": schema, + } + ) def _get_dbt_columns_from_bq_table(self, table) -> List[BigQueryColumn]: "Translates BQ SchemaField dicts into dbt BigQueryColumn objects" @@ -362,8 +334,7 @@ def _get_dbt_columns_from_bq_table(self, table) -> List[BigQueryColumn]: for col in table.schema: # BigQuery returns type labels that are not valid type specifiers dtype = self.Column.translate_type(col.field_type) - column = self.Column( - col.name, dtype, col.fields, col.mode) + column = self.Column(col.name, dtype, col.fields, col.mode) columns.append(column) return columns @@ -371,8 +342,7 @@ def _get_dbt_columns_from_bq_table(self, table) -> List[BigQueryColumn]: def _agate_to_schema( self, agate_table: agate.Table, column_override: Dict[str, str] ) -> List[SchemaField]: - """Convert agate.Table with column names to a list of bigquery schemas. - """ + """Convert agate.Table with column names to a list of bigquery schemas.""" bq_schema = [] for idx, col_name in enumerate(agate_table.column_names): inferred_type = self.convert_agate_type(agate_table, idx) @@ -381,17 +351,14 @@ def _agate_to_schema( return bq_schema def _materialize_as_view(self, model: Dict[str, Any]) -> str: - model_database = model.get('database') - model_schema = model.get('schema') - model_alias = model.get('alias') - model_sql = model.get('compiled_sql') + model_database = model.get("database") + model_schema = model.get("schema") + model_alias = model.get("alias") + model_sql = model.get("compiled_sql") logger.debug("Model SQL ({}):\n{}".format(model_alias, model_sql)) self.connections.create_view( - database=model_database, - schema=model_schema, - table_name=model_alias, - sql=model_sql + database=model_database, schema=model_schema, table_name=model_alias, sql=model_sql ) return "CREATE VIEW" @@ -401,9 +368,9 @@ def _materialize_as_table( model_sql: str, decorator: Optional[str] = None, ) -> str: - model_database = model.get('database') - model_schema = model.get('schema') - model_alias = model.get('alias') + model_database = model.get("database") + model_schema = model.get("schema") + model_alias = model.get("alias") if decorator is None: table_name = model_alias @@ -412,28 +379,25 @@ def _materialize_as_table( logger.debug("Model SQL ({}):\n{}".format(table_name, model_sql)) self.connections.create_table( - database=model_database, - schema=model_schema, - table_name=table_name, - sql=model_sql + database=model_database, schema=model_schema, table_name=table_name, sql=model_sql ) return "CREATE TABLE" - @available.parse(lambda *a, **k: '') + @available.parse(lambda *a, **k: "") def copy_table(self, source, destination, materialization): - if materialization == 'incremental': + if materialization == "incremental": write_disposition = WRITE_APPEND - elif materialization == 'table': + elif materialization == "table": write_disposition = WRITE_TRUNCATE else: dbt.exceptions.raise_compiler_error( 'Copy table materialization must be "copy" or "table", but ' f"config.get('copy_materialization', 'table') was " - f'{materialization}') + f"{materialization}" + ) - self.connections.copy_bq_table( - source, destination, write_disposition) + self.connections.copy_bq_table(source, destination, write_disposition) return "COPY TABLE with materialization: {}".format(materialization) @@ -441,18 +405,16 @@ def copy_table(self, source, destination, materialization): def poll_until_job_completes(cls, job, timeout): retry_count = timeout - while retry_count > 0 and job.state != 'DONE': + while retry_count > 0 and job.state != "DONE": retry_count -= 1 time.sleep(1) job.reload() - if job.state != 'DONE': + if job.state != "DONE": raise dbt.exceptions.RuntimeException("BigQuery Timeout Exceeded") elif job.error_result: - message = '\n'.join( - error['message'].strip() for error in job.errors - ) + message = "\n".join(error["message"].strip() for error in job.errors) raise dbt.exceptions.RuntimeException(message) def _bq_table_to_relation(self, bq_table): @@ -463,13 +425,8 @@ def _bq_table_to_relation(self, bq_table): database=bq_table.project, schema=bq_table.dataset_id, identifier=bq_table.table_id, - quote_policy={ - 'schema': True, - 'identifier': True - }, - type=self.RELATION_TYPES.get( - bq_table.table_type, RelationType.External - ), + quote_policy={"schema": True, "identifier": True}, + type=self.RELATION_TYPES.get(bq_table.table_type, RelationType.External), ) @classmethod @@ -479,21 +436,19 @@ def warning_on_hooks(hook_type): logger.info(warn_msg) @available - def add_query(self, sql, auto_begin=True, bindings=None, - abridge_sql_log=False): - if self.nice_connection_name() in ['on-run-start', 'on-run-end']: + def add_query(self, sql, auto_begin=True, bindings=None, abridge_sql_log=False): + if self.nice_connection_name() in ["on-run-start", "on-run-end"]: self.warning_on_hooks(self.nice_connection_name()) else: raise dbt.exceptions.NotImplementedException( - '`add_query` is not implemented for this adapter!') + "`add_query` is not implemented for this adapter!" + ) ### # Special bigquery adapter methods ### - def _partitions_match( - self, table, conf_partition: Optional[PartitionConfig] - ) -> bool: + def _partitions_match(self, table, conf_partition: Optional[PartitionConfig]) -> bool: """ Check if the actual and configured partitions for a table are a match. BigQuery tables can be replaced if: @@ -502,23 +457,27 @@ def _partitions_match( If there is a mismatch, then the table cannot be replaced directly. """ - is_partitioned = (table.range_partitioning or table.time_partitioning) + is_partitioned = table.range_partitioning or table.time_partitioning if not is_partitioned and not conf_partition: return True elif conf_partition and table.time_partitioning is not None: table_field = table.time_partitioning.field.lower() table_granularity = table.partitioning_type.lower() - return table_field == conf_partition.field.lower() \ + return ( + table_field == conf_partition.field.lower() and table_granularity == conf_partition.granularity.lower() + ) elif conf_partition and table.range_partitioning is not None: dest_part = table.range_partitioning conf_part = conf_partition.range or {} - return dest_part.field == conf_partition.field \ - and dest_part.range_.start == conf_part.get('start') \ - and dest_part.range_.end == conf_part.get('end') \ - and dest_part.range_.interval == conf_part.get('interval') + return ( + dest_part.field == conf_partition.field + and dest_part.range_.start == conf_part.get("start") + and dest_part.range_.end == conf_part.get("end") + and dest_part.range_.interval == conf_part.get("interval") + ) else: return False @@ -535,10 +494,7 @@ def _clusters_match(self, table, conf_cluster) -> bool: @available.parse(lambda *a, **k: True) def is_replaceable( - self, - relation, - conf_partition: Optional[PartitionConfig], - conf_cluster + self, relation, conf_partition: Optional[PartitionConfig], conf_cluster ) -> bool: """ Check if a given partition and clustering column spec for a table @@ -552,22 +508,20 @@ def is_replaceable( try: table = self.connections.get_bq_table( - database=relation.database, - schema=relation.schema, - identifier=relation.identifier + database=relation.database, schema=relation.schema, identifier=relation.identifier ) except google.cloud.exceptions.NotFound: return True - return all(( - self._partitions_match(table, conf_partition), - self._clusters_match(table, conf_cluster) - )) + return all( + ( + self._partitions_match(table, conf_partition), + self._clusters_match(table, conf_cluster), + ) + ) @available - def parse_partition_by( - self, raw_partition_by: Any - ) -> Optional[PartitionConfig]: + def parse_partition_by(self, raw_partition_by: Any) -> Optional[PartitionConfig]: """ dbt v0.16.0 expects `partition_by` to be a dictionary where previously it was a string. Check the type of `partition_by`, raise error @@ -576,11 +530,9 @@ def parse_partition_by( return PartitionConfig.parse(raw_partition_by) def get_table_ref_from_relation(self, relation): - return self.connections.table_ref( - relation.database, relation.schema, relation.identifier - ) + return self.connections.table_ref(relation.database, relation.schema, relation.identifier) - def _update_column_dict(self, bq_column_dict, dbt_columns, parent=''): + def _update_column_dict(self, bq_column_dict, dbt_columns, parent=""): """ Helper function to recursively traverse the schema of a table in the update_column_descriptions function below. @@ -589,28 +541,24 @@ def _update_column_dict(self, bq_column_dict, dbt_columns, parent=''): function of a SchemaField object. """ if parent: - dotted_column_name = '{}.{}'.format(parent, bq_column_dict['name']) + dotted_column_name = "{}.{}".format(parent, bq_column_dict["name"]) else: - dotted_column_name = bq_column_dict['name'] + dotted_column_name = bq_column_dict["name"] if dotted_column_name in dbt_columns: column_config = dbt_columns[dotted_column_name] - bq_column_dict['description'] = column_config.get('description') - if column_config.get('policy_tags'): - bq_column_dict['policyTags'] = { - 'names': column_config.get('policy_tags') - } + bq_column_dict["description"] = column_config.get("description") + if column_config.get("policy_tags"): + bq_column_dict["policyTags"] = {"names": column_config.get("policy_tags")} new_fields = [] - for child_col_dict in bq_column_dict.get('fields', list()): + for child_col_dict in bq_column_dict.get("fields", list()): new_child_column_dict = self._update_column_dict( - child_col_dict, - dbt_columns, - parent=dotted_column_name + child_col_dict, dbt_columns, parent=dotted_column_name ) new_fields.append(new_child_column_dict) - bq_column_dict['fields'] = new_fields + bq_column_dict["fields"] = new_fields return bq_column_dict @@ -626,14 +574,11 @@ def update_columns(self, relation, columns): new_schema = [] for bq_column in table.schema: bq_column_dict = bq_column.to_api_repr() - new_bq_column_dict = self._update_column_dict( - bq_column_dict, - columns - ) + new_bq_column_dict = self._update_column_dict(bq_column_dict, columns) new_schema.append(SchemaField.from_api_repr(new_bq_column_dict)) new_table = google.cloud.bigquery.Table(table_ref, schema=new_schema) - conn.handle.update_table(new_table, ['schema']) + conn.handle.update_table(new_table, ["schema"]) @available.parse_none def update_table_description( @@ -645,13 +590,12 @@ def update_table_description( table_ref = self.connections.table_ref(database, schema, identifier) table = client.get_table(table_ref) table.description = description - client.update_table(table, ['description']) + client.update_table(table, ["description"]) @available.parse_none def alter_table_add_columns(self, relation, columns): - logger.debug('Adding columns ({}) to table {}".'.format( - columns, relation)) + logger.debug('Adding columns ({}) to table {}".'.format(columns, relation)) conn = self.connections.get_thread_connection() client = conn.handle @@ -663,11 +607,10 @@ def alter_table_add_columns(self, relation, columns): new_schema = table.schema + new_columns new_table = google.cloud.bigquery.Table(table_ref, schema=new_schema) - client.update_table(new_table, ['schema']) + client.update_table(new_table, ["schema"]) @available.parse_none - def load_dataframe(self, database, schema, table_name, agate_table, - column_override): + def load_dataframe(self, database, schema, table_name, agate_table, column_override): bq_schema = self._agate_to_schema(agate_table, column_override) conn = self.connections.get_thread_connection() client = conn.handle @@ -679,43 +622,40 @@ def load_dataframe(self, database, schema, table_name, agate_table, load_config.schema = bq_schema with open(agate_table.original_abspath, "rb") as f: - job = client.load_table_from_file(f, table_ref, rewind=True, - job_config=load_config) + job = client.load_table_from_file(f, table_ref, rewind=True, job_config=load_config) timeout = self.connections.get_job_execution_timeout_seconds(conn) with self.connections.exception_handler("LOAD TABLE"): self.poll_until_job_completes(job, timeout) @available.parse_none - def upload_file(self, local_file_path: str, database: str, table_schema: str, - table_name: str, **kwargs) -> None: + def upload_file( + self, local_file_path: str, database: str, table_schema: str, table_name: str, **kwargs + ) -> None: conn = self.connections.get_thread_connection() client = conn.handle table_ref = self.connections.table_ref(database, table_schema, table_name) load_config = google.cloud.bigquery.LoadJobConfig() - for k, v in kwargs['kwargs'].items(): + for k, v in kwargs["kwargs"].items(): if k == "schema": setattr(load_config, k, json.loads(v)) else: setattr(load_config, k, v) with open(local_file_path, "rb") as f: - job = client.load_table_from_file(f, table_ref, rewind=True, - job_config=load_config) + job = client.load_table_from_file(f, table_ref, rewind=True, job_config=load_config) timeout = self.connections.get_job_execution_timeout_seconds(conn) with self.connections.exception_handler("LOAD TABLE"): self.poll_until_job_completes(job, timeout) @classmethod - def _catalog_filter_table( - cls, table: agate.Table, manifest: Manifest - ) -> agate.Table: - table = table.rename(column_names={ - col.name: col.name.replace('__', ':') for col in table.columns - }) + def _catalog_filter_table(cls, table: agate.Table, manifest: Manifest) -> agate.Table: + table = table.rename( + column_names={col.name: col.name.replace("__", ":") for col in table.columns} + ) return super()._catalog_filter_table(table, manifest) def _get_catalog_schemas(self, manifest: Manifest) -> SchemaSearchMap: @@ -726,13 +666,14 @@ def _get_catalog_schemas(self, manifest: Manifest) -> SchemaSearchMap: for candidate, schemas in candidates.items(): database = candidate.database if database not in db_schemas: - db_schemas[database] = set(self.list_schemas(database)) - if candidate.schema in db_schemas[database]: + db_schemas[database] = set(self.list_schemas(database)) # type: ignore[index] + if candidate.schema in db_schemas[database]: # type: ignore[index] result[candidate] = schemas else: logger.debug( - 'Skipping catalog for {}.{} - schema does not exist' - .format(database, candidate.schema) + "Skipping catalog for {}.{} - schema does not exist".format( + database, candidate.schema + ) ) return result @@ -742,19 +683,19 @@ def get_common_options( ) -> Dict[str, Any]: opts = {} - if (config.get('hours_to_expiration') is not None) and (not temporary): - expiration = ( - 'TIMESTAMP_ADD(CURRENT_TIMESTAMP(), INTERVAL ' - '{} hour)').format(config.get('hours_to_expiration')) - opts['expiration_timestamp'] = expiration + if (config.get("hours_to_expiration") is not None) and (not temporary): + expiration = ("TIMESTAMP_ADD(CURRENT_TIMESTAMP(), INTERVAL " "{} hour)").format( + config.get("hours_to_expiration") + ) + opts["expiration_timestamp"] = expiration - if config.persist_relation_docs() and 'description' in node: - description = sql_escape(node['description']) - opts['description'] = '"""{}"""'.format(description) + if config.persist_relation_docs() and "description" in node: # type: ignore[attr-defined] + description = sql_escape(node["description"]) + opts["description"] = '"""{}"""'.format(description) - if config.get('labels'): - labels = config.get('labels', {}) - opts['labels'] = list(labels.items()) + if config.get("labels"): + labels = config.get("labels", {}) + opts["labels"] = list(labels.items()) # type: ignore[assignment] return opts @@ -764,30 +705,28 @@ def get_table_options( ) -> Dict[str, Any]: opts = self.get_common_options(config, node, temporary) - if config.get('kms_key_name') is not None: - opts['kms_key_name'] = "'{}'".format(config.get('kms_key_name')) + if config.get("kms_key_name") is not None: + opts["kms_key_name"] = "'{}'".format(config.get("kms_key_name")) if temporary: - expiration = 'TIMESTAMP_ADD(CURRENT_TIMESTAMP(), INTERVAL 12 hour)' - opts['expiration_timestamp'] = expiration + expiration = "TIMESTAMP_ADD(CURRENT_TIMESTAMP(), INTERVAL 12 hour)" + opts["expiration_timestamp"] = expiration else: # It doesn't apply the `require_partition_filter` option for a temporary table # so that we avoid the error by not specifying a partition with a temporary table # in the incremental model. - if config.get('require_partition_filter') is not None and \ - config.get('partition_by') is not None: - opts['require_partition_filter'] = config.get( - 'require_partition_filter') - if config.get('partition_expiration_days') is not None: - opts['partition_expiration_days'] = config.get( - 'partition_expiration_days') + if ( + config.get("require_partition_filter") is not None + and config.get("partition_by") is not None + ): + opts["require_partition_filter"] = config.get("require_partition_filter") + if config.get("partition_expiration_days") is not None: + opts["partition_expiration_days"] = config.get("partition_expiration_days") return opts @available.parse(lambda *a, **k: {}) - def get_view_options( - self, config: Dict[str, Any], node: Dict[str, Any] - ) -> Dict[str, Any]: + def get_view_options(self, config: Dict[str, Any], node: Dict[str, Any]) -> Dict[str, Any]: opts = self.get_common_options(config, node) return opts @@ -804,27 +743,26 @@ def grant_access_to(self, entity, entity_type, role, grant_target_dict): dataset_ref = self.connections.dataset_ref(grant_target.project, grant_target.dataset) dataset = client.get_dataset(dataset_ref) - if entity_type == 'view': + if entity_type == "view": entity = self.get_table_ref_from_relation(entity).to_api_repr() access_entry = AccessEntry(role, entity_type, entity) access_entries = dataset.access_entries if access_entry in access_entries: - logger.debug(f"Access entry {access_entry} " - f"already exists in dataset") + logger.debug(f"Access entry {access_entry} " f"already exists in dataset") return access_entries.append(AccessEntry(role, entity_type, entity)) dataset.access_entries = access_entries - client.update_dataset(dataset, ['access_entries']) + client.update_dataset(dataset, ["access_entries"]) - def get_rows_different_sql( + def get_rows_different_sql( # type: ignore[override] self, relation_a: BigQueryRelation, relation_b: BigQueryRelation, column_names: Optional[List[str]] = None, - except_operator='EXCEPT DISTINCT' + except_operator="EXCEPT DISTINCT", ) -> str: return super().get_rows_different_sql( relation_a=relation_a, @@ -833,17 +771,18 @@ def get_rows_different_sql( except_operator=except_operator, ) - def timestamp_add_sql( - self, add_to: str, number: int = 1, interval: str = 'hour' - ) -> str: - return f'timestamp_add({add_to}, interval {number} {interval})' + def timestamp_add_sql(self, add_to: str, number: int = 1, interval: str = "hour") -> str: + return f"timestamp_add({add_to}, interval {number} {interval})" def string_add_sql( - self, add_to: str, value: str, location='append', + self, + add_to: str, + value: str, + location="append", ) -> str: - if location == 'append': + if location == "append": return f"concat({add_to}, '{value}')" - elif location == 'prepend': + elif location == "prepend": return f"concat('{value}', {add_to})" else: raise dbt.exceptions.RuntimeException( @@ -852,15 +791,15 @@ def string_add_sql( # This is used by the test suite def run_sql_for_tests(self, sql, fetch, conn=None): - """ For the testing framework. + """For the testing framework. Run an SQL query on a bigquery adapter. No cursors, transactions, etc. to worry about""" - do_fetch = fetch != 'None' + do_fetch = fetch != "None" _, res = self.execute(sql, fetch=do_fetch) # convert dataframe to matrix-ish repr - if fetch == 'one': + if fetch == "one": return res[0] else: return list(res) diff --git a/dbt/adapters/bigquery/relation.py b/dbt/adapters/bigquery/relation.py index 08f2c8f06..8156e360d 100644 --- a/dbt/adapters/bigquery/relation.py +++ b/dbt/adapters/bigquery/relation.py @@ -1,19 +1,17 @@ from dataclasses import dataclass from typing import Optional -from dbt.adapters.base.relation import ( - BaseRelation, ComponentName, InformationSchema -) +from dbt.adapters.base.relation import BaseRelation, ComponentName, InformationSchema from dbt.utils import filter_null_values from typing import TypeVar -Self = TypeVar('Self', bound='BigQueryRelation') +Self = TypeVar("Self", bound="BigQueryRelation") @dataclass(frozen=True, eq=False, repr=False) class BigQueryRelation(BaseRelation): - quote_character: str = '`' + quote_character: str = "`" def matches( self, @@ -21,11 +19,13 @@ def matches( schema: Optional[str] = None, identifier: Optional[str] = None, ) -> bool: - search = filter_null_values({ - ComponentName.Database: database, - ComponentName.Schema: schema, - ComponentName.Identifier: identifier - }) + search = filter_null_values( + { + ComponentName.Database: database, + ComponentName.Schema: schema, + ComponentName.Identifier: identifier, + } + ) if not search: # nothing was passed in @@ -45,24 +45,22 @@ def project(self): def dataset(self): return self.schema - def information_schema( - self, identifier: Optional[str] = None - ) -> 'BigQueryInformationSchema': + def information_schema(self, identifier: Optional[str] = None) -> "BigQueryInformationSchema": return BigQueryInformationSchema.from_relation(self, identifier) @dataclass(frozen=True, eq=False, repr=False) class BigQueryInformationSchema(InformationSchema): - quote_character: str = '`' + quote_character: str = "`" @classmethod def get_include_policy(cls, relation, information_schema_view): schema = True - if information_schema_view in ('SCHEMATA', 'SCHEMATA_OPTIONS', None): + if information_schema_view in ("SCHEMATA", "SCHEMATA_OPTIONS", None): schema = False identifier = True - if information_schema_view == '__TABLES__': + if information_schema_view == "__TABLES__": identifier = False return relation.include_policy.replace( @@ -71,10 +69,10 @@ def get_include_policy(cls, relation, information_schema_view): ) def replace(self, **kwargs): - if 'information_schema_view' in kwargs: - view = kwargs['information_schema_view'] + if "information_schema_view" in kwargs: + view = kwargs["information_schema_view"] # we also need to update the include policy, unless the caller did # in which case it's their problem - if 'include_policy' not in kwargs: - kwargs['include_policy'] = self.get_include_policy(self, view) + if "include_policy" not in kwargs: + kwargs["include_policy"] = self.get_include_policy(self, view) return super().replace(**kwargs) diff --git a/dbt/include/bigquery/__init__.py b/dbt/include/bigquery/__init__.py index 564a3d1e8..b177e5d49 100644 --- a/dbt/include/bigquery/__init__.py +++ b/dbt/include/bigquery/__init__.py @@ -1,2 +1,3 @@ import os + PACKAGE_PATH = os.path.dirname(__file__) diff --git a/dbt/include/bigquery/macros/adapters.sql b/dbt/include/bigquery/macros/adapters.sql index fa30922fe..f5a732d4d 100644 --- a/dbt/include/bigquery/macros/adapters.sql +++ b/dbt/include/bigquery/macros/adapters.sql @@ -136,14 +136,14 @@ {% endmacro %} {% macro bigquery__alter_relation_add_columns(relation, add_columns) %} - + {% set sql -%} - + alter {{ relation.type }} {{ relation }} {% for column in add_columns %} add column {{ column.name }} {{ column.data_type }}{{ ',' if not loop.last }} {% endfor %} - + {%- endset -%} {{ return(run_query(sql)) }} @@ -151,17 +151,17 @@ {% endmacro %} {% macro bigquery__alter_relation_drop_columns(relation, drop_columns) %} - + {% set sql -%} - + alter {{ relation.type }} {{ relation }} {% for column in drop_columns %} drop column {{ column.name }}{{ ',' if not loop.last }} {% endfor %} - + {%- endset -%} - + {{ return(run_query(sql)) }} {% endmacro %} @@ -198,11 +198,11 @@ {% macro bigquery__test_unique(model, column_name) %} with dbt_test__target as ( - + select {{ column_name }} as unique_field from {{ model }} where {{ column_name }} is not null - + ) select diff --git a/dbt/include/bigquery/macros/materializations/incremental.sql b/dbt/include/bigquery/macros/materializations/incremental.sql index 56811234f..f9d36aead 100644 --- a/dbt/include/bigquery/macros/materializations/incremental.sql +++ b/dbt/include/bigquery/macros/materializations/incremental.sql @@ -6,7 +6,7 @@ select max({{ partition_by.field }}) from {{ this }} where {{ partition_by.field }} is not null ); - + {% endif %} {% endmacro %} @@ -66,7 +66,7 @@ {# have we already created the temp table to check for schema changes? #} {% if not tmp_relation_exists %} {{ declare_dbt_max_partition(this, partition_by, sql) }} - + -- 1. create a temp table {{ create_table_as(True, tmp_relation, sql) }} {% else %} @@ -156,12 +156,12 @@ {% if existing_relation is none %} {% set build_sql = create_table_as(False, target_relation, sql) %} - + {% elif existing_relation.is_view %} {#-- There's no way to atomically replace a view with a table on BQ --#} {{ adapter.drop_relation(existing_relation) }} {% set build_sql = create_table_as(False, target_relation, sql) %} - + {% elif full_refresh_mode %} {#-- If the partition/cluster config has changed, then we must drop and recreate --#} {% if not adapter.is_replaceable(existing_relation, partition_by, cluster_by) %} @@ -169,7 +169,7 @@ {{ adapter.drop_relation(existing_relation) }} {% endif %} {% set build_sql = create_table_as(False, target_relation, sql) %} - + {% else %} {% set tmp_relation_exists = false %} {% if on_schema_change != 'ignore' %} {# Check first, since otherwise we may not build a temp table #} diff --git a/dev_requirements.txt b/dev_requirements.txt index e33fb051e..ff7b6522b 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -3,6 +3,7 @@ git+https://github.com/dbt-labs/dbt-core.git#egg=dbt-core&subdirectory=core git+https://github.com/dbt-labs/dbt-core.git#egg=dbt-tests-adapter&subdirectory=tests/adapter +black==22.3.0 bumpversion flake8 flaky @@ -10,6 +11,7 @@ freezegun==1.1.0 ipdb mypy==0.782 pip-tools +pre-commit pytest pytest-dotenv pytest-logbook diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 000000000..51fada1b1 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,3 @@ +[mypy] +mypy_path = ./third-party-stubs +namespace_packages = True diff --git a/scripts/build-dist.sh b/scripts/build-dist.sh index 65e6dbc97..3c3808399 100755 --- a/scripts/build-dist.sh +++ b/scripts/build-dist.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/bin/bash set -eo pipefail diff --git a/setup.py b/setup.py index 42e0d662f..3d1a95766 100644 --- a/setup.py +++ b/setup.py @@ -5,41 +5,39 @@ # require python 3.7 or newer if sys.version_info < (3, 7): - print('Error: dbt does not support this version of Python.') - print('Please upgrade to Python 3.7 or higher.') + print("Error: dbt does not support this version of Python.") + print("Please upgrade to Python 3.7 or higher.") sys.exit(1) # require version of setuptools that supports find_namespace_packages from setuptools import setup + try: from setuptools import find_namespace_packages except ImportError: # the user has a downlevel version of setuptools. - print('Error: dbt requires setuptools v40.1.0 or higher.') - print('Please upgrade setuptools with "pip install --upgrade setuptools" ' - 'and try again') + print("Error: dbt requires setuptools v40.1.0 or higher.") + print('Please upgrade setuptools with "pip install --upgrade setuptools" ' "and try again") sys.exit(1) # pull long description from README this_directory = os.path.abspath(os.path.dirname(__file__)) -with open(os.path.join(this_directory, 'README.md')) as f: +with open(os.path.join(this_directory, "README.md")) as f: long_description = f.read() # get this package's version from dbt/adapters//__version__.py def _get_plugin_version_dict(): - _version_path = os.path.join( - this_directory, 'dbt', 'adapters', 'bigquery', '__version__.py' - ) - _semver = r'''(?P\d+)\.(?P\d+)\.(?P\d+)''' - _pre = r'''((?Pa|b|rc)(?P
\d+))?'''
-    _version_pattern = fr'''version\s*=\s*["']{_semver}{_pre}["']'''
+    _version_path = os.path.join(this_directory, "dbt", "adapters", "bigquery", "__version__.py")
+    _semver = r"""(?P\d+)\.(?P\d+)\.(?P\d+)"""
+    _pre = r"""((?Pa|b|rc)(?P
\d+))?"""
+    _version_pattern = rf"""version\s*=\s*["']{_semver}{_pre}["']"""
     with open(_version_path) as f:
         match = re.search(_version_pattern, f.read().strip())
         if match is None:
-            raise ValueError(f'invalid version at {_version_path}')
+            raise ValueError(f"invalid version at {_version_path}")
         return match.groupdict()
 
 
@@ -47,7 +45,7 @@ def _get_plugin_version_dict():
 def _get_dbt_core_version():
     parts = _get_plugin_version_dict()
     minor = "{major}.{minor}.0".format(**parts)
-    pre = (parts["prekind"]+"1" if parts["prekind"] else "")
+    pre = parts["prekind"] + "1" if parts["prekind"] else ""
     return f"{minor}{pre}"
 
 
@@ -61,33 +59,30 @@ def _get_dbt_core_version():
     version=package_version,
     description=description,
     long_description=long_description,
-    long_description_content_type='text/markdown',
+    long_description_content_type="text/markdown",
     author="dbt Labs",
     author_email="info@dbtlabs.com",
     url="https://github.com/dbt-labs/dbt-bigquery",
-    packages=find_namespace_packages(include=['dbt', 'dbt.*']),
+    packages=find_namespace_packages(include=["dbt", "dbt.*"]),
     include_package_data=True,
     install_requires=[
-        'dbt-core~={}'.format(dbt_core_version),
-        'protobuf>=3.13.0,<4',
-        'google-cloud-core>=1.3.0,<3',
-        'google-cloud-bigquery>=1.25.0,<3',
-        'google-api-core>=1.16.0,<3',
-        'googleapis-common-protos>=1.6.0,<2',
+        "dbt-core~={}".format(dbt_core_version),
+        "protobuf>=3.13.0,<4",
+        "google-cloud-core>=1.3.0,<3",
+        "google-cloud-bigquery>=1.25.0,<3",
+        "google-api-core>=1.16.0,<3",
+        "googleapis-common-protos>=1.6.0,<2",
     ],
     zip_safe=False,
     classifiers=[
-        'Development Status :: 5 - Production/Stable',
-
-        'License :: OSI Approved :: Apache Software License',
-
-        'Operating System :: Microsoft :: Windows',
-        'Operating System :: MacOS :: MacOS X',
-        'Operating System :: POSIX :: Linux',
-
-        'Programming Language :: Python :: 3.7',
-        'Programming Language :: Python :: 3.8',
-        'Programming Language :: Python :: 3.9',
+        "Development Status :: 5 - Production/Stable",
+        "License :: OSI Approved :: Apache Software License",
+        "Operating System :: Microsoft :: Windows",
+        "Operating System :: MacOS :: MacOS X",
+        "Operating System :: POSIX :: Linux",
+        "Programming Language :: Python :: 3.7",
+        "Programming Language :: Python :: 3.8",
+        "Programming Language :: Python :: 3.9",
     ],
     python_requires=">=3.7",
 )
diff --git a/test.env.example b/test.env.example
index a69a35fe7..031968c60 100644
--- a/test.env.example
+++ b/test.env.example
@@ -1,3 +1,4 @@
 BIGQUERY_TEST_DATABASE=
 BIGQUERY_TEST_ALT_DATABASE=
-BIGQUERY_TEST_SERVICE_ACCOUNT_JSON='{}'
\ No newline at end of file
+BIGQUERY_TEST_NO_ACCESS_DATABASE=
+BIGQUERY_TEST_SERVICE_ACCOUNT_JSON='{}'
diff --git a/tox.ini b/tox.ini
index e73a436b2..4b23d97e2 100644
--- a/tox.ini
+++ b/tox.ini
@@ -1,15 +1,6 @@
 [tox]
 skipsdist = True
-envlist = py37,py38,py39,flake8
-
-[testenv:flake8]
-description = flake8 code checks
-basepython = python3.8
-skip_install = true
-commands = flake8 --select=E,W,F --ignore=W504,E741 --max-line-length 99 \
-  dbt
-deps =
-  -rdev_requirements.txt
+envlist = py37,py38,py39
 
 [testenv:{unit,py37,py38,py39,py}]
 description = unit testing