Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions SharpPluginLoader.Core/Memory/Hook.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,30 @@

namespace SharpPluginLoader.Core.Memory
{
/// <summary>
/// Used to mark a method as a hook.
/// </summary>
/// <remarks>
/// Methods marked with this attribute must be in a class marked with <see cref="HookProviderAttribute"/>.
/// Additionally, the method must be static <b>unless</b> it is inside an <see cref="IPlugin"/> class.
/// </remarks>
[AttributeUsage(AttributeTargets.Method)]
public class HookAttribute : Attribute
{
public long Address { get; init; }
public string? Pattern { get; init; }
public int Offset { get; init; }
public bool Cache { get; init; }
}

/// <summary>
/// Used to mark a class as a hook provider. Classes marked with this attribute are allowed
/// to contain methods marked with <see cref="HookAttribute"/>. All hooks in the class will be
/// automatically registered when the plugin is loaded, and unregistered when the plugin is unloaded.
/// </summary>
[AttributeUsage(AttributeTargets.Class)]
public class HookProviderAttribute : Attribute;

public static class Hook
{
/// <summary>
Expand All @@ -22,6 +46,9 @@ public static Hook<TFunction> Create<TFunction>(long address, TFunction hook)
/// Represents a native function hook.
/// </summary>
/// <typeparam name="TFunction">The type of the hooked function</typeparam>
/// <remarks>
/// Use <see cref="Hook.Create{TFunction}(long, TFunction)"/> to create a new hook.
/// </remarks>
public class Hook<TFunction> : IDisposable
{
/// <summary>
Expand Down
56 changes: 56 additions & 0 deletions SharpPluginLoader.HookGenerator/Diagnostics.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
using Microsoft.CodeAnalysis;
using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Text;

namespace SharpPluginLoader.HookGenerator;

public enum DiagnosticCode
{
/// <summary>
/// Method with HookAttribute not inside a class marked with HookProviderAttribute.
/// </summary>
HSG001,
/// <summary>
/// HookAttribute with neither Address nor Pattern property.
/// </summary>
HSG002,
/// <summary>
/// HookProviderAttribute class is not partial.
/// </summary>
HSG003,
}

public static class Diagnostics
{
private static readonly Dictionary<DiagnosticCode, DiagnosticDescriptor> _diagnostics = new()
{
[DiagnosticCode.HSG001] = new DiagnosticDescriptor(
DiagnosticCode.HSG001.ToString(),
"HookGenerator",
"This method is marked with HookAttribute but is not inside a class marked with HookProviderAttribute."
"HookGenerator",
DiagnosticSeverity.Error,
true
),
[DiagnosticCode.HSG002] = new DiagnosticDescriptor(
DiagnosticCode.HSG002.ToString(),
"HookGenerator",
"HookAttribute must have either an Address or a Pattern property.",
"HookGenerator",
DiagnosticSeverity.Error,
true
),
[DiagnosticCode.HSG003] = new DiagnosticDescriptor(
DiagnosticCode.HSG003.ToString(),
"HookGenerator",
"HookProviderAttribute class must be partial.",
"HookGenerator",
DiagnosticSeverity.Error,
true
),
};

public static DiagnosticDescriptor GetDiagnostic(DiagnosticCode code) => _diagnostics[code];
}
22 changes: 22 additions & 0 deletions SharpPluginLoader.HookGenerator/HookMethod.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using Microsoft.CodeAnalysis;
using System;
using System.Collections.Generic;
using System.Net;
using System.Text;

namespace SharpPluginLoader.HookGenerator;

public class HookCollection(INamedTypeSymbol containingType)
{
public INamedTypeSymbol ContainingType = containingType;
public List<HookMethod> Methods = [];
}

public readonly struct HookMethod(IMethodSymbol method, long address, string? pattern, int offset, bool cache)
{
public IMethodSymbol Method { get; } = method;
public long Address { get; } = address;
public string? Pattern { get; } = pattern;
public int Offset { get; } = offset;
public bool Cache { get; } = cache;
}
176 changes: 176 additions & 0 deletions SharpPluginLoader.HookGenerator/HookSourceGenerator.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Text;
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Text;
using System.Threading;

namespace SharpPluginLoader.HookGenerator;

[Generator]
public class HookSourceGenerator : IIncrementalGenerator
{
public void Initialize(IncrementalGeneratorInitializationContext context)
{
var internalCalls = context.SyntaxProvider
.CreateSyntaxProvider(
predicate: static (s, _) => IsSyntaxTargetForGeneration(s),
transform: static (ctx, _) => GetSemanticTargetForGeneration(ctx)!)
.Where(static m => m is not null);

IncrementalValueProvider<(Compilation, ImmutableArray<MethodDeclarationSyntax>)> compilationAndInternalCalls
= context.CompilationProvider.Combine(internalCalls.Collect());

context.RegisterSourceOutput(compilationAndInternalCalls,
static (spc, source) => Execute(source.Item1, source.Item2, spc));
}

public static void Execute(Compilation compilation, ImmutableArray<MethodDeclarationSyntax> methods,
SourceProductionContext context)
{
if (methods.IsDefaultOrEmpty)
return;

var distinctICalls = methods.Distinct();

var methodsToGenerate = GetHooksToGenerate(compilation, distinctICalls, context, context.CancellationToken);

if (methodsToGenerate.Count == 0)
return;

foreach (var hookCollection in methodsToGenerate.Values)
{
var source = SourceGenerationHelper.GenerateHookClass(hookCollection);
context.AddSource($"{hookCollection.ContainingType.Name}_Hooks.g.cs", SourceText.From(source, Encoding.UTF8));
}
}

public static Dictionary<INamedTypeSymbol, HookCollection> GetHooksToGenerate(Compilation compilation,
IEnumerable<MethodDeclarationSyntax> methods, SourceProductionContext context, CancellationToken ct)
{
Dictionary<INamedTypeSymbol, HookCollection> hooksToGenerate = [];

var hookAttribute = compilation.GetTypeByMetadataName(SourceGenerationHelper.HookAttributeName);
if (hookAttribute is null)
return hooksToGenerate;

var hookProviderAttribute = compilation.GetTypeByMetadataName(SourceGenerationHelper.HookProviderAttributeName);
if (hookProviderAttribute is null)
return hooksToGenerate;

foreach (var method in methods)
{
ct.ThrowIfCancellationRequested();

var semanticModel = compilation.GetSemanticModel(method.SyntaxTree);
if (semanticModel.GetDeclaredSymbol(method) is not IMethodSymbol methodSymbol)
continue;

// Skip generic methods and partial definitions
if (methodSymbol.IsGenericMethod || methodSymbol.IsPartialDefinition)
continue;

var attributeData = methodSymbol.GetAttributes().FirstOrDefault(
ad => SymbolEqualityComparer.Default.Equals(ad.AttributeClass, hookAttribute));

if (attributeData is null)
continue;

// Check for HookProviderAttribute
var containingType = methodSymbol.ContainingType;
if (containingType is null)
continue;

if (containingType.GetAttributes().FirstOrDefault(
ad => SymbolEqualityComparer.Default.Equals(ad.AttributeClass, hookProviderAttribute)) is null)
{
ReportDiagnostic(DiagnosticCode.HSG001, method.GetLocation(), context);
continue;
}

// Check for named arguments
long address = 0;
string? pattern = null;
var offset = 0;
var cache = true;

if (attributeData is not null)
{
foreach (var namedArg in attributeData.NamedArguments)
{
switch (namedArg.Key)
{
case SourceGenerationHelper.AddressPropertyName:
address = (long)namedArg.Value.Value!;
break;
case SourceGenerationHelper.PatternPropertyName:
pattern = (string?)namedArg.Value.Value!;
break;
case SourceGenerationHelper.OffsetPropertyName:
offset = (int)namedArg.Value.Value!;
break;
case SourceGenerationHelper.CachePropertyName:
cache = (bool)namedArg.Value.Value!;
break;
}
}
}

if (address == 0 && pattern is null)
{
ReportDiagnostic(DiagnosticCode.HSG002, method.GetLocation(), context);
continue;
}

if (hooksToGenerate.TryGetValue(containingType, out var hookCollection))
{
hookCollection.Methods.Add(new HookMethod(methodSymbol, address, pattern, offset, cache));
}
else
{
hooksToGenerate[containingType] = new HookCollection(containingType)
{
Methods = [new HookMethod(methodSymbol, address, pattern, offset, cache)]
};
}
}

return hooksToGenerate;
}

public static bool IsSyntaxTargetForGeneration(SyntaxNode syntax)
{
return syntax is MethodDeclarationSyntax { AttributeLists.Count: > 0 };
}

public static MethodDeclarationSyntax? GetSemanticTargetForGeneration(GeneratorSyntaxContext context)
{
var method = (MethodDeclarationSyntax)context.Node;

foreach (var attributeList in method.AttributeLists)
{
foreach (var attribute in attributeList.Attributes)
{
if (context.SemanticModel.GetSymbolInfo(attribute).Symbol is not IMethodSymbol attributeSymbol)
continue;

var attributeType = attributeSymbol.ContainingType;
var fullName = attributeType?.ToDisplayString() ?? string.Empty;

if (fullName == SourceGenerationHelper.HookAttributeName)
return method;
}
}

return null;
}

private static void ReportDiagnostic(DiagnosticCode code, Location location, SourceProductionContext context)
{
var diagnostic = Diagnostic.Create(Diagnostics.GetDiagnostic(code), location);
context.ReportDiagnostic(diagnostic);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>netstandard2.0</TargetFramework>
<LangVersion>12.0</LangVersion>
<AllowUnsafeBlocks>True</AllowUnsafeBlocks>
<IncludeBuildOutput>false</IncludeBuildOutput>
<Nullable>enable</Nullable>
<EnforceExtendedAnalyzerRules>true</EnforceExtendedAnalyzerRules>
<IsRoslynComponent>true</IsRoslynComponent>
<GeneratePackageOnBuild>True</GeneratePackageOnBuild>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Microsoft.CodeAnalysis.Analyzers" Version="3.3.4">
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
<PackageReference Include="Microsoft.CodeAnalysis.CSharp" Version="4.8.0" />
</ItemGroup>

</Project>
Loading