diff --git a/src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs b/src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs index a4c620a4..31e7a249 100644 --- a/src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs +++ b/src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs @@ -74,6 +74,8 @@ internal static SyntaxToken Token(SyntaxKind kind) internal static BlockSyntax Block(params StatementSyntax[] statements) => SyntaxFactory.Block(OpenBrace, List(statements), CloseBrace); + internal static BlockSyntax Block(IEnumerable statements) => SyntaxFactory.Block(OpenBrace, List(statements), CloseBrace); + internal static ImplicitArrayCreationExpressionSyntax ImplicitArrayCreationExpression(InitializerExpressionSyntax initializerExpression) => SyntaxFactory.ImplicitArrayCreationExpression(Token(SyntaxKind.NewKeyword), Token(SyntaxKind.OpenBracketToken), default, Token(SyntaxKind.CloseBracketToken), initializerExpression); internal static ForStatementSyntax ForStatement(VariableDeclarationSyntax? declaration, ExpressionSyntax condition, SeparatedSyntaxList incrementors, StatementSyntax statement) @@ -100,10 +102,12 @@ internal static ForStatementSyntax ForStatement(VariableDeclarationSyntax? decla internal static DeclarationExpressionSyntax DeclarationExpression(TypeSyntax type, VariableDesignationSyntax designation) => SyntaxFactory.DeclarationExpression(type, designation); - internal static VariableDeclaratorSyntax VariableDeclarator(SyntaxToken identifier) => SyntaxFactory.VariableDeclarator(identifier); + internal static VariableDeclaratorSyntax VariableDeclarator(SyntaxToken identifier, EqualsValueClauseSyntax? initializer = null) => SyntaxFactory.VariableDeclarator(identifier, argumentList: null, initializer: initializer); internal static VariableDeclarationSyntax VariableDeclaration(TypeSyntax type) => SyntaxFactory.VariableDeclaration(type.WithTrailingTrivia(TriviaList(Space))); + internal static VariableDeclarationSyntax VariableDeclaration(TypeSyntax type, params VariableDeclaratorSyntax[] variables) => SyntaxFactory.VariableDeclaration(type.WithTrailingTrivia(TriviaList(Space)), SeparatedList(variables)); + internal static SizeOfExpressionSyntax SizeOfExpression(TypeSyntax type) => SyntaxFactory.SizeOfExpression(Token(SyntaxKind.SizeOfKeyword), Token(SyntaxKind.OpenParenToken), type, Token(SyntaxKind.CloseParenToken)); internal static MemberAccessExpressionSyntax MemberAccessExpression(SyntaxKind kind, ExpressionSyntax expression, SimpleNameSyntax name) => SyntaxFactory.MemberAccessExpression(kind, expression, Token(GetMemberAccessExpressionOperatorTokenKind(kind)), name); @@ -190,7 +194,7 @@ internal static ForStatementSyntax ForStatement(VariableDeclarationSyntax? decla internal static InitializerExpressionSyntax InitializerExpression(SyntaxKind kind, SeparatedSyntaxList expressions) => SyntaxFactory.InitializerExpression(kind, OpenBrace, expressions, CloseBrace); - internal static ObjectCreationExpressionSyntax ObjectCreationExpression(TypeSyntax type) => SyntaxFactory.ObjectCreationExpression(Token(TriviaList(), SyntaxKind.NewKeyword, TriviaList(Space)), type, ArgumentList(), null); + internal static ObjectCreationExpressionSyntax ObjectCreationExpression(TypeSyntax type, SeparatedSyntaxList arguments = default) => SyntaxFactory.ObjectCreationExpression(Token(TriviaList(), SyntaxKind.NewKeyword, TriviaList(Space)), type, ArgumentList(arguments), null); internal static ArrayCreationExpressionSyntax ArrayCreationExpression(ArrayTypeSyntax type, InitializerExpressionSyntax? initializer = null) => SyntaxFactory.ArrayCreationExpression(Token(SyntaxKind.NewKeyword), type, initializer); @@ -295,7 +299,7 @@ internal static SyntaxList SingletonList(TNode node) internal static AttributeArgumentListSyntax AttributeArgumentList(SeparatedSyntaxList arguments = default) => SyntaxFactory.AttributeArgumentList(Token(SyntaxKind.OpenParenToken), arguments, Token(SyntaxKind.CloseParenToken)); - internal static AttributeListSyntax AttributeList() => SyntaxFactory.AttributeList(Token(SyntaxKind.OpenBracketToken), null, SeparatedList(), TokenWithLineFeed(SyntaxKind.CloseBracketToken)); + internal static AttributeListSyntax AttributeList(params SeparatedSyntaxList attributes) => SyntaxFactory.AttributeList(Token(SyntaxKind.OpenBracketToken), null, attributes, TokenWithLineFeed(SyntaxKind.CloseBracketToken)); internal static SyntaxList List() where TNode : SyntaxNode => SyntaxFactory.List(); @@ -305,7 +309,7 @@ internal static SyntaxList List(IEnumerable nodes) internal static ParameterListSyntax ParameterList() => SyntaxFactory.ParameterList(Token(SyntaxKind.OpenParenToken), SeparatedList(), Token(SyntaxKind.CloseParenToken)); - internal static ArgumentListSyntax ArgumentList(SeparatedSyntaxList arguments = default) => SyntaxFactory.ArgumentList(Token(SyntaxKind.OpenParenToken), arguments, Token(SyntaxKind.CloseParenToken)); + internal static ArgumentListSyntax ArgumentList(params SeparatedSyntaxList arguments) => SyntaxFactory.ArgumentList(Token(SyntaxKind.OpenParenToken), arguments, Token(SyntaxKind.CloseParenToken)); internal static AssignmentExpressionSyntax AssignmentExpression(SyntaxKind kind, ExpressionSyntax left, ExpressionSyntax right) => SyntaxFactory.AssignmentExpression(kind, left, Token(GetAssignmentExpressionOperatorTokenKind(kind)).WithLeadingTrivia(Space), right); diff --git a/src/Microsoft.Windows.CsWin32/Generator.Features.cs b/src/Microsoft.Windows.CsWin32/Generator.Features.cs index 1e286683..a7591320 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.Features.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.Features.cs @@ -22,6 +22,7 @@ public partial class Generator private readonly bool unscopedRefAttributePredefined; private readonly bool canUseComVariant; private readonly bool canUseMemberFunctionCallingConvention; + private readonly bool canUseMarshalInitHandle; private readonly INamedTypeSymbol? runtimeFeatureClass; private readonly bool generateSupportedOSPlatformAttributes; private readonly bool generateSupportedOSPlatformAttributesOnInterfaces; // only supported on net6.0 (https://github.com/dotnet/runtime/pull/48838) diff --git a/src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs b/src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs index 6987a082..798705e5 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs @@ -378,19 +378,54 @@ private IEnumerable DeclareFriendlyOverload( .WithModifiers(TokenList(TokenWithSpace(SyntaxKind.OutKeyword))); // HANDLE SomeLocal; - leadingStatements.Add(LocalDeclarationStatement(VariableDeclaration(pointedElementInfo.ToTypeSyntax(parameterTypeSyntaxSettings, GeneratingElement.FriendlyOverload, null).Type).AddVariables( - VariableDeclarator(typeDefHandleName.Identifier)))); + leadingStatements.Add( + LocalDeclarationStatement( + VariableDeclaration( + pointedElementInfo.ToTypeSyntax(parameterTypeSyntaxSettings, GeneratingElement.FriendlyOverload, null).Type, + VariableDeclarator(typeDefHandleName.Identifier)))); + + ArgumentSyntax ownsHandleArgument = Argument( + NameColon(IdentifierName("ownsHandle")), + refKindKeyword: default, + LiteralExpression(doNotRelease ? SyntaxKind.FalseLiteralExpression : SyntaxKind.TrueLiteralExpression)); + + if (this.canUseMarshalInitHandle) + { + // Some = new SafeHandle(default, ownsHandle: true); + leadingStatements.Add( + ExpressionStatement( + AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + origName, + ObjectCreationExpression(safeHandleType, [Argument(LiteralExpression(SyntaxKind.DefaultLiteralExpression)), ownsHandleArgument])))); + + // Marshal.InitHandle(Some, SomeLocal); + trailingStatements.Add( + ExpressionStatement( + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(nameof(Marshal)), + IdentifierName("InitHandle")), + ArgumentList( + [ + Argument(origName), + Argument(this.GetIntPtrFromTypeDef(typeDefHandleName, pointedElementInfo)), + ])))); + } + else + { + // Some = new SafeHandle(SomeLocal, ownsHandle: true); + trailingStatements.Add(ExpressionStatement(AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + origName, + ObjectCreationExpression(safeHandleType).AddArgumentListArguments( + Argument(this.GetIntPtrFromTypeDef(typeDefHandleName, pointedElementInfo)), + ownsHandleArgument)))); + } // Argument: &SomeLocal arguments[paramIndex] = Argument(PrefixUnaryExpression(SyntaxKind.AddressOfExpression, typeDefHandleName)); - - // Some = new SafeHandle(SomeLocal, ownsHandle: true); - trailingStatements.Add(ExpressionStatement(AssignmentExpression( - SyntaxKind.SimpleAssignmentExpression, - origName, - ObjectCreationExpression(safeHandleType).AddArgumentListArguments( - Argument(this.GetIntPtrFromTypeDef(typeDefHandleName, pointedElementInfo)), - Argument(LiteralExpression(doNotRelease ? SyntaxKind.FalseLiteralExpression : SyntaxKind.TrueLiteralExpression)).WithNameColon(NameColon(IdentifierName("ownsHandle"))))))); } } else if (this.options.UseSafeHandles && isIn && !isOut && !isReleaseMethod && parameterTypeInfo is HandleTypeHandleInfo parameterHandleTypeInfo && this.TryGetHandleReleaseMethod(parameterHandleTypeInfo.Handle, paramAttributes, out string? releaseMethod) && !this.Reader.StringComparer.Equals(methodDefinition.Name, releaseMethod) @@ -1108,7 +1143,46 @@ bool TryHandleCountParam(TypeSyntax elementType, bool nullableSource) && returnTypeHandleInfo.Generator.TryGetHandleReleaseMethod(returnTypeHandleInfo.Handle, returnTypeAttributes, out string? returnReleaseMethod) ? this.RequestSafeHandle(returnReleaseMethod) : null; - if ((returnSafeHandleType is object || minorSignatureChange) && !signatureChanged) + IdentifierNameSyntax resultLocal = IdentifierName("__result"); + + if (this.canUseMarshalInitHandle && returnSafeHandleType is not null) + { + IdentifierNameSyntax resultSafeHandleLocal = IdentifierName("__resultSafeHandle"); + + // SafeHandle __resultSafeHandle = new SafeHandle(default, ownsHandle: true); + leadingStatements.Add( + LocalDeclarationStatement( + VariableDeclaration( + returnSafeHandleType, + VariableDeclarator( + resultSafeHandleLocal.Identifier, + EqualsValueClause( + ObjectCreationExpression( + returnSafeHandleType, + [ + Argument(LiteralExpression(SyntaxKind.DefaultLiteralExpression)), + Argument( + NameColon(IdentifierName("ownsHandle")), + refKindKeyword: default, + LiteralExpression(doNotRelease ? SyntaxKind.FalseLiteralExpression : SyntaxKind.TrueLiteralExpression)) + ])))))); + + // Marshal.InitHandle(__resultSafeHandle, __result); + trailingStatements.Add( + ExpressionStatement( + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(nameof(Marshal)), + IdentifierName("InitHandle")), + ArgumentList( + [ + Argument(resultSafeHandleLocal), + Argument(this.GetIntPtrFromTypeDef(resultLocal, originalSignature.ReturnType)), + ])))); + } + + if ((returnSafeHandleType is not null || minorSignatureChange) && !signatureChanged) { // The parameter types are all the same, but we need a friendly overload with a different return type. // Our only choice is to rename the friendly overload. @@ -1145,20 +1219,33 @@ bool TryHandleCountParam(TypeSyntax elementType, bool nullableSource) }) .WithArgumentList(FixTrivia(ArgumentList().AddArguments(arguments.ToArray()))); bool hasVoidReturn = externMethodReturnType is PredefinedTypeSyntax { Keyword: { RawKind: (int)SyntaxKind.VoidKeyword } }; - BlockSyntax? body = Block().AddStatements(leadingStatements.ToArray()); - IdentifierNameSyntax resultLocal = IdentifierName("__result"); - if (returnSafeHandleType is object) + BlockSyntax? body = Block(leadingStatements); + if (returnSafeHandleType is not null) { - //// HANDLE result = invocation(); + // HANDLE result = invocation(); body = body.AddStatements(LocalDeclarationStatement(VariableDeclaration(externMethodReturnType) .AddVariables(VariableDeclarator(resultLocal.Identifier).WithInitializer(EqualsValueClause(externInvocation))))); body = body.AddStatements(trailingStatements.ToArray()); - //// return new SafeHandle(result, ownsHandle: true); - body = body.AddStatements(ReturnStatement(ObjectCreationExpression(returnSafeHandleType).AddArgumentListArguments( - Argument(this.GetIntPtrFromTypeDef(resultLocal, originalSignature.ReturnType)), - Argument(LiteralExpression(doNotRelease ? SyntaxKind.FalseLiteralExpression : SyntaxKind.TrueLiteralExpression)).WithNameColon(NameColon(IdentifierName("ownsHandle")))))); + ReturnStatementSyntax returnStatement; + if (this.canUseMarshalInitHandle) + { + // return __resultSafeHandle; + returnStatement = ReturnStatement(IdentifierName("__resultSafeHandle")); + } + else + { + // return new SafeHandle(result, ownsHandle: true); + returnStatement = ReturnStatement(ObjectCreationExpression(returnSafeHandleType).AddArgumentListArguments( + Argument(this.GetIntPtrFromTypeDef(resultLocal, originalSignature.ReturnType)), + Argument( + NameColon(IdentifierName("ownsHandle")), + refKindKeyword: default, + LiteralExpression(doNotRelease ? SyntaxKind.FalseLiteralExpression : SyntaxKind.TrueLiteralExpression)))); + } + + body = body.AddStatements(returnStatement); } else if (hasVoidReturn) { diff --git a/src/Microsoft.Windows.CsWin32/Generator.Handle.cs b/src/Microsoft.Windows.CsWin32/Generator.Handle.cs index ad7d77f4..5c300b8c 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.Handle.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.Handle.cs @@ -123,29 +123,32 @@ public partial class Generator VariableDeclarator(invalidValueFieldName.Identifier).WithInitializer(EqualsValueClause(invalidHandleIntPtr)))) .AddModifiers(TokenWithSpace(SyntaxKind.PrivateKeyword), TokenWithSpace(SyntaxKind.StaticKeyword), TokenWithSpace(SyntaxKind.ReadOnlyKeyword))); + SyntaxToken visibilityModifier = TokenWithSpace(this.Visibility); + // public SafeHandle() : base(INVALID_HANDLE_VALUE, true) members.Add(ConstructorDeclaration(safeHandleTypeIdentifier.Identifier) - .AddModifiers(TokenWithSpace(this.Visibility)) + .AddModifiers(visibilityModifier) .WithInitializer(ConstructorInitializer(SyntaxKind.BaseConstructorInitializer, ArgumentList().AddArguments( Argument(invalidValueFieldName), Argument(LiteralExpression(SyntaxKind.TrueLiteralExpression))))) .WithBody(Block())); // public SafeHandle(IntPtr preexistingHandle, bool ownsHandle = true) : base(INVALID_HANDLE_VALUE, ownsHandle) { this.SetHandle(preexistingHandle); } - const string preexistingHandleName = "preexistingHandle"; - const string ownsHandleName = "ownsHandle"; + IdentifierNameSyntax preexistingHandleName = IdentifierName("preexistingHandle"); + IdentifierNameSyntax ownsHandleName = IdentifierName("ownsHandle"); members.Add(ConstructorDeclaration(safeHandleTypeIdentifier.Identifier) - .AddModifiers(TokenWithSpace(this.Visibility)) + .AddModifiers(visibilityModifier) .AddParameterListParameters( - Parameter(Identifier(preexistingHandleName)).WithType(IntPtrTypeSyntax.WithTrailingTrivia(TriviaList(Space))), - Parameter(Identifier(ownsHandleName)).WithType(PredefinedType(TokenWithSpace(SyntaxKind.BoolKeyword))) + Parameter(preexistingHandleName.Identifier).WithType(IntPtrTypeSyntax.WithTrailingTrivia(TriviaList(Space))), + Parameter(ownsHandleName.Identifier) + .WithType(PredefinedType(TokenWithSpace(SyntaxKind.BoolKeyword))) .WithDefault(EqualsValueClause(LiteralExpression(SyntaxKind.TrueLiteralExpression)))) .WithInitializer(ConstructorInitializer(SyntaxKind.BaseConstructorInitializer, ArgumentList().AddArguments( Argument(invalidValueFieldName), - Argument(IdentifierName(ownsHandleName))))) + Argument(ownsHandleName)))) .WithBody(Block().AddStatements( ExpressionStatement(InvocationExpression(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, ThisExpression(), IdentifierName("SetHandle"))) - .WithArgumentList(ArgumentList(SingletonSeparatedList(Argument(IdentifierName(preexistingHandleName))))))))); + .WithArgumentList(ArgumentList(SingletonSeparatedList(Argument(preexistingHandleName)))))))); // public override bool IsInvalid => this.handle.ToInt64() == 0 || this.handle.ToInt64() == -1; ExpressionSyntax thisHandleToInt64 = InvocationExpression(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, thisHandle, IdentifierName(nameof(IntPtr.ToInt64))), ArgumentList()); @@ -290,7 +293,7 @@ public partial class Generator IEnumerable xmlDocParameterTypes = releaseMethodSignature.ParameterTypes.Select(p => p.ToTypeSyntax(this.externSignatureTypeSettings, GeneratingElement.HelperClassMember, default).Type); ClassDeclarationSyntax safeHandleDeclaration = ClassDeclaration(Identifier(safeHandleClassName)) - .AddModifiers(TokenWithSpace(this.Visibility), TokenWithSpace(SyntaxKind.PartialKeyword)) + .AddModifiers(visibilityModifier, TokenWithSpace(SyntaxKind.PartialKeyword)) .WithBaseList(BaseList(SingletonSeparatedList(SimpleBaseType(SafeHandleTypeSyntax)))) .AddMembers(members.ToArray()) .AddAttributeLists(AttributeList().AddAttributes(GeneratedCodeAttribute)) diff --git a/src/Microsoft.Windows.CsWin32/Generator.cs b/src/Microsoft.Windows.CsWin32/Generator.cs index f037cded..6ef0cb0a 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.cs @@ -127,6 +127,7 @@ public Generator(string metadataLibraryPath, Docs? docs, IEnumerable add this.canUseIPropertyValue = this.compilation?.GetTypeByMetadataName("Windows.Foundation.IPropertyValue")?.DeclaredAccessibility == Accessibility.Public; this.canUseComVariant = this.compilation?.GetTypeByMetadataName("System.Runtime.InteropServices.Marshalling.ComVariant") is not null; this.canUseMemberFunctionCallingConvention = this.compilation?.GetTypeByMetadataName("System.Runtime.CompilerServices.CallConvMemberFunction") is not null; + this.canUseMarshalInitHandle = this.compilation?.GetTypeByMetadataName(typeof(Marshal).FullName)?.GetMembers("InitHandle").Length > 0; if (this.FindTypeSymbolIfAlreadyAvailable("System.Runtime.Versioning.SupportedOSPlatformAttribute") is { } attribute) { this.generateSupportedOSPlatformAttributes = true; diff --git a/test/CsWin32Generator.Tests/CsWin32GeneratorTests.cs b/test/CsWin32Generator.Tests/CsWin32GeneratorTests.cs index cb994e81..9fe899b2 100644 --- a/test/CsWin32Generator.Tests/CsWin32GeneratorTests.cs +++ b/test/CsWin32Generator.Tests/CsWin32GeneratorTests.cs @@ -619,4 +619,30 @@ public async Task TestComVariantReturnValue() var method = Assert.Single(methods, m => m.Identifier.Text == "GetCachedPropertyValue"); Assert.Contains("ComVariant", method.ReturnType.ToString()); } + + [Theory, CombinatorialData] // https://github.com/microsoft/CsWin32/issues/1430 + public void UseInitHandleApiWhenPossible( + [CombinatorialValues( + "SysAllocString", // Returns owning safe handle + "ShellExecute", // Returns non-owning safe handle + "OpenProcessToken")] // Returns owning safe handle as an out parameter + string api, + bool initHandleApiAvailable) + { + this.compilation = this.starterCompilations[initHandleApiAvailable ? "net8.0" : "net472"]; + this.GenerateApi(api); + + MethodDeclarationSyntax friendlyOverload = Assert.Single( + this.FindGeneratedMethod(api), + m => !m.AttributeLists.Any(al => al.Attributes.Any(a => a.Name.ToString() == "DllImport"))); + + if (initHandleApiAvailable) + { + Assert.Contains(friendlyOverload.DescendantNodes(), n => n is IdentifierNameSyntax { Identifier.Text: "InitHandle" }); + } + else + { + Assert.DoesNotContain(friendlyOverload.DescendantNodes(), n => n is IdentifierNameSyntax { Identifier.Text: "InitHandle" }); + } + } }