Skip to content

Commit

Permalink
Implement ByteBufferSerialization to start getting rid of UnsafeSeria…
Browse files Browse the repository at this point in the history
…lization
  • Loading branch information
alexklibisz committed Nov 29, 2023
1 parent 8b7f42e commit d1c3e90
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package com.klibisz.elastiknn.jmhbenchmarks

import com.klibisz.elastiknn.api.Vec
import com.klibisz.elastiknn.storage.{ByteBufferSerialization, UnsafeSerialization}
import org.openjdk.jmh.annotations._

import scala.util.Random

@State(Scope.Benchmark)
class VectorSerializationBenchmarksState {
implicit private val rng: Random = new Random(0)
val floatsOriginal = Vec.DenseFloat.random(999).values
val floatsSerialized = UnsafeSerialization.writeFloats(floatsOriginal)
}

class VectorSerializationBenchmarks {

@Benchmark
@BenchmarkMode(Array(Mode.Throughput))
@Fork(value = 1)
@Warmup(time = 5, iterations = 5)
@Measurement(time = 5, iterations = 5)
def writeFloatsUnsafe(state: VectorSerializationBenchmarksState): Array[Byte] = {
UnsafeSerialization.writeFloats(state.floatsOriginal)
}

@Benchmark
@BenchmarkMode(Array(Mode.Throughput))
@Fork(value = 1)
@Warmup(time = 5, iterations = 5)
@Measurement(time = 5, iterations = 5)
def writeFloatsByteBuffer(state: VectorSerializationBenchmarksState): Array[Byte] = {
ByteBufferSerialization.writeFloats(state.floatsOriginal)
}

@Benchmark
@BenchmarkMode(Array(Mode.Throughput))
@Fork(value = 1)
@Warmup(time = 5, iterations = 5)
@Measurement(time = 5, iterations = 5)
def readFloatsUnsafe(state: VectorSerializationBenchmarksState): Array[Float] = {
UnsafeSerialization.readFloats(state.floatsSerialized, 0, state.floatsSerialized.length)
}

@Benchmark
@BenchmarkMode(Array(Mode.Throughput))
@Fork(value = 1)
@Warmup(time = 5, iterations = 5)
@Measurement(time = 5, iterations = 5)
def readFloatsByteBuffer(state: VectorSerializationBenchmarksState): Array[Float] = {
ByteBufferSerialization.readFloats(state.floatsSerialized, 0, state.floatsSerialized.length)
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package com.klibisz.elastiknn.storage;

import scala.util.control.Exception;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;

public class ByteBufferSerialization {
public static final int numBytesInInt = 4;
public static final int numBytesInFloat = 4;

public static final ByteOrder byteOrder = ByteOrder.LITTLE_ENDIAN;

public static byte[] writeInt(final int i) {
ByteBuffer bb;
final int a = Math.abs(i);
if (a <= Byte.MAX_VALUE) {
bb = ByteBuffer.allocate(1).order(byteOrder);
bb.asIntBuffer().put(i);
return bb.array();
} else if (a <= Short.MAX_VALUE) {
bb = ByteBuffer.allocate(2).order(byteOrder);
bb.asIntBuffer().put(i);
return bb.array();
} else {
bb = ByteBuffer.allocate(4).order(byteOrder);
bb.asIntBuffer().put(i);
return bb.array();
}
}

public static byte[] writeFloats(final float[] farr) {
ByteBuffer bb = ByteBuffer.allocate(farr.length * numBytesInFloat).order(byteOrder);
bb.asFloatBuffer().put(farr);
return bb.array();
}

public static float[] readFloats(final byte[] barr, int offset, int length) {
float[] dst = new float[length / numBytesInFloat];
ByteBuffer bb = ByteBuffer.wrap(barr, offset, length).order(byteOrder);
bb.asFloatBuffer().get(dst);
return dst;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package com.klibisz.elastiknn.storage

import org.scalatest.freespec.AnyFreeSpec
import org.scalatest.matchers.should.Matchers

class ByteBufferSerializationSpec extends AnyFreeSpec with Matchers {

"writeInt" - {
"variable length encoding" in {
ByteBufferSerialization.writeInt(127) should have length 1
ByteBufferSerialization.writeInt(-127) should have length 1
ByteBufferSerialization.writeInt(32767) should have length 2
ByteBufferSerialization.writeInt(-32767) should have length 2
println(Int.MaxValue)
println(Int.MinValue)
}
}

"writeFloats and readFloats" - {
"example" in {
val original = (-10 until 10).map(_.toFloat).toArray
val serialized = ByteBufferSerialization.writeFloats(original)
val deserialized = ByteBufferSerialization.readFloats(serialized, 0, serialized.length)
deserialized.toList shouldBe original.toList
}
"example with offset and length" in {
val (dropLeftFloats, dropRightFloats) = (2, 5)
val (dropLeftBytes, dropRightBytes) = (dropLeftFloats * 4, dropRightFloats * 4)
val original = (-10 until 10).map(_.toFloat).toArray
val serialized = ByteBufferSerialization.writeFloats(original)
val deserialized = ByteBufferSerialization.readFloats(serialized, dropLeftBytes, serialized.length - dropLeftBytes - dropRightBytes)
deserialized.toList shouldBe original.drop(dropLeftFloats).dropRight(dropRightFloats).toList
}
"compatibility with UnsafeSerialization" in {
val original = (-10 until 10).map(_.toFloat).toArray
val unsafeSerialized = UnsafeSerialization.writeFloats(original)
val byteBufferSerialized = ByteBufferSerialization.writeFloats(original)
val unsafeDeserializedFromByteBuffer = ByteBufferSerialization.readFloats(byteBufferSerialized, 0, byteBufferSerialized.length)
val byteBufferDeserializedFromUnsafe = ByteBufferSerialization.readFloats(unsafeSerialized, 0, unsafeSerialized.length)
byteBufferSerialized.toList shouldBe unsafeSerialized.toList
unsafeDeserializedFromByteBuffer.toList shouldBe original.toList
byteBufferDeserializedFromUnsafe.toList shouldBe original.toList
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package com.klibisz.elastiknn

import java.nio.{ByteBuffer, ByteOrder}

object Playground extends App {

val arr = (0 until 10).toArray
val bb: ByteBuffer = ByteBuffer.allocate(arr.length * 4).order(ByteOrder.nativeOrder())
bb.asIntBuffer().put(arr)
println(bb.array().length)

}

0 comments on commit d1c3e90

Please sign in to comment.