From b9f61c30aa1a76af631df86e236513a6a711e2d5 Mon Sep 17 00:00:00 2001 From: YeongJae Min Date: Fri, 19 Jun 2026 10:14:41 +0900 Subject: [PATCH] Use BeanInstanceSupplier with matching arguments BeanFactory#getBean(..., args) previously skipped instance suppliers whenever explicit arguments were provided. In AOT/native mode, this could make a prototype bean with runtime constructor arguments use a different instantiation path and miss generated post-processing. Allow BeanInstanceSupplier to opt in when the explicit arguments exactly match the selected constructor or factory method, while keeping the regular bean factory path for other argument shapes. Closes gh-36649 Signed-off-by: YeongJae Min --- .../factory/aot/BeanInstanceSupplier.java | 64 +++++++- .../AbstractAutowireCapableBeanFactory.java | 37 ++++- .../support/DefaultListableBeanFactory.java | 11 ++ .../factory/support/InstanceSupplier.java | 40 +++++ .../aot/BeanInstanceSupplierTests.java | 80 +++++++++ .../support/BeanFactorySupplierTests.java | 154 ++++++++++++++++++ .../support/InstanceSupplierTests.java | 43 +++++ .../ApplicationContextAotGeneratorTests.java | 42 +++++ 8 files changed, 461 insertions(+), 10 deletions(-) diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanInstanceSupplier.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanInstanceSupplier.java index 48a1fcca45ce..20afadba61c5 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanInstanceSupplier.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanInstanceSupplier.java @@ -189,23 +189,52 @@ public BeanInstanceSupplier withShortcut(String... beanNames) { @Override public T get(RegisteredBean registeredBean) { Assert.notNull(registeredBean, "'registeredBean' must not be null"); + return get(registeredBean, (AutowiredArguments) null); + } + + @Override + public T get(RegisteredBean registeredBean, @Nullable Object @Nullable ... args) { + Assert.notNull(registeredBean, "'registeredBean' must not be null"); + return get(registeredBean, (args != null ? AutowiredArguments.of(args) : null)); + } + + @Override + public boolean supportsExplicitArguments(@Nullable Object @Nullable ... args) { + return (this.generatorWithoutArguments == null && this.lookup.supportsArguments(args)); + } + + @SuppressWarnings("unchecked") + private T get(RegisteredBean registeredBean, @Nullable AutowiredArguments explicitArguments) { if (this.generatorWithoutArguments != null) { + Assert.isTrue(explicitArguments == null || explicitArguments.toArray().length == 0, + "Explicit arguments are not supported"); Executable executable = getFactoryMethodForGenerator(); return invokeBeanSupplier(executable, () -> this.generatorWithoutArguments.apply(registeredBean)); } else if (this.generatorWithArguments != null) { Executable executable = getFactoryMethodForGenerator(); - AutowiredArguments arguments = resolveArguments(registeredBean, - executable != null ? executable : this.lookup.get(registeredBean)); + Executable argumentsExecutable = (executable != null ? executable : this.lookup.get(registeredBean)); + AutowiredArguments arguments = (explicitArguments != null ? explicitArguments : + resolveArguments(registeredBean, argumentsExecutable)); + validateArgumentCount(argumentsExecutable, arguments); return invokeBeanSupplier(executable, () -> this.generatorWithArguments.apply(registeredBean, arguments)); } else { Executable executable = this.lookup.get(registeredBean); - @Nullable Object[] arguments = resolveArguments(registeredBean, executable).toArray(); + @Nullable Object[] arguments = (explicitArguments != null ? + explicitArguments.toArray() : resolveArguments(registeredBean, executable).toArray()); + validateArgumentCount(executable, AutowiredArguments.of(arguments)); return invokeBeanSupplier(executable, () -> (T) instantiate(registeredBean, executable, arguments)); } } + private void validateArgumentCount(Executable executable, AutowiredArguments arguments) { + int argumentCount = arguments.toArray().length; + int parameterCount = executable.getParameterCount(); + Assert.isTrue(argumentCount == parameterCount, + () -> "Incorrect number of arguments: expected " + parameterCount + " but got " + argumentCount); + } + @Override public @Nullable Method getFactoryMethod() { // Cached factory method retrieval for qualifier introspection etc. @@ -383,6 +412,25 @@ private static String toCommaSeparatedNames(Class... parameterTypes) { abstract static class ExecutableLookup { abstract Executable get(RegisteredBean registeredBean); + + abstract Class[] getParameterTypes(); + + boolean supportsArguments(@Nullable Object @Nullable [] args) { + if (args == null) { + return true; + } + Class[] parameterTypes = getParameterTypes(); + if (args.length != parameterTypes.length) { + return false; + } + for (int i = 0; i < args.length; i++) { + @Nullable Object arg = args[i]; + if (arg == null || ClassUtils.resolvePrimitiveIfNecessary(parameterTypes[i]) != arg.getClass()) { + return false; + } + } + return true; + } } @@ -409,6 +457,11 @@ public Executable get(RegisteredBean registeredBean) { } } + @Override + Class[] getParameterTypes() { + return this.parameterTypes; + } + @Override public String toString() { return "Constructor with parameter types [%s]".formatted(toCommaSeparatedNames(this.parameterTypes)); @@ -451,6 +504,11 @@ Method get() { return method; } + @Override + Class[] getParameterTypes() { + return this.parameterTypes; + } + @Override public String toString() { return "Factory method '%s' with parameter types [%s] declared on %s".formatted( diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/support/AbstractAutowireCapableBeanFactory.java b/spring-beans/src/main/java/org/springframework/beans/factory/support/AbstractAutowireCapableBeanFactory.java index 95bd2e43fa8c..a654d1b04650 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/support/AbstractAutowireCapableBeanFactory.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/support/AbstractAutowireCapableBeanFactory.java @@ -1181,11 +1181,9 @@ protected BeanWrapper createBeanInstance(String beanName, RootBeanDefinition mbd "Bean class isn't public, and non-public access not allowed: " + beanClass.getName()); } - if (args == null) { - Supplier instanceSupplier = mbd.getInstanceSupplier(); - if (instanceSupplier != null) { - return obtainFromSupplier(instanceSupplier, beanName, mbd); - } + Supplier instanceSupplier = mbd.getInstanceSupplier(); + if (instanceSupplier != null && supportsInstanceSupplierWithArguments(instanceSupplier, args)) { + return obtainFromSupplier(instanceSupplier, beanName, mbd, args); } if (mbd.getFactoryMethodName() != null) { @@ -1229,19 +1227,28 @@ protected BeanWrapper createBeanInstance(String beanName, RootBeanDefinition mbd return instantiateBean(beanName, mbd); } + private boolean supportsInstanceSupplierWithArguments( + Supplier instanceSupplier, @Nullable Object @Nullable [] args) { + + return (args == null || (instanceSupplier instanceof InstanceSupplier supplier && + supplier.supportsExplicitArguments(args))); + } + /** * Obtain a bean instance from the given supplier. * @param supplier the configured supplier * @param beanName the corresponding bean name * @return a BeanWrapper for the new instance */ - private BeanWrapper obtainFromSupplier(Supplier supplier, String beanName, RootBeanDefinition mbd) { + private BeanWrapper obtainFromSupplier(Supplier supplier, String beanName, RootBeanDefinition mbd, + @Nullable Object @Nullable [] args) { + String outerBean = this.currentlyCreatedBean.get(); this.currentlyCreatedBean.set(beanName); Object instance; try { - instance = obtainInstanceFromSupplier(supplier, beanName, mbd); + instance = obtainInstanceFromSupplier(supplier, beanName, mbd, args); } catch (Throwable ex) { if (ex instanceof BeanCreationException bce && beanName.equals(bce.getBeanName())) { @@ -1283,6 +1290,22 @@ private BeanWrapper obtainFromSupplier(Supplier supplier, String beanName, Ro return supplier.get(); } + /** + * Obtain a bean instance from the given supplier. + * @param supplier the configured supplier + * @param beanName the corresponding bean name + * @param mbd the bean definition for the bean + * @param args explicit arguments passed in programmatically via the getBean method, + * or {@code null} if none + * @return the bean instance (possibly {@code null}) + * @since 7.1 + */ + protected @Nullable Object obtainInstanceFromSupplier(Supplier supplier, String beanName, RootBeanDefinition mbd, + @Nullable Object @Nullable [] args) throws Exception { + + return obtainInstanceFromSupplier(supplier, beanName, mbd); + } + /** * Overridden in order to implicitly register the currently created bean as * dependent on further beans getting programmatically retrieved during a diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/support/DefaultListableBeanFactory.java b/spring-beans/src/main/java/org/springframework/beans/factory/support/DefaultListableBeanFactory.java index a2a57bcb8191..66eea76d748e 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/support/DefaultListableBeanFactory.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/support/DefaultListableBeanFactory.java @@ -1026,6 +1026,17 @@ protected boolean isBeanEligibleForMetadataCaching(String beanName) { return super.obtainInstanceFromSupplier(supplier, beanName, mbd); } + @Override + protected @Nullable Object obtainInstanceFromSupplier(Supplier supplier, String beanName, RootBeanDefinition mbd, + @Nullable Object @Nullable [] args) throws Exception { + + if (args != null && supplier instanceof InstanceSupplier instanceSupplier) { + RegisteredBean registeredBean = RegisteredBean.of(this, beanName, mbd); + return instanceSupplier.get(registeredBean, args); + } + return obtainInstanceFromSupplier(supplier, beanName, mbd); + } + @Override protected void cacheMergedBeanDefinition(RootBeanDefinition mbd, String beanName) { super.cacheMergedBeanDefinition(mbd, beanName); diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/support/InstanceSupplier.java b/spring-beans/src/main/java/org/springframework/beans/factory/support/InstanceSupplier.java index b74220fb17a8..ba6e7fa7e56b 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/support/InstanceSupplier.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/support/InstanceSupplier.java @@ -22,6 +22,7 @@ import org.jspecify.annotations.Nullable; import org.springframework.util.Assert; +import org.springframework.util.ObjectUtils; import org.springframework.util.function.ThrowingBiFunction; import org.springframework.util.function.ThrowingSupplier; @@ -54,6 +55,37 @@ default T getWithException() { */ T get(RegisteredBean registeredBean) throws Exception; + /** + * Get the supplied instance using the specified explicit arguments. + * @param registeredBean the registered bean requesting the instance + * @param args the explicit arguments to use + * @return the supplied instance + * @throws Exception on error + * @since 7.1 + * @see #supportsExplicitArguments(Object...) + */ + default T get(RegisteredBean registeredBean, @Nullable Object @Nullable ... args) throws Exception { + if (!ObjectUtils.isEmpty(args)) { + throw new UnsupportedOperationException("Retrieval with arguments not supported - " + + "for custom InstanceSupplier classes, implement get(RegisteredBean, Object...) for your purposes"); + } + return get(registeredBean); + } + + /** + * Return whether this supplier supports the specified explicit arguments. + *

The bean factory calls this method before using + * {@link #get(RegisteredBean, Object...)} for explicit arguments. Custom + * implementations that override {@code get(RegisteredBean, Object...)} + * should return {@code true} for supported argument arrangements. + * @param args the explicit arguments to check + * @return {@code true} if this supplier supports explicit arguments + * @since 7.1 + */ + default boolean supportsExplicitArguments(@Nullable Object @Nullable ... args) { + return false; + } + /** * Return the factory method that this supplier uses to create the * instance, or {@code null} if it is not known or this supplier uses @@ -83,6 +115,14 @@ public V get(RegisteredBean registeredBean) throws Exception { return after.applyWithException(registeredBean, InstanceSupplier.this.get(registeredBean)); } @Override + public V get(RegisteredBean registeredBean, @Nullable Object @Nullable ... args) throws Exception { + return after.applyWithException(registeredBean, InstanceSupplier.this.get(registeredBean, args)); + } + @Override + public boolean supportsExplicitArguments(@Nullable Object @Nullable ... args) { + return InstanceSupplier.this.supportsExplicitArguments(args); + } + @Override public @Nullable Method getFactoryMethod() { return InstanceSupplier.this.getFactoryMethod(); } diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanInstanceSupplierTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanInstanceSupplierTests.java index 560cf37054c2..5f85d1d05a68 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanInstanceSupplierTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanInstanceSupplierTests.java @@ -216,6 +216,86 @@ void getWithGeneratorCallsBiFunction() { assertThat(((AutowiredArguments) result.get(0)).toArray()).containsExactly("1"); } + @Test + void getWithGeneratorAndExplicitArgumentsCallsBiFunction() { + BeanRegistrar registrar = new BeanRegistrar(SingleArgConstructor.class); + RegisteredBean registerBean = registrar.registerBean(this.beanFactory); + List result = new ArrayList<>(); + BeanInstanceSupplier resolver = BeanInstanceSupplier.forConstructor(String.class) + .withGenerator((registeredBean, args) -> result.add(args)); + resolver.get(registerBean, "test"); + assertThat(result).hasSize(1); + assertThat(((AutowiredArguments) result.get(0)).toArray()).containsExactly("test"); + } + + @Test + void getWithGeneratorAndIncorrectExplicitArgumentCountThrowsException() { + BeanRegistrar registrar = new BeanRegistrar(SingleArgConstructor.class); + RegisteredBean registerBean = registrar.registerBean(this.beanFactory); + BeanInstanceSupplier resolver = BeanInstanceSupplier.forConstructor(String.class) + .withGenerator((registeredBean, args) -> args.get(0)); + assertThatIllegalArgumentException() + .isThrownBy(() -> resolver.get(registerBean, "test", "extra")) + .withMessage("Incorrect number of arguments: expected 1 but got 2"); + } + + @Test + void getWithNoGeneratorAndExplicitArgumentsUsesReflection() { + BeanRegistrar registrar = new BeanRegistrar(SingleArgConstructor.class); + RegisteredBean registerBean = registrar.registerBean(this.beanFactory); + BeanInstanceSupplier resolver = BeanInstanceSupplier.forConstructor(String.class); + assertThat(resolver.get(registerBean, "test").getString()).isEqualTo("test"); + } + + @Test + void supportsExplicitArgumentsReturnsTrue() { + BeanInstanceSupplier resolver = BeanInstanceSupplier.forConstructor(); + assertThat(resolver.supportsExplicitArguments()).isTrue(); + } + + @Test + void supportsExplicitArgumentsWhenArgumentCountMatchesReturnsTrue() { + BeanInstanceSupplier resolver = BeanInstanceSupplier.forConstructor(String.class); + assertThat(resolver.supportsExplicitArguments("test")).isTrue(); + } + + @Test + void supportsExplicitArgumentsWhenArgumentCountDoesNotMatchReturnsFalse() { + BeanInstanceSupplier resolver = BeanInstanceSupplier.forConstructor(String.class); + assertThat(resolver.supportsExplicitArguments("test", "extra")).isFalse(); + } + + @Test + void supportsExplicitArgumentsWhenArgumentTypeDoesNotMatchReturnsFalse() { + BeanInstanceSupplier resolver = BeanInstanceSupplier.forConstructor(String.class); + assertThat(resolver.supportsExplicitArguments(1)).isFalse(); + } + + @Test + void supportsExplicitArgumentsWhenArgumentTypeIsAssignableButNotExactReturnsFalse() { + BeanInstanceSupplier resolver = BeanInstanceSupplier.forConstructor(Object.class); + assertThat(resolver.supportsExplicitArguments("test")).isFalse(); + } + + @Test + void supportsExplicitArgumentsWhenArgumentIsNullReturnsFalse() { + BeanInstanceSupplier resolver = BeanInstanceSupplier.forConstructor(String.class); + assertThat(resolver.supportsExplicitArguments((Object) null)).isFalse(); + } + + @Test + void supportsExplicitArgumentsWhenPrimitiveParameterMatchesWrapperReturnsTrue() { + BeanInstanceSupplier resolver = BeanInstanceSupplier.forConstructor(int.class); + assertThat(resolver.supportsExplicitArguments(1)).isTrue(); + } + + @Test + void supportsExplicitArgumentsWhenUsingGeneratorWithoutArgumentsReturnsFalse() { + BeanInstanceSupplier resolver = BeanInstanceSupplier.forConstructor() + .withGenerator(registeredBean -> "test"); + assertThat(resolver.supportsExplicitArguments()).isFalse(); + } + @Test void getWithGeneratorCallsFunction() { BeanRegistrar registrar = new BeanRegistrar(SingleArgConstructor.class); diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/support/BeanFactorySupplierTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/support/BeanFactorySupplierTests.java index 22a0af8a752e..7b2b5bbaaec1 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/support/BeanFactorySupplierTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/support/BeanFactorySupplierTests.java @@ -17,10 +17,13 @@ package org.springframework.beans.factory.support; import java.io.IOException; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Supplier; import org.junit.jupiter.api.Test; import org.springframework.beans.factory.BeanCreationException; +import org.springframework.beans.factory.aot.BeanInstanceSupplier; import org.springframework.util.function.ThrowingSupplier; import static org.assertj.core.api.Assertions.assertThat; @@ -64,6 +67,115 @@ void getBeanWhenUsingInstanceSupplier() { assertThat(beanFactory.getBean("test")).isEqualTo("I am bean test of class java.lang.String"); } + @Test + void getBeanWhenUsingInstanceSupplierUsesObtainInstanceFromSupplierExtensionPoint() { + AtomicBoolean called = new AtomicBoolean(); + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory() { + @Override + protected Object obtainInstanceFromSupplier(Supplier supplier, String beanName, RootBeanDefinition mbd) + throws Exception { + + called.set(true); + return super.obtainInstanceFromSupplier(supplier, beanName, mbd); + } + }; + RootBeanDefinition beanDefinition = new RootBeanDefinition(String.class); + beanDefinition.setInstanceSupplier(InstanceSupplier.of(registeredBean -> "I am supplied")); + beanFactory.registerBeanDefinition("test", beanDefinition); + assertThat(beanFactory.getBean("test")).isEqualTo("I am supplied"); + assertThat(called).isTrue(); + } + + @Test + void getBeanWithExplicitArgumentsWhenUsingSupportingInstanceSupplier() { + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + RootBeanDefinition beanDefinition = new RootBeanDefinition(Object.class); + beanDefinition.setInstanceSupplier(new InstanceSupplier<>() { + @Override + public Object get(RegisteredBean registeredBean) { + return "I am supplied"; + } + @Override + public Object get(RegisteredBean registeredBean, Object... args) { + return args[0]; + } + @Override + public boolean supportsExplicitArguments(Object... args) { + return true; + } + }); + beanFactory.registerBeanDefinition("test", beanDefinition); + assertThat(beanFactory.getBean("test", "I am supplied with an argument")) + .isEqualTo("I am supplied with an argument"); + } + + @Test + void getBeanWithEmptyExplicitArgumentsWhenUsingSupportingInstanceSupplier() { + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + RootBeanDefinition beanDefinition = new RootBeanDefinition(Object.class); + beanDefinition.setInstanceSupplier(new InstanceSupplier<>() { + @Override + public Object get(RegisteredBean registeredBean) { + return "I am supplied"; + } + @Override + public Object get(RegisteredBean registeredBean, Object... args) { + return args.length; + } + @Override + public boolean supportsExplicitArguments(Object... args) { + return true; + } + }); + beanFactory.registerBeanDefinition("test", beanDefinition); + assertThat(beanFactory.getBean("test", new Object[0])).isEqualTo(0); + } + + @Test + void getBeanWithEmptyExplicitArgumentsWhenUsingNonSupportingInstanceSupplier() { + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + RootBeanDefinition beanDefinition = new RootBeanDefinition(String.class); + beanDefinition.setInstanceSupplier(InstanceSupplier.of(registeredBean -> "I am supplied")); + beanFactory.registerBeanDefinition("test", beanDefinition); + assertThat(beanFactory.getBean("test", new Object[0])).isEqualTo(""); + } + + @Test + void getBeanWithExplicitArgumentsWhenUsingUnsupportedBeanInstanceSupplierArgumentsUsesRegularInstantiation() { + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + RootBeanDefinition beanDefinition = new RootBeanDefinition(MultiConstructorBean.class); + beanDefinition.setInstanceSupplier(BeanInstanceSupplier.forConstructor(String.class) + .withGenerator((registeredBean, args) -> new MultiConstructorBean(args.get(0), null))); + beanFactory.registerBeanDefinition("test", beanDefinition); + MultiConstructorBean bean = beanFactory.getBean(MultiConstructorBean.class, "test", 1); + assertThat(bean.name).isEqualTo("test"); + assertThat(bean.counter).isEqualTo(1); + } + + @Test + void getBeanWithExplicitArgumentsWhenUsingUnsupportedBeanInstanceSupplierArgumentTypesUsesRegularInstantiation() { + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + RootBeanDefinition beanDefinition = new RootBeanDefinition(MultiConstructorBean.class); + beanDefinition.setInstanceSupplier(BeanInstanceSupplier.forConstructor(String.class) + .withGenerator((registeredBean, args) -> new MultiConstructorBean(args.get(0), null))); + beanFactory.registerBeanDefinition("test", beanDefinition); + MultiConstructorBean bean = beanFactory.getBean(MultiConstructorBean.class, 1); + assertThat(bean.name).isNull(); + assertThat(bean.counter).isEqualTo(1); + } + + @Test + void getBeanWithExplicitArgumentsWhenUsingAssignableBeanInstanceSupplierArgumentsUsesRegularInstantiation() { + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + RootBeanDefinition beanDefinition = new RootBeanDefinition(OverloadedConstructorBean.class); + beanDefinition.setInstanceSupplier(BeanInstanceSupplier.forConstructor(Object.class) + .withGenerator((registeredBean, args) -> new OverloadedConstructorBean(args.get(0), "supplier"))); + beanFactory.registerBeanDefinition("test", beanDefinition); + OverloadedConstructorBean bean = beanFactory.getBean(OverloadedConstructorBean.class, "test"); + assertThat(bean.argument).isEqualTo("test"); + assertThat(bean.constructor).isEqualTo("string"); + } + @Test void getBeanWithInnerBeanUsingInstanceSupplier() { DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); @@ -124,4 +236,46 @@ void getBeanWhenUsingThrowableSupplierThatThrowsRuntimeException() { .withCauseInstanceOf(IllegalStateException.class); } + + static class MultiConstructorBean { + + final String name; + + final Integer counter; + + MultiConstructorBean(String name) { + this(name, null); + } + + MultiConstructorBean(Integer counter) { + this(null, counter); + } + + MultiConstructorBean(String name, Integer counter) { + this.name = name; + this.counter = counter; + } + } + + + static class OverloadedConstructorBean { + + final Object argument; + + final String constructor; + + OverloadedConstructorBean(Object argument) { + this(argument, "object"); + } + + OverloadedConstructorBean(String argument) { + this(argument, "string"); + } + + OverloadedConstructorBean(Object argument, String constructor) { + this.argument = argument; + this.constructor = constructor; + } + } + } diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/support/InstanceSupplierTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/support/InstanceSupplierTests.java index bd1ddb133e07..b181f1b5f392 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/support/InstanceSupplierTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/support/InstanceSupplierTests.java @@ -23,6 +23,7 @@ import org.springframework.util.function.ThrowingBiFunction; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatIllegalStateException; @@ -56,6 +57,26 @@ void getReturnsResult() throws Exception { assertThat(supplier.get(this.registeredBean)).isEqualTo("test"); } + @Test + void getWithEmptyExplicitArgumentsReturnsResult() throws Exception { + InstanceSupplier supplier = registeredBean -> "test"; + assertThat(supplier.get(this.registeredBean, new Object[0])).isEqualTo("test"); + } + + @Test + void getWithExplicitArgumentsWhenNotSupportedThrowsException() { + InstanceSupplier supplier = registeredBean -> "test"; + assertThatExceptionOfType(UnsupportedOperationException.class) + .isThrownBy(() -> supplier.get(this.registeredBean, "test")) + .withMessageContaining("Retrieval with arguments not supported"); + } + + @Test + void supportsExplicitArgumentsReturnsFalse() { + InstanceSupplier supplier = registeredBean -> "test"; + assertThat(supplier.supportsExplicitArguments()).isFalse(); + } + @Test void andThenWhenFunctionIsNullThrowsException() { InstanceSupplier supplier = registeredBean -> "test"; @@ -72,6 +93,28 @@ void andThenAppliesFunctionToObtainResult() throws Exception { assertThat(supplier.get(this.registeredBean)).isEqualTo("test-bean"); } + @Test + void andThenAppliesFunctionToObtainResultWithExplicitArguments() throws Exception { + InstanceSupplier supplier = new InstanceSupplier<>() { + @Override + public String get(RegisteredBean registeredBean) { + return "bean"; + } + @Override + public String get(RegisteredBean registeredBean, Object... args) { + return (String) args[0]; + } + @Override + public boolean supportsExplicitArguments(Object... args) { + return true; + } + }; + supplier = supplier.andThen( + (registeredBean, string) -> registeredBean.getBeanName() + "-" + string); + assertThat(supplier.supportsExplicitArguments()).isTrue(); + assertThat(supplier.get(this.registeredBean, "bean")).isEqualTo("test-bean"); + } + @Test void andThenWhenInstanceSupplierHasFactoryMethod() throws Exception { Method factoryMethod = getClass().getDeclaredMethod("andThenWhenInstanceSupplierHasFactoryMethod"); diff --git a/spring-context/src/test/java/org/springframework/context/aot/ApplicationContextAotGeneratorTests.java b/spring-context/src/test/java/org/springframework/context/aot/ApplicationContextAotGeneratorTests.java index 96afaa990ec6..b5f3d1929183 100644 --- a/spring-context/src/test/java/org/springframework/context/aot/ApplicationContextAotGeneratorTests.java +++ b/spring-context/src/test/java/org/springframework/context/aot/ApplicationContextAotGeneratorTests.java @@ -39,6 +39,8 @@ import org.springframework.aot.hint.predicate.RuntimeHintsPredicates; import org.springframework.aot.test.generate.TestGenerationContext; import org.springframework.beans.BeansException; +import org.springframework.beans.factory.BeanFactory; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.AutowiredAnnotationBeanPostProcessor; import org.springframework.beans.factory.aot.AotProcessingException; import org.springframework.beans.factory.aot.BeanFactoryInitializationAotContribution; @@ -293,6 +295,25 @@ void processAheadOfTimeWhenHasAutowiringOnUnresolvedGeneric() { }); } + @Test // gh-36649 + void processAheadOfTimeWhenPrototypeBeanHasAutowiringAndRuntimeConstructorArguments() { + GenericApplicationContext applicationContext = new GenericApplicationContext(); + registerBeanPostProcessor(applicationContext, + AnnotationConfigUtils.AUTOWIRED_ANNOTATION_PROCESSOR_BEAN_NAME, AutowiredAnnotationBeanPostProcessor.class); + RootBeanDefinition beanDefinition = new RootBeanDefinition(AutowiredPrototypeComponent.class); + beanDefinition.setScope(BeanDefinition.SCOPE_PROTOTYPE); + applicationContext.registerBeanDefinition("autowiredPrototypeComponent", beanDefinition); + + testCompiledResult(applicationContext, (initializer, compiled) -> { + GenericApplicationContext freshApplicationContext = toFreshApplicationContext(initializer); + Object argument = new Object(); + AutowiredPrototypeComponent bean = + freshApplicationContext.getBean(AutowiredPrototypeComponent.class, argument); + assertThat(bean.getConstructorArgument()).isSameAs(argument); + assertThat(bean.getBeanFactory()).isSameAs(freshApplicationContext.getBeanFactory()); + }); + } + @Test void processAheadOfTimeWhenHasLazyAutowiringOnField() { testAutowiredComponent(LazyAutowiredFieldComponent.class, (bean, generationContext) -> { @@ -873,4 +894,25 @@ public Object reimplement(Object obj, Method method, Object[] args) throws Throw } } + + public static class AutowiredPrototypeComponent { + + private final Object constructorArgument; + + @Autowired + public BeanFactory beanFactory; + + public AutowiredPrototypeComponent(Object constructorArgument) { + this.constructorArgument = constructorArgument; + } + + public Object getConstructorArgument() { + return this.constructorArgument; + } + + public BeanFactory getBeanFactory() { + return this.beanFactory; + } + } + }