Skip to content

Commit

Permalink
Enable LIT in colab and jupyter notebooks.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 346311084
  • Loading branch information
jameswex authored and LIT team committed Dec 9, 2020
1 parent c91fbf8 commit afcf134
Show file tree
Hide file tree
Showing 6 changed files with 243 additions and 4 deletions.
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,5 @@ dependencies:
- umap-learn
- transformers==2.11.0
- google-cloud-translate
- rouge-score
- portpicker
- annoy
6 changes: 4 additions & 2 deletions lit_nlp/dev_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def serve(self):
served by this module.
Returns:
WSGI app if the server type is 'external', otherwise None when
serving is complete.
WSGI app if the server type is 'external', server if the server type
is 'notebook', otherwise None when serving is complete.
"""
while True:
logging.info(get_lit_logo())
Expand All @@ -104,6 +104,8 @@ def serve(self):
# The underlying TSServer registers a SIGINT handler,
# so if you hit Ctrl+C it will return.
server.serve()
if self._server_type == 'notebook':
return server
app.save_cache()
# Optionally, reload server for development.
# Potentially brittle - don't use this for real deployments.
Expand Down
2 changes: 1 addition & 1 deletion lit_nlp/lib/wsgi_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _LoadResource(path):
"""

try:
with open(path,"rb") as f:
with open(path, 'rb') as f:
return f.read()
except IOError as e:
logging.warning('IOError %s on path %s', e, path)
Expand Down
105 changes: 105 additions & 0 deletions lit_nlp/lib/wsgi_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@
# Lint as: python3
"""WSGI servers to power the LIT backend."""

import socket
import threading
from typing import Optional, Text, List
from wsgiref import validate
import wsgiref.simple_server

from absl import logging
import portpicker
from werkzeug import serving as werkzeug_serving


Expand All @@ -42,3 +46,104 @@ def serve(self):
self._app,
use_debugger=False,
use_reloader=False)


class WsgiServerIpv6(wsgiref.simple_server.WSGIServer):
"""IPv6 based extension of the simple WSGIServer."""

address_family = socket.AF_INET6


class NotebookWsgiServer(object):
"""WSGI server for notebook environments."""

def __init__(self, wsgi_app, host: Text = 'localhost',
port: Optional[int] = None, **unused_kw):
"""Initialize the WSGI server.
Args:
wsgi_app: WSGI pep-333 application to run.
host: Host to run on, defaults to 'localhost'.
port: Port to run on. If not specified, then an unused one will be picked.
"""
self._app = wsgi_app
self._host = host
self._port = port
self._server_thread = None
self.can_act_as_model_server = True

@property
def port(self):
"""Returns the current port or error if the server is not started.
Raises:
RuntimeError: If server has not been started yet.
Returns:
The port being used by the server.
"""
if self._server_thread is None:
raise RuntimeError('Server not started.')
return self._port

def stop(self):
"""Stops the server thread."""
if self._server_thread is None:
return
self._stopping.set()
self._server_thread = None
self._stopped.wait()

def serve(self):
"""Starts a server in a thread using the WSGI application provided.
Will wait until the thread has started calling with an already serving
application will simple return.
"""
if self._server_thread is not None:
return
if self._port is None:
self._port = portpicker.pick_unused_port()
started = threading.Event()
self._stopped = threading.Event()
self._stopping = threading.Event()

def build_server(started, stopped, stopping):
"""Closure to build the server function to be passed to the thread.
Args:
started: Threading event to notify when started.
stopped: Threading event to notify when stopped.
stopping: Threading event to notify when stopping.
Returns:
A function that function that takes a port and WSGI app and notifies
about its status via the threading events provided.
"""

def server(port, wsgi_app):
"""Serve a WSGI application until stopped.
Args:
port: Port number to serve on.
wsgi_app: WSGI application to serve.
"""
try:
httpd = wsgiref.simple_server.make_server(self._host, port, wsgi_app)
except socket.error:
# Try IPv6
httpd = wsgiref.simple_server.make_server(
self._host, port, wsgi_app, server_class=WsgiServerIpv6)
started.set()
httpd.timeout = 30
while not stopping.is_set():
httpd.handle_request()
stopped.set()

return server

server = build_server(started, self._stopped, self._stopping)
server_thread = threading.Thread(
target=server, args=(self._port, self._app))
self._server_thread = server_thread

server_thread.start()
started.wait()
131 changes: 131 additions & 0 deletions lit_nlp/notebook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
"""Notebook usage of LIT."""

import html
import json
import os
import pathlib
import random
import typing
from absl import flags
# pytype: disable=import-error
from IPython import display
from lit_nlp import dev_server
from lit_nlp import server_flags
from lit_nlp.lib import wsgi_serving

try:
import google.colab # pylint: disable=g-import-not-at-top,unused-import
is_colab = True
except ImportError:
is_colab = False

flags.FLAGS.set_default('server_type', 'notebook')
flags.FLAGS.set_default('host', 'localhost')
flags.FLAGS.set_default('port', None)


def start_lit(models, datasets, height=1000, proxy_url=None):
"""Start and display a LIT instance in a notebook instance.
Args:
models: A dict of model names to LIT model instances.
datasets: A dict of dataset names to LIT dataset instances.
height: Height to display the LIT UI in pixels. Defaults to 1000.
proxy_url: Optional proxy URL, if using in a notebook with a server proxy.
Defaults to None.
Returns:
Callback method to stop the LIT server.
"""
lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags())
server = typing.cast(wsgi_serving.NotebookWsgiServer, lit_demo.serve())

if is_colab:
_display_colab(server.port, height)
else:
_display_jupyter(server.port, height, proxy_url)

return server.stop


def _display_colab(port, height):
"""Display the LIT UI in colab.
Args:
port: The port the LIT server is running on.
height: The height of the LIT UI in pixels.
"""

shell = """
(async () => {
const url = new URL(
await google.colab.kernel.proxyPort(%PORT%, {'cache': true}));
const iframe = document.createElement('iframe');
iframe.src = url;
iframe.setAttribute('width', '100%');
iframe.setAttribute('height', '%HEIGHT%px');
iframe.setAttribute('frameborder', 0);
document.body.appendChild(iframe);
})();
"""
replacements = [
('%PORT%', '%d' % port),
('%HEIGHT%', '%d' % height),
]
for (k, v) in replacements:
shell = shell.replace(k, v)

script = display.Javascript(shell)
display.display(script)


def _display_jupyter(port, height, proxy_url):
"""Display the LIT UI in colab.
Args:
port: The port the LIT server is running on.
height: The height of the LIT UI in pixels.
proxy_url: Optional proxy URL, if using in a notebook with a server proxy.
"""

frame_id = 'lit-frame-{:08x}'.format(random.getrandbits(64))
shell = """
<iframe id='%HTML_ID%' width='100%' height='%HEIGHT%' frameborder='0'>
</iframe>
<script>
(function() {
const frame = document.getElementById(%JSON_ID%);
const url = new URL(%URL%, window.location);
const port = %PORT%;
if (port) {
url.port = port;
}
frame.src = url;
})();
</script>
"""
if proxy_url is not None:
# Allow %PORT% in proxy_url.
proxy_url = proxy_url.replace('%PORT%', '%d' % port)
replacements = [
('%HTML_ID%', html.escape(frame_id, quote=True)),
('%JSON_ID%', json.dumps(frame_id)),
('%HEIGHT%', '%d' % height),
('%PORT%', '0'),
('%URL%', json.dumps(proxy_url)),
]
else:
replacements = [
('%HTML_ID%', html.escape(frame_id, quote=True)),
('%JSON_ID%', json.dumps(frame_id)),
('%HEIGHT%', '%d' % height),
('%PORT%', '%d' % port),
('%URL%', json.dumps('/')),
]

for (k, v) in replacements:
shell = shell.replace(k, v)

iframe = display.HTML(shell)
display.display(iframe)
1 change: 1 addition & 0 deletions pip_package/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"numpy",
"scipy",
"pandas",
"portpicker",
"scikit-learn",
"sacrebleu",
"umap-learn",
Expand Down

0 comments on commit afcf134

Please sign in to comment.