From 6031b253677b27d9f379b9942a5a169f855654e4 Mon Sep 17 00:00:00 2001 From: Ahmed Sobhi Date: Sun, 21 Aug 2016 01:43:09 +0100 Subject: [PATCH] Make Trie Serializable - Base implementation on https://github.com/robert-bor/aho-corasick/pull/11 - Update pull request code to work on latest ahocorasick code - Omit usage of static fields in serialization code - Reformat/cleanup general code --- pom.xml | 16 +- .../org/ahocorasick/interval/Interval.java | 27 ++- .../ahocorasick/interval/IntervalNode.java | 31 ++- .../ahocorasick/interval/IntervalTree.java | 6 +- .../ahocorasick/interval/Intervalable.java | 6 +- src/main/java/org/ahocorasick/trie/Emit.java | 21 +- .../java/org/ahocorasick/trie/MatchToken.java | 2 +- src/main/java/org/ahocorasick/trie/State.java | 212 ++++++++++++++++-- src/main/java/org/ahocorasick/trie/Token.java | 2 +- src/main/java/org/ahocorasick/trie/Trie.java | 158 +++++++------ .../java/org/ahocorasick/trie/TrieConfig.java | 47 +++- .../trie/handler/DefaultEmitHandler.java | 2 +- .../ahocorasick/interval/IntervalTest.java | 10 +- .../interval/IntervalTreeTest.java | 6 +- .../IntervalableComparatorByPositionTest.java | 2 +- .../IntervalableComparatorBySizeTest.java | 4 +- .../java/org/ahocorasick/trie/StateTest.java | 21 +- .../java/org/ahocorasick/trie/TrieTest.java | 46 +++- 18 files changed, 464 insertions(+), 155 deletions(-) diff --git a/pom.xml b/pom.xml index 9d964e4..6cb2ef9 100644 --- a/pom.xml +++ b/pom.xml @@ -60,6 +60,18 @@ test + + + org.apache.commons + commons-lang3 + 3.4 + + + + com.google.guava + guava + 19.0 + @@ -78,8 +90,8 @@ org.apache.maven.plugins maven-compiler-plugin - 1.7 - 1.7 + 1.8 + 1.8 diff --git a/src/main/java/org/ahocorasick/interval/Interval.java b/src/main/java/org/ahocorasick/interval/Interval.java index c43dd7c..65f156a 100644 --- a/src/main/java/org/ahocorasick/interval/Interval.java +++ b/src/main/java/org/ahocorasick/interval/Interval.java @@ -1,6 +1,9 @@ package org.ahocorasick.interval; -public class Interval implements Intervalable { +import java.io.IOException; +import java.io.Serializable; + +public class Interval implements Intervalable, Serializable { private int start; private int end; @@ -10,21 +13,24 @@ public Interval(final int start, final int end) { this.end = end; } + @Override public int getStart() { return this.start; } + @Override public int getEnd() { return this.end; } + @Override public int size() { return end - start + 1; } public boolean overlapsWith(Interval other) { return this.start <= other.getEnd() && - this.end >= other.getStart(); + this.end >= other.getStart(); } public boolean overlapsWith(int point) { @@ -36,9 +42,9 @@ public boolean equals(Object o) { if (!(o instanceof Intervalable)) { return false; } - Intervalable other = (Intervalable)o; + Intervalable other = (Intervalable) o; return this.start == other.getStart() && - this.end == other.getEnd(); + this.end == other.getEnd(); } @Override @@ -51,7 +57,7 @@ public int compareTo(Object o) { if (!(o instanceof Intervalable)) { return -1; } - Intervalable other = (Intervalable)o; + Intervalable other = (Intervalable) o; int comparison = this.start - other.getStart(); return comparison != 0 ? comparison : this.end - other.getEnd(); } @@ -61,4 +67,15 @@ public String toString() { return this.start + ":" + this.end; } + protected void writeObject(java.io.ObjectOutputStream stream) + throws IOException { + stream.writeInt(start); + stream.writeInt(end); + } + + protected void readObject(java.io.ObjectInputStream stream) + throws IOException, ClassNotFoundException, IllegalAccessException, NoSuchFieldException { + this.start = stream.readInt(); + this.end = stream.readInt(); + } } diff --git a/src/main/java/org/ahocorasick/interval/IntervalNode.java b/src/main/java/org/ahocorasick/interval/IntervalNode.java index 11db0ae..706b41a 100644 --- a/src/main/java/org/ahocorasick/interval/IntervalNode.java +++ b/src/main/java/org/ahocorasick/interval/IntervalNode.java @@ -3,6 +3,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.stream.Collectors; public class IntervalNode { @@ -10,14 +11,14 @@ private enum Direction { LEFT, RIGHT } private IntervalNode left = null; private IntervalNode right = null; - private int point; - private List intervals = new ArrayList(); + private final int point; + private final List intervals = new ArrayList<>(); public IntervalNode(List intervals) { this.point = determineMedian(intervals); - List toLeft = new ArrayList(); - List toRight = new ArrayList(); + List toLeft = new ArrayList<>(); + List toRight = new ArrayList<>(); for (Intervalable interval : intervals) { if (interval.getEnd() < this.point) { @@ -37,7 +38,7 @@ public IntervalNode(List intervals) { } } - public int determineMedian(List intervals) { + private int determineMedian(List intervals) { int start = -1; int end = -1; for (Intervalable interval : intervals) { @@ -55,7 +56,7 @@ public int determineMedian(List intervals) { public List findOverlaps(Intervalable interval) { - List overlaps = new ArrayList(); + List overlaps = new ArrayList<>(); if (this.point < interval.getStart()) { // Tends to the right addToOverlaps(interval, overlaps, findOverlappingRanges(this.right, interval)); @@ -72,25 +73,21 @@ public List findOverlaps(Intervalable interval) { return overlaps; } - protected void addToOverlaps(Intervalable interval, List overlaps, List newOverlaps) { - for (Intervalable currentInterval : newOverlaps) { - if (!currentInterval.equals(interval)) { - overlaps.add(currentInterval); - } - } + private void addToOverlaps(Intervalable interval, List overlaps, List newOverlaps) { + overlaps.addAll(newOverlaps.stream().filter(currentInterval -> !currentInterval.equals(interval)).collect(Collectors.toList())); } - protected List checkForOverlapsToTheLeft(Intervalable interval) { + private List checkForOverlapsToTheLeft(Intervalable interval) { return checkForOverlaps(interval, Direction.LEFT); } - protected List checkForOverlapsToTheRight(Intervalable interval) { + private List checkForOverlapsToTheRight(Intervalable interval) { return checkForOverlaps(interval, Direction.RIGHT); } - protected List checkForOverlaps(Intervalable interval, Direction direction) { + private List checkForOverlaps(Intervalable interval, Direction direction) { - List overlaps = new ArrayList(); + List overlaps = new ArrayList<>(); for (Intervalable currentInterval : this.intervals) { switch (direction) { case LEFT : @@ -109,7 +106,7 @@ protected List checkForOverlaps(Intervalable interval, Direction d } - protected List findOverlappingRanges(IntervalNode node, Intervalable interval) { + private List findOverlappingRanges(IntervalNode node, Intervalable interval) { if (node != null) { return node.findOverlaps(interval); } diff --git a/src/main/java/org/ahocorasick/interval/IntervalTree.java b/src/main/java/org/ahocorasick/interval/IntervalTree.java index 40eeb0b..ca73c27 100644 --- a/src/main/java/org/ahocorasick/interval/IntervalTree.java +++ b/src/main/java/org/ahocorasick/interval/IntervalTree.java @@ -18,7 +18,7 @@ public List removeOverlaps(List intervals) { // Sort the intervals on size, then left-most position Collections.sort(intervals, new IntervalableComparatorBySize()); - Set removeIntervals = new TreeSet(); + Set removeIntervals = new TreeSet<>(); for (Intervalable interval : intervals) { // If the interval was already removed, ignore it @@ -31,9 +31,7 @@ public List removeOverlaps(List intervals) { } // Remove all intervals that were overlapping - for (Intervalable removeInterval : removeIntervals) { - intervals.remove(removeInterval); - } + removeIntervals.forEach(intervals::remove); // Sort the intervals, now on left-most position only Collections.sort(intervals, new IntervalableComparatorByPosition()); diff --git a/src/main/java/org/ahocorasick/interval/Intervalable.java b/src/main/java/org/ahocorasick/interval/Intervalable.java index 286a232..baab8a6 100644 --- a/src/main/java/org/ahocorasick/interval/Intervalable.java +++ b/src/main/java/org/ahocorasick/interval/Intervalable.java @@ -2,8 +2,8 @@ public interface Intervalable extends Comparable { - public int getStart(); - public int getEnd(); - public int size(); + int getStart(); + int getEnd(); + int size(); } diff --git a/src/main/java/org/ahocorasick/trie/Emit.java b/src/main/java/org/ahocorasick/trie/Emit.java index 60c1f9e..4617c58 100644 --- a/src/main/java/org/ahocorasick/trie/Emit.java +++ b/src/main/java/org/ahocorasick/trie/Emit.java @@ -3,7 +3,11 @@ import org.ahocorasick.interval.Interval; import org.ahocorasick.interval.Intervalable; -public class Emit extends Interval implements Intervalable { +import java.io.IOException; +import java.io.Serializable; +import java.lang.reflect.Field; + +public class Emit extends Interval implements Intervalable, Serializable { private final String keyword; @@ -21,4 +25,19 @@ public String toString() { return super.toString() + "=" + this.keyword; } + @Override + protected void writeObject(java.io.ObjectOutputStream stream) + throws IOException { + super.writeObject(stream); + stream.writeUTF(keyword); + } + + @Override + protected void readObject(java.io.ObjectInputStream stream) + throws IOException, ClassNotFoundException, IllegalAccessException, NoSuchFieldException { + Field f = this.getClass().getDeclaredField("keyword"); + super.readObject(stream); + f.setAccessible(true); + f.set(this, stream.readUTF()); + } } diff --git a/src/main/java/org/ahocorasick/trie/MatchToken.java b/src/main/java/org/ahocorasick/trie/MatchToken.java index c2615dc..9d91693 100644 --- a/src/main/java/org/ahocorasick/trie/MatchToken.java +++ b/src/main/java/org/ahocorasick/trie/MatchToken.java @@ -2,7 +2,7 @@ public class MatchToken extends Token { - private Emit emit; + private final Emit emit; public MatchToken(String fragment, Emit emit) { super(fragment); diff --git a/src/main/java/org/ahocorasick/trie/State.java b/src/main/java/org/ahocorasick/trie/State.java index 5220f72..ae1bfa4 100644 --- a/src/main/java/org/ahocorasick/trie/State.java +++ b/src/main/java/org/ahocorasick/trie/State.java @@ -1,53 +1,67 @@ package org.ahocorasick.trie; +import com.google.common.base.Objects; + +import java.io.*; +import java.lang.reflect.Field; import java.util.*; +import static com.google.common.base.Preconditions.checkNotNull; + /** *

- * A state has various important tasks it must attend to: + * A state has various important tasks it must attend to: *

- * + *

*

    - *
  • success; when a character points to another state, it must return that state
  • - *
  • failure; when a character has no matching state, the algorithm must be able to fall back on a - * state with less depth
  • - *
  • emits; when this state is passed and keywords have been matched, the matches must be - * 'emitted' so that they can be used later on.
  • + *
  • success; when a character points to another state, it must return that state
  • + *
  • failure; when a character has no matching state, the algorithm must be able to fall back on a + * state with less depth
  • + *
  • emits; when this state is passed and keywords have been matched, the matches must be + * 'emitted' so that they can be used later on.
  • *
- * *

- * The root state is special in the sense that it has no failure state; it cannot fail. If it 'fails' - * it will still parse the next character and start from the root node. This ensures that the algorithm - * always runs. All other states always have a fail state. + *

+ * The root state is special in the sense that it has no failure state; it cannot fail. If it 'fails' + * it will still parse the next character and start from the root node. This ensures that the algorithm + * always runs. All other states always have a fail state. *

* * @author Robert Bor */ -public class State { +public class State implements Serializable { - /** effective the size of the keyword */ + /** + * effective the size of the keyword + */ private final int depth; - /** only used for the root state to refer to itself in case no matches have been found */ + /** + * only used for the root state to refer to itself in case no matches have been found + */ private final State rootState; /** * referred to in the white paper as the 'goto' structure. From a state it is possible to go * to other states, depending on the character passed. */ - private Map success = new HashMap(); + private Map success = new TreeMap<>(); - /** if no matching states are found, the failure state will be returned */ + /** + * if no matching states are found, the failure state will be returned + */ private State failure = null; - /** whenever this state is reached, it will emit the matches keywords for future reference */ + /** + * whenever this state is reached, it will emit the matches keywords for future reference + */ private Set emits = null; public State() { this(0); } - public State(int depth) { + private State(int depth) { this.depth = depth; this.rootState = depth == 0 ? this : null; } @@ -64,14 +78,14 @@ public State nextState(Character character) { return nextState(character, false); } - public State nextStateIgnoreRootState(Character character) { + private State nextStateIgnoreRootState(Character character) { return nextState(character, true); } public State addState(Character character) { State nextState = nextStateIgnoreRootState(character); if (nextState == null) { - nextState = new State(this.depth+1); + nextState = new State(this.depth + 1); this.success.put(character, nextState); } return nextState; @@ -89,13 +103,11 @@ public void addEmit(String keyword) { } public void addEmit(Collection emits) { - for (String emit : emits) { - addEmit(emit); - } + emits.forEach(this::addEmit); } public Collection emit() { - return this.emits == null ? Collections. emptyList() : this.emits; + return this.emits == null ? Collections.emptyList() : this.emits; } public State failure() { @@ -114,4 +126,156 @@ public Collection getTransitions() { return this.success.keySet(); } + private void writeObject(StateObjectOutputStream stream) + throws IOException { + stream.writeInt(this.depth); + stream.writeInt(this.success.size()); + for (Map.Entry e : this.success.entrySet()) { + stream.writeObject(e.getKey()); + + Integer reference = stream.objectToReference.get(e.getValue()); + if (reference == null) { + stream.objectToReference.put(e.getValue(), stream.incrementAndGetReferenceCount()); + stream.writeInt(0); + stream.writeInt(stream.getReferenceCount()); + stream.writeObject(e.getValue()); + } else { + stream.writeInt(reference); + } + } + stream.writeObject(this.emits); + } + + private void writeObject(ObjectOutputStream stream) + throws IOException { + if (stream instanceof StateObjectOutputStream) { + writeObject((StateObjectOutputStream) stream); + } else { + // this is the root state + IdentityHashMap objectToReference = new IdentityHashMap<>(); + writeObject(new StateObjectOutputStream(stream, objectToReference, 1)); + } + } + + private void readObject(StateObjectInputStream stream) + throws IOException, ClassNotFoundException, NoSuchFieldException, IllegalAccessException { + + // Use reflection to modify final field + Field f = this.getClass().getDeclaredField("depth"); + f.setAccessible(true); + f.set(this, checkNotNull(stream.readInt())); + + f = this.getClass().getDeclaredField("rootState"); + f.setAccessible(true); + f.set(this, (depth == 0) ? this : null); + int successSize = checkNotNull(stream.readInt()); + success = new TreeMap<>(); + for (int i = 0; i < successSize; i++) { + Character character = checkNotNull((Character) stream.readObject()); + Integer reference = checkNotNull(stream.readInt()); + State treeState; + if (reference == 0) { + Integer referenceID = checkNotNull(stream.readInt()); + treeState = checkNotNull((State) stream.readObject()); + stream.getReferenceToObject().put(referenceID, treeState); + } else { + try { + treeState = checkNotNull((State) stream.getReferenceToObject().get(reference)); + } catch (NullPointerException e) { + throw new RuntimeException("reference=" + reference + ", " + stream.getReferenceToObject().size(), e); + } + } + success.put(character, treeState); + } + emits = (TreeSet) stream.readObject(); + } + + private void readObject(ObjectInputStream stream) + throws IOException, ClassNotFoundException, NoSuchFieldException, IllegalAccessException { + if (stream instanceof StateObjectInputStream) { + readObject((StateObjectInputStream) stream); + } else { + // this is the root state + IdentityHashMap referenceToObject = new IdentityHashMap<>(); + readObject(new StateObjectInputStream(stream, referenceToObject)); + // failure was not serialized/deserialized as it complicates logic, let's just reconstruct it + Trie.constructFailureStates(this); + } + } + + @Override + public boolean equals(Object other) { + if (this == other) return true; + if (other == null || getClass() != other.getClass()) return false; + State state = (State) other; + return equals(this, state, new IdentityHashMap<>()); + } + + private boolean equals(State self, State other, IdentityHashMap equalityReferenceMap) { + if (self == other) return true; + return !(self == null || other == null) && + self.depth == other.depth && + ((self.depth > 0 && self.rootState == null && other.rootState == null) || + (self.depth == 0 && self == self.rootState && other == other.rootState)) && + gotoEquality(self.success, other.success, equalityReferenceMap) && + equals(self.failure, other.failure, equalityReferenceMap) && Objects.equal(self.emits, other.emits); + } + + private boolean gotoEquality(Map mine, Map other, + IdentityHashMap equalityReferenceMap) { + if (mine.size() != other.size()) return false; + Iterator otherEntrySet = other.entrySet().iterator(); + for (Map.Entry e : mine.entrySet()) { + Map.Entry otherE = (Map.Entry) otherEntrySet.next(); + + if (!e.getKey().equals(otherE.getKey())) return false; + Integer reference = equalityReferenceMap.get(e.getValue()); + if (reference == null) { + equalityReferenceMap.put(e.getValue(), equalityReferenceMap.size() + 1); + if (!equals(e.getValue(), otherE.getValue(), equalityReferenceMap)) return false; + } + } + return true; + } + + private class StateObjectOutputStream extends ObjectOutputStream { + + private final IdentityHashMap objectToReference; + private int referenceCount; + + StateObjectOutputStream(ObjectOutputStream out, IdentityHashMap objectToReference, + int referenceCount) + throws IOException { + super(out); + this.objectToReference = objectToReference; + this.referenceCount = referenceCount; + } + + IdentityHashMap getObjectToReference() { + return objectToReference; + } + + int getReferenceCount() { + return referenceCount; + } + + int incrementAndGetReferenceCount() { + referenceCount += 1; + return referenceCount; + } + } + + class StateObjectInputStream extends ObjectInputStream { + private final IdentityHashMap referenceToObject; + + StateObjectInputStream(ObjectInputStream in, IdentityHashMap referenceToObject) + throws IOException { + super(in); + this.referenceToObject = referenceToObject; + } + + IdentityHashMap getReferenceToObject() { + return referenceToObject; + } + } } diff --git a/src/main/java/org/ahocorasick/trie/Token.java b/src/main/java/org/ahocorasick/trie/Token.java index 65c1fac..2e4c72f 100644 --- a/src/main/java/org/ahocorasick/trie/Token.java +++ b/src/main/java/org/ahocorasick/trie/Token.java @@ -2,7 +2,7 @@ public abstract class Token { - private String fragment; + private final String fragment; public Token(String fragment) { this.fragment = fragment; diff --git a/src/main/java/org/ahocorasick/trie/Trie.java b/src/main/java/org/ahocorasick/trie/Trie.java index 0b62c82..4500931 100644 --- a/src/main/java/org/ahocorasick/trie/Trie.java +++ b/src/main/java/org/ahocorasick/trie/Trie.java @@ -1,22 +1,28 @@ package org.ahocorasick.trie; +import com.google.common.base.Objects; import org.ahocorasick.interval.IntervalTree; import org.ahocorasick.interval.Intervalable; import org.ahocorasick.trie.handler.DefaultEmitHandler; import org.ahocorasick.trie.handler.EmitHandler; +import java.io.IOException; +import java.io.Serializable; import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.Queue; import java.util.concurrent.LinkedBlockingDeque; +import java.util.stream.Collectors; + +import static com.google.common.base.Preconditions.checkNotNull; /** - * * Based on the Aho-Corasick white paper, Bell technologies: http://cr.yp.to/bib/1975/aho.pdf + * * @author Robert Bor */ -public class Trie { +public class Trie implements Serializable { private TrieConfig trieConfig; @@ -62,11 +68,11 @@ public Collection tokenize(String text) { } private Token createFragment(Emit emit, String text, int lastCollectedPosition) { - return new FragmentToken(text.substring(lastCollectedPosition+1, emit == null ? text.length() : emit.getStart())); + return new FragmentToken(text.substring(lastCollectedPosition + 1, emit == null ? text.length() : emit.getStart())); } private Token createMatch(Emit emit, String text) { - return new MatchToken(text.substring(emit.getStart(), emit.getEnd()+1), emit); + return new MatchToken(text.substring(emit.getStart(), emit.getEnd() + 1), emit); } @SuppressWarnings("unchecked") @@ -85,17 +91,17 @@ public Collection parseText(CharSequence text) { } if (!trieConfig.isAllowOverlaps()) { - IntervalTree intervalTree = new IntervalTree((List)(List)collectedEmits); + IntervalTree intervalTree = new IntervalTree((List) (List) collectedEmits); intervalTree.removeOverlaps((List) (List) collectedEmits); } return collectedEmits; } - public boolean containsMatch(CharSequence text) { - Emit firstMatch = firstMatch(text); - return firstMatch != null; - } + public boolean containsMatch(CharSequence text) { + Emit firstMatch = firstMatch(text); + return firstMatch != null; + } public void parseText(CharSequence text, EmitHandler emitHandler) { State currentState = this.rootState; @@ -112,72 +118,63 @@ public void parseText(CharSequence text, EmitHandler emitHandler) { } - public Emit firstMatch(CharSequence text) { - if (!trieConfig.isAllowOverlaps()) { - // Slow path. Needs to find all the matches to detect overlaps. - Collection parseText = parseText(text); - if (parseText != null && !parseText.isEmpty()) { - return parseText.iterator().next(); - } - } else { - // Fast path. Returns first match found. - State currentState = this.rootState; + public Emit firstMatch(CharSequence text) { + if (!trieConfig.isAllowOverlaps()) { + // Slow path. Needs to find all the matches to detect overlaps. + Collection parseText = parseText(text); + if (parseText != null && !parseText.isEmpty()) { + return parseText.iterator().next(); + } + } else { + // Fast path. Returns first match found. + State currentState = this.rootState; for (int position = 0; position < text.length(); position++) { Character character = text.charAt(position); - if (trieConfig.isCaseInsensitive()) { - character = Character.toLowerCase(character); - } - currentState = getState(currentState, character); - Collection emitStrs = currentState.emit(); - if (emitStrs != null && !emitStrs.isEmpty()) { - for (String emitStr : emitStrs) { - final Emit emit = new Emit(position - emitStr.length() + 1, position, emitStr); - if (trieConfig.isOnlyWholeWords()) { - if (!isPartialMatch(text, emit)) { - return emit; - } - } else { - return emit; - } - } - } - } - } - return null; - } - - private boolean isPartialMatch(CharSequence searchText, Emit emit) { - return (emit.getStart() != 0 && - Character.isAlphabetic(searchText.charAt(emit.getStart() - 1))) || - (emit.getEnd() + 1 != searchText.length() && - Character.isAlphabetic(searchText.charAt(emit.getEnd() + 1))); - } - - private void removePartialMatches(CharSequence searchText, List collectedEmits) { - List removeEmits = new ArrayList<>(); - for (Emit emit : collectedEmits) { - if (isPartialMatch(searchText, emit)) { - removeEmits.add(emit); - } - } - for (Emit removeEmit : removeEmits) { - collectedEmits.remove(removeEmit); - } - } + if (trieConfig.isCaseInsensitive()) { + character = Character.toLowerCase(character); + } + currentState = getState(currentState, character); + Collection emitStrs = currentState.emit(); + if (emitStrs != null && !emitStrs.isEmpty()) { + for (String emitStr : emitStrs) { + final Emit emit = new Emit(position - emitStr.length() + 1, position, emitStr); + if (trieConfig.isOnlyWholeWords()) { + if (!isPartialMatch(text, emit)) { + return emit; + } + } else { + return emit; + } + } + } + } + } + return null; + } + + private boolean isPartialMatch(CharSequence searchText, Emit emit) { + return (emit.getStart() != 0 && + Character.isAlphabetic(searchText.charAt(emit.getStart() - 1))) || + (emit.getEnd() + 1 != searchText.length() && + Character.isAlphabetic(searchText.charAt(emit.getEnd() + 1))); + } + + private void removePartialMatches(CharSequence searchText, List collectedEmits) { + List removeEmits = collectedEmits.stream().filter(emit -> isPartialMatch(searchText, emit)).collect(Collectors.toList()); + removeEmits.forEach(collectedEmits::remove); + } private void removePartialMatchesWhiteSpaceSeparated(CharSequence searchText, List collectedEmits) { long size = searchText.length(); List removeEmits = new ArrayList<>(); for (Emit emit : collectedEmits) { if ((emit.getStart() == 0 || Character.isWhitespace(searchText.charAt(emit.getStart() - 1))) && - (emit.getEnd() + 1 == size || Character.isWhitespace(searchText.charAt(emit.getEnd() + 1)))) { + (emit.getEnd() + 1 == size || Character.isWhitespace(searchText.charAt(emit.getEnd() + 1)))) { continue; } removeEmits.add(emit); } - for (Emit removeEmit : removeEmits) { - collectedEmits.remove(removeEmit); - } + removeEmits.forEach(collectedEmits::remove); } private State getState(State currentState, Character character) { @@ -190,11 +187,15 @@ private State getState(State currentState, Character character) { } private void constructFailureStates() { + constructFailureStates(this.rootState); + } + + static void constructFailureStates(State rootState) { Queue queue = new LinkedBlockingDeque<>(); // First, set the fail state of all depth 1 states to the root state - for (State depthOneState : this.rootState.getStates()) { - depthOneState.setFailure(this.rootState); + for (State depthOneState : rootState.getStates()) { + depthOneState.setFailure(rootState); queue.add(depthOneState); } @@ -229,17 +230,38 @@ private boolean storeEmits(int position, State currentState, EmitHandler emitHan return emitted; } + private void writeObject(java.io.ObjectOutputStream stream) + throws IOException { + stream.writeObject(trieConfig); + stream.writeObject(rootState); + } + + private void readObject(java.io.ObjectInputStream stream) + throws IOException, ClassNotFoundException { + trieConfig = checkNotNull((TrieConfig) stream.readObject()); + rootState = checkNotNull((State) stream.readObject()); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Trie trie = (Trie) o; + return Objects.equal(trieConfig, trie.trieConfig) && Objects.equal(rootState, trie.rootState); + } + public static TrieBuilder builder() { return new TrieBuilder(); } public static class TrieBuilder { - private TrieConfig trieConfig = new TrieConfig(); + private final TrieConfig trieConfig = new TrieConfig(); - private Trie trie = new Trie(trieConfig); + private final Trie trie = new Trie(trieConfig); - private TrieBuilder() {} + private TrieBuilder() { + } public TrieBuilder caseInsensitive() { this.trieConfig.setCaseInsensitive(true); diff --git a/src/main/java/org/ahocorasick/trie/TrieConfig.java b/src/main/java/org/ahocorasick/trie/TrieConfig.java index f9f0125..9170a2d 100644 --- a/src/main/java/org/ahocorasick/trie/TrieConfig.java +++ b/src/main/java/org/ahocorasick/trie/TrieConfig.java @@ -1,6 +1,9 @@ package org.ahocorasick.trie; -public class TrieConfig { +import java.io.IOException; +import java.io.Serializable; + +public class TrieConfig implements Serializable { private boolean allowOverlaps = true; @@ -12,9 +15,13 @@ public class TrieConfig { private boolean stopOnHit = false; - public boolean isStopOnHit() { return stopOnHit; } + public boolean isStopOnHit() { + return stopOnHit; + } - public void setStopOnHit(boolean stopOnHit) { this.stopOnHit = stopOnHit; } + public void setStopOnHit(boolean stopOnHit) { + this.stopOnHit = stopOnHit; + } public boolean isAllowOverlaps() { return allowOverlaps; @@ -32,7 +39,9 @@ public void setOnlyWholeWords(boolean onlyWholeWords) { this.onlyWholeWords = onlyWholeWords; } - public boolean isOnlyWholeWordsWhiteSpaceSeparated() { return onlyWholeWordsWhiteSpaceSeparated; } + public boolean isOnlyWholeWordsWhiteSpaceSeparated() { + return onlyWholeWordsWhiteSpaceSeparated; + } public void setOnlyWholeWordsWhiteSpaceSeparated(boolean onlyWholeWordsWhiteSpaceSeparated) { this.onlyWholeWordsWhiteSpaceSeparated = onlyWholeWordsWhiteSpaceSeparated; @@ -45,4 +54,34 @@ public boolean isCaseInsensitive() { public void setCaseInsensitive(boolean caseInsensitive) { this.caseInsensitive = caseInsensitive; } + + private void writeObject(java.io.ObjectOutputStream stream) + throws IOException { + stream.writeBoolean(allowOverlaps); + stream.writeBoolean(onlyWholeWords); + stream.writeBoolean(onlyWholeWordsWhiteSpaceSeparated); + stream.writeBoolean(caseInsensitive); + stream.writeBoolean(stopOnHit); + } + + private void readObject(java.io.ObjectInputStream stream) + throws IOException, ClassNotFoundException { + this.allowOverlaps = stream.readBoolean(); + this.onlyWholeWords = stream.readBoolean(); + onlyWholeWordsWhiteSpaceSeparated = stream.readBoolean(); + this.caseInsensitive = stream.readBoolean(); + this.stopOnHit = stream.readBoolean(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TrieConfig that = (TrieConfig) o; + return allowOverlaps == that.allowOverlaps && + onlyWholeWords == that.onlyWholeWords && + onlyWholeWordsWhiteSpaceSeparated == that.onlyWholeWordsWhiteSpaceSeparated && + caseInsensitive == that.caseInsensitive && + stopOnHit == that.stopOnHit; + } } diff --git a/src/main/java/org/ahocorasick/trie/handler/DefaultEmitHandler.java b/src/main/java/org/ahocorasick/trie/handler/DefaultEmitHandler.java index 656d1e2..7c2ce3b 100644 --- a/src/main/java/org/ahocorasick/trie/handler/DefaultEmitHandler.java +++ b/src/main/java/org/ahocorasick/trie/handler/DefaultEmitHandler.java @@ -7,7 +7,7 @@ public class DefaultEmitHandler implements EmitHandler { - private List emits = new ArrayList<>(); + private final List emits = new ArrayList<>(); @Override public void emit(Emit emit) { diff --git a/src/test/java/org/ahocorasick/interval/IntervalTest.java b/src/test/java/org/ahocorasick/interval/IntervalTest.java index e61bad7..fdcddaa 100644 --- a/src/test/java/org/ahocorasick/interval/IntervalTest.java +++ b/src/test/java/org/ahocorasick/interval/IntervalTest.java @@ -2,11 +2,11 @@ import org.junit.Test; -import java.util.*; +import java.util.Iterator; +import java.util.Set; +import java.util.TreeSet; -import static junit.framework.Assert.assertEquals; -import static junit.framework.Assert.assertFalse; -import static junit.framework.Assert.assertTrue; +import static junit.framework.Assert.*; public class IntervalTest { @@ -44,7 +44,7 @@ public void pointDoesNotOverlap() { @Test public void comparable() { - Set intervals = new TreeSet(); + Set intervals = new TreeSet<>(); intervals.add(new Interval(4, 6)); intervals.add(new Interval(2, 7)); intervals.add(new Interval(3, 4)); diff --git a/src/test/java/org/ahocorasick/interval/IntervalTreeTest.java b/src/test/java/org/ahocorasick/interval/IntervalTreeTest.java index f4a7f57..1d9f75b 100644 --- a/src/test/java/org/ahocorasick/interval/IntervalTreeTest.java +++ b/src/test/java/org/ahocorasick/interval/IntervalTreeTest.java @@ -12,7 +12,7 @@ public class IntervalTreeTest { @Test public void findOverlaps() { - List intervals = new ArrayList(); + List intervals = new ArrayList<>(); intervals.add(new Interval(0, 2)); intervals.add(new Interval(1, 3)); intervals.add(new Interval(2, 4)); @@ -30,7 +30,7 @@ public void findOverlaps() { @Test public void removeOverlaps() { - List intervals = new ArrayList(); + List intervals = new ArrayList<>(); intervals.add(new Interval(0, 2)); intervals.add(new Interval(4, 5)); intervals.add(new Interval(2, 10)); @@ -43,7 +43,7 @@ public void removeOverlaps() { } - protected void assertOverlap(Intervalable interval, int expectedStart, int expectedEnd) { + private void assertOverlap(Intervalable interval, int expectedStart, int expectedEnd) { assertEquals(expectedStart, interval.getStart()); assertEquals(expectedEnd, interval.getEnd()); } diff --git a/src/test/java/org/ahocorasick/interval/IntervalableComparatorByPositionTest.java b/src/test/java/org/ahocorasick/interval/IntervalableComparatorByPositionTest.java index a6f1017..6703adc 100644 --- a/src/test/java/org/ahocorasick/interval/IntervalableComparatorByPositionTest.java +++ b/src/test/java/org/ahocorasick/interval/IntervalableComparatorByPositionTest.java @@ -12,7 +12,7 @@ public class IntervalableComparatorByPositionTest { @Test public void sortOnPosition() { - List intervals = new ArrayList(); + List intervals = new ArrayList<>(); intervals.add(new Interval(4,5)); intervals.add(new Interval(1,4)); intervals.add(new Interval(3,8)); diff --git a/src/test/java/org/ahocorasick/interval/IntervalableComparatorBySizeTest.java b/src/test/java/org/ahocorasick/interval/IntervalableComparatorBySizeTest.java index 208cf3d..6b531e3 100644 --- a/src/test/java/org/ahocorasick/interval/IntervalableComparatorBySizeTest.java +++ b/src/test/java/org/ahocorasick/interval/IntervalableComparatorBySizeTest.java @@ -12,7 +12,7 @@ public class IntervalableComparatorBySizeTest { @Test public void sortOnSize() { - List intervals = new ArrayList(); + List intervals = new ArrayList<>(); intervals.add(new Interval(4,5)); intervals.add(new Interval(1,4)); intervals.add(new Interval(3,8)); @@ -24,7 +24,7 @@ public void sortOnSize() { @Test public void sortOnSizeThenPosition() { - List intervals = new ArrayList(); + List intervals = new ArrayList<>(); intervals.add(new Interval(4,7)); intervals.add(new Interval(2,5)); Collections.sort(intervals, new IntervalableComparatorBySize()); diff --git a/src/test/java/org/ahocorasick/trie/StateTest.java b/src/test/java/org/ahocorasick/trie/StateTest.java index 2a64370..f3b4b8e 100644 --- a/src/test/java/org/ahocorasick/trie/StateTest.java +++ b/src/test/java/org/ahocorasick/trie/StateTest.java @@ -1,8 +1,10 @@ package org.ahocorasick.trie; -import org.ahocorasick.trie.State; +import org.apache.commons.lang3.SerializationUtils; import org.junit.Test; +import java.io.Serializable; + import static junit.framework.Assert.assertEquals; public class StateTest { @@ -11,9 +13,9 @@ public class StateTest { public void constructSequenceOfCharacters() { State rootState = new State(); rootState - .addState('a') - .addState('b') - .addState('c'); + .addState('a') + .addState('b') + .addState('c'); State currentState = rootState.nextState('a'); assertEquals(1, currentState.getDepth()); currentState = currentState.nextState('b'); @@ -22,4 +24,15 @@ public void constructSequenceOfCharacters() { assertEquals(3, currentState.getDepth()); } + @Test + public void testSerialization() { + State rootState = new State(); + rootState.addState('a') + .addState('b') + .addState('c'); + Trie.constructFailureStates(rootState); + + Serializable copy = SerializationUtils.clone(rootState); + assertEquals(copy, rootState); + } } diff --git a/src/test/java/org/ahocorasick/trie/TrieTest.java b/src/test/java/org/ahocorasick/trie/TrieTest.java index 6a620f0..aa324eb 100644 --- a/src/test/java/org/ahocorasick/trie/TrieTest.java +++ b/src/test/java/org/ahocorasick/trie/TrieTest.java @@ -1,6 +1,7 @@ package org.ahocorasick.trie; import org.ahocorasick.trie.handler.EmitHandler; +import org.apache.commons.lang3.SerializationUtils; import org.junit.Test; import java.util.ArrayList; @@ -21,6 +22,7 @@ public void keywordAndTextAreTheSame() { Collection emits = trie.parseText("abc"); Iterator iterator = emits.iterator(); checkEmit(iterator.next(), 0, 2, "abc"); + checkSerialization(trie); } @Test @@ -30,6 +32,7 @@ public void keywordAndTextAreTheSameFirstMatch() { .build(); Emit firstMatch = trie.firstMatch("abc"); checkEmit(firstMatch, 0, 2, "abc"); + checkSerialization(trie); } @Test @@ -40,6 +43,7 @@ public void textIsLongerThanKeyword() { Collection emits = trie.parseText(" abc"); Iterator iterator = emits.iterator(); checkEmit(iterator.next(), 1, 3, "abc"); + checkSerialization(trie); } @Test @@ -61,6 +65,7 @@ public void variousKeywordsOneMatch() { Collection emits = trie.parseText("bcd"); Iterator iterator = emits.iterator(); checkEmit(iterator.next(), 0, 2, "bcd"); + checkSerialization(trie); } @Test @@ -72,6 +77,7 @@ public void variousKeywordsFirstMatch() { .build(); Emit firstMatch = trie.firstMatch("bcd"); checkEmit(firstMatch, 0, 2, "bcd"); + checkSerialization(trie); } @Test @@ -88,6 +94,7 @@ public void ushersTestAndStopOnHit() { Iterator iterator = emits.iterator(); checkEmit(iterator.next(), 2, 3, "he"); checkEmit(iterator.next(), 1, 3, "she"); + checkSerialization(trie); } @Test @@ -104,6 +111,7 @@ public void ushersTest() { checkEmit(iterator.next(), 2, 3, "he"); checkEmit(iterator.next(), 1, 3, "she"); checkEmit(iterator.next(), 2, 5, "hers"); + checkSerialization(trie); } @Test @@ -121,6 +129,7 @@ public void ushersTestWithCapitalKeywords() { checkEmit(iterator.next(), 2, 3, "he"); checkEmit(iterator.next(), 1, 3, "she"); checkEmit(iterator.next(), 2, 5, "hers"); + checkSerialization(trie); } @Test @@ -133,6 +142,7 @@ public void ushersTestFirstMatch() { .build(); Emit firstMatch = trie.firstMatch("ushers"); checkEmit(firstMatch, 2, 3, "he"); + checkSerialization(trie); } @Test @@ -143,15 +153,9 @@ public void ushersTestByCallback() { .addKeyword("she") .addKeyword("he") .build(); - + checkSerialization(trie); final List emits = new ArrayList<>(); - EmitHandler emitHandler = new EmitHandler() { - - @Override - public void emit(Emit emit) { - emits.add(emit); - } - }; + EmitHandler emitHandler = emits::add; trie.parseText("ushers", emitHandler); assertEquals(3, emits.size()); // she @ 3, he @ 3, hers @ 5 Iterator iterator = emits.iterator(); @@ -168,6 +172,7 @@ public void misleadingTest() { Collection emits = trie.parseText("h he her hers"); Iterator iterator = emits.iterator(); checkEmit(iterator.next(), 9, 12, "hers"); + checkSerialization(trie); } @Test @@ -177,6 +182,7 @@ public void misleadingTestFirstMatch() { .build(); Emit firstMatch = trie.firstMatch("h he her hers"); checkEmit(firstMatch, 9, 12, "hers"); + checkSerialization(trie); } @Test @@ -193,6 +199,7 @@ public void recipes() { checkEmit(iterator.next(), 18, 25, "tomatoes"); checkEmit(iterator.next(), 40, 43, "veal"); checkEmit(iterator.next(), 51, 58, "broccoli"); + checkSerialization(trie); } @Test @@ -206,6 +213,7 @@ public void recipesFirstMatch() { Emit firstMatch = trie.firstMatch("2 cauliflowers, 3 tomatoes, 4 slices of veal, 100g broccoli"); checkEmit(firstMatch, 2, 12, "cauliflower"); + checkSerialization(trie); } @Test @@ -223,6 +231,7 @@ public void longAndShortOverlappingMatch() { checkEmit(iterator.next(), 0, 7, "hehehehe"); checkEmit(iterator.next(), 8, 9, "he"); checkEmit(iterator.next(), 2, 9, "hehehehe"); + checkSerialization(trie); } @Test @@ -238,6 +247,7 @@ public void nonOverlapping() { // With overlaps: ab@1, ab@3, ababc@4, cba@6, ab@7 checkEmit(iterator.next(), 0, 4, "ababc"); checkEmit(iterator.next(), 6, 7, "ab"); + checkSerialization(trie); } @Test @@ -250,6 +260,7 @@ public void nonOverlappingFirstMatch() { Emit firstMatch = trie.firstMatch("ababcbab"); checkEmit(firstMatch, 0, 4, "ababc"); + checkSerialization(trie); } @Test @@ -260,6 +271,7 @@ public void containsMatch() { .addKeyword("ababc") .build(); assertTrue(trie.containsMatch("ababcbab")); + checkSerialization(trie); } @Test @@ -278,6 +290,7 @@ public void startOfChurchillSpeech() { .build(); Collection emits = trie.parseText("Turning"); assertEquals(2, emits.size()); + checkSerialization(trie); } @Test @@ -289,6 +302,7 @@ public void partialMatch() { Collection emits = trie.parseText("sugarcane sugarcane sugar canesugar"); // left, middle, right test assertEquals(1, emits.size()); // Match must not be made checkEmit(emits.iterator().next(), 20, 24, "sugar"); + checkSerialization(trie); } @Test @@ -300,6 +314,7 @@ public void partialMatchFirstMatch() { Emit firstMatch = trie.firstMatch("sugarcane sugarcane sugar canesugar"); // left, middle, right test checkEmit(firstMatch, 20, 24, "sugar"); + checkSerialization(trie); } @Test @@ -319,6 +334,7 @@ public void tokenizeFullSentence() { assertEquals(" from the rear, ", tokensIt.next().getFragment()); assertEquals("Gamma", tokensIt.next().getFragment()); assertEquals(" in reserve", tokensIt.next().getFragment()); + checkSerialization(trie); } @Test @@ -336,6 +352,7 @@ public void bug5InGithubReportedByXCurry() { checkEmit(it.next(), 8, 11, "once"); checkEmit(it.next(), 13, 17, "again"); checkEmit(it.next(), 19, 23, "börkü"); + checkSerialization(trie); } @Test @@ -353,6 +370,7 @@ public void caseInsensitive() { checkEmit(it.next(), 8, 11, "once"); checkEmit(it.next(), 13, 17, "again"); checkEmit(it.next(), 19, 23, "börkü"); + checkSerialization(trie); } @Test @@ -366,6 +384,7 @@ public void caseInsensitiveFirstMatch() { Emit firstMatch = trie.firstMatch("TurninG OnCe AgAiN BÖRKÜ"); checkEmit(firstMatch, 0, 6, "turning"); + checkSerialization(trie); } @Test @@ -377,6 +396,7 @@ public void tokenizeTokensInSequence() { .build(); Collection tokens = trie.tokenize("Alpha Beta Gamma"); assertEquals(5, tokens.size()); + checkSerialization(trie); } // Test offered by XCurry, https://github.com/robert-bor/aho-corasick/issues/7 @@ -386,6 +406,7 @@ public void zeroLengthTestBug7InGithubReportedByXCurry() { .addKeyword("") .build(); trie.tokenize("Try a natural lip and subtle bronzer to keep all the focus on those big bright eyes with NARS Eyeshadow Duo in Rated R And the winner is... Boots No7 Advanced Renewal Anti-ageing Glycolic Peel Kit ($25 amazon.com) won most-appealing peel."); + checkSerialization(trie); } // Test offered by dwyerk, https://github.com/robert-bor/aho-corasick/issues/8 @@ -400,6 +421,7 @@ public void unicodeIssueBug8ReportedByDwyerk() { assertEquals(1, emits.size()); Iterator it = emits.iterator(); checkEmit(it.next(), 5, 8, "this"); + checkSerialization(trie); } @Test @@ -413,6 +435,7 @@ public void unicodeIssueBug8ReportedByDwyerkFirstMatch() { assertEquals("THIS", target.substring(5, 9)); // Java does it the right way Emit firstMatch = trie.firstMatch(target); checkEmit(firstMatch, 5, 8, "this"); + checkSerialization(trie); } @Test @@ -421,9 +444,10 @@ public void partialMatchWhiteSpaces() { .onlyWholeWordsWhiteSpaceSeparated() .addKeyword("#sugar-123") .build(); - Collection < Emit > emits = trie.parseText("#sugar-123 #sugar-1234"); // left, middle, right test + Collection emits = trie.parseText("#sugar-123 #sugar-1234"); // left, middle, right test assertEquals(1, emits.size()); // Match must not be made checkEmit(emits.iterator().next(), 0, 9, "#sugar-123"); + checkSerialization(trie); } private void checkEmit(Emit next, int expectedStart, int expectedEnd, String expectedKeyword) { @@ -432,4 +456,8 @@ private void checkEmit(Emit next, int expectedStart, int expectedEnd, String exp assertEquals(expectedKeyword, next.getKeyword()); } + private void checkSerialization(Trie trie) { + Trie clonedTrie = SerializationUtils.clone(trie); + assertEquals(trie, clonedTrie); + } }