Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions NewType.Generator/AliasAttributeSource.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ internal enum NewtypeOptions
/// <summary>Suppress implicit conversions and constructor forwarding.</summary>
Opaque = NoImplicitConversions | NoConstructorForwarding,
}

/// <summary>
/// Marks a partial type as a type alias for the specified type.
/// The source generator will generate implicit conversions, operator forwarding,
Expand All @@ -54,7 +54,7 @@ public newtypeAttribute() { }

/// <summary>Controls which features the generator emits.</summary>
public NewtypeOptions Options { get; set; }

/// <summary>
/// Overrides the MethodImplOptions applied to generated members.
/// Default is <see cref="global::System.Runtime.CompilerServices.MethodImplOptions.AggressiveInlining"/>.
Expand Down
162 changes: 124 additions & 38 deletions NewType.Generator/AliasCodeGenerator.cs

Large diffs are not rendered by default.

78 changes: 63 additions & 15 deletions NewType.Generator/AliasGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,30 @@ public class AliasGenerator : IIncrementalGenerator
public void Initialize(IncrementalGeneratorInitializationContext context)
{
// Register the attribute source
context.RegisterPostInitializationOutput(ctx => { ctx.AddSource("newtypeAttribute.g.cs", SourceText.From(NewtypeAttributeSource.Source, Encoding.UTF8)); });
context.RegisterPostInitializationOutput(ctx =>
{
ctx.AddSource("newtypeAttribute.g.cs", SourceText.From(NewtypeAttributeSource.Source, Encoding.UTF8));
ctx.AddSource("newtypeConstraintAttribute.g.cs",
SourceText.From(ConstraintAttributeSource.Source, Encoding.UTF8));
});

// Pipeline for generic [newtype<T>] attribute
var genericPipeline = context.SyntaxProvider
IncrementalValuesProvider<AliasModel> genericPipeline = context.SyntaxProvider
.ForAttributeWithMetadataName(
"newtype.newtypeAttribute`1",
predicate: static (node, _) => node is TypeDeclarationSyntax,
transform: static (ctx, _) => ExtractGenericModel(ctx))
.Where(static model => model is not null)
.Select(static (model, _) => model!.Value);
.Where(static model => model is not null)!;

context.RegisterSourceOutput(genericPipeline, static (spc, model) => GenerateAliasCode(spc, model));

// Pipeline for non-generic [newtype(typeof(T))] attribute
var nonGenericPipeline = context.SyntaxProvider
IncrementalValuesProvider<AliasModel> nonGenericPipeline = context.SyntaxProvider
.ForAttributeWithMetadataName(
"newtype.newtypeAttribute",
predicate: static (node, _) => node is TypeDeclarationSyntax,
transform: static (ctx, _) => ExtractNonGenericModel(ctx))
.Where(static model => model is not null)
.Select(static (model, _) => model!.Value);
.Where(static model => model is not null)!;

context.RegisterSourceOutput(nonGenericPipeline, static (spc, model) => GenerateAliasCode(spc, model));
}
Expand All @@ -45,14 +48,15 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
foreach (var attributeData in context.Attributes)
{
var attributeClass = attributeData.AttributeClass;
if (attributeClass is {IsGenericType: true} &&
if (attributeClass is { IsGenericType: true } &&
attributeClass.TypeArguments.Length == 1)
{
var aliasedType = attributeClass.TypeArguments[0];
var (options, methodImpl) = ExtractNamedArguments(attributeData);
return AliasModelExtractor.Extract(context, aliasedType, options, methodImpl);
var options = ExtractNamedArguments(attributeData);
return AliasModelExtractor.Extract(context, aliasedType, options);
}
}

return null;
}

Expand All @@ -63,10 +67,12 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
if (attributeData.ConstructorArguments.Length > 0 &&
attributeData.ConstructorArguments[0].Value is ITypeSymbol aliasedType)
{
var (options, methodImpl) = ExtractNamedArguments(attributeData);
return AliasModelExtractor.Extract(context, aliasedType, options, methodImpl);
var options = ExtractNamedArguments(attributeData);

return AliasModelExtractor.Extract(context, aliasedType, options);
}
}

return null;
}

Expand All @@ -75,7 +81,8 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
private const int DefaultOptions = 0;
private const int DefaultMethodImplAggressiveInlining = 256;

private static (int options, int methodImpl) ExtractNamedArguments(AttributeData attributeData)
private static ExtractedOptions ExtractNamedArguments(
AttributeData attributeData)
{
int options = DefaultOptions;
int methodImpl = DefaultMethodImplAggressiveInlining;
Expand All @@ -93,17 +100,58 @@ private static (int options, int methodImpl) ExtractNamedArguments(AttributeData
}
}

return (options, methodImpl);
return new ExtractedOptions(options, methodImpl);
}

private static void GenerateAliasCode(
SourceProductionContext context,
AliasModel model)
{
if (!model.ConstraintModel.Valid)
{
context.ReportDiagnostic(
Diagnostic.Create(
ValidatorInvalidDiagnostic, model.ConstraintModel.LocationInfo?.ToLocation(),
model.TypeName,
model.ConstraintModel.ValidationSymbolName ?? "Method",
model.AliasedTypeMinimalName
));

return;
}

if (model.ConstraintModel.Multiple)
{
context.ReportDiagnostic(
Diagnostic.Create(ValidatorMultipleDiagnostic, model.LocationInfo?.ToLocation())
);
return;
}

var generator = new AliasCodeGenerator(model);
var source = generator.Generate();

var fileName = $"{model.TypeDisplayString.Replace(".", "_").Replace("<", "_").Replace(">", "_")}.g.cs";
context.AddSource(fileName, SourceText.From(source, Encoding.UTF8));
}
}

private static readonly DiagnosticDescriptor ValidatorInvalidDiagnostic =
new(
id: "NEWTYPE001",
title: "Malformed validation method",
messageFormat: "Incorrectly formed validation method for type '{0}'. Expected signature: 'bool {1}({2})'.",
category: "Unknown",
DiagnosticSeverity.Error,
isEnabledByDefault: true
);

private static readonly DiagnosticDescriptor ValidatorMultipleDiagnostic =
new(
id: "NEWTYPE002",
title: "Multiple validators",
messageFormat: "Only a single validation method should be used",
category: "Unknown",
DiagnosticSeverity.Error,
isEnabledByDefault: true
);
}
19 changes: 17 additions & 2 deletions NewType.Generator/AliasModel.cs
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
using System;
using System.Collections.Immutable;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.Text;

namespace newtype.generator;

/// <summary>
/// Fully-extracted, equatable model representing a newtype alias.
/// Contains only strings, bools, plain enums, and EquatableArrays — no Roslyn symbols.
/// </summary>
internal readonly record struct AliasModel(
internal record AliasModel(
// Type being declared
string TypeName,
string Namespace,
Accessibility DeclaredAccessibility,
bool IsReadonly,
bool IsClass,
bool IsRecord,
bool IsRecordStruct,

// Location for messages
LocationInfo? LocationInfo,

// Aliased type
string AliasedTypeFullName,
Expand All @@ -41,6 +44,7 @@ internal readonly record struct AliasModel(
bool SuppressImplicitUnwrap,
bool SuppressConstructorForwarding,
int MethodImplValue,
ConstraintModel ConstraintModel,

// Members
EquatableArray<BinaryOperatorInfo> BinaryOperators,
Expand All @@ -52,6 +56,8 @@ internal readonly record struct AliasModel(
EquatableArray<ConstructorInfo> ForwardedConstructors
);

internal readonly record struct ExtractedOptions(int Options, int MethodImpl);

internal readonly record struct BinaryOperatorInfo(
string Name,
string LeftTypeFullName,
Expand Down Expand Up @@ -116,3 +122,12 @@ internal readonly record struct ConstructorParameterInfo(
bool IsParams,
string? DefaultValueLiteral
) : IEquatable<ConstructorParameterInfo>;

internal readonly record struct LocationInfo(
string FilePath,
TextSpan TextSpan,
LinePositionSpan LineSpan
) : IEquatable<LocationInfo>
{
public Location ToLocation() => Location.Create(FilePath, TextSpan, LineSpan);
}
90 changes: 80 additions & 10 deletions NewType.Generator/AliasModelExtractor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Text;

namespace newtype.generator;

Expand All @@ -17,25 +18,23 @@ internal static class AliasModelExtractor
private const int OptionsNoImplicitUnwrap = 2;
private const int OptionsNoConstructorForwarding = 4;

public static AliasModel? Extract(
public static AliasModel? Extract(
GeneratorAttributeSyntaxContext context,
ITypeSymbol aliasedType,
int options,
int methodImpl)
ExtractedOptions allOptions)
{
var typeDecl = (TypeDeclarationSyntax)context.TargetNode;
var typeSymbol = (INamedTypeSymbol)context.TargetSymbol;

var typeName = typeSymbol.Name;
var ns = typeSymbol.ContainingNamespace;
var namespaceName = ns is {IsGlobalNamespace: false} ? ns.ToDisplayString() : "";
var namespaceName = ns is { IsGlobalNamespace: false } ? ns.ToDisplayString() : "";

var isReadonly = typeDecl.Modifiers.Any(SyntaxKind.ReadOnlyKeyword);
var isClass = typeDecl is ClassDeclarationSyntax
|| (typeDecl is RecordDeclarationSyntax rds
&& !rds.ClassOrStructKeyword.IsKind(SyntaxKind.StructKeyword));
var isRecord = typeDecl is RecordDeclarationSyntax;
var isRecordStruct = isRecord && !isClass;

var aliasedTypeFullName = aliasedType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
var aliasedTypeMinimalName = aliasedType.ToDisplayString(SymbolDisplayFormat.MinimallyQualifiedFormat);
Expand All @@ -56,14 +55,16 @@ internal static class AliasModelExtractor

var typeDisplayString = typeSymbol.ToDisplayString();

var constraintModel = ExtractValidationMethod(typeSymbol, aliasedType);

return new AliasModel(
TypeName: typeName,
Namespace: namespaceName,
DeclaredAccessibility: typeSymbol.DeclaredAccessibility,
IsReadonly: isReadonly,
IsClass: isClass,
IsRecord: isRecord,
IsRecordStruct: isRecordStruct,
LocationInfo: ToLocationStruct(typeSymbol.Locations.FirstOrDefault()),
AliasedTypeFullName: aliasedTypeFullName,
AliasedTypeMinimalName: aliasedTypeMinimalName,
AliasedTypeSpecialType: aliasedType.SpecialType,
Expand All @@ -73,10 +74,11 @@ internal static class AliasModelExtractor
HasNativeEqualityOperator: hasNativeEquality,
TypeDisplayString: typeDisplayString,
HasStaticMemberCandidates: hasStaticMemberCandidates,
SuppressImplicitWrap: (options & OptionsNoImplicitWrap) != 0,
SuppressImplicitUnwrap: (options & OptionsNoImplicitUnwrap) != 0,
SuppressConstructorForwarding: (options & OptionsNoConstructorForwarding) != 0,
MethodImplValue: methodImpl,
SuppressImplicitWrap: (allOptions.Options & OptionsNoImplicitWrap) != 0,
SuppressImplicitUnwrap: (allOptions.Options & OptionsNoImplicitUnwrap) != 0,
SuppressConstructorForwarding: (allOptions.Options & OptionsNoConstructorForwarding) != 0,
MethodImplValue: allOptions.MethodImpl,
ConstraintModel: constraintModel,
BinaryOperators: binaryOperators,
UnaryOperators: unaryOperators,
StaticMembers: staticMembers,
Expand Down Expand Up @@ -355,6 +357,54 @@ private static string GetConstructorSignature(IMethodSymbol ctor)
}));
}

private static ConstraintModel ExtractValidationMethod(ITypeSymbol targetType,
ITypeSymbol aliasedType)
{
IMethodSymbol? validationMethod = null;
Location? location = null;
bool invalid = false;
bool multiple = false;
bool inRelease = false;

foreach (var method in targetType.GetMembers().OfType<IMethodSymbol>())
{
foreach (var attributeData in method.GetAttributes().Where(x => x is not null))
{
if (attributeData.AttributeClass!.Name == ConstraintAttributeSource.AttributeName)
{
foreach (var arg in attributeData.NamedArguments)
{
if (arg.Key == "IncludeInRelease")
{
inRelease = (bool)arg.Value.Value!;
}
}

// doesn't have to be static
var methodValid = method.ReturnType.SpecialType == SpecialType.System_Boolean &&
method.Parameters.Length == 1 &&
SymbolEqualityComparer.Default.Equals(
method.Parameters[0].Type,
aliasedType);

invalid |= !methodValid;

if (validationMethod == null)
{
validationMethod = method;
location = method.Locations[0];
}
else
{
multiple = true;
}
}
}
}

return new ConstraintModel(validationMethod?.Name, inRelease, !invalid, multiple, ToLocationStruct(location));
}

private static string FormatDefaultValue(IParameterSymbol param)
{
var value = param.ExplicitDefaultValue;
Expand Down Expand Up @@ -416,4 +466,24 @@ private static bool ImplementsInterface(ITypeSymbol type, string interfaceFullNa

return type.AllInterfaces.Any(i => i.ToDisplayString() == interfaceFullName);
}

private static LocationInfo? ToLocationStruct(Location? location) =>
location is not null && location.IsInSource ?
new LocationInfo(
location.SourceTree.FilePath,
location.SourceSpan,
new LinePositionSpan(
location.GetLineSpan().StartLinePosition,
location.GetLineSpan().EndLinePosition))
:null;
}

internal record ConstraintModel(
string? ValidationSymbolName,
bool InRelease,
bool Valid,
bool Multiple,
LocationInfo? LocationInfo)
{
public bool UseConstraints => ValidationSymbolName is not null && Valid && !Multiple;
};
Loading