Skip to content

Commit

Permalink
fix(builder): update version 0.5.1 #andy (#407)
Browse files Browse the repository at this point in the history
  • Loading branch information
andylau-55 authored Nov 22, 2024
1 parent 7838257 commit 918a1fb
Show file tree
Hide file tree
Showing 27 changed files with 551 additions and 306 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
@AllArgsConstructor
public class LogicalPlan implements Serializable {

private static final long serialVersionUID = -4487139289740223319L;

/** DAG (Directed Acyclic Graph) of the logical execution plan. */
private final Graph<BaseLogicalNode<?>, DefaultEdge> dag;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
@AllArgsConstructor
public class PhysicalPlan implements Serializable {

private static final long serialVersionUID = -5866035535857620657L;

/** DAG (Directed Acyclic Graph) of the physical execution plan. */
private final Graph<BaseProcessor<?>, DefaultEdge> dag;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.RejectedExecutionHandler;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import lombok.extern.slf4j.Slf4j;
Expand All @@ -41,14 +42,16 @@ public class LLMNlExtractProcessor extends BasePythonProcessor<LLMNlExtractNodeC

private ExecuteNode node;

private static final ThreadPoolExecutor executor =
new ThreadPoolExecutor(
30,
60,
60 * 60,
TimeUnit.SECONDS,
new LinkedBlockingQueue<>(1000),
new ThreadPoolExecutor.CallerRunsPolicy());
private static final RejectedExecutionHandler handler =
(r, executor) -> {
try {
executor.getQueue().put(r);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
};

private static ThreadPoolExecutor executor;

public LLMNlExtractProcessor(String id, String name, LLMNlExtractNodeConfig config) {
super(id, name, config);
Expand All @@ -58,6 +61,16 @@ public LLMNlExtractProcessor(String id, String name, LLMNlExtractNodeConfig conf
public void doInit(BuilderContext context) throws BuilderException {
super.doInit(context);
this.node = context.getExecuteNodes().get(getId());
if (executor == null) {
executor =
new ThreadPoolExecutor(
context.getModelExecuteNum(),
context.getModelExecuteNum(),
60 * 60,
TimeUnit.SECONDS,
new LinkedBlockingQueue<>(100),
handler);
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@ public List<ChunkRecord.Chunk> readFile(String fileUrl) {
case "json":
className = "JSONReader";
break;
case "doc":
case "docx":
className = "DocxReader";
break;
}
node.addTraceLog("invoke chunk operator:%s", className);
pythonInterpreter.exec("from kag.builder.component.reader import " + className);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ public static List<BaseSPGRecord> convertNodes(
return records;
}

private static String labelPrefix(String namespace, String label) {
public static String labelPrefix(String namespace, String label) {
if (label.contains(DOT)) {
return label;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
@Accessors(chain = true)
public class BuilderContext implements Serializable {

private static final long serialVersionUID = 2446709406202543546L;

private long projectId;
private String project;
private String jobName;
Expand All @@ -41,6 +43,7 @@ public class BuilderContext implements Serializable {
private int batchSize = 1;
private int parallelism = 1;
private boolean enableLeadTo;
private Integer modelExecuteNum = 5;

private Map<String, ExecuteNode> executeNodes;
private String schemaUrl;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,12 @@
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.RejectedExecutionHandler;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;

@Slf4j
public class Neo4jSinkWriter extends BaseSinkWriter<Neo4jSinkNodeConfig> {
Expand All @@ -53,8 +55,16 @@ public class Neo4jSinkWriter extends BaseSinkWriter<Neo4jSinkNodeConfig> {
private Neo4jStoreClient client;
private Project project;
private static final String DOT = ".";
ExecutorService nodeExecutor;
ExecutorService edgeExecutor;
ExecutorService executor;

RejectedExecutionHandler handler =
(r, executor) -> {
try {
executor.getQueue().put(r);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
};

public Neo4jSinkWriter(String id, String name, Neo4jSinkNodeConfig config) {
super(id, name, config);
Expand All @@ -69,22 +79,14 @@ public void doInit(BuilderContext context) throws BuilderException {
}
client = new Neo4jStoreClient(context.getGraphStoreUrl());
project = JSON.parseObject(context.getProject(), Project.class);
nodeExecutor =
new ThreadPoolExecutor(
NUM_THREADS,
NUM_THREADS,
2 * 60L,
TimeUnit.SECONDS,
new LinkedBlockingQueue<>(1000),
new ThreadPoolExecutor.CallerRunsPolicy());
edgeExecutor =
executor =
new ThreadPoolExecutor(
NUM_THREADS,
NUM_THREADS,
2 * 60L,
TimeUnit.SECONDS,
new LinkedBlockingQueue<>(1000),
new ThreadPoolExecutor.CallerRunsPolicy());
new LinkedBlockingQueue<>(100),
handler);
}

@Override
Expand Down Expand Up @@ -115,7 +117,7 @@ public void writeToNeo4j(SubGraphRecord subGraphRecord) {
try {
node.addTraceLog("Start Writer Nodes processor...");
List<Future<Void>> nodeFutures =
submitTasks(nodeExecutor, subGraphRecord.getResultNodes(), this::writeNode);
submitTasks(executor, subGraphRecord.getResultNodes(), this::writeNode);
awaitAllTasks(nodeFutures);
node.addTraceLog("Writer Nodes succeed");
} catch (InterruptedException | ExecutionException e) {
Expand All @@ -125,7 +127,7 @@ public void writeToNeo4j(SubGraphRecord subGraphRecord) {
try {
node.addTraceLog("Start Writer Edges processor...");
List<Future<Void>> edgeFutures =
submitTasks(edgeExecutor, subGraphRecord.getResultEdges(), this::writeEdge);
submitTasks(executor, subGraphRecord.getResultEdges(), this::writeEdge);
awaitAllTasks(edgeFutures);
node.addTraceLog("Writer Edges succeed");
} catch (InterruptedException | ExecutionException e) {
Expand Down Expand Up @@ -174,7 +176,10 @@ private void writeNode(SubGraphRecord.Node node) {
try {
Long statr = System.currentTimeMillis();
RecordAlterOperationEnum operation = context.getOperation();
if (node.getId() == null || node.getName() == null) {
if (StringUtils.isBlank(node.getId())
|| StringUtils.isBlank(node.getName())
|| StringUtils.isBlank(node.getLabel())) {
log.info(String.format("write Node ignore node:%s", JSON.toJSONString(node)));
return;
}
String label = labelPrefix(node.getLabel());
Expand Down Expand Up @@ -212,7 +217,10 @@ private void writeEdge(SubGraphRecord.Edge edge) {
try {
Long statr = System.currentTimeMillis();
RecordAlterOperationEnum operation = context.getOperation();
if (edge.getFrom() == null || edge.getTo() == null) {
if (StringUtils.isBlank(edge.getFrom())
|| StringUtils.isBlank(edge.getTo())
|| StringUtils.isBlank(edge.getLabel())) {
log.info(String.format("write Edge ignore edge:%s", JSON.toJSONString(edge)));
return;
}
List<EdgeRecord> edgeRecords = Lists.newArrayList();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public static Driver getNeo4jDriver(String uri, String user, String password) {
Config config =
Config.builder()
.withMaxConnectionPoolSize(200)
.withMaxConnectionLifetime(2, TimeUnit.HOURS)
.withMaxConnectionLifetime(4, TimeUnit.HOURS)
.withMaxTransactionRetryTime(300, TimeUnit.SECONDS)
.withConnectionAcquisitionTimeout(300, TimeUnit.SECONDS)
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

package com.antgroup.openspg.common.util.neo4j;

import com.antgroup.openspg.common.util.Md5Utils;
import com.antgroup.openspg.common.util.tuple.Tuple2;
import com.antgroup.openspg.core.schema.model.predicate.IndexTypeEnum;
import com.antgroup.openspg.core.schema.model.predicate.Property;
Expand All @@ -21,6 +22,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
Expand All @@ -35,8 +37,6 @@
@Slf4j
public class Neo4jGraphUtils {

private static final String ALL_GRAPH = "allGraph";

private Driver driver;
private String database;
private Neo4jIndexUtils neo4jIndex;
Expand Down Expand Up @@ -127,22 +127,22 @@ public List<String> getAllLabels() {
return labels;
}

public void createAllGraph() {
public void createAllGraph(String allGraph) {
Session session = driver.session(SessionConfig.forDatabase(this.database));
String existsQuery =
String.format(
"CALL gds.graph.exists('%s') YIELD exists "
+ "WHERE exists "
+ "CALL gds.graph.drop('%s') YIELD graphName "
+ "RETURN graphName",
ALL_GRAPH, ALL_GRAPH);
allGraph, allGraph);

Result result = session.run(existsQuery);
ResultSummary summary = result.consume();
log.debug(
"create pagerank graph exists graph_name: {} database: {} succeed "
+ "executed: {} consumed: {}",
ALL_GRAPH,
allGraph,
database,
summary.resultAvailableAfter(TimeUnit.MILLISECONDS),
summary.resultConsumedAfter(TimeUnit.MILLISECONDS));
Expand All @@ -152,13 +152,33 @@ public void createAllGraph() {
"CALL gds.graph.project('%s','*','*') "
+ "YIELD graphName, nodeCount AS nodes, relationshipCount AS rels "
+ "RETURN graphName, nodes, rels",
ALL_GRAPH);
allGraph);

result = session.run(projectQuery);
summary = result.consume();
log.debug(
"create pagerank graph graph_name: {} database: {} succeed " + "executed: {} consumed: {}",
ALL_GRAPH,
allGraph,
database,
summary.resultAvailableAfter(TimeUnit.MILLISECONDS),
summary.resultConsumedAfter(TimeUnit.MILLISECONDS));
}

public void dropAllGraph(String allGraph) {
Session session = driver.session(SessionConfig.forDatabase(this.database));
String existsQuery =
String.format(
"CALL gds.graph.exists('%s') YIELD exists "
+ "WHERE exists "
+ "CALL gds.graph.drop('%s') YIELD graphName "
+ "RETURN graphName",
allGraph, allGraph);

Result result = session.run(existsQuery);
ResultSummary summary = result.consume();
log.debug(
"drop pagerank graph graph_name: {} database: {} succeed executed: {} consumed: {}",
allGraph,
database,
summary.resultAvailableAfter(TimeUnit.MILLISECONDS),
summary.resultConsumedAfter(TimeUnit.MILLISECONDS));
Expand All @@ -167,12 +187,18 @@ public void createAllGraph() {
public List<Map<String, Object>> getPageRankScores(
List<Map<String, String>> startNodes, String targetType) {
Session session = driver.session(SessionConfig.forDatabase(this.database));
createAllGraph();
return session.writeTransaction(tx -> getPageRankScores(tx, startNodes, targetType));
String allGraph = "allGraph_" + Md5Utils.md5Of(UUID.randomUUID().toString());
createAllGraph(allGraph);
try {
return session.writeTransaction(
tx -> getPageRankScores(tx, allGraph, startNodes, targetType));
} finally {
dropAllGraph(allGraph);
}
}

private List<Map<String, Object>> getPageRankScores(
Transaction tx, List<Map<String, String>> startNodes, String returnType) {
Transaction tx, String allGraph, List<Map<String, String>> startNodes, String returnType) {
List<String> matchClauses = new ArrayList<>();
List<String> matchIdentifiers = new ArrayList<>();

Expand Down Expand Up @@ -205,7 +231,7 @@ private List<Map<String, Object>> getPageRankScores(
+ "RETURN id(m) AS g_id, gds.util.asNode(nodeId).id AS id, score "
+ "ORDER BY score DESC",
matchQuery,
ALL_GRAPH,
allGraph,
matchIdentifierStr,
Neo4jCommonUtils.escapeNeo4jIdentifier(returnType));

Expand All @@ -231,6 +257,16 @@ public void createDatabase(String database) {
});
}

public void dropDatabase(String database) {
Session session = driver.session(SessionConfig.forDatabase(this.database));
session.writeTransaction(
tx -> {
tx.run(String.format("DROP DATABASE %s IF EXISTS", database));
tx.commit();
return null;
});
}

public void deleteAllData(String database) {
if (!this.database.equals(database)) {
throw new IllegalArgumentException(
Expand Down
5 changes: 3 additions & 2 deletions dev/release/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@ services:
command: [
"java",
"-Dfile.encoding=UTF-8",
"-Xms4096m",
"-Xmx4096m",
"-Xms2048m",
"-Xmx8192m",
"-jar",
"arks-sofaboot-0.0.1-SNAPSHOT-executable.jar",
'--server.repository.impl.jdbc.host=mysql',
'--server.repository.impl.jdbc.password=openspg',
'--builder.model.execute.num=5',
'--cloudext.graphstore.url=neo4j://release-openspg-neo4j:7687?user=neo4j&password=neo4j@openspg&database=neo4j',
'--cloudext.searchengine.url=neo4j://release-openspg-neo4j:7687?user=neo4j&password=neo4j@openspg&database=neo4j'
]
Expand Down
2 changes: 1 addition & 1 deletion dev/release/mysql/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

FROM mariadb:10.5.8

ADD sql/initdb.sql /docker-entrypoint-initdb.d
ADD sql /docker-entrypoint-initdb.d

EXPOSE 3306

Expand Down
4 changes: 3 additions & 1 deletion dev/release/mysql/buildx-release-mysql.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
# or implied.

docker buildx build -f Dockerfile --platform linux/arm64/v8,linux/amd64 --push \
-t openspg/openspg-mysql:0.5 \
-t spg-registry.cn-hangzhou.cr.aliyuncs.com/spg/openspg-mysql:0.5.1 \
-t spg-registry.cn-hangzhou.cr.aliyuncs.com/spg/openspg-mysql:latest \
-t openspg/openspg-mysql:0.5.1 \
-t openspg/openspg-mysql:latest \
.
Loading

0 comments on commit 918a1fb

Please sign in to comment.