From 1a7bb31bda0c9378548bca3d1a1c6d15a1e5aacc Mon Sep 17 00:00:00 2001 From: Joshua Horton Date: Thu, 30 Apr 2026 08:28:43 -0500 Subject: [PATCH] feat(web): tokenize input corrections and provide for multi-token predictions Build-bot: skip build:web Test-bot: skip --- .../correction/tokenization-result-mapping.ts | 4 +- .../worker-thread/src/main/predict-helpers.ts | 85 ++++++++++--------- ...ine-tokenized-correction-sequence.tests.ts | 48 +++++++---- ...raversalless-correction-sequences.tests.ts | 6 +- 4 files changed, 83 insertions(+), 60 deletions(-) diff --git a/web/src/engine/predictive-text/worker-thread/src/main/correction/tokenization-result-mapping.ts b/web/src/engine/predictive-text/worker-thread/src/main/correction/tokenization-result-mapping.ts index 32e0fb48fce..c5588424afe 100644 --- a/web/src/engine/predictive-text/worker-thread/src/main/correction/tokenization-result-mapping.ts +++ b/web/src/engine/predictive-text/worker-thread/src/main/correction/tokenization-result-mapping.ts @@ -5,13 +5,13 @@ export class TokenizationResultMapping implements CorrectionResultMapping; - constructor(tokenization: TokenResult[], corrector: TokenizationCorrector) { + constructor(tokenization: TokenResult[], corrector?: TokenizationCorrector) { this.matchingSpace = corrector; this.matchedResult = tokenization; } get spaceId(): number { - return this.matchingSpace.tokenization.spaceId; + return this.matchingSpace?.tokenization.spaceId; } // /** diff --git a/web/src/engine/predictive-text/worker-thread/src/main/predict-helpers.ts b/web/src/engine/predictive-text/worker-thread/src/main/predict-helpers.ts index 196a6ac235c..67b04137bbc 100644 --- a/web/src/engine/predictive-text/worker-thread/src/main/predict-helpers.ts +++ b/web/src/engine/predictive-text/worker-thread/src/main/predict-helpers.ts @@ -14,6 +14,8 @@ import { ExecutionTimer } from './correction/execution-timer.js'; import { ModelCompositor } from './model-compositor.js'; import { EDIT_DISTANCE_COST_SCALE, getBestTokenMatches } from './correction/distance-modeler.js'; import { TokenResult } from './correction/tokenization-corrector.js'; +import { TokenizationCorrector } from './correction/tokenization-corrector.js'; +import { TokenizationResultMapping } from './correction/tokenization-result-mapping.js'; import CasingForm = LexicalModelTypes.CasingForm; import Context = LexicalModelTypes.Context; @@ -233,7 +235,8 @@ export function determineTraversallessCorrectionSequences( return match; }); - const suggestionParams = buildCorrectionSequence(transitionEffects, context, correctionRoots[correctionRoots.length - 1]); + // But, for now, only actually use the last one. + const suggestionParams = buildCorrectionSequence(transitionEffects, context, new TokenizationResultMapping([correctionRoots[correctionRoots.length - 1]], null)); const tokenizationMapping = mapWhitespacedTokenization(tokenization.left.map((t) => { return {exampleInput: t.text, codepointLength: KMWString.length(t.text)} }), lexicalModel, correction.sample); const tokenizedCorrection = tokenizationMapping.tokenizedTransform; @@ -377,12 +380,8 @@ export function determineSuggestionRange( return temp(a, b); } - const deleteLeftCalc = (tokenSet: T[], predictCount: number) => { - // TODO: once we start activating multi-tokenization for real, only the - // 'reduce' component should remain. - return (predictCount > 1) - ? (tokenSet[tokenSet.length - 1]?.codepointLength ?? 0) - : tokenSet.reduce((prev, curr) => prev + curr.codepointLength, 0); + const deleteLeftCalc = (tokenSet: T[]) => { + return tokenSet.reduce((prev, curr) => prev + curr.codepointLength, 0); } const tokenSetA = userContextTokenization.slice(); @@ -396,7 +395,7 @@ export function determineSuggestionRange( return { tokensToRemove: tokenSetA, tokensToPredict: tokenSetB, - deleteLeft: deleteLeftCalc(tokenSetA, tokenSetB.length) + deleteLeft: deleteLeftCalc(tokenSetA) } } else if(aHeadIndexInB != 0 && bHeadIndexInA != 0) { throw new Error("Leading edge of context should not differ in both tokenizations."); @@ -422,7 +421,7 @@ export function determineSuggestionRange( return { tokensToRemove, tokensToPredict, - deleteLeft: deleteLeftCalc(tokensToRemove, tokensToPredict.length) + deleteLeft: deleteLeftCalc(tokensToRemove) } } @@ -456,44 +455,49 @@ export interface PredictionParameters { } export function buildCorrectionSequence( - transitionEffects: ReturnType, + transitionEffects: SuggestionReplacement, context: Context, - match: Readonly, + tokenizationCorrection: TokenizationResultMapping ) { const { deleteLeft } = transitionEffects; const rootContext = models.applyTransform({insert: '', deleteLeft}, context); // Replace the existing context with the correction. - const correctionTransform: Transform = { - insert: match.matchString, // insert correction string - deleteLeft: 0, - } + const tokenizedCorrections = tokenizationCorrection.matchedResult.map((correction, i) => { + /* If we're dealing with the FIRST keystroke of a new sequence, we'll **dramatically** boost + * the exponent to ensure only VERY nearby corrections have a chance of winning, and only if + * there are significantly more likely words. We only need this to allow very minor fat-finger + * adjustments for 100% keystroke-sequence corrections in order to prevent finickiness on + * key borders. + * + * Technically, the probabilities this produces won't be normalized as-is... but there's no + * true NEED to do so for it, even if it'd be 'nice to have'. Consistently tracking when + * to apply it could become tricky, so it's simpler to leave out. + * + * Worst-case, it's possible to temporarily add normalization if a code deep-dive + * is needed in the future. + */ + const costFactor = (correction.inputCount <= 1) ? ModelCompositor.SINGLE_CHAR_KEY_PROB_EXPONENT : 1; + + const entry = { + sample: { + insert: correction.matchString, // insert correction string + deleteLeft: 0, + } as Transform, + p: Math.exp(-correction.totalCost * costFactor) + }; - /* If we're dealing with the FIRST keystroke of a new sequence, we'll **dramatically** boost - * the exponent to ensure only VERY nearby corrections have a chance of winning, and only if - * there are significantly more likely words. We only need this to allow very minor fat-finger - * adjustments for 100% keystroke-sequence corrections in order to prevent finickiness on - * key borders. - * - * Technically, the probabilities this produces won't be normalized as-is... but there's no - * true NEED to do so for it, even if it'd be 'nice to have'. Consistently tracking when - * to apply it could become tricky, so it's simpler to leave out. - * - * Worst-case, it's possible to temporarily add normalization if a code deep-dive - * is needed in the future. - */ - const costFactor = (match.inputCount <= 1) ? ModelCompositor.SINGLE_CHAR_KEY_PROB_EXPONENT : 1; - - const rootCost = match.totalCost; - const predictionRoot = { - sample: correctionTransform, - p: Math.exp(-rootCost * costFactor) - }; + if(transitionEffects.transitionId !== undefined) { + entry.sample.id = transitionEffects.transitionId; + } + + return entry; + }); return { rootContext, - tokenizedCorrection: [predictionRoot] + tokenizedCorrection: tokenizedCorrections }; } @@ -512,10 +516,11 @@ export function buildCorrectionSequence( export function determineTokenizedCorrectionSequence( transition: ContextTransition, tokenization: ContextTokenization, - match: Readonly + match: TokenizationResultMapping ): PredictionParameters { const applicationTarget = transition.base.displayTokenization; const transitionParams = determineSuggestionRange(applicationTarget.tokens, tokenization.tokens, (a, b) => a.spaceId == b.spaceId); + transitionParams.transitionId = transition.transitionId; const suggestionParams = buildCorrectionSequence(transitionParams, transition.base.context, match); @@ -641,7 +646,11 @@ export async function correctAndEnumerate( continue; } - const predictionPrep = determineTokenizedCorrectionSequence(transition, tokenization, match); + const suggestionRange = determineSuggestionRange(transition.base.displayTokenization.tokens, tokenization.tokens, (a, b) => a.spaceId == b.spaceId); + suggestionRange.transitionId = transition.transitionId; + const corrector = new TokenizationCorrector(tokenization, suggestionRange.tokensToPredict.length, () => true); + const predictionPrep = determineTokenizedCorrectionSequence(transition, tokenization, new TokenizationResultMapping([match], corrector)); + const predictions = predictFromCorrectionSequence(lexicalModel, predictionPrep.tokenizedCorrection, predictionPrep.rootContext, transition.transitionId); predictions.forEach((p) => predictionPrep.applyInPost(p)); diff --git a/web/src/test/auto/headless/engine/predictive-text/worker-thread/prediction-helpers/determine-tokenized-correction-sequence.tests.ts b/web/src/test/auto/headless/engine/predictive-text/worker-thread/prediction-helpers/determine-tokenized-correction-sequence.tests.ts index 7623bfffe51..ee8e2ba9109 100644 --- a/web/src/test/auto/headless/engine/predictive-text/worker-thread/prediction-helpers/determine-tokenized-correction-sequence.tests.ts +++ b/web/src/test/auto/headless/engine/predictive-text/worker-thread/prediction-helpers/determine-tokenized-correction-sequence.tests.ts @@ -15,7 +15,16 @@ import * as wordBreakers from '@keymanapp/models-wordbreakers'; import { jsonFixture } from '@keymanapp/common-test-resources/model-helpers.mjs'; import { KMWString } from 'keyman/common/web-utils'; -import { determineTokenizedCorrectionSequence, models, ContextState, ContextToken, ContextTokenization, CorrectionPredictionTuple, ModelCompositor } from "@keymanapp/lm-worker/test-index"; +import { + determineTokenizedCorrectionSequence, + models, + ContextState, + ContextToken, + ContextTokenization, + CorrectionPredictionTuple, + ModelCompositor, + TokenizationResultMapping +} from "@keymanapp/lm-worker/test-index"; import Context = LexicalModelTypes.Context; import ProbabilityMass = LexicalModelTypes.ProbabilityMass; @@ -51,13 +60,14 @@ describe('determineTokenizedCorrectionSequence', () => { const results = determineTokenizedCorrectionSequence( transition, - transition.final.displayTokenization, { + transition.final.displayTokenization, + new TokenizationResultMapping([{ matchString: 'fo', inputSamplingCost: -Math.log(trueInput.p), inputCount: 2, knownCost: 0, totalCost: -Math.log(trueInput.p) - } + }], null) ); assert.deepEqual({...results.rootContext, casingForm: results.rootContext.casingForm}, { @@ -101,13 +111,14 @@ describe('determineTokenizedCorrectionSequence', () => { const results = determineTokenizedCorrectionSequence( transition, - transition.final.displayTokenization, { + transition.final.displayTokenization, + new TokenizationResultMapping([{ matchString: ' ', inputSamplingCost: -Math.log(trueInput.p), inputCount: 1, knownCost: 0, totalCost: -Math.log(trueInput.p) - } + }], null) ); assert.deepEqual({...results.rootContext, casingForm: results.rootContext.casingForm}, { @@ -148,13 +159,14 @@ describe('determineTokenizedCorrectionSequence', () => { const results = determineTokenizedCorrectionSequence( transition, - transition.final.displayTokenization, { + transition.final.displayTokenization, + new TokenizationResultMapping([{ matchString: 'f', inputSamplingCost: -Math.log(trueInput.p), inputCount: 1, knownCost: 0, totalCost: -Math.log(trueInput.p) - } + }], null) ); assert.deepEqual({...results.rootContext, casingForm: results.rootContext.casingForm}, { @@ -202,13 +214,14 @@ describe('determineTokenizedCorrectionSequence', () => { const results = determineTokenizedCorrectionSequence( transition, - transition.final.displayTokenization, { + transition.final.displayTokenization, + new TokenizationResultMapping([{ matchString: 'can\'t', inputSamplingCost: -Math.log(trueInput.p), inputCount: 5, knownCost: 0, totalCost: -Math.log(trueInput.p) - } + }], null) ); assert.deepEqual({...results.rootContext, casingForm: results.rootContext.casingForm}, { @@ -253,13 +266,14 @@ describe('determineTokenizedCorrectionSequence', () => { const results = determineTokenizedCorrectionSequence( transition, - transition.final.displayTokenization, { + transition.final.displayTokenization, + new TokenizationResultMapping([{ matchString: ' ', inputSamplingCost: -Math.log(trueInput.p), inputCount: 1, knownCost: 0, totalCost: -Math.log(trueInput.p) - } + }], null) ); assert.deepEqual({...results.rootContext, casingForm: results.rootContext.casingForm}, { @@ -300,26 +314,26 @@ describe('determineTokenizedCorrectionSequence', () => { const results = determineTokenizedCorrectionSequence( transition, - transition.final.displayTokenization, { + transition.final.displayTokenization, + new TokenizationResultMapping([{ matchString: 'd', inputSamplingCost: -Math.log(trueInput.p), inputCount: 1, knownCost: 0, totalCost: -Math.log(trueInput.p) - } + }], null) ); - // Large-scale deletions will receive enhanced handling soon. But, for now, it's - // deleted by the `preservationTransform`, not here. assert.deepEqual({...results.rootContext, casingForm: results.rootContext.casingForm}, { casingForm: undefined, - left: 'the quick brown ', + left: 'the ', right: '', startOfBuffer: true, endOfBuffer: true }); - + // Coming up next - actually providing ALL correction elements, not just the final one. + // We're not _quite_ ready for that yet, though. assert.equal(results.tokenizedCorrection.length, 1); assert.deepEqual(results.tokenizedCorrection[0].sample, { insert: 'd', diff --git a/web/src/test/auto/headless/engine/predictive-text/worker-thread/prediction-helpers/determine-traversalless-correction-sequences.tests.ts b/web/src/test/auto/headless/engine/predictive-text/worker-thread/prediction-helpers/determine-traversalless-correction-sequences.tests.ts index 11a4e841311..f9c7dccd550 100644 --- a/web/src/test/auto/headless/engine/predictive-text/worker-thread/prediction-helpers/determine-traversalless-correction-sequences.tests.ts +++ b/web/src/test/auto/headless/engine/predictive-text/worker-thread/prediction-helpers/determine-traversalless-correction-sequences.tests.ts @@ -361,15 +361,15 @@ describe('determineTraversallessCorrectionSequences', () => { ...entry.rootContext, casingForm: entry.rootContext.casingForm ?? undefined }, { casingForm: undefined, - // Large-scale deletions will receive enhanced handling soon. But, for now, it's - // deleted by the `preservationTransform`, not here. - left: 'the quick brown ', + left: 'the ', right: '', startOfBuffer: true, endOfBuffer: true } ); + // Coming up next - actually providing ALL correction elements, not just the final one. + // We're not _quite_ ready for that yet, though. assert.equal(entry.tokenizedCorrection.length, 1); assert.deepEqual(entry.tokenizedCorrection[0].sample, { insert: 'd',