diff --git a/lit_nlp/client/lib/token_utils.ts b/lit_nlp/client/lib/token_utils.ts index b81f04ae..75317d36 100644 --- a/lit_nlp/client/lib/token_utils.ts +++ b/lit_nlp/client/lib/token_utils.ts @@ -17,13 +17,22 @@ export function cleanSpmText(text: string): string { /** * Use a regex to match segment prefixes. The prefix and anything * following it (until the next match) are treated as one segment. + * + * @param tokens tokens to group + * @param matcher regex to group by; must have /g set + * @param breakOnMatchEnd if true, will also break segments on the /end/ of + * a matching span in addition to the beginning. */ -export function groupTokensByRegexPrefix( - tokens: string[], - matcher: RegExp, - ): string[][] { +function groupTokensByRegex( + tokens: string[], matcher: RegExp, breakOnMatchEnd: boolean): string[][] { const text = tokens.join(''); - const matches = [...text.matchAll(matcher)]; + const matchIdxs: Array = []; + for (const match of text.matchAll(matcher)) { + matchIdxs.push(match.index); + if (match.index !== undefined && breakOnMatchEnd) { + matchIdxs.push(match.index + match[0].length); + } + } let textCharOffset = 0; // chars into text let matchIdx = 0; // indices into matches @@ -31,12 +40,11 @@ export function groupTokensByRegexPrefix( let acc: string[] = []; for (let i = 0; i < tokens.length; i++) { const token = tokens[i]; - const nextMatch = matches[matchIdx]; + const nextMatch = matchIdxs[matchIdx]; // Look ahead to see if this token intrudes on a match. // If so, start a new segment before pushing the token. - if (nextMatch !== undefined && - textCharOffset + token.length > nextMatch.index!) { + if (nextMatch !== undefined && textCharOffset + token.length > nextMatch) { // Don't push an empty group if the first token is part of a match. if (acc.length > 0 || groups.length > 0) groups.push(acc); acc = []; @@ -50,4 +58,24 @@ export function groupTokensByRegexPrefix( // Finally, push any open group. if (acc.length > 0) groups.push(acc); return groups; +} + +/** + * Use a regex to match segment prefixes. The prefix and anything + * following it (until the next match) are treated as one segment. + * For example, groupTokensByRegexPrefix(tokens, /Example:/g) will + * create a segment each time the text "Example:" is seen. + */ +export function groupTokensByRegexPrefix(tokens: string[], matcher: RegExp) { + return groupTokensByRegex(tokens, matcher, /* breakOnMatchEnd */ false); +} + +/** + * Use a regex to match a separator segment. A matching span is treated + * as a segment, and anything between matches is treated as a separate segment. + * For example, groupTokensByRegexSeparator(tokens, /\n+/g) will group tokens + * in between newlines, with any sequence of \n as its own segment. + */ +export function groupTokensByRegexSeparator(tokens: string[], matcher: RegExp) { + return groupTokensByRegex(tokens, matcher, /* breakOnMatchEnd */ true); } \ No newline at end of file diff --git a/lit_nlp/client/lib/token_utils_test.ts b/lit_nlp/client/lib/token_utils_test.ts index 76223106..6ef966d4 100644 --- a/lit_nlp/client/lib/token_utils_test.ts +++ b/lit_nlp/client/lib/token_utils_test.ts @@ -85,4 +85,40 @@ describe('groupTokensByRegexPrefix test', () => { expect(groups).toEqual(expectedGroups); }); }); +}); + + +describe('groupTokensByRegexSeparator test', () => { + [{ + testcaseName: 'groups tokens by line', + tokens: [ + 'Sum', 'mar', 'ize', '▁this', '▁sent', 'ence', ':', '\n', '\n', 'Once', + '▁upon', '▁a', '▁time', '\n', '▁there', '▁was' + ], + // Line separator is one or more \n + regex: /\n+/g, + expectedGroups: [ + ['Sum', 'mar', 'ize', '▁this', '▁sent', 'ence', ':'], ['\n', '\n'], + ['Once', '▁upon', '▁a', '▁time'], ['\n'], ['▁there', '▁was'] + ], + }, + { + testcaseName: 'groups tokens by paragraph', + tokens: [ + 'Sum', 'mar', 'ize', '▁this', '▁sent', 'ence', ':', '\n', '\n', 'Once', + '▁upon', '▁a', '▁time', '\n', '▁there', '▁was' + ], + // Line separator is two or more \n + regex: /\n\n+/g, + expectedGroups: [ + ['Sum', 'mar', 'ize', '▁this', '▁sent', 'ence', ':'], ['\n', '\n'], + ['Once', '▁upon', '▁a', '▁time', '\n', '▁there', '▁was'] + ], + }, + ].forEach(({testcaseName, tokens, regex, expectedGroups}) => { + it(testcaseName, () => { + const groups = tokenUtils.groupTokensByRegexSeparator(tokens, regex); + expect(groups).toEqual(expectedGroups); + }); + }); }); \ No newline at end of file diff --git a/lit_nlp/client/modules/lm_salience_module.ts b/lit_nlp/client/modules/lm_salience_module.ts index 746ad158..26ba9996 100644 --- a/lit_nlp/client/modules/lm_salience_module.ts +++ b/lit_nlp/client/modules/lm_salience_module.ts @@ -22,7 +22,7 @@ import {CONTINUOUS_SIGNED_LAB, CONTINUOUS_UNSIGNED_LAB, SalienceCmap, SignedSali import {GENERATION_TYPES, getAllTargetOptions, TargetOption, TargetSource} from '../lib/generated_text_utils'; import {LitType, LitTypeTypesList, Tokens, TokenScores} from '../lib/lit_types'; import {styles as sharedStyles} from '../lib/shared_styles.css'; -import {cleanSpmText, groupTokensByRegexPrefix} from '../lib/token_utils'; +import {cleanSpmText, groupTokensByRegexPrefix, groupTokensByRegexSeparator} from '../lib/token_utils'; import {type IndexedInput, type Preds, SCROLL_SYNC_CSS_CLASS, type Spec} from '../lib/types'; import {cumSumArray, filterToKeys, findSpecKeys, groupAlike, makeModifiedInput, sumArray} from '../lib/utils'; @@ -325,16 +325,11 @@ export class LMSalienceModule extends SingleExampleSingleModelModule { return groupTokensByRegexPrefix( this.currentTokens, /(\n+)|((?<=\n)[^\n])|((?<=[.?!])([▁\s]+))/g); } else if (this.segmentationMode === SegmentationMode.LINES) { - // Line start is either: - // - a run of consecutive \n as its own segment - // - any non-\n following \n - return groupTokensByRegexPrefix(this.currentTokens, /(\n+)|([^\n]+)/g); + // Line separator is one or more newlines. + return groupTokensByRegexSeparator(this.currentTokens, /\n+/g); } else if (this.segmentationMode === SegmentationMode.PARAGRAPHS) { - // Paragraph start is either: - // - two or more newlines as its own segment - // - any non-\n following \n\n - return groupTokensByRegexPrefix( - this.currentTokens, /(\n\n+)|(?<=\n\n)([^\n]+)/g); + // Paragraph separator is two or more newlines. + return groupTokensByRegexSeparator(this.currentTokens, /\n\n+/g); } else { throw new Error( `Unsupported segmentation mode ${this.segmentationMode}.`);