-
Notifications
You must be signed in to change notification settings - Fork 184
Expand file tree
/
Copy pathServiceCollectionChatClientExtensions.cs
More file actions
103 lines (91 loc) · 4.28 KB
/
ServiceCollectionChatClientExtensions.cs
File metadata and controls
103 lines (91 loc) · 4.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
using Azure.AI.OpenAI;
using OllamaSharp;
using System.ClientModel;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.DependencyInjection;
using OpenAI;
using System.Data.Common;
using Microsoft.Extensions.Configuration;
namespace Microsoft.Extensions.Hosting;
public static class ServiceCollectionChatClientExtensions
{
public static IServiceCollection AddOllamaChatClient(
this IHostApplicationBuilder hostBuilder,
string serviceName,
Func<ChatClientBuilder, ChatClientBuilder>? builder = null,
string? modelName = null)
{
if (modelName is null)
{
var configKey = $"{serviceName}:LlmModelName";
modelName = hostBuilder.Configuration[configKey];
if (string.IsNullOrEmpty(modelName))
{
throw new InvalidOperationException($"No {nameof(modelName)} was specified, and none could be found from configuration at '{configKey}'");
}
}
return hostBuilder.Services.AddOllamaChatClient(
modelName,
new Uri($"http://{serviceName}"),
builder);
}
public static IServiceCollection AddOllamaChatClient(
this IServiceCollection services,
string modelName,
Uri? uri = null,
Func<ChatClientBuilder, ChatClientBuilder>? builder = null)
{
uri ??= new Uri("http://localhost:11434");
return services.AddChatClient(pipeline =>
{
builder?.Invoke(pipeline);
// Temporary workaround for Ollama issues
pipeline.UsePreventStreamingWithFunctions();
var httpClient = pipeline.Services.GetService<HttpClient>() ?? new();
return pipeline.Use(new OllamaApiClient(httpClient, modelName));
});
}
public static IServiceCollection AddOpenAIChatClient(
this IHostApplicationBuilder hostBuilder,
string serviceName,
Func<ChatClientBuilder, ChatClientBuilder>? builder = null,
string? modelOrDeploymentName = null)
{
// TODO: We would prefer to use Aspire.AI.OpenAI here, but it doesn't yet support the OpenAI v2 client.
// So for now we access the connection string and set up a client manually.
var connectionString = hostBuilder.Configuration.GetConnectionString(serviceName);
if (string.IsNullOrWhiteSpace(connectionString))
{
throw new InvalidOperationException($"No connection string named '{serviceName}' was found. Ensure a corresponding Aspire service was registered.");
}
var connectionStringBuilder = new DbConnectionStringBuilder();
connectionStringBuilder.ConnectionString = connectionString;
var endpoint = (string?)connectionStringBuilder["endpoint"];
var apiKey = (string)connectionStringBuilder["key"] ?? throw new InvalidOperationException($"The connection string named '{serviceName}' does not specify a value for 'Key', but this is required.");
modelOrDeploymentName ??= (connectionStringBuilder["Deployment"] ?? connectionStringBuilder["Model"]) as string;
if (string.IsNullOrWhiteSpace(modelOrDeploymentName))
{
throw new InvalidOperationException($"The connection string named '{serviceName}' does not specify a value for 'Deployment' or 'Model', and no value was passed for {nameof(modelOrDeploymentName)}.");
}
var endpointUri = string.IsNullOrEmpty(endpoint) ? null : new Uri(endpoint);
return hostBuilder.Services.AddOpenAIChatClient(apiKey, modelOrDeploymentName, endpointUri, builder);
}
public static IServiceCollection AddOpenAIChatClient(
this IServiceCollection services,
string apiKey,
string modelOrDeploymentName,
Uri? endpoint = null,
Func<ChatClientBuilder, ChatClientBuilder>? builder = null)
{
return services
.AddSingleton(_ => endpoint is null
? new OpenAIClient(apiKey)
: new AzureOpenAIClient(endpoint, new ApiKeyCredential(apiKey)))
.AddChatClient(pipeline =>
{
builder?.Invoke(pipeline);
var openAiClient = pipeline.Services.GetRequiredService<OpenAIClient>();
return pipeline.Use(openAiClient.AsChatClient(modelOrDeploymentName));
});
}
}