Skip to content

Commit

Permalink
feat: add BoolVector-related methods to Packet
Browse files Browse the repository at this point in the history
  • Loading branch information
homuler committed Dec 17, 2023
1 parent 23e021f commit 9795145
Show file tree
Hide file tree
Showing 7 changed files with 220 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;

namespace Mediapipe
{
[StructLayout(LayoutKind.Sequential)]
internal readonly struct StructArray
{
private readonly IntPtr _data;
private readonly int _size;

public void Dispose()
{
UnsafeNativeMethods.delete_array__Pf(_data);
}

public List<T> Copy<T>() where T : unmanaged
{
var data = new List<T>(_size);

CopyTo(data);
return data;
}

public void CopyTo<T>(List<T> data) where T : unmanaged
{
data.Clear();

unsafe
{
var ptr = (T*)_data;

for (var i = 0; i < _size; i++)
{
data.Add(*ptr++);
}
}
}
}
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
// https://opensource.org/licenses/MIT.

using System;
using System.Collections.Generic;

namespace Mediapipe
{
Expand Down Expand Up @@ -43,20 +44,52 @@ internal static Packet CreateEmpty()
/// </param>
public static Packet CreateForReference(IntPtr ptr) => new Packet(ptr, false);

/// <summary>
/// Create a bool Packet from a boolean.
/// </summary>
public static Packet CreateBool(bool value)
{
UnsafeNativeMethods.mp__MakeBoolPacket__b(value, out var ptr).Assert();

return new Packet(ptr, true);
}

/// <summary>
/// Create a bool Packet.
/// </summary>
/// <param name="timestampMicrosec">
/// The timestamp of the packet.
/// </param>
public static Packet CreateBoolAt(bool value, long timestampMicrosec)
{
UnsafeNativeMethods.mp__MakeBoolPacket_At__b_ll(value, timestampMicrosec, out var ptr).Assert();

return new Packet(ptr, true);
}

/// <summary>
/// Create a bool vector Packet.
/// </summary>
public static Packet CreateBoolVector(bool[] value)
{
UnsafeNativeMethods.mp__MakeBoolVectorPacket__Pb_i(value, value.Length, out var ptr).Assert();

return new Packet(ptr, true);
}

/// <summary>
/// Create a bool vector Packet.
/// </summary>
/// <param name="timestampMicrosec">
/// The timestamp of the packet.
/// </param>
public static Packet CreateBoolVectorAt(bool[] value, long timestampMicrosec)
{
UnsafeNativeMethods.mp__MakeBoolVectorPacket_At__Pb_i_ll(value, value.Length, timestampMicrosec, out var ptr).Assert();

return new Packet(ptr, true);
}

/// <summary>
/// Get the content of the <see cref="Packet"/> as a boolean.
/// </summary>
Expand All @@ -74,6 +107,32 @@ public bool GetBool()
return value;
}

/// <summary>
/// Get the content of a bool vector Packet as a <see cref="List{bool}"/>.
/// </summary>
public List<bool> GetBoolList()
{
var value = new List<bool>();
GetBoolList(value);

return value;
}

/// <summary>
/// Get the content of a bool vector Packet as a <see cref="List{bool}"/>.
/// </summary>
/// <param name="value">
/// The <see cref="List{bool}"/> to be filled with the content of the <see cref="Packet"/>.
/// </param>
public void GetBoolList(List<bool> value)
{
UnsafeNativeMethods.mp_Packet__GetBoolVector(mpPtr, out var structArray).Assert();
GC.KeepAlive(this);

structArray.CopyTo(value);
structArray.Dispose();
}

/// <summary>
/// Validate if the content of the <see cref="Packet"/> is a boolean.
/// </summary>
Expand All @@ -87,5 +146,19 @@ public void ValidateAsBool()
GC.KeepAlive(this);
AssertStatusOk(statusPtr);
}

/// <summary>
/// Validate if the content of the <see cref="Packet"/> is a std::vector&lt;bool&gt;.
/// </summary>
/// <exception cref="BadStatusException">
/// If the <see cref="Packet"/> doesn't contain std::vector&lt;bool&gt;.
/// </exception>
public void ValidateAsBoolVector()
{
UnsafeNativeMethods.mp_Packet__ValidateAsBoolVector(mpPtr, out var statusPtr).Assert();

GC.KeepAlive(this);
AssertStatusOk(statusPtr);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,20 @@ internal static partial class UnsafeNativeMethods
public static extern MpReturnCode mp_Packet__ValidateAsBool(IntPtr packet, out IntPtr status);
#endregion

#region BoolVector
[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp__MakeBoolVectorPacket__Pb_i(bool[] value, int size, out IntPtr packet);

[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp__MakeBoolVectorPacket_At__Pb_i_ll(bool[] value, int size, long timestampMicrosec, out IntPtr packet);

[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp_Packet__GetBoolVector(IntPtr packet, out StructArray value);

[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp_Packet__ValidateAsBoolVector(IntPtr packet, out IntPtr status);
#endregion

#region Float
[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp__MakeFloatPacket__f(float value, out IntPtr packet);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,59 @@ public void CreateBoolAt_ShouldReturnNewBoolPacket(bool value)
}
#endregion

#region BoolVector
[Test]
public void CreateBoolVector_ShouldReturnNewBoolVectorPacket()
{
var value = new bool[] { true, false };
using var packet = Packet.CreateBoolVector(value);

Assert.DoesNotThrow(packet.ValidateAsBoolVector);

var result = packet.GetBoolList();
Assert.AreEqual(value.Length, result.Count);
for (var i = 0; i < value.Length; i++)
{
Assert.AreEqual(value[i], result[i]);
}

using var unsetTimestamp = Timestamp.Unset();
Assert.AreEqual(unsetTimestamp.Microseconds(), packet.TimestampMicroseconds());
}

[Test]
public void CreateBoolVectorAt_ShouldReturnNewBoolVectorPacket()
{
var value = new bool[] { true, false };
var timestamp = 1;
using var packet = Packet.CreateBoolVectorAt(value, timestamp);

Assert.DoesNotThrow(packet.ValidateAsBoolVector);

var result = packet.GetBoolList();
Assert.AreEqual(value.Length, result.Count);
for (var i = 0; i < value.Length; i++)
{
Assert.AreEqual(value[i], result[i]);
}
Assert.AreEqual(timestamp, packet.TimestampMicroseconds());
}
#endregion

#region #Validate
[Test]
public void ValidateAsBool_ShouldThrow_When_ValueIsNotSet()
{
using var packet = Packet.CreateEmpty();
_ = Assert.Throws<BadStatusException>(packet.ValidateAsBool);
}

[Test]
public void ValidateAsBoolVector_ShouldThrow_When_ValueIsNotSet()
{
using var packet = Packet.CreateEmpty();
_ = Assert.Throws<BadStatusException>(packet.ValidateAsBoolVector);
}
#endregion
}
}
20 changes: 20 additions & 0 deletions mediapipe_api/framework/packet.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,26 @@ MpReturnCode mp_Packet__ValidateAsBool(mediapipe::Packet* packet, absl::Status**
CATCH_EXCEPTION
}

// BoolVectorPacket
MpReturnCode mp__MakeBoolVectorPacket__Pb_i(bool* value, int size, mediapipe::Packet** packet_out) {
return mp__MakeVectorPacket(value, size, packet_out);
}

MpReturnCode mp__MakeBoolVectorPacket_At__Pb_i_ll(bool* value, int size, int64 timestampMicrosec, mediapipe::Packet** packet_out) {
return mp__MakeVectorPacket_At(value, size, timestampMicrosec, packet_out);
}

MpReturnCode mp_Packet__GetBoolVector(mediapipe::Packet* packet, mp_api::StructArray<bool>* value_out) {
return mp_Packet__GetStructVector(packet, value_out);
}

MpReturnCode mp_Packet__ValidateAsBoolVector(mediapipe::Packet* packet, absl::Status** status_out) {
TRY
*status_out = new absl::Status{packet->ValidateAsType<std::vector<bool>>()};
RETURN_CODE(MpReturnCode::Success);
CATCH_EXCEPTION
}

// FloatPacket
MpReturnCode mp__MakeFloatPacket__f(float value, mediapipe::Packet** packet_out) {
TRY
Expand Down
15 changes: 15 additions & 0 deletions mediapipe_api/framework/packet.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ MP_CAPI(MpReturnCode) mp__MakeBoolPacket_At__b_ll(bool value, int64 timestampMic
MP_CAPI(MpReturnCode) mp_Packet__GetBool(mediapipe::Packet* packet, bool* value_out);
MP_CAPI(MpReturnCode) mp_Packet__ValidateAsBool(mediapipe::Packet* packet, absl::Status** status_out);

// std::vector<bool>
MP_CAPI(MpReturnCode) mp__MakeBoolVectorPacket__Pb_i(bool* value, int size, mediapipe::Packet** packet_out);
MP_CAPI(MpReturnCode) mp__MakeBoolVectorPacket_At__Pb_i_ll(bool* value, int size, int64 timestampMicrosec, mediapipe::Packet** packet_out);
MP_CAPI(MpReturnCode) mp_Packet__GetBoolVector(mediapipe::Packet* packet, mp_api::StructArray<bool>* value_out);
MP_CAPI(MpReturnCode) mp_Packet__ValidateAsBoolVector(mediapipe::Packet* packet, absl::Status** status_out);

// float
MP_CAPI(MpReturnCode) mp__MakeFloatPacket__f(float value, mediapipe::Packet** packet_out);
MP_CAPI(MpReturnCode) mp__MakeFloatPacket_At__f_Rt(float value, mediapipe::Timestamp* timestamp, mediapipe::Packet** packet_out);
Expand Down Expand Up @@ -146,6 +152,15 @@ inline MpReturnCode mp__MakeVectorPacket_At(const T* array, int size, mediapipe:
CATCH_EXCEPTION
}

template <typename T>
inline MpReturnCode mp__MakeVectorPacket_At(const T* array, int size, int64 timestampMicrosec, mediapipe::Packet** packet_out) {
TRY
std::vector<T> vector(array, array + size);
*packet_out = new mediapipe::Packet{mediapipe::MakePacket<std::vector<T>>(vector).At(mediapipe::Timestamp(timestampMicrosec))};
RETURN_CODE(MpReturnCode::Success);
CATCH_EXCEPTION
}

template <typename T>
inline MpReturnCode mp_Packet__GetStructVector(mediapipe::Packet* packet, mp_api::StructArray<T>* value_out) {
TRY_ALL
Expand Down

0 comments on commit 9795145

Please sign in to comment.