Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Riok.Mapperly.Symbols.Members;
using static Riok.Mapperly.Emit.Syntax.SyntaxFactoryHelper;
Expand All @@ -17,7 +18,12 @@ private EnsureCapacityMethodSetter() { }

public bool SupportsCoalesceAssignment => false;

public ExpressionSyntax BuildAssignment(ExpressionSyntax? baseAccess, ExpressionSyntax valueToAssign, bool coalesceAssignment = false)
public ExpressionSyntax BuildAssignment(
ExpressionSyntax? baseAccess,
ExpressionSyntax valueToAssign,
INamedTypeSymbol? containingType = null,
bool coalesceAssignment = false
)
{
if (baseAccess == null)
throw new ArgumentNullException(nameof(baseAccess));
Expand Down
16 changes: 12 additions & 4 deletions src/Riok.Mapperly/Descriptors/UnsafeAccess/UnsafeFieldAccessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public MethodDeclarationSyntax BuildAccessorMethod(SourceEmitterContext ctx)
return ctx.SyntaxFactory.PublicStaticExternMethod(returnType, methodName, parameters, [attribute]);
}

public ExpressionSyntax BuildAccess(ExpressionSyntax? baseAccess, bool nullConditional = false)
public ExpressionSyntax BuildAccess(ExpressionSyntax? baseAccess, INamedTypeSymbol? containingType = null, bool nullConditional = false)
{
if (baseAccess == null)
throw new ArgumentNullException(nameof(baseAccess));
Expand All @@ -54,7 +54,10 @@ public ExpressionSyntax BuildAccess(ExpressionSyntax? baseAccess, bool nullCondi
return InvocationWithoutIndention(method);
}

var genericClassName = GenericName(className).WithTypeArgumentList(TypeArgumentList(symbol.ContainingType.TypeArguments));
// Use the passed containingType for type arguments if provided,
// otherwise fall back to the symbol's containing type.
var typeArgs = containingType?.TypeArguments ?? symbol.ContainingType.TypeArguments;
var genericClassName = GenericName(className).WithTypeArgumentList(TypeArgumentList(typeArgs));
var invocation = InvocationExpression(MemberAccess(genericClassName, methodName))
.WithArgumentList(ArgumentListWithoutIndention([baseAccess]));

Expand All @@ -64,9 +67,14 @@ public ExpressionSyntax BuildAccess(ExpressionSyntax? baseAccess, bool nullCondi
return Conditional(IsNotNull(baseAccess), invocation, DefaultLiteral());
}

public ExpressionSyntax BuildAssignment(ExpressionSyntax? baseAccess, ExpressionSyntax valueToAssign, bool coalesceAssignment = false)
public ExpressionSyntax BuildAssignment(
ExpressionSyntax? baseAccess,
ExpressionSyntax valueToAssign,
INamedTypeSymbol? containingType = null,
bool coalesceAssignment = false
)
{
var access = BuildAccess(baseAccess);
var access = BuildAccess(baseAccess, containingType);
return Assignment(access, valueToAssign, coalesceAssignment);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public MethodDeclarationSyntax BuildAccessorMethod(SourceEmitterContext ctx)
);
}

public ExpressionSyntax BuildAccess(ExpressionSyntax? baseAccess, bool nullConditional = false)
public ExpressionSyntax BuildAccess(ExpressionSyntax? baseAccess, INamedTypeSymbol? containingType = null, bool nullConditional = false)
{
if (baseAccess == null)
throw new ArgumentNullException(nameof(baseAccess));
Expand All @@ -57,7 +57,12 @@ public ExpressionSyntax BuildAccess(ExpressionSyntax? baseAccess, bool nullCondi
return InvocationWithoutIndention(method);
}

var genericClassName = GenericName(className).WithTypeArgumentList(TypeArgumentList(symbol.ContainingType.TypeArguments));
// Use the passed containingType for type arguments if provided,
// otherwise fall back to the symbol's containing type.
// This is critical for inherited members where the cached symbol's
// type arguments may differ from the actual derived type being mapped.
var typeArgs = containingType?.TypeArguments ?? symbol.ContainingType.TypeArguments;
var genericClassName = GenericName(className).WithTypeArgumentList(TypeArgumentList(typeArgs));
var invocation = InvocationExpression(MemberAccess(genericClassName, methodName))
.WithArgumentList(ArgumentListWithoutIndention([baseAccess]));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,12 @@ public MethodDeclarationSyntax BuildAccessorMethod(SourceEmitterContext ctx)
);
}

public ExpressionSyntax BuildAssignment(ExpressionSyntax? baseAccess, ExpressionSyntax valueToAssign, bool coalesceAssignment = false)
public ExpressionSyntax BuildAssignment(
ExpressionSyntax? baseAccess,
ExpressionSyntax valueToAssign,
INamedTypeSymbol? containingType = null,
bool coalesceAssignment = false
)
{
if (baseAccess == null)
throw new ArgumentNullException(nameof(baseAccess));
Expand All @@ -63,8 +68,13 @@ public ExpressionSyntax BuildAssignment(ExpressionSyntax? baseAccess, Expression
return InvocationWithoutIndention(MemberAccess(baseAccess, methodName), valueToAssign);
}

// Use the passed containingType for type arguments if provided,
// otherwise fall back to the symbol's containing type.
// This is critical for inherited members where the cached symbol's
// type arguments may differ from the actual derived type being mapped.
var typeArgs = containingType?.TypeArguments ?? symbol.ContainingType.TypeArguments;
var args = new[] { baseAccess, valueToAssign };
var genericClassName = GenericName(className).WithTypeArgumentList(TypeArgumentList(symbol.ContainingType.TypeArguments));
var genericClassName = GenericName(className).WithTypeArgumentList(TypeArgumentList(typeArgs));
return InvocationExpression(MemberAccess(genericClassName, methodName)).WithArgumentList(ArgumentListWithoutIndention(args));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,9 @@ public class ConstructorParameterMember(IParameterSymbol symbol, SymbolAccessor
public IMemberSetter BuildSetter(UnsafeAccessorContext ctx) =>
throw new InvalidOperationException($"Cannot create a setter for {nameof(ParameterSourceMember)}");

public ExpressionSyntax BuildAccess(ExpressionSyntax? baseAccess, bool nullConditional = false) => IdentifierName(Name);
public ExpressionSyntax BuildAccess(
ExpressionSyntax? baseAccess,
INamedTypeSymbol? containingType = null,
bool nullConditional = false
) => IdentifierName(Name);
}
9 changes: 7 additions & 2 deletions src/Riok.Mapperly/Symbols/Members/FieldMember.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,18 @@ public IMemberSetter BuildSetter(UnsafeAccessorContext ctx)
return ctx.GetOrBuildFieldGetter(this);
}

public ExpressionSyntax BuildAssignment(ExpressionSyntax? baseAccess, ExpressionSyntax valueToAssign, bool coalesceAssignment = false)
public ExpressionSyntax BuildAssignment(
ExpressionSyntax? baseAccess,
ExpressionSyntax valueToAssign,
INamedTypeSymbol? containingType = null,
bool coalesceAssignment = false
)
{
var targetMemberRef = BuildAccess(baseAccess);
return Assignment(targetMemberRef, valueToAssign, coalesceAssignment);
}

public ExpressionSyntax BuildAccess(ExpressionSyntax? baseAccess, bool nullConditional = false)
public ExpressionSyntax BuildAccess(ExpressionSyntax? baseAccess, INamedTypeSymbol? containingType = null, bool nullConditional = false)
{
if (baseAccess == null)
return SyntaxFactory.IdentifierName(Name);
Expand Down
3 changes: 2 additions & 1 deletion src/Riok.Mapperly/Symbols/Members/IMemberGetter.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;

namespace Riok.Mapperly.Symbols.Members;

public interface IMemberGetter
{
ExpressionSyntax BuildAccess(ExpressionSyntax? baseAccess, bool nullConditional = false);
ExpressionSyntax BuildAccess(ExpressionSyntax? baseAccess, INamedTypeSymbol? containingType = null, bool nullConditional = false);
}
8 changes: 7 additions & 1 deletion src/Riok.Mapperly/Symbols/Members/IMemberSetter.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;

namespace Riok.Mapperly.Symbols.Members;
Expand All @@ -6,5 +7,10 @@ public interface IMemberSetter
{
bool SupportsCoalesceAssignment { get; }

ExpressionSyntax BuildAssignment(ExpressionSyntax? baseAccess, ExpressionSyntax valueToAssign, bool coalesceAssignment = false);
ExpressionSyntax BuildAssignment(
ExpressionSyntax? baseAccess,
ExpressionSyntax valueToAssign,
INamedTypeSymbol? containingType = null,
bool coalesceAssignment = false
);
}
10 changes: 5 additions & 5 deletions src/Riok.Mapperly/Symbols/Members/MemberPathGetter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public static MemberPathGetter Build(SimpleMappingBuilderContext ctx, MemberPath
{
return path.AggregateWithPrevious(
baseAccess,
(expr, prevProp, prop) => prop.Getter.BuildAccess(expr, prevProp.Member?.IsNullable == true)
(expr, prevProp, prop) => prop.Getter.BuildAccess(expr, prop.Member.ContainingType, prevProp.Member?.IsNullable == true)
);
}

Expand All @@ -66,12 +66,12 @@ public static MemberPathGetter Build(SimpleMappingBuilderContext ctx, MemberPath
baseAccess,
(a, b) =>
b.Member.Type.IsNullableValueType()
? MemberAccess(b.Getter.BuildAccess(a), NullableValueProperty)
: b.Getter.BuildAccess(a)
? MemberAccess(b.Getter.BuildAccess(a, b.Member.ContainingType), NullableValueProperty)
: b.Getter.BuildAccess(a, b.Member.ContainingType)
);
}

return path.Aggregate(baseAccess, (a, b) => b.Getter.BuildAccess(a));
return path.Aggregate(baseAccess, (a, b) => b.Getter.BuildAccess(a, b.Member.ContainingType));
}

/// <summary>
Expand Down Expand Up @@ -99,7 +99,7 @@ private BinaryExpressionSyntax BuildNonNullCondition(ExpressionSyntax baseAccess
var conditions = new List<BinaryExpressionSyntax>();
foreach (var pathPart in nullablePath)
{
access = pathPart.Getter.BuildAccess(access);
access = pathPart.Getter.BuildAccess(access, pathPart.Member.ContainingType);

if (!pathPart.Member.IsNullable)
continue;
Expand Down
13 changes: 10 additions & 3 deletions src/Riok.Mapperly/Symbols/Members/MemberPathSetter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,19 @@ public class MemberPathSetter
private readonly NonEmptyMemberPath _memberPath;
private readonly MemberPathGetter _baseAccessGetter;
private readonly IMemberSetter _memberSetter;
private readonly IMappableMember _member;

private MemberPathSetter(NonEmptyMemberPath memberPath, MemberPathGetter baseAccessGetter, IMemberSetter memberSetter)
private MemberPathSetter(
NonEmptyMemberPath memberPath,
MemberPathGetter baseAccessGetter,
IMemberSetter memberSetter,
IMappableMember member
)
{
_memberPath = memberPath;
_baseAccessGetter = baseAccessGetter;
_memberSetter = memberSetter;
_member = member;
}

public bool SupportsCoalesceAssignment => _memberSetter.SupportsCoalesceAssignment;
Expand All @@ -30,13 +37,13 @@ public static MemberPathSetter Build(SimpleMappingBuilderContext ctx, NonEmptyMe
var objectPath = MemberPath.Create(path.RootType, path.ObjectPath.ToList());
var objectGetter = objectPath.BuildGetter(ctx);
var memberSetter = path.Member.BuildSetter(ctx.UnsafeAccessorContext);
return new MemberPathSetter(path, objectGetter, memberSetter);
return new MemberPathSetter(path, objectGetter, memberSetter, path.Member);
}

public ExpressionSyntax BuildAssignment(ExpressionSyntax? baseAccess, ExpressionSyntax valueToAssign, bool coalesceAssignment = false)
{
baseAccess = _baseAccessGetter.BuildAccess(baseAccess);
return _memberSetter.BuildAssignment(baseAccess, valueToAssign, coalesceAssignment);
return _memberSetter.BuildAssignment(baseAccess, valueToAssign, _member.ContainingType, coalesceAssignment);
}

public override bool Equals(object? obj)
Expand Down
6 changes: 5 additions & 1 deletion src/Riok.Mapperly/Symbols/Members/ParameterSourceMember.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ public class ParameterSourceMember(MethodParameter parameter) : IMappableMember,
public IMemberSetter BuildSetter(UnsafeAccessorContext ctx) =>
throw new InvalidOperationException($"Cannot create a setter for {nameof(ParameterSourceMember)}");

public ExpressionSyntax BuildAccess(ExpressionSyntax? baseAccess, bool nullConditional = false) => IdentifierName(Name);
public ExpressionSyntax BuildAccess(
ExpressionSyntax? baseAccess,
INamedTypeSymbol? containingType = null,
bool nullConditional = false
) => IdentifierName(Name);

public override bool Equals(object? obj)
{
Expand Down
9 changes: 7 additions & 2 deletions src/Riok.Mapperly/Symbols/Members/PropertyMember.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,20 @@ public IMemberSetter BuildSetter(UnsafeAccessorContext ctx)
return ctx.GetOrBuildPropertySetter(this);
}

public ExpressionSyntax BuildAssignment(ExpressionSyntax? baseAccess, ExpressionSyntax valueToAssign, bool coalesceAssignment = false)
public ExpressionSyntax BuildAssignment(
ExpressionSyntax? baseAccess,
ExpressionSyntax valueToAssign,
INamedTypeSymbol? containingType = null,
bool coalesceAssignment = false
)
{
Debug.Assert(CanSetDirectly);
ExpressionSyntax targetMember = baseAccess == null ? IdentifierName(Name) : MemberAccess(baseAccess, Name);

return Assignment(targetMember, valueToAssign, coalesceAssignment);
}

public ExpressionSyntax BuildAccess(ExpressionSyntax? baseAccess, bool nullConditional = false)
public ExpressionSyntax BuildAccess(ExpressionSyntax? baseAccess, INamedTypeSymbol? containingType = null, bool nullConditional = false)
{
Debug.Assert(CanGetDirectly);
if (baseAccess == null)
Expand Down
23 changes: 23 additions & 0 deletions test/Riok.Mapperly.Tests/Mapping/UnsafeAccessorTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,29 @@ public Task PrivatePropertyInGenericClassMultipleTypeParametersWithConstraints()
return TestHelper.VerifyGenerator(source);
}

[Fact]
public Task ProtectedPropertyInGenericBaseClassWithDifferentInstantiations()
{
var source = TestSourceBuilder.MapperWithBodyAndTypes(
"""
partial B1 Map1(A1 source);
partial B2 Map2(A2 source);
""",
TestSourceBuilderOptions.WithMemberVisibility(MemberVisibility.All),
"interface IA { }",
"interface IB : IA { }",
"interface IC : IA { }",
"class A<T> where T : IA { protected T _value { get; set; } }",
"class A1 : A<IB> { }",
"class A2 : A<IC> { }",
"class B<T> where T : IA { protected T _value { get; set; } }",
"class B1 : B<IB> { }",
"class B2 : B<IC> { }"
);

return TestHelper.VerifyGenerator(source);
}

[Fact]
public Task ProtectedProperty()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public partial class Mapper
partial global::B<float> Map(global::A<float> source)
{
var target = new global::B<float>();
BAccessor<int>.SetValue(target, AAccessor<int>.GetValue(source));
BAccessor<float>.SetValue(target, AAccessor<float>.GetValue(source));
return target;
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
//HintName: Mapper.g.cs
// <auto-generated />
#nullable enable
public partial class Mapper
{
[global::System.CodeDom.Compiler.GeneratedCode("Riok.Mapperly", "0.0.1.0")]
partial global::B1 Map1(global::A1 source)
{
var target = new global::B1();
BAccessor<global::IB>.SetValue(target, AAccessor<global::IB>.GetValue(source));
return target;
}

[global::System.CodeDom.Compiler.GeneratedCode("Riok.Mapperly", "0.0.1.0")]
partial global::B2 Map2(global::A2 source)
{
var target = new global::B2();
BAccessor<global::IC>.SetValue(target, AAccessor<global::IC>.GetValue(source));
return target;
}
}

[global::System.CodeDom.Compiler.GeneratedCode("Riok.Mapperly", "0.0.1.0")]
static file class AAccessor<T>
where T : global::IA
{
[global::System.CodeDom.Compiler.GeneratedCode("Riok.Mapperly", "0.0.1.0")]
[global::System.Runtime.CompilerServices.UnsafeAccessor(global::System.Runtime.CompilerServices.UnsafeAccessorKind.Method, Name = "get__value")]
public static extern T GetValue(global::A<T> source);
}

[global::System.CodeDom.Compiler.GeneratedCode("Riok.Mapperly", "0.0.1.0")]
static file class BAccessor<T>
where T : global::IA
{
[global::System.CodeDom.Compiler.GeneratedCode("Riok.Mapperly", "0.0.1.0")]
[global::System.Runtime.CompilerServices.UnsafeAccessor(global::System.Runtime.CompilerServices.UnsafeAccessorKind.Method, Name = "set__value")]
public static extern void SetValue(global::B<T> target, T value);
}
Loading