diff --git a/src/main/java/emissary/kff/KffFile.java b/src/main/java/emissary/kff/KffFile.java index 9edbfbd65f..2d602c9c9d 100755 --- a/src/main/java/emissary/kff/KffFile.java +++ b/src/main/java/emissary/kff/KffFile.java @@ -1,18 +1,13 @@ package emissary.kff; -import emissary.core.channels.FileChannelFactory; -import emissary.core.channels.SeekableByteChannelFactory; - -import org.apache.commons.io.IOUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.io.EOFException; import java.io.IOException; -import java.io.InputStream; +import java.io.RandomAccessFile; import java.nio.ByteBuffer; -import java.nio.channels.SeekableByteChannel; -import java.nio.file.Files; -import java.nio.file.Paths; +import java.util.concurrent.locks.ReentrantLock; import javax.annotation.Nonnull; /** @@ -31,19 +26,19 @@ *

*/ public class KffFile implements KffFilter { - private final Logger logger; + private static final Logger logger = LoggerFactory.getLogger(KffFile.class); /** File containing SHA-1/CRC32 results of known files */ - protected SeekableByteChannelFactory knownFileFactory; + protected RandomAccessFile knownFile; /** Byte buffer that is mapped to the above file */ protected ByteBuffer mappedBuf; /** Initial value of high index for binary search */ - private int bSearchInitHigh; + private long bSearchInitHigh; public static final int DEFAULT_RECORD_LENGTH = 24; - protected int recordLength; + protected final int recordLength; /** String logical name for this filter */ protected String filterName = "UNKNOWN"; @@ -52,6 +47,8 @@ public class KffFile implements KffFilter { protected String myPreferredAlgorithm = "SHA-1"; + protected ReentrantLock reentrantLock = new ReentrantLock(); + /** * Creates a new instance of KffFile * @@ -81,16 +78,11 @@ public KffFile(String filename, String filterName, FilterType ftype, int recordL this.filterName = filterName; this.recordLength = recordLength; - // Set logger to run time class - logger = LoggerFactory.getLogger(this.getClass()); - // Open file in read-only mode - knownFileFactory = FileChannelFactory.create(Paths.get(filename)); + knownFile = new RandomAccessFile(filename, "r"); - // Initial high value for binary search is the largest index - try (SeekableByteChannel sbc = knownFileFactory.create()) { - bSearchInitHigh = ((int) sbc.size() / recordLength) - 1; - } + // Initial high value for binary search is largest index + bSearchInitHigh = (knownFile.length() / (long) recordLength) - 1; logger.debug("KFF File {} has {} records", filename, (bSearchInitHigh + 1)); } @@ -147,26 +139,23 @@ public String getPreferredAlgorithm() { private boolean binaryFileSearch(@Nonnull byte[] hash, long crc) { // Initialize indexes for binary search - int low = 0; - int high = bSearchInitHigh; + long low = 0; + long high = bSearchInitHigh; /* Buffer to hold a record */ byte[] rec = new byte[recordLength]; - ByteBuffer byteBuffer = ByteBuffer.wrap(rec); - // Search until the indexes cross - try (SeekableByteChannel knownFile = knownFileFactory.create()) { - while (low <= high) { - byteBuffer.clear(); + reentrantLock.lock(); + try { + // Search until the indexes cross + while (low <= high) { // Calculate the midpoint - int mid = (low + high) >> 1; + long mid = (low + high) >> 1; + + // Multiply the index by the record length to get the buffer position and read the record + knownFile.seek(recordLength * mid); + knownFile.readFully(rec); - knownFile.position(rec.length * (long) mid); - int count = IOUtils.read(knownFile, byteBuffer); - if (count != rec.length) { - logger.warn("Short read on KffFile at {} read {} expected {}", (recordLength * mid), count, recordLength); - return false; - } // Compare the record with the target. Adjust the indexes accordingly. int c = compare(rec, hash, crc); if (c < 0) { @@ -177,10 +166,16 @@ private boolean binaryFileSearch(@Nonnull byte[] hash, long crc) { return true; } } + } catch (EOFException e) { + // this shouldn't happen if we're synchronizing calls correctly + logger.warn("EOFException reading KffFile: {}", e.getLocalizedMessage()); } catch (IOException e) { logger.warn("Exception reading KffFile", e); + } finally { + if (reentrantLock.isHeldByCurrentThread()) { + reentrantLock.unlock(); + } } - // not found return false; } @@ -188,21 +183,19 @@ private boolean binaryFileSearch(@Nonnull byte[] hash, long crc) { /** * Compares the given hash/crc to the one in the record. * - * @param rec bytes from the kff binary file, one record long + * @param record bytes from the kff binary file, one record long * @param hash HASH to compare to record * @param crc CRC to compare to record * @return <0 if given value is less than record, >0 if given value is greater than record, 0 if they match */ - private int compare(@Nonnull byte[] rec, @Nonnull byte[] hash, long crc) { + private int compare(@Nonnull byte[] record, @Nonnull byte[] hash, long crc) { int i; - // Compare the HASHs first. We can't compare the bytes directly - // because a Java byte is signed and may generate the wrong - // result. We must convert to integers and then mask off the - // sign bits to get proper results. + // Compare the hashes first. We can't compare the bytes directly because a Java byte is signed and may generate the + // wrong result. We must convert to integers and then mask off the sign bits to get proper results. for (i = 0; i < hash.length; i++) { int ihash = hash[i] & 0xff; - int irec = rec[i] & 0xff; + int irec = record[i] & 0xff; if (ihash < irec) { return -1; } else if (ihash > irec) { @@ -210,11 +203,11 @@ private int compare(@Nonnull byte[] rec, @Nonnull byte[] hash, long crc) { } } - // If the HASHs match, check the CRCs. + // If the hashes match, check the CRCs. if (crc != -1L) { - for (int j = 24; i < rec.length; i++, j -= 8) { + for (int j = 24; i < record.length; i++, j -= 8) { int icrc = ((int) crc >> j) & 0xff; - int irec = rec[i] & 0xff; + int irec = record[i] & 0xff; if (icrc < irec) { return -1; } else if (icrc > irec) { @@ -234,25 +227,4 @@ public boolean check(String fname, ChecksumResults csum) throws Exception { } return binaryFileSearch(hash, csum.getCrc()); } - - public static void main(String[] args) throws Exception { - KffChain kff = new KffChain(); - KffFile kfile = new KffFile(args[0], "TEST", FilterType.Ignore); - kfile.setPreferredAlgorithm("SHA-1"); - kff.addFilter(kfile); - kff.addAlgorithm("CRC32"); - kff.addAlgorithm("SSDEEP"); - kff.addAlgorithm("MD5"); - kff.addAlgorithm("SHA-1"); - kff.addAlgorithm("SHA-256"); - - for (int i = 1; i < args.length; i++) { - try (InputStream is = Files.newInputStream(Paths.get(args[i]))) { - byte[] buffer = IOUtils.toByteArray(is); - - KffResult r = kff.check(args[i], buffer); - System.out.println(args[i] + ": " + r.isKnown() + " - " + r.getShaString() + " - " + r.getCrc32()); - } - } - } } diff --git a/src/test/java/emissary/kff/KffFileTest.java b/src/test/java/emissary/kff/KffFileTest.java index 86f2ac96bb..4d61738279 100644 --- a/src/test/java/emissary/kff/KffFileTest.java +++ b/src/test/java/emissary/kff/KffFileTest.java @@ -29,7 +29,6 @@ import java.util.stream.Collectors; import static emissary.kff.KffFile.DEFAULT_RECORD_LENGTH; -import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -87,12 +86,6 @@ void testKffFileCheck() { } } - @Test - void testKffFileMain() { - String[] args = {resourcePath, resourcePath}; - assertDoesNotThrow(() -> KffFile.main(args)); - } - /** * Tests concurrent {@link KffFile#check(String, ChecksumResults)} invocations to ensure that method's thread-safety */