Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vdk-jupyter: fix bug for failed requests and improve error handling #2916

Merged
merged 6 commits into from
Nov 20, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,21 @@ import {
jobRequest,
jobRunRequest
} from '../serverRequests';
import { Dialog, showErrorMessage } from '@jupyterlab/apputils';

jest.mock('../handler', () => {
return {
requestAPI: jest.fn()
};
});

jest.mock('@jupyterlab/apputils', () => ({
showErrorMessage: jest.fn(),
Dialog: {
okButton: jest.fn()
}
}));

describe('jobdDataRequest', () => {
afterEach(() => {
jest.clearAllMocks();
Expand Down Expand Up @@ -213,6 +221,58 @@ describe('jobRequest()', () => {
isSuccessful: false
});
});

it('should show an error message if a task fails', async () => {
const mockData = {
[VdkOption.NAME]: 'Test Job',
[VdkOption.TEAM]: 'Test Team'
};

jobData.set(VdkOption.NAME, mockData[VdkOption.NAME]);
jobData.set(VdkOption.TEAM, mockData[VdkOption.TEAM]);

const endpoint = 'DEPLOY';
const taskId = endpoint + '-6266cd99-908c-480b-9a3e-8a30564736a4';
const taskInitiationResponse = {
error: '',
message: `Task ${taskId} started`
};
const taskCompletionResponse = {
task_id: taskId,
status: 'failed',
message: '',
error: 'An error occurred'
};

(requestAPI as jest.Mock)
.mockResolvedValueOnce(taskInitiationResponse)
.mockResolvedValue(taskCompletionResponse);

const result = await jobRequest(endpoint);

// Verify the call for initiating the task
expect(requestAPI).toHaveBeenCalledWith(endpoint, {
body: JSON.stringify(getJobDataJsonObject()),
method: 'POST'
});

// Verify the polling for task status
expect(requestAPI).toHaveBeenCalledWith(`taskStatus?taskId=${taskId}`, {
method: 'GET'
});

expect(showErrorMessage).toHaveBeenCalledWith(
'Encountered an error while trying to connect the server. Error:',
taskCompletionResponse.error,
[Dialog.okButton()]
);

// Verify the final result
expect(result).toEqual({
message: taskCompletionResponse.error,
isSuccessful: false
});
});
});

describe('getFailingNotebookInfo', () => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,13 @@ const pollForTaskCompletion = async (
`taskStatus?taskId=${taskId}`,
{ method: 'GET' }
);
if (
result.task_id === taskId &&
(result.status === 'completed' || result.error)
) {
return result;
if (result.task_id === taskId) {
if (result.status !== 'running') {
if (result.status === 'failed') {
showError(result.error);
}
return result;
}
}
} catch (error) {
showError(error);
Expand Down Expand Up @@ -151,10 +153,14 @@ export async function jobRunRequest(): Promise<jobRequestResult> {

const taskId = extractTaskIdFromMessage(initialResponse.message);
const finalResult = await pollForTaskCompletion(taskId);
return {
message: finalResult.message as string,
isSuccessful: !finalResult.error
};
if (finalResult.error) {
return { message: finalResult.error, isSuccessful: false };
} else {
return {
message: finalResult.message as string,
isSuccessful: true
};
}
} catch (error) {
showError(error);
return { message: '', isSuccessful: false };
Expand Down Expand Up @@ -191,10 +197,14 @@ export async function jobRequest(endPoint: string): Promise<jobRequestResult> {

const taskId = extractTaskIdFromMessage(initialResponse.message);
const finalResult = await pollForTaskCompletion(taskId);
return {
message: finalResult.message as string,
isSuccessful: !finalResult.error
};
if (finalResult.error) {
return { message: finalResult.error, isSuccessful: false };
} else {
return {
message: finalResult.message as string,
isSuccessful: true
};
}
} catch (error) {
showError(error);
return { message: '', isSuccessful: false };
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,23 @@ def get(self):
self.set_status(400)
self.finish(json.dumps({"error": "taskId not provided."}))
return
current_status = task_runner.get_status()
if current_status["task_id"] != task_id:
self.finish(
json.dumps(
{
"status": "failed",
"message": "Mismatched taskId.",
"error": f"Requested status for {task_id} but currently processing {current_status['task_id']}",
}
try:
current_status = task_runner.get_status()
if current_status["task_id"] != task_id:
self.finish(
json.dumps(
{
"status": "failed",
"message": "Mismatched taskId.",
"error": f"Requested status for {task_id} but currently processing {current_status['task_id']}",
}
)
)
)
return
return

self.finish(json.dumps(current_status))
self.finish(json.dumps(current_status))
except Exception as e:
self.finish(json.dumps({"message": f"{e}", "error": "true"}))


class LoadJobDataHandler(APIHandler):
Expand Down Expand Up @@ -114,18 +117,21 @@ class RunJobHandler(APIHandler):
@tornado.web.authenticated
def post(self):
input_data = self.get_json_body()
task_id = task_runner.start_task(
"RUN",
lambda: VdkUI.run_job(
input_data[VdkOption.PATH.value],
input_data[VdkOption.ARGUMENTS.value],
),
)
try:
task_id = task_runner.start_task(
"RUN",
lambda: VdkUI.run_job(
input_data[VdkOption.PATH.value],
input_data[VdkOption.ARGUMENTS.value],
),
)

if task_id:
self.finish(task_start_response_success(task_id))
else:
self.finish(task_start_response_failure("RUN"))
if task_id:
self.finish(task_start_response_success(task_id))
else:
self.finish(task_start_response_failure("RUN"))
except Exception as e:
self.finish(json.dumps({"message": f"{e}", "error": "true"}))


class DownloadJobHandler(APIHandler):
Expand All @@ -140,19 +146,22 @@ class DownloadJobHandler(APIHandler):
@tornado.web.authenticated
def post(self):
input_data = self.get_json_body()
task_id = task_runner.start_task(
"DOWNLOAD",
lambda: VdkUI.download_job(
input_data[VdkOption.NAME.value],
input_data[VdkOption.TEAM.value],
input_data[VdkOption.PATH.value],
),
)
try:
task_id = task_runner.start_task(
"DOWNLOAD",
lambda: VdkUI.download_job(
input_data[VdkOption.NAME.value],
input_data[VdkOption.TEAM.value],
input_data[VdkOption.PATH.value],
),
)

if task_id:
self.finish(task_start_response_success(task_id))
else:
self.finish(task_start_response_failure("DOWNLOAD"))
if task_id:
self.finish(task_start_response_success(task_id))
else:
self.finish(task_start_response_failure("DOWNLOAD"))
except Exception as e:
self.finish(json.dumps({"message": f"{e}", "error": "true"}))


class ConvertJobHandler(APIHandler):
Expand All @@ -164,15 +173,18 @@ class ConvertJobHandler(APIHandler):
@tornado.web.authenticated
def post(self):
input_data = self.get_json_body()
task_id = task_runner.start_task(
"CONVERTJOBTONOTEBOOK",
lambda: VdkUI.convert_job(input_data[VdkOption.PATH.value]),
)
try:
task_id = task_runner.start_task(
"CONVERTJOBTONOTEBOOK",
lambda: VdkUI.convert_job(input_data[VdkOption.PATH.value]),
)

if task_id:
self.finish(task_start_response_success(task_id))
else:
self.finish(task_start_response_failure("CONVERTJOBTONOTEBOOK"))
if task_id:
self.finish(task_start_response_success(task_id))
else:
self.finish(task_start_response_failure("CONVERTJOBTONOTEBOOK"))
except Exception as e:
self.finish(json.dumps({"message": f"{e}", "error": "true"}))


class CreateJobHandler(APIHandler):
Expand All @@ -188,19 +200,22 @@ class CreateJobHandler(APIHandler):
@tornado.web.authenticated
def post(self):
input_data = self.get_json_body()
task_id = task_runner.start_task(
"CREATE",
lambda: VdkUI.create_job(
input_data[VdkOption.NAME.value],
input_data[VdkOption.TEAM.value],
input_data[VdkOption.PATH.value],
),
)
try:
task_id = task_runner.start_task(
"CREATE",
lambda: VdkUI.create_job(
input_data[VdkOption.NAME.value],
input_data[VdkOption.TEAM.value],
input_data[VdkOption.PATH.value],
),
)

if task_id:
self.finish(task_start_response_success(task_id))
else:
self.finish(task_start_response_failure("CREATE"))
if task_id:
self.finish(task_start_response_success(task_id))
else:
self.finish(task_start_response_failure("CREATE"))
except Exception as e:
self.finish(json.dumps({"message": f"{e}", "error": "true"}))


class CreateDeploymentHandler(APIHandler):
Expand All @@ -215,20 +230,23 @@ class CreateDeploymentHandler(APIHandler):
@tornado.web.authenticated
def post(self):
input_data = self.get_json_body()
task_id = task_runner.start_task(
"DEPLOY",
lambda: VdkUI.create_deployment(
input_data[VdkOption.NAME.value],
input_data[VdkOption.TEAM.value],
input_data[VdkOption.PATH.value],
input_data[VdkOption.DEPLOYMENT_REASON.value],
),
)
try:
task_id = task_runner.start_task(
"DEPLOY",
lambda: VdkUI.create_deployment(
input_data[VdkOption.NAME.value],
input_data[VdkOption.TEAM.value],
input_data[VdkOption.PATH.value],
input_data[VdkOption.DEPLOYMENT_REASON.value],
),
)

if task_id:
self.finish(task_start_response_success(task_id))
else:
self.finish(task_start_response_failure("DEPLOY"))
if task_id:
self.finish(task_start_response_success(task_id))
else:
self.finish(task_start_response_failure("DEPLOY"))
except Exception as e:
self.finish(json.dumps({"message": f"{e}", "error": "true"}))


class GetNotebookInfoHandler(APIHandler):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def start_task(self, task_type, task_handler):
"""
task_id = f"{task_type}-{str(uuid.uuid4())}"
with self.lock:
if self.__task_status["status"] not in ["idle", "completed"]:
if self.__task_status["status"] not in ["idle", "completed", "failed"]:
return None

self.__task_status = {
Expand All @@ -51,9 +51,9 @@ def start_task(self, task_type, task_handler):
"error": None,
}

thread = threading.Thread(target=self._run_task, args=(task_handler,))
thread.start()
return task_id
thread = threading.Thread(target=self._run_task, args=(task_handler,))
thread.start()
return task_id

def get_status(self):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,25 @@ def test_task_status_after_completion(task_runner):
assert (
task_runner.start_task(task_type, task_handler) is not None
), "Unable to start a new task after completion"


def test_task_status_fail_then_success(task_runner):
task_handler = Mock(
side_effect=[Exception("An error occurred"), "Task completed successfully"]
)
task_type = "TEST"
task_id = task_runner.start_task(task_type, task_handler)
assert task_id is not None

time.sleep(0.1)
status = task_runner.get_status()
assert status["status"] == "failed"
assert "An error occurred" in status["error"]

task_id = task_runner.start_task(task_type, task_handler)
assert task_id is not None

time.sleep(0.1)
status = task_runner.get_status()
assert status["status"] == "completed"
assert status["message"] == "Task completed successfully"