From e4655ea525719efdec2e54efe576e22b5038fd2d Mon Sep 17 00:00:00 2001 From: Andrew Arnott Date: Wed, 9 Oct 2024 09:51:00 -0600 Subject: [PATCH] Preserve null terminating character in `ref Span` parameters Fixes #1295 --- .../Generator.FriendlyOverloads.cs | 7 ++++--- test/GenerationSandbox.Tests/BasicTests.cs | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs b/src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs index a0634618..e277b7c7 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs @@ -449,7 +449,7 @@ private IEnumerable DeclareFriendlyOverloads(MethodDefi // wstrParam1 arguments[param.SequenceNumber - 1] = Argument(localWstrName); - // if (buffer != null && buffer.LastIndexOf('\0') == -1) throw new ArgumentException("Required null terminator is missing.", "Param1"); + // if (Param1 != null && Param1.LastIndexOf('\0') == -1) throw new ArgumentException("Required null terminator is missing.", "Param1"); InvocationExpressionSyntax lastIndexOf = InvocationExpression( MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, origName, IdentifierName(nameof(MemoryExtensions.LastIndexOf))), ArgumentList().AddArguments(Argument(LiteralExpression(SyntaxKind.CharacterLiteralExpression, Literal('\0'))))); @@ -465,7 +465,8 @@ private IEnumerable DeclareFriendlyOverloads(MethodDefi leadingStatements.Add(LocalDeclarationStatement( VariableDeclaration(externParam.Type).AddVariables(VariableDeclarator(localWstrName.Identifier).WithInitializer(EqualsValueClause(localName))))); - // Param1 = Param1.Slice(0, wstrParam1.Length); + // Preserve the null terminator in the result, which contractually was included in the input. + // Param1 = Param1.Slice(0, wstrParam1.Length + 1); trailingStatements.Add(ExpressionStatement(AssignmentExpression( SyntaxKind.SimpleAssignmentExpression, origName, @@ -473,7 +474,7 @@ private IEnumerable DeclareFriendlyOverloads(MethodDefi MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, origName, IdentifierName(nameof(Span.Slice))), ArgumentList().AddArguments( Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0))), - Argument(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, localWstrName, IdentifierName("Length")))))))); + Argument(BinaryExpression(SyntaxKind.AddExpression, MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, localWstrName, IdentifierName("Length")), LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(1))))))))); } else if (isIn && isOptional && !isOut && isManagedParameterType && parameterTypeInfo is PointerTypeHandleInfo ptrInfo && ptrInfo.ElementType.IsValueType(parameterTypeSyntaxSettings) is true && this.canUseUnsafeAsRef) { diff --git a/test/GenerationSandbox.Tests/BasicTests.cs b/test/GenerationSandbox.Tests/BasicTests.cs index 89e1633d..beb76322 100644 --- a/test/GenerationSandbox.Tests/BasicTests.cs +++ b/test/GenerationSandbox.Tests/BasicTests.cs @@ -453,7 +453,7 @@ public void PathParseIconLocation_Friendly() sourceString.AsSpan().CopyTo(buffer); int result = PInvoke.PathParseIconLocation(ref buffer); Assert.Equal(3, result); - Assert.Equal("hi there", buffer.ToString()); + Assert.Equal("hi there\0", buffer.ToString()); } [Fact]