Skip to content

Commit

Permalink
feat(bop): implement path encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
DerYeger committed Sep 6, 2024
1 parent fbb7b69 commit 22c631b
Show file tree
Hide file tree
Showing 12 changed files with 413 additions and 281 deletions.
28 changes: 28 additions & 0 deletions packages/encoder/bag-of-paths-encoder/src/encode-paths.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import type { GraphModel, GraphNode } from '@cm2ml/ir'

import type { PathData } from './paths'
import type { Template } from './template'
import { compileTemplate } from './template'

export function encodePaths(paths: PathData[], model: GraphModel, templates: readonly string[]) {
const nodes = [...model.nodes]
const compiledTemplates = templates.map((template) => compileTemplate(template))
return paths.map((path) => encodePath(path, nodes, compiledTemplates))
}

function encodePath(path: PathData, nodes: GraphNode[], templates: Template[]) {
const parts: string[] = []
for (const step of path.steps) {
const node = nodes[step]
if (!node) {
throw new Error(`Node index out-of-bounds. This is an internal error.`)
}
const mapper = templates.find((template) => template.isApplicable(node))
const mapped = mapper?.apply(node)
if (mapped !== undefined) {
parts.push(mapped)
}
}
// TODO/Jan: Also encode the edges
return parts.join(' -> ')
}
44 changes: 12 additions & 32 deletions packages/encoder/bag-of-paths-encoder/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,12 @@ import { batchTryCatch, compose, definePlugin } from '@cm2ml/plugin'
import { Stream } from '@yeger/streams'

import { pathWeightTypes, sortOrders, stepWeightTypes } from './bop-types'
import { encodeNode, nodeEncodingTypes } from './node-encodings'
import type { PathData } from './paths'
import { encodePaths } from './encode-paths'
import { collectPaths } from './paths'

export type { PathData } from './paths'
export { stepWeightTypes, pathWeightTypes }
export type { PathWeight, StepWeight } from './bop-types'
export { nodeEncodingTypes }
export type { NodeEncodingType, NodeEncoding, PathCounts } from './node-encodings'

const PathBuilder = definePlugin({
name: 'path-builder',
Expand Down Expand Up @@ -69,51 +66,34 @@ const PathBuilder = definePlugin({
description: 'Ordering of paths according to their weight',
group: 'Filtering',
},
nodeEncoding: {
nodeTemplates: {
type: 'list<string>',
unique: true,
allowedValues: nodeEncodingTypes,
defaultValue: [],
description: 'Encodings to apply to nodes',
ordered: true,
defaultValue: [
'type="Model"->{<<}{{name}}{>>}',
'{{name}}{ : }{{type}}',
],
description: 'Template for encoding nodes of paths',
group: 'Encoding',
},
},
invoke: ({ data, metadata: featureContext }: { data: GraphModel, metadata: FeatureContext }, parameters) => {
invoke: ({ data, metadata }: { data: GraphModel, metadata: FeatureContext }, parameters) => {
const paths = collectPaths(data, parameters)
const mapping = Stream
.from(data.nodes)
.map((node) => node.requireId())
.toArray()
const nodeData = Array.from(data.nodes).map((node, nodeIndex) => [node, getRelevantPaths(nodeIndex, paths)] as const)
const longestPathLength = paths.reduce((max, path) => Math.max(max, path.steps.length), 0)
const highestPathCount = nodeData.reduce((max, [_, paths]) => Math.max(max, paths.size), 0)
const additionalMetadata = parameters.nodeEncoding.includes('features') ? { nodeFeatures: featureContext.nodeFeatures, edgeFeatures: featureContext.edgeFeatures } : {}
return {
data: {
paths,
encodedPaths: encodePaths(paths, data, parameters.nodeTemplates),
mapping,
nodes: parameters.nodeEncoding.length > 0
? nodeData.map(([node, paths], nodeIndex) =>
encodeNode({
nodeIndex,
node,
featureContext,
paths,
parameters,
longestPathLength,
highestPathCount,
},
),
)
: undefined,
},
metadata: { ...additionalMetadata, idAttribute: data.metamodel.idAttribute, typeAttributes: data.metamodel.typeAttributes },
metadata: { ...metadata, idAttribute: data.metamodel.idAttribute, typeAttributes: data.metamodel.typeAttributes },
}
},
})

function getRelevantPaths(nodeIndex: number, paths: PathData[]) {
return Stream.from(paths).filter((path) => path.steps[0] === nodeIndex || path.steps.at(-1) === nodeIndex).toSet()
}

// TODO/Jan: Remove feature encoder
export const BagOfPathsEncoder = compose(FeatureEncoder, batchTryCatch(PathBuilder), 'bag-of-paths')
85 changes: 85 additions & 0 deletions packages/encoder/bag-of-paths-encoder/src/mapper.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import type { ModelMember } from '@cm2ml/ir'

const separator = '@'

export const attributeKeyword = `attribute${separator}`

class AttributeValueSelector {
public constructor(public readonly attribute: string) { }

public select(element: ModelMember) {
return element.getAttribute(this.attribute)?.value.literal
}

public static fromString(selector: string) {
if (!selector.startsWith(attributeKeyword)) {
throw new Error(`Invalid selector: ${selector}`)
}
const attribute = selector.slice(attributeKeyword.length)
if (attribute.includes(separator)) {
throw new Error(`Invalid attribute: ${attribute}`)
}
return new AttributeValueSelector(attribute)
}
}

const directAccessKeys = ['id', 'type', 'tag', 'name'] as const
type DirectAccessKey = typeof directAccessKeys[number]
function isDirectAccessKey(key: string): key is DirectAccessKey {
return directAccessKeys.includes(key as DirectAccessKey)
}

class DirectSelector {
public constructor(public readonly key: DirectAccessKey) { }

public select(element: ModelMember) {
return element[this.key]
}

public static fromString(selector: string) {
if (!isDirectAccessKey(selector)) {
return null
}
return new DirectSelector(selector)
}
}

export function getSelector(selector: string) {
return DirectSelector.fromString(selector) ?? AttributeValueSelector.fromString(selector)
}

export type Selector = ReturnType<typeof getSelector>

class EqualityFilter {
public constructor(public readonly selector: Selector, public readonly comparator: '=', public readonly target: string | Selector) { }

public matches(element: ModelMember) {
const selected = this.selector.select(element)
const resolvedTarget = typeof this.target === 'string' ? this.target : this.target.select(element)
return selected === resolvedTarget
}

public static fromString(filter: string) {
const parts = filter.split('=')
if (parts.length !== 2) {
throw new Error(`Invalid filter: ${filter}`)
}
const selector = getSelector(parts[0]!)
if (selector === null) {
throw new Error(`Invalid filter: ${filter}`)
}
const target = parts[1]!
if (target.at(0) === '"' && target.at(-1) === '"') {
return new EqualityFilter(selector, '=', target.slice(1, -1))
}
const targetSelector = getSelector(target)
if (!targetSelector) {
throw new Error(`Invalid filter: ${filter}`)
}
return new EqualityFilter(selector, '=', targetSelector)
}
}

export function getFilter(filter: string) {
return EqualityFilter.fromString(filter)
}
74 changes: 0 additions & 74 deletions packages/encoder/bag-of-paths-encoder/src/node-encodings.ts

This file was deleted.

64 changes: 64 additions & 0 deletions packages/encoder/bag-of-paths-encoder/src/template.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import type { ModelMember } from '@cm2ml/ir'
import { Stream } from '@yeger/streams'

import { getFilter, getSelector } from './mapper'

export interface Template {
isApplicable: (modelMember: ModelMember) => boolean
apply: (modelMember: ModelMember) => string
}

export function compileTemplate(rawInput: string): Template {
const { rawFilter, rawTemplate } = extractFilter(rawInput)
if (!rawTemplate) {
throw new Error(`Invalid template: ${rawInput}`)
}
const segments = getSegments(rawTemplate)
if (!segments) {
throw new Error(`Invalid template: ${rawInput}`)
}
const replacers = segments.map(parseSegment)
const f = rawFilter ? getFilter(rawFilter) : undefined
return {
isApplicable: (modelMember: ModelMember) => !f || f.matches(modelMember),
apply: (modelMember: ModelMember) => Stream.from(replacers).map((replace) => replace(modelMember)).filterNonNull().join(''),
}
}

function extractFilter(template: string) {
const parts = template.split('->')
if (parts.length === 0) {
return {}
}
if (parts.length === 1) {
return { rawTemplate: parts[0]! }
}
if (parts.length === 2) {
return { rawFilter: parts[0]!, rawTemplate: parts[1]! }
}
throw new Error(`Invalid template: ${template}`)
}

export function getSegments(template: string) {
const parts = template.split('}{')
if (parts.length === 0) {
return null
}
parts[0] = parts[0]!.slice(1)
parts[parts.length - 1] = parts[parts.length - 1]!.slice(0, -1)
return parts
}

export function parseSegment(segment: string) {
if (segment.startsWith('{') && segment.endsWith('}')) {
const mapper = getSelector(segment.slice(1, -1))
return (modelMember: ModelMember): string | undefined => {
const mapped = mapper.select(modelMember)
if (Array.isArray(mapped)) {
return mapped.join('')
}
return mapped
}
}
return () => segment
}
10 changes: 10 additions & 0 deletions packages/encoder/bag-of-paths-encoder/test/template.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import { describe, expect, it } from 'vitest'

import { getSegments } from '../src/template'

describe('mapper', () => {
it('can parse a template', () => {
const result = getSegments('{a}{b}')
expect(result).toEqual(['a', 'b'])
})
})
2 changes: 1 addition & 1 deletion packages/encoder/bag-of-paths-encoder/test/test-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ export function createTestModel(nodes: string[], edges: [string, string][] | (re
const graphNode = model.addNode('node')
graphNode.id = id
graphNode.type = 'node'
graphNode.parent = root
root.addChild(graphNode)
})
edges.forEach(([sourceId, targetId]) => {
const source = model.getNodeById(sourceId)
Expand Down
25 changes: 19 additions & 6 deletions packages/framework/plugin/src/parameters.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ export type ListParameter = ParameterBase & Readonly<
* If false, the items will be sorted.
*/
readonly ordered?: boolean
} & ({ readonly unique?: false, readonly ordered?: boolean } | { readonly unique: true, readonly ordered?: false })
} & ({ readonly unique?: boolean, readonly ordered?: false } | { readonly unique: true, readonly ordered: true })
>

export type Parameter = PrimitiveParameter | ListParameter
Expand All @@ -62,13 +62,26 @@ function getZodValidator(parameter: Parameter) {
case 'list<string>': {
const baseArray = parameter.allowedValues && parameter.allowedValues.length > 0 ? z.array(z.enum(parameter.allowedValues as [string, ...string[]])) : z.array(z.string())
return baseArray.default([...parameter.defaultValue]).transform((value) => {
if (parameter.unique) {
return [...new Set(value)]
if (!parameter.ordered && !parameter.unique) {
return value.toSorted()
}
if (!parameter.ordered) {
return value.sort()
if (parameter.ordered && !parameter.unique) {
throw new Error('Ordered lists must be unique')
}
return value
if (parameter.ordered && parameter.unique) {
const visited = new Set<string>()
return value.filter((item) => {
if (visited.has(item)) {
return false
}
visited.add(item)
return true
})
}
if (!parameter.ordered && parameter.unique) {
return [...new Set(value)].toSorted()
}
throw new Error('Invalid list configuration. This is an internal error.')
})
}
}
Expand Down
Loading

0 comments on commit 22c631b

Please sign in to comment.