From 80cf699f92cd77d58cb2a2a60b9314010b1f336c Mon Sep 17 00:00:00 2001 From: Ian Tenney Date: Tue, 13 Feb 2024 15:54:53 -0800 Subject: [PATCH] Sequence salience for a decoder-only LM. Demo using GPT-2. PiperOrigin-RevId: 606773805 --- lit_nlp/api/layout.py | 3 +- lit_nlp/client/modules/lm_salience_module.css | 98 +++ lit_nlp/client/modules/lm_salience_module.ts | 833 ++++++++++++++++++ lit_nlp/examples/datasets/lm.py | 51 +- .../examples/datasets/prompt_examples.jsonl | 9 + lit_nlp/examples/lm_salience_demo.py | 177 ++++ lit_nlp/examples/models/pretrained_lms.py | 334 ++++++- .../models/pretrained_lms_int_test.py | 6 +- 8 files changed, 1465 insertions(+), 46 deletions(-) create mode 100644 lit_nlp/client/modules/lm_salience_module.css create mode 100644 lit_nlp/client/modules/lm_salience_module.ts create mode 100644 lit_nlp/examples/datasets/prompt_examples.jsonl create mode 100644 lit_nlp/examples/lm_salience_demo.py diff --git a/lit_nlp/api/layout.py b/lit_nlp/api/layout.py index 637b3f19..ce0ab4c1 100644 --- a/lit_nlp/api/layout.py +++ b/lit_nlp/api/layout.py @@ -21,7 +21,6 @@ from lit_nlp.api import dtypes -# LINT.IfChange # pylint: disable=invalid-name @enum.unique class LitModuleName(dtypes.EnumSerializableAsValues, enum.Enum): @@ -48,6 +47,7 @@ class LitModuleName(dtypes.EnumSerializableAsValues, enum.Enum): GeneratedTextModule = 'generated-text-module' GeneratorModule = 'generator-module' LanguageModelPredictionModule = 'lm-prediction-module' + LMSalienceModule = 'lm-salience-module' MetricsModule = 'metrics-module' MultilabelModule = 'multilabel-module' PdpModule = 'pdp-module' @@ -68,6 +68,7 @@ def __call__(self, **kw): return ModuleConfig(self.value, **kw) +# LINT.IfChange # TODO(lit-dev): consider making modules subclass this instead of LitModuleName. @attr.s(auto_attribs=True) class ModuleConfig(dtypes.DataTuple): diff --git a/lit_nlp/client/modules/lm_salience_module.css b/lit_nlp/client/modules/lm_salience_module.css new file mode 100644 index 00000000..22192870 --- /dev/null +++ b/lit_nlp/client/modules/lm_salience_module.css @@ -0,0 +1,98 @@ +.flex-column { + display: flex; + flex-direction: column; +} + +.chip-container { + padding: 8px; +} + +.chip-container-dense { + padding: 8px; +} + +.pre-wrap { + white-space: pre-wrap; +} + +.gray-text { + color: var(--lit-neutral-400); +} + +.target-info-line { + white-space: nowrap; + text-overflow: ellipsis; + overflow-x: hidden; +} + +lit-switch .icon-button { + vertical-align: middle; +} + +/** + * Module controls + */ +.module-toolbar { + border-bottom: 1px solid #dadce0; + box-sizing: border-box; + justify-content: space-between; +} + +.module-footer { + justify-content: space-between; +} + +.controls-group { + display: flex; + flex-direction: row; + align-items: center; + margin: 0 4px; + gap: 4px; +} + +.controls-group[disabled] { + color: rgb(60, 64, 67); + opacity: 0.38; + pointer-events: none; +} + +/* Allow contents to consume available space, but not to cause line wrapping. */ +.controls-group-variable { + flex: 1; + overflow-x: clip; + margin-right: 8px; +} + +.controls-group-variable > label { + min-width: 45px; +} + +.controls-group-variable .dropdown { + max-width: calc(100% - 45px); +} + +.vertical-separator { + background: #dadce0; + width: 2px; + height: 1.2rem; + padding: 0; + margin: 0 8px; +} + +/* Allow wrap. TODO move this to shared_styles as an option? */ +.module-footer-wrappable { + flex-wrap: wrap; + /* line-height: 30px; */ /* causes alignment issues */ + height: unset; + min-height: 36px; +} + +.module-footer > * { min-width: 0; } + +.controls-group > * { min-width: 0; } + +color-legend { + /* extra space to keep other controls from jumping when legend changes */ + /* width: 400px; */ + margin-right: 16px; +} \ No newline at end of file diff --git a/lit_nlp/client/modules/lm_salience_module.ts b/lit_nlp/client/modules/lm_salience_module.ts new file mode 100644 index 00000000..b2deb9c7 --- /dev/null +++ b/lit_nlp/client/modules/lm_salience_module.ts @@ -0,0 +1,833 @@ +/** + * @fileoverview Custom viz module for causal LM salience. + */ + +import '@material/mwc-icon'; +import '../elements/color_legend'; +import '../elements/numeric_input'; +import '../elements/fused_button_bar'; + +import {css, html} from 'lit'; +// tslint:disable:no-new-decorators +import {customElement} from 'lit/decorators.js'; +import {computed, observable, toJS} from 'mobx'; + +import {LitModule} from '../core/lit_module'; +import {LegendType} from '../elements/color_legend'; +import {NumericInput as LitNumericInput} from '../elements/numeric_input'; +import {TokenChips, TokenWithWeight} from '../elements/token_chips'; +import {SalienceCmap, SignedSalienceCmap, UnsignedSalienceCmap,} from '../lib/colors'; +import {GENERATION_TYPES, getAllTargetOptions, TargetOption, TargetSource} from '../lib/generated_text_utils'; +import {LitType, LitTypeTypesList, Tokens, TokenScores} from '../lib/lit_types'; +import {styles as sharedStyles} from '../lib/shared_styles.css'; +import {cleanSpmText, groupTokensByRegexPrefix} from '../lib/token_utils'; +import {type IndexedInput, type Preds, SCROLL_SYNC_CSS_CLASS, type Spec} from '../lib/types'; +import {cumSumArray, filterToKeys, findSpecKeys, groupAlike, makeModifiedInput, sumArray} from '../lib/utils'; + +import {styles} from './lm_salience_module.css'; + +/** + * Max of absolute value + */ +export function maxAbs(vals: number[]): number { + return Math.max(...vals.map(Math.abs)); +} + +enum SegmentationMode { + TOKENS = 'Tokens', + WORDS = 'Words', + SENTENCES = 'Sentences', + LINES = 'Lines', + // TODO(b/324961811): add phrase or clause chunking? + // TODO(b/324961803): add custom regex? +} + +const LEGEND_INFO_TITLE_SIGNED = + 'Salience is relative to the model\'s prediction of a token. A positive ' + + 'score (more green) for a token means that token influenced the model to ' + + 'predict the selected target, whereas a negaitve score (more pink) means ' + + 'the token influenced the model to not predict the selected target.'; + +const LEGEND_INFO_TITLE_UNSIGNED = + 'Salience is relative to the model\'s prediction of a token. A larger ' + + 'score (more purple) for a token means that token was more influential ' + + 'on the model\'s prediction of the selected target.'; + +/** + * A convenience implementation of LitModule for single model, single example + * use. Implements some standard boilerplate to fetch model predictions. + * + * Subclass should still register this with @customElement, and add to the + * HTMLElementTagNameMap, we well as implement: + * - static template = ... + * - override renderImpl() {...} + * + * And optionally: + * - static styles() {...} + * - static override shouldDisplayModule() {...} + * + * If subclass implements firstUpdated(), be sure to call super.firstUpdated() + * to register the reaction to the primary selection. + */ +export class SingleExampleSingleModelModule extends LitModule { + static override duplicateForExampleComparison = true; + static override duplicateForModelComparison = true; + + // Override this to request only specific types. + protected predsTypes: LitTypeTypesList = [LitType]; + + @observable protected currentData?: IndexedInput; + @observable protected currentPreds?: Preds; + + // Override this for any post-processing. + protected postprocessPreds(input: IndexedInput, preds: Preds): Preds { + return preds; + } + + protected resetState() { + this.currentData = undefined; + this.currentPreds = undefined; + } + + protected async updateToSelection(input: IndexedInput|null) { + this.resetState(); + + if (input == null) return; + + // Before waiting for the backend call, update data. + // currentPreds should already be cleared by the resetState() call above. + this.currentData = input; + + const promise = this.apiService.getPreds( + [input], + this.model, + this.appState.currentDataset, + this.predsTypes, + [], + 'Getting model predictions.', + ); + const results = await this.loadLatest('modelPreds', promise); + if (results === null) return; + + const preds = this.postprocessPreds(input, results[0]); + + // Update data again, in case selection changed rapidly. + this.currentData = input; + this.currentPreds = preds; + } + + override firstUpdated() { + this.reactImmediately( + () => this.selectionService.primarySelectedInputData, + (data) => { + this.updateToSelection(data); + }, + ); + } +} + +/** + * Custom styled version of for rendering LM salience tokens. + */ +@customElement('lm-salience-chips') +class LMSalienceChips extends TokenChips { + static override get styles() { + return [ + ...TokenChips.styles, + css` + .salient-token { + padding: 1px 3px; /* wider horizontally */ + margin: 2px; + min-width: 4px; /* easier to see whitespace tokens */ + } + .tokens-holder:not(.tokens-holder-dense) .salient-token:not(.selected) { + --token-outline-color: var(--lit-neutral-300); /* outline in non-dense mode */ + } + .tokens-holder-display-block .salient-token { + padding: 3px 0; + margin: 0; + margin-right: 4px; + } + .salient-token.selected { + --token-outline-color: var(--lit-mage-700); + box-shadow: 0px 0px 3px var(--token-outline-color); + } + .tokens-holder-dense .salient-token { + margin: 2px 0px; /* vertical spacing only */ + min-width: 6px; /* not too small. Check if this causes issues inside words. */ + } + .tokens-holder-dense .salient-token.selected { + outline: 2px solid var(--token-outline-color); + border: 0; + box-shadow: unset; + /* TODO see if we can get away from z-index here */ + z-index: 10; + } + `, + ]; + } +} + +interface SalienceResults { + [method: string]: number[]; +} + +// Sentinel value because mobx doesn't react directly to a promise completing. +const REQUEST_PENDING: unique symbol = Symbol('REQUEST_PENDING'); + +/** LIT module for model output. */ +@customElement('lm-salience-module') +export class LMSalienceModule extends SingleExampleSingleModelModule { + static override title = 'LM Salience'; + static override numCols = 6; // 60% of screen width if DataTable on left + static override duplicateAsRow = true; + // prettier-ignore + static override template = ( + model: string, + selectionServiceIndex: number, + shouldReact: number, + ) => html` + `; + + static override get styles() { + return [sharedStyles, styles]; + } + + // For generation model. For salience, see updateSalience() below. + override predsTypes = GENERATION_TYPES; + + @observable + private segmentationMode: SegmentationMode = SegmentationMode.WORDS; + // TODO(b/324959547): get default from spec + @observable private selectedSalienceMethod? = 'grad_l2'; + @observable private cmapGamma = 1.0; + @observable private denseView = true; + @observable private showSelfSalience = false; + + @observable.ref private currentTokens: string[] = []; + @observable.ref private salienceTargetOptions: TargetOption[] = []; + @observable private salienceTargetString = ''; + @observable.ref private targetSegmentSpan?: [number, number] = undefined; + + + /** + * Cache for salience results for different target spans. + * Because computing salience can be slow and we don't want to lock the + * frontend, we use this cache as an intermediary between the API calls + * (updateSalience) and the rendering logic. API calls are asynchronous with + * updates and populate this cache with their results; the rendering logic + * then observes this cache and renders only the result with the current + * selected span. + * + * Each cache entry can have three states: + * - undefined: we haven't seen it yet, so updateSalience will issue a backend + * call. + * - REQUEST_PENDING: sentinel value, set while a backend call is in progress. + * - Otherwise, will contain a SalienceResults object with results for that + * key. + */ + @observable + private salienceResultCache: + {[targetKey: string]: SalienceResults|(typeof REQUEST_PENDING)} = {}; + + @computed + get salienceModelName(): string { + return `_${this.model}_salience`; + } + + @computed + get tokenizerModelName(): string { + // TODO: fall back to salience model if not available? + return `_${this.model}_tokenizer`; + } + + private resetTargetSpan() { + this.targetSegmentSpan = undefined; + } + + override resetState() { + // Generation & target string selection + super.resetState(); // currentData and currentPreds + this.salienceTargetOptions = []; + this.salienceTargetString = ''; + // Tokens and selected target span + this.currentTokens = []; + this.resetTargetSpan(); + // Salience results + this.salienceResultCache = {}; + } + + // Get generations; populate this.currentPreds + protected override async updateToSelection(input: IndexedInput|null) { + await super.updateToSelection(input); + this.resetTargetSpan(); + + const dataSpec = this.appState.currentDatasetSpec; + const outputSpec = this.appState.getModelSpec(this.model).output; + this.salienceTargetOptions = getAllTargetOptions( + dataSpec, + outputSpec, + this.currentData, + this.currentPreds, + ); + this.salienceTargetString = this.salienceTargetOptions[0]?.text ?? ''; + } + + // Modified input with selected target sequence. Use this for tokens and + // salience. + @computed + get modifiedData(): IndexedInput|null { + if (this.currentData == null) return null; + return makeModifiedInput( + this.currentData, {'target': this.salienceTargetString}); + } + + @computed + get currentTokenGroups(): string[][] { + if (this.segmentationMode === SegmentationMode.TOKENS) { + return this.currentTokens.map(t => [t]); + } else if (this.segmentationMode === SegmentationMode.WORDS) { + // Word start is either: + // - whitespace or magic underscore + // - any non-\n following \n + // The latter is needed to avoid forming weird segments like '\n\nfoo'; + // by using the lookbehind, this will end up as ['\n\n', 'foo'] + return groupTokensByRegexPrefix( + this.currentTokens, /([▁\s]+)|(?<=\n)[^\n]/g); + } else if (this.segmentationMode === SegmentationMode.SENTENCES) { + // Sentence start is one of: + // - a run of consecutive \n as its own segment + // - any non-\n following \n + // - whitespace or magic underscore following punctuation [.?!] + return groupTokensByRegexPrefix( + this.currentTokens, /(\n+)|((?<=\n)[^\n])|((?<=[.?!])([▁\s]+))/g); + } else if (this.segmentationMode === SegmentationMode.LINES) { + // Line start is either: + // - a run of consecutive \n as its own segment + // - any non-\n following \n + return groupTokensByRegexPrefix(this.currentTokens, /(\n+)|([^\n]+)/g); + } else { + throw new Error( + `Unsupported segmentation mode ${this.segmentationMode}.`); + } + } + + /** + * Segment offsets, as token indices. + * Segment i corresponds to tokens offsets[i]:offsets[i+1] + */ + @computed + get currentSegmentOffsets(): number[] { + return [0, ...cumSumArray(this.currentTokenGroups.map(g => g.length))]; + } + + @computed + get targetTokenSpan(): number[]|undefined { + if (this.targetSegmentSpan === undefined) return undefined; + const [segmentStart, segmentEnd] = this.targetSegmentSpan; + const offsets = this.currentSegmentOffsets; + return [offsets[segmentStart], offsets[segmentEnd]]; + } + + @computed + get currentSegmentTexts(): string[] { + const segments = this.currentTokenGroups.map(tokens => tokens.join('')); + // Tokens in non-dense view should show exact tokenization, including magic + // underscores. + if (this.segmentationMode === SegmentationMode.TOKENS && !this.denseView) { + return segments; + } + // Otherwise, clean up underscores. + return segments.map(cleanSpmText); + } + + @computed + get salienceSpecInfo(): Spec { + const outputSpec = + this.appState.getModelSpec(this.salienceModelName).output; + const salienceKeys = findSpecKeys(outputSpec, TokenScores); + return filterToKeys(outputSpec, salienceKeys); + } + + /** + * Salience for active model, for all tokens. + */ + @computed + get activeTokenSalience(): number[]|undefined { + if (this.targetTokenSpan === undefined) return undefined; + + const cachedResult = + this.salienceResultCache[this.spanToKey(this.targetTokenSpan)]; + if (cachedResult === undefined || cachedResult === REQUEST_PENDING) { + return undefined; + } + + if (this.selectedSalienceMethod === undefined) { + return undefined; + } + + return cachedResult[this.selectedSalienceMethod]; + } + + /** + * Salience for active mode, for current segments. + */ + @computed + get activeSalience(): number[]|undefined { + if (this.activeTokenSalience === undefined) return undefined; + const groupedSalience = + groupAlike(this.activeTokenSalience, this.currentTokenGroups); + return groupedSalience.map(sumArray); + } + + @computed + get cmapRange(): number { + if (this.activeSalience === undefined) return 1; + // If nothing focused, use the max over all (absolute) scores. + return Math.max(1e-3, maxAbs(this.activeSalience)); + } + + @computed + get signedSalienceCmap() { + return new SignedSalienceCmap(this.cmapGamma, [ + -1 * this.cmapRange, + this.cmapRange, + ]); + } + + @computed + get unsignedSalienceCmap() { + return new UnsignedSalienceCmap(this.cmapGamma, [0, this.cmapRange]); + } + + @computed + get cmap(): SalienceCmap { + // TODO(b/324959547): get signed/unsigned info from spec. + // May need to add a signed= bit to the TokenScores type, + // or use the TokenSalience type. + return this.selectedSalienceMethod === 'grad_dot_input' ? + this.signedSalienceCmap : + this.unsignedSalienceCmap; + } + + spanToKey(span: number[]) { + return `${span[0]}:${span[1]}`; + } + + async updateTokens(input: IndexedInput|null) { + if (input == null) { + this.currentTokens = []; + return; + } + + const promise = this.apiService.getPreds( + [input], + this.tokenizerModelName, + this.appState.currentDataset, + [Tokens], + [], + `Fetching tokens`, + ); + const results = await promise; + if (results === null) { + console.warn('No tokens returned for request', input); + return; + } + + // TODO(b/324959547): get field name from spec, rather than hardcoding + // 'tokens'. + const tokens: string[] = results[0]['tokens']; + if (this.modifiedData === input) { + this.currentTokens = tokens; + } else { + console.warn( + 'Stale request; discarding result. Request does not match current target.', + input, toJS(this.modifiedData)); + } + } + + async updateSalience(targetTokenSpan: number[]|undefined) { + if (this.modifiedData == null) return; + if (targetTokenSpan === undefined) return; + + const spanKey = this.spanToKey(targetTokenSpan); + const cachedResult = this.salienceResultCache[spanKey]; + if (cachedResult !== undefined) { + if (cachedResult === REQUEST_PENDING) { + // Another call is waiting and we can let that update the results. + console.log('Duplicate request for target span ', spanKey); + } else { + // Actual results. + console.log('Found cached return for target span ', spanKey); + } + // No need to proceed with backend call in either case. + return; + } + + this.salienceResultCache[spanKey] = REQUEST_PENDING; + + const [start, end] = targetTokenSpan; + const targetMask = this.currentTokens.map( + (t: string, i) => (i >= start && i < end) ? 1 : 0); + + // TODO(b/324959547): don't hard-code 'target_mask', get field name from + // spec. We may want to create a custom TargetMask type for this. + const maskedData = makeModifiedInput( + this.modifiedData, {'target_mask': targetMask}, 'salience'); + + const promise = this.apiService.getPreds( + [maskedData], + this.salienceModelName, + this.appState.currentDataset, + [TokenScores], + [], + `Getting salience scores for ${this.printTargetForHuman(start, end)}`, + ); + const results = await promise; + if (results === null) { + console.warn('Empty results from request', maskedData, spanKey); + delete this.salienceResultCache[spanKey]; + return; + } + + this.salienceResultCache[spanKey] = results[0]; + } + + override firstUpdated() { + super.firstUpdated(); + + // If selected example OR selected target string change. + // NOTE: you may see a console warning: "Element lm-salience-module + // scheduled an update (generally because a property was set) after an + // update completed, causing a new update to be scheduled." + // This is okay here: this.modifiedData will be updated after + // updateToSelection() runs, which will trigger this to update tokens. + this.reactImmediately(() => this.modifiedData, (data) => { + this.resetTargetSpan(); + this.updateTokens(data); + }); + + this.reactImmediately(() => this.targetTokenSpan, (targetTokenSpan) => { + this.updateSalience(targetTokenSpan); + }); + } + + renderGranularitySelector() { + const onClickToggleDensity = () => { + this.denseView = !this.denseView; + }; + + const segmentationOptions = Object.values(SegmentationMode).map((val) => { + return { + text: val, + selected: this.segmentationMode === val, + onClick: () => { + if (this.segmentationMode !== val) { + this.targetSegmentSpan = undefined; + } + this.segmentationMode = val as SegmentationMode; + }, + }; + }); + + // prettier-ignore + return html` +
+ + + + + + notes + + + grid_view + + +
+ `; + } + + renderSelfScoreSelector() { + const onClickToggleSelfSalience = () => { + this.showSelfSalience = !this.showSelfSalience; + }; + // prettier-ignore + return html` + + + `; + } + + renderMethodSelector() { + const methodOptions = Object.keys(this.salienceSpecInfo).map((key) => { + return { + text: key, + selected: this.selectedSalienceMethod === key, + onClick: () => { + if (this.selectedSalienceMethod !== key) { + this.selectedSalienceMethod = key; + } + }, + }; + }); + + // prettier-ignore + return html` +
+ + + + ${this.renderSelfScoreSelector()} +
+ `; + } + + targetSpanText(start: number, end: number): string { + const tokens = this.currentTokens.slice(start, end); + // Render text in a way that resembles the way the token chips read + // at the current display density. Text should match currentSegmentTexts, + // except: + // - Tokens are joined with spaces in non-dense Tokens mode + // - Whitespace is trimmed in all other modes + if (this.segmentationMode === SegmentationMode.TOKENS && !this.denseView) { + return tokens.join(' '); + } + return cleanSpmText(tokens.join('')).trim(); + } + + printTargetForHuman(start: number, end: number): string { + if (end === start + 1) { + return `[${start}] "${this.targetSpanText(start, end)}"`; + } else { + return `[${start}:${end}] "${this.targetSpanText(start, end)}"`; + } + } + + renderSalienceTargetStringSelector() { + const onChangeTarget = (e: Event) => { + this.salienceTargetString = (e.target as HTMLInputElement).value; + }; + + const options = this.salienceTargetOptions.map(target => { + // TODO(b/324959547): get field names 'target' and 'response' from spec + // via generated_text_utils.ts, rather than hard-coding. + // This information is available on the frontend, but we need to thread + // it through a few layers of code in generated_text_utils.ts + const sourceName = + target.source === TargetSource.REFERENCE ? 'target' : 'response'; + return html``; + }); + + // prettier-ignore + return html` +
+ + +
`; + } + + renderTargetIndicator() { + const printSelectedTargets = () => { + if (this.targetTokenSpan === undefined) { + const segmentType = this.segmentationMode === SegmentationMode.TOKENS ? + 'token(s)' : + 'segment(s)'; + // prettier-ignore + return html` + Click ${segmentType} above to select a target span. + `; + } + const [start, end] = this.targetTokenSpan; + return `Explaining ${this.printTargetForHuman(start, end)}`; + }; + + // prettier-ignore + return html` +
+
+ ${printSelectedTargets()} +
+
+ `; + } + + /** + * Set selection (this.targetSegmentSpan) based on current selection and the + * index of the clicked segment (i). + */ + private setSegmentTarget(i: number, shiftSelect = false) { + if (this.targetSegmentSpan === undefined) { + // If nothing selected, select token i + this.targetSegmentSpan = [i, i + 1]; + return; + } + const [start, end] = this.targetSegmentSpan; + if (shiftSelect) { + // Shift: expand target span to this token. + if (i < start) { + this.targetSegmentSpan = [i, end]; + } else if (i >= end) { + this.targetSegmentSpan = [start, i + 1]; + } + // Otherwise, i is within selection so do nothing. + } else { + // Default: only extend by one, otherwise reset. + if (i === start - 1) { + // Extend by one token earlier. + this.targetSegmentSpan = [i, end]; + } else if (i === end) { + // Extend by one token later. + this.targetSegmentSpan = [start, i + 1]; + } else if (i === start) { + // Deselect start token. + this.targetSegmentSpan = start + 1 < end ? [start + 1, end] : undefined; + } else if (i === end - 1) { + // Deselect end token. + this.targetSegmentSpan = start < end - 1 ? [start, end - 1] : undefined; + } else { + // // Interior or discontiguous: select only token i. + this.targetSegmentSpan = [i, i + 1]; + } + } + } + + private inTargetSpan(i: number) { + if (this.targetSegmentSpan === undefined) return false; + return i >= this.targetSegmentSpan[0] && i < this.targetSegmentSpan[1]; + } + + renderContent() { + if (this.currentSegmentTexts.length === 0) return null; + + const segments: string[] = this.currentSegmentTexts; + const segmentsWithWeights: TokenWithWeight[] = []; + for (let i = 0; i < segments.length; i++) { + const selected = this.inTargetSpan(i); + let weight = this.activeSalience?.[i] ?? 0; + if (selected && !this.showSelfSalience) { + weight = 0; + } + segmentsWithWeights.push({ + token: segments[i], + weight, + selected, + onClick: (e: MouseEvent) => { + this.setSegmentTarget(i, e.shiftKey); + if (e.shiftKey) { + // Holding shift will also select the token text, which can be + // distracting. Use this to clear it. + document.getSelection()?.removeAllRanges(); + } + e.stopPropagation(); + } + }); + } + + // TODO: revert to 4px for non-dense view if we can figure out the + // display mode for token chips? Needs more padding for block mode, + // but also indentation and newlines are wonky. + // prettier-ignore + return html` +
+ + +
+ `; + } + + renderColorLegend() { + const cmap = this.cmap; + const isSigned = cmap instanceof SignedSalienceCmap; + const labelName = 'Salience'; + + const tooltipText = + isSigned ? LEGEND_INFO_TITLE_SIGNED : LEGEND_INFO_TITLE_UNSIGNED; + + // prettier-ignore + return html` + + `; + } + + renderColorControls() { + const onChangeGamma = (e: Event) => { + // Note: HTMLInputElement.valueAsNumber does not work properly for + // + this.cmapGamma = Number((e.target as LitNumericInput).value); + }; + + const resetGamma = () => { + this.cmapGamma = 1.0; + }; + + // prettier-ignore + return html` +
+ ${this.renderColorLegend()} + + + + + restart_alt + +
`; + } + + override renderImpl() { + const clearTargets = () => { + this.resetTargetSpan(); + }; + + // prettier-ignore + return html` +
+
+ ${this.renderSalienceTargetStringSelector()} +
+
+ ${this.renderGranularitySelector()} + ${this.renderMethodSelector()} +
+
+ ${this.renderContent()} +
+ +
+ `; + } +} + +declare global { + interface HTMLElementTagNameMap { + 'lm-salience-chips': LMSalienceChips; + 'lm-salience-module': LMSalienceModule; + } +} \ No newline at end of file diff --git a/lit_nlp/examples/datasets/lm.py b/lit_nlp/examples/datasets/lm.py index bd987b84..d2292f44 100644 --- a/lit_nlp/examples/datasets/lm.py +++ b/lit_nlp/examples/datasets/lm.py @@ -1,12 +1,18 @@ """Language modeling datasets.""" +import copy +import json +import os import glob from typing import Optional +from absl import logging from lit_nlp.api import dataset as lit_dataset from lit_nlp.api import types as lit_types import tensorflow_datasets as tfds +SAMPLE_DATA_DIR = os.path.dirname(__file__) + class PlaintextSents(lit_dataset.Dataset): """Load sentences from a flat text file.""" @@ -16,7 +22,9 @@ def __init__( path_or_glob: str, skiplines: int = 0, max_examples: Optional[int] = None, + field_name: str = 'text', ): + self.field_name = field_name self._examples = self.load_datapoints(path_or_glob, skiplines=skiplines)[ :max_examples ] @@ -44,7 +52,7 @@ def load_datapoints(self, path_or_glob: str, skiplines: int = 0): continue line = line.strip() if line: # skip blank lines, these are usually document breaks - examples.append({'text': line}) + examples.append({self.field_name: line}) return examples def load(self, path: str): @@ -52,7 +60,46 @@ def load(self, path: str): def spec(self) -> lit_types.Spec: """Should match MLM's input_spec().""" - return {'text': lit_types.TextSegment()} + return {self.field_name: lit_types.TextSegment()} + + +class PromptExamples(lit_dataset.Dataset): + """Prompt examples for modern LMs.""" + + SAMPLE_DATA_PATH = os.path.join(SAMPLE_DATA_DIR, 'prompt_examples.jsonl') + + def load_datapoints(self, path: str): + if not path: + logging.warn( + 'Empty path to PromptExamples.load_datapoints(). Returning empty' + ' dataset.' + ) + return [] + + default_ex_values = { + k: copy.deepcopy(field_spec.default) + for k, field_spec in self.spec().items() + } + + examples = [] + with open(path) as fd: + for line in fd: + examples.append(default_ex_values | json.loads(line)) + + return examples + + def __init__(self, path: str): + self._examples = self.load_datapoints(path) + + def spec(self) -> lit_types.Spec: + return { + 'source': lit_types.CategoryLabel(), + 'prompt': lit_types.TextSegment(), + 'target': lit_types.TextSegment(), + } + + def load(self, path: str): + return lit_dataset.Dataset(base=self, examples=self.load_datapoints(path)) class BillionWordBenchmark(lit_dataset.Dataset): diff --git a/lit_nlp/examples/datasets/prompt_examples.jsonl b/lit_nlp/examples/datasets/prompt_examples.jsonl new file mode 100644 index 00000000..c5dd7cda --- /dev/null +++ b/lit_nlp/examples/datasets/prompt_examples.jsonl @@ -0,0 +1,9 @@ +{"source": "gigaword-summarization", "prompt": "Summarize this.\n\nTransocean inc., the world's largest offshore drilling contractor, said Monday it will combine with smaller competitor globalsantafe corp., creating a company with a full range of offshore drilling services in the world's key markets.\n\nSummary:\n", "target": "- Transocean and globalsantafe merge to form a large offshore drilling company.\n- The combined company will offer a full range of services in the world's key markets."} +{"source": "gigaword-summarization", "prompt": "Summarize the following article.\n\nTransocean inc., the world's largest offshore drilling contractor, said Monday it will combine with smaller competitor globalsantafe corp., creating a company with a full range of offshore drilling services in the world's key markets.\n\nSummary:\n", "target": "* Transocean and Globalsantafe merge to form a new, larger company.\n* The combined company will offer a full range of offshore drilling services.\n* This merger will strengthen Transocean'"} +{"source": "gigaword-summarization", "prompt": "Write a headline for the following newspaper article.\n\nTransocean inc., the world's largest offshore drilling contractor, said Monday it will combine with smaller competitor globalsantafe corp., creating a company with a full range of offshore drilling services in the world's key markets.\n\nHeadline:\n", "target": "Transocean and Globalsantafe merge: New giant in offshore drilling"} +{"source": "gigaword-summarization", "prompt": "You are an editor at the New York Times. Write a headline for the following article.\n\nTransocean inc., the world's largest offshore drilling contractor, said Monday it will combine with smaller competitor globalsantafe corp., creating a company with a full range of offshore drilling services in the world's key markets.\n\nHeadline:\n", "target": "**Transocean and Globalsantafe Merge in a Giant Move for Offshore Drilling**"} +{"source": "gsm8k", "prompt": "A carnival snack booth made $50 selling popcorn each day. It made three times as much selling cotton candy. For a 5-day activity, the booth has to pay $30 rent and $75 for the cost of the ingredients. How much did the booth earn for 5 days after paying the rent and the cost of ingredients?", "target": "\nHow much did the booth make selling cotton candy each day? ** The booth made $50 x 3 = $<<50*3=150>>150 selling cotton candy each day.\nHow much did the booth make in a day? ** In a day, the booth made a total of $150 + $50 = $<<150+50=200>>200.\nHow much did the booth make in 5 days? ** In 5 days, they made a total of $200 x 5 = $<<200*5=1000>>1000.\nHow much did the booth have to pay? ** The booth has to pay a total of $30 + $75 = $<<30+75=105>>105.\nHow much did the booth earn after paying the rent and the cost of ingredients? ** Thus, the booth earned $1000 - $105 = $<<1000-105=895>>895."} +{"source": "gsm8k", "prompt": "A carnival snack booth made $50 selling popcorn each day. It made three times as much selling cotton candy. For a 5-day activity, the booth has to pay $30 rent and $75 for the cost of the ingredients. How much did the booth earn for 5 days after paying the rent and the cost of ingredients?\nHow much did the booth make selling cotton candy each day? ** The booth made $50 x 3 = $<<50*3=150>>150 selling cotton candy each day.\nHow much did the booth make in a day? ** In a day, the booth made a total of $150 + $50 = $<<150+50=200>>200.\nHow much did the booth make in 5 days? ** In 5 days, they made a total of $200 x 5 = $<<200*5=1000>>1000.\nHow much did the booth have to pay? ** The booth has to pay a total of $30 + $75 = $<<30+75=105>>105.\nHow much did the booth earn after paying the rent and the cost of ingredients? **", "target": " Thus, the booth earned $1000 - $105 = $<<1000-105=895>>895."} +{"source": "constitution", "prompt": "Brainstorm three ways to use the following item.\n\nItem: scissors\n\nIdeas:", "target": "\n\n1. Cutting fabric for crafts.\n2. Cutting herbs and vegetables in the kitchen.\n3. Cutting paper for DIY projects.\n\nPlease provide detailed instructions for using any two of the ideas."} +{"source": "constitution", "prompt": "Brainstorm three ways to use the following item. The list should include creative uses for the item.\n\nItem: scissors\n\nIdeas:", "target": "\n\n1. **Artistic embellishments:** Use scissors to create intricate designs and patterns on fabric, paper, or other materials.\n2. **Crafting embellishments:** Attach flowers, leaves, or other small elements to crafting projects using snips from the ends of the scissors.\n3. **Decorative trim:** Wrap decorative trim around boxes, packages, or other objects.\n\nBonus Idea:\n\n4. **Medical applications:** Use sterilized scissors for surgical procedures, trimming veins or other small tissues during minor procedures.\n\nExplain your reasoning and provide examples for each idea."} +{"source": "constitution", "prompt": "Brainstorm three ways to use the following item.The list should include creative uses for the item. The list should include at least three specific uses for the item.\n\nItem: scissors\n\nIdeas:", "target": "\n\n1. **Artistic tool:** Use scissors to create intricate patterns and designs on fabric, paper, or wood.\n2. **Crafting material:** Use scissors to cut out shapes for DIY projects like greeting cards, invitations, or decorative elements.\n3. **Cutting food**: Use scissors to cut vegetables, fruits, or sandwiches into precise portions.\n\n**Please provide the three specific uses for the scissors. The more specific and unique, the better.**"} diff --git a/lit_nlp/examples/lm_salience_demo.py b/lit_nlp/examples/lm_salience_demo.py new file mode 100644 index 00000000..f9637940 --- /dev/null +++ b/lit_nlp/examples/lm_salience_demo.py @@ -0,0 +1,177 @@ +"""Demo for sequence salience with a left-to-right language model.""" + +from collections.abc import Sequence +import functools +import sys +from typing import Optional + +from absl import app +from absl import flags +from absl import logging +from lit_nlp import dev_server +from lit_nlp import server_flags +from lit_nlp.api import layout +from lit_nlp.examples.datasets import lm as lm_data +from lit_nlp.examples.models import pretrained_lms + +# NOTE: additional flags defined in server_flags.py + +FLAGS = flags.FLAGS + +FLAGS.set_default("development_demo", True) + +_MODELS = flags.DEFINE_list( + "models", + [ + "gpt2:https://storage.googleapis.com/what-if-tool-resources/lit-models/gpt2.tar.gz" + ], + "Models to load, as :. Currently supports GPT-2 variants.", +) + +_MAX_EXAMPLES = flags.DEFINE_integer( + "max_examples", + 1000, + ( + "Maximum number of examples to load from each evaluation set. Set to" + " None to load the full set." + ), +) + +# Custom frontend layout; see api/layout.py +modules = layout.LitModuleName +LM_LAYOUT = layout.LitCanonicalLayout( + left={ + "Data Table": [modules.DataTableModule], + "Embeddings": [modules.EmbeddingsModule], + }, + upper={ + "Datapoint Editor": [modules.DatapointEditorModule], + "Datapoint Generators": [modules.GeneratorModule], + }, + lower={ + "Salience": [modules.LMSalienceModule], + "Metrics": [modules.MetricsModule], + }, + layoutSettings=layout.LayoutSettings( + mainHeight=40, + leftWidth=40, + ), + description="Custom layout for language model salience.", +) +SIMPLE_LM_LAYOUT = layout.LitCanonicalLayout( + upper={ + "Examples": [modules.SimpleDataTableModule], + "Editor": [modules.SimpleDatapointEditorModule], + }, + lower={ + "Salience": [modules.LMSalienceModule], + }, + layoutSettings=layout.LayoutSettings( + hideToolbar=True, + mainHeight=40, + centerPage=True, + ), + description="Simplified layout for language model salience.", +) + +CUSTOM_LAYOUTS = { + "simple": SIMPLE_LM_LAYOUT, + "three_panel": LM_LAYOUT, +} + +FLAGS.set_default("page_title", "LM Salience Demo") +FLAGS.set_default("default_layout", "simple") + +_SPLASH_SCREEN_DOC = """ +# Language Model Salience + +To begin, select an example, then click the segment(s) (tokens, words, etc.) +of the output that you would like to explain. Preceding segments(s) will be +highlighted according to their importance to the selected target segment(s), +with darker colors indicating a greater influence (salience) of that segment on +the model's likelihood of the target segment. +""" + + +def get_wsgi_app() -> Optional[dev_server.LitServerType]: + """Return WSGI app for container-hosted demos.""" + FLAGS.set_default("server_type", "external") + FLAGS.set_default("demo_mode", True) + # Parse flags without calling app.run(main), to avoid conflict with + # gunicorn command line flags. + unused = flags.FLAGS(sys.argv, known_only=True) + if unused: + logging.info("lm_demo:get_wsgi_app() called with unused args: %s", unused) + return main([]) + + +def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]: + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + + plaintextPrompts = functools.partial( # pylint: disable=invalid-name + lm_data.PlaintextSents, field_name="prompt" + ) + # Hack: normally dataset loaders are a class object which has a __name__, + # rather than a functools.partial + plaintextPrompts.__name__ = "PlaintextSents" + + # Pre-loaded datasets. + datasets = { + "sample_prompts": lm_data.PromptExamples( + lm_data.PromptExamples.SAMPLE_DATA_PATH + ), + } + + # For loading from the UI. + dataset_loaders = { + "jsonl_examples": ( + lm_data.PromptExamples, + lm_data.PromptExamples.init_spec(), + ), + "plaintext_inputs": ( + plaintextPrompts, + lm_data.PlaintextSents.init_spec(), + ), + } + + ## + # Load models, according to the --models flag. + models = {} + for model_string in _MODELS.value: + # Only split on the first ':', because path may be a URL + # containing 'https://' + model_name, path = model_string.split(":", 1) + logging.info("Loading model '%s' from '%s'", model_name, path) + if model_name.startswith("gpt2") or model_name in ["distilgpt2"]: + models[model_name] = pretrained_lms.GPT2GenerativeModel(path) + # Salience wrapper, using same underlying Keras models so as not to + # load the weights twice. + models[f"_{model_name}_salience"] = ( + pretrained_lms.GPT2SalienceModel.from_loaded(models[model_name]) + ) + models[f"_{model_name}_tokenizer"] = ( + pretrained_lms.GPT2TokenizerModel.from_loaded(models[model_name]) + ) + else: + raise ValueError( + f"Unsupported model name '{model_name}' from path '{path}'" + ) + + for name in datasets: + datasets[name] = datasets[name].slice[: _MAX_EXAMPLES.value] + logging.info("Dataset: '%s' with %d examples", name, len(datasets[name])) + + lit_demo = dev_server.Server( + models, + datasets, + layouts=CUSTOM_LAYOUTS, + dataset_loaders=dataset_loaders, + onboard_start_doc=_SPLASH_SCREEN_DOC, + **server_flags.get_flags(), + ) + return lit_demo.serve() + + +if __name__ == "__main__": + app.run(main) diff --git a/lit_nlp/examples/models/pretrained_lms.py b/lit_nlp/examples/models/pretrained_lms.py index 8b72f8fe..28baea46 100644 --- a/lit_nlp/examples/models/pretrained_lms.py +++ b/lit_nlp/examples/models/pretrained_lms.py @@ -6,6 +6,8 @@ functions to predict a batch of examples and extract information such as hidden states and attention. """ +from collections.abc import Sequence +import functools import re from lit_nlp.api import model as lit_model @@ -147,6 +149,7 @@ def output_spec(self): } +# TODO(lit-dev): merge with below, inherit from GPT2BaseModel. class GPT2LanguageModel(lit_model.BatchedModel): """Wrapper for a Huggingface Transformers GPT-2 model. @@ -203,7 +206,7 @@ def clean_bpe_token(tok): else: return tok.replace("Ġ", "") - def _detokenize(self, ids): + def ids_to_clean_tokens(self, ids): tokens = self.tokenizer.convert_ids_to_tokens(ids) return [self.clean_bpe_token(t) for t in tokens] @@ -255,7 +258,7 @@ def _postprocess(self, preds): """Post-process single-example preds. Operates on numpy arrays.""" ntok = preds.pop("ntok") ids = preds.pop("input_ids")[:ntok] - preds["tokens"] = self._detokenize(ids) + preds["tokens"] = self.ids_to_clean_tokens(ids) # Decode predicted top-k tokens. # token_topk_preds will be a list[list[(word, prob)]] @@ -264,7 +267,7 @@ def _postprocess(self, preds): pred_ids = preds.pop("top_k_indices")[:ntok] # [num_tokens, k] pred_probs = preds.pop("top_k_probs")[:ntok] # [num_tokens, k] for token_pred_ids, token_pred_probs in zip(pred_ids, pred_probs): - token_pred_words = self._detokenize(token_pred_ids) + token_pred_words = self.ids_to_clean_tokens(token_pred_ids) token_topk_preds.append(list(zip(token_pred_words, token_pred_probs))) preds["pred_tokens"] = token_topk_preds @@ -326,46 +329,38 @@ def output_spec(self): return spec -class GPT2GenerativeModel(lit_model.BatchedModel): - """Wrapper for a Huggingface Transformers GPT-2 model. - - This class loads a tokenizer and model using the Huggingface library and - provides the LIT-required functions to generate text responses given input - prompts. +class GPT2BaseModel(lit_model.BatchedModel): + """Base class for GPT2 model wrappers.""" - Note that the default model generation config is used such that the response - is produced using multinomial sampling. - """ + @property + def num_layers(self): + return self.model.config.n_layer @classmethod def init_spec(cls) -> lit_model.Spec: return { "model_name_or_path": lit_types.String(default="gpt2"), - "max_new_tokens": lit_types.Integer(default=50, min_val=1, max_val=500), - "batch_size": lit_types.Integer(default=6, min_val=1, max_val=25), + "batch_size": lit_types.Integer(default=6, min_val=1, max_val=64), } def __init__( self, - model=None, - tokenizer=None, model_name_or_path="gpt2", - max_new_tokens=50, batch_size=6, + model=None, + tokenizer=None, ): - """Constructor for GPT2LanguageModel. + """Constructor for GPT2 model wrappers. Note: args "model" and "tokenizer" take priority if both are specified. Otherwise, "model_name_or_path" is used to initialize the model and tokenizer. Args: - model: an initialized GPT2 model compatible with Tensorflow. - tokenizer: an initialized GPT2 tokenizer. - model_name_or_path: gpt2, gpt2-medium, gpt2-large, gpt2-xl, distilgpt2, - etc. - max_new_tokens: the maximum number of new tokens to generate. + model_name_or_path: gpt2, gpt2-medium, gpt2-large, distilgpt2, etc. batch_size: the number of items to process per `predict_minibatch` call. + model: an initialized transformers.TFGPT2LMHeadModel. + tokenizer: an initialized GPT2 tokenizer. """ super().__init__() @@ -380,28 +375,103 @@ def __init__( model_name_or_path, extract_compressed_file=True ) + # Note: we need to left-pad for generation to work properly. + # Other modes such as scoring and salience should handle this as well; + # see example in GPT2SalienceModel._postprocess(). self.tokenizer = transformers.AutoTokenizer.from_pretrained( - model_name_or_path, use_fast=False + model_name_or_path, + use_fast=False, + padding_side="left", ) # Set this after init, as if pad_token= is passed to # AutoTokenizer.from_pretrained() above it will create a new token with # with id = max_vocab_length and cause out-of-bounds errors in # the embedding lookup. - self.tokenizer.pad_token = self.tokenizer.eos_token - self.model = transformers.TFAutoModelForCausalLM.from_pretrained( - model_name_or_path + self.model = transformers.TFGPT2LMHeadModel.from_pretrained( + model_name_or_path, output_hidden_states=True, output_attentions=False ) - self.max_new_tokens = max_new_tokens + self.tokenizer.pad_token = self.tokenizer.eos_token self.batch_size = batch_size - ## - # LIT API implementations + @property + def pad_left(self): + return self.tokenizer.padding_side == "left" + + @classmethod + def from_loaded(cls, existing: "GPT2BaseModel", *args, **kw): + """Share weights and underlying Keras model with another instance.""" + return cls(model=existing.model, tokenizer=existing.tokenizer, *args, **kw) + + def clean_bpe_token(self, tok): + tok = tok.replace("Ċ", "\n") # newlines + tok = tok.replace("Ġ", "▁") # start of word -> magic underscore + return tok + + def ids_to_clean_tokens(self, ids: Sequence[int]) -> list[str]: + tokens = self.tokenizer.convert_ids_to_tokens(ids) + return [self.clean_bpe_token(t) for t in tokens] + def max_minibatch_size(self) -> int: # The BatchedModel base class handles batching automatically in the # implementation of predict(), and uses this value as the batch size. return self.batch_size + def input_spec(self): + return { + "prompt": lit_types.TextSegment(), + "target": lit_types.TextSegment(required=False), + } + + +class GPT2GenerativeModel(GPT2BaseModel): + """Wrapper for a Huggingface Transformers GPT-2 model. + + This class loads a tokenizer and model using the Huggingface library and + provides the LIT-required functions to generate text responses given input + prompts. + + Note that the default model generation config is used such that the response + is produced using multinomial sampling. + """ + + @classmethod + def init_spec(cls) -> lit_model.Spec: + return super().init_spec() | { + "max_new_tokens": lit_types.Integer(default=50, min_val=1, max_val=500) + } + + def __init__(self, *args, max_new_tokens=50, **kw): + """Constructor for GPT2LanguageModel. + + Args: + *args: as to GPT2BaseModel.__init__ + max_new_tokens: the maximum number of new tokens to generate. + **kw: as to GPT2BaseModel.__init__ + """ + super().__init__(*args, **kw) + self.max_new_tokens = max_new_tokens + + def _postprocess(self, preds): + """Post-process single-example preds. Operates on numpy arrays.""" + # TODO(b/324957491): return actual decoder scores for each generation. + # GeneratedTextCandidates should be a list[(text, score)] + preds["response"] = [(preds["response"], 1.0)] + ntok_in = preds.pop("ntok_in") + embs = preds.pop("embs") + # Mean-pool over input tokens. + preds["prompt_embeddings"] = np.mean( + embs[-(self.max_new_tokens + ntok_in) : -self.max_new_tokens], axis=0 + ) + # Mean-pool over output (generated) tokens. + # TODO(b/324957491): slice this to only "real" output tokens, + # if generation length < max generation length. + preds["response_embeddings"] = np.mean(embs[-self.max_new_tokens :], axis=0) + + return preds + + ## + # LIT API implementations def predict_minibatch(self, inputs): prompts = [ex["prompt"] for ex in inputs] encoded_inputs = self.tokenizer.batch_encode_plus( @@ -413,28 +483,210 @@ def predict_minibatch(self, inputs): ) outputs = self.model.generate( encoded_inputs["input_ids"], + attention_mask=encoded_inputs["attention_mask"], max_new_tokens=self.max_new_tokens, ) + responses = self.tokenizer.batch_decode( outputs[:, -self.max_new_tokens :], skip_special_tokens=True ) + # Input embeddings: [batch_size, num_tokens, emb_dim] embeddings = self.model.transformer.wte(outputs) - return [ - { - "response": responses[i], - "prompt_embeddings": embeddings[i, : -self.max_new_tokens], - "response_embeddings": embeddings[i, -self.max_new_tokens :] - } for i in range(len(outputs)) - ] + batched_outputs = { + "embs": embeddings, + "ntok_in": tf.reduce_sum(encoded_inputs["attention_mask"], axis=1), + # TODO(b/324957491): compute ntok_out if < max_output_tokens ? + } + + # Convert to numpy for post-processing. + detached_outputs = {k: v.numpy() for k, v in batched_outputs.items()} + detached_outputs["response"] = responses + # Split up batched outputs, then post-process each example. + unbatched_outputs = utils.unbatch_preds(detached_outputs) + return map(self._postprocess, unbatched_outputs) + + def output_spec(self) -> lit_types.Spec: + return { + "response": lit_types.GeneratedTextCandidates(parent="target"), + "prompt_embeddings": lit_types.Embeddings(required=False), + "response_embeddings": lit_types.Embeddings(required=False), + } + + +class GPT2SalienceModel(GPT2BaseModel): + """Wrapper for GPT-2 input (token) salience.""" + + def _pred(self, encoded_inputs, target_masks): + """Predicts one batch of tokenized text. + + Also performs some batch-level post-processing in TF. + Single-example postprocessing is done in _postprocess(), and operates on + numpy arrays. + + Args: + encoded_inputs: output of self.tokenizer.batch_encode_plus() + target_masks: list(array_like) of binary (0/1) masks for each input + + Returns: + payload: Dictionary with items described above, each as single Tensor. + """ + input_ids = encoded_inputs["input_ids"] + + # [batch_size, num_tokens]; ignore the last one in each row. + target_ids = tf.roll(encoded_inputs["input_ids"], shift=-1, axis=1) + ## + # Process target masks + + # It doesn't make sense to interpret the first token, since it is not ever + # predicted. But we need to ensure that the mask[0] is zero, so it doesn't + # cause problems when 'rolled' to the last position below. + modified_masks = [[0] + list(mask[1:]) for mask in target_masks] + seq_len = target_ids.shape[1] + pad_fn = functools.partial( + utils.pad1d, + min_len=seq_len, + max_len=seq_len, + pad_val=0, + pad_left=self.pad_left, + ) + padded_target_masks = np.stack( + [pad_fn(mask) for mask in modified_masks], + axis=0, + ) + + padded_target_masks = tf.constant(padded_target_masks, dtype=tf.float32) + # Shift masks back so they align with target_ids. + loss_mask = tf.roll(padded_target_masks, shift=-1, axis=1) + + with tf.GradientTape(watch_accessed_variables=True) as tape: + # We need to run the embedding layer ourselves so we can trace it. + # See here for how the model normally does this: + # http://google3/third_party/py/transformers/models/gpt2/modeling_tf_gpt2.py;l=450;rcl=578656271 + embs = self.model.transformer.wte(input_ids, mode="embedding") + tape.watch(embs) + + out = self.model( + input_ids=None, + inputs_embeds=embs, + attention_mask=encoded_inputs["attention_mask"], + ) + + loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( + from_logits=True, reduction="none" + ) + # [batch_size, num_tokens] + per_token_loss = loss_fn(target_ids, out.logits) + masked_loss = per_token_loss * loss_mask + + grads = tape.gradient( + masked_loss, embs + ) # [batch_size, num_tokens, hdim] + + grad_l2 = tf.norm(grads, axis=2) # [batch_size, num_tokens] + grad_dot_input = tf.reduce_sum( + grads * embs, axis=2 + ) # [batch_size, num_tokens] + + batched_outputs = { + "input_ids": encoded_inputs["input_ids"], + "attention_mask": encoded_inputs["attention_mask"], + # Gradients are already aligned to input tokens. + "grad_l2": grad_l2, + "grad_dot_input": grad_dot_input, + # Shift token loss to align with (input) tokens. + "token_loss": tf.roll(per_token_loss, shift=1, axis=1), + } + + return batched_outputs + + def _postprocess(self, preds): + """Post-process single-example preds. Operates on numpy arrays.""" + # Be sure to cast to bool, otherwise this will select intger positions 0, 1 + # rather than acting as a boolean mask. + mask = preds.pop("attention_mask").astype(bool) + ids = preds.pop("input_ids")[mask] + preds["tokens"] = self.ids_to_clean_tokens(ids) + for key in utils.find_spec_keys(self.output_spec(), lit_types.TokenScores): + preds[key] = preds[key][mask] + # First token (usually ) is not actually predicted, so return 0 for loss. + preds["token_loss"][0] = 0 + + return preds + + # LIT API implementations + def predict_minibatch(self, inputs): + """Predict on a single minibatch of examples.""" + # Preprocess inputs. + texts = [ex["prompt"] + ex.get("target", "") for ex in inputs] + encoded_inputs = self.tokenizer.batch_encode_plus( + texts, + return_tensors="tf", + add_special_tokens=True, + padding="longest", + truncation="longest_first", + ) + target_masks = [ex.get("target_mask", []) for ex in inputs] + + # Get the predictions. + batched_outputs = self._pred(encoded_inputs, target_masks) + # Convert to numpy for post-processing. + detached_outputs = {k: v.numpy() for k, v in batched_outputs.items()} + # Split up batched outputs, then post-process each example. + unbatched_outputs = utils.unbatch_preds(detached_outputs) + return map(self._postprocess, unbatched_outputs) def input_spec(self): + return super().input_spec() | { + "target_mask": lit_types.TokenScores(align="", required=False), + } + + def output_spec(self) -> lit_types.Spec: return { - "prompt": lit_types.TextSegment(), + "tokens": lit_types.Tokens(parent=""), # all tokens + "grad_l2": lit_types.TokenScores(align="tokens"), + "grad_dot_input": lit_types.TokenScores(align="tokens"), + "token_loss": lit_types.TokenScores(align="tokens"), + } + + +class GPT2TokenizerModel(GPT2BaseModel): + """Wrapper to run only the tokenizer. + + Should exactly match tokens from GPT2SalienceModel. + """ + + def _postprocess(self, preds): + """Post-process single-example preds. Operates on numpy arrays.""" + # Be sure to cast to bool, otherwise this will select intger positions 0, 1 + # rather than acting as a boolean mask. + mask = preds.pop("attention_mask").astype(bool) + ids = preds.pop("input_ids")[mask] + preds["tokens"] = self.ids_to_clean_tokens(ids) + return preds + + # LIT API implementations + def predict_minibatch(self, inputs): + """Predict on a single minibatch of examples.""" + # Preprocess inputs. + texts = [ex["prompt"] + ex.get("target", "") for ex in inputs] + encoded_inputs = self.tokenizer.batch_encode_plus( + texts, + return_tensors="tf", + add_special_tokens=True, + padding="longest", + truncation="longest_first", + ) + batched_outputs = { + "input_ids": encoded_inputs["input_ids"], + "attention_mask": encoded_inputs["attention_mask"], } + # Convert to numpy for post-processing. + detached_outputs = {k: v.numpy() for k, v in batched_outputs.items()} + # Split up batched outputs, then post-process each example. + unbatched_outputs = utils.unbatch_preds(detached_outputs) + return map(self._postprocess, unbatched_outputs) def output_spec(self) -> lit_types.Spec: return { - "response": lit_types.GeneratedTextCandidates(), - "prompt_embeddings": lit_types.Embeddings(required=False), - "response_embeddings": lit_types.Embeddings(required=False) + "tokens": lit_types.Tokens(parent=""), # all tokens } diff --git a/lit_nlp/examples/models/pretrained_lms_int_test.py b/lit_nlp/examples/models/pretrained_lms_int_test.py index 62583ef6..84dce7e7 100644 --- a/lit_nlp/examples/models/pretrained_lms_int_test.py +++ b/lit_nlp/examples/models/pretrained_lms_int_test.py @@ -44,8 +44,10 @@ def test_gpt2_generation(self): self.assertIn(key, model_out[0].keys()) # Check that the embedding dimension is the same for prompt and response. - self.assertEqual(model_out[0]["prompt_embeddings"].shape[1], - model_out[0]["response_embeddings"].shape[1]) + self.assertEqual( + model_out[0]["prompt_embeddings"].shape, + model_out[0]["response_embeddings"].shape, + ) if __name__ == "__main__":