sys/sysmod/openai/
function.rs

1//! OpenAI API - function.
2
3use super::basicfuncs;
4use crate::config;
5use crate::taskserver::Control;
6use anyhow::{Result, anyhow, bail};
7use log::{info, warn};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::future::Future;
11use std::ops::RangeBounds;
12use std::path::{Path, PathBuf};
13use std::pin::Pin;
14use std::sync::Arc;
15use std::sync::atomic::{AtomicBool, Ordering};
16
17/// Function でもトークンを消費するが、算出方法がよく分からないので定数で確保する。
18/// トークン制限エラーが起きた場合、エラーメッセージ中に含まれていた気がするので
19/// それより大きめに確保する。
20pub const FUNCTION_TOKEN: usize = 800;
21
22// https://users.rust-lang.org/t/how-to-handle-a-vector-of-async-function-pointers/39804
23
24/// OpenAI API json 定義の再エクスポート。
25pub use super::Function;
26pub use super::ParameterElement;
27pub use super::Parameters;
28/// sync fn で、async fn に引数を引き渡して呼び出しその Future を返す関数型。
29pub type FuncBodyAsync<'a> = Pin<Box<dyn Future<Output = Result<String>> + Sync + Send + 'a>>;
30/// 関数の Rust 上での定義。
31///
32/// 引数は [BasicContext], T, [FuncArgs] で、返り値は文字列の async fn。
33pub type FuncBody<T> = dyn Fn(Arc<BasicContext>, T, &FuncArgs) -> FuncBodyAsync + Sync + Send;
34/// 引数。文字列から Json value へのマップ。
35pub type FuncArgs = HashMap<String, serde_json::value::Value>;
36
37/// 引数は JSON ソース文字列で与えられる。
38/// デシリアライズでパースするための構造体。
39#[derive(Default, Debug, Clone, Serialize, Deserialize)]
40pub struct Args {
41    #[serde(flatten)]
42    args: FuncArgs,
43}
44
45/// 標準で関数に提供されるコンテキスト情報。
46pub struct BasicContext {
47    /// システムハンドル。
48    pub ctrl: Control,
49    /// 永続化データストレージの場所。
50    pub storage_dir: Option<PathBuf>,
51    /// デバッグモード。
52    /// 標準関数から変更されるが、自動でトレースは行われない。
53    /// 関数呼び出し側で制御が必要。
54    pub debug_mode: AtomicBool,
55}
56
57/// OpenAI function の管理テーブル。
58///
59/// [BasicContext] は標準で関数に渡されるコンテキスト情報で、
60/// コンストラクタで初期化され、[Self] はそれへの参照を保持する。
61///
62/// *T* は追加のコンテキスト情報の型。
63/// 標準以外の関数を追加する場合に使用可能。
64/// [Self::call] に渡したのものがそのまま関数に引き渡される。
65pub struct FunctionTable<T> {
66    /// OpenAI API に渡すためのリスト。
67    function_list: Vec<Function>,
68    /// 関数名から Rust 関数へのマップ。
69    call_table: HashMap<String, Box<FuncBody<T>>>,
70    /// [BasicContext] への参照。
71    basic_context: Arc<BasicContext>,
72}
73
74impl<T: 'static> FunctionTable<T> {
75    pub fn new(ctrl: Control, storage_dir_name: Option<&str>) -> Self {
76        // openai config でディレクトリが指定されており、かつ、
77        // この関数にストレージディレクトリ名が指定されている場合、Some
78        let storage_dir = if let Some(storage_dir_name) = storage_dir_name {
79            let dir = config::get(|c| c.openai.storage_dir.clone());
80            if !dir.is_empty() {
81                Some(Path::new(&dir).join(storage_dir_name))
82            } else {
83                None
84            }
85        } else {
86            None
87        };
88        let basic_context = BasicContext {
89            ctrl,
90            storage_dir,
91            debug_mode: AtomicBool::new(false),
92        };
93
94        Self {
95            function_list: Default::default(),
96            call_table: Default::default(),
97            basic_context: Arc::new(basic_context),
98        }
99    }
100
101    pub fn basic_context(&self) -> &BasicContext {
102        &self.basic_context
103    }
104
105    pub fn debug_mode(&self) -> bool {
106        self.basic_context.debug_mode.load(Ordering::SeqCst)
107    }
108
109    /// OpenAI API に渡すためのリストを取得する。
110    pub fn function_list(&self) -> &[Function] {
111        &self.function_list
112    }
113
114    /// 関数一覧のヘルプ文字列を生成する。
115    pub fn create_help(&self) -> String {
116        let mut text = String::new();
117
118        let mut first = true;
119        for f in self.function_list.iter() {
120            if first {
121                first = false;
122            } else {
123                text.push('\n');
124            }
125
126            text.push_str(&f.name);
127
128            let mut params: Vec<_> = f.parameters.properties.keys().cloned().collect();
129            params.sort();
130            text.push_str(&format!("({})", params.join(", ")));
131
132            if let Some(desc) = &f.description {
133                text.push_str(&format!("\n    {desc}"));
134            }
135        }
136
137        text
138    }
139
140    /// 関数を呼び出す。
141    ///
142    /// 引数は json 文字列であり、OpenAI API からのデータをそのまま渡せる。
143    /// エラーも適切な文字列メッセージとして返す。
144    pub async fn call(&self, ctx: T, func_name: &str, args_json_str: &str) -> String {
145        info!("[openai-func] Call {func_name} {args_json_str}");
146
147        let res = {
148            let args = serde_json::from_str::<Args>(args_json_str)
149                .map_err(|err| anyhow!("Arguments parse error: {err}"));
150            match args {
151                Ok(args) => self.call_internal(ctx, func_name, &args.args).await,
152                Err(err) => Err(err),
153            }
154        };
155
156        match &res {
157            Ok(res) => {
158                info!("[openai-func] {func_name} returned: {res}");
159                res.to_string()
160            }
161            Err(err) => {
162                warn!("[openai-func] {func_name} failed: {err:#?}");
163                err.to_string()
164            }
165        }
166    }
167
168    /// [Self::call] の内部メイン処理。
169    async fn call_internal(&self, ctx: T, func_name: &str, args: &FuncArgs) -> Result<String> {
170        let func = self
171            .call_table
172            .get(func_name)
173            .ok_or_else(|| anyhow!("Error: Function {func_name} not found"))?;
174
175        // call body
176        let bctx = Arc::clone(&self.basic_context);
177        func(bctx, ctx, args)
178            .await
179            .map_err(|err| anyhow!("Error: {err}"))
180    }
181
182    /// 関数を登録する。
183    pub fn register_function(
184        &mut self,
185        function: Function,
186        body: impl Fn(Arc<BasicContext>, T, &FuncArgs) -> FuncBodyAsync + Send + Sync + 'static,
187    ) {
188        let name = function.name.clone();
189        self.function_list.push(function);
190        self.call_table.insert(name, Box::new(body));
191    }
192
193    /// [basicfuncs] 以下のすべての基本的な関数を登録する。
194    pub fn register_basic_functions(&mut self) {
195        basicfuncs::register_all(self);
196    }
197}
198
199/// args から引数名で文字列値を取得する。
200/// 見つからない、または型が違う場合、いい感じのエラーメッセージの [anyhow::Error] を返す。
201pub fn get_arg_str<'a>(args: &'a FuncArgs, name: &str) -> Result<&'a str> {
202    let value = args.get(&name.to_string());
203    let value = value.ok_or_else(|| anyhow!("Error: Argument {name} is required"))?;
204    let value = value
205        .as_str()
206        .ok_or_else(|| anyhow!("Error: Argument {name} must be string"))?;
207
208    Ok(value)
209}
210
211/// args から引数名で bool を取得する。
212/// 見つからない、または型が違う場合、
213/// いい感じのエラーメッセージの [anyhow::Error] を返す。
214pub fn get_arg_bool(args: &FuncArgs, name: &str) -> Result<bool> {
215    let value = args.get(&name.to_string());
216    let value = value.ok_or_else(|| anyhow!("Error: Argument {name} is required"))?;
217    let value = value
218        .as_bool()
219        .ok_or_else(|| anyhow!("Error: Argument {name} must be boolean"))?;
220
221    Ok(value)
222}
223
224/// args から引数名で bool を取得する。
225/// 見つからない場合は None を返す。
226/// 型が違う場合、
227/// いい感じのエラーメッセージの [anyhow::Error] を返す。
228pub fn get_arg_bool_opt(args: &FuncArgs, name: &str) -> Result<Option<bool>> {
229    if args.get(&name.to_string()).is_none() {
230        Ok(None)
231    } else {
232        get_arg_bool(args, name).map(Some)
233    }
234}
235
236/// args から引数名で i64 を取得する。
237/// 見つからない、または型が違う場合、または範囲外の場合、
238/// いい感じのエラーメッセージの [anyhow::Error] を返す。
239pub fn get_arg_i64(args: &FuncArgs, name: &str, range: impl RangeBounds<i64>) -> Result<i64> {
240    let value = args.get(&name.to_string());
241    let value = value.ok_or_else(|| anyhow!("Error: Argument {name} is required"))?;
242    let value = value
243        .as_i64()
244        .ok_or_else(|| anyhow!("Error: Argument {name} must be integer"))?;
245
246    if range.contains(&value) {
247        Ok(value)
248    } else {
249        bail!("Error: Out of range: {name}")
250    }
251}
252
253/// args から引数名で i64 を取得する。
254/// 見つからない場合は None を返す。
255/// 変換に失敗した場合、または範囲外の場合、
256/// いい感じのエラーメッセージの [anyhow::Error] を返す。
257pub fn get_arg_i64_opt(
258    args: &FuncArgs,
259    name: &str,
260    range: impl RangeBounds<i64>,
261) -> Result<Option<i64>> {
262    if args.get(&name.to_string()).is_none() {
263        Ok(None)
264    } else {
265        get_arg_i64(args, name, range).map(Some)
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272
273    #[test]
274    fn function_args() {
275        let mut args = FuncArgs::new();
276        args.insert("str".to_string(), "ok".into());
277        args.insert("bool_f".to_string(), false.into());
278        args.insert("bool_t".to_string(), true.into());
279        args.insert("int".to_string(), 42.into());
280
281        assert_eq!(get_arg_str(&args, "str").unwrap(), "ok");
282        assert!(
283            get_arg_str(&args, "not_found")
284                .unwrap_err()
285                .to_string()
286                .contains("required")
287        );
288
289        assert!(!get_arg_bool(&args, "bool_f",).unwrap());
290        assert!(get_arg_bool(&args, "bool_t",).unwrap());
291        assert!(
292            get_arg_bool(&args, "str")
293                .unwrap_err()
294                .to_string()
295                .contains("must be boolean")
296        );
297        assert!(
298            get_arg_bool(&args, "not_found")
299                .unwrap_err()
300                .to_string()
301                .contains("required")
302        );
303
304        assert_eq!(get_arg_bool_opt(&args, "bool_f").unwrap(), Some(false));
305        assert_eq!(get_arg_bool_opt(&args, "bool_t").unwrap(), Some(true));
306        assert!(
307            get_arg_bool_opt(&args, "str")
308                .unwrap_err()
309                .to_string()
310                .contains("must be boolean")
311        );
312        assert_eq!(get_arg_bool_opt(&args, "not_found").unwrap(), None);
313
314        assert_eq!(get_arg_i64(&args, "int", 1..=42).unwrap(), 42);
315        assert!(
316            get_arg_i64(&args, "str", 1..43)
317                .unwrap_err()
318                .to_string()
319                .contains("must be integer")
320        );
321        assert!(
322            get_arg_i64(&args, "int", 1..42)
323                .unwrap_err()
324                .to_string()
325                .contains("Out of range")
326        );
327        assert!(
328            get_arg_i64(&args, "not_found", 1..42)
329                .unwrap_err()
330                .to_string()
331                .contains("required")
332        );
333
334        assert_eq!(get_arg_i64_opt(&args, "int", 1..=42).unwrap(), Some(42));
335        assert_eq!(get_arg_i64_opt(&args, "int", 1..=42).unwrap(), Some(42));
336        assert!(
337            get_arg_i64_opt(&args, "str", 1..43)
338                .unwrap_err()
339                .to_string()
340                .contains("must be integer")
341        );
342        assert!(
343            get_arg_i64_opt(&args, "int", 1..42)
344                .unwrap_err()
345                .to_string()
346                .contains("Out of range")
347        );
348        assert_eq!(get_arg_i64_opt(&args, "not_found", 1..42).unwrap(), None);
349    }
350}