From fc924b8e437006a44e7f150abeb716602120145b Mon Sep 17 00:00:00 2001 From: Arthur Matsur Date: Tue, 13 Jan 2026 11:34:47 +0300 Subject: [PATCH 1/5] Use `Marshal.InitHandle` API to avoid memory leak when OOM happens --- .../FastSyntaxFactory.cs | 6 +- .../Generator.Features.cs | 1 + .../Generator.FriendlyOverloads.cs | 109 +++++++++++++++--- src/Microsoft.Windows.CsWin32/Generator.cs | 1 + .../CsWin32GeneratorTests.cs | 25 ++++ 5 files changed, 122 insertions(+), 20 deletions(-) diff --git a/src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs b/src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs index a4c620a4..94b066e3 100644 --- a/src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs +++ b/src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs @@ -74,6 +74,8 @@ internal static SyntaxToken Token(SyntaxKind kind) internal static BlockSyntax Block(params StatementSyntax[] statements) => SyntaxFactory.Block(OpenBrace, List(statements), CloseBrace); + internal static BlockSyntax Block(IEnumerable statements) => SyntaxFactory.Block(OpenBrace, List(statements), CloseBrace); + internal static ImplicitArrayCreationExpressionSyntax ImplicitArrayCreationExpression(InitializerExpressionSyntax initializerExpression) => SyntaxFactory.ImplicitArrayCreationExpression(Token(SyntaxKind.NewKeyword), Token(SyntaxKind.OpenBracketToken), default, Token(SyntaxKind.CloseBracketToken), initializerExpression); internal static ForStatementSyntax ForStatement(VariableDeclarationSyntax? declaration, ExpressionSyntax condition, SeparatedSyntaxList incrementors, StatementSyntax statement) @@ -100,10 +102,12 @@ internal static ForStatementSyntax ForStatement(VariableDeclarationSyntax? decla internal static DeclarationExpressionSyntax DeclarationExpression(TypeSyntax type, VariableDesignationSyntax designation) => SyntaxFactory.DeclarationExpression(type, designation); - internal static VariableDeclaratorSyntax VariableDeclarator(SyntaxToken identifier) => SyntaxFactory.VariableDeclarator(identifier); + internal static VariableDeclaratorSyntax VariableDeclarator(SyntaxToken identifier, EqualsValueClauseSyntax? initializer = null) => SyntaxFactory.VariableDeclarator(identifier, argumentList: null, initializer: initializer); internal static VariableDeclarationSyntax VariableDeclaration(TypeSyntax type) => SyntaxFactory.VariableDeclaration(type.WithTrailingTrivia(TriviaList(Space))); + internal static VariableDeclarationSyntax VariableDeclaration(TypeSyntax type, params VariableDeclaratorSyntax[] variables) => SyntaxFactory.VariableDeclaration(type.WithTrailingTrivia(TriviaList(Space)), SeparatedList(variables)); + internal static SizeOfExpressionSyntax SizeOfExpression(TypeSyntax type) => SyntaxFactory.SizeOfExpression(Token(SyntaxKind.SizeOfKeyword), Token(SyntaxKind.OpenParenToken), type, Token(SyntaxKind.CloseParenToken)); internal static MemberAccessExpressionSyntax MemberAccessExpression(SyntaxKind kind, ExpressionSyntax expression, SimpleNameSyntax name) => SyntaxFactory.MemberAccessExpression(kind, expression, Token(GetMemberAccessExpressionOperatorTokenKind(kind)), name); diff --git a/src/Microsoft.Windows.CsWin32/Generator.Features.cs b/src/Microsoft.Windows.CsWin32/Generator.Features.cs index 1e286683..a7591320 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.Features.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.Features.cs @@ -22,6 +22,7 @@ public partial class Generator private readonly bool unscopedRefAttributePredefined; private readonly bool canUseComVariant; private readonly bool canUseMemberFunctionCallingConvention; + private readonly bool canUseMarshalInitHandle; private readonly INamedTypeSymbol? runtimeFeatureClass; private readonly bool generateSupportedOSPlatformAttributes; private readonly bool generateSupportedOSPlatformAttributesOnInterfaces; // only supported on net6.0 (https://github.com/dotnet/runtime/pull/48838) diff --git a/src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs b/src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs index 6987a082..a8da0d3b 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs @@ -378,19 +378,49 @@ private IEnumerable DeclareFriendlyOverload( .WithModifiers(TokenList(TokenWithSpace(SyntaxKind.OutKeyword))); // HANDLE SomeLocal; - leadingStatements.Add(LocalDeclarationStatement(VariableDeclaration(pointedElementInfo.ToTypeSyntax(parameterTypeSyntaxSettings, GeneratingElement.FriendlyOverload, null).Type).AddVariables( - VariableDeclarator(typeDefHandleName.Identifier)))); + leadingStatements.Add( + LocalDeclarationStatement( + VariableDeclaration( + pointedElementInfo.ToTypeSyntax(parameterTypeSyntaxSettings, GeneratingElement.FriendlyOverload, null).Type, + VariableDeclarator(typeDefHandleName.Identifier)))); + + if (this.canUseMarshalInitHandle && !doNotRelease) + { + // Some = new SafeHandle(); + leadingStatements.Add( + ExpressionStatement( + AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + origName, + ObjectCreationExpression(safeHandleType)))); + + // Marshal.InitHandle(Some, SomeLocal); + trailingStatements.Add( + ExpressionStatement( + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(nameof(Marshal)), + IdentifierName("InitHandle")), + ArgumentList( + [ + Argument(origName), + Argument(typeDefHandleName), + ])))); + } + else + { + // Some = new SafeHandle(SomeLocal, ownsHandle: true); + trailingStatements.Add(ExpressionStatement(AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + origName, + ObjectCreationExpression(safeHandleType).AddArgumentListArguments( + Argument(this.GetIntPtrFromTypeDef(typeDefHandleName, pointedElementInfo)), + Argument(LiteralExpression(doNotRelease ? SyntaxKind.FalseLiteralExpression : SyntaxKind.TrueLiteralExpression)).WithNameColon(NameColon(IdentifierName("ownsHandle"))))))); + } // Argument: &SomeLocal arguments[paramIndex] = Argument(PrefixUnaryExpression(SyntaxKind.AddressOfExpression, typeDefHandleName)); - - // Some = new SafeHandle(SomeLocal, ownsHandle: true); - trailingStatements.Add(ExpressionStatement(AssignmentExpression( - SyntaxKind.SimpleAssignmentExpression, - origName, - ObjectCreationExpression(safeHandleType).AddArgumentListArguments( - Argument(this.GetIntPtrFromTypeDef(typeDefHandleName, pointedElementInfo)), - Argument(LiteralExpression(doNotRelease ? SyntaxKind.FalseLiteralExpression : SyntaxKind.TrueLiteralExpression)).WithNameColon(NameColon(IdentifierName("ownsHandle"))))))); } } else if (this.options.UseSafeHandles && isIn && !isOut && !isReleaseMethod && parameterTypeInfo is HandleTypeHandleInfo parameterHandleTypeInfo && this.TryGetHandleReleaseMethod(parameterHandleTypeInfo.Handle, paramAttributes, out string? releaseMethod) && !this.Reader.StringComparer.Equals(methodDefinition.Name, releaseMethod) @@ -1108,7 +1138,38 @@ bool TryHandleCountParam(TypeSyntax elementType, bool nullableSource) && returnTypeHandleInfo.Generator.TryGetHandleReleaseMethod(returnTypeHandleInfo.Handle, returnTypeAttributes, out string? returnReleaseMethod) ? this.RequestSafeHandle(returnReleaseMethod) : null; - if ((returnSafeHandleType is object || minorSignatureChange) && !signatureChanged) + IdentifierNameSyntax resultLocal = IdentifierName("__result"); + + if (this.canUseMarshalInitHandle && returnSafeHandleType is not null) + { + IdentifierNameSyntax resultSafeHandleLocal = IdentifierName("__resultSafeHandle"); + + // SafeHandle __resultSafeHandle = new SafeHandle(); + leadingStatements.Add( + LocalDeclarationStatement( + VariableDeclaration( + returnSafeHandleType, + VariableDeclarator( + resultSafeHandleLocal.Identifier, + EqualsValueClause( + ObjectCreationExpression(returnSafeHandleType)))))); + + // Marshal.InitHandle(__resultSafeHandle, __result); + trailingStatements.Add( + ExpressionStatement( + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(nameof(Marshal)), + IdentifierName("InitHandle")), + ArgumentList( + [ + Argument(resultSafeHandleLocal), + Argument(resultLocal), + ])))); + } + + if ((returnSafeHandleType is not null || minorSignatureChange) && !signatureChanged) { // The parameter types are all the same, but we need a friendly overload with a different return type. // Our only choice is to rename the friendly overload. @@ -1145,20 +1206,30 @@ bool TryHandleCountParam(TypeSyntax elementType, bool nullableSource) }) .WithArgumentList(FixTrivia(ArgumentList().AddArguments(arguments.ToArray()))); bool hasVoidReturn = externMethodReturnType is PredefinedTypeSyntax { Keyword: { RawKind: (int)SyntaxKind.VoidKeyword } }; - BlockSyntax? body = Block().AddStatements(leadingStatements.ToArray()); - IdentifierNameSyntax resultLocal = IdentifierName("__result"); - if (returnSafeHandleType is object) + BlockSyntax? body = Block(leadingStatements); + if (returnSafeHandleType is not null) { - //// HANDLE result = invocation(); + // HANDLE result = invocation(); body = body.AddStatements(LocalDeclarationStatement(VariableDeclaration(externMethodReturnType) .AddVariables(VariableDeclarator(resultLocal.Identifier).WithInitializer(EqualsValueClause(externInvocation))))); body = body.AddStatements(trailingStatements.ToArray()); - //// return new SafeHandle(result, ownsHandle: true); - body = body.AddStatements(ReturnStatement(ObjectCreationExpression(returnSafeHandleType).AddArgumentListArguments( - Argument(this.GetIntPtrFromTypeDef(resultLocal, originalSignature.ReturnType)), - Argument(LiteralExpression(doNotRelease ? SyntaxKind.FalseLiteralExpression : SyntaxKind.TrueLiteralExpression)).WithNameColon(NameColon(IdentifierName("ownsHandle")))))); + ReturnStatementSyntax returnStatement; + if (this.canUseMarshalInitHandle) + { + // return __resultSafeHandle; + returnStatement = ReturnStatement(IdentifierName("__resultSafeHandle")); + } + else + { + // return new SafeHandle(result, ownsHandle: true); + returnStatement = ReturnStatement(ObjectCreationExpression(returnSafeHandleType).AddArgumentListArguments( + Argument(this.GetIntPtrFromTypeDef(resultLocal, originalSignature.ReturnType)), + Argument(LiteralExpression(doNotRelease ? SyntaxKind.FalseLiteralExpression : SyntaxKind.TrueLiteralExpression)).WithNameColon(NameColon(IdentifierName("ownsHandle"))))); + } + + body = body.AddStatements(returnStatement); } else if (hasVoidReturn) { diff --git a/src/Microsoft.Windows.CsWin32/Generator.cs b/src/Microsoft.Windows.CsWin32/Generator.cs index f037cded..6ef0cb0a 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.cs @@ -127,6 +127,7 @@ public Generator(string metadataLibraryPath, Docs? docs, IEnumerable add this.canUseIPropertyValue = this.compilation?.GetTypeByMetadataName("Windows.Foundation.IPropertyValue")?.DeclaredAccessibility == Accessibility.Public; this.canUseComVariant = this.compilation?.GetTypeByMetadataName("System.Runtime.InteropServices.Marshalling.ComVariant") is not null; this.canUseMemberFunctionCallingConvention = this.compilation?.GetTypeByMetadataName("System.Runtime.CompilerServices.CallConvMemberFunction") is not null; + this.canUseMarshalInitHandle = this.compilation?.GetTypeByMetadataName(typeof(Marshal).FullName)?.GetMembers("InitHandle").Length > 0; if (this.FindTypeSymbolIfAlreadyAvailable("System.Runtime.Versioning.SupportedOSPlatformAttribute") is { } attribute) { this.generateSupportedOSPlatformAttributes = true; diff --git a/test/CsWin32Generator.Tests/CsWin32GeneratorTests.cs b/test/CsWin32Generator.Tests/CsWin32GeneratorTests.cs index cb994e81..60a758ef 100644 --- a/test/CsWin32Generator.Tests/CsWin32GeneratorTests.cs +++ b/test/CsWin32Generator.Tests/CsWin32GeneratorTests.cs @@ -619,4 +619,29 @@ public async Task TestComVariantReturnValue() var method = Assert.Single(methods, m => m.Identifier.Text == "GetCachedPropertyValue"); Assert.Contains("ComVariant", method.ReturnType.ToString()); } + + [Theory, CombinatorialData] // https://github.com/microsoft/CsWin32/issues/1430 + public void UseInitHandleApiWhenPossible( + [CombinatorialValues( + "SysAllocString", // Returns safe handle directly + "OpenProcessToken")] // Returns safe handle as an out parameter + string api, + bool initHandleApiAvailable) + { + this.compilation = this.starterCompilations[initHandleApiAvailable ? "net8.0" : "net472"]; + this.GenerateApi(api); + + MethodDeclarationSyntax friendlyOverload = Assert.Single( + this.FindGeneratedMethod(api), + m => !m.AttributeLists.Any(al => al.Attributes.Any(a => a.Name.ToString() == "DllImport"))); + + if (initHandleApiAvailable) + { + Assert.Contains(friendlyOverload.DescendantNodes(), n => n.ToString() == "InitHandle"); + } + else + { + Assert.DoesNotContain(friendlyOverload.DescendantNodes(), n => n.ToString() == "InitHandle"); + } + } } From b08e97a763b3f26196e2e86daeec6ec3e5b1a0ca Mon Sep 17 00:00:00 2001 From: Arthur Matsur Date: Wed, 14 Jan 2026 10:00:33 +0300 Subject: [PATCH 2/5] Use pattern matching --- test/CsWin32Generator.Tests/CsWin32GeneratorTests.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/CsWin32Generator.Tests/CsWin32GeneratorTests.cs b/test/CsWin32Generator.Tests/CsWin32GeneratorTests.cs index 60a758ef..45ea7f96 100644 --- a/test/CsWin32Generator.Tests/CsWin32GeneratorTests.cs +++ b/test/CsWin32Generator.Tests/CsWin32GeneratorTests.cs @@ -637,11 +637,11 @@ public void UseInitHandleApiWhenPossible( if (initHandleApiAvailable) { - Assert.Contains(friendlyOverload.DescendantNodes(), n => n.ToString() == "InitHandle"); + Assert.Contains(friendlyOverload.DescendantNodes(), n => n is IdentifierNameSyntax { Identifier.Text: "InitHandle" }); } else { - Assert.DoesNotContain(friendlyOverload.DescendantNodes(), n => n.ToString() == "InitHandle"); + Assert.DoesNotContain(friendlyOverload.DescendantNodes(), n => n is IdentifierNameSyntax { Identifier.Text: "InitHandle" }); } } } From 10bf83b937cc04a452ce0dc3fa921bedaa3e1bf9 Mon Sep 17 00:00:00 2001 From: Arthur Matsur Date: Wed, 14 Jan 2026 10:28:41 +0300 Subject: [PATCH 3/5] Fix obtaining `IntPtr` from handles baked by non-`IntPtr` values --- src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs b/src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs index a8da0d3b..8caef292 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs @@ -405,7 +405,7 @@ private IEnumerable DeclareFriendlyOverload( ArgumentList( [ Argument(origName), - Argument(typeDefHandleName), + Argument(this.GetIntPtrFromTypeDef(typeDefHandleName, pointedElementInfo)), ])))); } else @@ -1165,7 +1165,7 @@ bool TryHandleCountParam(TypeSyntax elementType, bool nullableSource) ArgumentList( [ Argument(resultSafeHandleLocal), - Argument(resultLocal), + Argument(this.GetIntPtrFromTypeDef(resultLocal, originalSignature.ReturnType)), ])))); } From d743bb8b910fb00df66f3ca49573b479a4f1e19b Mon Sep 17 00:00:00 2001 From: Arthur Matsur Date: Wed, 14 Jan 2026 12:25:16 +0300 Subject: [PATCH 4/5] Use `InitHandle` for non-owning handles when possible --- .../FastSyntaxFactory.cs | 4 +- .../Generator.FriendlyOverloads.cs | 42 ++++++++++--- .../Generator.Handle.cs | 59 +++++++++++++++---- .../CsWin32GeneratorTests.cs | 20 ++++++- 4 files changed, 102 insertions(+), 23 deletions(-) diff --git a/src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs b/src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs index 94b066e3..6d0846ba 100644 --- a/src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs +++ b/src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs @@ -194,7 +194,7 @@ internal static ForStatementSyntax ForStatement(VariableDeclarationSyntax? decla internal static InitializerExpressionSyntax InitializerExpression(SyntaxKind kind, SeparatedSyntaxList expressions) => SyntaxFactory.InitializerExpression(kind, OpenBrace, expressions, CloseBrace); - internal static ObjectCreationExpressionSyntax ObjectCreationExpression(TypeSyntax type) => SyntaxFactory.ObjectCreationExpression(Token(TriviaList(), SyntaxKind.NewKeyword, TriviaList(Space)), type, ArgumentList(), null); + internal static ObjectCreationExpressionSyntax ObjectCreationExpression(TypeSyntax type, params SeparatedSyntaxList arguments) => SyntaxFactory.ObjectCreationExpression(Token(TriviaList(), SyntaxKind.NewKeyword, TriviaList(Space)), type, ArgumentList(arguments), null); internal static ArrayCreationExpressionSyntax ArrayCreationExpression(ArrayTypeSyntax type, InitializerExpressionSyntax? initializer = null) => SyntaxFactory.ArrayCreationExpression(Token(SyntaxKind.NewKeyword), type, initializer); @@ -299,7 +299,7 @@ internal static SyntaxList SingletonList(TNode node) internal static AttributeArgumentListSyntax AttributeArgumentList(SeparatedSyntaxList arguments = default) => SyntaxFactory.AttributeArgumentList(Token(SyntaxKind.OpenParenToken), arguments, Token(SyntaxKind.CloseParenToken)); - internal static AttributeListSyntax AttributeList() => SyntaxFactory.AttributeList(Token(SyntaxKind.OpenBracketToken), null, SeparatedList(), TokenWithLineFeed(SyntaxKind.CloseBracketToken)); + internal static AttributeListSyntax AttributeList(params SeparatedSyntaxList attributes) => SyntaxFactory.AttributeList(Token(SyntaxKind.OpenBracketToken), null, attributes, TokenWithLineFeed(SyntaxKind.CloseBracketToken)); internal static SyntaxList List() where TNode : SyntaxNode => SyntaxFactory.List(); diff --git a/src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs b/src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs index 8caef292..8861275d 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs @@ -384,15 +384,22 @@ private IEnumerable DeclareFriendlyOverload( pointedElementInfo.ToTypeSyntax(parameterTypeSyntaxSettings, GeneratingElement.FriendlyOverload, null).Type, VariableDeclarator(typeDefHandleName.Identifier)))); - if (this.canUseMarshalInitHandle && !doNotRelease) + ArgumentSyntax ownsHandleArgument = Argument( + NameColon(IdentifierName("ownsHandle")), + refKindKeyword: default, + LiteralExpression(doNotRelease ? SyntaxKind.FalseLiteralExpression : SyntaxKind.TrueLiteralExpression)); + + var isBclSafeParameterHandle = BclInteropSafeHandles.ContainsKey(outReleaseMethod); + + if (this.canUseMarshalInitHandle && (!doNotRelease || !isBclSafeParameterHandle)) { - // Some = new SafeHandle(); + // Some = new SafeHandle(ownsHandle: true); leadingStatements.Add( ExpressionStatement( AssignmentExpression( SyntaxKind.SimpleAssignmentExpression, origName, - ObjectCreationExpression(safeHandleType)))); + ObjectCreationExpression(safeHandleType, !isBclSafeParameterHandle ? [ownsHandleArgument] : [])))); // Marshal.InitHandle(Some, SomeLocal); trailingStatements.Add( @@ -416,7 +423,7 @@ private IEnumerable DeclareFriendlyOverload( origName, ObjectCreationExpression(safeHandleType).AddArgumentListArguments( Argument(this.GetIntPtrFromTypeDef(typeDefHandleName, pointedElementInfo)), - Argument(LiteralExpression(doNotRelease ? SyntaxKind.FalseLiteralExpression : SyntaxKind.TrueLiteralExpression)).WithNameColon(NameColon(IdentifierName("ownsHandle"))))))); + ownsHandleArgument)))); } // Argument: &SomeLocal @@ -1140,11 +1147,23 @@ bool TryHandleCountParam(TypeSyntax elementType, bool nullableSource) IdentifierNameSyntax resultLocal = IdentifierName("__result"); - if (this.canUseMarshalInitHandle && returnSafeHandleType is not null) + var isBclSafeReturnHandle = returnSafeHandleType is not null && BclInteropSafeHandles.ContainsValue(returnSafeHandleType); + + if (this.canUseMarshalInitHandle && + returnSafeHandleType is not null && + (!doNotRelease || !isBclSafeReturnHandle)) { IdentifierNameSyntax resultSafeHandleLocal = IdentifierName("__resultSafeHandle"); - // SafeHandle __resultSafeHandle = new SafeHandle(); + SeparatedSyntaxList constructorParameters = !isBclSafeReturnHandle ? + [ + Argument( + NameColon(IdentifierName("ownsHandle")), + refKindKeyword: default, + LiteralExpression(doNotRelease ? SyntaxKind.FalseLiteralExpression : SyntaxKind.TrueLiteralExpression)) + ] : []; + + // SafeHandle __resultSafeHandle = new SafeHandle(ownsHandle: true); leadingStatements.Add( LocalDeclarationStatement( VariableDeclaration( @@ -1152,7 +1171,9 @@ bool TryHandleCountParam(TypeSyntax elementType, bool nullableSource) VariableDeclarator( resultSafeHandleLocal.Identifier, EqualsValueClause( - ObjectCreationExpression(returnSafeHandleType)))))); + ObjectCreationExpression( + returnSafeHandleType, + constructorParameters)))))); // Marshal.InitHandle(__resultSafeHandle, __result); trailingStatements.Add( @@ -1216,7 +1237,7 @@ bool TryHandleCountParam(TypeSyntax elementType, bool nullableSource) body = body.AddStatements(trailingStatements.ToArray()); ReturnStatementSyntax returnStatement; - if (this.canUseMarshalInitHandle) + if (this.canUseMarshalInitHandle && (!doNotRelease || !isBclSafeReturnHandle)) { // return __resultSafeHandle; returnStatement = ReturnStatement(IdentifierName("__resultSafeHandle")); @@ -1226,7 +1247,10 @@ bool TryHandleCountParam(TypeSyntax elementType, bool nullableSource) // return new SafeHandle(result, ownsHandle: true); returnStatement = ReturnStatement(ObjectCreationExpression(returnSafeHandleType).AddArgumentListArguments( Argument(this.GetIntPtrFromTypeDef(resultLocal, originalSignature.ReturnType)), - Argument(LiteralExpression(doNotRelease ? SyntaxKind.FalseLiteralExpression : SyntaxKind.TrueLiteralExpression)).WithNameColon(NameColon(IdentifierName("ownsHandle"))))); + Argument( + NameColon(IdentifierName("ownsHandle")), + refKindKeyword: default, + LiteralExpression(doNotRelease ? SyntaxKind.FalseLiteralExpression : SyntaxKind.TrueLiteralExpression)))); } body = body.AddStatements(returnStatement); diff --git a/src/Microsoft.Windows.CsWin32/Generator.Handle.cs b/src/Microsoft.Windows.CsWin32/Generator.Handle.cs index ad7d77f4..a646f368 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.Handle.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.Handle.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. +using System.ComponentModel; + namespace Microsoft.Windows.CsWin32; public partial class Generator @@ -123,29 +125,66 @@ public partial class Generator VariableDeclarator(invalidValueFieldName.Identifier).WithInitializer(EqualsValueClause(invalidHandleIntPtr)))) .AddModifiers(TokenWithSpace(SyntaxKind.PrivateKeyword), TokenWithSpace(SyntaxKind.StaticKeyword), TokenWithSpace(SyntaxKind.ReadOnlyKeyword))); + SyntaxToken visibilityModifier = TokenWithSpace(this.Visibility); + + IdentifierNameSyntax ownsHandleName = IdentifierName("ownsHandle"); + ParameterSyntax ownsHandleParameter = Parameter(ownsHandleName.Identifier) + .WithType(PredefinedType(TokenWithSpace(SyntaxKind.BoolKeyword))) + .WithDefault(EqualsValueClause(LiteralExpression(SyntaxKind.TrueLiteralExpression))); + + // [EditorBrowsable(EditorBrowsableState.Advanced)] + // public SafeHandle(bool ownsHandle = true) : base(INVALID_HANDLE_VALUE, ownsHandle) + members.Add( + ConstructorDeclaration(safeHandleTypeIdentifier.Identifier) + .AddModifiers(visibilityModifier) + .AddAttributeLists( + AttributeList( + [ + Attribute(ParseName($"global::{typeof(EditorBrowsableAttribute).FullName}")) + .AddArgumentListArguments( + AttributeArgument( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + ParseName($"global::{typeof(EditorBrowsableState).FullName}"), + IdentifierName(nameof(EditorBrowsableState.Advanced))))) + ])) + .AddParameterListParameters(ownsHandleParameter) + .WithInitializer( + ConstructorInitializer( + SyntaxKind.BaseConstructorInitializer, + ArgumentList( + [ + Argument(invalidValueFieldName), + Argument(ownsHandleName), + ]))) + .WithBody(Block()) + .WithLeadingTrivia(ParseLeadingTrivia($""" +/// +/// This constructor is intended to be used when the handle is initialized with {(this.canUseMarshalInitHandle ? "" : "Marshal.InitHandle")} API after creation +/// {'\n'} +"""))); + // public SafeHandle() : base(INVALID_HANDLE_VALUE, true) members.Add(ConstructorDeclaration(safeHandleTypeIdentifier.Identifier) - .AddModifiers(TokenWithSpace(this.Visibility)) + .AddModifiers(visibilityModifier) .WithInitializer(ConstructorInitializer(SyntaxKind.BaseConstructorInitializer, ArgumentList().AddArguments( Argument(invalidValueFieldName), Argument(LiteralExpression(SyntaxKind.TrueLiteralExpression))))) .WithBody(Block())); // public SafeHandle(IntPtr preexistingHandle, bool ownsHandle = true) : base(INVALID_HANDLE_VALUE, ownsHandle) { this.SetHandle(preexistingHandle); } - const string preexistingHandleName = "preexistingHandle"; - const string ownsHandleName = "ownsHandle"; + IdentifierNameSyntax preexistingHandleName = IdentifierName("preexistingHandle"); members.Add(ConstructorDeclaration(safeHandleTypeIdentifier.Identifier) - .AddModifiers(TokenWithSpace(this.Visibility)) + .AddModifiers(visibilityModifier) .AddParameterListParameters( - Parameter(Identifier(preexistingHandleName)).WithType(IntPtrTypeSyntax.WithTrailingTrivia(TriviaList(Space))), - Parameter(Identifier(ownsHandleName)).WithType(PredefinedType(TokenWithSpace(SyntaxKind.BoolKeyword))) - .WithDefault(EqualsValueClause(LiteralExpression(SyntaxKind.TrueLiteralExpression)))) + Parameter(preexistingHandleName.Identifier).WithType(IntPtrTypeSyntax.WithTrailingTrivia(TriviaList(Space))), + ownsHandleParameter) .WithInitializer(ConstructorInitializer(SyntaxKind.BaseConstructorInitializer, ArgumentList().AddArguments( Argument(invalidValueFieldName), - Argument(IdentifierName(ownsHandleName))))) + Argument(ownsHandleName)))) .WithBody(Block().AddStatements( ExpressionStatement(InvocationExpression(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, ThisExpression(), IdentifierName("SetHandle"))) - .WithArgumentList(ArgumentList(SingletonSeparatedList(Argument(IdentifierName(preexistingHandleName))))))))); + .WithArgumentList(ArgumentList(SingletonSeparatedList(Argument(preexistingHandleName)))))))); // public override bool IsInvalid => this.handle.ToInt64() == 0 || this.handle.ToInt64() == -1; ExpressionSyntax thisHandleToInt64 = InvocationExpression(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, thisHandle, IdentifierName(nameof(IntPtr.ToInt64))), ArgumentList()); @@ -290,7 +329,7 @@ public partial class Generator IEnumerable xmlDocParameterTypes = releaseMethodSignature.ParameterTypes.Select(p => p.ToTypeSyntax(this.externSignatureTypeSettings, GeneratingElement.HelperClassMember, default).Type); ClassDeclarationSyntax safeHandleDeclaration = ClassDeclaration(Identifier(safeHandleClassName)) - .AddModifiers(TokenWithSpace(this.Visibility), TokenWithSpace(SyntaxKind.PartialKeyword)) + .AddModifiers(visibilityModifier, TokenWithSpace(SyntaxKind.PartialKeyword)) .WithBaseList(BaseList(SingletonSeparatedList(SimpleBaseType(SafeHandleTypeSyntax)))) .AddMembers(members.ToArray()) .AddAttributeLists(AttributeList().AddAttributes(GeneratedCodeAttribute)) diff --git a/test/CsWin32Generator.Tests/CsWin32GeneratorTests.cs b/test/CsWin32Generator.Tests/CsWin32GeneratorTests.cs index 45ea7f96..0241413f 100644 --- a/test/CsWin32Generator.Tests/CsWin32GeneratorTests.cs +++ b/test/CsWin32Generator.Tests/CsWin32GeneratorTests.cs @@ -623,8 +623,9 @@ public async Task TestComVariantReturnValue() [Theory, CombinatorialData] // https://github.com/microsoft/CsWin32/issues/1430 public void UseInitHandleApiWhenPossible( [CombinatorialValues( - "SysAllocString", // Returns safe handle directly - "OpenProcessToken")] // Returns safe handle as an out parameter + "SysAllocString", // Returns owning custom safe handle + "ShellExecute", // Returns non-owning custom safe handle + "OpenProcessToken")] // Returns owning BCL safe handle as an out parameter string api, bool initHandleApiAvailable) { @@ -644,4 +645,19 @@ public void UseInitHandleApiWhenPossible( Assert.DoesNotContain(friendlyOverload.DescendantNodes(), n => n is IdentifierNameSyntax { Identifier.Text: "InitHandle" }); } } + + [Fact] + public void UnableToUseInitHandleDueToBclSafeHandleLimitations() + { + // Although Marshal.InitHandle API is available we cannot use it + // since the returned handle must be non-owning and BCL handle + // which is used as a safe handle for HANDLE doesn't expose functionality + // to create non-owning handle and then initialize it + this.compilation = this.starterCompilations["net8.0"]; + this.GenerateApi("GetProcessHeap"); + + MethodDeclarationSyntax friendlyOverload = Assert.Single(this.FindGeneratedMethod("GetProcessHeap_SafeHandle")); + + Assert.DoesNotContain(friendlyOverload.DescendantNodes(), n => n is IdentifierNameSyntax { Identifier.Text: "InitHandle" }); + } } From c220ab822423a380ead67587a67bb3f2e10955d1 Mon Sep 17 00:00:00 2001 From: Arthur Matsur Date: Thu, 15 Jan 2026 10:37:55 +0300 Subject: [PATCH 5/5] Simplify --- .../FastSyntaxFactory.cs | 4 +- .../Generator.FriendlyOverloads.cs | 34 ++++++-------- .../Generator.Handle.cs | 44 ++----------------- .../CsWin32GeneratorTests.cs | 21 ++------- 4 files changed, 22 insertions(+), 81 deletions(-) diff --git a/src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs b/src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs index 6d0846ba..31e7a249 100644 --- a/src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs +++ b/src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs @@ -194,7 +194,7 @@ internal static ForStatementSyntax ForStatement(VariableDeclarationSyntax? decla internal static InitializerExpressionSyntax InitializerExpression(SyntaxKind kind, SeparatedSyntaxList expressions) => SyntaxFactory.InitializerExpression(kind, OpenBrace, expressions, CloseBrace); - internal static ObjectCreationExpressionSyntax ObjectCreationExpression(TypeSyntax type, params SeparatedSyntaxList arguments) => SyntaxFactory.ObjectCreationExpression(Token(TriviaList(), SyntaxKind.NewKeyword, TriviaList(Space)), type, ArgumentList(arguments), null); + internal static ObjectCreationExpressionSyntax ObjectCreationExpression(TypeSyntax type, SeparatedSyntaxList arguments = default) => SyntaxFactory.ObjectCreationExpression(Token(TriviaList(), SyntaxKind.NewKeyword, TriviaList(Space)), type, ArgumentList(arguments), null); internal static ArrayCreationExpressionSyntax ArrayCreationExpression(ArrayTypeSyntax type, InitializerExpressionSyntax? initializer = null) => SyntaxFactory.ArrayCreationExpression(Token(SyntaxKind.NewKeyword), type, initializer); @@ -309,7 +309,7 @@ internal static SyntaxList List(IEnumerable nodes) internal static ParameterListSyntax ParameterList() => SyntaxFactory.ParameterList(Token(SyntaxKind.OpenParenToken), SeparatedList(), Token(SyntaxKind.CloseParenToken)); - internal static ArgumentListSyntax ArgumentList(SeparatedSyntaxList arguments = default) => SyntaxFactory.ArgumentList(Token(SyntaxKind.OpenParenToken), arguments, Token(SyntaxKind.CloseParenToken)); + internal static ArgumentListSyntax ArgumentList(params SeparatedSyntaxList arguments) => SyntaxFactory.ArgumentList(Token(SyntaxKind.OpenParenToken), arguments, Token(SyntaxKind.CloseParenToken)); internal static AssignmentExpressionSyntax AssignmentExpression(SyntaxKind kind, ExpressionSyntax left, ExpressionSyntax right) => SyntaxFactory.AssignmentExpression(kind, left, Token(GetAssignmentExpressionOperatorTokenKind(kind)).WithLeadingTrivia(Space), right); diff --git a/src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs b/src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs index 8861275d..798705e5 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs @@ -389,17 +389,15 @@ private IEnumerable DeclareFriendlyOverload( refKindKeyword: default, LiteralExpression(doNotRelease ? SyntaxKind.FalseLiteralExpression : SyntaxKind.TrueLiteralExpression)); - var isBclSafeParameterHandle = BclInteropSafeHandles.ContainsKey(outReleaseMethod); - - if (this.canUseMarshalInitHandle && (!doNotRelease || !isBclSafeParameterHandle)) + if (this.canUseMarshalInitHandle) { - // Some = new SafeHandle(ownsHandle: true); + // Some = new SafeHandle(default, ownsHandle: true); leadingStatements.Add( ExpressionStatement( AssignmentExpression( SyntaxKind.SimpleAssignmentExpression, origName, - ObjectCreationExpression(safeHandleType, !isBclSafeParameterHandle ? [ownsHandleArgument] : [])))); + ObjectCreationExpression(safeHandleType, [Argument(LiteralExpression(SyntaxKind.DefaultLiteralExpression)), ownsHandleArgument])))); // Marshal.InitHandle(Some, SomeLocal); trailingStatements.Add( @@ -1147,23 +1145,11 @@ bool TryHandleCountParam(TypeSyntax elementType, bool nullableSource) IdentifierNameSyntax resultLocal = IdentifierName("__result"); - var isBclSafeReturnHandle = returnSafeHandleType is not null && BclInteropSafeHandles.ContainsValue(returnSafeHandleType); - - if (this.canUseMarshalInitHandle && - returnSafeHandleType is not null && - (!doNotRelease || !isBclSafeReturnHandle)) + if (this.canUseMarshalInitHandle && returnSafeHandleType is not null) { IdentifierNameSyntax resultSafeHandleLocal = IdentifierName("__resultSafeHandle"); - SeparatedSyntaxList constructorParameters = !isBclSafeReturnHandle ? - [ - Argument( - NameColon(IdentifierName("ownsHandle")), - refKindKeyword: default, - LiteralExpression(doNotRelease ? SyntaxKind.FalseLiteralExpression : SyntaxKind.TrueLiteralExpression)) - ] : []; - - // SafeHandle __resultSafeHandle = new SafeHandle(ownsHandle: true); + // SafeHandle __resultSafeHandle = new SafeHandle(default, ownsHandle: true); leadingStatements.Add( LocalDeclarationStatement( VariableDeclaration( @@ -1173,7 +1159,13 @@ returnSafeHandleType is not null && EqualsValueClause( ObjectCreationExpression( returnSafeHandleType, - constructorParameters)))))); + [ + Argument(LiteralExpression(SyntaxKind.DefaultLiteralExpression)), + Argument( + NameColon(IdentifierName("ownsHandle")), + refKindKeyword: default, + LiteralExpression(doNotRelease ? SyntaxKind.FalseLiteralExpression : SyntaxKind.TrueLiteralExpression)) + ])))))); // Marshal.InitHandle(__resultSafeHandle, __result); trailingStatements.Add( @@ -1237,7 +1229,7 @@ returnSafeHandleType is not null && body = body.AddStatements(trailingStatements.ToArray()); ReturnStatementSyntax returnStatement; - if (this.canUseMarshalInitHandle && (!doNotRelease || !isBclSafeReturnHandle)) + if (this.canUseMarshalInitHandle) { // return __resultSafeHandle; returnStatement = ReturnStatement(IdentifierName("__resultSafeHandle")); diff --git a/src/Microsoft.Windows.CsWin32/Generator.Handle.cs b/src/Microsoft.Windows.CsWin32/Generator.Handle.cs index a646f368..5c300b8c 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.Handle.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.Handle.cs @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. -using System.ComponentModel; - namespace Microsoft.Windows.CsWin32; public partial class Generator @@ -127,43 +125,6 @@ public partial class Generator SyntaxToken visibilityModifier = TokenWithSpace(this.Visibility); - IdentifierNameSyntax ownsHandleName = IdentifierName("ownsHandle"); - ParameterSyntax ownsHandleParameter = Parameter(ownsHandleName.Identifier) - .WithType(PredefinedType(TokenWithSpace(SyntaxKind.BoolKeyword))) - .WithDefault(EqualsValueClause(LiteralExpression(SyntaxKind.TrueLiteralExpression))); - - // [EditorBrowsable(EditorBrowsableState.Advanced)] - // public SafeHandle(bool ownsHandle = true) : base(INVALID_HANDLE_VALUE, ownsHandle) - members.Add( - ConstructorDeclaration(safeHandleTypeIdentifier.Identifier) - .AddModifiers(visibilityModifier) - .AddAttributeLists( - AttributeList( - [ - Attribute(ParseName($"global::{typeof(EditorBrowsableAttribute).FullName}")) - .AddArgumentListArguments( - AttributeArgument( - MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - ParseName($"global::{typeof(EditorBrowsableState).FullName}"), - IdentifierName(nameof(EditorBrowsableState.Advanced))))) - ])) - .AddParameterListParameters(ownsHandleParameter) - .WithInitializer( - ConstructorInitializer( - SyntaxKind.BaseConstructorInitializer, - ArgumentList( - [ - Argument(invalidValueFieldName), - Argument(ownsHandleName), - ]))) - .WithBody(Block()) - .WithLeadingTrivia(ParseLeadingTrivia($""" -/// -/// This constructor is intended to be used when the handle is initialized with {(this.canUseMarshalInitHandle ? "" : "Marshal.InitHandle")} API after creation -/// {'\n'} -"""))); - // public SafeHandle() : base(INVALID_HANDLE_VALUE, true) members.Add(ConstructorDeclaration(safeHandleTypeIdentifier.Identifier) .AddModifiers(visibilityModifier) @@ -174,11 +135,14 @@ public partial class Generator // public SafeHandle(IntPtr preexistingHandle, bool ownsHandle = true) : base(INVALID_HANDLE_VALUE, ownsHandle) { this.SetHandle(preexistingHandle); } IdentifierNameSyntax preexistingHandleName = IdentifierName("preexistingHandle"); + IdentifierNameSyntax ownsHandleName = IdentifierName("ownsHandle"); members.Add(ConstructorDeclaration(safeHandleTypeIdentifier.Identifier) .AddModifiers(visibilityModifier) .AddParameterListParameters( Parameter(preexistingHandleName.Identifier).WithType(IntPtrTypeSyntax.WithTrailingTrivia(TriviaList(Space))), - ownsHandleParameter) + Parameter(ownsHandleName.Identifier) + .WithType(PredefinedType(TokenWithSpace(SyntaxKind.BoolKeyword))) + .WithDefault(EqualsValueClause(LiteralExpression(SyntaxKind.TrueLiteralExpression)))) .WithInitializer(ConstructorInitializer(SyntaxKind.BaseConstructorInitializer, ArgumentList().AddArguments( Argument(invalidValueFieldName), Argument(ownsHandleName)))) diff --git a/test/CsWin32Generator.Tests/CsWin32GeneratorTests.cs b/test/CsWin32Generator.Tests/CsWin32GeneratorTests.cs index 0241413f..9fe899b2 100644 --- a/test/CsWin32Generator.Tests/CsWin32GeneratorTests.cs +++ b/test/CsWin32Generator.Tests/CsWin32GeneratorTests.cs @@ -623,9 +623,9 @@ public async Task TestComVariantReturnValue() [Theory, CombinatorialData] // https://github.com/microsoft/CsWin32/issues/1430 public void UseInitHandleApiWhenPossible( [CombinatorialValues( - "SysAllocString", // Returns owning custom safe handle - "ShellExecute", // Returns non-owning custom safe handle - "OpenProcessToken")] // Returns owning BCL safe handle as an out parameter + "SysAllocString", // Returns owning safe handle + "ShellExecute", // Returns non-owning safe handle + "OpenProcessToken")] // Returns owning safe handle as an out parameter string api, bool initHandleApiAvailable) { @@ -645,19 +645,4 @@ public void UseInitHandleApiWhenPossible( Assert.DoesNotContain(friendlyOverload.DescendantNodes(), n => n is IdentifierNameSyntax { Identifier.Text: "InitHandle" }); } } - - [Fact] - public void UnableToUseInitHandleDueToBclSafeHandleLimitations() - { - // Although Marshal.InitHandle API is available we cannot use it - // since the returned handle must be non-owning and BCL handle - // which is used as a safe handle for HANDLE doesn't expose functionality - // to create non-owning handle and then initialize it - this.compilation = this.starterCompilations["net8.0"]; - this.GenerateApi("GetProcessHeap"); - - MethodDeclarationSyntax friendlyOverload = Assert.Single(this.FindGeneratedMethod("GetProcessHeap_SafeHandle")); - - Assert.DoesNotContain(friendlyOverload.DescendantNodes(), n => n is IdentifierNameSyntax { Identifier.Text: "InitHandle" }); - } }