Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 42 additions & 2 deletions src/Core/ExtensionsManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using SwarmUI.Utils;
using System.IO;
using System.Reflection;
using System.Runtime.Loader;

namespace SwarmUI.Core;

Expand All @@ -15,11 +16,40 @@ public class ExtensionsManager
/// <summary>Hashset of folder names of all extensions currently loaded.</summary>
public HashSet<string> LoadedExtensionFolders = [];

/// <summary>Hashset of dependency names that are considered "core" and should not be loaded from the extension's folder.</summary>
public HashSet<string> CoreDependencyNames = [];

/// <summary>Simple holder of information about extensions available online.</summary>
public record class ExtensionInfo(string Name, string Author, string License, string Description, string URL, string OldURL, string[] Tags, string[] FolderNames)
{
}

private class SwarmExtensionLoadContext(ExtensionsManager manager, string name, string extensionDir) : AssemblyLoadContext(name, isCollectible: false)
{
/// <summary>Host wins, then we probe the extension's folder for private deps.</summary>
protected override Assembly Load(AssemblyName name)
{
string dependency = Path.Combine(extensionDir, name.Name + ".dll");
try
{
Default.LoadFromAssemblyName(name);
// We only get here if host successfully loads the assembly.
if (File.Exists(dependency) && !manager.CoreDependencyNames.Contains(name.Name))
{
Logs.Warning($"Extension {Name} ships {name.Name}.dll but host already has it loaded; using host copy.");
}
return null;
}
catch (FileNotFoundException) { }
if (!File.Exists(dependency))
{
return null;
}
Logs.Debug($"Extension {Name} loading private dep {name.Name} from {dependency}");
return LoadFromAssemblyPath(dependency);
}
}

public static HtmlString HtmlTags(string[] tags)
{
return new(tags.Select(t =>
Expand All @@ -45,6 +75,11 @@ public static HtmlString HtmlTags(string[] tags)
/// <summary>List of known online available extensions.</summary>
public List<ExtensionInfo> KnownExtensions = [];

private Assembly LoadInExtensionContext(string dllName, string targetPath)
{
return new SwarmExtensionLoadContext(this, dllName, Path.GetDirectoryName(targetPath)).LoadFromAssemblyPath(targetPath);
}

public static string ReferenceCsproj =
"""
<Project Sdk="Microsoft.NET.Sdk.Web">
Expand All @@ -58,6 +93,11 @@ public static HtmlString HtmlTags(string[] tags)
/// <summary>Initial call that prepares the extensions list.</summary>
public async Task PrepExtensions()
{
CoreDependencyNames =
[
.. AssemblyLoadContext.Default.Assemblies.Select(a => a.GetName().Name).Where(n => n is not null),
.. Directory.EnumerateFiles(AppContext.BaseDirectory, "*.dll").Select(Path.GetFileNameWithoutExtension)
];
await BuildPublicExtensionList();
string[] builtins = [.. Directory.EnumerateDirectories("./src/BuiltinExtensions").Select(s => "src/" + s.Replace('\\', '/').AfterLast("/src/"))];
string[] extras = Directory.Exists("./src/Extensions") ? [.. Directory.EnumerateDirectories("./src/Extensions/").Select(s => "src/" + s.Replace('\\', '/').AfterLast("/src/"))] : [];
Expand Down Expand Up @@ -181,7 +221,7 @@ public async Task<Assembly> BuildExtension(string folder, string projFile)
if (File.Exists(target) && !Program.IsDevMode)
{
Logs.Debug($"Don't need to rebuild extension {projFile}, already built.");
return Assembly.LoadFile(Path.GetFullPath(target));
return LoadInExtensionContext(dllName, Path.GetFullPath(target));
}
Logs.Debug($"Building extension project: {projFile}...");
string buildParam = $"-p:BaseIntermediateOutputPath={Path.GetFullPath($"./src/obj/extensions/{dllName}/")};TargetName={dllName}-{hash}";
Expand All @@ -195,7 +235,7 @@ public async Task<Assembly> BuildExtension(string folder, string projFile)
{
Logs.Debug($"Successful build output for extension project {projFile}:\n{output}");
}
return Assembly.LoadFile(Path.GetFullPath(target));
return LoadInExtensionContext(dllName, Path.GetFullPath(target));
}

public void PrepExtension(Type extType, bool isCore, string[] possible)
Expand Down
Loading