Skip to content

Commit

Permalink
Add stlite modifications (#672)
Browse files Browse the repository at this point in the history
* Add stlite modifications

* Add back connection manager

* Move methods

* Some updates

* Remove header

* More cleanup

* More changes

* More cleanup

* More modifications

* Don't update pyiodide

* More cleanup

* Don't Update pyodide

* Don't update pyiodide

* More cleanup

* More cleanup

* Rename of option

* Remove new line

* Minor cleanup

* Fix wrong option name

* Minor renaming

* Fix typo

Co-authored-by: Yuichiro Tachibana (Tsuchiya) <[email protected]>

* Update comment

Co-authored-by: Yuichiro Tachibana (Tsuchiya) <[email protected]>

* Remove paths

* Change to put

* Update URL

* Add gatherUsageStats flag

---------

Co-authored-by: Yuichiro Tachibana (Tsuchiya) <[email protected]>
  • Loading branch information
lukasmasuch and whitphx authored Jan 17, 2024
1 parent 05396fc commit 42f3a4a
Show file tree
Hide file tree
Showing 15 changed files with 128 additions and 198 deletions.
10 changes: 9 additions & 1 deletion packages/common-react/src/toastify-components/callbacks.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@ interface ToastKernelCallbacks {
onLoad: NonNullable<StliteKernelOptions["onLoad"]>;
onError: NonNullable<StliteKernelOptions["onError"]>;
}
export function makeToastKernelCallbacks(): ToastKernelCallbacks {
export function makeToastKernelCallbacks(disableProgressToasts = false, disableErrorToasts = false): ToastKernelCallbacks {
let prevToastId: ToastId | null = null;
const toastIds: ToastId[] = [];
const onProgress: StliteKernelOptions["onProgress"] = (message) => {
if (disableProgressToasts) {
return;
}

const id = toast(message, {
position: toast.POSITION.BOTTOM_RIGHT,
transition: Slide,
Expand All @@ -33,6 +37,10 @@ export function makeToastKernelCallbacks(): ToastKernelCallbacks {
toastIds.forEach((id) => toast.dismiss(id));
};
const onError: StliteKernelOptions["onError"] = (error) => {
if (disableErrorToasts) {
return;
}

toast(
<ErrorToastContent message="Error during booting up" error={error} />,
{
Expand Down
6 changes: 1 addition & 5 deletions packages/common-react/tsconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,7 @@
// "rootDir": "./", /* Specify the root folder within your source files. */
"moduleResolution": "node" /* Specify how TypeScript looks up a file from a given module specifier. */,
// "baseUrl": "./", /* Specify the base directory to resolve non-relative module names. */
"paths": {
"src/theme": ["../../streamlit/frontend/src/theme"],
"src/theme/*": ["../../streamlit/frontend/src/theme/*"],
"src/lib/*": ["../../streamlit/frontend/src/lib/*"],
} /* Specify a set of entries that re-map imports to additional lookup locations. */,
// "paths": {} /* Specify a set of entries that re-map imports to additional lookup locations. */,
// "rootDirs": [], /* Allow multiple folders to be treated as one when resolving modules. */
// "typeRoots": [], /* Specify multiple folders that act like `./node_modules/@types`. */
// "types": [], /* Specify type package names to be included without being referenced in a source file. */
Expand Down
26 changes: 26 additions & 0 deletions packages/kernel/py/stlite-server/stlite_server/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,30 @@ def _on_pages_changed(_path: str) -> None:
allow_nonexistent=True,
)

def _fix_altair():
"""Fix an issue with Altair and the mocked pyarrow module of stlite."""
try:
from altair.utils import _importers

def _pyarrow_available():
return False

_importers.pyarrow_available = _pyarrow_available

def _import_pyarrow_interchange():
raise ImportError("Pyarrow is not available in stlite.")

_importers.import_pyarrow_interchange = _import_pyarrow_interchange
except:
pass

def _fix_requests():
try:
import pyodide_http
pyodide_http.patch_all() # Patch all libraries
except ImportError:
# pyodide_http is not installed. No need to do anything.
pass

def prepare(
main_script_path: str,
Expand All @@ -135,6 +159,8 @@ def prepare(
"""
_fix_sys_path(main_script_path)
_fix_matplotlib_crash()
_fix_altair()
_fix_requests()
_fix_sys_argv(main_script_path, args)
_fix_pydeck_mapbox_api_warning()
_install_pages_watcher(main_script_path)
5 changes: 4 additions & 1 deletion packages/kernel/py/stlite-server/stlite_server/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,8 @@ class RequestHandler(abc.ABC):
def get(self, request: Request) -> Response | Awaitable[Response]:
return Response(status_code=405, headers={}, body="")

def post(self, request: Request) -> Response | Awaitable[Response]:
def put(self, request: Request) -> Response | Awaitable[Response]:
return Response(status_code=405, headers={}, body="")

def delete(self, request: Request) -> Response | Awaitable[Response]:
return Response(status_code=405, headers={}, body="")
16 changes: 12 additions & 4 deletions packages/kernel/py/stlite-server/stlite_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
from streamlit.runtime import Runtime, RuntimeConfig, SessionClient
from streamlit.runtime.memory_media_file_storage import MemoryMediaFileStorage
from streamlit.runtime.runtime_util import serialize_forward_msg
from streamlit.runtime.memory_uploaded_file_manager import MemoryUploadedFileManager
from streamlit.web.cache_storage_manager_config import (
create_default_cache_storage_manager,
)

from .component_request_handler import ComponentRequestHandler
from .handler import RequestHandler
Expand All @@ -21,9 +25,10 @@

LOGGER = logging.getLogger(__name__)

# These route definitions are copied from the original impl at https://github.com/streamlit/streamlit/blob/1.18.1/lib/streamlit/web/server/server.py#L81-L89 # noqa: E501
# These route definitions are copied from the original impl at https://github.com/streamlit/streamlit/blob/1.27.0/lib/streamlit/web/server/server.py#L83-L92 # noqa: E501
UPLOAD_FILE_ENDPOINT: Final = "/_stcore/upload_file"
MEDIA_ENDPOINT: Final = "/media"
STREAM_ENDPOINT: Final = r"_stcore/stream"
STREAM_ENDPOINT: Final = "_stcore/stream"
HEALTH_ENDPOINT: Final = r"(?:healthz|_stcore/health)"


Expand All @@ -34,12 +39,15 @@ def __init__(self, main_script_path: str, command_line: str | None) -> None:
self._main_script_path = main_script_path

self._media_file_storage = MemoryMediaFileStorage(MEDIA_ENDPOINT)
self.uploaded_file_mgr = MemoryUploadedFileManager(UPLOAD_FILE_ENDPOINT)

self._runtime = Runtime(
RuntimeConfig(
script_path=main_script_path,
command_line=command_line,
media_file_storage=self._media_file_storage,
uploaded_file_manager=self.uploaded_file_mgr,
cache_storage_manager=create_default_cache_storage_manager(),
),
)

Expand Down Expand Up @@ -152,8 +160,8 @@ def receive_http(
on_response(404, {}, b"No handler found")
return
method_name = method.lower()
if method_name not in ("get", "post"):
on_response(405, {}, b"Now allowed")
if method_name not in ("get", "put", "delete"):
on_response(405, {}, b"Not allowed")
return
handler_method = getattr(handler, method_name, None)
if handler_method is None:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copied from https://github.com/streamlit/streamlit/blob/1.18.1/lib/streamlit/web/server/server_util.py#L73-L77 # noqa: E501
# Copied from https://github.com/streamlit/streamlit/blob/1.27.1/lib/streamlit/web/server/server_util.py#L73-L77 # noqa: E501
def make_url_path_regex(*path: str, **kwargs) -> str:
"""Get a regex of the form ^/foo/bar/baz/?$ for a path (foo, bar, baz)."""
valid_path = [x.strip("/") for x in path if x] # Filter out falsely components.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import logging
from typing import Callable, Dict, List
import re

from streamlit.runtime.uploaded_file_manager import UploadedFileManager, UploadedFileRec

from .handler import Request, RequestHandler, Response
from .httputil import HTTPFile, parse_body_arguments

# /_stcore/upload_file/(optional session id)/(optional widget id)
# /_stcore/upload_file/(optional session id)/(optional file id)
UPLOAD_FILE_ROUTE = (
r"/_stcore/upload_file/?(?P<session_id>[^/]*)?/?(?P<widget_id>[^/]*)?"
r"/_stcore/upload_file/(?P<session_id>[^/]+)/(?P<file_id>[^/]+)"
)
LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -40,7 +41,7 @@ def _require_arg(args: Dict[str, List[bytes]], name: str) -> str:
# Convert bytes to string
return arg[0].decode("utf-8")

def post(self, request: Request, **kwargs) -> Response:
def put(self, request: Request, **kwargs) -> Response:
# NOTE: The original implementation uses an async function,
# but it didn't make use of any async features,
# so we made it a regular function here for simplicity sake.
Expand All @@ -61,8 +62,11 @@ def post(self, request: Request, **kwargs) -> Response:
)

try:
session_id = self._require_arg(args, "sessionId")
widget_id = self._require_arg(args, "widgetId")
path_args = re.match(UPLOAD_FILE_ROUTE, request.path)
session_id = path_args.group('session_id')
file_id = path_args.group('file_id')
# session_id = self._require_arg(args, "sessionId")
# file_id = self._require_arg(args, "fileId")
if not self._is_active_session(session_id):
raise Exception(f"Invalid session_id: '{session_id}'")

Expand All @@ -78,7 +82,7 @@ def post(self, request: Request, **kwargs) -> Response:
for file in flist:
uploaded_files.append(
UploadedFileRec(
id=0,
file_id=file_id,
name=file.filename,
type=file.content_type,
data=file.body,
Expand All @@ -92,10 +96,17 @@ def post(self, request: Request, **kwargs) -> Response:
body=f"Expected 1 file, but got {len(uploaded_files)}",
)

added_file = self._file_mgr.add_file(
session_id=session_id, widget_id=widget_id, file=uploaded_files[0]
self._file_mgr.add_file(
session_id=session_id, file=uploaded_files[0]
)
return Response(status_code=204, headers={}, body="")

# Return the file_id to the client. (The client will parse
# the string back to an int.)
return Response(status_code=200, headers={}, body=str(added_file.id))
def delete(self, request: Request, **kwargs):
"""Delete file request handler."""

path_args = re.match(UPLOAD_FILE_ROUTE, request.path)
session_id = path_args.group('session_id')
file_id = path_args.group('file_id')

self._file_mgr.remove_file(session_id=session_id, file_id=file_id)
self.set_status(204)
1 change: 0 additions & 1 deletion packages/kernel/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
export * from "./kernel";
export * from "./streamlit-replacements/lib/ConnectionManager";
export * from "./streamlit-replacements/lib/FileUploadClient";
export * from "./react-helpers";
export * from "./types";
6 changes: 6 additions & 0 deletions packages/kernel/src/kernel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ export interface StliteKernelOptions {
*/
streamlitConfig?: StreamlitConfig;

/**
* If true, no toasts will be shown on loading progress steps.
*/
disableProgressToasts?: boolean;

onProgress?: (message: string) => void;

onLoad?: () => void;
Expand Down Expand Up @@ -186,6 +191,7 @@ export class StliteKernel {
archives: options.archives,
requirements: options.requirements,
pyodideUrl: options.pyodideUrl,
disableProgressToasts: options.disableProgressToasts,
wheels,
mountedSitePackagesSnapshotFilePath:
options.mountedSitePackagesSnapshotFilePath,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
// Mimic https://github.com/streamlit/streamlit/blob/1.9.0/frontend/src/lib/ConnectionManager.ts
// Mimic https://github.com/streamlit/streamlit/blob/1.27.0/frontend/app/src/connection/ConnectionManager.ts
// and WebsocketConnection.

import type { ReactNode } from "react"

import { BackMsg, ForwardMsg } from "@streamlit/lib/src/proto"
import type { IAllowedMessageOriginsResponse } from "@streamlit/lib/src/hostComm/types"
import type { BaseUriParts } from "@streamlit/lib/src/util/UriUtil"
import {
IAllowedMessageOriginsResponse,
BaseUriParts,
SessionInfo,
StreamlitEndpoints,
ensureError,
BackMsg,
ForwardMsg,
} from "@streamlit/lib"

import type { StliteKernel } from "../../kernel"
import { ConnectionState } from "@streamlit/app/src/connection/ConnectionState"
import type { SessionInfo } from "@streamlit/lib/src/SessionInfo"
import { ensureError } from "@streamlit/lib/src/util/ErrorHandling"
import { DUMMY_BASE_HOSTNAME, DUMMY_BASE_PORT } from "../../consts"
import { ConnectionState } from "./ConnectionState"

import type { StliteKernel } from "@stlite/kernel"

interface MessageQueue {
[index: number]: any

Check warning on line 22 in packages/kernel/src/streamlit-replacements/lib/ConnectionManager.ts

View workflow job for this annotation

GitHub Actions / test-kernel

Unexpected any. Specify a different type
Expand All @@ -26,6 +31,9 @@ interface Props {
/** The app's SessionInfo instance */
sessionInfo: SessionInfo

/** The app's StreamlitEndpoints instance */
endpoints: StreamlitEndpoints

/**
* Function to be called when we receive a message from the server.
*/
Expand All @@ -41,13 +49,29 @@ interface Props {
*/
connectionStateChanged: (connectionState: ConnectionState) => void

/**
* Function to get the auth token set by the host of this app (if in a
* relevant deployment scenario).
*/
claimHostAuthToken: () => Promise<string | undefined>

/**
* Function to clear the withHostCommunication hoc's auth token. This should
* be called after the promise returned by claimHostAuthToken successfully
* resolves.
*/
resetHostAuthToken: () => void

/**
* Function to set the list of origins that this app should accept
* cross-origin messages from (if in a relevant deployment scenario).
*/
setAllowedOriginsResp: (resp: IAllowedMessageOriginsResponse) => void
}

/**
* Manages our connection to the Server.
*/
export class ConnectionManager {
private readonly props: Props

Expand Down Expand Up @@ -137,6 +161,13 @@ export class ConnectionManager {
// Because caching is disabled in stlite. See https://github.com/whitphx/stlite/issues/495
}

/**
* No-op in stlite.
*/
disconnect(): void {
// no-op.
}

private async handleMessage(data: ArrayBuffer): Promise<void> {
// Assign this message an index.
const messageIndex = this.nextMessageIndex
Expand Down
Loading

0 comments on commit 42f3a4a

Please sign in to comment.