From f0df76bd2f3ab5d09814ad77d07d3c2a01dab14e Mon Sep 17 00:00:00 2001 From: Ned Wolpert Date: Mon, 9 Jun 2025 07:43:29 -0700 Subject: [PATCH 1/2] Added tests and updated dispatcher to make it easier to create new callback execution patterns --- .../smr/dispatcher/BaseDispatcher.java | 106 ++++++++++++++++++ .../smr/dispatcher/SynchronousDispatcher.java | 91 +-------------- .../dispatcher/SynchronousDispatcherTest.java | 33 +++++- 3 files changed, 144 insertions(+), 86 deletions(-) create mode 100644 smr/src/main/java/com/codeheadsystems/smr/dispatcher/BaseDispatcher.java diff --git a/smr/src/main/java/com/codeheadsystems/smr/dispatcher/BaseDispatcher.java b/smr/src/main/java/com/codeheadsystems/smr/dispatcher/BaseDispatcher.java new file mode 100644 index 0000000..0799d02 --- /dev/null +++ b/smr/src/main/java/com/codeheadsystems/smr/dispatcher/BaseDispatcher.java @@ -0,0 +1,106 @@ +package com.codeheadsystems.smr.dispatcher; + +import com.codeheadsystems.smr.Callback; +import com.codeheadsystems.smr.Context; +import com.codeheadsystems.smr.Dispatcher; +import com.codeheadsystems.smr.ImmutableCallback; +import com.codeheadsystems.smr.Phase; +import com.codeheadsystems.smr.State; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.function.Consumer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public abstract class BaseDispatcher implements Dispatcher { + protected static final Logger log = LoggerFactory.getLogger(SynchronousDispatcher.class); + protected final Map>[]> callbackMap; + + public BaseDispatcher(final Set states) { + log.info("BaseDispatcher()"); + this.callbackMap = states.stream() + .collect(HashMap::new, (map, state) -> map.put(state, buildList()), HashMap::putAll); + } + + @Override + public void enable(final State state, + final Phase phase, + final Consumer contextConsumer) { + log.trace("enable({}, {}, {})", state, phase, contextConsumer); + callbackMap.get(state)[phase.ordinal()].add(contextConsumer); + } + + @Override + public void disable(final State state, + final Phase phase, + final Consumer contextConsumer) { + log.trace("disable({}, {}, {})", state, phase, contextConsumer); + callbackMap.get(state)[phase.ordinal()].remove(contextConsumer); + } + + /** + * TODO: This method needs to be handled with care. Need to consider if we want to 1) back out of events if + * things failed, 2) keep it simple but incomplete, 3) allow for various implementations. (Most likely). + * + * @param context that has state being changed. + * @param currentState the from state. + * @param newState the too state. + */ + @Override + public void handleTransitionEvent(final Context context, + final State currentState, + final State newState) { + log.trace("handleTransitionEvent({}, {}, {})", context, currentState, newState); + dispatchCallbacks(context, currentState, Phase.EXIT); + final State previousState = changeState(context, currentState, newState); + if (!previousState.equals(currentState)) { + log.warn("handleTransitionEvent:state: {} != {}", previousState, currentState); + } + dispatchCallbacks(context, newState, Phase.ENTER); + } + + @Override + public State changeState(final Context context, final State currentState, final State newState) { + return context.reference().getAndSet(newState); + } + + @Override + public void dispatchCallbacks(final Context context, + final State currentState, + final Phase phase) { + log.trace("dispatchCallbacks({}, {}, {})", context, currentState, phase); + final Set>[] callbacks = callbackMap.get(currentState); + final Set> phasedCallbacks = callbacks[phase.ordinal()]; + if (phasedCallbacks.isEmpty()) { + return; + } + final Callback callback = ImmutableCallback.builder() + .context(context) + .state(currentState) + .phase(phase) + .build(); + executeCallbacks(phasedCallbacks, callback); + } + + // Implement this for any logic you need. + abstract protected void executeCallbacks(Set> phasedCallbacks, final Callback callback); + + @Override + public void executeCallback(final Consumer consumer, + final Callback callback) { + try { + consumer.accept(callback); + } catch (RuntimeException e) { + log.error("dispatchCallbacks:error: {}", consumer, e); + } + } + + @SuppressWarnings("unchecked") + protected Set>[] buildList() { + return Arrays.stream(Phase.values()) + .map(event -> new HashSet>()).toArray(Set[]::new); + } +} diff --git a/smr/src/main/java/com/codeheadsystems/smr/dispatcher/SynchronousDispatcher.java b/smr/src/main/java/com/codeheadsystems/smr/dispatcher/SynchronousDispatcher.java index f2071d4..d650742 100644 --- a/smr/src/main/java/com/codeheadsystems/smr/dispatcher/SynchronousDispatcher.java +++ b/smr/src/main/java/com/codeheadsystems/smr/dispatcher/SynchronousDispatcher.java @@ -1,29 +1,15 @@ package com.codeheadsystems.smr.dispatcher; import com.codeheadsystems.smr.Callback; -import com.codeheadsystems.smr.Context; -import com.codeheadsystems.smr.Dispatcher; -import com.codeheadsystems.smr.ImmutableCallback; import com.codeheadsystems.smr.Phase; import com.codeheadsystems.smr.State; -import com.codeheadsystems.smr.StateMachineDefinition; -import java.util.Arrays; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Map; import java.util.Set; import java.util.function.Consumer; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; /** * The type Synchronous dispatcher. */ -public class SynchronousDispatcher implements Dispatcher { - - private static final Logger log = LoggerFactory.getLogger(SynchronousDispatcher.class); - - private final Map>[]> callbackMap; +public class SynchronousDispatcher extends BaseDispatcher { /** * Instantiates a new Synchronous dispatcher. @@ -31,82 +17,17 @@ public class SynchronousDispatcher implements Dispatcher { * @param states the states */ public SynchronousDispatcher(final Set states) { + super(states); log.info("SynchronousDispatcher()"); - this.callbackMap = states.stream() - .collect(HashMap::new, (map, state) -> map.put(state, buildList()), HashMap::putAll); - } - - @Override - public void enable(final State state, - final Phase phase, - final Consumer contextConsumer) { - log.trace("enable({}, {}, {})", state, phase, contextConsumer); - callbackMap.get(state)[phase.ordinal()].add(contextConsumer); - } - - @Override - public void disable(final State state, - final Phase phase, - final Consumer contextConsumer) { - log.trace("disable({}, {}, {})", state, phase, contextConsumer); - callbackMap.get(state)[phase.ordinal()].remove(contextConsumer); - } - - /** - * TODO: This method needs to be handled with care. Need to consider if we want to 1) back out of events if - * things failed, 2) keep it simple but incomplete, 3) allow for various implementations. (Most likely). - * - * @param context that has state being changed. - * @param currentState the from state. - * @param newState the too state. - */ - @Override - public void handleTransitionEvent(final Context context, - final State currentState, - final State newState) { - log.trace("handleTransitionEvent({}, {}, {})", context, currentState, newState); - dispatchCallbacks(context, currentState, Phase.EXIT); - final State previousState = changeState(context, currentState, newState); - if (!previousState.equals(currentState)) { - log.warn("handleTransitionEvent:state: {} != {}", previousState, currentState); - } - dispatchCallbacks(context, newState, Phase.ENTER); - } - - @Override - public State changeState(final Context context, final State currentState, final State newState) { - return context.reference().getAndSet(newState); } + // A synchronized execution of the callbacks. Basic. @Override - public void dispatchCallbacks(final Context context, - final State currentState, - final Phase phase) { - log.trace("dispatchCallbacks({}, {}, {})", context, currentState, phase); - final Set>[] callbacks = callbackMap.get(currentState); - final Callback callback = ImmutableCallback.builder() - .context(context) - .state(currentState) - .phase(phase) - .build(); - callbacks[phase.ordinal()].forEach(consumer -> { + protected void executeCallbacks(final Set> phasedCallbacks, + final Callback callback) { + phasedCallbacks.forEach(consumer -> { executeCallback(consumer, callback); }); } - @Override - public void executeCallback(final Consumer consumer, final Callback callback) { - try { - consumer.accept(callback); - } catch (RuntimeException e) { - log.error("dispatchCallbacks:error: {}", consumer, e); - } - } - - @SuppressWarnings("unchecked") - private Set>[] buildList() { - return Arrays.stream(Phase.values()) - .map(event -> new HashSet>()).toArray(Set[]::new); - } - } diff --git a/smr/src/test/java/com/codeheadsystems/smr/dispatcher/SynchronousDispatcherTest.java b/smr/src/test/java/com/codeheadsystems/smr/dispatcher/SynchronousDispatcherTest.java index 3999b1c..3eab7c7 100644 --- a/smr/src/test/java/com/codeheadsystems/smr/dispatcher/SynchronousDispatcherTest.java +++ b/smr/src/test/java/com/codeheadsystems/smr/dispatcher/SynchronousDispatcherTest.java @@ -2,6 +2,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -28,6 +29,7 @@ class SynchronousDispatcherTest extends TestBase { @Mock private Consumer consumer; + @Mock private Consumer secondConsumer; @Mock private Context context; @Captor private ArgumentCaptor callback; @@ -54,6 +56,36 @@ void dispatchCallbacks_noException() { assertThat(callback.getValue().phase()).isEqualTo(Phase.TICK); } + /** + * Dispatch callbacks no exception. + */ + @Test + void dispatchCallbacks_disable() { + dispatcher.enable(ONE, Phase.TICK, consumer); + dispatcher.disable(ONE, Phase.TICK, consumer); + dispatcher.dispatchCallbacks(context, ONE, Phase.TICK); + verify(consumer, never()).accept(callback.capture()); + } + + @Test + void dispatchCallbacks_multipleConsumers() { + dispatcher.enable(ONE, Phase.TICK, consumer); + dispatcher.enable(ONE, Phase.TICK, secondConsumer); + dispatcher.dispatchCallbacks(context, ONE, Phase.TICK); + verify(consumer).accept(callback.capture()); + verify(secondConsumer).accept(callback.capture()); + } + + @Test + void dispatchCallbacks_multipleConsumersOneDisabled() { + dispatcher.enable(ONE, Phase.TICK, consumer); + dispatcher.enable(ONE, Phase.TICK, secondConsumer); + dispatcher.disable(ONE, Phase.TICK, consumer); + dispatcher.dispatchCallbacks(context, ONE, Phase.TICK); + verify(consumer, never()).accept(callback.capture()); + verify(secondConsumer).accept(callback.capture()); + } + /** * Dispatch callbacks with exception. */ @@ -103,5 +135,4 @@ void handleTransitionEvent_noFailWhenContextHasWrongeState() { assertThat(callback.getAllValues().get(1).phase()).isEqualTo(Phase.ENTER); } - } \ No newline at end of file From 632e661de65cfb36761d6bb765a010875d4ffc67 Mon Sep 17 00:00:00 2001 From: Ned Wolpert Date: Mon, 9 Jun 2025 19:13:01 -0700 Subject: [PATCH 2/2] Added asynch dispatcher and turned the context into a domain object --- .../java/com/codeheadsystems/smr/Context.java | 31 ++++---- .../com/codeheadsystems/smr/Dispatcher.java | 15 +--- .../com/codeheadsystems/smr/StateMachine.java | 11 +-- .../dispatcher/AsynchronousDispatcher.java | 75 +++++++++++++++++++ .../smr/dispatcher/BaseDispatcher.java | 4 +- .../com/codeheadsystems/smr/TestBase.java | 2 +- ...ispatcherTest.java => DispatcherTest.java} | 64 +++++++++------- 7 files changed, 137 insertions(+), 65 deletions(-) create mode 100644 smr/src/main/java/com/codeheadsystems/smr/dispatcher/AsynchronousDispatcher.java rename smr/src/test/java/com/codeheadsystems/smr/dispatcher/{SynchronousDispatcherTest.java => DispatcherTest.java} (74%) diff --git a/smr/src/main/java/com/codeheadsystems/smr/Context.java b/smr/src/main/java/com/codeheadsystems/smr/Context.java index 1a19c4e..9e2414a 100644 --- a/smr/src/main/java/com/codeheadsystems/smr/Context.java +++ b/smr/src/main/java/com/codeheadsystems/smr/Context.java @@ -1,30 +1,28 @@ package com.codeheadsystems.smr; +import java.util.Optional; import java.util.concurrent.atomic.AtomicReference; /** * You can have many contexts for a single state machine. And the * state machine manages this one context. */ -@FunctionalInterface public interface Context { /** - * Reference atomic reference. + * State state. * - * @return the atomic reference + * @return the state */ - AtomicReference reference(); + State state(); /** - * State state. + * Sets the current state to the new state, returning the existing state, if any. * - * @return the state + * @param newState for the context. + * @return the old state */ - default State state() { - return reference().get(); - } - + Optional setState(State newState); /** * You can extend this to generate your own context easily enough. */ @@ -33,7 +31,7 @@ abstract class Impl implements Context { /** * The State. */ - protected final AtomicReference state; + private final AtomicReference reference; /** * Instantiates a new . @@ -41,12 +39,17 @@ abstract class Impl implements Context { * @param initialState the initial state */ public Impl(State initialState) { - this.state = new AtomicReference<>(initialState); + this.reference = new AtomicReference<>(initialState); + } + + @Override + public State state() { + return reference.get(); } @Override - public AtomicReference reference() { - return state; + public Optional setState(State newState) { + return Optional.ofNullable(reference.getAndSet(newState)); } } diff --git a/smr/src/main/java/com/codeheadsystems/smr/Dispatcher.java b/smr/src/main/java/com/codeheadsystems/smr/Dispatcher.java index e9f5473..81290fd 100644 --- a/smr/src/main/java/com/codeheadsystems/smr/Dispatcher.java +++ b/smr/src/main/java/com/codeheadsystems/smr/Dispatcher.java @@ -29,19 +29,6 @@ void disable(State state, Phase phase, Consumer contextConsumer); - /** - * Wrapper to use the event unaware method in case you want to make decisions based on the - * event. This is not recommended as it breaks the state machine pattern, but you do you. - * - * @param context that holds onto the current state. - * @param currentState expected current state. - * @param newState new state to call. - * @param event that caused the transition. - */ - default void handleTransitionEvent(Context context, State currentState, State newState, Event event) { - handleTransitionEvent(context, currentState, newState); - } - /** * Macro method that handles the full state change and callback execution. * @@ -51,6 +38,8 @@ default void handleTransitionEvent(Context context, State currentState, State ne */ void handleTransitionEvent(Context context, State currentState, State newState); + // ---- Below here exist for decorators ---- + /** * Does the state change. * diff --git a/smr/src/main/java/com/codeheadsystems/smr/StateMachine.java b/smr/src/main/java/com/codeheadsystems/smr/StateMachine.java index 3b4424e..fd85cca 100644 --- a/smr/src/main/java/com/codeheadsystems/smr/StateMachine.java +++ b/smr/src/main/java/com/codeheadsystems/smr/StateMachine.java @@ -50,15 +50,6 @@ public static StateMachine.Builder builder() { return new StateMachine.Builder(); } - /** - * The current state of the state machine. - * - * @return the current state. - */ - public State state() { - return state.get(); - } - /** * Get the states that are valid for the current state machine. * @@ -133,7 +124,7 @@ public State dispatch(final Event event) { final Optional optionalNewState = definition.forEvent(currentState, event); if (optionalNewState.isPresent()) { final State newState = optionalNewState.get(); - dispatcher.handleTransitionEvent(this, currentState, newState, event); + dispatcher.handleTransitionEvent(this, currentState, newState); return newState; } else { log.warn("No transition for event {} from state {}", event, currentState); diff --git a/smr/src/main/java/com/codeheadsystems/smr/dispatcher/AsynchronousDispatcher.java b/smr/src/main/java/com/codeheadsystems/smr/dispatcher/AsynchronousDispatcher.java new file mode 100644 index 0000000..ace2512 --- /dev/null +++ b/smr/src/main/java/com/codeheadsystems/smr/dispatcher/AsynchronousDispatcher.java @@ -0,0 +1,75 @@ +package com.codeheadsystems.smr.dispatcher; + +import com.codeheadsystems.smr.Callback; +import com.codeheadsystems.smr.State; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; +import java.util.function.Consumer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * The type asynchronous dispatcher. + */ +public class AsynchronousDispatcher extends BaseDispatcher { + + protected static final Logger log = LoggerFactory.getLogger(AsynchronousDispatcher.class); + private final Executor executor; + + public static Builder builder() { + return new Builder(); + } + + /** + * Instantiates a new Synchronous dispatcher. + * + * @param states the states + * @param executor + */ + private AsynchronousDispatcher(final Set states, final Executor executor) { + super(states); + this.executor = executor; + log.info("SynchronousDispatcher()"); + } + + // A synchronized execution of the callbacks. Basic. + @Override + protected void executeCallbacks(final Set> phasedCallbacks, + final Callback callback) { + phasedCallbacks.stream() + .map(callbackConsumer -> + CompletableFuture.runAsync(() -> executeCallback(callbackConsumer, callback), executor)) + .forEach(CompletableFuture::join); + } + + public static class Builder { + + private Executor executor; + private Set states; + + public Builder withExecutor(Executor executor) { + this.executor = executor; + return this; + } + + public Builder withStates(Set states) { + this.states = states; + return this; + } + + public AsynchronousDispatcher build() { + if (states == null) { + throw new IllegalArgumentException("states must not be null"); + } + Executor localExecutor = executor; + if (localExecutor == null) { + localExecutor = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors()); + } + return new AsynchronousDispatcher(states, localExecutor); + } + + } + +} diff --git a/smr/src/main/java/com/codeheadsystems/smr/dispatcher/BaseDispatcher.java b/smr/src/main/java/com/codeheadsystems/smr/dispatcher/BaseDispatcher.java index 0799d02..1efe2dd 100644 --- a/smr/src/main/java/com/codeheadsystems/smr/dispatcher/BaseDispatcher.java +++ b/smr/src/main/java/com/codeheadsystems/smr/dispatcher/BaseDispatcher.java @@ -16,7 +16,7 @@ import org.slf4j.LoggerFactory; public abstract class BaseDispatcher implements Dispatcher { - protected static final Logger log = LoggerFactory.getLogger(SynchronousDispatcher.class); + protected static final Logger log = LoggerFactory.getLogger(BaseDispatcher.class); protected final Map>[]> callbackMap; public BaseDispatcher(final Set states) { @@ -64,7 +64,7 @@ public void handleTransitionEvent(final Context context, @Override public State changeState(final Context context, final State currentState, final State newState) { - return context.reference().getAndSet(newState); + return context.setState(newState).orElse(null); } @Override diff --git a/smr/src/test/java/com/codeheadsystems/smr/TestBase.java b/smr/src/test/java/com/codeheadsystems/smr/TestBase.java index 754fdb1..491910b 100644 --- a/smr/src/test/java/com/codeheadsystems/smr/TestBase.java +++ b/smr/src/test/java/com/codeheadsystems/smr/TestBase.java @@ -38,7 +38,7 @@ public class TestBase { /** * The State machine definition. */ - protected StateMachineDefinition stateMachineDefinition = StateMachineDefinition.builder() + public static StateMachineDefinition stateMachineDefinition = StateMachineDefinition.builder() .addState(ONE).addState(TWO).addState(THREE) .setInitialState(ONE) .addTransition(ONE, TO_TWO, TWO) diff --git a/smr/src/test/java/com/codeheadsystems/smr/dispatcher/SynchronousDispatcherTest.java b/smr/src/test/java/com/codeheadsystems/smr/dispatcher/DispatcherTest.java similarity index 74% rename from smr/src/test/java/com/codeheadsystems/smr/dispatcher/SynchronousDispatcherTest.java rename to smr/src/test/java/com/codeheadsystems/smr/dispatcher/DispatcherTest.java index 3eab7c7..cd2400d 100644 --- a/smr/src/test/java/com/codeheadsystems/smr/dispatcher/SynchronousDispatcherTest.java +++ b/smr/src/test/java/com/codeheadsystems/smr/dispatcher/DispatcherTest.java @@ -11,12 +11,16 @@ import com.codeheadsystems.smr.Context; import com.codeheadsystems.smr.Dispatcher; import com.codeheadsystems.smr.Phase; +import com.codeheadsystems.smr.State; import com.codeheadsystems.smr.TestBase; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; +import java.util.stream.Stream; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; @@ -26,28 +30,32 @@ * The type Synchronous dispatcher test. */ @ExtendWith(MockitoExtension.class) -class SynchronousDispatcherTest extends TestBase { +class DispatcherTest extends TestBase { @Mock private Consumer consumer; @Mock private Consumer secondConsumer; - @Mock private Context context; + private Context context; @Captor private ArgumentCaptor callback; - private Dispatcher dispatcher; - - /** - * Sets up. - */ @BeforeEach void setUp() { - dispatcher = new SynchronousDispatcher(stateMachineDefinition.states()); + context = new Context.Impl(ONE) { + }; + } + + static Stream dispatchers() { + return Stream.of( + Arguments.of(new SynchronousDispatcher(stateMachineDefinition.states())), + Arguments.of(AsynchronousDispatcher.builder().withStates(stateMachineDefinition.states()).build()) + ); } /** * Dispatch callbacks no exception. */ - @Test - void dispatchCallbacks_noException() { + @ParameterizedTest + @MethodSource("dispatchers") + void dispatchCallbacks_noException(Dispatcher dispatcher) { dispatcher.enable(ONE, Phase.TICK, consumer); dispatcher.dispatchCallbacks(context, ONE, Phase.TICK); verify(consumer).accept(callback.capture()); @@ -59,16 +67,18 @@ void dispatchCallbacks_noException() { /** * Dispatch callbacks no exception. */ - @Test - void dispatchCallbacks_disable() { + @ParameterizedTest + @MethodSource("dispatchers") + void dispatchCallbacks_disable(Dispatcher dispatcher) { dispatcher.enable(ONE, Phase.TICK, consumer); dispatcher.disable(ONE, Phase.TICK, consumer); dispatcher.dispatchCallbacks(context, ONE, Phase.TICK); verify(consumer, never()).accept(callback.capture()); } - @Test - void dispatchCallbacks_multipleConsumers() { + @ParameterizedTest + @MethodSource("dispatchers") + void dispatchCallbacks_multipleConsumers(Dispatcher dispatcher) { dispatcher.enable(ONE, Phase.TICK, consumer); dispatcher.enable(ONE, Phase.TICK, secondConsumer); dispatcher.dispatchCallbacks(context, ONE, Phase.TICK); @@ -76,8 +86,9 @@ void dispatchCallbacks_multipleConsumers() { verify(secondConsumer).accept(callback.capture()); } - @Test - void dispatchCallbacks_multipleConsumersOneDisabled() { + @ParameterizedTest + @MethodSource("dispatchers") + void dispatchCallbacks_multipleConsumersOneDisabled(Dispatcher dispatcher) { dispatcher.enable(ONE, Phase.TICK, consumer); dispatcher.enable(ONE, Phase.TICK, secondConsumer); dispatcher.disable(ONE, Phase.TICK, consumer); @@ -89,8 +100,9 @@ void dispatchCallbacks_multipleConsumersOneDisabled() { /** * Dispatch callbacks with exception. */ - @Test - void dispatchCallbacks_withException() { + @ParameterizedTest + @MethodSource("dispatchers") + void dispatchCallbacks_withException(Dispatcher dispatcher) { doThrow(new RuntimeException("test")).when(consumer).accept(callback.capture()); dispatcher.enable(ONE, Phase.TICK, consumer); dispatcher.dispatchCallbacks(context, ONE, Phase.TICK); @@ -102,11 +114,12 @@ void dispatchCallbacks_withException() { /** * Handle transition event. */ - @Test - void handleTransitionEvent() { + @ParameterizedTest + @MethodSource("dispatchers") + void handleTransitionEvent(Dispatcher dispatcher) { dispatcher.enable(ONE, Phase.EXIT, consumer); dispatcher.enable(TWO, Phase.ENTER, consumer); - when(context.reference()).thenReturn(new AtomicReference<>(ONE)); + context.setState(ONE); dispatcher.handleTransitionEvent(context, ONE, TWO); verify(consumer, times(2)).accept(callback.capture()); assertThat(callback.getAllValues().get(0).context()).isEqualTo(context); @@ -120,11 +133,12 @@ void handleTransitionEvent() { /** * Handle transition event no fail when context has wronge state. */ - @Test - void handleTransitionEvent_noFailWhenContextHasWrongeState() { + @ParameterizedTest + @MethodSource("dispatchers") + void handleTransitionEvent_noFailWhenContextHasWrongeState(Dispatcher dispatcher) { dispatcher.enable(ONE, Phase.EXIT, consumer); dispatcher.enable(TWO, Phase.ENTER, consumer); - when(context.reference()).thenReturn(new AtomicReference<>(THREE)); + context.setState(THREE); dispatcher.handleTransitionEvent(context, ONE, TWO); verify(consumer, times(2)).accept(callback.capture()); assertThat(callback.getAllValues().get(0).context()).isEqualTo(context);