Skip to content

Commit

Permalink
[BUGFIX] Pass verify (and other transport args) when creating the a…
Browse files Browse the repository at this point in the history
…rgilla client (#5789)

# Description
<!-- Please include a summary of the changes and the related issue.
Please also include relevant motivation and context. List any
dependencies that are required for this change. -->

Closes #5548

**Type of change**
<!-- Please delete options that are not relevant. Remember to title the
PR according to the type of change -->

- Bug fix (non-breaking change which fixes an issue)
- Improvement (change adding some improvement to an existing
functionality)

**How Has This Been Tested**
<!-- Please add some reference about how your feature has been tested.
-->

**Checklist**
<!-- Please go over the list and make sure you've taken everything into
account -->

- I added relevant documentation
- I followed the style guidelines of this project
- I did a self-review of my code
- I made corresponding changes to the documentation
- I confirm My changes generate no new warnings
- I have added tests that prove my fix is effective or that my feature
works
- I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)
  • Loading branch information
frascuchon authored Jan 21, 2025
1 parent c49305d commit 045dec2
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 8 deletions.
3 changes: 3 additions & 0 deletions argilla/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ These are the section headers that we use:
- Added support to create users with predefined ids. ([#5786](https://github.com/argilla-io/argilla/pull/5786))
- Added support to create workspaces with predefined ids. ([#5786](https://github.com/argilla-io/argilla/pull/5786))

### Fixed

- Fixed connection error when passing `verify=False` in the argilla client initialization. ([#5548](https://github.com/argilla-io/argilla/issues/5548)

## [2.6.0](https://github.com/argilla-io/argilla/compare/v2.5.0...v2.6.0)

Expand Down
4 changes: 2 additions & 2 deletions argilla/src/argilla/_api/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,12 @@ def __init__(
self.api_key = api_key

http_client_args = http_client_args or {}
http_client_args["timeout"] = timeout
http_client_args["retries"] = retries

self.http_client = create_http_client(
api_url=self.api_url, # type: ignore
api_key=self.api_key, # type: ignore
timeout=timeout,
retries=retries,
**http_client_args,
)

Expand Down
16 changes: 12 additions & 4 deletions argilla/src/argilla/_api/_http/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from dataclasses import dataclass

import httpx
Expand All @@ -27,17 +27,25 @@ class HTTPClientConfig:
retries: int = 5


def create_http_client(api_url: str, api_key: str, **client_args) -> httpx.Client:
TRANSPORT_ARGS = inspect.getfullargspec(httpx.HTTPTransport.__init__).args


def create_http_client(api_url: str, api_key: str, timeout: int, retries: int, **client_args) -> httpx.Client:
"""Initialize the SDK with the given API URL and API key."""
# This piece of code is needed to make old sdk works in combination with new one

headers = client_args.pop("headers", {})
headers["X-Argilla-Api-Key"] = api_key
retries = client_args.pop("retries", 0)

http_transport = httpx.HTTPTransport(
retries=retries,
**{name: client_args.pop(name) for name in TRANSPORT_ARGS if name in client_args},
)

return httpx.Client(
base_url=api_url,
headers=headers,
transport=httpx.HTTPTransport(retries=retries),
timeout=timeout,
transport=http_transport,
**client_args,
)
23 changes: 21 additions & 2 deletions argilla/tests/unit/api/http/test_http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import ssl
from unittest.mock import MagicMock, patch

import pytest
from argilla import Argilla
from httpx import Timeout

from argilla import Argilla


class TestHTTPClient:
def test_create_default_client(self):
Expand Down Expand Up @@ -76,3 +77,21 @@ def test_create_client_with_various_retries(self, retries):
mock_create_http_client.assert_called_once_with(
api_url="http://test.com", api_key="test_key", timeout=60, retries=retries
)

def test_create_client_with_verify(self):
http_client = Argilla(
api_url="http://test.com",
api_key="test_key",
verify=True,
).http_client

# See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.verify_mode
assert http_client._transport._pool._ssl_context.verify_mode == ssl.CERT_REQUIRED

http_client = Argilla(
api_url="http://test.com",
api_key="test_key",
verify=False,
).http_client

assert http_client._transport._pool._ssl_context.verify_mode == ssl.CERT_NONE

0 comments on commit 045dec2

Please sign in to comment.