Skip to content

Commit

Permalink
Bind attr
Browse files Browse the repository at this point in the history
  • Loading branch information
NikolayPianikov committed Jul 26, 2024
1 parent bd99101 commit bab688a
Show file tree
Hide file tree
Showing 2 changed files with 201 additions and 200 deletions.
196 changes: 2 additions & 194 deletions src/Pure.DI.Core/Core/ApiInvocationProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ internal class ApiInvocationProcessor(
: IApiInvocationProcessor
{
private static readonly char[] TypeNamePartsSeparators = ['.'];
private MdTag? lastTag;

public void ProcessInvocation(
IMetadataVisitor metadataVisitor,
Expand Down Expand Up @@ -63,12 +62,10 @@ public void ProcessInvocation(
if (type is INamedTypeSymbol { TypeArguments.Length: 2, TypeArguments: [_, { } resultType] })
{
VisitFactory(metadataVisitor, semanticModel, resultType, lambdaExpression);
FinishBinding(metadataVisitor, invocation, semanticModel, resultType);
}
else
{
VisitFactory(metadataVisitor, semanticModel, type, lambdaExpression);
FinishBinding(metadataVisitor, invocation, semanticModel, type);
}

break;
Expand Down Expand Up @@ -103,7 +100,7 @@ public void ProcessInvocation(
case nameof(IBinding.Tags):
foreach (var tag in BuildTags(semanticModel, invocation.ArgumentList.Arguments))
{
metadataVisitor.VisitTag(RegisterTag(tag));
metadataVisitor.VisitTag(tag);
}

break;
Expand Down Expand Up @@ -192,7 +189,6 @@ public void ProcessInvocation(
case [{ Expression: LambdaExpressionSyntax lambdaExpression }]:
var factoryType = semantic.GetTypeSymbol<ITypeSymbol>(semanticModel, implementationTypeSyntax);
VisitFactory(metadataVisitor, semanticModel, factoryType, lambdaExpression);
FinishBinding(metadataVisitor, invocation, semanticModel, factoryType);
break;

case [{ Expression: LiteralExpressionSyntax { Token.Value: string sourceCodeStatement } }]:
Expand All @@ -201,13 +197,11 @@ public void ProcessInvocation(
.WithExpressionBody(SyntaxFactory.IdentifierName(sourceCodeStatement));
factoryType = semantic.GetTypeSymbol<ITypeSymbol>(semanticModel, implementationTypeSyntax);
VisitFactory(metadataVisitor, semanticModel, factoryType, lambda, true);
FinishBinding(metadataVisitor, invocation, semanticModel, factoryType);
break;

case []:
var implementationType = semantic.GetTypeSymbol<INamedTypeSymbol>(semanticModel, implementationTypeSyntax);
metadataVisitor.VisitImplementation(new MdImplementation(semanticModel, implementationTypeSyntax, implementationType));
FinishBinding(metadataVisitor, invocation, semanticModel, implementationType);
break;

default:
Expand Down Expand Up @@ -328,184 +322,6 @@ public void ProcessInvocation(
}
}

private void FinishBinding(
IMetadataVisitor metadataVisitor,
SyntaxNode source,
SemanticModel semanticModel,
ITypeSymbol type)
{
try
{
var membersToBind =
from member in type.GetMembers()
where member.DeclaredAccessibility >= Accessibility.Internal && !member.IsStatic && member.CanBeReferencedByName && member is IFieldSymbol or IPropertySymbol or IMethodSymbol
from attribute in member.GetAttributes()
where attribute.AttributeClass?.ToDisplayString(NullableFlowState.None, SymbolDisplayFormat.FullyQualifiedFormat) == Names.BindAttributeName
select (attribute, member);

foreach (var (attribute, member) in membersToBind)
{
var values = arguments.GetArgs(attribute.ConstructorArguments, attribute.NamedArguments, "type", "lifetime", "tags");
ITypeSymbol? contractType = default;
if (values.Length > 0 && values[0].Value is ITypeSymbol newContractType)
{
contractType = newContractType;
}

const string ctxName = "ctx";
const string valueName = "value";
var instance = SyntaxFactory.IdentifierName(valueName);
ExpressionSyntax value = SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
instance,
SyntaxFactory.IdentifierName(member.Name));

var position = 0;
var resolvers = new List<MdResolver>();
var block = new List<StatementSyntax>();
switch (member)
{
case IFieldSymbol fieldSymbol:
contractType ??= fieldSymbol.Type;
break;

case IPropertySymbol propertySymbol:
contractType ??= propertySymbol.Type;
break;

case IMethodSymbol methodSymbol:
contractType ??= methodSymbol.ReturnType;

// ReSharper disable once ForeachCanBeConvertedToQueryUsingAnotherGetEnumerator
foreach (var parameter in methodSymbol.Parameters)
{
block.Add(SyntaxFactory.ExpressionStatement(Inject(parameter.Type, parameter.Name, resolvers, MdTag.ContextTag, ref position)));
}

var args = methodSymbol.Parameters
.Select(i => SyntaxFactory.Argument(SyntaxFactory.IdentifierName(i.Name)))
.ToArray();

value = SyntaxFactory
.InvocationExpression(value)
.AddArgumentListArguments(args);

break;
}

if (contractType is null)
{
continue;
}

var lifetime = Lifetime.Transient;
if (values.Length > 1 && values[1].Value is int newLifetime)
{
lifetime = (Lifetime)newLifetime;
}

List<object?> tags;
if (values.Length > 2)
{
var tagsValue = values[2];
if (tagsValue.Kind == TypedConstantKind.Array)
{
if (!tagsValue.Values.IsDefaultOrEmpty)
{
tags = tagsValue.Values.Select(i => i.Value).ToList();
}
else
{
tags = [];
}
}
else
{
tags = values.Skip(2).Select(i => i.Value).ToList();
}
}
else
{
tags = [];
}

var contextParameter = SyntaxFactory.Parameter(SyntaxFactory.Identifier(ctxName));
block.Add(SyntaxFactory.ExpressionStatement(Inject(type, valueName, resolvers, lastTag?.Value, ref position)));
block.Add(SyntaxFactory.ReturnStatement(value));
var lambdaExpression = SyntaxFactory.SimpleLambdaExpression(contextParameter)
.WithBlock(SyntaxFactory.Block(block));

metadataVisitor.VisitContract(
new MdContract(
semanticModel,
source,
contractType,
ContractKind.Explicit,
ImmutableArray<MdTag>.Empty));

if (lifetime != Lifetime.Transient)
{
metadataVisitor.VisitLifetime(new MdLifetime(semanticModel, source, lifetime));
}

var tagPosition = 0;
foreach (var tag in tags)
{
metadataVisitor.VisitTag(new MdTag(tagPosition++, tag));
}

if (tags.Count == 0)
{
metadataVisitor.VisitTag(new MdTag(tagPosition, default));
}

metadataVisitor.VisitFactory(
new MdFactory(
semanticModel,
source,
contractType,
lambdaExpression,
contextParameter,
resolvers.ToImmutableArray(),
true));

continue;

InvocationExpressionSyntax Inject(ITypeSymbol injectedType, string injectedName, ICollection<MdResolver> resolversSet, object? tag, ref int curPosition)
{
resolversSet.Add(new MdResolver
{
SemanticModel = semanticModel,
Source = source,
ContractType = injectedType,
Tag = new MdTag(curPosition++, tag)
});

// MdTag.ContextTag

var valueDeclaration = SyntaxFactory.DeclarationExpression(
SyntaxFactory.ParseTypeName(injectedType.ToDisplayString(NullableFlowState.None, SymbolDisplayFormat.FullyQualifiedFormat)).WithTrailingTrivia(SyntaxFactory.Space),
SyntaxFactory.SingleVariableDesignation(SyntaxFactory.Identifier(injectedName)));

var valueArg =
SyntaxFactory.Argument(valueDeclaration)
.WithRefOrOutKeyword(SyntaxFactory.Token(SyntaxKind.OutKeyword));

return SyntaxFactory.InvocationExpression(
SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
SyntaxFactory.IdentifierName(ctxName),
SyntaxFactory.IdentifierName(nameof(IContext.Inject))))
.AddArgumentListArguments(valueArg);
}
}
}
finally
{
lastTag = default;
}
}

private bool TryGetAttributeType(
GenericNameSyntax genericName,
SemanticModel semanticModel,
Expand Down Expand Up @@ -561,7 +377,6 @@ private void VisitBind(
InvocationExpressionSyntax invocation,
GenericNameSyntax genericName)
{
lastTag = default;
var tags = BuildTags(semanticModel, invocation.ArgumentList.Arguments);
VisitBind(metadataVisitor, semanticModel, invocation, tags, genericName);
}
Expand Down Expand Up @@ -603,7 +418,6 @@ private void VisitArg(
var argType = semantic.GetTypeSymbol<ITypeSymbol>(semanticModel, argTypeSyntax);
metadataVisitor.VisitContract(new MdContract(semanticModel, invocation, argType, ContractKind.Explicit, tags.ToImmutableArray()));
metadataVisitor.VisitArg(new MdArg(semanticModel, argTypeSyntax, argType, name, kind, argComments));
FinishBinding(metadataVisitor, invocation, semanticModel, argType);
}
}

Expand Down Expand Up @@ -743,18 +557,12 @@ private IReadOnlyList<T> BuildConstantArgs<T>(
$"{a.Expression} must be a non-null value of type {typeof(T)}.", a.Expression.GetLocation(), LogId.ErrorInvalidMetadata))
.ToList();

private MdTag RegisterTag(in MdTag tag)
{
lastTag = tag;
return tag;
}

private ImmutableArray<MdTag> BuildTags(
SemanticModel semanticModel,
IEnumerable<ArgumentSyntax> args) =>
args
.SelectMany(t => semantic.GetConstantValues<object>(semanticModel, t.Expression))
.Select((tag, i) => RegisterTag(new MdTag(i, tag)))
.Select((tag, i) => new MdTag(i, tag))
.ToImmutableArray();

private static CompositionName CreateCompositionName(
Expand Down
Loading

0 comments on commit bab688a

Please sign in to comment.