diff --git a/src/Core/ExtensionsManager.cs b/src/Core/ExtensionsManager.cs index a2831daf6..2553dfa44 100644 --- a/src/Core/ExtensionsManager.cs +++ b/src/Core/ExtensionsManager.cs @@ -4,6 +4,7 @@ using SwarmUI.Utils; using System.IO; using System.Reflection; +using System.Runtime.Loader; namespace SwarmUI.Core; @@ -15,11 +16,40 @@ public class ExtensionsManager /// Hashset of folder names of all extensions currently loaded. public HashSet LoadedExtensionFolders = []; + /// Hashset of dependency names that are considered "core" and should not be loaded from the extension's folder. + public HashSet CoreDependencyNames = []; + /// Simple holder of information about extensions available online. public record class ExtensionInfo(string Name, string Author, string License, string Description, string URL, string OldURL, string[] Tags, string[] FolderNames) { } + private class SwarmExtensionLoadContext(ExtensionsManager manager, string name, string extensionDir) : AssemblyLoadContext(name, isCollectible: false) + { + /// Host wins, then we probe the extension's folder for private deps. + protected override Assembly Load(AssemblyName name) + { + string dependency = Path.Combine(extensionDir, name.Name + ".dll"); + try + { + Default.LoadFromAssemblyName(name); + // We only get here if host successfully loads the assembly. + if (File.Exists(dependency) && !manager.CoreDependencyNames.Contains(name.Name)) + { + Logs.Warning($"Extension {Name} ships {name.Name}.dll but host already has it loaded; using host copy."); + } + return null; + } + catch (FileNotFoundException) { } + if (!File.Exists(dependency)) + { + return null; + } + Logs.Debug($"Extension {Name} loading private dep {name.Name} from {dependency}"); + return LoadFromAssemblyPath(dependency); + } + } + public static HtmlString HtmlTags(string[] tags) { return new(tags.Select(t => @@ -45,6 +75,11 @@ public static HtmlString HtmlTags(string[] tags) /// List of known online available extensions. public List KnownExtensions = []; + private Assembly LoadInExtensionContext(string dllName, string targetPath) + { + return new SwarmExtensionLoadContext(this, dllName, Path.GetDirectoryName(targetPath)).LoadFromAssemblyPath(targetPath); + } + public static string ReferenceCsproj = """ @@ -58,6 +93,11 @@ public static HtmlString HtmlTags(string[] tags) /// Initial call that prepares the extensions list. public async Task PrepExtensions() { + CoreDependencyNames = + [ + .. AssemblyLoadContext.Default.Assemblies.Select(a => a.GetName().Name).Where(n => n is not null), + .. Directory.EnumerateFiles(AppContext.BaseDirectory, "*.dll").Select(Path.GetFileNameWithoutExtension) + ]; await BuildPublicExtensionList(); string[] builtins = [.. Directory.EnumerateDirectories("./src/BuiltinExtensions").Select(s => "src/" + s.Replace('\\', '/').AfterLast("/src/"))]; string[] extras = Directory.Exists("./src/Extensions") ? [.. Directory.EnumerateDirectories("./src/Extensions/").Select(s => "src/" + s.Replace('\\', '/').AfterLast("/src/"))] : []; @@ -181,7 +221,7 @@ public async Task BuildExtension(string folder, string projFile) if (File.Exists(target) && !Program.IsDevMode) { Logs.Debug($"Don't need to rebuild extension {projFile}, already built."); - return Assembly.LoadFile(Path.GetFullPath(target)); + return LoadInExtensionContext(dllName, Path.GetFullPath(target)); } Logs.Debug($"Building extension project: {projFile}..."); string buildParam = $"-p:BaseIntermediateOutputPath={Path.GetFullPath($"./src/obj/extensions/{dllName}/")};TargetName={dllName}-{hash}"; @@ -195,7 +235,7 @@ public async Task BuildExtension(string folder, string projFile) { Logs.Debug($"Successful build output for extension project {projFile}:\n{output}"); } - return Assembly.LoadFile(Path.GetFullPath(target)); + return LoadInExtensionContext(dllName, Path.GetFullPath(target)); } public void PrepExtension(Type extType, bool isCore, string[] possible)