diff --git a/SightKeeper.Application.Tests/SightKeeper.Application.Tests.csproj b/SightKeeper.Application.Tests/SightKeeper.Application.Tests.csproj index 11638210..8fbe79ba 100644 --- a/SightKeeper.Application.Tests/SightKeeper.Application.Tests.csproj +++ b/SightKeeper.Application.Tests/SightKeeper.Application.Tests.csproj @@ -12,7 +12,7 @@ - + runtime; build; native; contentfiles; analyzers; buildtransitive diff --git a/SightKeeper.Application/SightKeeper.Application.csproj b/SightKeeper.Application/SightKeeper.Application.csproj index de01b409..8167e620 100644 --- a/SightKeeper.Application/SightKeeper.Application.csproj +++ b/SightKeeper.Application/SightKeeper.Application.csproj @@ -10,10 +10,10 @@ - - + + - + diff --git a/SightKeeper.Data.Tests/DbContextTests.cs b/SightKeeper.Data.Tests/DbContextTests.cs index 42885dea..cf7c886e 100644 --- a/SightKeeper.Data.Tests/DbContextTests.cs +++ b/SightKeeper.Data.Tests/DbContextTests.cs @@ -27,11 +27,13 @@ public void ShouldCreateSqLiteAppDbFileWithSomeData() var itemClass = dataSet.CreateItemClass("Test item class", 0); asset.CreateItem(itemClass, new Bounding(0, 0, 1, 1)); - DbWeightsDataAccess weightsDataAccess = new(dbContext); - weightsDataAccess.CreateWeights(dataSet.Weights, Array.Empty(), Array.Empty(), Size.Medium, + DbWeightsDataAccess weightsDataAccessTests = new(dbContext); + var weights = weightsDataAccessTests.CreateWeights(dataSet.Weights, [0], [1], Size.Medium, new WeightsMetrics(11, new LossMetrics(12, 13, 14)), Array.Empty()); dbContext.DataSets.Add(dataSet); dbContext.SaveChanges(); + + weightsDataAccessTests.LoadWeightsONNXData(weights).Content.Single().Should().Be(0); } } \ No newline at end of file diff --git a/SightKeeper.Data.Tests/DbWeightsDataAccessTests.cs b/SightKeeper.Data.Tests/DbWeightsDataAccessTests.cs new file mode 100644 index 00000000..968d7d7d --- /dev/null +++ b/SightKeeper.Data.Tests/DbWeightsDataAccessTests.cs @@ -0,0 +1,33 @@ +using SightKeeper.Data.Services; +using SightKeeper.Domain.Model.DataSets; +using SightKeeper.Tests.Common; + +namespace SightKeeper.Data.Tests; + +public sealed class DbWeightsDataAccessTests : DbRelatedTests +{ + [Fact] + public void ShouldSaveAndLoadWeights() + { + var dataSet = DomainTestsHelper.NewDataSet; + var dbContext = DbContextFactory.CreateDbContext(); + dbContext.Add(dataSet); + DbWeightsDataAccess weightsDataAccess = new(dbContext); + var weights = weightsDataAccess.CreateWeights(dataSet.Weights, [0], [1], Size.Large, new WeightsMetrics(), Array.Empty()); + dbContext.SaveChanges(); + weightsDataAccess.LoadWeightsONNXData(weights).Content.Single().Should().Be(0); + weightsDataAccess.LoadWeightsPTData(weights).Content.Single().Should().Be(1); + } + + [Fact] + public void ShouldSaveAndLoadWeightsWithoutSaving() + { + var dataSet = DomainTestsHelper.NewDataSet; + var dbContext = DbContextFactory.CreateDbContext(); + dbContext.Add(dataSet); + DbWeightsDataAccess weightsDataAccess = new(dbContext); + var weights = weightsDataAccess.CreateWeights(dataSet.Weights, [0], [1], Size.Large, new WeightsMetrics(), Array.Empty()); + weightsDataAccess.LoadWeightsONNXData(weights).Content.Single().Should().Be(0); + weightsDataAccess.LoadWeightsPTData(weights).Content.Single().Should().Be(1); + } +} \ No newline at end of file diff --git a/SightKeeper.Data/Configuration/WeightsDataConfiguration.cs b/SightKeeper.Data/Configuration/WeightsDataConfiguration.cs index 8fe5c5b4..cbee73f8 100644 --- a/SightKeeper.Data/Configuration/WeightsDataConfiguration.cs +++ b/SightKeeper.Data/Configuration/WeightsDataConfiguration.cs @@ -9,7 +9,8 @@ public void Configure(EntityTypeBuilder builder) { builder.HasFlakeIdKey(); builder.ToTable("WeightsData"); - builder.ComplexProperty(weightsData => weightsData.Data, subBuilder => subBuilder.Property(weights => weights.Content).HasColumnName("Content")); + // probably builder.ComplexProperty can be used, but https://github.com/dotnet/efcore/issues/9849 or smthng (Works with SqLite, but won't with InMemory) + builder.OwnsOne(weightsData => weightsData.Data, subBuilder => subBuilder.Property(weights => weights.Content).HasColumnName("Content")); builder.Property(weightsData => weightsData.Format); builder.HasIndex("WeightsId", "Format").IsUnique(); } diff --git a/SightKeeper.Data/Migrations/20240312211014_Initial.Designer.cs b/SightKeeper.Data/Migrations/20240312223758_Initial.Designer.cs similarity index 99% rename from SightKeeper.Data/Migrations/20240312211014_Initial.Designer.cs rename to SightKeeper.Data/Migrations/20240312223758_Initial.Designer.cs index 54921dcb..cf717ffa 100644 --- a/SightKeeper.Data/Migrations/20240312211014_Initial.Designer.cs +++ b/SightKeeper.Data/Migrations/20240312223758_Initial.Designer.cs @@ -12,14 +12,14 @@ namespace SightKeeper.Data.Migrations { [DbContext(typeof(AppDbContext))] - [Migration("20240312211014_Initial")] + [Migration("20240312223758_Initial")] partial class Initial { /// protected override void BuildTargetModel(ModelBuilder modelBuilder) { #pragma warning disable 612, 618 - modelBuilder.HasAnnotation("ProductVersion", "8.0.2"); + modelBuilder.HasAnnotation("ProductVersion", "8.0.3"); modelBuilder.Entity("SightKeeper.Data.DbWeightsData", b => { diff --git a/SightKeeper.Data/Migrations/20240312211014_Initial.cs b/SightKeeper.Data/Migrations/20240312223758_Initial.cs similarity index 100% rename from SightKeeper.Data/Migrations/20240312211014_Initial.cs rename to SightKeeper.Data/Migrations/20240312223758_Initial.cs diff --git a/SightKeeper.Data/Migrations/AppDbContextModelSnapshot.cs b/SightKeeper.Data/Migrations/AppDbContextModelSnapshot.cs index 2f03bdd8..c0c60ccd 100644 --- a/SightKeeper.Data/Migrations/AppDbContextModelSnapshot.cs +++ b/SightKeeper.Data/Migrations/AppDbContextModelSnapshot.cs @@ -16,7 +16,7 @@ partial class AppDbContextModelSnapshot : ModelSnapshot protected override void BuildModel(ModelBuilder modelBuilder) { #pragma warning disable 612, 618 - modelBuilder.HasAnnotation("ProductVersion", "8.0.2"); + modelBuilder.HasAnnotation("ProductVersion", "8.0.3"); modelBuilder.Entity("SightKeeper.Data.DbWeightsData", b => { diff --git a/SightKeeper.Data/Services/DbWeightsDataAccess.cs b/SightKeeper.Data/Services/DbWeightsDataAccess.cs index 2891dab0..15b37e57 100644 --- a/SightKeeper.Data/Services/DbWeightsDataAccess.cs +++ b/SightKeeper.Data/Services/DbWeightsDataAccess.cs @@ -1,4 +1,5 @@ using FlakeId; +using Microsoft.EntityFrameworkCore; using SightKeeper.Domain.Model.DataSets; namespace SightKeeper.Data.Services; @@ -8,25 +9,50 @@ public sealed class DbWeightsDataAccess : WeightsDataAccess public DbWeightsDataAccess(AppDbContext dbContext) { _dbContext = dbContext; + _dbContext.SavedChanges += OnDbContextSavedChanges; } + public override WeightsData LoadWeightsONNXData(Weights weights) { - throw new NotImplementedException(); + return LoadWeightsData(weights, DbWeightsData.DataFormat.ONNX); } public override WeightsData LoadWeightsPTData(Weights weights) { - throw new NotImplementedException(); + return LoadWeightsData(weights, DbWeightsData.DataFormat.PT); } protected override void SaveWeightsData(Weights weights, WeightsData onnxData, WeightsData ptData) { - _dbContext.Add(new DbWeightsData(onnxData, weights, DbWeightsData.DataFormat.ONNX)); - _dbContext.Add(new DbWeightsData(ptData, weights, DbWeightsData.DataFormat.PT)); + DbWeightsData onnxWeightsData = new(onnxData, weights, DbWeightsData.DataFormat.ONNX); + _dbContext.Add(onnxWeightsData); + _unsavedWeightsData.Add((weights, DbWeightsData.DataFormat.ONNX), onnxData); + DbWeightsData ptWeightsData = new(ptData, weights, DbWeightsData.DataFormat.PT); + _dbContext.Add(ptWeightsData); + _unsavedWeightsData.Add((weights, DbWeightsData.DataFormat.PT), ptData); } protected override void RemoveWeightsData(Weights weights) { - throw new NotImplementedException(); + _unsavedWeightsData.Remove((weights, DbWeightsData.DataFormat.ONNX)); + _unsavedWeightsData.Remove((weights, DbWeightsData.DataFormat.PT)); } private readonly AppDbContext _dbContext; + private readonly Dictionary<(Weights, DbWeightsData.DataFormat), WeightsData> _unsavedWeightsData = new(); + + private WeightsData LoadWeightsData(Weights weights, DbWeightsData.DataFormat format) + { + if (_unsavedWeightsData.TryGetValue((weights, format), out var unsavedWeightsData)) + return unsavedWeightsData; + var weightsEntry = _dbContext.Entry(weights); + var weightsId = weightsEntry.Property("Id").CurrentValue; + return _dbContext.Set() + .AsNoTracking() + .Where(weightsData => EF.Property(weightsData.Weights, "Id") == weightsId && weightsData.Format == format) + .Select(weightsData => weightsData.Data) + .Single(); + } + private void OnDbContextSavedChanges(object? sender, SavedChangesEventArgs e) + { + _unsavedWeightsData.Clear(); + } } \ No newline at end of file diff --git a/SightKeeper.Data/SightKeeper.Data.csproj b/SightKeeper.Data/SightKeeper.Data.csproj index 1873fee8..43b6555d 100644 --- a/SightKeeper.Data/SightKeeper.Data.csproj +++ b/SightKeeper.Data/SightKeeper.Data.csproj @@ -9,11 +9,11 @@ - - - - - + + + + + all runtime; build; native; contentfiles; analyzers; buildtransitive diff --git a/SightKeeper.Domain.Model/DataSets/WeightsDataAccess.cs b/SightKeeper.Domain.Model/DataSets/WeightsDataAccess.cs index 0cd4ac19..23e259f7 100644 --- a/SightKeeper.Domain.Model/DataSets/WeightsDataAccess.cs +++ b/SightKeeper.Domain.Model/DataSets/WeightsDataAccess.cs @@ -3,7 +3,7 @@ namespace SightKeeper.Domain.Model.DataSets; -public abstract class WeightsDataAccess +public abstract class WeightsDataAccess : IDisposable { public IObservable WeightsCreated => _weightsCreated.AsObservable(); public IObservable WeightsRemoved => _weightsRemoved.AsObservable(); @@ -30,6 +30,12 @@ public void RemoveWeights(Weights weights) RemoveWeightsData(weights); _weightsRemoved.OnNext(weights); } + public void Dispose() + { + _weightsCreated.Dispose(); + _weightsRemoved.Dispose(); + GC.SuppressFinalize(this); + } protected abstract void SaveWeightsData(Weights weights, WeightsData onnxData, WeightsData ptData); protected abstract void RemoveWeightsData(Weights weights); diff --git a/SightKeeper.Domain.Model/SightKeeper.Domain.Model.csproj b/SightKeeper.Domain.Model/SightKeeper.Domain.Model.csproj index 10d0abdf..1b6aff61 100644 --- a/SightKeeper.Domain.Model/SightKeeper.Domain.Model.csproj +++ b/SightKeeper.Domain.Model/SightKeeper.Domain.Model.csproj @@ -13,7 +13,7 @@ - + all runtime; build; native; contentfiles; analyzers; buildtransitive diff --git a/SightKeeper.Services.Tests/SightKeeper.Services.Tests.csproj b/SightKeeper.Services.Tests/SightKeeper.Services.Tests.csproj index d8e71c77..d64fde78 100644 --- a/SightKeeper.Services.Tests/SightKeeper.Services.Tests.csproj +++ b/SightKeeper.Services.Tests/SightKeeper.Services.Tests.csproj @@ -14,7 +14,7 @@ - + runtime; build; native; contentfiles; analyzers; buildtransitive diff --git a/SightKeeper.Services.Windows/SightKeeper.Services.Windows.csproj b/SightKeeper.Services.Windows/SightKeeper.Services.Windows.csproj index ac39ec10..4eb89eb9 100644 --- a/SightKeeper.Services.Windows/SightKeeper.Services.Windows.csproj +++ b/SightKeeper.Services.Windows/SightKeeper.Services.Windows.csproj @@ -15,7 +15,7 @@ - + diff --git a/SightKeeper.Services/Prediction/ONNXDetector.cs b/SightKeeper.Services/Prediction/ONNXDetector.cs index f92bcd5f..f5d6247f 100644 --- a/SightKeeper.Services/Prediction/ONNXDetector.cs +++ b/SightKeeper.Services/Prediction/ONNXDetector.cs @@ -35,9 +35,15 @@ public Weights? Weights private void SetWeights(Weights weights) { var weightsData = weightsDataAccess.LoadWeightsONNXData(weights); - _predictor = new YoloV8(new ModelSelector(weightsData.Content), CreateMetadata(weights.Library.DataSet)); - _predictor.Parameters.Confidence = ProbabilityThreshold; - _predictor.Parameters.IoU = IoU; + YoloV8Builder builder = new(); + builder.UseOnnxModel(new BinarySelector(weightsData.Content)); + builder.WithMetadata(CreateMetadata(weights.Library.DataSet)); + builder.WithConfiguration(configuration => + { + configuration.Confidence = ProbabilityThreshold; + configuration.IoU = IoU; + }); + _predictor = builder.Build(); } public float ProbabilityThreshold @@ -47,7 +53,7 @@ public float ProbabilityThreshold { _probabilityThreshold = value; if (_predictor != null) - _predictor.Parameters.Confidence = value; + _predictor.Configuration.Confidence = value; } } @@ -58,7 +64,7 @@ public float IoU { _iou = value; if (_predictor != null) - _predictor.Parameters.IoU = value; + _predictor.Configuration.IoU = value; } } @@ -82,9 +88,9 @@ public async Task> DetectAsync(byte[] image, Cancel return result; } - private float _probabilityThreshold = YoloV8Parameters.Default.Confidence; - private float _iou = YoloV8Parameters.Default.IoU; - private YoloV8? _predictor; + private float _probabilityThreshold = YoloV8Configuration.Default.Confidence; + private float _iou = YoloV8Configuration.Default.IoU; + private YoloV8Predictor? _predictor; private Weights? _weights; private Dictionary? _itemClasses; diff --git a/SightKeeper.Services/SightKeeper.Services.csproj b/SightKeeper.Services/SightKeeper.Services.csproj index e38da76d..b400172d 100644 --- a/SightKeeper.Services/SightKeeper.Services.csproj +++ b/SightKeeper.Services/SightKeeper.Services.csproj @@ -13,10 +13,10 @@ - - - - + + + +