diff --git a/SightKeeper.Application.Tests/SightKeeper.Application.Tests.csproj b/SightKeeper.Application.Tests/SightKeeper.Application.Tests.csproj
index 11638210..8fbe79ba 100644
--- a/SightKeeper.Application.Tests/SightKeeper.Application.Tests.csproj
+++ b/SightKeeper.Application.Tests/SightKeeper.Application.Tests.csproj
@@ -12,7 +12,7 @@
-
+
runtime; build; native; contentfiles; analyzers; buildtransitive
diff --git a/SightKeeper.Application/SightKeeper.Application.csproj b/SightKeeper.Application/SightKeeper.Application.csproj
index de01b409..8167e620 100644
--- a/SightKeeper.Application/SightKeeper.Application.csproj
+++ b/SightKeeper.Application/SightKeeper.Application.csproj
@@ -10,10 +10,10 @@
-
-
+
+
-
+
diff --git a/SightKeeper.Data.Tests/DbContextTests.cs b/SightKeeper.Data.Tests/DbContextTests.cs
index 42885dea..cf7c886e 100644
--- a/SightKeeper.Data.Tests/DbContextTests.cs
+++ b/SightKeeper.Data.Tests/DbContextTests.cs
@@ -27,11 +27,13 @@ public void ShouldCreateSqLiteAppDbFileWithSomeData()
var itemClass = dataSet.CreateItemClass("Test item class", 0);
asset.CreateItem(itemClass, new Bounding(0, 0, 1, 1));
- DbWeightsDataAccess weightsDataAccess = new(dbContext);
- weightsDataAccess.CreateWeights(dataSet.Weights, Array.Empty(), Array.Empty(), Size.Medium,
+ DbWeightsDataAccess weightsDataAccessTests = new(dbContext);
+ var weights = weightsDataAccessTests.CreateWeights(dataSet.Weights, [0], [1], Size.Medium,
new WeightsMetrics(11, new LossMetrics(12, 13, 14)), Array.Empty());
dbContext.DataSets.Add(dataSet);
dbContext.SaveChanges();
+
+ weightsDataAccessTests.LoadWeightsONNXData(weights).Content.Single().Should().Be(0);
}
}
\ No newline at end of file
diff --git a/SightKeeper.Data.Tests/DbWeightsDataAccessTests.cs b/SightKeeper.Data.Tests/DbWeightsDataAccessTests.cs
new file mode 100644
index 00000000..968d7d7d
--- /dev/null
+++ b/SightKeeper.Data.Tests/DbWeightsDataAccessTests.cs
@@ -0,0 +1,33 @@
+using SightKeeper.Data.Services;
+using SightKeeper.Domain.Model.DataSets;
+using SightKeeper.Tests.Common;
+
+namespace SightKeeper.Data.Tests;
+
+public sealed class DbWeightsDataAccessTests : DbRelatedTests
+{
+ [Fact]
+ public void ShouldSaveAndLoadWeights()
+ {
+ var dataSet = DomainTestsHelper.NewDataSet;
+ var dbContext = DbContextFactory.CreateDbContext();
+ dbContext.Add(dataSet);
+ DbWeightsDataAccess weightsDataAccess = new(dbContext);
+ var weights = weightsDataAccess.CreateWeights(dataSet.Weights, [0], [1], Size.Large, new WeightsMetrics(), Array.Empty());
+ dbContext.SaveChanges();
+ weightsDataAccess.LoadWeightsONNXData(weights).Content.Single().Should().Be(0);
+ weightsDataAccess.LoadWeightsPTData(weights).Content.Single().Should().Be(1);
+ }
+
+ [Fact]
+ public void ShouldSaveAndLoadWeightsWithoutSaving()
+ {
+ var dataSet = DomainTestsHelper.NewDataSet;
+ var dbContext = DbContextFactory.CreateDbContext();
+ dbContext.Add(dataSet);
+ DbWeightsDataAccess weightsDataAccess = new(dbContext);
+ var weights = weightsDataAccess.CreateWeights(dataSet.Weights, [0], [1], Size.Large, new WeightsMetrics(), Array.Empty());
+ weightsDataAccess.LoadWeightsONNXData(weights).Content.Single().Should().Be(0);
+ weightsDataAccess.LoadWeightsPTData(weights).Content.Single().Should().Be(1);
+ }
+}
\ No newline at end of file
diff --git a/SightKeeper.Data/Configuration/WeightsDataConfiguration.cs b/SightKeeper.Data/Configuration/WeightsDataConfiguration.cs
index 8fe5c5b4..cbee73f8 100644
--- a/SightKeeper.Data/Configuration/WeightsDataConfiguration.cs
+++ b/SightKeeper.Data/Configuration/WeightsDataConfiguration.cs
@@ -9,7 +9,8 @@ public void Configure(EntityTypeBuilder builder)
{
builder.HasFlakeIdKey();
builder.ToTable("WeightsData");
- builder.ComplexProperty(weightsData => weightsData.Data, subBuilder => subBuilder.Property(weights => weights.Content).HasColumnName("Content"));
+ // probably builder.ComplexProperty can be used, but https://github.com/dotnet/efcore/issues/9849 or smthng (Works with SqLite, but won't with InMemory)
+ builder.OwnsOne(weightsData => weightsData.Data, subBuilder => subBuilder.Property(weights => weights.Content).HasColumnName("Content"));
builder.Property(weightsData => weightsData.Format);
builder.HasIndex("WeightsId", "Format").IsUnique();
}
diff --git a/SightKeeper.Data/Migrations/20240312211014_Initial.Designer.cs b/SightKeeper.Data/Migrations/20240312223758_Initial.Designer.cs
similarity index 99%
rename from SightKeeper.Data/Migrations/20240312211014_Initial.Designer.cs
rename to SightKeeper.Data/Migrations/20240312223758_Initial.Designer.cs
index 54921dcb..cf717ffa 100644
--- a/SightKeeper.Data/Migrations/20240312211014_Initial.Designer.cs
+++ b/SightKeeper.Data/Migrations/20240312223758_Initial.Designer.cs
@@ -12,14 +12,14 @@
namespace SightKeeper.Data.Migrations
{
[DbContext(typeof(AppDbContext))]
- [Migration("20240312211014_Initial")]
+ [Migration("20240312223758_Initial")]
partial class Initial
{
///
protected override void BuildTargetModel(ModelBuilder modelBuilder)
{
#pragma warning disable 612, 618
- modelBuilder.HasAnnotation("ProductVersion", "8.0.2");
+ modelBuilder.HasAnnotation("ProductVersion", "8.0.3");
modelBuilder.Entity("SightKeeper.Data.DbWeightsData", b =>
{
diff --git a/SightKeeper.Data/Migrations/20240312211014_Initial.cs b/SightKeeper.Data/Migrations/20240312223758_Initial.cs
similarity index 100%
rename from SightKeeper.Data/Migrations/20240312211014_Initial.cs
rename to SightKeeper.Data/Migrations/20240312223758_Initial.cs
diff --git a/SightKeeper.Data/Migrations/AppDbContextModelSnapshot.cs b/SightKeeper.Data/Migrations/AppDbContextModelSnapshot.cs
index 2f03bdd8..c0c60ccd 100644
--- a/SightKeeper.Data/Migrations/AppDbContextModelSnapshot.cs
+++ b/SightKeeper.Data/Migrations/AppDbContextModelSnapshot.cs
@@ -16,7 +16,7 @@ partial class AppDbContextModelSnapshot : ModelSnapshot
protected override void BuildModel(ModelBuilder modelBuilder)
{
#pragma warning disable 612, 618
- modelBuilder.HasAnnotation("ProductVersion", "8.0.2");
+ modelBuilder.HasAnnotation("ProductVersion", "8.0.3");
modelBuilder.Entity("SightKeeper.Data.DbWeightsData", b =>
{
diff --git a/SightKeeper.Data/Services/DbWeightsDataAccess.cs b/SightKeeper.Data/Services/DbWeightsDataAccess.cs
index 2891dab0..15b37e57 100644
--- a/SightKeeper.Data/Services/DbWeightsDataAccess.cs
+++ b/SightKeeper.Data/Services/DbWeightsDataAccess.cs
@@ -1,4 +1,5 @@
using FlakeId;
+using Microsoft.EntityFrameworkCore;
using SightKeeper.Domain.Model.DataSets;
namespace SightKeeper.Data.Services;
@@ -8,25 +9,50 @@ public sealed class DbWeightsDataAccess : WeightsDataAccess
public DbWeightsDataAccess(AppDbContext dbContext)
{
_dbContext = dbContext;
+ _dbContext.SavedChanges += OnDbContextSavedChanges;
}
+
public override WeightsData LoadWeightsONNXData(Weights weights)
{
- throw new NotImplementedException();
+ return LoadWeightsData(weights, DbWeightsData.DataFormat.ONNX);
}
public override WeightsData LoadWeightsPTData(Weights weights)
{
- throw new NotImplementedException();
+ return LoadWeightsData(weights, DbWeightsData.DataFormat.PT);
}
protected override void SaveWeightsData(Weights weights, WeightsData onnxData, WeightsData ptData)
{
- _dbContext.Add(new DbWeightsData(onnxData, weights, DbWeightsData.DataFormat.ONNX));
- _dbContext.Add(new DbWeightsData(ptData, weights, DbWeightsData.DataFormat.PT));
+ DbWeightsData onnxWeightsData = new(onnxData, weights, DbWeightsData.DataFormat.ONNX);
+ _dbContext.Add(onnxWeightsData);
+ _unsavedWeightsData.Add((weights, DbWeightsData.DataFormat.ONNX), onnxData);
+ DbWeightsData ptWeightsData = new(ptData, weights, DbWeightsData.DataFormat.PT);
+ _dbContext.Add(ptWeightsData);
+ _unsavedWeightsData.Add((weights, DbWeightsData.DataFormat.PT), ptData);
}
protected override void RemoveWeightsData(Weights weights)
{
- throw new NotImplementedException();
+ _unsavedWeightsData.Remove((weights, DbWeightsData.DataFormat.ONNX));
+ _unsavedWeightsData.Remove((weights, DbWeightsData.DataFormat.PT));
}
private readonly AppDbContext _dbContext;
+ private readonly Dictionary<(Weights, DbWeightsData.DataFormat), WeightsData> _unsavedWeightsData = new();
+
+ private WeightsData LoadWeightsData(Weights weights, DbWeightsData.DataFormat format)
+ {
+ if (_unsavedWeightsData.TryGetValue((weights, format), out var unsavedWeightsData))
+ return unsavedWeightsData;
+ var weightsEntry = _dbContext.Entry(weights);
+ var weightsId = weightsEntry.Property("Id").CurrentValue;
+ return _dbContext.Set()
+ .AsNoTracking()
+ .Where(weightsData => EF.Property(weightsData.Weights, "Id") == weightsId && weightsData.Format == format)
+ .Select(weightsData => weightsData.Data)
+ .Single();
+ }
+ private void OnDbContextSavedChanges(object? sender, SavedChangesEventArgs e)
+ {
+ _unsavedWeightsData.Clear();
+ }
}
\ No newline at end of file
diff --git a/SightKeeper.Data/SightKeeper.Data.csproj b/SightKeeper.Data/SightKeeper.Data.csproj
index 1873fee8..43b6555d 100644
--- a/SightKeeper.Data/SightKeeper.Data.csproj
+++ b/SightKeeper.Data/SightKeeper.Data.csproj
@@ -9,11 +9,11 @@
-
-
-
-
-
+
+
+
+
+
all
runtime; build; native; contentfiles; analyzers; buildtransitive
diff --git a/SightKeeper.Domain.Model/DataSets/WeightsDataAccess.cs b/SightKeeper.Domain.Model/DataSets/WeightsDataAccess.cs
index 0cd4ac19..23e259f7 100644
--- a/SightKeeper.Domain.Model/DataSets/WeightsDataAccess.cs
+++ b/SightKeeper.Domain.Model/DataSets/WeightsDataAccess.cs
@@ -3,7 +3,7 @@
namespace SightKeeper.Domain.Model.DataSets;
-public abstract class WeightsDataAccess
+public abstract class WeightsDataAccess : IDisposable
{
public IObservable WeightsCreated => _weightsCreated.AsObservable();
public IObservable WeightsRemoved => _weightsRemoved.AsObservable();
@@ -30,6 +30,12 @@ public void RemoveWeights(Weights weights)
RemoveWeightsData(weights);
_weightsRemoved.OnNext(weights);
}
+ public void Dispose()
+ {
+ _weightsCreated.Dispose();
+ _weightsRemoved.Dispose();
+ GC.SuppressFinalize(this);
+ }
protected abstract void SaveWeightsData(Weights weights, WeightsData onnxData, WeightsData ptData);
protected abstract void RemoveWeightsData(Weights weights);
diff --git a/SightKeeper.Domain.Model/SightKeeper.Domain.Model.csproj b/SightKeeper.Domain.Model/SightKeeper.Domain.Model.csproj
index 10d0abdf..1b6aff61 100644
--- a/SightKeeper.Domain.Model/SightKeeper.Domain.Model.csproj
+++ b/SightKeeper.Domain.Model/SightKeeper.Domain.Model.csproj
@@ -13,7 +13,7 @@
-
+
all
runtime; build; native; contentfiles; analyzers; buildtransitive
diff --git a/SightKeeper.Services.Tests/SightKeeper.Services.Tests.csproj b/SightKeeper.Services.Tests/SightKeeper.Services.Tests.csproj
index d8e71c77..d64fde78 100644
--- a/SightKeeper.Services.Tests/SightKeeper.Services.Tests.csproj
+++ b/SightKeeper.Services.Tests/SightKeeper.Services.Tests.csproj
@@ -14,7 +14,7 @@
-
+
runtime; build; native; contentfiles; analyzers; buildtransitive
diff --git a/SightKeeper.Services.Windows/SightKeeper.Services.Windows.csproj b/SightKeeper.Services.Windows/SightKeeper.Services.Windows.csproj
index ac39ec10..4eb89eb9 100644
--- a/SightKeeper.Services.Windows/SightKeeper.Services.Windows.csproj
+++ b/SightKeeper.Services.Windows/SightKeeper.Services.Windows.csproj
@@ -15,7 +15,7 @@
-
+
diff --git a/SightKeeper.Services/Prediction/ONNXDetector.cs b/SightKeeper.Services/Prediction/ONNXDetector.cs
index f92bcd5f..f5d6247f 100644
--- a/SightKeeper.Services/Prediction/ONNXDetector.cs
+++ b/SightKeeper.Services/Prediction/ONNXDetector.cs
@@ -35,9 +35,15 @@ public Weights? Weights
private void SetWeights(Weights weights)
{
var weightsData = weightsDataAccess.LoadWeightsONNXData(weights);
- _predictor = new YoloV8(new ModelSelector(weightsData.Content), CreateMetadata(weights.Library.DataSet));
- _predictor.Parameters.Confidence = ProbabilityThreshold;
- _predictor.Parameters.IoU = IoU;
+ YoloV8Builder builder = new();
+ builder.UseOnnxModel(new BinarySelector(weightsData.Content));
+ builder.WithMetadata(CreateMetadata(weights.Library.DataSet));
+ builder.WithConfiguration(configuration =>
+ {
+ configuration.Confidence = ProbabilityThreshold;
+ configuration.IoU = IoU;
+ });
+ _predictor = builder.Build();
}
public float ProbabilityThreshold
@@ -47,7 +53,7 @@ public float ProbabilityThreshold
{
_probabilityThreshold = value;
if (_predictor != null)
- _predictor.Parameters.Confidence = value;
+ _predictor.Configuration.Confidence = value;
}
}
@@ -58,7 +64,7 @@ public float IoU
{
_iou = value;
if (_predictor != null)
- _predictor.Parameters.IoU = value;
+ _predictor.Configuration.IoU = value;
}
}
@@ -82,9 +88,9 @@ public async Task> DetectAsync(byte[] image, Cancel
return result;
}
- private float _probabilityThreshold = YoloV8Parameters.Default.Confidence;
- private float _iou = YoloV8Parameters.Default.IoU;
- private YoloV8? _predictor;
+ private float _probabilityThreshold = YoloV8Configuration.Default.Confidence;
+ private float _iou = YoloV8Configuration.Default.IoU;
+ private YoloV8Predictor? _predictor;
private Weights? _weights;
private Dictionary? _itemClasses;
diff --git a/SightKeeper.Services/SightKeeper.Services.csproj b/SightKeeper.Services/SightKeeper.Services.csproj
index e38da76d..b400172d 100644
--- a/SightKeeper.Services/SightKeeper.Services.csproj
+++ b/SightKeeper.Services/SightKeeper.Services.csproj
@@ -13,10 +13,10 @@
-
-
-
-
+
+
+
+