-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
190 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import { describe, expect, test } from 'vitest' | ||
import { createStoredCorpus } from '../corpus' | ||
import { calculateTFIDF, createIndexForTFIDF } from './tfidf' | ||
|
||
describe('createIndexForTFIDF', () => { | ||
const corpus = createStoredCorpus([ | ||
{ docID: 1, text: `a b c c c` }, | ||
{ docID: 2, text: `b c d` }, | ||
{ docID: 3, text: `c d e` }, | ||
]) | ||
const docIDs = corpus.docs.map(({ doc: { docID } }) => docID) | ||
const tfidf = createIndexForTFIDF(corpus) | ||
|
||
test.only('term in 1 doc', () => { | ||
expect(docIDs.map(docID => tfidf('a', docID, 0))).toEqual([ | ||
calculateTFIDF({ termOccurrencesInChunk: 1, chunkTermLength: 5, totalChunks: 3, termChunkFrequency: 1 }), | ||
0, | ||
0, | ||
]) | ||
}) | ||
|
||
test('term in all docs', () => { | ||
expect(docIDs.map(docID => tfidf('c', docID, 0))).toEqual([ | ||
calculateTFIDF({ termOccurrencesInChunk: 3, chunkTermLength: 5, totalChunks: 3, termChunkFrequency: 3 }), | ||
calculateTFIDF({ termOccurrencesInChunk: 1, chunkTermLength: 3, totalChunks: 3, termChunkFrequency: 3 }), | ||
calculateTFIDF({ termOccurrencesInChunk: 1, chunkTermLength: 3, totalChunks: 3, termChunkFrequency: 3 }), | ||
]) | ||
}) | ||
|
||
test('unknown term', () => { | ||
expect(docIDs.map(docID => tfidf('x', docID, 0))).toEqual([0, 0, 0]) | ||
}) | ||
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
import { ChunkIndex, DocID, StoredCorpus } from '../corpus' | ||
|
||
/** | ||
* TF-IDF is a way of measuring the relevance of a term to a document in a corpus. See | ||
* https://en.wikipedia.org/wiki/Tf%E2%80%93idf. | ||
* | ||
* TF-IDF = TF * IDF | ||
* - TF = number of occurrences of term in the chunk / number of (non-unique) terms in the chunk | ||
* - IDF = log(number of chunks / number of chunks containing the term) | ||
*/ | ||
export type TFIDF = (term: Term, docID: DocID, chunk: ChunkIndex) => number | ||
|
||
/** | ||
* Index the corpus for fast computation of TF-IDF. @see {TFIDF} | ||
*/ | ||
export function createIndexForTFIDF(storage: StoredCorpus): TFIDF { | ||
/** | ||
* Document -> chunk index -> term -> number of occurrences of term in the chunk. | ||
* | ||
* "TF" in "TF-IDF" (with chunks instead of documents as the unit of analysis). | ||
*/ | ||
const termFrequency = new Map<DocID, Map<Term, number>[]>() | ||
|
||
/** | ||
* Document -> chunk index -> number of (non-unique) terms in the chunk. | ||
*/ | ||
const termLength = new Map<DocID, number[]>() | ||
|
||
/** | ||
* Term -> number of chunks containing the term. | ||
* | ||
* "DF" in "IDF" in "TF-IDF" (with chunks instead of documents as the unit of analysis). | ||
*/ | ||
const chunkFrequency = new Map<Term, number>() | ||
|
||
let totalChunks = 0 | ||
|
||
for (const { doc, chunks } of storage.docs) { | ||
const docTermFrequency: Map<Term, number>[] = new Array(chunks.length) | ||
termFrequency.set(doc.docID, docTermFrequency) | ||
|
||
const docTermLength: number[] = new Array(chunks.length) | ||
termLength.set(doc.docID, docTermLength) | ||
|
||
for (const [i, chunk] of chunks.entries()) { | ||
const chunkTerms = terms(chunk.text) | ||
|
||
// Set chunk frequencies. | ||
for (const uniqueTerm of new Set<Term>(chunkTerms).values()) { | ||
chunkFrequency.set(uniqueTerm, (chunkFrequency.get(uniqueTerm) ?? 0) + 1) | ||
} | ||
|
||
// Set term frequencies. | ||
const chunkTermFrequency = new Map<Term, number>() | ||
docTermFrequency[i] = chunkTermFrequency | ||
for (const term of chunkTerms) { | ||
chunkTermFrequency.set(term, (chunkTermFrequency.get(term) ?? 0) + 1) | ||
} | ||
|
||
// Set term cardinality. | ||
docTermLength[i] = chunkTerms.length | ||
|
||
// Increment total chunks. | ||
totalChunks++ | ||
} | ||
} | ||
|
||
return (termRaw: string, doc: DocID, chunk: ChunkIndex): number => { | ||
const processedTerms = terms(termRaw) | ||
if (processedTerms.length !== 1) { | ||
throw new Error(`term ${JSON.stringify(termRaw)} is not a single term`) | ||
} | ||
const term = processedTerms[0] | ||
|
||
const docTermLength = termLength.get(doc) | ||
if (!docTermLength) { | ||
throw new Error(`doc ${doc} not found in termLength`) | ||
} | ||
if (typeof docTermLength[chunk] !== 'number') { | ||
throw new Error(`chunk ${chunk} not found in termLength for doc ${doc}`) | ||
} | ||
|
||
const docTermFrequency = termFrequency.get(doc) | ||
if (!docTermFrequency) { | ||
throw new Error(`doc ${doc} not found in termFrequency`) | ||
} | ||
if (!(docTermFrequency[chunk] instanceof Map)) { | ||
throw new Error(`chunk ${chunk} not found in termFrequency for doc ${doc}`) | ||
} | ||
|
||
return calculateTFIDF({ | ||
termOccurrencesInChunk: docTermFrequency[chunk].get(term) ?? 0, | ||
chunkTermLength: docTermLength[chunk], | ||
totalChunks, | ||
termChunkFrequency: chunkFrequency.get(term) ?? 0, | ||
}) | ||
} | ||
} | ||
|
||
/** | ||
* Calculate TF-IDF given the formula inputs. @see {TFIDF} | ||
* | ||
* Use {@link createIndexForTFIDF} instead of calling this directly. | ||
*/ | ||
export function calculateTFIDF({ | ||
termOccurrencesInChunk, | ||
chunkTermLength, | ||
totalChunks, | ||
termChunkFrequency, | ||
}: { | ||
termOccurrencesInChunk: number | ||
chunkTermLength: number | ||
totalChunks: number | ||
termChunkFrequency: number | ||
}): number { | ||
return (termOccurrencesInChunk / chunkTermLength) * Math.log((1 + totalChunks) / (1 + termChunkFrequency)) | ||
} | ||
|
||
type Term = string | ||
|
||
/** | ||
* All terms in the text, with normalization and stemming applied. | ||
*/ | ||
function terms(text: string): Term[] { | ||
return ( | ||
text | ||
.toLowerCase() | ||
.split(/[^a-zA-Z0-9-_]+/) | ||
// TODO(sqs): get a real stemmer | ||
.map(term => term.replace(/(.*)(?:es|ed|ing|s|ed|ing)$/, '$1')) | ||
) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters