diff --git a/lit_nlp/api/layout.py b/lit_nlp/api/layout.py
index bb8d773e..188c8cd4 100644
--- a/lit_nlp/api/layout.py
+++ b/lit_nlp/api/layout.py
@@ -59,10 +59,6 @@ class LitModuleName(dtypes.EnumSerializableAsValues, enum.Enum):
SimpleDatapointEditorModule = 'simple-datapoint-editor-module'
# Non-replicating version of Datapoint Editor
SingleDatapointEditorModule = 'single-datapoint-editor-module'
- SpanGraphGoldModule = 'span-graph-gold-module'
- SpanGraphGoldModuleVertical = 'span-graph-gold-module-vertical'
- SpanGraphModule = 'span-graph-module'
- SpanGraphModuleVertical = 'span-graph-module-vertical'
TCAVModule = 'tcav-module'
ThresholderModule = 'thresholder-module'
TrainingDataAttributionModule = 'tda-module'
@@ -126,8 +122,6 @@ def to_json(self) -> dtypes.JsonDict:
modules = LitModuleName # pylint: disable=invalid-name
MODEL_PREDS_MODULES = (
- modules.SpanGraphGoldModuleVertical,
- modules.SpanGraphModuleVertical,
modules.ClassificationModule,
modules.MultilabelModule,
modules.RegressionModule,
diff --git a/lit_nlp/client/elements/span_graph_vis.css b/lit_nlp/client/elements/span_graph_vis.css
deleted file mode 100644
index 310775f3..00000000
--- a/lit_nlp/client/elements/span_graph_vis.css
+++ /dev/null
@@ -1,74 +0,0 @@
-text.token-text {
- alignment-baseline: middle;
- dominant-baseline: central;
-}
-
-polyline.span-bracket {
- fill: none;
- stroke-width: 1.2px;
- stroke: var(--group-color);
-}
-
-.selected polyline.span-bracket {
- stroke-width: 1.8px;
-}
-
-path.arc-path {
- stroke-width: 1.2px;
- stroke: var(--group-color);
- fill: none;
-}
-
-path.arc-path.arc-neg {
- stroke-dasharray: 3,1;
- stroke: gray;
-}
-
-.selected path.arc-path {
- stroke-width: 1.8px;
-}
-
-path.arc-arrow {
- stroke-width: 1.2px;
- stroke: var(--group-color);
- fill: var(--group-color);
-}
-
-path.arc-arrow.arc-neg {
- stroke: gray;
- fill: gray;
-}
-
-
-.layer-label text {
- font-family: 'Share Tech Mono', monospace;
- dominant-baseline: middle;
- text-anchor: end;
-}
-
-foreignObject.span-label {
- overflow: visible;
-}
-
-.span-label div {
- background-color: white;
- font-family: 'Share Tech Mono', monospace;
- line-height: 1.0;
- padding: 1px;
- padding-right: 3px; /* for occluding labels on mouseover */
- white-space: nowrap;
- overflow: hidden;
- text-overflow: ellipsis;
- color: var(--group-color);
-}
-
-g.selected .span-label div {
- background-color: white;
- overflow-x: visible;
- width: fit-content; /* needed to include background when expanding */
-}
-
-.mousebox {
- fill: white;
- fill-opacity: 0.0;
-}
diff --git a/lit_nlp/client/elements/span_graph_vis.ts b/lit_nlp/client/elements/span_graph_vis.ts
deleted file mode 100644
index 0e10cf3a..00000000
--- a/lit_nlp/client/elements/span_graph_vis.ts
+++ /dev/null
@@ -1,472 +0,0 @@
-/**
- * @license
- * Copyright 2020 Google LLC
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/**
- * Visualization component for structured prediction over text.
- */
-
-import * as d3 from 'd3';
-import {html, LitElement, svg} from 'lit';
-import {customElement, property} from 'lit/decorators.js';
-import {classMap} from 'lit/directives/class-map.js';
-import {styleMap} from 'lit/directives/style-map.js';
-
-import {getVizColor} from '../lib/colors';
-import {EdgeLabel} from '../lib/dtypes';
-
-import {styles} from './span_graph_vis.css';
-
-
-/**
- * Represents a group of directed graphs anchored to token spans.
- * This is the general "edge probing" representation, which can be used for many
- * problems including sequence tagging, span labeling, and directed graphs like
- * semantic frames, coreference, and dependency parsing. See
- * https://arxiv.org/abs/1905.06316 for more on this formalism.
- */
-export interface SpanGraph {
- 'tokens': string[];
- 'layers': AnnotationLayer[];
-}
-
-/**
- * A single layer of annotations, like 'pos' (part-of-speech)
- * or 'ner' (named entities).
- */
-export interface AnnotationLayer {
- 'name': string;
- 'edges': EdgeLabel[];
-}
-
-/* Compute points for a polyline bracket. */
-function hBracketPoints(width: number, height: number, lean: number) {
- // Points for a polyline bracket.
- const start = [0, 0];
- const ltop = [lean, height];
- const rtop = [width - lean, height];
- const end = [width, 0];
- return [start, ltop, rtop, end];
-}
-
-/**
- * Compute path for a dependency arc.
- */
-function arcPath(
- startY: number, x1: number, x2: number, height: number, aspect: number) {
- const left = Math.min(x1, x2);
- const right = Math.max(x1, x2);
- let pathCommands = `M ${left} ${startY} `;
- if ((right - left) > (2 * aspect * height)) {
- // Long arcs: draw as 90* curve, flat, then 90* curve.
- const majorAxis = aspect * height;
- pathCommands += `A ${majorAxis} ${height} 0 0 1 ${left + majorAxis} ${
- startY - height} `;
- pathCommands += `L ${right - majorAxis} ${startY - height} `;
- pathCommands += `A ${majorAxis} ${height} 0 0 1 ${right} ${startY} `;
- } else {
- // Short arcs: draw as single 180* curve.
- height = (right - left) / (2 * aspect);
- pathCommands +=
- `A ${(right - left) / 2} ${height} 0 0 1 ${right} ${startY} `;
- }
- return pathCommands; /* assign as 'd' attribute to path */
-}
-
-/**
- * Compute path for the arrow at the end of an arc.
- */
-function arcArrow(startY: number, x: number, markSize: number) {
- let pathCommands = `M ${x - markSize} ${startY - (1.5 * markSize)} `;
- pathCommands += `L ${x + markSize} ${startY - (1.5 * markSize)} `;
- pathCommands += `L ${x} ${startY} Z`;
- return pathCommands; /* assign as 'd' attribute to path */
-}
-
-/* Set attributes to match target's size to the source element. */
-function matchBBox(source: SVGGElement, target: SVGRectElement) {
- const bbox = source.getBBox();
- target.setAttribute('x', `${bbox.x}`);
- target.setAttribute('y', `${bbox.y}`);
- target.setAttribute('width', `${bbox.width}`);
- target.setAttribute('height', `${bbox.height}`);
-}
-
-/** Structured prediction (SpanGraph) visualization class. */
-@customElement('span-graph-vis')
-export class SpanGraphVis extends LitElement {
- /* Data binding */
- @property({type: Object}) data: SpanGraph = {tokens: [], layers: []};
- @property({type: Boolean}) showLayerLabel: boolean = true;
-
- /* Rendering parameters */
- @property({type: Number}) lineHeight: number = 18;
- @property({type: Number}) bracketHeight: number = 5.5;
- @property({type: Number}) yPad: number = 5;
- // For arcs between spans.
- @property({type: Number}) arcBaseHeight: number = 20;
- @property({type: Number}) arcMaxHeight: number = 40;
- @property({type: Number}) arcAspect: number = 1.2;
- @property({type: Number}) arcArrowSize: number = 4;
- // Padding for SVG viewport, to avoid clipping some elements (like polyline).
- @property({type: Number}) viewPad: number = 5;
- // Multiplier from SVG units to screen pixels.
- @property({type: Number}) svgScaling: number = 1.2;
-
- /* Internal rendering state */
- private tokenXBounds: Array<[number, number]> = [];
-
- static override get styles() {
- return styles;
- }
-
- renderTokens(tokens: string[]) {
- return svg`
-
-
- ${tokens.map(t => svg`${svg`${t + ' '}`}`)}
-
- `;
- }
-
- private getTokenGroup() {
- return this.shadowRoot!.querySelector('g#token-group') as SVGGElement;
- }
-
- renderEdge(edge: EdgeLabel, color: string) {
- // Positioning relative to the group transform, which will be applied later.
- const labelHeight = this.lineHeight;
- const labelY = -(this.bracketHeight + this.lineHeight);
-
- let labelText = edge.label;
- let isNegativeEdge = false;
- if (typeof edge.label === 'number') {
- labelText = edge.label.toFixed(3);
- isNegativeEdge = edge.label < 0.5;
- }
- const arcPathClass =
- classMap({'arc-path': true, 'arc-neg': isNegativeEdge});
- const arcArrowClass =
- classMap({'arc-arrow': true, 'arc-neg': isNegativeEdge});
- color = isNegativeEdge ? 'gray' : color;
- // clang-format off
- return svg`
-
- ${edge.span2 ? svg`
-
-
- ` : ''}
-
-
-
- ${html`${labelText}
`}
-
-
-
- ${edge.span2 ? svg`
-
- ` : ''}
-
- `;
- // clang-format on
- }
-
- renderLayer(layer: AnnotationLayer, i: number) {
- const rowColor = getVizColor('deep', i).color;
- // Positioning relative to the group transform, which will be applied later.
- const rowLabelX = -10;
- const rowLabelY = -(this.bracketHeight + 0.5 * this.lineHeight);
-
- const orderedEdges = this.sortEdges(layer.edges);
- // clang-format off
- return svg`
-
- ${this.showLayerLabel ? svg`
-
-
- ${svg`${layer.name}`}
-
- ` : null}
- ${orderedEdges.map(edge => this.renderEdge(edge, rowColor))}
-
- `;
- // clang-format on
- }
-
- private getLayerGroup(name: string) {
- return this.shadowRoot!.querySelector(`g#layer-group-${name}`) as
- SVGGElement;
- }
-
- override render() {
- return svg`
- `;
- }
-
- private findTokenBounds() {
- const tokenNodes = this.getTokenGroup().querySelectorAll('tspan');
- const tokenXBounds: Array<[number, number]> = [];
- tokenNodes.forEach(tspan => {
- // Use getBBox() to avoid a crash when tspan.getNumberOfChars() === 0.
- // TODO(lit-dev): figure out why this case happens - maybe
- // the nodes are not yet attached to the DOM?
- const bbox = tspan.getBBox();
- tokenXBounds.push([bbox.x, bbox.x + bbox.width]);
- });
- return tokenXBounds;
- }
-
- /**
- * Consistent sort order.
- * Because span labels overflow to the right, we order these so the rightmost
- * spans appear first in the DOM, and thus render under anything to the left
- * that needs to overflow.
- */
- private sortEdges(edges: EdgeLabel[]) {
- return edges.slice().sort((a, b) => d3.descending(a.span1[1], b.span1[1]));
- }
-
- /* Starting x position for a bracket, in SVG coordinates */
- private getStartX(span: [number, number]) {
- return this.tokenXBounds[span[0]][0];
- }
-
- /* Ending x position for a bracket, in SVG coordinates */
- private getEndX(span: [number, number]) {
- return this.tokenXBounds[span[1] - 1][1];
- }
-
- /* Find available width without clipping the next label */
- private findAvailableWidths(layerGroup: Element, edges: EdgeLabel[]):
- number[] {
- const availableWidths: number[] = edges.map(() => 0);
- // Find available space for each label, by checking where the next label
- // starts. We iterate from right to left through the spans, starting with
- // the second-rightmost (i=1).
- for (let i = 1; i < edges.length; i++) {
- const edge = edges[i]; // this span
- const nextEdge = edges[i - 1]; // right neighboring span
- availableWidths[i] =
- this.getStartX(nextEdge.span1) - this.getStartX(edge.span1);
- }
- // We don't want the rightmost label (index 0) to be cut off by the edge
- // of the SVG draw area, even if the label extends past the end of the
- // text. So we need to:
- // 1) Set this label to fit the content, so the bounding box contains all
- // the label text.
- // 2) Set the available width to this rendered width, so we don't clip it
- // later.
- const firstSpanDiv =
- // tslint:disable-next-line:no-unnecessary-type-assertion
- layerGroup.querySelector('g.edge-group foreignObject div') as
- HTMLDivElement |
- null;
- if (firstSpanDiv !== null) {
- firstSpanDiv.style.width = 'fit-content';
- availableWidths[0] = firstSpanDiv.getBoundingClientRect().width;
- }
- return availableWidths;
- }
-
- /* Set mouseovers, using d3. */
- private setMouseovers(group: SVGGElement, edges: EdgeLabel[]) {
- const rowColor = group.dataset['color'] as string;
- const grayColor = getVizColor('deep', 'other').color;
-
- const spanGroups = d3.select(group).selectAll('g.edge-group').data(edges);
- const tokenSpans = d3.select(this.getTokenGroup()).selectAll('tspan');
-
- // On mouseover, highlight this span and the corresponding text.
- spanGroups.each(function(d, i) {
- const colorFn = (e: unknown, j: number) =>
- (i === j) ? rowColor : grayColor;
- const tokenColorFn = (t: unknown, j: number) => {
- const inSpan1 = (d.span1[0] <= j && j < d.span1[1]);
- const inSpan2 = d.span2 ? (d.span2[0] <= j && j < d.span2[1]) : false;
- return (inSpan1 || inSpan2) ? rowColor : 'black';
- };
- const mouseBox = d3.select(this).select('rect.mousebox');
- mouseBox.on('mouseover', () => {
- spanGroups.style('--group-color', colorFn);
- tokenSpans.attr('fill', tokenColorFn);
- d3.select(this).classed('selected', true);
- // Ideally we'd also move this element so that it renders above
- // the other groups, but SVG2 z-index is not supported by most browsers
- // and simply reordering child nodes does not play well with lit-html's
- // rendering logic, which relies on pointers to specific positions in
- // the DOM.
- // d3.select(this).classed('selected', true).raise();
- // TODO(iftenney): consider implementing a tooltip that clones this
- // element but always renders above the other spans.
- });
- mouseBox.on('mouseout', () => {
- // Reset to original color, stored on group element.
- // TODO(lit-dev): do this with another CSS class instead?
- spanGroups.style('--group-color', function(e) {
- return (this as SVGElement).dataset['color'] as string;
- });
- tokenSpans.attr('fill', 'black');
- d3.select(this).classed('selected', false);
- });
- });
- }
-
- /* Set y-position of rendered layers */
- private positionLayers() {
- let rowStartY = this.getTokenGroup().getBBox().y - this.yPad / 2;
- for (let i = 0; i < this.data.layers.length; i++) {
- const group: SVGGElement = this.getLayerGroup(this.data.layers[i].name);
- group.setAttribute('transform', `translate(0, ${rowStartY})`);
- rowStartY -= group.getBBox().height + this.yPad;
- }
- }
-
- /* Set the SVG viewport to the bounding box of the main group. */
- private setSVGViewport() {
- const mainGroup = this.shadowRoot!.querySelector('g#all') as SVGGElement;
- const bbox = mainGroup.getBBox();
- const svg = this.shadowRoot!.getElementById('svg')!;
- // Set bounding box to cover main group + viewPad on all sides.
- const viewBox = [
- bbox.x - this.viewPad, bbox.y - this.viewPad,
- bbox.width + 2 * this.viewPad, bbox.height + 2 * this.viewPad
- ];
- svg.setAttribute('viewBox', `${viewBox}`);
- // Set the height of the SVG as it will render on the page.
- svg.setAttribute(
- 'height', `${this.svgScaling * (bbox.height + 2 * this.viewPad)}`);
- }
-
- /**
- * Post-render callback. Performs imperative updates to layout and component
- * sizes which need to depend on the positions of each token. Also sets up
- * mouseover behavior.
- */
- override updated() {
- if (this.data == null) {
- this.tokenXBounds = [];
- return;
- }
- this.tokenXBounds = this.findTokenBounds();
-
- // For each layer, position the span groups
- for (const layer of this.data.layers) {
- const orderedEdges = this.sortEdges(layer.edges);
-
- // Container group for this layer.
- const layerGroup: SVGGElement = this.getLayerGroup(layer.name);
-
- // Compute available widths, needed for clipping of labels.
- const availableWidths =
- this.findAvailableWidths(layerGroup, orderedEdges);
-
- // Edge groups within this layer.
- const edgeGroups = layerGroup.querySelectorAll('g.edge-group');
- edgeGroups.forEach((g, i) => {
- const edge = orderedEdges[i];
-
- const g1 = g.querySelector('g.at-span1')!;
- // Set position within this row.
- g1.setAttribute(
- 'transform', `translate(${this.getStartX(edge.span1)}, 0)`);
-
- // Compute span width in SVG units, based on rendered token width.
- const span1Width =
- this.getEndX(edge.span1) - this.getStartX(edge.span1);
- // Set points for span1 bracket.
- const points1 =
- hBracketPoints(span1Width, -1 * (this.bracketHeight - 1), 1);
- g1.querySelector('polyline')!.setAttribute('points', `${points1}`);
-
- // Set the width for the label; this will show ellipsis for the label
- // text if it is longer.
- // Leave a few pixels spacing if we can afford it, but don't go
- // shorter than the token width.
- const displayWidth = Math.max(span1Width, availableWidths[i] - 5);
- g.querySelector('foreignObject')!.setAttribute(
- 'width', `${displayWidth}`);
-
- // If there's a second span, set up bracket
- // and draw arc from span1 -> span2 with the arrow on span1.
- if (edge.span2) {
- const g2 = g.querySelector('g.at-span2')!;
- // Set position within this row.
- g2.setAttribute(
- 'transform', `translate(${this.getStartX(edge.span2)}, 0)`);
- // Compute span width in SVG units, based on rendered token width.
- const span2Width =
- this.getEndX(edge.span2) - this.getStartX(edge.span2);
- const points2 =
- hBracketPoints(span2Width, -1 * (this.bracketHeight - 1), 1);
- g2.querySelector('polyline')!.setAttribute('points', `${points2}`);
-
- // Draw arc.
- const startY =
- -1 * (this.bracketHeight + this.lineHeight + 1 /* pad */);
- const x1 =
- (this.getEndX(edge.span1) + this.getStartX(edge.span1)) / 2;
- let x2 = (this.getEndX(edge.span2) + this.getStartX(edge.span2)) / 2;
- // Adjust arc end to avoid overlapping arrows.
- // See //nlp/saft/rendering/sentence-html-renderer.js
- if (x2 > x1) {
- x2 -= (this.arcArrowSize + 2);
- } else {
- x2 += (this.arcArrowSize + 2);
- }
- // Adjust arc height based on edge length (# tokens between
- // midpoints). See nlp_saft::SentenceRenderer::CalculateDimensions()
- // from //nlp/saft/rendering/sentence-html-rendering.cc
- const mid1 = (edge.span1[1] + edge.span1[0]) / 2;
- const mid2 = (edge.span2[1] + edge.span2[0]) / 2;
- const l = Math.min(30, Math.abs(mid2 - mid1));
- const arcHeight = Math.min(
- this.arcBaseHeight + Math.round((10 - (l / 6.0)) * l),
- this.arcMaxHeight);
- g.querySelector('path.arc-path')!.setAttribute(
- 'd', `${arcPath(startY, x1, x2, arcHeight, this.arcAspect)}`);
- g.querySelector('path.arc-arrow')!.setAttribute(
- 'd', `${arcArrow(startY, x1, this.arcArrowSize)}`);
- }
- });
-
- // Set mouseover behavior for this layer.
- this.setMouseovers(layerGroup, orderedEdges);
- }
-
- // Set mouseover boxes to match the _visible_ size of the label container.
- this.shadowRoot!.querySelectorAll('g.edge-group').forEach(g => {
- matchBBox(
- g.querySelector('foreignObject') as SVGGElement,
- g.querySelector('rect.mousebox') as SVGRectElement);
- });
-
- // Stack layers vertically, using bounding boxes to avoid occlusion.
- this.positionLayers();
- // Finally, after everything is positioned, set the viewport for the whole
- // SVG.
- this.setSVGViewport();
- }
-}
-
-declare global {
- interface HTMLElementTagNameMap {
- 'span-graph-vis': SpanGraphVis;
- }
-}
diff --git a/lit_nlp/client/elements/span_graph_vis_vertical.css b/lit_nlp/client/elements/span_graph_vis_vertical.css
deleted file mode 100644
index 2431fed4..00000000
--- a/lit_nlp/client/elements/span_graph_vis_vertical.css
+++ /dev/null
@@ -1,103 +0,0 @@
-.holder {
- display: flex;
- font-family: 'Share Tech Mono', monospace;
- position: relative;
- color: #555;
-}
-.layer {
- cursor: pointer;
- color: var(--group-color);
-}
-.layer-label-vert {
- top: calc(0px - var(--line-height));
- position: absolute;
- padding-left: 7px;
- display: flex;
- transform: rotate(0deg);
- transition: .25s transform;
- transform-origin: 10px 10px;
- color: var(--group-color);
-}
-.layer-label-vert.hidden {
- transform: rotate(-90deg);
-}
-.column {
- position: relative;
- height: 100%;
- transition: .25s width;
-}
-.column.hidden {
- width: 13px !important;
- opacity: 0;
-}
-.tokens {
- z-index: 1;
-}
-.line {
- height: var(--line-height);
- box-sizing: border-box;
- padding: 0 7px;
- text-align: right;
- white-space: nowrap;
-}
-.token.selected{
- color: black;
-}
-.child {
- font-weight: bold;
- filter: hue-rotate(-40deg);
- border-width: 3px;
-}
-.parent {
- font-weight: bold;
- filter: hue-rotate(40deg);
- border-width: 3px;
-}
-.selected {
- font-weight: bold;
- border-width: 3px;
-}
-
-.gray {
- border-color: #ddd !important;
- color: #ddd
-}
-
-.edge {
- position: absolute;
-}
-.edge-line {
- border: 1px solid var(--group-color);
- border-left: 0;
- width: 3px;
- left: -3px;
-}
-.arrow-head {
- border: 5px solid transparent;
- border-right: 5px solid var(--group-color);
- width: 0;
- height: 0;
- top: -5px;
- left: -5px;
- position: absolute;
-}
-.gray .arrow-head{
- border-right-color: #ddd;
-}
-.arrow-head.bottom{
- bottom: -5px;
- top: unset;
-}
-.background-lines{
- position: absolute;
- width: 100%;
- top: -3px;
-}
-.background-line {
- height: var(--line-height);
- width: 100%;
- padding: 0 7px;
-}
-.background-line:nth-child(odd){
- background: #f5f5f5;
-}
diff --git a/lit_nlp/client/elements/span_graph_vis_vertical.ts b/lit_nlp/client/elements/span_graph_vis_vertical.ts
deleted file mode 100644
index 86915ef5..00000000
--- a/lit_nlp/client/elements/span_graph_vis_vertical.ts
+++ /dev/null
@@ -1,278 +0,0 @@
-/**
- * @license
- * Copyright 2020 Google LLC
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/**
- * Visualization component for structured prediction over text.
- */
-
-// tslint:disable:no-new-decorators
-
-import {property} from 'lit/decorators.js';
-import {customElement} from 'lit/decorators.js';
-import { html} from 'lit';
-import {classMap} from 'lit/directives/class-map.js';
-import {styleMap} from 'lit/directives/style-map.js';
-import {observable} from 'mobx';
-
-import {getVizColor} from '../lib/colors';
-import {EdgeLabel} from '../lib/dtypes';
-import {ReactiveElement} from '../lib/elements';
-
-import {styles} from './span_graph_vis_vertical.css';
-
-
-/**
- * Represents a group of directed graphs anchored to token spans.
- * This is the general "edge probing" representation, which can be used for many
- * problems including sequence tagging, span labeling, and directed graphs like
- * semantic frames, coreference, and dependency parsing. See
- * https://arxiv.org/abs/1905.06316 for more on this formalism.
- */
-export interface SpanGraph {
- tokens: string[];
- layers: AnnotationLayer[];
-}
-
-/**
- * A single layer of annotations, like 'pos' (part-of-speech)
- * or 'ner' (named entities).
- */
-export interface AnnotationLayer {
- name: string;
- edges: EdgeLabel[];
- hideBracket?: boolean;
-}
-
-function formatEdgeLabel(label: string|number): string {
- if (typeof (label) === 'number') {
- return Number.isInteger(label) ? label.toString() :
- label.toFixed(3).toString();
- }
- return `${label}`;
-}
-
-/** Structured prediction (SpanGraph) visualization class. */
-@customElement('span-graph-vis-vertical')
-export class SpanGraphVis extends ReactiveElement {
- /* Data binding */
- @property({ type: Object }) data: SpanGraph = { tokens: [], layers: [] };
- @property({ type: Boolean }) showLayerLabel: boolean = true;
-
- @observable private selectedTokIdx?: number;
- @observable private readonly columnVisibility: { [key: string]: boolean } = {};
-
- /* Rendering parameters */
- @property({ type: Number }) lineHeight: number = 18;
- @property({ type: Number }) approxFontSize = this.lineHeight / 3;
-
- // Padding for SVG viewport, to avoid clipping some elements (like polyline).
- @property({ type: Number }) viewPad: number = 5;
-
- static override get styles() {
- return styles;
- }
-
- override render() {
- if (!this.data) {
- return ``;
- }
- const host = this.shadowRoot!.host as HTMLElement;
- host.style.setProperty('--line-height', `${this.lineHeight}pt`);
- const tokens = this.data.tokens;
-
- const tokenClasses = (i: number) => classMap({
- line: true,
- token: true,
- selected: i === this.selectedTokIdx
- });
-
- // clang-format off
- return html`
-
-
- ${tokens.map(t => html`
`)}
-
-
- ${tokens.map((t, i) => html`
-
this.selectedTokIdx = i}
- @mouseleave=${() => this.selectedTokIdx = undefined}>
- ${t}
-
- `)}
-
- ${this.data.layers.map((layer, i) => this.renderLayer(layer, i))}
-
`;
- // clang-format on
- }
-
- /**
- * Render a given annotation layer.
- */
- renderLayer(layer: AnnotationLayer, i: number) {
-
- if (!layer.edges.length) {
- return html``;
- }
-
- const layerStyles = styleMap({
- '--group-color': getVizColor('dark', i).color
- });
-
- // The column width is the width of the longest label, in pixels.
- const colWidth =
- Math.max(
- layer.name.length,
- ...layer.edges.map(e => formatEdgeLabel(e.label).length)) *
- this.approxFontSize +
- this.viewPad * 2;
-
- const colStyles = styleMap({ width: `${colWidth}pt` });
- const hidden = this.columnVisibility[layer.name];
- const columnClasses = classMap({
- 'column': true,
- 'hidden': hidden
- });
-
- const headerClasses = classMap({ 'layer-label-vert': true, hidden });
- const onClick = () =>
- this.columnVisibility[layer.name] = !this.columnVisibility[layer.name];
-
- // clang-format off
- return html`
-
- ${this.showLayerLabel ? html`
- ` : null}
-
- ${layer.edges.map(edge => this.renderEdge(edge, layer, colWidth))}
-
-
- `;
- // clang-format on
- }
-
- /**
- * Render an edge and its label. See the note on the SpanGraph interface
- * above for more details.
- */
- private renderEdge(edge: EdgeLabel, layer: AnnotationLayer, colWidth: number) {
- const isArc = 'span2' in edge;
- const span0 = edge.span1[0];
- const span1 = edge.span2 ? edge.span2[0] : edge.span1[1];
- const topSpan = Math.min(span0, span1);
- const botSpan = Math.max(span0, span1);
-
-
- const isInSpan = (i: number, span:[number, number]) => i >= span[0] && i < span[1];
-
- // Span classes (child, parent, etc, based on the currently selected token.)
- const tokSelected = this.selectedTokIdx !== undefined;
- const selected = isInSpan(this.selectedTokIdx!, edge.span1) || (isArc && isInSpan(this.selectedTokIdx!, edge.span2!));
- const child = isArc && this.selectedTokIdx === span1;
- const parent = isArc && this.isChildOfSelected(layer, span0);
- const grayLine = tokSelected && !(selected || child);
- const grayLabel = grayLine && !(parent);
-
- // Edge labels can be either strings or numbers; format the latter nicely.
- const formattedLabel = formatEdgeLabel(edge.label);
-
- // Styling for the label text.
- const labelWidthInPx = formattedLabel.length * this.approxFontSize;
- const labelStyle = styleMap({
- top: `${span0 * this.lineHeight}pt`,
- left: isArc ? `${colWidth - labelWidthInPx - this.viewPad}pt` : '',
- });
- const labelClasses = classMap({
- child, parent, selected,
- gray: grayLabel,
- line: true,
- edge: true
- });
-
- // Styling for the arc (a line and sometimes an arrowhead)
- const arcPad = .3;
- const offset = this.lineHeight / 8;
- const top = isArc ?
- (topSpan + arcPad) * this.lineHeight + (topSpan === span0 ? 0 : this.viewPad) :
- topSpan * (this.lineHeight) - offset;
- const bottom = isArc ?
- (botSpan + arcPad) * this.lineHeight + (botSpan === span0 ? 0 : -this.viewPad) :
- botSpan * (this.lineHeight) - 2 * offset;
-
- const arcHeight = bottom - top;
- const width = isArc ? `${Math.max(arcHeight / 2, this.lineHeight / 2)}pt` : '';
-
- const rad = isArc ? arcHeight / 2 : 3;
- const lineStyle = styleMap({
- top: `${top}pt`,
- height: `${arcHeight}pt`,
- width,
- 'border-radius': `0pt ${rad}pt ${rad}pt 0pt`,
- left: isArc ? `${colWidth + 10}pt` : '',
- visibility: layer.hideBracket ? 'hidden' : 'visble',
- });
-
- const arrowHeadClasses = classMap({
- 'arrow-head': true,
- 'bottom': topSpan === span1,
- });
-
- const arrowClasses = classMap({
- child,
- parent: selected,
- gray: grayLine,
- edge: true,
- 'edge-line': true
- });
-
- return html`
-
- ${isArc ? html`
` : ''}
-
-
- ${formattedLabel}
-
- `;
- }
-
- /**
- * Is this token (indicated by tokenIdx) a child of the selected token at
- * the specified layer. This assumes that the edge goes from span1 to span2,
- * as in a dependency parse tree.
- */
- isChildOfSelected(layer: AnnotationLayer, tokenIdx: number) {
- for (let j = 0; j < layer.edges.length; j++) {
- const edge = layer.edges[j];
- if (edge.span2 &&
- (this.selectedTokIdx === edge.span1[0]) &&
- (tokenIdx === edge.span2[0])) {
- return true;
- }
- }
- return false;
- }
-
-}
-
-declare global {
- interface HTMLElementTagNameMap {
- 'span-graph-vis-vertical': SpanGraphVis;
- }
-}
diff --git a/lit_nlp/client/modules/annotated_text_module.ts b/lit_nlp/client/modules/annotated_text_module.ts
index 286d8c76..7165cb12 100644
--- a/lit_nlp/client/modules/annotated_text_module.ts
+++ b/lit_nlp/client/modules/annotated_text_module.ts
@@ -6,7 +6,7 @@
* spans in running text, which is well-suited for tasks like QA or entity
* recognition which have a small number of spans over a longer passage.
*
- * Similar to span_graph_module, we provide two module classes:
+ * We provide two module classes:
* - AnnotatedTextGoldModule for gold annotations (in the input data)
* - AnnotatedTextModule for model predictions
*/
@@ -14,18 +14,17 @@
// tslint:disable:no-new-decorators
import '../elements/annotated_text_vis';
+import {html} from 'lit';
import {customElement} from 'lit/decorators.js';
-import { html} from 'lit';
import {observable} from 'mobx';
import {LitModule} from '../core/lit_module';
import {type AnnotationGroups, TextSegments} from '../elements/annotated_text_vis';
import {MultiSegmentAnnotations, TextSegment} from '../lib/lit_types';
+import {styles as sharedStyles} from '../lib/shared_styles.css';
import {type IndexedInput, ModelInfoMap, Spec} from '../lib/types';
import {doesOutputSpecContain, filterToKeys, findSpecKeys} from '../lib/utils';
-import {styles as sharedStyles} from '../lib/shared_styles.css';
-
/** LIT module for model output. */
@customElement('annotated-text-gold-module')
export class AnnotatedTextGoldModule extends LitModule {
@@ -80,7 +79,8 @@ export class AnnotatedTextGoldModule extends LitModule {
// clang-format on
}
- static override shouldDisplayModule(modelSpecs: ModelInfoMap, datasetSpec: Spec) {
+ static override shouldDisplayModule(
+ modelSpecs: ModelInfoMap, datasetSpec: Spec) {
return findSpecKeys(datasetSpec, MultiSegmentAnnotations).length > 0;
}
}
@@ -159,7 +159,8 @@ export class AnnotatedTextModule extends LitModule {
// clang-format on
}
- static override shouldDisplayModule(modelSpecs: ModelInfoMap, datasetSpec: Spec) {
+ static override shouldDisplayModule(
+ modelSpecs: ModelInfoMap, datasetSpec: Spec) {
return doesOutputSpecContain(modelSpecs, MultiSegmentAnnotations);
}
}
diff --git a/lit_nlp/client/modules/span_graph_module.ts b/lit_nlp/client/modules/span_graph_module.ts
deleted file mode 100644
index 52bc33a5..00000000
--- a/lit_nlp/client/modules/span_graph_module.ts
+++ /dev/null
@@ -1,346 +0,0 @@
-/**
- * @license
- * Copyright 2020 Google LLC
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/**
- * Module within LIT for showing sequence and span tagging
- * results.
- */
-
-// tslint:disable:no-new-decorators
-import '../elements/span_graph_vis';
-import '../elements/span_graph_vis_vertical';
-
-import {customElement} from 'lit/decorators.js';
-import {css, html} from 'lit';
-import {computed, observable} from 'mobx';
-
-import {LitModule} from '../core/lit_module';
-import {AnnotationLayer, SpanGraph} from '../elements/span_graph_vis_vertical';
-import {EdgeLabel, SpanLabel} from '../lib/dtypes';
-import {EdgeLabels, SequenceTags, SpanLabels, LitTypeTypesList, LitTypeWithAlign, TextSegment, Tokens} from '../lib/lit_types';
-import {IndexedInput, Input, ModelInfoMap, Preds, Spec} from '../lib/types';
-import {findSpecKeys} from '../lib/utils';
-
-import {styles as sharedStyles} from '../lib/shared_styles.css';
-
-interface FieldNameMultimap {
- [fieldName: string]: string[];
-}
-
-interface Annotations {
- [tokenKey: string]: SpanGraph;
-}
-
-// Shared by gold and preds modules.
-const moduleStyles = css`
- .outer-container {
- display: flex;
- flex-direction: column;
- justify-content: center;
- position: relative;
- overflow: hidden;
- }
-
- .token-group {
- padding-top: 40px;
- }
-
- .field-title {
- padding: 4px;
- }
-`;
-
-const supportedPredTypes: LitTypeTypesList =
- [SequenceTags, SpanLabels, EdgeLabels];
-
-/**
- * Convert sequence tags to a list of length-1 span labels.
- */
-function tagsToEdges(tags: string[]): EdgeLabel[] {
- return tags.map((label: string, i: number) => {
- return {span1: [i, i + 1], label} as EdgeLabel;
- });
-}
-
-/**
- * Convert span labels to single-sided edge labels.
- */
-function spansToEdges(spans: SpanLabel[]): EdgeLabel[] {
- return spans.map(
- d => ({span1: [d.start, d.end], label: d.label as string} as EdgeLabel));
-}
-
-function mapTokenToTags(spec: Spec): FieldNameMultimap {
- const tagKeys = findSpecKeys(spec, supportedPredTypes);
- const tokenKeys = findSpecKeys(spec, Tokens);
-
- // Make a mapping of token keys to one or more tag sets
- const tokenToTags: FieldNameMultimap = {};
- for (const tagKey of tagKeys) {
- const {align: tokenKey} = spec[tagKey] as LitTypeWithAlign;
- if (tokenKey == null || !tokenKeys.includes(tokenKey)) {
- continue;
- } else if (tokenToTags[tokenKey] == null) {
- tokenToTags[tokenKey] = [];
- }
- tokenToTags[tokenKey].push(tagKey);
- }
- return tokenToTags;
-}
-
-function parseInput(data: Input|Preds, spec: Spec): Annotations {
- const tokenToTags = mapTokenToTags(spec);
-
- // Render a row for each set of tokens
- const ret: Annotations = {};
- for (const tokenKey of Object.keys(tokenToTags)) {
- const annotationLayers: AnnotationLayer[] = [];
- for (const tagKey of tokenToTags[tokenKey]) {
- let edges = data[tagKey];
- let hideBracket = false;
- // Temporary workaround: if we manually create a new datapoint, the span
- // or tag field may be "" rather than [].
- // TODO(lit-team): remove this once the datapoint editor is type-safe
- // for structured fields.
- if (edges.length === 0) {
- edges = [];
- }
- if (spec[tagKey] instanceof SequenceTags) {
- edges = tagsToEdges(edges);
- hideBracket = true;
- } else if (spec[tagKey] instanceof SpanLabels) {
- edges = spansToEdges(edges);
- }
- annotationLayers.push({name: tagKey, edges, hideBracket});
- }
- // Try to infer tokens from text, if that field is empty.
- let tokens = data[tokenKey];
- if (tokens.length === 0) {
- const textKey = findSpecKeys(spec, TextSegment)[0];
- tokens = data[textKey].split();
- }
- ret[tokenKey] = {tokens, layers: annotationLayers};
- }
- return ret;
-}
-
-function renderTokenGroups(
- data: Annotations, spec: Spec, orientation: 'horizontal'|'vertical') {
- const tokenToTags = mapTokenToTags(spec);
- const visElement = (data: SpanGraph, showLayerLabel: boolean) => {
- if (orientation === 'vertical') {
- return html``;
- } else {
- return html``;
- }
- };
- // clang-format off
- return html`${Object.keys(tokenToTags).map(tokenKey => {
- const labelHere = data[tokenKey]?.layers?.length === 1;
- return html`
-
- ${labelHere ?
- html`
${data[tokenKey].layers[0].name}
`
- : null}
- ${visElement(data[tokenKey], !labelHere)}
-
- `;
- })}`;
- // clang-format on
-}
-
-/** Gold predictions module class. */
-@customElement('span-graph-gold-module')
-export class SpanGraphGoldModule extends LitModule {
- static override title = 'Structured Prediction (gold)';
- static override duplicateForExampleComparison = true;
- static override duplicateForModelComparison = false;
- static override duplicateAsRow = false;
- static override numCols = 4;
- static override template =
- (model: string, selectionServiceIndex: number, shouldReact: number) => html`
-
- `;
- static orientation = 'horizontal';
-
- @computed
- get dataSpec() {
- return this.appState.currentDatasetSpec;
- }
-
- @computed
- get goldDisplayData(): Annotations {
- const input = this.selectionService.primarySelectedInputData;
- if (input === null) {
- return {};
- } else {
- return parseInput(input.data, this.dataSpec);
- }
- }
-
- static override get styles() {
- return [sharedStyles, moduleStyles];
- }
-
- // tslint:disable:no-any
- override renderImpl() {
- // If more than one model is selected, SpanGraphModule will be offset
- // vertically due to the model name header, while this one won't be.
- // So, add an offset so that the content still aligns when there is a
- // SpanGraphGoldModule and a SpanGraphModule side-by-side.
- const offsetForHeader = !this.appState.compareExamplesEnabled &&
- this.appState.currentModels.length > 1;
- // clang-format off
- return html`
- ${offsetForHeader? html`` : null}
-
- ${
- renderTokenGroups(
- this.goldDisplayData, this.dataSpec,
- (this.constructor as any).orientation)}
-
- `;
- // clang-format on
- }
- // tslint:enable:no-any
-
- static override shouldDisplayModule(modelSpecs: ModelInfoMap, datasetSpec: Spec) {
- const hasTokens = findSpecKeys(datasetSpec, Tokens).length > 0;
- const hasSupportedPreds =
- findSpecKeys(datasetSpec, supportedPredTypes).length > 0;
- return (hasTokens && hasSupportedPreds);
- }
-}
-
-/** Model output module class. */
-@customElement('span-graph-module')
-export class SpanGraphModule extends LitModule {
- static override title = 'Structured Prediction (model preds)';
- static override duplicateForExampleComparison = true;
- static override duplicateAsRow = false;
- static override numCols = 4;
- static override template =
- (model: string, selectionServiceIndex: number, shouldReact: number) => html`
-
- `;
- static orientation = 'horizontal';
-
- @computed
- get predSpec() {
- return this.appState.getModelSpec(this.model).output;
- }
-
- // This is updated with an API call, via a reaction.
- @observable predDisplayData: Annotations = {};
-
- private async updatePredDisplayData(input: IndexedInput|null) {
- if (input === null) {
- this.predDisplayData = {};
- } else {
- const promise = this.apiService.getPreds(
- [input], this.model, this.appState.currentDataset,
- [Tokens, ...supportedPredTypes]);
-
- const results = await this.loadLatest('getPreds', promise);
- if (!results) return;
-
- this.predDisplayData = parseInput(results[0], this.predSpec);
- }
- }
-
- static override get styles() {
- return [sharedStyles, moduleStyles];
- }
-
- override firstUpdated() {
- this.reactImmediately(
- () => this.selectionService.primarySelectedInputData, input => {
- this.updatePredDisplayData(input);
- });
- }
-
- // tslint:disable:no-any
- override renderImpl() {
- return html`
-
- ${
- renderTokenGroups(
- this.predDisplayData, this.predSpec,
- (this.constructor as any).orientation)}
-
- `;
- }
- // tslint:enable:no-any
-
- static override shouldDisplayModule(modelSpecs: ModelInfoMap, datasetSpec: Spec) {
- const models = Object.keys(modelSpecs);
- for (let modelNum = 0; modelNum < models.length; modelNum++) {
- const spec = modelSpecs[models[modelNum]].spec;
- const hasTokens = findSpecKeys(spec.output, Tokens).length > 0;
- const hasSupportedPreds =
- findSpecKeys(spec.output, supportedPredTypes).length > 0;
- if (hasTokens && hasSupportedPreds) {
- return true;
- }
- }
- return false;
- }
-}
-
-// tslint:disable:class-as-namespace
-
-/** Gold predictions module class. */
-@customElement('span-graph-gold-module-vertical')
-export class SpanGraphGoldModuleVertical extends SpanGraphGoldModule {
- static override duplicateAsRow = true;
- static override orientation = 'vertical';
- static override numCols = 4;
- static override template =
- (model: string, selectionServiceIndex: number, shouldReact: number) => html`
-
- `;
-}
-
-/** Model output module class. */
-@customElement('span-graph-module-vertical')
-export class SpanGraphModuleVertical extends SpanGraphModule {
- static override duplicateAsRow = true;
- static override orientation = 'vertical';
- static override template =
- (model: string, selectionServiceIndex: number, shouldReact: number) => html`
-
- `;
-}
-
-// tslint:enable:class-as-namespace
-
-declare global {
- interface HTMLElementTagNameMap {
- 'span-graph-gold-module': SpanGraphGoldModule;
- 'span-graph-module': SpanGraphModule;
- // TODO(b/172979677): make these parameterized versions, rather than
- // separate classes.
- 'span-graph-gold-module-vertical': SpanGraphGoldModuleVertical;
- 'span-graph-module-vertical': SpanGraphModuleVertical;
- }
-}