From f05633cd21a18bebe391a3b00ff47eb0126d1460 Mon Sep 17 00:00:00 2001 From: Alexey Zimarev Date: Tue, 11 Jun 2024 21:59:17 +0200 Subject: [PATCH] Use code generator for cloning responses --- gen/SourceGenerator/Extensions.cs | 44 ++++++++ gen/SourceGenerator/ImmutableGenerator.cs | 11 +- .../InheritedCloneGenerator.cs | 105 ++++++++++++++++++ .../Properties/launchSettings.json | 9 ++ gen/SourceGenerator/SourceGenerator.csproj | 14 ++- .../Extensions/GenerateImmutableAttribute.cs | 9 +- .../Extensions/HttpResponseExtensions.cs | 7 +- src/RestSharp/Response/RestResponse.cs | 30 +---- src/RestSharp/Response/RestResponseBase.cs | 7 +- src/RestSharp/RestClient.Async.cs | 2 +- 10 files changed, 189 insertions(+), 49 deletions(-) create mode 100644 gen/SourceGenerator/Extensions.cs create mode 100644 gen/SourceGenerator/InheritedCloneGenerator.cs create mode 100644 gen/SourceGenerator/Properties/launchSettings.json diff --git a/gen/SourceGenerator/Extensions.cs b/gen/SourceGenerator/Extensions.cs new file mode 100644 index 000000000..12daff3e6 --- /dev/null +++ b/gen/SourceGenerator/Extensions.cs @@ -0,0 +1,44 @@ +// Copyright (c) .NET Foundation and Contributors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace SourceGenerator; + +static class Extensions { + public static IEnumerable FindClasses(this Compilation compilation, Func predicate) + => compilation.SyntaxTrees + .Select(tree => compilation.GetSemanticModel(tree)) + .SelectMany(model => model.SyntaxTree.GetRoot().DescendantNodes().OfType()) + .Where(predicate); + + public static IEnumerable FindAnnotatedClass(this Compilation compilation, string attributeName, bool strict) { + return compilation.FindClasses( + syntax => syntax.AttributeLists.Any(list => list.Attributes.Any(CheckAttribute)) + ); + + bool CheckAttribute(AttributeSyntax attr) { + var name = attr.Name.ToString(); + return strict ? name == attributeName : name.StartsWith(attributeName); + } + } + + public static IEnumerable GetBaseTypesAndThis(this ITypeSymbol type) { + var current = type; + + while (current != null) { + yield return current; + + current = current.BaseType; + } + } +} \ No newline at end of file diff --git a/gen/SourceGenerator/ImmutableGenerator.cs b/gen/SourceGenerator/ImmutableGenerator.cs index 24adff24d..3a5832536 100644 --- a/gen/SourceGenerator/ImmutableGenerator.cs +++ b/gen/SourceGenerator/ImmutableGenerator.cs @@ -13,12 +13,6 @@ // limitations under the License. // -using System.Text; -using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp; -using Microsoft.CodeAnalysis.CSharp.Syntax; -using Microsoft.CodeAnalysis.Text; - namespace SourceGenerator; [Generator] @@ -28,10 +22,7 @@ public void Initialize(GeneratorInitializationContext context) { } public void Execute(GeneratorExecutionContext context) { var compilation = context.Compilation; - var mutableClasses = compilation.SyntaxTrees - .Select(tree => compilation.GetSemanticModel(tree)) - .SelectMany(model => model.SyntaxTree.GetRoot().DescendantNodes().OfType()) - .Where(syntax => syntax.AttributeLists.Any(list => list.Attributes.Any(attr => attr.Name.ToString() == "GenerateImmutable"))); + var mutableClasses = compilation.FindAnnotatedClass("GenerateImmutable", strict: true); foreach (var mutableClass in mutableClasses) { var immutableClass = GenerateImmutableClass(mutableClass, compilation); diff --git a/gen/SourceGenerator/InheritedCloneGenerator.cs b/gen/SourceGenerator/InheritedCloneGenerator.cs new file mode 100644 index 000000000..2bc635443 --- /dev/null +++ b/gen/SourceGenerator/InheritedCloneGenerator.cs @@ -0,0 +1,105 @@ +// Copyright (c) .NET Foundation and Contributors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace SourceGenerator; + +[Generator] +public class InheritedCloneGenerator : ISourceGenerator { + const string AttributeName = "GenerateClone"; + + public void Initialize(GeneratorInitializationContext context) { } + + public void Execute(GeneratorExecutionContext context) { + var compilation = context.Compilation; + + var candidates = compilation.FindAnnotatedClass(AttributeName, false); + + foreach (var candidate in candidates) { + var semanticModel = compilation.GetSemanticModel(candidate.SyntaxTree); + var genericClassSymbol = semanticModel.GetDeclaredSymbol(candidate); + if (genericClassSymbol == null) continue; + + // Get the method name from the attribute Name argument + var attributeData = genericClassSymbol.GetAttributes().FirstOrDefault(a => a.AttributeClass?.Name == $"{AttributeName}Attribute"); + var methodName = (string)attributeData.NamedArguments.FirstOrDefault(arg => arg.Key == "Name").Value.Value; + + // Get the generic argument type where properties need to be copied from + var attributeSyntax = candidate.AttributeLists + .SelectMany(l => l.Attributes) + .FirstOrDefault(a => a.Name.ToString().StartsWith(AttributeName)); + if (attributeSyntax == null) continue; // This should never happen + + var typeArgumentSyntax = ((GenericNameSyntax)attributeSyntax.Name).TypeArgumentList.Arguments[0]; + var typeSymbol = (INamedTypeSymbol)semanticModel.GetSymbolInfo(typeArgumentSyntax).Symbol; + + var code = GenerateMethod(candidate, genericClassSymbol, typeSymbol, methodName); + context.AddSource($"{genericClassSymbol.Name}.Clone.g.cs", SourceText.From(code, Encoding.UTF8)); + } + } + + static string GenerateMethod( + TypeDeclarationSyntax classToExtendSyntax, + INamedTypeSymbol classToExtendSymbol, + INamedTypeSymbol classToClone, + string methodName + ) { + var namespaceName = classToExtendSymbol.ContainingNamespace.ToDisplayString(); + var className = classToExtendSyntax.Identifier.Text; + var genericTypeParameters = string.Join(", ", classToExtendSymbol.TypeParameters.Select(tp => tp.Name)); + var classDeclaration = classToExtendSymbol.TypeParameters.Length > 0 ? $"{className}<{genericTypeParameters}>" : className; + + var all = classToClone.GetBaseTypesAndThis(); + var props = all.SelectMany(x => x.GetMembers().OfType()).ToArray(); + var usings = classToExtendSyntax.SyntaxTree.GetCompilationUnitRoot().Usings.Select(u => u.ToString()); + + var constructorParams = classToExtendSymbol.Constructors.First().Parameters.ToArray(); + var constructorArgs = string.Join(", ", constructorParams.Select(p => $"original.{GetPropertyName(p.Name, props)}")); + var constructorParamNames = constructorParams.Select(p => p.Name).ToArray(); + + var properties = props + // ReSharper disable once PossibleUnintendedLinearSearchInSet + .Where(prop => !constructorParamNames.Contains(prop.Name, StringComparer.OrdinalIgnoreCase) && prop.SetMethod != null) + .Select(prop => $" {prop.Name} = original.{prop.Name},") + .ToArray(); + + const string template = """ + {Usings} + + namespace {Namespace}; + + public partial class {ClassDeclaration} { + public static {ClassDeclaration} {MethodName}({OriginalClassName} original) + => new {ClassDeclaration}({ConstructorArgs}) { + {Properties} + }; + } + """; + + var code = template + .Replace("{Usings}", string.Join("\n", usings)) + .Replace("{Namespace}", namespaceName) + .Replace("{ClassDeclaration}", classDeclaration) + .Replace("{OriginalClassName}", classToClone.Name) + .Replace("{MethodName}", methodName) + .Replace("{ConstructorArgs}", constructorArgs) + .Replace("{Properties}", string.Join("\n", properties).TrimEnd(',')); + + return code; + + static string GetPropertyName(string parameterName, IPropertySymbol[] properties) { + var property = properties.FirstOrDefault(p => string.Equals(p.Name, parameterName, StringComparison.OrdinalIgnoreCase)); + return property?.Name ?? parameterName; + } + } +} \ No newline at end of file diff --git a/gen/SourceGenerator/Properties/launchSettings.json b/gen/SourceGenerator/Properties/launchSettings.json new file mode 100644 index 000000000..cd06df480 --- /dev/null +++ b/gen/SourceGenerator/Properties/launchSettings.json @@ -0,0 +1,9 @@ +{ + "$schema": "http://json.schemastore.org/launchsettings.json", + "profiles": { + "Generators": { + "commandName": "DebugRoslynComponent", + "targetProject": "../../src/RestSharp/RestSharp.csproj" + } + } +} \ No newline at end of file diff --git a/gen/SourceGenerator/SourceGenerator.csproj b/gen/SourceGenerator/SourceGenerator.csproj index b4aa74be5..2d26b1392 100644 --- a/gen/SourceGenerator/SourceGenerator.csproj +++ b/gen/SourceGenerator/SourceGenerator.csproj @@ -9,10 +9,18 @@ false - - + + - + + + + + + + + + diff --git a/src/RestSharp/Extensions/GenerateImmutableAttribute.cs b/src/RestSharp/Extensions/GenerateImmutableAttribute.cs index c4fe51817..172d7c24e 100644 --- a/src/RestSharp/Extensions/GenerateImmutableAttribute.cs +++ b/src/RestSharp/Extensions/GenerateImmutableAttribute.cs @@ -16,7 +16,12 @@ namespace RestSharp.Extensions; [AttributeUsage(AttributeTargets.Class)] -class GenerateImmutableAttribute : Attribute { } +class GenerateImmutableAttribute : Attribute; + +[AttributeUsage(AttributeTargets.Class)] +class GenerateCloneAttribute : Attribute where T : class { + public string? Name { get; set; } +}; [AttributeUsage(AttributeTargets.Property)] -class Exclude : Attribute { } +class Exclude : Attribute; diff --git a/src/RestSharp/Extensions/HttpResponseExtensions.cs b/src/RestSharp/Extensions/HttpResponseExtensions.cs index d501f5c5a..44497c91b 100644 --- a/src/RestSharp/Extensions/HttpResponseExtensions.cs +++ b/src/RestSharp/Extensions/HttpResponseExtensions.cs @@ -27,13 +27,12 @@ static class HttpResponseExtensions { : new HttpRequestException($"Request failed with status code {httpResponse.StatusCode}"); #endif - public static string GetResponseString(this HttpResponseMessage response, byte[] bytes, Encoding clientEncoding) { + public static async Task GetResponseString(this HttpResponseMessage response, byte[] bytes, Encoding clientEncoding) { var encodingString = response.Content.Headers.ContentType?.CharSet; var encoding = encodingString != null ? TryGetEncoding(encodingString) : clientEncoding; using var reader = new StreamReader(new MemoryStream(bytes), encoding); - return reader.ReadToEnd(); - + return await reader.ReadToEndAsync(); Encoding TryGetEncoding(string es) { try { return Encoding.GetEncoding(es); @@ -69,4 +68,4 @@ Encoding TryGetEncoding(string es) { return original == null ? null : streamWriter(original); } } -} +} \ No newline at end of file diff --git a/src/RestSharp/Response/RestResponse.cs b/src/RestSharp/Response/RestResponse.cs index 7aba25740..e6d664e79 100644 --- a/src/RestSharp/Response/RestResponse.cs +++ b/src/RestSharp/Response/RestResponse.cs @@ -13,7 +13,6 @@ // limitations under the License. using System.Diagnostics; -using System.Net; using System.Text; using RestSharp.Extensions; @@ -25,34 +24,13 @@ namespace RestSharp; /// Container for data sent back from API including deserialized data /// /// Type of data to deserialize to -[DebuggerDisplay("{" + nameof(DebuggerDisplay) + "()}")] -public class RestResponse(RestRequest request) : RestResponse(request) { +[GenerateClone(Name = "FromResponse")] +[DebuggerDisplay($"{{{nameof(DebuggerDisplay)}()}}")] +public partial class RestResponse(RestRequest request) : RestResponse(request) { /// /// Deserialized entity data /// public T? Data { get; set; } - - public static RestResponse FromResponse(RestResponse response) - => new(response.Request) { - Content = response.Content, - ContentEncoding = response.ContentEncoding, - ContentHeaders = response.ContentHeaders, - ContentLength = response.ContentLength, - ContentType = response.ContentType, - Cookies = response.Cookies, - ErrorException = response.ErrorException, - ErrorMessage = response.ErrorMessage, - Headers = response.Headers, - IsSuccessStatusCode = response.IsSuccessStatusCode, - RawBytes = response.RawBytes, - ResponseStatus = response.ResponseStatus, - ResponseUri = response.ResponseUri, - RootElement = response.RootElement, - Server = response.Server, - StatusCode = response.StatusCode, - StatusDescription = response.StatusDescription, - Version = response.Version - }; } /// @@ -78,7 +56,7 @@ async Task GetDefaultResponse() { #endif var bytes = stream == null ? null : await stream.ReadAsBytes(cancellationToken).ConfigureAwait(false); - var content = bytes == null ? null : httpResponse.GetResponseString(bytes, encoding); + var content = bytes == null ? null : await httpResponse.GetResponseString(bytes, encoding); return new RestResponse(request) { Content = content, diff --git a/src/RestSharp/Response/RestResponseBase.cs b/src/RestSharp/Response/RestResponseBase.cs index 0bbff2d55..96697c33f 100644 --- a/src/RestSharp/Response/RestResponseBase.cs +++ b/src/RestSharp/Response/RestResponseBase.cs @@ -13,7 +13,7 @@ // limitations under the License. using System.Diagnostics; -using System.Net; +// ReSharper disable PropertyCanBeMadeInitOnly.Global namespace RestSharp; @@ -65,12 +65,13 @@ protected RestResponseBase(RestRequest request) { public HttpStatusCode StatusCode { get; set; } /// - /// Whether or not the HTTP response status code indicates success + /// Whether the HTTP response status code indicates success /// public bool IsSuccessStatusCode { get; set; } /// - /// Whether or not the HTTP response status code indicates success and no other error occurred (deserialization, timeout, ...) + /// Whether the HTTP response status code indicates success and no other error occurred + /// (deserialization, timeout, ...) /// public bool IsSuccessful => IsSuccessStatusCode && ResponseStatus == ResponseStatus.Completed; diff --git a/src/RestSharp/RestClient.Async.cs b/src/RestSharp/RestClient.Async.cs index ce6706216..a7cd76ca1 100644 --- a/src/RestSharp/RestClient.Async.cs +++ b/src/RestSharp/RestClient.Async.cs @@ -43,7 +43,7 @@ public async Task ExecuteAsync(RestRequest request, CancellationTo /// [PublicAPI] public async Task DownloadStreamAsync(RestRequest request, CancellationToken cancellationToken = default) { - // Make sure we only read the headers so we can stream the content body efficiently + // Make sure we only read the headers, so we can stream the content body efficiently request.CompletionOption = HttpCompletionOption.ResponseHeadersRead; var response = await ExecuteRequestAsync(request, cancellationToken).ConfigureAwait(false);