Skip to content

Commit

Permalink
Fix presigned_url_artifact request (mlflow#13366)
Browse files Browse the repository at this point in the history
Signed-off-by: Weichen Xu <[email protected]>
  • Loading branch information
WeichenXu123 authored Oct 11, 2024
1 parent 7f46072 commit 49e0382
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 28 deletions.
2 changes: 1 addition & 1 deletion mlflow/store/artifact/presigned_url_artifact_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def list_artifacts(self, path=""):
page_token = ""
while True:
endpoint = posixpath.join(DIRECTORIES_ENDPOINT, self.artifact_uri.lstrip("/"), path)
req_body = json.dumps({"page_token": page_token}) if page_token else ""
req_body = json.dumps({"page_token": page_token}) if page_token else None

response_proto = ListDirectoryResponse()
resp = call_endpoint(
Expand Down
2 changes: 1 addition & 1 deletion mlflow/utils/rest_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def get_set_trace_tag_endpoint(request_id):

def call_endpoint(host_creds, endpoint, method, json_body, response_proto, extra_headers=None):
# Convert json string to json dictionary, to pass to requests
if json_body:
if json_body is not None:
json_body = json.loads(json_body)
call_kwargs = {
"host_creds": host_creds,
Expand Down
27 changes: 1 addition & 26 deletions tests/utils/test_rest_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,31 +27,6 @@
from tests import helper_functions


def test_well_formed_json_error_response():
with mock.patch(
"requests.Session.request", return_value=mock.MagicMock(status_code=400, text="{}")
):
host_only = MlflowHostCreds("http://my-host")
response_proto = GetRun.Response()
with pytest.raises(RestException, match="INTERNAL_ERROR"):
call_endpoint(host_only, "/my/endpoint", "GET", "", response_proto)


def test_non_json_ok_response():
with mock.patch(
"requests.Session.request",
return_value=mock.MagicMock(status_code=200, text="<html></html>"),
):
host_only = MlflowHostCreds("http://my-host")
response_proto = GetRun.Response()
with pytest.raises(
MlflowException,
match="API request to endpoint was successful but the response body was not "
"in a valid JSON format",
):
call_endpoint(host_only, "/api/2.0/fetch-model", "GET", "", response_proto)


@pytest.mark.parametrize(
"response_mock",
[
Expand All @@ -68,7 +43,7 @@ def test_malformed_json_error_response(response_mock):
with pytest.raises(
MlflowException, match="API request to endpoint /my/endpoint failed with error code 400"
):
call_endpoint(host_only, "/my/endpoint", "GET", "", response_proto)
call_endpoint(host_only, "/my/endpoint", "GET", None, response_proto)


def test_call_endpoints():
Expand Down

0 comments on commit 49e0382

Please sign in to comment.