Skip to content

Commit

Permalink
Bump dependencies; YoloV8 library v4 breaking changes adaptation; Add…
Browse files Browse the repository at this point in the history
… DbWeightsDataAccess tests; DbWeightsDataAccess before changes data loading capability;
  • Loading branch information
Neakita committed Mar 12, 2024
1 parent f688177 commit e09df02
Show file tree
Hide file tree
Showing 16 changed files with 110 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
<PackageReference Include="FluentAssertions" Version="6.12.0" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.9.0" />
<PackageReference Include="NSubstitute" Version="5.1.0" />
<PackageReference Include="Serilog.Sinks.Seq" Version="6.0.0" />
<PackageReference Include="Serilog.Sinks.Seq" Version="7.0.0" />
<PackageReference Include="xunit" Version="2.7.0" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.5.7">
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
Expand Down
6 changes: 3 additions & 3 deletions SightKeeper.Application/SightKeeper.Application.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
<PackageReference Include="CommunityToolkit.Diagnostics" Version="8.2.2" />
<PackageReference Include="FluentValidation" Version="11.9.0" />
<PackageReference Include="Serilog" Version="3.1.1" />
<PackageReference Include="SerilogTimings" Version="3.0.1" />
<PackageReference Include="SixLabors.ImageSharp" Version="3.1.2" />
<PackageReference Include="SerilogTimings" Version="3.1.0" />
<PackageReference Include="SixLabors.ImageSharp" Version="3.1.3" />
<PackageReference Include="System.Reactive" Version="6.0.0" />
<PackageReference Include="YamlDotNet" Version="15.1.1" />
<PackageReference Include="YamlDotNet" Version="15.1.2" />
</ItemGroup>

<ItemGroup>
Expand Down
6 changes: 4 additions & 2 deletions SightKeeper.Data.Tests/DbContextTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@ public void ShouldCreateSqLiteAppDbFileWithSomeData()
var itemClass = dataSet.CreateItemClass("Test item class", 0);
asset.CreateItem(itemClass, new Bounding(0, 0, 1, 1));

DbWeightsDataAccess weightsDataAccess = new(dbContext);
weightsDataAccess.CreateWeights(dataSet.Weights, Array.Empty<byte>(), Array.Empty<byte>(), Size.Medium,
DbWeightsDataAccess weightsDataAccessTests = new(dbContext);
var weights = weightsDataAccessTests.CreateWeights(dataSet.Weights, [0], [1], Size.Medium,
new WeightsMetrics(11, new LossMetrics(12, 13, 14)), Array.Empty<ItemClass>());

dbContext.DataSets.Add(dataSet);
dbContext.SaveChanges();

weightsDataAccessTests.LoadWeightsONNXData(weights).Content.Single().Should().Be(0);
}
}
33 changes: 33 additions & 0 deletions SightKeeper.Data.Tests/DbWeightsDataAccessTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
using SightKeeper.Data.Services;
using SightKeeper.Domain.Model.DataSets;
using SightKeeper.Tests.Common;

namespace SightKeeper.Data.Tests;

public sealed class DbWeightsDataAccessTests : DbRelatedTests
{
[Fact]
public void ShouldSaveAndLoadWeights()
{
var dataSet = DomainTestsHelper.NewDataSet;
var dbContext = DbContextFactory.CreateDbContext();
dbContext.Add(dataSet);
DbWeightsDataAccess weightsDataAccess = new(dbContext);
var weights = weightsDataAccess.CreateWeights(dataSet.Weights, [0], [1], Size.Large, new WeightsMetrics(), Array.Empty<ItemClass>());
dbContext.SaveChanges();
weightsDataAccess.LoadWeightsONNXData(weights).Content.Single().Should().Be(0);
weightsDataAccess.LoadWeightsPTData(weights).Content.Single().Should().Be(1);
}

[Fact]
public void ShouldSaveAndLoadWeightsWithoutSaving()
{
var dataSet = DomainTestsHelper.NewDataSet;
var dbContext = DbContextFactory.CreateDbContext();
dbContext.Add(dataSet);
DbWeightsDataAccess weightsDataAccess = new(dbContext);
var weights = weightsDataAccess.CreateWeights(dataSet.Weights, [0], [1], Size.Large, new WeightsMetrics(), Array.Empty<ItemClass>());
weightsDataAccess.LoadWeightsONNXData(weights).Content.Single().Should().Be(0);
weightsDataAccess.LoadWeightsPTData(weights).Content.Single().Should().Be(1);
}
}
3 changes: 2 additions & 1 deletion SightKeeper.Data/Configuration/WeightsDataConfiguration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ public void Configure(EntityTypeBuilder<DbWeightsData> builder)
{
builder.HasFlakeIdKey();
builder.ToTable("WeightsData");
builder.ComplexProperty(weightsData => weightsData.Data, subBuilder => subBuilder.Property(weights => weights.Content).HasColumnName("Content"));
// probably builder.ComplexProperty can be used, but https://github.com/dotnet/efcore/issues/9849 or smthng (Works with SqLite, but won't with InMemory)
builder.OwnsOne(weightsData => weightsData.Data, subBuilder => subBuilder.Property(weights => weights.Content).HasColumnName("Content"));
builder.Property(weightsData => weightsData.Format);
builder.HasIndex("WeightsId", "Format").IsUnique();
}
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion SightKeeper.Data/Migrations/AppDbContextModelSnapshot.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ partial class AppDbContextModelSnapshot : ModelSnapshot
protected override void BuildModel(ModelBuilder modelBuilder)
{
#pragma warning disable 612, 618
modelBuilder.HasAnnotation("ProductVersion", "8.0.2");
modelBuilder.HasAnnotation("ProductVersion", "8.0.3");

modelBuilder.Entity("SightKeeper.Data.DbWeightsData", b =>
{
Expand Down
36 changes: 31 additions & 5 deletions SightKeeper.Data/Services/DbWeightsDataAccess.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using FlakeId;
using Microsoft.EntityFrameworkCore;
using SightKeeper.Domain.Model.DataSets;

namespace SightKeeper.Data.Services;
Expand All @@ -8,25 +9,50 @@ public sealed class DbWeightsDataAccess : WeightsDataAccess
public DbWeightsDataAccess(AppDbContext dbContext)
{
_dbContext = dbContext;
_dbContext.SavedChanges += OnDbContextSavedChanges;
}

public override WeightsData LoadWeightsONNXData(Weights weights)
{
throw new NotImplementedException();
return LoadWeightsData(weights, DbWeightsData.DataFormat.ONNX);
}
public override WeightsData LoadWeightsPTData(Weights weights)
{
throw new NotImplementedException();
return LoadWeightsData(weights, DbWeightsData.DataFormat.PT);
}

protected override void SaveWeightsData(Weights weights, WeightsData onnxData, WeightsData ptData)
{
_dbContext.Add(new DbWeightsData(onnxData, weights, DbWeightsData.DataFormat.ONNX));
_dbContext.Add(new DbWeightsData(ptData, weights, DbWeightsData.DataFormat.PT));
DbWeightsData onnxWeightsData = new(onnxData, weights, DbWeightsData.DataFormat.ONNX);
_dbContext.Add(onnxWeightsData);
_unsavedWeightsData.Add((weights, DbWeightsData.DataFormat.ONNX), onnxData);
DbWeightsData ptWeightsData = new(ptData, weights, DbWeightsData.DataFormat.PT);
_dbContext.Add(ptWeightsData);
_unsavedWeightsData.Add((weights, DbWeightsData.DataFormat.PT), ptData);
}
protected override void RemoveWeightsData(Weights weights)
{
throw new NotImplementedException();
_unsavedWeightsData.Remove((weights, DbWeightsData.DataFormat.ONNX));
_unsavedWeightsData.Remove((weights, DbWeightsData.DataFormat.PT));
}

private readonly AppDbContext _dbContext;
private readonly Dictionary<(Weights, DbWeightsData.DataFormat), WeightsData> _unsavedWeightsData = new();

private WeightsData LoadWeightsData(Weights weights, DbWeightsData.DataFormat format)
{
if (_unsavedWeightsData.TryGetValue((weights, format), out var unsavedWeightsData))
return unsavedWeightsData;
var weightsEntry = _dbContext.Entry(weights);
var weightsId = weightsEntry.Property<Id>("Id").CurrentValue;
return _dbContext.Set<DbWeightsData>()
.AsNoTracking()
.Where(weightsData => EF.Property<Id>(weightsData.Weights, "Id") == weightsId && weightsData.Format == format)
.Select(weightsData => weightsData.Data)
.Single();
}
private void OnDbContextSavedChanges(object? sender, SavedChangesEventArgs e)
{
_unsavedWeightsData.Clear();
}
}
10 changes: 5 additions & 5 deletions SightKeeper.Data/SightKeeper.Data.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
<ItemGroup>
<PackageReference Include="DynamicData" Version="8.3.27" />
<PackageReference Include="FlakeId" Version="1.1.2" />
<PackageReference Include="Microsoft.EntityFrameworkCore" Version="8.0.2" />
<PackageReference Include="Microsoft.EntityFrameworkCore.Abstractions" Version="8.0.2" />
<PackageReference Include="Microsoft.EntityFrameworkCore.InMemory" Version="8.0.2" />
<PackageReference Include="Microsoft.EntityFrameworkCore.Sqlite" Version="8.0.2" />
<PackageReference Include="Microsoft.EntityFrameworkCore.Tools" Version="8.0.2">
<PackageReference Include="Microsoft.EntityFrameworkCore" Version="8.0.3" />
<PackageReference Include="Microsoft.EntityFrameworkCore.Abstractions" Version="8.0.3" />
<PackageReference Include="Microsoft.EntityFrameworkCore.InMemory" Version="8.0.3" />
<PackageReference Include="Microsoft.EntityFrameworkCore.Sqlite" Version="8.0.3" />
<PackageReference Include="Microsoft.EntityFrameworkCore.Tools" Version="8.0.3">
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
Expand Down
8 changes: 7 additions & 1 deletion SightKeeper.Domain.Model/DataSets/WeightsDataAccess.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

namespace SightKeeper.Domain.Model.DataSets;

public abstract class WeightsDataAccess
public abstract class WeightsDataAccess : IDisposable
{
public IObservable<Weights> WeightsCreated => _weightsCreated.AsObservable();
public IObservable<Weights> WeightsRemoved => _weightsRemoved.AsObservable();
Expand All @@ -30,6 +30,12 @@ public void RemoveWeights(Weights weights)
RemoveWeightsData(weights);
_weightsRemoved.OnNext(weights);
}
public void Dispose()
{
_weightsCreated.Dispose();
_weightsRemoved.Dispose();
GC.SuppressFinalize(this);
}

protected abstract void SaveWeightsData(Weights weights, WeightsData onnxData, WeightsData ptData);
protected abstract void RemoveWeightsData(Weights weights);
Expand Down
2 changes: 1 addition & 1 deletion SightKeeper.Domain.Model/SightKeeper.Domain.Model.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
<ItemGroup>
<PackageReference Include="CommunityToolkit.Diagnostics" Version="8.2.2" />
<PackageReference Include="EmptyConstructor.Fody" Version="3.0.3" />
<PackageReference Include="Fody" Version="6.5.3">
<PackageReference Include="Fody" Version="6.8.0">
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.9.0" />
<PackageReference Include="NSubstitute" Version="5.1.0" />
<PackageReference Include="Serilog" Version="3.1.1" />
<PackageReference Include="Serilog.Sinks.Seq" Version="6.0.0" />
<PackageReference Include="Serilog.Sinks.Seq" Version="7.0.0" />
<PackageReference Include="xunit" Version="2.7.0" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.5.7">
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
<PackageReference Include="SharpDX" Version="4.2.0" />
<PackageReference Include="SharpDX.Direct3D11" Version="4.2.0" />
<PackageReference Include="SharpDX.DXGI" Version="4.2.0" />
<PackageReference Include="System.Drawing.Common" Version="8.0.2" />
<PackageReference Include="System.Drawing.Common" Version="8.0.3" />
</ItemGroup>

</Project>
22 changes: 14 additions & 8 deletions SightKeeper.Services/Prediction/ONNXDetector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,15 @@ public Weights? Weights
private void SetWeights(Weights weights)
{
var weightsData = weightsDataAccess.LoadWeightsONNXData(weights);
_predictor = new YoloV8(new ModelSelector(weightsData.Content), CreateMetadata(weights.Library.DataSet));
_predictor.Parameters.Confidence = ProbabilityThreshold;
_predictor.Parameters.IoU = IoU;
YoloV8Builder builder = new();
builder.UseOnnxModel(new BinarySelector(weightsData.Content));
builder.WithMetadata(CreateMetadata(weights.Library.DataSet));
builder.WithConfiguration(configuration =>
{
configuration.Confidence = ProbabilityThreshold;
configuration.IoU = IoU;
});
_predictor = builder.Build();
}

public float ProbabilityThreshold
Expand All @@ -47,7 +53,7 @@ public float ProbabilityThreshold
{
_probabilityThreshold = value;
if (_predictor != null)
_predictor.Parameters.Confidence = value;
_predictor.Configuration.Confidence = value;
}
}

Expand All @@ -58,7 +64,7 @@ public float IoU
{
_iou = value;
if (_predictor != null)
_predictor.Parameters.IoU = value;
_predictor.Configuration.IoU = value;
}
}

Expand All @@ -82,9 +88,9 @@ public async Task<ImmutableList<DetectionItem>> DetectAsync(byte[] image, Cancel
return result;
}

private float _probabilityThreshold = YoloV8Parameters.Default.Confidence;
private float _iou = YoloV8Parameters.Default.IoU;
private YoloV8? _predictor;
private float _probabilityThreshold = YoloV8Configuration.Default.Confidence;
private float _iou = YoloV8Configuration.Default.IoU;
private YoloV8Predictor? _predictor;
private Weights? _weights;
private Dictionary<int, ItemClass>? _itemClasses;

Expand Down
8 changes: 4 additions & 4 deletions SightKeeper.Services/SightKeeper.Services.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@

<ItemGroup>
<PackageReference Include="Autofac" Version="8.0.0" />
<PackageReference Include="SerilogTimings" Version="3.0.1" />
<PackageReference Include="SharpHook" Version="5.3.0" />
<PackageReference Include="SharpHook.Reactive" Version="5.3.0" />
<PackageReference Include="YoloV8" Version="3.1.1" />
<PackageReference Include="SerilogTimings" Version="3.1.0" />
<PackageReference Include="SharpHook" Version="5.3.1" />
<PackageReference Include="SharpHook.Reactive" Version="5.3.1" />
<PackageReference Include="YoloV8" Version="4.0.0" />
</ItemGroup>

</Project>

0 comments on commit e09df02

Please sign in to comment.