From 8364df9c6a2513d7b74fb6517f354db8a93a7897 Mon Sep 17 00:00:00 2001 From: wenchengyao Date: Mon, 21 Oct 2024 14:27:40 +0800 Subject: [PATCH] fix(reasoner): fix PER_NODE_LIMIT (#368) Co-authored-by: peilong --- .../local/main/KgReasonerTopKFilmTest.java | 13 ++++++++++ .../reasoner/pattern/PatternMatcher.java | 24 ++++++++++++------- 2 files changed, 29 insertions(+), 8 deletions(-) diff --git a/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerTopKFilmTest.java b/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerTopKFilmTest.java index 2ba74a0b3..7186bf7c7 100644 --- a/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerTopKFilmTest.java +++ b/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerTopKFilmTest.java @@ -457,4 +457,17 @@ private void doTest14() { Assert.assertEquals("root", result.get(0)[0]); Assert.assertEquals("L1_1_star", result.get(0)[1]); } + + @Test + public void test15() { + FileMutex.runTestWithMutex(this::doTest15); + } + + private void doTest15() { + String dsl = + "match (s:Film)-[p:starOfFilm|directOfFilm PER_NODE_LIMIT 1]->(o:FilmStar|FilmDirector) where s.id = 'root' return s.id, o.id"; + List result = runTestResult(dsl); + Assert.assertEquals(2, result.size()); + Assert.assertEquals(2, result.get(0).length); + } } diff --git a/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/pattern/PatternMatcher.java b/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/pattern/PatternMatcher.java index 06c3bdd9c..dd6dfff47 100644 --- a/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/pattern/PatternMatcher.java +++ b/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/pattern/PatternMatcher.java @@ -256,9 +256,6 @@ public boolean test(IEdge e) { + JSON.toJSONString(dstVertexRuleList)); } } - if (patternConnection.limit() != null && patternConnection.limit() > 0) { - limit = new Long(patternConnection.limit()); - } List> validEdges = matchEdges( vertexContext, willMatchEdgeList, patternConnection, pattern, edgeRuleMap, limit); @@ -299,10 +296,15 @@ private List> matchEdges( Connection patternConnection, Pattern pattern, Map> edgeRuleMap, - Long limit) { + Long totalLimit) { ArrayList> result = new ArrayList<>(); - long oneTypeEdgeCount = 0; + Map edgeTypeCountMap = new HashMap<>(); + Long totalCount = 0L; for (IEdge edge : edgeList) { + String edgeType = edge.getType(); + if (!edgeTypeCountMap.containsKey(edgeType)) { + edgeTypeCountMap.put(edgeType, 0L); + } if (!isEdgeMatch( vertexContext, edge, @@ -311,11 +313,17 @@ private List> matchEdges( edgeRuleMap.get(patternConnection.alias()))) { continue; } - oneTypeEdgeCount++; - if (null != limit && oneTypeEdgeCount > limit) { - // reach max path limit + totalCount = totalCount + 1; + if (null != totalLimit && totalCount > totalLimit) { break; } + long currentEdgeTypeCount = edgeTypeCountMap.get(edgeType) + 1; + edgeTypeCountMap.put(edgeType, currentEdgeTypeCount); + if (null != patternConnection.limit() + && patternConnection.limit() > 0 + && currentEdgeTypeCount > patternConnection.limit()) { + continue; + } result.add(edge); } result.trimToSize();