1use super::basicfuncs;
4use crate::config;
5use crate::taskserver::Control;
6use anyhow::{Result, anyhow, bail};
7use log::{info, warn};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::future::Future;
11use std::ops::RangeBounds;
12use std::path::{Path, PathBuf};
13use std::pin::Pin;
14use std::sync::Arc;
15use std::sync::atomic::{AtomicBool, Ordering};
16
17pub const FUNCTION_TOKEN: usize = 800;
21
22pub use super::Function;
26pub use super::ParameterElement;
27pub use super::Parameters;
28pub type FuncBodyAsync<'a> = Pin<Box<dyn Future<Output = Result<String>> + Sync + Send + 'a>>;
30pub type FuncBody<T> = dyn Fn(Arc<BasicContext>, T, &FuncArgs) -> FuncBodyAsync + Sync + Send;
34pub type FuncArgs = HashMap<String, serde_json::value::Value>;
36
37#[derive(Default, Debug, Clone, Serialize, Deserialize)]
40pub struct Args {
41 #[serde(flatten)]
42 args: FuncArgs,
43}
44
45pub struct BasicContext {
47 pub ctrl: Control,
49 pub storage_dir: Option<PathBuf>,
51 pub debug_mode: AtomicBool,
55}
56
57pub struct FunctionTable<T> {
66 function_list: Vec<Function>,
68 call_table: HashMap<String, Box<FuncBody<T>>>,
70 basic_context: Arc<BasicContext>,
72}
73
74impl<T: 'static> FunctionTable<T> {
75 pub fn new(ctrl: Control, storage_dir_name: Option<&str>) -> Self {
76 let storage_dir = if let Some(storage_dir_name) = storage_dir_name {
79 let dir = config::get(|c| c.openai.storage_dir.clone());
80 if !dir.is_empty() {
81 Some(Path::new(&dir).join(storage_dir_name))
82 } else {
83 None
84 }
85 } else {
86 None
87 };
88 let basic_context = BasicContext {
89 ctrl,
90 storage_dir,
91 debug_mode: AtomicBool::new(false),
92 };
93
94 Self {
95 function_list: Default::default(),
96 call_table: Default::default(),
97 basic_context: Arc::new(basic_context),
98 }
99 }
100
101 pub fn basic_context(&self) -> &BasicContext {
102 &self.basic_context
103 }
104
105 pub fn debug_mode(&self) -> bool {
106 self.basic_context.debug_mode.load(Ordering::SeqCst)
107 }
108
109 pub fn function_list(&self) -> &[Function] {
111 &self.function_list
112 }
113
114 pub fn create_help(&self) -> String {
116 let mut text = String::new();
117
118 let mut first = true;
119 for f in self.function_list.iter() {
120 if first {
121 first = false;
122 } else {
123 text.push('\n');
124 }
125
126 text.push_str(&f.name);
127
128 let mut params: Vec<_> = f.parameters.properties.keys().cloned().collect();
129 params.sort();
130 text.push_str(&format!("({})", params.join(", ")));
131
132 if let Some(desc) = &f.description {
133 text.push_str(&format!("\n {desc}"));
134 }
135 }
136
137 text
138 }
139
140 pub async fn call(&self, ctx: T, func_name: &str, args_json_str: &str) -> String {
145 info!("[openai-func] Call {func_name} {args_json_str}");
146
147 let res = {
148 let args = serde_json::from_str::<Args>(args_json_str)
149 .map_err(|err| anyhow!("Arguments parse error: {err}"));
150 match args {
151 Ok(args) => self.call_internal(ctx, func_name, &args.args).await,
152 Err(err) => Err(err),
153 }
154 };
155
156 match &res {
157 Ok(res) => {
158 info!("[openai-func] {func_name} returned: {res}");
159 res.to_string()
160 }
161 Err(err) => {
162 warn!("[openai-func] {func_name} failed: {err:#?}");
163 err.to_string()
164 }
165 }
166 }
167
168 async fn call_internal(&self, ctx: T, func_name: &str, args: &FuncArgs) -> Result<String> {
170 let func = self
171 .call_table
172 .get(func_name)
173 .ok_or_else(|| anyhow!("Error: Function {func_name} not found"))?;
174
175 let bctx = Arc::clone(&self.basic_context);
177 func(bctx, ctx, args)
178 .await
179 .map_err(|err| anyhow!("Error: {err}"))
180 }
181
182 pub fn register_function(
184 &mut self,
185 function: Function,
186 body: impl Fn(Arc<BasicContext>, T, &FuncArgs) -> FuncBodyAsync + Send + Sync + 'static,
187 ) {
188 let name = function.name.clone();
189 self.function_list.push(function);
190 self.call_table.insert(name, Box::new(body));
191 }
192
193 pub fn register_basic_functions(&mut self) {
195 basicfuncs::register_all(self);
196 }
197}
198
199pub fn get_arg_str<'a>(args: &'a FuncArgs, name: &str) -> Result<&'a str> {
202 let value = args.get(&name.to_string());
203 let value = value.ok_or_else(|| anyhow!("Error: Argument {name} is required"))?;
204 let value = value
205 .as_str()
206 .ok_or_else(|| anyhow!("Error: Argument {name} must be string"))?;
207
208 Ok(value)
209}
210
211pub fn get_arg_bool(args: &FuncArgs, name: &str) -> Result<bool> {
215 let value = args.get(&name.to_string());
216 let value = value.ok_or_else(|| anyhow!("Error: Argument {name} is required"))?;
217 let value = value
218 .as_bool()
219 .ok_or_else(|| anyhow!("Error: Argument {name} must be boolean"))?;
220
221 Ok(value)
222}
223
224pub fn get_arg_bool_opt(args: &FuncArgs, name: &str) -> Result<Option<bool>> {
229 if args.get(&name.to_string()).is_none() {
230 Ok(None)
231 } else {
232 get_arg_bool(args, name).map(Some)
233 }
234}
235
236pub fn get_arg_i64(args: &FuncArgs, name: &str, range: impl RangeBounds<i64>) -> Result<i64> {
240 let value = args.get(&name.to_string());
241 let value = value.ok_or_else(|| anyhow!("Error: Argument {name} is required"))?;
242 let value = value
243 .as_i64()
244 .ok_or_else(|| anyhow!("Error: Argument {name} must be integer"))?;
245
246 if range.contains(&value) {
247 Ok(value)
248 } else {
249 bail!("Error: Out of range: {name}")
250 }
251}
252
253pub fn get_arg_i64_opt(
258 args: &FuncArgs,
259 name: &str,
260 range: impl RangeBounds<i64>,
261) -> Result<Option<i64>> {
262 if args.get(&name.to_string()).is_none() {
263 Ok(None)
264 } else {
265 get_arg_i64(args, name, range).map(Some)
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272
273 #[test]
274 fn function_args() {
275 let mut args = FuncArgs::new();
276 args.insert("str".to_string(), "ok".into());
277 args.insert("bool_f".to_string(), false.into());
278 args.insert("bool_t".to_string(), true.into());
279 args.insert("int".to_string(), 42.into());
280
281 assert_eq!(get_arg_str(&args, "str").unwrap(), "ok");
282 assert!(
283 get_arg_str(&args, "not_found")
284 .unwrap_err()
285 .to_string()
286 .contains("required")
287 );
288
289 assert!(!get_arg_bool(&args, "bool_f",).unwrap());
290 assert!(get_arg_bool(&args, "bool_t",).unwrap());
291 assert!(
292 get_arg_bool(&args, "str")
293 .unwrap_err()
294 .to_string()
295 .contains("must be boolean")
296 );
297 assert!(
298 get_arg_bool(&args, "not_found")
299 .unwrap_err()
300 .to_string()
301 .contains("required")
302 );
303
304 assert_eq!(get_arg_bool_opt(&args, "bool_f").unwrap(), Some(false));
305 assert_eq!(get_arg_bool_opt(&args, "bool_t").unwrap(), Some(true));
306 assert!(
307 get_arg_bool_opt(&args, "str")
308 .unwrap_err()
309 .to_string()
310 .contains("must be boolean")
311 );
312 assert_eq!(get_arg_bool_opt(&args, "not_found").unwrap(), None);
313
314 assert_eq!(get_arg_i64(&args, "int", 1..=42).unwrap(), 42);
315 assert!(
316 get_arg_i64(&args, "str", 1..43)
317 .unwrap_err()
318 .to_string()
319 .contains("must be integer")
320 );
321 assert!(
322 get_arg_i64(&args, "int", 1..42)
323 .unwrap_err()
324 .to_string()
325 .contains("Out of range")
326 );
327 assert!(
328 get_arg_i64(&args, "not_found", 1..42)
329 .unwrap_err()
330 .to_string()
331 .contains("required")
332 );
333
334 assert_eq!(get_arg_i64_opt(&args, "int", 1..=42).unwrap(), Some(42));
335 assert_eq!(get_arg_i64_opt(&args, "int", 1..=42).unwrap(), Some(42));
336 assert!(
337 get_arg_i64_opt(&args, "str", 1..43)
338 .unwrap_err()
339 .to_string()
340 .contains("must be integer")
341 );
342 assert!(
343 get_arg_i64_opt(&args, "int", 1..42)
344 .unwrap_err()
345 .to_string()
346 .contains("Out of range")
347 );
348 assert_eq!(get_arg_i64_opt(&args, "not_found", 1..42).unwrap(), None);
349 }
350}