diff --git a/.github/boring-cyborg.yml b/.github/boring-cyborg.yml index 045a49b4f6e778..6410440739c13b 100644 --- a/.github/boring-cyborg.yml +++ b/.github/boring-cyborg.yml @@ -99,9 +99,7 @@ labelPRBasedOnFilePath: - providers/celery/** provider:cloudant: - - providers/src/airflow/providers/cloudant/**/* - - docs/apache-airflow-providers-cloudant/**/* - - providers/tests/cloudant/**/* + - providers/cloudant/** provider:cncf-kubernetes: - airflow/example_dags/example_kubernetes_executor.py diff --git a/airflow/new_provider.yaml.schema.json b/airflow/new_provider.yaml.schema.json index 83f98abb038cb9..4b5e16cedc0eb1 100644 --- a/airflow/new_provider.yaml.schema.json +++ b/airflow/new_provider.yaml.schema.json @@ -31,6 +31,13 @@ "removed" ] }, + "excluded-python-versions": { + "description": "List of python versions excluded for that provider", + "type": "array", + "items": { + "type": "string" + } + }, "integrations": { "description": "List of integrations supported by the provider.", "type": "array", diff --git a/dev/breeze/src/airflow_breeze/prepare_providers/provider_documentation.py b/dev/breeze/src/airflow_breeze/prepare_providers/provider_documentation.py index 6b705148c94cba..98936310c3c379 100644 --- a/dev/breeze/src/airflow_breeze/prepare_providers/provider_documentation.py +++ b/dev/breeze/src/airflow_breeze/prepare_providers/provider_documentation.py @@ -1231,6 +1231,7 @@ def _regenerate_pyproject_toml(context: dict[str, Any], provider_details: Provid trim_blocks=True, keep_trailing_newline=True, ) + get_pyproject_toml_path.write_text(get_pyproject_toml_content) get_console().print( f"[info]Generated {get_pyproject_toml_path} for the {provider_details.provider_id} provider\n" diff --git a/dev/breeze/src/airflow_breeze/templates/pyproject_TEMPLATE.toml.jinja2 b/dev/breeze/src/airflow_breeze/templates/pyproject_TEMPLATE.toml.jinja2 index 5da149fa0d542d..62560249e23315 100644 --- a/dev/breeze/src/airflow_breeze/templates/pyproject_TEMPLATE.toml.jinja2 +++ b/dev/breeze/src/airflow_breeze/templates/pyproject_TEMPLATE.toml.jinja2 @@ -68,7 +68,7 @@ classifiers = [ {% endfor %} "Topic :: System :: Monitoring", ] -requires-python = "~=3.9" +requires-python = "{{ REQUIRES_PYTHON }}" # The dependencies should be modified in place in the generated file # Any change in the dependencies is preserved when the file is regenerated diff --git a/dev/breeze/src/airflow_breeze/utils/packages.py b/dev/breeze/src/airflow_breeze/utils/packages.py index fec5ed898d6975..022f093871e371 100644 --- a/dev/breeze/src/airflow_breeze/utils/packages.py +++ b/dev/breeze/src/airflow_breeze/utils/packages.py @@ -32,6 +32,7 @@ from airflow_breeze.global_constants import ( ALLOWED_PYTHON_MAJOR_MINOR_VERSIONS, + DEFAULT_PYTHON_MAJOR_MINOR_VERSION, PROVIDER_DEPENDENCIES, PROVIDER_RUNTIME_DATA_SCHEMA_PATH, REGULAR_DOC_PACKAGES, @@ -788,6 +789,12 @@ def get_provider_jinja_context( p for p in ALLOWED_PYTHON_MAJOR_MINOR_VERSIONS if p not in provider_details.excluded_python_versions ] cross_providers_dependencies = get_cross_provider_dependent_packages(provider_package_id=provider_id) + + # Most providers require the same python versions, but some may have exclusions + requires_python_version: str = f"~={DEFAULT_PYTHON_MAJOR_MINOR_VERSION}" + for excluded_python_version in provider_details.excluded_python_versions: + requires_python_version += f",!={excluded_python_version}" + context: dict[str, Any] = { "PROVIDER_ID": provider_details.provider_id, "PACKAGE_PIP_NAME": get_pip_package_name(provider_details.provider_id), @@ -825,6 +832,7 @@ def get_provider_jinja_context( "PIP_REQUIREMENTS_TABLE_RST": convert_pip_requirements_to_table( get_provider_requirements(provider_id), markdown=False ), + "REQUIRES_PYTHON": requires_python_version, } return context diff --git a/docs/.gitignore b/docs/.gitignore index 2125d28bc1c316..9d9b64168546f6 100644 --- a/docs/.gitignore +++ b/docs/.gitignore @@ -18,6 +18,7 @@ apache-airflow-providers-apprise apache-airflow-providers-asana apache-airflow-providers-atlassian-jira apache-airflow-providers-celery +apache-airflow-providers-cloudant apache-airflow-providers-cohere apache-airflow-providers-common-compat apache-airflow-providers-common-io diff --git a/docs/apache-airflow-providers-cloudant/changelog.rst b/docs/apache-airflow-providers-cloudant/changelog.rst deleted file mode 100644 index d969e082c17b2e..00000000000000 --- a/docs/apache-airflow-providers-cloudant/changelog.rst +++ /dev/null @@ -1,25 +0,0 @@ - - .. Licensed to the Apache Software Foundation (ASF) under one - or more contributor license agreements. See the NOTICE file - distributed with this work for additional information - regarding copyright ownership. The ASF licenses this file - to you under the Apache License, Version 2.0 (the - "License"); you may not use this file except in compliance - with the License. You may obtain a copy of the License at - - .. http://www.apache.org/licenses/LICENSE-2.0 - - .. Unless required by applicable law or agreed to in writing, - software distributed under the License is distributed on an - "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - KIND, either express or implied. See the License for the - specific language governing permissions and limitations - under the License. - - .. NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE - OVERWRITTEN WHEN PREPARING PACKAGES. - - .. IF YOU WANT TO MODIFY THIS FILE, YOU SHOULD MODIFY THE TEMPLATE - `PROVIDER_CHANGELOG_TEMPLATE.rst.jinja2` IN the `dev/breeze/src/airflow_breeze/templates` DIRECTORY - -.. include:: ../../providers/src/airflow/providers/cloudant/CHANGELOG.rst diff --git a/docs/apache-airflow/howto/variable.rst b/docs/apache-airflow/howto/variable.rst index b4b395dd63c2cc..5e0017fb1c09b9 100644 --- a/docs/apache-airflow/howto/variable.rst +++ b/docs/apache-airflow/howto/variable.rst @@ -37,7 +37,7 @@ Storing Variables in Environment Variables Airflow Variables can also be created and managed using Environment Variables. The environment variable naming convention is :envvar:`AIRFLOW_VAR_{VARIABLE_NAME}`, all uppercase. -So if your variable key is ``FOO`` then the variable name should be ``AIRFLOW_VAR_FOO``. +So if your variable key is ``foo`` then the variable name should be ``AIRFLOW_VAR_FOO``. For example, diff --git a/docs/apache-airflow/start.rst b/docs/apache-airflow/start.rst index ea8d624d74d9b7..ff51325d0c842a 100644 --- a/docs/apache-airflow/start.rst +++ b/docs/apache-airflow/start.rst @@ -24,7 +24,7 @@ This quick start guide will help you bootstrap an Airflow standalone instance on .. note:: - Successful installation requires a Python 3 environment. Starting with Airflow 2.7.0, Airflow supports Python 3.9, 3.10, 3.11 and 3.12. + Successful installation requires a Python 3 environment. Starting with Airflow 2.7.0, Airflow supports Python 3.9, 3.10, 3.11, and 3.12. Only ``pip`` installation is currently officially supported. @@ -44,7 +44,27 @@ This quick start guide will help you bootstrap an Airflow standalone instance on The installation of Airflow is straightforward if you follow the instructions below. Airflow uses constraint files to enable reproducible installation, so using ``pip`` and constraint files is recommended. -1. Set Airflow Home (optional): +1. **(Recommended) Create and Activate a Virtual Environment**: + + To avoid issues such as the ``externally-managed-environment`` error, particularly on modern Linux distributions like Ubuntu 22.04+ and Debian 12+, it is highly recommended to install Airflow inside a Python virtual environment. This approach prevents conflicts with system-level Python packages and ensures smooth installation. + + For more details on this error, see the Python Packaging Authority's explanation in the `PEP 668 documentation `_. + + .. code-block:: bash + + # Create a virtual environment in your desired directory + python3 -m venv airflow_venv + + # Activate the virtual environment + source airflow_venv/bin/activate + + # Upgrade pip within the virtual environment + pip install --upgrade pip + + # Optional: Deactivate the virtual environment when done + deactivate + +2. **Set Airflow Home (optional)**: Airflow requires a home directory, and uses ``~/airflow`` by default, but you can set a different location if you prefer. The ``AIRFLOW_HOME`` environment variable is used to inform Airflow of the desired location. This step of setting the environment variable should be done before installing Airflow so that the installation process knows where to store the necessary files. @@ -52,7 +72,8 @@ constraint files to enable reproducible installation, so using ``pip`` and const export AIRFLOW_HOME=~/airflow -2. Install Airflow using the constraints file, which is determined based on the URL we pass: + +3. Install Airflow using the constraints file, which is determined based on the URL we pass: .. code-block:: bash :substitutions: @@ -69,7 +90,7 @@ constraint files to enable reproducible installation, so using ``pip`` and const pip install "apache-airflow==${AIRFLOW_VERSION}" --constraint "${CONSTRAINT_URL}" -3. Run Airflow Standalone: +4. Run Airflow Standalone: The ``airflow standalone`` command initializes the database, creates a user, and starts all components. @@ -77,7 +98,7 @@ constraint files to enable reproducible installation, so using ``pip`` and const airflow standalone -4. Access the Airflow UI: +5. Access the Airflow UI: Visit ``localhost:8080`` in your browser and log in with the admin account details shown in the terminal. Enable the ``example_bash_operator`` DAG in the home page. diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index a67e67bfd14333..787ced4abb28f4 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -352,7 +352,7 @@ "cloudant": { "deps": [ "apache-airflow>=2.9.0", - "ibmcloudant==0.9.1 ; python_version >= \"3.10\"" + "ibmcloudant==0.9.1;python_version>=\"3.10\"" ], "devel-deps": [], "plugins": [], diff --git a/providers/cloudant/README.rst b/providers/cloudant/README.rst new file mode 100644 index 00000000000000..84bef64ed2376e --- /dev/null +++ b/providers/cloudant/README.rst @@ -0,0 +1,62 @@ + + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + .. NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE OVERWRITTEN! + + .. IF YOU WANT TO MODIFY TEMPLATE FOR THIS FILE, YOU SHOULD MODIFY THE TEMPLATE + `PROVIDER_README_TEMPLATE.rst.jinja2` IN the `dev/breeze/src/airflow_breeze/templates` DIRECTORY + + +Package ``apache-airflow-providers-cloudant`` + +Release: ``4.1.0`` + + +`IBM Cloudant `__ + + +Provider package +---------------- + +This is a provider package for ``cloudant`` provider. All classes for this provider package +are in ``airflow.providers.cloudant`` python package. + +You can find package information and changelog for the provider +in the `documentation `_. + +Installation +------------ + +You can install this package on top of an existing Airflow 2 installation (see ``Requirements`` below +for the minimum Airflow version supported) via +``pip install apache-airflow-providers-cloudant`` + +The package supports the following python versions: 3.10,3.11,3.12 + +Requirements +------------ + +================== ===================================== +PIP package Version required +================== ===================================== +``apache-airflow`` ``>=2.9.0`` +``ibmcloudant`` ``==0.9.1; python_version >= "3.10"`` +================== ===================================== + +The changelog for the provider package can be found in the +`changelog `_. diff --git a/providers/src/airflow/providers/cloudant/.latest-doc-only-change.txt b/providers/cloudant/docs/.latest-doc-only-change.txt similarity index 100% rename from providers/src/airflow/providers/cloudant/.latest-doc-only-change.txt rename to providers/cloudant/docs/.latest-doc-only-change.txt diff --git a/providers/src/airflow/providers/cloudant/CHANGELOG.rst b/providers/cloudant/docs/changelog.rst similarity index 100% rename from providers/src/airflow/providers/cloudant/CHANGELOG.rst rename to providers/cloudant/docs/changelog.rst diff --git a/docs/apache-airflow-providers-cloudant/commits.rst b/providers/cloudant/docs/commits.rst similarity index 100% rename from docs/apache-airflow-providers-cloudant/commits.rst rename to providers/cloudant/docs/commits.rst diff --git a/docs/apache-airflow-providers-cloudant/index.rst b/providers/cloudant/docs/index.rst similarity index 100% rename from docs/apache-airflow-providers-cloudant/index.rst rename to providers/cloudant/docs/index.rst diff --git a/docs/apache-airflow-providers-cloudant/installing-providers-from-sources.rst b/providers/cloudant/docs/installing-providers-from-sources.rst similarity index 100% rename from docs/apache-airflow-providers-cloudant/installing-providers-from-sources.rst rename to providers/cloudant/docs/installing-providers-from-sources.rst diff --git a/docs/integration-logos/cloudant/Cloudant.png b/providers/cloudant/docs/integration-logos/Cloudant.png similarity index 100% rename from docs/integration-logos/cloudant/Cloudant.png rename to providers/cloudant/docs/integration-logos/Cloudant.png diff --git a/docs/apache-airflow-providers-cloudant/security.rst b/providers/cloudant/docs/security.rst similarity index 100% rename from docs/apache-airflow-providers-cloudant/security.rst rename to providers/cloudant/docs/security.rst diff --git a/providers/src/airflow/providers/cloudant/provider.yaml b/providers/cloudant/provider.yaml similarity index 86% rename from providers/src/airflow/providers/cloudant/provider.yaml rename to providers/cloudant/provider.yaml index dbbabf157a88ad..8a4944a2f6278c 100644 --- a/providers/src/airflow/providers/cloudant/provider.yaml +++ b/providers/cloudant/provider.yaml @@ -49,12 +49,6 @@ versions: - 1.0.1 - 1.0.0 -dependencies: - - apache-airflow>=2.9.0 - # Even though 3.9 is excluded below, we need to make this python_version aware so that `uv` can generate a - # full lock file when building lock file from provider sources - - 'ibmcloudant==0.9.1 ; python_version >= "3.10"' - excluded-python-versions: # ibmcloudant transitively brings in urllib3 2.x, but the snowflake provider has a dependency that pins # urllib3 to 1.x on Python 3.9; thus we exclude those Python versions from taking the update @@ -65,7 +59,7 @@ excluded-python-versions: integrations: - integration-name: IBM Cloudant external-doc-url: https://www.ibm.com/cloud/cloudant - logo: /integration-logos/cloudant/Cloudant.png + logo: /docs/integration-logos/Cloudant.png tags: [service] hooks: diff --git a/providers/cloudant/pyproject.toml b/providers/cloudant/pyproject.toml new file mode 100644 index 00000000000000..61f2ce3c0ead79 --- /dev/null +++ b/providers/cloudant/pyproject.toml @@ -0,0 +1,79 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE OVERWRITTEN! + +# IF YOU WANT TO MODIFY THIS FILE EXCEPT DEPENDENCIES, YOU SHOULD MODIFY THE TEMPLATE +# `pyproject_TEMPLATE.toml.jinja2` IN the `dev/breeze/src/airflow_breeze/templates` DIRECTORY +[build-system] +requires = ["flit_core==3.10.1"] +build-backend = "flit_core.buildapi" + +[project] +name = "apache-airflow-providers-cloudant" +version = "4.1.0" +description = "Provider package apache-airflow-providers-cloudant for Apache Airflow" +readme = "README.rst" +authors = [ + {name="Apache Software Foundation", email="dev@airflow.apache.org"}, +] +maintainers = [ + {name="Apache Software Foundation", email="dev@airflow.apache.org"}, +] +keywords = [ "airflow-provider", "cloudant", "airflow", "integration" ] +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Environment :: Console", + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Intended Audience :: System Administrators", + "Framework :: Apache Airflow", + "Framework :: Apache Airflow :: Provider", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: System :: Monitoring", +] +requires-python = "~=3.9,!=3.9" + +# The dependencies should be modified in place in the generated file +# Any change in the dependencies is preserved when the file is regenerated +dependencies = [ + "apache-airflow>=2.9.0", + # Even though 3.9 is excluded below, we need to make this python_version aware so that `uv` can generate a + # full lock file when building lock file from provider sources + "ibmcloudant==0.9.1;python_version>=\"3.10\"", +] + +[project.urls] +"Documentation" = "https://airflow.apache.org/docs/apache-airflow-providers-cloudant/4.1.0" +"Changelog" = "https://airflow.apache.org/docs/apache-airflow-providers-cloudant/4.1.0/changelog.html" +"Bug Tracker" = "https://github.com/apache/airflow/issues" +"Source Code" = "https://github.com/apache/airflow" +"Slack Chat" = "https://s.apache.org/airflow-slack" +"Twitter" = "https://x.com/ApacheAirflow" +"YouTube" = "https://www.youtube.com/channel/UCSXwxpWZQ7XZ1WL3wqevChA/" + +[project.entry-points."apache_airflow_provider"] +provider_info = "airflow.providers.cloudant.get_provider_info:get_provider_info" + +[tool.flit.module] +name = "airflow.providers.cloudant" + +[tool.pytest.ini_options] +ignore = "tests/system/" diff --git a/providers/cloudant/src/airflow/providers/cloudant/LICENSE b/providers/cloudant/src/airflow/providers/cloudant/LICENSE new file mode 100644 index 00000000000000..11069edd79019f --- /dev/null +++ b/providers/cloudant/src/airflow/providers/cloudant/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/providers/src/airflow/providers/cloudant/__init__.py b/providers/cloudant/src/airflow/providers/cloudant/__init__.py similarity index 100% rename from providers/src/airflow/providers/cloudant/__init__.py rename to providers/cloudant/src/airflow/providers/cloudant/__init__.py diff --git a/providers/src/airflow/providers/cloudant/cloudant_fake.py b/providers/cloudant/src/airflow/providers/cloudant/cloudant_fake.py similarity index 100% rename from providers/src/airflow/providers/cloudant/cloudant_fake.py rename to providers/cloudant/src/airflow/providers/cloudant/cloudant_fake.py diff --git a/providers/cloudant/src/airflow/providers/cloudant/get_provider_info.py b/providers/cloudant/src/airflow/providers/cloudant/get_provider_info.py new file mode 100644 index 00000000000000..b8935323687806 --- /dev/null +++ b/providers/cloudant/src/airflow/providers/cloudant/get_provider_info.py @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE OVERWRITTEN! +# +# IF YOU WANT TO MODIFY THIS FILE, YOU SHOULD MODIFY THE TEMPLATE +# `get_provider_info_TEMPLATE.py.jinja2` IN the `dev/breeze/src/airflow_breeze/templates` DIRECTORY + + +def get_provider_info(): + return { + "package-name": "apache-airflow-providers-cloudant", + "name": "IBM Cloudant", + "description": "`IBM Cloudant `__\n", + "state": "ready", + "source-date-epoch": 1734529058, + "versions": [ + "4.1.0", + "4.0.3", + "4.0.2", + "4.0.1", + "4.0.0", + "3.6.0", + "3.5.2", + "3.5.1", + "3.5.0", + "3.4.1", + "3.4.0", + "3.3.0", + "3.2.1", + "3.2.0", + "3.1.0", + "3.0.0", + "2.0.4", + "2.0.3", + "2.0.2", + "2.0.1", + "2.0.0", + "1.0.1", + "1.0.0", + ], + "excluded-python-versions": ["3.9"], + "integrations": [ + { + "integration-name": "IBM Cloudant", + "external-doc-url": "https://www.ibm.com/cloud/cloudant", + "logo": "/docs/integration-logos/Cloudant.png", + "tags": ["service"], + } + ], + "hooks": [ + { + "integration-name": "IBM Cloudant", + "python-modules": ["airflow.providers.cloudant.hooks.cloudant"], + } + ], + "connection-types": [ + { + "hook-class-name": "airflow.providers.cloudant.hooks.cloudant.CloudantHook", + "connection-type": "cloudant", + } + ], + "dependencies": ["apache-airflow>=2.9.0", 'ibmcloudant==0.9.1;python_version>="3.10"'], + } diff --git a/providers/src/airflow/providers/cloudant/hooks/__init__.py b/providers/cloudant/src/airflow/providers/cloudant/hooks/__init__.py similarity index 100% rename from providers/src/airflow/providers/cloudant/hooks/__init__.py rename to providers/cloudant/src/airflow/providers/cloudant/hooks/__init__.py diff --git a/providers/src/airflow/providers/cloudant/hooks/cloudant.py b/providers/cloudant/src/airflow/providers/cloudant/hooks/cloudant.py similarity index 100% rename from providers/src/airflow/providers/cloudant/hooks/cloudant.py rename to providers/cloudant/src/airflow/providers/cloudant/hooks/cloudant.py diff --git a/providers/cloudant/tests/conftest.py b/providers/cloudant/tests/conftest.py new file mode 100644 index 00000000000000..068fe6bbf5ae9a --- /dev/null +++ b/providers/cloudant/tests/conftest.py @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pathlib + +import pytest + +pytest_plugins = "tests_common.pytest_plugin" + + +@pytest.hookimpl(tryfirst=True) +def pytest_configure(config: pytest.Config) -> None: + deprecations_ignore_path = pathlib.Path(__file__).parent.joinpath("deprecations_ignore.yml") + dep_path = [deprecations_ignore_path] if deprecations_ignore_path.exists() else [] + config.inicfg["airflow_deprecations_ignore"] = ( + config.inicfg.get("airflow_deprecations_ignore", []) + dep_path # type: ignore[assignment,operator] + ) diff --git a/providers/cloudant/tests/provider_tests/__init__.py b/providers/cloudant/tests/provider_tests/__init__.py new file mode 100644 index 00000000000000..e8fd22856438c4 --- /dev/null +++ b/providers/cloudant/tests/provider_tests/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/providers/tests/cloudant/__init__.py b/providers/cloudant/tests/provider_tests/cloudant/__init__.py similarity index 100% rename from providers/tests/cloudant/__init__.py rename to providers/cloudant/tests/provider_tests/cloudant/__init__.py diff --git a/providers/tests/cloudant/hooks/__init__.py b/providers/cloudant/tests/provider_tests/cloudant/hooks/__init__.py similarity index 100% rename from providers/tests/cloudant/hooks/__init__.py rename to providers/cloudant/tests/provider_tests/cloudant/hooks/__init__.py diff --git a/providers/tests/cloudant/hooks/test_cloudant.py b/providers/cloudant/tests/provider_tests/cloudant/hooks/test_cloudant.py similarity index 100% rename from providers/tests/cloudant/hooks/test_cloudant.py rename to providers/cloudant/tests/provider_tests/cloudant/hooks/test_cloudant.py diff --git a/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py b/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py index ab1bc433d94a45..582e4abdb9e129 100644 --- a/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py +++ b/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +from collections.abc import Iterable, Mapping from functools import cached_property from typing import TYPE_CHECKING, Any from urllib import parse @@ -42,6 +43,73 @@ def connect( return ESConnection(host, port, user, password, scheme, **kwargs) +class ElasticsearchSQLCursor: + """A PEP 249-like Cursor class for Elasticsearch SQL API""" + + def __init__(self, es: Elasticsearch, **kwargs): + self.es = es + self.body = { + "fetch_size": kwargs.get("fetch_size", 1000), + "field_multi_value_leniency": kwargs.get("field_multi_value_leniency", False), + } + self._response: ObjectApiResponse | None = None + + @property + def response(self) -> ObjectApiResponse: + return self._response or {} # type: ignore + + @response.setter + def response(self, value): + self._response = value + + @property + def cursor(self): + return self.response.get("cursor") + + @property + def rows(self): + return self.response.get("rows", []) + + @property + def rowcount(self) -> int: + return len(self.rows) + + @property + def description(self) -> list[tuple]: + return [(column["name"], column["type"]) for column in self.response.get("columns", [])] + + def execute( + self, statement: str, params: Iterable | Mapping[str, Any] | None = None + ) -> ObjectApiResponse: + self.body["query"] = statement + if params: + self.body["params"] = params + self.response = self.es.sql.query(body=self.body) + if self.cursor: + self.body["cursor"] = self.cursor + else: + self.body.pop("cursor", None) + return self.response + + def fetchone(self): + if self.rows: + return self.rows[0] + return None + + def fetchmany(self, size: int | None = None): + raise NotImplementedError() + + def fetchall(self): + results = self.rows + while self.cursor: + self.execute(statement=self.body["query"]) + results.extend(self.rows) + return results + + def close(self): + self._response = None + + class ESConnection: """wrapper class for elasticsearch.Elasticsearch.""" @@ -67,9 +135,19 @@ def __init__( else: self.es = Elasticsearch(self.url, **self.kwargs) - def execute_sql(self, query: str) -> ObjectApiResponse: - sql_query = {"query": query} - return self.es.sql.query(body=sql_query) + def cursor(self) -> ElasticsearchSQLCursor: + return ElasticsearchSQLCursor(self.es, **self.kwargs) + + def close(self): + self.es.close() + + def commit(self): + pass + + def execute_sql( + self, query: str, params: Iterable | Mapping[str, Any] | None = None + ) -> ObjectApiResponse: + return self.cursor().execute(query, params) class ElasticsearchSQLHook(DbApiHook): @@ -84,13 +162,13 @@ class ElasticsearchSQLHook(DbApiHook): conn_name_attr = "elasticsearch_conn_id" default_conn_name = "elasticsearch_default" + connector = ESConnection conn_type = "elasticsearch" hook_name = "Elasticsearch" def __init__(self, schema: str = "http", connection: AirflowConnection | None = None, *args, **kwargs): super().__init__(*args, **kwargs) self.schema = schema - self._connection = connection def get_conn(self) -> ESConnection: """Return an elasticsearch connection object.""" @@ -104,11 +182,10 @@ def get_conn(self) -> ESConnection: "scheme": conn.schema or "http", } - if conn.extra_dejson.get("http_compress", False): - conn_args["http_compress"] = bool(["http_compress"]) + conn_args.update(conn.extra_dejson) - if conn.extra_dejson.get("timeout", False): - conn_args["timeout"] = conn.extra_dejson["timeout"] + if conn_args.get("http_compress", False): + conn_args["http_compress"] = bool(conn_args["http_compress"]) return connect(**conn_args) diff --git a/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py b/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py index ea34f2532de418..953e7dd50ef725 100644 --- a/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py +++ b/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py @@ -20,15 +20,42 @@ from unittest import mock from unittest.mock import MagicMock +import pytest from elasticsearch import Elasticsearch +from elasticsearch._sync.client import SqlClient +from kgb import SpyAgency from airflow.models import Connection +from airflow.providers.common.sql.hooks.handlers import fetch_all_handler from airflow.providers.elasticsearch.hooks.elasticsearch import ( ElasticsearchPythonHook, + ElasticsearchSQLCursor, ElasticsearchSQLHook, ESConnection, ) +ROWS = [ + [1, "Stallone", "Sylvester", "78"], + [2, "Statham", "Jason", "57"], + [3, "Li", "Jet", "61"], + [4, "Lundgren", "Dolph", "66"], + [5, "Norris", "Chuck", "84"], +] +RESPONSE_WITHOUT_CURSOR = { + "columns": [ + {"name": "index", "type": "long"}, + {"name": "name", "type": "text"}, + {"name": "firstname", "type": "text"}, + {"name": "age", "type": "long"}, + ], + "rows": ROWS, +} +RESPONSE = {**RESPONSE_WITHOUT_CURSOR, **{"cursor": "e7f8QwXUruW2mIebzudH4BwAA//8DAA=="}} +RESPONSES = [ + RESPONSE, + RESPONSE_WITHOUT_CURSOR, +] + class TestElasticsearchSQLHookConn: def setup_method(self): @@ -48,10 +75,68 @@ def test_get_conn(self, mock_connect): mock_connect.assert_called_with(host="localhost", port=9200, scheme="http", user=None, password=None) +class TestElasticsearchSQLCursor: + def setup_method(self): + sql = MagicMock(spec=SqlClient) + sql.query.side_effect = RESPONSES + self.es = MagicMock(sql=sql, spec=Elasticsearch) + + def test_execute(self): + cursor = ElasticsearchSQLCursor(es=self.es, options={}) + + assert cursor.execute("SELECT * FROM hollywood.actors") == RESPONSE + + def test_rowcount(self): + cursor = ElasticsearchSQLCursor(es=self.es, options={}) + cursor.execute("SELECT * FROM hollywood.actors") + + assert cursor.rowcount == len(ROWS) + + def test_description(self): + cursor = ElasticsearchSQLCursor(es=self.es, options={}) + cursor.execute("SELECT * FROM hollywood.actors") + + assert cursor.description == [ + ("index", "long"), + ("name", "text"), + ("firstname", "text"), + ("age", "long"), + ] + + def test_fetchone(self): + cursor = ElasticsearchSQLCursor(es=self.es, options={}) + cursor.execute("SELECT * FROM hollywood.actors") + + assert cursor.fetchone() == ROWS[0] + + def test_fetchmany(self): + cursor = ElasticsearchSQLCursor(es=self.es, options={}) + cursor.execute("SELECT * FROM hollywood.actors") + + with pytest.raises(NotImplementedError): + cursor.fetchmany() + + def test_fetchall(self): + cursor = ElasticsearchSQLCursor(es=self.es, options={}) + cursor.execute("SELECT * FROM hollywood.actors") + + records = cursor.fetchall() + + assert len(records) == 10 + assert records == ROWS + + class TestElasticsearchSQLHook: def setup_method(self): - self.cur = mock.MagicMock(rowcount=0) - self.conn = mock.MagicMock() + sql = MagicMock(spec=SqlClient) + sql.query.side_effect = RESPONSES + es = MagicMock(sql=sql, spec=Elasticsearch) + self.cur = ElasticsearchSQLCursor(es=es, options={}) + self.spy_agency = SpyAgency() + self.spy_agency.spy_on(self.cur.close, call_original=True) + self.spy_agency.spy_on(self.cur.execute, call_original=True) + self.spy_agency.spy_on(self.cur.fetchall, call_original=True) + self.conn = MagicMock(spec=ESConnection) self.conn.cursor.return_value = self.cur conn = self.conn @@ -64,55 +149,60 @@ def get_conn(self): self.db_hook = UnitTestElasticsearchSQLHook() def test_get_first_record(self): - statement = "SQL" - result_sets = [("row1",), ("row2",)] - self.cur.fetchone.return_value = result_sets[0] + statement = "SELECT * FROM hollywood.actors" + + assert self.db_hook.get_first(statement) == ROWS[0] - assert result_sets[0] == self.db_hook.get_first(statement) self.conn.close.assert_called_once_with() - self.cur.close.assert_called_once_with() - self.cur.execute.assert_called_once_with(statement) + self.spy_agency.assert_spy_called(self.cur.close) + self.spy_agency.assert_spy_called(self.cur.execute) def test_get_records(self): - statement = "SQL" - result_sets = [("row1",), ("row2",)] - self.cur.fetchall.return_value = result_sets + statement = "SELECT * FROM hollywood.actors" + + assert self.db_hook.get_records(statement) == ROWS - assert result_sets == self.db_hook.get_records(statement) self.conn.close.assert_called_once_with() - self.cur.close.assert_called_once_with() - self.cur.execute.assert_called_once_with(statement) + self.spy_agency.assert_spy_called(self.cur.close) + self.spy_agency.assert_spy_called(self.cur.execute) def test_get_pandas_df(self): - statement = "SQL" - column = "col" - result_sets = [("row1",), ("row2",)] - self.cur.description = [(column,)] - self.cur.fetchall.return_value = result_sets + statement = "SELECT * FROM hollywood.actors" df = self.db_hook.get_pandas_df(statement) - assert column == df.columns[0] + assert list(df.columns) == ["index", "name", "firstname", "age"] + assert df.values.tolist() == ROWS + + self.conn.close.assert_called_once_with() + self.spy_agency.assert_spy_called(self.cur.close) + self.spy_agency.assert_spy_called(self.cur.execute) + + def test_run(self): + statement = "SELECT * FROM hollywood.actors" - assert result_sets[0][0] == df.values.tolist()[0][0] - assert result_sets[1][0] == df.values.tolist()[1][0] + assert self.db_hook.run(statement, handler=fetch_all_handler) == ROWS - self.cur.execute.assert_called_once_with(statement) + self.conn.close.assert_called_once_with() + self.spy_agency.assert_spy_called(self.cur.close) + self.spy_agency.assert_spy_called(self.cur.execute) @mock.patch("airflow.providers.elasticsearch.hooks.elasticsearch.Elasticsearch") def test_execute_sql_query(self, mock_es): mock_es_sql_client = MagicMock() - mock_es_sql_client.query.return_value = { - "columns": [{"name": "id"}, {"name": "first_name"}], - "rows": [[1, "John"], [2, "Jane"]], - } + mock_es_sql_client.query.return_value = RESPONSE_WITHOUT_CURSOR mock_es.return_value.sql = mock_es_sql_client es_connection = ESConnection(host="localhost", port=9200) - response = es_connection.execute_sql("SELECT * FROM index1") - mock_es_sql_client.query.assert_called_once_with(body={"query": "SELECT * FROM index1"}) - - assert response["rows"] == [[1, "John"], [2, "Jane"]] - assert response["columns"] == [{"name": "id"}, {"name": "first_name"}] + response = es_connection.execute_sql("SELECT * FROM hollywood.actors") + mock_es_sql_client.query.assert_called_once_with( + body={ + "fetch_size": 1000, + "field_multi_value_leniency": False, + "query": "SELECT * FROM hollywood.actors", + } + ) + + assert response == RESPONSE_WITHOUT_CURSOR class MockElasticsearch: diff --git a/providers/microsoft/mssql/src/airflow/providers/microsoft/mssql/dialects/mssql.py b/providers/microsoft/mssql/src/airflow/providers/microsoft/mssql/dialects/mssql.py index edad1a11515d54..0c0ba72309a9aa 100644 --- a/providers/microsoft/mssql/src/airflow/providers/microsoft/mssql/dialects/mssql.py +++ b/providers/microsoft/mssql/src/airflow/providers/microsoft/mssql/dialects/mssql.py @@ -55,10 +55,15 @@ def generate_replace_sql(self, table, values, target_fields, **kwargs) -> str: self.log.debug("primary_keys: %s", primary_keys) self.log.debug("columns: %s", columns) - return f"""MERGE INTO {table} WITH (ROWLOCK) AS target + sql = f"""MERGE INTO {table} WITH (ROWLOCK) AS target USING (SELECT {', '.join(map(lambda column: f'{self.placeholder} AS {self.escape_word(column)}', target_fields))}) AS source - ON {' AND '.join(map(lambda column: f'target.{self.escape_word(column)} = source.{self.escape_word(column)}', primary_keys))} + ON {' AND '.join(map(lambda column: f'target.{self.escape_word(column)} = source.{self.escape_word(column)}', primary_keys))}""" + + if columns: + sql = f"""{sql} WHEN MATCHED THEN - UPDATE SET {', '.join(map(lambda column: f'target.{column} = source.{column}', columns))} + UPDATE SET {', '.join(map(lambda column: f'target.{column} = source.{column}', columns))}""" + + return f"""{sql} WHEN NOT MATCHED THEN INSERT ({', '.join(map(self.escape_word, target_fields))}) VALUES ({', '.join(map(lambda column: f'source.{self.escape_word(column)}', target_fields))});""" diff --git a/providers/microsoft/mssql/tests/provider_tests/microsoft/mssql/dialects/test_mssql.py b/providers/microsoft/mssql/tests/provider_tests/microsoft/mssql/dialects/test_mssql.py index 749a79c13fcd11..c584a15ba3b858 100644 --- a/providers/microsoft/mssql/tests/provider_tests/microsoft/mssql/dialects/test_mssql.py +++ b/providers/microsoft/mssql/tests/provider_tests/microsoft/mssql/dialects/test_mssql.py @@ -17,51 +17,119 @@ # under the License. from __future__ import annotations -from unittest.mock import MagicMock +import pytest -from sqlalchemy.engine import Inspector - -from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.providers.microsoft.mssql.dialects.mssql import MsSqlDialect class TestMsSqlDialect: - def setup_method(self): - inspector = MagicMock(spc=Inspector) - inspector.get_columns.side_effect = lambda table_name, schema: [ - {"name": "index", "identity": True}, - {"name": "name"}, - {"name": "firstname"}, - {"name": "age"}, - ] - self.test_db_hook = MagicMock(placeholder="?", inspector=inspector, spec=DbApiHook) - self.test_db_hook.run.side_effect = lambda *args: [("index",)] - self.test_db_hook.reserved_words = {"index", "user"} - self.test_db_hook.escape_word_format = "[{}]" - self.test_db_hook.escape_column_names = False - - def test_placeholder(self): - assert MsSqlDialect(self.test_db_hook).placeholder == "?" + @pytest.mark.parametrize( + "create_db_api_hook", + [ + ( + [ + {"name": "index", "identity": True}, + {"name": "name"}, + {"name": "firstname"}, + {"name": "age"}, + ], # columns + [("index",)], # primary_keys + {"index", "user"}, # reserved_words + False, # escape_column_names + ), + ], + indirect=True, + ) + def test_placeholder(self, create_db_api_hook): + assert MsSqlDialect(create_db_api_hook).placeholder == "?" - def test_get_column_names(self): - assert MsSqlDialect(self.test_db_hook).get_column_names("hollywood.actors") == [ + @pytest.mark.parametrize( + "create_db_api_hook", + [ + ( + [ + {"name": "index", "identity": True}, + {"name": "name"}, + {"name": "firstname"}, + {"name": "age"}, + ], # columns + [("index",)], # primary_keys + {"index", "user"}, # reserved_words + False, # escape_column_names + ), + ], + indirect=True, + ) + def test_get_column_names(self, create_db_api_hook): + assert MsSqlDialect(create_db_api_hook).get_column_names("hollywood.actors") == [ "index", "name", "firstname", "age", ] - def test_get_target_fields(self): - assert MsSqlDialect(self.test_db_hook).get_target_fields("hollywood.actors") == [ + @pytest.mark.parametrize( + "create_db_api_hook", + [ + ( + [ + {"name": "index", "identity": True}, + {"name": "name"}, + {"name": "firstname"}, + {"name": "age"}, + ], # columns + [("index",)], # primary_keys + {"index", "user"}, # reserved_words + False, # escape_column_names + ), + ], + indirect=True, + ) + def test_get_target_fields(self, create_db_api_hook): + assert MsSqlDialect(create_db_api_hook).get_target_fields("hollywood.actors") == [ "name", "firstname", "age", ] - def test_get_primary_keys(self): - assert MsSqlDialect(self.test_db_hook).get_primary_keys("hollywood.actors") == ["index"] + @pytest.mark.parametrize( + "create_db_api_hook", + [ + ( + [ + {"name": "index", "identity": True}, + {"name": "name"}, + {"name": "firstname"}, + {"name": "age"}, + ], # columns + [("index",)], # primary_keys + {"index", "user"}, # reserved_words + False, # escape_column_names + ), + ], + indirect=True, + ) + def test_get_primary_keys(self, create_db_api_hook): + assert MsSqlDialect(create_db_api_hook).get_primary_keys("hollywood.actors") == ["index"] - def test_generate_replace_sql(self): + @pytest.mark.parametrize( + "create_db_api_hook", + [ + ( + [ + {"name": "index", "identity": True}, + {"name": "name"}, + {"name": "firstname"}, + {"name": "age"}, + ], # columns + [("index",)], # primary_keys + {"index", "user"}, # reserved_words + False, # escape_column_names + ), + ], + indirect=True, + ) + def test_generate_replace_sql(self, create_db_api_hook): values = [ {"index": 1, "name": "Stallone", "firstname": "Sylvester", "age": "78"}, {"index": 2, "name": "Statham", "firstname": "Jason", "age": "57"}, @@ -70,7 +138,7 @@ def test_generate_replace_sql(self): {"index": 5, "name": "Norris", "firstname": "Chuck", "age": "84"}, ] target_fields = ["index", "name", "firstname", "age"] - sql = MsSqlDialect(self.test_db_hook).generate_replace_sql("hollywood.actors", values, target_fields) + sql = MsSqlDialect(create_db_api_hook).generate_replace_sql("hollywood.actors", values, target_fields) assert ( sql == """ @@ -84,7 +152,62 @@ def test_generate_replace_sql(self): """.strip() ) - def test_generate_replace_sql_when_escape_column_names_is_enabled(self): + @pytest.mark.parametrize( + "create_db_api_hook", + [ + ( + [ + {"name": "index", "identity": True}, + {"name": "name", "identity": True}, + {"name": "firstname", "identity": True}, + {"name": "age", "identity": True}, + ], # columns + [("index",), ("name",), ("firstname",), ("age",)], # primary_keys + {"index", "user"}, # reserved_words + False, # escape_column_names + ), + ], + indirect=True, + ) + def test_generate_replace_sql_when_all_columns_are_part_of_primary_key(self, create_db_api_hook): + values = [ + {"index": 1, "name": "Stallone", "firstname": "Sylvester", "age": "78"}, + {"index": 2, "name": "Statham", "firstname": "Jason", "age": "57"}, + {"index": 3, "name": "Li", "firstname": "Jet", "age": "61"}, + {"index": 4, "name": "Lundgren", "firstname": "Dolph", "age": "66"}, + {"index": 5, "name": "Norris", "firstname": "Chuck", "age": "84"}, + ] + target_fields = ["index", "name", "firstname", "age"] + sql = MsSqlDialect(create_db_api_hook).generate_replace_sql("hollywood.actors", values, target_fields) + assert ( + sql + == """ + MERGE INTO hollywood.actors WITH (ROWLOCK) AS target + USING (SELECT ? AS [index], ? AS name, ? AS firstname, ? AS age) AS source + ON target.[index] = source.[index] AND target.name = source.name AND target.firstname = source.firstname AND target.age = source.age + WHEN NOT MATCHED THEN + INSERT ([index], name, firstname, age) VALUES (source.[index], source.name, source.firstname, source.age); + """.strip() + ) + + @pytest.mark.parametrize( + "create_db_api_hook", + [ + ( + [ + {"name": "index", "identity": True}, + {"name": "name"}, + {"name": "firstname"}, + {"name": "age"}, + ], # columns + [("index",)], # primary_keys + {"index", "user"}, # reserved_words + True, # escape_column_names + ), + ], + indirect=True, + ) + def test_generate_replace_sql_when_escape_column_names_is_enabled(self, create_db_api_hook): values = [ {"index": 1, "name": "Stallone", "firstname": "Sylvester", "age": "78"}, {"index": 2, "name": "Statham", "firstname": "Jason", "age": "57"}, @@ -93,8 +216,7 @@ def test_generate_replace_sql_when_escape_column_names_is_enabled(self): {"index": 5, "name": "Norris", "firstname": "Chuck", "age": "84"}, ] target_fields = ["index", "name", "firstname", "age"] - self.test_db_hook.escape_column_names = True - sql = MsSqlDialect(self.test_db_hook).generate_replace_sql("hollywood.actors", values, target_fields) + sql = MsSqlDialect(create_db_api_hook).generate_replace_sql("hollywood.actors", values, target_fields) assert ( sql == """ diff --git a/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py b/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py index e7ca8d94fc3590..a6a90bc6ea5d67 100644 --- a/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py +++ b/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py @@ -22,20 +22,24 @@ import datetime import os import stat -from collections.abc import Sequence +import warnings +from collections.abc import Generator, Sequence +from contextlib import closing, contextmanager from fnmatch import fnmatch +from io import BytesIO from pathlib import Path from typing import TYPE_CHECKING, Any, Callable import asyncssh from asgiref.sync import sync_to_async -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.hooks.base import BaseHook from airflow.providers.ssh.hooks.ssh import SSHHook if TYPE_CHECKING: - import paramiko + from paramiko.sftp_attr import SFTPAttributes + from paramiko.sftp_client import SFTPClient from airflow.models.connection import Connection @@ -52,8 +56,6 @@ class SFTPHook(SSHHook): - In contrast with FTPHook describe_directory only returns size, type and modify. It doesn't return unix.owner, unix.mode, perm, unix.group and unique. - - retrieve_file and store_file only take a local full path and not a - buffer. - If no mode is passed to create_directory it will be created with 777 permissions. @@ -85,7 +87,22 @@ def __init__( *args, **kwargs, ) -> None: - self.conn: paramiko.SFTPClient | None = None + # TODO: remove support for ssh_hook when it is removed from SFTPOperator + if kwargs.get("ssh_hook") is not None: + warnings.warn( + "Parameter `ssh_hook` is deprecated and will be ignored.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + + ftp_conn_id = kwargs.pop("ftp_conn_id", None) + if ftp_conn_id: + warnings.warn( + "Parameter `ftp_conn_id` is deprecated. Please use `ssh_conn_id` instead.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + ssh_conn_id = ftp_conn_id kwargs["ssh_conn_id"] = ssh_conn_id kwargs["host_proxy_cmd"] = host_proxy_cmd @@ -93,17 +110,11 @@ def __init__( super().__init__(*args, **kwargs) - def get_conn(self) -> paramiko.SFTPClient: # type: ignore[override] - """Open an SFTP connection to the remote host.""" - if self.conn is None: - self.conn = super().get_conn().open_sftp() - return self.conn - - def close_conn(self) -> None: - """Close the SFTP connection.""" - if self.conn is not None: - self.conn.close() - self.conn = None + @contextmanager + def get_conn(self) -> Generator[SFTPClient, None, None]: + """Context manager that closes the connection after use.""" + with closing(super().get_conn().open_sftp()) as conn: + yield conn def describe_directory(self, path: str) -> dict[str, dict[str, str | int | None]]: """ @@ -114,17 +125,17 @@ def describe_directory(self, path: str) -> dict[str, dict[str, str | int | None] :param path: full path to the remote directory """ - conn = self.get_conn() - flist = sorted(conn.listdir_attr(path), key=lambda x: x.filename) - files = {} - for f in flist: - modify = datetime.datetime.fromtimestamp(f.st_mtime).strftime("%Y%m%d%H%M%S") # type: ignore - files[f.filename] = { - "size": f.st_size, - "type": "dir" if stat.S_ISDIR(f.st_mode) else "file", # type: ignore - "modify": modify, - } - return files + with self.get_conn() as conn: # type: SFTPClient + flist = sorted(conn.listdir_attr(path), key=lambda x: x.filename) + files = {} + for f in flist: + modify = datetime.datetime.fromtimestamp(f.st_mtime).strftime("%Y%m%d%H%M%S") # type: ignore + files[f.filename] = { + "size": f.st_size, + "type": "dir" if stat.S_ISDIR(f.st_mode) else "file", # type: ignore + "modify": modify, + } + return files def list_directory(self, path: str) -> list[str]: """ @@ -132,18 +143,17 @@ def list_directory(self, path: str) -> list[str]: :param path: full path to the remote directory to list """ - conn = self.get_conn() - files = sorted(conn.listdir(path)) - return files + with self.get_conn() as conn: + return sorted(conn.listdir(path)) - def list_directory_with_attr(self, path: str) -> list[paramiko.SFTPAttributes]: + def list_directory_with_attr(self, path: str) -> list[SFTPAttributes]: """ List files in a directory on the remote system including their SFTPAttributes. :param path: full path to the remote directory to list """ - conn = self.get_conn() - return [file for file in conn.listdir_attr(path)] + with self.get_conn() as conn: + return [file for file in conn.listdir_attr(path)] def mkdir(self, path: str, mode: int = 0o777) -> None: """ @@ -155,8 +165,8 @@ def mkdir(self, path: str, mode: int = 0o777) -> None: :param path: full path to the remote directory to create :param mode: int permissions of octal mode for directory """ - conn = self.get_conn() - conn.mkdir(path, mode=mode) + with self.get_conn() as conn: + conn.mkdir(path, mode=mode) def isdir(self, path: str) -> bool: """ @@ -164,12 +174,11 @@ def isdir(self, path: str) -> bool: :param path: full path to the remote directory to check """ - conn = self.get_conn() - try: - result = stat.S_ISDIR(conn.stat(path).st_mode) # type: ignore - except OSError: - result = False - return result + with self.get_conn() as conn: + try: + return stat.S_ISDIR(conn.stat(path).st_mode) # type: ignore + except OSError: + return False def isfile(self, path: str) -> bool: """ @@ -177,12 +186,11 @@ def isfile(self, path: str) -> bool: :param path: full path to the remote file to check """ - conn = self.get_conn() - try: - result = stat.S_ISREG(conn.stat(path).st_mode) # type: ignore - except OSError: - result = False - return result + with self.get_conn() as conn: + try: + return stat.S_ISREG(conn.stat(path).st_mode) # type: ignore + except OSError: + return False def create_directory(self, path: str, mode: int = 0o777) -> None: """ @@ -196,19 +204,19 @@ def create_directory(self, path: str, mode: int = 0o777) -> None: :param path: full path to the remote directory to create :param mode: int permissions of octal mode for directory """ - conn = self.get_conn() - if self.isdir(path): - self.log.info("%s already exists", path) - return - elif self.isfile(path): - raise AirflowException(f"{path} already exists and is a file") - else: - dirname, basename = os.path.split(path) - if dirname and not self.isdir(dirname): - self.create_directory(dirname, mode) - if basename: - self.log.info("Creating %s", path) - conn.mkdir(path, mode=mode) + with self.get_conn() as conn: + if self.isdir(path): + self.log.info("%s already exists", path) + return + elif self.isfile(path): + raise AirflowException(f"{path} already exists and is a file") + else: + dirname, basename = os.path.split(path) + if dirname and not self.isdir(dirname): + self.create_directory(dirname, mode) + if basename: + self.log.info("Creating %s", path) + conn.mkdir(path, mode=mode) def delete_directory(self, path: str) -> None: """ @@ -216,8 +224,8 @@ def delete_directory(self, path: str) -> None: :param path: full path to the remote directory to delete """ - conn = self.get_conn() - conn.rmdir(path) + with self.get_conn() as conn: + conn.rmdir(path) def retrieve_file(self, remote_full_path: str, local_full_path: str, prefetch: bool = True) -> None: """ @@ -227,11 +235,14 @@ def retrieve_file(self, remote_full_path: str, local_full_path: str, prefetch: b at that location. :param remote_full_path: full path to the remote file - :param local_full_path: full path to the local file + :param local_full_path: full path to the local file or a file-like buffer :param prefetch: controls whether prefetch is performed (default: True) """ - conn = self.get_conn() - conn.get(remote_full_path, local_full_path, prefetch=prefetch) + with self.get_conn() as conn: + if isinstance(local_full_path, BytesIO): + conn.getfo(remote_full_path, local_full_path, prefetch=prefetch) + else: + conn.get(remote_full_path, local_full_path, prefetch=prefetch) def store_file(self, remote_full_path: str, local_full_path: str, confirm: bool = True) -> None: """ @@ -241,10 +252,13 @@ def store_file(self, remote_full_path: str, local_full_path: str, confirm: bool from that location. :param remote_full_path: full path to the remote file - :param local_full_path: full path to the local file + :param local_full_path: full path to the local file or a file-like buffer """ - conn = self.get_conn() - conn.put(local_full_path, remote_full_path, confirm=confirm) + with self.get_conn() as conn: + if isinstance(local_full_path, BytesIO): + conn.putfo(local_full_path, remote_full_path, confirm=confirm) + else: + conn.put(local_full_path, remote_full_path, confirm=confirm) def delete_file(self, path: str) -> None: """ @@ -252,8 +266,8 @@ def delete_file(self, path: str) -> None: :param path: full path to the remote file """ - conn = self.get_conn() - conn.remove(path) + with self.get_conn() as conn: + conn.remove(path) def retrieve_directory(self, remote_full_path: str, local_full_path: str, prefetch: bool = True) -> None: """ @@ -306,9 +320,9 @@ def get_mod_time(self, path: str) -> str: :param path: full path to the remote file """ - conn = self.get_conn() - ftp_mdtm = conn.stat(path).st_mtime - return datetime.datetime.fromtimestamp(ftp_mdtm).strftime("%Y%m%d%H%M%S") # type: ignore + with self.get_conn() as conn: + ftp_mdtm = conn.stat(path).st_mtime + return datetime.datetime.fromtimestamp(ftp_mdtm).strftime("%Y%m%d%H%M%S") # type: ignore def path_exists(self, path: str) -> bool: """ @@ -316,12 +330,12 @@ def path_exists(self, path: str) -> bool: :param path: full path to the remote file or directory """ - conn = self.get_conn() - try: - conn.stat(path) - except OSError: - return False - return True + with self.get_conn() as conn: + try: + conn.stat(path) + except OSError: + return False + return True @staticmethod def _is_path_match(path: str, prefix: str | None = None, delimiter: str | None = None) -> bool: @@ -415,9 +429,9 @@ def append_matching_path_callback(list_: list[str]) -> Callable: def test_connection(self) -> tuple[bool, str]: """Test the SFTP connection by calling path with directory.""" try: - conn = self.get_conn() - conn.normalize(".") - return True, "Connection successfully tested" + with self.get_conn() as conn: + conn.normalize(".") + return True, "Connection successfully tested" except Exception as e: return False, str(e) @@ -432,7 +446,6 @@ def get_file_by_pattern(self, path, fnmatch_pattern) -> str: for file in self.list_directory(path): if fnmatch(file, fnmatch_pattern): return file - return "" def get_files_by_pattern(self, path, fnmatch_pattern) -> list[str]: @@ -600,17 +613,13 @@ async def get_mod_time(self, path: str) -> str: # type: ignore[return] :param path: full path to the remote file """ - ssh_conn = None - try: - ssh_conn = await self._get_conn() - sftp_client = await ssh_conn.start_sftp_client() - ftp_mdtm = await sftp_client.stat(path) - modified_time = ftp_mdtm.mtime - mod_time = datetime.datetime.fromtimestamp(modified_time).strftime("%Y%m%d%H%M%S") # type: ignore[arg-type] - self.log.info("Found File %s last modified: %s", str(path), str(mod_time)) - return mod_time - except asyncssh.SFTPNoSuchFile: - raise AirflowException("No files matching") - finally: - if ssh_conn: - ssh_conn.close() + async with await self._get_conn() as ssh_conn: + try: + sftp_client = await ssh_conn.start_sftp_client() + ftp_mdtm = await sftp_client.stat(path) + modified_time = ftp_mdtm.mtime + mod_time = datetime.datetime.fromtimestamp(modified_time).strftime("%Y%m%d%H%M%S") # type: ignore[arg-type] + self.log.info("Found File %s last modified: %s", str(path), str(mod_time)) + return mod_time + except asyncssh.SFTPNoSuchFile: + raise AirflowException("No files matching") diff --git a/providers/sftp/tests/provider_tests/sftp/hooks/test_sftp.py b/providers/sftp/tests/provider_tests/sftp/hooks/test_sftp.py index 54b27e450be9af..91632010d12e69 100644 --- a/providers/sftp/tests/provider_tests/sftp/hooks/test_sftp.py +++ b/providers/sftp/tests/provider_tests/sftp/hooks/test_sftp.py @@ -21,14 +21,15 @@ import json import os import shutil -from io import StringIO -from unittest import mock -from unittest.mock import AsyncMock, patch +from io import BytesIO, StringIO +from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch import paramiko import pytest from asyncssh import SFTPAttrs, SFTPNoSuchFile from asyncssh.sftp import SFTPName +from paramiko.client import SSHClient +from paramiko.sftp_client import SFTPClient from airflow.exceptions import AirflowException from airflow.models import Connection @@ -87,7 +88,10 @@ def setup_test_cases(self, tmp_path_factory): file.write("Test file") with open(os.path.join(temp_dir, TMP_DIR_FOR_TESTS, SUB_DIR, TMP_FILE_FOR_TESTS), "a") as file: file.write("Test file") - os.mkfifo(os.path.join(temp_dir, TMP_DIR_FOR_TESTS, FIFO_FOR_TESTS)) + try: + os.mkfifo(os.path.join(temp_dir, TMP_DIR_FOR_TESTS, FIFO_FOR_TESTS)) + except AttributeError: + os.makedirs(os.path.join(temp_dir, TMP_DIR_FOR_TESTS, FIFO_FOR_TESTS)) self.temp_dir = str(temp_dir) @@ -99,14 +103,20 @@ def setup_test_cases(self, tmp_path_factory): self.update_connection(self.old_login) def test_get_conn(self): - output = self.hook.get_conn() - assert isinstance(output, paramiko.SFTPClient) + with self.hook.get_conn() as conn: + assert isinstance(conn, paramiko.SFTPClient) - def test_close_conn(self): - self.hook.conn = self.hook.get_conn() - assert self.hook.conn is not None - self.hook.close_conn() - assert self.hook.conn is None + @patch("airflow.providers.ssh.hooks.ssh.SSHHook.get_conn") + def test_get_close_conn(self, mock_get_conn): + mock_sftp_client = MagicMock(spec=SFTPClient) + mock_ssh_client = MagicMock(spec=SSHClient) + mock_ssh_client.open_sftp.return_value = mock_sftp_client + mock_get_conn.return_value = mock_ssh_client + + with SFTPHook().get_conn() as conn: + assert conn == mock_sftp_client + + mock_sftp_client.close.assert_called_once() def test_describe_directory(self): output = self.hook.describe_directory(self.temp_dir) @@ -129,8 +139,9 @@ def test_mkdir(self): assert new_dir_name in output # test the directory has default permissions to 777 - umask umask = 0o022 - output = self.hook.get_conn().lstat(os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, new_dir_name)) - assert output.st_mode & 0o777 == 0o777 - umask + with self.hook.get_conn() as conn: + output = conn.lstat(os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, new_dir_name)) + assert output.st_mode & 0o777 == 0o777 - umask def test_create_and_delete_directory(self): new_dir_name = "new_dir" @@ -139,8 +150,9 @@ def test_create_and_delete_directory(self): assert new_dir_name in output # test the directory has default permissions to 777 umask = 0o022 - output = self.hook.get_conn().lstat(os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, new_dir_name)) - assert output.st_mode & 0o777 == 0o777 - umask + with self.hook.get_conn() as conn: + output = conn.lstat(os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, new_dir_name)) + assert output.st_mode & 0o777 == 0o777 - umask # test directory already exists for code coverage, should not raise an exception self.hook.create_directory(os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, new_dir_name)) # test path already exists and is a file, should raise an exception @@ -185,6 +197,24 @@ def test_store_retrieve_and_delete_file(self): output = self.hook.list_directory(path=os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS)) assert output == [SUB_DIR, FIFO_FOR_TESTS] + def test_store_retrieve_and_delete_file_using_buffer(self): + file_contents = BytesIO(b"Test file") + self.hook.store_file( + remote_full_path=os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS), + local_full_path=file_contents, + ) + output = self.hook.list_directory(path=os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS)) + assert output == [SUB_DIR, FIFO_FOR_TESTS, TMP_FILE_FOR_TESTS] + retrieved_file_contents = BytesIO() + self.hook.retrieve_file( + remote_full_path=os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS), + local_full_path=retrieved_file_contents, + ) + assert retrieved_file_contents.getvalue() == file_contents.getvalue() + self.hook.delete_file(path=os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS)) + output = self.hook.list_directory(path=os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS)) + assert output == [SUB_DIR, FIFO_FOR_TESTS] + def test_get_mod_time(self): self.hook.store_file( remote_full_path=os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS), @@ -195,14 +225,14 @@ def test_get_mod_time(self): ) assert len(output) == 14 - @mock.patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") + @patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") def test_no_host_key_check_default(self, get_connection): connection = Connection(login="login", host="host") get_connection.return_value = connection hook = SFTPHook() assert hook.no_host_key_check is True - @mock.patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") + @patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") def test_no_host_key_check_enabled(self, get_connection): connection = Connection(login="login", host="host", extra='{"no_host_key_check": true}') @@ -210,7 +240,7 @@ def test_no_host_key_check_enabled(self, get_connection): hook = SFTPHook() assert hook.no_host_key_check is True - @mock.patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") + @patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") def test_no_host_key_check_disabled(self, get_connection): connection = Connection(login="login", host="host", extra='{"no_host_key_check": false}') @@ -218,7 +248,7 @@ def test_no_host_key_check_disabled(self, get_connection): hook = SFTPHook() assert hook.no_host_key_check is False - @mock.patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") + @patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") def test_ciphers(self, get_connection): connection = Connection(login="login", host="host", extra='{"ciphers": ["A", "B", "C"]}') @@ -226,7 +256,7 @@ def test_ciphers(self, get_connection): hook = SFTPHook() assert hook.ciphers == ["A", "B", "C"] - @mock.patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") + @patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") def test_no_host_key_check_disabled_for_all_but_true(self, get_connection): connection = Connection(login="login", host="host", extra='{"no_host_key_check": "foo"}') @@ -234,7 +264,7 @@ def test_no_host_key_check_disabled_for_all_but_true(self, get_connection): hook = SFTPHook() assert hook.no_host_key_check is False - @mock.patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") + @patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") def test_no_host_key_check_ignore(self, get_connection): connection = Connection(login="login", host="host", extra='{"ignore_hostkey_verification": true}') @@ -242,14 +272,14 @@ def test_no_host_key_check_ignore(self, get_connection): hook = SFTPHook() assert hook.no_host_key_check is True - @mock.patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") + @patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") def test_host_key_default(self, get_connection): connection = Connection(login="login", host="host") get_connection.return_value = connection hook = SFTPHook() assert hook.host_key is None - @mock.patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") + @patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") def test_host_key(self, get_connection): connection = Connection( login="login", @@ -260,7 +290,7 @@ def test_host_key(self, get_connection): hook = SFTPHook() assert hook.host_key.get_base64() == TEST_HOST_KEY - @mock.patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") + @patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") def test_host_key_with_type(self, get_connection): connection = Connection( login="login", @@ -271,14 +301,14 @@ def test_host_key_with_type(self, get_connection): hook = SFTPHook() assert hook.host_key.get_base64() == TEST_HOST_KEY - @mock.patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") + @patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") def test_host_key_with_no_host_key_check(self, get_connection): connection = Connection(login="login", host="host", extra=json.dumps({"host_key": TEST_HOST_KEY})) get_connection.return_value = connection hook = SFTPHook() assert hook.host_key is not None - @mock.patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") + @patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") def test_key_content_as_str(self, get_connection): file_obj = StringIO() TEST_PKEY.write_private_key(file_obj) @@ -299,7 +329,7 @@ def test_key_content_as_str(self, get_connection): assert hook.pkey == TEST_PKEY assert hook.key_file is None - @mock.patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") + @patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") def test_key_file(self, get_connection): connection = Connection( login="login", @@ -356,37 +386,50 @@ def test_get_tree_map(self): assert dirs == [os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, SUB_DIR)] assert unknowns == [os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, FIFO_FOR_TESTS)] - @mock.patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") - def test_connection_failure(self, mock_get_connection): - connection = Connection( - login="login", - host="host", + @patch("airflow.providers.ssh.hooks.ssh.SSHHook.get_conn") + def test_connection_failure(self, mock_get_conn): + mock_ssh_client = MagicMock(spec=SSHClient) + type(mock_ssh_client.open_sftp.return_value).normalize = PropertyMock( + side_effect=Exception("Connection Error") ) - mock_get_connection.return_value = connection - with mock.patch.object(SFTPHook, "get_conn") as get_conn: - type(get_conn.return_value).normalize = mock.PropertyMock( - side_effect=Exception("Connection Error") - ) + mock_get_conn.return_value = mock_ssh_client + + hook = SFTPHook() + status, msg = hook.test_connection() - hook = SFTPHook() - status, msg = hook.test_connection() assert status is False assert msg == "Connection Error" - @mock.patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") - def test_connection_success(self, mock_get_connection): - connection = Connection( - login="login", - host="host", + @pytest.mark.parametrize( + "test_connection_side_effect", + [ + (lambda arg: (True, "Connection successfully tested")), + (lambda arg: RuntimeError("Test connection failed")), + ], + ) + @patch("airflow.providers.ssh.hooks.ssh.SSHHook.get_conn") + def test_context_manager(self, mock_get_conn, test_connection_side_effect): + mock_sftp_client = MagicMock(spec=SFTPClient) + mock_ssh_client = MagicMock(spec=SSHClient) + mock_ssh_client.open_sftp.return_value = mock_sftp_client + mock_get_conn.return_value = mock_ssh_client + + type(mock_sftp_client.normalize.return_value).normalize = PropertyMock( + side_effect=test_connection_side_effect ) - mock_get_connection.return_value = connection - with mock.patch.object(SFTPHook, "get_conn") as get_conn: - get_conn.return_value.pwd = "/home/someuser" - hook = SFTPHook() + hook = SFTPHook() + if isinstance(test_connection_side_effect, RuntimeError): + with pytest.raises(RuntimeError, match="Test connection failed"): + hook.test_connection() + else: status, msg = hook.test_connection() - assert status is True - assert msg == "Connection successfully tested" + + assert status is True + assert msg == "Connection successfully tested" + + mock_ssh_client.open_sftp.assert_called_once() + mock_sftp_client.close.assert_called() def test_get_suffix_pattern_match(self): output = self.hook.get_file_by_pattern(self.temp_dir, "*.txt") @@ -447,6 +490,38 @@ def test_store_and_retrieve_directory(self): ) assert retrieved_dir_name in os.listdir(os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS)) + @patch("paramiko.SSHClient") + @patch("paramiko.ProxyCommand") + def test_sftp_hook_with_proxy_command(self, mock_proxy_command, mock_ssh_client): + mock_sftp_client = MagicMock(spec=SFTPClient) + mock_ssh_client.open_sftp.return_value = mock_sftp_client + + mock_transport = MagicMock() + mock_ssh_client.return_value.get_transport.return_value = mock_transport + mock_proxy_command.return_value = MagicMock() + + host_proxy_cmd = "ncat --proxy-auth proxy_user:**** --proxy proxy_host:port %h %p" + + hook = SFTPHook( + remote_host="example.com", + username="user", + host_proxy_cmd=host_proxy_cmd, + ) + + with hook.get_conn(): + mock_proxy_command.assert_called_once_with(host_proxy_cmd) + mock_ssh_client.return_value.connect.assert_called_once_with( + hostname="example.com", + username="user", + timeout=None, + compress=True, + port=22, + sock=mock_proxy_command.return_value, + look_for_keys=True, + banner_timeout=30.0, + auth_timeout=None, + ) + class MockSFTPClient: def __init__(self): @@ -755,8 +830,9 @@ async def test_get_mod_time(self, mock_hook_get_conn): """ Assert that file attribute and return the modified time of the file """ - mock_hook_get_conn.return_value.start_sftp_client.return_value = MockSFTPClient() + mock_hook_get_conn.return_value.__aenter__.return_value = MockSSHClient() hook = SFTPHookAsync() + mod_time = await hook.get_mod_time("/path/exists/file") expected_value = datetime.datetime.fromtimestamp(1667302566).strftime("%Y%m%d%H%M%S") assert mod_time == expected_value @@ -767,36 +843,9 @@ async def test_get_mod_time_exception(self, mock_hook_get_conn): """ Assert that get_mod_time raise exception when file does not exist """ - mock_hook_get_conn.return_value.start_sftp_client.return_value = MockSFTPClient() + mock_hook_get_conn.return_value.__aenter__.return_value = MockSSHClient() hook = SFTPHookAsync() + with pytest.raises(AirflowException) as exc: await hook.get_mod_time("/path/does_not/exist/") assert str(exc.value) == "No files matching" - - @patch("paramiko.SSHClient") - @mock.patch("paramiko.ProxyCommand") - def test_sftp_hook_with_proxy_command(self, mock_proxy_command, mock_ssh_client): - mock_transport = mock.MagicMock() - mock_ssh_client.return_value.get_transport.return_value = mock_transport - mock_proxy_command.return_value = mock.MagicMock() - - host_proxy_cmd = "ncat --proxy-auth proxy_user:**** --proxy proxy_host:port %h %p" - hook = SFTPHook( - remote_host="example.com", - username="user", - host_proxy_cmd=host_proxy_cmd, - ) - hook.get_conn() - - mock_proxy_command.assert_called_once_with(host_proxy_cmd) - mock_ssh_client.return_value.connect.assert_called_once_with( - hostname="example.com", - username="user", - timeout=None, - compress=True, - port=22, - sock=mock_proxy_command.return_value, - look_for_keys=True, - banner_timeout=30.0, - auth_timeout=None, - ) diff --git a/pyproject.toml b/pyproject.toml index 8f024bf37b6b20..48e5ff48ef8b5d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -656,6 +656,7 @@ dev = [ "apache-airflow-providers-asana", "apache-airflow-providers-atlassian-jira", "apache-airflow-providers-celery", + "apache-airflow-providers-cloudant", "apache-airflow-providers-cohere", "apache-airflow-providers-common-compat", "apache-airflow-providers-common-io", @@ -744,6 +745,7 @@ apache-airflow-providers-apprise = { workspace = true } apache-airflow-providers-asana = { workspace = true } apache-airflow-providers-atlassian-jira = { workspace = true } apache-airflow-providers-celery = {workspace = true} +apache-airflow-providers-cloudant = { workspace = true } apache-airflow-providers-cohere = { workspace = true } apache-airflow-providers-common-compat = { workspace = true } apache-airflow-providers-common-io = { workspace = true } @@ -830,6 +832,7 @@ members = [ "providers/asana", "providers/atlassian/jira", "providers/celery", + "providers/cloudant", "providers/cohere", "providers/common/compat", "providers/common/io", diff --git a/scripts/ci/docker-compose/remove-sources.yml b/scripts/ci/docker-compose/remove-sources.yml index 056da8cf4fa7be..ec4f9955060de6 100644 --- a/scripts/ci/docker-compose/remove-sources.yml +++ b/scripts/ci/docker-compose/remove-sources.yml @@ -50,6 +50,7 @@ services: - ../../../empty:/opt/airflow/providers/asana/src - ../../../empty:/opt/airflow/providers/atlassian/jira/src - ../../../empty:/opt/airflow/providers/celery/src + - ../../../empty:/opt/airflow/providers/cloudant/src - ../../../empty:/opt/airflow/providers/cohere/src - ../../../empty:/opt/airflow/providers/common/compat/src - ../../../empty:/opt/airflow/providers/common/io/src diff --git a/scripts/ci/docker-compose/tests-sources.yml b/scripts/ci/docker-compose/tests-sources.yml index a1214bfb32229b..c26e5441f48c8b 100644 --- a/scripts/ci/docker-compose/tests-sources.yml +++ b/scripts/ci/docker-compose/tests-sources.yml @@ -57,6 +57,7 @@ services: - ../../../providers/asana/tests:/opt/airflow/providers/asana/tests - ../../../providers/atlassian/jira/tests:/opt/airflow/providers/atlassian/jira/tests - ../../../providers/celery/tests:/opt/airflow/providers/celery/tests + - ../../../providers/cloudant/tests:/opt/airflow/providers/cloudant/tests - ../../../providers/cohere/tests:/opt/airflow/providers/cohere/tests - ../../../providers/common/compat/tests:/opt/airflow/providers/common/compat/tests - ../../../providers/common/io/tests:/opt/airflow/providers/common/io/tests diff --git a/tests_common/pytest_plugin.py b/tests_common/pytest_plugin.py index 3698c4cc426546..732714debd9d6e 100644 --- a/tests_common/pytest_plugin.py +++ b/tests_common/pytest_plugin.py @@ -1635,3 +1635,25 @@ def url_safe_serializer(secret_key) -> URLSafeSerializer: from itsdangerous import URLSafeSerializer return URLSafeSerializer(secret_key) + + +@pytest.fixture +def create_db_api_hook(request): + from unittest.mock import MagicMock + + from sqlalchemy.engine import Inspector + + from airflow.providers.common.sql.hooks.sql import DbApiHook + + columns, primary_keys, reserved_words, escape_column_names = request.param + + inspector = MagicMock(spec=Inspector) + inspector.get_columns.side_effect = lambda table_name, schema: columns + + test_db_hook = MagicMock(placeholder="?", inspector=inspector, spec=DbApiHook) + test_db_hook.run.side_effect = lambda *args: primary_keys + test_db_hook.reserved_words = reserved_words + test_db_hook.escape_word_format = "[{}]" + test_db_hook.escape_column_names = escape_column_names or False + + return test_db_hook