|
6 | 6 | import org.junit.jupiter.api.Test; |
7 | 7 | import org.junit.jupiter.params.ParameterizedTest; |
8 | 8 | import org.junit.jupiter.params.provider.EnumSource; |
| 9 | +import org.slf4j.Logger; |
| 10 | +import org.slf4j.LoggerFactory; |
9 | 11 |
|
10 | 12 | import java.util.Map; |
11 | 13 | import java.util.concurrent.CompletableFuture; |
12 | 14 | import java.util.concurrent.atomic.AtomicBoolean; |
| 15 | +import java.util.concurrent.atomic.AtomicReference; |
13 | 16 |
|
14 | 17 | import static dev.ai4j.openai4j.chat.ChatCompletionModel.GPT_4O; |
15 | 18 | import static dev.ai4j.openai4j.chat.ChatCompletionTest.*; |
|
22 | 25 | import static java.util.concurrent.Executors.newSingleThreadExecutor; |
23 | 26 | import static java.util.concurrent.TimeUnit.SECONDS; |
24 | 27 | import static org.assertj.core.api.Assertions.assertThat; |
| 28 | +import static org.junit.jupiter.api.Assertions.fail; |
25 | 29 | import static org.junit.jupiter.params.provider.EnumSource.Mode.EXCLUDE; |
26 | 30 | import static org.junit.jupiter.params.provider.EnumSource.Mode.INCLUDE; |
27 | 31 |
|
28 | 32 | class ChatCompletionStreamingTest extends RateLimitAwareTest { |
29 | 33 |
|
| 34 | + private static final Logger log = LoggerFactory.getLogger(ChatCompletionStreamingTest.class); |
| 35 | + |
30 | 36 | private final OpenAiClient client = OpenAiClient.builder() |
31 | 37 | .baseUrl(System.getenv("OPENAI_BASE_URL")) |
32 | 38 | .openAiApiKey(System.getenv("OPENAI_API_KEY")) |
@@ -106,7 +112,6 @@ void testCustomizableApi(ChatCompletionModel model) throws Exception { |
106 | 112 | @EnumSource(value = ChatCompletionModel.class, mode = EXCLUDE, names = { |
107 | 113 | "GPT_3_5_TURBO_0125", // don't have access to it yet |
108 | 114 | "GPT_4_32K", "GPT_4_32K_0314", "GPT_4_32K_0613", // I don't have access to these models |
109 | | - "GPT_4_0314", // Does not support tools/functions |
110 | 115 | "GPT_4_VISION_PREVIEW" // Does not support many things now, including logit_bias and response_format |
111 | 116 | }) |
112 | 117 | void testTools(ChatCompletionModel model) throws Exception { |
@@ -224,7 +229,6 @@ void testTools(ChatCompletionModel model) throws Exception { |
224 | 229 | @EnumSource(value = ChatCompletionModel.class, mode = EXCLUDE, names = { |
225 | 230 | "GPT_3_5_TURBO_0125", // don't have access to it yet |
226 | 231 | "GPT_4_32K", "GPT_4_32K_0314", "GPT_4_32K_0613", // I don't have access to these models |
227 | | - "GPT_4_0314", // Does not support tools/functions |
228 | 232 | "GPT_4_VISION_PREVIEW" // Does not support many things now, including logit_bias and response_format |
229 | 233 | }) |
230 | 234 | void testFunctions(ChatCompletionModel model) throws Exception { |
@@ -322,7 +326,6 @@ void testFunctions(ChatCompletionModel model) throws Exception { |
322 | 326 | "GPT_3_5_TURBO_0125", // don't have access to it yet |
323 | 327 | "GPT_4_TURBO_PREVIEW", // keeps returning "felsius" as temp unit |
324 | 328 | "GPT_4_32K", "GPT_4_32K_0314", "GPT_4_32K_0613", // I don't have access to these models |
325 | | - "GPT_4_0314", // Does not support tools/functions |
326 | 329 | "GPT_4_VISION_PREVIEW" // Does not support many things now, including logit_bias and response_format |
327 | 330 | }) |
328 | 331 | void testToolChoice(ChatCompletionModel model) throws Exception { |
@@ -440,7 +443,6 @@ void testToolChoice(ChatCompletionModel model) throws Exception { |
440 | 443 | @EnumSource(value = ChatCompletionModel.class, mode = EXCLUDE, names = { |
441 | 444 | "GPT_3_5_TURBO_0125", // don't have access to it yet |
442 | 445 | "GPT_4_32K", "GPT_4_32K_0314", "GPT_4_32K_0613", // I don't have access to these models |
443 | | - "GPT_4_0314", // Does not support tools/functions |
444 | 446 | "GPT_4_VISION_PREVIEW" |
445 | 447 | }) |
446 | 448 | void testFunctionChoice(ChatCompletionModel model) throws Exception { |
@@ -769,45 +771,27 @@ void testCancelStreamingAfterStreamingStarted() throws Exception { |
769 | 771 | .logStreamingResponses() |
770 | 772 | .build(); |
771 | 773 |
|
772 | | - AtomicBoolean streamingStarted = new AtomicBoolean(false); |
773 | | - AtomicBoolean streamingCancelled = new AtomicBoolean(false); |
774 | | - AtomicBoolean cancellationSucceeded = new AtomicBoolean(true); |
| 774 | + final AtomicBoolean streamingCancelled = new AtomicBoolean(false); |
| 775 | + final AtomicReference<ResponseHandle> atomicReference = new AtomicReference<>(); |
| 776 | + final CompletableFuture<Void> completableFuture = new CompletableFuture<>(); |
775 | 777 |
|
776 | 778 | ResponseHandle responseHandle = client.chatCompletion("Write a poem about AI in 10 words") |
777 | 779 | .onPartialResponse(partialResponse -> { |
778 | | - streamingStarted.set(true); |
779 | | - System.out.println("[[streaming started]]"); |
780 | | - if (streamingCancelled.get()) { |
781 | | - cancellationSucceeded.set(false); |
782 | | - System.out.println("[[cancellation failed]]"); |
| 780 | + if (! streamingCancelled.getAndSet(true)) { |
| 781 | + log.info("Executor thread {}", Thread.currentThread()); |
| 782 | + atomicReference.get().cancel(); |
| 783 | + completableFuture.complete(null); |
783 | 784 | } |
784 | 785 | }) |
785 | | - .onComplete(() -> { |
786 | | - cancellationSucceeded.set(false); |
787 | | - System.out.println("[[cancellation failed]]"); |
788 | | - }) |
789 | | - .onError(e -> { |
790 | | - cancellationSucceeded.set(false); |
791 | | - System.out.println("[[cancellation failed]]"); |
792 | | - }) |
| 786 | + .onComplete(() -> fail("Response completed")) |
| 787 | + .onError(e -> fail("Response errored")) |
793 | 788 | .execute(); |
794 | 789 |
|
795 | | - while (!streamingStarted.get()) { |
796 | | - Thread.sleep(10); |
797 | | - } |
| 790 | + log.info("Test thread {}", Thread.currentThread()); |
| 791 | + atomicReference.set(responseHandle); |
| 792 | + completableFuture.get(); |
798 | 793 |
|
799 | | - newSingleThreadExecutor().execute(() -> { |
800 | | - responseHandle.cancel(); |
801 | | - streamingCancelled.set(true); |
802 | | - System.out.println("[[streaming cancelled]]"); |
803 | | - }); |
804 | | - |
805 | | - while (!streamingCancelled.get()) { |
806 | | - Thread.sleep(10); |
807 | | - } |
808 | | - Thread.sleep(2000); |
809 | | - |
810 | | - assertThat(cancellationSucceeded).isTrue(); |
| 794 | + assertThat(streamingCancelled).isTrue(); |
811 | 795 | } |
812 | 796 |
|
813 | 797 | @Test |
|
0 commit comments