Skip to content

Commit

Permalink
Use code generator for cloning responses (#2223)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexeyzimarev authored Jun 17, 2024
1 parent 6b0e036 commit 57f6b3a
Show file tree
Hide file tree
Showing 10 changed files with 189 additions and 49 deletions.
44 changes: 44 additions & 0 deletions gen/SourceGenerator/Extensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright (c) .NET Foundation and Contributors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

namespace SourceGenerator;

static class Extensions {
public static IEnumerable<ClassDeclarationSyntax> FindClasses(this Compilation compilation, Func<ClassDeclarationSyntax, bool> predicate)
=> compilation.SyntaxTrees
.Select(tree => compilation.GetSemanticModel(tree))
.SelectMany(model => model.SyntaxTree.GetRoot().DescendantNodes().OfType<ClassDeclarationSyntax>())
.Where(predicate);

public static IEnumerable<ClassDeclarationSyntax> FindAnnotatedClass(this Compilation compilation, string attributeName, bool strict) {
return compilation.FindClasses(
syntax => syntax.AttributeLists.Any(list => list.Attributes.Any(CheckAttribute))
);

bool CheckAttribute(AttributeSyntax attr) {
var name = attr.Name.ToString();
return strict ? name == attributeName : name.StartsWith(attributeName);
}
}

public static IEnumerable<ITypeSymbol> GetBaseTypesAndThis(this ITypeSymbol type) {
var current = type;

while (current != null) {
yield return current;

current = current.BaseType;
}
}
}
11 changes: 1 addition & 10 deletions gen/SourceGenerator/ImmutableGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,6 @@
// limitations under the License.
//

using System.Text;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Text;

namespace SourceGenerator;

[Generator]
Expand All @@ -28,10 +22,7 @@ public void Initialize(GeneratorInitializationContext context) { }
public void Execute(GeneratorExecutionContext context) {
var compilation = context.Compilation;

var mutableClasses = compilation.SyntaxTrees
.Select(tree => compilation.GetSemanticModel(tree))
.SelectMany(model => model.SyntaxTree.GetRoot().DescendantNodes().OfType<ClassDeclarationSyntax>())
.Where(syntax => syntax.AttributeLists.Any(list => list.Attributes.Any(attr => attr.Name.ToString() == "GenerateImmutable")));
var mutableClasses = compilation.FindAnnotatedClass("GenerateImmutable", strict: true);

foreach (var mutableClass in mutableClasses) {
var immutableClass = GenerateImmutableClass(mutableClass, compilation);
Expand Down
105 changes: 105 additions & 0 deletions gen/SourceGenerator/InheritedCloneGenerator.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
// Copyright (c) .NET Foundation and Contributors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

namespace SourceGenerator;

[Generator]
public class InheritedCloneGenerator : ISourceGenerator {
const string AttributeName = "GenerateClone";

public void Initialize(GeneratorInitializationContext context) { }

public void Execute(GeneratorExecutionContext context) {
var compilation = context.Compilation;

var candidates = compilation.FindAnnotatedClass(AttributeName, false);

foreach (var candidate in candidates) {
var semanticModel = compilation.GetSemanticModel(candidate.SyntaxTree);
var genericClassSymbol = semanticModel.GetDeclaredSymbol(candidate);
if (genericClassSymbol == null) continue;

// Get the method name from the attribute Name argument
var attributeData = genericClassSymbol.GetAttributes().FirstOrDefault(a => a.AttributeClass?.Name == $"{AttributeName}Attribute");
var methodName = (string)attributeData.NamedArguments.FirstOrDefault(arg => arg.Key == "Name").Value.Value;

// Get the generic argument type where properties need to be copied from
var attributeSyntax = candidate.AttributeLists
.SelectMany(l => l.Attributes)
.FirstOrDefault(a => a.Name.ToString().StartsWith(AttributeName));
if (attributeSyntax == null) continue; // This should never happen

var typeArgumentSyntax = ((GenericNameSyntax)attributeSyntax.Name).TypeArgumentList.Arguments[0];
var typeSymbol = (INamedTypeSymbol)semanticModel.GetSymbolInfo(typeArgumentSyntax).Symbol;

var code = GenerateMethod(candidate, genericClassSymbol, typeSymbol, methodName);
context.AddSource($"{genericClassSymbol.Name}.Clone.g.cs", SourceText.From(code, Encoding.UTF8));
}
}

static string GenerateMethod(
TypeDeclarationSyntax classToExtendSyntax,
INamedTypeSymbol classToExtendSymbol,
INamedTypeSymbol classToClone,
string methodName
) {
var namespaceName = classToExtendSymbol.ContainingNamespace.ToDisplayString();
var className = classToExtendSyntax.Identifier.Text;
var genericTypeParameters = string.Join(", ", classToExtendSymbol.TypeParameters.Select(tp => tp.Name));
var classDeclaration = classToExtendSymbol.TypeParameters.Length > 0 ? $"{className}<{genericTypeParameters}>" : className;

var all = classToClone.GetBaseTypesAndThis();
var props = all.SelectMany(x => x.GetMembers().OfType<IPropertySymbol>()).ToArray();
var usings = classToExtendSyntax.SyntaxTree.GetCompilationUnitRoot().Usings.Select(u => u.ToString());

var constructorParams = classToExtendSymbol.Constructors.First().Parameters.ToArray();
var constructorArgs = string.Join(", ", constructorParams.Select(p => $"original.{GetPropertyName(p.Name, props)}"));
var constructorParamNames = constructorParams.Select(p => p.Name).ToArray();

var properties = props
// ReSharper disable once PossibleUnintendedLinearSearchInSet
.Where(prop => !constructorParamNames.Contains(prop.Name, StringComparer.OrdinalIgnoreCase) && prop.SetMethod != null)
.Select(prop => $" {prop.Name} = original.{prop.Name},")
.ToArray();

const string template = """
{Usings}
namespace {Namespace};
public partial class {ClassDeclaration} {
public static {ClassDeclaration} {MethodName}({OriginalClassName} original)
=> new {ClassDeclaration}({ConstructorArgs}) {
{Properties}
};
}
""";

var code = template
.Replace("{Usings}", string.Join("\n", usings))
.Replace("{Namespace}", namespaceName)
.Replace("{ClassDeclaration}", classDeclaration)
.Replace("{OriginalClassName}", classToClone.Name)
.Replace("{MethodName}", methodName)
.Replace("{ConstructorArgs}", constructorArgs)
.Replace("{Properties}", string.Join("\n", properties).TrimEnd(','));

return code;

static string GetPropertyName(string parameterName, IPropertySymbol[] properties) {
var property = properties.FirstOrDefault(p => string.Equals(p.Name, parameterName, StringComparison.OrdinalIgnoreCase));
return property?.Name ?? parameterName;
}
}
}
9 changes: 9 additions & 0 deletions gen/SourceGenerator/Properties/launchSettings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"$schema": "http://json.schemastore.org/launchsettings.json",
"profiles": {
"Generators": {
"commandName": "DebugRoslynComponent",
"targetProject": "../../src/RestSharp/RestSharp.csproj"
}
}
}
14 changes: 11 additions & 3 deletions gen/SourceGenerator/SourceGenerator.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,18 @@
<IsPackable>false</IsPackable>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.CodeAnalysis.Analyzers" PrivateAssets="All" />
<PackageReference Include="Microsoft.CodeAnalysis.CSharp" PrivateAssets="All" />
<PackageReference Include="Microsoft.CodeAnalysis.Analyzers" PrivateAssets="All"/>
<PackageReference Include="Microsoft.CodeAnalysis.CSharp" PrivateAssets="All"/>
</ItemGroup>
<ItemGroup>
<None Include="$(OutputPath)\$(AssemblyName).dll" Pack="true" PackagePath="analyzers/dotnet/cs" Visible="false" />
<None Include="$(OutputPath)\$(AssemblyName).dll" Pack="true" PackagePath="analyzers/dotnet/cs" Visible="false"/>
</ItemGroup>
<ItemGroup>

<Using Include="System.Text"/>
<Using Include="Microsoft.CodeAnalysis"/>
<Using Include="Microsoft.CodeAnalysis.CSharp"/>
<Using Include="Microsoft.CodeAnalysis.CSharp.Syntax"/>
<Using Include="Microsoft.CodeAnalysis.Text"/>
</ItemGroup>
</Project>
9 changes: 7 additions & 2 deletions src/RestSharp/Extensions/GenerateImmutableAttribute.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
namespace RestSharp.Extensions;

[AttributeUsage(AttributeTargets.Class)]
class GenerateImmutableAttribute : Attribute { }
class GenerateImmutableAttribute : Attribute;

[AttributeUsage(AttributeTargets.Class)]
class GenerateCloneAttribute<T> : Attribute where T : class {
public string? Name { get; set; }
};

[AttributeUsage(AttributeTargets.Property)]
class Exclude : Attribute { }
class Exclude : Attribute;
7 changes: 3 additions & 4 deletions src/RestSharp/Extensions/HttpResponseExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,12 @@ static class HttpResponseExtensions {
: new HttpRequestException($"Request failed with status code {httpResponse.StatusCode}");
#endif

public static string GetResponseString(this HttpResponseMessage response, byte[] bytes, Encoding clientEncoding) {
public static async Task<string> GetResponseString(this HttpResponseMessage response, byte[] bytes, Encoding clientEncoding) {
var encodingString = response.Content.Headers.ContentType?.CharSet;
var encoding = encodingString != null ? TryGetEncoding(encodingString) : clientEncoding;

using var reader = new StreamReader(new MemoryStream(bytes), encoding);
return reader.ReadToEnd();

return await reader.ReadToEndAsync();
Encoding TryGetEncoding(string es) {
try {
return Encoding.GetEncoding(es);
Expand Down Expand Up @@ -69,4 +68,4 @@ Encoding TryGetEncoding(string es) {
return original == null ? null : streamWriter(original);
}
}
}
}
30 changes: 4 additions & 26 deletions src/RestSharp/Response/RestResponse.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
// limitations under the License.

using System.Diagnostics;
using System.Net;
using System.Text;
using RestSharp.Extensions;

Expand All @@ -25,34 +24,13 @@ namespace RestSharp;
/// Container for data sent back from API including deserialized data
/// </summary>
/// <typeparam name="T">Type of data to deserialize to</typeparam>
[DebuggerDisplay("{" + nameof(DebuggerDisplay) + "()}")]
public class RestResponse<T>(RestRequest request) : RestResponse(request) {
[GenerateClone<RestResponse>(Name = "FromResponse")]
[DebuggerDisplay($"{{{nameof(DebuggerDisplay)}()}}")]
public partial class RestResponse<T>(RestRequest request) : RestResponse(request) {
/// <summary>
/// Deserialized entity data
/// </summary>
public T? Data { get; set; }

public static RestResponse<T> FromResponse(RestResponse response)
=> new(response.Request) {
Content = response.Content,
ContentEncoding = response.ContentEncoding,
ContentHeaders = response.ContentHeaders,
ContentLength = response.ContentLength,
ContentType = response.ContentType,
Cookies = response.Cookies,
ErrorException = response.ErrorException,
ErrorMessage = response.ErrorMessage,
Headers = response.Headers,
IsSuccessStatusCode = response.IsSuccessStatusCode,
RawBytes = response.RawBytes,
ResponseStatus = response.ResponseStatus,
ResponseUri = response.ResponseUri,
RootElement = response.RootElement,
Server = response.Server,
StatusCode = response.StatusCode,
StatusDescription = response.StatusDescription,
Version = response.Version
};
}

/// <summary>
Expand All @@ -78,7 +56,7 @@ async Task<RestResponse> GetDefaultResponse() {
#endif

var bytes = stream == null ? null : await stream.ReadAsBytes(cancellationToken).ConfigureAwait(false);
var content = bytes == null ? null : httpResponse.GetResponseString(bytes, encoding);
var content = bytes == null ? null : await httpResponse.GetResponseString(bytes, encoding);

return new RestResponse(request) {
Content = content,
Expand Down
7 changes: 4 additions & 3 deletions src/RestSharp/Response/RestResponseBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.

using System.Diagnostics;
using System.Net;
// ReSharper disable PropertyCanBeMadeInitOnly.Global

namespace RestSharp;

Expand Down Expand Up @@ -65,12 +65,13 @@ protected RestResponseBase(RestRequest request) {
public HttpStatusCode StatusCode { get; set; }

/// <summary>
/// Whether or not the HTTP response status code indicates success
/// Whether the HTTP response status code indicates success
/// </summary>
public bool IsSuccessStatusCode { get; set; }

/// <summary>
/// Whether or not the HTTP response status code indicates success and no other error occurred (deserialization, timeout, ...)
/// Whether the HTTP response status code indicates success and no other error occurred
/// (deserialization, timeout, ...)
/// </summary>
public bool IsSuccessful => IsSuccessStatusCode && ResponseStatus == ResponseStatus.Completed;

Expand Down
2 changes: 1 addition & 1 deletion src/RestSharp/RestClient.Async.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public async Task<RestResponse> ExecuteAsync(RestRequest request, CancellationTo
/// <inheritdoc />
[PublicAPI]
public async Task<Stream?> DownloadStreamAsync(RestRequest request, CancellationToken cancellationToken = default) {
// Make sure we only read the headers so we can stream the content body efficiently
// Make sure we only read the headers, so we can stream the content body efficiently
request.CompletionOption = HttpCompletionOption.ResponseHeadersRead;
var response = await ExecuteRequestAsync(request, cancellationToken).ConfigureAwait(false);

Expand Down

0 comments on commit 57f6b3a

Please sign in to comment.