Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ export class TokenizationResultMapping implements CorrectionResultMapping<Readon
readonly matchingSpace: TokenizationCorrector;
readonly matchedResult: ReadonlyArray<TokenResult>;

constructor(tokenization: TokenResult[], corrector: TokenizationCorrector) {
constructor(tokenization: TokenResult[], corrector?: TokenizationCorrector) {
this.matchingSpace = corrector;
this.matchedResult = tokenization;
}

get spaceId(): number {
return this.matchingSpace.tokenization.spaceId;
return this.matchingSpace?.tokenization.spaceId;
}

// /**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import { ExecutionTimer } from './correction/execution-timer.js';
import { ModelCompositor } from './model-compositor.js';
import { EDIT_DISTANCE_COST_SCALE, getBestTokenMatches } from './correction/distance-modeler.js';
import { TokenResult } from './correction/tokenization-corrector.js';
import { TokenizationCorrector } from './correction/tokenization-corrector.js';
import { TokenizationResultMapping } from './correction/tokenization-result-mapping.js';

import CasingForm = LexicalModelTypes.CasingForm;
import Context = LexicalModelTypes.Context;
Expand Down Expand Up @@ -233,7 +235,8 @@ export function determineTraversallessCorrectionSequences(
return match;
});

const suggestionParams = buildCorrectionSequence(transitionEffects, context, correctionRoots[correctionRoots.length - 1]);
// But, for now, only actually use the last one.
const suggestionParams = buildCorrectionSequence(transitionEffects, context, new TokenizationResultMapping([correctionRoots[correctionRoots.length - 1]], null));

const tokenizationMapping = mapWhitespacedTokenization(tokenization.left.map((t) => { return {exampleInput: t.text, codepointLength: KMWString.length(t.text)} }), lexicalModel, correction.sample);
const tokenizedCorrection = tokenizationMapping.tokenizedTransform;
Expand Down Expand Up @@ -377,12 +380,8 @@ export function determineSuggestionRange<T extends ContextTokenLike>(
return temp(a, b);
}

const deleteLeftCalc = (tokenSet: T[], predictCount: number) => {
// TODO: once we start activating multi-tokenization for real, only the
// 'reduce' component should remain.
return (predictCount > 1)
? (tokenSet[tokenSet.length - 1]?.codepointLength ?? 0)
: tokenSet.reduce((prev, curr) => prev + curr.codepointLength, 0);
const deleteLeftCalc = (tokenSet: T[]) => {
return tokenSet.reduce((prev, curr) => prev + curr.codepointLength, 0);
}

const tokenSetA = userContextTokenization.slice();
Expand All @@ -396,7 +395,7 @@ export function determineSuggestionRange<T extends ContextTokenLike>(
return {
tokensToRemove: tokenSetA,
tokensToPredict: tokenSetB,
deleteLeft: deleteLeftCalc(tokenSetA, tokenSetB.length)
deleteLeft: deleteLeftCalc(tokenSetA)
}
} else if(aHeadIndexInB != 0 && bHeadIndexInA != 0) {
throw new Error("Leading edge of context should not differ in both tokenizations.");
Expand All @@ -422,7 +421,7 @@ export function determineSuggestionRange<T extends ContextTokenLike>(
return {
tokensToRemove,
tokensToPredict,
deleteLeft: deleteLeftCalc(tokensToRemove, tokensToPredict.length)
deleteLeft: deleteLeftCalc(tokensToRemove)
}
}

Expand Down Expand Up @@ -456,44 +455,49 @@ export interface PredictionParameters {
}

export function buildCorrectionSequence(
transitionEffects: ReturnType<typeof determineSuggestionRange>,
transitionEffects: SuggestionReplacement<ContextTokenLike>,
context: Context,
match: Readonly<TokenResult>,
tokenizationCorrection: TokenizationResultMapping
) {
const { deleteLeft } = transitionEffects;

const rootContext = models.applyTransform({insert: '', deleteLeft}, context);

// Replace the existing context with the correction.
const correctionTransform: Transform = {
insert: match.matchString, // insert correction string
deleteLeft: 0,
}
const tokenizedCorrections = tokenizationCorrection.matchedResult.map((correction, i) => {
/* If we're dealing with the FIRST keystroke of a new sequence, we'll **dramatically** boost
* the exponent to ensure only VERY nearby corrections have a chance of winning, and only if
* there are significantly more likely words. We only need this to allow very minor fat-finger
* adjustments for 100% keystroke-sequence corrections in order to prevent finickiness on
* key borders.
*
* Technically, the probabilities this produces won't be normalized as-is... but there's no
* true NEED to do so for it, even if it'd be 'nice to have'. Consistently tracking when
* to apply it could become tricky, so it's simpler to leave out.
*
* Worst-case, it's possible to temporarily add normalization if a code deep-dive
* is needed in the future.
*/
const costFactor = (correction.inputCount <= 1) ? ModelCompositor.SINGLE_CHAR_KEY_PROB_EXPONENT : 1;

const entry = {
sample: {
insert: correction.matchString, // insert correction string
deleteLeft: 0,
} as Transform,
p: Math.exp(-correction.totalCost * costFactor)
};

/* If we're dealing with the FIRST keystroke of a new sequence, we'll **dramatically** boost
* the exponent to ensure only VERY nearby corrections have a chance of winning, and only if
* there are significantly more likely words. We only need this to allow very minor fat-finger
* adjustments for 100% keystroke-sequence corrections in order to prevent finickiness on
* key borders.
*
* Technically, the probabilities this produces won't be normalized as-is... but there's no
* true NEED to do so for it, even if it'd be 'nice to have'. Consistently tracking when
* to apply it could become tricky, so it's simpler to leave out.
*
* Worst-case, it's possible to temporarily add normalization if a code deep-dive
* is needed in the future.
*/
const costFactor = (match.inputCount <= 1) ? ModelCompositor.SINGLE_CHAR_KEY_PROB_EXPONENT : 1;

const rootCost = match.totalCost;
const predictionRoot = {
sample: correctionTransform,
p: Math.exp(-rootCost * costFactor)
};
if(transitionEffects.transitionId !== undefined) {
entry.sample.id = transitionEffects.transitionId;
}

return entry;
});

return {
rootContext,
tokenizedCorrection: [predictionRoot]
tokenizedCorrection: tokenizedCorrections
};
}

Expand All @@ -512,10 +516,11 @@ export function buildCorrectionSequence(
export function determineTokenizedCorrectionSequence(
transition: ContextTransition,
tokenization: ContextTokenization,
match: Readonly<TokenResult>
match: TokenizationResultMapping
): PredictionParameters {
const applicationTarget = transition.base.displayTokenization;
const transitionParams = determineSuggestionRange(applicationTarget.tokens, tokenization.tokens, (a, b) => a.spaceId == b.spaceId);
transitionParams.transitionId = transition.transitionId;

const suggestionParams = buildCorrectionSequence(transitionParams, transition.base.context, match);

Expand Down Expand Up @@ -641,7 +646,11 @@ export async function correctAndEnumerate(
continue;
}

const predictionPrep = determineTokenizedCorrectionSequence(transition, tokenization, match);
const suggestionRange = determineSuggestionRange(transition.base.displayTokenization.tokens, tokenization.tokens, (a, b) => a.spaceId == b.spaceId);
suggestionRange.transitionId = transition.transitionId;
const corrector = new TokenizationCorrector(tokenization, suggestionRange.tokensToPredict.length, () => true);
const predictionPrep = determineTokenizedCorrectionSequence(transition, tokenization, new TokenizationResultMapping([match], corrector));

Comment on lines +649 to +653

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This gets the main correction-loop one step closer to integrating multi-token capabilities. It's just a stop-gap in this PR, though.

const predictions = predictFromCorrectionSequence(lexicalModel, predictionPrep.tokenizedCorrection, predictionPrep.rootContext, transition.transitionId);
predictions.forEach((p) => predictionPrep.applyInPost(p));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,16 @@ import * as wordBreakers from '@keymanapp/models-wordbreakers';
import { jsonFixture } from '@keymanapp/common-test-resources/model-helpers.mjs';
import { KMWString } from 'keyman/common/web-utils';

import { determineTokenizedCorrectionSequence, models, ContextState, ContextToken, ContextTokenization, CorrectionPredictionTuple, ModelCompositor } from "@keymanapp/lm-worker/test-index";
import {
determineTokenizedCorrectionSequence,
models,
ContextState,
ContextToken,
ContextTokenization,
CorrectionPredictionTuple,
ModelCompositor,
TokenizationResultMapping
} from "@keymanapp/lm-worker/test-index";

import Context = LexicalModelTypes.Context;
import ProbabilityMass = LexicalModelTypes.ProbabilityMass;
Expand Down Expand Up @@ -51,13 +60,14 @@ describe('determineTokenizedCorrectionSequence', () => {

const results = determineTokenizedCorrectionSequence(
transition,
transition.final.displayTokenization, {
transition.final.displayTokenization,
new TokenizationResultMapping([{
matchString: 'fo',
inputSamplingCost: -Math.log(trueInput.p),
inputCount: 2,
knownCost: 0,
totalCost: -Math.log(trueInput.p)
}
}], null)
);

assert.deepEqual({...results.rootContext, casingForm: results.rootContext.casingForm}, {
Expand Down Expand Up @@ -101,13 +111,14 @@ describe('determineTokenizedCorrectionSequence', () => {

const results = determineTokenizedCorrectionSequence(
transition,
transition.final.displayTokenization, {
transition.final.displayTokenization,
new TokenizationResultMapping([{
matchString: ' ',
inputSamplingCost: -Math.log(trueInput.p),
inputCount: 1,
knownCost: 0,
totalCost: -Math.log(trueInput.p)
}
}], null)
);

assert.deepEqual({...results.rootContext, casingForm: results.rootContext.casingForm}, {
Expand Down Expand Up @@ -148,13 +159,14 @@ describe('determineTokenizedCorrectionSequence', () => {

const results = determineTokenizedCorrectionSequence(
transition,
transition.final.displayTokenization, {
transition.final.displayTokenization,
new TokenizationResultMapping([{
matchString: 'f',
inputSamplingCost: -Math.log(trueInput.p),
inputCount: 1,
knownCost: 0,
totalCost: -Math.log(trueInput.p)
}
}], null)
);

assert.deepEqual({...results.rootContext, casingForm: results.rootContext.casingForm}, {
Expand Down Expand Up @@ -202,13 +214,14 @@ describe('determineTokenizedCorrectionSequence', () => {

const results = determineTokenizedCorrectionSequence(
transition,
transition.final.displayTokenization, {
transition.final.displayTokenization,
new TokenizationResultMapping([{
matchString: 'can\'t',
inputSamplingCost: -Math.log(trueInput.p),
inputCount: 5,
knownCost: 0,
totalCost: -Math.log(trueInput.p)
}
}], null)
);

assert.deepEqual({...results.rootContext, casingForm: results.rootContext.casingForm}, {
Expand Down Expand Up @@ -253,13 +266,14 @@ describe('determineTokenizedCorrectionSequence', () => {

const results = determineTokenizedCorrectionSequence(
transition,
transition.final.displayTokenization, {
transition.final.displayTokenization,
new TokenizationResultMapping([{
matchString: ' ',
inputSamplingCost: -Math.log(trueInput.p),
inputCount: 1,
knownCost: 0,
totalCost: -Math.log(trueInput.p)
}
}], null)
);

assert.deepEqual({...results.rootContext, casingForm: results.rootContext.casingForm}, {
Expand Down Expand Up @@ -300,26 +314,26 @@ describe('determineTokenizedCorrectionSequence', () => {

const results = determineTokenizedCorrectionSequence(
transition,
transition.final.displayTokenization, {
transition.final.displayTokenization,
new TokenizationResultMapping([{
matchString: 'd',
inputSamplingCost: -Math.log(trueInput.p),
inputCount: 1,
knownCost: 0,
totalCost: -Math.log(trueInput.p)
}
}], null)
);

// Large-scale deletions will receive enhanced handling soon. But, for now, it's
// deleted by the `preservationTransform`, not here.
assert.deepEqual({...results.rootContext, casingForm: results.rootContext.casingForm}, {
casingForm: undefined,
left: 'the quick brown ',
left: 'the ',
right: '',
startOfBuffer: true,
endOfBuffer: true
});


// Coming up next - actually providing ALL correction elements, not just the final one.
// We're not _quite_ ready for that yet, though.
assert.equal(results.tokenizedCorrection.length, 1);
assert.deepEqual(results.tokenizedCorrection[0].sample, {
insert: 'd',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -361,15 +361,15 @@ describe('determineTraversallessCorrectionSequences', () => {
...entry.rootContext, casingForm: entry.rootContext.casingForm ?? undefined
}, {
casingForm: undefined,
// Large-scale deletions will receive enhanced handling soon. But, for now, it's
// deleted by the `preservationTransform`, not here.
left: 'the quick brown ',
left: 'the ',
right: '',
startOfBuffer: true,
endOfBuffer: true
}
);

// Coming up next - actually providing ALL correction elements, not just the final one.
// We're not _quite_ ready for that yet, though.
assert.equal(entry.tokenizedCorrection.length, 1);
assert.deepEqual(entry.tokenizedCorrection[0].sample, {
insert: 'd',
Expand Down
Loading