diff --git a/web/src/engine/predictive-text/worker-thread/src/main/model-compositor.ts b/web/src/engine/predictive-text/worker-thread/src/main/model-compositor.ts index 1d4893aab49..71c7704c144 100644 --- a/web/src/engine/predictive-text/worker-thread/src/main/model-compositor.ts +++ b/web/src/engine/predictive-text/worker-thread/src/main/model-compositor.ts @@ -1,15 +1,13 @@ import * as models from '@keymanapp/models-templates'; import { LexicalModelTypes } from '@keymanapp/common-types'; -import { applySuggestionCasing, correctAndEnumerate, createDefaultKeep, dedupeSuggestions, finalizeSuggestions, predictionAutoSelect, processSimilarity, toAnnotatedSuggestion, tupleDisplayOrderSort } from './predict-helpers.js'; -import { detectCurrentCasing, determineModelTokenizer, determineModelWordbreaker, determinePunctuationFromModel } from './model-helpers.js'; -import TransformUtils from './transformUtils.js'; - import * as correction from './correction/index.js' +import { applySuggestionCasing, compositeIntermediatePredictions, correctAndEnumerate, createDefaultKeep, dedupeSuggestions, finalizeSuggestions, predictionAutoSelect, processSimilarity, toAnnotatedSuggestion, tupleDisplayOrderSort } from './predict-helpers.js'; +import { determineModelTokenizer, determineModelWordbreaker, determinePunctuationFromModel } from './model-helpers.js'; + import { ContextTracker } from './correction/context-tracker.js'; import { DEFAULT_ALLOTTED_CORRECTION_TIME_INTERVAL } from './correction/distance-modeler.js'; -import CasingForm = LexicalModelTypes.CasingForm; import Configuration = LexicalModelTypes.Configuration; import Context = LexicalModelTypes.Context; import Distribution = LexicalModelTypes.Distribution; @@ -125,24 +123,6 @@ export class ModelCompositor { const transformId = inputTransform.id; this.initContextTracker(context, transformId); - const allowBksp = TransformUtils.isBackspace(inputTransform); - const allowWhitespace = TransformUtils.isWhitespace(inputTransform); - - const postContext = models.applyTransform(inputTransform, context); - - // TODO: It would be best for the correctAndEnumerate method to return the - // suggestion's prefix, as it already has lots of logic oriented to this. - // The context-tracker used there with v14+ models can determine this more - // robustly. - const truePrefix = this.wordbreak(postContext); - // Only use of `truePrefix`. - const basePrefix = (allowBksp || allowWhitespace) ? truePrefix : this.wordbreak(context); - - // Used to restore whitespaces if operations would remove them. - const currentCasing: CasingForm = lexicalModel.languageUsesCasing - ? detectCurrentCasing(lexicalModel, postContext) - : null; - // Section 1: determine 'prediction roots' - enumerate corrections from most to least likely, // searching for results that yield viable predictions from the model. @@ -160,9 +140,9 @@ export class ModelCompositor { // Properly capitalizes the suggestions based on the existing context casing state. // This may result in duplicates if multiple casing options exist within the // lexicon for a word. (Example: "Apple" the company vs "apple" the fruit.) - for(let tuple of rawPredictions) { - if(currentCasing && currentCasing != 'lower') { - applySuggestionCasing(tuple.components.prediction, basePrefix, this.lexicalModel, currentCasing); + if(lexicalModel.languageUsesCasing) { + for(let tuple of rawPredictions) { + tuple.components.forEach((component) => applySuggestionCasing(component, this.lexicalModel)); } } @@ -171,9 +151,10 @@ export class ModelCompositor { // We want to dedupe before trimming the list so that we can present a full set // of viable distinct suggestions if available. - const deduplicatedSuggestionTuples = dedupeSuggestions(this.lexicalModel, rawPredictions, context); + const deduplicatedSuggestionTuples = dedupeSuggestions(this.lexicalModel, compositeIntermediatePredictions(rawPredictions), context); // Needs "casing" to be applied first. + const postContext = postContextState?.context ?? models.applyTransform(inputTransform, context); const hasExistingKeep = processSimilarity(this.lexicalModel, deduplicatedSuggestionTuples, context, postContext); // If no existing suggestion directly matches the user-visible version of diff --git a/web/src/engine/predictive-text/worker-thread/src/main/predict-helpers.ts b/web/src/engine/predictive-text/worker-thread/src/main/predict-helpers.ts index 176c80e09f0..5f825e4b9d9 100644 --- a/web/src/engine/predictive-text/worker-thread/src/main/predict-helpers.ts +++ b/web/src/engine/predictive-text/worker-thread/src/main/predict-helpers.ts @@ -4,7 +4,7 @@ import { LexicalModelTypes } from '@keymanapp/common-types'; import { defaultWordbreaker, WordBreakProperty } from '@keymanapp/models-wordbreakers'; import TransformUtils from './transformUtils.js'; -import { determineModelTokenizer, determineModelWordbreaker, determinePunctuationFromModel } from './model-helpers.js'; +import { detectCurrentCasing, determineModelTokenizer, determineModelWordbreaker, determinePunctuationFromModel } from './model-helpers.js'; import { ContextTokenization, mapWhitespacedTokenization } from './correction/context-tokenization.js'; import { ContextTracker } from './correction/context-tracker.js'; import { ContextToken } from './correction/context-token.js'; @@ -76,6 +76,24 @@ export const CORRECTION_SEARCH_THRESHOLDS = { REPLACEMENT_SEARCH_THRESHOLD: 4 as const // e^-4 = 0.0183156388. Allows "80%" of an extra edit. } +export interface TokenizedPredictionData { + /** + * The potential Suggestion + */ + prediction: Suggestion, + /** + * The correction upon which the Suggestion is based + */ + correction: string, + /** + * The ContextToken underlying the correction/prediction. + * + * May be undefined, especially for models that do not leverage the + * LexiconTraversal pattern. + */ + source: ContextToken // useful for getting the unkeyed, original version of the text (in model-compositor, where casing is applied) +} + export interface CompositedPredictionData { /** * The potential Suggestion (or Keep) @@ -138,6 +156,19 @@ export interface PredictionMetadata { preservationTransform?: Transform; } +export interface IntermediateTokenizedPrediction { + /** + * Contains the tokenized components to be used to construct a full + * predictive-text Suggestion, as well as data about the source for each + * component. + */ + components: TokenizedPredictionData[]; + /** + * Tracks common intermediate prediction data, such as its underlying probabilities and its similarity to the actual context. + */ + metadata: PredictionMetadata; +} + export interface IntermediateCompositedPrediction { /** * Contains the fully composited predictive-text Suggestion and its underlying correction string. @@ -149,7 +180,7 @@ export interface IntermediateCompositedPrediction { metadata: PredictionMetadata; } -type IntermediatePrediction = IntermediateCompositedPrediction; +type IntermediatePrediction = IntermediateCompositedPrediction | IntermediateTokenizedPrediction; /** * An enum to be used when categorizing the level of similarity between @@ -363,7 +394,7 @@ export function determineSuggestionRange( export function buildAndMapPredictions( transition: ContextTransition, tokenizationCorrection: TokenizationResultMapping, -): IntermediateCompositedPrediction[] { +): IntermediateTokenizedPrediction[] { const model = transition.final.model; const tokenization = tokenizationCorrection.matchingSpace.tokenization; @@ -451,21 +482,28 @@ export function buildAndMapPredictions( // rather than predicting (and possibly extending) tokens not adjacent to the caret. // // Also, fall back to the actual correction string should prediction not be valid here. - return i == correctionTransforms.length - 1 ? predictions : [predictions[0]]; + const predictionsToReturn = i == correctionTransforms.length - 1 ? predictions : [predictions[0]]; + + return predictionsToReturn.map((prediction) => { + return { + prediction, + correction: correctionTransform.insert + }; + }); }); // Constructs a common prefix for all but the final token's component. const predictionPrefix = predictionComponents .slice(0, predictionComponents.length-1) - .reduce((accum, curr) => models.buildMergedTransform(accum, curr[0].sample.transform), { insert: '', deleteLeft: 0 }); + .map((arr) => arr[0]); const prefixProb = predictionComponents .slice(0, predictionComponents.length-1) - .reduce((accum, curr) => accum * curr[0].p, 1) + .reduce((accum, curr) => accum * curr[0].prediction.p, 1) - const completePredictionTuples: IntermediateCompositedPrediction[] = predictionComponents[predictionComponents.length-1].map((prediction) => { - const predictionCost = prediction.p * prefixProb; + const completePredictionTuples: IntermediateTokenizedPrediction[] = predictionComponents[predictionComponents.length-1].map((tuple) => { + const predictionCost = tuple.prediction.p * prefixProb; - return { + const returnVal: IntermediateTokenizedPrediction = { // Will need to do this differently. We want to have each component // individualized b/c casing. Case should be maintained for prior tokens // and managed independently for each. @@ -476,23 +514,15 @@ export function buildAndMapPredictions( // applySuggestionCasing applies onto suggestions, so we'll want to build // the FULL suggestion AFTER applying casing changes (to each token's // suggestion component). - components: { - prediction: { - transformId: transition.transitionId, - transform: models.buildMergedTransform(predictionPrefix, prediction.sample.transform), - displayAs: models.buildMergedTransform(predictionPrefix, prediction.sample.transform).insert // should composite the displayAs strings instead... - }, - correction: correctionTransforms[correctionTransforms.length-1].insert - }, + components: [], metadata: { probabilities: { prediction: predictionCost, correction: correctionCost, total: predictionCost * correctionCost }, - matchLevel: SuggestionSimilarity.none, autoSelectable: tokenizationCorrection.matchingSpace.modelsCorrectables, - + matchLevel: SuggestionSimilarity.none, // Long-term, we shouldn't have `.preservationTransform` here. // // Needed for now until the search actually operates based on @@ -501,6 +531,21 @@ export function buildAndMapPredictions( preservationTransform: tokenization.taillessTrueKeystroke } } + + // Iteratively add the components into the return value here. + const orderedTokens = tokenizationCorrection.matchingSpace.orderedTokens; + const reportTokenizedPrediction = (tuple: typeof predictionPrefix[0], index: number) => { + returnVal.components.push({ + prediction: tuple.prediction.sample, + correction: tuple.correction, + source: orderedTokens[index] + }); + }; + // Also gets the (changing) tail portion. + predictionPrefix.forEach((tuple, index) => reportTokenizedPrediction(tuple, index)); + reportTokenizedPrediction(tuple, orderedTokens.length - 1); + + return returnVal; }); return completePredictionTuples; @@ -565,7 +610,7 @@ export async function correctAndEnumerate( /** * The suggestions generated based on the user's input state. */ - rawPredictions: IntermediateCompositedPrediction[]; + rawPredictions: IntermediateTokenizedPrediction[]; /** * The id of a prior ContextTransition event that triggered a Suggestion found @@ -617,7 +662,7 @@ export async function correctAndEnumerate( const searchModules = tokenizations.map(t => t.tail.searchModule); // Only run the correction search when corrections are enabled. - let rawPredictions: IntermediateCompositedPrediction[] = []; + let rawPredictions: IntermediateTokenizedPrediction[] = []; let bestCorrectionCost: number; for await(const match of getBestTokenMatches(searchModules, timer)) { // Corrections obtained: now to predict from them! @@ -666,7 +711,7 @@ export async function correctAndEnumerate( export function shouldStopSearchingEarly( bestCorrectionCost: number, currentCorrectionCost: number, - rawPredictions: IntermediateCompositedPrediction[] + rawPredictions: IntermediateTokenizedPrediction[] ) { if(currentCorrectionCost >= bestCorrectionCost + CORRECTION_SEARCH_THRESHOLDS.MAX_SEARCH_THRESHOLD) { return true; @@ -707,9 +752,8 @@ export function correctAndEnumerateWithoutTraversals( lexicalModel: LexicalModel, corrections: ProbabilityMass[], context: Context -): IntermediateCompositedPrediction[] { - let returnedPredictions: IntermediateCompositedPrediction[] = []; - +): IntermediateTokenizedPrediction[] { + let returnedPredictions: IntermediateTokenizedPrediction[] = []; const wordbreak = determineModelWordbreaker(lexicalModel); const tokenizer = determineModelTokenizer(lexicalModel); @@ -720,13 +764,28 @@ export function correctAndEnumerateWithoutTraversals( // support, though. const tokenizedCorrection = mapWhitespacedTokenization(tokenization.left.map((t) => { return {exampleInput: t.text} }), lexicalModel, correction.sample).tokenizedTransform; - const deleteLeft = [...tokenizedCorrection.values()].reduce((total, curr) => total + curr.deleteLeft, 0); + const deleteLeft = tokenization.left.length > 1 ? 0 : tokenization.left.reduce((prev, curr) => prev + KMWString.length(curr.text), 0); + + const intermediateTokens: TokenizedPredictionData[] = []; + [...tokenizedCorrection.entries()].forEach((entry, index) => { + let dl = index == 0 ? deleteLeft: 0; + let text: string; + + if(index != 0) { + text = entry[1].insert; + } else { + text = wordbreak(models.applyTransform(entry[1], context)); + } - const tokenizedCorrectionEntries = [...tokenizedCorrection.entries()]; - const preservationTransform = tokenizedCorrectionEntries.slice(0, -1).map((e) => e[1]).reduce((accum, curr) => { - return models.buildMergedTransform(accum, {...curr, deleteLeft: 0}); - }, { insert: '', deleteLeft: 0, id: correction.sample.id}); - preservationTransform.deleteLeft = deleteLeft; + intermediateTokens.push({ + prediction: { + transform: { insert: text, deleteLeft: dl }, + displayAs: text + }, + correction: text, + source: null + }) + }); // Step 2: predict based on the final token. const emptyContext: Context = { @@ -735,32 +794,28 @@ export function correctAndEnumerateWithoutTraversals( endOfBuffer: true }; - const tailCorrection = tokenizedCorrectionEntries[tokenizedCorrectionEntries.length-1][1]; + const tailCorrection = { insert: intermediateTokens[intermediateTokens.length-1].correction, deleteLeft: 0}; let predictions = lexicalModel.predict(tailCorrection, emptyContext); // Step 3: create the intermediate prediction data entries for each generated prediction let predictionSet = predictions.map((pair: ProbabilityMass) => { + + // Overwrite the last entry with the prediction. + const components = [...intermediateTokens]; + + components[components.length - 1] = { + ...components[components.length - 1], + prediction: pair.sample + }; + // Let's not rely on the model to copy transform IDs. // Only bother is there IS an ID to copy. if(correction.sample.id !== undefined) { - pair.sample.transformId = correction.sample.id; + components.forEach((c) => c.prediction.transformId = correction.sample.id); } - let correctionText: string; - if(tokenizedCorrectionEntries.length != 1) { - correctionText = correction.sample.insert; - // deleteLeft: 0; it's pre-applied within preservationTransform. - } else { - // Use the deleteLeft & tokenize. - const postContext = models.applyTransform(correction.sample, context); - correctionText = wordbreak(postContext); - } - - let tuple: IntermediateCompositedPrediction = { - components: { - prediction: pair.sample, - correction: correctionText - }, + let tuple: IntermediateTokenizedPrediction = { + components, metadata: { probabilities: { prediction: pair.p, @@ -768,8 +823,7 @@ export function correctAndEnumerateWithoutTraversals( total: pair.p * correction.p }, autoSelectable: correctionValidForAutoSelect(tailCorrection.insert), - matchLevel: SuggestionSimilarity.none, - preservationTransform + matchLevel: SuggestionSimilarity.none } }; return tuple; @@ -789,18 +843,60 @@ export function correctAndEnumerateWithoutTraversals( * @param lexicalModel * @param casingForm */ -export function applySuggestionCasing(suggestion: Suggestion, baseWord: string, lexicalModel: LexicalModel, casingForm: CasingForm) { - // Step 1: does the suggestion replace the whole word? If not, we should extend the suggestion to do so. - let unchangedLength = KMWString.length(baseWord) - suggestion.transform.deleteLeft; +export function applySuggestionCasing(predictionToken: TokenizedPredictionData, lexicalModel: LexicalModel) { + const suggestion = predictionToken.prediction; - if(unchangedLength > 0) { - suggestion.transform.deleteLeft += unchangedLength; - suggestion.transform.insert = KMWString.substr(baseWord, 0, unchangedLength) + suggestion.transform.insert; + // Step 0: our pattern for generating predictions and corrections already + // enforces them to encompass the whole word. + + // Step 1: detect the original token's casing + let casingForm: CasingForm; + + // If we are using the context-tracking engine (when traversals are enabled), + // we just leverage the context token's exampleInput to determine casing. + // + // If it's not available, the correction entry reflects a word-broken piece of + // the original context, with its original casing - so we use that instead. + let casingRoot = predictionToken.source ? predictionToken.source.exampleInput : predictionToken.correction; + if(!casingRoot) { + // There's no text in place to verify casing expectations; just leave it + // unchanged. + return; } + casingForm = detectCurrentCasing(lexicalModel, { + left: casingRoot, + startOfBuffer: true, + endOfBuffer: true + }); + // Step 2: Now that the transform affects the whole word, we may safely apply casing rules. - suggestion.transform.insert = lexicalModel.applyCasing(casingForm, suggestion.transform.insert); - suggestion.displayAs = lexicalModel.applyCasing(casingForm, suggestion.displayAs); + if(casingForm && casingForm != 'lower') { + suggestion.transform.insert = lexicalModel.applyCasing(casingForm, suggestion.transform.insert); + suggestion.displayAs = lexicalModel.applyCasing(casingForm, suggestion.displayAs); + } +} + +export function compositeIntermediatePredictions(predictions: IntermediateTokenizedPrediction[]): IntermediateCompositedPrediction[] { + return predictions.map((predictionData) => { + const components = predictionData.components; + + return { + components: components.reduce((total, current) => { + const mergedTransform = models.buildMergedTransform(total.prediction.transform, current.prediction.transform); + const mergedDisplayAs = total.prediction.displayAs + current.prediction.displayAs + + return { + prediction: {...total.prediction, transform: mergedTransform, displayAs: mergedDisplayAs}, + correction: total.correction + current.correction + } + }, { + prediction: {...components[0].prediction, transform: { insert: '', deleteLeft: 0 }, displayAs: ''}, + correction: '' + }), + metadata: predictionData.metadata + }; + }); } /** diff --git a/web/src/test/auto/headless/engine/predictive-text/worker-thread/correction-search/early-correction-search-stopping.tests.ts b/web/src/test/auto/headless/engine/predictive-text/worker-thread/correction-search/early-correction-search-stopping.tests.ts index 430d9c6c7e0..9595f15527a 100644 --- a/web/src/test/auto/headless/engine/predictive-text/worker-thread/correction-search/early-correction-search-stopping.tests.ts +++ b/web/src/test/auto/headless/engine/predictive-text/worker-thread/correction-search/early-correction-search-stopping.tests.ts @@ -1,15 +1,15 @@ import { assert } from 'chai'; -import { CORRECTION_SEARCH_THRESHOLDS, IntermediateCompositedPrediction, ModelCompositor, shouldStopSearchingEarly } from "@keymanapp/lm-worker/test-index"; +import { CORRECTION_SEARCH_THRESHOLDS, IntermediateTokenizedPrediction, ModelCompositor, shouldStopSearchingEarly } from "@keymanapp/lm-worker/test-index"; -function mockIntermediatePrediction(value: number) { +function mockTokenizedPrediction(value: number) { return { metadata: { probabilities: { total: value } } - } as IntermediateCompositedPrediction + } as IntermediateTokenizedPrediction } describe('correction-search: shouldStopSearchingEarly', () => { @@ -22,7 +22,7 @@ describe('correction-search: shouldStopSearchingEarly', () => { assert.equal(predictionProbs.length, ModelCompositor.MAX_SUGGESTIONS, "test setup no longer valid"); // The only part for each entry we actually care about here: .totalProb. - const predictions = predictionProbs.map((entry) => mockIntermediatePrediction(entry)); + const predictions = predictionProbs.map((entry) => mockTokenizedPrediction(entry)); // Thresholding is performed in log-space. // 0.0501 and 0.0499 are offset on each side of 0.05, the last value in the array defined above. @@ -38,8 +38,8 @@ describe('correction-search: shouldStopSearchingEarly', () => { // // Can technically run the method with an empty array, but the actual scenario would have // at least one prediction present in the "found predictions" array. - assert.isFalse(shouldStopSearchingEarly(baseCost, baseCost + expectedThreshold - 0.01, [mockIntermediatePrediction(Math.exp(-1))])); - assert.isTrue(shouldStopSearchingEarly( baseCost, baseCost + expectedThreshold + 0.01, [mockIntermediatePrediction(Math.exp(-1))])); + assert.isFalse(shouldStopSearchingEarly(baseCost, baseCost + expectedThreshold - 0.01, [mockTokenizedPrediction(Math.exp(-1))])); + assert.isTrue(shouldStopSearchingEarly( baseCost, baseCost + expectedThreshold + 0.01, [mockTokenizedPrediction(Math.exp(-1))])); }); it('stops checking corrections earlier when enough predictions have been found', () => { @@ -48,7 +48,7 @@ describe('correction-search: shouldStopSearchingEarly', () => { // The only part for each entry we actually care about here: .totalProb. /** @type {import('#./predict-helpers.js').CorrectionPredictionTuple[]} */ - const predictions = predictionProbs.map((entry) => mockIntermediatePrediction(entry)); + const predictions = predictionProbs.map((entry) => mockTokenizedPrediction(entry)); const baseCost = 1; diff --git a/web/src/test/auto/headless/engine/predictive-text/worker-thread/prediction-helpers/predict-from-corrections.tests.ts b/web/src/test/auto/headless/engine/predictive-text/worker-thread/prediction-helpers/predict-from-corrections.tests.ts index 8234c6ba2a9..d18a2d92e43 100644 --- a/web/src/test/auto/headless/engine/predictive-text/worker-thread/prediction-helpers/predict-from-corrections.tests.ts +++ b/web/src/test/auto/headless/engine/predictive-text/worker-thread/prediction-helpers/predict-from-corrections.tests.ts @@ -114,11 +114,12 @@ describe('correctAndEnumerateWithoutTraversals', () => { const predictions = correctAndEnumerateWithoutTraversals(model, correctionDistribution, context); - predictions.forEach((entry) => assert.equal(entry.components.correction, 'Its')); + predictions.forEach((entry) => assert.equal(entry.components.length, 1)); + predictions.forEach((entry) => assert.equal(entry.components[0].correction, 'Its')); predictions.forEach((entry) => assert.equal(entry.metadata.probabilities.correction, 0.6)); predictions.sort(tupleDisplayOrderSort); - assert.sameDeepOrderedMembers(predictions.map((entry) => entry.components.prediction), dummied_suggestions); + assert.sameDeepOrderedMembers(predictions.map((entry) => entry.components[0].prediction), dummied_suggestions); assert.approximately(predictions[0].metadata.probabilities.total, 0.18 * 0.6, 0.00001); assert.approximately(predictions[1].metadata.probabilities.total, 0.02 * 0.6, 0.00001); @@ -167,12 +168,13 @@ describe('correctAndEnumerateWithoutTraversals', () => { const predictions = correctAndEnumerateWithoutTraversals(model, correctionDistribution, context); - predictions.forEach((entry) => assert.equal(entry.components.correction, 'Its')); + predictions.forEach((entry) => assert.equal(entry.components.length, 1)); + predictions.forEach((entry) => assert.equal(entry.components[0].correction, 'Its')); predictions.forEach((entry) => assert.equal(entry.metadata.probabilities.correction, 0.6)); predictions.sort(tupleDisplayOrderSort); - assert.sameOrderedMembers(predictions.map((entry) => entry.components.prediction.displayAs), ["it's", "its"]); - assert.sameDeepOrderedMembers(predictions.map((entry) => entry.components.prediction), dummied_suggestions.map((entry) => { + assert.sameOrderedMembers(predictions.map((entry) => entry.components[0].prediction.displayAs), ["it's", "its"]); + assert.sameDeepOrderedMembers(predictions.map((entry) => entry.components[0].prediction), dummied_suggestions.map((entry) => { entry = deepCopy(entry); entry.transformId = 314159; return entry; @@ -252,8 +254,9 @@ describe('correctAndEnumerateWithoutTraversals', () => { const predictions = correctAndEnumerateWithoutTraversals(model, correctionDistribution, context); predictions.sort(tupleDisplayOrderSort); - assert.sameOrderedMembers(predictions.map((entry) => entry.components.prediction.displayAs), ["is", "it's", "isn't", "its"]); - assert.sameDeepMembers(predictions.map((entry) => entry.components.prediction), dummied_suggestions.flatMap((entry) => entry)); + predictions.forEach((entry) => assert.equal(entry.components.length, 1)); + assert.sameOrderedMembers(predictions.map((entry) => entry.components[0].prediction.displayAs), ["is", "it's", "isn't", "its"]); + assert.sameDeepMembers(predictions.map((entry) => entry.components[0].prediction), dummied_suggestions.flatMap((entry) => entry)); assert.approximately(predictions[0].metadata.probabilities.total, 0.4 * 0.4, 0.00001); assert.approximately(predictions[1].metadata.probabilities.total, 0.18 * 0.6, 0.00001); diff --git a/web/src/test/auto/headless/engine/predictive-text/worker-thread/suggestion-casing.tests.ts b/web/src/test/auto/headless/engine/predictive-text/worker-thread/suggestion-casing.tests.ts index dd586eab646..7de4abac395 100644 --- a/web/src/test/auto/headless/engine/predictive-text/worker-thread/suggestion-casing.tests.ts +++ b/web/src/test/auto/headless/engine/predictive-text/worker-thread/suggestion-casing.tests.ts @@ -13,7 +13,7 @@ import * as wordBreakers from '@keymanapp/models-wordbreakers'; import { jsonFixture } from '@keymanapp/common-test-resources/model-helpers.mjs'; import { LexicalModelTypes } from '@keymanapp/common-types'; -import { applySuggestionCasing, models } from '@keymanapp/lm-worker/test-index'; +import { TokenizedPredictionData, applySuggestionCasing, models } from '@keymanapp/lm-worker/test-index'; import CasingFunction = LexicalModelTypes.CasingFunction; import TrieModel = models.TrieModel; @@ -45,117 +45,137 @@ describe('applySuggestionCasing', function() { ); it('properly cases suggestions with no suggestion root', function() { - let suggestion = { - transform: { - insert: 'the', - deleteLeft: 0 + let suggestion: TokenizedPredictionData[] = [{ + prediction: { + transform: { + insert: 'the', + deleteLeft: 0 + }, + displayAs: 'the' }, - displayAs: 'the' - }; - - applySuggestionCasing(suggestion, '', plainCasedModel, 'initial'); - assert.equal(suggestion.displayAs, 'The'); - assert.equal(suggestion.transform.insert, 'The'); - - suggestion = { - transform: { - insert: 'thE', - deleteLeft: 0 - }, - displayAs: 'thE' - }; - - applySuggestionCasing(suggestion, '', plainCasedModel, 'initial'); - assert.equal(suggestion.displayAs, 'ThE'); - assert.equal(suggestion.transform.insert, 'ThE'); - - suggestion = { - transform: { - insert: 'the', - deleteLeft: 0 + correction: '', + source: null + }]; + + applySuggestionCasing(suggestion[0], plainCasedModel); + assert.equal(suggestion[0].prediction.displayAs, 'the'); + assert.equal(suggestion[0].prediction.transform.insert, 'the'); + + suggestion = [{ + prediction: { + transform: { + insert: 'ThE', + deleteLeft: 0 + }, + displayAs: 'ThE' }, - displayAs: 'the' - }; + correction: '', + source: null + }]; - applySuggestionCasing(suggestion, '', plainCasedModel, 'upper'); - assert.equal(suggestion.displayAs, 'THE'); - assert.equal(suggestion.transform.insert, 'THE'); + applySuggestionCasing(suggestion[0], plainCasedModel); + assert.equal(suggestion[0].prediction.displayAs, 'ThE'); + assert.equal(suggestion[0].prediction.transform.insert, 'ThE'); }); it('properly cases suggestions that fully replace the suggestion root', function() { - let suggestion = { - transform: { - insert: 'therefore', - deleteLeft: 3 + let suggestion: TokenizedPredictionData[] = [{ + prediction: { + transform: { + insert: 'therefore', + deleteLeft: 3 + }, + displayAs: 'therefore' }, - displayAs: 'therefore' - }; - - applySuggestionCasing(suggestion, 'the', plainCasedModel, 'initial'); - assert.equal(suggestion.displayAs, 'Therefore'); - assert.equal(suggestion.transform.insert, 'Therefore'); - - suggestion = { - transform: { - insert: 'thereFore', - deleteLeft: 3 + correction: 'The', + source: null + }]; + + applySuggestionCasing(suggestion[0], plainCasedModel); + assert.equal(suggestion[0].prediction.displayAs, 'Therefore'); + assert.equal(suggestion[0].prediction.transform.insert, 'Therefore'); + + suggestion = [{ + prediction: { + transform: { + insert: 'thereFore', + deleteLeft: 3 + }, + displayAs: 'thereFore' }, - displayAs: 'thereFore' - }; - - applySuggestionCasing(suggestion, 'the', plainCasedModel, 'initial'); - assert.equal(suggestion.displayAs, 'ThereFore'); - assert.equal(suggestion.transform.insert, 'ThereFore'); - - suggestion = { - transform: { - insert: 'therefore', - deleteLeft: 3 + correction: 'The', + source: null + }]; + + applySuggestionCasing(suggestion[0], plainCasedModel); + assert.equal(suggestion[0].prediction.displayAs, 'ThereFore'); + assert.equal(suggestion[0].prediction.transform.insert, 'ThereFore'); + + suggestion = [{ + prediction: { + transform: { + insert: 'therefore', + deleteLeft: 3 + }, + displayAs: 'therefore' }, - displayAs: 'therefore' - }; + correction: 'THE', + source: null + }]; - applySuggestionCasing(suggestion, 'the', plainCasedModel, 'upper'); - assert.equal(suggestion.displayAs, 'THEREFORE'); - assert.equal(suggestion.transform.insert, 'THEREFORE'); + applySuggestionCasing(suggestion[0], plainCasedModel); + assert.equal(suggestion[0].prediction.displayAs, 'THEREFORE'); + assert.equal(suggestion[0].prediction.transform.insert, 'THEREFORE'); }); it('properly cases suggestions that do not fully replace the suggestion root', function() { - let suggestion = { - transform: { - insert: 'erefore', - deleteLeft: 1 + let suggestion: TokenizedPredictionData[] = [{ + prediction: { + transform: { + insert: 'therefore', + deleteLeft: 3 + }, + displayAs: 'therefore' }, - displayAs: 'therefore' - }; + correction: 'The', + source: null + }]; // When integrated, the 'the' string comes from a wordbreak operation on the current context. - applySuggestionCasing(suggestion, 'the', plainCasedModel, 'initial'); - assert.equal(suggestion.displayAs, 'Therefore'); - assert.equal(suggestion.transform.insert, 'Therefore'); - - suggestion = { - transform: { - insert: 'ereFore', - deleteLeft: 1 + applySuggestionCasing(suggestion[0], plainCasedModel); + assert.equal(suggestion[0].prediction.displayAs, 'Therefore'); + assert.equal(suggestion[0].prediction.transform.insert, 'Therefore'); + + suggestion = [{ + prediction: { + transform: { + insert: 'ThereFore', + deleteLeft: 3 + }, + displayAs: 'thereFore' }, - displayAs: 'thereFore' - }; - - applySuggestionCasing(suggestion, 'the', plainCasedModel, 'initial'); - assert.equal(suggestion.displayAs, 'ThereFore'); - assert.equal(suggestion.transform.insert, 'ThereFore'); - - suggestion = { - transform: { - insert: 'erefore', - deleteLeft: 1 + correction: 'The', + source: null + }]; + + applySuggestionCasing(suggestion[0], plainCasedModel); + assert.equal(suggestion[0].prediction.displayAs, 'ThereFore'); + assert.equal(suggestion[0].prediction.transform.insert, 'ThereFore'); + + suggestion = [{ + prediction: { + transform: { + insert: 'therefore', + deleteLeft: 3 + }, + displayAs: 'therefore' }, - displayAs: 'therefore' - }; + correction: 'THE', + source: null + }]; - applySuggestionCasing(suggestion, 'the', plainCasedModel, 'upper'); - assert.equal(suggestion.displayAs, 'THEREFORE'); - assert.equal(suggestion.transform.insert, 'THEREFORE'); + applySuggestionCasing(suggestion[0], plainCasedModel); + assert.equal(suggestion[0].prediction.displayAs, 'THEREFORE'); + assert.equal(suggestion[0].prediction.transform.insert, 'THEREFORE'); }); }); \ No newline at end of file