Skip to content

Commit

Permalink
vdk-jupyter: add oauth2 authentication implementation
Browse files Browse the repository at this point in the history
This is adding the server part of Oauth2 authentication process.

It adds 1 more APIs: `/login`

When called it without "code" query paramter, it will start the
authentication proces as per OAuth2 standard .
We are using only native app workflow with PKCE (RFC 7636) because we
cannot really secure the server side so we cannot reliably use client
secret.
When called with "code" query paramter it will finish the process and
exchange the code for access token (and refresh token) and safe it in
VDK storage.

This change add integration with jupyter configuration. This way the
extension can be configured more natively using jupyter configuration
mechanism.

In future change we can add integration between VDK configuration
mechanims and jupyter so that properties set in VDK can be recognized in
Jupyter and vice-versa but that's more advanced use-case
  • Loading branch information
antoniivanov committed Aug 21, 2023
1 parent 3fbaca2 commit 103eb65
Show file tree
Hide file tree
Showing 5 changed files with 213 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Copyright 2021-2023 VMware, Inc.
# SPDX-License-Identifier: Apache-2.0
import tornado

from ._version import __version__
from .config import VdkJupyterConfig
from .handlers import setup_handlers


Expand All @@ -20,7 +23,12 @@ def _load_jupyter_server_extension(server_app):
server_app: jupyterlab.labapp.LabApp
JupyterLab application instance
"""
setup_handlers(server_app.web_app)
tornado.log.enable_pretty_logging()
tornado.log.app_log.setLevel(tornado.log.logging.DEBUG)

cfg = VdkJupyterConfig(config=server_app.config)

setup_handlers(server_app.web_app, cfg)
name = "vdk_jupyterlab_extension"
server_app.log.info(f"Registered {name} server extension")

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright 2021-2023 VMware, Inc.
# SPDX-License-Identifier: Apache-2.0
from traitlets import Unicode
from traitlets.config import Configurable


class VdkJupyterConfig(Configurable):
oauth2_authorization_url = Unicode(
"",
config=True,
help="The Oauth2 authorization URL. "
"This is the URL used to start the authentication process."
"Used to redirect to the authorization provider for user to login.",
)
oauth2_token_url = Unicode(
"",
config=True,
help="The Oauth2 token URL. "
"Used in the second phase of authentication process. "
"Used to exchange authorization code with access token.",
)
oauth2_client_id = Unicode(
"",
config=True,
help="The Oauth2 client ID. Note that client secret is not specified "
"since we only support native app workflow with PKCE (RFC 7636)",
)
oauth2_redirect_url = Unicode(
"",
config=True,
help="The Oauth2 Redirect URL (or callback URL)."
" This is the URL that authorization provider will redirect back the user to with the authorization code."
" If empty automatically the request URL will be used.",
)
rest_api_url = Unicode(
default_value="", config=True, help="The VDK Control Service REST API URL"
)
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from jupyter_server.base.handlers import APIHandler
from jupyter_server.utils import url_path_join

from . import VdkJupyterConfig
from .job_data import JobDataLoader
from .oauth2 import OAuth2Handler
from .vdk_options.vdk_options import VdkOption
from .vdk_ui import VdkUI

Expand Down Expand Up @@ -200,17 +202,18 @@ def get(self):
self.finish(json.dumps(os.getcwd()))


def setup_handlers(web_app):
def setup_handlers(web_app, cfg: VdkJupyterConfig):
host_pattern = ".*$"
base_url = web_app.settings["base_url"]

def add_handler(handler, endpoint):
job_route_pattern = url_path_join(
base_url, "vdk-jupyterlab-extension", endpoint
)
job_handlers = [(job_route_pattern, handler)]
job_handlers = [(job_route_pattern, handler, {"vdk_config": cfg})]
web_app.add_handlers(host_pattern, job_handlers)

add_handler(OAuth2Handler, "login")
add_handler(RunJobHandler, "run")
add_handler(DeleteJobHandler, "delete")
add_handler(DownloadJobHandler, "download")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright 2021-2023 VMware, Inc.
# SPDX-License-Identifier: Apache-2.0
import json
import logging
import os
import time
from urllib.parse import urlparse
from urllib.parse import urlunparse

import requests
from jupyter_server.base.handlers import APIHandler
from requests.auth import HTTPBasicAuth
from requests_oauthlib import OAuth2Session
from vdk.plugin.control_api_auth.auth_config import InMemAuthConfiguration
from vdk.plugin.control_api_auth.auth_request_values import AuthRequestValues
from vdk.plugin.control_api_auth.autorization_code_auth import generate_pkce_codes
from vdk.plugin.control_api_auth.base_auth import BaseAuth
from vdk_jupyterlab_extension import VdkJupyterConfig

log = logging.getLogger(__name__)


class OAuth2Handler(APIHandler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def initialize(self, vdk_config: VdkJupyterConfig):
self._authorization_url = vdk_config.oauth2_authorization_url
self._access_token_url = vdk_config.oauth2_token_url
self._client_id = vdk_config.oauth2_client_id
self._redirect_url = vdk_config.oauth2_redirect_url

log.info(f"Authorization URL: {self._authorization_url}")
log.info(f"Access Token URL: {self._access_token_url}")
# log.info(f"client_id: {self._client_id}")

# No client secret. We use only native app workflow with PKCE (RFC 7636)

@staticmethod
def _fix_localhost(uri: str):
"""
This is added for local testing. Oauthorization Providers generally allow 127.0.0.1 to be registered as redirect URL
so we change localhost to 127.0.0.1
:param uri:
:return:
"""
parsed_uri = urlparse(uri)

if parsed_uri.hostname == "localhost":
netloc = parsed_uri.netloc.replace("localhost", "127.0.0.1")
modified_uri = parsed_uri._replace(netloc=netloc, query="")
return urlunparse(modified_uri)
else:
modified_uri = parsed_uri._replace(query="")
return urlunparse(modified_uri)

def get(self):
# TODO: this is duplicating a lot of the code in vdk-control-api-auth
# https://github.com/vmware/versatile-data-kit/tree/main/projects/vdk-plugins/vdk-control-api-auth
# But that module is written with focus on CLI usage a bit making it harder to reuse
# and it needs to be refactored first.
redirect_url = self._redirect_url
if not redirect_url:
redirect_url = self.request.full_url()
redirect_url = self._fix_localhost(redirect_url)

log.info(f"redirect uri is {redirect_url}")

if self.get_argument("code", None):
log.info(
"Authorization code received. Will generate access token using authorization code."
)
tokens = self._exchange_auth_code_for_access_token(redirect_url)
log.info(f"Got tokens data: {tokens}") # TODO: remove this
self._persist_tokens_data(tokens)
else:
log.info(f"Authorization URL is: {self._authorization_url}")
full_authorization_url = self._prepare_authorization_code_request_url(
redirect_url
)
self.finish(full_authorization_url)

def _persist_tokens_data(self, tokens):
auth = BaseAuth(conf=InMemAuthConfiguration())
auth.update_oauth2_authorization_url(self._access_token_url)
auth.update_client_id(self._client_id)
auth.update_access_token(tokens.get(AuthRequestValues.ACCESS_TOKEN_KEY.value))
auth.update_access_token_expiration_time(
time.time() + int(tokens[AuthRequestValues.EXPIRATION_TIME_KEY.value])
)
if AuthRequestValues.REFRESH_TOKEN_GRANT_TYPE in tokens:
auth.update_refresh_token(
tokens.get(AuthRequestValues.REFRESH_TOKEN_GRANT_TYPE)
)

def _prepare_authorization_code_request_url(self, redirect_uri):
(code_verifier, code_challenge, code_challenge_method) = generate_pkce_codes()
self.application.settings["code_verifier"] = code_verifier
oauth = OAuth2Session(client_id=self._client_id, redirect_uri=redirect_uri)
full_authorization_url = oauth.authorization_url(
self._authorization_url,
state="requested",
prompt=AuthRequestValues.LOGIN_PROMPT.value,
code_challenge=code_challenge,
code_challenge_method=code_challenge_method,
)[0]
return full_authorization_url

def _exchange_auth_code_for_access_token(self, redirect_uri) -> dict:
code = self.get_argument("code")
headers = {
AuthRequestValues.CONTENT_TYPE_HEADER.value: AuthRequestValues.CONTENT_TYPE_URLENCODED.value,
}
code_verifier = self.application.settings["code_verifier"]

data = (
f"code={code}&"
+ f"grant_type=authorization_code&"
+ f"code_verifier={code_verifier}&"
f"redirect_uri={redirect_uri}"
)
basic_auth = HTTPBasicAuth(self._client_id, "")
try:
# TODO : this should be async io
response = requests.post(
self._access_token_url, data=data, headers=headers, auth=basic_auth
)
if response.status_code >= 400:
log.error(
f"Request to {self._access_token_url} with data {data} returned {response.status_code}\n"
rf"Reason: {response.reason}\dn"
f"Response content: {response.content}\n"
f"Response headers: {response.headers}"
)

json_data = json.loads(response.text)
return json_data
except Exception as e:
log.exception(e)
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright 2021-2023 VMware, Inc.
# SPDX-License-Identifier: Apache-2.0
from vdk_jupyterlab_extension.handlers import OAuth2Handler


def test_fix_redirect_uri():
assert (
OAuth2Handler._fix_localhost("http://localhost?foo=bar") == "http://127.0.0.1"
)
assert (
OAuth2Handler._fix_localhost("http://localhost:8888?foo=bar")
== "http://127.0.0.1:8888"
)
assert (
OAuth2Handler._fix_localhost("http://something?foo=bar") == "http://something"
)
assert (
OAuth2Handler._fix_localhost("http://something:9999?foo=bar")
== "http://something:9999"
)
assert (
OAuth2Handler._fix_localhost("http://something:9999") == "http://something:9999"
)

0 comments on commit 103eb65

Please sign in to comment.