sys/sysmod/openai/
chat_history.rs

1//! OpenAI API の会話コンテキストのトークン数制限付き管理。
2
3use crate::sysmod::openai::{InputItem, Role, WebSearchCall};
4
5use anyhow::{Result, ensure};
6use std::collections::VecDeque;
7use tiktoken_rs::CoreBPE;
8
9use super::{InputContent, InputImageDetail};
10
11/// 会話履歴管理。
12pub struct ChatHistory {
13    /// トークナイザ。
14    core: CoreBPE,
15
16    /// トークン数。
17    total_token_limit: usize,
18    /// トークン数合計上限。
19    token_limit: usize,
20    /// 現在のトークン数合計。
21    token_count: usize,
22    /// 履歴データのキュー。
23    history: VecDeque<Element>,
24}
25
26/// 履歴データ。
27struct Element {
28    /// メッセージのリスト。
29    /// 削除は [Element] 単位で行われる。
30    items: Vec<InputItem>,
31    /// [Self::msg] の総トークン数。
32    token_count: usize,
33}
34
35impl ChatHistory {
36    /// コンストラクタ。
37    ///
38    /// * `model` - OpenAI API モデル名。
39    pub fn new(model: &str) -> Self {
40        let core = tiktoken_rs::get_bpe_from_model(model).unwrap();
41        let total_token_limit = tiktoken_rs::model::get_context_size(model);
42
43        Self {
44            core,
45            total_token_limit,
46            token_limit: total_token_limit,
47            token_count: 0,
48            history: Default::default(),
49        }
50    }
51
52    /// トークン数合計上限を減らす。
53    pub fn reserve_tokens(&mut self, token_count: usize) {
54        if self.token_limit < token_count {
55            panic!("Invalid reserve size");
56        }
57        self.token_limit -= token_count;
58    }
59
60    pub fn push_input_message(&mut self, role: Role, text: &str) -> Result<()> {
61        assert!(matches!(role, Role::Developer | Role::User));
62
63        let tokens = self.tokenize(text);
64        let token_count = tokens.len();
65
66        let item = InputItem::Message {
67            role,
68            content: vec![InputContent::InputText {
69                text: text.to_string(),
70            }],
71        };
72
73        self.push(vec![item], token_count)
74    }
75
76    pub fn push_input_and_images(
77        &mut self,
78        role: Role,
79        text: &str,
80        images: impl IntoIterator<Item = InputContent>,
81    ) -> Result<()> {
82        assert!(matches!(role, Role::Developer | Role::User));
83
84        let tokens = self.tokenize(text);
85        let mut token_count = tokens.len();
86
87        // content = [InputText, InputImage*]
88        let mut content = vec![InputContent::InputText {
89            text: text.to_string(),
90        }];
91
92        const IMAGE_TOKEN_LOW: usize = 85;
93        for image in images {
94            match &image {
95                InputContent::InputImage {
96                    image_url: _,
97                    detail,
98                } => {
99                    assert!(matches!(detail, InputImageDetail::Low));
100                }
101                _ => {
102                    panic!("Must be an InputImage");
103                }
104            }
105            content.push(image);
106            token_count += IMAGE_TOKEN_LOW;
107        }
108
109        let item = InputItem::Message { role, content };
110        self.push(vec![item], token_count)
111    }
112
113    pub fn push_output_message(&mut self, text: &str) -> Result<()> {
114        self.push_output_and_tools(Some(text), std::iter::empty())
115    }
116
117    pub fn push_output_and_tools(
118        &mut self,
119        text: Option<&str>,
120        web_search_ids: impl Iterator<Item = WebSearchCall>,
121    ) -> Result<()> {
122        let mut items = vec![];
123        let mut token_count = 0;
124
125        if let Some(text) = text {
126            let tokens = self.tokenize(text);
127
128            let content = vec![InputContent::OutputText {
129                text: text.to_string(),
130            }];
131            let item = InputItem::Message {
132                role: Role::Assistant,
133                content,
134            };
135            items.push(item);
136            token_count += tokens.len();
137        }
138
139        for wsc in web_search_ids {
140            let item = InputItem::WebSearchCall(wsc);
141            items.push(item);
142            // コンテキストウィンドウサイズには影響しないらしい
143            token_count += 0;
144        }
145
146        // 空なら追加せず成功とする
147        if !items.is_empty() {
148            self.push(items, token_count)
149        } else {
150            Ok(())
151        }
152    }
153
154    pub fn push_function(
155        &mut self,
156        call_id: &str,
157        name: &str,
158        arguments: &str,
159        output: &str,
160    ) -> Result<()> {
161        let item1 = InputItem::FunctionCall {
162            call_id: call_id.to_string(),
163            name: name.to_string(),
164            arguments: arguments.to_string(),
165        };
166        let item2 = InputItem::FunctionCallOutput {
167            call_id: call_id.to_string(),
168            output: output.to_string(),
169        };
170        // call_id も含めるべきかは不明。
171        let token_count = self.tokenize(name).len()
172            + self.tokenize(arguments).len()
173            + self.tokenize(output).len();
174
175        self.push(vec![item1, item2], token_count)
176    }
177
178    /// ヒストリの最後にエントリを追加する。
179    ///
180    /// 合計サイズを超えた場合、超えなくなるように先頭から削除する。
181    /// このエントリだけでサイズを超えてしまっている場合、エラー。
182    fn push(&mut self, items: Vec<InputItem>, token_count: usize) -> Result<()> {
183        ensure!(token_count <= self.token_limit, "Too long message");
184
185        self.history.push_back(Element { items, token_count });
186        self.token_count += token_count;
187
188        while self.token_count > self.token_limit {
189            let front = self.history.pop_front().unwrap();
190            self.token_count -= front.token_count;
191        }
192
193        Ok(())
194    }
195
196    /// 全履歴をクリアする。
197    pub fn clear(&mut self) {
198        self.history.clear();
199        self.token_count = 0;
200    }
201
202    /// 全履歴を走査するイテレータを返す。
203    pub fn iter(&self) -> impl Iterator<Item = &InputItem> {
204        self.history.iter().flat_map(|e| e.items.iter())
205    }
206
207    /// 履歴の数を返す。
208    pub fn len(&self) -> usize {
209        self.history.len()
210    }
211
212    /// 履歴のが空かどうかを返す。
213    pub fn is_empty(&self) -> bool {
214        self.history.len() == 0
215    }
216
217    /// トークン制限総量を返す。
218    pub fn get_total_limit(&self) -> usize {
219        self.total_token_limit
220    }
221
222    /// 現在のトークン数使用量を (usage / total) のタプルで返す。
223    pub fn usage(&self) -> (usize, usize) {
224        (self.token_count, self.token_limit)
225    }
226
227    /// 文章をトークン化する。
228    fn tokenize(&self, text: &str) -> Vec<u32> {
229        self.core.encode_with_special_tokens(text)
230    }
231
232    /// 文章のトークン数を数える。
233    pub fn token_count(&self, text: &str) -> usize {
234        self.tokenize(text).len()
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241
242    #[test]
243    fn token() {
244        let hist = ChatHistory::new("gpt-4o");
245        let count = hist.token_count("こんにちは、管理人形さん。");
246
247        // https://platform.openai.com/tokenizer
248        assert_eq!(7, count);
249    }
250}