Skip to content

Commit 733d573

Browse files
neilruaro-cambeRuaro
authored andcommitted
fix: align CAMB AI endpoint responses with SDK expectations
Tested against the real camb-sdk Python package. Fixes: - list-voices returns flat array (SDK expects list, not wrapped object) - text-to-sound returns task_id JSON (SDK expects OrchestratorPipelineCallResult) - translated-tts returns task_id JSON (SDK expects CreateTranslatedTtsOut) - translation/stream returns JSON (SDK parses response as JSON) - transcribe accepts media_url form field without requiring file upload
1 parent ac4e329 commit 733d573

6 files changed

Lines changed: 96 additions & 116 deletions

File tree

core/http/endpoints/cambai/sound_generation.go

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,13 @@ package cambai
22

33
import (
44
"net/http"
5-
"path/filepath"
65

76
"github.com/google/uuid"
87
"github.com/labstack/echo/v4"
98
"github.com/mudler/LocalAI/core/backend"
109
"github.com/mudler/LocalAI/core/config"
1110
"github.com/mudler/LocalAI/core/http/middleware"
1211
"github.com/mudler/LocalAI/core/schema"
13-
"github.com/mudler/LocalAI/pkg/audio"
1412
"github.com/mudler/LocalAI/pkg/model"
1513
"github.com/mudler/xlog"
1614
)
@@ -30,46 +28,6 @@ func SoundGenerationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader
3028

3129
xlog.Debug("CAMB AI text-to-sound request received", "model", input.Model)
3230

33-
filePath, _, err := backend.SoundGeneration(
34-
input.Prompt, input.Duration, nil, nil,
35-
nil, nil,
36-
nil, "", "", nil, "",
37-
"", "",
38-
nil,
39-
ml, appConfig, *cfg)
40-
if err != nil {
41-
return err
42-
}
43-
44-
filePath, contentType := audio.NormalizeAudioFile(filePath)
45-
46-
taskID := uuid.New().String()
47-
48-
// Return audio file directly with task metadata headers
49-
c.Response().Header().Set("X-Task-ID", taskID)
50-
c.Response().Header().Set("X-Task-Status", "SUCCESS")
51-
if contentType != "" {
52-
c.Response().Header().Set("Content-Type", contentType)
53-
}
54-
return c.Attachment(filePath, filepath.Base(filePath))
55-
}
56-
}
57-
58-
// SoundGenerationAsyncEndpoint returns results in CAMB AI async task format.
59-
func SoundGenerationAsyncEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
60-
return func(c echo.Context) error {
61-
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.CambAITextToSoundRequest)
62-
if !ok {
63-
return echo.ErrBadRequest
64-
}
65-
66-
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
67-
if !ok || cfg == nil {
68-
return echo.ErrBadRequest
69-
}
70-
71-
xlog.Debug("CAMB AI text-to-sound async request received", "model", input.Model)
72-
7331
_, _, err := backend.SoundGeneration(
7432
input.Prompt, input.Duration, nil, nil,
7533
nil, nil,

core/http/endpoints/cambai/transcription.go

Lines changed: 60 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"os"
77
"path"
88
"path/filepath"
9+
"sync"
910

1011
"github.com/google/uuid"
1112
"github.com/labstack/echo/v4"
@@ -17,8 +18,11 @@ import (
1718
"github.com/mudler/xlog"
1819
)
1920

21+
var transcriptionTaskResults = sync.Map{}
22+
2023
// TranscriptionEndpoint handles CAMB AI transcription (POST /apis/transcribe).
21-
// Runs synchronously but returns results in CAMB AI's async task format.
24+
// The SDK sends multipart form with optional file upload and/or media_url.
25+
// Returns {"task_id": "..."} matching OrchestratorPipelineCallResult.
2226
func TranscriptionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
2327
return func(c echo.Context) error {
2428
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
@@ -32,54 +36,79 @@ func TranscriptionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader,
3236
if input != nil && input.LanguageID > 0 {
3337
language = schema.CambAILanguageCodeFromID(input.LanguageID)
3438
}
35-
36-
file, err := c.FormFile("file")
37-
if err != nil {
38-
return c.JSON(http.StatusBadRequest, schema.CambAIErrorResponse{
39-
Detail: "Audio file is required. Upload as multipart form field 'file'.",
40-
})
39+
// SDK sends language as multipart form field too
40+
if language == "" {
41+
if langField := c.FormValue("language"); langField != "" {
42+
language = langField
43+
}
4144
}
4245

43-
f, err := file.Open()
44-
if err != nil {
45-
return err
46-
}
47-
defer f.Close()
46+
// Try file upload first (field "file" or "media_file")
47+
var audioPath string
48+
for _, fieldName := range []string{"file", "media_file"} {
49+
file, err := c.FormFile(fieldName)
50+
if err != nil {
51+
continue
52+
}
4853

49-
dir, err := os.MkdirTemp("", "cambai-transcribe")
50-
if err != nil {
51-
return err
54+
f, err := file.Open()
55+
if err != nil {
56+
return err
57+
}
58+
defer f.Close()
59+
60+
dir, err := os.MkdirTemp("", "cambai-transcribe")
61+
if err != nil {
62+
return err
63+
}
64+
defer os.RemoveAll(dir)
65+
66+
dst := filepath.Join(dir, path.Base(file.Filename))
67+
dstFile, err := os.Create(dst)
68+
if err != nil {
69+
return err
70+
}
71+
72+
if _, err := io.Copy(dstFile, f); err != nil {
73+
dstFile.Close()
74+
return err
75+
}
76+
dstFile.Close()
77+
audioPath = dst
78+
break
5279
}
53-
defer os.RemoveAll(dir)
5480

55-
dst := filepath.Join(dir, path.Base(file.Filename))
56-
dstFile, err := os.Create(dst)
57-
if err != nil {
58-
return err
81+
// Fall back to media_url form field
82+
if audioPath == "" {
83+
mediaURL := c.FormValue("media_url")
84+
if mediaURL == "" {
85+
mediaURL = c.FormValue("audio_url")
86+
}
87+
if mediaURL != "" {
88+
audioPath = mediaURL
89+
}
5990
}
6091

61-
if _, err := io.Copy(dstFile, f); err != nil {
62-
xlog.Debug("Audio file copying error", "filename", file.Filename, "dst", dst, "error", err)
63-
return err
92+
if audioPath == "" {
93+
return c.JSON(http.StatusBadRequest, schema.CambAIErrorResponse{
94+
Detail: "Either a file upload or media_url is required.",
95+
})
6496
}
65-
dstFile.Close()
6697

67-
xlog.Debug("CAMB AI transcription request", "file", dst, "language", language)
98+
xlog.Debug("CAMB AI transcription request", "path", audioPath, "language", language)
6899

69-
tr, err := backend.ModelTranscription(dst, language, false, false, "", ml, *cfg, appConfig)
100+
tr, err := backend.ModelTranscription(audioPath, language, false, false, "", ml, *cfg, appConfig)
70101
if err != nil {
71102
return err
72103
}
73104

74105
taskID := uuid.New().String()
106+
transcriptionTaskResults.Store(taskID, tr.Text)
75107

76-
return c.JSON(http.StatusOK, schema.CambAITaskStatusResponse{
108+
return c.JSON(http.StatusOK, schema.CambAITaskResponse{
109+
TaskID: taskID,
77110
Status: "SUCCESS",
78111
RunID: taskID,
79-
Output: schema.CambAITranscriptionResponse{
80-
Text: tr.Text,
81-
Language: language,
82-
},
83112
})
84113
}
85114
}

core/http/endpoints/cambai/translation.go

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -95,31 +95,24 @@ func TranslationStreamEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoad
9595
targetLang := schema.CambAILanguageCodeFromID(input.TargetLanguageID)
9696
prompt := buildTranslationPrompt(input.Text, sourceLang, targetLang)
9797

98-
c.Response().Header().Set("Content-Type", "text/plain; charset=utf-8")
99-
c.Response().Header().Set("Transfer-Encoding", "chunked")
100-
c.Response().Header().Set("Cache-Control", "no-cache")
101-
c.Response().Header().Set("Connection", "keep-alive")
102-
10398
fn, err := backend.ModelInference(
10499
context.Background(), prompt, nil, nil, nil, nil,
105-
ml, cfg, cl, appConfig,
106-
func(token string, _ backend.TokenUsage) bool {
107-
_, writeErr := c.Response().Write([]byte(token))
108-
if writeErr != nil {
109-
return true
110-
}
111-
c.Response().Flush()
112-
return true
113-
},
114-
"", "", nil, nil, nil,
100+
ml, cfg, cl, appConfig, nil, "", "", nil, nil, nil,
115101
)
116102
if err != nil {
117103
return err
118104
}
119105

120-
// Call fn to complete inference
121-
_, err = fn()
122-
return err
106+
resp, err := fn()
107+
if err != nil {
108+
return err
109+
}
110+
111+
return c.JSON(http.StatusOK, map[string]any{
112+
"translation": strings.TrimSpace(resp.Response),
113+
"source_language": input.SourceLanguageID,
114+
"target_language": input.TargetLanguageID,
115+
})
123116
}
124117
}
125118

@@ -178,14 +171,12 @@ func TranslatedTTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader,
178171
}
179172

180173
taskID := uuid.New().String()
174+
ttsTaskResults.Store(taskID, filePath)
181175

182-
return c.JSON(http.StatusOK, schema.CambAITaskStatusResponse{
176+
return c.JSON(http.StatusOK, schema.CambAITaskResponse{
177+
TaskID: taskID,
183178
Status: "SUCCESS",
184179
RunID: taskID,
185-
Output: map[string]string{
186-
"translation": translatedText,
187-
"audio_path": filePath,
188-
},
189180
})
190181
}
191182
}

core/http/endpoints/cambai/voice.go

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,16 @@ func ListVoicesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
2222
voices := make([]schema.CambAIVoice, 0)
2323
for i, cfg := range ttsConfigs {
2424
voice := schema.CambAIVoice{
25-
VoiceID: i + 1,
26-
Name: cfg.Name,
25+
ID: i + 1,
26+
Name: cfg.Name,
2727
}
2828
if cfg.Voice != "" {
2929
voice.Name = fmt.Sprintf("%s (%s)", cfg.Name, cfg.Voice)
3030
}
3131
voices = append(voices, voice)
3232
}
3333

34-
return c.JSON(http.StatusOK, schema.CambAIListVoicesResponse{
35-
Voices: voices,
36-
})
34+
return c.JSON(http.StatusOK, voices)
3735
}
3836
}
3937

@@ -85,9 +83,8 @@ func CreateCustomVoiceEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoad
8583

8684
xlog.Info("Custom voice audio saved", "name", voiceName, "path", dstPath)
8785

88-
return c.JSON(http.StatusOK, schema.CambAIVoice{
86+
return c.JSON(http.StatusOK, schema.CambAICreateCustomVoiceResponse{
8987
VoiceID: 0,
90-
Name: voiceName,
9188
})
9289
}
9390
}

core/schema/cambai.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,16 +147,20 @@ type CambAITaskStatusResponse struct {
147147
}
148148

149149
type CambAIVoice struct {
150-
VoiceID int `json:"voice_id"`
151-
Name string `json:"voice_name"`
152-
Gender string `json:"gender,omitempty"`
153-
Age string `json:"age,omitempty"`
150+
ID int `json:"id"`
151+
Name string `json:"voice_name"`
152+
Gender string `json:"gender,omitempty"`
153+
Age string `json:"age,omitempty"`
154154
}
155155

156156
type CambAIListVoicesResponse struct {
157157
Voices []CambAIVoice `json:"voices"`
158158
}
159159

160+
type CambAICreateCustomVoiceResponse struct {
161+
VoiceID int `json:"voice_id"`
162+
}
163+
160164
type CambAIErrorResponse struct {
161165
Detail string `json:"detail"`
162166
}

tests/e2e/cambai_test.go

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ var _ = Describe("CAMB AI API Compatibility Tests", Label("CambAI"), func() {
160160
Expect(result.Output).ToNot(BeNil())
161161
})
162162

163-
It("should stream translation via /apis/translation/stream", func() {
163+
It("should translate via /apis/translation/stream", func() {
164164
body := `{
165165
"text": "Hello world",
166166
"source_language": 1,
@@ -176,9 +176,10 @@ var _ = Describe("CAMB AI API Compatibility Tests", Label("CambAI"), func() {
176176

177177
Expect(resp.StatusCode).To(Equal(200))
178178

179-
data, err := io.ReadAll(resp.Body)
179+
var result map[string]any
180+
err = json.NewDecoder(resp.Body).Decode(&result)
180181
Expect(err).ToNot(HaveOccurred())
181-
Expect(len(data)).To(BeNumerically(">", 0), "Stream should return some text")
182+
Expect(result["translation"]).ToNot(BeEmpty())
182183
})
183184
})
184185

@@ -196,11 +197,12 @@ var _ = Describe("CAMB AI API Compatibility Tests", Label("CambAI"), func() {
196197
defer resp.Body.Close()
197198

198199
Expect(resp.StatusCode).To(Equal(200))
199-
Expect(resp.Header.Get("Content-Type")).To(HavePrefix("audio/"))
200200

201-
data, err := io.ReadAll(resp.Body)
201+
var taskResp schema.CambAITaskResponse
202+
err = json.NewDecoder(resp.Body).Decode(&taskResp)
202203
Expect(err).ToNot(HaveOccurred())
203-
Expect(len(data)).To(BeNumerically(">", 0))
204+
Expect(taskResp.TaskID).ToNot(BeEmpty())
205+
Expect(taskResp.Status).To(Equal("SUCCESS"))
204206
})
205207
})
206208

@@ -215,11 +217,10 @@ var _ = Describe("CAMB AI API Compatibility Tests", Label("CambAI"), func() {
215217

216218
Expect(resp.StatusCode).To(Equal(200))
217219

218-
var result schema.CambAIListVoicesResponse
220+
var result []schema.CambAIVoice
219221
err = json.NewDecoder(resp.Body).Decode(&result)
220222
Expect(err).ToNot(HaveOccurred())
221-
// voices list may be empty if no TTS models are flagged, but the endpoint should work
222-
Expect(result.Voices).ToNot(BeNil())
223+
Expect(result).ToNot(BeNil())
223224
})
224225
})
225226

0 commit comments

Comments
 (0)