Skip to content

Commit

Permalink
Support for generic attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
NikolayPianikov committed Jun 5, 2024
1 parent f895f30 commit 7f4a437
Show file tree
Hide file tree
Showing 11 changed files with 242 additions and 19 deletions.
39 changes: 35 additions & 4 deletions readme/custom-attributes.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,31 +29,45 @@ class MyTagAttribute(object tag) : Attribute;
| AttributeTargets.Field)]
class MyTypeAttribute(Type type) : Attribute;

[AttributeUsage(
AttributeTargets.Parameter
| AttributeTargets.Property
| AttributeTargets.Field)]
class MyGenericTypeAttribute<T> : Attribute;

interface IPerson;

class Person([MyTag("NikName")] string name) : IPerson
{
private object? _state;

[MyOrdinal(1)]
[MyType(typeof(int))]
internal object Id = "";

public override string ToString() => $"{Id} {name}";
[MyOrdinal(2)]
public void Initialize([MyGenericType<Uri>] object state) =>
_state = state;

public override string ToString() => $"{Id} {name} {_state}";
}

DI.Setup(nameof(PersonComposition))
.TagAttribute<MyTagAttribute>()
.OrdinalAttribute<MyOrdinalAttribute>()
.TypeAttribute<MyTypeAttribute>()
.TypeAttribute<MyGenericTypeAttribute<TT>>()
.Arg<int>("personId")
.Bind<string>("NikName").To(_ => "Nik")
.Bind<IPerson>().To<Person>()
.Bind().To(_ => new Uri("https://github.com/DevTeam/Pure.DI"))
.Bind("NikName").To(_ => "Nik")
.Bind().To<Person>()

// Composition root
.Root<IPerson>("Person");

var composition = new PersonComposition(personId: 123);
var person = composition.Person;
person.ToString().ShouldBe("123 Nik");
person.ToString().ShouldBe("123 Nik https://github.com/DevTeam/Pure.DI");
```

The following partial class will be generated:
Expand Down Expand Up @@ -82,9 +96,11 @@ partial class PersonComposition
[MethodImpl(MethodImplOptions.AggressiveInlining)]
get
{
Uri transientUri2 = new Uri("https://github.com/DevTeam/Pure.DI");
string transientString1 = "Nik";
Person transientPerson0 = new Person(transientString1);
transientPerson0.Id = _argPersonId;
transientPerson0.Initialize(transientUri2);
return transientPerson0;
}
}
Expand All @@ -100,17 +116,32 @@ classDiagram
+IPerson Person
}
class Int32
Uri --|> ISpanFormattable
Uri --|> IFormattable
Uri --|> ISerializable
class Uri
class String
Person --|> IPerson
class Person {
+Person(String name)
~Object Id
+Initialize(Object state) : Void
}
class ISpanFormattable {
<<interface>>
}
class IFormattable {
<<interface>>
}
class ISerializable {
<<interface>>
}
class IPerson {
<<interface>>
}
PersonComposition ..> Person : IPerson Person
Person *-- String : "NikName" String
Person o-- Int32 : Argument "personId"
Person *-- Uri : Uri
```

12 changes: 9 additions & 3 deletions src/Pure.DI.Core/Core/ApiInvocationProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -222,23 +222,29 @@ public void ProcessInvocation(
case nameof(IConfiguration.TypeAttribute):
if (genericName.TypeArgumentList.Arguments is [{ } typeAttributeType])
{
metadataVisitor.VisitTypeAttribute(new MdTypeAttribute(semanticModel, invocation.ArgumentList, semanticModel.GetTypeSymbol<ITypeSymbol>(typeAttributeType, cancellationToken), BuildConstantArgs<object>(semanticModel, invocation.ArgumentList.Arguments) is [int positionVal] ? positionVal : 0));
var type = semanticModel.GetTypeSymbol<INamedTypeSymbol>(typeAttributeType, cancellationToken);
if (type.IsGenericType)
{
type = type.ConstructUnboundGenericType();
}

metadataVisitor.VisitTypeAttribute(new MdTypeAttribute(semanticModel, invocation.ArgumentList, type, BuildConstantArgs<object>(semanticModel, invocation.ArgumentList.Arguments) is [int positionVal] ? positionVal : 0));
}

break;

case nameof(IConfiguration.TagAttribute):
if (genericName.TypeArgumentList.Arguments is [{ } tagAttributeType])
{
metadataVisitor.VisitTagAttribute(new MdTagAttribute(semanticModel, invocation.ArgumentList, semanticModel.GetTypeSymbol<ITypeSymbol>(tagAttributeType, cancellationToken), BuildConstantArgs<object>(semanticModel, invocation.ArgumentList.Arguments) is [int positionVal] ? positionVal : 0));
metadataVisitor.VisitTagAttribute(new MdTagAttribute(semanticModel, invocation.ArgumentList, semanticModel.GetTypeSymbol<INamedTypeSymbol>(tagAttributeType, cancellationToken), BuildConstantArgs<object>(semanticModel, invocation.ArgumentList.Arguments) is [int positionVal] ? positionVal : 0));
}

break;

case nameof(IConfiguration.OrdinalAttribute):
if (genericName.TypeArgumentList.Arguments is [{ } ordinalAttributeType])
{
metadataVisitor.VisitOrdinalAttribute(new MdOrdinalAttribute(semanticModel, invocation.ArgumentList, semanticModel.GetTypeSymbol<ITypeSymbol>(ordinalAttributeType, cancellationToken), BuildConstantArgs<object>(semanticModel, invocation.ArgumentList.Arguments) is [int positionVal] ? positionVal : 0));
metadataVisitor.VisitOrdinalAttribute(new MdOrdinalAttribute(semanticModel, invocation.ArgumentList, semanticModel.GetTypeSymbol<INamedTypeSymbol>(ordinalAttributeType, cancellationToken), BuildConstantArgs<object>(semanticModel, invocation.ArgumentList.Arguments) is [int positionVal] ? positionVal : 0));
}

break;
Expand Down
10 changes: 8 additions & 2 deletions src/Pure.DI.Core/Core/CompilationExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,20 @@ namespace Pure.DI.Core;
internal static class CompilationExtensions
{
[SuppressMessage("ReSharper", "HeapView.ClosureAllocation")]
public static IReadOnlyList<AttributeData> GetAttributes(this ISymbol member, ITypeSymbol attributeType) =>
public static IReadOnlyList<AttributeData> GetAttributes(this ISymbol member, INamedTypeSymbol attributeType) =>
member
.GetAttributes()
.Where(attr => attr.AttributeClass != null && SymbolEqualityComparer.Default.Equals(attr.AttributeClass, attributeType))
.Where(attr => attr.AttributeClass != null && SymbolEqualityComparer.Default.Equals(GetUnboundTypeSymbol(attr.AttributeClass), attributeType))
.ToArray();

public static LanguageVersion GetLanguageVersion(this Compilation compilation) =>
compilation is CSharpCompilation sharpCompilation
? sharpCompilation.LanguageVersion
: LanguageVersion.Default;

private static INamedTypeSymbol? GetUnboundTypeSymbol(INamedTypeSymbol? typeSymbol) =>
typeSymbol is null
? typeSymbol : typeSymbol.IsGenericType
? typeSymbol.ConstructUnboundGenericType()
: typeSymbol;
}
16 changes: 14 additions & 2 deletions src/Pure.DI.Core/Core/ImplementationDependencyNodeBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
// ReSharper disable ClassNeverInstantiated.Global
namespace Pure.DI.Core;

using ITypeSymbol = Microsoft.CodeAnalysis.ITypeSymbol;

internal sealed class ImplementationDependencyNodeBuilder(
ILogger<ImplementationDependencyNodeBuilder> logger,
IBuilder<DpImplementation, IEnumerable<DpImplementation>> implementationVariantsBuilder)
Expand Down Expand Up @@ -224,8 +226,18 @@ private T GetAttribute<TMdAttribute, T>(
switch (attributeData.Count)
{
case 1:
var args = attributeData[0].ConstructorArguments;
if (attribute.ArgumentPosition > args.Length)
var attr = attributeData[0];
if (typeof(ITypeSymbol).IsAssignableFrom(typeof(T)) && attr.AttributeClass is { IsGenericType: true, TypeArguments.Length: > 0 } attributeClass)
{
if (attribute.ArgumentPosition < attributeClass.TypeArguments.Length
&& attributeClass.TypeArguments[attribute.ArgumentPosition] is { } typeSymbol)
{
return (T)typeSymbol;
}
}

var args = attr.ConstructorArguments;
if (attribute.ArgumentPosition >= args.Length)
{
logger.CompileError($"The argument position {attribute.ArgumentPosition.ToString()} of attribute {attribute.Source} is out of range [0..{args.Length.ToString()}].", attribute.Source.GetLocation(), LogId.ErrorInvalidMetadata);
}
Expand Down
2 changes: 1 addition & 1 deletion src/Pure.DI.Core/Core/Models/IMdAttribute.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ internal interface IMdAttribute

SyntaxNode Source { get; }

ITypeSymbol AttributeType { get; }
INamedTypeSymbol AttributeType { get; }

int ArgumentPosition { get; }
}
2 changes: 1 addition & 1 deletion src/Pure.DI.Core/Core/Models/MdOrdinalAttribute.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ namespace Pure.DI.Core.Models;
internal readonly record struct MdOrdinalAttribute(
SemanticModel SemanticModel,
SyntaxNode Source,
ITypeSymbol AttributeType,
INamedTypeSymbol AttributeType,
int ArgumentPosition) : IMdAttribute
{
public override string ToString() => $".OrdinalAttribute<{AttributeType}>({(ArgumentPosition != 0 ? ArgumentPosition.ToString() : string.Empty)})";
Expand Down
2 changes: 1 addition & 1 deletion src/Pure.DI.Core/Core/Models/MdTagAttribute.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ namespace Pure.DI.Core.Models;
internal readonly record struct MdTagAttribute(
SemanticModel SemanticModel,
SyntaxNode Source,
ITypeSymbol AttributeType,
INamedTypeSymbol AttributeType,
int ArgumentPosition) : IMdAttribute
{
public override string ToString() => $".TagAttribute<{AttributeType}>({(ArgumentPosition != 0 ? ArgumentPosition.ToString() : string.Empty)})";
Expand Down
2 changes: 1 addition & 1 deletion src/Pure.DI.Core/Core/Models/MdTypeAttribute.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ namespace Pure.DI.Core.Models;
internal readonly record struct MdTypeAttribute(
SemanticModel SemanticModel,
SyntaxNode Source,
ITypeSymbol AttributeType,
INamedTypeSymbol AttributeType,
int ArgumentPosition) : IMdAttribute
{
public override string ToString() => $".TypeAttribute<{AttributeType}>({(ArgumentPosition != 0 ? ArgumentPosition.ToString() : string.Empty)})";
Expand Down
1 change: 1 addition & 0 deletions src/Pure.DI.Core/Pure.DI.Core.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
<DesignTime>True</DesignTime>
<DependentUpon>GenericTypeArguments.g.tt</DependentUpon>
</Compile>
<Compile Remove="Components\Api.CSharp11.g.cs" />
</ItemGroup>

</Project>
152 changes: 152 additions & 0 deletions tests/Pure.DI.IntegrationTests/TypeAttributeTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
namespace Pure.DI.IntegrationTests;

public class TypeAttributeTests
{
[Fact]
public async Task ShouldSupportTypeAttribute()
{
// Given

// When
var result = await """
using System;
using Pure.DI;
namespace Sample
{
interface IDependency { }
class AbcDependency : IDependency { }
class XyzDependency : IDependency { }
interface IService
{
IDependency Dependency1 { get; }
IDependency Dependency2 { get; }
}
class Service : IService
{
public Service(
[Type(typeof(AbcDependency))] IDependency dependency1,
[Type(typeof(XyzDependency))] IDependency dependency2)
{
Dependency1 = dependency1;
Dependency2 = dependency2;
}
public IDependency Dependency1 { get; }
public IDependency Dependency2 { get; }
}
partial class Composition
{
private static void SetupComposition()
{
DI.Setup(nameof(Composition))
.Bind<IService>().To<Service>()
// Composition root
.Root<IService>("Root");
}
}
public class Program
{
public static void Main()
{
var composition = new Composition();
var service = composition.Root;
Console.WriteLine(service.Dependency1);
Console.WriteLine(service.Dependency2);
}
}
}
""".RunAsync();

// Then
result.Success.ShouldBeTrue(result);
result.StdOut.ShouldBe(["Sample.AbcDependency", "Sample.XyzDependency"], result);
}

#if ROSLYN4_8_OR_GREATER
[Fact]
public async Task ShouldSupportGenericTypeAttribute()
{
// Given

// When
var result = await """
using System;
using Pure.DI;
namespace Sample
{
[AttributeUsage(AttributeTargets.Parameter | AttributeTargets.Property | AttributeTargets.Field, AllowMultiple = false)]
internal class TypeAttribute<T> : Attribute
{
}
interface IDependency { }
class AbcDependency : IDependency { }
class XyzDependency : IDependency { }
interface IService
{
IDependency Dependency1 { get; }
IDependency Dependency2 { get; }
}
class Service : IService
{
public Service(
[Type<AbcDependency>] IDependency dependency1,
[Type<XyzDependency>] IDependency dependency2)
{
Dependency1 = dependency1;
Dependency2 = dependency2;
}
public IDependency Dependency1 { get; }
public IDependency Dependency2 { get; }
}
partial class Composition
{
private static void SetupComposition()
{
DI.Setup(nameof(Composition))
.TypeAttribute<TypeAttribute<TT>>()
.Bind<IService>().To<Service>()
// Composition root
.Root<IService>("Root");
}
}
public class Program
{
public static void Main()
{
var composition = new Composition();
var service = composition.Root;
Console.WriteLine(service.Dependency1);
Console.WriteLine(service.Dependency2);
}
}
}
""".RunAsync(new Options(LanguageVersion.CSharp11));

// Then
result.Success.ShouldBeTrue(result);
result.StdOut.ShouldBe(["Sample.AbcDependency", "Sample.XyzDependency"], result);
}
#endif
}
Loading

0 comments on commit 7f4a437

Please sign in to comment.