Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ import { ContextToken } from "./context-token.js";
import { CorrectionSearchable, PathResult } from "./correction-searchable.js";
import { ContextTokenization } from "./context-tokenization.js";
import { QuotientNodeFinalizer } from "./quotient-node-finalizer.js";
import { TokenizationResultMapping } from "./tokenization-result-mapping.js";
import { TokenizationResult, TokenizationResultMapping } from "./tokenization-result-mapping.js";
import { EDIT_DISTANCE_COST_SCALE } from "./distance-modeler.js";
import { MAX_EDIT_THRESHOLD_FACTOR } from "./search-quotient-spur.js";
import { TokenResultMapping } from "./token-result-mapping.js";

// PathResult needs to be generic:
// - a result for correcting a single Token - "TokenResult"?
Expand Down Expand Up @@ -46,7 +47,7 @@ export type TokenResult = {
* all correctable tokens, generating corrections for the full represented
* range.
*/
export class TokenizationCorrector implements CorrectionSearchable<ReadonlyArray<TokenResult>, TokenizationResultMapping> {
export class TokenizationCorrector implements CorrectionSearchable<TokenizationResult, TokenizationResultMapping> {
public readonly tokenization: ContextTokenization;
private readonly tailCorrectionLength: number;

Expand All @@ -56,6 +57,7 @@ export class TokenizationCorrector implements CorrectionSearchable<ReadonlyArray
private _predictable?: QuotientNodeFinalizer;
private _generatedTokenResults: Map<number, TokenResult>;
private _previousResults: TokenizationResultMapping[] = [];
private _correctableCodepointLength: number = 0;

// fully private
public readonly modelsCorrectables: boolean;
Expand All @@ -65,6 +67,7 @@ export class TokenizationCorrector implements CorrectionSearchable<ReadonlyArray
private lastTotalCost: number;
private handleHasBeenCalled: boolean = false;
private predictableMatchFound: boolean = false;
private matchableTokenCount = 0;

get currentCost(): number {
const correctable = this.selectionQueue.peek();
Expand Down Expand Up @@ -175,16 +178,23 @@ export class TokenizationCorrector implements CorrectionSearchable<ReadonlyArray
this.tokenLookupMap = new Map();
let modelsCorrectables = false;

// 0 index: the first index in range to be modeled, as split off from the main tokenization.
orderedTokens.forEach((token, index) => {
// New issue: this mangles the space IDs! We almost certainly need some
// sort of proper map to the source token.
const searchModule = new QuotientNodeFinalizer(token.searchModule, index == orderedTokens.length - 1);
this.tokenLookupMap.set(searchModule.spaceId, token);
const passesFilter = filterClosure(token);
// Index within the token subset being examined.
const passesFilter = filterClosure(token, index);
modelsCorrectables ||= passesFilter;
if(!passesFilter) {
this._uncorrectables.push(searchModule);
} else if(index == tailCorrectionLength - 1) {
return;
}

this.matchableTokenCount++;
this._correctableCodepointLength += searchModule.codepointLength;
if(index == tailCorrectionLength - 1) {
// The sole assignment case for this field. It may only be assigned for
// the final token, and only if its text is of a form considered
// correctable by the filter.
Expand Down Expand Up @@ -249,6 +259,10 @@ export class TokenizationCorrector implements CorrectionSearchable<ReadonlyArray
return new TokenizationResultMapping(results, this);
}

private get matchedTokenCount() {
return [...this._generatedTokenResults.values()].filter((r) => r instanceof TokenResultMapping).length;
}

// The actual method used to iteratively search for tokenization-level corrections.
handleNextNode(): PathResult<TokenizationResultMapping> {
// Notable states:
Expand All @@ -272,19 +286,24 @@ export class TokenizationCorrector implements CorrectionSearchable<ReadonlyArray
this.handleHasBeenCalled = true;
const results = this.collateResults();
this._previousResults.push(results);
return {
'type': 'complete',
cost: this.lastTotalCost,
mapping: results
};

// If no matchables exist, there's no prediction to do; don't make a return.
if(this.matchedTokenCount > 0) {
return {
'type': 'complete',
cost: this.lastTotalCost,
mapping: results
};
} else {
return { type: 'none' };
}
}
}

this.handleHasBeenCalled = true;

const correctableToUpdate = this.selectionQueue.dequeue();
const tokenResult = correctableToUpdate?.handleNextNode();

const delistCorrectable = () => {
if(correctableToUpdate != this._predictable) {
// Lock the 'correctable' token now that either a valid correction for
Expand All @@ -298,8 +317,12 @@ export class TokenizationCorrector implements CorrectionSearchable<ReadonlyArray
}

const correctionIsThePredictable = correctableToUpdate == this._predictable;

// TODO: adjust this._correctableCodepointLength when converting a token from
// correctable/predictable to uncorrectable!
if(tokenResult.type == 'none') {
if(!correctionIsThePredictable || !this.predictableMatchFound) {
this._correctableCodepointLength -= correctableToUpdate.codepointLength;
// Transition the node from 'correctable' to 'uncorrectable' - we were
// unable to find valid corrections for it.
const lockedResult = correctableToUpdate.bestExample;
Expand Down Expand Up @@ -359,25 +382,33 @@ export class TokenizationCorrector implements CorrectionSearchable<ReadonlyArray
this.selectionQueue.enqueue(this._predictable);
}

const correctionResults = this.collateResults();
if(correctionResults.matchedResult.findIndex((c) => c == undefined) != -1) {
// If any token lacks a matching lookup value, abort.
if([...this.tokenLookupMap.keys()].find((k) => !this._generatedTokenResults.has(k))) {
return {
type: 'intermediate',
cost: tokenizationCost
};
}
const correctionResults = this.collateResults();

// Determine the proper return type and construct the proper return object accordingly.
//
// If there was no result obtained from the predictable and a result was previously found,
// that indicates no further predictions may be found.
if(tokenResult.type != 'none' || !correctionIsThePredictable || !this.predictableMatchFound) {
this._previousResults.push(correctionResults);
return {
type: 'complete',
cost: tokenizationCost,
mapping: correctionResults
};

if(this.matchedTokenCount > 0) {
return {
type: 'complete',
cost: tokenizationCost,
mapping: correctionResults
};
} else {
return {
type: 'none'
}
}
} else {
return {
type: 'none'
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
import { CorrectionResultMapping } from "./correction-result-mapping.js";
import { TokenizationCorrector, TokenResult } from './tokenization-corrector.js';

export class TokenizationResultMapping implements CorrectionResultMapping<ReadonlyArray<TokenResult>> {
export interface TokenizationResult {
tokenCorrections: ReadonlyArray<TokenResult>,
totalEditCount: number,
totalEditableCodepoints: number
}

export class TokenizationResultMapping implements CorrectionResultMapping<TokenizationResult> {
readonly matchingSpace: TokenizationCorrector;
readonly matchedResult: ReadonlyArray<TokenResult>;
readonly matchedResult: TokenizationResult;

constructor(tokenization: TokenResult[], corrector: TokenizationCorrector) {
this.matchingSpace = corrector;
this.matchedResult = tokenization;
this.matchedResult = {
tokenCorrections: tokenization,
totalEditCount: tokenization.reduce((accum, curr) => accum + curr.knownCost, 0),
totalEditableCodepoints: 0 //corrector.
}
}

get spaceId(): number {
Expand All @@ -22,7 +32,7 @@ export class TokenizationResultMapping implements CorrectionResultMapping<Readon
// * `totalCost`.)
// */
// get knownCost(): number {
// return this.node.editCount;
// return this.matchedResult.tokenCorrections.reduce((accum, curr) => accum + curr.knownCost, 0);
// }

// /**
Expand All @@ -40,6 +50,6 @@ export class TokenizationResultMapping implements CorrectionResultMapping<Readon
* to the resulting output.
*/
get totalCost(): number {
return this.matchedResult.reduce((total, curr) => total + curr.totalCost, 0);
return this.matchedResult.tokenCorrections.reduce((total, curr) => total + curr.totalCost, 0);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -410,15 +410,15 @@ export function buildAndMapPredictions(
endOfBuffer: false
};

const correctionTransforms = tokenizationCorrection.matchedResult.map((correction, i) => {
const correctionTransforms = tokenizationCorrection.matchedResult.tokenCorrections.map((correction, i) => {
return {
insert: correction.matchString, // insert correction string
deleteLeft: i == 0 ? deleteLeft : 0,
id: transition.transitionId // The correction should always be based on the most recent external transform/transcription ID.
};
});

const correctionCost = tokenizationCorrection.matchedResult.map((correction) => {
const correctionCost = tokenizationCorrection.matchedResult.tokenCorrections.map((correction) => {
let rootCost = correction.totalCost;
/* If we're dealing with the FIRST keystroke of a new sequence, we'll **dramatically** boost
* the exponent to ensure only VERY nearby corrections have a chance of winning, and only if
Expand Down Expand Up @@ -452,7 +452,16 @@ export function buildAndMapPredictions(
}).reduce((accum, curr) => accum * curr, 1);

const predictionComponents = correctionTransforms.map((correctionTransform, i) => {
const predictions = model.predict(correctionTransform, emptyContext);
let predictions = model.predict(correctionTransform, emptyContext);

// Ensure codepointLength == prediction codepoint length if i does not match the tail!
// Filter out cases that do not conform to this condition.
if(i != correctionTransforms.length - 1) {
predictions = predictions.filter((p) => {
const codepointLength = tokenizationCorrection.matchingSpace.orderedTokens[i].searchModule.codepointLength;
return KMWString.length(p.sample.transform.insert) == codepointLength;
});
}

// Failsafe: if there are no matching predictions, create a fake prediction
// matching the original text.
Expand Down Expand Up @@ -577,7 +586,8 @@ export function prepareTokenizationSearch(
return new TokenizationCorrector(tuple.tokenization, mutatedLength, (token, index) => {
return index >= unaffectedTokenCount // is a modified token
&& index == mutatedLength - 1 // TEMP: adjacent to the caret (TO BE REMOVED)
&& correctionValidForAutoSelect(token.exampleInput); // and is eligible text-correction
// and is eligible for text-correction
&& (token.searchModule.codepointLength == 0 || correctionValidForAutoSelect(token.exampleInput));
});
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ import {
SubstitutionQuotientSpur,
TokenizationCorrector,
TokenResult,
TokenizationResultMapping
TokenizationResultMapping,
TokenizationResult
} from '@keymanapp/lm-worker/test-index';

import Distribution = LexicalModelTypes.Distribution;
Expand Down Expand Up @@ -302,7 +303,7 @@ describe('TokenizationCorrector', () => {
assert.equal(searchResult.type, 'complete');
if(searchResult.type == 'complete') {
const mapping = searchResult.mapping;
const tokenResults = mapping.matchedResult;
const tokenResults = mapping.matchedResult.tokenCorrections;
assert.isNotNaN(searchResult.cost);
assert.equal(searchResult.cost, searchResult.mapping.totalCost);
assert.equal(tokenResults.length, 1);
Expand All @@ -327,7 +328,7 @@ describe('TokenizationCorrector', () => {
assert.equal(searchResult.type, 'none');
});

it('finds a default correction for a single correctable token without a model match', () => {
it('returns no result when a single correctable token lacks a model match', () => {
const fixture = buildFixture_therefore();

const theref = fixture.theref.tail;
Expand Down Expand Up @@ -371,23 +372,6 @@ describe('TokenizationCorrector', () => {
searchResult = instance.handleNextNode();
} while(searchResult.type == 'intermediate');

assert.equal(searchResult.type, 'complete');
if(searchResult.type == 'complete') {
const mapping = searchResult.mapping;
const tokenResults = mapping.matchedResult;
assert.isNotNaN(searchResult.cost);
assert.equal(searchResult.cost, searchResult.mapping.totalCost);
assert.equal(tokenResults.length, 1);
assert.sameOrderedMembers(tokenResults.map((r) => r.matchString), ['therefxyz']);

// Now that an entry has been found, verify the corrector's state.
assert.isNotOk(instance.predictableToken); // should become an uncorrectable.
assert.isTrue(instance.generatedTokenResults.has(therefxyz));
assert.equal(instance.generatedTokenResults.get(therefxyz), tokenResults[0]);
}

// There should be no further possible suggestions.
searchResult = instance.handleNextNode();
assert.equal(searchResult.type, 'none');
});

Expand All @@ -411,7 +395,7 @@ describe('TokenizationCorrector', () => {
let firstResults: ReadonlyArray<TokenResult>;
if(searchResult.type == 'complete') {
const mapping = searchResult.mapping;
const tokenResults = mapping.matchedResult;
const tokenResults = mapping.matchedResult.tokenCorrections;
firstResults = tokenResults;
assert.isNotNaN(searchResult.cost);
assert.equal(searchResult.cost, searchResult.mapping.totalCost);
Expand All @@ -434,7 +418,7 @@ describe('TokenizationCorrector', () => {
searchResult = instance.handleNextNode();
if(searchResult.type == 'complete') {
const mapping = searchResult.mapping;
const tokenResults = mapping.matchedResult;
const tokenResults = mapping.matchedResult.tokenCorrections;

// Verify that the first (bound) token is not altered further.
// It should receive no further correction attempts.
Expand All @@ -445,7 +429,7 @@ describe('TokenizationCorrector', () => {
} while(searchResult.type != 'none');
});

it('immediately returns a single result when the only represented token is uncorrectable', () => {
it('immediately returns with no result when the only represented token is uncorrectable', () => {
const fixture = buildFixture_terminalWhitespace();

const tokenization = fixture.spaceOnly;
Expand All @@ -457,13 +441,7 @@ describe('TokenizationCorrector', () => {
);

const searchResult = instance.handleNextNode();
assert.equal(searchResult.type, 'complete');
if(searchResult.type == 'complete') {
assert.equal(searchResult.mapping.matchedResult[0].matchString, ' ');
}

const nilResult = instance.handleNextNode();
assert.equal(nilResult.type, 'none');
assert.equal(searchResult.type, 'none');
});

it('returns a single result when the final token is uncorrectable', () => {
Expand All @@ -484,8 +462,8 @@ describe('TokenizationCorrector', () => {

assert.equal(searchResult.type, 'complete');
if(searchResult.type == 'complete') {
assert.equal(searchResult.mapping.matchedResult[0].matchString, 'space');
assert.equal(searchResult.mapping.matchedResult[1].matchString, ' ');
assert.equal(searchResult.mapping.matchedResult.tokenCorrections[0].matchString, 'space');
assert.equal(searchResult.mapping.matchedResult.tokenCorrections[1].matchString, ' ');
}

const nilResult = instance.handleNextNode();
Expand All @@ -502,20 +480,20 @@ describe('TokenizationCorrector', () => {
let haveSeenSingleTokenCorrection = false;
let haveSeenThreeTokenCorrection = false;
for await(let phraseMatch of getBestMatches<
ReadonlyArray<TokenResult>,
TokenizationResult,
TokenizationResultMapping,
TokenizationCorrector
>(correctors, buildTestTimer())) {

if(phraseMatch.matchedResult.length == 1) {
if(phraseMatch.matchedResult.tokenCorrections.length == 1) {
if(!haveSeenSingleTokenCorrection) {
assert.sameOrderedMembers(phraseMatch.matchedResult.map((t) => t.matchString), ['theref' /* -ore */]);
assert.sameOrderedMembers(phraseMatch.matchedResult.tokenCorrections.map((t) => t.matchString), ['theref' /* -ore */]);
}

haveSeenSingleTokenCorrection = true;
} else if(phraseMatch.matchedResult.length == 3) {
} else if(phraseMatch.matchedResult.tokenCorrections.length == 3) {
if(!haveSeenThreeTokenCorrection) {
assert.sameOrderedMembers(phraseMatch.matchedResult.map((t) => t.matchString), ['the', ' ', 'ef' /* -fort */]);
assert.sameOrderedMembers(phraseMatch.matchedResult.tokenCorrections.map((t) => t.matchString), ['the', ' ', 'ef' /* -fort */]);
}
haveSeenThreeTokenCorrection = true;
}
Expand Down
Loading