Skip to content

Commit

Permalink
addressed PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Aditya Pratap Singh committed Aug 21, 2024
1 parent 7883009 commit 7e3e556
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -261,17 +261,11 @@ public static class DagNode<T> {
private T value;
//List of parent Nodes that are dependencies of this Node.
private List<DagNode<T>> parentNodes;
@Setter
private boolean isFailedDag;

//Constructor
public DagNode(T value) {
this.value = value;
}
public DagNode(T value,boolean isFailedDag) {
this.value = value;
this.isFailedDag = isFailedDag;
}


public void addParentNode(DagNode<T> node) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,20 @@ public class MysqlDagStateStoreWithDagNodes implements DagStateStoreWithDagNodes
protected final GsonSerDe<List<JobExecutionPlan>> serDe;
private final JobExecutionPlanDagFactory jobExecPlanDagFactory;

protected static final String CREATE_TABLE_STATEMENT = "CREATE TABLE IF NOT EXISTS %s ("
+ "dag_node_id VARCHAR(" + ServiceConfigKeys.MAX_DAG_NODE_ID_LENGTH
+ ") CHARACTER SET latin1 COLLATE latin1_bin NOT NULL, " + "parent_dag_id VARCHAR("
+ ServiceConfigKeys.MAX_DAG_ID_LENGTH + ") NOT NULL, " + "dag_node JSON NOT NULL, "
+ "modified_time timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, "
+ "is_failed_dag INT NOT NULL DEFAULT 0, " + "PRIMARY KEY (dag_node_id), "
+ "UNIQUE INDEX dag_node_index (dag_node_id), " + "INDEX dag_index (parent_dag_id))";

protected static final String INSERT_STATEMENT = "INSERT INTO %s (dag_node_id, parent_dag_id, dag_node, is_failed_dag) "
+ "VALUES (?, ?, ?, ?) AS new ON DUPLICATE KEY UPDATE dag_node = new.dag_node, is_failed_dag = new.is_failed_dag";
protected static final String GET_DAG_NODES_STATEMENT = "SELECT dag_node,is_failed_dag FROM %s WHERE parent_dag_id = ?";
protected static final String GET_DAG_NODE_STATEMENT = "SELECT dag_node,is_failed_dag FROM %s WHERE dag_node_id = ?";
protected static final String CREATE_TABLE_STATEMENT =
"CREATE TABLE IF NOT EXISTS %s (" + "dag_node_id VARCHAR(" + ServiceConfigKeys.MAX_DAG_NODE_ID_LENGTH
+ ") CHARACTER SET latin1 COLLATE latin1_bin NOT NULL, " + "parent_dag_id VARCHAR("
+ ServiceConfigKeys.MAX_DAG_ID_LENGTH + ") NOT NULL, " + "dag_node JSON NOT NULL, "
+ "modified_time timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, "
+ "is_failed_dag TINYINT(1) DEFAULT 0, " + "PRIMARY KEY (dag_node_id), "
+ "UNIQUE INDEX dag_node_index (dag_node_id), " + "INDEX dag_index (parent_dag_id))";

protected static final String INSERT_STATEMENT =
"INSERT INTO %s (dag_node_id, parent_dag_id, dag_node, is_failed_dag) "
+ "VALUES (?, ?, ?, ?) AS new ON DUPLICATE KEY UPDATE dag_node = new.dag_node, is_failed_dag = new.is_failed_dag";
protected static final String GET_DAG_NODES_STATEMENT =
"SELECT dag_node FROM %s WHERE parent_dag_id = ?";
protected static final String GET_DAG_NODE_STATEMENT = "SELECT dag_node FROM %s WHERE dag_node_id = ?";
protected static final String DELETE_DAG_STATEMENT = "DELETE FROM %s WHERE parent_dag_id = ?";
private final ContextAwareCounter totalDagCount;

Expand All @@ -103,7 +105,8 @@ public MysqlDagStateStoreWithDagNodes(Config config, Map<URI, TopologySpec> topo
DataSource dataSource = MysqlDataSourceFactory.get(config, SharedResourcesBrokerFactory.getImplicitBroker());

try (Connection connection = dataSource.getConnection();
PreparedStatement createStatement = connection.prepareStatement(String.format(CREATE_TABLE_STATEMENT, tableName))) {
PreparedStatement createStatement = connection.prepareStatement(
String.format(CREATE_TABLE_STATEMENT, tableName))) {
createStatement.executeUpdate();
connection.commit();
} catch (SQLException e) {
Expand Down Expand Up @@ -150,8 +153,7 @@ public boolean cleanUp(DagManager.DagId dagId) throws IOException {
return deleteStatement.executeUpdate() != 0;
} catch (SQLException e) {
throw new IOException(String.format("Failure deleting dag for %s", dagId), e);
}
}, true);
}}, true);
this.totalDagCount.dec();
return true;
}
Expand All @@ -165,7 +167,7 @@ public void cleanUp(String dagId) throws IOException {
@Override
public List<Dag<JobExecutionPlan>> getDags() throws IOException {
throw new NotSupportedException(getClass().getSimpleName() + " does not need this legacy API that originated with "
+ "the DagManager that is replaced by DagProcessingEngine"); }
+ "the DagManager that is replaced by DagProcessingEngine");}

@Override
public Dag<JobExecutionPlan> getDag(DagManager.DagId dagId) throws IOException {
Expand All @@ -189,19 +191,11 @@ private Dag<JobExecutionPlan> convertDagNodesIntoDag(Set<Dag.DagNode<JobExecutio
if (dagNodes.isEmpty()) {
return null;
}
Dag<JobExecutionPlan> dag = jobExecPlanDagFactory.createDag(dagNodes.stream().map(Dag.DagNode::getValue).collect(Collectors.toList()));

// if any node of the dag is failed it means that the dag has been marked as failed, update the is_failed_dag field of the dag and it's nodes as true
if (dag.getNodes().stream().anyMatch(Dag.DagNode::isFailedDag)) {
dag.setFailedDag(true);
dag.getNodes().forEach(node -> node.setFailedDag(true));
}
return dag;
return jobExecPlanDagFactory.createDag(dagNodes.stream().map(Dag.DagNode::getValue).collect(Collectors.toList()));
}

@Override
public int updateDagNode(DagManager.DagId parentDagId, Dag.DagNode<JobExecutionPlan> dagNode, boolean isFailedDag)
throws IOException {
public int updateDagNode(DagManager.DagId parentDagId, Dag.DagNode<JobExecutionPlan> dagNode, boolean isFailedDag) throws IOException {
String dagNodeId = dagNode.getValue().getId().toString();
return dbStatementExecutor.withPreparedStatement(String.format(INSERT_STATEMENT, tableName), insertStatement -> {
try {
Expand All @@ -212,25 +206,23 @@ public int updateDagNode(DagManager.DagId parentDagId, Dag.DagNode<JobExecutionP
return insertStatement.executeUpdate();
} catch (SQLException e) {
throw new IOException(String.format("Failure adding dag node for %s", dagNodeId), e);
}
}, true);
}}, true);
}

@Override
public Set<Dag.DagNode<JobExecutionPlan>> getDagNodes(DagManager.DagId parentDagId) throws IOException {
return dbStatementExecutor.withPreparedStatement(String.format(GET_DAG_NODES_STATEMENT, tableName),
getStatement -> {
getStatement.setString(1, parentDagId.toString());
HashSet<Dag.DagNode<JobExecutionPlan>> dagNodes = new HashSet<>();
try (ResultSet rs = getStatement.executeQuery()) {
while (rs.next()) {
dagNodes.add(new Dag.DagNode<>(this.serDe.deserialize(rs.getString(1)).get(0), rs.getBoolean(2)));
}
return dagNodes;
} catch (SQLException e) {
throw new IOException(String.format("Failure get dag nodes for dag %s", parentDagId), e);
}
}, true);
return dbStatementExecutor.withPreparedStatement(String.format(GET_DAG_NODES_STATEMENT, tableName), getStatement -> {
getStatement.setString(1, parentDagId.toString());
HashSet<Dag.DagNode<JobExecutionPlan>> dagNodes = new HashSet<>();
try (ResultSet rs = getStatement.executeQuery()) {
while (rs.next()) {
dagNodes.add(new Dag.DagNode<>(this.serDe.deserialize(rs.getString(1)).get(0)));
}
return dagNodes;
} catch (SQLException e) {
throw new IOException(String.format("Failure get dag nodes for dag %s", parentDagId), e);
}
}, true);
}

@Override
Expand All @@ -239,7 +231,7 @@ public Optional<Dag.DagNode<JobExecutionPlan>> getDagNode(DagNodeId dagNodeId) t
getStatement.setString(1, dagNodeId.toString());
try (ResultSet rs = getStatement.executeQuery()) {
if (rs.next()) {
return Optional.of(new Dag.DagNode<>(this.serDe.deserialize(rs.getString(1)).get(0), rs.getBoolean(2)));
return Optional.of(new Dag.DagNode<>(this.serDe.deserialize(rs.getString(1)).get(0)));
}
return Optional.empty();
} catch (SQLException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,18 @@

package org.apache.gobblin.service.modules.orchestration;

import java.io.IOException;
import java.net.URI;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;

import lombok.extern.slf4j.Slf4j;

import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
Expand All @@ -36,17 +44,22 @@
import org.apache.gobblin.service.ExecutionStatus;
import org.apache.gobblin.service.modules.flowgraph.Dag;
import org.apache.gobblin.service.modules.spec.JobExecutionPlan;

import org.apache.gobblin.broker.SharedResourcesBrokerFactory;
import org.apache.gobblin.metastore.MysqlDataSourceFactory;
import org.apache.gobblin.util.DBStatementExecutor;

/**
* Mainly testing functionalities related to DagStateStore but not Mysql-related components.
*/
@Slf4j
public class MysqlDagStateStoreWithDagNodesTest {

private DagStateStore dagStateStore;

private static final String TEST_USER = "testUser";
private static ITestMetastoreDatabase testDb;
private DBStatementExecutor dbStatementExecutor;
private static final String GET_DAG_NODES_STATEMENT = "SELECT dag_node, is_failed_dag FROM %s WHERE parent_dag_id = ?";
private static final String tableName = "dag_node_state_store";

@BeforeClass
public void setUp() throws Exception {
Expand All @@ -63,6 +76,8 @@ public void setUp() throws Exception {
URI specExecURI = new URI(specExecInstance);
topologySpecMap.put(specExecURI, topologySpec);
this.dagStateStore = new MysqlDagStateStoreWithDagNodes(configBuilder.build(), topologySpecMap);
dbStatementExecutor = new DBStatementExecutor(
MysqlDataSourceFactory.get(configBuilder.build(), SharedResourcesBrokerFactory.getImplicitBroker()), log);
}

@AfterClass(alwaysRun = true)
Expand All @@ -74,7 +89,7 @@ public void tearDown() throws Exception {
}

@Test
public void testAddGetAndDeleteDag() throws Exception{
public void testAddGetAndDeleteDag() throws Exception {
Dag<JobExecutionPlan> originalDag1 = DagTestUtils.buildDag("random_1", 123L);
Dag<JobExecutionPlan> originalDag2 = DagTestUtils.buildDag("random_2", 456L);
DagManager.DagId dagId1 = DagManagerUtils.generateDagId(originalDag1);
Expand Down Expand Up @@ -140,26 +155,53 @@ public void testAddGetAndDeleteDag() throws Exception{

@Test
public void testMarkDagAsFailed() throws Exception {
//Set up initial conditions
// Set up initial conditions
Dag<JobExecutionPlan> dag = DagTestUtils.buildDag("test_dag", 789L);
DagManager.DagId dagId = DagManagerUtils.generateDagId(dag);

this.dagStateStore.writeCheckpoint(dag);
//Check Initial State
for (Dag.DagNode<JobExecutionPlan> node : dag.getNodes()) {
Assert.assertFalse(node.isFailedDag());

// Fetch all initial states into a list
List<Boolean> initialStates = fetchDagNodeStates(dagId.toString());

// Check Initial State
for (Boolean state : initialStates) {
Assert.assertFalse(state);
}
// Set the DAG as failed
dag.setFailedDag(true);
this.dagStateStore.writeCheckpoint(dag);

Dag<JobExecutionPlan> updatedDag = this.dagStateStore.getDag(dagId);
for (Dag.DagNode<JobExecutionPlan> node : updatedDag.getNodes()) {
Assert.assertTrue(node.isFailedDag());
}
// Fetch all states after marking the DAG as failed
List<Boolean> failedStates = fetchDagNodeStates(dagId.toString());

// Cleanup
// Check if all states are now true (indicating failure)
for (Boolean state : failedStates) {
Assert.assertTrue(state);
}
dagStateStore.cleanUp(dagId);
Assert.assertNull(this.dagStateStore.getDag(dagId));
}

private List<Boolean> fetchDagNodeStates(String dagId) throws IOException {
List<Boolean> states = new ArrayList<>();

dbStatementExecutor.withPreparedStatement(String.format(GET_DAG_NODES_STATEMENT, tableName), getStatement -> {

getStatement.setString(1, dagId.toString());

HashSet<Dag.DagNode<JobExecutionPlan>> dagNodes = new HashSet<>();

try (ResultSet rs = getStatement.executeQuery()) {
while (rs.next()) {
states.add(rs.getBoolean(2));
}
return dagNodes;
} catch (SQLException e) {
throw new IOException(String.format("Failure get dag nodes for dag %s", dagId), e);
}
}, true);

return states;
}
}

0 comments on commit 7e3e556

Please sign in to comment.