@@ -2,27 +2,65 @@ package com.javaaidev.openai
22
33import com.openai.client.OpenAIClient
44import com.openai.models.EmbeddingCreateParams
5+ import org.springframework.ai.chat.metadata.EmptyUsage
56import org.springframework.ai.document.Document
6- import org.springframework.ai.embedding.AbstractEmbeddingModel
7- import org.springframework.ai.embedding.Embedding
8- import org.springframework.ai.embedding.EmbeddingRequest
9- import org.springframework.ai.embedding.EmbeddingResponse
7+ import org.springframework.ai.embedding.*
8+ import org.springframework.ai.model.ModelOptionsUtils
109
11- class OpenAIEmbeddingModel (private val openAIClient : OpenAIClient ) : AbstractEmbeddingModel() {
10+ class OpenAIEmbeddingModel (
11+ private val openAIClient : OpenAIClient ,
12+ private val defaultOptions : OpenAIEmbeddingOptions ? = null ,
13+ ) :
14+ AbstractEmbeddingModel () {
1215 override fun call (request : EmbeddingRequest ): EmbeddingResponse {
1316 val paramsBuilder = EmbeddingCreateParams .builder()
1417 .inputOfArrayOfStrings(request.instructions)
15- request.options.model?.let {
18+
19+ val options = mergeOptions(request.options)
20+
21+ options.model?.let {
1622 paramsBuilder.model(it)
1723 }
18- request. options.dimensions?.let {
24+ options.dimensions?.let {
1925 paramsBuilder.dimensions(it.toLong())
2026 }
27+ options.encodingFormat?.let {
28+ paramsBuilder.encodingFormat(EmbeddingCreateParams .EncodingFormat .of(it))
29+ }
30+ options.user?.let {
31+ paramsBuilder.user(it)
32+ }
33+
2134 val response = openAIClient.embeddings().create(paramsBuilder.build())
2235 val embeddings = response.data().map { e ->
2336 Embedding (e.embedding().map { v -> v.toFloat() }.toFloatArray(), e.index().toInt())
2437 }
25- return EmbeddingResponse (embeddings)
38+ return EmbeddingResponse (embeddings, EmbeddingResponseMetadata (response.model(), EmptyUsage ()))
39+ }
40+
41+ private fun mergeOptions (runtimeOptions : EmbeddingOptions ? ): OpenAIEmbeddingOptions {
42+ val defaultOptions = this .defaultOptions ? : OpenAIEmbeddingOptions .builder().build()
43+ return ModelOptionsUtils .copyToTarget(
44+ runtimeOptions, EmbeddingOptions ::class .java,
45+ OpenAIEmbeddingOptions ::class .java
46+ )?.let { options ->
47+ OpenAIEmbeddingOptions .builder()
48+ .model(ModelOptionsUtils .mergeOption(options.model, defaultOptions.model))
49+ .dimensions(
50+ ModelOptionsUtils .mergeOption(
51+ options.dimensions,
52+ defaultOptions.dimensions
53+ )
54+ )
55+ .encodingFormat(
56+ ModelOptionsUtils .mergeOption(
57+ options.encodingFormat,
58+ defaultOptions.encodingFormat
59+ )
60+ )
61+ .user(ModelOptionsUtils .mergeOption(options.user, defaultOptions.user))
62+ .build()
63+ } ? : defaultOptions
2664 }
2765
2866 override fun embed (document : Document ): FloatArray {
0 commit comments