diff --git a/SharpPluginLoader.Core/Memory/Hook.cs b/SharpPluginLoader.Core/Memory/Hook.cs
index 1378880..6f1c63b 100644
--- a/SharpPluginLoader.Core/Memory/Hook.cs
+++ b/SharpPluginLoader.Core/Memory/Hook.cs
@@ -3,6 +3,30 @@
namespace SharpPluginLoader.Core.Memory
{
+ ///
+ /// Used to mark a method as a hook.
+ ///
+ ///
+ /// Methods marked with this attribute must be in a class marked with .
+ /// Additionally, the method must be static unless it is inside an class.
+ ///
+ [AttributeUsage(AttributeTargets.Method)]
+ public class HookAttribute : Attribute
+ {
+ public long Address { get; init; }
+ public string? Pattern { get; init; }
+ public int Offset { get; init; }
+ public bool Cache { get; init; }
+ }
+
+ ///
+ /// Used to mark a class as a hook provider. Classes marked with this attribute are allowed
+ /// to contain methods marked with . All hooks in the class will be
+ /// automatically registered when the plugin is loaded, and unregistered when the plugin is unloaded.
+ ///
+ [AttributeUsage(AttributeTargets.Class)]
+ public class HookProviderAttribute : Attribute;
+
public static class Hook
{
///
@@ -22,6 +46,9 @@ public static Hook Create(long address, TFunction hook)
/// Represents a native function hook.
///
/// The type of the hooked function
+ ///
+ /// Use to create a new hook.
+ ///
public class Hook : IDisposable
{
///
diff --git a/SharpPluginLoader.HookGenerator/Diagnostics.cs b/SharpPluginLoader.HookGenerator/Diagnostics.cs
new file mode 100644
index 0000000..fa84632
--- /dev/null
+++ b/SharpPluginLoader.HookGenerator/Diagnostics.cs
@@ -0,0 +1,56 @@
+using Microsoft.CodeAnalysis;
+using System;
+using System.Collections.Generic;
+using System.Runtime.CompilerServices;
+using System.Text;
+
+namespace SharpPluginLoader.HookGenerator;
+
+public enum DiagnosticCode
+{
+ ///
+ /// Method with HookAttribute not inside a class marked with HookProviderAttribute.
+ ///
+ HSG001,
+ ///
+ /// HookAttribute with neither Address nor Pattern property.
+ ///
+ HSG002,
+ ///
+ /// HookProviderAttribute class is not partial.
+ ///
+ HSG003,
+}
+
+public static class Diagnostics
+{
+ private static readonly Dictionary _diagnostics = new()
+ {
+ [DiagnosticCode.HSG001] = new DiagnosticDescriptor(
+ DiagnosticCode.HSG001.ToString(),
+ "HookGenerator",
+ "This method is marked with HookAttribute but is not inside a class marked with HookProviderAttribute."
+ "HookGenerator",
+ DiagnosticSeverity.Error,
+ true
+ ),
+ [DiagnosticCode.HSG002] = new DiagnosticDescriptor(
+ DiagnosticCode.HSG002.ToString(),
+ "HookGenerator",
+ "HookAttribute must have either an Address or a Pattern property.",
+ "HookGenerator",
+ DiagnosticSeverity.Error,
+ true
+ ),
+ [DiagnosticCode.HSG003] = new DiagnosticDescriptor(
+ DiagnosticCode.HSG003.ToString(),
+ "HookGenerator",
+ "HookProviderAttribute class must be partial.",
+ "HookGenerator",
+ DiagnosticSeverity.Error,
+ true
+ ),
+ };
+
+ public static DiagnosticDescriptor GetDiagnostic(DiagnosticCode code) => _diagnostics[code];
+}
diff --git a/SharpPluginLoader.HookGenerator/HookMethod.cs b/SharpPluginLoader.HookGenerator/HookMethod.cs
new file mode 100644
index 0000000..87b78dc
--- /dev/null
+++ b/SharpPluginLoader.HookGenerator/HookMethod.cs
@@ -0,0 +1,22 @@
+using Microsoft.CodeAnalysis;
+using System;
+using System.Collections.Generic;
+using System.Net;
+using System.Text;
+
+namespace SharpPluginLoader.HookGenerator;
+
+public class HookCollection(INamedTypeSymbol containingType)
+{
+ public INamedTypeSymbol ContainingType = containingType;
+ public List Methods = [];
+}
+
+public readonly struct HookMethod(IMethodSymbol method, long address, string? pattern, int offset, bool cache)
+{
+ public IMethodSymbol Method { get; } = method;
+ public long Address { get; } = address;
+ public string? Pattern { get; } = pattern;
+ public int Offset { get; } = offset;
+ public bool Cache { get; } = cache;
+}
diff --git a/SharpPluginLoader.HookGenerator/HookSourceGenerator.cs b/SharpPluginLoader.HookGenerator/HookSourceGenerator.cs
new file mode 100644
index 0000000..1c98559
--- /dev/null
+++ b/SharpPluginLoader.HookGenerator/HookSourceGenerator.cs
@@ -0,0 +1,176 @@
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CSharp.Syntax;
+using Microsoft.CodeAnalysis.Text;
+using System;
+using System.Collections.Generic;
+using System.Collections.Immutable;
+using System.Linq;
+using System.Text;
+using System.Threading;
+
+namespace SharpPluginLoader.HookGenerator;
+
+[Generator]
+public class HookSourceGenerator : IIncrementalGenerator
+{
+ public void Initialize(IncrementalGeneratorInitializationContext context)
+ {
+ var internalCalls = context.SyntaxProvider
+ .CreateSyntaxProvider(
+ predicate: static (s, _) => IsSyntaxTargetForGeneration(s),
+ transform: static (ctx, _) => GetSemanticTargetForGeneration(ctx)!)
+ .Where(static m => m is not null);
+
+ IncrementalValueProvider<(Compilation, ImmutableArray)> compilationAndInternalCalls
+ = context.CompilationProvider.Combine(internalCalls.Collect());
+
+ context.RegisterSourceOutput(compilationAndInternalCalls,
+ static (spc, source) => Execute(source.Item1, source.Item2, spc));
+ }
+
+ public static void Execute(Compilation compilation, ImmutableArray methods,
+ SourceProductionContext context)
+ {
+ if (methods.IsDefaultOrEmpty)
+ return;
+
+ var distinctICalls = methods.Distinct();
+
+ var methodsToGenerate = GetHooksToGenerate(compilation, distinctICalls, context, context.CancellationToken);
+
+ if (methodsToGenerate.Count == 0)
+ return;
+
+ foreach (var hookCollection in methodsToGenerate.Values)
+ {
+ var source = SourceGenerationHelper.GenerateHookClass(hookCollection);
+ context.AddSource($"{hookCollection.ContainingType.Name}_Hooks.g.cs", SourceText.From(source, Encoding.UTF8));
+ }
+ }
+
+ public static Dictionary GetHooksToGenerate(Compilation compilation,
+ IEnumerable methods, SourceProductionContext context, CancellationToken ct)
+ {
+ Dictionary hooksToGenerate = [];
+
+ var hookAttribute = compilation.GetTypeByMetadataName(SourceGenerationHelper.HookAttributeName);
+ if (hookAttribute is null)
+ return hooksToGenerate;
+
+ var hookProviderAttribute = compilation.GetTypeByMetadataName(SourceGenerationHelper.HookProviderAttributeName);
+ if (hookProviderAttribute is null)
+ return hooksToGenerate;
+
+ foreach (var method in methods)
+ {
+ ct.ThrowIfCancellationRequested();
+
+ var semanticModel = compilation.GetSemanticModel(method.SyntaxTree);
+ if (semanticModel.GetDeclaredSymbol(method) is not IMethodSymbol methodSymbol)
+ continue;
+
+ // Skip generic methods and partial definitions
+ if (methodSymbol.IsGenericMethod || methodSymbol.IsPartialDefinition)
+ continue;
+
+ var attributeData = methodSymbol.GetAttributes().FirstOrDefault(
+ ad => SymbolEqualityComparer.Default.Equals(ad.AttributeClass, hookAttribute));
+
+ if (attributeData is null)
+ continue;
+
+ // Check for HookProviderAttribute
+ var containingType = methodSymbol.ContainingType;
+ if (containingType is null)
+ continue;
+
+ if (containingType.GetAttributes().FirstOrDefault(
+ ad => SymbolEqualityComparer.Default.Equals(ad.AttributeClass, hookProviderAttribute)) is null)
+ {
+ ReportDiagnostic(DiagnosticCode.HSG001, method.GetLocation(), context);
+ continue;
+ }
+
+ // Check for named arguments
+ long address = 0;
+ string? pattern = null;
+ var offset = 0;
+ var cache = true;
+
+ if (attributeData is not null)
+ {
+ foreach (var namedArg in attributeData.NamedArguments)
+ {
+ switch (namedArg.Key)
+ {
+ case SourceGenerationHelper.AddressPropertyName:
+ address = (long)namedArg.Value.Value!;
+ break;
+ case SourceGenerationHelper.PatternPropertyName:
+ pattern = (string?)namedArg.Value.Value!;
+ break;
+ case SourceGenerationHelper.OffsetPropertyName:
+ offset = (int)namedArg.Value.Value!;
+ break;
+ case SourceGenerationHelper.CachePropertyName:
+ cache = (bool)namedArg.Value.Value!;
+ break;
+ }
+ }
+ }
+
+ if (address == 0 && pattern is null)
+ {
+ ReportDiagnostic(DiagnosticCode.HSG002, method.GetLocation(), context);
+ continue;
+ }
+
+ if (hooksToGenerate.TryGetValue(containingType, out var hookCollection))
+ {
+ hookCollection.Methods.Add(new HookMethod(methodSymbol, address, pattern, offset, cache));
+ }
+ else
+ {
+ hooksToGenerate[containingType] = new HookCollection(containingType)
+ {
+ Methods = [new HookMethod(methodSymbol, address, pattern, offset, cache)]
+ };
+ }
+ }
+
+ return hooksToGenerate;
+ }
+
+ public static bool IsSyntaxTargetForGeneration(SyntaxNode syntax)
+ {
+ return syntax is MethodDeclarationSyntax { AttributeLists.Count: > 0 };
+ }
+
+ public static MethodDeclarationSyntax? GetSemanticTargetForGeneration(GeneratorSyntaxContext context)
+ {
+ var method = (MethodDeclarationSyntax)context.Node;
+
+ foreach (var attributeList in method.AttributeLists)
+ {
+ foreach (var attribute in attributeList.Attributes)
+ {
+ if (context.SemanticModel.GetSymbolInfo(attribute).Symbol is not IMethodSymbol attributeSymbol)
+ continue;
+
+ var attributeType = attributeSymbol.ContainingType;
+ var fullName = attributeType?.ToDisplayString() ?? string.Empty;
+
+ if (fullName == SourceGenerationHelper.HookAttributeName)
+ return method;
+ }
+ }
+
+ return null;
+ }
+
+ private static void ReportDiagnostic(DiagnosticCode code, Location location, SourceProductionContext context)
+ {
+ var diagnostic = Diagnostic.Create(Diagnostics.GetDiagnostic(code), location);
+ context.ReportDiagnostic(diagnostic);
+ }
+}
diff --git a/SharpPluginLoader.HookGenerator/SharpPluginLoader.HookGenerator.csproj b/SharpPluginLoader.HookGenerator/SharpPluginLoader.HookGenerator.csproj
new file mode 100644
index 0000000..b362c89
--- /dev/null
+++ b/SharpPluginLoader.HookGenerator/SharpPluginLoader.HookGenerator.csproj
@@ -0,0 +1,22 @@
+
+
+
+ netstandard2.0
+ 12.0
+ True
+ false
+ enable
+ true
+ true
+ True
+
+
+
+
+ all
+ runtime; build; native; contentfiles; analyzers; buildtransitive
+
+
+
+
+
diff --git a/SharpPluginLoader.HookGenerator/SourceGenerationHelper.cs b/SharpPluginLoader.HookGenerator/SourceGenerationHelper.cs
new file mode 100644
index 0000000..6cae369
--- /dev/null
+++ b/SharpPluginLoader.HookGenerator/SourceGenerationHelper.cs
@@ -0,0 +1,99 @@
+using Microsoft.CodeAnalysis;
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace SharpPluginLoader.HookGenerator;
+
+public static class SourceGenerationHelper
+{
+ public const string HookAttributeName = "SharpPluginLoader.Core.Memory.HookAttribute";
+ public const string HookProviderAttributeName = "SharpPluginLoader.Core.Memory.HookProviderAttribute";
+ public const string AddressPropertyName = "Address";
+ public const string PatternPropertyName = "Pattern";
+ public const string OffsetPropertyName = "Offset";
+ public const string CachePropertyName = "Cache";
+
+ private static StringBuilder? _sb;
+ private static int _indentLevel;
+ public static string GenerateHookClass(HookCollection hooks)
+ {
+ if (hooks.Methods.Count == 0)
+ return string.Empty;
+
+ _sb = new StringBuilder();
+ _indentLevel = 0;
+
+ var containingType = hooks.ContainingType;
+ var namespaceName = containingType.ContainingNamespace.ToDisplayString();
+ var className = containingType.Name;
+
+ AppendLine("#nullable enable");
+ AppendLine("using System;");
+ AppendLine("using System.Runtime.CompilerServices;");
+ AppendLine("using System.Runtime.InteropServices;");
+ AppendLine("using SharpPluginLoader.Core.Memory;");
+
+ var classDefSb = new StringBuilder();
+
+ classDefSb.Append(containingType.DeclaredAccessibility switch
+ {
+ Accessibility.Public => "public ",
+ Accessibility.Internal => "internal ",
+ Accessibility.Private => "private ",
+ Accessibility.Protected => "protected ",
+ Accessibility.ProtectedAndInternal => "protected internal ",
+ Accessibility.ProtectedOrInternal => "protected internal ",
+ _ => "internal "
+ });
+
+ if (containingType.IsStatic)
+ classDefSb.Append("static ");
+
+ classDefSb.Append("partial class ");
+
+ classDefSb.Append(className);
+
+ Append($$"""
+
+ namespace {{namespaceName}};
+
+ {{classDefSb}}
+ {
+
+ """);
+ Indent();
+
+ foreach (var method in hooks.Methods)
+ {
+
+ }
+
+ return _sb.ToString();
+ }
+
+ private static void Append(string value)
+ {
+ _sb!.Append(_indentLevel == 0 ? value : new string(' ', _indentLevel * 4) + value);
+ }
+
+ private static void AppendLine(string value)
+ {
+ _sb!.AppendLine(_indentLevel == 0 ? value : new string(' ', _indentLevel * 4) + value);
+ }
+
+ private static void AppendLine()
+ {
+ _sb!.AppendLine();
+ }
+
+ private static void Indent()
+ {
+ _indentLevel++;
+ }
+
+ private static void Unindent()
+ {
+ _indentLevel--;
+ }
+}
diff --git a/mhw-cs-plugin-loader.sln b/mhw-cs-plugin-loader.sln
index 09215d9..7e8f405 100644
--- a/mhw-cs-plugin-loader.sln
+++ b/mhw-cs-plugin-loader.sln
@@ -42,6 +42,8 @@ Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "ExperimentalTesting.Native"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "DtiHierarchyVisualizer", "Examples\DtiHierarchyVisualizer\DtiHierarchyVisualizer.csproj", "{803FDF24-F0D1-4643-B595-13FEEE59BE32}"
EndProject
+Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "SharpPluginLoader.HookGenerator", "SharpPluginLoader.HookGenerator\SharpPluginLoader.HookGenerator.csproj", "{FE138097-BA2A-420F-A157-E664B424FC95}"
+EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@@ -394,6 +396,30 @@ Global
{803FDF24-F0D1-4643-B595-13FEEE59BE32}.RelWithDebInfo|x64.Build.0 = Release|Any CPU
{803FDF24-F0D1-4643-B595-13FEEE59BE32}.RelWithDebInfo|x86.ActiveCfg = Release|Any CPU
{803FDF24-F0D1-4643-B595-13FEEE59BE32}.RelWithDebInfo|x86.Build.0 = Release|Any CPU
+ {FE138097-BA2A-420F-A157-E664B424FC95}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {FE138097-BA2A-420F-A157-E664B424FC95}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {FE138097-BA2A-420F-A157-E664B424FC95}.Debug|x64.ActiveCfg = Debug|Any CPU
+ {FE138097-BA2A-420F-A157-E664B424FC95}.Debug|x64.Build.0 = Debug|Any CPU
+ {FE138097-BA2A-420F-A157-E664B424FC95}.Debug|x86.ActiveCfg = Debug|Any CPU
+ {FE138097-BA2A-420F-A157-E664B424FC95}.Debug|x86.Build.0 = Debug|Any CPU
+ {FE138097-BA2A-420F-A157-E664B424FC95}.MinSizeRel|Any CPU.ActiveCfg = Debug|Any CPU
+ {FE138097-BA2A-420F-A157-E664B424FC95}.MinSizeRel|Any CPU.Build.0 = Debug|Any CPU
+ {FE138097-BA2A-420F-A157-E664B424FC95}.MinSizeRel|x64.ActiveCfg = Debug|Any CPU
+ {FE138097-BA2A-420F-A157-E664B424FC95}.MinSizeRel|x64.Build.0 = Debug|Any CPU
+ {FE138097-BA2A-420F-A157-E664B424FC95}.MinSizeRel|x86.ActiveCfg = Debug|Any CPU
+ {FE138097-BA2A-420F-A157-E664B424FC95}.MinSizeRel|x86.Build.0 = Debug|Any CPU
+ {FE138097-BA2A-420F-A157-E664B424FC95}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {FE138097-BA2A-420F-A157-E664B424FC95}.Release|Any CPU.Build.0 = Release|Any CPU
+ {FE138097-BA2A-420F-A157-E664B424FC95}.Release|x64.ActiveCfg = Release|Any CPU
+ {FE138097-BA2A-420F-A157-E664B424FC95}.Release|x64.Build.0 = Release|Any CPU
+ {FE138097-BA2A-420F-A157-E664B424FC95}.Release|x86.ActiveCfg = Release|Any CPU
+ {FE138097-BA2A-420F-A157-E664B424FC95}.Release|x86.Build.0 = Release|Any CPU
+ {FE138097-BA2A-420F-A157-E664B424FC95}.RelWithDebInfo|Any CPU.ActiveCfg = Release|Any CPU
+ {FE138097-BA2A-420F-A157-E664B424FC95}.RelWithDebInfo|Any CPU.Build.0 = Release|Any CPU
+ {FE138097-BA2A-420F-A157-E664B424FC95}.RelWithDebInfo|x64.ActiveCfg = Release|Any CPU
+ {FE138097-BA2A-420F-A157-E664B424FC95}.RelWithDebInfo|x64.Build.0 = Release|Any CPU
+ {FE138097-BA2A-420F-A157-E664B424FC95}.RelWithDebInfo|x86.ActiveCfg = Release|Any CPU
+ {FE138097-BA2A-420F-A157-E664B424FC95}.RelWithDebInfo|x86.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
@@ -413,6 +439,7 @@ Global
{CBD13575-CA0C-4EB5-821D-847A6EE5581F} = {22614599-FA47-4E69-8DBB-A183F0A4AE25}
{E3BA4F2C-E79F-486E-8051-0835D8057C7D} = {93226706-71AC-41BC-A457-21D3CAA8C751}
{803FDF24-F0D1-4643-B595-13FEEE59BE32} = {93226706-71AC-41BC-A457-21D3CAA8C751}
+ {FE138097-BA2A-420F-A157-E664B424FC95} = {22614599-FA47-4E69-8DBB-A183F0A4AE25}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {623F8CBC-2C3B-488D-B7BB-0140A8201BB7}