From 060a2b9aea45e11404bce57fbd80017a2a34561c Mon Sep 17 00:00:00 2001 From: Kevin Bates Date: Fri, 19 Oct 2018 16:08:46 -0700 Subject: [PATCH 1/3] Embed NB2KG into Notebook server This change alleviates a significant pain-point for consumers of Jupyter Kernel and Enterprise Gateway projects by embedding the few classes defined in the NB2KG server extension directly into the Notebook server. All code resides in a separate gateway directory and the 'extension' is enabled via a new configuration option `--gateway-url`. Renamed classes from those used in standard NB2KG code so that Notebook servers using the existing NB2KG extension will still work. Added test_gateway.py to exercise overridden methods. It does this by mocking the call that issues requests to the gateway server. Updated the _Running a notebook server_ topic to include a description of this feature. --- docs/source/public_server.rst | 29 ++ notebook/gateway/__init__.py | 0 notebook/gateway/handlers.py | 393 ++++++++++++++++++++++++++ notebook/gateway/managers.py | 493 +++++++++++++++++++++++++++++++++ notebook/notebookapp.py | 55 +++- notebook/tests/test_gateway.py | 327 ++++++++++++++++++++++ 6 files changed, 1287 insertions(+), 10 deletions(-) create mode 100644 notebook/gateway/__init__.py create mode 100644 notebook/gateway/handlers.py create mode 100644 notebook/gateway/managers.py create mode 100644 notebook/tests/test_gateway.py diff --git a/docs/source/public_server.rst b/docs/source/public_server.rst index 3796a2a4fb..079916441a 100644 --- a/docs/source/public_server.rst +++ b/docs/source/public_server.rst @@ -343,6 +343,35 @@ single-tab mode: }); +Using a gateway server for kernel management +-------------------------------------------- + +You are now able to redirect the management of your kernels to a Gateway Server +(i.e., `Jupyter Kernel Gateway `_ or +`Jupyter Enterprise Gateway `_) +simply by specifying a Gateway url via the following command-line option: + + .. code-block:: bash + + $ jupyter notebook --gateway-url=http://my-gateway-server:8888 + +the environment: + + .. code-block:: bash + + GATEWAY_URL=http://my-gateway-server:8888 + +or in :file:`jupyter_notebook_config.py`: + + .. code-block:: python + + c.NotebookApp.gateway_url = http://my-gateway-server:8888 + +When provided, all kernel specifications will be retrieved from the specified Gateway server and all +kernels will be managed by that server. This option enables the ability to target kernel processes +against managed clusters while allowing for the notebook's management to remain local to the Notebook +server. + Known issues ------------ diff --git a/notebook/gateway/__init__.py b/notebook/gateway/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/notebook/gateway/handlers.py b/notebook/gateway/handlers.py new file mode 100644 index 0000000000..30098787e4 --- /dev/null +++ b/notebook/gateway/handlers.py @@ -0,0 +1,393 @@ +# Copyright (c) Jupyter Development Team. +# Distributed under the terms of the Modified BSD License. + +import os +import json +import logging +from socket import gaierror + +from ..base.handlers import APIHandler, IPythonHandler +from ..utils import url_path_join + +from tornado import gen, web +from tornado.concurrent import Future +from tornado.ioloop import IOLoop +from tornado.websocket import WebSocketHandler, websocket_connect +from tornado.httpclient import HTTPRequest +from tornado.simple_httpclient import HTTPTimeoutError +from tornado.escape import url_escape, json_decode, utf8 + +from ipython_genutils.py3compat import cast_unicode +from jupyter_client.session import Session +from traitlets.config.configurable import LoggingConfigurable + +# Note: Although some of these are available via NotebookApp (command line), we will +# take the approach of using environment variables to enable separate sets of values +# for use with the local Notebook server (via command-line) and remote Gateway server +# (via environment variables). + +GATEWAY_HEADERS = json.loads(os.getenv('GATEWAY_HEADERS', '{}')) +GATEWAY_HEADERS.update({ + 'Authorization': 'token {}'.format(os.getenv('GATEWAY_AUTH_TOKEN', '')) +}) +VALIDATE_GATEWAY_CERT = os.getenv('VALIDATE_GATEWAY_CERT') not in ['no', 'false'] + +GATEWAY_CLIENT_KEY = os.getenv('GATEWAY_CLIENT_KEY') +GATEWAY_CLIENT_CERT = os.getenv('GATEWAY_CLIENT_CERT') +GATEWAY_CLIENT_CA = os.getenv('GATEWAY_CLIENT_CA') + +GATEWAY_HTTP_USER = os.getenv('GATEWAY_HTTP_USER') +GATEWAY_HTTP_PASS = os.getenv('GATEWAY_HTTP_PASS') + +# Get env variables to handle timeout of request and connection +GATEWAY_CONNECT_TIMEOUT = float(os.getenv('GATEWAY_CONNECT_TIMEOUT', 20.0)) +GATEWAY_REQUEST_TIMEOUT = float(os.getenv('GATEWAY_REQUEST_TIMEOUT', 20.0)) + + +class WebSocketChannelsHandler(WebSocketHandler, IPythonHandler): + + session = None + gateway = None + kernel_id = None + + def set_default_headers(self): + """Undo the set_default_headers in IPythonHandler which doesn't make sense for websockets""" + pass + + def get_compression_options(self): + # use deflate compress websocket + return {} + + def authenticate(self): + """Run before finishing the GET request + + Extend this method to add logic that should fire before + the websocket finishes completing. + """ + # authenticate the request before opening the websocket + if self.get_current_user() is None: + self.log.warning("Couldn't authenticate WebSocket connection") + raise web.HTTPError(403) + + if self.get_argument('session_id', False): + self.session.session = cast_unicode(self.get_argument('session_id')) + else: + self.log.warning("No session ID specified") + + def initialize(self): + self.log.debug("Initializing websocket connection %s", self.request.path) + self.session = Session(config=self.config) + self.gateway = GatewayWebSocketClient(gateway_url=self.kernel_manager.parent.gateway_url) + + @gen.coroutine + def get(self, kernel_id, *args, **kwargs): + self.authenticate() + self.kernel_id = cast_unicode(kernel_id, 'ascii') + super(WebSocketChannelsHandler, self).get(kernel_id=kernel_id, *args, **kwargs) + + def open(self, kernel_id, *args, **kwargs): + """Handle web socket connection open to notebook server and delegate to gateway web socket handler """ + self.gateway.on_open( + kernel_id=kernel_id, + message_callback=self.write_message, + compression_options=self.get_compression_options() + ) + + def on_message(self, message): + """Forward message to gateway web socket handler.""" + self.log.debug("Sending message to gateway: {}".format(message)) + self.gateway.on_message(message) + + def write_message(self, message, binary=False): + """Send message back to notebook client. This is called via callback from self.gateway._read_messages.""" + self.log.debug("Receiving message from gateway: {}".format(message)) + if self.ws_connection: # prevent WebSocketClosedError + super(WebSocketChannelsHandler, self).write_message(message, binary=binary) + elif self.log.isEnabledFor(logging.DEBUG): + msg_summary = WebSocketChannelsHandler._get_message_summary(json_decode(utf8(message))) + self.log.debug("Notebook client closed websocket connection - message dropped: {}".format(msg_summary)) + + def on_close(self): + self.log.debug("Closing websocket connection %s", self.request.path) + self.gateway.on_close() + super(WebSocketChannelsHandler, self).on_close() + + @staticmethod + def _get_message_summary(message): + summary = [] + message_type = message['msg_type'] + summary.append('type: {}'.format(message_type)) + + if message_type == 'status': + summary.append(', state: {}'.format(message['content']['execution_state'])) + elif message_type == 'error': + summary.append(', {}:{}:{}'.format(message['content']['ename'], + message['content']['evalue'], + message['content']['traceback'])) + else: + summary.append(', ...') # don't display potentially sensitive data + + return ''.join(summary) + + +class GatewayWebSocketClient(LoggingConfigurable): + """Proxy web socket connection to a kernel/enterprise gateway.""" + + def __init__(self, **kwargs): + super(GatewayWebSocketClient, self).__init__(**kwargs) + self.gateway_url = kwargs['gateway_url'] + self.kernel_id = None + self.ws = None + self.ws_future = Future() + self.ws_future_cancelled = False + + @gen.coroutine + def _connect(self, kernel_id): + self.kernel_id = kernel_id + ws_url = url_path_join( + os.getenv('GATEWAY_WS_URL', self.gateway_url.replace('http', 'ws')), + '/api/kernels', + url_escape(kernel_id), + 'channels' + ) + self.log.info('Connecting to {}'.format(ws_url)) + parameters = { + "headers": GATEWAY_HEADERS, + "validate_cert": VALIDATE_GATEWAY_CERT, + "connect_timeout": GATEWAY_CONNECT_TIMEOUT, + "request_timeout": GATEWAY_REQUEST_TIMEOUT + } + if GATEWAY_HTTP_USER: + parameters["auth_username"] = GATEWAY_HTTP_USER + if GATEWAY_HTTP_PASS: + parameters["auth_password"] = GATEWAY_HTTP_PASS + if GATEWAY_CLIENT_KEY: + parameters["client_key"] = GATEWAY_CLIENT_KEY + parameters["client_cert"] = GATEWAY_CLIENT_CERT + if GATEWAY_CLIENT_CA: + parameters["ca_certs"] = GATEWAY_CLIENT_CA + + request = HTTPRequest(ws_url, **parameters) + self.ws_future = websocket_connect(request) + self.ws_future.add_done_callback(self._connection_done) + + def _connection_done(self, fut): + if not self.ws_future_cancelled: # prevent concurrent.futures._base.CancelledError + self.ws = fut.result() + self.log.debug("Connection is ready: ws: {}".format(self.ws)) + else: + self.log.warning("Websocket connection has been cancelled via client disconnect before its establishment. " + "Kernel with ID '{}' may not be terminated on Gateway: {}". + format(self.kernel_id, self.gateway_url)) + + def _disconnect(self): + if self.ws is not None: + # Close connection + self.ws.close() + elif not self.ws_future.done(): + # Cancel pending connection. Since future.cancel() is a noop on tornado, we'll track cancellation locally + self.ws_future.cancel() + self.ws_future_cancelled = True + self.log.debug("_disconnect: ws_future_cancelled: {}".format(self.ws_future_cancelled)) + + @gen.coroutine + def _read_messages(self, callback): + """Read messages from gateway server.""" + while True: + message = None + if not self.ws_future_cancelled: + try: + message = yield self.ws.read_message() + except Exception as e: + self.log.error("Exception reading message from websocket: {}".format(e)) # , exc_info=True) + if message is None: + break + callback(message) # pass back to notebook client (see self.on_open and WebSocketChannelsHandler.open) + else: # ws cancelled - stop reading + break + + def on_open(self, kernel_id, message_callback, **kwargs): + """Web socket connection open against gateway server.""" + self._connect(kernel_id) + loop = IOLoop.current() + loop.add_future( + self.ws_future, + lambda future: self._read_messages(message_callback) + ) + + def on_message(self, message): + """Send message to gateway server.""" + if self.ws is None: + loop = IOLoop.current() + loop.add_future( + self.ws_future, + lambda future: self._write_message(message) + ) + else: + self._write_message(message) + + def _write_message(self, message): + """Send message to gateway server.""" + try: + if not self.ws_future_cancelled: + self.ws.write_message(message) + except Exception as e: + self.log.error("Exception writing message to websocket: {}".format(e)) # , exc_info=True) + + def on_close(self): + """Web socket closed event.""" + self._disconnect() + + +# ----------------------------------------------------------------------------- +# kernel handlers +# ----------------------------------------------------------------------------- + +class MainKernelHandler(APIHandler): + """Replace default MainKernelHandler to enable async lookup of kernels.""" + + @web.authenticated + @gen.coroutine + def get(self): + km = self.kernel_manager + kernels = yield gen.maybe_future(km.list_kernels()) + self.finish(json.dumps(kernels)) + + @web.authenticated + @gen.coroutine + def post(self): + km = self.kernel_manager + model = self.get_json_body() + if model is None: + model = { + 'name': km.default_kernel_name + } + else: + model.setdefault('name', km.default_kernel_name) + + kernel_id = yield gen.maybe_future(km.start_kernel(kernel_name=model['name'])) + # This is now an async operation + model = yield gen.maybe_future(km.kernel_model(kernel_id)) + location = url_path_join(self.base_url, 'api', 'kernels', url_escape(kernel_id)) + self.set_header('Location', location) + self.set_status(201) + self.finish(json.dumps(model)) + + +class KernelHandler(APIHandler): + """Replace default KernelHandler to enable async lookup of kernels.""" + + @web.authenticated + @gen.coroutine + def get(self, kernel_id): + km = self.kernel_manager + # This is now an async operation + model = yield gen.maybe_future(km.kernel_model(kernel_id)) + if model is None: + raise web.HTTPError(404, u'Kernel does not exist: %s' % kernel_id) + self.finish(json.dumps(model)) + + @web.authenticated + @gen.coroutine + def delete(self, kernel_id): + km = self.kernel_manager + yield gen.maybe_future(km.shutdown_kernel(kernel_id)) + self.set_status(204) + self.finish() + + +class KernelActionHandler(APIHandler): + """Replace default KernelActionHandler to enable async lookup of kernels.""" + + @web.authenticated + @gen.coroutine + def post(self, kernel_id, action): + km = self.kernel_manager + + if action == 'interrupt': + km.interrupt_kernel(kernel_id) + self.set_status(204) + + if action == 'restart': + try: + yield gen.maybe_future(km.restart_kernel(kernel_id)) + except Exception as e: + self.log.error("Exception restarting kernel", exc_info=True) + self.set_status(500) + else: + # This is now an async operation + model = yield gen.maybe_future(km.kernel_model(kernel_id)) + self.write(json.dumps(model)) + self.finish() + +# ----------------------------------------------------------------------------- +# kernel spec handlers +# ----------------------------------------------------------------------------- + + +class MainKernelSpecHandler(APIHandler): + @web.authenticated + @gen.coroutine + def get(self): + ksm = self.kernel_spec_manager + try: + kernel_specs = yield gen.maybe_future(ksm.list_kernel_specs()) + # TODO: Remove resources until we support them + for name, spec in kernel_specs['kernelspecs'].items(): + spec['resources'] = {} + self.set_header("Content-Type", 'application/json') + self.write(json.dumps(kernel_specs)) + + # Trap a set of common exceptions so that we can inform the user that their Gateway url is incorrect + # or the server is not running. + # NOTE: We do this here since this handler is called during the Notebook's startup and subsequent refreshes + # of the tree view. + except ConnectionRefusedError: + gateway_url = ksm.parent.gateway_url + self.log.error("Connection refused from Gateway server url '{}'. " + "Check to be sure the Gateway instance is running.".format(gateway_url)) + except HTTPTimeoutError: + # This can occur if the host is valid (e.g., foo.com) but there's nothing there. + gateway_url = ksm.parent.gateway_url + self.log.error("Timeout error attempting to connect to Gateway server url '{}'. " + "Ensure gateway_url is valid and the Gateway instance is running.".format(gateway_url)) + except gaierror as e: + gateway_url = ksm.parent.gateway_url + self.log.error("The Gateway server specified in the gateway_url '{}' doesn't appear to be valid. " + "Ensure gateway_url is valid and the Gateway instance is running.".format(gateway_url)) + + self.finish() + + +class KernelSpecHandler(APIHandler): + @web.authenticated + @gen.coroutine + def get(self, kernel_name): + ksm = self.kernel_spec_manager + kernel_spec = yield ksm.get_kernel_spec(kernel_name) + if kernel_spec is None: + raise web.HTTPError(404, u'Kernel spec %s not found' % kernel_name) + # TODO: Remove resources until we support them + kernel_spec['resources'] = {} + self.set_header("Content-Type", 'application/json') + self.finish(json.dumps(kernel_spec)) + +# ----------------------------------------------------------------------------- +# URL to handler mappings +# ----------------------------------------------------------------------------- + + +from ..services.kernels.handlers import _kernel_id_regex, _kernel_action_regex +from ..services.kernelspecs.handlers import kernel_name_regex + +default_handlers = [ + (r"/api/kernels", MainKernelHandler), + (r"/api/kernels/%s" % _kernel_id_regex, KernelHandler), + (r"/api/kernels/%s/%s" % (_kernel_id_regex, _kernel_action_regex), KernelActionHandler), + (r"/api/kernels/%s/channels" % _kernel_id_regex, WebSocketChannelsHandler), + (r"/api/kernelspecs", MainKernelSpecHandler), + (r"/api/kernelspecs/%s" % kernel_name_regex, KernelSpecHandler), + # TODO: support kernel spec resources + # (r"/kernelspecs/%s/(?P.*)" % kernel_name_regex, KernelSpecResourceHandler), + +] diff --git a/notebook/gateway/managers.py b/notebook/gateway/managers.py new file mode 100644 index 0000000000..31dcba9cdc --- /dev/null +++ b/notebook/gateway/managers.py @@ -0,0 +1,493 @@ +# Copyright (c) Jupyter Development Team. +# Distributed under the terms of the Modified BSD License. + +import os +import json + +from tornado import gen +from tornado.escape import json_encode, json_decode, url_escape +from tornado.httpclient import HTTPClient, AsyncHTTPClient, HTTPError + +from ..services.kernels.kernelmanager import MappingKernelManager +from ..services.sessions.sessionmanager import SessionManager + +from jupyter_client.kernelspec import KernelSpecManager +from ..utils import url_path_join + +from traitlets import Instance, Unicode, default + +# Note: Although some of these are available via NotebookApp (command line), we will +# take the approach of using environment variables to enable separate sets of values +# for use with the local Notebook server (via command-line) and remote Gateway server +# (via environment variables). + +GATEWAY_HEADERS = json.loads(os.getenv('GATEWAY_HEADERS', '{}')) +GATEWAY_HEADERS.update({ + 'Authorization': 'token {}'.format(os.getenv('GATEWAY_AUTH_TOKEN', '')) +}) +VALIDATE_GATEWAY_CERT = os.getenv('VALIDATE_GATEWAY_CERT') not in ['no', 'false'] + +GATEWAY_CLIENT_KEY = os.getenv('GATEWAY_CLIENT_KEY') +GATEWAY_CLIENT_CERT = os.getenv('GATEWAY_CLIENT_CERT') +GATEWAY_CLIENT_CA = os.getenv('GATEWAY_CLIENT_CA') + +GATEWAY_HTTP_USER = os.getenv('GATEWAY_HTTP_USER') +GATEWAY_HTTP_PASS = os.getenv('GATEWAY_HTTP_PASS') + +GATEWAY_CONNECT_TIMEOUT = float(os.getenv('GATEWAY_CONNECT_TIMEOUT', 20.0)) +GATEWAY_REQUEST_TIMEOUT = float(os.getenv('GATEWAY_REQUEST_TIMEOUT', 20.0)) + + +def load_connection_args(**kwargs): + + if GATEWAY_CLIENT_CERT: + kwargs["client_key"] = kwargs.get("client_key", GATEWAY_CLIENT_KEY) + kwargs["client_cert"] = kwargs.get("client_cert", GATEWAY_CLIENT_CERT) + if GATEWAY_CLIENT_CA: + kwargs["ca_certs"] = kwargs.get("ca_certs", GATEWAY_CLIENT_CA) + kwargs['connect_timeout'] = kwargs.get('connect_timeout', GATEWAY_CONNECT_TIMEOUT) + kwargs['request_timeout'] = kwargs.get('request_timeout', GATEWAY_REQUEST_TIMEOUT) + kwargs['headers'] = kwargs.get('headers', GATEWAY_HEADERS) + kwargs['validate_cert'] = kwargs.get('validate_cert', VALIDATE_GATEWAY_CERT) + if GATEWAY_HTTP_USER: + kwargs['auth_username'] = kwargs.get('auth_username', GATEWAY_HTTP_USER) + if GATEWAY_HTTP_PASS: + kwargs['auth_password'] = kwargs.get('auth_password', GATEWAY_HTTP_PASS) + + return kwargs + + +@gen.coroutine +def fetch_gateway(endpoint, **kwargs): + """Make an async request to kernel gateway endpoint.""" + client = AsyncHTTPClient() + + kwargs = load_connection_args(**kwargs) + + response = yield client.fetch(endpoint, **kwargs) + raise gen.Return(response) + + +class GatewayKernelManager(MappingKernelManager): + """Kernel manager that supports remote kernels hosted by Jupyter + kernel gateway.""" + + kernels_endpoint_env = 'GATEWAY_KERNELS_ENDPOINT' + kernels_endpoint = Unicode(config=True, + help="""The gateway API endpoint for accessing kernel resources (GATEWAY_KERNELS_ENDPOINT env var)""") + + @default('kernels_endpoint') + def kernels_endpoint_default(self): + return os.getenv(self.kernels_endpoint_env, '/api/kernels') + + # We'll maintain our own set of kernel ids + _kernels = {} + + def __init__(self, **kwargs): + super(GatewayKernelManager, self).__init__(**kwargs) + self.gateway_url = self.parent.gateway_url + + def __contains__(self, kernel_id): + return kernel_id in self._kernels + + def remove_kernel(self, kernel_id): + """Complete override since we want to be more tolerant of missing keys """ + try: + return self._kernels.pop(kernel_id) + except KeyError: + pass + + def _get_kernel_endpoint_url(self, kernel_id=None): + """Builds a url for the kernels endpoint + + Parameters + ---------- + kernel_id: kernel UUID (optional) + """ + if kernel_id: + return url_path_join(self.gateway_url, self.kernels_endpoint, url_escape(str(kernel_id))) + + return url_path_join(self.gateway_url, self.kernels_endpoint) + + @gen.coroutine + def start_kernel(self, kernel_id=None, path=None, **kwargs): + """Start a kernel for a session and return its kernel_id. + + Parameters + ---------- + kernel_id : uuid + The uuid to associate the new kernel with. If this + is not None, this kernel will be persistent whenever it is + requested. + path : API path + The API path (unicode, '/' delimited) for the cwd. + Will be transformed to an OS path relative to root_dir. + """ + self.log.info('Request start kernel: kernel_id=%s, path="%s"', kernel_id, path) + + if kernel_id is None: + kernel_name = kwargs.get('kernel_name', 'python3') + kernel_url = self._get_kernel_endpoint_url() + self.log.debug("Request new kernel at: %s" % kernel_url) + + kernel_env = {k: v for (k, v) in dict(os.environ).items() if k.startswith('KERNEL_') + or k in os.environ.get('GATEWAY_ENV_WHITELIST', '').split(",")} + json_body = json_encode({'name': kernel_name, 'env': kernel_env}) + + response = yield fetch_gateway(kernel_url, method='POST', body=json_body) + kernel = json_decode(response.body) + kernel_id = kernel['id'] + self.log.info("Kernel started: %s" % kernel_id) + else: + kernel = yield self.get_kernel(kernel_id) + kernel_id = kernel['id'] + self.log.info("Using existing kernel: %s" % kernel_id) + + self._kernels[kernel_id] = kernel + raise gen.Return(kernel_id) + + @gen.coroutine + def get_kernel(self, kernel_id=None, **kwargs): + """Get kernel for kernel_id. + + Parameters + ---------- + kernel_id : uuid + The uuid of the kernel. + """ + kernel_url = self._get_kernel_endpoint_url(kernel_id) + self.log.debug("Request kernel at: %s" % kernel_url) + try: + response = yield fetch_gateway(kernel_url, method='GET') + except HTTPError as error: + if error.code == 404: + self.log.warn("Kernel not found at: %s" % kernel_url) + self.remove_kernel(kernel_id) + kernel = None + else: + raise + else: + kernel = json_decode(response.body) + self._kernels[kernel_id] = kernel + self.log.info("Kernel retrieved: %s" % kernel) + raise gen.Return(kernel) + + @gen.coroutine + def kernel_model(self, kernel_id): + """Return a dictionary of kernel information described in the + JSON standard model. + + Parameters + ---------- + kernel_id : uuid + The uuid of the kernel. + """ + self.log.debug("RemoteKernelManager.kernel_model: %s", kernel_id) + model = yield self.get_kernel(kernel_id) + raise gen.Return(model) + + @gen.coroutine + def list_kernels(self, **kwargs): + """Get a list of kernels.""" + kernel_url = self._get_kernel_endpoint_url() + self.log.debug("Request list kernels: %s", kernel_url) + response = yield fetch_gateway(kernel_url, method='GET') + kernels = json_decode(response.body) + self._kernels = {x['id']:x for x in kernels} + raise gen.Return(kernels) + + @gen.coroutine + def shutdown_kernel(self, kernel_id): + """Shutdown a kernel by its kernel uuid. + + Parameters + ========== + kernel_id : uuid + The id of the kernel to shutdown. + """ + kernel_url = self._get_kernel_endpoint_url(kernel_id) + self.log.debug("Request shutdown kernel at: %s", kernel_url) + response = yield fetch_gateway(kernel_url, method='DELETE') + self.log.debug("Shutdown kernel response: %d %s", response.code, response.reason) + self.remove_kernel(kernel_id) + + @gen.coroutine + def restart_kernel(self, kernel_id, now=False, **kwargs): + """Restart a kernel by its kernel uuid. + + Parameters + ========== + kernel_id : uuid + The id of the kernel to restart. + """ + kernel_url = self._get_kernel_endpoint_url(kernel_id) + '/restart' + self.log.debug("Request restart kernel at: %s", kernel_url) + response = yield fetch_gateway(kernel_url, method='POST', body=json_encode({})) + self.log.debug("Restart kernel response: %d %s", response.code, response.reason) + + @gen.coroutine + def interrupt_kernel(self, kernel_id, **kwargs): + """Interrupt a kernel by its kernel uuid. + + Parameters + ========== + kernel_id : uuid + The id of the kernel to interrupt. + """ + kernel_url = self._get_kernel_endpoint_url(kernel_id) + '/interrupt' + self.log.debug("Request interrupt kernel at: %s", kernel_url) + response = yield fetch_gateway(kernel_url, method='POST', body=json_encode({})) + self.log.debug("Interrupt kernel response: %d %s", response.code, response.reason) + + def shutdown_all(self): + """Shutdown all kernels.""" + # Note: We have to make this sync because the NotebookApp does not wait for async. + kwargs = {'method': 'DELETE'} + kwargs = load_connection_args(**kwargs) + client = HTTPClient() + for kernel_id in self._kernels.keys(): + kernel_url = self._get_kernel_endpoint_url(kernel_id) + self.log.debug("Request delete kernel at: %s", kernel_url) + try: + response = client.fetch(kernel_url, **kwargs) + except HTTPError: + pass + self.log.debug("Delete kernel response: %d %s", response.code, response.reason) + self.remove_kernel(kernel_id) + client.close() + + +class GatewayKernelSpecManager(KernelSpecManager): + + kernelspecs_endpoint_env = 'GATEWAY_KERNELSPECS_ENDPOINT' + kernelspecs_endpoint = Unicode(config=True, + help="""The kernel gateway API endpoint for accessing kernelspecs + (GATEWAY_KERNELSPECS_ENDPOINT env var)""") + + @default('kernelspecs_endpoint') + def kernelspecs_endpoint_default(self): + return os.getenv(self.kernelspecs_endpoint_env, '/api/kernelspecs') + + def __init__(self, **kwargs): + super(GatewayKernelSpecManager, self).__init__(**kwargs) + self.gateway_url = self.parent.gateway_url + + def _get_kernelspecs_endpoint_url(self, kernel_name=None): + """Builds a url for the kernels endpoint + + Parameters + ---------- + kernel_name: kernel name (optional) + """ + if kernel_name: + return url_path_join(self.gateway_url, self.kernelspecs_endpoint, url_escape(kernel_name)) + + return url_path_join(self.gateway_url, self.kernelspecs_endpoint) + + @gen.coroutine + def list_kernel_specs(self): + """Get a list of kernel specs.""" + kernel_spec_url = self._get_kernelspecs_endpoint_url() + self.log.debug("Request list kernel specs at: %s", kernel_spec_url) + response = yield fetch_gateway(kernel_spec_url, method='GET') + kernel_specs = json_decode(response.body) + raise gen.Return(kernel_specs) + + @gen.coroutine + def get_kernel_spec(self, kernel_name, **kwargs): + """Get kernel spec for kernel_name. + + Parameters + ---------- + kernel_name : str + The name of the kernel. + """ + kernel_spec_url = self._get_kernelspecs_endpoint_url(kernel_name=str(kernel_name)) + self.log.debug("Request kernel spec at: %s" % kernel_spec_url) + try: + response = yield fetch_gateway(kernel_spec_url, method='GET') + except HTTPError as error: + if error.code == 404: + self.log.warn("Kernel spec not found at: %s" % kernel_spec_url) + kernel_spec = None + else: + raise + else: + kernel_spec = json_decode(response.body) + raise gen.Return(kernel_spec) + + +class GatewaySessionManager(SessionManager): + kernel_manager = Instance('notebook.gateway.managers.GatewayKernelManager') + + @gen.coroutine + def create_session(self, path=None, name=None, type=None, + kernel_name=None, kernel_id=None): + """Creates a session and returns its model. + + Overrides base class method to turn into an async operation. + """ + session_id = self.new_session_id() + + kernel = None + if kernel_id is not None: + # This is now an async operation + kernel = yield self.kernel_manager.get_kernel(kernel_id) + + if kernel is not None: + pass + else: + kernel_id = yield self.start_kernel_for_session( + session_id, path, name, type, kernel_name, + ) + + result = yield self.save_session( + session_id, path=path, name=name, type=type, kernel_id=kernel_id, + ) + raise gen.Return(result) + + @gen.coroutine + def save_session(self, session_id, path=None, name=None, type=None, + kernel_id=None): + """Saves the items for the session with the given session_id + + Given a session_id (and any other of the arguments), this method + creates a row in the sqlite session database that holds the information + for a session. + + Parameters + ---------- + session_id : str + uuid for the session; this method must be given a session_id + path : str + the path for the given notebook + kernel_id : str + a uuid for the kernel associated with this session + + Returns + ------- + model : dict + a dictionary of the session model + """ + # This is now an async operation + session = yield super(GatewaySessionManager, self).save_session( + session_id, path=path, name=name, type=type, kernel_id=kernel_id + ) + raise gen.Return(session) + + @gen.coroutine + def get_session(self, **kwargs): + """Returns the model for a particular session. + + Takes a keyword argument and searches for the value in the session + database, then returns the rest of the session's info. + + Overrides base class method to turn into an async operation. + + Parameters + ---------- + **kwargs : keyword argument + must be given one of the keywords and values from the session database + (i.e. session_id, path, kernel_id) + + Returns + ------- + model : dict + returns a dictionary that includes all the information from the + session described by the kwarg. + """ + # This is now an async operation + session = yield super(GatewaySessionManager, self).get_session(**kwargs) + raise gen.Return(session) + + @gen.coroutine + def update_session(self, session_id, **kwargs): + """Updates the values in the session database. + + Changes the values of the session with the given session_id + with the values from the keyword arguments. + + Overrides base class method to turn into an async operation. + + Parameters + ---------- + session_id : str + a uuid that identifies a session in the sqlite3 database + **kwargs : str + the key must correspond to a column title in session database, + and the value replaces the current value in the session + with session_id. + """ + # This is now an async operation + session = yield self.get_session(session_id=session_id) + + if not kwargs: + # no changes + return + + sets = [] + for column in kwargs.keys(): + if column not in self._columns: + raise TypeError("No such column: %r" % column) + sets.append("%s=?" % column) + query = "UPDATE session SET %s WHERE session_id=?" % (', '.join(sets)) + self.cursor.execute(query, list(kwargs.values()) + [session_id]) + + @gen.coroutine + def row_to_model(self, row): + """Takes sqlite database session row and turns it into a dictionary. + + Overrides base class method to turn into an async operation. + """ + # Retrieve kernel for session, which is now an async operation + kernel = yield self.kernel_manager.get_kernel(row['kernel_id']) + if kernel is None: + # The kernel was killed or died without deleting the session. + # We can't use delete_session here because that tries to find + # and shut down the kernel. + self.cursor.execute("DELETE FROM session WHERE session_id=?", + (row['session_id'],)) + raise KeyError + + model = { + 'id': row['session_id'], + 'path': row['path'], + 'name': row['name'], + 'type': row['type'], + 'kernel': kernel + } + if row['type'] == 'notebook': # Provide the deprecated API. + model['notebook'] = {'path': row['path'], 'name': row['name']} + + raise gen.Return(model) + + @gen.coroutine + def list_sessions(self): + """Returns a list of dictionaries containing all the information from + the session database. + + Overrides base class method to turn into an async operation. + """ + c = self.cursor.execute("SELECT * FROM session") + result = [] + # We need to use fetchall() here, because row_to_model can delete rows, + # which messes up the cursor if we're iterating over rows. + for row in c.fetchall(): + try: + # This is now an async operation + model = yield self.row_to_model(row) + result.append(model) + except KeyError: + pass + raise gen.Return(result) + + @gen.coroutine + def delete_session(self, session_id): + """Deletes the row in the session database with given session_id. + + Overrides base class method to turn into an async operation. + """ + # This is now an async operation + session = yield self.get_session(session_id=session_id) + yield gen.maybe_future(self.kernel_manager.shutdown_kernel(session['kernel']['id'])) + self.cursor.execute("DELETE FROM session WHERE session_id=?", (session_id,)) diff --git a/notebook/notebookapp.py b/notebook/notebookapp.py index 05e44cf29f..abdefd274f 100755 --- a/notebook/notebookapp.py +++ b/notebook/notebookapp.py @@ -84,6 +84,9 @@ from .services.contents.filemanager import FileContentsManager from .services.contents.largefilemanager import LargeFileManager from .services.sessions.sessionmanager import SessionManager +from .gateway.managers import GatewayKernelManager +from .gateway.managers import GatewayKernelSpecManager +from .gateway.managers import GatewaySessionManager from .auth.login import LoginHandler from .auth.logout import LogoutHandler @@ -96,7 +99,7 @@ ) from jupyter_core.paths import jupyter_config_path from jupyter_client import KernelManager -from jupyter_client.kernelspec import KernelSpecManager, NoSuchKernel, NATIVE_KERNEL_NAME +from jupyter_client.kernelspec import KernelSpecManager from jupyter_client.session import Session from nbformat.sign import NotebookNotary from traitlets import ( @@ -144,19 +147,20 @@ def load_handlers(name): # The Tornado web application #----------------------------------------------------------------------------- + class NotebookWebApplication(web.Application): def __init__(self, jupyter_app, kernel_manager, contents_manager, session_manager, kernel_spec_manager, config_manager, extra_services, log, - base_url, default_url, settings_overrides, jinja_env_options): - + base_url, default_url, settings_overrides, jinja_env_options, + gateway_url): settings = self.init_settings( jupyter_app, kernel_manager, contents_manager, session_manager, kernel_spec_manager, config_manager, extra_services, log, base_url, - default_url, settings_overrides, jinja_env_options) + default_url, settings_overrides, jinja_env_options, gateway_url) handlers = self.init_handlers(settings) super(NotebookWebApplication, self).__init__(handlers, **settings) @@ -165,7 +169,7 @@ def init_settings(self, jupyter_app, kernel_manager, contents_manager, session_manager, kernel_spec_manager, config_manager, extra_services, log, base_url, default_url, settings_overrides, - jinja_env_options=None): + jinja_env_options=None, gateway_url=None): _template_path = settings_overrides.get( "template_path", @@ -279,6 +283,7 @@ def init_settings(self, jupyter_app, kernel_manager, contents_manager, server_root_dir=root_dir, jinja2_env=env, terminals_available=False, # Set later if terminals are available + gateway_url=gateway_url, ) # allow custom overrides for the tornado web app. @@ -305,13 +310,19 @@ def init_handlers(self, settings): handlers.extend(load_handlers('notebook.edit.handlers')) handlers.extend(load_handlers('notebook.services.api.handlers')) handlers.extend(load_handlers('notebook.services.config.handlers')) - handlers.extend(load_handlers('notebook.services.kernels.handlers')) handlers.extend(load_handlers('notebook.services.contents.handlers')) handlers.extend(load_handlers('notebook.services.sessions.handlers')) handlers.extend(load_handlers('notebook.services.nbconvert.handlers')) - handlers.extend(load_handlers('notebook.services.kernelspecs.handlers')) handlers.extend(load_handlers('notebook.services.security.handlers')) handlers.extend(load_handlers('notebook.services.shutdown')) + + # If gateway server is configured, replace appropriate handlers to perform redirection + if settings['gateway_url']: + handlers.extend(load_handlers('notebook.gateway.handlers')) + else: + handlers.extend(load_handlers('notebook.services.kernels.handlers')) + handlers.extend(load_handlers('notebook.services.kernelspecs.handlers')) + handlers.extend(settings['contents_manager'].get_extra_handlers()) handlers.append( @@ -547,6 +558,7 @@ def start(self): 'notebook-dir': 'NotebookApp.notebook_dir', 'browser': 'NotebookApp.browser', 'pylab': 'NotebookApp.pylab', + 'gateway-url': 'NotebookApp.gateway_url', }) #----------------------------------------------------------------------------- @@ -565,9 +577,9 @@ class NotebookApp(JupyterApp): flags = flags classes = [ - KernelManager, Session, MappingKernelManager, + KernelManager, Session, MappingKernelManager, KernelSpecManager, ContentsManager, FileContentsManager, NotebookNotary, - KernelSpecManager, + GatewayKernelManager, GatewayKernelSpecManager, GatewaySessionManager, ] flags = Dict(flags) aliases = Dict(aliases) @@ -1295,6 +1307,20 @@ def _update_server_extensions(self, change): is not available. """)) + gateway_url_env = 'GATEWAY_URL' + + @default('gateway_url') + def gateway_url_default(self): + return os.getenv(self.gateway_url_env) + + gateway_url = Unicode(default_value=None, allow_none=True, config=True, + help="""The url of the Kernel or Enterprise Gateway server where + kernel specifications are defined and kernel management takes place. + If defined, this Notebook server acts as a proxy for all kernel + management and kernel specification retrieval. (GATEWAY_URL env var) + """ + ) + def parse_command_line(self, argv=None): super(NotebookApp, self).parse_command_line(argv) @@ -1316,6 +1342,13 @@ def parse_command_line(self, argv=None): self.update_config(c) def init_configurables(self): + + # If gateway server is configured, replace appropriate managers to perform redirection + if self.gateway_url: + self.kernel_manager_class = 'notebook.gateway.managers.GatewayKernelManager' + self.session_manager_class = 'notebook.gateway.managers.GatewaySessionManager' + self.kernel_spec_manager_class = 'notebook.gateway.managers.GatewayKernelSpecManager' + self.kernel_spec_manager = self.kernel_spec_manager_class( parent=self, ) @@ -1381,7 +1414,7 @@ def init_webapp(self): self.session_manager, self.kernel_spec_manager, self.config_manager, self.extra_services, self.log, self.base_url, self.default_url, self.tornado_settings, - self.jinja_environment_options, + self.jinja_environment_options, self.gateway_url, ) ssl_options = self.ssl_options if self.certfile: @@ -1661,6 +1694,8 @@ def notebook_info(self, kernel_count=True): info += "\n" # Format the info so that the URL fits on a single line in 80 char display info += _("The Jupyter Notebook is running at:\n%s") % self.display_url + if self.gateway_url: + info += _("\nKernels will be managed by the Gateway server running at:\n%s") % self.gateway_url return info def server_info(self): diff --git a/notebook/tests/test_gateway.py b/notebook/tests/test_gateway.py new file mode 100644 index 0000000000..385c1f1cd5 --- /dev/null +++ b/notebook/tests/test_gateway.py @@ -0,0 +1,327 @@ +"""Test Gateway""" +import os +import json +import uuid +from datetime import datetime +from tornado import gen +from tornado.httpclient import HTTPRequest, HTTPResponse, HTTPError +from traitlets.config import Config +from .launchnotebook import NotebookTestBase + +try: + from unittest.mock import patch, Mock +except ImportError: + from mock import patch, Mock # py2 + +try: + from io import StringIO +except ImportError: + import StringIO + +import nose.tools as nt + + +def generate_kernelspec(name): + argv_stanza = ['python', '-m', 'ipykernel_launcher', '-f', '{connection_file}'] + spec_stanza = {'spec': {'argv': argv_stanza, 'env': {}, 'display_name': name, 'language': 'python', 'interrupt_mode': 'signal', 'metadata': {}}} + kernelspec_stanza = {name: {'name': name, 'spec': spec_stanza, 'resources': {}}} + return kernelspec_stanza + + +# We'll mock up two kernelspecs - kspec_foo and kspec_bar +kernelspecs = {'kernelspecs': {'kspec_foo': generate_kernelspec('kspec_foo'), 'kspec_bar': generate_kernelspec('kspec_bar')}} + + +# maintain a dictionary of expected running kernels. Key = kernel_id, Value = model. +running_kernels = dict() + + +def generate_model(name): + """Generate a mocked kernel model. Caller is responsible for adding model to running_kernels dictionary.""" + dt = datetime.utcnow().isoformat() + 'Z' + kernel_id = str(uuid.uuid4()) + model = {'id': kernel_id, 'name': name, 'last_activity': str(dt), 'execution_state': 'idle', 'connections': 1} + return model + + +@gen.coroutine +def mock_fetch_gateway(url, **kwargs): + method = 'GET' + if kwargs['method']: + method = kwargs['method'] + + request = HTTPRequest(url=url, **kwargs) + + endpoint = str(url) + + # Fetch all kernelspecs + if endpoint.endswith('/api/kernelspecs') and method == 'GET': + response_buf = StringIO(json.dumps(kernelspecs)) + response = yield gen.maybe_future(HTTPResponse(request, 200, buffer=response_buf)) + raise gen.Return(response) + + # Fetch named kernelspec + if endpoint.rfind('/api/kernelspecs/') >= 0 and method == 'GET': + requested_kernelspec = endpoint.rpartition('/')[2] + kspecs = kernelspecs.get('kernelspecs') + if requested_kernelspec in kspecs: + response_buf = StringIO(json.dumps(kspecs.get(requested_kernelspec))) + response = yield gen.maybe_future(HTTPResponse(request, 200, buffer=response_buf)) + raise gen.Return(response) + else: + raise HTTPError(404, message='Kernelspec does not exist: %s' % requested_kernelspec) + + # Create kernel + if endpoint.endswith('/api/kernels') and method == 'POST': + json_body = json.loads(kwargs['body']) + name = json_body.get('name') + env = json_body.get('env') + kspec_name = env.get('KERNEL_KSPEC_NAME') + nt.assert_equal(name, kspec_name) # Ensure that KERNEL_ env values get propagated + model = generate_model(name) + running_kernels[model.get('id')] = model # Register model as a running kernel + response_buf = StringIO(json.dumps(model)) + response = yield gen.maybe_future(HTTPResponse(request, 201, buffer=response_buf)) + raise gen.Return(response) + + # Fetch list of running kernels + if endpoint.endswith('/api/kernels') and method == 'GET': + kernels = [] + for kernel_id in running_kernels.keys(): + model = running_kernels.get(kernel_id) + kernels.append(model) + response_buf = StringIO(json.dumps(kernels)) + response = yield gen.maybe_future(HTTPResponse(request, 200, buffer=response_buf)) + raise gen.Return(response) + + # Interrupt or restart existing kernel + if endpoint.rfind('/api/kernels/') >= 0 and method == 'POST': + requested_kernel_id, sep, action = endpoint.rpartition('/api/kernels/')[2].rpartition('/') + + if action == 'interrupt': + if requested_kernel_id in running_kernels: + response = yield gen.maybe_future(HTTPResponse(request, 204)) + raise gen.Return(response) + else: + raise HTTPError(404, message='Kernel does not exist: %s' % requested_kernel_id) + elif action == 'restart': + if requested_kernel_id in running_kernels: + response_buf = StringIO(json.dumps(running_kernels.get(requested_kernel_id))) + response = yield gen.maybe_future(HTTPResponse(request, 204, buffer=response_buf)) + raise gen.Return(response) + else: + raise HTTPError(404, message='Kernel does not exist: %s' % requested_kernel_id) + else: + raise HTTPError(404, message='Bad action detected: %s' % action) + + # Shutdown existing kernel + if endpoint.rfind('/api/kernels/') >= 0 and method == 'DELETE': + requested_kernel_id = endpoint.rpartition('/')[2] + running_kernels.pop(requested_kernel_id) # Simulate shutdown by removing kernel from running set + response = yield gen.maybe_future(HTTPResponse(request, 204)) + raise gen.Return(response) + + # Fetch existing kernel + if endpoint.rfind('/api/kernels/') >= 0 and method == 'GET': + requested_kernel_id = endpoint.rpartition('/')[2] + if requested_kernel_id in running_kernels: + response_buf = StringIO(json.dumps(running_kernels.get(requested_kernel_id))) + response = yield gen.maybe_future(HTTPResponse(request, 200, buffer=response_buf)) + raise gen.Return(response) + else: + raise HTTPError(404, message='Kernel does not exist: %s' % requested_kernel_id) + + +mocked_gateway = patch('notebook.gateway.managers.fetch_gateway', mock_fetch_gateway) + + +class TestGateway(NotebookTestBase): + + @classmethod + def setup_class(cls): + cls.config = Config() + cls.config.NotebookApp.gateway_url = 'http://mock-gateway-server:8889' + super(TestGateway, cls).setup_class() + + def test_gateway_class_mappings(self): + # Ensure appropriate class mappings are in place. + nt.assert_equal(self.notebook.kernel_manager_class.__name__, 'GatewayKernelManager') + nt.assert_equal(self.notebook.session_manager_class.__name__, 'GatewaySessionManager') + nt.assert_equal(self.notebook.kernel_spec_manager_class.__name__, 'GatewayKernelSpecManager') + + def test_gateway_get_kernelspecs(self): + # Validate that kernelspecs come from gateway. + with mocked_gateway: + response = self.request('GET', '/api/kernelspecs') + self.assertEqual(response.status_code, 200) + content = json.loads(response.content.decode('utf-8'), encoding='utf-8') + kspecs = content.get('kernelspecs') + self.assertEqual(len(kspecs), 2) + self.assertEqual(kspecs.get('kspec_bar').get('kspec_bar').get('name'), 'kspec_bar') + + def test_gateway_get_named_kernelspec(self): + # Validate that a specific kernelspec can be retrieved from gateway. + with mocked_gateway: + response = self.request('GET', '/api/kernelspecs/kspec_foo') + self.assertEqual(response.status_code, 200) + content = json.loads(response.content.decode('utf-8'), encoding='utf-8') + kspec_foo = content.get('kspec_foo') + self.assertEqual(kspec_foo.get('name'), 'kspec_foo') + + response = self.request('GET', '/api/kernelspecs/no_such_spec') + self.assertEqual(response.status_code, 404) + + def test_gateway_session_lifecycle(self): + # Validate session lifecycle functions; create and delete. + + # create + session_id, kernel_id = self.create_session('kspec_foo') + + # ensure kernel still considered running + self.assertTrue(self.is_kernel_running(kernel_id)) + + # interrupt + self.interrupt_kernel(kernel_id) + + # ensure kernel still considered running + self.assertTrue(self.is_kernel_running(kernel_id)) + + # restart + self.restart_kernel(kernel_id) + + # ensure kernel still considered running + self.assertTrue(self.is_kernel_running(kernel_id)) + + # delete + self.delete_session(session_id) + self.assertFalse(self.is_kernel_running(kernel_id)) + + def test_gateway_kernel_lifecycle(self): + # Validate kernel lifecycle functions; create, interrupt, restart and delete. + + # create + kernel_id = self.create_kernel('kspec_bar') + + # ensure kernel still considered running + self.assertTrue(self.is_kernel_running(kernel_id)) + + # interrupt + self.interrupt_kernel(kernel_id) + + # ensure kernel still considered running + self.assertTrue(self.is_kernel_running(kernel_id)) + + # restart + self.restart_kernel(kernel_id) + + # ensure kernel still considered running + self.assertTrue(self.is_kernel_running(kernel_id)) + + # delete + self.delete_kernel(kernel_id) + self.assertFalse(self.is_kernel_running(kernel_id)) + + def create_session(self, kernel_name): + """Creates a session for a kernel. The session is created against the notebook server + which then uses the gateway for kernel management. + """ + with mocked_gateway: + nb_path = os.path.join(self.notebook_dir, 'testgw.ipynb') + kwargs = dict() + kwargs['json'] = {'path': nb_path, 'type': 'notebook', 'kernel': {'name': kernel_name}} + + # add a KERNEL_ value to the current env and we'll ensure that that value exists in the mocked method + os.environ['KERNEL_KSPEC_NAME'] = kernel_name + + # Create the kernel... (also tests get_kernel) + response = self.request('POST', '/api/sessions', **kwargs) + self.assertEqual(response.status_code, 201) + model = json.loads(response.content.decode('utf-8'), encoding='utf-8') + self.assertEqual(model.get('path'), nb_path) + kernel_id = model.get('kernel').get('id') + # ensure its in the running_kernels and name matches. + running_kernel = running_kernels.get(kernel_id) + self.assertEqual(kernel_id, running_kernel.get('id')) + self.assertEqual(model.get('kernel').get('name'), running_kernel.get('name')) + session_id = model.get('id') + + # restore env + os.environ.pop('KERNEL_KSPEC_NAME') + return session_id, kernel_id + + def delete_session(self, session_id): + """Deletes a session corresponding to the given session id. + """ + with mocked_gateway: + # Delete the session (and kernel) + response = self.request('DELETE', '/api/sessions/' + session_id) + self.assertEqual(response.status_code, 204) + self.assertEqual(response.reason, 'No Content') + + def is_kernel_running(self, kernel_id): + """Issues request to get the set of running kernels + """ + with mocked_gateway: + # Get list of running kernels + response = self.request('GET', '/api/kernels') + self.assertEqual(response.status_code, 200) + kernels = json.loads(response.content.decode('utf-8'), encoding='utf-8') + self.assertEqual(len(kernels), len(running_kernels)) + for model in kernels: + if model.get('id') == kernel_id: + return True + return False + + def create_kernel(self, kernel_name): + """Issues request to retart the given kernel + """ + with mocked_gateway: + kwargs = dict() + kwargs['json'] = {'name': kernel_name} + + # add a KERNEL_ value to the current env and we'll ensure that that value exists in the mocked method + os.environ['KERNEL_KSPEC_NAME'] = kernel_name + + response = self.request('POST', '/api/kernels', **kwargs) + self.assertEqual(response.status_code, 201) + model = json.loads(response.content.decode('utf-8'), encoding='utf-8') + kernel_id = model.get('id') + # ensure its in the running_kernels and name matches. + running_kernel = running_kernels.get(kernel_id) + self.assertEqual(kernel_id, running_kernel.get('id')) + self.assertEqual(model.get('name'), kernel_name) + + # restore env + os.environ.pop('KERNEL_KSPEC_NAME') + return kernel_id + + def interrupt_kernel(self, kernel_id): + """Issues request to interrupt the given kernel + """ + with mocked_gateway: + response = self.request('POST', '/api/kernels/' + kernel_id + '/interrupt') + self.assertEqual(response.status_code, 204) + self.assertEqual(response.reason, 'No Content') + + def restart_kernel(self, kernel_id): + """Issues request to retart the given kernel + """ + with mocked_gateway: + response = self.request('POST', '/api/kernels/' + kernel_id + '/restart') + self.assertEqual(response.status_code, 200) + model = json.loads(response.content.decode('utf-8'), encoding='utf-8') + restarted_kernel_id = model.get('id') + # ensure its in the running_kernels and name matches. + running_kernel = running_kernels.get(restarted_kernel_id) + self.assertEqual(restarted_kernel_id, running_kernel.get('id')) + self.assertEqual(model.get('name'), running_kernel.get('name')) + + def delete_kernel(self, kernel_id): + """Deletes kernel corresponding to the given kernel id. + """ + with mocked_gateway: + # Delete the session (and kernel) + response = self.request('DELETE', '/api/kernels/' + kernel_id) + self.assertEqual(response.status_code, 204) + self.assertEqual(response.reason, 'No Content') + From f74ef2f69121acfe07f2200a88d4f2119b871551 Mon Sep 17 00:00:00 2001 From: Kevin Bates Date: Tue, 4 Dec 2018 14:19:56 -0800 Subject: [PATCH 2/3] Move environment variables to SingletonConfigurable Created a singleton class `Gateway` to store all configuration options for a Gateway. This class also holds some help methods to make it easier to use the options and determine if the gateway option is enabled. Updated the NotebookTestBase class to allow for subclasses to infuence the patched environment as well as command line options via argv. Added a test to ensure various gateway configuration items can be set via the environment or command-line. --- docs/source/public_server.rst | 4 +- notebook/gateway/handlers.py | 63 ++----- notebook/gateway/managers.py | 284 ++++++++++++++++++++++++------- notebook/notebookapp.py | 42 ++--- notebook/tests/launchnotebook.py | 28 +-- notebook/tests/test_gateway.py | 32 +++- 6 files changed, 295 insertions(+), 158 deletions(-) diff --git a/docs/source/public_server.rst b/docs/source/public_server.rst index 079916441a..5bbce15429 100644 --- a/docs/source/public_server.rst +++ b/docs/source/public_server.rst @@ -359,13 +359,13 @@ the environment: .. code-block:: bash - GATEWAY_URL=http://my-gateway-server:8888 + JUPYTER_GATEWAY_URL=http://my-gateway-server:8888 or in :file:`jupyter_notebook_config.py`: .. code-block:: python - c.NotebookApp.gateway_url = http://my-gateway-server:8888 + c.Gateway.url = http://my-gateway-server:8888 When provided, all kernel specifications will be retrieved from the specified Gateway server and all kernels will be managed by that server. This option enables the ability to target kernel processes diff --git a/notebook/gateway/handlers.py b/notebook/gateway/handlers.py index 30098787e4..1e28bfa779 100644 --- a/notebook/gateway/handlers.py +++ b/notebook/gateway/handlers.py @@ -21,27 +21,7 @@ from jupyter_client.session import Session from traitlets.config.configurable import LoggingConfigurable -# Note: Although some of these are available via NotebookApp (command line), we will -# take the approach of using environment variables to enable separate sets of values -# for use with the local Notebook server (via command-line) and remote Gateway server -# (via environment variables). - -GATEWAY_HEADERS = json.loads(os.getenv('GATEWAY_HEADERS', '{}')) -GATEWAY_HEADERS.update({ - 'Authorization': 'token {}'.format(os.getenv('GATEWAY_AUTH_TOKEN', '')) -}) -VALIDATE_GATEWAY_CERT = os.getenv('VALIDATE_GATEWAY_CERT') not in ['no', 'false'] - -GATEWAY_CLIENT_KEY = os.getenv('GATEWAY_CLIENT_KEY') -GATEWAY_CLIENT_CERT = os.getenv('GATEWAY_CLIENT_CERT') -GATEWAY_CLIENT_CA = os.getenv('GATEWAY_CLIENT_CA') - -GATEWAY_HTTP_USER = os.getenv('GATEWAY_HTTP_USER') -GATEWAY_HTTP_PASS = os.getenv('GATEWAY_HTTP_PASS') - -# Get env variables to handle timeout of request and connection -GATEWAY_CONNECT_TIMEOUT = float(os.getenv('GATEWAY_CONNECT_TIMEOUT', 20.0)) -GATEWAY_REQUEST_TIMEOUT = float(os.getenv('GATEWAY_REQUEST_TIMEOUT', 20.0)) +from .managers import Gateway class WebSocketChannelsHandler(WebSocketHandler, IPythonHandler): @@ -77,7 +57,7 @@ def authenticate(self): def initialize(self): self.log.debug("Initializing websocket connection %s", self.request.path) self.session = Session(config=self.config) - self.gateway = GatewayWebSocketClient(gateway_url=self.kernel_manager.parent.gateway_url) + self.gateway = GatewayWebSocketClient(gateway_url=Gateway.instance().url) @gen.coroutine def get(self, kernel_id, *args, **kwargs): @@ -135,7 +115,6 @@ class GatewayWebSocketClient(LoggingConfigurable): def __init__(self, **kwargs): super(GatewayWebSocketClient, self).__init__(**kwargs) - self.gateway_url = kwargs['gateway_url'] self.kernel_id = None self.ws = None self.ws_future = Future() @@ -145,29 +124,14 @@ def __init__(self, **kwargs): def _connect(self, kernel_id): self.kernel_id = kernel_id ws_url = url_path_join( - os.getenv('GATEWAY_WS_URL', self.gateway_url.replace('http', 'ws')), - '/api/kernels', - url_escape(kernel_id), - 'channels' + Gateway.instance().ws_url, + Gateway.instance().kernels_endpoint, url_escape(kernel_id), 'channels' ) self.log.info('Connecting to {}'.format(ws_url)) - parameters = { - "headers": GATEWAY_HEADERS, - "validate_cert": VALIDATE_GATEWAY_CERT, - "connect_timeout": GATEWAY_CONNECT_TIMEOUT, - "request_timeout": GATEWAY_REQUEST_TIMEOUT - } - if GATEWAY_HTTP_USER: - parameters["auth_username"] = GATEWAY_HTTP_USER - if GATEWAY_HTTP_PASS: - parameters["auth_password"] = GATEWAY_HTTP_PASS - if GATEWAY_CLIENT_KEY: - parameters["client_key"] = GATEWAY_CLIENT_KEY - parameters["client_cert"] = GATEWAY_CLIENT_CERT - if GATEWAY_CLIENT_CA: - parameters["ca_certs"] = GATEWAY_CLIENT_CA - - request = HTTPRequest(ws_url, **parameters) + kwargs = {} + kwargs = Gateway.instance().load_connection_args(**kwargs) + + request = HTTPRequest(ws_url, **kwargs) self.ws_future = websocket_connect(request) self.ws_future.add_done_callback(self._connection_done) @@ -178,7 +142,7 @@ def _connection_done(self, fut): else: self.log.warning("Websocket connection has been cancelled via client disconnect before its establishment. " "Kernel with ID '{}' may not be terminated on Gateway: {}". - format(self.kernel_id, self.gateway_url)) + format(self.kernel_id, Gateway.instance().url)) def _disconnect(self): if self.ws is not None: @@ -343,18 +307,15 @@ def get(self): # NOTE: We do this here since this handler is called during the Notebook's startup and subsequent refreshes # of the tree view. except ConnectionRefusedError: - gateway_url = ksm.parent.gateway_url self.log.error("Connection refused from Gateway server url '{}'. " - "Check to be sure the Gateway instance is running.".format(gateway_url)) + "Check to be sure the Gateway instance is running.".format(Gateway.instance().url)) except HTTPTimeoutError: # This can occur if the host is valid (e.g., foo.com) but there's nothing there. - gateway_url = ksm.parent.gateway_url self.log.error("Timeout error attempting to connect to Gateway server url '{}'. " - "Ensure gateway_url is valid and the Gateway instance is running.".format(gateway_url)) + "Ensure gateway url is valid and the Gateway instance is running.".format(Gateway.instance().url)) except gaierror as e: - gateway_url = ksm.parent.gateway_url self.log.error("The Gateway server specified in the gateway_url '{}' doesn't appear to be valid. " - "Ensure gateway_url is valid and the Gateway instance is running.".format(gateway_url)) + "Ensure gateway url is valid and the Gateway instance is running.".format(Gateway.instance().url)) self.finish() diff --git a/notebook/gateway/managers.py b/notebook/gateway/managers.py index 31dcba9cdc..42da1696fc 100644 --- a/notebook/gateway/managers.py +++ b/notebook/gateway/managers.py @@ -14,78 +14,243 @@ from jupyter_client.kernelspec import KernelSpecManager from ..utils import url_path_join -from traitlets import Instance, Unicode, default +from traitlets import Instance, Unicode, Float, Bool, default, validate, TraitError +from traitlets.config import SingletonConfigurable -# Note: Although some of these are available via NotebookApp (command line), we will -# take the approach of using environment variables to enable separate sets of values -# for use with the local Notebook server (via command-line) and remote Gateway server -# (via environment variables). -GATEWAY_HEADERS = json.loads(os.getenv('GATEWAY_HEADERS', '{}')) -GATEWAY_HEADERS.update({ - 'Authorization': 'token {}'.format(os.getenv('GATEWAY_AUTH_TOKEN', '')) -}) -VALIDATE_GATEWAY_CERT = os.getenv('VALIDATE_GATEWAY_CERT') not in ['no', 'false'] +@gen.coroutine +def fetch_gateway(endpoint, **kwargs): + """Make an async request to kernel gateway endpoint.""" + client = AsyncHTTPClient() -GATEWAY_CLIENT_KEY = os.getenv('GATEWAY_CLIENT_KEY') -GATEWAY_CLIENT_CERT = os.getenv('GATEWAY_CLIENT_CERT') -GATEWAY_CLIENT_CA = os.getenv('GATEWAY_CLIENT_CA') + kwargs = Gateway.instance().load_connection_args(**kwargs) -GATEWAY_HTTP_USER = os.getenv('GATEWAY_HTTP_USER') -GATEWAY_HTTP_PASS = os.getenv('GATEWAY_HTTP_PASS') + response = yield client.fetch(endpoint, **kwargs) + raise gen.Return(response) -GATEWAY_CONNECT_TIMEOUT = float(os.getenv('GATEWAY_CONNECT_TIMEOUT', 20.0)) -GATEWAY_REQUEST_TIMEOUT = float(os.getenv('GATEWAY_REQUEST_TIMEOUT', 20.0)) +class Gateway(SingletonConfigurable): + """This class manages the configuration. It's its own class so that we can avoid having command + line options of the likes `--GatewayKernelManager.connect_timeout` and use the shorter and more + applicable `--Gateway.connect_timeout`, etc. It also contains some helper methods to build + request arguments out of the various config options. -def load_connection_args(**kwargs): + """ - if GATEWAY_CLIENT_CERT: - kwargs["client_key"] = kwargs.get("client_key", GATEWAY_CLIENT_KEY) - kwargs["client_cert"] = kwargs.get("client_cert", GATEWAY_CLIENT_CERT) - if GATEWAY_CLIENT_CA: - kwargs["ca_certs"] = kwargs.get("ca_certs", GATEWAY_CLIENT_CA) - kwargs['connect_timeout'] = kwargs.get('connect_timeout', GATEWAY_CONNECT_TIMEOUT) - kwargs['request_timeout'] = kwargs.get('request_timeout', GATEWAY_REQUEST_TIMEOUT) - kwargs['headers'] = kwargs.get('headers', GATEWAY_HEADERS) - kwargs['validate_cert'] = kwargs.get('validate_cert', VALIDATE_GATEWAY_CERT) - if GATEWAY_HTTP_USER: - kwargs['auth_username'] = kwargs.get('auth_username', GATEWAY_HTTP_USER) - if GATEWAY_HTTP_PASS: - kwargs['auth_password'] = kwargs.get('auth_password', GATEWAY_HTTP_PASS) + url = Unicode(default_value=None, allow_none=True, config=True, + help="""The url of the Kernel or Enterprise Gateway server where + kernel specifications are defined and kernel management takes place. + If defined, this Notebook server acts as a proxy for all kernel + management and kernel specification retrieval. (JUPYTER_GATEWAY_URL env var) + """ + ) + + url_env = 'JUPYTER_GATEWAY_URL' + @default('url') + def _url_default(self): + return os.environ.get(self.url_env) + + @validate('url') + def _url_validate(self, proposal): + value = proposal['value'] + # Ensure value, if present, starts with 'http' + if value is not None and len(value) > 0: + if not str(value).lower().startswith('http'): + raise TraitError("Gateway url must start with 'http': '%r'" % value) + return value + + ws_url = Unicode(default_value=None, allow_none=True, config=True, + help="""The websocket url of the Kernel or Enterprise Gateway server. If not provided, this value + will correspond to the value of the Gateway url with 'ws' in place of 'http'. (JUPYTER_GATEWAY_WS_URL env var) + """ + ) + + ws_url_env = 'JUPYTER_GATEWAY_WS_URL' + @default('ws_url') + def _ws_url_default(self): + default_value = os.environ.get(self.ws_url_env) + if default_value is None: + if self.gateway_enabled: + default_value = self.url.lower().replace('http', 'ws') + return default_value + + @validate('ws_url') + def _ws_url_validate(self, proposal): + value = proposal['value'] + # Ensure value, if present, starts with 'ws' + if value is not None and len(value) > 0: + if not str(value).lower().startswith('ws'): + raise TraitError("Gateway ws_url must start with 'ws': '%r'" % value) + return value + + kernels_endpoint_default_value = '/api/kernels' + kernels_endpoint_env = 'JUPYTER_GATEWAY_KERNELS_ENDPOINT' + kernels_endpoint = Unicode(default_value=kernels_endpoint_default_value, config=True, + help="""The gateway API endpoint for accessing kernel resources (JUPYTER_GATEWAY_KERNELS_ENDPOINT env var)""") - return kwargs + @default('kernels_endpoint') + def _kernels_endpoint_default(self): + return os.environ.get(self.kernels_endpoint_env, self.kernels_endpoint_default_value) + kernelspecs_endpoint_default_value = '/api/kernelspecs' + kernelspecs_endpoint_env = 'JUPYTER_GATEWAY_KERNELSPECS_ENDPOINT' + kernelspecs_endpoint = Unicode(default_value=kernelspecs_endpoint_default_value, config=True, + help="""The gateway API endpoint for accessing kernelspecs (JUPYTER_GATEWAY_KERNELSPECS_ENDPOINT env var)""") -@gen.coroutine -def fetch_gateway(endpoint, **kwargs): - """Make an async request to kernel gateway endpoint.""" - client = AsyncHTTPClient() + @default('kernelspecs_endpoint') + def _kernelspecs_endpoint_default(self): + return os.environ.get(self.kernelspecs_endpoint_env, self.kernelspecs_endpoint_default_value) + + connect_timeout_default_value = 20.0 + connect_timeout_env = 'JUPYTER_GATEWAY_CONNECT_TIMEOUT' + connect_timeout = Float(default_value=connect_timeout_default_value, config=True, + help="""The time allowed for HTTP connection establishment with the Gateway server. + (JUPYTER_GATEWAY_CONNECT_TIMEOUT env var)""") + + @default('connect_timeout') + def connect_timeout_default(self): + return float(os.environ.get('JUPYTER_GATEWAY_CONNECT_TIMEOUT', self.connect_timeout_default_value)) + + request_timeout_default_value = 20.0 + request_timeout_env = 'JUPYTER_GATEWAY_REQUEST_TIMEOUT' + request_timeout = Float(default_value=request_timeout_default_value, config=True, + help="""The time allowed for HTTP request completion. (JUPYTER_GATEWAY_REQUEST_TIMEOUT env var)""") + + @default('request_timeout') + def request_timeout_default(self): + return float(os.environ.get('JUPYTER_GATEWAY_REQUEST_TIMEOUT', self.request_timeout_default_value)) + + client_key = Unicode(default_value=None, allow_none=True, config=True, + help="""The filename for client SSL key, if any. (JUPYTER_GATEWAY_CLIENT_KEY env var) + """ + ) + client_key_env = 'JUPYTER_GATEWAY_CLIENT_KEY' - kwargs = load_connection_args(**kwargs) + @default('client_key') + def _client_key_default(self): + return os.environ.get(self.client_key_env) - response = yield client.fetch(endpoint, **kwargs) - raise gen.Return(response) + client_cert = Unicode(default_value=None, allow_none=True, config=True, + help="""The filename for client SSL certificate, if any. (JUPYTER_GATEWAY_CLIENT_CERT env var) + """ + ) + client_cert_env = 'JUPYTER_GATEWAY_CLIENT_CERT' + @default('client_cert') + def _client_cert_default(self): + return os.environ.get(self.client_cert_env) -class GatewayKernelManager(MappingKernelManager): - """Kernel manager that supports remote kernels hosted by Jupyter - kernel gateway.""" + ca_certs = Unicode(default_value=None, allow_none=True, config=True, + help="""The filename of CA certificates or None to use defaults. (JUPYTER_GATEWAY_CA_CERTS env var) + """ + ) + ca_certs_env = 'JUPYTER_GATEWAY_CA_CERTS' - kernels_endpoint_env = 'GATEWAY_KERNELS_ENDPOINT' - kernels_endpoint = Unicode(config=True, - help="""The gateway API endpoint for accessing kernel resources (GATEWAY_KERNELS_ENDPOINT env var)""") + @default('ca_certs') + def _ca_certs_default(self): + return os.environ.get(self.ca_certs_env) - @default('kernels_endpoint') - def kernels_endpoint_default(self): - return os.getenv(self.kernels_endpoint_env, '/api/kernels') + http_user = Unicode(default_value=None, allow_none=True, config=True, + help="""The username for HTTP authentication. (JUPYTER_GATEWAY_HTTP_USER env var) + """ + ) + http_user_env = 'JUPYTER_GATEWAY_HTTP_USER' + + @default('http_user') + def _http_user_default(self): + return os.environ.get(self.http_user_env) + + http_pwd = Unicode(default_value=None, allow_none=True, config=True, + help="""The password for HTTP authentication. (JUPYTER_GATEWAY_HTTP_PWD env var) + """ + ) + http_pwd_env = 'JUPYTER_GATEWAY_HTTP_PWD' + + @default('http_pwd') + def _http_pwd_default(self): + return os.environ.get(self.http_pwd_env) + + headers_default_value = '{}' + headers_env = 'JUPYTER_GATEWAY_HEADERS' + headers = Unicode(default_value=headers_default_value, allow_none=True,config=True, + help="""Additional HTTP headers to pass on the request. This value will be converted to a dict. + (JUPYTER_GATEWAY_HEADERS env var) + """ + ) + + @default('headers') + def _headers_default(self): + return os.environ.get(self.headers_env, self.headers_default_value) + + auth_token = Unicode(default_value=None, allow_none=True, config=True, + help="""The authorization token used in the HTTP headers. (JUPYTER_GATEWAY_AUTH_TOKEN env var) + """ + ) + auth_token_env = 'JUPYTER_GATEWAY_AUTH_TOKEN' + + @default('auth_token') + def _auth_token_default(self): + return os.environ.get(self.auth_token_env) + + validate_cert_default_value = True + validate_cert_env = 'JUPYTER_GATEWAY_VALIDATE_CERT' + validate_cert = Bool(default_value=validate_cert_default_value, config=True, + help="""For HTTPS requests, determines if server's certificate should be validated or not. + (JUPYTER_GATEWAY_VALIDATE_CERT env var)""" + ) + + @default('validate_cert') + def validate_cert_default(self): + return bool(os.environ.get(self.validate_cert_env, str(self.validate_cert_default_value)) not in ['no', 'false']) + + def __init__(self, **kwargs): + super(Gateway, self).__init__(**kwargs) + self._static_args = {} # initialized on first use + + @property + def gateway_enabled(self): + return bool(self.url is not None and len(self.url) > 0) + + def init_static_args(self): + """Initialize arguments used on every request. Since these are static values, we'll + perform this operation once. + + """ + self._static_args['headers'] = json.loads(self.headers) + self._static_args['headers'].update({'Authorization': 'token {}'.format(self.auth_token)}) + self._static_args['connect_timeout'] = self.connect_timeout + self._static_args['request_timeout'] = self.request_timeout + self._static_args['validate_cert'] = self.validate_cert + if self.client_cert: + self._static_args['client_cert'] = self.client_cert + self._static_args['client_key'] = self.client_key + if self.ca_certs: + self._static_args['ca_certs'] = self.ca_certs + if self.http_user: + self._static_args['auth_username'] = self.http_user + if self.http_pwd: + self._static_args['auth_password'] = self.http_pwd + + def load_connection_args(self, **kwargs): + """Merges the static args relative to the connection, with the given keyword arguments. If statics + have yet to be initialized, we'll do that here. + + """ + if len(self._static_args) == 0: + self.init_static_args() + + kwargs.update(self._static_args) + return kwargs + +class GatewayKernelManager(MappingKernelManager): + """Kernel manager that supports remote kernels hosted by Jupyter Kernel or Enterprise Gateway.""" # We'll maintain our own set of kernel ids _kernels = {} def __init__(self, **kwargs): super(GatewayKernelManager, self).__init__(**kwargs) - self.gateway_url = self.parent.gateway_url + self.base_endpoint = url_path_join(Gateway.instance().url, Gateway.instance().kernels_endpoint) def __contains__(self, kernel_id): return kernel_id in self._kernels @@ -105,9 +270,9 @@ def _get_kernel_endpoint_url(self, kernel_id=None): kernel_id: kernel UUID (optional) """ if kernel_id: - return url_path_join(self.gateway_url, self.kernels_endpoint, url_escape(str(kernel_id))) + return url_path_join(self.base_endpoint, url_escape(str(kernel_id))) - return url_path_join(self.gateway_url, self.kernels_endpoint) + return self.base_endpoint @gen.coroutine def start_kernel(self, kernel_id=None, path=None, **kwargs): @@ -243,7 +408,7 @@ def shutdown_all(self): """Shutdown all kernels.""" # Note: We have to make this sync because the NotebookApp does not wait for async. kwargs = {'method': 'DELETE'} - kwargs = load_connection_args(**kwargs) + kwargs = Gateway.instance().load_connection_args(**kwargs) client = HTTPClient() for kernel_id in self._kernels.keys(): kernel_url = self._get_kernel_endpoint_url(kernel_id) @@ -259,18 +424,9 @@ def shutdown_all(self): class GatewayKernelSpecManager(KernelSpecManager): - kernelspecs_endpoint_env = 'GATEWAY_KERNELSPECS_ENDPOINT' - kernelspecs_endpoint = Unicode(config=True, - help="""The kernel gateway API endpoint for accessing kernelspecs - (GATEWAY_KERNELSPECS_ENDPOINT env var)""") - - @default('kernelspecs_endpoint') - def kernelspecs_endpoint_default(self): - return os.getenv(self.kernelspecs_endpoint_env, '/api/kernelspecs') - def __init__(self, **kwargs): super(GatewayKernelSpecManager, self).__init__(**kwargs) - self.gateway_url = self.parent.gateway_url + self.base_endpoint = url_path_join(Gateway.instance().url, Gateway.instance().kernelspecs_endpoint) def _get_kernelspecs_endpoint_url(self, kernel_name=None): """Builds a url for the kernels endpoint @@ -280,9 +436,9 @@ def _get_kernelspecs_endpoint_url(self, kernel_name=None): kernel_name: kernel name (optional) """ if kernel_name: - return url_path_join(self.gateway_url, self.kernelspecs_endpoint, url_escape(kernel_name)) + return url_path_join(self.base_endpoint, url_escape(kernel_name)) - return url_path_join(self.gateway_url, self.kernelspecs_endpoint) + return self.base_endpoint @gen.coroutine def list_kernel_specs(self): diff --git a/notebook/notebookapp.py b/notebook/notebookapp.py index abdefd274f..559e775794 100755 --- a/notebook/notebookapp.py +++ b/notebook/notebookapp.py @@ -84,9 +84,7 @@ from .services.contents.filemanager import FileContentsManager from .services.contents.largefilemanager import LargeFileManager from .services.sessions.sessionmanager import SessionManager -from .gateway.managers import GatewayKernelManager -from .gateway.managers import GatewayKernelSpecManager -from .gateway.managers import GatewaySessionManager +from .gateway.managers import GatewayKernelManager, GatewayKernelSpecManager, GatewaySessionManager, Gateway from .auth.login import LoginHandler from .auth.logout import LogoutHandler @@ -153,14 +151,13 @@ class NotebookWebApplication(web.Application): def __init__(self, jupyter_app, kernel_manager, contents_manager, session_manager, kernel_spec_manager, config_manager, extra_services, log, - base_url, default_url, settings_overrides, jinja_env_options, - gateway_url): + base_url, default_url, settings_overrides, jinja_env_options): settings = self.init_settings( jupyter_app, kernel_manager, contents_manager, session_manager, kernel_spec_manager, config_manager, extra_services, log, base_url, - default_url, settings_overrides, jinja_env_options, gateway_url) + default_url, settings_overrides, jinja_env_options) handlers = self.init_handlers(settings) super(NotebookWebApplication, self).__init__(handlers, **settings) @@ -169,7 +166,7 @@ def init_settings(self, jupyter_app, kernel_manager, contents_manager, session_manager, kernel_spec_manager, config_manager, extra_services, log, base_url, default_url, settings_overrides, - jinja_env_options=None, gateway_url=None): + jinja_env_options=None): _template_path = settings_overrides.get( "template_path", @@ -283,7 +280,6 @@ def init_settings(self, jupyter_app, kernel_manager, contents_manager, server_root_dir=root_dir, jinja2_env=env, terminals_available=False, # Set later if terminals are available - gateway_url=gateway_url, ) # allow custom overrides for the tornado web app. @@ -317,7 +313,7 @@ def init_handlers(self, settings): handlers.extend(load_handlers('notebook.services.shutdown')) # If gateway server is configured, replace appropriate handlers to perform redirection - if settings['gateway_url']: + if Gateway.instance().gateway_enabled: handlers.extend(load_handlers('notebook.gateway.handlers')) else: handlers.extend(load_handlers('notebook.services.kernels.handlers')) @@ -558,7 +554,7 @@ def start(self): 'notebook-dir': 'NotebookApp.notebook_dir', 'browser': 'NotebookApp.browser', 'pylab': 'NotebookApp.pylab', - 'gateway-url': 'NotebookApp.gateway_url', + 'gateway-url': 'Gateway.url', }) #----------------------------------------------------------------------------- @@ -579,7 +575,7 @@ class NotebookApp(JupyterApp): classes = [ KernelManager, Session, MappingKernelManager, KernelSpecManager, ContentsManager, FileContentsManager, NotebookNotary, - GatewayKernelManager, GatewayKernelSpecManager, GatewaySessionManager, + GatewayKernelManager, GatewayKernelSpecManager, GatewaySessionManager, Gateway, ] flags = Dict(flags) aliases = Dict(aliases) @@ -1307,20 +1303,6 @@ def _update_server_extensions(self, change): is not available. """)) - gateway_url_env = 'GATEWAY_URL' - - @default('gateway_url') - def gateway_url_default(self): - return os.getenv(self.gateway_url_env) - - gateway_url = Unicode(default_value=None, allow_none=True, config=True, - help="""The url of the Kernel or Enterprise Gateway server where - kernel specifications are defined and kernel management takes place. - If defined, this Notebook server acts as a proxy for all kernel - management and kernel specification retrieval. (GATEWAY_URL env var) - """ - ) - def parse_command_line(self, argv=None): super(NotebookApp, self).parse_command_line(argv) @@ -1344,7 +1326,9 @@ def parse_command_line(self, argv=None): def init_configurables(self): # If gateway server is configured, replace appropriate managers to perform redirection - if self.gateway_url: + self.gateway_config = Gateway.instance(parent=self) + + if self.gateway_config.gateway_enabled: self.kernel_manager_class = 'notebook.gateway.managers.GatewayKernelManager' self.session_manager_class = 'notebook.gateway.managers.GatewaySessionManager' self.kernel_spec_manager_class = 'notebook.gateway.managers.GatewayKernelSpecManager' @@ -1414,7 +1398,7 @@ def init_webapp(self): self.session_manager, self.kernel_spec_manager, self.config_manager, self.extra_services, self.log, self.base_url, self.default_url, self.tornado_settings, - self.jinja_environment_options, self.gateway_url, + self.jinja_environment_options, ) ssl_options = self.ssl_options if self.certfile: @@ -1694,8 +1678,8 @@ def notebook_info(self, kernel_count=True): info += "\n" # Format the info so that the URL fits on a single line in 80 char display info += _("The Jupyter Notebook is running at:\n%s") % self.display_url - if self.gateway_url: - info += _("\nKernels will be managed by the Gateway server running at:\n%s") % self.gateway_url + if self.gateway_config.gateway_enabled: + info += _("\nKernels will be managed by the Gateway server running at:\n%s") % self.gateway_config.url return info def server_info(self): diff --git a/notebook/tests/launchnotebook.py b/notebook/tests/launchnotebook.py index 1b685df0ca..9e84a5964b 100644 --- a/notebook/tests/launchnotebook.py +++ b/notebook/tests/launchnotebook.py @@ -91,6 +91,22 @@ def request(cls, verb, path, **kwargs): url_path_join(cls.base_url(), path), **kwargs) return response + + @classmethod + def get_patch_env(cls): + return { + 'HOME': cls.home_dir, + 'PYTHONPATH': os.pathsep.join(sys.path), + 'IPYTHONDIR': pjoin(cls.home_dir, '.ipython'), + 'JUPYTER_NO_CONFIG': '1', # needed in the future + 'JUPYTER_CONFIG_DIR' : cls.config_dir, + 'JUPYTER_DATA_DIR' : cls.data_dir, + 'JUPYTER_RUNTIME_DIR': cls.runtime_dir, + } + + @classmethod + def get_argv(cls): + return [] @classmethod def setup_class(cls): @@ -109,15 +125,7 @@ def tmp(*parts): config_dir = cls.config_dir = tmp('config') runtime_dir = cls.runtime_dir = tmp('runtime') cls.notebook_dir = tmp('notebooks') - cls.env_patch = patch.dict('os.environ', { - 'HOME': cls.home_dir, - 'PYTHONPATH': os.pathsep.join(sys.path), - 'IPYTHONDIR': pjoin(cls.home_dir, '.ipython'), - 'JUPYTER_NO_CONFIG': '1', # needed in the future - 'JUPYTER_CONFIG_DIR' : config_dir, - 'JUPYTER_DATA_DIR' : data_dir, - 'JUPYTER_RUNTIME_DIR': runtime_dir, - }) + cls.env_patch = patch.dict('os.environ', cls.get_patch_env()) cls.env_patch.start() cls.path_patch = patch.multiple( jupyter_core.paths, @@ -157,7 +165,7 @@ def start_thread(): # needs to be redone after initialize, which reconfigures logging app.log.propagate = True app.log.handlers = [] - app.initialize(argv=[]) + app.initialize(argv=cls.get_argv()) app.log.propagate = True app.log.handlers = [] loop = IOLoop.current() diff --git a/notebook/tests/test_gateway.py b/notebook/tests/test_gateway.py index 385c1f1cd5..a007046966 100644 --- a/notebook/tests/test_gateway.py +++ b/notebook/tests/test_gateway.py @@ -7,6 +7,7 @@ from tornado.httpclient import HTTPRequest, HTTPResponse, HTTPError from traitlets.config import Config from .launchnotebook import NotebookTestBase +from notebook.gateway.managers import Gateway try: from unittest.mock import patch, Mock @@ -137,12 +138,39 @@ def mock_fetch_gateway(url, **kwargs): class TestGateway(NotebookTestBase): + mock_gateway_url = 'http://mock-gateway-server:8889' + mock_http_user = 'alice' + @classmethod def setup_class(cls): - cls.config = Config() - cls.config.NotebookApp.gateway_url = 'http://mock-gateway-server:8889' + Gateway.clear_instance() super(TestGateway, cls).setup_class() + @classmethod + def teardown_class(cls): + Gateway.clear_instance() + super(TestGateway, cls).teardown_class() + + @classmethod + def get_patch_env(cls): + test_env = super(TestGateway, cls).get_patch_env() + test_env.update({'JUPYTER_GATEWAY_URL': TestGateway.mock_gateway_url, + 'JUPYTER_GATEWAY_REQUEST_TIMEOUT': '44.4'}) + return test_env + + @classmethod + def get_argv(cls): + argv = super(TestGateway, cls).get_argv() + argv.extend(['--Gateway.connect_timeout=44.4', '--Gateway.http_user=' + TestGateway.mock_http_user]) + return argv + + def test_gateway_options(self): + nt.assert_equal(self.notebook.gateway_config.gateway_enabled, True) + nt.assert_equal(self.notebook.gateway_config.url, TestGateway.mock_gateway_url) + nt.assert_equal(self.notebook.gateway_config.http_user, TestGateway.mock_http_user) + nt.assert_equal(self.notebook.gateway_config.connect_timeout, self.notebook.gateway_config.connect_timeout) + nt.assert_equal(self.notebook.gateway_config.connect_timeout, 44.4) + def test_gateway_class_mappings(self): # Ensure appropriate class mappings are in place. nt.assert_equal(self.notebook.kernel_manager_class.__name__, 'GatewayKernelManager') From acba19033b720902287eaed11f9ed656fd8f456d Mon Sep 17 00:00:00 2001 From: Kevin Bates Date: Thu, 6 Dec 2018 14:18:28 -0800 Subject: [PATCH 3/3] Minimize handlers and manager methods Eliminated the Kernel and Kernelspec handlers. The Websocket (ZMQ) channels handler still remains. This required turning a few methods into coroutines in the Notebook server. Renamed the Gateway config object to GatewayClient in case we want to extend NB server (probably jupyter_server at that point) with Gateway server functionality - so an NB server could be a Gateway client or a server depending on launch settings. Add code to _replace_ the channels handler rather than rely on position within the handlers lists. Updated mock-gateway to return the appropriate form of results. Updated the session manager tests to use a sync ioloop to call the now async manager methods. --- docs/source/public_server.rst | 2 +- notebook/gateway/handlers.py | 165 +-------- notebook/gateway/managers.py | 326 +++++++----------- notebook/notebookapp.py | 31 +- notebook/services/kernels/handlers.py | 5 +- notebook/services/kernelspecs/handlers.py | 9 +- notebook/services/sessions/sessionmanager.py | 47 ++- .../sessions/tests/test_sessionmanager.py | 26 +- notebook/tests/test_gateway.py | 23 +- 9 files changed, 209 insertions(+), 425 deletions(-) diff --git a/docs/source/public_server.rst b/docs/source/public_server.rst index 5bbce15429..edadbe3ffc 100644 --- a/docs/source/public_server.rst +++ b/docs/source/public_server.rst @@ -365,7 +365,7 @@ or in :file:`jupyter_notebook_config.py`: .. code-block:: python - c.Gateway.url = http://my-gateway-server:8888 + c.GatewayClient.url = http://my-gateway-server:8888 When provided, all kernel specifications will be retrieved from the specified Gateway server and all kernels will be managed by that server. This option enables the ability to target kernel processes diff --git a/notebook/gateway/handlers.py b/notebook/gateway/handlers.py index 1e28bfa779..8e09b10861 100644 --- a/notebook/gateway/handlers.py +++ b/notebook/gateway/handlers.py @@ -2,11 +2,9 @@ # Distributed under the terms of the Modified BSD License. import os -import json import logging -from socket import gaierror -from ..base.handlers import APIHandler, IPythonHandler +from ..base.handlers import IPythonHandler from ..utils import url_path_join from tornado import gen, web @@ -14,14 +12,13 @@ from tornado.ioloop import IOLoop from tornado.websocket import WebSocketHandler, websocket_connect from tornado.httpclient import HTTPRequest -from tornado.simple_httpclient import HTTPTimeoutError from tornado.escape import url_escape, json_decode, utf8 from ipython_genutils.py3compat import cast_unicode from jupyter_client.session import Session from traitlets.config.configurable import LoggingConfigurable -from .managers import Gateway +from .managers import GatewayClient class WebSocketChannelsHandler(WebSocketHandler, IPythonHandler): @@ -57,7 +54,7 @@ def authenticate(self): def initialize(self): self.log.debug("Initializing websocket connection %s", self.request.path) self.session = Session(config=self.config) - self.gateway = GatewayWebSocketClient(gateway_url=Gateway.instance().url) + self.gateway = GatewayWebSocketClient(gateway_url=GatewayClient.instance().url) @gen.coroutine def get(self, kernel_id, *args, **kwargs): @@ -124,12 +121,12 @@ def __init__(self, **kwargs): def _connect(self, kernel_id): self.kernel_id = kernel_id ws_url = url_path_join( - Gateway.instance().ws_url, - Gateway.instance().kernels_endpoint, url_escape(kernel_id), 'channels' + GatewayClient.instance().ws_url, + GatewayClient.instance().kernels_endpoint, url_escape(kernel_id), 'channels' ) self.log.info('Connecting to {}'.format(ws_url)) kwargs = {} - kwargs = Gateway.instance().load_connection_args(**kwargs) + kwargs = GatewayClient.instance().load_connection_args(**kwargs) request = HTTPRequest(ws_url, **kwargs) self.ws_future = websocket_connect(request) @@ -141,8 +138,8 @@ def _connection_done(self, fut): self.log.debug("Connection is ready: ws: {}".format(self.ws)) else: self.log.warning("Websocket connection has been cancelled via client disconnect before its establishment. " - "Kernel with ID '{}' may not be terminated on Gateway: {}". - format(self.kernel_id, Gateway.instance().url)) + "Kernel with ID '{}' may not be terminated on GatewayClient: {}". + format(self.kernel_id, GatewayClient.instance().url)) def _disconnect(self): if self.ws is not None: @@ -203,152 +200,8 @@ def on_close(self): self._disconnect() -# ----------------------------------------------------------------------------- -# kernel handlers -# ----------------------------------------------------------------------------- - -class MainKernelHandler(APIHandler): - """Replace default MainKernelHandler to enable async lookup of kernels.""" - - @web.authenticated - @gen.coroutine - def get(self): - km = self.kernel_manager - kernels = yield gen.maybe_future(km.list_kernels()) - self.finish(json.dumps(kernels)) - - @web.authenticated - @gen.coroutine - def post(self): - km = self.kernel_manager - model = self.get_json_body() - if model is None: - model = { - 'name': km.default_kernel_name - } - else: - model.setdefault('name', km.default_kernel_name) - - kernel_id = yield gen.maybe_future(km.start_kernel(kernel_name=model['name'])) - # This is now an async operation - model = yield gen.maybe_future(km.kernel_model(kernel_id)) - location = url_path_join(self.base_url, 'api', 'kernels', url_escape(kernel_id)) - self.set_header('Location', location) - self.set_status(201) - self.finish(json.dumps(model)) - - -class KernelHandler(APIHandler): - """Replace default KernelHandler to enable async lookup of kernels.""" - - @web.authenticated - @gen.coroutine - def get(self, kernel_id): - km = self.kernel_manager - # This is now an async operation - model = yield gen.maybe_future(km.kernel_model(kernel_id)) - if model is None: - raise web.HTTPError(404, u'Kernel does not exist: %s' % kernel_id) - self.finish(json.dumps(model)) - - @web.authenticated - @gen.coroutine - def delete(self, kernel_id): - km = self.kernel_manager - yield gen.maybe_future(km.shutdown_kernel(kernel_id)) - self.set_status(204) - self.finish() - - -class KernelActionHandler(APIHandler): - """Replace default KernelActionHandler to enable async lookup of kernels.""" - - @web.authenticated - @gen.coroutine - def post(self, kernel_id, action): - km = self.kernel_manager - - if action == 'interrupt': - km.interrupt_kernel(kernel_id) - self.set_status(204) - - if action == 'restart': - try: - yield gen.maybe_future(km.restart_kernel(kernel_id)) - except Exception as e: - self.log.error("Exception restarting kernel", exc_info=True) - self.set_status(500) - else: - # This is now an async operation - model = yield gen.maybe_future(km.kernel_model(kernel_id)) - self.write(json.dumps(model)) - self.finish() - -# ----------------------------------------------------------------------------- -# kernel spec handlers -# ----------------------------------------------------------------------------- - - -class MainKernelSpecHandler(APIHandler): - @web.authenticated - @gen.coroutine - def get(self): - ksm = self.kernel_spec_manager - try: - kernel_specs = yield gen.maybe_future(ksm.list_kernel_specs()) - # TODO: Remove resources until we support them - for name, spec in kernel_specs['kernelspecs'].items(): - spec['resources'] = {} - self.set_header("Content-Type", 'application/json') - self.write(json.dumps(kernel_specs)) - - # Trap a set of common exceptions so that we can inform the user that their Gateway url is incorrect - # or the server is not running. - # NOTE: We do this here since this handler is called during the Notebook's startup and subsequent refreshes - # of the tree view. - except ConnectionRefusedError: - self.log.error("Connection refused from Gateway server url '{}'. " - "Check to be sure the Gateway instance is running.".format(Gateway.instance().url)) - except HTTPTimeoutError: - # This can occur if the host is valid (e.g., foo.com) but there's nothing there. - self.log.error("Timeout error attempting to connect to Gateway server url '{}'. " - "Ensure gateway url is valid and the Gateway instance is running.".format(Gateway.instance().url)) - except gaierror as e: - self.log.error("The Gateway server specified in the gateway_url '{}' doesn't appear to be valid. " - "Ensure gateway url is valid and the Gateway instance is running.".format(Gateway.instance().url)) - - self.finish() - - -class KernelSpecHandler(APIHandler): - @web.authenticated - @gen.coroutine - def get(self, kernel_name): - ksm = self.kernel_spec_manager - kernel_spec = yield ksm.get_kernel_spec(kernel_name) - if kernel_spec is None: - raise web.HTTPError(404, u'Kernel spec %s not found' % kernel_name) - # TODO: Remove resources until we support them - kernel_spec['resources'] = {} - self.set_header("Content-Type", 'application/json') - self.finish(json.dumps(kernel_spec)) - -# ----------------------------------------------------------------------------- -# URL to handler mappings -# ----------------------------------------------------------------------------- - - -from ..services.kernels.handlers import _kernel_id_regex, _kernel_action_regex -from ..services.kernelspecs.handlers import kernel_name_regex +from ..services.kernels.handlers import _kernel_id_regex default_handlers = [ - (r"/api/kernels", MainKernelHandler), - (r"/api/kernels/%s" % _kernel_id_regex, KernelHandler), - (r"/api/kernels/%s/%s" % (_kernel_id_regex, _kernel_action_regex), KernelActionHandler), (r"/api/kernels/%s/channels" % _kernel_id_regex, WebSocketChannelsHandler), - (r"/api/kernelspecs", MainKernelSpecHandler), - (r"/api/kernelspecs/%s" % kernel_name_regex, KernelSpecHandler), - # TODO: support kernel spec resources - # (r"/kernelspecs/%s/(?P.*)" % kernel_name_regex, KernelSpecResourceHandler), - ] diff --git a/notebook/gateway/managers.py b/notebook/gateway/managers.py index 42da1696fc..73af7d9799 100644 --- a/notebook/gateway/managers.py +++ b/notebook/gateway/managers.py @@ -4,9 +4,11 @@ import os import json -from tornado import gen +from socket import gaierror +from tornado import gen, web from tornado.escape import json_encode, json_decode, url_escape from tornado.httpclient import HTTPClient, AsyncHTTPClient, HTTPError +from tornado.simple_httpclient import HTTPTimeoutError from ..services.kernels.kernelmanager import MappingKernelManager from ..services.sessions.sessionmanager import SessionManager @@ -18,22 +20,10 @@ from traitlets.config import SingletonConfigurable -@gen.coroutine -def fetch_gateway(endpoint, **kwargs): - """Make an async request to kernel gateway endpoint.""" - client = AsyncHTTPClient() - - kwargs = Gateway.instance().load_connection_args(**kwargs) - - response = yield client.fetch(endpoint, **kwargs) - raise gen.Return(response) - - -class Gateway(SingletonConfigurable): - """This class manages the configuration. It's its own class so that we can avoid having command - line options of the likes `--GatewayKernelManager.connect_timeout` and use the shorter and more - applicable `--Gateway.connect_timeout`, etc. It also contains some helper methods to build - request arguments out of the various config options. +class GatewayClient(SingletonConfigurable): + """This class manages the configuration. It's its own singleton class so that we + can share these values across all objects. It also contains some helper methods + to build request arguments out of the various config options. """ @@ -56,7 +46,7 @@ def _url_validate(self, proposal): # Ensure value, if present, starts with 'http' if value is not None and len(value) > 0: if not str(value).lower().startswith('http'): - raise TraitError("Gateway url must start with 'http': '%r'" % value) + raise TraitError("GatewayClient url must start with 'http': '%r'" % value) return value ws_url = Unicode(default_value=None, allow_none=True, config=True, @@ -80,7 +70,7 @@ def _ws_url_validate(self, proposal): # Ensure value, if present, starts with 'ws' if value is not None and len(value) > 0: if not str(value).lower().startswith('ws'): - raise TraitError("Gateway ws_url must start with 'ws': '%r'" % value) + raise TraitError("GatewayClient ws_url must start with 'ws': '%r'" % value) return value kernels_endpoint_default_value = '/api/kernels' @@ -204,9 +194,21 @@ def validate_cert_default(self): return bool(os.environ.get(self.validate_cert_env, str(self.validate_cert_default_value)) not in ['no', 'false']) def __init__(self, **kwargs): - super(Gateway, self).__init__(**kwargs) + super(GatewayClient, self).__init__(**kwargs) self._static_args = {} # initialized on first use + env_whitelist_default_value = '' + env_whitelist_env = 'JUPYTER_GATEWAY_ENV_WHITELIST' + env_whitelist = Unicode(default_value=env_whitelist_default_value, config=True, + help="""A comma-separated list of environment variable names that will be included, along with + their values, in the kernel startup request. The corresponding `env_whitelist` configuration + value must also be set on the Gateway server - since that configuration value indicates which + environmental values to make available to the kernel. (JUPYTER_GATEWAY_ENV_WHITELIST env var)""") + + @default('env_whitelist') + def _env_whitelist_default(self): + return os.environ.get(self.env_whitelist_env, self.env_whitelist_default_value) + @property def gateway_enabled(self): return bool(self.url is not None and len(self.url) > 0) @@ -242,6 +244,34 @@ def load_connection_args(self, **kwargs): kwargs.update(self._static_args) return kwargs + +@gen.coroutine +def gateway_request(endpoint, **kwargs): + """Make an async request to kernel gateway endpoint, returns a response """ + client = AsyncHTTPClient() + kwargs = GatewayClient.instance().load_connection_args(**kwargs) + try: + response = yield client.fetch(endpoint, **kwargs) + # Trap a set of common exceptions so that we can inform the user that their Gateway url is incorrect + # or the server is not running. + # NOTE: We do this here since this handler is called during the Notebook's startup and subsequent refreshes + # of the tree view. + except ConnectionRefusedError: + raise web.HTTPError(503, "Connection refused from Gateway server url '{}'. " + "Check to be sure the Gateway instance is running.".format(GatewayClient.instance().url)) + except HTTPTimeoutError: + # This can occur if the host is valid (e.g., foo.com) but there's nothing there. + raise web.HTTPError(504, "Timeout error attempting to connect to Gateway server url '{}'. " \ + "Ensure gateway url is valid and the Gateway instance is running.".format( + GatewayClient.instance().url)) + except gaierror as e: + raise web.HTTPError(404, "The Gateway server specified in the gateway_url '{}' doesn't appear to be valid. " + "Ensure gateway url is valid and the Gateway instance is running.".format( + GatewayClient.instance().url)) + + raise gen.Return(response) + + class GatewayKernelManager(MappingKernelManager): """Kernel manager that supports remote kernels hosted by Jupyter Kernel or Enterprise Gateway.""" @@ -250,7 +280,7 @@ class GatewayKernelManager(MappingKernelManager): def __init__(self, **kwargs): super(GatewayKernelManager, self).__init__(**kwargs) - self.base_endpoint = url_path_join(Gateway.instance().url, Gateway.instance().kernels_endpoint) + self.base_endpoint = url_path_join(GatewayClient.instance().url, GatewayClient.instance().kernels_endpoint) def __contains__(self, kernel_id): return kernel_id in self._kernels @@ -291,18 +321,25 @@ def start_kernel(self, kernel_id=None, path=None, **kwargs): self.log.info('Request start kernel: kernel_id=%s, path="%s"', kernel_id, path) if kernel_id is None: + if path is not None: + kwargs['cwd'] = self.cwd_for_path(path) kernel_name = kwargs.get('kernel_name', 'python3') kernel_url = self._get_kernel_endpoint_url() self.log.debug("Request new kernel at: %s" % kernel_url) + # Let KERNEL_USERNAME take precedent over http_user config option. + if os.environ.get('KERNEL_USERNAME') is None and GatewayClient.instance().http_user: + os.environ['KERNEL_USERNAME'] = GatewayClient.instance().http_user + kernel_env = {k: v for (k, v) in dict(os.environ).items() if k.startswith('KERNEL_') - or k in os.environ.get('GATEWAY_ENV_WHITELIST', '').split(",")} + or k in GatewayClient.instance().env_whitelist.split(",")} json_body = json_encode({'name': kernel_name, 'env': kernel_env}) - response = yield fetch_gateway(kernel_url, method='POST', body=json_body) + response = yield gateway_request(kernel_url, method='POST', body=json_body) kernel = json_decode(response.body) kernel_id = kernel['id'] self.log.info("Kernel started: %s" % kernel_id) + self.log.debug("Kernel args: %r" % kwargs) else: kernel = yield self.get_kernel(kernel_id) kernel_id = kernel['id'] @@ -323,7 +360,7 @@ def get_kernel(self, kernel_id=None, **kwargs): kernel_url = self._get_kernel_endpoint_url(kernel_id) self.log.debug("Request kernel at: %s" % kernel_url) try: - response = yield fetch_gateway(kernel_url, method='GET') + response = yield gateway_request(kernel_url, method='GET') except HTTPError as error: if error.code == 404: self.log.warn("Kernel not found at: %s" % kernel_url) @@ -334,7 +371,7 @@ def get_kernel(self, kernel_id=None, **kwargs): else: kernel = json_decode(response.body) self._kernels[kernel_id] = kernel - self.log.info("Kernel retrieved: %s" % kernel) + self.log.debug("Kernel retrieved: %s" % kernel) raise gen.Return(kernel) @gen.coroutine @@ -356,13 +393,13 @@ def list_kernels(self, **kwargs): """Get a list of kernels.""" kernel_url = self._get_kernel_endpoint_url() self.log.debug("Request list kernels: %s", kernel_url) - response = yield fetch_gateway(kernel_url, method='GET') + response = yield gateway_request(kernel_url, method='GET') kernels = json_decode(response.body) self._kernels = {x['id']:x for x in kernels} raise gen.Return(kernels) @gen.coroutine - def shutdown_kernel(self, kernel_id): + def shutdown_kernel(self, kernel_id, now=False, restart=False): """Shutdown a kernel by its kernel uuid. Parameters @@ -372,7 +409,7 @@ def shutdown_kernel(self, kernel_id): """ kernel_url = self._get_kernel_endpoint_url(kernel_id) self.log.debug("Request shutdown kernel at: %s", kernel_url) - response = yield fetch_gateway(kernel_url, method='DELETE') + response = yield gateway_request(kernel_url, method='DELETE') self.log.debug("Shutdown kernel response: %d %s", response.code, response.reason) self.remove_kernel(kernel_id) @@ -387,7 +424,7 @@ def restart_kernel(self, kernel_id, now=False, **kwargs): """ kernel_url = self._get_kernel_endpoint_url(kernel_id) + '/restart' self.log.debug("Request restart kernel at: %s", kernel_url) - response = yield fetch_gateway(kernel_url, method='POST', body=json_encode({})) + response = yield gateway_request(kernel_url, method='POST', body=json_encode({})) self.log.debug("Restart kernel response: %d %s", response.code, response.reason) @gen.coroutine @@ -401,14 +438,15 @@ def interrupt_kernel(self, kernel_id, **kwargs): """ kernel_url = self._get_kernel_endpoint_url(kernel_id) + '/interrupt' self.log.debug("Request interrupt kernel at: %s", kernel_url) - response = yield fetch_gateway(kernel_url, method='POST', body=json_encode({})) + response = yield gateway_request(kernel_url, method='POST', body=json_encode({})) self.log.debug("Interrupt kernel response: %d %s", response.code, response.reason) - def shutdown_all(self): + def shutdown_all(self, now=False): """Shutdown all kernels.""" # Note: We have to make this sync because the NotebookApp does not wait for async. + shutdown_kernels = [] kwargs = {'method': 'DELETE'} - kwargs = Gateway.instance().load_connection_args(**kwargs) + kwargs = GatewayClient.instance().load_connection_args(**kwargs) client = HTTPClient() for kernel_id in self._kernels.keys(): kernel_url = self._get_kernel_endpoint_url(kernel_id) @@ -417,16 +455,19 @@ def shutdown_all(self): response = client.fetch(kernel_url, **kwargs) except HTTPError: pass - self.log.debug("Delete kernel response: %d %s", response.code, response.reason) - self.remove_kernel(kernel_id) + else: + self.log.debug("Delete kernel response: %d %s", response.code, response.reason) + shutdown_kernels.append(kernel_id) # avoid changing dict size during iteration client.close() + for kernel_id in shutdown_kernels: + self.remove_kernel(kernel_id) class GatewayKernelSpecManager(KernelSpecManager): def __init__(self, **kwargs): super(GatewayKernelSpecManager, self).__init__(**kwargs) - self.base_endpoint = url_path_join(Gateway.instance().url, Gateway.instance().kernelspecs_endpoint) + self.base_endpoint = url_path_join(GatewayClient.instance().url, GatewayClient.instance().kernelspecs_endpoint) def _get_kernelspecs_endpoint_url(self, kernel_name=None): """Builds a url for the kernels endpoint @@ -440,12 +481,39 @@ def _get_kernelspecs_endpoint_url(self, kernel_name=None): return self.base_endpoint + @gen.coroutine + def get_all_specs(self): + fetched_kspecs = yield self.list_kernel_specs() + + # get the default kernel name and compare to that of this server. + # If different log a warning and reset the default. However, the + # caller of this method will still return this server's value until + # the next fetch of kernelspecs - at which time they'll match. + km = self.parent.kernel_manager + remote_default_kernel_name = fetched_kspecs.get('default') + if remote_default_kernel_name != km.default_kernel_name: + self.log.info("Default kernel name on Gateway server ({gateway_default}) differs from " + "Notebook server ({notebook_default}). Updating to Gateway server's value.". + format(gateway_default=remote_default_kernel_name, + notebook_default=km.default_kernel_name)) + km.default_kernel_name = remote_default_kernel_name + + # gateway doesn't support resources (requires transfer for use by NB client) + # so add `resource_dir` to each kernelspec and value of 'not supported in gateway mode' + remote_kspecs = fetched_kspecs.get('kernelspecs') + for kernel_name, kspec_info in remote_kspecs.items(): + if not kspec_info.get('resource_dir'): + kspec_info['resource_dir'] = 'not supported in gateway mode' + remote_kspecs[kernel_name].update(kspec_info) + + raise gen.Return(remote_kspecs) + @gen.coroutine def list_kernel_specs(self): """Get a list of kernel specs.""" kernel_spec_url = self._get_kernelspecs_endpoint_url() self.log.debug("Request list kernel specs at: %s", kernel_spec_url) - response = yield fetch_gateway(kernel_spec_url, method='GET') + response = yield gateway_request(kernel_spec_url, method='GET') kernel_specs = json_decode(response.body) raise gen.Return(kernel_specs) @@ -461,189 +529,27 @@ def get_kernel_spec(self, kernel_name, **kwargs): kernel_spec_url = self._get_kernelspecs_endpoint_url(kernel_name=str(kernel_name)) self.log.debug("Request kernel spec at: %s" % kernel_spec_url) try: - response = yield fetch_gateway(kernel_spec_url, method='GET') + response = yield gateway_request(kernel_spec_url, method='GET') except HTTPError as error: if error.code == 404: - self.log.warn("Kernel spec not found at: %s" % kernel_spec_url) - kernel_spec = None + # Convert not found to KeyError since that's what the Notebook handler expects + # message is not used, but might as well make it useful for troubleshooting + raise KeyError('kernelspec {kernel_name} not found on Gateway server at: {gateway_url}'. + format(kernel_name=kernel_name, gateway_url=GatewayClient.instance().url)) else: raise else: kernel_spec = json_decode(response.body) - raise gen.Return(kernel_spec) + # Convert to instance of Kernelspec + kspec_instance = self.kernel_spec_class(resource_dir=u'', **kernel_spec['spec']) + raise gen.Return(kspec_instance) class GatewaySessionManager(SessionManager): kernel_manager = Instance('notebook.gateway.managers.GatewayKernelManager') @gen.coroutine - def create_session(self, path=None, name=None, type=None, - kernel_name=None, kernel_id=None): - """Creates a session and returns its model. - - Overrides base class method to turn into an async operation. - """ - session_id = self.new_session_id() - - kernel = None - if kernel_id is not None: - # This is now an async operation - kernel = yield self.kernel_manager.get_kernel(kernel_id) - - if kernel is not None: - pass - else: - kernel_id = yield self.start_kernel_for_session( - session_id, path, name, type, kernel_name, - ) - - result = yield self.save_session( - session_id, path=path, name=name, type=type, kernel_id=kernel_id, - ) - raise gen.Return(result) - - @gen.coroutine - def save_session(self, session_id, path=None, name=None, type=None, - kernel_id=None): - """Saves the items for the session with the given session_id - - Given a session_id (and any other of the arguments), this method - creates a row in the sqlite session database that holds the information - for a session. - - Parameters - ---------- - session_id : str - uuid for the session; this method must be given a session_id - path : str - the path for the given notebook - kernel_id : str - a uuid for the kernel associated with this session - - Returns - ------- - model : dict - a dictionary of the session model - """ - # This is now an async operation - session = yield super(GatewaySessionManager, self).save_session( - session_id, path=path, name=name, type=type, kernel_id=kernel_id - ) - raise gen.Return(session) - - @gen.coroutine - def get_session(self, **kwargs): - """Returns the model for a particular session. - - Takes a keyword argument and searches for the value in the session - database, then returns the rest of the session's info. - - Overrides base class method to turn into an async operation. - - Parameters - ---------- - **kwargs : keyword argument - must be given one of the keywords and values from the session database - (i.e. session_id, path, kernel_id) - - Returns - ------- - model : dict - returns a dictionary that includes all the information from the - session described by the kwarg. - """ - # This is now an async operation - session = yield super(GatewaySessionManager, self).get_session(**kwargs) - raise gen.Return(session) - - @gen.coroutine - def update_session(self, session_id, **kwargs): - """Updates the values in the session database. - - Changes the values of the session with the given session_id - with the values from the keyword arguments. - - Overrides base class method to turn into an async operation. - - Parameters - ---------- - session_id : str - a uuid that identifies a session in the sqlite3 database - **kwargs : str - the key must correspond to a column title in session database, - and the value replaces the current value in the session - with session_id. - """ - # This is now an async operation - session = yield self.get_session(session_id=session_id) - - if not kwargs: - # no changes - return - - sets = [] - for column in kwargs.keys(): - if column not in self._columns: - raise TypeError("No such column: %r" % column) - sets.append("%s=?" % column) - query = "UPDATE session SET %s WHERE session_id=?" % (', '.join(sets)) - self.cursor.execute(query, list(kwargs.values()) + [session_id]) - - @gen.coroutine - def row_to_model(self, row): - """Takes sqlite database session row and turns it into a dictionary. - - Overrides base class method to turn into an async operation. - """ - # Retrieve kernel for session, which is now an async operation - kernel = yield self.kernel_manager.get_kernel(row['kernel_id']) - if kernel is None: - # The kernel was killed or died without deleting the session. - # We can't use delete_session here because that tries to find - # and shut down the kernel. - self.cursor.execute("DELETE FROM session WHERE session_id=?", - (row['session_id'],)) - raise KeyError - - model = { - 'id': row['session_id'], - 'path': row['path'], - 'name': row['name'], - 'type': row['type'], - 'kernel': kernel - } - if row['type'] == 'notebook': # Provide the deprecated API. - model['notebook'] = {'path': row['path'], 'name': row['name']} - - raise gen.Return(model) - - @gen.coroutine - def list_sessions(self): - """Returns a list of dictionaries containing all the information from - the session database. - - Overrides base class method to turn into an async operation. - """ - c = self.cursor.execute("SELECT * FROM session") - result = [] - # We need to use fetchall() here, because row_to_model can delete rows, - # which messes up the cursor if we're iterating over rows. - for row in c.fetchall(): - try: - # This is now an async operation - model = yield self.row_to_model(row) - result.append(model) - except KeyError: - pass - raise gen.Return(result) - - @gen.coroutine - def delete_session(self, session_id): - """Deletes the row in the session database with given session_id. - - Overrides base class method to turn into an async operation. - """ - # This is now an async operation - session = yield self.get_session(session_id=session_id) - yield gen.maybe_future(self.kernel_manager.shutdown_kernel(session['kernel']['id'])) - self.cursor.execute("DELETE FROM session WHERE session_id=?", (session_id,)) + def kernel_culled(self, kernel_id): + """Checks if the kernel is still considered alive and returns true if its not found. """ + kernel = yield self.kernel_manager.get_kernel(kernel_id) + raise gen.Return(kernel is None) diff --git a/notebook/notebookapp.py b/notebook/notebookapp.py index 559e775794..2639b4faa8 100755 --- a/notebook/notebookapp.py +++ b/notebook/notebookapp.py @@ -84,7 +84,7 @@ from .services.contents.filemanager import FileContentsManager from .services.contents.largefilemanager import LargeFileManager from .services.sessions.sessionmanager import SessionManager -from .gateway.managers import GatewayKernelManager, GatewayKernelSpecManager, GatewaySessionManager, Gateway +from .gateway.managers import GatewayKernelManager, GatewayKernelSpecManager, GatewaySessionManager, GatewayClient from .auth.login import LoginHandler from .auth.logout import LogoutHandler @@ -311,16 +311,22 @@ def init_handlers(self, settings): handlers.extend(load_handlers('notebook.services.nbconvert.handlers')) handlers.extend(load_handlers('notebook.services.security.handlers')) handlers.extend(load_handlers('notebook.services.shutdown')) - - # If gateway server is configured, replace appropriate handlers to perform redirection - if Gateway.instance().gateway_enabled: - handlers.extend(load_handlers('notebook.gateway.handlers')) - else: - handlers.extend(load_handlers('notebook.services.kernels.handlers')) - handlers.extend(load_handlers('notebook.services.kernelspecs.handlers')) + handlers.extend(load_handlers('notebook.services.kernels.handlers')) + handlers.extend(load_handlers('notebook.services.kernelspecs.handlers')) handlers.extend(settings['contents_manager'].get_extra_handlers()) + # If gateway mode is enabled, replace appropriate handlers to perform redirection + if GatewayClient.instance().gateway_enabled: + # for each handler required for gateway, locate its pattern + # in the current list and replace that entry... + gateway_handlers = load_handlers('notebook.gateway.handlers') + for i, gwh in enumerate(gateway_handlers): + for j, h in enumerate(handlers): + if gwh[0] == h[0]: + handlers[j] = (gwh[0], gwh[1]) + break + handlers.append( (r"/nbextensions/(.*)", FileFindHandler, { 'path': settings['nbextensions_path'], @@ -554,7 +560,7 @@ def start(self): 'notebook-dir': 'NotebookApp.notebook_dir', 'browser': 'NotebookApp.browser', 'pylab': 'NotebookApp.pylab', - 'gateway-url': 'Gateway.url', + 'gateway-url': 'GatewayClient.url', }) #----------------------------------------------------------------------------- @@ -575,7 +581,7 @@ class NotebookApp(JupyterApp): classes = [ KernelManager, Session, MappingKernelManager, KernelSpecManager, ContentsManager, FileContentsManager, NotebookNotary, - GatewayKernelManager, GatewayKernelSpecManager, GatewaySessionManager, Gateway, + GatewayKernelManager, GatewayKernelSpecManager, GatewaySessionManager, GatewayClient, ] flags = Dict(flags) aliases = Dict(aliases) @@ -1325,8 +1331,9 @@ def parse_command_line(self, argv=None): def init_configurables(self): - # If gateway server is configured, replace appropriate managers to perform redirection - self.gateway_config = Gateway.instance(parent=self) + # If gateway server is configured, replace appropriate managers to perform redirection. To make + # this determination, instantiate the GatewayClient config singleton. + self.gateway_config = GatewayClient.instance(parent=self) if self.gateway_config.gateway_enabled: self.kernel_manager_class = 'notebook.gateway.managers.GatewayKernelManager' diff --git a/notebook/services/kernels/handlers.py b/notebook/services/kernels/handlers.py index cfef2a4a0e..897fa51db2 100644 --- a/notebook/services/kernels/handlers.py +++ b/notebook/services/kernels/handlers.py @@ -45,7 +45,7 @@ def post(self): model.setdefault('name', km.default_kernel_name) kernel_id = yield gen.maybe_future(km.start_kernel(kernel_name=model['name'])) - model = km.kernel_model(kernel_id) + model = yield gen.maybe_future(km.kernel_model(kernel_id)) location = url_path_join(self.base_url, 'api', 'kernels', url_escape(kernel_id)) self.set_header('Location', location) self.set_status(201) @@ -57,7 +57,6 @@ class KernelHandler(APIHandler): @web.authenticated def get(self, kernel_id): km = self.kernel_manager - km._check_kernel_id(kernel_id) model = km.kernel_model(kernel_id) self.finish(json.dumps(model, default=date_default)) @@ -87,7 +86,7 @@ def post(self, kernel_id, action): self.log.error("Exception restarting kernel", exc_info=True) self.set_status(500) else: - model = km.kernel_model(kernel_id) + model = yield gen.maybe_future(km.kernel_model(kernel_id)) self.write(json.dumps(model, default=date_default)) self.finish() diff --git a/notebook/services/kernelspecs/handlers.py b/notebook/services/kernelspecs/handlers.py index d272db2f71..c0157e4c57 100644 --- a/notebook/services/kernelspecs/handlers.py +++ b/notebook/services/kernelspecs/handlers.py @@ -11,7 +11,7 @@ import os pjoin = os.path.join -from tornado import web +from tornado import web, gen from ...base.handlers import APIHandler from ...utils import url_path_join, url_unescape @@ -48,13 +48,15 @@ def kernelspec_model(handler, name, spec_dict, resource_dir): class MainKernelSpecHandler(APIHandler): @web.authenticated + @gen.coroutine def get(self): ksm = self.kernel_spec_manager km = self.kernel_manager model = {} model['default'] = km.default_kernel_name model['kernelspecs'] = specs = {} - for kernel_name, kernel_info in ksm.get_all_specs().items(): + kspecs = yield gen.maybe_future(ksm.get_all_specs()) + for kernel_name, kernel_info in kspecs.items(): try: d = kernelspec_model(self, kernel_name, kernel_info['spec'], kernel_info['resource_dir']) @@ -69,11 +71,12 @@ def get(self): class KernelSpecHandler(APIHandler): @web.authenticated + @gen.coroutine def get(self, kernel_name): ksm = self.kernel_spec_manager kernel_name = url_unescape(kernel_name) try: - spec = ksm.get_kernel_spec(kernel_name) + spec = yield gen.maybe_future(ksm.get_kernel_spec(kernel_name)) except KeyError: raise web.HTTPError(404, u'Kernel spec %s not found' % kernel_name) model = kernelspec_model(self, kernel_name, spec.to_dict(), spec.resource_dir) diff --git a/notebook/services/sessions/sessionmanager.py b/notebook/services/sessions/sessionmanager.py index ee70eb0810..4497cfbc33 100644 --- a/notebook/services/sessions/sessionmanager.py +++ b/notebook/services/sessions/sessionmanager.py @@ -56,21 +56,22 @@ def __del__(self): """Close connection once SessionManager closes""" self.close() + @gen.coroutine def session_exists(self, path): """Check to see if the session of a given name exists""" + exists = False self.cursor.execute("SELECT * FROM session WHERE path=?", (path,)) row = self.cursor.fetchone() - if row is None: - return False - else: + if row is not None: # Note, although we found a row for the session, the associated kernel may have # been culled or died unexpectedly. If that's the case, we should delete the # row, thereby terminating the session. This can be done via a call to # row_to_model that tolerates that condition. If row_to_model returns None, # we'll return false, since, at that point, the session doesn't exist anyway. - if self.row_to_model(row, tolerate_culled=True) is None: - return False - return True + model = yield gen.maybe_future(self.row_to_model(row, tolerate_culled=True)) + if model is not None: + exists = True + raise gen.Return(exists) def new_session_id(self): "Create a uuid for a new session" @@ -101,6 +102,7 @@ def start_kernel_for_session(self, session_id, path, name, type, kernel_name): # py2-compat raise gen.Return(kernel_id) + @gen.coroutine def save_session(self, session_id, path=None, name=None, type=None, kernel_id=None): """Saves the items for the session with the given session_id @@ -129,8 +131,10 @@ def save_session(self, session_id, path=None, name=None, type=None, kernel_id=No self.cursor.execute("INSERT INTO session VALUES (?,?,?,?,?)", (session_id, path, name, type, kernel_id) ) - return self.get_session(session_id=session_id) + result = yield gen.maybe_future(self.get_session(session_id=session_id)) + raise gen.Return(result) + @gen.coroutine def get_session(self, **kwargs): """Returns the model for a particular session. @@ -174,8 +178,10 @@ def get_session(self, **kwargs): raise web.HTTPError(404, u'Session not found: %s' % (', '.join(q))) - return self.row_to_model(row) + model = yield gen.maybe_future(self.row_to_model(row)) + raise gen.Return(model) + @gen.coroutine def update_session(self, session_id, **kwargs): """Updates the values in the session database. @@ -191,7 +197,7 @@ def update_session(self, session_id, **kwargs): and the value replaces the current value in the session with session_id. """ - self.get_session(session_id=session_id) + yield gen.maybe_future(self.get_session(session_id=session_id)) if not kwargs: # no changes @@ -205,9 +211,15 @@ def update_session(self, session_id, **kwargs): query = "UPDATE session SET %s WHERE session_id=?" % (', '.join(sets)) self.cursor.execute(query, list(kwargs.values()) + [session_id]) + def kernel_culled(self, kernel_id): + """Checks if the kernel is still considered alive and returns true if its not found. """ + return kernel_id not in self.kernel_manager + + @gen.coroutine def row_to_model(self, row, tolerate_culled=False): """Takes sqlite database session row and turns it into a dictionary""" - if row['kernel_id'] not in self.kernel_manager: + kernel_culled = yield gen.maybe_future(self.kernel_culled(row['kernel_id'])) + if kernel_culled: # The kernel was culled or died without deleting the session. # We can't use delete_session here because that tries to find # and shut down the kernel - so we'll delete the row directly. @@ -222,21 +234,23 @@ def row_to_model(self, row, tolerate_culled=False): format(kernel_id=row['kernel_id'],session_id=row['session_id']) if tolerate_culled: self.log.warning(msg + " Continuing...") - return None + raise gen.Return(None) raise KeyError(msg) + kernel_model = yield gen.maybe_future(self.kernel_manager.kernel_model(row['kernel_id'])) model = { 'id': row['session_id'], 'path': row['path'], 'name': row['name'], 'type': row['type'], - 'kernel': self.kernel_manager.kernel_model(row['kernel_id']) + 'kernel': kernel_model } if row['type'] == 'notebook': # Provide the deprecated API. model['notebook'] = {'path': row['path'], 'name': row['name']} - return model + raise gen.Return(model) + @gen.coroutine def list_sessions(self): """Returns a list of dictionaries containing all the information from the session database""" @@ -246,14 +260,15 @@ def list_sessions(self): # which messes up the cursor if we're iterating over rows. for row in c.fetchall(): try: - result.append(self.row_to_model(row)) + model = yield gen.maybe_future(self.row_to_model(row)) + result.append(model) except KeyError: pass - return result + raise gen.Return(result) @gen.coroutine def delete_session(self, session_id): """Deletes the row in the session database with given session_id""" - session = self.get_session(session_id=session_id) + session = yield gen.maybe_future(self.get_session(session_id=session_id)) yield gen.maybe_future(self.kernel_manager.shutdown_kernel(session['kernel']['id'])) self.cursor.execute("DELETE FROM session WHERE session_id=?", (session_id,)) diff --git a/notebook/services/sessions/tests/test_sessionmanager.py b/notebook/services/sessions/tests/test_sessionmanager.py index 96847a868a..97331ebf9b 100644 --- a/notebook/services/sessions/tests/test_sessionmanager.py +++ b/notebook/services/sessions/tests/test_sessionmanager.py @@ -62,11 +62,11 @@ def co_add(): def create_session(self, **kwargs): return self.create_sessions(kwargs)[0] - + def test_get_session(self): sm = self.sm session_id = self.create_session(path='/path/to/test.ipynb', kernel_name='bar')['id'] - model = sm.get_session(session_id=session_id) + model = self.loop.run_sync(lambda: sm.get_session(session_id=session_id)) expected = {'id':session_id, 'path': u'/path/to/test.ipynb', 'notebook': {'path': u'/path/to/test.ipynb', 'name': None}, @@ -86,7 +86,8 @@ def test_bad_get_session(self): sm = self.sm session_id = self.create_session(path='/path/to/test.ipynb', kernel_name='foo')['id'] - self.assertRaises(TypeError, sm.get_session, bad_id=session_id) # Bad keyword + with self.assertRaises(TypeError): + self.loop.run_sync(lambda: sm.get_session(bad_id=session_id)) # Bad keyword def test_get_session_dead_kernel(self): sm = self.sm @@ -94,9 +95,9 @@ def test_get_session_dead_kernel(self): # kill the kernel sm.kernel_manager.shutdown_kernel(session['kernel']['id']) with self.assertRaises(KeyError): - sm.get_session(session_id=session['id']) + self.loop.run_sync(lambda: sm.get_session(session_id=session['id'])) # no sessions left - listed = sm.list_sessions() + listed = self.loop.run_sync(lambda: sm.list_sessions()) self.assertEqual(listed, []) def test_list_sessions(self): @@ -107,7 +108,7 @@ def test_list_sessions(self): dict(path='/path/to/3', name='foo', type='console', kernel_name='python'), ) - sessions = sm.list_sessions() + sessions = self.loop.run_sync(lambda: sm.list_sessions()) expected = [ { 'id':sessions[0]['id'], @@ -158,7 +159,7 @@ def test_list_sessions_dead_kernel(self): ) # kill one of the kernels sm.kernel_manager.shutdown_kernel(sessions[0]['kernel']['id']) - listed = sm.list_sessions() + listed = self.loop.run_sync(lambda: sm.list_sessions()) expected = [ { 'id': sessions[1]['id'], @@ -181,8 +182,8 @@ def test_update_session(self): sm = self.sm session_id = self.create_session(path='/path/to/test.ipynb', kernel_name='julia')['id'] - sm.update_session(session_id, path='/path/to/new_name.ipynb') - model = sm.get_session(session_id=session_id) + self.loop.run_sync(lambda: sm.update_session(session_id, path='/path/to/new_name.ipynb')) + model = self.loop.run_sync(lambda: sm.get_session(session_id=session_id)) expected = {'id':session_id, 'path': u'/path/to/new_name.ipynb', 'type': 'notebook', @@ -203,7 +204,8 @@ def test_bad_update_session(self): sm = self.sm session_id = self.create_session(path='/path/to/test.ipynb', kernel_name='ir')['id'] - self.assertRaises(TypeError, sm.update_session, session_id=session_id, bad_kw='test.ipynb') # Bad keyword + with self.assertRaises(TypeError): + self.loop.run_sync(lambda: sm.update_session(session_id=session_id, bad_kw='test.ipynb')) # Bad keyword def test_delete_session(self): sm = self.sm @@ -212,8 +214,8 @@ def test_delete_session(self): dict(path='/path/to/2/test2.ipynb', kernel_name='python'), dict(path='/path/to/3', name='foo', type='console', kernel_name='python'), ) - sm.delete_session(sessions[1]['id']) - new_sessions = sm.list_sessions() + self.loop.run_sync(lambda: sm.delete_session(sessions[1]['id'])) + new_sessions = self.loop.run_sync(lambda: sm.list_sessions()) expected = [{ 'id': sessions[0]['id'], 'path': u'/path/to/1/test1.ipynb', diff --git a/notebook/tests/test_gateway.py b/notebook/tests/test_gateway.py index a007046966..ef3cd7ef56 100644 --- a/notebook/tests/test_gateway.py +++ b/notebook/tests/test_gateway.py @@ -1,4 +1,4 @@ -"""Test Gateway""" +"""Test GatewayClient""" import os import json import uuid @@ -7,7 +7,7 @@ from tornado.httpclient import HTTPRequest, HTTPResponse, HTTPError from traitlets.config import Config from .launchnotebook import NotebookTestBase -from notebook.gateway.managers import Gateway +from notebook.gateway.managers import GatewayClient try: from unittest.mock import patch, Mock @@ -25,12 +25,12 @@ def generate_kernelspec(name): argv_stanza = ['python', '-m', 'ipykernel_launcher', '-f', '{connection_file}'] spec_stanza = {'spec': {'argv': argv_stanza, 'env': {}, 'display_name': name, 'language': 'python', 'interrupt_mode': 'signal', 'metadata': {}}} - kernelspec_stanza = {name: {'name': name, 'spec': spec_stanza, 'resources': {}}} + kernelspec_stanza = {'name': name, 'spec': spec_stanza, 'resources': {}} return kernelspec_stanza # We'll mock up two kernelspecs - kspec_foo and kspec_bar -kernelspecs = {'kernelspecs': {'kspec_foo': generate_kernelspec('kspec_foo'), 'kspec_bar': generate_kernelspec('kspec_bar')}} +kernelspecs = {'default': 'kspec_foo', 'kernelspecs': {'kspec_foo': generate_kernelspec('kspec_foo'), 'kspec_bar': generate_kernelspec('kspec_bar')}} # maintain a dictionary of expected running kernels. Key = kernel_id, Value = model. @@ -46,7 +46,7 @@ def generate_model(name): @gen.coroutine -def mock_fetch_gateway(url, **kwargs): +def mock_gateway_request(url, **kwargs): method = 'GET' if kwargs['method']: method = kwargs['method'] @@ -133,7 +133,7 @@ def mock_fetch_gateway(url, **kwargs): raise HTTPError(404, message='Kernel does not exist: %s' % requested_kernel_id) -mocked_gateway = patch('notebook.gateway.managers.fetch_gateway', mock_fetch_gateway) +mocked_gateway = patch('notebook.gateway.managers.gateway_request', mock_gateway_request) class TestGateway(NotebookTestBase): @@ -143,12 +143,12 @@ class TestGateway(NotebookTestBase): @classmethod def setup_class(cls): - Gateway.clear_instance() + GatewayClient.clear_instance() super(TestGateway, cls).setup_class() @classmethod def teardown_class(cls): - Gateway.clear_instance() + GatewayClient.clear_instance() super(TestGateway, cls).teardown_class() @classmethod @@ -161,7 +161,7 @@ def get_patch_env(cls): @classmethod def get_argv(cls): argv = super(TestGateway, cls).get_argv() - argv.extend(['--Gateway.connect_timeout=44.4', '--Gateway.http_user=' + TestGateway.mock_http_user]) + argv.extend(['--GatewayClient.connect_timeout=44.4', '--GatewayClient.http_user=' + TestGateway.mock_http_user]) return argv def test_gateway_options(self): @@ -185,15 +185,14 @@ def test_gateway_get_kernelspecs(self): content = json.loads(response.content.decode('utf-8'), encoding='utf-8') kspecs = content.get('kernelspecs') self.assertEqual(len(kspecs), 2) - self.assertEqual(kspecs.get('kspec_bar').get('kspec_bar').get('name'), 'kspec_bar') + self.assertEqual(kspecs.get('kspec_bar').get('name'), 'kspec_bar') def test_gateway_get_named_kernelspec(self): # Validate that a specific kernelspec can be retrieved from gateway. with mocked_gateway: response = self.request('GET', '/api/kernelspecs/kspec_foo') self.assertEqual(response.status_code, 200) - content = json.loads(response.content.decode('utf-8'), encoding='utf-8') - kspec_foo = content.get('kspec_foo') + kspec_foo = json.loads(response.content.decode('utf-8'), encoding='utf-8') self.assertEqual(kspec_foo.get('name'), 'kspec_foo') response = self.request('GET', '/api/kernelspecs/no_such_spec')