Skip to content
This repository has been archived by the owner on Jan 21, 2024. It is now read-only.

Commit

Permalink
Merge pull request #6 from SalesforceLabs/guspatch
Browse files Browse the repository at this point in the history
Performance improvements and more customization
  • Loading branch information
iskander-m authored Mar 31, 2021
2 parents 3e5bdb3 + b392f1f commit aa160b8
Show file tree
Hide file tree
Showing 31 changed files with 914 additions and 118 deletions.
36 changes: 34 additions & 2 deletions force-app/main/algorithms/classes/ClusterAlgorithmRunner.cls
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,21 @@ public abstract with sharing class ClusterAlgorithmRunner {
private static ClusterJobState jobState; //TODO: get rid of this or store somewhere else
public List<ClusterAlgorithmStep> steps;
public ClusterDistanceCacheManager distanceCacheManager;
private ClusterRecordPreprocessor recordPreprocessor;

public ClusterAlgorithmRunner() {
this.steps = new List<ClusterAlgorithmStep>();
this.objectFactory = new ClusterObjectFactory();
}

protected virtual void initRecordPreprocessor() {
this.recordPreprocessor = this.objectFactory.createRecordPreprocessor();
}

public void setRecordPreprocessor(ClusterRecordPreprocessor recordPreprocessor) {
this.recordPreprocessor = recordPreprocessor;
}

public abstract void initializeDistanceCache();

public virtual ClusterAlgorithmRunner.ModelValidationResult validateModel(ClusterModelWrapper model) {
Expand Down Expand Up @@ -74,6 +84,7 @@ public abstract with sharing class ClusterAlgorithmRunner {
}

public virtual void init(ClusterModelWrapper model) {
this.initRecordPreprocessor();
Boolean hasJobOutput = false;
for (Integer i=0; i<model.fields.size(); i++) {
if (model.fields[i].distanceType == ClusterConstants.FIELDTYPE_OUTPUT) {
Expand All @@ -88,7 +99,7 @@ public abstract with sharing class ClusterAlgorithmRunner {
}
public abstract ClusterJobState getJobState();
public abstract ClusterJobState createJobState();
private ClusterObjectFactory objectFactory{ get; set; }
public ClusterObjectFactory objectFactory{ get; set; }

public virtual void start() {
try {
Expand Down Expand Up @@ -225,8 +236,29 @@ public abstract with sharing class ClusterAlgorithmRunner {
if (predictRecords.size() != 1) {
throw new ClusterException('Record with id ' + externalRecordId + ' was not found');
}
return this.getDataPoint(predictRecords.get(0));
}

public ClusterDataPoint getDataPoint(SObject record) {
if (this.recordPreprocessor != null) {
List<SObject> records = new List<SObject>();
records.add(record);
this.preprocessSObjects(records, false);
}
ClusterSObjectProcessor objectProcessor = this.getSObjectProcessor();
return objectProcessor.processSObject(predictRecords.get(0));
return objectProcessor.processSObject(record);
}

public void preprocessSObjects(List<SObject> records, boolean isLearning) {
if (this.recordPreprocessor != null) {
ClusterRecordPreprocessorParameters parameters = new ClusterRecordPreprocessorParameters();
ClusterJobState state = this.getJobState();
parameters.setModelId(state.model.modelId);
parameters.setModelName(state.model.name);
parameters.setJobId(state.clusterJob.Id);
parameters.setIsLearning(isLearning);
this.recordPreprocessor.processRecords(records, parameters);
}
}

public abstract Double calculateDistance(Object[] currentObject, Object[] centroid);
Expand Down
150 changes: 93 additions & 57 deletions force-app/main/algorithms/classes/ClusterKNNPredictor.cls

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ public with sharing class ClusterLongTextFieldValueProcessor implements ClusterF
Map<String,ClusterWordPreprocessor> wordPreprocessors;
Boolean useCompression;
private Integer minWordCount;
private Integer maxWordBagSize;

public ClusterLongTextFieldValueProcessor(ClusterJobState state) {
this.jobState = state;
Expand All @@ -30,6 +31,7 @@ public with sharing class ClusterLongTextFieldValueProcessor implements ClusterF
if (this.minWordCount == null) {
this.minWordCount = ClusterConstants.getMinTfIdfWordCount();
}
this.maxWordBagSize = ClusterConstants.getTFIDFWordBagSize();
}

public void setCompression(Boolean compression) {
Expand Down Expand Up @@ -74,10 +76,10 @@ public with sharing class ClusterLongTextFieldValueProcessor implements ClusterF
}
String language = fieldDesc.language != null ? fieldDesc.language.toLowerCase() : ClusterConstants.LANGUAGE_NONE;
if (this.useCompression) {
return calculateCompressedTF(text, wordMap, wordList, this.wordPreprocessors.get(language), this.minWordCount);
return calculateCompressedTF(text, wordMap, wordList, this.wordPreprocessors.get(language), this.minWordCount, this.maxWordBagSize);
}
else {
Double[] tf = calculateTF(text, wordMap, wordList, this.wordPreprocessors.get(language), this.minWordCount);
Double[] tf = calculateTF(text, wordMap, wordList, this.wordPreprocessors.get(language), this.minWordCount, this.maxWordBagSize);
return tf;
}
}
Expand All @@ -86,27 +88,33 @@ public with sharing class ClusterLongTextFieldValueProcessor implements ClusterF
return text==null ? null : text.split(WORD_SPLIT_REGEX);
}

private static Integer prepareWordList(String text, Map<String, Integer> wordMap, List<String> wordList, Map<String, Integer> currentWordMap, ClusterWordPreprocessor wordPreprocessor, Integer minWordCount) {
if (text == '') {
return null;
private static Integer prepareWordList(String text, Map<String, Integer> wordMap, List<String> wordList, Map<String, Integer> currentWordMap, ClusterWordPreprocessor wordPreprocessor, Integer minWordCount, Integer maxWordBagSize) {
if (text.isWhitespace()) {
return 0;
}
//Removing html tags and breaking into words
String[] words = splitText(text.stripHtmlTags());
String strippedText = text.stripHtmlTags();
if (strippedText.isWhitespace()) {
strippedText = text;
}
String[] words = splitText(strippedText);

//Unfiltered word map
Map<String, Integer> tempWordMap = new Map<String, Integer>();
Integer wordsSize = words.size();

for (Integer i = 0; i < words.size(); i++) {
for (Integer i = 0; i < wordsSize; i++) {
String currWord = words[i];
//Skip empty and single character words
if (words[i].length() < 2) {
if (currWord.length() < 2) {
continue;
}
String token = wordPreprocessor != null ? wordPreprocessor.preprocess(words[i]) : words[i];
String token = wordPreprocessor != null ? wordPreprocessor.preprocess(currWord) : currWord;
//Skip words that were filtered out by the preprocessor
if (token == null) {
continue;
}
addWordToMap(token, tempWordMap, 1, null);
addWordToMap(token, tempWordMap, 1, null, maxWordBagSize);
}

//Adding words and updating counts in aggregated structures
Expand All @@ -119,7 +127,7 @@ public with sharing class ClusterLongTextFieldValueProcessor implements ClusterF
if (currentTokenCount >= currMinWordCount) {
currentWordMap.put(currentToken, currentTokenCount);
numTokens += currentTokenCount;
addWordToMap(currentToken, wordMap, 1, wordList); // Adding 1 here because for IDF we need to calculate the number of documents containing this term
addWordToMap(currentToken, wordMap, 1, wordList, maxWordBagSize); // Adding 1 here because for IDF we need to calculate the number of documents containing this term
}
}
if (numTokens == 0) {
Expand All @@ -135,11 +143,11 @@ public with sharing class ClusterLongTextFieldValueProcessor implements ClusterF
return numTokens;
}

public static Double[] calculateTF(String text, Map<String, Integer> wordMap, List<String> wordList, ClusterWordPreprocessor wordPreprocessor, Integer minWordCount) {
public static Double[] calculateTF(String text, Map<String, Integer> wordMap, List<String> wordList, ClusterWordPreprocessor wordPreprocessor, Integer minWordCount, Integer maxWordBagSize) {
//This will contain word counts for the current document
Map<String, Integer> currentWordMap = new Map<String,Integer>();

Integer numTokens = prepareWordList(text, wordMap, wordList, currentWordMap, wordPreprocessor, minWordCount);
Integer numTokens = prepareWordList(text, wordMap, wordList, currentWordMap, wordPreprocessor, minWordCount, maxWordBagSize);
//Calculating tf for the text
Double[] tf = new Double[wordList.size()];
for (Integer i=0; i<wordList.size(); i++) {
Expand All @@ -155,14 +163,15 @@ public with sharing class ClusterLongTextFieldValueProcessor implements ClusterF
return tf;
}

public static ClusterCompressedDoubleArray calculateCompressedTF(String text, Map<String, Integer> wordMap, List<String> wordList, ClusterWordPreprocessor wordPreprocessor, Integer minWordCount) {
public static ClusterCompressedDoubleArray calculateCompressedTF(String text, Map<String, Integer> wordMap, List<String> wordList, ClusterWordPreprocessor wordPreprocessor, Integer minWordCount, Integer maxWordBagSize) {
//This will contain word counts for the current document
Map<String, Integer> currentWordMap = new Map<String,Integer>();

Integer numTokens = prepareWordList(text, wordMap, wordList, currentWordMap, wordPreprocessor, minWordCount);
Integer numTokens = prepareWordList(text, wordMap, wordList, currentWordMap, wordPreprocessor, minWordCount, maxWordBagSize);
//Calculating tf for the text
ClusterCompressedDoubleArray tf = new ClusterCompressedDoubleArray();
for (Integer i=0; i<wordList.size(); i++) {
Integer wordListSize = wordList.size();
for (Integer i=0; i<wordListSize; i++) {
String currentToken = wordList.get(i);
Integer wordCount = currentWordMap.get(currentToken);
if (wordCount != null && numTokens > 0) {
Expand All @@ -175,10 +184,10 @@ public with sharing class ClusterLongTextFieldValueProcessor implements ClusterF
return tf;
}

private static void addWordToMap(String word, Map<String, Integer> wordMap, Integer count, List<String> wordList) {
private static void addWordToMap(String word, Map<String, Integer> wordMap, Integer count, List<String> wordList, Integer maxWordBagSize) {
Integer currentCount = wordMap.get(word);
if (currentCount == null) {
if ((wordList == null) || (wordList.size() < ClusterConstants.MAX_TFIDF_WORDBAG_SIZE)) {
if ((wordList == null) || (wordList.size() < maxWordBagSize)) {
wordMap.put(word, count);
//Also adding new word to the list
if (wordList != null) {
Expand Down
33 changes: 32 additions & 1 deletion force-app/main/algorithms/classes/ClusterObjectFactory.cls
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
*
* @author: Iskander Mukhamedgaliyev
*/
public with sharing class ClusterObjectFactory {
public virtual with sharing class ClusterObjectFactory {
public virtual ClusterSObjectProcessor createSObjectProcessor(ClusterJobState state) {
return this.createSObjectProcessor(state, true);
}
Expand All @@ -18,4 +18,35 @@ public with sharing class ClusterObjectFactory {
return new ClusterKNNPredictor(runner);
}

public virtual ClusterRecordPreprocessor createRecordPreprocessor() {
String customPreprocessorApexClassName = ClusterConstants.getApexRecordPreprocessorClassName();
return this.createRecordPreprocessor(customPreprocessorApexClassName);
}

public virtual ClusterRecordPreprocessor createRecordPreprocessor(String customPreprocessorApexClassName) {
ClusterRecordPreprocessor recordPreprocessor;
if (String.isNotBlank(customPreprocessorApexClassName)) {
String[] classNameSplit = customPreprocessorApexClassName.split('.');
String namespace;
String className;
if (classNameSplit.size() >= 2) {
namespace = classNameSplit[0];
className = classNameSplit[1];
if (classNameSplit.size() == 3) {
className += '.' + classNameSplit[2];
}
}
else {
namespace = '';
className = customPreprocessorApexClassName;
}
Type customApexPreprocessorType = Type.forName(namespace, customPreprocessorApexClassName);
recordPreprocessor = (ClusterRecordPreprocessor)customApexPreprocessorType.newInstance();
}
else {
recordPreprocessor = null;
}
return recordPreprocessor;
}

}
3 changes: 3 additions & 0 deletions force-app/main/algorithms/classes/ClusterPrepareDataStep.cls
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ public with sharing class ClusterPrepareDataStep extends ClusterBatchBase implem
//Somebody could provide a SOQL query from an object which is different from the model object
//This is not allowed
ClusterAccessCheck.checkSObjectReadPermission(scope.get(0), model.objectName);

//Preprocessing records if custom preprocessor is specified
this.runner.preprocessSObjects(scope, true);
}
for (Integer sindex = 0; sindex < scopeSize; sindex++){
SObject record = scope[sindex];
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
/*
* Allows to implement custom Apex record preprocessors
*
* @author: Iskander Mukhamedgaliyev
*/
global interface ClusterRecordPreprocessor {
void processRecords(List<SObject> records, ClusterRecordPreprocessorParameters parameters);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
<?xml version="1.0" encoding="UTF-8"?>
<ApexClass xmlns="http://soap.sforce.com/2006/04/metadata">
<apiVersion>50.0</apiVersion>
<status>Active</status>
</ApexClass>
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Parameters for ClusterRecordPreprocessor
*
* @author: Iskander Mukhamedgaliyev
*/
global inherited sharing class ClusterRecordPreprocessorParameters {
private Id modelId;
private String modelName;
private Id jobId;
private Boolean isLearning;

public ClusterRecordPreprocessorParameters() {

}

global Id getModelId() {
return this.modelId;
}

public void setModelId(Id modelId) {
this.modelId = modelId;
}

global String getModelName() {
return this.modelName;
}

public void setModelName(String modelName) {
this.modelName = modelName;
}

global Id getJobId() {
return this.jobId;
}

public void setJobId(Id jobId) {
this.jobId = jobId;
}

global Boolean getIsLearning() {
return this.isLearning;
}

public void setIsLearning(Boolean isLearning) {
this.isLearning = isLearning;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
<?xml version="1.0" encoding="UTF-8"?>
<ApexClass xmlns="http://soap.sforce.com/2006/04/metadata">
<apiVersion>50.0</apiVersion>
<status>Active</status>
</ApexClass>
12 changes: 10 additions & 2 deletions force-app/main/api/classes/ClusterApi.cls
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,15 @@ global with sharing class ClusterApi {
return this.runner.getPredictor().getNearestNeighborsFromDb(recordId, runner.getJobState().clusterJob.id, numNeighbors);
}

/**
* Deletes previously calculated nearest neighbors from the database
* @param recordId
* Id of the source record
*/
global void deleteNearestNeighborsFromDb(Id recordId) {
this.checkRunnerInitialized();
this.runner.getPredictor().deleteNearestNeighbors(recordId, runner.getJobState().clusterJob.id);
}

/**
* Finds (calculates) nearest neighbors and stores calculations in the db
Expand Down Expand Up @@ -106,8 +115,7 @@ global with sharing class ClusterApi {
*/
global ClusterDataPoint convertToDataPoint(SObject record) {
this.checkRunnerInitialized();
ClusterSObjectProcessor objectProcessor = this.runner.getSObjectProcessor();
return objectProcessor.processSObject(record);
return this.runner.getDataPoint(record);
}

}
Loading

0 comments on commit aa160b8

Please sign in to comment.