Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Properly designate model state for actively training models when nodes crash or leave cluster #1317

Merged
merged 40 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
d9269b3
Initial implementation
ryanbogan Nov 6, 2023
945a4da
Fix compile errors for tests
ryanbogan Nov 9, 2023
b2fc712
Temporary tests
ryanbogan Nov 13, 2023
9e21f07
Ensure backwards compatibility and add zombie to model state enum
ryanbogan Nov 17, 2023
ad09839
Update current tests
ryanbogan Nov 17, 2023
0537111
Fix current integration tests
ryanbogan Nov 20, 2023
c25075f
Fix unit tests with new changes
ryanbogan Nov 20, 2023
a28ad42
Add unit tests
ryanbogan Nov 20, 2023
3f31741
Fix spotless
ryanbogan Nov 20, 2023
85ed0bf
Add changelog entry
ryanbogan Nov 20, 2023
f464d2e
Delete temporary test file
ryanbogan Nov 20, 2023
ba7d5f2
Remove temporary changes to build.gradle
ryanbogan Nov 20, 2023
91778e1
Add more backwards compatibility
ryanbogan Nov 21, 2023
de2c3aa
Attempt to fix bwc tests
ryanbogan Nov 21, 2023
14aa761
Fix spotless
ryanbogan Nov 21, 2023
62d0082
Remove star imports
ryanbogan Nov 21, 2023
47a3800
Add another unit test
ryanbogan Nov 21, 2023
c15dc9a
Modify unit test to increase coverage
ryanbogan Nov 21, 2023
c7e0dcf
Change unit test to increase coverage
ryanbogan Nov 21, 2023
25eab9c
Merge branch 'main' into model_stuck_train_state
ryanbogan Nov 21, 2023
66e787b
Add method description for clusterChanged
ryanbogan Nov 22, 2023
257623a
Address PR feedback
ryanbogan Nov 28, 2023
85635c3
Refactor into TrainingJobClusterStateListener
ryanbogan Nov 28, 2023
bbd3b47
Make node assignment final and added in the constructor of TrainingJob
ryanbogan Nov 29, 2023
ea73a16
Remove clusterService from TrainingJobRunner
ryanbogan Nov 29, 2023
012a76e
Address PR Feedback
ryanbogan Dec 1, 2023
613a28e
Add flag when node rejoins and check when serializing model
ryanbogan Dec 1, 2023
ac5df23
Address PR feedback
ryanbogan Dec 1, 2023
91ea4df
Merge branch 'main' into model_stuck_train_state
ryanbogan Dec 1, 2023
c7b6281
Address PR Feedback
ryanbogan Dec 4, 2023
bf20c77
Fix spotless
ryanbogan Dec 4, 2023
4148b28
Test new version check for StreamInput
ryanbogan Dec 5, 2023
c1bdac9
Remove check to test new method
ryanbogan Dec 5, 2023
fdabe9f
Add version check for stream input/output logic
ryanbogan Dec 5, 2023
3eb2375
Address PR Feedback
ryanbogan Dec 6, 2023
80574a2
Address PR Feedback
ryanbogan Dec 7, 2023
6f1a064
Address PR Feedback
ryanbogan Dec 7, 2023
bf4407d
Address PR Feedback
ryanbogan Dec 7, 2023
586797f
Address PR Feedback
ryanbogan Dec 7, 2023
b6b85a9
Merge branch main into model_stuck_train_state
ryanbogan Dec 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.11...2.x)
### Features
* Add parent join support for lucene knn [#1182](https://github.com/opensearch-project/k-NN/pull/1182)
* Properly designate model state for actively training models when nodes crash or leave cluster [#1317](https://github.com/opensearch-project/k-NN/pull/1317)
### Enhancements
### Bug Fixes
* Fix use-after-free case on nmslib search path [#1305](https://github.com/opensearch-project/k-NN/pull/1305)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,6 @@ public String modelIndexMapping(String fieldName, String modelId) throws IOExcep
}

private ModelMetadata getModelMetadata() {
return new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, 4, ModelState.CREATED, "2021-03-27", "test model", "");
return new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, 4, ModelState.CREATED, "2021-03-27", "test model", "", "");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ public class KNNConstants {
public static final String MODEL_TIMESTAMP = "timestamp";
public static final String MODEL_DESCRIPTION = "description";
public static final String MODEL_ERROR = "error";
public static final String MODEL_NODE_ASSIGNMENT = "node_assignment";
public static final String PARAM_SIZE = "size";
public static final Integer SEARCH_MODEL_MIN_SIZE = 1;
public static final Integer SEARCH_MODEL_MAX_SIZE = 1000;
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/org/opensearch/knn/index/IndexUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@
public class IndexUtil {

private static final Version MINIMAL_SUPPORTED_VERSION_FOR_IGNORE_UNMAPPED = Version.V_2_11_0;
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_NODE_ASSIGNMENT = Version.V_2_12_0;
private static final Map<String, Version> minimalRequiredVersionMap = new HashMap<String, Version>() {
{
put("ignore_unmapped", MINIMAL_SUPPORTED_VERSION_FOR_IGNORE_UNMAPPED);
put("model_node_assignment", MINIMAL_SUPPORTED_VERSION_FOR_MODEL_NODE_ASSIGNMENT);
}
};

Expand Down
1 change: 1 addition & 0 deletions src/main/java/org/opensearch/knn/indices/ModelDao.java
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ private void putInternal(Model model, ActionListener<IndexResponse> listener, Do
put(KNNConstants.MODEL_TIMESTAMP, modelMetadata.getTimestamp());
put(KNNConstants.MODEL_DESCRIPTION, modelMetadata.getDescription());
put(KNNConstants.MODEL_ERROR, modelMetadata.getError());
put(KNNConstants.MODEL_NODE_ASSIGNMENT, modelMetadata.getNodeAssignment());
}
};

Expand Down
49 changes: 40 additions & 9 deletions src/main/java/org/opensearch/knn/indices/ModelMetadata.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.IndexUtil;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.util.KNNEngine;

Expand All @@ -34,6 +35,7 @@
import static org.opensearch.knn.common.KNNConstants.MODEL_ERROR;
import static org.opensearch.knn.common.KNNConstants.MODEL_STATE;
import static org.opensearch.knn.common.KNNConstants.MODEL_TIMESTAMP;
import static org.opensearch.knn.common.KNNConstants.MODEL_NODE_ASSIGNMENT;

public class ModelMetadata implements Writeable, ToXContentObject {

Expand All @@ -47,6 +49,7 @@ public class ModelMetadata implements Writeable, ToXContentObject {
final private String timestamp;
final private String description;
private String error;
private String nodeAssignment;

/**
* Constructor
Expand All @@ -64,6 +67,7 @@ public ModelMetadata(StreamInput in) throws IOException {
// which is checked in constructor and setters
this.description = in.readString();
this.error = in.readString();
this.nodeAssignment = in.readOptionalString();
}

/**
Expand All @@ -84,7 +88,8 @@ public ModelMetadata(
ModelState modelState,
String timestamp,
String description,
String error
String error,
String nodeAssignment
) {
this.knnEngine = Objects.requireNonNull(knnEngine, "knnEngine must not be null");
this.spaceType = Objects.requireNonNull(spaceType, "spaceType must not be null");
Expand All @@ -104,6 +109,7 @@ public ModelMetadata(
this.timestamp = Objects.requireNonNull(timestamp, "timestamp must not be null");
this.description = Objects.requireNonNull(description, "description must not be null");
this.error = Objects.requireNonNull(error, "error must not be null");
this.nodeAssignment = nodeAssignment;
}

/**
Expand Down Expand Up @@ -169,6 +175,19 @@ public String getError() {
return error;
}

/**
* getter for model's node assignment
*
* @return nodeAssignment
*/
public String getNodeAssignment() {
return nodeAssignment;
}

public void setNodeAssignment(String nodeAssignment) {
this.nodeAssignment = nodeAssignment;
}

/**
* setter for model's state
*
Expand Down Expand Up @@ -197,7 +216,8 @@ public String toString() {
getState().toString(),
timestamp,
description,
error
error,
nodeAssignment
);
}

Expand Down Expand Up @@ -240,10 +260,10 @@ public int hashCode() {
public static ModelMetadata fromString(String modelMetadataString) {
String[] modelMetadataArray = modelMetadataString.split(DELIMITER, -1);

if (modelMetadataArray.length != 7) {
if (modelMetadataArray.length != 8 && modelMetadataArray.length != 7) {
throw new IllegalArgumentException(
"Illegal format for model metadata. Must be of the form "
+ "\"<KNNEngine>,<SpaceType>,<Dimension>,<ModelState>,<Timestamp>,<Description>,<Error>\"."
+ "\"<KNNEngine>,<SpaceType>,<Dimension>,<ModelState>,<Timestamp>,<Description>,<Error>,<NodeAssignment>\"."
);
}

Expand All @@ -254,8 +274,11 @@ public static ModelMetadata fromString(String modelMetadataString) {
String timestamp = modelMetadataArray[4];
String description = modelMetadataArray[5];
String error = modelMetadataArray[6];

return new ModelMetadata(knnEngine, spaceType, dimension, modelState, timestamp, description, error);
if (modelMetadataArray.length == 8) {
String nodeAssignment = modelMetadataArray[7];
return new ModelMetadata(knnEngine, spaceType, dimension, modelState, timestamp, description, error, nodeAssignment);
}
return new ModelMetadata(knnEngine, spaceType, dimension, modelState, timestamp, description, error, "");
}

private static String objectToString(Object value) {
Expand All @@ -278,10 +301,11 @@ public static ModelMetadata getMetadataFromSourceMap(final Map<String, Object> m
Object engine = modelSourceMap.get(KNNConstants.KNN_ENGINE);
Object space = modelSourceMap.get(KNNConstants.METHOD_PARAMETER_SPACE_TYPE);
Object dimension = modelSourceMap.get(KNNConstants.DIMENSION);
Object state = modelSourceMap.get(KNNConstants.MODEL_STATE);
Object timestamp = modelSourceMap.get(KNNConstants.MODEL_TIMESTAMP);
Object state = modelSourceMap.get(MODEL_STATE);
Object timestamp = modelSourceMap.get(MODEL_TIMESTAMP);
Object description = modelSourceMap.get(KNNConstants.MODEL_DESCRIPTION);
Object error = modelSourceMap.get(KNNConstants.MODEL_ERROR);
Object nodeAssignment = modelSourceMap.get(KNNConstants.MODEL_NODE_ASSIGNMENT);

ModelMetadata modelMetadata = new ModelMetadata(
KNNEngine.getEngine(objectToString(engine)),
Expand All @@ -290,7 +314,8 @@ public static ModelMetadata getMetadataFromSourceMap(final Map<String, Object> m
ModelState.getModelState(objectToString(state)),
objectToString(timestamp),
objectToString(description),
objectToString(error)
objectToString(error),
objectToString(nodeAssignment)
);
return modelMetadata;
}
Expand All @@ -304,6 +329,9 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeString(getTimestamp());
out.writeString(getDescription());
out.writeString(getError());
if (IndexUtil.isClusterOnOrAfterMinRequiredVersion("model_node_assignment")) {
out.writeString(getNodeAssignment());
}
}

@Override
Expand All @@ -316,6 +344,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(METHOD_PARAMETER_SPACE_TYPE, getSpaceType().getValue());
builder.field(DIMENSION, getDimension());
builder.field(KNN_ENGINE, getKnnEngine().getName());
if (IndexUtil.isClusterOnOrAfterMinRequiredVersion("model_node_assignment")) {
builder.field(MODEL_NODE_ASSIGNMENT, getNodeAssignment());
}
return builder;
}
}
7 changes: 6 additions & 1 deletion src/main/java/org/opensearch/knn/indices/ModelState.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
public enum ModelState implements Writeable {
TRAINING("training"),
CREATED("created"),
FAILED("failed");
FAILED("failed"),
ZOMBIE("zombie");

private final String name;

Expand Down Expand Up @@ -79,6 +80,10 @@ public static ModelState getModelState(String name) {
return FAILED;
}

if (ZOMBIE.getName().equals(name)) {
return ZOMBIE;
}

throw new IllegalArgumentException("Unable to find model state: \"" + name + "\"");
}
}
2 changes: 1 addition & 1 deletion src/main/java/org/opensearch/knn/plugin/KNNPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ public Collection<Object> createComponents(
KNNClusterUtil.instance().initialize(clusterService);
ModelDao.OpenSearchKNNModelDao.initialize(client, clusterService, environment.settings());
ModelCache.initialize(ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService);
TrainingJobRunner.initialize(threadPool, ModelDao.OpenSearchKNNModelDao.getInstance());
TrainingJobRunner.initialize(threadPool, ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService);
KNNCircuitBreaker.getInstance().initialize(threadPool, clusterService, client);
KNNQueryBuilder.initialize(ModelDao.OpenSearchKNNModelDao.getInstance());
KNNWeight.initialize(ModelDao.OpenSearchKNNModelDao.getInstance());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ public TrainingJob(
ModelState.TRAINING,
ZonedDateTime.now(ZoneOffset.UTC).toString(),
description,
"",
""
),
null,
Expand Down
Loading