Skip to content

Commit

Permalink
sharing service backend (jupyter-server#432)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zsailer committed Nov 8, 2022
1 parent 14d4e5c commit bd88090
Show file tree
Hide file tree
Showing 14 changed files with 720 additions and 5 deletions.
9 changes: 8 additions & 1 deletion data_studio_jupyter_extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ def _jupyter_server_extension_points(): # pragma: no cover
KernelWebsocketOverrideExtension,
)
from data_studio_jupyter_extensions.extensions.identity.extension import (
IdentityExtension,
IdentityExtension
)
from data_studio_jupyter_extensions.extensions.publishing.extension import (
PublishingExtension,
)

return [
Expand Down Expand Up @@ -74,4 +77,8 @@ def _jupyter_server_extension_points(): # pragma: no cover
"module": "data_studio_jupyter_extensions.extensions.identity.extension",
"app": IdentityExtension,
},
{
"module": "data_studio_jupyter_extensions.extensions.publishing.extension",
"app": PublishingExtension,
},
]
1 change: 1 addition & 0 deletions data_studio_jupyter_extensions/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
"session_manager_class": "data_studio_jupyter_extensions.configurables.session_manager.DataStudioSessionManager",
"login_handler_class": "data_studio_jupyter_extensions.auth.login.DataStudioLoginHandler",
"logout_handler_class": "data_studio_jupyter_extensions.auth.logout.DataStudioLogoutHandler",
"contents_manager_class": "data_studio_jupyter_extensions.configurables.contents_manager.DataStudioContentsManager",
"cookie_options": {"expires_days": 1},
"max_body_size": 2 * 1024 * 1024 * 1024,
"max_buffer_size": 2 * 1024 * 1024 * 1024,
Expand Down
40 changes: 38 additions & 2 deletions data_studio_jupyter_extensions/auth/authenticator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from dataclasses import dataclass

import jwcrypto.jws as jws
import jwcrypto.jwt as jwt
from jupyter_server.utils import run_sync
from jwcrypto.common import json_decode
from jwcrypto.jwk import JWK
from tornado.httpclient import AsyncHTTPClient
from tornado.httpclient import HTTPRequest
from traitlets import Dict
from traitlets import List
from traitlets.config import LoggingConfigurable

Expand All @@ -13,8 +16,15 @@
from data_studio_jupyter_extensions.traits import UnicodeFromEnv


class JWTAuthenticator(LoggingConfigurable):
@dataclass
class User:
first_name: str
last_name: str
dsid: str
email: str


class JWTAuthenticator(LoggingConfigurable):
_alg = "ES256"
_issuer = "PIE DSW"

Expand Down Expand Up @@ -45,6 +55,8 @@ class JWTAuthenticator(LoggingConfigurable):
allow_none=True,
).tag(config=True)

user = Dict(allow_none=True)

@property
def _expected_payload(self):
return {"client_id": self.client_id}
Expand Down Expand Up @@ -122,11 +134,35 @@ def verified_claims_are_validated(
return False
return True

def get_user_from_token(self, data_studio_jwt):
for public_key in self.get_public_keys():
try:
decoded_jwt = jwt.JWT(
jwt=data_studio_jwt,
key=public_key,
check_claims={"exp": None},
)
c = json_decode(decoded_jwt.claims)
user = {
"dsid": c["corpds:ds:dsid"],
"firstName": c["corpds:ds:firstName"],
"lastName": c["corpds:ds:lastName"],
"email": c["corpds:ds:email"],
}
return user
except jws.InvalidJWSSignature as e:
print("retrying with next public key", flush=True)
continue

def is_authenticated(self, data_studio_jwt: str) -> bool:
public_keys = self.get_public_keys()
for public_key in public_keys:
try:
return self.verified_claims_are_validated(data_studio_jwt, public_key)
verified = self.verified_claims_are_validated(
data_studio_jwt, public_key
)
self.user = self.get_user_from_token(data_studio_jwt)
return verified
except jwt.JWTMissingClaim as e:
print("MissingClaim: {}".format(e), flush=True)
break
Expand Down
4 changes: 4 additions & 0 deletions data_studio_jupyter_extensions/base_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,7 @@ def datastudio_cookie_name(self) -> str:
@property
def identity_provider(self) -> DataStudioIdentityProvider:
return self.settings[self.name]["identity_provider"]

@property
def publishing_client(self):
return self.settings[self.name]["publishing_client"]
39 changes: 39 additions & 0 deletions data_studio_jupyter_extensions/configurables/contents_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""NOTE:
A custom override of the Contents Manager to add
the file ID service.
This is added in an unmerged PR: https://github.com/jupyter-server/jupyter_server/pull/921
Once this PR is merged and released, we can remove this override.
"""
from jupyter_server.services.contents.largefilemanager import LargeFileManager
from traitlets import default
from traitlets import Instance

from .fileid_manager import FileIdManager


class DataStudioContentsManager(LargeFileManager):
file_id_manager = Instance(
FileIdManager,
help="File ID manager instance to use. Defaults to `FileIdManager`.",
)

@default("file_id_manager")
def _default_file_id_manager(self):
return FileIdManager(parent=self, log=self.log)

def delete(self, path):
super().delete(path)
is_dir = self.dir_exists(path)
self.file_id_manager.delete(path, recursive=is_dir)

def rename(self, old_path, new_path):
super().rename(old_path, new_path)
is_dir = self.dir_exists(old_path)
self.file_id_manager.move(old_path, new_path, recursive=is_dir)

def copy(self, from_path, to_path=None):
super().copy(from_path, to_path=to_path)
is_dir = self.dir_exists(from_path)
self.file_id_manager.copy(from_path, to_path, recursive=is_dir)
154 changes: 154 additions & 0 deletions data_studio_jupyter_extensions/configurables/fileid_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
"""A fork of this PR in Jupyter Server: https://github.com/jupyter-server/jupyter_server/pull/921
"""
import os
import sqlite3
import uuid

from traitlets import Unicode
from traitlets.config.configurable import LoggingConfigurable

from data_studio_jupyter_extensions.constants import JUPYTER_PERSISTENT_DB_PATH


class FileIdManager(LoggingConfigurable):

database_filepath = Unicode(JUPYTER_PERSISTENT_DB_PATH)

_cursor = None
_connection = None

@property
def cursor(self):
"""Start a cursor and create a database called 'session'"""
if self._cursor is None:
self.log.debug("Creating File ID tables and indices")
self._cursor = self.connection.cursor()
self._cursor.execute(
"CREATE TABLE IF NOT EXISTS Files(id BLOB PRIMARY KEY, path TEXT NOT NULL UNIQUE)"
)
self._cursor.execute(
"CREATE INDEX IF NOT EXISTS ix_Files_path ON FILES (path)"
)
return self._cursor

@property
def connection(self):
"""Start a database connection"""
if self._connection is None:
# Set isolation level to None to autocommit all changes to the database.
self._connection = sqlite3.connect(
self.database_filepath, isolation_level=None
)
self._connection.row_factory = sqlite3.Row
return self._connection

def close(self):
"""Close the sqlite connection"""
if self._cursor is not None:
self._cursor.close()
self._cursor = None

def __del__(self):
"""Close connection once SessionManager closes"""
self.close()

def _normalize_path(self, path):
"""Normalizes a given file path."""
path = os.path.normcase(path)
path = os.path.normpath(path)
return path

def index(self, path):
"""Adds the file path to the Files table, then returns the file ID. If
the file is already indexed, the file ID is immediately returned."""
path = self._normalize_path(path)
existing_id = self.get_id(path)
if existing_id is not None:
return existing_id
# Generate a unique ID
file_id = uuid.uuid4()
file_id_bytes = file_id.bytes
self.cursor.execute(
"INSERT INTO Files (id, path) VALUES (?, ?)",
(
file_id_bytes,
path,
),
)
return file_id

def get_id(self, path):
"""Retrieves the file ID associated with a file path. Returns None if
the file path has not yet been indexed."""
path = self._normalize_path(path)
row = self.cursor.execute(
"SELECT id FROM Files WHERE path = ?", (path,)
).fetchone()
file_id_bytes = row[0] if row else None
file_id = None
if file_id_bytes:
file_id = uuid.UUID(bytes=file_id_bytes)
return file_id

def get_path(self, id):
"""Retrieves the file path associated with a file ID. Returns None if
the ID does not exist in the Files table."""
row = self.cursor.execute(
"SELECT path FROM Files WHERE id = ?", (id,)
).fetchone()
return row[0] if row else None

def move(self, old_path, new_path, recursive=False):
"""Handles file moves by updating the file path of the associated file
ID. Returns the file ID."""
old_path = self._normalize_path(old_path)
new_path = self._normalize_path(new_path)
self.log.debug(f"Moving file from ${old_path} to ${new_path}")

if recursive:
old_path_glob = os.path.join(old_path, "*")
self.cursor.execute(
"UPDATE Files SET path = ? || substr(path, ?) WHERE path GLOB ?",
(new_path, len(old_path) + 1, old_path_glob),
)

id = self.get_id(old_path)
if id is None:
return self.index(new_path)
else:
self.cursor.execute(
"UPDATE Files SET path = ? WHERE id = ?", (new_path, id)
)
return id

def copy(self, from_path, to_path, recursive=False):
"""Handles file copies by creating a new record in the Files table.
Returns the file ID associated with `new_path`. Also indexes `old_path`
if record does not exist in Files table. TODO: emit to event bus to
inform client extensions to copy records associated with old file ID to
the new file ID."""
from_path = self._normalize_path(from_path)
to_path = self._normalize_path(to_path)
self.log.debug(f"Copying file from ${from_path} to ${to_path}")

if recursive:
from_path_glob = os.path.join(from_path, "*")
self.cursor.execute(
"INSERT INTO Files (path) SELECT (? || substr(path, ?)) FROM Files WHERE path GLOB ?",
(to_path, len(from_path) + 1, from_path_glob),
)

self.index(from_path)
return self.index(to_path)

def delete(self, path, recursive=False):
"""Handles file deletions by deleting the associated record in the File
table. Returns None."""
path = self._normalize_path(path)
self.log.debug(f"Deleting file {path}")

if recursive:
path_glob = os.path.join(path, "*")
self.cursor.execute("DELETE FROM Files WHERE path GLOB ?", (path_glob,))

self.cursor.execute("DELETE FROM Files WHERE path = ?", (path,))
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
from traitlets import Unicode
from traitlets.config import SingletonConfigurable

from data_studio_jupyter_extensions.constants import JUPYTER_PERSISTENT_DB_PATH
from data_studio_jupyter_extensions.hubble import HUBBLE_METRICS


KERNEL_SESSION_DB_PATH = osp.join(jupyter_runtime_dir(), "jupyter-session.db")
KERNEL_CACHE_PATH = jupyter_runtime_dir()


Expand All @@ -39,7 +39,7 @@ def _default_event_logger(self):
if self.parent and hasattr(self.parent, "event_logger"):
return self.parent.event_logger

database_filepath = Unicode(KERNEL_SESSION_DB_PATH)
database_filepath = Unicode(JUPYTER_PERSISTENT_DB_PATH)
cache = Cache(
KERNEL_CACHE_PATH,
eviction_policy="least-recently-used",
Expand Down
5 changes: 5 additions & 0 deletions data_studio_jupyter_extensions/constants.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import os.path as osp
from dataclasses import dataclass

from jupyter_core.paths import jupyter_runtime_dir

JUPYTER_PERSISTENT_DB_PATH = osp.join(jupyter_runtime_dir(), "jupyter-session.db")

DS_URL = "DATASTUDIO_UI_URL"
DS_API_URL = "DATASTUDIO_API_URL"
DS_PROJECT_ID = "DATASTUDIO_PROJECT_ID"
Expand Down
Empty file.
Loading

0 comments on commit bd88090

Please sign in to comment.