sys/sysmod/openai/
chat_history.rs1use crate::sysmod::openai::{InputItem, Role, WebSearchCall};
4
5use anyhow::{Result, ensure};
6use std::collections::VecDeque;
7use tiktoken_rs::CoreBPE;
8
9use super::{InputContent, InputImageDetail};
10
11pub struct ChatHistory {
13 core: CoreBPE,
15
16 total_token_limit: usize,
18 token_limit: usize,
20 token_count: usize,
22 history: VecDeque<Element>,
24}
25
26struct Element {
28 items: Vec<InputItem>,
31 token_count: usize,
33}
34
35impl ChatHistory {
36 pub fn new(model: &str) -> Self {
40 let core = tiktoken_rs::get_bpe_from_model(model).unwrap();
41 let total_token_limit = tiktoken_rs::model::get_context_size(model);
42
43 Self {
44 core,
45 total_token_limit,
46 token_limit: total_token_limit,
47 token_count: 0,
48 history: Default::default(),
49 }
50 }
51
52 pub fn reserve_tokens(&mut self, token_count: usize) {
54 if self.token_limit < token_count {
55 panic!("Invalid reserve size");
56 }
57 self.token_limit -= token_count;
58 }
59
60 pub fn push_input_message(&mut self, role: Role, text: &str) -> Result<()> {
61 assert!(matches!(role, Role::Developer | Role::User));
62
63 let tokens = self.tokenize(text);
64 let token_count = tokens.len();
65
66 let item = InputItem::Message {
67 role,
68 content: vec![InputContent::InputText {
69 text: text.to_string(),
70 }],
71 };
72
73 self.push(vec![item], token_count)
74 }
75
76 pub fn push_input_and_images(
77 &mut self,
78 role: Role,
79 text: &str,
80 images: impl IntoIterator<Item = InputContent>,
81 ) -> Result<()> {
82 assert!(matches!(role, Role::Developer | Role::User));
83
84 let tokens = self.tokenize(text);
85 let mut token_count = tokens.len();
86
87 let mut content = vec![InputContent::InputText {
89 text: text.to_string(),
90 }];
91
92 const IMAGE_TOKEN_LOW: usize = 85;
93 for image in images {
94 match &image {
95 InputContent::InputImage {
96 image_url: _,
97 detail,
98 } => {
99 assert!(matches!(detail, InputImageDetail::Low));
100 }
101 _ => {
102 panic!("Must be an InputImage");
103 }
104 }
105 content.push(image);
106 token_count += IMAGE_TOKEN_LOW;
107 }
108
109 let item = InputItem::Message { role, content };
110 self.push(vec![item], token_count)
111 }
112
113 pub fn push_output_message(&mut self, text: &str) -> Result<()> {
114 self.push_output_and_tools(Some(text), std::iter::empty())
115 }
116
117 pub fn push_output_and_tools(
118 &mut self,
119 text: Option<&str>,
120 web_search_ids: impl Iterator<Item = WebSearchCall>,
121 ) -> Result<()> {
122 let mut items = vec![];
123 let mut token_count = 0;
124
125 if let Some(text) = text {
126 let tokens = self.tokenize(text);
127
128 let content = vec![InputContent::OutputText {
129 text: text.to_string(),
130 }];
131 let item = InputItem::Message {
132 role: Role::Assistant,
133 content,
134 };
135 items.push(item);
136 token_count += tokens.len();
137 }
138
139 for wsc in web_search_ids {
140 let item = InputItem::WebSearchCall(wsc);
141 items.push(item);
142 token_count += 0;
144 }
145
146 if !items.is_empty() {
148 self.push(items, token_count)
149 } else {
150 Ok(())
151 }
152 }
153
154 pub fn push_function(
155 &mut self,
156 call_id: &str,
157 name: &str,
158 arguments: &str,
159 output: &str,
160 ) -> Result<()> {
161 let item1 = InputItem::FunctionCall {
162 call_id: call_id.to_string(),
163 name: name.to_string(),
164 arguments: arguments.to_string(),
165 };
166 let item2 = InputItem::FunctionCallOutput {
167 call_id: call_id.to_string(),
168 output: output.to_string(),
169 };
170 let token_count = self.tokenize(name).len()
172 + self.tokenize(arguments).len()
173 + self.tokenize(output).len();
174
175 self.push(vec![item1, item2], token_count)
176 }
177
178 fn push(&mut self, items: Vec<InputItem>, token_count: usize) -> Result<()> {
183 ensure!(token_count <= self.token_limit, "Too long message");
184
185 self.history.push_back(Element { items, token_count });
186 self.token_count += token_count;
187
188 while self.token_count > self.token_limit {
189 let front = self.history.pop_front().unwrap();
190 self.token_count -= front.token_count;
191 }
192
193 Ok(())
194 }
195
196 pub fn clear(&mut self) {
198 self.history.clear();
199 self.token_count = 0;
200 }
201
202 pub fn iter(&self) -> impl Iterator<Item = &InputItem> {
204 self.history.iter().flat_map(|e| e.items.iter())
205 }
206
207 pub fn len(&self) -> usize {
209 self.history.len()
210 }
211
212 pub fn is_empty(&self) -> bool {
214 self.history.len() == 0
215 }
216
217 pub fn get_total_limit(&self) -> usize {
219 self.total_token_limit
220 }
221
222 pub fn usage(&self) -> (usize, usize) {
224 (self.token_count, self.token_limit)
225 }
226
227 fn tokenize(&self, text: &str) -> Vec<u32> {
229 self.core.encode_with_special_tokens(text)
230 }
231
232 pub fn token_count(&self, text: &str) -> usize {
234 self.tokenize(text).len()
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241
242 #[test]
243 fn token() {
244 let hist = ChatHistory::new("gpt-4o");
245 let count = hist.token_count("こんにちは、管理人形さん。");
246
247 assert_eq!(7, count);
249 }
250}