diff --git a/src/Core/ExtensionsManager.cs b/src/Core/ExtensionsManager.cs
index a2831daf6..2553dfa44 100644
--- a/src/Core/ExtensionsManager.cs
+++ b/src/Core/ExtensionsManager.cs
@@ -4,6 +4,7 @@
using SwarmUI.Utils;
using System.IO;
using System.Reflection;
+using System.Runtime.Loader;
namespace SwarmUI.Core;
@@ -15,11 +16,40 @@ public class ExtensionsManager
/// Hashset of folder names of all extensions currently loaded.
public HashSet LoadedExtensionFolders = [];
+ /// Hashset of dependency names that are considered "core" and should not be loaded from the extension's folder.
+ public HashSet CoreDependencyNames = [];
+
/// Simple holder of information about extensions available online.
public record class ExtensionInfo(string Name, string Author, string License, string Description, string URL, string OldURL, string[] Tags, string[] FolderNames)
{
}
+ private class SwarmExtensionLoadContext(ExtensionsManager manager, string name, string extensionDir) : AssemblyLoadContext(name, isCollectible: false)
+ {
+ /// Host wins, then we probe the extension's folder for private deps.
+ protected override Assembly Load(AssemblyName name)
+ {
+ string dependency = Path.Combine(extensionDir, name.Name + ".dll");
+ try
+ {
+ Default.LoadFromAssemblyName(name);
+ // We only get here if host successfully loads the assembly.
+ if (File.Exists(dependency) && !manager.CoreDependencyNames.Contains(name.Name))
+ {
+ Logs.Warning($"Extension {Name} ships {name.Name}.dll but host already has it loaded; using host copy.");
+ }
+ return null;
+ }
+ catch (FileNotFoundException) { }
+ if (!File.Exists(dependency))
+ {
+ return null;
+ }
+ Logs.Debug($"Extension {Name} loading private dep {name.Name} from {dependency}");
+ return LoadFromAssemblyPath(dependency);
+ }
+ }
+
public static HtmlString HtmlTags(string[] tags)
{
return new(tags.Select(t =>
@@ -45,6 +75,11 @@ public static HtmlString HtmlTags(string[] tags)
/// List of known online available extensions.
public List KnownExtensions = [];
+ private Assembly LoadInExtensionContext(string dllName, string targetPath)
+ {
+ return new SwarmExtensionLoadContext(this, dllName, Path.GetDirectoryName(targetPath)).LoadFromAssemblyPath(targetPath);
+ }
+
public static string ReferenceCsproj =
"""
@@ -58,6 +93,11 @@ public static HtmlString HtmlTags(string[] tags)
/// Initial call that prepares the extensions list.
public async Task PrepExtensions()
{
+ CoreDependencyNames =
+ [
+ .. AssemblyLoadContext.Default.Assemblies.Select(a => a.GetName().Name).Where(n => n is not null),
+ .. Directory.EnumerateFiles(AppContext.BaseDirectory, "*.dll").Select(Path.GetFileNameWithoutExtension)
+ ];
await BuildPublicExtensionList();
string[] builtins = [.. Directory.EnumerateDirectories("./src/BuiltinExtensions").Select(s => "src/" + s.Replace('\\', '/').AfterLast("/src/"))];
string[] extras = Directory.Exists("./src/Extensions") ? [.. Directory.EnumerateDirectories("./src/Extensions/").Select(s => "src/" + s.Replace('\\', '/').AfterLast("/src/"))] : [];
@@ -181,7 +221,7 @@ public async Task BuildExtension(string folder, string projFile)
if (File.Exists(target) && !Program.IsDevMode)
{
Logs.Debug($"Don't need to rebuild extension {projFile}, already built.");
- return Assembly.LoadFile(Path.GetFullPath(target));
+ return LoadInExtensionContext(dllName, Path.GetFullPath(target));
}
Logs.Debug($"Building extension project: {projFile}...");
string buildParam = $"-p:BaseIntermediateOutputPath={Path.GetFullPath($"./src/obj/extensions/{dllName}/")};TargetName={dllName}-{hash}";
@@ -195,7 +235,7 @@ public async Task BuildExtension(string folder, string projFile)
{
Logs.Debug($"Successful build output for extension project {projFile}:\n{output}");
}
- return Assembly.LoadFile(Path.GetFullPath(target));
+ return LoadInExtensionContext(dllName, Path.GetFullPath(target));
}
public void PrepExtension(Type extType, bool isCore, string[] possible)