Skip to content

Commit

Permalink
Use C# 13 overload resolution attribute to improve friendly overloads…
Browse files Browse the repository at this point in the history
… experience
  • Loading branch information
AArnott committed Jan 22, 2025
1 parent 1beb6ff commit 519a407
Show file tree
Hide file tree
Showing 8 changed files with 144 additions and 2 deletions.
36 changes: 36 additions & 0 deletions src/Microsoft.Windows.CsWin32/Generator.Features.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,56 @@ public partial class Generator
private readonly bool canUseUnsafeSkipInit;
private readonly bool canUseUnmanagedCallersOnlyAttribute;
private readonly bool canUseSetLastPInvokeError;
private readonly bool overloadResolutionPriorityAttributePredefined;
private readonly bool unscopedRefAttributePredefined;
private readonly INamedTypeSymbol? runtimeFeatureClass;
private readonly bool generateSupportedOSPlatformAttributes;
private readonly bool generateSupportedOSPlatformAttributesOnInterfaces; // only supported on net6.0 (https://github.com/dotnet/runtime/pull/48838)
private readonly bool generateDefaultDllImportSearchPathsAttribute;
private readonly Dictionary<Feature, bool> supportedFeatures = new();

private void DeclareOverloadResolutionPriorityAttributeIfNecessary()
{
// This attribute may only be applied for C# 13 and later, or else C# errors out.
if (this.LanguageVersion < (LanguageVersion)1300)
{
throw new GenerationFailedException("The OverloadResolutionPriorityAttribute requires C# 13 or later.");
}

if (this.overloadResolutionPriorityAttributePredefined)
{
return;
}

// Always generate these in the context of the most common metadata so we don't emit it more than once.
if (!this.IsWin32Sdk)
{
this.MainGenerator.volatileCode.GenerationTransaction(() => this.MainGenerator.DeclareOverloadResolutionPriorityAttributeIfNecessary());
return;
}

const string name = "OverloadResolutionPriorityAttribute";
this.volatileCode.GenerateSpecialType(name, delegate
{
// This is a polyfill attribute, so never promote visibility to public.
if (!TryFetchTemplate(name, this, out CompilationUnitSyntax? compilationUnit))
{
throw new GenerationFailedException($"Failed to retrieve template: {name}");
}

MemberDeclarationSyntax templateNamespace = compilationUnit.Members.Single();
this.volatileCode.AddSpecialType(name, templateNamespace, topLevel: true);
});
}

private void DeclareUnscopedRefAttributeIfNecessary()
{
if (this.unscopedRefAttributePredefined)
{
return;
}

// Always generate these in the context of the most common metadata so we don't emit it more than once.
if (!this.IsWin32Sdk)
{
this.MainGenerator.volatileCode.GenerationTransaction(() => this.MainGenerator.DeclareUnscopedRefAttributeIfNecessary());
Expand Down
7 changes: 7 additions & 0 deletions src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,13 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
friendlyDeclaration = friendlyDeclaration.AddAttributeLists(AttributeList().AddAttributes(supportedOSPlatformAttribute));
}

// If we're using C# 13 or later, consider adding the overload resolution attribute if it would likely resolve ambiguities.
if (this.LanguageVersion >= (LanguageVersion)1300 && parameters.Count == externMethodDeclaration.ParameterList.Parameters.Count)
{
this.DeclareOverloadResolutionPriorityAttributeIfNecessary();
friendlyDeclaration = friendlyDeclaration.AddAttributeLists(AttributeList().AddAttributes(OverloadResolutionPriorityAttribute(1)));
}

friendlyDeclaration = friendlyDeclaration
.WithLeadingTrivia(leadingTrivia);

Expand Down
21 changes: 21 additions & 0 deletions src/Microsoft.Windows.CsWin32/Generator.Templates.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,27 @@ private static bool TryFetchTemplate(string name, Generator? generator, [NotNull
}

member = generator?.ElevateVisibility(member) ?? member;

return true;
}

private static bool TryFetchTemplate(string name, Generator? generator, [NotNullWhen(true)] out CompilationUnitSyntax? compilationUnit)
{
string? template = FetchTemplateText(name);
if (template == null)
{
compilationUnit = null;
return false;
}

compilationUnit = SyntaxFactory.ParseCompilationUnit(template, options: generator?.parseOptions) ?? throw new GenerationFailedException($"Unable to parse compilation unit from a template: {name}");

// Strip out #if/#else/#endif trivia, which was already evaluated with the parse options we passed in.
if (generator?.parseOptions is not null)
{
compilationUnit = (CompilationUnitSyntax)compilationUnit.Accept(DirectiveTriviaRemover.Instance)!;
}

return true;
}

Expand Down
1 change: 1 addition & 0 deletions src/Microsoft.Windows.CsWin32/Generator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ public Generator(string metadataLibraryPath, Docs? docs, GeneratorOptions option
this.canUseUnmanagedCallersOnlyAttribute = this.FindTypeSymbolsIfAlreadyAvailable("System.Runtime.InteropServices.UnmanagedCallersOnlyAttribute").Count > 0;
this.canUseSetLastPInvokeError = this.compilation?.GetTypeByMetadataName("System.Runtime.InteropServices.Marshal")?.GetMembers("GetLastSystemError").IsEmpty is false;
this.unscopedRefAttributePredefined = this.FindTypeSymbolIfAlreadyAvailable("System.Diagnostics.CodeAnalysis.UnscopedRefAttribute") is not null;
this.overloadResolutionPriorityAttributePredefined = this.FindTypeSymbolIfAlreadyAvailable("System.Runtime.CompilerServices.OverloadResolutionPriorityAttribute") is not null;
this.runtimeFeatureClass = (INamedTypeSymbol?)this.FindTypeSymbolIfAlreadyAvailable("System.Runtime.CompilerServices.RuntimeFeature");
this.comIIDInterfacePredefined = this.FindTypeSymbolIfAlreadyAvailable($"{this.Namespace}.{IComIIDGuidInterfaceName}") is not null;
this.getDelegateForFunctionPointerGenericExists = this.compilation?.GetTypeByMetadataName(typeof(Marshal).FullName)?.GetMembers(nameof(Marshal.GetDelegateForFunctionPointer)).Any(m => m is IMethodSymbol { IsGenericMethod: true }) is true;
Expand Down
2 changes: 2 additions & 0 deletions src/Microsoft.Windows.CsWin32/SimpleSyntaxFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ internal static class SimpleSyntaxFactory
internal static readonly IdentifierNameSyntax ComIIDGuidPropertyName = IdentifierName("Guid");
internal static readonly AttributeSyntax FieldOffsetAttributeSyntax = Attribute(IdentifierName("FieldOffset"));

internal static AttributeSyntax OverloadResolutionPriorityAttribute(int priority) => Attribute(ParseName("OverloadResolutionPriority")).AddArgumentListArguments(AttributeArgument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(priority))));

[return: NotNullIfNotNull("marshalAs")]
internal static AttributeSyntax? MarshalAs(MarshalAsAttribute? marshalAs, Generator.NativeArrayInfo? nativeArrayInfo)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
namespace System.Runtime.CompilerServices
{
/// <summary>
/// Specifies the priority of a member in overload resolution.
/// When unspecified, the default priority is 0.
/// </summary>
[global::System.AttributeUsage(global::System.AttributeTargets.Constructor | global::System.AttributeTargets.Method | global::System.AttributeTargets.Property, AllowMultiple = false, Inherited = false)]
internal sealed class OverloadResolutionPriorityAttribute : global::System.Attribute
{
/// <summary>
/// Initializes a new instance of the <see cref="OverloadResolutionPriorityAttribute"/> class.
/// </summary>
/// <param name="priority">The priority of the attributed member. Higher numbers are prioritized, lower numbers are deprioritized. 0 is the default if no attribute is present.</param>
public OverloadResolutionPriorityAttribute(int priority)
{
this.Priority = priority;
}

/// <summary>
/// The priority of the member.
/// </summary>
public int Priority { get; }
}
}
35 changes: 35 additions & 0 deletions test/Microsoft.Windows.CsWin32.Tests/ExternMethodTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,41 @@ public void ReferencesToStructWithFlexibleArrayAreAlwaysPointers()
Assert.Empty(this.FindGeneratedType("BITMAPINFO_unmanaged"));
}

[Theory, PairwiseData]
public void OverloadResolutionAttributeUsage(
bool useMatchingLanguageVersion,
[CombinatorialMemberData(nameof(TFMDataNoNetFx35))] string tfm)
{
// Set up the test under the appropriate TFM and either a matching language version or C# 13,
// which is the first version that supports the OverloadResolutionPriorityAttribute.
this.compilation = this.starterCompilations[tfm];
this.parseOptions = this.parseOptions.WithLanguageVersion(
useMatchingLanguageVersion ? (GetLanguageVersionForTfm(tfm) ?? LanguageVersion.Latest) : LanguageVersion.CSharp13);
this.generator = this.CreateGenerator();

this.GenerateApi("EnumDisplayMonitors");

// Emit usage that would be ambiguous without the OverloadResolutionPriorityAttribute.
this.compilation = this.AddCode("""
using Windows.Win32;
using Windows.Win32.Foundation;
using Windows.Win32.Graphics.Gdi;
class Foo
{
static void Use()
{
PInvoke.EnumDisplayMonitors(default, null, default, default);
}
}
""");

Func<Diagnostic, bool>? isAcceptable = this.parseOptions.LanguageVersion >= LanguageVersion.CSharp13
? null // C# 13 and later should not produce any diagnostics.
: diag => diag.Descriptor.Id == "CS0121";
this.AssertNoDiagnostics(this.compilation, logAllGeneratedCode: false, acceptable: isAcceptable);
}

private static AttributeSyntax? FindDllImportAttribute(SyntaxList<AttributeListSyntax> attributeLists) => attributeLists.SelectMany(al => al.Attributes).FirstOrDefault(a => a.Name.ToString() == "DllImport");

private IEnumerable<MethodDeclarationSyntax> GenerateMethod(string methodName)
Expand Down
20 changes: 18 additions & 2 deletions test/Microsoft.Windows.CsWin32.Tests/GeneratorTestBase.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System.Diagnostics.CodeAnalysis;

public abstract class GeneratorTestBase : IDisposable, IAsyncLifetime
{
protected const string DefaultTFM = "netstandard2.0";
Expand Down Expand Up @@ -215,6 +217,20 @@ protected CSharpCompilation AddGeneratedCode(CSharpCompilation compilation, IGen
return compilation.AddSyntaxTrees(syntaxTrees);
}

/// <summary>
/// Adds a code file to a compilation.
/// </summary>
/// <param name="code">The syntax file to add.</param>
/// <param name="fileName">The name of the code file to add.</param>
/// <param name="compilation">The compilation to add to. When omitted, <see cref="GeneratorTestBase.compilation"/> is assumed.</param>
/// <returns>The modified compilation.</returns>
protected CSharpCompilation AddCode([StringSyntax("c#-test")] string code, string? fileName = null, CSharpCompilation? compilation = null)
{
compilation ??= this.compilation;
SyntaxTree syntaxTree = CSharpSyntaxTree.ParseText(code, this.parseOptions, fileName ?? $"AdditionalCode{compilation.SyntaxTrees.Length + 1}.cs");
return compilation.AddSyntaxTrees(syntaxTree);
}

protected void CollectGeneratedCode(IGenerator generator) => this.compilation = this.AddGeneratedCode(this.compilation, generator);

protected IEnumerable<MethodDeclarationSyntax> FindGeneratedMethod(string name, Compilation? compilation = null) => (compilation ?? this.compilation).SyntaxTrees.SelectMany(st => st.GetRoot().DescendantNodes().OfType<MethodDeclarationSyntax>()).Where(md => md.Identifier.ValueText == name);
Expand Down Expand Up @@ -243,7 +259,7 @@ protected CSharpCompilation AddGeneratedCode(CSharpCompilation compilation, IGen

protected void AssertNoDiagnostics(bool logAllGeneratedCode = true) => this.AssertNoDiagnostics(this.compilation, logAllGeneratedCode);

protected void AssertNoDiagnostics(CSharpCompilation compilation, bool logAllGeneratedCode = true)
protected void AssertNoDiagnostics(CSharpCompilation compilation, bool logAllGeneratedCode = true, Func<Diagnostic, bool>? acceptable = null)
{
var diagnostics = FilterDiagnostics(compilation.GetDiagnostics());
this.logger.WriteLine($"{diagnostics.Length} diagnostics reported.");
Expand Down Expand Up @@ -274,7 +290,7 @@ protected void AssertNoDiagnostics(CSharpCompilation compilation, bool logAllGen
}
}

Assert.Empty(diagnostics);
Assert.Empty(acceptable is null ? diagnostics : diagnostics.Where(d => !acceptable(d)));
if (emitSuccessful.HasValue)
{
Assert.Empty(emitDiagnostics);
Expand Down

0 comments on commit 519a407

Please sign in to comment.