Skip to content

Commit

Permalink
Allow weights creation directly in library with specified creation da…
Browse files Browse the repository at this point in the history
…te, move WeightsDataAccess from domain to application
  • Loading branch information
Neakita committed Aug 24, 2024
1 parent f5e86f8 commit 7cc053b
Show file tree
Hide file tree
Showing 28 changed files with 65 additions and 162 deletions.
1 change: 0 additions & 1 deletion SightKeeper.Application/Annotating/HotKeyScreenshoter.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using CommunityToolkit.Diagnostics;
using SharpHook.Native;
using SightKeeper.Application.Input;
using SightKeeper.Domain.Model.DataSets;
using SightKeeper.Domain.Model.DataSets.Detector;

namespace SightKeeper.Application.Annotating;
Expand Down
1 change: 0 additions & 1 deletion SightKeeper.Application/Annotating/Screenshoter.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using SightKeeper.Domain.Model;
using SightKeeper.Domain.Model.DataSets.Detector;
using SightKeeper.Domain.Model.DataSets.Screenshots;
using SightKeeper.Domain.Services;

namespace SightKeeper.Application.Annotating;

Expand Down
3 changes: 1 addition & 2 deletions SightKeeper.Application/Annotating/StreamScreenshoter.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using SightKeeper.Domain.Model.DataSets;
using SightKeeper.Domain.Model.DataSets.Detector;
using SightKeeper.Domain.Model.DataSets.Detector;

namespace SightKeeper.Application.Annotating;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
using System.Reactive.Linq;
using SightKeeper.Domain.Model.DataSets;
using SightKeeper.Domain.Model.DataSets.Detector;
using SightKeeper.Domain.Services;

namespace SightKeeper.Application.Prediction.Handling;

Expand Down
1 change: 0 additions & 1 deletion SightKeeper.Application/Prediction/ONNXDetector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
using SightKeeper.Domain.Model.DataSets;
using SightKeeper.Domain.Model.DataSets.Tags;
using SightKeeper.Domain.Model.DataSets.Weights;
using SightKeeper.Domain.Services;
using SixLabors.ImageSharp;
using RectangleF = System.Drawing.RectangleF;
using Size = SixLabors.ImageSharp.Size;
Expand Down
1 change: 0 additions & 1 deletion SightKeeper.Application/Training/Trainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
using SightKeeper.Application.Extensions;
using SightKeeper.Domain.Model.DataSets.Detector;
using SightKeeper.Domain.Model.DataSets.Weights;
using SightKeeper.Domain.Services;

namespace SightKeeper.Application.Training;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,35 @@
using SightKeeper.Domain.Model.DataSets.Tags;
using SightKeeper.Domain.Model.DataSets.Weights;

namespace SightKeeper.Domain.Services;
namespace SightKeeper.Application;

public abstract class WeightsDataAccess
{
public Weights<TTag> CreateWeights<TTag>(
WeightsLibrary<TTag> library,
byte[] data,
DateTime creationDate,
ModelSize modelSize,
WeightsMetrics metrics,
IEnumerable<TTag> tags) where TTag : Tag, MinimumTagsCount
{
var weights = library.CreateWeights(modelSize, metrics, tags);
SaveWeightsData(weights, new WeightsData(data));
var weights = library.CreateWeights(creationDate, modelSize, metrics, tags);
SaveWeightsData(weights, data);
return weights;
}

public Weights<TTag, TKeyPointTag> CreateWeights<TTag, TKeyPointTag>(
WeightsLibrary<TTag, TKeyPointTag> library,
byte[] data,
DateTime creationDate,
ModelSize modelSize,
WeightsMetrics metrics,
IEnumerable<(TTag, IEnumerable<TKeyPointTag>)> tags)
where TTag : PoserTag
where TKeyPointTag : KeyPointTag<TTag>
{
var weights = library.CreateWeights(modelSize, metrics, tags);
SaveWeightsData(weights, new WeightsData(data));
var weights = library.CreateWeights(creationDate, modelSize, metrics, tags);
SaveWeightsData(weights, data);
return weights;
}

Expand All @@ -46,8 +48,8 @@ public void RemoveWeights<TTag, TKeyPointTag>(Weights<TTag, TKeyPointTag> weight
RemoveWeightsData(weights);
}

public abstract WeightsData LoadWeightsData(Weights weights);
public abstract byte[] LoadWeightsData(Weights weights);

protected abstract void SaveWeightsData(Weights weights, WeightsData data);
protected abstract void SaveWeightsData(Weights weights, byte[] data);
protected abstract void RemoveWeightsData(Weights weights);
}
2 changes: 1 addition & 1 deletion SightKeeper.Avalonia/DataSets/DataSetViewModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
using System.Collections.Generic;
using System.Reactive.Disposables;
using DynamicData;
using SightKeeper.Application;
using SightKeeper.Avalonia.ViewModels;
using SightKeeper.Domain.Model;
using SightKeeper.Domain.Model.DataSets;
using SightKeeper.Domain.Model.DataSets.Tags;
using SightKeeper.Domain.Model.DataSets.Weights;
using SightKeeper.Domain.Services;

namespace SightKeeper.Avalonia.DataSets;

Expand Down
1 change: 0 additions & 1 deletion SightKeeper.Avalonia/DataSets/DataSetsListViewModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
using SightKeeper.Application.Extensions;
using SightKeeper.Avalonia.ViewModels;
using SightKeeper.Domain.Model.DataSets;
using SightKeeper.Domain.Services;

namespace SightKeeper.Avalonia.DataSets;

Expand Down
1 change: 0 additions & 1 deletion SightKeeper.Avalonia/Setup/ServicesBootstrapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
using SightKeeper.Data.Binary.Formatters;
using SightKeeper.Data.Binary.Services;
using SightKeeper.Domain.Model.DataSets;
using SightKeeper.Domain.Services;
using GamesDataAccess = SightKeeper.Application.Games.GamesDataAccess;

namespace SightKeeper.Avalonia.Setup;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
using System.Reactive.Subjects;
using CommunityToolkit.Diagnostics;
using SightKeeper.Application.Annotating;
using SightKeeper.Domain.Model.DataSets;
using SightKeeper.Domain.Model.DataSets.Detector;

namespace SightKeeper.Avalonia.ViewModels.Annotating;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
using System.Collections.Generic;
using System.Reactive.Disposables;
using DynamicData;
using SightKeeper.Application;
using SightKeeper.Avalonia.DataSets;
using SightKeeper.Domain.Model.DataSets.Weights;
using SightKeeper.Domain.Services;

namespace SightKeeper.Avalonia.ViewModels.Annotating;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input;
using DynamicData;
using SightKeeper.Application;
using SightKeeper.Avalonia.Dialogs;
using SightKeeper.Domain.Model.DataSets.Weights;
using SightKeeper.Domain.Services;

namespace SightKeeper.Avalonia.ViewModels.Dialogs;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
using System.Collections.Immutable;
using System.Runtime.CompilerServices;
using CommunityToolkit.Diagnostics;
using SightKeeper.Data.Binary.DataSets;
using SightKeeper.Data.Binary.Services;
using SightKeeper.Domain.Model.DataSets.Classifier;
using SightKeeper.Domain.Model.DataSets.Screenshots;
using SightKeeper.Domain.Model.DataSets.Tags;
using SightKeeper.Domain.Model.DataSets.Weights;
using ClassifierAsset = SightKeeper.Data.Binary.DataSets.Classifier.ClassifierAsset;
using ClassifierDataSet = SightKeeper.Data.Binary.DataSets.Classifier.ClassifierDataSet;
using Screenshot = SightKeeper.Data.Binary.DataSets.Screenshot;
Expand Down Expand Up @@ -67,14 +64,6 @@ public Domain.Model.DataSets.Classifier.ClassifierDataSet ConvertBack(
private readonly ScreenshotsConverter _screenshotsConverter;
private readonly ClassifierAssetsConverter _assetsConverter;

[UnsafeAccessor(UnsafeAccessorKind.Method)]
private static extern Weights<TTag> CreateWeights<TTag>(
WeightsLibrary<TTag> library,
ModelSize size,
WeightsMetrics metrics,
IEnumerable<ClassifierTag> tags)
where TTag : Domain.Model.DataSets.Tags.Tag, MinimumTagsCount;

private static void AddTags(Domain.Model.DataSets.Classifier.ClassifierDataSet dataSet, ImmutableArray<Tag> tags, ReverseConversionSession session)
{
foreach (var rawTag in tags)
Expand Down Expand Up @@ -110,7 +99,7 @@ private void AddWeights(Domain.Model.DataSets.Classifier.ClassifierDataSet dataS
{
foreach (var rawWeights in raw)
{
var weights = CreateWeights(dataSet.Weights, rawWeights.Size, rawWeights.Metrics, rawWeights.Tags.Select(tagId => (ClassifierTag)session.Tags[tagId]));
var weights = dataSet.Weights.CreateWeights(rawWeights.CreationDate, rawWeights.Size, rawWeights.Metrics, rawWeights.Tags.Select(tagId => (ClassifierTag)session.Tags[tagId]));
_weightsDataAccess.AssociateId(weights, rawWeights.Id);
session.Weights.Add(rawWeights.Id, weights);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
using System.Collections.Immutable;
using System.Runtime.CompilerServices;
using CommunityToolkit.Diagnostics;
using SightKeeper.Data.Binary.DataSets;
using SightKeeper.Data.Binary.Services;
using SightKeeper.Domain.Model.DataSets.Detector;
using SightKeeper.Domain.Model.DataSets.Screenshots;
using SightKeeper.Domain.Model.DataSets.Tags;
using SightKeeper.Domain.Model.DataSets.Weights;
using DetectorAsset = SightKeeper.Data.Binary.DataSets.Detector.DetectorAsset;
using DetectorDataSet = SightKeeper.Data.Binary.DataSets.Detector.DetectorDataSet;
using Screenshot = SightKeeper.Data.Binary.DataSets.Screenshot;
Expand Down Expand Up @@ -69,14 +66,6 @@ internal Domain.Model.DataSets.Detector.DetectorDataSet ConvertBack(
private readonly WeightsConverter _weightsConverter;
private readonly DetectorAssetsConverter _assetsConverter;

[UnsafeAccessor(UnsafeAccessorKind.Method)]
private static extern Weights<TTag> CreateWeights<TTag>(
WeightsLibrary<TTag> library,
ModelSize size,
WeightsMetrics metrics,
IEnumerable<DetectorTag> tags)
where TTag : Domain.Model.DataSets.Tags.Tag, MinimumTagsCount;

private static void AddTags(Domain.Model.DataSets.Detector.DetectorDataSet dataSet, ImmutableArray<Tag> tags, ReverseConversionSession session)
{
foreach (var rawTag in tags)
Expand Down Expand Up @@ -113,7 +102,7 @@ private void AddWeights(Domain.Model.DataSets.Detector.DetectorDataSet dataSet,
{
foreach (var rawWeights in raw)
{
var weights = CreateWeights(dataSet.Weights, rawWeights.Size, rawWeights.Metrics, rawWeights.Tags.Select(tagId => (DetectorTag)session.Tags[tagId]));
var weights = dataSet.Weights.CreateWeights(rawWeights.CreationDate, rawWeights.Size, rawWeights.Metrics, rawWeights.Tags.Select(tagId => (DetectorTag)session.Tags[tagId]));
_weightsDataAccess.AssociateId(weights, rawWeights.Id);
session.Weights.Add(rawWeights.Id, weights);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
using System.Collections.Immutable;
using System.Runtime.CompilerServices;
using CommunityToolkit.Diagnostics;
using FlakeId;
using SightKeeper.Data.Binary.DataSets.Poser;
using SightKeeper.Data.Binary.Services;
using SightKeeper.Domain.Model.DataSets.Poser;
using SightKeeper.Domain.Model.DataSets.Poser2D;
using SightKeeper.Domain.Model.DataSets.Screenshots;
using SightKeeper.Domain.Model.DataSets.Weights;
using Poser2DAsset = SightKeeper.Data.Binary.DataSets.Poser2D.Poser2DAsset;
using Poser2DDataSet = SightKeeper.Data.Binary.DataSets.Poser2D.Poser2DDataSet;
using Poser2DTag = SightKeeper.Data.Binary.DataSets.Poser2D.Poser2DTag;
Expand Down Expand Up @@ -72,15 +69,6 @@ internal Domain.Model.DataSets.Poser2D.Poser2DDataSet ConvertBack(
private readonly WeightsConverter _weightsConverter;
private readonly Poser2DAssetsConverter _assetsConverter;

[UnsafeAccessor(UnsafeAccessorKind.Method)]
private static extern Weights<TTag, TKeyPointTag> CreateWeights<TTag, TKeyPointTag>(
WeightsLibrary<TTag, TKeyPointTag> library,
ModelSize size,
WeightsMetrics metrics,
ImmutableDictionary<Domain.Model.DataSets.Poser2D.Poser2DTag, ImmutableHashSet<KeyPointTag2D>> tags)
where TTag : PoserTag
where TKeyPointTag : KeyPointTag<TTag>;

private static void CreateTags(Domain.Model.DataSets.Poser2D.Poser2DDataSet dataSet, ImmutableArray<Poser2DTag> tags, ReverseConversionSession session)
{
foreach (var rawTag in tags)
Expand Down Expand Up @@ -136,18 +124,20 @@ private void CreateWeights(Domain.Model.DataSets.Poser2D.Poser2DDataSet dataSet,
{
foreach (var rawWeights in raw)
{
var weights = CreateWeights(dataSet.Weights, rawWeights.Size, rawWeights.Metrics, ConvertBack(rawWeights.Tags, session));
var weights = dataSet.Weights.CreateWeights(rawWeights.CreationDate, rawWeights.Size, rawWeights.Metrics, ConvertBack(rawWeights.Tags, session));
_weightsDataAccess.AssociateId(weights, rawWeights.Id);
session.Weights.Add(rawWeights.Id, weights);
}
}

private ImmutableDictionary<Domain.Model.DataSets.Poser2D.Poser2DTag, ImmutableHashSet<KeyPointTag2D>> ConvertBack(
private IEnumerable<(Domain.Model.DataSets.Poser2D.Poser2DTag, IEnumerable<KeyPointTag2D>)> ConvertBack(
ImmutableArray<(Id Id, ImmutableArray<Id> KeyPointIds)> tags,
ReverseConversionSession session)
{
return tags.ToImmutableDictionary(
t => (Domain.Model.DataSets.Poser2D.Poser2DTag)session.Tags[t.Id],
t => t.KeyPointIds.Select(id => (KeyPointTag2D)session.Tags[id]).ToImmutableHashSet());
foreach (var (tagId, keyPointIds) in tags)
{
yield return ((Domain.Model.DataSets.Poser2D.Poser2DTag)session.Tags[tagId],
keyPointIds.Select(id => (KeyPointTag2D)session.Tags[id]).ToImmutableList());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ private static ImmutableArray<Id> Convert<TTag>(
}

private static ImmutableArray<(Id, ImmutableArray<Id>)> Convert<TTag, TKeyPointTag>(
IImmutableDictionary<TTag, IImmutableSet<TKeyPointTag>> tags,
ImmutableDictionary<TTag, ImmutableHashSet<TKeyPointTag>> tags,
ConversionSession session)
where TTag : PoserTag
where TKeyPointTag : KeyPointTag<TTag>
Expand Down
17 changes: 6 additions & 11 deletions SightKeeper.Data/Binary/Services/FileSystemWeightsDataAccess.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using System.Runtime.CompilerServices;
using FlakeId;
using FlakeId;
using SightKeeper.Application;
using SightKeeper.Domain.Model.DataSets.Weights;
using SightKeeper.Domain.Services;

namespace SightKeeper.Data.Binary.Services;

Expand All @@ -13,13 +12,9 @@ public string DirectoryPath
set => _dataAccess.DirectoryPath = value;
}

public override WeightsData LoadWeightsData(Weights weights)
public override byte[] LoadWeightsData(Weights weights)
{
var data = _dataAccess.ReadAllBytes(weights);
return CreateWeightsData(data);

[UnsafeAccessor(UnsafeAccessorKind.Constructor)]
static extern WeightsData CreateWeightsData(byte[] content);
return _dataAccess.ReadAllBytes(weights);
}

public Id GetId(Weights weights)
Expand All @@ -32,9 +27,9 @@ public void AssociateId(Weights weights, Id id)
_dataAccess.AssociateId(weights, id);
}

protected override void SaveWeightsData(Weights weights, WeightsData data)
protected override void SaveWeightsData(Weights weights, byte[] data)
{
_dataAccess.WriteAllBytes(weights, data.Content);
_dataAccess.WriteAllBytes(weights, data);
}

protected override void RemoveWeightsData(Weights weights)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,15 @@ public void ShouldCreateWeights()
ClassifierDataSet dataSet = new();
var tag1 = dataSet.Tags.CreateTag("1");
var tag2 = dataSet.Tags.CreateTag("2");
SimpleWeightsDataAccess weightsDataAccess = new();
var weights = weightsDataAccess.CreateWeights(dataSet.Weights, [], ModelSize.Nano, new WeightsMetrics(), [tag1, tag2]);
var weights = dataSet.Weights.CreateWeights(DateTime.UtcNow, ModelSize.Nano, new WeightsMetrics(), [tag1, tag2]);
dataSet.Weights.Should().Contain(weights);
}

[Fact]
public void ShouldNotCreateWeightsWithNoTags()
{
ClassifierDataSet dataSet = new();
SimpleWeightsDataAccess weightsDataAccess = new();
Assert.ThrowsAny<Exception>(() => weightsDataAccess.CreateWeights(dataSet.Weights, [], ModelSize.Nano, new WeightsMetrics(), []));
Assert.ThrowsAny<Exception>(() => dataSet.Weights.CreateWeights(DateTime.UtcNow, ModelSize.Nano, new WeightsMetrics(), []));
dataSet.Weights.Should().BeEmpty();
}

Expand All @@ -30,8 +28,7 @@ public void ShouldNotCreateWeightsWithOneTag()
{
ClassifierDataSet dataSet = new();
var tag = dataSet.Tags.CreateTag("");
SimpleWeightsDataAccess weightsDataAccess = new();
Assert.ThrowsAny<Exception>(() => weightsDataAccess.CreateWeights(dataSet.Weights, [], ModelSize.Nano, new WeightsMetrics(), [tag]));
Assert.ThrowsAny<Exception>(() => dataSet.Weights.CreateWeights(DateTime.UtcNow, ModelSize.Nano, new WeightsMetrics(), [tag]));
dataSet.Weights.Should().BeEmpty();
}

Expand All @@ -41,8 +38,7 @@ public void ShouldNotCreateWeightsWithDuplicateTags()
ClassifierDataSet dataSet = new();
var tag1 = dataSet.Tags.CreateTag("1");
var tag2 = dataSet.Tags.CreateTag("2");
SimpleWeightsDataAccess weightsDataAccess = new();
Assert.ThrowsAny<Exception>(() => weightsDataAccess.CreateWeights(dataSet.Weights, [], ModelSize.Nano, new WeightsMetrics(), [tag1, tag1, tag2]));
Assert.ThrowsAny<Exception>(() => dataSet.Weights.CreateWeights(DateTime.UtcNow, ModelSize.Nano, new WeightsMetrics(), [tag1, tag1, tag2]));
dataSet.Weights.Should().BeEmpty();
}
}
Loading

0 comments on commit 7cc053b

Please sign in to comment.