diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index e7149e6b..a90f0b6f 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -32,7 +32,7 @@ jobs:
name: Build and test (${{ matrix.python-version }})
strategy:
matrix:
- python-version: ["3.7"]
+ python-version: ["3.9"]
defaults:
run:
shell: bash -l {0}
diff --git a/.gitignore b/.gitignore
index ac11a7c1..52d4fe89 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,6 +1,10 @@
-lit_nlp/yarn-error.log
+**/npm-debug.log*
+**/yarn-debug.log*
+**/yarn-error.log*
website/www/**
**/build/**
**/node_modules/**
**/__pycache__/**
**/*.pyc
+
+**/.DS_Store
diff --git a/Dockerfile b/Dockerfile
index 0053073c..255d3a12 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,6 +1,6 @@
# Use the official lightweight Python image.
# https://hub.docker.com/_/python
-FROM python:3.7-slim
+FROM python:3.9-slim
# Update Ubuntu packages and install basic utils
RUN apt-get update
diff --git a/README.md b/README.md
index b55383da..2a0c4960 100644
--- a/README.md
+++ b/README.md
@@ -1,11 +1,12 @@
-# 🔥 Language Interpretability Tool (LIT)
+# 🔥 Learning Interpretability Tool (LIT)
-
+
-The Language Interpretability Tool (LIT) is a visual, interactive
-model-understanding tool for ML models, focusing on NLP use-cases. It can be run
-as a standalone server, or inside of notebook environments such as Colab,
-Jupyter, and Google Cloud Vertex AI notebooks.
+The Learning Interpretability Tool (🔥LIT, formerly known as the Language
+Interpretability Tool) is a visual, interactive ML model-understanding tool that
+supports text, image, and tabular data. It can be run as a standalone server, or
+inside of notebook environments such as Colab, Jupyter, and Google Cloud Vertex
+AI notebooks.
LIT is built to answer questions such as:
@@ -50,12 +51,12 @@ For a broader overview, check out [our paper](https://arxiv.org/abs/2008.05122)
## Download and Installation
-LIT can be installed via pip, or can be built from source. Building from source
-is necessary if you wish to update any of the front-end or core back-end code.
+LIT can be installed via `pip` or built from source. Building from source is
+necessary if you update any of the front-end or core back-end code.
### Install from source
-Download the repo and set up a Python environment:
+Clone the repo and set up a Python environment:
```sh
git clone https://github.com/PAIR-code/lit.git ~/lit
@@ -68,11 +69,11 @@ conda install cudnn cupti # optional, for GPU support
conda install -c pytorch pytorch # optional, for PyTorch
# Build the frontend
-pushd lit_nlp; yarn && yarn build; popd
+(cd lit_nlp; yarn && yarn build)
```
Note: if you see [an error](https://github.com/yarnpkg/yarn/issues/2821)
-running yarn on Ubuntu/Debian, be sure you have the
+running `yarn` on Ubuntu/Debian, be sure you have the
[correct version installed](https://yarnpkg.com/en/docs/install#linux-tab).
### pip installation
@@ -81,62 +82,79 @@ running yarn on Ubuntu/Debian, be sure you have the
pip install lit-nlp
```
-The pip installation will install all necessary prerequisite packages for use
-of the core LIT package. It also installs the code to run our demo examples.
-It does not install the prerequisites for those demos, so you need to install
-those yourself if you wish to run the demos. See
-[environment.yml](./environment.yml) for the list of all packages needed for
-running the demos.
+The `pip` installation will install all necessary prerequisite packages for use
+of the core LIT package.
+
+It **does not** install the prerequisites for the provided demos, so you need to
+install those yourself. See [environment.yml](./environment.yml) for the list of
+packages required to run the demos.
## Running LIT
Explore a collection of hosted demos on the
[LIT website demos page](https://pair-code.github.io/lit/demos).
-Colab notebooks showing the use of LIT inside of notebooks can be found at [lit_nlp/examples/notebooks](./lit_nlp/examples/notebooks).
-A simple example can be viewed
-[here](https://colab.research.google.com/github/pair-code/lit/blob/main/lit_nlp/examples/notebooks/LIT_sentiment_classifier.ipynb).
-
### Quick-start: classification and regression
-To explore classification and regression models tasks from the popular [GLUE benchmark](https://gluebenchmark.com/):
+To explore classification and regression models tasks from the popular
+[GLUE benchmark](https://gluebenchmark.com/):
```sh
python -m lit_nlp.examples.glue_demo --port=5432 --quickstart
```
-Navigate to http://localhost:5432 to access the LIT UI.
+Navigate to http://localhost:5432 to access the LIT UI.
-Your default view will be a
+Your default view will be a
[small BERT-based model](https://arxiv.org/abs/1908.08962) fine-tuned on the
[Stanford Sentiment Treebank](https://nlp.stanford.edu/sentiment/treebank.html),
-but you can switch to
-[STS-B](http://ixa2.si.ehu.es/stswiki/index.php/STSbenchmark) or [MultiNLI](https://cims.nyu.edu/~sbowman/multinli/) using the toolbar or the gear icon in
-the upper right.
+but you can switch to
+[STS-B](http://ixa2.si.ehu.es/stswiki/index.php/STSbenchmark) or
+[MultiNLI](https://cims.nyu.edu/~sbowman/multinli/) using the toolbar or the
+gear icon in the upper right.
+### Quick-start: language modeling
-### Quick start: language modeling
-
-To explore predictions from a pretrained language model (BERT or GPT-2), run:
+To explore predictions from a pre-trained language model (BERT or GPT-2), run:
```sh
-python -m lit_nlp.examples.lm_demo --models=bert-base-uncased \
- --port=5432
+python -m lit_nlp.examples.lm_demo --models=bert-base-uncased --port=5432
```
And navigate to http://localhost:5432 for the UI.
### Notebook usage
-A simple colab demo can be found [here](https://colab.research.google.com/github/PAIR-code/lit/blob/main/lit_nlp/examples/notebooks/LIT_sentiment_classifier.ipynb).
-Just run all the cells to see LIT on an example classification model right in
-the notebook.
+Colab notebooks showing the use of LIT inside of notebooks can be found at
+google3/third_party/py/lit_nlp/examples/notebooks.
+
+We provide a simple
+[Colab demo](https://colab.research.google.com/github/PAIR-code/lit/blob/main/lit_nlp/examples/notebooks/LIT_sentiment_classifier.ipynb).
+Run all the cells to see LIT on an example classification model in the notebook.
### Run LIT in a Docker container
-See [docker.md](https://github.com/PAIR-code/lit/wiki/docker.md) for instructions on running LIT as
-a containerized web app. This is the approach we take for our
-[website demos](https://pair-code.github.io/lit/demos/).
+LIT can be run as a containerized app using [Docker](https://www.docker.com/) or
+your preferred engine. Use the following shell commands to build the default
+Docker image for LIT from the provided `Dockerfile`, and then run a container
+from that image. Comments are provided in-line to help explain each step.
+
+```shell
+# Build the docker image using the -t argument to name the image. Remember to
+# include the trailing . so Docker knows where to look for the Dockerfile
+docker build -t lit_app .
+
+# Now you can run LIT as a containerized app using the following command. Note
+# that the last parameter to the run command is the value you passed to the -t
+# argument in the build command above.
+docker run --rm -p 5432:5432 lit-app
+```
+
+The image above defaults to launching the GLUE demo on port 5432, but you can
+override this using environment variables. See our
+[advanced guide](https://github.com/PAIR-code/lit/wiki/docker.md) for detailed instructions on using the default
+LIT Docker image, running LIT as a containerized web app in different scenarios,
+and how to creating your own LIT images.
### More Examples
@@ -154,15 +172,13 @@ watch this [video](https://www.youtube.com/watch?v=CuRI_VK83dU).
## Adding your own models or data
You can easily run LIT with your own model by creating a custom `demo.py`
-launcher, similar to those in [lit_nlp/examples](./lit_nlp/examples). The basic
-steps are:
+launcher, similar to those in [lit_nlp/examples](./lit_nlp/examples). The
+basic steps are:
-* Write a data loader which follows the
- [`Dataset` API](https://github.com/PAIR-code/lit/wiki/api.md#datasets)
+* Write a data loader which follows the [`Dataset` API](https://github.com/PAIR-code/lit/wiki/api.md#datasets)
* Write a model wrapper which follows the [`Model` API](https://github.com/PAIR-code/lit/wiki/api.md#models)
* Pass models, datasets, and any additional
- [components](https://github.com/PAIR-code/lit/wiki/api.md#interpretation-components) to the LIT server
- class
+ [components](https://github.com/PAIR-code/lit/wiki/api.md#interpretation-components) to the LIT server class
For a full walkthrough, see
[adding models and data](https://github.com/PAIR-code/lit/wiki/api.md#adding-models-and-data).
@@ -170,17 +186,38 @@ For a full walkthrough, see
## Extending LIT with new components
LIT is easy to extend with new interpretability components, generators, and
-more, both on the frontend or the backend. See our
-[documentation](https://github.com/PAIR-code/lit/wiki) to get started.
+more, both on the frontend or the backend. See our [documentation](https://github.com/PAIR-code/lit/wiki) to get
+started.
## Pull Request Process
-To make code changes to LIT, please work off of the `dev` branch and create
-pull requests against that branch. The `main` branch is for stable releases, and it is expected that the `dev` branch will always be ahead of `main` in terms of commits.
+To make code changes to LIT, please work off of the `dev` branch and
+[create pull requests](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request)
+(PRs) against that branch. The `main` branch is for stable releases, and it is
+expected that the `dev` branch will always be ahead of `main`.
+
+[Draft PRs](https://github.blog/2019-02-14-introducing-draft-pull-requests/) are
+encouraged, especially for first-time contributors or contributors working on
+complex tasks (e.g., Google Summer of Code contributors). Please use these to
+communicate ideas and implementations with the LIT team, in addition to issues.
+
+Prior to sending your PR or marking a Draft PR as "Ready for Review", please run
+the Python and TypeScript linters on your code to ensure compliance with
+Google's [Python](https://google.github.io/styleguide/pyguide.html) and
+[TypeScript](https://google.github.io/styleguide/tsguide.html) Style Guides.
+
+```sh
+# Run Pylint on your code using the following command from the root of this repo
+pushd lit_nlp & pylint & popd
+
+# Run ESLint on your code using the following command from the root of this repo
+pushd lit_nlp & yarn lint & popd
+```
## Citing LIT
-If you use LIT as part of your work, please cite [our EMNLP paper](https://arxiv.org/abs/2008.05122):
+If you use LIT as part of your work, please cite
+[our EMNLP paper](https://arxiv.org/abs/2008.05122):
```
@misc{tenney2020language,
@@ -198,8 +235,8 @@ If you use LIT as part of your work, please cite [our EMNLP paper](https://arxiv
This is not an official Google product.
-LIT is a research project, and under active development by a small team.
-There will be some bugs and rough edges, but we're releasing at an early stage
-because we think it's pretty useful already. We want LIT to be an open platform,
-not a walled garden, and we'd love your suggestions and feedback - drop us a
-line in the [issues](https://github.com/pair-code/lit/issues).
+LIT is a research project and under active development by a small team. There
+will be some bugs and rough edges, but we're releasing at an early stage because
+we think it's pretty useful already. We want LIT to be an open platform, not a
+walled garden, and we would love your suggestions and feedback - drop us a line
+in the [issues](https://github.com/pair-code/lit/issues).
diff --git a/RELEASE.md b/RELEASE.md
index c01393eb..e4ef25a7 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -1,4 +1,228 @@
-# Language Interpretability Tool releases
+# Learning Interpretability Tool Release Notes
+
+## Release 0.5
+
+This is a major release, covering many new features from the `dev` branch since
+the v0.4 release nearly 11 months ago. Most notably, we're renaming! It's still
+LIT, but now the L stands for "Learning" instead of "Language", to better
+reflect the scope of LIT and support for non-text modalities like images and
+tabular data. Additionally, we've made lots of improvements, including:
+
+* New modules including salience clustering, tabular feature attribution, and
+ a new Dive module for data exploration (inspired by our prior work on
+ [Facets Dive](https://pair-code.github.io/facets/)).
+* New demos and tutorials for input salience comparison and tabular feature
+ attribution.
+* Many UI improvements, with better consistency across modules and shared
+ functionality for colors, slicing, and faceting of data.
+* Better performance on large datasets (up to 100k examples), as well as
+ improvements to the type system and new validation routines (`--validate`) for
+ models and datasets.
+* Download data as CSV directly from tables in the UI, and in notebook mode
+ access selected examples directly from Python.
+* Update to Python 3.9 and TypeScript 4.7.
+
+This release would not have been possible without the work of many new
+contributors in 2022. Many thanks to
+[Crystal Qian](https://github.com/cjqian),
+[Shane Wong](https://github.com/jswong65),
+[Anjishnu Mukherjee](https://github.com/iamshnoo),
+[Aryan Chaurasia](https://github.com/aryan1107),
+[Animesh Okhade](https://github.com/animeshokhade),
+[Daniel Levenson](https://github.com/dleve123),
+[Danila Sinopalnikov](https://github.com/sinopalnikov),
+[Deepak Ramachandran](https://github.com/DeepakRamachandran),
+[Rebecca Chen](https://github.com/rchen152),
+[Sebastian Ebert](https://github.com/eberts-google), and
+[Yilei Yang](https://github.com/yilei)
+for your support and contributions to this project!
+
+### Breaking Changes
+
+* Upgraded to Python 3.9 –
+ [17bfabd](https://github.com/PAIR-code/lit/commit/17bfabd75959feae4d64e79db695fe38be7a14b0)
+* Upgraded to Typescript 4.7 –
+ [10e2548](https://github.com/PAIR-code/lit/commit/10e25480d43ecfa1800ed77fd5e2b49b69723c39)
+* Layout definitions moved to Python –
+ [05824c8](https://github.com/PAIR-code/lit/commit/05824c88296e9fed48ed6757b2f459ff6cc29968),
+ [d3d19d2](https://github.com/PAIR-code/lit/commit/d3d19d2fbada9c12ab06630494c7cc84f9b3a9c8),
+ [2994d7e](https://github.com/PAIR-code/lit/commit/2994d7e00582cff528e3753b43ce81ced00a1b30),
+ [b78c962](https://github.com/PAIR-code/lit/commit/b78c96227bc760bb5009a1ed119b8fd568076767),
+ [0eacdd0](https://github.com/PAIR-code/lit/commit/0eacdd026d2a0933f67d8aa2b5a1ec9d37a0d2d6)
+* Moving classification and regression results to Interpreters –
+ [2b4e622](https://github.com/PAIR-code/lit/commit/2b4e622922ba35df79e538c3a157356b854a54c6),
+ [bcdbb80](https://github.com/PAIR-code/lit/commit/bcdbb8050ed1cdcd6350a556bbc394e67d4113fe),
+ [dad8edb](https://github.com/PAIR-code/lit/commit/dad8edb8f05af8e4c3e46c352ee689988ec5cc11)
+* Use a Pinning construct instead of comparison mode –
+ [05bfc90](https://github.com/PAIR-code/lit/commit/05bfc906c91b3b748ffc7f3b414a046629ca16b1),
+ [d7bdc65](https://github.com/PAIR-code/lit/commit/d7bdc654f147f879dec97e96f25d95d963fb7caa),
+ [6a4ca00](https://github.com/PAIR-code/lit/commit/6a4ca0018211ed52e7eb24ec3d01ca4c683f179a),
+ [0fe3c79](https://github.com/PAIR-code/lit/commit/0fe3c79352832c594a82a8e52d853c1c29742910),
+ [5b2b737](https://github.com/PAIR-code/lit/commit/5b2b73767a2fb81f90c222786ed2a73b9171969d)
+* Parallel, class-based Specs and LitTypes in Python and TypeScript code
+ * Prep work –
+ [db1ef3d](https://github.com/PAIR-code/lit/commit/db1ef3ddc7bd35df8c75325b9662fa41facf4359),
+ [c85e556](https://github.com/PAIR-code/lit/commit/c85e556eedddf555449cd8e92b3218503b46dbb4),
+ [660b8ef](https://github.com/PAIR-code/lit/commit/660b8ef3d47430e71fc0f9fcfad32a0e7b360557),
+ [db58fa4](https://github.com/PAIR-code/lit/commit/db58fa42d18e605d997dce84f0b08797cc2729dc),
+ [c020d25](https://github.com/PAIR-code/lit/commit/c020d2535a10ea137e25ea5ba87fa6d3d4cecc58),
+ [eb02465](https://github.com/PAIR-code/lit/commit/eb024651e3b09e8bcd836e3558b6cef7e7b70160),
+ [72edd26](https://github.com/PAIR-code/lit/commit/72edd26ed4f71d6b8d81ecefa5d09b508a29861d),
+ [65c5b8a](https://github.com/PAIR-code/lit/commit/65c5b8a93643d4735c51e6ded48dcb3434203e60),
+ [abb8889](https://github.com/PAIR-code/lit/commit/abb88890898848bb5a8fbe84f184a4b2b3a244cf),
+ [4c93b62](https://github.com/PAIR-code/lit/commit/4c93b62da400ae30a86b65e415bd495f3e611449),
+ [40d14e5](https://github.com/PAIR-code/lit/commit/40d14e5985c8dcde384de0b9f5bc469239e269f0),
+ [9ec5324](https://github.com/PAIR-code/lit/commit/9ec53248e8c7b0a2e1ba0996e6084709ce2080ea),
+ [40a661e](https://github.com/PAIR-code/lit/commit/40a661edafc71e1a0ae4f2d88eeb529d04c1172a)
+ * Breaking changes to front-end typing infrastructure –
+ [8c6ac11](https://github.com/PAIR-code/lit/commit/8c6ac1174cd1020c00491736a3d0fa78e05e0eed),
+ [2522e4f](https://github.com/PAIR-code/lit/commit/2522e4f72e96c09a019630623b9061e73b4dce54),
+ [0f8ff8e](https://github.com/PAIR-code/lit/commit/0f8ff8e251aee27654a9e1590c50aa5f75598edc),
+ [58970de](https://github.com/PAIR-code/lit/commit/58970de691dea2be533e9c80e52768b2eb7b8f07),
+ [ef72bfc](https://github.com/PAIR-code/lit/commit/ef72bfc4fcfc2bde06db0db0a7f105e9401d4cd2),
+ [ccbb72c](https://github.com/PAIR-code/lit/commit/ccbb72c60d1eefc71c1eeca50408613cb65e445c),
+ [a5b9f65](https://github.com/PAIR-code/lit/commit/a5b9f658188339c11c28fd43dbe25ff167e06c0b),
+ [ab1e06a](https://github.com/PAIR-code/lit/commit/ab1e06a016fd7b309ee77237adf35f52d43e52d6),
+ [853edd0](https://github.com/PAIR-code/lit/commit/853edd0b03f695aaa5d708312325dc13758070da),
+ [cb528f1](https://github.com/PAIR-code/lit/commit/cb528f1bd502edf9f6ed25734a1ef81cfbff007b),
+ [a36a936](https://github.com/PAIR-code/lit/commit/a36a936689443b4ed2417299e17dcd5a0b49de39),
+ [74b5dbb](https://github.com/PAIR-code/lit/commit/74b5dbbb23259df7c3233cfcedce588ef62def82),
+ [e811359](https://github.com/PAIR-code/lit/commit/e811359cabd092bacf14799ab811c314f6a8bf84)
+ * Build fixes –
+ [948adb3](https://github.com/PAIR-code/lit/commit/948adb3d35894cbd78cc73ddbe2ea8da5a883ace)
+* Minimizing duplication in modules
+ * Classification Results –
+ [4f2b53d](https://github.com/PAIR-code/lit/commit/4f2b53d94c73e210a1def9043623590e077ee1b8)
+ * Scalars, including its migration to Megaplot –
+ [353b96e](https://github.com/PAIR-code/lit/commit/353b96ea5fd0aca9ace2ac47491b99d58cbbbc67),
+ [ed07199](https://github.com/PAIR-code/lit/commit/ed07199189bce50446e05506cdfb8260781977eb),
+ [184c8c6](https://github.com/PAIR-code/lit/commit/184c8c684c1f497f8911a5e886cec604b46c12f9),
+ [14f82d5](https://github.com/PAIR-code/lit/commit/14f82d53b2e41b2cc088db3c1df3ebac5aee193a),
+ [764674a](https://github.com/PAIR-code/lit/commit/764674a0430fc8e55535e09ad4bae4dc1eac1234)
+* Changes to component `is_compatible()` signature
+ * Added checks to some generators –
+ [9b2de92](https://github.com/PAIR-code/lit/commit/9b2de92101b0a0c4961007a0a37fa936ee708e29),
+ [db94849](https://github.com/PAIR-code/lit/commit/db948496d7b040463328ce926499d79e9a4d434d)
+ * Added Dataset parameter to all checks –
+ [ecd3a66](https://github.com/PAIR-code/lit/commit/ecd3a6623f2a0d45ae26c74d0d72fb68b7bcb9aa)
+* Adds `core` components library to encapsulate default interpreters,
+ generators, and metrics –
+ [9ea4ab2](https://github.com/PAIR-code/lit/commit/9ea4ab264f6d9b03ee19ab8af4309e97862c089a)
+* Removed the Color module –
+ [b18d887](https://github.com/PAIR-code/lit/commit/b18d8871ea7ab1d2b5e4c671d33653d32f87d952)
+* Removed the Slice module –
+ [7db22ae](https://github.com/PAIR-code/lit/commit/7db22ae197650935ab916b248ca3c06f8593afb5)
+* Moved star button to Data Table module –
+ [cd14f35](https://github.com/PAIR-code/lit/commit/cd14f355781500b07a433a5df58d2ca0ec8ed6f8)
+* Salience Maps now inside of expansion panels with popup controls –
+ [1994425](https://github.com/PAIR-code/lit/commit/199442552586fa48780a33166cd6927ba4ab3530)
+* Metrics
+ * Promotion to a major `component` type –
+ [de7d8ba](https://github.com/PAIR-code/lit/commit/de7d8ba26e74ecf2fd8a7700352e0d6d469d22ac)
+ * Improved compatibility checks –
+ [0d8341d](https://github.com/PAIR-code/lit/commit/0d8341d9f120359bec86c983c5618dd59bb6f591)
+
+### New Stuff
+
+* Common Color Legend element –
+ [f846772](https://github.com/PAIR-code/lit/commit/f8467720d33dd8ef3d0da5c5a12eed2db37bb4b0),
+ [7a1e26a](https://github.com/PAIR-code/lit/commit/7a1e26a9759882e0bf697363298e68f969c24a84),
+ [0cc934c](https://github.com/PAIR-code/lit/commit/0cc934c8980a6d4563319087fd7e9ee5201acd04)
+* Common Expansion Panel element –
+ [2d67ce](https://github.com/PAIR-code/lit/commit/2d670ce70a6e41d7c2fc1d4d9b8c37c2b3b8876b)
+* Common Faceting Control –
+ [0f46e16](https://github.com/PAIR-code/lit/commit/0f46e166595c83773611a715be694100d89cace0),
+ [b109f9b](https://github.com/PAIR-code/lit/commit/b109f9b8cad9c26f328c1634122fe874309d5b53),
+ [8993f9b](https://github.com/PAIR-code/lit/commit/8993f9b5cd92f0d4fdfcd1c9e654c2aa4e15fb98),
+ [670abeb](https://github.com/PAIR-code/lit/commit/670abeb25dbdc747067fae725a50a873355eb368)
+* Common Popup element –
+ [1994425](https://github.com/PAIR-code/lit/commit/199442552586fa48780a33166cd6927ba4ab3530),
+ [cca3511](https://github.com/PAIR-code/lit/commit/cca3511322189ddb49bb6a533576d01f532a6f23)
+* A new Dive module for exploring your data –
+ [155e0c4](https://github.com/PAIR-code/lit/commit/155e0c4f1fb8198a18186c432bdb1516e9910f9e),
+ [1d17ca2](https://github.com/PAIR-code/lit/commit/1d17ca23245765d2ded6790902eb5c4b9af3c954),
+ [a0da9cf](https://github.com/PAIR-code/lit/commit/a0da9cf0643c2468b06d964b942aa523cd06069c)
+* Copy or download data from Table elements –
+ [d23ecfc](https://github.com/PAIR-code/lit/commit/d23ecfc74993dc932d88e412170cbb3cf6998408)
+* Training Data Attribution module –
+ [5ff9102](https://github.com/PAIR-code/lit/commit/5ff91029b05bea2d47835b81e840387ce8e70294),
+ [c7398f8](https://github.com/PAIR-code/lit/commit/c7398f82f845180192a76eba2c0caade05a5c0bc)
+* Tabular Feature Attribution module with a heatmap mode and
+ [SHAP](https://shap.readthedocs.io/en/latest/index.html) interpreter –
+ [45e526c](https://github.com/PAIR-code/lit/commit/45e526c76c586ba3539f28c0e03ab4adb9825def),
+ [76379ad](https://github.com/PAIR-code/lit/commit/76379adac37f7e284faf979673cbb0399a36d8ee)
+* Salience Clustering module –
+ [8f3c26c](https://github.com/PAIR-code/lit/commit/8f3c26c60b652ae22cbb8c64e4b2212747c40413),
+ [fb795e8](https://github.com/PAIR-code/lit/commit/fb795e8949b4b430c96e6d02d001e0a9aedd6c42),
+ [49faa00](https://github.com/PAIR-code/lit/commit/49faa002d648a4b128c862f63b8202bf739c75d2),
+ [e35d8d8](https://github.com/PAIR-code/lit/commit/e35d8d84edb9bc1ced5c4bc5e7bbcd8307dc99ac),
+ [7505861](https://github.com/PAIR-code/lit/commit/75058615bc46b53c82b8561ae2bf80ff4c0eb2aa),
+ [f970958](https://github.com/PAIR-code/lit/commit/f970958024c821880b3238d7a2f293b947a4e1e7)
+* Selection state syncing in Python notebooks –
+ [08abc2c](https://github.com/PAIR-code/lit/commit/08abc2ca3a25f368823a4a9f3ba9d5b5ebeac7a6),
+ [06613b9](https://github.com/PAIR-code/lit/commit/06613b909173978c1d4648c8b37c28269a783c14)
+* Unified DataService –
+ [9bdc23e](https://github.com/PAIR-code/lit/commit/9bdc23e7890e8afeb7ab6dcc89c8cb7730c10b26),
+ [00749fc](https://github.com/PAIR-code/lit/commit/00749fc0d4f83cad204a69d792f602c63b1ff676)
+* AUC ROC and AUC PR Curve interpreters and module –
+ [51842ba](https://github.com/PAIR-code/lit/commit/51842babef63f9aa29d1d2add14633c4640627fc),
+ [0f9fd4d](https://github.com/PAIR-code/lit/commit/0f9fd4dccc9e6375c577012191f89c3fb7067b01),
+ [0558ef5](https://github.com/PAIR-code/lit/commit/0558ef52276ed6797a7a6f9d88721a50b6d6a792),
+ [4efd58e](https://github.com/PAIR-code/lit/commit/4efd58e788a3cd38852961b136d8461f3b75b3d7)
+* Splash screen documentation –
+ [1f09ae9](https://github.com/PAIR-code/lit/commit/1f09ae9ca326dbaf0e5541f0f24370b56bcc6d1b),
+ [cfabe78](https://github.com/PAIR-code/lit/commit/cfabe7865df5fd51ff8c483296f3fccc0fa30d28),
+ [aca35d8](https://github.com/PAIR-code/lit/commit/aca35d832ad00a4bf35fd27adf35ba76f4d0d87f)
+* Added a `GeneratedURL` type that displays in the Generated Text module –
+ [bb06368](https://github.com/PAIR-code/lit/commit/bb06368602cfcca656746525de16a603e2359cb3)
+* Added new built-in
+ [ROUGE](https://en.wikipedia.org/wiki/ROUGE_(metric)) and Exact Match
+ metrics –
+ [6773927](https://github.com/PAIR-code/lit/commit/67739270434388a63627ee1bc405bc16923dd631),
+ [eac9382](https://github.com/PAIR-code/lit/commit/eac9382cebbc9d1e974ec0e7b6bc1cd528a4df1a)
+* Input Salience demo –
+ [a98edce](https://github.com/PAIR-code/lit/commit/a98edce9caf8e8481f4105cb26b57d5d0429f963),
+ [75ff835](https://github.com/PAIR-code/lit/commit/75ff835ed2e051899d3839ab0ca4360bbf0b9897),
+ [55579de](https://github.com/PAIR-code/lit/commit/55579de33fd292b34fecbcd058686cab1f05fd74)
+* Model and Dataset validation –
+ [0fef77a](https://github.com/PAIR-code/lit/commit/0fef77a7835bbfc9a022a9b1c99b10fc9f5a55c7)
+* Tutorials written by our Google Summer of Code contributor,
+ [Anjishnu Mukherjee](https://github.com/iamshnoo)
+ * Using LIT for Tabular Feature Attribution –
+ [2c0703d](https://github.com/PAIR-code/lit/commit/2c0703d69b3c5d3f9ef5aa4c03fe3c6262e707c3)
+ * Making Sense of Salience Maps –
+ [4159527](https://github.com/PAIR-code/lit/commit/415952702893febbcea9d631ad1a289a3e43e27c)
+
+### Non-breaking Changes, Bug Fixes, and Enhancements
+
+* Added Dataset embeddings to Embeddings projector –
+ [78e2e9c](https://github.com/PAIR-code/lit/commit/78e2e9c05c831fafd360da5f1c3b9b4e12054df9),
+ [3c0929f](https://github.com/PAIR-code/lit/commit/3c0929f9bb293391471d5bc3c1219b6025946354),
+ [e7ac98b](https://github.com/PAIR-code/lit/commit/e7ac98bbabb5b0bf40bd956724cc5a63aef10350)
+* Added a “sparse” mode to Classification Results –
+ [20a8f31](https://github.com/PAIR-code/lit/commit/20a8f316ec0b3d68cd131b785b8dfd6fa61ab3e5)
+* Added “Show only generated” option to Data Table module –
+ [4851c9d](https://github.com/PAIR-code/lit/commit/4851c9de8917d35e2e1cc66d8d33d52f78418acf)
+* Added threshold property for `MulticlassPreds` that allows for default
+ threshold values other than 0.5 –
+ [5e91b19](https://github.com/PAIR-code/lit/commit/5e91b1984700f6c1bb25b05d25e091d8d522c7e9)
+* Added toggle for module duplication direction –
+ [4e05a75](https://github.com/PAIR-code/lit/commit/4e05a759bca13afe857abd10abfdb5229d1ae622)
+* Clickable links in the Generated Images module –
+ [8cf8119](https://github.com/PAIR-code/lit/commit/8cf8119cbdaa5beea2b615d2eadb66630234af38)
+* Constructor parameters for salience interpreters – [
+ ab057b5](https://github.com/PAIR-code/lit/commit/ab057b55a938a59b87f08597050af5adfa2b8bcc)
+* Image upload in Datapoint Editor –
+ [a23b146](https://github.com/PAIR-code/lit/commit/a23b14676c7cb4fa7b82e42f9b6c036108801a54)
+* Markdown support in LIT component descriptions –
+ [0eaa00c](https://github.com/PAIR-code/lit/commit/0eaa00c1f58097c6e77354678f0b603eeabe74cd)
+* Selection updates based on interactions in Metrics module –
+ [c3b6a0c](https://github.com/PAIR-code/lit/commit/c3b6a0cceb300de5e18ff9bd68cf8c29b49b49b8)
+* Support for many. new types of inputs in the Datapoint editor, including
+ `GeneratedText`, `GeneratedTextCandidates`, `MultiSegmentAnnotation`,
+ `Tokens`, `SparseMultilabel`, and `SparseMultilabelPreds`
+* Various styling fixes and code cleanup efforts
+* Docs, FAQ, and README updates
## Release 0.4.1
diff --git a/docs/assets/css/new.css b/docs/assets/css/new.css
index e40f75d0..982b7744 100644
--- a/docs/assets/css/new.css
+++ b/docs/assets/css/new.css
@@ -544,7 +544,9 @@ h5 {
margin-bottom: 20px;
}
-.info-box-text, .info-box-text p {
+.info-box-text,
+.info-box-text p,
+.info-box-text ul {
font-family: "Open Sans";
font-style: normal;
font-weight: normal;
diff --git a/docs/assets/images/actor_to_actress.png b/docs/assets/images/actor_to_actress.png
index 40083bf5..a524b851 100644
Binary files a/docs/assets/images/actor_to_actress.png and b/docs/assets/images/actor_to_actress.png differ
diff --git a/docs/assets/images/lit-coref-compare.png b/docs/assets/images/lit-coref-compare.png
index a008c66d..eaa0c7e7 100644
Binary files a/docs/assets/images/lit-coref-compare.png and b/docs/assets/images/lit-coref-compare.png differ
diff --git a/docs/assets/images/lit-coref-data.png b/docs/assets/images/lit-coref-data.png
index adc2297d..cbfe7ccb 100644
Binary files a/docs/assets/images/lit-coref-data.png and b/docs/assets/images/lit-coref-data.png differ
diff --git a/docs/assets/images/lit-coref-metric-top.png b/docs/assets/images/lit-coref-metric-top.png
index 0cb95772..1c799e91 100644
Binary files a/docs/assets/images/lit-coref-metric-top.png and b/docs/assets/images/lit-coref-metric-top.png differ
diff --git a/docs/assets/images/lit-coref-metrics.png b/docs/assets/images/lit-coref-metrics.png
index 995c4187..2be62bcc 100644
Binary files a/docs/assets/images/lit-coref-metrics.png and b/docs/assets/images/lit-coref-metrics.png differ
diff --git a/docs/assets/images/lit-coref-pred.png b/docs/assets/images/lit-coref-pred.png
index dfe8099b..1d95a94a 100644
Binary files a/docs/assets/images/lit-coref-pred.png and b/docs/assets/images/lit-coref-pred.png differ
diff --git a/docs/assets/images/lit-coref-select.png b/docs/assets/images/lit-coref-select.png
index ca4e7c05..4e71fa1d 100644
Binary files a/docs/assets/images/lit-coref-select.png and b/docs/assets/images/lit-coref-select.png differ
diff --git a/docs/assets/images/lit-metrics-not.png b/docs/assets/images/lit-metrics-not.png
index 7d074c19..3bfe8e10 100644
Binary files a/docs/assets/images/lit-metrics-not.png and b/docs/assets/images/lit-metrics-not.png differ
diff --git a/docs/assets/images/lit-not-saliency.png b/docs/assets/images/lit-not-saliency.png
index 33d9cba6..d7b27791 100644
Binary files a/docs/assets/images/lit-not-saliency.png and b/docs/assets/images/lit-not-saliency.png differ
diff --git a/docs/assets/images/lit-saliency.png b/docs/assets/images/lit-saliency.png
index f9af0032..409810f0 100644
Binary files a/docs/assets/images/lit-saliency.png and b/docs/assets/images/lit-saliency.png differ
diff --git a/docs/assets/images/lit-t5.png b/docs/assets/images/lit-t5.png
index 8db30f4a..78071b9f 100644
Binary files a/docs/assets/images/lit-t5.png and b/docs/assets/images/lit-t5.png differ
diff --git a/docs/assets/images/lit-toolbars.gif b/docs/assets/images/lit-toolbars.gif
index d0b856ea..5588dc4f 100644
Binary files a/docs/assets/images/lit-toolbars.gif and b/docs/assets/images/lit-toolbars.gif differ
diff --git a/docs/assets/images/lit-tweet.gif b/docs/assets/images/lit-tweet.gif
index 4525c18e..473e5b66 100644
Binary files a/docs/assets/images/lit-tweet.gif and b/docs/assets/images/lit-tweet.gif differ
diff --git a/docs/assets/images/lit-workspaces.png b/docs/assets/images/lit-workspaces.png
index ca0b9ff7..92802f47 100644
Binary files a/docs/assets/images/lit-workspaces.png and b/docs/assets/images/lit-workspaces.png differ
diff --git a/docs/assets/images/lit_data_table_annotated.png b/docs/assets/images/lit_data_table_annotated.png
index 830ad979..6c199d0a 100644
Binary files a/docs/assets/images/lit_data_table_annotated.png and b/docs/assets/images/lit_data_table_annotated.png differ
diff --git a/docs/assets/images/lit_slice_editor_annotated.png b/docs/assets/images/lit_slice_editor_annotated.png
index 066eaca9..5749020e 100644
Binary files a/docs/assets/images/lit_slice_editor_annotated.png and b/docs/assets/images/lit_slice_editor_annotated.png differ
diff --git a/docs/assets/images/lit_tcav_screen_annotated.png b/docs/assets/images/lit_tcav_screen_annotated.png
index f81b03f7..526c3312 100644
Binary files a/docs/assets/images/lit_tcav_screen_annotated.png and b/docs/assets/images/lit_tcav_screen_annotated.png differ
diff --git a/docs/assets/images/tab-feat-attr-image-1.png b/docs/assets/images/tab-feat-attr-image-1.png
new file mode 100644
index 00000000..2505b1d0
Binary files /dev/null and b/docs/assets/images/tab-feat-attr-image-1.png differ
diff --git a/docs/assets/images/tab-feat-attr-image-10.png b/docs/assets/images/tab-feat-attr-image-10.png
new file mode 100644
index 00000000..adbf225b
Binary files /dev/null and b/docs/assets/images/tab-feat-attr-image-10.png differ
diff --git a/docs/assets/images/tab-feat-attr-image-11.png b/docs/assets/images/tab-feat-attr-image-11.png
new file mode 100644
index 00000000..95eaf9cc
Binary files /dev/null and b/docs/assets/images/tab-feat-attr-image-11.png differ
diff --git a/docs/assets/images/tab-feat-attr-image-12.png b/docs/assets/images/tab-feat-attr-image-12.png
new file mode 100644
index 00000000..76345e2e
Binary files /dev/null and b/docs/assets/images/tab-feat-attr-image-12.png differ
diff --git a/docs/assets/images/tab-feat-attr-image-2.png b/docs/assets/images/tab-feat-attr-image-2.png
new file mode 100644
index 00000000..3afb969e
Binary files /dev/null and b/docs/assets/images/tab-feat-attr-image-2.png differ
diff --git a/docs/assets/images/tab-feat-attr-image-3.png b/docs/assets/images/tab-feat-attr-image-3.png
new file mode 100644
index 00000000..471352c1
Binary files /dev/null and b/docs/assets/images/tab-feat-attr-image-3.png differ
diff --git a/docs/assets/images/tab-feat-attr-image-4.png b/docs/assets/images/tab-feat-attr-image-4.png
new file mode 100644
index 00000000..5510cf15
Binary files /dev/null and b/docs/assets/images/tab-feat-attr-image-4.png differ
diff --git a/docs/assets/images/tab-feat-attr-image-5.png b/docs/assets/images/tab-feat-attr-image-5.png
new file mode 100644
index 00000000..46c16cd0
Binary files /dev/null and b/docs/assets/images/tab-feat-attr-image-5.png differ
diff --git a/docs/assets/images/tab-feat-attr-image-6.png b/docs/assets/images/tab-feat-attr-image-6.png
new file mode 100644
index 00000000..1ea99f60
Binary files /dev/null and b/docs/assets/images/tab-feat-attr-image-6.png differ
diff --git a/docs/assets/images/tab-feat-attr-image-7.png b/docs/assets/images/tab-feat-attr-image-7.png
new file mode 100644
index 00000000..26a2cd58
Binary files /dev/null and b/docs/assets/images/tab-feat-attr-image-7.png differ
diff --git a/docs/assets/images/tab-feat-attr-image-8.png b/docs/assets/images/tab-feat-attr-image-8.png
new file mode 100644
index 00000000..23a78404
Binary files /dev/null and b/docs/assets/images/tab-feat-attr-image-8.png differ
diff --git a/docs/assets/images/tab-feat-attr-image-9.png b/docs/assets/images/tab-feat-attr-image-9.png
new file mode 100644
index 00000000..93c35ea5
Binary files /dev/null and b/docs/assets/images/tab-feat-attr-image-9.png differ
diff --git a/docs/assets/images/text-salience-image-1.png b/docs/assets/images/text-salience-image-1.png
new file mode 100644
index 00000000..30f695a0
Binary files /dev/null and b/docs/assets/images/text-salience-image-1.png differ
diff --git a/docs/assets/images/text-salience-image-10.png b/docs/assets/images/text-salience-image-10.png
new file mode 100644
index 00000000..fdc3bdfa
Binary files /dev/null and b/docs/assets/images/text-salience-image-10.png differ
diff --git a/docs/assets/images/text-salience-image-11.png b/docs/assets/images/text-salience-image-11.png
new file mode 100644
index 00000000..f771bcfe
Binary files /dev/null and b/docs/assets/images/text-salience-image-11.png differ
diff --git a/docs/assets/images/text-salience-image-12.png b/docs/assets/images/text-salience-image-12.png
new file mode 100644
index 00000000..60cf6083
Binary files /dev/null and b/docs/assets/images/text-salience-image-12.png differ
diff --git a/docs/assets/images/text-salience-image-13.png b/docs/assets/images/text-salience-image-13.png
new file mode 100644
index 00000000..c703eda0
Binary files /dev/null and b/docs/assets/images/text-salience-image-13.png differ
diff --git a/docs/assets/images/text-salience-image-14.png b/docs/assets/images/text-salience-image-14.png
new file mode 100644
index 00000000..83631aeb
Binary files /dev/null and b/docs/assets/images/text-salience-image-14.png differ
diff --git a/docs/assets/images/text-salience-image-15.png b/docs/assets/images/text-salience-image-15.png
new file mode 100644
index 00000000..a8e9eeb1
Binary files /dev/null and b/docs/assets/images/text-salience-image-15.png differ
diff --git a/docs/assets/images/text-salience-image-16.png b/docs/assets/images/text-salience-image-16.png
new file mode 100644
index 00000000..3ff7be26
Binary files /dev/null and b/docs/assets/images/text-salience-image-16.png differ
diff --git a/docs/assets/images/text-salience-image-2.png b/docs/assets/images/text-salience-image-2.png
new file mode 100644
index 00000000..8bc48b74
Binary files /dev/null and b/docs/assets/images/text-salience-image-2.png differ
diff --git a/docs/assets/images/text-salience-image-3.png b/docs/assets/images/text-salience-image-3.png
new file mode 100644
index 00000000..692695b5
Binary files /dev/null and b/docs/assets/images/text-salience-image-3.png differ
diff --git a/docs/assets/images/text-salience-image-4.png b/docs/assets/images/text-salience-image-4.png
new file mode 100644
index 00000000..96582a42
Binary files /dev/null and b/docs/assets/images/text-salience-image-4.png differ
diff --git a/docs/demos/index.html b/docs/demos/index.html
index 6b3b3ecc..89602036 100644
--- a/docs/demos/index.html
+++ b/docs/demos/index.html
@@ -42,7 +42,7 @@
- Language Interpretability Tool
+ Learning Interpretability Tool
@@ -59,7 +59,7 @@
Use LIT with any of three tasks from the General Language Understanding Evaluation (GLUE) benchmark suite. This demo contains binary classification (for sentiment analysis, using SST2), multi-class classification (for textual entailment, using MultiNLI), and regression (for measuringtext similarity, using STS-B).
+
Use LIT with any of three tasks from the General Language Understanding Evaluation (GLUE) benchmark suite. This demo contains binary classification (for sentiment analysis, using SST2), multi-class classification (for textual entailment, using MultiNLI), and regression (for measuring text similarity, using STS-B).
@@ -162,7 +162,19 @@
Use a T5 model to summarize text. For any example of interest, quickly find similar examples from the training set, using an approximate nearest-neighbors index.
LIT supports many techniques like salience maps and counterfactual generators
+for text data. But what if you have a tabular dataset? You might want to find
+out which features (columns) are most relevant to the model’s predictions. LIT's
+Feature Attribution module for
+tabular datasets
+support identification of these important features. This tutorial provides a
+walkthrough for this module within LIT, on the
+Palmer Penguins dataset.
+
+
Kernel SHAP based Feature Attribution
+
The Feature Attribution functionality is
+achieved using SHAP.
+In particular LIT uses
+Kernel SHAP
+over tabular data, which is basically a specially weighted local linear
+regression for estimating SHAP values and works for any model. For now,
+the feature attribution module is only shown in the UI when working with
+tabular data.
+
+
Overview
+
The penguins demo is a
+simple classifier for predicting penguin species from the Palmer Penguins
+dataset. It classifies the penguins as either Adelie, Chinstrap, or Gentoo based
+on 6 features—body mass (g), culmen
+depth (mm), culmen length (mm), flipper length (mm), island, and sex.
+
+
Filtering out incomplete data points
+
Palmer Penguins is a tabular dataset with 344 penguin specimens. LIT’s
+penguin demo filters out 11 of these penguins due to missing information (sex
+is missing for all penguins, though some are missing additional information),
+resulting in 333 data points being loaded for analysis.
+
+
The Feature Attribution module shows up in the bottom right of the demo within
+the Explanations tab. It computes
+Shapley Additive exPlanation (SHAP)
+values for each feature in a set of inputs and displays these values in a table.
+The controls for this module are:
+
+
The sample size slider, which defaults to a value of 30. SHAP
+computations are very expensive and it is infeasible to compute them for the
+entire dataset. Through testing, we found that 30 is about the maximum
+number of samples we can run SHAP on before performance takes a significant
+hit, and it becomes difficult to use above 50 examples. Clicking the Apply
+button will automatically check the Show attributions from the Tabular SHAP
+checkbox, and LIT will start computing the SHAP values.
+
The prediction key selects the model output value for which influence is
+computed. Since the penguin mode only predicts one feature, species, this is
+set to species and cannot be changed. If a model can predict multiple values
+in different fields, for example predicting species and island or species
+and sex, then you could change which output field to explain before clicking
+Apply.
+
The heatmap toggle can be enabled to color code the SHAP values.
+
The facets button and show attributions for selection checkbox
+enable conditionally running the Kernel SHAP interpreter over subsets of the
+data. We will get into the specifics of this with an example later on in
+this tutorial.
+
+
+
+
An overview of the Penguins demo, notice the tabular feature
+ attribution (1) and salience maps (2) modules in the bottom right and
+ center, respectively.
+
+
+
+
The tabular feature attribution module has three main elements of
+ interactivity: an expansion panel where you can configure the SHAP
+ parameters (1), a heatmap toggle to activate color the cells in the
+ results table based on the scores (2), and a facets control for
+ exploring subsets of the data (3).
+
+
A Simple Use Case : Feature Attribution for 10 samples
+
To get started with the module, we set sample size to a small value, 10, and
+start the SHAP computation with heatmap enabled.
+
+
Edge cases for the sample size button
+
Kernel SHAP computes feature importance relative to a pseudo-random
+sample of the dataset. The sample size is set with the slider, and the samples
+are drawn from either the current selection (i.e., a subset of the data that
+were manually selected or are included as part of a slice) or the entire
+dataset. When sampling from the current selection, the sample size can have
+interesting edge cases:
+
+
If the selection is empty, LIT samples the “sample size” number of data points
+from the entire dataset.
+
If the sample size is zero or larger than the selection, then LIT computes
+SHAP for the entire selection and does not sample additional data from the
+dataset.
+
If sample size is smaller than the selection, then LIT samples the “sample
+size” number of data points from the selected inputs.
+
+
+
Enabling the heatmap provides a visual indicator of the polarity and strength of
+a feature's influence. A reddish hue indicates negative attribution for that
+particular feature and a bluish hue indicates positive attribution. The deeper
+the color the stronger its influence on the predictions.
+
+
Interpreting salience polarity
+
Salience is always relative to the model's prediction of one class.
+Intuitively, a positive attribution score for a feature of an example
+means that if this feature was removed we expect a drop in model
+confidence in the prediction of this class. Similarly, removing a
+feature with a negative score would correspond to an increase in the
+model's confidence in the prediction of this class.
+
+
SHAP values are computed per feature per example, from which LIT computes the
+mean, min, median, and max feature values across the examples. The min and max
+values can be used to spot any outliers during analysis. The difference between
+the mean and the median can be used to gain more insights about the
+distribution. All of this enables statistical comparisons and will be enhanced
+in future releases of LIT.
+
Each of the columns in the table can be sorted using the up (ascending) or down
+(descending) arrow symbols in the column headers. The table is sorted in
+ascending alphabetical order of input feature names (field) by default. If there
+are many features in a dataset this space will get crowded, so LIT offers a
+filter button for each of the columns to look up a particular feature or value
+directly.
+
+
+
Start by reducing the sample size from 30 to 10, this will speed
+ up the SHAP computations.
+
+
+
+
The results of the SHAP run over a sample of 10 inputs from the
+ entire dataset. Notice how subtle the salience values are in the "mean"
+ column.
+
+
Faceting & Binning of Features
+
Simply speaking, facets are subsets of the dataset based on specific feature
+values. We can use facets to explore differences in SHAP values between subsets.
+For example, instead of looking at SHAP values from 10 samples containing both
+male and female penguins, we can look at male penguins and female penguins
+separately by faceting based on sex. LIT also allows you to select multiple
+features for faceting, and it will generate the facets by feature crosses. For
+example, if you select both sex (either male or female) and island (one of
+Biscoe, Dream and Torgersen), then LIT will create 6 facets for (Male, Biscoe),
+(Male, Dream), (Male, Torgersen), (Female, Biscoe), (Female, Dream), (Female,
+Torgersen) and show the SHAP values for whichever facets have a non-zero number
+of samples.
+
+
+
Each facet of the dataset is given its own expansion panel. Click
+ on the down arrow on the right to expand the section and see the results
+ for that facet.
+
+
Numerical features support more complex faceting options. Faceting based on
+numerical features allows for defining bins using 4 methods: discrete, equal
+intervals, quantile, and threshold. Equal intervals will evenly divide the
+feature’s domain into N
+equal-sized bins. Quantile will create N bins that each contain (approximately)
+the same number of examples. Threshold creates two bins, one for the examples
+with values up to and including the threshold value, and one for examples with
+values above the threshold value. The discrete method requires specific dataset
+or model spec configuration, and we do not recommend using that method with this
+demo.
+
Categorical and boolean features do not have controllable binning behavior. A
+bin is created for each label in their vocabulary.
+
+
+
Clicking the facets button will open the configuration controls.
+ Use these to configure how divide the dataset into subsets.
+
+
LIT supports as many as 100 facets (aka bins). An indicator in the faceting
+config dialog lets you know how many would be created given the current
+settings.
+
Faceting is not supported for selections, meaning that if you already have a
+selection of elements (let’s say 10 penguins), then facets won’t split it
+further.
+
+
+
LIT limits the number of facets to 100 bins for performance
+ reasons. Attempting to exceed this limit will cause the active features
+ to highlight red so you can adjust their configurations.
+
+
Side-by-side comparison : Salience Maps Vs Tabular Feature Attribution
+
The Feature Attribution module works well in conjunction with other modules. In
+particular, we are going to look at the Salience Maps module which allows us to
+enhance our analysis. Salience Maps work on one data point at a time, whereas
+the Tabular Feature Attribution usually looks at a set of data points.
+
+
Slightly different color scales
+
The color scales are slightly different between the salience maps
+module and the tabular feature attribution module. Salience maps use a
+gamma-adjusted color scale to make values more prominent.
+
+
One random data point
+
In this example, a random data point is chosen using the select random button in
+the top right corner and the unselected data points are hidden in the Data
+Table. After running both the salience maps module and the feature attribution
+module for the selected point, we can see that the values in the mean column of
+Tabular SHAP output match the saliency scores exactly. Note also that the mean,
+min, median and max values are all the same when a single datapoint is selected.
+
+
+
The results in the tabular feature attribution and salience maps
+ modules will be the same for single datapoint selections.
+
+
A slice of 5 random data points
+
LIT uses a complex selection model
+and different modules react to it differently. Salience Maps only care about the
+primary selection (the data point highlighted in a deep cyan hue in the data
+table) in a slice of elements, whereas Feature Attribution uses the entire list
+of selected elements.
+
+
Using Salience Maps to support Tabular Feature Attribution
+
Changing primary selection reruns SHAP in the Salience Maps module
+but not in Tabular Feature Attribution. So, we can effectively toggle
+through the items in our selection one-by-one and see how they compare
+to the mean values in the Feature Attribution module. Another thing to
+note is that the Salience Maps module supports comparison between a
+pinned datapoint and the primary selection, so we can do the above
+comparisons in a pair-wise manner as well.
+
+
As we can see in this example, where we run both modules on a slice of 5
+elements, the Salience Maps module is only providing its output for the primary
+selection (data point 0), whereas the Tabular Feature Attribution module is
+providing values for the entire selection by enabling the “Show attributions for
+selection” checkbox. This allows us to use the salience map module as a kind of
+magnifying glass to focus on any individual example even when we are considering
+a slice of examples in our exploration of the dataset.
+
+
+
The salience maps module is a great way to compare the scores for
+ each datapoint in a selection against the scores for that entire
+ selection from. the tabular feature attribution module.
+
+
Conclusion
+
Tabular Feature Attribution based on Kernel SHAP allows LIT users to explore
+their tabular data and find the most influential features affecting model
+predictions. It also integrates nicely with the Salience Maps module to allow
+for fine-grained inspections. This is the first of many features in LIT for
+exploring tabular data, and more exciting updates would be coming in future
+releases!
+
+
+
+
+
time to read
+
15 minutes
+
takeaways
+
Learn how to use the Kernel SHAP based Tabular Feature Attribution module in LIT.
LIT enables users to analyze individual predictions for text input using
+salience maps, for which gradient-based and/or blackbox methods are available.
+In this tutorial, we will explore how to use salience maps to analyze a text
+classifier in the Classification and Regression models demo
+from the LIT website, and
+how these findings can support counterfactual analysis using LIT’s generators,
+such as Hotflip, to test hypotheses. The Salience Maps module can be found under
+the Explanations tab in the bottom half of this demo and it supports four
+different methods for the GLUE model under test (with other models it might
+support a different number of these methods) -
+Grad L2 Norm,
+Grad · Input,
+Integrated Gradients (IG)
+and LIME.
+
Heuristics : Which salience method for which task?
With those limitations in mind, the question remains as to which methods should
+be used and when. To offer some guidance, we have come up with the following
+decision aid that provides some ideas about which salience method(s) might be
+appropriate.
+
+
+
This flow chart can help you decide which salience interpreter to
+ apply given the information provided by your model.
+
+
If your model does not output gradients with its predictions (i.e., is a
+blackbox), LIME is your only choice as
+it is currently the only black-box method LIT supports for text data.
+
If your model does output gradients, then you can choose among three methods:
+Grad L2 Norm,
+Grad · Input, and
+Integrated Gradients (IG).
+Grad L2 Norm and Grad · Input are easy to use and fast to compute, but can
+suffer from gradient saturation. IG addresses the gradient saturation issue in
+the Grad methods (described in detail below), but requires that the model output
+both gradients and embeddings, is much more expensive to compute, and requires
+parameterization to optimize results.
+
Remember that a good investigative process will check for commonalities and
+patterns across salience values from multiple salience methods. Further,
+salience methods should be an entry point for developing hypotheses about your
+model’s behavior, and for identifying subsets of examples and/or creating
+counterfactual examples that test those hypotheses.
+
Salience Maps for Text : Theoretical background and LIT overview
+
All methods calculate salience, but there are subtle differences in their
+approaches towards calculating a salience score for each token. Grad L2 Norm
+only produces absolute salience scores while other methods like Grad · Input
+(and also Integrated Gradients and LIME) produce signed values, leading to an
+improved interpretation of whether a token has positive or negative influence on
+the prediction.
+
LIT uses different color scales to represent signed and unsigned salience scores.
+Methods that produce unsigned salience values, such as Grad L2 Norm, use a
+purple scale where darker colors indicate greater salience, whereas the other
+methods use a red-to-green scale, with red denoting negative scores and green
+denoting positive.
+
+
Interpreting salience polarity
+
Salience is always relative to the model’s prediction of one class.
+Intuitively, a positive influence score (attribution) for a token (or
+word, depending on your method) in an example means that if this token
+was removed we expect a drop in model confidence in the prediction of
+the class. Similarly, removing a negative token would correspond to an
+increase in the model's confidence in the prediction of this class.
+
+
+
+
The tokens from same example in the SST-2 dataset can have
+ dramatically different scores depending on the interpreter, as seen in
+ this screenshot. Different salience interpreters output scores in
+ different ranges, for example, Grad L2 Norm outputs unsigned values in
+ the range from 0 to 1, denoted by the purple colors (more purple means
+ closer to one), whereas others output signed scores in the range -1 to
+ 1, denoted by the pink to green color scale.
+
+
Token-Based Methods
+
Gradient saturation
+is a potential problem for all of the Gradient based methods, such as
+Grad L2 Norm and
+Grad · Input, that we need to look out
+for. Essentially if the model learning saturates for a particular token, then
+its gradient goes to zero and appears to have zero salience. At the same time,
+some tokens actually have a zero salience score, because they do not affect the
+predictions. And there is no simple way to tell if a token that we are
+interested in is legitimately irrelevant or if we are just observing the effects
+of gradient saturation.
+
The integrated gradients method
+addresses the gradient saturation problem by enriching gradients with
+embeddings.
+Tokens
+are the discrete building blocks of text sequences, but they can also be
+represented as vectors in a
+continuous embedding space.
+IG computes per-token salience as the average salience over a set of local
+gradients computed by interpolating between the token’s embedding vectors and a
+baseline (typically the zero vector). The tradeoff is that IG requires more
+effort to identify the right number of interpolation steps to be
+effective (configurable in LIT’s interface), with the number of steps
+correlating directly with runtime. It also requires more information,
+which the model may or may not be able to provide.
+
+
+
Integrated gradients can be configured to explain a specific
+ class, to normalize the data during analysis, and to interpolate a
+ given number of steps (between 5 and 100).
+
+
Blackbox Methods
+
Some models do not provide tokens or token-level gradients, effectively making
+them blackboxes. LIME can be used with
+these models. LIME works by generating a set of perturbed inputs, generally, by
+dropping out or masking tokens, and training a local linear model to reconstruct
+the original model's predictions. The weights of this linear model are treated
+as the salience values.
+
LIME has two limitations, compared to gradient-based methods:
+
+
it can be slow as it requires many evaluations of the model, and
+
it can be noisy on longer inputs where there are more tokens to ablate.
+
+
We can increase the number of samples to be used for LIME within LIT to
+counter the potential noisiness, however this is at the cost of computation
+time.
+
+
+
LIME can be configured to explain a specific output field and/or
+ class, to use a specific masking token, and to use a specific seed for
+ its random number generator. The most often used configuration
+ parameters are the number of samples and kerne size, which can reduce
+ noise in the results, but also affect the time required for each run.
+
+
Another interesting difference between the gradient based methods and LIME lies
+in how they analyze the input. The gradient based methods use the model’s
+tokenizer, which splits up words into smaller constituents, whereas LIME splits
+the text into words at whitespaces. Thus, LIME’s word-level results are often
+incomparable with the token-level results from other methods, as you can see in
+the salience maps below.
+
+
+
LIME splits the input sentence based on whitespace and punctuation
+ characters, whereas the other methods use the model's tokenizer to
+ separate the input into its constituent parts.
+
+
Single example use-case : Interpreting the salience maps module
+
Let’s take a concrete example and walkthrough how we might use the salience maps
+module and counterfactual generators to analyze the behavior of the sst2-tiny
+model on the classification task.
+
First, let’s refer back to our heuristic for choosing appropriate methods.
+Because sst2-tiny does not have a LSTM architecture, we shouldn't rely too
+much on Grad · Input. So, we are left with Grad L2 Norm, Integrated Gradients
+and LIME to base our decisions on.
+
To gain some confidence in our heuristic, we look for examples where Grad ·
+Input performs poorly compared to the other methods. There are quite a few in
+the dataset, for example the sentence below where Grad · Input predicts
+completely opposite salience scores to its counterparts.
+
+
+
An example of how Grad · Input can perform poorly—all pink
+ values, the opposite of what. other methods found—on certain input
+ and model architecture combinations.
+
+
Use Case 1: Sexism Analysis with Counterfactuals
+
Coming back to our use-case, we want to investigate if the model displays sexist
+behavior for a particular input sentence. We take a datapoint with a negative
+sentiment label, which talks about the performance of an actress in the movie.
+
The key words/tokens (based on salience scores across the three chosen methods)
+in this sentence are “hampered”, “lifetime-channel”, “lead”, “actress”, “her”
+and “depth”. The only words out of this which are related to gender are
+“actress” and “her”. The words “actress” and “her” get a significant weight
+for both Grad L2 Norm and IG, and is assigned a positive score (IG scores are
+slightly stronger than Grad L2 Norm scores), indicating that the gender of the
+person is helping the model be sure of its predictions of this sentence being a
+negative review sentiment. However for LIME, the salience scores for these two
+words is a small negative number, indicating that the gender of the model is
+actually causing a small decrease in model confidence for the prediction of this
+being a negative review. Even with this small disparity between the token-based
+and blackbox methods in the gender related words in the sentence, it turns out
+that these are not the most important words. “Hampered”, “lifetime-channel” and
+“plot” are the dominating words/tokens for this particular example in helping
+the model make its decision. We still want to explore if reversing the gender
+might change this. Would it make the model give more or less importance to other
+tokens or the tokens we replaced? Would it change the model prediction
+confidence scores?
+
To do this, we generate a counterfactual example using the Datapoint
+Editor which is located right beside the Data Table in the UI, changing
+"actress" with "actor" and "her" with "his" after selecting our datapoint of
+interest. An alternative to this approach is to use the Word replacer under the
+Counterfactuals tab in the bottom half of the LIT app to achieve the same task.
+If our model is predicting a negative sentiment due to sexist influences towards
+“actress” or “her”, then the hypothesis is that it should show opposite
+sentiments if we flip those key tokens.
+
+
+
Manually generating a counterfactual example in the Datapoint
+ Editor, in this case changing "actress" to "actor" and "her" to "his",
+ does not induce much change in the token salience scores.
+
+
However, it turns out that there is very minimal (and hence negligible) change
+in the salience score values of any of the tokens. The model doesn't change its
+prediction either. It still predicts this to be a negative review sentiment with
+approximately the same prediction confidence. This indicates that at least for
+this particular example, our model isn’t displaying sexist behavior and is
+actually making its prediction based on key tokens in the sentence which are not
+related to the gender of the actress/actor.
+
Use Case 2: Pairwise Comparisons
+
Let’s take another example. This time we consider the sentence “a sometimes
+tedious film” and generate three counterfactuals, first by replacing the two
+words “sometimes” and “tedious” with their respective antonyms one-by-one and
+then together to observe the changes in predictions and salience.
+
To create the counterfactuals, we can simply use the Datapoint Editor which is
+located right beside the Data Table in the UI. We can just select our data point
+of interest (data point 6), and then replace the words we are interested in with
+the respective substitutes. Then we assign a label to the newly created
+sentence and add it to our data. For this particular example, we are assigning 0
+when "tedious" appears and 1 when "exciting" appears in the sentence. An
+alternative to this approach is to use the Word replacer under the
+Counterfactuals tab in the bottom half of the LIT app to achieve the same task.
+
+
+
The Data Table and Datapoint Editor modules showing the three
+ manually generated counterfactuals that will be used to explore pairwise
+ comparisons of salience results.
+
+
We can pin the original sentence in the data table and then cycle through the
+three available pairs by selecting each of the new sentences as our primary
+selection. This will give us a comparison-type output in the Salience Maps
+module between the pinned and the selected examples.
+
+
+
A table of the salience scores for each token in the inputs.
+
+
When we replace “sometimes” with “often”, it gets a negative score of almost
+equal magnitude (reversing polarity) from LIME which makes sense, because
+“often” makes the next word in the sentence more impactful, linguistically. The
+model prediction doesn’t change either, and this new review is still classified
+as having a negative sentiment.
+
+
+
Replacing "sometimes" with "often" had a minimal impact on the
+ gradient-based salience interpreters, but it did flip the polarity of
+ that token in the LIME results.
+
+
On replacing “tedious” with “exciting”, the salience for “sometimes” changes
+from positive score to negative in the LIME output. In the IG output,
+“sometimes” changes from a strong positive score to a weak positive score. These
+changes are also justified because in this new sentence “sometimes” counters the
+positive effect of the word “exciting”. The main negative word in our original
+datapoint was “tedious” and by replacing this with a positive word “exciting”,
+the model’s classification of this new sentence also changes and the new
+sentence is classified as positive with a very high confidence score.
+
+
+
Replacing "tedious" with "exciting" had a substantial impact on
+ the salience interpreters that output signed results, but only a minimal
+ impact on the Grad L2 Norm interpreter.
+
+
And finally, when we replace both “sometimes tedious” with “often exciting”, we
+get strong positive scores from both LIME and IG, which is in line with the
+overall strong positive sentiment of the sentence. The model predicts this new
+sentence as positive sentiment, and the confidence score for this prediction is
+slightly higher than the previous sentence where instead of “often” we had used
+“sometimes”. This makes sense as well because “often” enhances the positive
+sentiment slightly more than using “sometimes” in a positive review.
+
+
+
Replacing both "sometimes" and "tedious" has a substantial impact
+ on all salience interpreters, attenuating some results, accentuating
+ others, and in the case of Grad · Input, demonstrating how this
+ counterfactual captures an opposing sentiment to the original.
+
+
In this second example, we mostly based our observation on LIME and IG, because
+we could observe visual changes directly from the outputs of these methods. Grad
+L2 Norm outputs were comparatively inconclusive, highlighting the need to
+select appropriate methods and compare results between them. The model
+predictions were in
+line with our expected class labels and the confidence scores for predictions on
+the counterfactuals could be justified using salience scores assigned to the new
+tokens.
+
Use Case 3: Quality Assurance
+
A real life use case for the salience maps module can be in Quality Assurance.
+For example, if there is a failure in production (e.g., wrong results for a search
+query), we know the text input and the label the model predicted. We can use LIT
+Salience Maps to debug this failure and figure out which tokens were most
+influential in the prediction of the wrong label, and which alternative labels
+could have been predicted (i.e., is there one clear winner, or are there a few
+that are roughly the same?). Once we are done with debugging using LIT, we can
+make the necessary changes to the model or training data (eg. adding fail-safes
+or checks) to solve the production failure.
+
Conclusion
+
Three gradient-based salience methods and one black box method are provided out
+of the box to LIT users who need to use these post-hoc interpretations to make
+sense of their language model’s predictions. This diverse array of built-in
+techniques can be used in combination with other LIT modules like
+counterfactuals to support robust exploration of a model's behavior, as
+illustrated in this tutorial. And as always, LIT strives to enable users to
+add their own salience interpreters
+to allow for a wider variety of use cases beyond these default capabilities!
+
+
+
+
+
time to read
+
15 minutes
+
takeaways
+
Learn how to use salience maps for text data in LIT.
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/docs/tutorials/tour/index.html b/docs/tutorials/tour/index.html
index 7aa59d3c..0434e68a 100644
--- a/docs/tutorials/tour/index.html
+++ b/docs/tutorials/tour/index.html
@@ -14,7 +14,7 @@
gtag('config', 'G-Q74F5RJLXB');
- A Quick Tour of the Language Interpretability Tool
+ A Quick Tour of the Learning Interpretability Tool
@@ -38,7 +38,7 @@
- Language Interpretability Tool
+ Learning Interpretability Tool
@@ -55,7 +55,7 @@
The Language Interpretability Tool (LIT) is a modular and extensible tool to interactively analyze and debug a variety of NLP models. LIT brings together common machine learning performance checks with interpretability methods specifically designed for NLP.
+
The Learning Interpretability Tool (LIT) is a modular and extensible tool to interactively analyze and debug a variety of NLP models. LIT brings together common machine learning performance checks with interpretability methods specifically designed for NLP.
Building blocks - modules, groups, and workspaces
Modules, groups, and workspaces form the building blocks of LIT. Modules are discrete windows in which you can perform a specific set of tasks or analyses. Workspaces display combinations of modules known as groups, so you can view different visualizations and interpretability methods side-by-side.
-
Above: Building blocks of the Language Interpretability Tool: (1) Modules, (2) Groups, (3) Static workspace, (4) Group-based workspace.
+
Above: Building blocks of the Learning Interpretability Tool: (1) Modules, (2) Groups, (3) Static workspace, (4) Group-based workspace.
LIT is divided into two workspaces - a Main workspace in the upper half of the interface, and a Group-based workspace in the lower half.
The Main workspace contains core modules that play a role in many analyses. By default, these include:
@@ -131,11 +131,11 @@
Using Modules
time to read
7 minutes
takeaways
-
Get familiar with the interface of the Language Interpretability Tool.
+
Get familiar with the interface of the Learning Interpretability Tool.
diff --git a/documentation/api.md b/documentation/api.md
index 53f41ec1..714c85cb 100644
--- a/documentation/api.md
+++ b/documentation/api.md
@@ -1,17 +1,17 @@
# LIT Python API
-
+
## Design Overview
-LIT is a modular system, consisting of a collection of backend components
-(written in Python) and frontend modules (written in TypeScript). Most users
-will develop against the Python API, which is documented below and allows LIT to
-be extended with custom models, datasets, metrics, counterfactual generators,
-and more. The LIT server and components are provided as a library which users
-can use through their own demo binaries or via Colab.
+LIT is a modular system, comprising a collection of backend components (written
+in Python) and frontend modules (written in TypeScript). Most users will develop
+against the Python API, which is documented below and allows LIT to be extended
+with custom models, datasets, metrics, counterfactual generators, and more. The
+LIT server and components are provided as a library which users can use through
+their own demo binaries or via Colab.
The components can also be used as regular Python classes without starting a
server; see [below](#using-components-outside-lit) for details.
@@ -26,15 +26,15 @@ simplifies component design and allows interactive use of large models like BERT
or T5.
The frontend is a stateful single-page app, built using
-[lit-element](https://lit-element.polymer-project.org/)[^1] for modularity and
-[MobX](https://mobx.js.org/) for state management. It consists of a core UI
-framework, a set of shared "services" which manage persistent state, and a set
-of independent modules which render visualizations and support user interaction.
-For more details, see the [UI guide](./ui_guide.md) and the
+[Lit](https://lit.dev/)[^1] for modularity and [MobX](https://mobx.js.org/) for
+state management. It consists of a core UI framework, a set of shared "services"
+which manage persistent state, and a set of independent modules which render
+visualizations and support user interaction. For more details, see the
+[UI guide](./ui_guide.md) and the
[frontend developer guide](./frontend_development.md).
-[^1]: Naming is just a happy coincidence; the Language Interpretability Tool is
- not related to the lit-html or lit-element projects.
+[^1]: Naming is just a happy coincidence; the Learning Interpretability Tool is
+ not related to the Lit projects.
## Adding Models and Data
@@ -70,6 +70,24 @@ and [`Model`](#models) classes implement this, and provide metadata (see the
For pre-built `demo.py` examples, check out
https://github.com/PAIR-code/lit/tree/main/lit_nlp/examples
+### Validating Models and Data
+
+Datasets and models can optionally be validated by LIT to ensure that dataset
+examples match their spec and that model output values match their spec.
+This can be very helpful during development of new model and dataset wrappers
+to ensure correct behavior in LIT.
+
+At LIT server startup, the `validate` flag can be used to enable validation.
+There are three modes:
+
+* `--validate=first` will check the first example in each dataset.
+* `--validate=sample` will validate a sample of 5% of each dataset.
+* `--validate=all` will run validation on all examples from all datasets.
+
+Additionally, if using LIT datasets and models outside of the LIT server,
+validation can be called directly through the
+[`validation`](../lit_nlp/lib/validation.py) module.
+
## Datasets
Datasets ([`Dataset`](../lit_nlp/api/dataset.py)) are
@@ -282,7 +300,7 @@ You can also implement multi-headed models this way: simply add additional
output fields for each prediction (such as another `MulticlassPreds`), and
they'll be automatically detected.
-See the [type system documentation](#type-system) for more details on avaible
+See the [type system documentation](#type-system) for more details on available
types and their semantics.
### Optional inputs
@@ -313,6 +331,55 @@ use these and bypass the tokenizer:
lit_types.CategoryLabel(required=False)`), though these can also be omitted from
the input spec entirely if they are not needed to compute model outputs.
+## UI Layouts
+
+You can also specify one or more custom layouts for the frontend UI. To do this,
+pass a dict of `LitCanonicalLayout` objects in `layouts=` when initializing the
+server. These objects represent a tabbed layout of modules, such as:
+
+```python
+LM_LAYOUT = layout.LitCanonicalLayout(
+ upper={
+ "Main": [
+ modules.EmbeddingsModule,
+ modules.DataTableModule,
+ modules.DatapointEditorModule,
+ ]
+ },
+ lower={
+ "Predictions": [
+ modules.LanguageModelPredictionModule,
+ modules.ConfusionMatrixModule,
+ ],
+ "Counterfactuals": [modules.GeneratorModule],
+ },
+ description="Custom layout for language models.",
+)
+```
+
+You can pass this to the server as:
+
+```python
+lit_demo = dev_server.Server(
+ models,
+ datasets,
+ # other args...
+ layouts={"lm": LM_LAYOUT},
+ **server_flags.get_flags())
+return lit_demo.serve()
+```
+
+For a full example, see
+[`lm_demo.py`](../lit_nlp/examples/lm_demo.py) You
+can see the default layouts as well as the list of available modules in
+[`layout.py`](../lit_nlp/api/layout.py).
+
+To use a specific layout for a given LIT instance, pass the key (e.g., "simple"
+or "default" or the name of a custom layout defined in Python) as a server flag
+when initializing LIT (`--default_layout=`). Commonly, this is done
+using `FLAGS.set_default('default_layout', 'my_layout_name')`. The layout can
+also be set on-the-fly the `layout=` URL param, which will take precedence.
+
## Interpretation Components
Backend interpretation components include metrics, salience maps, visualization
@@ -441,8 +508,8 @@ on the unpacked values.
### Generators
Conceptually, a generator is just an interpreter that returns new input
-examples. These may depend on the input only, as for techniques such as
-backtranslation, or can involve feedback from the model, such as for adversarial
+examples. These may depend on the input only, as for techniques such as back-
+translation, or can involve feedback from the model, such as for adversarial
attacks.
The core generator API is:
@@ -472,7 +539,7 @@ class Generator(Interpreter):
Where the output is a list of lists: a set of generated examples for each input.
For convenience, there is also a `generate()` method which takes a single
example and returns a single list; we provide the more general `generate_all()`
-API to support model-based generators (such as backtranslation) which benefit
+API to support model-based generators (such as back-translation) which benefit
from batched requests.
As with other interpreter components, a generator can take custom arguments
@@ -496,7 +563,7 @@ backtranlator generator if you pass it as a generator in the Server constructor.
Interpreter components support an optional `config` option to specify run-time
options, such as the number of samples for LIME or the pivot languages for
-backtranslation. LIT provides a simple DSL to define these options, which will
+back-translation. LIT provides a simple DSL to define these options, which will
auto-generate a form on the frontend. The DSL uses the same
[type system](#type-system) as used to define data and model outputs, and the
`config` argument will be passed a dict with the form values.
@@ -514,9 +581,9 @@ For example, the following spec:
}
```
-will give this form to configure backtranslation:
+will give this form to configure back-translation:
-![Backtranslation Config Form](./images/api/backtranslation-form-example.png)
+![Back-translation Config Form](./images/api/backtranslation-form-example.png)
Currently `config_spec()` is supported only for generators and salience methods,
though any component can support the `config` argument to its `run()` method,
@@ -525,11 +592,12 @@ which can be useful if
The following [types](#available-types) are supported (see
[interpreter_controls.ts](../lit_nlp/client/elements/interpreter_controls.ts)):
+
* `Scalar`, which creates a slider for setting a numeric option. You can
specify the `min_val`, `max_val`, `default`, and `step`, values for the
slider through arguments to the `Scalar` constructor.
-* `Boolean`, which creates a checkbox, with a `default` value to be set in
- the constructor.
+* `Boolean` (`BooleanLitType` in TypeScript), which creates a checkbox, with
+ a `default` value to be set in the constructor.
* `CategoryLabel`, which creates a dropdown with options specified in the
`vocab` argument.
* `SparseMultilabel`, which creates a series of checkboxes for each option
@@ -539,13 +607,14 @@ The following [types](#available-types) are supported (see
* `Tokens`, which creates an input text box for entry of multiple,
comma-separated strings which are parsed into a list of strings to be
supplied to the interpreter.
-* `FieldMatcher`, which acts like a `CategoryLabel` but where the vocab is
- automatically populated by the names of fields from the data or model spec.
- For example, `FieldMatcher(spec='dataset', types=['TextSegment'])` will give
- a dropdown with the names of all `TextSegment` fields in the dataset.
-* `MultiFieldMatcher` is similar to `FieldMatcher` except it gives a set of
- checkboxes to select one or more matching field names. The returned value in
- `config` will be a list of string values.
+* `SingleFieldMatcher`, which acts like a `CategoryLabel` but where the vocab
+ is automatically populated by the names of fields from the data or model
+ spec. For example, `SingleFieldMatcher(spec='dataset',
+ types=['TextSegment'])` will give a dropdown with the names of all
+ `TextSegment` fields in the dataset.
+* `MultiFieldMatcher` is similar to `SingleFieldMatcher` except it gives a set
+ of checkboxes to select one or more matching field names. The returned value
+ in `config` will be a list of string values.
The field matching controls can be useful for selecting one or more fields to
operate on. For example,to choose which input fields to perturb, or which output
@@ -589,15 +658,16 @@ lime.run([dataset.examples[0]], model, dataset)
# will return {"tokens": ..., "salience": ...} for each example given
```
-For a full working example in Colab, see https://colab.research.google.com/github/pair-code/lit/blob/dev/lit_nlp/examples/notebooks/LIT_Components_Example.ipynb.
+For a full working example in Colab, see https://colab.research.google.com/github/pair-code/lit/blob/dev/lit_nlp/examples/notebooks/LIT_components_example.ipynb.
## Type System
Input examples and model outputs in LIT are flat records (i.e. Python `dict` and
JavaScript `object`). Field names (keys) are user-specified strings, and we use
a system of "specs" to describe the types of the values. This spec system is
-semantic: in addition to defining the datatype (string, float, etc.), spec types
-define how a field should be interpreted by LIT components and frontend modules.
+semantic: in addition to defining the data type (string, float, etc.), spec
+types define how a field should be interpreted by LIT components and frontend
+modules.
For example, the [MultiNLI](https://cims.nyu.edu/~sbowman/multinli/) dataset
might define the following spec:
@@ -674,7 +744,9 @@ to provide access to model internals. For a more detailed example, see the
The actual spec types, such as `MulticlassLabel`, are simple dataclasses (built
using [`attr.s`](https://www.attrs.org/en/stable/). They are defined in Python,
-but are available in the [TypeScript client](client.md) as well.
+but are available in
+[TypeScript](../lit_nlp/client/lib/lit_types.ts) as
+well.
[`utils.find_spec_keys()`](../lit_nlp/lib/utils.py)
(Python) and
@@ -729,6 +801,10 @@ Values can be plain data, NumPy arrays, or custom dataclasses - see
[serialize.py](../lit_nlp/api/serialize.py) for
further detail.
+*Note: Note that `String`, `Boolean` and `URL` types in Python are represented
+as `StringLitType`, `BooleanLitType` and `URLLitType` in TypeScript to avoid
+naming collisions with protected TypeScript keywords.*
+
### Conventions
The semantics of each type are defined individually, and documented in
@@ -768,6 +844,12 @@ to `dev_server.Server()`. These include:
the section below for available layouts.
* `demo_mode`: demo / kiosk mode, which disables some functionality (such as
save/load datapoints) which you may not want to expose to untrusted users.
+* `inline_doc`: a markdown string that will be rendered in a documentation
+ module in the main LIT panel.
+* `onboard_start_doc`: a markdown string that will be rendered as the first
+ panel of the LIT onboarding splash-screen.
+* `onboard_end_doc`: a markdown string that will be rendered as the last
+ panel of the LIT onboarding splash-screen.
For detailed documentation, see
[server_flags.py](../lit_nlp/server_flags.py).
diff --git a/documentation/components.md b/documentation/components.md
index 4bc2927e..ae38dfd3 100644
--- a/documentation/components.md
+++ b/documentation/components.md
@@ -1,6 +1,6 @@
# Components and Features
-
+
@@ -115,8 +115,9 @@ implemented with the `MulticlassPreds` and `CategoryLabel` types.
field should set the `parent=` attribute to the name of this field.
* A negative class can be designated using the `null_idx` attribute of
`MulticlassPreds` (most commonly, `null_idx=0`), and metrics such as
- precision, recall, and F1 will be computed for the remaining classes. For an
- example, see the
+ precision, recall, F1 will be computed for the remaining classes. AUC and
+ AUCPR will be computed for binary classification tasks. For an example, see
+ the
[comment toxicity model](../lit_nlp/examples/models/glue_models.py?l=518&rcl=386779180).
* If `null_idx` is set and there is only one other class, the other class
(often, class `1`) is treated as a positive class, and the LIT UI can be
@@ -215,13 +216,21 @@ and otherwise to different parts of the input.
### Tabular data
-While many uses of LIT involve natural language inputs, data can also contain or
-consist of categorical or scalar features using the `CategoryLabel` or `Scalar`
-types. LIT can be used as a replacement for the [What-If Tool](https://whatif-tool.dev),
-containing a similar feature set but with more extensibility.
+LIT can be used as a replacement for the [What-If Tool](https://whatif-tool.dev)
+but with more extensibility, when working with predictions over tabular data.
-* For a demo using a penguin stats dataset/binary classification task, see
- [lit_nlp/examples/penguin_demo.py](../lit_nlp/examples/penguin_demo.py).
+Some interpreters, such as Kernel SHAP, require models that use tabular data. In
+these cases, LIT validates model compatibility by checking that:
+
+* The model inputs (`input_spec()`) are exclusively categorical
+ (`CategoryLabel`) or numeric (`Scalar`), and none of these are marked as
+ optional (`required=False`).
+* The model outputs include at least one classification (`MulticlassPreds`),
+ regression (`RegressionScore` or `Scalar`), or multilabel
+ (`SparseMultilabel`) field.
+
+For a demo using a penguin stats dataset/binary classification task, see
+google3/third_party/py/lit_nlp/examples/penguin_demo.py.
### Images
@@ -259,8 +268,9 @@ soon. Available methods include:
### Gradient Norm
This is a simple method, in which salience scores are proportional to the L2
-norm of the gradient, i.e. the score for token $$i$$ is $$ S(i) \propto
-||\nabla_{x_i} \hat{y}||_2 $$.
+norm of the gradient, i.e. the score for token $i$ is:
+
+$$S(i) \propto ||\nabla_{x_i} \hat{y}||_2$$
To enable this method, your model should, as part of the
[output spec and `predict()` implementation](./api.md#models):
@@ -270,22 +280,25 @@ To enable this method, your model should, as part of the
* Return a `TokenGradients` field with the `align` attribute pointing to the
name of the `Tokens` field (i.e. `align="tokens"`). Values should be arrays
of shape `[num_tokens, emb_dim]` representing the gradient
- $$\nabla_{x} \hat{y}$$ of the embeddings with respect to the prediction
- $$\hat{y}$$.
+ $\nabla_{x} \hat{y}$ of the embeddings with respect to the prediction
+ $\hat{y}$.
Because LIT is framework-agnostic, the model code is responsible for performing
the gradient computation and returning the result as a NumPy array. The choice
-of $$\hat{y}$$ is up to the developer; typically for regression/scoring this is
+of $\hat{y}$ is up to the developer; typically for regression/scoring this is
the raw score and for classification this is the score of the predicted (argmax)
class.
### Gradient-dot-Input
In this method, salience scores are proportional to the dot product of the input
-embeddings and their gradients, i.e. for token $$i$$ we take $$ S(i) \propto x_i
-\cdot \nabla_{x_i} \hat{y}$$. Compared to grad-norm, this gives directional
+embeddings and their gradients, i.e. for token $i$ we compute:
+
+$$S(i) \propto x_i \cdot \nabla_{x_i} \hat{y}$$
+
+Compared to grad-norm, this gives directional
scores: a positive score is can be interpreted as that token having a positive
-influence on the prediction $$\hat{y}$$, while a negative score suggests that
+influence on the prediction $\hat{y}$, while a negative score suggests that
the prediction would be stronger if that token was removed.
To enable this method, your model should, as part of the
@@ -294,13 +307,13 @@ To enable this method, your model should, as part of the
* Return a `Tokens` field with values (as `List[str]`) containing the
tokenized input.
* Return a `TokenEmbeddings` field with values as arrays of shape
- `[num_tokens, emb_dim]` containing the input embeddings $$x$$.
+ `[num_tokens, emb_dim]` containing the input embeddings $x$.
* Return a `TokenGradients` field with the `align` attribute pointing to the
name of the `Tokens` field (i.e. `align="tokens"`), and the `grad_for`
attribute pointing to the name of the `TokenEmbeddings` field. Values should
be arrays of shape `[num_tokens, emb_dim]` representing the gradient
- $$\nabla_{x} \hat{y}$$ of the embeddings with respect to the prediction
- $$\hat{y}$$.
+ $\nabla_{x} \hat{y}$ of the embeddings with respect to the prediction
+ $\hat{y}$.
As with grad-norm, the model should return embeddings and gradients as NumPy
arrays. The LIT `GradientDotInput` component will compute the dot products and
@@ -326,7 +339,7 @@ needed for grad-dot-input, and also to *accept* modified embeddings as input.
* The model should have an additional field ("grad_class", below) which is
used to pin the gradients to a particular target class. This is necessary
because we want to integrate gradients with respect to a single target
- $$\hat{y}$$, but the argmax prediction may change over the integration path.
+ $\hat{y}$, but the argmax prediction may change over the integration path.
This field can be any type, though for classification models it is typically
a `CategoryLabel`. The value of this on the original input (usually, the
argmax class) is stored and fed back in to the model during integration.
@@ -381,6 +394,21 @@ can increase the number of samples:
LIME works out-of-the-box with any classification (`MulticlassPreds`) or
regression/scoring (`RegressionScore`) model.
+### Salience Clustering
+
+LIT includes a basic implementation of the salience clustering method from
+[Ebert et al. 2022](https://arxiv.org/abs/2211.05485), which uses k-means on a
+salience-weighted bag-of-words representation to find patterns in model
+behavior. This method is available using any of the token-based salience methods
+above, and if enabled will appear in the "Salience Clustering" tab:
+
+![Salience clustering UI](./images/components/salience-clustering.png)
+
+To run clustering, select a group of examples or the entire dataset, choose a
+salience method, and run using the "Apply" button. The result will be a set of
+top tokens for each cluster, as in Table 6 of
+[the paper](https://arxiv.org/pdf/2211.05485.pdf).
+
## Pixel-based Salience
LIT also supports pixel-based salience methods, for models that take images as
@@ -410,9 +438,9 @@ your model should, as part of the
A variety of image saliency techniques are implemented for models that return
image gradients, through use of the
-[PAIR-code saliency library](https://github.com/PAIR-code/saliency), including *
-Integrated gradients * Guided integrated gradients * Blurred integrated
-gradients * XRAI
+[PAIR-code saliency library](https://github.com/PAIR-code/saliency), including
+integrated gradients, guided integrated gradients, blurred integrated gradients,
+and XRAI.
Each of these techniques returns a saliency map image as a base64-encoded string
through the `ImageSalience` type.
@@ -603,6 +631,33 @@ datapoints, giving a global view of feature effects.
![Partial Dependence Plots Module](./images/components/lit-pdps.png)
+### Dive
+
+Dive is a visualization module, inspired by our prior work on
+[Facets Dive](https://pair-code.github.io/facets/) and its use in the
+[What-If Tool](https://pair-code.github.io/what-if-tool/), that enables
+exploration of data subsets grouped by feature values.
+
+![Dive module](./images/components/dive.png)
+
+Data are displayed in a matrix of groups based on feature values, with each
+group containing the datapoints at the intersection of the feature values for
+that column and row. Use the drop-downs at the top to select the feature to use
+for the rows and columns in the matrix. You can use the "Color By" drop-down in
+the main toolbar to change the feature by which datapoints are colored in the
+matrix.
+
+This visualization is powered by
+[Megaplot](https://github.com/PAIR-code/megaplot), which allows it to support up
+to 100k datapoints. Dive support mouse-based zoom (scroll) and pan (drag)
+interactions to help you navigate these very large datasets. You can also use
+the "zoom in", "zoom out", and "reset view" buttons in the module toolbar to
+help navigate with more precision.
+
+Dive is currently integrated in the
+[Penguins demo](https://pair-code.github.io/lit/demos/penguins.html), and will
+be supported in other demos in future releases.
+
## TCAV
Many interpretability methods provide importance values per input feature (e.g,
@@ -672,8 +727,8 @@ CAVs using the selected concept slice and random splits of the same size from
the remainder of the dataset. We also generate 15 random CAVs using random
splits against random splits. We then do a t-test to check if these two sets of
scores are from the same distribution and reject CAVs as insignificant if the
-p-value is greater than 0.05. (If this happens, a warning is displayed in place of
-the TCAV score in the UI.)
+p-value is greater than 0.05. (If this happens, a warning is displayed in place
+of the TCAV score in the UI.)
For relative TCAV, users would ideally test concepts with at least ~100 examples
each so we can perform ~15 runs on unique subsets. In practice, users may not
diff --git a/documentation/demos.md b/documentation/demos.md
index c031abf0..13ca9f73 100644
--- a/documentation/demos.md
+++ b/documentation/demos.md
@@ -1,6 +1,6 @@
# Demos
-
@@ -162,3 +162,4 @@ https://pair-code.github.io/lit/tutorials/coref
* Showing using of LIT on image data.
* Explore results of multiple gradient-based image saliency techniques in the
Salience Maps module.
+
diff --git a/documentation/docker.md b/documentation/docker.md
index e98382a6..26e1098f 100644
--- a/documentation/docker.md
+++ b/documentation/docker.md
@@ -1,39 +1,145 @@
# Running LIT in a Docker container
-
+
Users might want to deploy LIT onto servers for public-facing, long-running
instances. This is how we host the LIT demos found on
-https://pair-code.github.io/lit/demos/. Specifically, we deploy containerized
-LIT instances through Google Cloud's Google Kubernetes Engine (GKE).
+https://pair-code.github.io/lit/demos/. This doc describes the basic usage of
+LIT's built-in demos, how to integrate your custom demo into this
-The code required to deploy LIT as a containerized web app can be seen by
-looking at our masked language model demo.
+## Basic Usage
-First, let's look at the relevant code in
-[`lm_demo.py`](../lit_nlp/examples/lm_demo.py):
+LIT can be run as a containerized app using [Docker](https://www.docker.com/) or
+your preferred engine. This is how we run our
+[hosted demos](https://pair-code.github.io/lit/demos/).
-The `get_wsgi_app()` method is what is invoked by the Dockerfile. It sets the
-`server_type` to `"external"`, constructs the LIT `Server` instance, and returns
-the result of it's `serve()` method which is the underlying `LitApp` WSGI
-application.
-
-Now, let's explore the [`Dockerfile`](https://github.com/PAIR-code/lit/blob/main/Dockerfile):
-
-The Dockerfile installs all necessary dependencies for LIT and builds the
+We provide a basic
+[`Dockerfile`](../lit_nlp/Dockerfile) that you can
+use to build and run any of the demos in the `lit_nlp/examples` directory. The
+`Dockerfile` installs all necessary dependencies for LIT and builds the
front-end code from source. Then it runs [gunicorn](https://gunicorn.org/) as
the HTTP server, invoking the `get_wsgi_app()` method from our demo file to get
the WSGI app to serve. The options provided to gunicorn for our use-case can be
found in
[`gunicorn_config.py`](../lit_nlp/examples/gunicorn_config.py).
+You can find a reference implementation in
+[`glue_demo.py`](../lit_nlp/examples/glue_demo.py) or
+[`lm_demo.py`](../lit_nlp/examples/lm_demo.py).
+
+Use the following shell commands to build the default Docker image for LIT from
+the provided `Dockerfile`, and then run a container from that image. Comments
+are provided in-line to help explain what each step does.
+
+```shell
+# Build the docker image using the -t argument to name the image. Remember to
+# include the trailing . so Docker knows where to look for the Dockerfile.
+docker build -t lit-app .
+
+# Now you can run LIT as a containerized app using the following command. Note
+# that the last parameter to the run command is the value you passed to the -t
+# argument in the build command above.
+docker run --rm -p 5432:5432 lit-app
+```
+
+The image above defaults to launching the GLUE demo on port 5432, but you can
+override this using the DEMO_NAME and DEMO_PORT environment variables, as shown
+below.
+
+```shell
+# DEMO_NAME is used to complete the Python module path
+#
+# "lit_nlp.examples.$DEMO_NAME"
+#
+# Therefore, valid values for DEMO_NAME are Python module paths in the
+# lit_nlp/examples directory, such as
+#
+# * direct children -- glue_demo, lm_demo, image_demo, t5_demo, etc.
+# * And nested children -- coref.coref_demo, is_eval.is_eval_demo, etc.
+docker run --rm -p 5432:5432 -e DEMO_NAME=lm_demo lit-app
+
+# Use the DEMO_PORT environment variable as to change the port that LIT uses in
+# the container. Be sure to also change the -p option to map the container's
+# DEMO_PORT to a port on the host system.
+docker run --rm -p 2345:2345 -e DEMO_PORT=2345 lit-app
+
+# Bringing this all together, you can run multiple LIT apps in separate
+# containers on your machine using the combination of the DEMO_NAME and
+# DEMO_PORT arguments, and docker run with the -d flag to run the container in
+# the background.
+docker run -d -p 5432:5432 -e DEMO_NAME=t5_demo lit-app
+docker run -d -p 2345:2345 -e DEMO_NAME=lm_demo -e DEMO_PORT=2345 lit-app
+```
+
+## Integrating Custom LIT Instances with the Default Docker Image
+
+Many LIT users create their own custom LIT server script to demo or serve, which
+involves creating an executable Python module with a `main()` method, as
+described in the [Python API docs](g3doc/api.md#adding-models-and-data).
+
+These custom server scripts can be easily integrated with LIT's default image as
+long as your server meets two requirements:
+
+1. Ensure your server script is located in the `lit_nlp/examples` directory (or
+ in a nested directory under `lit_nlp/examples`).
+2. Ensure that your server script defines a `get_wsgi_app()` function similar
+ to the minimal example shown below.
+
+```python
+def get_wsgi_app() -> Optional[dev_server.LitServerType]:
+ """Return WSGI app for container-hosted demos."""
+ # Set any flag defaults for this LIT instance
+ FLAGS.set_default("server_type", "external")
+ FLAGS.set_default("demo_mode", True)
+ # Parse flags before calling main()
+ unused = flags.FLAGS(sys.argv, known_only=True)
+ if unused:
+ logging.info("get_wsgi_app() called with unused args: %s", unused)
+ return main([])
+```
+
+Assuming your custom script meets the two requirements above, you can simply
+rebuild the default Docker image and run a container using the steps above,
+ensuring that you pass the `-e DEMO_NAME=your_server_script_path_here` to the
+run command.
+
+A more detailed description of the `get_wsgi_app()` code can be found below.
+
+```python
+def get_wsgi_app() -> Optional[dev_server.LitServerType]:
+ """Returns a WSGI app for gunicorn to consume in container-hosted demos."""
+ # Start by setting any default values for flags your LIT instance requires.
+ # Here we set:
+ #
+ # server_type to "external" (required), and
+ # demo_mode to "True" (optional)
+ #
+ # You can add additional defaults as required for your use case.
+ FLAGS.set_default("server_type", "external")
+ FLAGS.set_default("demo_mode", True)
+
+ # Parse any parameters from flags before calling main(). All flags should
+ # defined using one of absl's flags.DEFINE methods.
+ #
+ # Note the use of the known_only=True parameter here. This ensures that only
+ # those flags that have been define using one of absl's flags.DEFINE methods
+ # will be parsed from the command line arguments in sys.argv. All unused
+ # arguments will be returned as a Sequence[str].
+ unused = flags.FLAGS(sys.argv, known_only=True)
+
+ # Running a LIT instance in a container based on the default Dockerfile and
+ # image will always produce unused arguments, because sys.argv contains the
+ # command and parameters used to run the gunicorn sever. While not stricly
+ # required, we recommend logging these to the console, e.g., in case you need
+ # to verify the value of an environment variable.
+ if unused:
+ logging.info("get_wsgi_app() called with unused args: %s", unused)
+
+ # Always pass an empty list to main() inside of get_wsgi_app() functions, as
+ # absl apps are supposed to use absl.flags to define any and all flags
+ # required to run the app.
+ return main([])
+```
-Then, our container is built and deployed following the basics of the
-[GKE tutorial](https://cloud.google.com/kubernetes-engine/docs/tutorials/hello-app).
+## Building Your Own Image
-You can launch any of the built-in demos from the same Docker image. First,
-build the image with `docker build -t lit:latest .`. Running a container from
-this image, as with `docker run --rm -p 5432:5432 lit:latest`, will start
-the GLUE demo and mount it on port 5432.You cna change the demo by setting the
-`$DEMO_NAME` environment variable to one of the valid demo names, and you can
-change the port by setting the `$DEMO_PORT` environment variable. Remember to
-change the `-p` option to forward the container's port to the host.
+Coming soon.
diff --git a/documentation/faq.md b/documentation/faq.md
index 81e774b5..8d26dfb8 100644
--- a/documentation/faq.md
+++ b/documentation/faq.md
@@ -1,6 +1,6 @@
# Frequently Asked Questions
-
+
@@ -8,40 +8,48 @@
### Dataset Size
-Currently, LIT can comfortably handle around 10,000 datapoints, though with a
-couple caveats:
+LIT can comfortably handle 10k-100k datapoints, depending on the speed of the
+server (for hosting the model) and your local machine (for viewing the UI). When
+working with large datasets, there are a couple caveats:
* LIT expects predictions to be available on the whole dataset when the UI
loads. This can take a while if you have a lot of examples or a larger model
- like BERT. In this case, you can pass `warm_start=1.0` to the server (or use
- `--warm_start=1.0`) to warm up the cache on server load.
-
-* If you're using the embedding projector - i.e. if your model returns any
- `Embeddings` fields to visualize - this runs in the browser using WebGL (via
- [ScatterGL](https://github.com/PAIR-code/scatter-gl)), and so may be slow on
- older machines if you have more than a few thousand points.
-
-We're hoping to scale the UI to support 50-100k points soon. In the mean time,
-you can use `Dataset.sample` or `Dataset.slice` to select a smaller number of
-examples to load. You can also pass individual examples to LIT through URL
-params, or load custom data files at runtime using the settings (⚙️) menu.
+ like BERT. In this case, we recommend adding the flag `--warm_start=1` (or
+ pass `warm_start=1` to the `Server` constructor in Python) to pre-compute
+ predictions before starting the server.
+
+* Datasets containing images may take a while to load. If full "native"
+ resolution is not needed (such as if the model operates on a smaller size
+ anyway, such as 256x256), then you can speed things up by resizing images in
+ your `Dataset` loading code.
+
+* LIT uses WebGL for the embedding projector (via
+ [ScatterGL](https://github.com/PAIR-code/scatter-gl)) and for the Scalars
+ and Dive modules (via [Megaplot](https://github.com/PAIR-code/megaplot)),
+ which may be slow on older machines if you have more than a few thousand
+ points.
+
+If you have more data, you can use `Dataset.sample` or `Dataset.slice` to select
+a smaller number of examples to load. You can also pass individual examples to
+LIT [through URL params](#sending-examples-from-another-tool), or load custom
+data files at runtime using the settings (⚙️) menu.
### Large Models
LIT can work with large or slow models, as long as you can wrap them into the
-model API. If you have more than a few pre-loaded datapoints, however, you'll
-probably want to use `warm_start=1.0` (or pass `--warm_start=1.0` as a flag) to
+model API. If you have more than a few preloaded datapoints, however, you'll
+probably want to use `warm_start=1` (or pass `--warm_start=1` as a flag) to
pre-compute predictions when the server loads, so you don't have to wait when
you first visit the UI.
Also, beware of memory usage: since LIT keeps the models in memory to support
-new queries, only so many can fit on a single node or GPU. If you want to load
-more models than can fit in local memory, you can host your model with your
-favorite serving framework and interface with it using a custom
-[`Model`](python_api.md#models) class.
+new queries, only so many models can fit on a single node or GPU. If you want to
+load more or larger models than can fit in local memory, you can host your model
+with your favorite serving framework and connect to it using a custom
+[`Model`](api.md#models) class.
We also have experimental support for using LIT as a lightweight model server;
-this can be useful e.g. for comparing an experimental model on your workstation
+this can be useful, e.g., for comparing an experimental model running locally
against a production model already running in an existing LIT demo. See
[`remote_model.py`](../lit_nlp/components/remote_model.py)
for more details.
@@ -54,13 +62,17 @@ however, model predictions and any newly-generated examples (including as
manually entered in the web UI) are stored in server memory, and if `--data_dir`
is specified, may be cached to disk.
-LIT contains the ability to create or edit datapoints in the UI and then save
-them to disk. If you do not want the tool to to be able to write edited
-datapoints to disk, then pass the `--demo_mode` runtime flag to the LIT server.
+LIT has the ability to create or edit datapoints in the UI and then save them to
+disk. If you do not want the tool to to be able to write edited datapoints to
+disk, then pass the `--demo_mode` runtime flag to the LIT server.
### Managing Access
-The default LIT development server does not implement any explicit access controls. However, this is just a thin convenience wrapper, and the underlying WSGI App can be easily exported and used with additional middleware layers or external serving frameworks. See [Running LIT in a Docker container](./docker.md) for an example of this usage.
+The default LIT development server does not implement any explicit access
+controls. However, this is just a thin convenience wrapper, and the underlying
+WSGI App can be easily exported and used with additional middleware layers or
+external serving frameworks. See
+[Running LIT in a Docker container](./docker.md) for an example of this usage.
## Languages
@@ -74,6 +86,21 @@ scripts should work without any modifications. For examples, see:
* [T5 demo](../lit_nlp/examples/t5_demo.py) -
includes WMT data for machine translation
+## Data Types
+
+In addition to text, LIT has good support for different input and output
+modalities, including images and tabular data. For examples, see:
+
+* [Image demo](../lit_nlp/examples/image_demo.py) -
+ image classification, using a Mobilenet model.
+* [Tabular demo](../lit_nlp/examples/penguin_demo.py) -
+ mult-class classification on tabular (numeric and categorical string) data,
+ using the
+ [Palmer Penguins](https://www.tensorflow.org/datasets/catalog/penguins)
+ dataset.
+
+For more details, see [the features guide to input and output types](api.md#input-and-output-types).
+
## Workflow and Integrations
### Sending examples from another tool
@@ -88,13 +115,17 @@ but using `data0`, `data1`, `data2`, e.g. `data0_=`.
### Downloading or exporting data
-There is currently limited support for this via the settings (⚙️) menu. Click
-the "Dataset" tab and enter a path to save to. This is done server-side, so be
-sure the path is accessible to the server process.
+Currently, there are three ways to export data from the LIT UI:
+
+- In the Data Table, you can copy or download the current view in CSV format -
+ see [the UI guide](./ui_guide.md#data-table) for more details.
+- In the "Dataset" tab of the settings (⚙️) menu, you can enter a path to save
+ data to. Data is pushed to the server and written by the server backend, so
+ be sure the path is writable.
-In the future, we hope to make this workflow more robust, including more control
-over data format, as well as browser-based uploads and downloads of the examples
-(such as from csv files or Google Sheets).
+- If using LIT in a Colab or other notebook environment, you can access the
+ current selection from another cell using `widget.ui_state.primary`,
+ `widget.ui_state.selection`, and `widget.ui_state.pinned`.
### Loading data from the UI
@@ -126,12 +157,12 @@ then be re-used in other environments.
### Training models with LIT
-LIT is primarily an evaluation/infererence-time tool, so we don't provide any
+LIT is primarily an evaluation/inference-time tool, so we don't provide any
official training APIs. However, to facilitate code reuse you can easily add
training methods to your model class. In fact, several of our demos do exactly
this, using LIT's `Dataset` objects to manage training data along with standard
training APIs (such as Keras' `model.fit()`). See
[`quickstart_sst_demo.py`](../lit_nlp/examples/quickstart_sst_demo.py)
-and
+and/or
[`glue_models.py`](../lit_nlp/examples/models/glue_models.py)
-for an example.
+for examples.
diff --git a/documentation/frontend_development.md b/documentation/frontend_development.md
index 91602eef..593d58c4 100644
--- a/documentation/frontend_development.md
+++ b/documentation/frontend_development.md
@@ -1,6 +1,6 @@
# Frontend Developer Guide
-
+
@@ -62,36 +62,13 @@ A layout is defined by a structure of `LitModule` classes, and includes a set of
main components that are always visible, (designated in the object by the "main"
key) and a set of tabs that each contain a group other components.
-A simplified version for a classifier model might look like:
-
-```typescript
-const layout: LitComponentLayout = {
- components : {
- 'Main': [DataTableModule, DatapointEditorModule],
-
- 'Classifiers': [
- ConfusionMatrixModule,
- ],
- 'Counterfactuals': [GeneratorModule],
- 'Predictions': [
- ScalarModule,
- ClassificationModule,
- ],
- 'Explanations': [
- ClassificationModule,
- SalienceMapModule,
- AttentionModule,
- ]
- }
-};
-```
-
-The full layouts are defined in
-[`layout.ts`](../lit_nlp/client/default/layout.ts). To
-use a specific layout for a given LIT instance, pass the key (e.g., "simple" or
-"mlm") in as a server flag when initializing LIT(`--layout=`). The
-layout can be set on-the-fly a URL param (the url param overrides the server
-flag).
+Layouts are generally specified in Python (see
+[Custom Layouts](./api.md#ui-layouts)) through the `LitCanonicalLayout` object.
+The default layouts are defined in
+[`layout.py`](../lit_nlp/api/layout.py), and you can
+add your own by defining one or more `LitCanonicalLayout` objects and passing
+them to the server. For an example, see `CUSTOM_LAYOUTS` in
+[`lm_demo.py`](../lit_nlp/examples/lm_demo.py).
The actual layout of components in
[``](../lit_nlp/client/core/modules.ts)
@@ -109,7 +86,7 @@ starting the initial load of data from the server. This process consists of:
1. Parsing the URL query params to get the url configuration
1. Fetching the app metadata, which includes what models/datasets are available
to use.
-1. Determining which models/datasets to load and then loding them.
+1. Determining which models/datasets to load and then loading them.
## Modules (LitModule)
@@ -129,9 +106,10 @@ outlined below:
@customElement('demo-module') // (0)
export class DemoTextModule extends LitModule {
static override title = 'Demo Module'; // (1)
- static override template = (model = '') => { // (2)
- return html``;
- };
+ static override template =
+ (model: string, selectionServiceIndex: number, shouldReact: number) => // (2)
+ html``;
static override duplicateForModelComparison = true; // (3)
static override get styles() {
@@ -159,7 +137,7 @@ export class DemoTextModule extends LitModule {
this.pigLatin = results;
}
- override render() { // (10)
+ override renderImpl() { // (10)
const color = this.colorService.getDatapointColor(
this.selectionService.primarySelectedInputData);
return html`
@@ -231,14 +209,14 @@ of the data. Since we're using mobx observables to store and compute our state,
we do this all in a reactive way.
First, since the `LitModule` base class derives from `MobxLitElement`, any
-observable data that we use in the `render` method automatically triggers a
-rerener when updated. This is excellent for simple use cases, but what about
+observable data that we use in the `renderImpl` method automatically triggers a
+re-render when updated. This is excellent for simple use cases, but what about
when we want to trigger more complex behavior, such as the asynchronous request
outlined above?
-The pattern that we leverage across the app is as follows: The `render` method
-(10) accesses a private observable `pigLatin` property (6) that, when updated,
-will rerender the template and show the results of the translation
+The pattern that we leverage across the app is as follows: The `renderImpl`
+method (10) accesses a private observable `pigLatin` property (6) that, when
+updated, will re-render the template and show the results of the translation
automatically. In order to update the `pigLatin` observable, we need to set up a
bit of machinery. In the lit-element lifecycle method `firstUpdated`, we use a
helper method `reactImmediately` (7) to set up an explicit reaction to the user
@@ -247,9 +225,12 @@ selecting data. Whatever is returned by the first function (in this case
second function immediately **and** whenever it changes, allowing us to do
something whenever the selection changes. Note, another helper method `react` is
used in the same way as `reactImmediately`, in instances where you don't want to
-immediately invoke the reaction.
+immediately invoke the reaction. Also note that modules should override
+`renderImpl` and not the base `render` method as our `LitModule` base class
+overrides `render` with custom logic which calls our `renderImpl` method for
+modules to perform their rendering in.
-We pass the selction to the `getTranslation` method to fetch the data from our
+We pass the selection to the `getTranslation` method to fetch the data from our
API service. However rather than awaiting our API request directly, we pass the
request promise (8) to another helper method `loadLatest` (9). This ensures that
we won't have any race conditions if, for instance, the user selects different
@@ -260,7 +241,7 @@ template is automatically rerendered, displaying our data.
This may seem like a bit of work for a simple module, but the pattern of using
purely observable data to declaratively specify what gets rendered is very
-powerful for simpligying the logic around building larger, more complex
+powerful for simplifying the logic around building larger, more complex
components.
### Escape Hatches
@@ -291,11 +272,48 @@ reconciliation of what needs to be updated per render.
this.drawCanvas(canvas);
}
- render() {
+ override renderImpl() {
return html``;
}
```
+### Stateful Child Elements
+
+Some modules may contain stateful child elements, where the element has some
+internal state that can have an effect on the module that contains it. Examples of this include any modules that contain the
+[elements/faceting_control.ts](../lit_nlp/client/elements/faceting_control.ts) element.
+
+With these types of child elements, it's important for the containing module
+to construct them programmatically and store them in a class member variable,
+as opposed to only constructing them in the module's html template
+string returned by the `renderImpl` method. Otherwise they will be destroyed
+and recreated when a module is hidden off-screen and then brought back
+on-screen, leading them to lose whatever state they previously held.
+Below is a snippet of example code to handle these types of elements.
+
+```typescript
+// An example of a LITModule using a stateful child element.
+@customElement('example-module')
+export class ExampleModule extends LitModule {
+ private readonly facetingControl = document.createElement('faceting-control');
+
+ constructor() {
+ super();
+
+ const facetsChange = (event: CustomEvent) => {
+ // Do something with the information from the event.
+ };
+ // Set the necessary properties on the faceting-control element.
+ this.facetingControl.contextName = ExampleModule.title;
+ this.facetingControl.addEventListener(
+ 'facets-change', facetsChange as EventListener)
+ }
+
+ override renderImpl() {
+ // Render the faceting-control element.
+ return html`${this.facetingControl}`;
+ }
+```
## Style Guide
* Please disable clang-format on `lit-html` templates and format these
@@ -358,15 +376,16 @@ source from the build output.
If you're modifying the Python backend, there is experimental support for
hot-reloading the LIT application logic (`app.py`) and some dependencies without
-needing to re-load models or datasets. See
-[`dev_server.py`](../lit_nlp/dev_server.py) for details.
+needing to reload models or datasets. See
+[`dev_server.py`](../lit_nlp/dev_server.py) for
+details.
You can use the `--data_dir` flag (see
-[`server_flags.py`](../lit_nlp/server_flags.py) to save the predictions cache to
-disk, and automatically re-load it on a subsequent run. In conjunction with
-`--warm_start`, you can use this to avoid re-running inference during
-development - though if you modify the model at all, you should be sure to
-remove any stale cache files.
+[`server_flags.py`](../lit_nlp/server_flags.py) to
+save the predictions cache to disk, and automatically reload it on a subsequent
+run. In conjunction with `--warm_start`, you can use this to avoid re-running
+inference during development - though if you modify the model at all, you should
+be sure to remove any stale cache files.
## Custom Client / Modules
@@ -375,26 +394,12 @@ modules, though this is currently provided as "best effort" support and the API
is not as mature as for Python extensions.
An example of a custom LIT client application, including a custom
-(potato-themed) module can be found in `lit_nlp/examples/custom_module`. In
-short, to build and serve a custom LIT client application, create a new
-directory containing a `main.ts` entrypoint. This should import any custom
-modules, define a layout that includes them, and call `app.initialize`. For
-example:
-
-```ts
-import {PotatoModule} from './potato';
-
-LAYOUTS = {}; // or import existing set from client/default/layout.ts
-LAYOUTS['potato'] = {
- components: {
- 'Main': [DatapointEditorModule, ClassificationModule],
- 'Data': [DataTableModule, PotatoModule],
- },
-};
-app.initialize(LAYOUTS);
-```
+(potato-themed) module can be found in
+[`lit_nlp/examples/custom_module`](../lit_nlp/examples/custom_module).
+You need only define any custom modules (subclass of `LitModule`) and include
+them in the build.
-Then, build the app, specifying the directory to build with the `env.build`
+When you build the app, specify the directory to build with the `env.build`
flag. For example, to build the `custom_module` demo app:
```sh
@@ -414,3 +419,24 @@ in `examples/custom_module/potato_demo.py`.
parent_dir = os.path.join(pathlib.Path(__file__).parent.absolute()
FLAGS.set_default("client_root", parent_dir, "build"))
```
+
+You must also define a [custom layout definition](./api.md#ui-layouts) in Python
+which references your new module. Note that because Python enums are not
+extensible, you need to reference the custom module using its HTML tag name:
+
+```python
+modules = layout.LitModuleName
+POTATO_LAYOUT = layout.LitCanonicalLayout(
+ upper={
+ "Main": [modules.DatapointEditorModule, modules.ClassificationModule],
+ },
+ lower={
+ "Data": [modules.DataTableModule, "potato-module"],
+ },
+ description="Custom layout with our spud-tastic potato module.",
+)
+```
+
+See
+[`potato_demo.py`](../lit_nlp/examples/custom_module/potato_demo.py)
+for the full example.
diff --git a/documentation/getting_started.md b/documentation/getting_started.md
new file mode 100644
index 00000000..c7d76c17
--- /dev/null
+++ b/documentation/getting_started.md
@@ -0,0 +1,80 @@
+# Getting Started with LIT
+
+
+
+
+
+## Hosted demos
+
+If you want to jump in and start playing with the LIT UI, check out
+https://pair-code.github.io/lit/demos/ for links to our hosted demos.
+
+For a guide to the many features available, check out the
+[UI guide](./ui_guide.md) or this
+[short video](https://www.youtube.com/watch?v=j0OfBWFUqIE).
+
+## LIT with your model
+
+LIT provides a simple [Python API](./api.md) for use with custom models and
+data, as well as components such as metrics and counterfactual generators. Most
+LIT users will take this route, which involves writing a short `demo.py` binary
+to link in `Model` and `Dataset` implementations and configure the server. In
+most cases this can be just a few lines:
+
+```python
+ datasets = {
+ 'foo_data': FooDataset('/path/to/foo.tsv'),
+ 'bar_data': BarDataset('/path/to/bar.tfrecord'),
+ }
+ models = {'my_model': MyModel('/path/to/model/files')}
+ lit_demo = lit_nlp.dev_server.Server(models, datasets, port=4321)
+ lit_demo.serve()
+```
+
+Check out the [API documentation](./api.md#adding-models-and-data) for more, and
+the [demos directory](./demos.md) for a wealth of examples. The
+[components guide](./components.md) also gives an overview of interpretability
+methods and other features available in LIT, and describes how to enable each
+for your task.
+
+## Using LIT in notebooks
+
+LIT can also be used directly from Colab and Jupyter notebooks, with the LIT UI
+rendered in an output cell. See https://colab.research.google.com/github/pair-code/lit/blob/dev/lit_nlp/examples/notebooks/LIT_sentiment_classifier.ipynb for an example.
+
+Note: if you see a 403 error in the output cell where LIT should render, you may
+need to enable cookies on the Colab site, or pass a custom `port=` to the
+`LitWidget` constructor.
+
+## Stand-alone components
+
+Many LIT components - such as models, datasets, metrics, and salience methods -
+are stand-alone Python classes and can be easily used outside of the LIT UI. For
+additional details, see the
+[API documentation](./api.md#using-components-outside-lit) and an example Colab
+at https://colab.research.google.com/github/pair-code/lit/blob/dev/lit_nlp/examples/notebooks/LIT_components_example.ipynb.
+
+## Run an existing example
+
+The [demos page](./demos.md) lists some of the pre-built demos available for a
+variety of model types. The code for these is under [lit_nlp/examples](../lit_nlp/examples)
+;
+each is a small script that loads one or more models and starts a LIT server.
+
+Most demos can be run with a single blaze command. To run the default one, you
+can do:
+
+```sh
+python -m lit_nlp.examples.glue_demo \
+ --quickstart --port=4321 --alsologtostderr
+```
+
+Then navigate to https://localhost:4321 to access the UI.
+
+For most models we recommend using a GPU, though the `--quickstart` flag above
+loads a set of smaller models that run well on CPU. You can also pass
+`--warm_start=1.0`, and LIT will run inference and cache the results before
+server start.
+
+For an overview of supported model types and frameworks, see the
+[components guide](./components.md).
diff --git a/documentation/glossary.md b/documentation/glossary.md
index ae177865..cf2ba892 100644
--- a/documentation/glossary.md
+++ b/documentation/glossary.md
@@ -13,8 +13,9 @@ LIT APIs and codebase:
feed to models and get predictions back.
* **Instance**, a specific implementation of LIT (e.g. a demo.py binary) or
server job running the former.
-* **LIT**, the Language Interpretability Tool. Always fully capitalized,
+* **LIT**, the Learning Interpretability Tool. Always fully capitalized,
sometimes accompanied by a 🔥 emoji. Pronounced "lit", not "ell-eye-tee".
+ Formerly known as the Language Interpretability Tool.
* **Lit**, the web framework consisting of
[lit-element](https://lit-element.polymer-project.org/guide) and
[lit-html](https://lit-html.polymer-project.org/guide) and maintained by the
diff --git a/documentation/images/components/dive.png b/documentation/images/components/dive.png
new file mode 100644
index 00000000..8dbbce65
Binary files /dev/null and b/documentation/images/components/dive.png differ
diff --git a/documentation/images/components/salience-clustering.png b/documentation/images/components/salience-clustering.png
new file mode 100644
index 00000000..9f2d92d7
Binary files /dev/null and b/documentation/images/components/salience-clustering.png differ
diff --git a/documentation/images/components/tabular-feature-attribution.png b/documentation/images/components/tabular-feature-attribution.png
new file mode 100644
index 00000000..0723fb5d
Binary files /dev/null and b/documentation/images/components/tabular-feature-attribution.png differ
diff --git a/documentation/images/figure-1.png b/documentation/images/figure-1.png
index b14d6cbe..d4aa3694 100644
Binary files a/documentation/images/figure-1.png and b/documentation/images/figure-1.png differ
diff --git a/documentation/images/lit-attention.png b/documentation/images/lit-attention.png
index 93102b58..bdab26a5 100644
Binary files a/documentation/images/lit-attention.png and b/documentation/images/lit-attention.png differ
diff --git a/documentation/images/lit-classification-results.png b/documentation/images/lit-classification-results.png
index ebfdb8fe..02274a28 100644
Binary files a/documentation/images/lit-classification-results.png and b/documentation/images/lit-classification-results.png differ
diff --git a/documentation/images/lit-conf-matrix.png b/documentation/images/lit-conf-matrix.png
index 4d17b79e..7bdde8a8 100644
Binary files a/documentation/images/lit-conf-matrix.png and b/documentation/images/lit-conf-matrix.png differ
diff --git a/documentation/images/lit-datapoint-compare.png b/documentation/images/lit-datapoint-compare.png
index 6d42bb9a..dbdda712 100644
Binary files a/documentation/images/lit-datapoint-compare.png and b/documentation/images/lit-datapoint-compare.png differ
diff --git a/documentation/images/lit-datapoint-generator.png b/documentation/images/lit-datapoint-generator.png
index 921625df..63f4d290 100644
Binary files a/documentation/images/lit-datapoint-generator.png and b/documentation/images/lit-datapoint-generator.png differ
diff --git a/documentation/images/lit-datatable-export.png b/documentation/images/lit-datatable-export.png
new file mode 100644
index 00000000..fb43f636
Binary files /dev/null and b/documentation/images/lit-datatable-export.png differ
diff --git a/documentation/images/lit-datatable.png b/documentation/images/lit-datatable.png
index 52fac18f..396c1d32 100644
Binary files a/documentation/images/lit-datatable.png and b/documentation/images/lit-datatable.png differ
diff --git a/documentation/images/lit-embeddings.png b/documentation/images/lit-embeddings.png
index 9cfc72ab..9cc731ba 100644
Binary files a/documentation/images/lit-embeddings.png and b/documentation/images/lit-embeddings.png differ
diff --git a/documentation/images/lit-metrics.png b/documentation/images/lit-metrics.png
index 04e612e5..c0de352d 100644
Binary files a/documentation/images/lit-metrics.png and b/documentation/images/lit-metrics.png differ
diff --git a/documentation/images/lit-model-compare.png b/documentation/images/lit-model-compare.png
index 6acad1be..4d19dde5 100644
Binary files a/documentation/images/lit-model-compare.png and b/documentation/images/lit-model-compare.png differ
diff --git a/documentation/images/lit-pred-score.png b/documentation/images/lit-pred-score.png
index 052c501e..17a22d09 100644
Binary files a/documentation/images/lit-pred-score.png and b/documentation/images/lit-pred-score.png differ
diff --git a/documentation/images/lit-salience.png b/documentation/images/lit-salience.png
index 0b484d49..040738e4 100644
Binary files a/documentation/images/lit-salience.png and b/documentation/images/lit-salience.png differ
diff --git a/documentation/images/lit-settings.png b/documentation/images/lit-settings.png
index 5362c6d9..1d38e173 100644
Binary files a/documentation/images/lit-settings.png and b/documentation/images/lit-settings.png differ
diff --git a/documentation/images/lit-slices.png b/documentation/images/lit-slices.png
index 1796e7d6..a9d835ce 100644
Binary files a/documentation/images/lit-slices.png and b/documentation/images/lit-slices.png differ
diff --git a/documentation/images/lit-ui.png b/documentation/images/lit-ui.png
index e9c05770..01cb2b7e 100644
Binary files a/documentation/images/lit-ui.png and b/documentation/images/lit-ui.png differ
diff --git a/documentation/includes/highlight_demos.md b/documentation/includes/highlight_demos.md
new file mode 100644
index 00000000..e69de29b
diff --git a/documentation/index.md b/documentation/index.md
index e6378ad4..ddd60ca7 100644
--- a/documentation/index.md
+++ b/documentation/index.md
@@ -1,88 +1,19 @@
-# Language Interpretability Tool (LIT)
+# Learning Interpretability Tool (LIT)
-
+
-Welcome to the Language Interpretability Tool (🔥LIT)!
-
-## Hosted demos
+Welcome to 🔥LIT (Learning Interpretability Tool, formerly Language
+Interpretability Tool)!
If you want to jump in and start playing with the LIT UI, check out
https://pair-code.github.io/lit/demos/ for links to our hosted demos.
-For a guide to the many features available, check out the
-[UI guide](./ui_guide.md) or this
-[short video](https://www.youtube.com/watch?v=j0OfBWFUqIE).
-
-## LIT with your model
-
-LIT provides a simple [Python API](./api.md) for use with custom models and
-data, as well as components such as metrics and counterfactual generators. Most
-LIT users will take this route, which involves writing a short `demo.py` binary
-to link in `Model` and `Dataset` implementations. In many cases, this can be
-just a few lines of logic:
-
-```python
- datasets = {
- 'foo_data': FooDataset('/path/to/foo.tsv'),
- 'bar_data': BarDataset('/path/to/bar.tfrecord'),
- }
- models = {'my_model': MyModel('/path/to/model/files')}
- lit_demo = lit_nlp.dev_server.Server(models, datasets, port=4321)
- lit_demo.serve()
-```
-
-Check out the [API documentation](./api.md#adding-models-and-data) for more, and
-the [demos directory](./demos.md) for a wealth of examples. The
-[components guide](./components.md) also gives a good overview of the different
-features that are available, and how to enable them for your model.
-
-Also, join https://groups.google.com/g/lit-annoucements to receive announcements and updates on new LIT features.
-
-## Using LIT in notebooks
-
-LIT can also be used directly from Colab and Jupyter notebooks, with the LIT UI
-rendered in an output cell. See https://colab.research.google.com/github/pair-code/lit/blob/dev/lit_nlp/examples/notebooks/LIT_sentiment_classifier.ipynb for an example.
-
-Note: if you see a 403 error in the output cell where LIT should render, you may
-need to enable cookies on the Colab site.
-
-## Stand-alone components
-
-Many LIT components - such as models, datasets, metrics, and salience methods -
-are stand-alone Python classes and can be easily used outside of the LIT UI. For
-additional details, see the
-[API documentation](./api.md#using-components-outside-lit) and an example Colab
-at https://colab.research.google.com/github/pair-code/lit/blob/dev/lit_nlp/examples/notebooks/LIT_Components_Example.ipynb.
-
-## Run an existing example
-
-The [demos page](./demos.md) lists some of the pre-built demos available for a
-variety of model types. The code for these is under [lit_nlp/examples](../lit_nlp/examples)
-;
-each is a small script that loads one or more models and starts a LIT server.
-
-Most demos can be run with a single blaze command. To run the default one, you
-can do:
-
-```sh
-python -m lit_nlp.examples.glue_demo \
- --quickstart --port=4321 --alsologtostderr
-```
-
-Then navigate to https://localhost:4321 to access the UI.
-
-For most models we recommend using a GPU, though the `--quickstart` flag above
-loads a set of smaller models that run well on CPU. You can also pass
-`--warm_start=1.0`, and LIT will run inference and cache the results before
-server start.
-
-For an overview of supported model types and frameworks, see the
-[components guide](./components.md).
-
## Documentation Links
+* [Getting Started](./getting_started.md)
+
* [Demos](./demos.md)
* [UI Guide](./ui_guide.md)
diff --git a/documentation/ui_guide.md b/documentation/ui_guide.md
index 00153397..7d258c4b 100644
--- a/documentation/ui_guide.md
+++ b/documentation/ui_guide.md
@@ -1,8 +1,8 @@
# UI Guide
-
+
-This is a user guide for the Language Interpretability Tool (LIT) UI.
+This is a user guide for the Learning Interpretability Tool (LIT) UI.
For a quick video tour of LIT, check out this
[video](https://www.youtube.com/watch?v=CuRI_VK83dU).
@@ -105,10 +105,10 @@ includes controls such as:
features or model outputs on those datapoints (such as coloring by some
categorical input feature, or by prediction error for a regression task).
-Next to the menus is a **"Compare datapoints"** switch. Enabling this puts LIT
-into datapoint comparison mode, where two datapoints can be compared against
-each other, across all applicable modules. This mode is described in more detail
-[below](#comparing-datapoints).
+Next to the menus is a button for pinning/unpinning a datapoint. Pinning a
+datapoint puts LIT into datapoint comparison mode, where two datapoints can be
+compared against each other, across all applicable modules. This mode is
+described in more detail [below](#comparing-datapoints).
On the right side of the toolbar, it displays
how many datapoints are in the loaded dataset and how many of those are
@@ -147,14 +147,16 @@ models.
## Comparing Datapoints
-Toggling the **"Compare datapoints"** switch in the mail toolbar puts LIT
-into **datapoint comparison mode**. In this mode, the primary datapoint
-selection is used a reference datapoint, and any subsequent setting of the
-primary selection causes it to be compared against the reference point. The
-reference datapoint is highlighted in a blue border in the data table.
+Pinning a datapoint, through either the toolbar button or controls in modules
+(e.g., the pin icons in Data Table rows), puts LIT into
+**datapoint comparison mode**. In this mode, the pinned datapoint is used as a
+reference to compare the primary selection. The pinned datapoint is indicated
+by a pin icon in modules that support datapoint comparison, such as the Data
+Table. Any changes to the primary selection will update datapoint comparison
+visualizations in all supporting modules.
-Just as with model comparison, certain modules are repeated, one showing the
-reference datapoint and one showing the primary selected datapoint.
+As with model comparison, some modules may be duplicated, one showing the pinned
+datapoint and one showing the primary selected datapoint.
This allows for easy comparison of model results on a datapoint to any generated
counterfactual datapoints, or any other datapoint from the loaded dataset.
@@ -246,7 +248,7 @@ header row. All columns that have filters set on them have their search button
outlined. Clicking the **"x"** button in the search box for a column will clear
that column's filter.
-The **"only show selected"** checkbox toggles the data table to only show the
+The **"show only selected"** checkbox toggles the data table to only show the
datapoints that are currently selected.
The **"reset view"** button returns the data table to its standard, default
@@ -261,8 +263,21 @@ The below data table shows one sorted by the "label" field, with the "passage"
field being filtered to only those datapoints that contain the word "sound" in
them.
+A datapoint can be pinned to enable comparison by clicking the pin icon on the
+left side of the datapoint's table entry when the datapoint is hovered over or
+selected. A pinned datapoint can be unpinned by clicking on its pin icon again.
+
![LIT data table](./images/lit-datatable.png "LIT data table")
+You can also export data to CSV using the copy or download buttons in the bottom
+right:
+
+![LIT data table](./images/lit-datatable-export.png "LIT data table export controls")
+
+This will export all data in the current table view. To export only the
+selection, use the "Show only selected" toggle. To include additional
+columns such as model predictions, enable them from the "Columns" dropdown.
+
### Datapoint Editor
The datapoint editor shows the details of the primary selected datapoint, if one
@@ -411,7 +426,7 @@ model's prediction on the primary selection. This module can contain multiple
methodologies for calculating this salience, depending on the capabilities of
the model being analyzed (e.x. if the model provides gradients, then
gradient-based token-wise salience can be calculated and displayed -- see
-[adding models and data](python_api.md#adding-models-and-data) for more). The
+[adding models and data](api.md#adding-models-and-data) for more). The
background of each text piece is colored by the salience of that piece on the
prediction, and hovering on any piece will display the exact value calculated
for that piece.
diff --git a/environment.yml b/environment.yml
index f970e02d..0133cb50 100644
--- a/environment.yml
+++ b/environment.yml
@@ -14,20 +14,18 @@
# ==============================================================================
name: lit-nlp
dependencies:
- - python=3.7
- - absl-py
- - numpy
- - scipy
- - pandas
- - scikit-learn
- - gunicorn
+ - python=3.9
- pip
- pip:
- - tensorflow==2.6.0
- - keras==2.6.0
+ - absl-py
+ - numpy
+ - scipy
+ - pandas
+ - scikit-learn==1.0.2
+ - tensorflow==2.10.0
+ - keras==2.10.0
- tfds-nightly
- - tensorflow-text==2.6.0
- - tensorflow-estimator==2.6.0
+ - tensorflow-text==2.10.0
- rouge-score
- sacrebleu
- numba==0.53.1
@@ -38,8 +36,11 @@ dependencies:
- portpicker
- annoy
- ml-collections
- - saliency
- - matplotlib
- - attrs
- - Jinja2
+ - saliency==0.1.3
+ - matplotlib==3.3.4
+ - attrs==22.1.0
- stanza
+ - shap==0.37.0
+ - tqdm==4.64.0
+ - gunicorn==20.1.0
+ - jinja2
diff --git a/lit_nlp/.eslintrc.json b/lit_nlp/.eslintrc.json
new file mode 100644
index 00000000..a1712aab
--- /dev/null
+++ b/lit_nlp/.eslintrc.json
@@ -0,0 +1,80 @@
+{
+ "extends": [
+ "eslint:recommended",
+ "plugin:node/recommended"
+ ],
+ "plugins": [
+ "node"
+ ],
+ "rules": {
+ "block-scoped-var": "error",
+ "eqeqeq": [
+ "error",
+ "always",
+ {
+ "null": "ignore"
+ }
+ ],
+ "no-var": "error",
+ "prefer-const": "error",
+ "eol-last": "error",
+ "prefer-arrow-callback": "error",
+ "no-trailing-spaces": "error",
+ "quotes": [
+ "warn",
+ "single",
+ {
+ "avoidEscape": true
+ }
+ ],
+ "no-restricted-properties": [
+ "error",
+ {
+ "object": "describe",
+ "property": "only"
+ },
+ {
+ "object": "it",
+ "property": "only"
+ }
+ ]
+ },
+ "overrides": [
+ {
+ "files": [
+ "**/*.ts",
+ "**/*.tsx"
+ ],
+ "parser": "@typescript-eslint/parser",
+ "extends": [
+ "plugin:@typescript-eslint/recommended"
+ ],
+ "rules": {
+ "@typescript-eslint/no-non-null-assertion": "off",
+ "@typescript-eslint/no-use-before-define": "off",
+ "@typescript-eslint/no-warning-comments": "off",
+ "@typescript-eslint/no-empty-function": "off",
+ "@typescript-eslint/no-var-requires": "off",
+ "@typescript-eslint/no-inferrable-types": "off",
+ "@typescript-eslint/explicit-function-return-type": "off",
+ "@typescript-eslint/explicit-module-boundary-types": "off",
+ "@typescript-eslint/ban-types": "off",
+ "@typescript-eslint/camelcase": "off",
+ "node/no-missing-import": "off",
+ "node/no-empty-function": "off",
+ "node/no-unpublished-import": "off",
+ "node/no-unsupported-features/es-syntax": "off",
+ "node/no-missing-require": "off",
+ "node/shebang": "off",
+ "no-dupe-class-members": "off",
+ "no-prototype-builtins": "off",
+ "quotes": "off",
+ "require-atomic-updates": "off"
+ },
+ "parserOptions": {
+ "ecmaVersion": 2018,
+ "sourceType": "module"
+ }
+ }
+ ]
+}
diff --git a/lit_nlp/.pylintrc b/lit_nlp/.pylintrc
new file mode 100644
index 00000000..35c5fc16
--- /dev/null
+++ b/lit_nlp/.pylintrc
@@ -0,0 +1,429 @@
+# This Pylint rcfile contains a best-effort configuration to uphold the
+# best-practices and style described in the Google Python style guide:
+# https://google.github.io/styleguide/pyguide.html
+#
+# Its canonical open-source location is:
+# https://google.github.io/styleguide/pylintrc
+
+[MAIN]
+
+# Files or directories to be skipped. They should be base names, not paths.
+ignore=third_party
+
+# Files or directories matching the regex patterns are skipped. The regex
+# matches against base names, not paths.
+ignore-patterns=
+
+# Pickle collected data for later comparisons.
+persistent=no
+
+# List of plugins (as comma separated values of python modules names) to load,
+# usually to register additional checkers.
+load-plugins=
+
+# Use multiple processes to speed up Pylint.
+jobs=4
+
+# Allow loading of arbitrary C extensions. Extensions are imported into the
+# active Python interpreter and may run arbitrary code.
+unsafe-load-any-extension=no
+
+
+[MESSAGES CONTROL]
+
+# Only show warnings with the listed confidence levels. Leave empty to show
+# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
+confidence=
+
+# Enable the message, report, category or checker with the given id(s). You can
+# either give multiple identifier separated by comma (,) or put this option
+# multiple time (only on the command line, not in the configuration file where
+# it should appear only once). See also the "--disable" option for examples.
+#enable=
+
+# Disable the message, report, category or checker with the given id(s). You
+# can either give multiple identifiers separated by comma (,) or put this
+# option multiple times (only on the command line, not in the configuration
+# file where it should appear only once).You can also use "--disable=all" to
+# disable everything first and then reenable specific checks. For example, if
+# you want to run only the similarities checker, you can use "--disable=all
+# --enable=similarities". If you want to run only the classes checker, but have
+# no Warning level messages displayed, use"--disable=all --enable=classes
+# --disable=W"
+disable=abstract-method,
+ apply-builtin,
+ arguments-differ,
+ attribute-defined-outside-init,
+ backtick,
+ bad-option-value,
+ basestring-builtin,
+ buffer-builtin,
+ c-extension-no-member,
+ consider-using-enumerate,
+ cmp-builtin,
+ cmp-method,
+ coerce-builtin,
+ coerce-method,
+ delslice-method,
+ div-method,
+ duplicate-code,
+ eq-without-hash,
+ execfile-builtin,
+ file-builtin,
+ filter-builtin-not-iterating,
+ fixme,
+ getslice-method,
+ global-statement,
+ hex-method,
+ idiv-method,
+ implicit-str-concat,
+ import-error,
+ import-self,
+ import-star-module-level,
+ inconsistent-return-statements,
+ input-builtin,
+ intern-builtin,
+ invalid-str-codec,
+ locally-disabled,
+ long-builtin,
+ long-suffix,
+ map-builtin-not-iterating,
+ misplaced-comparison-constant,
+ missing-function-docstring,
+ metaclass-assignment,
+ next-method-called,
+ next-method-defined,
+ no-absolute-import,
+ no-else-break,
+ no-else-continue,
+ no-else-raise,
+ no-else-return,
+ no-init, # added
+ no-member,
+ no-name-in-module,
+ no-self-use,
+ nonzero-method,
+ oct-method,
+ old-division,
+ old-ne-operator,
+ old-octal-literal,
+ old-raise-syntax,
+ parameter-unpacking,
+ print-statement,
+ raising-string,
+ range-builtin-not-iterating,
+ raw_input-builtin,
+ rdiv-method,
+ reduce-builtin,
+ relative-import,
+ reload-builtin,
+ round-builtin,
+ setslice-method,
+ signature-differs,
+ standarderror-builtin,
+ suppressed-message,
+ sys-max-int,
+ too-few-public-methods,
+ too-many-ancestors,
+ too-many-arguments,
+ too-many-boolean-expressions,
+ too-many-branches,
+ too-many-instance-attributes,
+ too-many-locals,
+ too-many-nested-blocks,
+ too-many-public-methods,
+ too-many-return-statements,
+ too-many-statements,
+ trailing-newlines,
+ unichr-builtin,
+ unicode-builtin,
+ unnecessary-pass,
+ unpacking-in-except,
+ useless-else-on-loop,
+ useless-object-inheritance,
+ useless-suppression,
+ using-cmp-argument,
+ wrong-import-order,
+ xrange-builtin,
+ zip-builtin-not-iterating,
+
+
+[REPORTS]
+
+# Set the output format. Available formats are text, parseable, colorized, msvs
+# (visual studio) and html. You can also give a reporter class, eg
+# mypackage.mymodule.MyReporterClass.
+output-format=text
+
+# Tells whether to display a full report or only the messages
+reports=no
+
+# Python expression which should return a note less than 10 (10 is the highest
+# note). You have access to the variables errors warning, statement which
+# respectively contain the number of errors / warnings messages and the total
+# number of statements analyzed. This is used by the global evaluation report
+# (RP0004).
+evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
+
+# Template used to display messages. This is a python new-style format string
+# used to format the message information. See doc for all details
+#msg-template=
+
+
+[BASIC]
+
+# Good variable names which should always be accepted, separated by a comma
+good-names=main,_
+
+# Bad variable names which should always be refused, separated by a comma
+bad-names=
+
+# Colon-delimited sets of names that determine each other's naming style when
+# the name regexes allow several styles.
+name-group=
+
+# Include a hint for the correct naming format with invalid-name
+include-naming-hint=no
+
+# List of decorators that produce properties, such as abc.abstractproperty. Add
+# to this list to register other decorators that produce valid properties.
+property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl
+
+# Regular expression matching correct function names
+function-rgx=^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$
+
+# Regular expression matching correct variable names
+variable-rgx=^[a-z][a-z0-9_]*$
+
+# Regular expression matching correct constant names
+const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
+
+# Regular expression matching correct attribute names
+attr-rgx=^_{0,2}[a-z][a-z0-9_]*$
+
+# Regular expression matching correct argument names
+argument-rgx=^[a-z][a-z0-9_]*$
+
+# Regular expression matching correct class attribute names
+class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
+
+# Regular expression matching correct inline iteration names
+inlinevar-rgx=^[a-z][a-z0-9_]*$
+
+# Regular expression matching correct class names
+class-rgx=^_?[A-Z][a-zA-Z0-9]*$
+
+# Regular expression matching correct module names
+module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$
+
+# Regular expression matching correct method names
+method-rgx=(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$
+
+# Regular expression which should only match function or class names that do
+# not require a docstring.
+no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$
+
+# Minimum line length for functions/classes that require docstrings, shorter
+# ones are exempt.
+docstring-min-length=10
+
+
+[TYPECHECK]
+
+# List of decorators that produce context managers, such as
+# contextlib.contextmanager. Add to this list to register other decorators that
+# produce valid context managers.
+contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager
+
+# Tells whether missing members accessed in mixin class should be ignored. A
+# mixin class is detected if its name ends with "mixin" (case insensitive).
+ignore-mixin-members=yes
+
+# List of module names for which member attributes should not be checked
+# (useful for modules/projects where namespaces are manipulated during runtime
+# and thus existing member attributes cannot be deduced by static analysis. It
+# supports qualified module names, as well as Unix pattern matching.
+ignored-modules=
+
+# List of class names for which member attributes should not be checked (useful
+# for classes with dynamically set attributes). This supports the use of
+# qualified names.
+ignored-classes=optparse.Values,thread._local,_thread._local
+
+# List of members which are set dynamically and missed by pylint inference
+# system, and so shouldn't trigger E1101 when accessed. Python regular
+# expressions are accepted.
+generated-members=
+
+
+[FORMAT]
+
+# Maximum number of characters on a single line.
+max-line-length=80
+
+# TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt
+# lines made too long by directives to pytype.
+
+# Regexp for a line that is allowed to be longer than the limit.
+ignore-long-lines=(?x)(
+ ^\s*(\#\ )??$|
+ ^\s*(from\s+\S+\s+)?import\s+.+$)
+
+# Allow the body of an if to be on the same line as the test if there is no
+# else.
+single-line-if-stmt=yes
+
+# Maximum number of lines in a module
+max-module-lines=99999
+
+# String used as indentation unit. The internal Google style guide mandates 2
+# spaces. Google's externaly-published style guide says 4, consistent with
+# PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google
+# projects (like TensorFlow).
+indent-string=' '
+
+# Number of spaces of indent required inside a hanging or continued line.
+indent-after-paren=4
+
+# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
+expected-line-ending-format=
+
+
+[MISCELLANEOUS]
+
+# List of note tags to take in consideration, separated by a comma.
+notes=TODO
+
+
+[STRING]
+
+# This flag controls whether inconsistent-quotes generates a warning when the
+# character used as a quote delimiter is used inconsistently within a module.
+check-quote-consistency=yes
+
+
+[VARIABLES]
+
+# Tells whether we should check for unused import in __init__ files.
+init-import=no
+
+# A regular expression matching the name of dummy variables (i.e. expectedly
+# not used).
+dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_)
+
+# List of additional names supposed to be defined in builtins. Remember that
+# you should avoid to define new builtins when possible.
+additional-builtins=
+
+# List of strings which can identify a callback function by name. A callback
+# name must start or end with one of those strings.
+callbacks=cb_,_cb
+
+# List of qualified module names which can have objects that can redefine
+# builtins.
+redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools
+
+
+[LOGGING]
+
+# Logging modules to check that the string format arguments are in logging
+# function parameter format
+logging-modules=logging,absl.logging,tensorflow.io.logging
+
+
+[SIMILARITIES]
+
+# Minimum lines number of a similarity.
+min-similarity-lines=4
+
+# Ignore comments when computing similarities.
+ignore-comments=yes
+
+# Ignore docstrings when computing similarities.
+ignore-docstrings=yes
+
+# Ignore imports when computing similarities.
+ignore-imports=no
+
+
+[SPELLING]
+
+# Spelling dictionary name. Available dictionaries: none. To make it working
+# install python-enchant package.
+spelling-dict=
+
+# List of comma separated words that should not be checked.
+spelling-ignore-words=
+
+# A path to a file that contains private dictionary; one word per line.
+spelling-private-dict-file=
+
+# Tells whether to store unknown words to indicated private dictionary in
+# --spelling-private-dict-file option instead of raising a message.
+spelling-store-unknown-words=no
+
+
+[IMPORTS]
+
+# Deprecated modules which should not be used, separated by a comma
+deprecated-modules=regsub,
+ TERMIOS,
+ Bastion,
+ rexec,
+ sets
+
+# Create a graph of every (i.e. internal and external) dependencies in the
+# given file (report RP0402 must not be disabled)
+import-graph=
+
+# Create a graph of external dependencies in the given file (report RP0402 must
+# not be disabled)
+ext-import-graph=
+
+# Create a graph of internal dependencies in the given file (report RP0402 must
+# not be disabled)
+int-import-graph=
+
+# Force import order to recognize a module as part of the standard
+# compatibility libraries.
+known-standard-library=
+
+# Force import order to recognize a module as part of a third party library.
+known-third-party=enchant, absl
+
+# Analyse import fallback blocks. This can be used to support both Python 2 and
+# 3 compatible code, which means that the block might have code that exists
+# only in one or another interpreter, leading to false positives when analysed.
+analyse-fallback-blocks=no
+
+
+[CLASSES]
+
+# List of method names used to declare (i.e. assign) instance attributes.
+defining-attr-methods=__init__,
+ __new__,
+ setUp
+
+# List of member names, which should be excluded from the protected access
+# warning.
+exclude-protected=_asdict,
+ _fields,
+ _replace,
+ _source,
+ _make
+
+# List of valid names for the first argument in a class method.
+valid-classmethod-first-arg=cls,
+ class_
+
+# List of valid names for the first argument in a metaclass class method.
+valid-metaclass-classmethod-first-arg=mcs
+
+
+[EXCEPTIONS]
+
+# Exceptions that will emit a warning when being caught. Defaults to
+# "Exception"
+overgeneral-exceptions=StandardError,
+ Exception,
+ BaseException
diff --git a/lit_nlp/.vscode/settings.json b/lit_nlp/.vscode/settings.json
new file mode 100644
index 00000000..5a285f01
--- /dev/null
+++ b/lit_nlp/.vscode/settings.json
@@ -0,0 +1,20 @@
+{
+ "editor.defaultFormatter": "esbenp.prettier-vscode",
+ "editor.detectIndentation": false,
+ "editor.formatOnSave": true,
+ "editor.formatOnSaveMode": "modifications",
+ "editor.insertSpaces": true,
+ "editor.rulers": [80],
+ "editor.tabSize": 2,
+ "editor.wrappingIndent": "none",
+ "files.insertFinalNewline": true,
+ "files.trimTrailingWhitespace": true,
+ "search.exclude": {
+ "**/node_modules": true,
+ "**/yarn.lock": true,
+ "**/.cache/**/*": true
+ },
+ "typescript.format.insertSpaceAfterOpeningAndBeforeClosingNonemptyBraces": false,
+ "typescript.tsdk": "${workspaceRoot}/lit_nlp/client/node_modules/typescript/lib",
+ "update.mode": "none"
+}
diff --git a/lit_nlp/api/components.py b/lit_nlp/api/components.py
index 9eaa00ab..51e2b890 100644
--- a/lit_nlp/api/components.py
+++ b/lit_nlp/api/components.py
@@ -12,11 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-# Lint as: python3
"""Base classes for LIT backend components."""
import abc
import inspect
-from typing import Dict, List, Optional, Sequence, Text
+from typing import Any, Optional, Sequence
from lit_nlp.api import dataset as lit_dataset
from lit_nlp.api import model as lit_model
@@ -24,6 +23,7 @@
JsonDict = types.JsonDict
IndexedInput = types.IndexedInput
+MetricsDict = dict[str, float]
class Interpreter(metaclass=abc.ABCMeta):
@@ -42,10 +42,10 @@ def description(self) -> str:
return inspect.getdoc(self) or ''
def run(self,
- inputs: List[JsonDict],
+ inputs: list[JsonDict],
model: lit_model.Model,
dataset: lit_dataset.Dataset,
- model_outputs: Optional[List[JsonDict]] = None,
+ model_outputs: Optional[list[JsonDict]] = None,
config: Optional[JsonDict] = None):
"""Run this component, given a model and input(s)."""
raise NotImplementedError(
@@ -56,15 +56,16 @@ def run_with_metadata(self,
indexed_inputs: Sequence[IndexedInput],
model: lit_model.Model,
dataset: lit_dataset.IndexedDataset,
- model_outputs: Optional[List[JsonDict]] = None,
+ model_outputs: Optional[list[JsonDict]] = None,
config: Optional[JsonDict] = None):
"""Run this component, with access to data indices and metadata."""
inputs = [ex['data'] for ex in indexed_inputs]
return self.run(inputs, model, dataset, model_outputs, config)
- def is_compatible(self, model: lit_model.Model):
- """Return if interpreter is compatible with the given model."""
- del model
+ def is_compatible(self, model: lit_model.Model,
+ dataset: lit_dataset.Dataset) -> bool:
+ """Return if interpreter is compatible with the dataset and model."""
+ del dataset, model # Unused in base class
return True
def config_spec(self) -> types.Spec:
@@ -91,22 +92,36 @@ def meta_spec(self) -> types.Spec:
return {}
+# TODO(b/254832560): Remove ComponentGroup class after promoting Metrics.
class ComponentGroup(Interpreter):
"""Convenience class to package a group of components together."""
- def __init__(self, subcomponents: Dict[Text, Interpreter]):
+ def __init__(self, subcomponents: dict[str, Interpreter]):
self._subcomponents = subcomponents
+ def meta_spec(self) -> types.Spec:
+ spec: types.Spec = {}
+ for component_name, component in self._subcomponents.items():
+ for field_name, field_spec in component.meta_spec().items():
+ spec[f'{component_name}: {field_name}'] = field_spec
+ return spec
+
def run_with_metadata(
self,
indexed_inputs: Sequence[IndexedInput],
model: lit_model.Model,
dataset: lit_dataset.IndexedDataset,
- model_outputs: Optional[List[JsonDict]] = None,
- config: Optional[JsonDict] = None) -> Dict[Text, JsonDict]:
+ model_outputs: Optional[list[JsonDict]] = None,
+ config: Optional[JsonDict] = None) -> dict[str, JsonDict]:
"""Run this component, given a model and input(s)."""
- assert model_outputs is not None
- assert len(model_outputs) == len(indexed_inputs)
+ if model_outputs is None:
+ raise ValueError('model_outputs cannot be None')
+
+ if len(model_outputs) != len(indexed_inputs):
+ raise ValueError('indexed_inputs and model_outputs must be the same size,'
+ f' received {len(indexed_inputs)} indexed_inputs and '
+ f'{len(model_outputs)} model_outputs')
+
ret = {}
for name, component in self._subcomponents.items():
ret[name] = component.run_with_metadata(indexed_inputs, model, dataset,
@@ -121,7 +136,7 @@ def run_with_metadata(self,
indexed_inputs: Sequence[IndexedInput],
model: lit_model.Model,
dataset: lit_dataset.IndexedDataset,
- model_outputs: Optional[List[JsonDict]] = None,
+ model_outputs: Optional[list[JsonDict]] = None,
config: Optional[JsonDict] = None):
"""Run this component, with access to data indices and metadata."""
# IndexedInput[] -> Input[]
@@ -129,10 +144,10 @@ def run_with_metadata(self,
return self.generate_all(inputs, model, dataset, config)
def generate_all(self,
- inputs: List[JsonDict],
+ inputs: list[JsonDict],
model: lit_model.Model,
dataset: lit_dataset.Dataset,
- config: Optional[JsonDict] = None) -> List[List[JsonDict]]:
+ config: Optional[JsonDict] = None) -> list[list[JsonDict]]:
"""Run generation on a set of inputs.
Args:
@@ -154,9 +169,83 @@ def generate(self,
example: JsonDict,
model: lit_model.Model,
dataset: lit_dataset.Dataset,
- config: Optional[JsonDict] = None) -> List[JsonDict]:
+ config: Optional[JsonDict] = None) -> list[JsonDict]:
"""Return a list of generated examples."""
- return
+ pass
+
+
+class Metrics(Interpreter):
+ """Base class for LIT metrics components."""
+
+ # Required methods implementations from Interpreter base class
+
+ def is_compatible(self, model: lit_model.Model,
+ dataset: lit_dataset.Dataset) -> bool:
+ """True if the model and dataset support metric computation."""
+ for pred_spec in model.output_spec().values():
+ parent_key: Optional[str] = getattr(pred_spec, 'parent', None)
+ parent_spec: Optional[types.LitType] = dataset.spec().get(parent_key)
+ if self.is_field_compatible(pred_spec, parent_spec):
+ return True
+ return False
+
+ def meta_spec(self):
+ """A dict of MetricResults defining the metrics computed by this class."""
+ raise NotImplementedError('Subclass should define its own meta spec.')
+
+ def run(
+ self,
+ inputs: Sequence[JsonDict],
+ model: lit_model.Model,
+ dataset: lit_dataset.Dataset,
+ model_outputs: Optional[list[JsonDict]] = None,
+ config: Optional[JsonDict] = None) -> list[JsonDict]:
+ raise NotImplementedError(
+ 'Subclass should implement its own run using compute.')
+
+ def run_with_metadata(
+ self,
+ indexed_inputs: Sequence[IndexedInput],
+ model: lit_model.Model,
+ dataset: lit_dataset.IndexedDataset,
+ model_outputs: Optional[list[JsonDict]] = None,
+ config: Optional[JsonDict] = None) -> list[JsonDict]:
+ inputs = [inp['data'] for inp in indexed_inputs]
+ return self.run(inputs, model, dataset, model_outputs, config)
+
+ # New methods introduced by this subclass
+
+ def is_field_compatible(
+ self,
+ pred_spec: types.LitType,
+ parent_spec: Optional[types.LitType]) -> bool:
+ """True if compatible with the prediction field and its parent."""
+ del pred_spec, parent_spec # Unused in base class
+ raise NotImplementedError('Subclass should implement field compatibility.')
+
+ def compute(
+ self,
+ labels: Sequence[Any],
+ preds: Sequence[Any],
+ label_spec: types.LitType,
+ pred_spec: types.LitType,
+ config: Optional[JsonDict] = None) -> MetricsDict:
+ """Compute metric(s) given labels and predictions."""
+ raise NotImplementedError('Subclass should implement this, or override '
+ 'compute_with_metadata() directly.')
+
+ def compute_with_metadata(
+ self,
+ labels: Sequence[Any],
+ preds: Sequence[Any],
+ label_spec: types.LitType,
+ pred_spec: types.LitType,
+ indices: Sequence[types.ExampleId],
+ metas: Sequence[JsonDict],
+ config: Optional[JsonDict] = None) -> MetricsDict:
+ """As compute(), but with access to indices and metadata."""
+ del indices, metas # unused by Metrics base class
+ return self.compute(labels, preds, label_spec, pred_spec, config)
class Annotator(metaclass=abc.ABCMeta):
@@ -177,7 +266,7 @@ def __init__(self, name: str, annotator_model: lit_model.Model):
self._annotator_model = annotator_model
@abc.abstractmethod
- def annotate(self, inputs: List[JsonDict],
+ def annotate(self, inputs: list[JsonDict],
dataset: lit_dataset.Dataset,
dataset_spec_to_annotate: Optional[types.Spec] = None):
"""Annotate the provided inputs.
diff --git a/lit_nlp/api/dataset.py b/lit_nlp/api/dataset.py
index ae446460..d3166927 100644
--- a/lit_nlp/api/dataset.py
+++ b/lit_nlp/api/dataset.py
@@ -12,14 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-# Lint as: python3
"""Base classes for LIT models."""
import glob
import inspect
import os
import random
from types import MappingProxyType # pylint: disable=g-importing-member
-from typing import cast, List, Dict, Optional, Callable, Mapping, Sequence
+from typing import cast, Optional, Callable, Mapping, Sequence
from absl import logging
@@ -47,20 +46,16 @@ def __getitem__(self, slice_obj):
class Dataset(object):
- """Base class for LIT datasets.
-
- We recommend pre-loading the data in the constructor, but you can also stream
- on the fly in Dataset.examples() if desired.
- """
+ """Base class for LIT datasets."""
_spec: Spec = {}
- _examples: List[JsonDict] = []
+ _examples: list[JsonDict] = []
_description: Optional[str] = None
_base: Optional['Dataset'] = None
def __init__(self,
spec: Optional[Spec] = None,
- examples: Optional[List[JsonDict]] = None,
+ examples: Optional[list[JsonDict]] = None,
description: Optional[str] = None,
base: Optional['Dataset'] = None):
"""Base class constructor.
@@ -83,8 +78,8 @@ def __init__(self,
# In case user child class requires the instance to convert examples
# this makes sure the user class is preserved. We cannot do this below
# as the default method is static and does not require instance.
- self.lit_example_to_bytes = self._base.lit_example_to_bytes
- self.bytes_to_lit_example = self._base.bytes_to_lit_example
+ self.bytes_from_lit_example = self._base.bytes_from_lit_example
+ self.lit_example_from_bytes = self._base.lit_example_from_bytes
# Override from direct arguments.
self._examples = examples if examples is not None else self._examples
@@ -116,7 +111,7 @@ def load(self, path: str):
return self._base.load(path)
pass
- def save(self, examples: List[IndexedInput], path: str):
+ def save(self, examples: list[IndexedInput], path: str):
"""Save newly-created datapoints to disk in a dataset-specific format.
Subclasses should override this method if they wish to save new, persisted
@@ -139,7 +134,7 @@ def spec(self) -> Spec:
return self._spec
@property
- def examples(self) -> List[JsonDict]:
+ def examples(self) -> list[JsonDict]:
"""Return examples, in format described by spec."""
return self._examples
@@ -167,24 +162,28 @@ def sample(self, n, seed=42):
examples = list(self.examples)
return Dataset(examples=examples, base=self)
+ def filter(self, predicate: Callable[[JsonDict], bool]):
+ selected_examples = list(filter(predicate, self.examples))
+ return Dataset(examples=selected_examples, base=self)
+
def shuffle(self, seed=42):
"""Return a new dataset with randomized example order."""
# random.shuffle will shuffle in-place; use sample to make a new list.
return self.sample(n=len(self), seed=seed)
- def remap(self, field_map: Dict[str, str]):
+ def remap(self, field_map: dict[str, str]):
"""Return a copy of this dataset with some fields renamed."""
new_spec = utils.remap_dict(self.spec(), field_map)
new_examples = [utils.remap_dict(ex, field_map) for ex in self.examples]
return Dataset(new_spec, new_examples, base=self)
@staticmethod
- def bytes_to_lit_example(input_bytes: bytes) -> Optional[JsonDict]:
+ def lit_example_from_bytes(input_bytes: bytes) -> Optional[JsonDict]:
"""Convert bytes representation to LIT example."""
return serialize.from_json(input_bytes.decode('utf-8'))
@staticmethod
- def lit_example_to_bytes(lit_example: JsonDict) -> bytes:
+ def bytes_from_lit_example(lit_example: JsonDict) -> bytes:
"""Convert LIT example to bytes representation."""
return serialize.to_json(lit_example).encode('utf-8')
@@ -195,19 +194,24 @@ def lit_example_to_bytes(lit_example: JsonDict) -> bytes:
class IndexedDataset(Dataset):
"""Dataset with additional indexing information."""
- _index: Dict[ExampleId, IndexedInput] = {}
+ _index: dict[ExampleId, IndexedInput] = {}
- def index_inputs(self, examples: List[types.Input]) -> List[IndexedInput]:
+ def index_inputs(self, examples: list[types.Input]) -> list[IndexedInput]:
"""Create indexed versions of inputs."""
+ # pylint: disable=g-complex-comprehension not complex, just a line-too-long
return [
- IndexedInput({'data': example, 'id': self.id_fn(example), 'meta': {}})
+ IndexedInput(
+ data=example,
+ id=self.id_fn(example),
+ meta=types.InputMetadata(added=None, parentId=None, source=None))
for example in examples
- ] # pyformat: disable
+ ]
+ # pylint: enable=g-complex-comprehension
def __init__(self,
*args,
id_fn: Optional[IdFnType] = None,
- indexed_examples: Optional[List[IndexedInput]] = None,
+ indexed_examples: Optional[list[IndexedInput]] = None,
**kw):
super().__init__(*args, **kw)
assert id_fn is not None, 'id_fn must be specified.'
@@ -219,6 +223,18 @@ def __init__(self,
self._indexed_examples = self.index_inputs(self._examples)
self._index = {ex['id']: ex for ex in self._indexed_examples}
+ @property
+ def slice(self):
+ """Syntactic sugar, allows .slice[i:j] to return a new IndexedDataset."""
+
+ def _slicer(slice_obj):
+ return IndexedDataset(
+ indexed_examples=self.indexed_examples[slice_obj],
+ id_fn=self.id_fn,
+ base=self)
+
+ return SliceWrapper(_slicer)
+
@classmethod
def index_all(cls, datasets: Mapping[str, Dataset], id_fn: IdFnType):
"""Convenience function to convert a dict of datasets."""
@@ -233,7 +249,7 @@ def index(self) -> Mapping[ExampleId, IndexedInput]:
"""Return a read-only view of the index."""
return MappingProxyType(self._index)
- def save(self, examples: List[IndexedInput], path: str):
+ def save(self, examples: list[IndexedInput], path: str):
"""Save newly-created datapoints to disk.
Args:
diff --git a/lit_nlp/api/dataset_test.py b/lit_nlp/api/dataset_test.py
index 86edc7d9..9b493a75 100644
--- a/lit_nlp/api/dataset_test.py
+++ b/lit_nlp/api/dataset_test.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-# Lint as: python3
"""Tests for lit_nlp.lib.model."""
from absl.testing import absltest
diff --git a/lit_nlp/api/dtypes.py b/lit_nlp/api/dtypes.py
index 3ee325a7..5a3ae45e 100644
--- a/lit_nlp/api/dtypes.py
+++ b/lit_nlp/api/dtypes.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-# Lint as: python3
"""Dataclasses for representing structured output.
Classes in this file should be used for actual input/output data,
@@ -30,13 +29,18 @@
on the frontend as corresponding JavaScript objects.
"""
import abc
-from typing import Any, Dict, List, Mapping, Optional, Sequence, Text, Tuple, Union
+from typing import Any, Dict, List, Optional, Sequence, Text, Tuple, Union
import attr
JsonDict = Dict[Text, Any]
+class EnumSerializableAsValues(object):
+ """Dummy class to mark that an enum's members should be serialized as their values."""
+ pass
+
+
@attr.s(auto_attribs=True, frozen=True, slots=True)
class DataTuple(metaclass=abc.ABCMeta):
"""Simple dataclasses.
@@ -97,8 +101,8 @@ def to_json(self) -> JsonDict:
@attr.s(auto_attribs=True, frozen=True, slots=True)
class TokenSalience(DataTuple):
"""Dataclass for a salience map over tokens."""
- tokens: List[str]
- salience: List[float] # parallel to tokens
+ tokens: Sequence[str]
+ salience: Sequence[float] # parallel to tokens
@attr.s(auto_attribs=True, frozen=True, slots=True)
@@ -117,49 +121,17 @@ class SequenceSalienceMap(DataTuple):
salience: Sequence[Sequence[float]] # usually, a np.ndarray
-# LINT.IfChange
-# pylint: disable=invalid-name
-@attr.s(auto_attribs=True)
-class LayoutSettings(DataTuple):
- hideToolbar: bool = False
- mainHeight: int = 45
- centerPage: bool = False
-
-
-@attr.s(auto_attribs=True)
-class LitComponentLayout(DataTuple):
- """Frontend UI layout (legacy); should match client/lib/types.ts."""
- # Keys are names of tabs; one must be called "Main".
- # Values are names of LitModule HTML elements,
- # e.g. data-table-module for the DataTableModule class.
- components: Dict[str, List[str]]
- layoutSettings: LayoutSettings = attr.ib(factory=LayoutSettings)
- description: Optional[str] = None
-
- def to_json(self) -> JsonDict:
- """Override serialization to properly convert nested objects."""
- # Not invertible, but these only go from server -> frontend anyway.
- return attr.asdict(self, recurse=True)
-
-
-@attr.s(auto_attribs=True)
-class LitCanonicalLayout(DataTuple):
- """Frontend UI layout; should match client/lib/types.ts."""
- # Keys are names of tabs, and values are names of LitModule HTML elements,
- # e.g. data-table-module for the DataTableModule class.
- upper: Dict[str, List[str]]
- lower: Dict[str, List[str]] = attr.ib(factory=dict)
- layoutSettings: LayoutSettings = attr.ib(factory=LayoutSettings)
- description: Optional[str] = None
-
- def to_json(self) -> JsonDict:
- """Override serialization to properly convert nested objects."""
- # Not invertible, but these only go from server -> frontend anyway.
- return attr.asdict(self, recurse=True)
-
+@attr.s(auto_attribs=True, frozen=True, slots=True)
+class RegressionResult(DataTuple):
+ """Dataclass for regression interpreter result."""
+ score: float
+ error: Optional[float]
+ squared_error: Optional[float]
-LitComponentLayouts = Mapping[str, Union[LitComponentLayout,
- LitCanonicalLayout]]
-# pylint: enable=invalid-name
-# LINT.ThenChange(../client/lib/types.ts)
+@attr.s(auto_attribs=True, frozen=True, slots=True)
+class ClassificationResult(DataTuple):
+ """Dataclass for classification interpreter result."""
+ scores: List[float]
+ predicted_class: str
+ correct: Optional[bool]
diff --git a/lit_nlp/api/layout.py b/lit_nlp/api/layout.py
new file mode 100644
index 00000000..83d4c881
--- /dev/null
+++ b/lit_nlp/api/layout.py
@@ -0,0 +1,237 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Module names and type definitions for frontend UI layouts."""
+import enum
+from typing import Any, Dict, List, Mapping, Optional, Text, Union
+
+import attr
+from lit_nlp.api import dtypes
+
+JsonDict = Dict[Text, Any]
+
+
+# LINT.IfChange
+# pylint: disable=invalid-name
+@enum.unique
+class LitModuleName(dtypes.EnumSerializableAsValues, enum.Enum):
+ """List of available frontend modules.
+
+ Entries should map the TypeScript class name to the HTML element name,
+ as declared in HTMLElementTagNameMap in the .ts file defining each LitModule.
+ """
+ AnnotatedTextModule = 'annotated-text-module'
+ AnnotatedTextGoldModule = 'annotated-text-gold-module'
+ AttentionModule = 'attention-module'
+ ClassificationModule = 'classification-module'
+ ConfusionMatrixModule = 'confusion-matrix-module'
+ CounterfactualExplainerModule = 'counterfactual-explainer-module'
+ CurvesModule = 'curves-module'
+ DataTableModule = 'data-table-module'
+ SimpleDataTableModule = 'simple-data-table-module'
+ DatapointEditorModule = 'datapoint-editor-module'
+ SimpleDatapointEditorModule = 'simple-datapoint-editor-module'
+ DiveModule = 'dive-module'
+ DocumentationModule = 'documentation-module'
+ EmbeddingsModule = 'embeddings-module'
+ FeatureAttributionModule = 'feature-attribution-module'
+ GeneratedImageModule = 'generated-image-module'
+ GeneratedTextModule = 'generated-text-module'
+ GeneratorModule = 'generator-module'
+ LanguageModelPredictionModule = 'lm-prediction-module'
+ MetricsModule = 'metrics-module'
+ MultilabelModule = 'multilabel-module'
+ PdpModule = 'pdp-module'
+ RegressionModule = 'regression-module'
+ SalienceClusteringModule = 'salience-clustering-module'
+ SalienceMapModule = 'salience-map-module'
+ ScalarModule = 'scalar-module'
+ SequenceSalienceModule = 'sequence-salience-module'
+ SpanGraphGoldModule = 'span-graph-gold-module'
+ SpanGraphModule = 'span-graph-module'
+ SpanGraphGoldModuleVertical = 'span-graph-gold-module-vertical'
+ SpanGraphModuleVertical = 'span-graph-module-vertical'
+ TCAVModule = 'tcav-module'
+ TrainingDataAttributionModule = 'tda-module'
+ ThresholderModule = 'thresholder-module'
+
+ def __call__(self, **kw):
+ return ModuleConfig(self.value, **kw)
+
+
+# TODO(lit-dev): consider making modules subclass this instead of LitModuleName.
+@attr.s(auto_attribs=True)
+class ModuleConfig(dtypes.DataTuple):
+ module: Union[str, LitModuleName]
+ requiredForTab: bool = False
+ # TODO(b/172979677): support title, duplicateAsRow, numCols,
+ # and startMinimized.
+
+
+# Most users should use LitModuleName, but we allow fallback to strings
+# so that users can reference custom modules which are defined in TypeScript
+# but not included in the LitModuleName enum above.
+# If a string is used, it should be the HTML element name, like foo-bar-module.
+LitModuleList = List[Union[str, LitModuleName, ModuleConfig]]
+
+
+@attr.s(auto_attribs=True)
+class LayoutSettings(dtypes.DataTuple):
+ hideToolbar: bool = False
+ mainHeight: int = 45
+ centerPage: bool = False
+
+
+@attr.s(auto_attribs=True)
+class LitComponentLayout(dtypes.DataTuple):
+ """Frontend UI layout (legacy); should match client/lib/types.ts."""
+ # Keys are names of tabs; one must be called "Main".
+ # Values are names of LitModule HTML elements,
+ # e.g. data-table-module for the DataTableModule class.
+ components: Dict[str, LitModuleList]
+ layoutSettings: LayoutSettings = attr.ib(factory=LayoutSettings)
+ description: Optional[str] = None
+
+ def to_json(self) -> JsonDict:
+ """Override serialization to properly convert nested objects."""
+ # Not invertible, but these only go from server -> frontend anyway.
+ return attr.asdict(self, recurse=True)
+
+
+@attr.s(auto_attribs=True)
+class LitCanonicalLayout(dtypes.DataTuple):
+ """Frontend UI layout; should match client/lib/types.ts."""
+ # Keys are names of tabs, and values are names of LitModule HTML elements,
+ # e.g. data-table-module for the DataTableModule class.
+ upper: Dict[str, LitModuleList]
+ lower: Dict[str, LitModuleList] = attr.ib(factory=dict)
+ layoutSettings: LayoutSettings = attr.ib(factory=LayoutSettings)
+ description: Optional[str] = None
+
+ def to_json(self) -> JsonDict:
+ """Override serialization to properly convert nested objects."""
+ # Not invertible, but these only go from server -> frontend anyway.
+ return attr.asdict(self, recurse=True)
+
+
+LitComponentLayouts = Mapping[str, Union[LitComponentLayout,
+ LitCanonicalLayout]]
+
+# pylint: enable=invalid-name
+# LINT.ThenChange(../client/lib/types.ts)
+
+##
+# Common layout definitions.
+
+modules = LitModuleName # pylint: disable=invalid-name
+
+MODEL_PREDS_MODULES = (
+ modules.SpanGraphGoldModuleVertical,
+ modules.SpanGraphModuleVertical,
+ modules.ClassificationModule,
+ modules.MultilabelModule,
+ modules.RegressionModule,
+ modules.LanguageModelPredictionModule,
+ modules.GeneratedTextModule,
+ modules.AnnotatedTextGoldModule,
+ modules.AnnotatedTextModule,
+ modules.GeneratedImageModule,
+)
+
+DEFAULT_MAIN_GROUP = (
+ modules.DataTableModule,
+ modules.DatapointEditorModule,
+)
+
+##
+# A "simple demo server" layout.
+SIMPLE_LAYOUT = LitCanonicalLayout(
+ upper={
+ 'Editor': [
+ modules.DocumentationModule,
+ modules.SimpleDatapointEditorModule,
+ ],
+ 'Examples': [modules.SimpleDataTableModule],
+ },
+ lower={
+ 'Predictions': list(MODEL_PREDS_MODULES),
+ 'Salience': [
+ *MODEL_PREDS_MODULES,
+ modules.SalienceMapModule(requiredForTab=True),
+ ],
+ 'Sequence Salience': [
+ *MODEL_PREDS_MODULES,
+ modules.SequenceSalienceModule(requiredForTab=True),
+ ],
+ 'Influence': [modules.TrainingDataAttributionModule],
+ },
+ layoutSettings=LayoutSettings(
+ hideToolbar=True,
+ mainHeight=30,
+ centerPage=True,
+ ),
+ description=(
+ 'A basic layout just containing a datapoint creator/editor, the '
+ 'predictions, and the data table. There are also some visual '
+ 'simplifications: the toolbar is hidden, and the modules are centered '
+ 'on the page rather than being full width.'),
+)
+
+##
+# A "kitchen sink" layout with maximum functionality.
+STANDARD_LAYOUT = LitCanonicalLayout(
+ upper={
+ 'Main': [
+ modules.DocumentationModule,
+ modules.EmbeddingsModule,
+ *DEFAULT_MAIN_GROUP,
+ ]
+ },
+ lower={
+ 'Predictions': [
+ *MODEL_PREDS_MODULES,
+ modules.ScalarModule,
+ modules.PdpModule,
+ ],
+ 'Explanations': [
+ *MODEL_PREDS_MODULES,
+ modules.SalienceMapModule,
+ modules.SequenceSalienceModule,
+ modules.AttentionModule,
+ modules.FeatureAttributionModule,
+ ],
+ 'Salience Clustering': [modules.SalienceClusteringModule],
+ 'Metrics': [
+ modules.MetricsModule,
+ modules.ConfusionMatrixModule,
+ modules.CurvesModule,
+ modules.ThresholderModule,
+ ],
+ 'Influence': [modules.TrainingDataAttributionModule],
+ 'Counterfactuals': [
+ modules.GeneratorModule,
+ modules.CounterfactualExplainerModule,
+ ],
+ 'TCAV': [modules.TCAVModule],
+ },
+ description=(
+ 'The default LIT layout, which includes the data table and data point '
+ 'editor, the performance and metrics, predictions, explanations, and '
+ 'counterfactuals.'),
+)
+
+DEFAULT_LAYOUTS = {
+ 'simple': SIMPLE_LAYOUT,
+ 'default': STANDARD_LAYOUT,
+}
diff --git a/lit_nlp/api/model.py b/lit_nlp/api/model.py
index 0526d754..248b8f6f 100644
--- a/lit_nlp/api/model.py
+++ b/lit_nlp/api/model.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-# Lint as: python3
"""Base classes for LIT models."""
import abc
import inspect
@@ -107,6 +106,18 @@ def max_minibatch_size(self) -> int:
"""Maximum minibatch size for this model."""
return 1
+ @property
+ def supports_concurrent_predictions(self):
+ """Indcates support for multiple concurrent predict calls across threads.
+
+ Defaults to false.
+
+ Returns:
+ (bool) True if the model can handle multiple concurrent calls to its
+ `predict_minibatch` method.
+ """
+ return False
+
@abc.abstractmethod
def predict_minibatch(self, inputs: List[JsonDict]) -> List[JsonDict]:
"""Run prediction on a batch of inputs.
@@ -240,11 +251,16 @@ def description(self) -> str:
def max_minibatch_size(self) -> int:
return self.wrapped.max_minibatch_size()
+ @property
+ def supports_concurrent_predictions(self):
+ return self.wrapped.supports_concurrent_predictions
+
def predict_minibatch(self, inputs: List[JsonDict], **kw) -> List[JsonDict]:
return self.wrapped.predict_minibatch(inputs, **kw)
- def predict(self, inputs: Iterable[JsonDict], **kw) -> Iterator[JsonDict]:
- return self.wrapped.predict(inputs, **kw)
+ def predict(self, inputs: Iterable[JsonDict], *args,
+ **kw) -> Iterator[JsonDict]:
+ return self.wrapped.predict(inputs, *args, **kw)
# NOTE: if a subclass modifies predict(), it should also override this to
# call the custom predict() method - otherwise this will delegate to the
@@ -297,7 +313,8 @@ def __init__(self,
self._max_qps = max_qps
self._pool = multiprocessing.pool.ThreadPool(max_concurrent_requests)
- def predict(self, inputs: Iterable[JsonDict], **kw) -> Iterator[JsonDict]:
+ def predict(self, inputs: Iterable[JsonDict], *unused_args,
+ **unused_kwargs) -> Iterator[JsonDict]:
batches = utils.batch_iterator(
inputs, max_batch_size=self.max_minibatch_size())
batches = utils.rate_limit(batches, self._max_qps)
@@ -308,6 +325,11 @@ def max_minibatch_size(self) -> int:
"""Maximum minibatch size for this model. Subclass can override this."""
return 1
+ @property
+ def supports_concurrent_predictions(self):
+ """Remote models can handle concurrent predictions by default."""
+ return True
+
@abc.abstractmethod
def predict_minibatch(self, inputs: List[JsonDict]) -> List[JsonDict]:
"""Run prediction on a batch of inputs.
diff --git a/lit_nlp/api/model_test.py b/lit_nlp/api/model_test.py
index b39c1d93..4eae04c0 100644
--- a/lit_nlp/api/model_test.py
+++ b/lit_nlp/api/model_test.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-# Lint as: python3
"""Tests for lit_nlp.lib.model."""
from absl.testing import absltest
diff --git a/lit_nlp/api/types.py b/lit_nlp/api/types.py
index 4aaa5ef9..d8a349bb 100644
--- a/lit_nlp/api/types.py
+++ b/lit_nlp/api/types.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-# Lint as: python3
"""Type classes for LIT inputs and outputs.
These are simple dataclasses used in model.input_spec() and model.output_spec()
@@ -26,15 +25,35 @@
should be rendered.
"""
import abc
-from typing import Any, Dict, List, NewType, Optional, Sequence, Text, Tuple, Union
+import enum
+import math
+import numbers
+from typing import Any, NewType, Optional, Sequence, Type, TypedDict, Union
import attr
+from lit_nlp.api import dtypes
+import numpy as np
-JsonDict = Dict[Text, Any]
-Input = JsonDict # TODO(lit-dev): stronger typing using NewType
-IndexedInput = NewType("IndexedInput", JsonDict) # has keys: id, data, meta
-ExampleId = Text
-TokenTopKPredsList = List[List[Tuple[str, float]]]
+JsonDict = dict[str, Any]
+Input = NewType("Input", JsonDict)
+ExampleId = NewType("ExampleId", str)
+ScoredTextCandidates = Sequence[tuple[str, Optional[float]]]
+TokenTopKPredsList = Sequence[ScoredTextCandidates]
+NumericTypes = numbers.Number
+
+
+class InputMetadata(TypedDict):
+ added: Optional[bool]
+ # pylint: disable=invalid-name
+ parentId: Optional[ExampleId] # Named to match TypeScript data structure
+ # pylint: enable=invalid-name
+ source: Optional[str]
+
+
+class IndexedInput(TypedDict):
+ data: Input
+ id: ExampleId
+ meta: InputMetadata
##
@@ -44,9 +63,50 @@ class LitType(metaclass=abc.ABCMeta):
"""Base class for LIT Types."""
required: bool = True # for input fields, mark if required by the model.
annotated: bool = False # If this type is created from an Annotator.
+ show_in_data_table = True # If true, show this info the data table.
# TODO(lit-dev): Add defaults for all LitTypes
default = None # an optional default value for a given type.
+ def validate_input(self, value: Any, spec: "Spec", example: Input):
+ """Validate a dataset example's value against its spec in an example.
+
+ Subtypes should override to validate a provided value and raise a ValueError
+ if the value is not valid.
+
+ Args:
+ value: The value to validate against the specific LitType.
+ spec: The spec of the dataset.
+ example: The entire example of which the value is a part of.
+
+ Raises:
+ ValueError if validation fails.
+ """
+ pass
+
+ def validate_output(self, value: Any, output_spec: "Spec",
+ output_dict: JsonDict, input_spec: "Spec",
+ dataset_spec: "Spec", input_example: Input):
+ """Validate a model output value against its spec and input example.
+
+ Subtypes should override to validate a provided value and raise a ValueError
+ if the value is not valid.
+
+ Args:
+ value: The value to validate against the specific LitType.
+ output_spec: The output spec of the model.
+ output_dict: The entire model output for the example.
+ input_spec: The input spec of the model.
+ dataset_spec: The dataset spec.
+ input_example: The example from which the output value is returned.
+
+ Raises:
+ ValueError if validation fails.
+ """
+ del output_spec, output_dict, dataset_spec
+ # If not overwritten by a LitType, then validate it as an input to re-use
+ # simple validation code.
+ self.validate_input(value, input_spec, input_example)
+
def is_compatible(self, other):
"""Check equality, ignoring some fields."""
# We allow this class to be a subclass of the other.
@@ -61,29 +121,48 @@ def is_compatible(self, other):
def to_json(self) -> JsonDict:
"""Used by serialize.py."""
d = attr.asdict(self)
- d["__class__"] = "LitType"
d["__name__"] = self.__class__.__name__
- # All parent classes, from method resolution order (mro).
- # Use this to check inheritance on the frontend.
- d["__mro__"] = [a.__name__ for a in self.__class__.__mro__]
return d
@staticmethod
def from_json(d: JsonDict):
- """Used by serialize.py."""
- cls = globals()[d.pop("__name__")] # class by name from this module
- del d["__mro__"]
- return cls(**d)
+ """Used by serialize.py.
+
+ Args:
+ d: The JSON Object-like dictionary to attempt to parse.
+ Returns:
+ An instance of a LitType subclass defined by the contents of `d`.
-Spec = Dict[Text, LitType]
+ Raises:
+ KeyError: If `d` does not have a `__name__` property.
+ NameError: If `d["__name__"]` is not a `LitType` subclass.
+ TypeError: If `d["__name__"]` is not a string.
+ """
+ try:
+ type_name = d.pop("__name__")
+ except KeyError as e:
+ raise KeyError("A __name__ property is required to parse a LitType from "
+ "JSON.") from e
+
+ if not isinstance(type_name, str):
+ raise TypeError("The value of __name__ must be a string.")
+
+ base_cls = globals().get("LitType")
+ cls = globals().get(type_name) # class by name from this module
+ if cls is None or not issubclass(cls, base_cls):
+ raise NameError(f"{type_name} is not a valid LitType.")
+
+ return cls(**d)
+
+Spec = dict[str, LitType]
# Attributes that should be treated as a reference to other fields.
FIELD_REF_ATTRIBUTES = frozenset(
{"parent", "align", "align_in", "align_out", "grad_for"})
-def _remap_leaf(leaf: LitType, keymap: Dict[str, str]) -> LitType:
+def _remap_leaf(leaf: LitType, keymap: dict[str, str]) -> LitType:
"""Remap any field references on a LitType."""
d = attr.asdict(leaf) # mutable
d = {
@@ -93,7 +172,7 @@ def _remap_leaf(leaf: LitType, keymap: Dict[str, str]) -> LitType:
return leaf.__class__(**d)
-def remap_spec(spec: Spec, keymap: Dict[str, str]) -> Spec:
+def remap_spec(spec: Spec, keymap: dict[str, str]) -> Spec:
"""Rename fields in a spec, with a best-effort to also remap field references."""
ret = {}
for k, v in spec.items():
@@ -109,7 +188,7 @@ def remap_spec(spec: Spec, keymap: Dict[str, str]) -> Spec:
@attr.s(auto_attribs=True, frozen=True, kw_only=True)
-class String(LitType):
+class StringLitType(LitType):
"""User-editable text input.
All automated edits are disabled for this type.
@@ -117,54 +196,138 @@ class String(LitType):
Mainly used for string inputs that have special formatting, and should only
be edited manually.
"""
- default: Text = ""
+ default: str = ""
+
+ def validate_input(self, value, spec: Spec, example: Input):
+ if not isinstance(value, str):
+ raise ValueError(f"{value} is of type {type(value)}, expected str")
@attr.s(auto_attribs=True, frozen=True, kw_only=True)
-class TextSegment(LitType):
+class TextSegment(StringLitType):
"""Text input (untokenized), a single string."""
- default: Text = ""
+ pass
@attr.s(auto_attribs=True, frozen=True, kw_only=True)
class ImageBytes(LitType):
"""An image, an encoded base64 ascii string (starts with 'data:image...')."""
- pass
+
+ def validate_input(self, value, spec: Spec, example: Input):
+ if not isinstance(value, str) or not value.startswith("data:image"):
+ raise ValueError(f"{value} is not an encoded image string.")
+
+
+@attr.s(auto_attribs=True, frozen=True, kw_only=True)
+class JPEGBytes(ImageBytes):
+ """As ImageBytes, but assumed to be in jpeg format."""
+
+ def validate_input(self, value, spec: Spec, example: Input):
+ if not isinstance(value, str) or not value.startswith("data:image/jpg"):
+ raise ValueError(f"{value} is not an encoded JPEG image string.")
+
+
+@attr.s(auto_attribs=True, frozen=True, kw_only=True)
+class PNGBytes(ImageBytes):
+ """As ImageBytes, but assumed to be in png format."""
+
+ def validate_input(self, value, spec: Spec, example: Input):
+ if not isinstance(value, str) or not value.startswith("data:image/png"):
+ raise ValueError(f"{value} is not an encoded PNG image string.")
@attr.s(auto_attribs=True, frozen=True, kw_only=True)
class GeneratedText(TextSegment):
"""Generated (untokenized) text."""
# Name of a TextSegment field to evaluate against
- parent: Optional[Text] = None
+ parent: Optional[str] = None
+ def validate_output(self, value, output_spec: Spec, output_dict: JsonDict,
+ input_spec: Spec, dataset_spec: Spec,
+ input_example: Input):
+ if not isinstance(value, str):
+ raise ValueError(f"{value} is of type {type(value)}, expected str")
+ if self.parent and not isinstance(input_spec[self.parent], TextSegment):
+ raise ValueError(f"parent field {self.parent} is of type "
+ f"{type(self.parent)}, expected TextSegment")
-ScoredTextCandidates = List[Tuple[str, Optional[float]]]
+
+@attr.s(auto_attribs=True, frozen=True, kw_only=True)
+class ListLitType(LitType):
+ """List type."""
+ default: Sequence[Any] = None
@attr.s(auto_attribs=True, frozen=True, kw_only=True)
-class GeneratedTextCandidates(TextSegment):
- """Multiple candidates for GeneratedText; values are List[(text, score)]."""
+class _StringCandidateList(ListLitType):
+ """A list of (text, score) tuples."""
+ default: ScoredTextCandidates = None
+
+ def validate_output(self, value, output_spec: Spec, output_dict: JsonDict,
+ input_spec: Spec, dataset_spec: Spec,
+ input_example: Input):
+ if not isinstance(value, list):
+ raise ValueError(f"{value} is not a list")
+
+ for v in value:
+ if not (isinstance(v, tuple) and isinstance(v[0], str) and
+ (v[1] is None or isinstance(v[1], NumericTypes))):
+ raise ValueError(f"{v} list item is not a (str, float) tuple)")
+
+
+@attr.s(auto_attribs=True, frozen=True, kw_only=True)
+class GeneratedTextCandidates(_StringCandidateList):
+ """Multiple candidates for GeneratedText."""
# Name of a TextSegment field to evaluate against
- parent: Optional[Text] = None
+ parent: Optional[str] = None
@staticmethod
def top_text(value: ScoredTextCandidates) -> str:
return value[0][0] if len(value) else ""
+ def validate_output(self, value, output_spec: Spec, output_dict: JsonDict,
+ input_spec: Spec, dataset_spec: Spec,
+ input_example: Input):
+ super().validate_output(
+ value, output_spec, output_dict, input_spec, dataset_spec,
+ input_example)
+ if self.parent and not isinstance(input_spec[self.parent], TextSegment):
+ raise ValueError(f"parent field {self.parent} is of type "
+ f"{type(input_spec[self.parent])}, expected TextSegment")
+
+
+@attr.s(auto_attribs=True, frozen=True, kw_only=True)
+class ReferenceTexts(_StringCandidateList):
+ """Multiple candidates for TextSegment."""
+ pass
+
@attr.s(auto_attribs=True, frozen=True, kw_only=True)
-class ReferenceTexts(LitType):
- """Multiple candidates for TextSegment; values are List[(text, score)]."""
+class TopTokens(_StringCandidateList):
+ """Multiple tokens with weight."""
pass
@attr.s(auto_attribs=True, frozen=True, kw_only=True)
-class URL(TextSegment):
+class URLLitType(TextSegment):
"""TextSegment that should be interpreted as a URL."""
pass
+@attr.s(auto_attribs=True, frozen=True, kw_only=True)
+class GeneratedURL(TextSegment):
+ """A URL that was generated as part of a model prediction."""
+ align: Optional[str] = None # name of a field in the model output
+
+ def validate_output(self, value, output_spec: Spec, output_dict: JsonDict,
+ input_spec: Spec, dataset_spec: Spec,
+ input_example: Input):
+ super().validate_output(value, output_spec, output_dict, input_spec,
+ dataset_spec, input_example)
+ if self.align and self.align not in output_spec:
+ raise ValueError(f"aligned field {self.align} is not in output_spec")
+
+
@attr.s(auto_attribs=True, frozen=True, kw_only=True)
class SearchQuery(TextSegment):
"""TextSegment that should be interpreted as a search query."""
@@ -172,24 +335,67 @@ class SearchQuery(TextSegment):
@attr.s(auto_attribs=True, frozen=True, kw_only=True)
-class Tokens(LitType):
- """Tokenized text, as List[str]."""
- default: List[Text] = attr.Factory(list)
+class _StringList(ListLitType):
+ """A list of strings."""
+ default: Sequence[str] = []
+
+ def validate_input(self, value, spec: Spec, example: Input):
+ if not isinstance(value, list) or not all(
+ [isinstance(v, str) for v in value]):
+ raise ValueError(f"{value} is not a list of strings")
+
+
+@attr.s(auto_attribs=True, frozen=True, kw_only=True)
+class Tokens(_StringList):
+ """Tokenized text."""
+ default: Sequence[str] = attr.Factory(list)
# Name of a TextSegment field from the input
# TODO(b/167617375): should we use 'align' here?
- parent: Optional[Text] = None
- mask_token: Optional[Text] = None # optional mask token for input
+ parent: Optional[str] = None
+ mask_token: Optional[str] = None # optional mask token for input
+ token_prefix: Optional[str] = "##" # optional prefix used in tokens
@attr.s(auto_attribs=True, frozen=True, kw_only=True)
-class TokenTopKPreds(LitType):
+class TokenTopKPreds(ListLitType):
"""Predicted tokens, as from a language model.
- Data should be a List[List[Tuple[str, float]]], where the inner list contains
- (word, probability) in descending order.
+ The inner list should contain (word, probability) in descending order.
"""
- align: Text = None # name of a Tokens field in the model output
- parent: Optional[Text] = None
+ default: Sequence[ScoredTextCandidates] = None
+
+ align: str = None # name of a Tokens field in the model output
+ parent: Optional[str] = None
+
+ def _validate_scored_candidates(self, scored_candidates):
+ """Validates a list of scored candidates."""
+ prev_val = math.inf
+ for scored_candidate in scored_candidates:
+ if not isinstance(scored_candidate, tuple):
+ raise ValueError(f"{scored_candidate} is not a tuple")
+ if not isinstance(scored_candidate[0], str):
+ raise ValueError(f"{scored_candidate} first element is not a str")
+ if scored_candidate[1] is not None:
+ if not isinstance(scored_candidate[1], NumericTypes):
+ raise ValueError(f"{scored_candidate} second element is not a num")
+ if prev_val < scored_candidate[1]:
+ raise ValueError(
+ "TokenTopKPreds candidates are not in descending order")
+ else:
+ prev_val = scored_candidate[1]
+
+ def validate_output(self, value, output_spec: Spec, output_dict: JsonDict,
+ input_spec: Spec, dataset_spec: Spec,
+ input_example: Input):
+
+ if not isinstance(value, list):
+ raise ValueError(f"{value} is not a list of scored text candidates")
+ for scored_candidates in value:
+ self._validate_scored_candidates(scored_candidates)
+ if self.align and not isinstance(output_spec[self.align], Tokens):
+ raise ValueError(
+ f"aligned field {self.align} is {type(output_spec[self.align])}, "
+ "expected Tokens")
@attr.s(auto_attribs=True, frozen=True, kw_only=True)
@@ -200,67 +406,168 @@ class Scalar(LitType):
default: float = 0
step: float = .01
+ def validate_input(self, value, spec: Spec, example: Input):
+ if not isinstance(value, NumericTypes):
+ raise ValueError(f"{value} is of type {type(value)}, expected a number")
+
@attr.s(auto_attribs=True, frozen=True, kw_only=True)
class RegressionScore(Scalar):
"""Regression score, a single float."""
# name of a Scalar or RegressionScore field in input
- parent: Optional[Text] = None
+ parent: Optional[str] = None
-
-@attr.s(auto_attribs=True, frozen=True, kw_only=True)
-class ReferenceScores(LitType):
- """Score of one or more target sequences, as List[float]."""
- # name of a TextSegment or ReferenceTexts field in the input
- parent: Optional[Text] = None
+ def validate_output(self, value, output_spec: Spec, output_dict: JsonDict,
+ input_spec: Spec, dataset_spec: Spec,
+ input_example: Input):
+ if not isinstance(value, NumericTypes):
+ raise ValueError(f"{value} is of type {type(value)}, expected a number")
+ if self.parent and not isinstance(dataset_spec[self.parent], Scalar):
+ raise ValueError(f"parent field {self.parent} is of type "
+ f"{type(self.parent)}, expected Scalar")
@attr.s(auto_attribs=True, frozen=True, kw_only=True)
-class CategoryLabel(LitType):
+class ReferenceScores(ListLitType):
+ """Score of one or more target sequences."""
+ default: Sequence[float] = None
+
+ # name of a TextSegment or ReferenceTexts field in the input
+ parent: Optional[str] = None
+
+ def validate_output(self, value, output_spec: Spec, output_dict: JsonDict,
+ input_spec: Spec, dataset_spec: Spec,
+ input_example: Input):
+ if isinstance(value, list):
+ if not all([isinstance(v, NumericTypes) for v in value]):
+ raise ValueError(f"{value} is of type {type(value)}, expected a list "
+ "of numbers")
+ elif not isinstance(value, np.ndarray) or not np.issubdtype(
+ value.dtype, np.number):
+ raise ValueError(f"{value} is of type {type(value)}, expected a list of "
+ "numbers")
+ if self.parent and not isinstance(
+ input_spec[self.parent], (TextSegment, ReferenceTexts)):
+ raise ValueError(f"parent field {self.parent} is of type "
+ f"{type(self.parent)}, expected TextSegment or "
+ "ReferenceTexts")
+
+
+@attr.s(auto_attribs=True, frozen=True, kw_only=True)
+class CategoryLabel(StringLitType):
"""Category or class label, a single string."""
# Optional vocabulary to specify allowed values.
# If omitted, any value is accepted.
- vocab: Optional[Sequence[Text]] = None # label names
-
-
-@attr.s(auto_attribs=True, frozen=True, kw_only=True)
-class MulticlassPreds(LitType):
+ vocab: Optional[Sequence[str]] = None # label names
+
+ def validate_input(self, value, spec: Spec, example: Input):
+ if not isinstance(value, str):
+ raise ValueError(f"{value} is of type {type(value)}, expected str")
+ if self.vocab and value not in list(self.vocab):
+ raise ValueError(f"{value} is not in provided vocab")
+
+
+@attr.s(auto_attribs=True, frozen=True, kw_only=True)
+class _Tensor(LitType):
+ """A tensor type."""
+ default: Sequence[float] = None
+
+ def validate_input(self, value, spec: Spec, example: Input):
+ if isinstance(value, list):
+ if not all([isinstance(v, NumericTypes) for v in value]):
+ raise ValueError(f"{value} is not a list of numbers")
+ elif isinstance(value, np.ndarray):
+ if not np.issubdtype(value.dtype, np.number):
+ raise ValueError(f"{value} is not an array of numbers")
+ else:
+ raise ValueError(f"{value} is not a list or ndarray of numbers")
+
+ def validate_ndim(self, value, ndim: Union[int, list[int]]):
+ """Validate the number of dimensions in a tensor.
+
+ Args:
+ value: The tensor to validate.
+ ndim: Either a number of dimensions to validate that the value has, or
+ a list of dimensions any of which are valid for the value to have.
+
+ Raises:
+ ValueError if validation fails.
+ """
+ if isinstance(ndim, int):
+ ndim = [ndim]
+ if isinstance(value, np.ndarray):
+ if value.ndim not in ndim:
+ raise ValueError(f"{value} ndim is not one of {ndim}")
+ else:
+ if 1 not in ndim:
+ raise ValueError(f"{value} ndim is not 1. "
+ "Use a numpy array for multidimensional arrays")
+
+
+@attr.s(auto_attribs=True, frozen=True, kw_only=True)
+class MulticlassPreds(_Tensor):
"""Multiclass predicted probabilities, as [num_labels]."""
# Vocabulary is required here for decoding model output.
# Usually this will match the vocabulary in the corresponding label field.
- vocab: Sequence[Text] # label names
+ vocab: Sequence[str] # label names
null_idx: Optional[int] = None # vocab index of negative (null) label
- parent: Optional[Text] = None # CategoryLabel field in input
+ parent: Optional[str] = None # CategoryLabel field in input
autosort: Optional[bool] = False # Enable automatic sorting
+ threshold: Optional[float] = None # binary threshold, used to compute margin
@property
def num_labels(self):
return len(self.vocab)
+ def validate_input(self, value, spec: Spec, example: Input):
+ super().validate_input(value, spec, example)
+ if self.null_idx is not None:
+ if self.null_idx < 0 or self.null_idx >= self.num_labels:
+ raise ValueError(f"null_idx {self.null_idx} is not in the vocab range")
+
+ def validate_output(self, value, output_spec: Spec, output_dict: JsonDict,
+ input_spec: Spec, dataset_spec: Spec,
+ input_example: Input):
+ self.validate_input(value, output_spec, input_example)
+ if self.parent and not isinstance(
+ dataset_spec[self.parent], CategoryLabel):
+ raise ValueError(f"parent field {self.parent} is of type "
+ f"{type(self.parent)}, expected CategoryLabel")
+
@attr.s(auto_attribs=True, frozen=True, kw_only=True)
-class SequenceTags(LitType):
+class SequenceTags(_StringList):
"""Sequence tags, aligned to tokens.
The data should be a list of string labels, one for each token.
"""
- align: Text # name of Tokens field
+ align: str # name of Tokens field
@attr.s(auto_attribs=True, frozen=True, kw_only=True)
-class SpanLabels(LitType):
- """Span labels, a List[dtypes.SpanLabel] aligned to tokens.
+class SpanLabels(ListLitType):
+ """Span labels aligned to tokens.
Span labels can cover more than one token, may not cover all tokens in the
sentence, and may overlap with each other.
"""
- align: Text # name of Tokens field
- parent: Optional[Text] = None
+ default: Sequence[dtypes.SpanLabel] = None
+ align: str # name of Tokens field
+ parent: Optional[str] = None
+
+ def validate_output(self, value, output_spec: Spec, output_dict: JsonDict,
+ input_spec: Spec, dataset_spec: Spec,
+ input_example: Input):
+ if not isinstance(value, list) or not all(
+ [isinstance(v, dtypes.SpanLabel) for v in value]):
+ raise ValueError(f"{value} is not a list of SpanLabels")
+ if not isinstance(output_spec[self.align], Tokens):
+ raise ValueError(f"{self.align} is not a Tokens field")
@attr.s(auto_attribs=True, frozen=True, kw_only=True)
-class EdgeLabels(LitType):
- """Edge labels, a List[dtypes.EdgeLabel] between pairs of spans.
+class EdgeLabels(ListLitType):
+ """Edge labels between pairs of spans.
This is a general form for structured prediction output; each entry consists
of (span1, span2, label). See
@@ -268,12 +575,22 @@ class EdgeLabels(LitType):
https://github.com/nyu-mll/jiant/tree/master/probing#data-format for more
details.
"""
- align: Text # name of Tokens field
+ default: Sequence[dtypes.EdgeLabel] = None
+ align: str # name of Tokens field
+
+ def validate_output(self, value, output_spec: Spec, output_dict: JsonDict,
+ input_spec: Spec, dataset_spec: Spec,
+ input_example: Input):
+ if not isinstance(value, list) or not all(
+ [isinstance(v, dtypes.EdgeLabel) for v in value]):
+ raise ValueError(f"{value} is not a list of EdgeLabel")
+ if not isinstance(output_spec[self.align], Tokens):
+ raise ValueError(f"{self.align} is not a Tokens field")
@attr.s(auto_attribs=True, frozen=True, kw_only=True)
-class MultiSegmentAnnotations(LitType):
- """Very general type for in-line text annotations, as List[AnnotationCluster].
+class MultiSegmentAnnotations(ListLitType):
+ """Very general type for in-line text annotations.
This is a more general version of SpanLabel, EdgeLabel, and other annotation
types, designed to represent annotations that may span multiple segments.
@@ -285,91 +602,180 @@ class MultiSegmentAnnotations(LitType):
TODO(lit-dev): by default, spans are treated as bytes in this context.
Make this configurable, if some spans need to refer to tokens instead.
"""
+ default: Sequence[dtypes.AnnotationCluster] = None
exclusive: bool = False # if true, treat as candidate list
background: bool = False # if true, don't emphasize in visualization
+ def validate_output(self, value, output_spec: Spec, output_dict: JsonDict,
+ input_spec: Spec, dataset_spec: Spec,
+ input_example: Input):
+ if not isinstance(value, list) or not all(
+ [isinstance(v, dtypes.AnnotationCluster) for v in value]):
+ raise ValueError(f"{value} is not a list of AnnotationCluster")
##
# Model internals, for interpretation.
@attr.s(auto_attribs=True, frozen=True, kw_only=True)
-class Embeddings(LitType):
+class Embeddings(_Tensor):
"""Embeddings or model activations, as fixed-length [emb_dim]."""
- pass
+
+ def validate_output(self, value, output_spec: Spec, output_dict: JsonDict,
+ input_spec: Spec, dataset_spec: Spec,
+ input_example: Input):
+ super().validate_output(value, output_spec, output_dict, input_spec,
+ dataset_spec, input_example)
+ self.validate_ndim(value, 1)
@attr.s(auto_attribs=True, frozen=True, kw_only=True)
-class Gradients(LitType):
- """Gradients with respect to embeddings."""
- grad_for: Optional[Text] = None # name of Embeddings field
+class _GradientsBase(_Tensor):
+ """Shared gradient attributes."""
+ align: Optional[str] = None # name of a Tokens field
+ grad_for: Optional[str] = None # name of Embeddings field
# Name of the field in the input that can be used to specify the target class
# for the gradients.
- grad_target_field_key: Optional[Text] = None
+ grad_target_field_key: Optional[str] = None
+
+ def validate_output(self, value, output_spec: Spec, output_dict: JsonDict,
+ input_spec: Spec, dataset_spec: Spec,
+ input_example: Input):
+ super().validate_output(
+ value, output_spec, output_dict, input_spec, dataset_spec,
+ input_example)
+ if self.align is not None:
+ align_entry = (output_spec[self.align] if self.align in output_spec
+ else input_spec[self.align])
+ if not isinstance(align_entry, (Tokens, ImageBytes)):
+ raise ValueError(f"{self.align} is not a Tokens or ImageBytes field")
+ if self.grad_for is not None and not isinstance(
+ output_spec[self.grad_for], (Embeddings, TokenEmbeddings)):
+ raise ValueError(f"{self.grad_for} is not a Embeddings field")
+ if (self.grad_target_field_key is not None and
+ self.grad_target_field_key not in input_spec):
+ raise ValueError(f"{self.grad_target_field_key} is not in input_spec")
+
+
+@attr.s(auto_attribs=True, frozen=True, kw_only=True)
+class Gradients(_GradientsBase):
+ """1D gradients with respect to embeddings."""
+
+ def validate_output(self, value, output_spec: Spec, output_dict: JsonDict,
+ input_spec: Spec, dataset_spec: Spec,
+ input_example: Input):
+ super().validate_output(
+ value, output_spec, output_dict, input_spec, dataset_spec,
+ input_example)
+ self.validate_ndim(value, 1)
@attr.s(auto_attribs=True, frozen=True, kw_only=True)
-class TokenEmbeddings(LitType):
+class _InfluenceEncodings(_Tensor):
+ """A single vector of [enc_dim]."""
+ grad_target: Optional[str] = None # class for computing gradients (string)
+
+ def validate_output(self, value, output_spec: Spec, output_dict: JsonDict,
+ input_spec: Spec, dataset_spec: Spec,
+ input_example: Input):
+ super().validate_output(
+ value, output_spec, output_dict, input_spec, dataset_spec,
+ input_example)
+ self.validate_ndim(value, 1)
+
+
+@attr.s(auto_attribs=True, frozen=True, kw_only=True)
+class TokenEmbeddings(_Tensor):
"""Per-token embeddings, as [num_tokens, emb_dim]."""
- align: Optional[Text] = None # name of a Tokens field
+ align: Optional[str] = None # name of a Tokens field
+
+ def validate_output(self, value, output_spec: Spec, output_dict: JsonDict,
+ input_spec: Spec, dataset_spec: Spec,
+ input_example: Input):
+ super().validate_output(
+ value, output_spec, output_dict, input_spec, dataset_spec,
+ input_example)
+ self.validate_ndim(value, 2)
+ if self.align is not None and not isinstance(
+ output_spec[self.align], Tokens):
+ raise ValueError(f"{self.align} is not a Tokens field")
@attr.s(auto_attribs=True, frozen=True, kw_only=True)
-class TokenGradients(LitType):
- """Gradients with respect to per-token inputs, as [num_tokens, emb_dim]."""
- align: Optional[Text] = None # name of a Tokens field
- grad_for: Optional[Text] = None # name of TokenEmbeddings field
- # Name of the field in the input that can be used to specify the target class
- # for the gradients.
- grad_target_field_key: Optional[Text] = None
+class TokenGradients(_GradientsBase):
+ """Gradients for per-token inputs, as [num_tokens, emb_dim]."""
+
+ def validate_output(self, value, output_spec: Spec, output_dict: JsonDict,
+ input_spec: Spec, dataset_spec: Spec,
+ input_example: Input):
+ super().validate_output(
+ value, output_spec, output_dict, input_spec, dataset_spec,
+ input_example)
+ self.validate_ndim(value, 2)
@attr.s(auto_attribs=True, frozen=True, kw_only=True)
-class ImageGradients(LitType):
+class ImageGradients(_GradientsBase):
"""Gradients with respect to per-pixel inputs, as a multidimensional array."""
- # Name of the field in the input for which the gradients are computed.
- align: Optional[Text] = None
- # Name of the field in the input that can be used to specify the target class
- # for the gradients.
- grad_target_field_key: Optional[Text] = None
+
+ def validate_output(self, value, output_spec: Spec, output_dict: JsonDict,
+ input_spec: Spec, dataset_spec: Spec,
+ input_example: Input):
+ super().validate_output(
+ value, output_spec, output_dict, input_spec, dataset_spec,
+ input_example)
+ self.validate_ndim(value, [2, 3])
@attr.s(auto_attribs=True, frozen=True, kw_only=True)
-class AttentionHeads(LitType):
+class AttentionHeads(_Tensor):
"""One or more attention heads, as [num_heads, num_tokens, num_tokens]."""
# input and output Tokens fields; for self-attention these can be the same
- align_in: Text
- align_out: Text
+ align_in: str
+ align_out: str
+
+ def validate_output(self, value, output_spec: Spec, output_dict: JsonDict,
+ input_spec: Spec, dataset_spec: Spec,
+ input_example: Input):
+ super().validate_output(
+ value, output_spec, output_dict, input_spec, dataset_spec,
+ input_example)
+ self.validate_ndim(value, 3)
+ if self.align_in is None or not isinstance(
+ output_spec[self.align_in], Tokens):
+ raise ValueError(f"{self.align_in} is not a Tokens field")
+ if self.align_out is None or not isinstance(
+ output_spec[self.align_out], Tokens):
+ raise ValueError(f"{self.align_out} is not a Tokens field")
@attr.s(auto_attribs=True, frozen=True, kw_only=True)
-class SubwordOffsets(LitType):
- """Offsets to align input tokens to wordpieces or characters, as List[int].
+class SubwordOffsets(ListLitType):
+ """Offsets to align input tokens to wordpieces or characters.
offsets[i] should be the index of the first wordpiece for input token i.
"""
- align_in: Text # name of field in data spec
- align_out: Text # name of field in model output spec
+ default: Sequence[int] = None
+ align_in: str # name of field in data spec
+ align_out: str # name of field in model output spec
@attr.s(auto_attribs=True, frozen=True, kw_only=True)
-class SparseMultilabel(LitType):
- """Sparse multi-label represented as a list of strings, as List[str]."""
- vocab: Optional[Sequence[Text]] = None # label names
- default: Sequence[Text] = []
- # TODO(b/162269499) Migrate non-comma separators to custom type.
- separator: Text = "," # Used for display purposes.
+class SparseMultilabel(_StringList):
+ """Sparse multi-label represented as a list of strings."""
+ vocab: Optional[Sequence[str]] = None # label names
+ separator: str = "," # Used for display purposes.
@attr.s(auto_attribs=True, frozen=True, kw_only=True)
-class SparseMultilabelPreds(LitType):
+class SparseMultilabelPreds(_StringCandidateList):
"""Sparse multi-label predictions represented as a list of tuples.
- The tuples are of the label and the score. So as a List[(str, float)].
+ The tuples are of the label and the score.
"""
- vocab: Optional[Sequence[Text]] = None # label names
- parent: Optional[Text] = None
- default: Sequence[Text] = []
+ default: ScoredTextCandidates = None
+ vocab: Optional[Sequence[str]] = None # label names
+ parent: Optional[str] = None
@attr.s(auto_attribs=True, frozen=True, kw_only=True)
@@ -377,64 +783,128 @@ class FieldMatcher(LitType):
"""For matching spec fields.
The front-end will perform spec matching and fill in the vocab field
- accordingly. UI will materialize this to a dropdown-list.
- Use MultiFieldMatcher when your intent is selecting more than one field in UI.
+ accordingly.
"""
- spec: Text # which spec to check, 'dataset', 'input', or 'output'.
- types: Union[Text, Sequence[Text]] # types of LitType to match in the spec.
- vocab: Optional[Sequence[Text]] = None # names matched from the spec.
+ spec: str # which spec to check, 'dataset', 'input', or 'output'.
+ types: Union[str, Sequence[str]] # types of LitType to match in the spec.
+ vocab: Optional[Sequence[str]] = None # names matched from the spec.
@attr.s(auto_attribs=True, frozen=True, kw_only=True)
-class MultiFieldMatcher(LitType):
- """For matching spec fields.
+class SingleFieldMatcher(FieldMatcher):
+ """For matching a single spec field.
- The front-end will perform spec matching and fill in the vocab field
- accordingly. UI will materialize this to multiple checkboxes. Use this when
- the user needs to pick more than one field in UI.
+ UI will materialize this to a dropdown-list.
"""
- spec: Text # which spec to check, 'dataset', 'input', or 'output'.
- types: Union[Text, Sequence[Text]] # types of LitType to match in the spec.
- vocab: Optional[Sequence[Text]] = None # names matched from the spec.
- default: Sequence[Text] = [] # default names of selected items.
+ default: str = None
+
+
+@attr.s(auto_attribs=True, frozen=True, kw_only=True)
+class MultiFieldMatcher(FieldMatcher):
+ """For matching multiple spec fields.
+
+ UI will materialize this to multiple checkboxes. Use this when the user needs
+ to pick more than one field in UI.
+ """
+ default: Sequence[str] = [] # default names of selected items.
select_all: bool = False # Select all by default (overriddes default).
@attr.s(auto_attribs=True, frozen=True, kw_only=True)
-class TokenSalience(LitType):
- """Metadata about a returned token salience map, returned as dtypes.TokenSalience."""
+class Salience(LitType):
+ """Metadata about a returned salience map."""
autorun: bool = False # If the saliency technique is automatically run.
signed: bool # If the returned values are signed.
@attr.s(auto_attribs=True, frozen=True, kw_only=True)
-class FeatureSalience(LitType):
- """Metadata about a returned feature salience map, returned as dtypes.FeatureSalience."""
- autorun: bool = True # If the saliency technique is automatically run.
- signed: bool # If the returned values are signed.
+class TokenSalience(Salience):
+ """Metadata about a returned token salience map."""
+ default: dtypes.TokenSalience = None
@attr.s(auto_attribs=True, frozen=True, kw_only=True)
-class ImageSalience(LitType):
+class FeatureSalience(Salience):
+ """Metadata about a returned feature salience map."""
+ default: dtypes.FeatureSalience = None
+
+
+@attr.s(auto_attribs=True, frozen=True, kw_only=True)
+class ImageSalience(Salience):
"""Metadata about a returned image saliency.
The data is returned as an image in the base64 URL encoded format, e.g.,
data:image/jpg;base64,w4J3k1Bfa...
"""
- autorun: bool = False # If the saliency technique is automatically run.
+ signed: bool = False # If the returned values are signed.
@attr.s(auto_attribs=True, frozen=True, kw_only=True)
-class SequenceSalience(LitType):
- """Metadata about a returned sequence salience map, returned as dtypes.SequenceSalienceMap."""
- autorun: bool = False # If the saliency technique is automatically run.
- signed: bool # If the returned values are signed.
+class SequenceSalience(Salience):
+ """Metadata about a returned sequence salience map."""
+ default: dtypes.SequenceSalienceMap = None
@attr.s(auto_attribs=True, frozen=True, kw_only=True)
-class Boolean(LitType):
+class BooleanLitType(LitType):
"""Boolean value."""
default: bool = False
+ def validate_input(self, value, spec, example: Input):
+ if not isinstance(value, bool):
+ raise ValueError(f"{value} is not a boolean")
+
+
+@attr.s(auto_attribs=True, frozen=True, kw_only=True)
+class CurveDataPoints(LitType):
+ """Represents data points of a curve.
+
+ A list of tuples where the first and second elements of the tuple are the
+ x and y coordinates of the corresponding curve point respectively.
+ """
+ pass
+
+
+@attr.s(auto_attribs=True, frozen=True, kw_only=True)
+class InfluentialExamples(LitType):
+ """Represents influential examples from the training set.
+
+ This is as returned by a training-data attribution method like TracIn or
+ influence functions.
+
+ This describes a generator component; values are Sequence[Sequence[JsonDict]].
+ """
+ pass
+
+
+@enum.unique
+class MetricBestValue(dtypes.EnumSerializableAsValues, enum.Enum):
+ """The method to use to determine the best value for a Metric."""
+ HIGHEST = "highest"
+ LOWEST = "lowest"
+ NONE = "none"
+ ZERO = "zero"
+
+
+@attr.s(auto_attribs=True, frozen=True, kw_only=True)
+class MetricResult(LitType):
+ """Score returned from the computation of a Metric."""
+ default: float = 0
+ description: str = ""
+ best_value: MetricBestValue = MetricBestValue.NONE
+
+
+# LINT.ThenChange(../client/lib/lit_types.ts)
+
+# Type aliases for backend use.
+# The following names are existing datatypes in TypeScript, so we add a
+# `LitType` suffix to avoid collisions with language features on the front-end.
+Boolean = BooleanLitType
+String = StringLitType
+URL = URLLitType
+
-# LINT.ThenChange(../client/lib/types.ts)
+def get_type_by_name(typename: str) -> Type[LitType]:
+ cls = globals()[typename]
+ assert issubclass(cls, LitType)
+ return cls
diff --git a/lit_nlp/api/types_test.py b/lit_nlp/api/types_test.py
new file mode 100644
index 00000000..fa9f2231
--- /dev/null
+++ b/lit_nlp/api/types_test.py
@@ -0,0 +1,579 @@
+"""Tests for types."""
+
+from typing import Any
+from absl.testing import absltest
+from absl.testing import parameterized
+from lit_nlp.api import dtypes
+from lit_nlp.api import types
+import numpy as np
+
+
+class TypesTest(parameterized.TestCase):
+
+ def test_inherit_parent_default_type(self):
+ lit_type = types.StringLitType()
+ self.assertIsInstance(lit_type.default, str)
+
+ def test_inherit_parent_default_value(self):
+ lit_type = types.SingleFieldMatcher(spec="dataset", types=["LitType"])
+ self.assertIsNone(lit_type.default)
+
+ def test_requires_parent_custom_properties(self):
+ # TokenSalience requires the `signed` property of its parent class.
+ with self.assertRaises(TypeError):
+ _ = types.TokenSalience(autorun=True)
+
+ def test_inherit_parent_custom_properties(self):
+ lit_type = types.TokenSalience(autorun=True, signed=True)
+ self.assertIsNone(lit_type.default)
+
+ lit_type = types.TokenGradients(
+ grad_for="cls_emb", grad_target_field_key="grad_class")
+ self.assertTrue(hasattr(lit_type, "align"))
+ self.assertFalse(hasattr(lit_type, "not_a_property"))
+
+ @parameterized.named_parameters(
+ ("list[int]", [1, 2, 3], 1),
+ ("np_array[int]", np.array([1, 2, 3]), 1),
+ ("np_array[list[int]]", np.array([[1, 1], [2, 3]]), 2),
+ ("np_array[list[int]]_2_dim", np.array([[1, 1], [2, 3]]), [2, 4]),
+ )
+ def test_tensor_ndim(self, value, ndim):
+ emb = types.Embeddings()
+ try:
+ emb.validate_ndim(value, ndim)
+ except ValueError:
+ self.fail("Raised unexpected error.")
+
+ @parameterized.named_parameters(
+ ("ndim_wrong_size", [1, 2, 3], 2),
+ ("ndim_wrong_type", np.array([[1, 1], [2, 3]]), [1]),
+ )
+ def test_tensor_ndim_errors(self, value, ndim):
+ with self.assertRaises(ValueError):
+ emb = types.Embeddings()
+ emb.validate_ndim(value, ndim)
+
+ @parameterized.named_parameters(
+ ("boolean", types.Boolean(), True),
+ ("embeddings_list[int]", types.Embeddings(), [1, 2]),
+ ("embeddings_np_array", types.Embeddings(), np.array([1, 2])),
+ ("image", types.ImageBytes(), "data:image/blah..."),
+ ("scalar_float", types.Scalar(), 3.4),
+ ("scalar_int", types.Scalar(), 3),
+ ("scalar_numpy", types.Scalar(), np.int64(2)),
+ ("text", types.TextSegment(), "hi"),
+ ("tokens", types.Tokens(), ["a", "b"]),
+ )
+ def test_type_validate_input(self, lit_type: types.LitType, value: Any):
+ spec = {"score": types.Scalar(), "text": types.TextSegment()}
+ example = {}
+ try:
+ lit_type.validate_input(value, spec, example)
+ except ValueError:
+ self.fail("Raised unexpected error.")
+
+ @parameterized.named_parameters(
+ ("boolean_number", types.Boolean(), 3.14159),
+ ("boolean_text", types.Boolean(), "hi"),
+ ("embeddings_bool", types.Embeddings(), True),
+ ("embeddings_number", types.Embeddings(), 3.14159),
+ ("embeddings_text", types.Embeddings(), "hi"),
+ ("image_bool", types.ImageBytes(), True),
+ ("image_number", types.ImageBytes(), 3.14159),
+ ("image_text", types.ImageBytes(), "hi"),
+ ("scalar_text", types.Scalar(), "hi"),
+ ("text_bool", types.TextSegment(), True),
+ ("text_number", types.TextSegment(), 3.14159),
+ ("tokens_bool", types.Tokens(), True),
+ ("tokens_number", types.Tokens(), 3.14159),
+ ("tokens_text", types.Tokens(), "hi"),
+ )
+ def test_type_validate_input_errors(self,
+ lit_type: types.LitType,
+ value: Any):
+ spec = {"score": types.Scalar(), "text": types.TextSegment()}
+ example = {}
+ with self.assertRaises(ValueError):
+ lit_type.validate_input(value, spec, example)
+
+ @parameterized.named_parameters(
+ dict(
+ testcase_name="CategoryLabel",
+ json_dict={
+ "required": False,
+ "annotated": False,
+ "default": "",
+ "vocab": ["0", "1"],
+ "__name__": "CategoryLabel",
+ },
+ expected_type=types.CategoryLabel,
+ ),
+ dict(
+ testcase_name="Embeddings",
+ json_dict={
+ "required": True,
+ "annotated": False,
+ "default": None,
+ "__name__": "Embeddings",
+ },
+ expected_type=types.Embeddings,
+ ),
+ dict(
+ testcase_name="Gradients",
+ json_dict={
+ "required": True,
+ "annotated": False,
+ "default": None,
+ "align": None,
+ "grad_for": "cls_emb",
+ "grad_target_field_key": "grad_class",
+ "__name__": "Gradients",
+ },
+ expected_type=types.Gradients,
+ ),
+ dict(
+ testcase_name="MulticlassPreds",
+ json_dict={
+ "required": True,
+ "annotated": False,
+ "default": None,
+ "vocab": ["0", "1"],
+ "null_idx": 0,
+ "parent": "label",
+ "autosort": False,
+ "threshold": None,
+ "__name__": "MulticlassPreds",
+ },
+ expected_type=types.MulticlassPreds,
+ ),
+ dict(
+ testcase_name="RegressionScore",
+ json_dict={
+ "required": True,
+ "annotated": False,
+ "min_val": 0,
+ "max_val": 1,
+ "default": 0,
+ "step": 0.01,
+ "parent": "label",
+ "__name__": "RegressionScore",
+ },
+ expected_type=types.RegressionScore,
+ ),
+ dict(
+ testcase_name="Scalar",
+ json_dict={
+ "required": True,
+ "annotated": False,
+ "min_val": 2,
+ "max_val": 100,
+ "default": 10,
+ "step": 1,
+ "__name__": "Scalar",
+ },
+ expected_type=types.Scalar,
+ ),
+ dict(
+ testcase_name="TextSegment",
+ json_dict={
+ "required": True,
+ "annotated": False,
+ "default": "",
+ "__name__": "TextSegment",
+ },
+ expected_type=types.TextSegment,
+ ),
+ dict(
+ testcase_name="TokenEmbeddings",
+ json_dict={
+ "required": True,
+ "annotated": False,
+ "default": None,
+ "align": "tokens_sentence",
+ "__name__": "TokenEmbeddings",
+ },
+ expected_type=types.TokenEmbeddings,
+ ),
+ dict(
+ testcase_name="Tokens",
+ json_dict={
+ "required": False,
+ "annotated": False,
+ "default": [],
+ "parent": "sentence",
+ "mask_token": None,
+ "token_prefix": "##",
+ "__name__": "Tokens",
+ },
+ expected_type=types.Tokens,
+ ),
+ )
+ def test_from_json(self, json_dict: types.JsonDict,
+ expected_type: types.LitType):
+ lit_type: types.LitType = types.LitType.from_json(json_dict)
+ self.assertIsInstance(lit_type, expected_type)
+ for key in json_dict:
+ if key == "__name__":
+ continue
+ elif hasattr(lit_type, key):
+ self.assertEqual(getattr(lit_type, key), json_dict[key])
+ else:
+ self.fail(f"Encountered unknown property {key} for type "
+ f"{lit_type.__class__.__name__}.")
+
+ @parameterized.named_parameters(
+ ("empty_dict", {}, KeyError),
+ ("invalid_name_empty", {"__name__": ""}, NameError),
+ ("invalid_name_none", {"__name__": None}, TypeError),
+ ("invalid_name_number", {"__name__": 3.14159}, TypeError),
+ ("invalid_type_name", {"__name__": "not_a_lit_type"}, NameError),
+ )
+ def test_from_json_errors(self, value: types.JsonDict, expected_error):
+ with self.assertRaises(expected_error):
+ _ = types.LitType.from_json(value)
+
+ def test_type_validate_gentext_output(self):
+ ds_spec = {
+ "num": types.Scalar(),
+ "text": types.TextSegment(),
+ }
+ out_spec = {
+ "gentext": types.GeneratedText(parent="text"),
+ "cands": types.GeneratedTextCandidates(parent="text")
+ }
+ example = {"num": 1, "text": "hi"}
+ output = {"gentext": "test", "cands": [("hi", 4), ("bye", None)]}
+
+ gentext = types.GeneratedText(parent="text")
+ gentextcands = types.GeneratedTextCandidates(parent="text")
+ try:
+ gentext.validate_output("hi", out_spec, output, ds_spec, ds_spec, example)
+ gentextcands.validate_output([("hi", 4), ("bye", None)], out_spec, output,
+ ds_spec, ds_spec, example)
+ except ValueError:
+ self.fail("Raised unexpected error.")
+
+ bad_gentext = types.GeneratedText(parent="num")
+ self.assertRaises(ValueError, bad_gentext.validate_output, "hi", out_spec,
+ output, ds_spec, ds_spec, example)
+
+ self.assertRaises(ValueError, gentextcands.validate_output,
+ [("hi", "wrong"), ("bye", None)], out_spec, output,
+ ds_spec, ds_spec, example)
+ bad_gentextcands = types.GeneratedTextCandidates(parent="num")
+ self.assertRaises(ValueError, bad_gentextcands.validate_output,
+ [("hi", 4), ("bye", None)], out_spec, output, ds_spec,
+ ds_spec, example)
+
+ def test_type_validate_genurl(self):
+ ds_spec = {
+ "text": types.TextSegment(),
+ }
+ out_spec = {
+ "genurl": types.GeneratedURL(align="cands"),
+ "cands": types.GeneratedTextCandidates(parent="text")
+ }
+ example = {"text": "hi"}
+ output = {"genurl": "https://blah", "cands": [("hi", 4), ("bye", None)]}
+
+ genurl = types.GeneratedURL(align="cands")
+ try:
+ genurl.validate_output("https://blah", out_spec, output, ds_spec, ds_spec,
+ example)
+ except ValueError:
+ self.fail("Raised unexpected error.")
+
+ self.assertRaises(ValueError, genurl.validate_output, 4,
+ out_spec, output, ds_spec, ds_spec, example)
+ bad_genurl = types.GeneratedURL(align="wrong")
+ self.assertRaises(ValueError, bad_genurl.validate_output, "https://blah",
+ out_spec, output, ds_spec, ds_spec, example)
+
+ def test_tokentopk(self):
+ ds_spec = {
+ "text": types.TextSegment(),
+ }
+ out_spec = {
+ "tokens": types.Tokens(),
+ "preds": types.TokenTopKPreds(align="tokens")
+ }
+ example = {"text": "hi"}
+ output = {"tokens": ["hi"], "preds": [[("one", .9), ("two", .4)]]}
+
+ preds = types.TokenTopKPreds(align="tokens")
+ try:
+ preds.validate_output(
+ [[("one", .9), ("two", .4)]], out_spec, output, ds_spec, ds_spec,
+ example)
+ except ValueError:
+ self.fail("Raised unexpected error.")
+
+ self.assertRaises(
+ ValueError, preds.validate_output,
+ [[("one", .2), ("two", .4)]], out_spec, output, ds_spec, ds_spec,
+ example)
+ self.assertRaises(
+ ValueError, preds.validate_output,
+ [["one", "two"]], out_spec, output, ds_spec, ds_spec, example)
+ self.assertRaises(
+ ValueError, preds.validate_output, ["wrong"], out_spec, output,
+ ds_spec, ds_spec, example)
+
+ bad_preds = types.TokenTopKPreds(align="preds")
+ self.assertRaises(
+ ValueError, bad_preds.validate_output,
+ [[("one", .9), ("two", .4)]], out_spec, output, ds_spec, ds_spec,
+ example)
+
+ def test_regression(self):
+ ds_spec = {
+ "val": types.Scalar(),
+ "text": types.TextSegment(),
+ }
+ out_spec = {
+ "score": types.RegressionScore(parent="val"),
+ }
+ example = {"val": 2}
+ output = {"score": 1}
+
+ score = types.RegressionScore(parent="val")
+ try:
+ score.validate_output(1, out_spec, output, ds_spec, ds_spec, example)
+ except ValueError:
+ self.fail("Raised unexpected error.")
+
+ self.assertRaises(ValueError, score.validate_output, "wrong",
+ out_spec, output, ds_spec, ds_spec, example)
+ bad_score = types.RegressionScore(parent="text")
+ self.assertRaises(ValueError, bad_score.validate_output, 1,
+ out_spec, output, ds_spec, ds_spec, example)
+
+ def test_reference(self):
+ ds_spec = {
+ "text": types.TextSegment(),
+ "val": types.Scalar(),
+ }
+ out_spec = {
+ "scores": types.ReferenceScores(parent="text"),
+ }
+ example = {"text": "hi"}
+ output = {"scores": [1, 2]}
+
+ score = types.ReferenceScores(parent="text")
+ try:
+ score.validate_output([1, 2], out_spec, output, ds_spec, ds_spec, example)
+ score.validate_output(np.array([1, 2]), out_spec, output, ds_spec,
+ ds_spec, example)
+ except ValueError:
+ self.fail("Raised unexpected error.")
+
+ self.assertRaises(ValueError, score.validate_output, ["a"],
+ out_spec, output, ds_spec, ds_spec, example)
+ bad_score = types.ReferenceScores(parent="val")
+ self.assertRaises(ValueError, bad_score.validate_output, [1],
+ out_spec, output, ds_spec, ds_spec, example)
+
+ def test_multiclasspreds(self):
+ ds_spec = {
+ "label": types.CategoryLabel(),
+ "val": types.Scalar(),
+ }
+ out_spec = {
+ "scores": types.MulticlassPreds(
+ parent="label", null_idx=0, vocab=["a", "b"]),
+ }
+ example = {"label": "hi", "val": 1}
+ output = {"scores": [1, 2]}
+
+ score = types.MulticlassPreds(parent="label", null_idx=0, vocab=["a", "b"])
+ try:
+ score.validate_output([1, 2], out_spec, output, ds_spec, ds_spec, example)
+ score.validate_output(np.array([1, 2]), out_spec, output, ds_spec,
+ ds_spec, example)
+ except ValueError:
+ self.fail("Raised unexpected error.")
+
+ self.assertRaises(ValueError, score.validate_output, ["a", "b"],
+ out_spec, output, ds_spec, ds_spec, example)
+ bad_score = types.MulticlassPreds(
+ parent="label", null_idx=2, vocab=["a", "b"])
+ self.assertRaises(ValueError, bad_score.validate_output, [1, 2],
+ out_spec, output, ds_spec, ds_spec, example)
+ bad_score = types.MulticlassPreds(
+ parent="val", null_idx=0, vocab=["a", "b"])
+ self.assertRaises(ValueError, bad_score.validate_output, [1, 2],
+ out_spec, output, ds_spec, ds_spec, example)
+
+ def test_annotations(self):
+ ds_spec = {
+ "text": types.TextSegment(),
+ }
+ out_spec = {
+ "tokens": types.Tokens(),
+ "spans": types.SpanLabels(align="tokens"),
+ "edges": types.EdgeLabels(align="tokens"),
+ "annot": types.MultiSegmentAnnotations(),
+ }
+ example = {"text": "hi"}
+ output = {"tokens": ["hi"], "preds": [dtypes.SpanLabel(start=0, end=1)],
+ "edges": [dtypes.EdgeLabel(span1=(0, 0), span2=(1, 1), label=0)],
+ "annot": [dtypes.AnnotationCluster(label="hi", spans=[])]}
+
+ spans = types.SpanLabels(align="tokens")
+ edges = types.EdgeLabels(align="tokens")
+ annot = types.MultiSegmentAnnotations()
+ try:
+ spans.validate_output(
+ [dtypes.SpanLabel(start=0, end=1)], out_spec, output, ds_spec,
+ ds_spec, example)
+ edges.validate_output(
+ [dtypes.EdgeLabel(span1=(0, 0), span2=(1, 1), label=0)], out_spec,
+ output, ds_spec, ds_spec, example)
+ annot.validate_output(
+ [dtypes.AnnotationCluster(label="hi", spans=[])], out_spec,
+ output, ds_spec, ds_spec, example)
+ except ValueError:
+ self.fail("Raised unexpected error.")
+
+ self.assertRaises(
+ ValueError, spans.validate_output, [1], out_spec, output, ds_spec,
+ ds_spec, example)
+ self.assertRaises(
+ ValueError, edges.validate_output, [1], out_spec, output, ds_spec,
+ ds_spec, example)
+ self.assertRaises(
+ ValueError, annot.validate_output, [1], out_spec, output, ds_spec,
+ ds_spec, example)
+
+ bad_spans = types.SpanLabels(align="edges")
+ bad_edges = types.EdgeLabels(align="spans")
+ self.assertRaises(
+ ValueError, bad_spans.validate_output,
+ [dtypes.SpanLabel(start=0, end=1)], out_spec, output, ds_spec, ds_spec,
+ example)
+ self.assertRaises(
+ ValueError, bad_edges.validate_output,
+ [dtypes.EdgeLabel(span1=(0, 0), span2=(1, 1), label=0)], out_spec,
+ output, ds_spec, ds_spec, example)
+
+ def test_gradients(self):
+ ds_spec = {
+ "text": types.TextSegment(),
+ "target": types.CategoryLabel()
+ }
+ out_spec = {
+ "tokens": types.Tokens(),
+ "embs": types.Embeddings(),
+ "grads": types.Gradients(align="tokens", grad_for="embs",
+ grad_target_field_key="target")
+ }
+ example = {"text": "hi", "target": "one"}
+ output = {"tokens": ["hi"], "embs": [.1, .2], "grads": [.1]}
+
+ grads = types.Gradients(align="tokens", grad_for="embs",
+ grad_target_field_key="target")
+ embs = types.Embeddings()
+ try:
+ grads.validate_output([.1], out_spec, output, ds_spec, ds_spec, example)
+ embs.validate_output([.1, .2], out_spec, output, ds_spec, ds_spec,
+ example)
+ except ValueError:
+ self.fail("Raised unexpected error.")
+
+ self.assertRaises(
+ ValueError, grads.validate_output, ["bad"], out_spec, output, ds_spec,
+ ds_spec, example)
+ self.assertRaises(
+ ValueError, embs.validate_output, ["bad"], out_spec, output, ds_spec,
+ ds_spec, example)
+
+ bad_grads = types.Gradients(align="text", grad_for="embs",
+ grad_target_field_key="target")
+ self.assertRaises(
+ ValueError, bad_grads.validate_output, [.1], out_spec, output, ds_spec,
+ ds_spec, example)
+ bad_grads = types.Gradients(align="tokens", grad_for="tokens",
+ grad_target_field_key="target")
+ self.assertRaises(
+ ValueError, bad_grads.validate_output, [.1], out_spec, output, ds_spec,
+ ds_spec, example)
+ bad_grads = types.Gradients(align="tokens", grad_for="embs",
+ grad_target_field_key="bad")
+ self.assertRaises(
+ ValueError, bad_grads.validate_output, [.1], out_spec, output, ds_spec,
+ ds_spec, example)
+
+ def test_tokenembsgrads(self):
+ ds_spec = {
+ "text": types.TextSegment(),
+ "target": types.CategoryLabel()
+ }
+ out_spec = {
+ "tokens": types.Tokens(),
+ "embs": types.TokenEmbeddings(align="tokens"),
+ "grads": types.TokenGradients(align="tokens", grad_for="embs",
+ grad_target_field_key="target")
+ }
+ example = {"text": "hi", "target": "one"}
+ output = {"tokens": ["hi"], "embs": np.array([[.1], [.2]]),
+ "grads": np.array([[.1], [.2]])}
+
+ grads = types.TokenGradients(align="tokens", grad_for="embs",
+ grad_target_field_key="target")
+ embs = types.TokenEmbeddings(align="tokens")
+ try:
+ grads.validate_output(np.array([[.1], [.2]]), out_spec, output, ds_spec,
+ ds_spec, example)
+ embs.validate_output(np.array([[.1], [.2]]), out_spec, output, ds_spec,
+ ds_spec, example)
+ except ValueError:
+ self.fail("Raised unexpected error.")
+
+ self.assertRaises(
+ ValueError, grads.validate_output, np.array([.1, .2]), out_spec, output,
+ ds_spec, ds_spec, example)
+ self.assertRaises(
+ ValueError, embs.validate_output, np.array([.1, .2]), out_spec, output,
+ ds_spec, ds_spec, example)
+
+ bad_embs = types.TokenEmbeddings(align="grads")
+ self.assertRaises(
+ ValueError, bad_embs.validate_output, np.array([[.1], [.2]]), out_spec,
+ output, ds_spec, ds_spec, example)
+
+ def test_attention(self):
+ ds_spec = {
+ "text": types.TextSegment(),
+ }
+ out_spec = {
+ "tokens": types.Tokens(),
+ "val": types.RegressionScore,
+ "attn": types.AttentionHeads(align_in="tokens", align_out="tokens"),
+ }
+ example = {"text": "hi"}
+ output = {"tokens": ["hi"], "attn": np.array([[[.1]], [[.2]]])}
+
+ attn = types.AttentionHeads(align_in="tokens", align_out="tokens")
+ try:
+ attn.validate_output(np.array([[[.1]], [[.2]]]), out_spec, output,
+ ds_spec, ds_spec, example)
+ except ValueError:
+ self.fail("Raised unexpected error.")
+
+ self.assertRaises(
+ ValueError, attn.validate_output, np.array([.1, .2]), out_spec, output,
+ ds_spec, ds_spec, example)
+
+ bad_attn = types.AttentionHeads(align_in="tokens", align_out="val")
+ self.assertRaises(
+ ValueError, bad_attn.validate_output, np.array([[[.1]], [[.2]]]),
+ out_spec, output, ds_spec, ds_spec, example)
+ bad_attn = types.AttentionHeads(align_in="val", align_out="tokens")
+ self.assertRaises(
+ ValueError, bad_attn.validate_output, np.array([[[.1]], [[.2]]]),
+ out_spec, output, ds_spec, ds_spec, example)
+
+
+if __name__ == "__main__":
+ absltest.main()
diff --git a/lit_nlp/app.py b/lit_nlp/app.py
index e49d600c..2c0117ec 100644
--- a/lit_nlp/app.py
+++ b/lit_nlp/app.py
@@ -12,43 +12,33 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-# Lint as: python3
"""LIT backend, as a standard WSGI app."""
import functools
import glob
+import math
import os
import random
+import threading
import time
-from typing import Optional, Text, List, Mapping, Sequence, Union
+from typing import Optional, Mapping, Sequence, Union, Callable, Iterable
from absl import logging
from lit_nlp.api import components as lit_components
from lit_nlp.api import dataset as lit_dataset
-from lit_nlp.api import dtypes
+from lit_nlp.api import layout
from lit_nlp.api import model as lit_model
from lit_nlp.api import types
-from lit_nlp.components import ablation_flip
-from lit_nlp.components import gradient_maps
-from lit_nlp.components import hotflip
-from lit_nlp.components import lemon_explainer
-from lit_nlp.components import lime_explainer
-from lit_nlp.components import metrics
-from lit_nlp.components import model_salience
-from lit_nlp.components import nearest_neighbors
-from lit_nlp.components import pca
-from lit_nlp.components import pdp
-from lit_nlp.components import projection
-from lit_nlp.components import scrambler
-from lit_nlp.components import tcav
-from lit_nlp.components import thresholder
-from lit_nlp.components import umap
-from lit_nlp.components import word_replacer
+from lit_nlp.components import core
from lit_nlp.lib import caching
+from lit_nlp.lib import flag_helpers
from lit_nlp.lib import serialize
+from lit_nlp.lib import ui_state
from lit_nlp.lib import utils
+from lit_nlp.lib import validation
from lit_nlp.lib import wsgi_app
+import tqdm
JsonDict = types.JsonDict
Input = types.Input
@@ -57,6 +47,8 @@
# Export this symbol, for access from demo.py
PredsCache = caching.PredsCache
+ProgressIndicator = Callable[[Iterable], Iterable]
+
class LitApp(object):
"""LIT WSGI application."""
@@ -66,23 +58,40 @@ def _build_metadata(self):
model_info = {}
for name, m in self._models.items():
mspec: lit_model.ModelSpec = m.spec()
- info = {}
- info['spec'] = {'input': mspec.input, 'output': mspec.output}
+ info = {
+ 'description': m.description(),
+ 'spec': {
+ 'input': mspec.input,
+ 'output': mspec.output
+ }
+ }
+
# List compatible datasets.
info['datasets'] = [
- dname for dname, ds in self._datasets.items()
- if mspec.is_compatible_with_dataset(ds.spec())
+ name for name, dataset in self._datasets.items()
+ if mspec.is_compatible_with_dataset(dataset.spec())
]
if len(info['datasets']) == 0: # pylint: disable=g-explicit-length-test
logging.error("Error: model '%s' has no compatible datasets!", name)
- info['generators'] = [
- name for name, gen in self._generators.items() if gen.is_compatible(m)
- ]
- info['interpreters'] = [
- name for name, interp in self._interpreters.items()
- if interp.is_compatible(m)
- ]
- info['description'] = m.description()
+
+ compat_gens: set[str] = set()
+ compat_interps: set[str] = set()
+
+ for d in info['datasets']:
+ dataset: lit_dataset.Dataset = self._datasets[d]
+ compat_gens.update([
+ name for name, gen in self._generators.items()
+ if gen.is_compatible(model=m, dataset=dataset)
+ ])
+ compat_interps.update([
+ name for name, interp in self._interpreters.items()
+ if interp.is_compatible(model=m, dataset=dataset)
+ ])
+
+ info['generators'] = [name for name in self._generators.keys()
+ if name in compat_gens]
+ info['interpreters'] = [name for name in self._interpreters.keys()
+ if name in compat_interps]
model_info[name] = info
dataset_info = {}
@@ -90,6 +99,7 @@ def _build_metadata(self):
dataset_info[name] = {
'spec': ds.spec(),
'description': ds.description(),
+ 'size': len(ds),
}
generator_info = {}
@@ -120,9 +130,13 @@ def _build_metadata(self):
'defaultLayout': self._default_layout,
'canonicalURL': self._canonical_url,
'pageTitle': self._page_title,
+ 'inlineDoc': self._inline_doc,
+ 'onboardStartDoc': self._onboard_start_doc,
+ 'onboardEndDoc': self._onboard_end_doc,
+ 'syncState': self.ui_state_tracker is not None,
}
- def _get_model_spec(self, name: Text):
+ def _get_model_spec(self, name: str):
return self._info['models'][name]['spec']
def _get_info(self, unused_data, **unused_kw):
@@ -130,7 +144,7 @@ def _get_info(self, unused_data, **unused_kw):
return self._info
def _reconstitute_inputs(self, inputs: Sequence[Union[IndexedInput, str]],
- dataset_name: str) -> List[IndexedInput]:
+ dataset_name: str) -> list[IndexedInput]:
"""Reconstitute any inputs sent as references (bare IDs)."""
index = self._datasets[dataset_name].index
# TODO(b/178228238): set up proper debug logging and hide this by default.
@@ -140,58 +154,60 @@ def _reconstitute_inputs(self, inputs: Sequence[Union[IndexedInput, str]],
num_aliased, len(inputs), dataset_name)
return [index[ex] if isinstance(ex, str) else ex for ex in inputs]
- def _predict(self, inputs: List[JsonDict], model_name: Text,
- dataset_name: Optional[Text]):
- """Run model predictions."""
- return list(self._models[model_name].predict_with_metadata(
- inputs, dataset_name=dataset_name))
-
- def _save_datapoints(self, data, dataset_name: Text, path: Text, **unused_kw):
+ def _save_datapoints(self, data, dataset_name: str, path: str, **unused_kw):
"""Save datapoints to disk."""
if self._demo_mode:
- logging.warn('Attempted to save datapoints in demo mode.')
+ logging.warning('Attempted to save datapoints in demo mode.')
return None
return self._datasets[dataset_name].save(data['inputs'], path)
- def _load_datapoints(self, unused_data, dataset_name: Text, path: Text,
+ def _load_datapoints(self, unused_data, dataset_name: str, path: str,
**unused_kw):
"""Load datapoints from disk."""
if self._demo_mode:
- logging.warn('Attempted to load datapoints in demo mode.')
+ logging.warning('Attempted to load datapoints in demo mode.')
return None
dataset = self._datasets[dataset_name].load(path)
return dataset.indexed_examples
def _get_preds(self,
data,
- model: Text,
- dataset_name: Optional[Text] = None,
- requested_types: Text = 'LitType',
- **unused_kw):
+ model: str,
+ dataset_name: Optional[str] = None,
+ requested_types: Optional[str] = None,
+ requested_fields: Optional[str] = None,
+ **kw):
"""Get model predictions.
Args:
data: data payload, containing 'inputs' field
model: name of the model to run
dataset_name: name of the active dataset
- requested_types: optional, comma-separated list of types to return
+ requested_types: optional, comma-separated list of type names to return
+ requested_fields: optional, comma-separated list of field names to return
+ in addition to the ones returned due to 'requested_types'.
+ **kw: additional args passed to model.predict_with_metadata()
Returns:
- List[JsonDict] containing requested fields of model predictions
+ list[JsonDict] containing requested fields of model predictions
"""
- preds = self._predict(data['inputs'], model, dataset_name)
+ preds = list(self._models[model].predict_with_metadata(
+ data['inputs'], dataset_name=dataset_name, **kw))
+ if not requested_types and not requested_fields:
+ return preds
# Figure out what to return to the frontend.
output_spec = self._get_model_spec(model)['output']
- requested_types = requested_types.split(',')
- logging.info('Requested types: %s', str(requested_types))
- ret_keys = []
+ requested_types = requested_types.split(',') if requested_types else []
+ requested_fields = requested_fields.split(',') if requested_fields else []
+ logging.info('Requested types: %s, fields: %s', str(requested_types),
+ str(requested_fields))
for t_name in requested_types:
t_class = getattr(types, t_name, None)
- assert issubclass(
- t_class, types.LitType), f"Class '{t_name}' is not a valid LitType."
- ret_keys.extend(utils.find_spec_keys(output_spec, t_class))
- ret_keys = set(ret_keys) # de-dupe
+ if not issubclass(t_class, types.LitType):
+ raise TypeError(f"Class '{t_name}' is not a valid LitType.")
+ requested_fields.extend(utils.find_spec_keys(output_spec, t_class))
+ ret_keys = set(requested_fields) # de-dupe
# Return selected keys.
logging.info('Will return keys: %s', str(ret_keys))
@@ -199,8 +215,10 @@ def _get_preds(self,
ret = [utils.filter_by_keys(p, ret_keys.__contains__) for p in preds]
return ret
- def _annotate_new_data(self, data, dataset_name: Optional[Text] = None,
- **unused_kw) -> List[IndexedInput]:
+ def _annotate_new_data(self,
+ data,
+ dataset_name: Optional[str] = None,
+ **unused_kw) -> list[IndexedInput]:
"""Fill in index and other extra data for the provided datapoints."""
# TODO(lit-dev): unify this with hash fn on dataset objects.
assert dataset_name is not None, 'No dataset specified.'
@@ -219,22 +237,67 @@ def _annotate_new_data(self, data, dataset_name: Optional[Text] = None,
return data['inputs']
+ def _post_new_data(
+ self, data, dataset_name: Optional[str] = None,
+ **unused_kw) -> dict[str, str]:
+ """Save datapoints provided, after annotatation, for later retrieval.
+
+ Args:
+ data: JsonDict of datapoints to add, in dict under key 'inputs', per
+ format for other requests.
+ dataset_name: Dataset containing the format of data to add, necessary for
+ proper datapoint annotation.
+
+ Returns:
+ A dict of two URLs (minus the root of the webserver). 'load' value is
+ for loading LIT with those datapoints. 'remove' value is for removing
+ those new datapoints from this server after they have been loaded, if
+ desired.
+ """
+ assert 'inputs' in data, 'Data dict does not contain "inputs" field'
+ data_with_metadata = [
+ {'data': d,
+ 'meta': {'added': True, 'source': 'POST', 'parentId': None}}
+ for d in data['inputs']]
+ annotation_input = {'inputs': data_with_metadata}
+ annotated_data = self._annotate_new_data(annotation_input, dataset_name)
+ datapoints_id = utils.get_uuid()
+ with self._saved_datapoints_lock:
+ self._saved_datapoints[datapoints_id] = annotated_data
+ return {
+ 'load': f'?saved_datapoints_id={datapoints_id}',
+ 'remove': f'/remove_new_data?saved_datapoints_id={datapoints_id}'}
+
+ def _fetch_new_data(self, unused_data, saved_datapoints_id: str, **unused_kw):
+ with self._saved_datapoints_lock:
+ assert saved_datapoints_id in self._saved_datapoints, (
+ 'No saved data with ID %s' % saved_datapoints_id)
+ return self._saved_datapoints[saved_datapoints_id]
+
+ def _remove_new_data(
+ self, unused_data, saved_datapoints_id: str, **unused_kw):
+ with self._saved_datapoints_lock:
+ assert saved_datapoints_id in self._saved_datapoints, (
+ 'No saved data with ID %s' % saved_datapoints_id)
+ del self._saved_datapoints[saved_datapoints_id]
+
def _get_dataset(self,
unused_data,
- dataset_name: Optional[Text] = None,
- **unused_kw):
+ dataset_name: Optional[str] = None,
+ **unused_kw) -> list[IndexedInput]:
"""Attempt to get dataset, or override with a specific path."""
return self._datasets[dataset_name].indexed_examples
def _create_dataset(self,
unused_data,
- dataset_name: Optional[Text] = None,
- dataset_path: Optional[Text] = None,
+ dataset_name: Optional[str] = None,
+ dataset_path: Optional[str] = None,
**unused_kw):
"""Create dataset from a path, updating and returning the metadata."""
assert dataset_name is not None, 'No dataset specified.'
assert dataset_path is not None, 'No dataset path specified.'
+
new_dataset = self._datasets[dataset_name].load(dataset_path)
if new_dataset is not None:
new_dataset_name = dataset_name + '-' + os.path.basename(dataset_path)
@@ -242,12 +305,13 @@ def _create_dataset(self,
self._info = self._build_metadata()
return (self._info, new_dataset_name)
else:
+ logging.error('Not able to load: %s', dataset_name)
return None
def _create_model(self,
unused_data,
- model_name: Optional[Text] = None,
- model_path: Optional[Text] = None,
+ model_name: Optional[str] = None,
+ model_path: Optional[str] = None,
**unused_kw):
"""Create model from a path, updating and returning the metadata."""
@@ -264,14 +328,14 @@ def _create_model(self,
else:
return None
- def _get_generated(self, data, model: Text, dataset_name: Text,
- generator: Text, **unused_kw):
+ def _get_generated(self, data, model: str, dataset_name: str, generator: str,
+ **unused_kw):
"""Generate new datapoints based on the request."""
generator_name = generator
generator: lit_components.Generator = self._generators[generator_name]
dataset = self._datasets[dataset_name]
# Nested list, containing generated examples from each input.
- all_generated: List[List[Input]] = generator.run_with_metadata(
+ all_generated: list[list[Input]] = generator.run_with_metadata(
data['inputs'], self._models[model], dataset, config=data.get('config'))
# Annotate datapoints
@@ -282,10 +346,11 @@ def annotate_generated(datapoints):
return annotated_dataset.examples
annotated_generated = [
- annotate_generated(generated) for generated in all_generated]
+ annotate_generated(generated) for generated in all_generated
+ ]
# Add metadata.
- all_generated_indexed: List[List[IndexedInput]] = [
+ all_generated_indexed: list[list[IndexedInput]] = [
dataset.index_inputs(generated) for generated in annotated_generated
]
for parent, indexed_generated in zip(data['inputs'], all_generated_indexed):
@@ -297,37 +362,87 @@ def annotate_generated(datapoints):
})
return all_generated_indexed
- def _get_interpretations(self, data, model: Text, dataset_name: Text,
- interpreter: Text, **unused_kw):
+ def _get_interpretations(self, data, model: str, dataset_name: str,
+ interpreter: str, **unused_kw):
"""Run an interpretation component."""
interpreter = self._interpreters[interpreter]
- # Pre-compute using self._predict, which looks for cached results.
- model_outputs = self._predict(data['inputs'], model, dataset_name)
+ # Get model preds before the interpreter call. Usually these are cached.
+ # TODO(lit-dev): see if we can remove this path and just allow interpreters
+ # to call the model directly.
+ if model:
+ assert model in self._models, f"Model '{model}' is not a valid model."
+ model_outputs = self._get_preds(data, model, dataset_name)
+ model = self._models[model]
+ else:
+ model_outputs = None
+ model = None
return interpreter.run_with_metadata(
data['inputs'],
- self._models[model],
+ model,
self._datasets[dataset_name],
model_outputs=model_outputs,
config=data.get('config'))
- def _warm_start(self, rate: float):
+ def _push_ui_state(self, data, dataset_name: str, **unused_kw):
+ """Push UI state back to Python."""
+ if self.ui_state_tracker is None:
+ raise RuntimeError('Attempted to push UI state, but that is not enabled '
+ 'for this server.')
+ options = data.get('config', {})
+ self.ui_state_tracker.update_state(data['inputs'],
+ self._datasets[dataset_name],
+ dataset_name, **options)
+
+ def _validate(self, validate: Optional[flag_helpers.ValidationMode],
+ report_all: bool):
+ """Validate all datasets and models loaded for proper setup."""
+ if validate is None or validate == flag_helpers.ValidationMode.OFF:
+ return
+
+ datasets_to_validate = {}
+ for dataset in self._datasets:
+ if validate == flag_helpers.ValidationMode.ALL:
+ datasets_to_validate[dataset] = self._datasets[dataset]
+ elif validate == flag_helpers.ValidationMode.FIRST:
+ datasets_to_validate[dataset] = self._datasets[dataset].slice[:1]
+ elif validate == flag_helpers.ValidationMode.SAMPLE:
+ sample_size = math.ceil(len(self._datasets[dataset]) * 0.05)
+ datasets_to_validate[dataset] = self._datasets[dataset].sample(
+ sample_size)
+ for dataset in datasets_to_validate:
+ logging.info("Validating dataset '%s'", dataset)
+ validation.validate_dataset(
+ datasets_to_validate[dataset], report_all)
+ for model, model_info in self._info['models'].items():
+ for dataset_name in model_info['datasets']:
+ logging.info("Validating model '%s' on dataset '%s'", model,
+ dataset_name)
+ validation.validate_model(
+ self._models[model], datasets_to_validate[dataset_name], report_all)
+
+ def _warm_start(self,
+ rate: float,
+ progress_indicator: Optional[ProgressIndicator] = None):
"""Warm-up the predictions cache by making some model calls."""
assert rate >= 0 and rate <= 1
for model, model_info in self._info['models'].items():
for dataset_name in model_info['datasets']:
logging.info("Warm-start of model '%s' on dataset '%s'", model,
dataset_name)
- full_dataset = self._get_dataset([], dataset_name)
+ all_examples: list[IndexedInput] = self._get_dataset([], dataset_name)
if rate < 1:
- dataset = random.sample(full_dataset, int(len(full_dataset) * rate))
+ examples = random.sample(all_examples, int(len(all_examples) * rate))
logging.info('Partial warm-start: running on %d/%d examples.',
- len(dataset), len(full_dataset))
+ len(examples), len(all_examples))
else:
- dataset = full_dataset
- _ = self._get_preds({'inputs': dataset}, model, dataset_name)
+ examples = all_examples
+ _ = self._get_preds({'inputs': examples},
+ model,
+ dataset_name,
+ progress_indicator=progress_indicator)
- def _warm_projections(self, interpreters: List[Text]):
+ def _warm_projections(self, interpreters: list[str]):
"""Pre-compute UMAP/PCA projections with default arguments."""
for model, model_info in self._info['models'].items():
for dataset_name in model_info['datasets']:
@@ -340,14 +455,15 @@ def _warm_projections(self, interpreters: List[Text]):
dataset_name=dataset_name,
model_name=model,
field_name=field_name,
+ use_input=False,
proj_kw={'n_components': 3})
data = {'inputs': [], 'config': config}
for interpreter_name in interpreters:
_ = self._get_interpretations(
data, model, dataset_name, interpreter=interpreter_name)
- def _run_annotators(
- self, dataset: lit_dataset.Dataset) -> lit_dataset.Dataset:
+ def _run_annotators(self,
+ dataset: lit_dataset.Dataset) -> lit_dataset.Dataset:
datapoints = [dict(ex) for ex in dataset.examples]
annotated_spec = dict(dataset.spec())
for annotator in self._annotators:
@@ -390,22 +506,30 @@ def _handler(app: wsgi_app.App, request, environ):
def __init__(
self,
- models: Mapping[Text, lit_model.Model],
- datasets: Mapping[Text, lit_dataset.Dataset],
- generators: Optional[Mapping[Text, lit_components.Generator]] = None,
- interpreters: Optional[Mapping[Text, lit_components.Interpreter]] = None,
- annotators: Optional[List[lit_components.Annotator]] = None,
- layouts: Optional[dtypes.LitComponentLayouts] = None,
+ models: Mapping[str, lit_model.Model],
+ datasets: Mapping[str, lit_dataset.Dataset],
+ generators: Optional[Mapping[str, lit_components.Generator]] = None,
+ interpreters: Optional[Mapping[str, lit_components.Interpreter]] = None,
+ annotators: Optional[list[lit_components.Annotator]] = None,
+ layouts: Optional[layout.LitComponentLayouts] = None,
# General server config; see server_flags.py.
- data_dir: Optional[Text] = None,
+ data_dir: Optional[str] = None,
warm_start: float = 0.0,
+ warm_start_progress_indicator: Optional[ProgressIndicator] = tqdm
+ .tqdm, # not in server_flags
warm_projections: bool = False,
- client_root: Optional[Text] = None,
+ client_root: Optional[str] = None,
demo_mode: bool = False,
default_layout: Optional[str] = None,
canonical_url: Optional[str] = None,
page_title: Optional[str] = None,
development_demo: bool = False,
+ inline_doc: Optional[str] = None,
+ onboard_start_doc: Optional[str] = None,
+ onboard_end_doc: Optional[str] = None,
+ sync_state: bool = False, # notebook-only; not in server_flags
+ validate: Optional[flag_helpers.ValidationMode] = None,
+ report_all: bool = False,
):
if client_root is None:
raise ValueError('client_root must be set on application')
@@ -414,22 +538,34 @@ def __init__(
self._default_layout = default_layout
self._canonical_url = canonical_url
self._page_title = page_title
+ self._inline_doc = inline_doc
+ self._onboard_start_doc = onboard_start_doc
+ self._onboard_end_doc = onboard_end_doc
self._data_dir = data_dir
- self._layouts = layouts or {}
if data_dir and not os.path.isdir(data_dir):
os.mkdir(data_dir)
+ # TODO(lit-dev): override layouts instead of merging, to allow clients
+ # to opt-out of the default bundled layouts. This will require updating
+ # client code to manually merge when this is the desired behavior.
+ self._layouts = dict(layout.DEFAULT_LAYOUTS, **(layouts or {}))
+
# Wrap models in caching wrapper
self._models = {
name: caching.CachingModelWrapper(model, name, cache_dir=data_dir)
for name, model in models.items()
}
- self._datasets = dict(datasets)
+ self._datasets: dict[str, lit_dataset.Dataset] = dict(datasets)
+ # TODO(b/202210900): get rid of this, just dynamically create the empty
+ # dataset on the frontend.
self._datasets['_union_empty'] = lit_dataset.NoneDataset(self._models)
self._annotators = annotators or []
+ self._saved_datapoints = {}
+ self._saved_datapoints_lock = threading.Lock()
+
# Run annotation on each dataset, creating an annotated dataset and
# replace the datasets with the annotated versions.
for ds_key, ds in self._datasets.items():
@@ -439,46 +575,30 @@ def __init__(
self._datasets = lit_dataset.IndexedDataset.index_all(
self._datasets, caching.input_hash)
+ # Generator initialization
if generators is not None:
self._generators = generators
else:
- self._generators = {
- 'Ablation Flip': ablation_flip.AblationFlip(),
- 'Hotflip': hotflip.HotFlip(),
- 'Scrambler': scrambler.Scrambler(),
- 'Word Replacer': word_replacer.WordReplacer(),
- }
+ self._generators = core.default_generators()
+ # Interpreter initialization
if interpreters is not None:
self._interpreters = interpreters
else:
- metrics_group = lit_components.ComponentGroup({
- 'regression': metrics.RegressionMetrics(),
- 'multiclass': metrics.MulticlassMetrics(),
- 'paired': metrics.MulticlassPairedMetrics(),
- 'bleu': metrics.CorpusBLEU(),
- })
- self._interpreters = {
- 'Grad L2 Norm': gradient_maps.GradientNorm(),
- 'Grad ⋅ Input': gradient_maps.GradientDotInput(),
- 'Integrated Gradients': gradient_maps.IntegratedGradients(),
- 'LIME': lime_explainer.LIME(),
- 'Model-provided salience': model_salience.ModelSalience(self._models),
- 'counterfactual explainer': lemon_explainer.LEMON(),
- 'tcav': tcav.TCAV(),
- 'thresholder': thresholder.Thresholder(),
- 'nearest neighbors': nearest_neighbors.NearestNeighbors(),
- 'metrics': metrics_group,
- 'pdp': pdp.PdpInterpreter(),
- # Embedding projectors expose a standard interface, but get special
- # handling so we can precompute the projections if requested.
- 'pca': projection.ProjectionManager(pca.PCAModel),
- 'umap': projection.ProjectionManager(umap.UmapModel),
- }
+ self._interpreters = core.default_interpreters(self._models)
+
+ # Component to sync state from TS -> Python. Used in notebooks.
+ if sync_state:
+ self.ui_state_tracker = ui_state.UIStateTracker()
+ else:
+ self.ui_state_tracker = None
# Information on models, datasets, and other components.
self._info = self._build_metadata()
+ # Validate datasets and models if specified.
+ self._validate(validate, report_all)
+
# Optionally, run models to pre-populate cache.
if warm_projections:
logging.info(
@@ -488,8 +608,11 @@ def __init__(
warm_start = 1.0
if warm_start > 0:
- self._warm_start(rate=warm_start)
+ self._warm_start(
+ rate=warm_start, progress_indicator=warm_start_progress_indicator)
self.save_cache()
+ if warm_start >= 1:
+ warm_projections = True
# If you add a new embedding projector that should be warm-started,
# also add it to the list here.
@@ -509,6 +632,10 @@ def __init__(
'/save_datapoints': self._save_datapoints,
'/load_datapoints': self._load_datapoints,
'/annotate_new_data': self._annotate_new_data,
+ '/post_new_data': self._post_new_data,
+ '/fetch_new_data': self._fetch_new_data,
+ '/remove_new_data': self._remove_new_data,
+ '/push_ui_state': self._push_ui_state,
# Model prediction endpoints.
'/get_preds': self._get_preds,
'/get_interpretations': self._get_interpretations,
diff --git a/lit_nlp/client/core/app.ts b/lit_nlp/client/core/app.ts
index d6894457..778b86e1 100644
--- a/lit_nlp/client/core/app.ts
+++ b/lit_nlp/client/core/app.ts
@@ -17,18 +17,17 @@
// Import Services
// Import and add injection functionality to LitModule
-import {reaction} from 'mobx';
-
-import {Constructor, LitComponentLayouts} from '../lib/types';
+import {autorun, toJS} from 'mobx';
+import {Constructor} from '../lib/types';
import {ApiService} from '../services/api_service';
import {ClassificationService} from '../services/classification_service';
import {ColorService} from '../services/color_service';
+import {DataService} from '../services/data_service';
import {FocusService} from '../services/focus_service';
import {GroupService} from '../services/group_service';
import {LitService} from '../services/lit_service';
import {ModulesService} from '../services/modules_service';
-import {RegressionService} from '../services/regression_service';
import {SelectionService} from '../services/selection_service';
import {SettingsService} from '../services/settings_service';
import {SliceService} from '../services/slice_service';
@@ -49,50 +48,75 @@ export class LitApp {
* Begins loading data from the LIT server, and computes the layout that
* the `modules` component will use to render.
*/
- async initialize(layouts: LitComponentLayouts) {
+ async initialize() {
+ const apiService = this.getService(ApiService);
const appState = this.getService(AppState);
const modulesService = this.getService(ModulesService);
- appState.addLayouts(layouts);
+ const [selectionService, pinnedSelectionService] =
+ this.getServiceArray(SelectionService);
+ const urlService = this.getService(UrlService);
+ const colorService = this.getService(ColorService);
- await appState.initialize();
+ // Load the app metadata before any further initialization
+ appState.metadata = await apiService.getInfo();
+ console.log('[LIT - metadata]', toJS(appState.metadata));
+
+ // Update page title based on metadata
if (appState.metadata.pageTitle) {
- document.querySelector('html head title')!.textContent = appState.metadata.pageTitle;
+ document.querySelector('html head title')!.textContent =
+ appState.metadata.pageTitle;
}
+
+ // Sync app state based on URL search params
+ urlService.syncStateToUrl(appState, modulesService, selectionService,
+ pinnedSelectionService, colorService);
+
+ // Initialize the rest of the app state
+ await appState.initialize();
+
+ // Initilize the module layout
modulesService.initializeLayout(
appState.layout, appState.currentModelSpecs,
appState.currentDatasetSpec, appState.compareExamplesEnabled);
- // Select the initial datapoint, if one was set in the url.
- const selectionServices = this.getServiceArray(SelectionService);
- await this.getService(UrlService).syncSelectedDatapointToUrl(appState, selectionServices[0]);
+ // Select the initial datapoint, if one was set in the URL.
+ await urlService.syncSelectedDatapointToUrl(appState, selectionService);
- // Reaction to sync other selection services to selections of the main one.
- reaction(() => appState.compareExamplesEnabled, compareExamplesEnabled => {
- this.syncSelectionServices();
- }, {fireImmediately: true});
+ // Enabling comparison mode if a datapoint has been pinned
+ if (pinnedSelectionService.primarySelectedId) {
+ appState.compareExamplesEnabled = true;
+ }
+
+ // If enabled, set up state syncing back to Python.
+ if (appState.metadata.syncState) {
+ autorun(() => {
+ apiService.pushState(
+ selectionService.selectedInputData, appState.currentDataset, {
+ 'primary_id': selectionService.primarySelectedId,
+ 'pinned_id': pinnedSelectionService.primarySelectedId,
+ });
+ });
+ }
}
private readonly services =
new Map, LitService|LitService[]>();
- /** Sync selection services */
- syncSelectionServices() {
- const selectionServices = this.getServiceArray(SelectionService);
- // TODO(lit-dev): can we just copy the object instead, and skip this
- // logic?
- selectionServices[1].syncFrom(selectionServices[0]);
- }
-
/** Simple DI service system */
- getService(t: Constructor): T {
+ getService(t: Constructor, instance?: string): T {
let service = this.services.get(t);
/**
* Modules that don't support example comparison will always get index
* 0 of selectionService. This way we do not have to edit any module that
- * does not explicitly support cloning
+ * does not explicitly support cloning. For modules that support comparison,
+ * if the `pinned` instance is specified then return the appropriate
+ * instance.
*/
if (Array.isArray(service)) {
- service = service[0];
+ if (instance != null && instance !== 'pinned') {
+ throw new Error(`Invalid service instance name: ${instance}`);
+ }
+ service = service[instance === 'pinned' ? 1 : 0];
}
if (service === undefined) {
throw new Error(`Service is undefined: ${t.name}`);
@@ -122,34 +146,32 @@ export class LitApp {
const statusService = new StatusService();
const apiService = new ApiService(statusService);
const modulesService = new ModulesService();
- const urlService = new UrlService();
+ const urlService = new UrlService(apiService);
const appState = new AppState(apiService, statusService);
- const selectionService0 = new SelectionService(appState);
- const selectionService1 = new SelectionService(appState);
- const sliceService = new SliceService(selectionService0, appState);
- const regressionService = new RegressionService(apiService, appState);
+ const selectionService = new SelectionService(appState);
+ const pinnedSelectionService = new SelectionService(appState);
+ const sliceService = new SliceService(selectionService, appState);
const settingsService =
- new SettingsService(appState, modulesService, selectionService0);
- const groupService = new GroupService(appState);
- const classificationService =
- new ClassificationService(apiService, appState, groupService);
- const colorService = new ColorService(
- appState, groupService, classificationService, regressionService);
- const focusService = new FocusService(selectionService0);
-
- // Initialize url syncing of state
- urlService.syncStateToUrl(appState, selectionService0, modulesService);
+ new SettingsService(appState, modulesService, selectionService);
+ const classificationService = new ClassificationService(appState);
+ const dataService = new DataService(
+ appState, classificationService, apiService, settingsService);
+ const groupService = new GroupService(appState, dataService);
+ const colorService = new ColorService(groupService, dataService);
+ const focusService = new FocusService(selectionService);
// Populate the internal services map for dependency injection
this.services.set(ApiService, apiService);
this.services.set(AppState, appState);
this.services.set(ClassificationService, classificationService);
this.services.set(ColorService, colorService);
+ this.services.set(DataService, dataService);
this.services.set(FocusService, focusService);
this.services.set(GroupService, groupService);
this.services.set(ModulesService, modulesService);
- this.services.set(RegressionService, regressionService);
- this.services.set(SelectionService, [selectionService0, selectionService1]);
+ this.services.set(SelectionService, [
+ selectionService, pinnedSelectionService
+ ]);
this.services.set(SettingsService, settingsService);
this.services.set(SliceService, sliceService);
this.services.set(StatusService, statusService);
diff --git a/lit_nlp/client/core/app_statusbar.css b/lit_nlp/client/core/app_statusbar.css
index bf9a391f..521175ad 100644
--- a/lit_nlp/client/core/app_statusbar.css
+++ b/lit_nlp/client/core/app_statusbar.css
@@ -41,19 +41,6 @@
color: red;
}
-mwc-icon.icon-button {
- height: 18px;
- width: 18px;
- min-width: 18px;
- --mdc-icon-size: 18px;
- cursor: pointer;
- color: var(--lit-cyea-800);
-}
-
-mwc-icon.icon-button:hover {
- color: var(--lit-cyea-600);
-}
-
.emoji {
height: 10pt;
margin-bottom: -5px;
@@ -112,7 +99,7 @@ mwc-icon.icon-button:hover {
width: 80vw;
box-sizing: border-box;
background-color: white;
- box-shadow: 0 2px 2px 0 rgba(0, 0, 0, .14), 0 3px 1px -2px rgba(0, 0, 0, .2), 0 1px 5px 0 rgba(0, 0, 0, .12);
+ box-shadow: var(--lit-box-shadow);
display: flex;
flex-direction: column;
justify-content: space-between;
@@ -130,6 +117,7 @@ mwc-icon.icon-button:hover {
.close-button-holder {
display: flex;
justify-content: flex-end;
+ column-gap: 8px;
margin: 10px;
}
diff --git a/lit_nlp/client/core/app_statusbar.ts b/lit_nlp/client/core/app_statusbar.ts
index 092a9f2c..85efb136 100644
--- a/lit_nlp/client/core/app_statusbar.ts
+++ b/lit_nlp/client/core/app_statusbar.ts
@@ -25,13 +25,13 @@ import './global_settings';
import '../elements/spinner';
import {MobxLitElement} from '@adobe/lit-mobx';
+import {html} from 'lit';
import {customElement} from 'lit/decorators';
-import { html} from 'lit';
import {classMap} from 'lit/directives/class-map';
import {observable} from 'mobx';
import {styles as sharedStyles} from '../lib/shared_styles.css';
-import {StatusService} from '../services/services';
+import {AppState, StatusService} from '../services/services';
import {app} from './app';
import {styles} from './app_statusbar.css';
@@ -45,6 +45,7 @@ export class StatusbarComponent extends MobxLitElement {
return [sharedStyles, styles];
}
+ private readonly appState = app.getService(AppState);
private readonly statusService = app.getService(StatusService);
@observable private renderFullMessages = false;
@@ -65,8 +66,8 @@ export class StatusbarComponent extends MobxLitElement {
`;
});
}
@@ -340,43 +426,82 @@ export class LitModules extends ReactiveElement {
}
renderWidgetGroups(
- configs: RenderConfig[][], section: string, layoutWidths: LayoutWidths) {
- // Calllback for widget isMinimized state changes.
- const onMin = (event: Event) => {
- // Recalculate the widget group widths in this section.
+ configs: RenderConfig[][], section: string, layoutWidths: LayoutWidths,
+ idPrefix: string, visible: boolean) {
+ // Recalculate the widget group widths when isMinimized state changes.
+ const onMin = () => {
this.calculatePanelWidths(section, configs, layoutWidths);
};
return configs.map((configGroup, i) => {
+ const width = layoutWidths[section]? layoutWidths[section][i] : 0;
+ const isLastGroup = i === configs.length - 1;
+ const id = `${idPrefix}-${i}`;
+
+ let nextShownGroupIndex = -1;
- // Callback from widget width drag events.
- const onDrag = (event: Event) => {
- // tslint:disable-next-line:no-any
- const dragWidth = (event as any).detail.dragWidth;
-
- // If the dragged group isn't the right-most group, then balance the
- // delta in width with the widget directly to it's left (so if a widget
- // is expanded, then its adjacent widget is shrunk by the same amount).
- if (i < configs.length - 1) {
- const adjacentConfig = configs[i + 1];
- if (!this.modulesService.isModuleGroupHidden(adjacentConfig[0])) {
- const widthChange = dragWidth - layoutWidths[section][i];
- const oldAdjacentWidth = layoutWidths[section][i + 1];
- layoutWidths[section][i + 1] =
- Math.max(MIN_GROUP_WIDTH_PX, oldAdjacentWidth - widthChange);
- }
+ // Try to find an open widget group to the right of this one
+ for (let adj = i + 1; adj < configs.length; adj++) {
+ if (!this.modulesService.isModuleGroupHidden(configs[adj][0])) {
+ nextShownGroupIndex = adj;
+ break;
}
+ }
- // Set the width of the dragged widget group.
- layoutWidths[section][i] = dragWidth;
+ const isDraggable =
+ nextShownGroupIndex > i &&
+ !this.modulesService.isModuleGroupHidden(configGroup[0]);
+
+ const expanderStyles = styleMap({
+ 'cursor': isDraggable ? 'ew-resize' : 'default'
+ });
+
+ const dragged = (e: DragEvent) => {
+ // If this is the rightmost group, or this isn't draggable, or this is
+ // minimized, do nothing.
+ if (isLastGroup || !isDraggable ||
+ this.modulesService.isModuleGroupHidden(configGroup[0])) return;
+
+ const widgetGroup = this.shadowRoot!.querySelector(`#${id}`);
+ const left = widgetGroup!.getBoundingClientRect().left;
+ const dragWidth = Math.round(e.clientX - left - EXPANDER_WIDTH);
+ const dragLength = dragWidth - width;
+
+ // Groups have a minimum width, so the user can't drag any further to
+ // the left than that
+ const atMinimum = dragWidth <= MIN_GROUP_WIDTH_PX;
+ // We enforce a minimum drag distance before requesting an update,
+ // effectively a distance-based throttle for render performance
+ const isSufficient = Math.abs(dragLength) > MIN_GROUP_WIDTH_DELTA_PX;
+ if (atMinimum || !isSufficient) return;
+
+ // Balance the delta in width with the next open widget to its right, so
+ // if a widget is expanded, then the next open widget to its right is
+ // shrunk by the same amount and vice versa.
+ const oldAdjacentWidth = layoutWidths[section][nextShownGroupIndex];
+ const newWidth = Math.round(oldAdjacentWidth - dragLength);
+ const newAdjacentWidth = Math.max(MIN_GROUP_WIDTH_PX, newWidth);
+ const deltaFromDrag = newAdjacentWidth - newWidth;
+ layoutWidths[section][nextShownGroupIndex] = newAdjacentWidth;
+ layoutWidths[section][i] =
+ dragWidth - (newAdjacentWidth > newWidth ? deltaFromDrag : 0);
this.requestUpdate();
};
- const width = layoutWidths[section] ? layoutWidths[section][i] : 0;
- return html``;
+ // clang-format off
+ return html`
+
+
+ ${isLastGroup ? html`` : html`
+
+
{ dragged(e); }}>
+
+
`}`;
+ // clang-format on
});
}
}
diff --git a/lit_nlp/client/modules/slice_module.css b/lit_nlp/client/core/slice_module.css
similarity index 93%
rename from lit_nlp/client/modules/slice_module.css
rename to lit_nlp/client/core/slice_module.css
index efec38f7..7a7594f0 100644
--- a/lit_nlp/client/modules/slice_module.css
+++ b/lit_nlp/client/core/slice_module.css
@@ -4,6 +4,8 @@
.module-container {
padding: 6px;
+ max-height: 100%;
+ overflow-y: hidden;
}
.row-container {
@@ -20,12 +22,13 @@
flex-direction: column;
flex: 1;
width: 100%;
+ overflow-y: hidden;
}
#slice-selector {
border: 1px solid rgb(218, 220, 224);
flex: 1;
- min-height: 250px;
+ max-height: calc(100% - 18px);
overflow-x: hidden;
overflow-y: auto;
}
diff --git a/lit_nlp/client/modules/slice_module.ts b/lit_nlp/client/core/slice_module.ts
similarity index 79%
rename from lit_nlp/client/modules/slice_module.ts
rename to lit_nlp/client/core/slice_module.ts
index 22925b3b..b5e02d30 100644
--- a/lit_nlp/client/modules/slice_module.ts
+++ b/lit_nlp/client/core/slice_module.ts
@@ -17,17 +17,18 @@
// tslint:disable:no-new-decorators
import {customElement} from 'lit/decorators';
-import { html} from 'lit';
+import {html} from 'lit';
import {classMap} from 'lit/directives/class-map';
import {computed, observable} from 'mobx';
-import {app} from '../core/app';
-import {LitModule} from '../core/lit_module';
+import {app} from './app';
+import {LitModule} from './lit_module';
import {ModelInfoMap, Spec} from '../lib/types';
import {handleEnterKey} from '../lib/utils';
-import {GroupService, FacetingMethod, FacetingConfig, NumericFeatureBins} from '../services/group_service';
+import {GroupService, NumericFeatureBins} from '../services/group_service';
import {SliceService} from '../services/services';
import {STARRED_SLICE_NAME} from '../services/slice_service';
+import {FacetsChange} from '../core/faceting_control';
import {styles as sharedStyles} from '../lib/shared_styles.css';
@@ -43,22 +44,40 @@ export class SliceModule extends LitModule {
}
static override title = 'Slice Editor';
+ static override referenceURL =
+ 'https://github.com/PAIR-code/lit/wiki/ui_guide.md#slices';
static override numCols = 2;
static override collapseByDefault = true;
static override duplicateForModelComparison = false;
- static override template = () => {
- return html``;
- };
+ static override template =
+ (model: string, selectionServiceIndex: number, shouldReact: number) =>
+ html`
+
+ `;
private readonly sliceService = app.getService(SliceService);
private readonly groupService = app.getService(GroupService);
+ private readonly facetingControl = document.createElement('faceting-control');
private sliceByBins: NumericFeatureBins = {};
@observable private sliceByFeatures: string[] = [];
@observable private sliceName: string|null = null;
+ constructor() {
+ super();
+
+ const facetsChange = (event: CustomEvent) => {
+ this.sliceByFeatures = event.detail.features;
+ this.sliceByBins = event.detail.bins;
+ };
+ this.facetingControl.contextName = SliceModule.title;
+ this.facetingControl.addEventListener(
+ 'facets-change', facetsChange as EventListener);
+ }
+
@computed
private get createButtonEnabled() {
const sliceFromFilters =
@@ -228,69 +247,13 @@ export class SliceModule extends LitModule {
// clang-format on
}
- /**
- * Create checkboxes for each value of each categorical feature.
- */
- renderFilters() {
- // Update the filterdict to match the checkboxes.
- const onChange = (e: Event, key: string) => {
- if ((e.target as HTMLInputElement).checked) {
- this.sliceByFeatures.push(key);
- } else {
- const index = this.sliceByFeatures.indexOf(key);
- this.sliceByFeatures.splice(index, 1);
- }
-
- const configs: FacetingConfig[] = this.sliceByFeatures.map(feature => ({
- featureName: feature,
- method: this.groupService.numericalFeatureNames.includes(feature) ?
- FacetingMethod.EQUAL_INTERVAL : FacetingMethod.DISCRETE
- }));
-
- this.sliceByBins = this.groupService.numericalFeatureBins(configs);
- };
-
- const renderFeatureCheckbox = (key: string) => {
- // clang-format off
- return html`
-
`;
diff --git a/lit_nlp/client/default/layout.ts b/lit_nlp/client/default/layout.ts
deleted file mode 100644
index 2b2c2940..00000000
--- a/lit_nlp/client/default/layout.ts
+++ /dev/null
@@ -1,124 +0,0 @@
-/**
- * @license
- * Copyright 2020 Google LLC
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// Import Modules
-import '../modules/span_graph_module';
-
-import {LitModuleType} from '../core/lit_module';
-import {LitComponentLayouts} from '../lib/types';
-import {AnnotatedTextGoldModule, AnnotatedTextModule} from '../modules/annotated_text_module';
-import {AttentionModule} from '../modules/attention_module';
-import {ClassificationModule} from '../modules/classification_module';
-import {ColorModule} from '../modules/color_module';
-import {ConfusionMatrixModule} from '../modules/confusion_matrix_module';
-import {CounterfactualExplainerModule} from '../modules/counterfactual_explainer_module';
-import {DataTableModule, SimpleDataTableModule} from '../modules/data_table_module';
-import {DatapointEditorModule, SimpleDatapointEditorModule} from '../modules/datapoint_editor_module';
-import {EmbeddingsModule} from '../modules/embeddings_module';
-import {GeneratedImageModule} from '../modules/generated_image_module';
-import {GeneratedTextModule} from '../modules/generated_text_module';
-import {GeneratorModule} from '../modules/generator_module';
-import {LanguageModelPredictionModule} from '../modules/lm_prediction_module';
-import {MetricsModule} from '../modules/metrics_module';
-import {MultilabelModule} from '../modules/multilabel_module';
-import {PdpModule} from '../modules/pdp_module';
-import {RegressionModule} from '../modules/regression_module';
-import {SalienceMapModule} from '../modules/salience_map_module';
-import {ScalarModule} from '../modules/scalar_module';
-import {SequenceSalienceModule} from '../modules/sequence_salience_module';
-import {SliceModule} from '../modules/slice_module';
-import {SpanGraphGoldModuleVertical, SpanGraphModuleVertical} from '../modules/span_graph_module';
-import {TCAVModule} from '../modules/tcav_module';
-import {ThresholderModule} from '../modules/thresholder_module';
-
-// clang-format off
-const MODEL_PREDS_MODULES: LitModuleType[] = [
- SpanGraphGoldModuleVertical,
- SpanGraphModuleVertical,
- ClassificationModule,
- MultilabelModule,
- RegressionModule,
- LanguageModelPredictionModule,
- GeneratedTextModule,
- AnnotatedTextGoldModule,
- AnnotatedTextModule,
- GeneratedImageModule,
-];
-
-const DEFAULT_MAIN_GROUP: LitModuleType[] = [
- DataTableModule,
- DatapointEditorModule,
- SliceModule,
- ColorModule,
-];
-// clang-format on
-
-// clang-format off
-/**
- * Possible layouts for LIT (component groups and settigns.)
- */
-export const LAYOUTS: LitComponentLayouts = {
- /**
- * A "simple demo server" layout.
- */
- 'simple': {
- upper: {
- "Editor": [SimpleDatapointEditorModule],
- "Examples": [SimpleDataTableModule],
- },
- lower: {
- 'Predictions': [ ...MODEL_PREDS_MODULES],
- 'Salience': [SalienceMapModule, SequenceSalienceModule],
- },
- layoutSettings: {
- hideToolbar: true,
- mainHeight: 30,
- centerPage: true
- },
- description: 'A basic layout just containing a datapoint creator/editor, the predictions, and the data table. There are also some visual simplifications: the toolbar is hidden, and the modules are centered on the page rather than being full width.'
- },
- /**
- * A default layout for LIT Modules
- */
- 'default': {
- components : {
- 'Main': [EmbeddingsModule, ...DEFAULT_MAIN_GROUP],
- 'Predictions': [
- ...MODEL_PREDS_MODULES,
- ScalarModule,
- PdpModule,
- ],
- 'Explanations': [
- ...MODEL_PREDS_MODULES,
- SalienceMapModule,
- SequenceSalienceModule,
- AttentionModule,
- ],
- 'Metrics': [
- MetricsModule,
- ConfusionMatrixModule,
- ThresholderModule,
- ],
- 'Counterfactuals': [GeneratorModule, CounterfactualExplainerModule],
- 'TCAV': [
- TCAVModule,
- ],
- },
- description: "The default LIT layout, which includes the data table and data point editor, the performance and metrics, predictions, explanations, and counterfactuals."
- },
-};
-// clang-format on
diff --git a/lit_nlp/client/elements/annotated_text_vis.ts b/lit_nlp/client/elements/annotated_text_vis.ts
index c59c6c3e..c7f608e9 100644
--- a/lit_nlp/client/elements/annotated_text_vis.ts
+++ b/lit_nlp/client/elements/annotated_text_vis.ts
@@ -13,8 +13,10 @@ import {computed, observable} from 'mobx';
import {getVizColor} from '../lib/colors';
import {ReactiveElement} from '../lib/elements';
+import {SpanLabel} from '../lib/dtypes';
+import {URLLitType} from '../lib/lit_types';
import {styles as sharedStyles} from '../lib/shared_styles.css';
-import {formatSpanLabel, SpanLabel} from '../lib/types';
+import {formatSpanLabel} from '../lib/types';
import {styles} from './annotated_text_vis.css';
@@ -194,7 +196,7 @@ export class AnnotatedTextVis extends ReactiveElement {
renderTextSegment(name: string) {
const text = this.segments[name];
const spans = this.activeAnnotations[name];
- const isURL = this.segmentSpec[name].__name__ === 'URL';
+ const isURL = this.segmentSpec[name] instanceof URLLitType;
return html`
`;
}
diff --git a/lit_nlp/client/elements/checkbox_test.ts b/lit_nlp/client/elements/checkbox_test.ts
new file mode 100644
index 00000000..499e2027
--- /dev/null
+++ b/lit_nlp/client/elements/checkbox_test.ts
@@ -0,0 +1,67 @@
+import 'jasmine';
+import {Checkbox} from '@material/mwc-checkbox';
+import {LitElement} from 'lit';
+import {LitCheckbox} from './checkbox';
+
+
+describe('faceting control test', () => {
+ let checkbox: LitCheckbox;
+
+ beforeEach(async () => {
+ // Set up.
+ checkbox = new LitCheckbox();
+ document.body.appendChild(checkbox);
+ await checkbox.updateComplete;
+ });
+
+ afterEach(() => {
+ document.body.removeChild(checkbox);
+ });
+
+ it('can be instantiated', () => {
+ expect(checkbox instanceof HTMLElement).toBeTrue();
+ expect(checkbox instanceof LitElement).toBeTrue();
+ });
+
+ it('comprises a div with an MWC Checkbox and a span as children', () => {
+ expect(checkbox.renderRoot.children.length).toEqual(1);
+
+ const [innerDiv] = checkbox.renderRoot.children;
+ expect(innerDiv instanceof HTMLDivElement).toBeTrue();
+ expect((innerDiv as HTMLDivElement).className).toEqual(' wrapper ');
+ expect(innerDiv.children.length).toEqual(2);
+
+ const [mwcCheckbox, label] = innerDiv.children;
+ expect(mwcCheckbox instanceof Checkbox).toBeTrue();
+ expect(label instanceof HTMLSpanElement).toBeTrue();
+ expect((label as HTMLSpanElement).className).toEqual('checkbox-label');
+ });
+
+ it('toggles checked state when the box is clicked', async () => {
+ expect(checkbox.checked).toBeFalse();
+
+ const mwcCheckbox =
+ checkbox.renderRoot.querySelector('lit-mwc-checkbox-internal')!;
+ mwcCheckbox.click();
+ await checkbox.updateComplete;
+ expect(checkbox.checked).toBeTrue();
+
+ mwcCheckbox.click();
+ await checkbox.updateComplete;
+ expect(checkbox.checked).toBeFalse();
+ });
+
+ it('toggles checked state when the label is clicked', async () => {
+ expect(checkbox.checked).toBeFalse();
+
+ const label =
+ checkbox.renderRoot.querySelector('span.checkbox-label')!;
+ label.click();
+ await checkbox.updateComplete;
+ expect(checkbox.checked).toBeTrue();
+
+ label.click();
+ await checkbox.updateComplete;
+ expect(checkbox.checked).toBeFalse();
+ });
+});
diff --git a/lit_nlp/client/elements/color_legend.css b/lit_nlp/client/elements/color_legend.css
new file mode 100644
index 00000000..7369eb19
--- /dev/null
+++ b/lit_nlp/client/elements/color_legend.css
@@ -0,0 +1,46 @@
+.legend-container {
+ display: flex;
+ flex-direction: row;
+ align-items: center;
+ height: 30px;
+ overflow-x: auto;
+ overflow-y: hidden;
+}
+
+.legend-line {
+ display: flex;
+ flex-direction: row;
+ align-items: center;
+}
+
+.legend-box {
+ width: 13px;
+ height: 13px;
+ margin: 2px;
+ border: 1px solid gray;
+ border-radius: 1px;
+}
+
+.legend-label {
+ margin-left: 3px;
+ margin-right: 2px;
+ line-height: 13px;
+ font-family: 'Roboto';
+ font-style: 'normal';
+ font-size: 13px;
+ color: var(--lit-neutral-800);
+}
+
+.color-label {
+ min-width: 20px;
+ margin-left: 3px;
+ margin-right: 3px;
+ line-height: 13px;
+ font-family: 'Roboto';
+ font-style: 'normal';
+ font-size: 13px;
+ overflow: hidden;
+ text-overflow: ellipsis;
+ white-space: nowrap;
+ color: var(--lit-neutral-800);
+}
diff --git a/lit_nlp/client/elements/color_legend.ts b/lit_nlp/client/elements/color_legend.ts
new file mode 100644
index 00000000..456c0c03
--- /dev/null
+++ b/lit_nlp/client/elements/color_legend.ts
@@ -0,0 +1,291 @@
+/**
+ * @fileoverview Element for displaying color legend.
+ *
+ * @license
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// tslint:disable:no-new-decorators
+import * as d3 from 'd3';
+import {html} from 'lit';
+import {customElement, property} from 'lit/decorators';
+import {styleMap} from 'lit/directives/style-map';
+import {computed, observable} from 'mobx';
+
+import {DEFAULT} from '../lib/colors';
+import {ReactiveElement} from '../lib/elements';
+import {styles as sharedStyles} from '../lib/shared_styles.css';
+import {D3Scale} from '../lib/types';
+import {getTextWidth, linearSpace} from '../lib/utils';
+
+import {styles} from './color_legend.css';
+
+/**
+ * Enumeration of the different legend types
+ */
+export enum LegendType {
+ SEQUENTIAL = 'sequential',
+ CATEGORICAL = 'categorical'
+}
+
+// default width of a character
+const DEFAULT_CHAR_WIDTH: number = 5.7;
+
+/** Removes non-digit chars from a style value and converts it to a number. */
+function stylePropToNumber(styles: CSSStyleDeclaration,
+ property: string): number {
+ try {
+ return Number(styles.getPropertyValue(property).replace(/[^\d\.]/g, ''));
+ } catch {
+ return 0;
+ }
+}
+
+/**
+ * Color legend visualization component.
+ */
+@customElement('color-legend')
+export class ColorLegend extends ReactiveElement {
+ @observable @property({type: Object}) scale: D3Scale =
+ d3.scaleOrdinal([DEFAULT]).domain(['all']) as D3Scale;
+ @property({type: String}) legendType = LegendType.CATEGORICAL;
+ @property({type: String}) selectedColorName = '';
+ /** Width of the container. Used to determine if blocks should be labeled. */
+ @property({type: Number}) legendWidth = 150;
+
+ // font attributes used to compute whether or not to show the text labels
+ private fontFamily: string = '';
+ private fontStyle: string = '';
+ private fontSize: string = '';
+
+ // label margin values will be updated to be correct one in firstUpdated
+ private labelMarginLeft: number = 3;
+ private labelMarginRight: number = 2;
+
+ private boxWidth: number = 13;
+ private boxMargin: number = 2;
+
+ private selectedColorLabelWidth: number = 46;
+ private iconWidth: number = 16;
+
+ static override get styles() {
+ return [sharedStyles, styles];
+ }
+
+ override firstUpdated() {
+ const {host} = this.shadowRoot!;
+ if (host) {
+ const style = window.getComputedStyle(host);
+ this.legendWidth = stylePropToNumber(style, 'width') || this.legendWidth;
+ }
+
+ /** Get font styles from the legend-label */
+ const legendLabelElement = this.shadowRoot!.querySelector('.legend-label');
+ if (legendLabelElement) {
+ const style = window.getComputedStyle(legendLabelElement);
+ this.fontFamily = style.getPropertyValue('font-family');
+ this.fontStyle = style.getPropertyValue('font-style');
+ this.fontSize = style.getPropertyValue('font-size');
+
+ this.labelMarginLeft =
+ stylePropToNumber(style, 'margin-left') || this.labelMarginLeft;
+ this.labelMarginRight =
+ stylePropToNumber(style, 'margin-right') || this.labelMarginRight;
+ }
+
+ /** Get styles from the legend-box */
+ const boxElement = this.shadowRoot!.querySelector('.legend-box');
+ if (boxElement) {
+ const style = window.getComputedStyle(boxElement);
+ this.boxWidth = stylePropToNumber(style, 'width') || this.boxWidth;
+ this.boxMargin = stylePropToNumber(style, 'margin') || this.boxMargin;
+ }
+
+ /** Get styles from the color-label */
+ const colorLabelElement = this.shadowRoot!.querySelector('.color-label');
+ if (colorLabelElement) {
+ const style = window.getComputedStyle(colorLabelElement);
+ const marginLeft = stylePropToNumber(style, 'margin-left') || 3;
+ const marginRight = stylePropToNumber(style, 'margin-right') || 3;
+ this.selectedColorLabelWidth = marginLeft + marginRight +
+ stylePropToNumber(style, 'width') || this.selectedColorLabelWidth;
+ }
+
+ /** Get styles from the palette-icon */
+ const iconElement = this.shadowRoot!.querySelector('.palette-icon');
+ if (iconElement) {
+ const style = window.getComputedStyle(iconElement);
+ this.iconWidth = stylePropToNumber(style, 'width') || this.iconWidth;
+ }
+ }
+
+ // TODO(b/237418328): Add a custom tooltip for a faster display time.
+ /**
+ * Render individual color block and the associated Label
+ * Hide the labels if it's a squential legendType or
+ * a categorical legendType which width exceeds legendWidth
+ */
+ private renderLegendBlock(val: string|number, hideLabels: boolean) {
+ const background = this.scale(val);
+ const style = styleMap({'background': background});
+
+ // clang-format off
+ return html`
+
+
+
${val}
+
+ `;
+ // clang-format on
+ }
+
+ /**
+ * Render color blocks for sequential values.
+ * When hovering over the blocks, a range of mapping values will be displayed
+ * @param {string|number} startVal - the min value of a range
+ * @param {string|number} endVal - the max value of a range
+ * @param {string|number} colorVal - for coloring the block
+ * @param {boolean} includeMax - whether to include the max value in a range
+ */
+ private renderSequentialBlock(startVal: string|number, endVal: number|string,
+ colorVal: string|number, includeMax: boolean = false) {
+ const title = startVal === endVal ? startVal :
+ includeMax ? `[${startVal}, ${endVal}]`
+ : `[${startVal}, ${endVal})`;
+ const background = this.scale(colorVal);
+ const style = styleMap({'background': background});
+
+ // TODO(b/237418328): Add a custom tooltip for a faster display time.
+ // clang-format off
+ return html`
+