Skip to content

Commit

Permalink
reverted KffFile back to using RandomAccessFile for file I/O, but now…
Browse files Browse the repository at this point in the history
… synchronizing access to its use (#605)
  • Loading branch information
drivenflywheel authored Oct 11, 2023
1 parent 76a4659 commit 6a0b8cc
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 73 deletions.
104 changes: 38 additions & 66 deletions src/main/java/emissary/kff/KffFile.java
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
package emissary.kff;

import emissary.core.channels.FileChannelFactory;
import emissary.core.channels.SeekableByteChannelFactory;

import org.apache.commons.io.IOUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.io.RandomAccessFile;
import java.nio.ByteBuffer;
import java.nio.channels.SeekableByteChannel;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.concurrent.locks.ReentrantLock;
import javax.annotation.Nonnull;

/**
Expand All @@ -31,19 +26,19 @@
* </p>
*/
public class KffFile implements KffFilter {
private final Logger logger;
private static final Logger logger = LoggerFactory.getLogger(KffFile.class);

/** File containing SHA-1/CRC32 results of known files */
protected SeekableByteChannelFactory knownFileFactory;
protected RandomAccessFile knownFile;

/** Byte buffer that is mapped to the above file */
protected ByteBuffer mappedBuf;

/** Initial value of high index for binary search */
private int bSearchInitHigh;
private long bSearchInitHigh;

public static final int DEFAULT_RECORD_LENGTH = 24;
protected int recordLength;
protected final int recordLength;

/** String logical name for this filter */
protected String filterName = "UNKNOWN";
Expand All @@ -52,6 +47,8 @@ public class KffFile implements KffFilter {

protected String myPreferredAlgorithm = "SHA-1";

protected ReentrantLock reentrantLock = new ReentrantLock();

/**
* Creates a new instance of KffFile
*
Expand Down Expand Up @@ -81,16 +78,11 @@ public KffFile(String filename, String filterName, FilterType ftype, int recordL
this.filterName = filterName;
this.recordLength = recordLength;

// Set logger to run time class
logger = LoggerFactory.getLogger(this.getClass());

// Open file in read-only mode
knownFileFactory = FileChannelFactory.create(Paths.get(filename));
knownFile = new RandomAccessFile(filename, "r");

// Initial high value for binary search is the largest index
try (SeekableByteChannel sbc = knownFileFactory.create()) {
bSearchInitHigh = ((int) sbc.size() / recordLength) - 1;
}
// Initial high value for binary search is largest index
bSearchInitHigh = (knownFile.length() / (long) recordLength) - 1;

logger.debug("KFF File {} has {} records", filename, (bSearchInitHigh + 1));
}
Expand Down Expand Up @@ -147,26 +139,23 @@ public String getPreferredAlgorithm() {
private boolean binaryFileSearch(@Nonnull byte[] hash, long crc) {

// Initialize indexes for binary search
int low = 0;
int high = bSearchInitHigh;
long low = 0;
long high = bSearchInitHigh;

/* Buffer to hold a record */
byte[] rec = new byte[recordLength];
ByteBuffer byteBuffer = ByteBuffer.wrap(rec);
// Search until the indexes cross
try (SeekableByteChannel knownFile = knownFileFactory.create()) {
while (low <= high) {
byteBuffer.clear();

reentrantLock.lock();
try {
// Search until the indexes cross
while (low <= high) {
// Calculate the midpoint
int mid = (low + high) >> 1;
long mid = (low + high) >> 1;

// Multiply the index by the record length to get the buffer position and read the record
knownFile.seek(recordLength * mid);
knownFile.readFully(rec);

knownFile.position(rec.length * (long) mid);
int count = IOUtils.read(knownFile, byteBuffer);
if (count != rec.length) {
logger.warn("Short read on KffFile at {} read {} expected {}", (recordLength * mid), count, recordLength);
return false;
}
// Compare the record with the target. Adjust the indexes accordingly.
int c = compare(rec, hash, crc);
if (c < 0) {
Expand All @@ -177,44 +166,48 @@ private boolean binaryFileSearch(@Nonnull byte[] hash, long crc) {
return true;
}
}
} catch (EOFException e) {
// this shouldn't happen if we're synchronizing calls correctly
logger.warn("EOFException reading KffFile: {}", e.getLocalizedMessage());
} catch (IOException e) {
logger.warn("Exception reading KffFile", e);
} finally {
if (reentrantLock.isHeldByCurrentThread()) {
reentrantLock.unlock();
}
}

// not found
return false;
}

/**
* Compares the given hash/crc to the one in the record.
*
* @param rec bytes from the kff binary file, one record long
* @param record bytes from the kff binary file, one record long
* @param hash HASH to compare to record
* @param crc CRC to compare to record
* @return &lt;0 if given value is less than record, &gt;0 if given value is greater than record, 0 if they match
*/
private int compare(@Nonnull byte[] rec, @Nonnull byte[] hash, long crc) {
private int compare(@Nonnull byte[] record, @Nonnull byte[] hash, long crc) {
int i;

// Compare the HASHs first. We can't compare the bytes directly
// because a Java byte is signed and may generate the wrong
// result. We must convert to integers and then mask off the
// sign bits to get proper results.
// Compare the hashes first. We can't compare the bytes directly because a Java byte is signed and may generate the
// wrong result. We must convert to integers and then mask off the sign bits to get proper results.
for (i = 0; i < hash.length; i++) {
int ihash = hash[i] & 0xff;
int irec = rec[i] & 0xff;
int irec = record[i] & 0xff;
if (ihash < irec) {
return -1;
} else if (ihash > irec) {
return 1;
}
}

// If the HASHs match, check the CRCs.
// If the hashes match, check the CRCs.
if (crc != -1L) {
for (int j = 24; i < rec.length; i++, j -= 8) {
for (int j = 24; i < record.length; i++, j -= 8) {
int icrc = ((int) crc >> j) & 0xff;
int irec = rec[i] & 0xff;
int irec = record[i] & 0xff;
if (icrc < irec) {
return -1;
} else if (icrc > irec) {
Expand All @@ -234,25 +227,4 @@ public boolean check(String fname, ChecksumResults csum) throws Exception {
}
return binaryFileSearch(hash, csum.getCrc());
}

public static void main(String[] args) throws Exception {
KffChain kff = new KffChain();
KffFile kfile = new KffFile(args[0], "TEST", FilterType.Ignore);
kfile.setPreferredAlgorithm("SHA-1");
kff.addFilter(kfile);
kff.addAlgorithm("CRC32");
kff.addAlgorithm("SSDEEP");
kff.addAlgorithm("MD5");
kff.addAlgorithm("SHA-1");
kff.addAlgorithm("SHA-256");

for (int i = 1; i < args.length; i++) {
try (InputStream is = Files.newInputStream(Paths.get(args[i]))) {
byte[] buffer = IOUtils.toByteArray(is);

KffResult r = kff.check(args[i], buffer);
System.out.println(args[i] + ": " + r.isKnown() + " - " + r.getShaString() + " - " + r.getCrc32());
}
}
}
}
7 changes: 0 additions & 7 deletions src/test/java/emissary/kff/KffFileTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import java.util.stream.Collectors;

import static emissary.kff.KffFile.DEFAULT_RECORD_LENGTH;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
Expand Down Expand Up @@ -87,12 +86,6 @@ void testKffFileCheck() {
}
}

@Test
void testKffFileMain() {
String[] args = {resourcePath, resourcePath};
assertDoesNotThrow(() -> KffFile.main(args));
}

/**
* Tests concurrent {@link KffFile#check(String, ChecksumResults)} invocations to ensure that method's thread-safety
*/
Expand Down

0 comments on commit 6a0b8cc

Please sign in to comment.