-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor(tf-encoder): refactor TF encoder
- Loading branch information
Showing
4 changed files
with
394 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,193 @@ | ||
import { GraphModel } from '@cm2ml/ir' | ||
import { ExecutionError, defineStructuredBatchPlugin } from '@cm2ml/plugin' | ||
|
||
import { DEFAULT_STOP_WORDS } from './stop-words' | ||
import type { ModelTerms } from './term-extraction' | ||
import { extractModelTerms } from './term-extraction' | ||
import { createTermDocumentMatrix, createTermList } from './term-frequency' | ||
|
||
export const biGramTermTypes = ['name', 'type', 'attribute'] as const | ||
export type BiGramTermType = typeof biGramTermTypes[number] | string & Record<never, never> | ||
|
||
export interface EncoderParameters { | ||
namesAsTerms: boolean | ||
typesAsTerms: boolean | ||
attributesAsTerms: readonly string[] | ||
tokenize: boolean | ||
stem: boolean | ||
stopWords: readonly string[] | ||
normalizeTf: boolean | ||
tfIdf: boolean | ||
frequencyCutoff: number | ||
// bi-gram | ||
bigramEnabled: boolean | ||
bigramSeparator: string | ||
bigramFirstTerm: BiGramTermType | ||
bigramFirstTermAttribute: string | ||
bigramSecondTerm: BiGramTermType | ||
bigramSecondTermAttribute: string | ||
} | ||
|
||
export interface TermFrequencyEncoding { | ||
termDocumentMatrix: TermDocumentMatrix | ||
termList: string[] | ||
modelIds: string[] | ||
} | ||
|
||
/** | ||
* Represents a term-document matrix where model IDs are rows and term occurrences are columns | ||
* | ||
* The matrix structure is as follows: | ||
* |---------|-------|-------|-------|-----| | ||
* | modelId1| 1 | 0 | 0 | ... | | ||
* | modelId2| 0 | 1 | 0 | ... | | ||
* | modelId3| 0 | 0 | 1 | ... | | ||
* | ... | ... | ... | ... | ... | | ||
* |---------|-------|-------|-------|-----| | ||
* | ||
* The concrete terms can be retrieved from the termList in the encoding @see {@link TermFrequencyEncoding} | ||
*/ | ||
export type TermDocumentMatrix = Record<string, number[]> | ||
|
||
export const TermFrequencyEncoder = defineStructuredBatchPlugin({ | ||
name: 'term-frequency', | ||
parameters: { | ||
namesAsTerms: { | ||
type: 'boolean', | ||
defaultValue: true, | ||
description: 'Encode names as terms', | ||
group: 'terms', | ||
displayName: 'Names as terms', | ||
}, | ||
typesAsTerms: { | ||
type: 'boolean', | ||
defaultValue: false, | ||
description: 'Encode types as terms', | ||
group: 'terms', | ||
displayName: 'Types as terms', | ||
}, | ||
attributesAsTerms: { | ||
type: 'array<string>', | ||
defaultValue: [], | ||
description: 'Additional attributes to encode as terms', | ||
group: 'terms', | ||
displayName: 'Attributes as terms', | ||
}, | ||
tokenize: { | ||
type: 'boolean', | ||
defaultValue: false, | ||
description: 'Split and clean terms into separate tokens', | ||
group: 'term-normalization', | ||
displayName: 'Tokenize terms', | ||
}, | ||
stem: { | ||
type: 'boolean', | ||
defaultValue: false, | ||
description: 'Apply stemming to terms', | ||
group: 'term-normalization', | ||
displayName: 'Stem terms', | ||
}, | ||
stopWords: { | ||
type: 'array<string>', | ||
defaultValue: DEFAULT_STOP_WORDS, | ||
description: 'List of stop words to remove from the term list', | ||
group: 'term-normalization', | ||
displayName: 'Stop words', | ||
}, | ||
normalizeTf: { | ||
type: 'boolean', | ||
defaultValue: false, | ||
description: 'Normalize term frequency by total number of terms in the document', | ||
group: 'term-frequency', | ||
displayName: 'Normalize TF', | ||
}, | ||
tfIdf: { | ||
type: 'boolean', | ||
defaultValue: false, | ||
description: 'Compute Term Frequency-Inverse Document Frequency (TF-IDF) scores for terms', | ||
group: 'term-frequency', | ||
displayName: 'Compute TF-IDF', | ||
}, | ||
frequencyCutoff: { | ||
type: 'number', | ||
defaultValue: 0, | ||
description: 'Minimum frequency for a term to be included in the matrix', | ||
group: 'term-frequency', | ||
displayName: 'Frequency Cut-off', | ||
}, | ||
bigramEnabled: { | ||
type: 'boolean', | ||
defaultValue: false, | ||
description: 'Enable bi-gram extraction', | ||
group: 'bi-gram', | ||
displayName: 'Enable Bi-grams', | ||
}, | ||
bigramSeparator: { | ||
type: 'string', | ||
defaultValue: '.', | ||
description: 'Separator for bi-gram terms', | ||
group: 'bi-gram', | ||
displayName: 'Bi-gram Separator', | ||
}, | ||
bigramFirstTerm: { | ||
type: 'string', | ||
defaultValue: biGramTermTypes[0], | ||
allowedValues: biGramTermTypes, | ||
description: 'First term of the bi-gram', | ||
group: 'bi-gram', | ||
displayName: 'Bi-gram First Term', | ||
}, | ||
bigramFirstTermAttribute: { | ||
type: 'string', | ||
defaultValue: '', | ||
description: 'Attribute name for the first term (if "attribute" is selected)', | ||
group: 'bi-gram', | ||
displayName: 'Bi-gram First Attribute', | ||
}, | ||
bigramSecondTerm: { | ||
type: 'string', | ||
defaultValue: biGramTermTypes[1], | ||
allowedValues: biGramTermTypes, | ||
description: 'Second term of the bi-gram', | ||
group: 'bi-gram', | ||
displayName: 'Bi-gram Second Term', | ||
}, | ||
bigramSecondTermAttribute: { | ||
type: 'string', | ||
defaultValue: '', | ||
description: 'Attribute name for the second term (if "attribute" is selected)', | ||
group: 'bi-gram', | ||
displayName: 'Bi-gram Second Attribute', | ||
}, | ||
}, | ||
invoke(input: (GraphModel | ExecutionError)[], parameters: EncoderParameters) { | ||
const models = filterValidModels(input) | ||
const modelTerms = extractModelTerms(models, parameters) | ||
const termList = createTermList(modelTerms, parameters) | ||
const termDocumentMatrix = createTermDocumentMatrix(modelTerms, termList, parameters) | ||
return createOutput(input, termDocumentMatrix, termList, modelTerms) | ||
}, | ||
}) | ||
|
||
function filterValidModels(input: (GraphModel | ExecutionError)[]): GraphModel[] { | ||
return input.filter((item) => item instanceof GraphModel) | ||
} | ||
|
||
function createOutput( | ||
input: (GraphModel | ExecutionError)[], | ||
termDocumentMatrix: TermDocumentMatrix, | ||
termList: string[], | ||
modelTerms: ModelTerms[], | ||
) { | ||
const modelIds = modelTerms.map(({ modelId }) => modelId) | ||
const result: TermFrequencyEncoding = { termDocumentMatrix, termList, modelIds } | ||
return input.map((item) => { | ||
if (item instanceof ExecutionError) { | ||
return item | ||
} | ||
return { | ||
data: input.length === 1 ? { modelTerms } : { modelId: item.root.id }, | ||
metadata: result, | ||
} | ||
}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
export const DEFAULT_STOP_WORDS = [ | ||
'a', | ||
'an', | ||
'and', | ||
'are', | ||
'as', | ||
'at', | ||
'be', | ||
'by', | ||
'for', | ||
'from', | ||
'has', | ||
'he', | ||
'in', | ||
'is', | ||
'it', | ||
'its', | ||
'of', | ||
'on', | ||
'that', | ||
'the', | ||
'to', | ||
'was', | ||
'were', | ||
'will', | ||
'with', | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import type { GraphModel, GraphNode } from '@cm2ml/ir' | ||
import { stemmer } from 'stemmer' | ||
|
||
import type { BiGramTermType, EncoderParameters } from './encoder' | ||
|
||
export interface ExtractedTerm { | ||
nodeId: string | ||
name: string | ||
occurences: number | ||
} | ||
|
||
export interface ModelTerms { | ||
modelId: string | ||
terms: ExtractedTerm[] | ||
} | ||
|
||
export function extractModelTerms(models: GraphModel[], parameters: EncoderParameters): ModelTerms[] { | ||
return models.map((model) => { | ||
if (model.root.id === undefined) { | ||
throw new Error('Model ID is undefined') | ||
} | ||
return { | ||
modelId: model.root.id, | ||
terms: extractTerms(model, parameters), | ||
} | ||
}) | ||
} | ||
|
||
function extractTerms(model: GraphModel, parameters: EncoderParameters): ExtractedTerm[] { | ||
const termMap = new Map<string, ExtractedTerm>() | ||
|
||
for (const node of model.nodes) { | ||
if (parameters.bigramEnabled) { | ||
const firstTerm = getTermValue(node, parameters.bigramFirstTerm, parameters.bigramFirstTermAttribute) | ||
const secondTerm = getTermValue(node, parameters.bigramSecondTerm, parameters.bigramSecondTermAttribute) | ||
|
||
if (firstTerm && secondTerm) { | ||
const bigram = `${firstTerm}${parameters.bigramSeparator}${secondTerm}` | ||
updateOrAddTerm(termMap, node.id!, bigram) | ||
} | ||
} else { | ||
// process names | ||
const name = node.getAttribute('name')?.value.literal | ||
if (parameters.namesAsTerms && name) { | ||
processTerms(termMap, node.id!, name, parameters) | ||
} | ||
|
||
// process types (do not tokenize) | ||
if (parameters.typesAsTerms && node.type) { | ||
updateOrAddTerm(termMap, node.id!, node.type) | ||
} | ||
|
||
// process additional attributes | ||
for (const attr of parameters.attributesAsTerms) { | ||
const attrValue = node.getAttribute(attr)?.value.literal | ||
if (attrValue) { | ||
processTerms(termMap, node.id!, attrValue, parameters) | ||
} | ||
} | ||
} | ||
} | ||
|
||
return Array.from(termMap.values()) | ||
} | ||
|
||
function processTerms(termMap: Map<string, ExtractedTerm>, nodeId: string, value: string, parameters: EncoderParameters) { | ||
const terms = parameters.tokenize ? tokenize(value) : [value] | ||
const stopWords = new Set(parameters.stopWords) | ||
|
||
terms.forEach((term: string) => { | ||
const processedTerm = parameters.stem ? stemmer(term) : term | ||
if (!stopWords.has(processedTerm)) { | ||
updateOrAddTerm(termMap, nodeId, processedTerm) | ||
} | ||
}) | ||
} | ||
|
||
function updateOrAddTerm(termMap: Map<string, ExtractedTerm>, nodeId: string, termName: string) { | ||
if (termMap.has(termName)) { | ||
const term = termMap.get(termName)! | ||
term.occurences++ | ||
} else { | ||
termMap.set(termName, { nodeId, name: termName, occurences: 1 }) | ||
} | ||
} | ||
|
||
function tokenize(text: string): string[] { | ||
// split text into words, considering spaces and punctuation as separators | ||
const rawTokens = text.split(/[\s\p{P}]+/u) | ||
// convert to lowercase and remove empty characters | ||
return rawTokens | ||
.map((token) => token.toLowerCase().replace(/\W+/g, '')) | ||
.filter((token) => token.length > 0) | ||
} | ||
|
||
function getTermValue( | ||
node: GraphNode, | ||
termType: BiGramTermType, | ||
attributeName: string | ||
): string | undefined { | ||
if (termType === 'name') { | ||
return node.getAttribute('name')?.value.literal | ||
} else if (termType === 'type') { | ||
return node.type | ||
} else if (termType === 'attribute') { | ||
return node.getAttribute(attributeName)?.value.literal | ||
} | ||
return undefined | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import type { EncoderParameters, TermDocumentMatrix } from './encoder' | ||
import type { ExtractedTerm, ModelTerms } from './term-extraction' | ||
|
||
export function createTermList(modelTerms: ModelTerms[], parameters: EncoderParameters): string[] { | ||
const termFrequencies = new Map<string, number>() | ||
|
||
modelTerms.forEach(({ terms }) => { | ||
terms.forEach((term) => { | ||
const currentFreq = termFrequencies.get(term.name) ?? 0 | ||
termFrequencies.set(term.name, currentFreq + term.occurences) | ||
}) | ||
}) | ||
|
||
return Array.from(termFrequencies.entries()) | ||
.filter(([_, frequency]) => frequency >= parameters.frequencyCutoff) | ||
.map(([term, _]) => term) | ||
} | ||
|
||
export function createTermDocumentMatrix(modelTerms: ModelTerms[], termList: string[], parameters: EncoderParameters): TermDocumentMatrix { | ||
const matrix: TermDocumentMatrix = {} | ||
|
||
modelTerms.forEach(({ modelId, terms }) => { | ||
const termCounts = countTerms(terms) | ||
const totalTerms = terms.reduce((sum, term) => sum + term.occurences, 0) | ||
|
||
matrix[modelId] = termList.map((term) => { | ||
const termCount = termCounts[term] ?? 0 | ||
const tf = computeTf(termCount, totalTerms, parameters) | ||
|
||
if (parameters.tfIdf) { | ||
const idf = computeIdf(term, modelTerms) | ||
return computeTfIdf(tf, idf) | ||
} | ||
return tf | ||
}) | ||
}) | ||
return matrix | ||
} | ||
|
||
function countTerms(terms: ExtractedTerm[]): Record<string, number> { | ||
return terms.reduce((counts, term) => { | ||
counts[term.name] = term.occurences | ||
return counts | ||
}, {} as Record<string, number>) | ||
} | ||
|
||
// TF = term count / total number of terms in the document (if normalized) | ||
// Note: some approaches use log(1 + TF), maybe support this too? | ||
function computeTf(termCount: number, totalTerms: number, parameters: EncoderParameters): number { | ||
return parameters.normalizeTf ? termCount / totalTerms : termCount | ||
} | ||
|
||
// IDF = log(number of documents / number of documents containing the term) | ||
function computeIdf(term: string, modelTerms: ModelTerms[]): number { | ||
const documentCount = modelTerms.length | ||
const documentFrequency = modelTerms.filter((model) => | ||
model.terms.some((t) => t.name === term), | ||
).length | ||
return Math.log(documentCount / (documentFrequency || 1)) | ||
} | ||
|
||
// TF-IDF = TF * IDF | ||
function computeTfIdf(tf: number, idf: number): number { | ||
return tf * idf | ||
} |