diff --git a/src/memos/api/config.py b/src/memos/api/config.py index b90df51b2..c62cd3b08 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -381,7 +381,7 @@ def get_reranker_config() -> dict[str, Any]: "url": os.getenv("MOS_RERANKER_URL"), "model": os.getenv("MOS_RERANKER_MODEL", "bge-reranker-v2-m3"), "timeout": 10, - "headers_extra": os.getenv("MOS_RERANKER_HEADERS_EXTRA"), + "headers_extra": json.loads(os.getenv("MOS_RERANKER_HEADERS_EXTRA", "{}")), "rerank_source": os.getenv("MOS_RERANK_SOURCE"), "reranker_strategy": os.getenv("MOS_RERANKER_STRATEGY", "single_turn"), }, @@ -407,6 +407,7 @@ def get_embedder_config() -> dict[str, Any]: "provider": os.getenv("MOS_EMBEDDER_PROVIDER", "openai"), "api_key": os.getenv("MOS_EMBEDDER_API_KEY", "sk-xxxx"), "model_name_or_path": os.getenv("MOS_EMBEDDER_MODEL", "text-embedding-3-large"), + "headers_extra": json.loads(os.getenv("MOS_EMBEDDER_HEADERS_EXTRA", "{}")), "base_url": os.getenv("MOS_EMBEDDER_API_BASE", "http://openai.com"), }, } diff --git a/src/memos/configs/embedder.py b/src/memos/configs/embedder.py index 70095a194..d88b6005e 100644 --- a/src/memos/configs/embedder.py +++ b/src/memos/configs/embedder.py @@ -12,6 +12,10 @@ class BaseEmbedderConfig(BaseConfig): embedding_dims: int | None = Field( default=None, description="Number of dimensions for the embedding" ) + headers_extra: dict[str, Any] | None = Field( + default=None, + description="Extra headers for the embedding model, only for universal_api backend", + ) class OllamaEmbedderConfig(BaseEmbedderConfig): diff --git a/src/memos/embedders/universal_api.py b/src/memos/embedders/universal_api.py index 583a02acb..f39ffaa58 100644 --- a/src/memos/embedders/universal_api.py +++ b/src/memos/embedders/universal_api.py @@ -16,7 +16,11 @@ def __init__(self, config: UniversalAPIEmbedderConfig): self.config = config if self.provider == "openai": - self.client = OpenAIClient(api_key=config.api_key, base_url=config.base_url) + self.client = OpenAIClient( + api_key=config.api_key, + base_url=config.base_url, + default_headers=config.headers_extra if config.headers_extra else None, + ) elif self.provider == "azure": self.client = AzureClient( azure_endpoint=config.base_url, diff --git a/tests/configs/test_embedder.py b/tests/configs/test_embedder.py index 8201f9bd8..10572f33e 100644 --- a/tests/configs/test_embedder.py +++ b/tests/configs/test_embedder.py @@ -17,7 +17,7 @@ def test_base_embedder_config(): required_fields=[ "model_name_or_path", ], - optional_fields=["embedding_dims"], + optional_fields=["embedding_dims", "headers_extra"], ) check_config_instantiation_valid( @@ -36,7 +36,7 @@ def test_ollama_embedder_config(): required_fields=[ "model_name_or_path", ], - optional_fields=["embedding_dims", "api_base"], + optional_fields=["embedding_dims", "headers_extra", "api_base"], ) check_config_instantiation_valid( diff --git a/tests/embedders/test_universal_api.py b/tests/embedders/test_universal_api.py index e4ebb7019..fd61b3e9a 100644 --- a/tests/embedders/test_universal_api.py +++ b/tests/embedders/test_universal_api.py @@ -28,8 +28,7 @@ def test_embed_single_text(self, mock_openai_client): # Assert OpenAIClient was created with proper args mock_openai_client.assert_called_once_with( - api_key="fake-api-key", - base_url="https://api.openai.com/v1", + api_key="fake-api-key", base_url="https://api.openai.com/v1", default_headers=None ) # Assert embeddings.create called with correct params