Skip to content

Commit 104209d

Browse files
authored
Merge pull request #403 from platiagro/fix/handle_exception_prediction
add exception handling
2 parents 33bf525 + 9912e19 commit 104209d

File tree

6 files changed

+45
-8
lines changed

6 files changed

+45
-8
lines changed

projects/controllers/predictions.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
from typing import Optional
66

77
import requests
8+
from sqlalchemy.exc import DataError
89
from platiagro import load_dataset
910

10-
from projects import models, schemas
11+
from projects import models, schemas, exceptions
1112
from projects.controllers.utils import (
1213
parse_dataframe_to_seldon_request,
1314
parse_file_buffer_to_seldon_request,
@@ -121,7 +122,10 @@ def create_prediction_database_object(
121122
)
122123

123124
self.session.add(prediction)
124-
self.session.commit()
125+
try:
126+
self.session.commit()
127+
except DataError:
128+
raise(exceptions.BadRequest("400", "File too large"))
125129
return prediction
126130

127131
def start_and_save_seldon_prediction(self, request_body, prediction_object, url):

projects/controllers/tasks/tasks.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from jinja2 import Template
1313
from sqlalchemy import asc, desc, func
1414

15-
from projects import __version__, models, schemas
15+
from projects import models, schemas
1616
from projects.controllers.utils import uuid_alpha
1717
from projects.exceptions import BadRequest, Forbidden, NotFound
1818
from projects.kubernetes.notebook import (

projects/schemas/project.py

+1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def validate_description(cls, v):
3737
class ProjectUpdate(ProjectBase):
3838
name: Optional[str]
3939
description: Optional[str]
40+
4041
@validator("name")
4142
def validate_name(cls, v):
4243
generic_validators.raise_if_exceeded(MAX_CHARS_ALLOWED, v)

tests/test_experiments.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,7 @@ def test_update_experiment_invalid_name(self):
517517
)
518518
rv.json()
519519
self.assertEqual(rv.status_code, 400)
520-
520+
521521
def test_update_experiment_invalid_name_size(self):
522522
"""
523523
Should return http status 400.
@@ -531,4 +531,4 @@ def test_update_experiment_invalid_name_size(self):
531531
json={"name": experiment_name},
532532
)
533533
rv.json()
534-
self.assertEqual(rv.status_code, 400)
534+
self.assertEqual(rv.status_code, 400)

tests/test_predictions.py

+31
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import unittest.mock as mock
44

55
from fastapi.testclient import TestClient
6+
from sqlalchemy.exc import DataError
67

78
from projects.api.main import app
89
from projects.database import session_scope
@@ -228,3 +229,33 @@ def test_create_prediction_dataset_image(
228229
}
229230
},
230231
)
232+
233+
@mock.patch(
234+
"projects.controllers.predictions.load_dataset",
235+
return_value=util.IRIS_DATAFRAME,
236+
)
237+
@mock.patch(
238+
"requests.post",
239+
return_value=util.MOCK_POST_PREDICTION,
240+
)
241+
@mock.patch("projects.api.predictions.Session.commit", side_effect=DataError("statement", "params", "orig"))
242+
def test_create_prediction_fail(
243+
self,
244+
mock_requests_post,
245+
mock_load_dataset,
246+
mock_data_error
247+
):
248+
"""
249+
Should return 400 error because the used dataset file is too large.
250+
"""
251+
project_id = util.MOCK_UUID_1
252+
deployment_id = util.MOCK_UUID_1
253+
dataset_name = util.IRIS_DATASET_NAME
254+
255+
rv = TEST_CLIENT.post(
256+
f"/projects/{project_id}/deployments/{deployment_id}/predictions",
257+
json={"dataset": dataset_name},
258+
)
259+
result = rv.json()
260+
self.assertEqual(rv.status_code, 400)
261+
self.assertEqual(result["message"], "File too large")

tests/test_projects.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
LoremipsumdolorsitametconsecteturadipiscingelitInteerelitexauc\
2121
LoremipsumdolorsitametconsecteturadipiscingelitInteerelitexauc"
2222

23+
2324
class TestProjects(unittest.TestCase):
2425
maxDiff = None
2526

@@ -444,7 +445,7 @@ def test_update_project_invalid_name(self):
444445
rv.json()
445446

446447
self.assertEqual(rv.status_code, 400)
447-
448+
448449
def test_update_project_invalid_name_size(self):
449450
"""
450451
Should return http status 400.
@@ -465,7 +466,7 @@ def test_update_project_invalid_description_size(self):
465466
rv.json()
466467

467468
self.assertEqual(rv.status_code, 400)
468-
469+
469470
def test_update_project_description_success(self):
470471
"""
471472
Should return http status 200.
@@ -474,4 +475,4 @@ def test_update_project_description_success(self):
474475
rv = TEST_CLIENT.patch(f"/projects/{project_id}", json={"description": "DESCRIPTION"})
475476
rv.json()
476477

477-
self.assertEqual(rv.status_code, 200)
478+
self.assertEqual(rv.status_code, 200)

0 commit comments

Comments
 (0)