sys/sysmod/openai/basicfuncs/
storage.rs

1use crate::sysmod::openai::{
2    Function, ParameterElement, ParameterType, Parameters,
3    function::{FuncArgs, FunctionTable, get_arg_i64_opt, get_arg_str},
4};
5use anyhow::{Result, ensure};
6use serde::{Deserialize, Serialize};
7use std::{
8    collections::{BTreeMap, HashMap, VecDeque},
9    path::{Path, PathBuf},
10};
11
12const NOTE_FILE_NAME: &str = "note.json";
13const NOTE_COUNT_MAX: usize = 8;
14const NOTE_LENGTH_MAX: usize = 256;
15const NOTE_LENGTH_MAX_I64: i64 = NOTE_LENGTH_MAX as i64;
16
17/// このモジュールの関数をすべて登録する。
18///
19/// [FunctionTable::basic_context] に [super::super::function::BasicContext] が
20/// 設定されている場合のみ登録される。
21pub fn register_all<T: 'static>(func_table: &mut FunctionTable<T>) {
22    if func_table.basic_context().storage_dir.is_some() {
23        register_load(func_table);
24        register_save(func_table);
25        register_delete(func_table);
26    }
27}
28
29/// ストレージからノートを読み込む。
30async fn load(storage_dir: PathBuf, args: &FuncArgs) -> Result<String> {
31    let user = get_arg_str(args, "user")?.to_string();
32
33    tokio::fs::create_dir_all(&storage_dir).await?;
34    let path = storage_dir.join(NOTE_FILE_NAME);
35    let json = {
36        let _lock = rlock_file().await;
37
38        let note = load_file(&path).await.unwrap_or_default();
39        note.map.get(&user).map_or_else(
40            || serde_json::to_string(&VecDeque::<String>::new()),
41            serde_json::to_string,
42        )?
43    };
44
45    Ok(json)
46}
47
48fn register_load<T: 'static>(func_table: &mut FunctionTable<T>) {
49    let mut properties = HashMap::new();
50    properties.insert(
51        "user".to_string(),
52        ParameterElement {
53            type_: vec![ParameterType::String],
54            description: Some("user name".to_string()),
55            ..Default::default()
56        },
57    );
58
59    func_table.register_function(
60        Function {
61            name: "note_load".to_string(),
62            description: Some("Load note from permanent storage".to_string()),
63            parameters: Parameters {
64                properties,
65                required: vec!["user".to_string()],
66                ..Default::default()
67            },
68            ..Default::default()
69        },
70        |bctx, _ctx, args| {
71            let storage_dir = bctx.storage_dir.as_ref().unwrap().clone();
72            Box::pin(load(storage_dir, args))
73        },
74    );
75}
76
77/// ストレージにノートを保存する。
78async fn save(storage_dir: PathBuf, args: &FuncArgs) -> Result<String> {
79    let user = get_arg_str(args, "user")?.to_string();
80    let content = get_arg_str(args, "content")?;
81    ensure!(
82        content.len() <= NOTE_LENGTH_MAX,
83        "content length must be less than {}",
84        NOTE_LENGTH_MAX
85    );
86    // タイムスタンプ付与
87    let elem = NoteElement {
88        datetime: chrono::Local::now().to_rfc3339(),
89        content: content.to_string(),
90    };
91
92    tokio::fs::create_dir_all(&storage_dir).await?;
93    let mut deleted = vec![];
94    let path = storage_dir.join(NOTE_FILE_NAME);
95    {
96        let _lock = wlock_file().await;
97
98        let mut note = load_file(&path).await.unwrap_or_default();
99        if !note.map.contains_key(&user) {
100            note.map.insert(user.clone(), VecDeque::new());
101        }
102        let list = note.map.get_mut(&user).unwrap();
103        list.push_back(elem);
104        while list.len() > NOTE_COUNT_MAX {
105            deleted.push(list.pop_front().unwrap());
106        }
107
108        save_file(&path, &note).await?;
109    }
110
111    #[derive(Serialize)]
112    struct FuncResult {
113        result: &'static str,
114        deleted: Vec<NoteElement>,
115    }
116    let result = FuncResult {
117        result: "OK",
118        deleted,
119    };
120
121    Ok(serde_json::to_string(&result)?)
122}
123
124fn register_save<T: 'static>(func_table: &mut FunctionTable<T>) {
125    let mut properties = HashMap::new();
126    properties.insert(
127        "user".to_string(),
128        ParameterElement {
129            type_: vec![ParameterType::String],
130            description: Some("user name".to_string()),
131            ..Default::default()
132        },
133    );
134    properties.insert(
135        "content".to_string(),
136        ParameterElement {
137            type_: vec![ParameterType::String],
138            description: Some("data to be saved".to_string()),
139            ..Default::default()
140        },
141    );
142
143    func_table.register_function(
144        Function {
145            name: "note_save".to_string(),
146            description: Some(format!("Save note to permanent storage. If {NOTE_COUNT_MAX} files already exist, the oldest one will be deleted.")),
147            parameters: Parameters {
148                properties,
149                required: vec!["user".to_string(), "content".to_string()],
150                ..Default::default()
151            },
152            ..Default::default()
153        },
154        |bctx, _ctx, args| {
155            let storage_dir = bctx.storage_dir.as_ref().unwrap().clone();
156            Box::pin(save(storage_dir, args))
157        },
158    );
159}
160
161/// ストレージからノートを部分削除する
162async fn delete(storage_dir: PathBuf, args: &FuncArgs) -> Result<String> {
163    let user = get_arg_str(args, "user")?.to_string();
164    let index = get_arg_i64_opt(args, "index", 0..NOTE_LENGTH_MAX_I64)?;
165
166    tokio::fs::create_dir_all(&storage_dir).await?;
167    let mut deleted = vec![];
168    let path = storage_dir.join(NOTE_FILE_NAME);
169    {
170        let _lock = wlock_file().await;
171
172        let mut note = load_file(&path).await.unwrap_or_default();
173        if !note.map.contains_key(&user) {
174            note.map.insert(user.clone(), VecDeque::new());
175        }
176        let list = note.map.get_mut(&user).unwrap();
177
178        if let Some(index) = index {
179            if let Some(elem) = list.remove(index as usize) {
180                deleted.push(elem);
181            }
182        } else {
183            while let Some(elem) = list.pop_front() {
184                deleted.push(elem);
185            }
186        }
187
188        save_file(&path, &note).await?;
189    }
190
191    #[derive(Serialize)]
192    struct FuncResult {
193        result: &'static str,
194        deleted: Vec<NoteElement>,
195    }
196    let result_str = if deleted.is_empty() {
197        "Error: No data deleted"
198    } else {
199        "OK"
200    };
201    let result = FuncResult {
202        result: result_str,
203        deleted,
204    };
205
206    Ok(serde_json::to_string(&result)?)
207}
208
209fn register_delete<T: 'static>(func_table: &mut FunctionTable<T>) {
210    let mut properties = HashMap::new();
211    properties.insert(
212        "user".to_string(),
213        ParameterElement {
214            type_: vec![ParameterType::String],
215            description: Some("user name".to_string()),
216            ..Default::default()
217        },
218    );
219    properties.insert(
220        "index".to_string(),
221        ParameterElement {
222            type_: vec![ParameterType::Integer, ParameterType::Null],
223            description: Some(format!(
224                "Data index to be deleted ({} <= index <= {}). If omitted, all data will be deleted.",
225                0,
226                NOTE_LENGTH_MAX_I64 - 1
227            )),
228            //minumum: Some(0),
229            //maximum: Some(NOTE_LENGTH_MAX_I64 - 1),
230            ..Default::default()
231        },
232    );
233
234    func_table.register_function(
235        Function {
236            name: "note_delete".to_string(),
237            description: Some("Delete note".to_string()),
238            parameters: Parameters {
239                properties,
240                required: vec!["user".to_string(), "index".to_string()],
241                ..Default::default()
242            },
243            ..Default::default()
244        },
245        |bctx, _ctx, args| {
246            let storage_dir = bctx.storage_dir.as_ref().unwrap().clone();
247            Box::pin(delete(storage_dir, args))
248        },
249    );
250}
251
252// -----------------------------------------------------------------------------
253
254#[derive(Debug, Clone, Default, Serialize, Deserialize)]
255struct AssistantNote {
256    #[serde(flatten)]
257    map: BTreeMap<String, VecDeque<NoteElement>>,
258}
259
260#[derive(Debug, Clone, Default, Serialize, Deserialize)]
261struct NoteElement {
262    datetime: String,
263    content: String,
264}
265
266static LOCK: tokio::sync::RwLock<()> = tokio::sync::RwLock::const_new(());
267
268async fn rlock_file() -> tokio::sync::RwLockReadGuard<'static, ()> {
269    LOCK.read().await
270}
271
272async fn wlock_file() -> tokio::sync::RwLockWriteGuard<'static, ()> {
273    LOCK.write().await
274}
275
276async fn load_file(path: impl AsRef<Path>) -> Result<AssistantNote> {
277    assert!(LOCK.try_write().is_err());
278
279    let src = tokio::fs::read_to_string(path).await?;
280    let note: AssistantNote = serde_json::from_str(&src)?;
281
282    Ok(note)
283}
284
285async fn save_file(path: impl AsRef<Path>, note: &AssistantNote) -> Result<()> {
286    assert!(LOCK.try_write().is_err());
287
288    let mut src = serde_json::to_string_pretty(note)?;
289    src.push('\n');
290    tokio::fs::write(path, src.as_bytes()).await?;
291
292    Ok(())
293}
294
295#[cfg(test)]
296mod tests {
297    use std::io::Write;
298
299    use super::*;
300
301    #[tokio::test]
302    async fn assistant_node_parse() -> Result<()> {
303        let f = tempfile::NamedTempFile::new()?;
304
305        let json_src = r#"{"user1": []}"#;
306        writeln!(f.as_file(), "{json_src}")?;
307
308        {
309            let _lock = rlock_file().await;
310            let note = load_file(f.path()).await?;
311            assert_eq!(note.map.len(), 1);
312            assert!(note.map.contains_key("user1"));
313        }
314
315        Ok(())
316    }
317}