From a97c6eaa21624a2710d78c372d5a622479b0ae68 Mon Sep 17 00:00:00 2001 From: Kevin Bates Date: Fri, 19 Oct 2018 16:08:46 -0700 Subject: [PATCH] Embed NB2KG into Jupyter server This change alleviates a significant pain-point for consumers of Jupyter Kernel and Enterprise Gateway projects by embedding the few classes defined in the NB2KG server extension directly into the Notebook server. All code resides in a separate gateway directory and the 'extension' is enabled via a new configuration option `--gateway-url`. Renamed classes from those used in standard NB2KG code so that Jupyter servers using the existing NB2KG extension will still work. Added test_gateway.py to exercise overridden methods. It does this by mocking the call that issues requests to the gateway server. Updated the _Running a notebook server_ topic to include a description of this feature. --- docs/source/public_server.rst | 29 ++ jupyter_server/gateway/__init__.py | 0 jupyter_server/gateway/handlers.py | 393 +++++++++++++++++++++ jupyter_server/gateway/managers.py | 493 +++++++++++++++++++++++++++ jupyter_server/serverapp.py | 67 +++- jupyter_server/tests/test_gateway.py | 327 ++++++++++++++++++ 6 files changed, 1293 insertions(+), 16 deletions(-) create mode 100644 jupyter_server/gateway/__init__.py create mode 100644 jupyter_server/gateway/handlers.py create mode 100644 jupyter_server/gateway/managers.py create mode 100644 jupyter_server/tests/test_gateway.py diff --git a/docs/source/public_server.rst b/docs/source/public_server.rst index 9fc36da110..7ce16ab7f8 100644 --- a/docs/source/public_server.rst +++ b/docs/source/public_server.rst @@ -343,6 +343,35 @@ single-tab mode: }); +Using a gateway server for kernel management +-------------------------------------------- + +You are now able to redirect the management of your kernels to a Gateway Server +(i.e., `Jupyter Kernel Gateway `_ or +`Jupyter Enterprise Gateway `_) +simply by specifying a Gateway url via the following command-line option: + + .. code-block:: bash + + $ jupyter notebook --gateway-url=http://my-gateway-server:8888 + +the environment: + + .. code-block:: bash + + GATEWAY_URL=http://my-gateway-server:8888 + +or in :file:`jupyter_notebook_config.py`: + + .. code-block:: python + + c.NotebookApp.gateway_url = http://my-gateway-server:8888 + +When provided, all kernel specifications will be retrieved from the specified Gateway server and all +kernels will be managed by that server. This option enables the ability to target kernel processes +against managed clusters while allowing for the notebook's management to remain local to the Notebook +server. + Known issues ------------ diff --git a/jupyter_server/gateway/__init__.py b/jupyter_server/gateway/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/jupyter_server/gateway/handlers.py b/jupyter_server/gateway/handlers.py new file mode 100644 index 0000000000..30098787e4 --- /dev/null +++ b/jupyter_server/gateway/handlers.py @@ -0,0 +1,393 @@ +# Copyright (c) Jupyter Development Team. +# Distributed under the terms of the Modified BSD License. + +import os +import json +import logging +from socket import gaierror + +from ..base.handlers import APIHandler, IPythonHandler +from ..utils import url_path_join + +from tornado import gen, web +from tornado.concurrent import Future +from tornado.ioloop import IOLoop +from tornado.websocket import WebSocketHandler, websocket_connect +from tornado.httpclient import HTTPRequest +from tornado.simple_httpclient import HTTPTimeoutError +from tornado.escape import url_escape, json_decode, utf8 + +from ipython_genutils.py3compat import cast_unicode +from jupyter_client.session import Session +from traitlets.config.configurable import LoggingConfigurable + +# Note: Although some of these are available via NotebookApp (command line), we will +# take the approach of using environment variables to enable separate sets of values +# for use with the local Notebook server (via command-line) and remote Gateway server +# (via environment variables). + +GATEWAY_HEADERS = json.loads(os.getenv('GATEWAY_HEADERS', '{}')) +GATEWAY_HEADERS.update({ + 'Authorization': 'token {}'.format(os.getenv('GATEWAY_AUTH_TOKEN', '')) +}) +VALIDATE_GATEWAY_CERT = os.getenv('VALIDATE_GATEWAY_CERT') not in ['no', 'false'] + +GATEWAY_CLIENT_KEY = os.getenv('GATEWAY_CLIENT_KEY') +GATEWAY_CLIENT_CERT = os.getenv('GATEWAY_CLIENT_CERT') +GATEWAY_CLIENT_CA = os.getenv('GATEWAY_CLIENT_CA') + +GATEWAY_HTTP_USER = os.getenv('GATEWAY_HTTP_USER') +GATEWAY_HTTP_PASS = os.getenv('GATEWAY_HTTP_PASS') + +# Get env variables to handle timeout of request and connection +GATEWAY_CONNECT_TIMEOUT = float(os.getenv('GATEWAY_CONNECT_TIMEOUT', 20.0)) +GATEWAY_REQUEST_TIMEOUT = float(os.getenv('GATEWAY_REQUEST_TIMEOUT', 20.0)) + + +class WebSocketChannelsHandler(WebSocketHandler, IPythonHandler): + + session = None + gateway = None + kernel_id = None + + def set_default_headers(self): + """Undo the set_default_headers in IPythonHandler which doesn't make sense for websockets""" + pass + + def get_compression_options(self): + # use deflate compress websocket + return {} + + def authenticate(self): + """Run before finishing the GET request + + Extend this method to add logic that should fire before + the websocket finishes completing. + """ + # authenticate the request before opening the websocket + if self.get_current_user() is None: + self.log.warning("Couldn't authenticate WebSocket connection") + raise web.HTTPError(403) + + if self.get_argument('session_id', False): + self.session.session = cast_unicode(self.get_argument('session_id')) + else: + self.log.warning("No session ID specified") + + def initialize(self): + self.log.debug("Initializing websocket connection %s", self.request.path) + self.session = Session(config=self.config) + self.gateway = GatewayWebSocketClient(gateway_url=self.kernel_manager.parent.gateway_url) + + @gen.coroutine + def get(self, kernel_id, *args, **kwargs): + self.authenticate() + self.kernel_id = cast_unicode(kernel_id, 'ascii') + super(WebSocketChannelsHandler, self).get(kernel_id=kernel_id, *args, **kwargs) + + def open(self, kernel_id, *args, **kwargs): + """Handle web socket connection open to notebook server and delegate to gateway web socket handler """ + self.gateway.on_open( + kernel_id=kernel_id, + message_callback=self.write_message, + compression_options=self.get_compression_options() + ) + + def on_message(self, message): + """Forward message to gateway web socket handler.""" + self.log.debug("Sending message to gateway: {}".format(message)) + self.gateway.on_message(message) + + def write_message(self, message, binary=False): + """Send message back to notebook client. This is called via callback from self.gateway._read_messages.""" + self.log.debug("Receiving message from gateway: {}".format(message)) + if self.ws_connection: # prevent WebSocketClosedError + super(WebSocketChannelsHandler, self).write_message(message, binary=binary) + elif self.log.isEnabledFor(logging.DEBUG): + msg_summary = WebSocketChannelsHandler._get_message_summary(json_decode(utf8(message))) + self.log.debug("Notebook client closed websocket connection - message dropped: {}".format(msg_summary)) + + def on_close(self): + self.log.debug("Closing websocket connection %s", self.request.path) + self.gateway.on_close() + super(WebSocketChannelsHandler, self).on_close() + + @staticmethod + def _get_message_summary(message): + summary = [] + message_type = message['msg_type'] + summary.append('type: {}'.format(message_type)) + + if message_type == 'status': + summary.append(', state: {}'.format(message['content']['execution_state'])) + elif message_type == 'error': + summary.append(', {}:{}:{}'.format(message['content']['ename'], + message['content']['evalue'], + message['content']['traceback'])) + else: + summary.append(', ...') # don't display potentially sensitive data + + return ''.join(summary) + + +class GatewayWebSocketClient(LoggingConfigurable): + """Proxy web socket connection to a kernel/enterprise gateway.""" + + def __init__(self, **kwargs): + super(GatewayWebSocketClient, self).__init__(**kwargs) + self.gateway_url = kwargs['gateway_url'] + self.kernel_id = None + self.ws = None + self.ws_future = Future() + self.ws_future_cancelled = False + + @gen.coroutine + def _connect(self, kernel_id): + self.kernel_id = kernel_id + ws_url = url_path_join( + os.getenv('GATEWAY_WS_URL', self.gateway_url.replace('http', 'ws')), + '/api/kernels', + url_escape(kernel_id), + 'channels' + ) + self.log.info('Connecting to {}'.format(ws_url)) + parameters = { + "headers": GATEWAY_HEADERS, + "validate_cert": VALIDATE_GATEWAY_CERT, + "connect_timeout": GATEWAY_CONNECT_TIMEOUT, + "request_timeout": GATEWAY_REQUEST_TIMEOUT + } + if GATEWAY_HTTP_USER: + parameters["auth_username"] = GATEWAY_HTTP_USER + if GATEWAY_HTTP_PASS: + parameters["auth_password"] = GATEWAY_HTTP_PASS + if GATEWAY_CLIENT_KEY: + parameters["client_key"] = GATEWAY_CLIENT_KEY + parameters["client_cert"] = GATEWAY_CLIENT_CERT + if GATEWAY_CLIENT_CA: + parameters["ca_certs"] = GATEWAY_CLIENT_CA + + request = HTTPRequest(ws_url, **parameters) + self.ws_future = websocket_connect(request) + self.ws_future.add_done_callback(self._connection_done) + + def _connection_done(self, fut): + if not self.ws_future_cancelled: # prevent concurrent.futures._base.CancelledError + self.ws = fut.result() + self.log.debug("Connection is ready: ws: {}".format(self.ws)) + else: + self.log.warning("Websocket connection has been cancelled via client disconnect before its establishment. " + "Kernel with ID '{}' may not be terminated on Gateway: {}". + format(self.kernel_id, self.gateway_url)) + + def _disconnect(self): + if self.ws is not None: + # Close connection + self.ws.close() + elif not self.ws_future.done(): + # Cancel pending connection. Since future.cancel() is a noop on tornado, we'll track cancellation locally + self.ws_future.cancel() + self.ws_future_cancelled = True + self.log.debug("_disconnect: ws_future_cancelled: {}".format(self.ws_future_cancelled)) + + @gen.coroutine + def _read_messages(self, callback): + """Read messages from gateway server.""" + while True: + message = None + if not self.ws_future_cancelled: + try: + message = yield self.ws.read_message() + except Exception as e: + self.log.error("Exception reading message from websocket: {}".format(e)) # , exc_info=True) + if message is None: + break + callback(message) # pass back to notebook client (see self.on_open and WebSocketChannelsHandler.open) + else: # ws cancelled - stop reading + break + + def on_open(self, kernel_id, message_callback, **kwargs): + """Web socket connection open against gateway server.""" + self._connect(kernel_id) + loop = IOLoop.current() + loop.add_future( + self.ws_future, + lambda future: self._read_messages(message_callback) + ) + + def on_message(self, message): + """Send message to gateway server.""" + if self.ws is None: + loop = IOLoop.current() + loop.add_future( + self.ws_future, + lambda future: self._write_message(message) + ) + else: + self._write_message(message) + + def _write_message(self, message): + """Send message to gateway server.""" + try: + if not self.ws_future_cancelled: + self.ws.write_message(message) + except Exception as e: + self.log.error("Exception writing message to websocket: {}".format(e)) # , exc_info=True) + + def on_close(self): + """Web socket closed event.""" + self._disconnect() + + +# ----------------------------------------------------------------------------- +# kernel handlers +# ----------------------------------------------------------------------------- + +class MainKernelHandler(APIHandler): + """Replace default MainKernelHandler to enable async lookup of kernels.""" + + @web.authenticated + @gen.coroutine + def get(self): + km = self.kernel_manager + kernels = yield gen.maybe_future(km.list_kernels()) + self.finish(json.dumps(kernels)) + + @web.authenticated + @gen.coroutine + def post(self): + km = self.kernel_manager + model = self.get_json_body() + if model is None: + model = { + 'name': km.default_kernel_name + } + else: + model.setdefault('name', km.default_kernel_name) + + kernel_id = yield gen.maybe_future(km.start_kernel(kernel_name=model['name'])) + # This is now an async operation + model = yield gen.maybe_future(km.kernel_model(kernel_id)) + location = url_path_join(self.base_url, 'api', 'kernels', url_escape(kernel_id)) + self.set_header('Location', location) + self.set_status(201) + self.finish(json.dumps(model)) + + +class KernelHandler(APIHandler): + """Replace default KernelHandler to enable async lookup of kernels.""" + + @web.authenticated + @gen.coroutine + def get(self, kernel_id): + km = self.kernel_manager + # This is now an async operation + model = yield gen.maybe_future(km.kernel_model(kernel_id)) + if model is None: + raise web.HTTPError(404, u'Kernel does not exist: %s' % kernel_id) + self.finish(json.dumps(model)) + + @web.authenticated + @gen.coroutine + def delete(self, kernel_id): + km = self.kernel_manager + yield gen.maybe_future(km.shutdown_kernel(kernel_id)) + self.set_status(204) + self.finish() + + +class KernelActionHandler(APIHandler): + """Replace default KernelActionHandler to enable async lookup of kernels.""" + + @web.authenticated + @gen.coroutine + def post(self, kernel_id, action): + km = self.kernel_manager + + if action == 'interrupt': + km.interrupt_kernel(kernel_id) + self.set_status(204) + + if action == 'restart': + try: + yield gen.maybe_future(km.restart_kernel(kernel_id)) + except Exception as e: + self.log.error("Exception restarting kernel", exc_info=True) + self.set_status(500) + else: + # This is now an async operation + model = yield gen.maybe_future(km.kernel_model(kernel_id)) + self.write(json.dumps(model)) + self.finish() + +# ----------------------------------------------------------------------------- +# kernel spec handlers +# ----------------------------------------------------------------------------- + + +class MainKernelSpecHandler(APIHandler): + @web.authenticated + @gen.coroutine + def get(self): + ksm = self.kernel_spec_manager + try: + kernel_specs = yield gen.maybe_future(ksm.list_kernel_specs()) + # TODO: Remove resources until we support them + for name, spec in kernel_specs['kernelspecs'].items(): + spec['resources'] = {} + self.set_header("Content-Type", 'application/json') + self.write(json.dumps(kernel_specs)) + + # Trap a set of common exceptions so that we can inform the user that their Gateway url is incorrect + # or the server is not running. + # NOTE: We do this here since this handler is called during the Notebook's startup and subsequent refreshes + # of the tree view. + except ConnectionRefusedError: + gateway_url = ksm.parent.gateway_url + self.log.error("Connection refused from Gateway server url '{}'. " + "Check to be sure the Gateway instance is running.".format(gateway_url)) + except HTTPTimeoutError: + # This can occur if the host is valid (e.g., foo.com) but there's nothing there. + gateway_url = ksm.parent.gateway_url + self.log.error("Timeout error attempting to connect to Gateway server url '{}'. " + "Ensure gateway_url is valid and the Gateway instance is running.".format(gateway_url)) + except gaierror as e: + gateway_url = ksm.parent.gateway_url + self.log.error("The Gateway server specified in the gateway_url '{}' doesn't appear to be valid. " + "Ensure gateway_url is valid and the Gateway instance is running.".format(gateway_url)) + + self.finish() + + +class KernelSpecHandler(APIHandler): + @web.authenticated + @gen.coroutine + def get(self, kernel_name): + ksm = self.kernel_spec_manager + kernel_spec = yield ksm.get_kernel_spec(kernel_name) + if kernel_spec is None: + raise web.HTTPError(404, u'Kernel spec %s not found' % kernel_name) + # TODO: Remove resources until we support them + kernel_spec['resources'] = {} + self.set_header("Content-Type", 'application/json') + self.finish(json.dumps(kernel_spec)) + +# ----------------------------------------------------------------------------- +# URL to handler mappings +# ----------------------------------------------------------------------------- + + +from ..services.kernels.handlers import _kernel_id_regex, _kernel_action_regex +from ..services.kernelspecs.handlers import kernel_name_regex + +default_handlers = [ + (r"/api/kernels", MainKernelHandler), + (r"/api/kernels/%s" % _kernel_id_regex, KernelHandler), + (r"/api/kernels/%s/%s" % (_kernel_id_regex, _kernel_action_regex), KernelActionHandler), + (r"/api/kernels/%s/channels" % _kernel_id_regex, WebSocketChannelsHandler), + (r"/api/kernelspecs", MainKernelSpecHandler), + (r"/api/kernelspecs/%s" % kernel_name_regex, KernelSpecHandler), + # TODO: support kernel spec resources + # (r"/kernelspecs/%s/(?P.*)" % kernel_name_regex, KernelSpecResourceHandler), + +] diff --git a/jupyter_server/gateway/managers.py b/jupyter_server/gateway/managers.py new file mode 100644 index 0000000000..31dcba9cdc --- /dev/null +++ b/jupyter_server/gateway/managers.py @@ -0,0 +1,493 @@ +# Copyright (c) Jupyter Development Team. +# Distributed under the terms of the Modified BSD License. + +import os +import json + +from tornado import gen +from tornado.escape import json_encode, json_decode, url_escape +from tornado.httpclient import HTTPClient, AsyncHTTPClient, HTTPError + +from ..services.kernels.kernelmanager import MappingKernelManager +from ..services.sessions.sessionmanager import SessionManager + +from jupyter_client.kernelspec import KernelSpecManager +from ..utils import url_path_join + +from traitlets import Instance, Unicode, default + +# Note: Although some of these are available via NotebookApp (command line), we will +# take the approach of using environment variables to enable separate sets of values +# for use with the local Notebook server (via command-line) and remote Gateway server +# (via environment variables). + +GATEWAY_HEADERS = json.loads(os.getenv('GATEWAY_HEADERS', '{}')) +GATEWAY_HEADERS.update({ + 'Authorization': 'token {}'.format(os.getenv('GATEWAY_AUTH_TOKEN', '')) +}) +VALIDATE_GATEWAY_CERT = os.getenv('VALIDATE_GATEWAY_CERT') not in ['no', 'false'] + +GATEWAY_CLIENT_KEY = os.getenv('GATEWAY_CLIENT_KEY') +GATEWAY_CLIENT_CERT = os.getenv('GATEWAY_CLIENT_CERT') +GATEWAY_CLIENT_CA = os.getenv('GATEWAY_CLIENT_CA') + +GATEWAY_HTTP_USER = os.getenv('GATEWAY_HTTP_USER') +GATEWAY_HTTP_PASS = os.getenv('GATEWAY_HTTP_PASS') + +GATEWAY_CONNECT_TIMEOUT = float(os.getenv('GATEWAY_CONNECT_TIMEOUT', 20.0)) +GATEWAY_REQUEST_TIMEOUT = float(os.getenv('GATEWAY_REQUEST_TIMEOUT', 20.0)) + + +def load_connection_args(**kwargs): + + if GATEWAY_CLIENT_CERT: + kwargs["client_key"] = kwargs.get("client_key", GATEWAY_CLIENT_KEY) + kwargs["client_cert"] = kwargs.get("client_cert", GATEWAY_CLIENT_CERT) + if GATEWAY_CLIENT_CA: + kwargs["ca_certs"] = kwargs.get("ca_certs", GATEWAY_CLIENT_CA) + kwargs['connect_timeout'] = kwargs.get('connect_timeout', GATEWAY_CONNECT_TIMEOUT) + kwargs['request_timeout'] = kwargs.get('request_timeout', GATEWAY_REQUEST_TIMEOUT) + kwargs['headers'] = kwargs.get('headers', GATEWAY_HEADERS) + kwargs['validate_cert'] = kwargs.get('validate_cert', VALIDATE_GATEWAY_CERT) + if GATEWAY_HTTP_USER: + kwargs['auth_username'] = kwargs.get('auth_username', GATEWAY_HTTP_USER) + if GATEWAY_HTTP_PASS: + kwargs['auth_password'] = kwargs.get('auth_password', GATEWAY_HTTP_PASS) + + return kwargs + + +@gen.coroutine +def fetch_gateway(endpoint, **kwargs): + """Make an async request to kernel gateway endpoint.""" + client = AsyncHTTPClient() + + kwargs = load_connection_args(**kwargs) + + response = yield client.fetch(endpoint, **kwargs) + raise gen.Return(response) + + +class GatewayKernelManager(MappingKernelManager): + """Kernel manager that supports remote kernels hosted by Jupyter + kernel gateway.""" + + kernels_endpoint_env = 'GATEWAY_KERNELS_ENDPOINT' + kernels_endpoint = Unicode(config=True, + help="""The gateway API endpoint for accessing kernel resources (GATEWAY_KERNELS_ENDPOINT env var)""") + + @default('kernels_endpoint') + def kernels_endpoint_default(self): + return os.getenv(self.kernels_endpoint_env, '/api/kernels') + + # We'll maintain our own set of kernel ids + _kernels = {} + + def __init__(self, **kwargs): + super(GatewayKernelManager, self).__init__(**kwargs) + self.gateway_url = self.parent.gateway_url + + def __contains__(self, kernel_id): + return kernel_id in self._kernels + + def remove_kernel(self, kernel_id): + """Complete override since we want to be more tolerant of missing keys """ + try: + return self._kernels.pop(kernel_id) + except KeyError: + pass + + def _get_kernel_endpoint_url(self, kernel_id=None): + """Builds a url for the kernels endpoint + + Parameters + ---------- + kernel_id: kernel UUID (optional) + """ + if kernel_id: + return url_path_join(self.gateway_url, self.kernels_endpoint, url_escape(str(kernel_id))) + + return url_path_join(self.gateway_url, self.kernels_endpoint) + + @gen.coroutine + def start_kernel(self, kernel_id=None, path=None, **kwargs): + """Start a kernel for a session and return its kernel_id. + + Parameters + ---------- + kernel_id : uuid + The uuid to associate the new kernel with. If this + is not None, this kernel will be persistent whenever it is + requested. + path : API path + The API path (unicode, '/' delimited) for the cwd. + Will be transformed to an OS path relative to root_dir. + """ + self.log.info('Request start kernel: kernel_id=%s, path="%s"', kernel_id, path) + + if kernel_id is None: + kernel_name = kwargs.get('kernel_name', 'python3') + kernel_url = self._get_kernel_endpoint_url() + self.log.debug("Request new kernel at: %s" % kernel_url) + + kernel_env = {k: v for (k, v) in dict(os.environ).items() if k.startswith('KERNEL_') + or k in os.environ.get('GATEWAY_ENV_WHITELIST', '').split(",")} + json_body = json_encode({'name': kernel_name, 'env': kernel_env}) + + response = yield fetch_gateway(kernel_url, method='POST', body=json_body) + kernel = json_decode(response.body) + kernel_id = kernel['id'] + self.log.info("Kernel started: %s" % kernel_id) + else: + kernel = yield self.get_kernel(kernel_id) + kernel_id = kernel['id'] + self.log.info("Using existing kernel: %s" % kernel_id) + + self._kernels[kernel_id] = kernel + raise gen.Return(kernel_id) + + @gen.coroutine + def get_kernel(self, kernel_id=None, **kwargs): + """Get kernel for kernel_id. + + Parameters + ---------- + kernel_id : uuid + The uuid of the kernel. + """ + kernel_url = self._get_kernel_endpoint_url(kernel_id) + self.log.debug("Request kernel at: %s" % kernel_url) + try: + response = yield fetch_gateway(kernel_url, method='GET') + except HTTPError as error: + if error.code == 404: + self.log.warn("Kernel not found at: %s" % kernel_url) + self.remove_kernel(kernel_id) + kernel = None + else: + raise + else: + kernel = json_decode(response.body) + self._kernels[kernel_id] = kernel + self.log.info("Kernel retrieved: %s" % kernel) + raise gen.Return(kernel) + + @gen.coroutine + def kernel_model(self, kernel_id): + """Return a dictionary of kernel information described in the + JSON standard model. + + Parameters + ---------- + kernel_id : uuid + The uuid of the kernel. + """ + self.log.debug("RemoteKernelManager.kernel_model: %s", kernel_id) + model = yield self.get_kernel(kernel_id) + raise gen.Return(model) + + @gen.coroutine + def list_kernels(self, **kwargs): + """Get a list of kernels.""" + kernel_url = self._get_kernel_endpoint_url() + self.log.debug("Request list kernels: %s", kernel_url) + response = yield fetch_gateway(kernel_url, method='GET') + kernels = json_decode(response.body) + self._kernels = {x['id']:x for x in kernels} + raise gen.Return(kernels) + + @gen.coroutine + def shutdown_kernel(self, kernel_id): + """Shutdown a kernel by its kernel uuid. + + Parameters + ========== + kernel_id : uuid + The id of the kernel to shutdown. + """ + kernel_url = self._get_kernel_endpoint_url(kernel_id) + self.log.debug("Request shutdown kernel at: %s", kernel_url) + response = yield fetch_gateway(kernel_url, method='DELETE') + self.log.debug("Shutdown kernel response: %d %s", response.code, response.reason) + self.remove_kernel(kernel_id) + + @gen.coroutine + def restart_kernel(self, kernel_id, now=False, **kwargs): + """Restart a kernel by its kernel uuid. + + Parameters + ========== + kernel_id : uuid + The id of the kernel to restart. + """ + kernel_url = self._get_kernel_endpoint_url(kernel_id) + '/restart' + self.log.debug("Request restart kernel at: %s", kernel_url) + response = yield fetch_gateway(kernel_url, method='POST', body=json_encode({})) + self.log.debug("Restart kernel response: %d %s", response.code, response.reason) + + @gen.coroutine + def interrupt_kernel(self, kernel_id, **kwargs): + """Interrupt a kernel by its kernel uuid. + + Parameters + ========== + kernel_id : uuid + The id of the kernel to interrupt. + """ + kernel_url = self._get_kernel_endpoint_url(kernel_id) + '/interrupt' + self.log.debug("Request interrupt kernel at: %s", kernel_url) + response = yield fetch_gateway(kernel_url, method='POST', body=json_encode({})) + self.log.debug("Interrupt kernel response: %d %s", response.code, response.reason) + + def shutdown_all(self): + """Shutdown all kernels.""" + # Note: We have to make this sync because the NotebookApp does not wait for async. + kwargs = {'method': 'DELETE'} + kwargs = load_connection_args(**kwargs) + client = HTTPClient() + for kernel_id in self._kernels.keys(): + kernel_url = self._get_kernel_endpoint_url(kernel_id) + self.log.debug("Request delete kernel at: %s", kernel_url) + try: + response = client.fetch(kernel_url, **kwargs) + except HTTPError: + pass + self.log.debug("Delete kernel response: %d %s", response.code, response.reason) + self.remove_kernel(kernel_id) + client.close() + + +class GatewayKernelSpecManager(KernelSpecManager): + + kernelspecs_endpoint_env = 'GATEWAY_KERNELSPECS_ENDPOINT' + kernelspecs_endpoint = Unicode(config=True, + help="""The kernel gateway API endpoint for accessing kernelspecs + (GATEWAY_KERNELSPECS_ENDPOINT env var)""") + + @default('kernelspecs_endpoint') + def kernelspecs_endpoint_default(self): + return os.getenv(self.kernelspecs_endpoint_env, '/api/kernelspecs') + + def __init__(self, **kwargs): + super(GatewayKernelSpecManager, self).__init__(**kwargs) + self.gateway_url = self.parent.gateway_url + + def _get_kernelspecs_endpoint_url(self, kernel_name=None): + """Builds a url for the kernels endpoint + + Parameters + ---------- + kernel_name: kernel name (optional) + """ + if kernel_name: + return url_path_join(self.gateway_url, self.kernelspecs_endpoint, url_escape(kernel_name)) + + return url_path_join(self.gateway_url, self.kernelspecs_endpoint) + + @gen.coroutine + def list_kernel_specs(self): + """Get a list of kernel specs.""" + kernel_spec_url = self._get_kernelspecs_endpoint_url() + self.log.debug("Request list kernel specs at: %s", kernel_spec_url) + response = yield fetch_gateway(kernel_spec_url, method='GET') + kernel_specs = json_decode(response.body) + raise gen.Return(kernel_specs) + + @gen.coroutine + def get_kernel_spec(self, kernel_name, **kwargs): + """Get kernel spec for kernel_name. + + Parameters + ---------- + kernel_name : str + The name of the kernel. + """ + kernel_spec_url = self._get_kernelspecs_endpoint_url(kernel_name=str(kernel_name)) + self.log.debug("Request kernel spec at: %s" % kernel_spec_url) + try: + response = yield fetch_gateway(kernel_spec_url, method='GET') + except HTTPError as error: + if error.code == 404: + self.log.warn("Kernel spec not found at: %s" % kernel_spec_url) + kernel_spec = None + else: + raise + else: + kernel_spec = json_decode(response.body) + raise gen.Return(kernel_spec) + + +class GatewaySessionManager(SessionManager): + kernel_manager = Instance('notebook.gateway.managers.GatewayKernelManager') + + @gen.coroutine + def create_session(self, path=None, name=None, type=None, + kernel_name=None, kernel_id=None): + """Creates a session and returns its model. + + Overrides base class method to turn into an async operation. + """ + session_id = self.new_session_id() + + kernel = None + if kernel_id is not None: + # This is now an async operation + kernel = yield self.kernel_manager.get_kernel(kernel_id) + + if kernel is not None: + pass + else: + kernel_id = yield self.start_kernel_for_session( + session_id, path, name, type, kernel_name, + ) + + result = yield self.save_session( + session_id, path=path, name=name, type=type, kernel_id=kernel_id, + ) + raise gen.Return(result) + + @gen.coroutine + def save_session(self, session_id, path=None, name=None, type=None, + kernel_id=None): + """Saves the items for the session with the given session_id + + Given a session_id (and any other of the arguments), this method + creates a row in the sqlite session database that holds the information + for a session. + + Parameters + ---------- + session_id : str + uuid for the session; this method must be given a session_id + path : str + the path for the given notebook + kernel_id : str + a uuid for the kernel associated with this session + + Returns + ------- + model : dict + a dictionary of the session model + """ + # This is now an async operation + session = yield super(GatewaySessionManager, self).save_session( + session_id, path=path, name=name, type=type, kernel_id=kernel_id + ) + raise gen.Return(session) + + @gen.coroutine + def get_session(self, **kwargs): + """Returns the model for a particular session. + + Takes a keyword argument and searches for the value in the session + database, then returns the rest of the session's info. + + Overrides base class method to turn into an async operation. + + Parameters + ---------- + **kwargs : keyword argument + must be given one of the keywords and values from the session database + (i.e. session_id, path, kernel_id) + + Returns + ------- + model : dict + returns a dictionary that includes all the information from the + session described by the kwarg. + """ + # This is now an async operation + session = yield super(GatewaySessionManager, self).get_session(**kwargs) + raise gen.Return(session) + + @gen.coroutine + def update_session(self, session_id, **kwargs): + """Updates the values in the session database. + + Changes the values of the session with the given session_id + with the values from the keyword arguments. + + Overrides base class method to turn into an async operation. + + Parameters + ---------- + session_id : str + a uuid that identifies a session in the sqlite3 database + **kwargs : str + the key must correspond to a column title in session database, + and the value replaces the current value in the session + with session_id. + """ + # This is now an async operation + session = yield self.get_session(session_id=session_id) + + if not kwargs: + # no changes + return + + sets = [] + for column in kwargs.keys(): + if column not in self._columns: + raise TypeError("No such column: %r" % column) + sets.append("%s=?" % column) + query = "UPDATE session SET %s WHERE session_id=?" % (', '.join(sets)) + self.cursor.execute(query, list(kwargs.values()) + [session_id]) + + @gen.coroutine + def row_to_model(self, row): + """Takes sqlite database session row and turns it into a dictionary. + + Overrides base class method to turn into an async operation. + """ + # Retrieve kernel for session, which is now an async operation + kernel = yield self.kernel_manager.get_kernel(row['kernel_id']) + if kernel is None: + # The kernel was killed or died without deleting the session. + # We can't use delete_session here because that tries to find + # and shut down the kernel. + self.cursor.execute("DELETE FROM session WHERE session_id=?", + (row['session_id'],)) + raise KeyError + + model = { + 'id': row['session_id'], + 'path': row['path'], + 'name': row['name'], + 'type': row['type'], + 'kernel': kernel + } + if row['type'] == 'notebook': # Provide the deprecated API. + model['notebook'] = {'path': row['path'], 'name': row['name']} + + raise gen.Return(model) + + @gen.coroutine + def list_sessions(self): + """Returns a list of dictionaries containing all the information from + the session database. + + Overrides base class method to turn into an async operation. + """ + c = self.cursor.execute("SELECT * FROM session") + result = [] + # We need to use fetchall() here, because row_to_model can delete rows, + # which messes up the cursor if we're iterating over rows. + for row in c.fetchall(): + try: + # This is now an async operation + model = yield self.row_to_model(row) + result.append(model) + except KeyError: + pass + raise gen.Return(result) + + @gen.coroutine + def delete_session(self, session_id): + """Deletes the row in the session database with given session_id. + + Overrides base class method to turn into an async operation. + """ + # This is now an async operation + session = yield self.get_session(session_id=session_id) + yield gen.maybe_future(self.kernel_manager.shutdown_kernel(session['kernel']['id'])) + self.cursor.execute("DELETE FROM session WHERE session_id=?", (session_id,)) diff --git a/jupyter_server/serverapp.py b/jupyter_server/serverapp.py index d7472d2693..f43f84ba81 100755 --- a/jupyter_server/serverapp.py +++ b/jupyter_server/serverapp.py @@ -74,6 +74,9 @@ from .services.contents.filemanager import FileContentsManager from .services.contents.largefilemanager import LargeFileManager from .services.sessions.sessionmanager import SessionManager +from .gateway.managers import GatewayKernelManager +from .gateway.managers import GatewayKernelSpecManager +from .gateway.managers import GatewaySessionManager from .auth.login import LoginHandler from .auth.logout import LogoutHandler @@ -119,10 +122,10 @@ files=['jupyter_server.files.handlers'], kernels=['jupyter_server.services.kernels.handlers'], kernelspecs=[ - 'jupyter_server.kernelspecs.handlers', + 'jupyter_server.kernelspecs.handlers', 'jupyter_server.services.kernelspecs.handlers'], nbconvert=[ - 'jupyter_server.nbconvert.handlers', + 'jupyter_server.nbconvert.handlers', 'jupyter_server.services.nbconvert.handlers'], security=['jupyter_server.services.security.handlers'], sessions=['jupyter_server.services.sessions.handlers'], @@ -159,14 +162,14 @@ class ServerWebApplication(web.Application): def __init__(self, jupyter_app, default_services, kernel_manager, contents_manager, session_manager, kernel_spec_manager, config_manager, extra_services, log, - base_url, default_url, settings_overrides, jinja_env_options): - + base_url, default_url, settings_overrides, jinja_env_options, + gateway_url): settings = self.init_settings( jupyter_app, kernel_manager, contents_manager, session_manager, kernel_spec_manager, config_manager, extra_services, log, base_url, - default_url, settings_overrides, jinja_env_options) + default_url, settings_overrides, jinja_env_options, gateway_url) handlers = self.init_handlers(default_services, settings) super(ServerWebApplication, self).__init__(handlers, **settings) @@ -175,7 +178,7 @@ def init_settings(self, jupyter_app, kernel_manager, contents_manager, session_manager, kernel_spec_manager, config_manager, extra_services, log, base_url, default_url, settings_overrides, - jinja_env_options=None): + jinja_env_options=None, gateway_url=None): _template_path = settings_overrides.get( "template_path", @@ -270,6 +273,7 @@ def init_settings(self, jupyter_app, kernel_manager, contents_manager, server_root_dir=root_dir, jinja2_env=env, terminals_available=False, # Set later if terminals are available + gateway_url=gateway_url, ) # allow custom overrides for the tornado web app. @@ -289,8 +293,8 @@ def init_handlers(self, default_services, settings): handlers.extend([(r"/login", settings['login_handler_class'])]) handlers.extend([(r"/logout", settings['logout_handler_class'])]) - # Load default services. Raise exception if service not - # found in JUPYTER_SERVICE_HANLDERS. + # Load default services. Raise exception if service not + # found in JUPYTER_SERVICE_HANLDERS. for service in default_services: if service in JUPYTER_SERVICE_HANDLERS: locations = JUPYTER_SERVICE_HANDLERS[service] @@ -306,6 +310,13 @@ def init_handlers(self, default_services, settings): # Add extra handlers from contents manager. handlers.extend(settings['contents_manager'].get_extra_handlers()) + # If gateway server is configured, replace appropriate handlers to perform redirection + if settings['gateway_url']: + handlers.extend(load_handlers('notebook.gateway.handlers')) + else: + handlers.extend(load_handlers('notebook.services.kernels.handlers')) + handlers.extend(load_handlers('notebook.services.kernelspecs.handlers')) + handlers.append( (r"/custom/(.*)", FileFindHandler, { 'path': settings['static_custom_path'], @@ -524,6 +535,7 @@ def start(self): 'notebook-dir': 'ServerApp.root_dir', 'browser': 'ServerApp.browser', 'pylab': 'ServerApp.pylab', + 'gateway-url': 'ServerApp.gateway_url', }) #----------------------------------------------------------------------------- @@ -542,9 +554,9 @@ class ServerApp(JupyterApp): flags = flags classes = [ - KernelManager, Session, MappingKernelManager, + KernelManager, Session, MappingKernelManager, KernelSpecManager, ContentsManager, FileContentsManager, NotebookNotary, - KernelSpecManager, + GatewayKernelManager, GatewayKernelSpecManager, GatewaySessionManager, ] flags = Dict(flags) aliases = Dict(aliases) @@ -561,12 +573,12 @@ class ServerApp(JupyterApp): default_services = ( 'api', 'auth', - 'config', - 'contents', + 'config', + 'contents', 'edit', - 'files', - 'kernels', - 'kernelspecs', + 'files', + 'kernels', + 'kernelspecs', 'nbconvert', 'security', 'sessions', @@ -1193,6 +1205,20 @@ def _update_server_extensions(self, change): is not available. """)) + gateway_url_env = 'GATEWAY_URL' + + @default('gateway_url') + def gateway_url_default(self): + return os.getenv(self.gateway_url_env) + + gateway_url = Unicode(default_value=None, allow_none=True, config=True, + help="""The url of the Kernel or Enterprise Gateway server where + kernel specifications are defined and kernel management takes place. + If defined, this Notebook server acts as a proxy for all kernel + management and kernel specification retrieval. (GATEWAY_URL env var) + """ + ) + def parse_command_line(self, argv=None): super(ServerApp, self).parse_command_line(argv) @@ -1214,6 +1240,13 @@ def parse_command_line(self, argv=None): self.update_config(c) def init_configurables(self): + + # If gateway server is configured, replace appropriate managers to perform redirection + if self.gateway_url: + self.kernel_manager_class = 'jupyter_server.gateway.managers.GatewayKernelManager' + self.session_manager_class = 'jupyter_server.gateway.managers.GatewaySessionManager' + self.kernel_spec_manager_class = 'jupyter_server.gateway.managers.GatewayKernelSpecManager' + self.kernel_spec_manager = self.kernel_spec_manager_class( parent=self, ) @@ -1279,7 +1312,7 @@ def init_webapp(self): self.session_manager, self.kernel_spec_manager, self.config_manager, self.extra_services, self.log, self.base_url, self.default_url, self.tornado_settings, - self.jinja_environment_options, + self.jinja_environment_options, self.gateway_url, ) ssl_options = self.ssl_options if self.certfile: @@ -1557,6 +1590,8 @@ def notebook_info(self, kernel_count=True): # Format the info so that the URL fits on a single line in 80 char display info += _("Jupyter Server {version} is running at:\n{url}". format(version=ServerApp.version, url=self.display_url)) + if self.gateway_url: + info += _("\nKernels will be managed by the Gateway server running at:\n%s") % self.gateway_url return info def server_info(self): diff --git a/jupyter_server/tests/test_gateway.py b/jupyter_server/tests/test_gateway.py new file mode 100644 index 0000000000..385c1f1cd5 --- /dev/null +++ b/jupyter_server/tests/test_gateway.py @@ -0,0 +1,327 @@ +"""Test Gateway""" +import os +import json +import uuid +from datetime import datetime +from tornado import gen +from tornado.httpclient import HTTPRequest, HTTPResponse, HTTPError +from traitlets.config import Config +from .launchnotebook import NotebookTestBase + +try: + from unittest.mock import patch, Mock +except ImportError: + from mock import patch, Mock # py2 + +try: + from io import StringIO +except ImportError: + import StringIO + +import nose.tools as nt + + +def generate_kernelspec(name): + argv_stanza = ['python', '-m', 'ipykernel_launcher', '-f', '{connection_file}'] + spec_stanza = {'spec': {'argv': argv_stanza, 'env': {}, 'display_name': name, 'language': 'python', 'interrupt_mode': 'signal', 'metadata': {}}} + kernelspec_stanza = {name: {'name': name, 'spec': spec_stanza, 'resources': {}}} + return kernelspec_stanza + + +# We'll mock up two kernelspecs - kspec_foo and kspec_bar +kernelspecs = {'kernelspecs': {'kspec_foo': generate_kernelspec('kspec_foo'), 'kspec_bar': generate_kernelspec('kspec_bar')}} + + +# maintain a dictionary of expected running kernels. Key = kernel_id, Value = model. +running_kernels = dict() + + +def generate_model(name): + """Generate a mocked kernel model. Caller is responsible for adding model to running_kernels dictionary.""" + dt = datetime.utcnow().isoformat() + 'Z' + kernel_id = str(uuid.uuid4()) + model = {'id': kernel_id, 'name': name, 'last_activity': str(dt), 'execution_state': 'idle', 'connections': 1} + return model + + +@gen.coroutine +def mock_fetch_gateway(url, **kwargs): + method = 'GET' + if kwargs['method']: + method = kwargs['method'] + + request = HTTPRequest(url=url, **kwargs) + + endpoint = str(url) + + # Fetch all kernelspecs + if endpoint.endswith('/api/kernelspecs') and method == 'GET': + response_buf = StringIO(json.dumps(kernelspecs)) + response = yield gen.maybe_future(HTTPResponse(request, 200, buffer=response_buf)) + raise gen.Return(response) + + # Fetch named kernelspec + if endpoint.rfind('/api/kernelspecs/') >= 0 and method == 'GET': + requested_kernelspec = endpoint.rpartition('/')[2] + kspecs = kernelspecs.get('kernelspecs') + if requested_kernelspec in kspecs: + response_buf = StringIO(json.dumps(kspecs.get(requested_kernelspec))) + response = yield gen.maybe_future(HTTPResponse(request, 200, buffer=response_buf)) + raise gen.Return(response) + else: + raise HTTPError(404, message='Kernelspec does not exist: %s' % requested_kernelspec) + + # Create kernel + if endpoint.endswith('/api/kernels') and method == 'POST': + json_body = json.loads(kwargs['body']) + name = json_body.get('name') + env = json_body.get('env') + kspec_name = env.get('KERNEL_KSPEC_NAME') + nt.assert_equal(name, kspec_name) # Ensure that KERNEL_ env values get propagated + model = generate_model(name) + running_kernels[model.get('id')] = model # Register model as a running kernel + response_buf = StringIO(json.dumps(model)) + response = yield gen.maybe_future(HTTPResponse(request, 201, buffer=response_buf)) + raise gen.Return(response) + + # Fetch list of running kernels + if endpoint.endswith('/api/kernels') and method == 'GET': + kernels = [] + for kernel_id in running_kernels.keys(): + model = running_kernels.get(kernel_id) + kernels.append(model) + response_buf = StringIO(json.dumps(kernels)) + response = yield gen.maybe_future(HTTPResponse(request, 200, buffer=response_buf)) + raise gen.Return(response) + + # Interrupt or restart existing kernel + if endpoint.rfind('/api/kernels/') >= 0 and method == 'POST': + requested_kernel_id, sep, action = endpoint.rpartition('/api/kernels/')[2].rpartition('/') + + if action == 'interrupt': + if requested_kernel_id in running_kernels: + response = yield gen.maybe_future(HTTPResponse(request, 204)) + raise gen.Return(response) + else: + raise HTTPError(404, message='Kernel does not exist: %s' % requested_kernel_id) + elif action == 'restart': + if requested_kernel_id in running_kernels: + response_buf = StringIO(json.dumps(running_kernels.get(requested_kernel_id))) + response = yield gen.maybe_future(HTTPResponse(request, 204, buffer=response_buf)) + raise gen.Return(response) + else: + raise HTTPError(404, message='Kernel does not exist: %s' % requested_kernel_id) + else: + raise HTTPError(404, message='Bad action detected: %s' % action) + + # Shutdown existing kernel + if endpoint.rfind('/api/kernels/') >= 0 and method == 'DELETE': + requested_kernel_id = endpoint.rpartition('/')[2] + running_kernels.pop(requested_kernel_id) # Simulate shutdown by removing kernel from running set + response = yield gen.maybe_future(HTTPResponse(request, 204)) + raise gen.Return(response) + + # Fetch existing kernel + if endpoint.rfind('/api/kernels/') >= 0 and method == 'GET': + requested_kernel_id = endpoint.rpartition('/')[2] + if requested_kernel_id in running_kernels: + response_buf = StringIO(json.dumps(running_kernels.get(requested_kernel_id))) + response = yield gen.maybe_future(HTTPResponse(request, 200, buffer=response_buf)) + raise gen.Return(response) + else: + raise HTTPError(404, message='Kernel does not exist: %s' % requested_kernel_id) + + +mocked_gateway = patch('notebook.gateway.managers.fetch_gateway', mock_fetch_gateway) + + +class TestGateway(NotebookTestBase): + + @classmethod + def setup_class(cls): + cls.config = Config() + cls.config.NotebookApp.gateway_url = 'http://mock-gateway-server:8889' + super(TestGateway, cls).setup_class() + + def test_gateway_class_mappings(self): + # Ensure appropriate class mappings are in place. + nt.assert_equal(self.notebook.kernel_manager_class.__name__, 'GatewayKernelManager') + nt.assert_equal(self.notebook.session_manager_class.__name__, 'GatewaySessionManager') + nt.assert_equal(self.notebook.kernel_spec_manager_class.__name__, 'GatewayKernelSpecManager') + + def test_gateway_get_kernelspecs(self): + # Validate that kernelspecs come from gateway. + with mocked_gateway: + response = self.request('GET', '/api/kernelspecs') + self.assertEqual(response.status_code, 200) + content = json.loads(response.content.decode('utf-8'), encoding='utf-8') + kspecs = content.get('kernelspecs') + self.assertEqual(len(kspecs), 2) + self.assertEqual(kspecs.get('kspec_bar').get('kspec_bar').get('name'), 'kspec_bar') + + def test_gateway_get_named_kernelspec(self): + # Validate that a specific kernelspec can be retrieved from gateway. + with mocked_gateway: + response = self.request('GET', '/api/kernelspecs/kspec_foo') + self.assertEqual(response.status_code, 200) + content = json.loads(response.content.decode('utf-8'), encoding='utf-8') + kspec_foo = content.get('kspec_foo') + self.assertEqual(kspec_foo.get('name'), 'kspec_foo') + + response = self.request('GET', '/api/kernelspecs/no_such_spec') + self.assertEqual(response.status_code, 404) + + def test_gateway_session_lifecycle(self): + # Validate session lifecycle functions; create and delete. + + # create + session_id, kernel_id = self.create_session('kspec_foo') + + # ensure kernel still considered running + self.assertTrue(self.is_kernel_running(kernel_id)) + + # interrupt + self.interrupt_kernel(kernel_id) + + # ensure kernel still considered running + self.assertTrue(self.is_kernel_running(kernel_id)) + + # restart + self.restart_kernel(kernel_id) + + # ensure kernel still considered running + self.assertTrue(self.is_kernel_running(kernel_id)) + + # delete + self.delete_session(session_id) + self.assertFalse(self.is_kernel_running(kernel_id)) + + def test_gateway_kernel_lifecycle(self): + # Validate kernel lifecycle functions; create, interrupt, restart and delete. + + # create + kernel_id = self.create_kernel('kspec_bar') + + # ensure kernel still considered running + self.assertTrue(self.is_kernel_running(kernel_id)) + + # interrupt + self.interrupt_kernel(kernel_id) + + # ensure kernel still considered running + self.assertTrue(self.is_kernel_running(kernel_id)) + + # restart + self.restart_kernel(kernel_id) + + # ensure kernel still considered running + self.assertTrue(self.is_kernel_running(kernel_id)) + + # delete + self.delete_kernel(kernel_id) + self.assertFalse(self.is_kernel_running(kernel_id)) + + def create_session(self, kernel_name): + """Creates a session for a kernel. The session is created against the notebook server + which then uses the gateway for kernel management. + """ + with mocked_gateway: + nb_path = os.path.join(self.notebook_dir, 'testgw.ipynb') + kwargs = dict() + kwargs['json'] = {'path': nb_path, 'type': 'notebook', 'kernel': {'name': kernel_name}} + + # add a KERNEL_ value to the current env and we'll ensure that that value exists in the mocked method + os.environ['KERNEL_KSPEC_NAME'] = kernel_name + + # Create the kernel... (also tests get_kernel) + response = self.request('POST', '/api/sessions', **kwargs) + self.assertEqual(response.status_code, 201) + model = json.loads(response.content.decode('utf-8'), encoding='utf-8') + self.assertEqual(model.get('path'), nb_path) + kernel_id = model.get('kernel').get('id') + # ensure its in the running_kernels and name matches. + running_kernel = running_kernels.get(kernel_id) + self.assertEqual(kernel_id, running_kernel.get('id')) + self.assertEqual(model.get('kernel').get('name'), running_kernel.get('name')) + session_id = model.get('id') + + # restore env + os.environ.pop('KERNEL_KSPEC_NAME') + return session_id, kernel_id + + def delete_session(self, session_id): + """Deletes a session corresponding to the given session id. + """ + with mocked_gateway: + # Delete the session (and kernel) + response = self.request('DELETE', '/api/sessions/' + session_id) + self.assertEqual(response.status_code, 204) + self.assertEqual(response.reason, 'No Content') + + def is_kernel_running(self, kernel_id): + """Issues request to get the set of running kernels + """ + with mocked_gateway: + # Get list of running kernels + response = self.request('GET', '/api/kernels') + self.assertEqual(response.status_code, 200) + kernels = json.loads(response.content.decode('utf-8'), encoding='utf-8') + self.assertEqual(len(kernels), len(running_kernels)) + for model in kernels: + if model.get('id') == kernel_id: + return True + return False + + def create_kernel(self, kernel_name): + """Issues request to retart the given kernel + """ + with mocked_gateway: + kwargs = dict() + kwargs['json'] = {'name': kernel_name} + + # add a KERNEL_ value to the current env and we'll ensure that that value exists in the mocked method + os.environ['KERNEL_KSPEC_NAME'] = kernel_name + + response = self.request('POST', '/api/kernels', **kwargs) + self.assertEqual(response.status_code, 201) + model = json.loads(response.content.decode('utf-8'), encoding='utf-8') + kernel_id = model.get('id') + # ensure its in the running_kernels and name matches. + running_kernel = running_kernels.get(kernel_id) + self.assertEqual(kernel_id, running_kernel.get('id')) + self.assertEqual(model.get('name'), kernel_name) + + # restore env + os.environ.pop('KERNEL_KSPEC_NAME') + return kernel_id + + def interrupt_kernel(self, kernel_id): + """Issues request to interrupt the given kernel + """ + with mocked_gateway: + response = self.request('POST', '/api/kernels/' + kernel_id + '/interrupt') + self.assertEqual(response.status_code, 204) + self.assertEqual(response.reason, 'No Content') + + def restart_kernel(self, kernel_id): + """Issues request to retart the given kernel + """ + with mocked_gateway: + response = self.request('POST', '/api/kernels/' + kernel_id + '/restart') + self.assertEqual(response.status_code, 200) + model = json.loads(response.content.decode('utf-8'), encoding='utf-8') + restarted_kernel_id = model.get('id') + # ensure its in the running_kernels and name matches. + running_kernel = running_kernels.get(restarted_kernel_id) + self.assertEqual(restarted_kernel_id, running_kernel.get('id')) + self.assertEqual(model.get('name'), running_kernel.get('name')) + + def delete_kernel(self, kernel_id): + """Deletes kernel corresponding to the given kernel id. + """ + with mocked_gateway: + # Delete the session (and kernel) + response = self.request('DELETE', '/api/kernels/' + kernel_id) + self.assertEqual(response.status_code, 204) + self.assertEqual(response.reason, 'No Content') +