-
Notifications
You must be signed in to change notification settings - Fork 25
Expand file tree
/
Copy pathcomplex_function.rs
More file actions
131 lines (113 loc) · 4.1 KB
/
complex_function.rs
File metadata and controls
131 lines (113 loc) · 4.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
use std::{collections::VecDeque, env};
use gemini_rust::{
Content, FunctionCall, FunctionCallingMode, FunctionDeclaration, FunctionResponse, Gemini,
Part, Role, ThinkingConfig,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use tracing::{info, warn};
#[derive(Deserialize, Serialize, Debug, JsonSchema)]
struct Command {
/// The command to run
command: String,
/// The command arguments
arguments: Vec<String>,
}
#[derive(Deserialize, Serialize, Debug, JsonSchema)]
struct RootCommander {
/// The current step number (starts at 1)
attempt: i64,
/// The command to use
command: Command,
}
#[derive(Deserialize, Serialize, Debug, JsonSchema)]
struct StatusResponse {
/// The status of the operation
status: bool,
/// Additional details about the operation
detail: String,
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing_subscriber::fmt::init();
let api_key = env::var("GEMINI_API_KEY")?;
let client = Gemini::pro(api_key).expect("unable to create Gemini API client");
let commander_tool = FunctionDeclaration::new(
"execute_command",
"Execute a system command with parameters",
None,
)
.with_parameters::<RootCommander>()
.with_response::<StatusResponse>();
info!("Sending function response...");
let response = client
.generate_content()
.with_thinking_config(ThinkingConfig::dynamic_thinking())
.with_temperature(0.1)
.with_top_p(0.95)
.with_function(commander_tool.clone())
.with_function_calling_mode(FunctionCallingMode::Any)
.with_user_message(
"I need you to run a system command 'bleep' with parameters 'boop' and 'bop'.",
)
.execute()
.await?;
let contents = response
.candidates
.into_iter()
.map(|c| c.content)
.collect::<Vec<_>>();
let mut function_queue = VecDeque::<FunctionCall>::new();
for content in &contents {
if let Some(parts) = &content.parts {
for part in parts {
if let Part::FunctionCall { function_call, .. } = part {
function_queue.push_front(function_call.clone());
}
if let Part::FunctionResponse { function_response } = part {
if let Some(last_call) = function_queue.pop_front() {
if last_call.name != function_response.name {
warn!(
"Warning: Function response name '{}' does not match last function call name '{}'",
function_response.name, last_call.name
);
}
} else {
warn!(
"Warning: Function response name '{}' has no matching function call",
function_response.name
);
}
}
}
}
}
let mut reply = client.generate_content();
reply.contents.extend(contents);
for function_call in function_queue {
info!(
"Function call received: {} with args:\n{}",
function_call.name,
serde_json::to_string_pretty(&function_call.args)?
);
let result = serde_json::from_value::<RootCommander>(function_call.args.clone())?;
// Simulate command execution
let Command { command, arguments } = result.command;
let status = StatusResponse {
status: true,
detail: format!(
"Command '{command}' executed successfully with arguments: {arguments:?}"
),
};
let content = Content::function_response(FunctionResponse::from_schema(
function_call.name.clone(),
status,
)?)
.with_role(Role::User);
reply.contents.push(content);
}
info!("Sending function response...",);
let final_response = reply.execute().await?;
info!("Final response from model: {}", final_response.text(),);
Ok(())
}