Skip to content

Commit

Permalink
feat(bop): only compile templates once
Browse files Browse the repository at this point in the history
  • Loading branch information
DerYeger committed Sep 23, 2024
1 parent af1e743 commit ef9aa97
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 82 deletions.
12 changes: 8 additions & 4 deletions packages/encoder/bag-of-paths-encoder/src/bop-types.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import type { GraphEdge, GraphNode } from '@cm2ml/ir'

import type { StepWeighting, Template } from './templates/model'

export const pathWeightTypes = ['step-sum', 'length', 'step-product'] as const

export type PathWeight = typeof pathWeightTypes[number] | string & Record<never, never>
Expand All @@ -13,10 +17,10 @@ export interface PathParameters {
maxPaths: number
allowCycles: boolean
order: SortOrder
stepWeighting: readonly string[]
}

export interface BoPEncodingParameters {
nodeTemplates: readonly string[]
edgeTemplates: readonly string[]
export interface CompiledTemplates {
stepWeighting: StepWeighting[]
nodeTemplates: Template<GraphNode>[]
edgeTemplates: Template<GraphEdge>[]
}
16 changes: 10 additions & 6 deletions packages/encoder/bag-of-paths-encoder/src/encoding.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import type { GraphEdge, GraphModel, GraphNode, ModelMember } from '@cm2ml/ir'
import { Stream } from '@yeger/streams'

import type { BoPEncodingParameters } from './bop-types'
import type { CompiledTemplates } from './bop-types'
import { MultiCache } from './multi-cache'
import type { PathData } from './paths'
import type { PathContext, Template } from './templates/model'
import { compileNodeTemplate, compileEdgeTemplate } from './templates/parser'

export type EncodedModelMember = string | null

Expand All @@ -19,14 +18,19 @@ export interface EncodedPath {
type NodeCache = MultiCache<string, GraphNode, readonly [number, EncodedModelMember]>
type EdgeCache = MultiCache<string, GraphEdge, EncodedModelMember>

export function encodePaths(paths: PathData[], model: GraphModel, parameters: BoPEncodingParameters): EncodedPath[] {
const compiledNodeTemplates = parameters.nodeTemplates.map((template) => compileNodeTemplate(template))
const compiledEdgeTemplates = parameters.edgeTemplates.map((template) => compileEdgeTemplate(template))
export function encodePaths(paths: PathData[], model: GraphModel, compiledTemplates: Omit<CompiledTemplates, 'stepWeighting'>): EncodedPath[] {
const nodeCache: NodeCache = new MultiCache()
const edgeCache: EdgeCache = new MultiCache()
const indexMap = new Map([...model.nodes].map((node, i) => [node, i]))
const getNodeIndex = (node: GraphNode) => indexMap.get(node)!
return paths.map((path) => encodePath(path, getNodeIndex, compiledNodeTemplates, compiledEdgeTemplates, nodeCache, edgeCache))
return paths.map((path) => encodePath(
path,
getNodeIndex,
compiledTemplates.nodeTemplates,
compiledTemplates.edgeTemplates,
nodeCache,
edgeCache,
))
}

function encodePath(path: PathData, getNodeIndex: (node: GraphNode) => number, nodeTemplates: Template<GraphNode>[], edgeTemplates: Template<GraphEdge>[], nodeCache: NodeCache, edgeCache: EdgeCache): EncodedPath {
Expand Down
83 changes: 51 additions & 32 deletions packages/encoder/bag-of-paths-encoder/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,60 @@
import type { GraphModel } from '@cm2ml/ir'
import { batchTryCatch, definePlugin } from '@cm2ml/plugin'
import { ExecutionError, batchTryCatch, compose, definePlugin, defineStructuredBatchPlugin } from '@cm2ml/plugin'
import { Stream } from '@yeger/streams'

import type { CompiledTemplates } from './bop-types'
import { pathWeightTypes, sortOrders } from './bop-types'
import { encodePaths } from './encoding'
import { collectPaths } from './paths'
import type { PruneMethod } from './prune'
import { pruneMethods, prunePaths } from './prune'
import { compileEdgeTemplate, compileNodeTemplate, compileStepWeighting } from './templates/parser'

export type { PathWeight } from './bop-types'
export { pathWeightTypes }
export type { EncodedModelMember, EncodedPath } from './encoding'

const TemplateCompiler = defineStructuredBatchPlugin({
name: 'template-compiler',
parameters: {
stepWeighting: {
type: 'list<string>',
unique: true,
ordered: true,
defaultValue: ['1'],
description: 'Custom weighting strategies',
group: 'Weighting',
helpText: __GRAMMAR,
},
nodeTemplates: {
type: 'list<string>',
unique: true,
ordered: true,
defaultValue: ['{{name}}.{{type}}'],
description: 'Template for encoding nodes of paths',
group: 'Encoding',
helpText: __GRAMMAR,
},
edgeTemplates: {
type: 'list<string>',
unique: true,
ordered: true,
defaultValue: ['{{tag}}'],
description: 'Template for encoding edges of paths',
group: 'Encoding',
helpText: __GRAMMAR,
},
},
invoke: (batch: (GraphModel | ExecutionError)[], parameters) => {
const metadata: CompiledTemplates = {
stepWeighting: parameters.stepWeighting.map(compileStepWeighting),
nodeTemplates: parameters.nodeTemplates.map(compileNodeTemplate),
edgeTemplates: parameters.edgeTemplates.map(compileEdgeTemplate),
}
return batch.map((data) => ({ data, metadata }))
},
})

const PathBuilder = definePlugin({
name: 'path-builder',
parameters: {
Expand Down Expand Up @@ -54,44 +97,20 @@ const PathBuilder = definePlugin({
description: 'Ordering of paths according to their weight',
group: 'Weighting',
},
stepWeighting: {
type: 'list<string>',
unique: true,
ordered: true,
defaultValue: ['1'],
description: 'Custom weighting strategies',
group: 'Weighting',
helpText: __GRAMMAR,
},
pathWeight: {
type: 'string',
allowedValues: pathWeightTypes,
defaultValue: pathWeightTypes[0],
description: 'Weighting strategy for paths',
group: 'Weighting',
},
nodeTemplates: {
type: 'list<string>',
unique: true,
ordered: true,
defaultValue: ['{{name}}.{{type}}'],
description: 'Template for encoding nodes of paths',
group: 'Encoding',
helpText: __GRAMMAR,
},
edgeTemplates: {
type: 'list<string>',
unique: true,
ordered: true,
defaultValue: ['{{tag}}'],
description: 'Template for encoding edges of paths',
group: 'Encoding',
helpText: __GRAMMAR,
},
},
invoke: (data: GraphModel, parameters) => {
const rawPaths = collectPaths(data, parameters)
const encodedPaths = encodePaths(rawPaths, data, parameters)
invoke: ({ data, metadata }: { data: GraphModel | ExecutionError, metadata: CompiledTemplates }, parameters) => {
if (data instanceof ExecutionError) {
return data
}
const rawPaths = collectPaths(data, parameters, metadata.stepWeighting)
const encodedPaths = encodePaths(rawPaths, data, metadata)
const mapping = Stream
.from(data.nodes)
.map((node) => node.requireId())
Expand All @@ -106,4 +125,4 @@ const PathBuilder = definePlugin({
},
})

export const BagOfPathsEncoder = batchTryCatch(PathBuilder, 'bag-of-paths')
export const BagOfPathsEncoder = compose(TemplateCompiler, batchTryCatch(PathBuilder), 'bag-of-paths')
6 changes: 2 additions & 4 deletions packages/encoder/bag-of-paths-encoder/src/paths.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import type { PathParameters, PathWeight } from './bop-types'
import { validatePathParameters } from './bop-validationts'
import { MultiCache } from './multi-cache'
import type { PathContext, StepWeighting } from './templates/model'
import { compileStepWeighting } from './templates/parser'

export interface StepData {
node: GraphNode
Expand Down Expand Up @@ -33,16 +32,15 @@ function pathOrder(order: 'asc' | 'desc' | string) {

type WeightCache = MultiCache<string, GraphEdge, number>

export function collectPaths(model: GraphModel, parameters: PathParameters) {
export function collectPaths(model: GraphModel, parameters: PathParameters, stepWeightings: StepWeighting[]): Omit<PathData, 'encodedSteps'>[] {
validatePathParameters(parameters)
const compiledStepWeighting = parameters.stepWeighting.map((weighting) => compileStepWeighting(weighting))
const nodes = Stream.from(model.nodes)
const weightCache: WeightCache = new MultiCache()
const paths = nodes
.flatMap((node) => Path.from(node, parameters))
.filter((path) => path.steps.length >= parameters.minPathLength)
.map<Omit<PathData, 'encodedSteps'>>((path) => {
const stepWeights = getStepWeights(path.steps, compiledStepWeighting, weightCache)
const stepWeights = getStepWeights(path.steps, stepWeightings, weightCache)
return {
steps: [{ node: path.startNode, via: undefined }, ...path.steps.map((step) => ({ node: step.target, via: step }))],
stepWeights,
Expand Down
43 changes: 15 additions & 28 deletions packages/encoder/bag-of-paths-encoder/test/paths.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { describe, expect, it } from 'vitest'

import type { PathData } from '../src/paths'
import { collectPaths } from '../src/paths'
import { compileStepWeighting } from '../src/templates/parser'

import { createTestModel } from './test-utils'

Expand All @@ -16,11 +17,10 @@ describe('paths', () => {
allowCycles: false,
minPathLength: 1,
maxPathLength: 3,
stepWeighting: ['1'],
pathWeight: 'length',
maxPaths: -1,
order: 'desc',
})
}, [compileStepWeighting('1')])
expect(snapshottify(paths)).toMatchInlineSnapshot(`
[
{
Expand Down Expand Up @@ -87,11 +87,10 @@ describe('paths', () => {
allowCycles: true,
minPathLength: 1,
maxPathLength: 3,
stepWeighting: ['1'],
pathWeight: 'length',
maxPaths: -1,
order: 'desc',
})
}, [compileStepWeighting('1')])
expect(snapshottify(paths).map(({ steps }) => steps)).toMatchInlineSnapshot(`
[
[
Expand Down Expand Up @@ -149,11 +148,10 @@ describe('paths', () => {
allowCycles: false,
minPathLength: 1,
maxPathLength: 3,
stepWeighting: ['1'],
pathWeight: 'length',
maxPaths: -1,
order: 'desc',
})
}, [compileStepWeighting('1')])
expect(snapshottify(paths).map(({ steps }) => steps)).toMatchInlineSnapshot(`
[
[
Expand Down Expand Up @@ -194,11 +192,10 @@ describe('paths', () => {
allowCycles: true,
minPathLength: 1,
maxPathLength: 3,
stepWeighting: ['1'],
pathWeight: 'length',
maxPaths: -1,
order: 'desc',
})
}, [compileStepWeighting('1')])
expect(snapshottify(paths).map(({ steps }) => steps)).toMatchInlineSnapshot(`
[
[
Expand Down Expand Up @@ -257,11 +254,10 @@ describe('paths', () => {
allowCycles: false,
minPathLength: 1,
maxPathLength: 3,
stepWeighting: ['1'],
pathWeight: 'length',
maxPaths: 1,
order: 'desc',
})
}, [compileStepWeighting('1')])
expect(snapshottify(paths)).toMatchInlineSnapshot(`
[
{
Expand All @@ -286,11 +282,10 @@ describe('paths', () => {
allowCycles: false,
minPathLength: 2,
maxPathLength: 3,
stepWeighting: ['1'],
pathWeight: 'length',
maxPaths: -1,
order: 'desc',
})
}, [compileStepWeighting('1')])
expect(snapshottify(paths)).toMatchInlineSnapshot(`
[
{
Expand Down Expand Up @@ -327,11 +322,10 @@ describe('paths', () => {
allowCycles: false,
minPathLength: 1,
maxPathLength: 1,
stepWeighting: ['1'],
pathWeight: 'length',
maxPaths: -1,
order: 'desc',
})
}, [compileStepWeighting('1')])
expect(snapshottify(paths)).toMatchInlineSnapshot(`
[
{
Expand Down Expand Up @@ -374,11 +368,10 @@ describe('paths', () => {
allowCycles: false,
minPathLength: 0,
maxPathLength: 0,
stepWeighting: ['1'],
pathWeight: 'length',
maxPaths: -1,
order: 'desc',
})
}, [compileStepWeighting('1')])
expect(snapshottify(paths)).toMatchInlineSnapshot(`
[
{
Expand Down Expand Up @@ -422,11 +415,10 @@ describe('paths', () => {
allowCycles: false,
minPathLength: 2,
maxPathLength: 3,
stepWeighting: ['2'],
pathWeight: 'length',
maxPaths: -1,
order: 'desc',
})
}, [compileStepWeighting('2')])
expect(snapshottify(paths).map(({ stepWeights }) => stepWeights)).toEqual([[2, 2]])
})

Expand All @@ -436,11 +428,10 @@ describe('paths', () => {
allowCycles: false,
minPathLength: 2,
maxPathLength: 4,
stepWeighting: ['@path.step = 1 >>> 2', '@path.step = 2 >>> 2.5', '@path.step = 3 >>> -42', '@path.step = 4 >>> -7,7'],
pathWeight: 'length',
maxPaths: 1,
order: 'desc',
})
}, ['@path.step = 1 >>> 2', '@path.step = 2 >>> 2.5', '@path.step = 3 >>> -42', '@path.step = 4 >>> -7,7'].map(compileStepWeighting))
expect(snapshottify(paths).map(({ stepWeights }) => stepWeights)).toEqual([[2, 2.5, -42, -7.7]])
})

Expand All @@ -450,11 +441,10 @@ describe('paths', () => {
allowCycles: false,
minPathLength: 2,
maxPathLength: 3,
stepWeighting: [],
pathWeight: 'length',
maxPaths: -1,
order: 'desc',
})
}, [])
expect(snapshottify(paths).map(({ stepWeights }) => stepWeights)).toEqual([[1, 1]])
})
})
Expand All @@ -466,11 +456,10 @@ describe('paths', () => {
allowCycles: false,
minPathLength: 1,
maxPathLength: 3,
stepWeighting: ['1'],
pathWeight: 'length',
maxPaths: -1,
order: 'desc',
})
}, [compileStepWeighting('1')])
expect(snapshottify(paths).map(({ weight }) => weight)).toEqual([2, 1, 1])
})

Expand All @@ -480,11 +469,10 @@ describe('paths', () => {
allowCycles: false,
minPathLength: 1,
maxPathLength: 2,
stepWeighting: ['@source.id = a >>> 3', '2'],
pathWeight: 'step-product',
maxPaths: -1,
order: 'desc',
})
}, ['@source.id = a >>> 3', '2'].map(compileStepWeighting))
expect(snapshottify(paths).map(({ weight }) => weight)).toEqual([6, 3, 2])
})

Expand All @@ -494,11 +482,10 @@ describe('paths', () => {
allowCycles: false,
minPathLength: 1,
maxPathLength: 2,
stepWeighting: ['@source.id = a >>> 3', '2'],
pathWeight: 'step-sum',
maxPaths: -1,
order: 'desc',
})
}, ['@source.id = a >>> 3', '2'].map(compileStepWeighting))
expect(snapshottify(paths).map(({ weight }) => weight)).toEqual([5, 3, 2])
})
})
Expand Down
Loading

0 comments on commit ef9aa97

Please sign in to comment.