Skip to content

Commit

Permalink
fix(raw-graph): do not encode edges for matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
DerYeger committed Oct 27, 2024
1 parent 811ef34 commit 060aa39
Showing 1 changed file with 17 additions and 24 deletions.
41 changes: 17 additions & 24 deletions packages/encoder/graph-encoder/src/edge-encoder.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { FeatureContext } from '@cm2ml/feature-encoder'
import type { FeatureContext, FeatureVector } from '@cm2ml/feature-encoder'
import type { GraphEdge, GraphModel } from '@cm2ml/ir'
import { defineStructuredPlugin } from '@cm2ml/plugin'
import { Stream } from '@yeger/streams'
Expand Down Expand Up @@ -29,8 +29,8 @@ export const EdgeEncoder = defineStructuredPlugin({
const { staticData, getNodeFeatureVector, getEdgeFeatureVector } = features
const sortedIds = getSortedIds(data)
const sortedEdges = Array.from(data.edges).sort(createEdgeSorter(sortedIds))
const edgeEncoder = format === 'list' ? encodeAsSparseList : encodeAsAdjacencyMatrix
const edgeEncoding = edgeEncoder(new Set(sortedEdges), sortedIds, weighted)
const edgeEncoder = format === 'list' ? encodeAsAdjacencyList : encodeAsAdjacencyMatrix
const edgeEncoding = edgeEncoder(new Set(sortedEdges), sortedIds, weighted, getEdgeFeatureVector)

const nodeFeatureVectors = Stream
.from(sortedIds)
Expand All @@ -39,12 +39,10 @@ export const EdgeEncoder = defineStructuredPlugin({
.map(getNodeFeatureVector)
.toArray()

const edgeFeatureVectors = sortedEdges.map(getEdgeFeatureVector)
return {
data: {
...edgeEncoding,
nodeFeatureVectors,
edgeFeatureVectors,
},
metadata: staticData,
}
Expand All @@ -61,14 +59,16 @@ function getSortedIds(model: GraphModel) {

export type AdjacencyList = [number, number][] | [number, number, number][]

function encodeAsSparseList(
function encodeAsAdjacencyList(
edges: ReadonlySet<GraphEdge>,
sortedIds: string[],
weighted: boolean,
getEdgeFeatureVector: (edge: GraphEdge) => FeatureVector,
) {
const list = new Array<
readonly [number, number] | readonly [number, number, number]
>()
const edgeFeatureVectors: FeatureVector[] = []
const indexRecord = createIndexRecord(sortedIds)
edges.forEach((edge) => {
const sourceId = edge.source.requireId()
Expand All @@ -85,16 +85,18 @@ function encodeAsSparseList(
? ([
sourceIndex,
targetIndex,
getWeightedValue(edge, edges),
getWeightedValue(edge),
] as const)
: ([sourceIndex, targetIndex] as const)
list.push(entry)
edgeFeatureVectors.push(getEdgeFeatureVector(edge))
})
sortAdjacencyList(list as AdjacencyList)
return {
format: 'list' as const,
list: list as AdjacencyList,
nodes: sortedIds,
edgeFeatureVectors,
}
}

Expand Down Expand Up @@ -129,7 +131,7 @@ function encodeAsAdjacencyMatrix(
) {
const matrix = createAdjacencyMatrix(sortedIds.length)
fillAdjacencyMatrix(matrix, edges, sortedIds, weighted)
return { format: 'matrix' as const, matrix, nodes: sortedIds }
return { format: 'matrix' as const, matrix, nodes: sortedIds, edgeFeatureVectors: [] }
}

function createAdjacencyMatrix(size: number): AdjacencyMatrix {
Expand Down Expand Up @@ -158,36 +160,27 @@ function fillAdjacencyMatrix(
if (targetIndex === undefined) {
throw new Error(`Target node ${targetId} not in model.`)
}
const value = weighted ? getWeightedValue(edge, edges) : 1
const value = weighted ? getWeightedValue(edge) : 1
matrix[sourceIndex]![targetIndex] = value
})
}

function getWeightedValue(
edge: GraphEdge,
edges: ReadonlySet<GraphEdge>,
) {
if (edges.size === edge.model.edges.size) {
return 1 / edge.target.incomingEdges.size
}
const relevantIncomingEdges = Stream.from(edge.target.incomingEdges)
.map((incomingEdge) => (edges.has(incomingEdge) ? 1 : 0))
.sum()
return 1 / relevantIncomingEdges
function getWeightedValue(edge: GraphEdge) {
return 1 / edge.target.incomingEdges.size
}

function createEdgeSorter(sortedIds: string[]) {
return (a: GraphEdge, b: GraphEdge) => {
const sourceIndexA = sortedIds.indexOf(a.source.id!)
const sourceIndexB = sortedIds.indexOf(b.source.id!)
const sourceIndexA = sortedIds.indexOf(a.source.requireId())
const sourceIndexB = sortedIds.indexOf(b.source.requireId())
if (sourceIndexA < sourceIndexB) {
return -1
}
if (sourceIndexA > sourceIndexB) {
return 1
}
const targetIndexA = sortedIds.indexOf(a.target.id!)
const targetIndexB = sortedIds.indexOf(b.target.id!)
const targetIndexA = sortedIds.indexOf(a.target.requireId())
const targetIndexB = sortedIds.indexOf(b.target.requireId())
if (targetIndexA < targetIndexB) {
return -1
}
Expand Down

0 comments on commit 060aa39

Please sign in to comment.