1mod basicfuncs;
4pub mod chat_history;
5pub mod function;
6
7use std::collections::HashMap;
8use std::fmt::Debug;
9use std::io::Cursor;
10use std::sync::LazyLock;
11use std::time::{Duration, Instant, SystemTime};
12
13use super::SystemModule;
14use crate::config;
15use crate::taskserver::Control;
16use base64::{Engine, engine::general_purpose};
17use utils::netutil::{self, HttpStatusError};
18
19use anyhow::{Context, ensure};
20use anyhow::{Result, anyhow, bail};
21use chrono::TimeZone;
22use log::warn;
23use log::{debug, info};
24use reqwest::Response;
25use serde::{Deserialize, Serialize};
26use serde_with::skip_serializing_none;
27
28const CONN_TIMEOUT: Duration = Duration::from_secs(10);
31const TIMEOUT: Duration = Duration::from_secs(60);
33const MODEL_INFO_UPDATE_INTERVAL: Duration = Duration::from_secs(24 * 3600);
36
37fn url_model(model: &str) -> String {
39 format!("https://api.openai.com/v1/models/{model}")
40}
41const URL_RESPONSE: &str = "https://api.openai.com/v1/responses";
42const URL_IMAGE_GEN: &str = "https://api.openai.com/v1/images/generations";
43const URL_AUDIO_SPEECH: &str = "https://api.openai.com/v1/audio/speech";
44
45#[derive(Debug, Clone, Serialize)]
47pub struct ModelInfo {
48 #[serde(flatten)]
49 offline: OfflineModelInfo,
50 #[serde(flatten)]
51 online: OnlineModelInfo,
52}
53
54#[derive(Debug, Clone, Copy, Serialize)]
57pub struct OfflineModelInfo {
58 pub name: &'static str,
59 pub context_window: usize,
61 pub max_output_tokens: usize,
63}
64
65const MODEL_LIST: &[OfflineModelInfo] = &[
71 OfflineModelInfo {
72 name: "gpt-4o-mini",
73 context_window: 128000,
74 max_output_tokens: 4096,
75 },
76 OfflineModelInfo {
77 name: "gpt-4o",
78 context_window: 128000,
79 max_output_tokens: 4096,
80 },
81 OfflineModelInfo {
82 name: "gpt-4",
83 context_window: 8192,
84 max_output_tokens: 8192,
85 },
86 OfflineModelInfo {
87 name: "gpt-4-turbo",
88 context_window: 128000,
89 max_output_tokens: 4096,
90 },
91];
92
93const MAX_OUTPUT_TOKENS_FACTOR: f32 = 1.05;
95
96const OUTPUT_RESERVED_RATIO: f32 = 0.2;
99
100fn get_offline_model_info(model: &str) -> Result<&OfflineModelInfo> {
104 static MAP: LazyLock<HashMap<&str, &OfflineModelInfo>> = LazyLock::new(|| {
105 let mut map = HashMap::new();
106 for info in MODEL_LIST.iter() {
107 map.insert(info.name, info);
108 }
109
110 map
111 });
112
113 MAP.get(model)
114 .copied()
115 .ok_or_else(|| anyhow!("Model not found: {model}"))
116}
117
118#[derive(Debug, Clone, Serialize)]
121pub struct CachedModelInfo {
122 last_update: SystemTime,
123 info: OnlineModelInfo,
124}
125
126#[derive(Default, Clone, Debug, Serialize, Deserialize)]
129struct Model {
130 id: String,
132 created: u64,
134 object: String,
136 owned_by: String,
138}
139
140#[derive(Default, Clone, Debug, Serialize)]
142pub struct OnlineModelInfo {
143 created: String,
144}
145
146impl OnlineModelInfo {
147 fn from(model: Model) -> Self {
148 let dt_str = chrono::Local
149 .timestamp_opt(model.created as i64, 0)
150 .single()
151 .map_or_else(|| "?".into(), |dt| dt.to_rfc3339());
152
153 Self { created: dt_str }
154 }
155}
156
157#[derive(Debug, Clone, Copy)]
161struct RateLimit {
162 timestamp: Instant,
163 limit_requests: u32,
164 limit_tokens: u32,
165 remaining_requests: u32,
166 remaining_tokens: u32,
167 reset_requests: Duration,
168 reset_tokens: Duration,
169}
170
171#[derive(Debug, Clone, Copy)]
172pub struct ExpectedRateLimit {
173 pub limit_requests: u32,
174 pub limit_tokens: u32,
175 pub remaining_requests: u32,
176 pub remaining_tokens: u32,
177}
178
179impl RateLimit {
180 fn from(resp: &reqwest::Response) -> Result<Self> {
181 let timestamp = Instant::now();
182 let headers = resp.headers();
183
184 let to_u32 = |key| -> Result<u32> {
185 let s = headers
186 .get(key)
187 .ok_or_else(|| anyhow!("not found: {key}"))?
188 .to_str()?;
189
190 s.parse::<u32>()
191 .with_context(|| format!("parse u32 failed: {s}"))
192 };
193 let to_secs_f64 = |key| -> Result<f64> {
194 let s = headers
195 .get(key)
196 .ok_or_else(|| anyhow!("not found: {key}"))?
197 .to_str()?;
198
199 Self::to_secs_f64(s).with_context(|| format!("parse f64 secs failed: {s}"))
200 };
201
202 let limit_requests = to_u32("x-ratelimit-limit-requests")?;
203 let limit_tokens = to_u32("x-ratelimit-limit-tokens")?;
204 let remaining_requests = to_u32("x-ratelimit-remaining-requests")?;
205 let remaining_tokens = to_u32("x-ratelimit-remaining-tokens")?;
206 let reset_requests = to_secs_f64("x-ratelimit-reset-requests")?;
207 let reset_tokens = to_secs_f64("x-ratelimit-reset-tokens")?;
208
209 Ok(Self {
210 timestamp,
211 limit_requests,
212 limit_tokens,
213 remaining_requests,
214 remaining_tokens,
215 reset_requests: Duration::from_secs_f64(reset_requests),
216 reset_tokens: Duration::from_secs_f64(reset_tokens),
217 })
218 }
219
220 fn to_secs_f64(s: &str) -> Result<f64> {
224 let mut sum = 0.0;
225
226 let unit_to_scale = |unit: &str| -> Result<f64> {
227 let scale = match unit {
228 "ns" => 0.000_000_001,
229 "us" => 0.000_001,
230 "ms" => 0.001,
231 "s" => 1.0,
232 "m" => 60.0,
233 "h" => 3600.0,
234 "d" => 86400.0,
235 _ => bail!("unknown unit: {unit}"),
236 };
237 Ok(scale)
238 };
239
240 let mut numbuf = String::new();
241 let mut unitbuf = String::new();
242 for c in s.chars() {
243 match c {
244 '0'..='9' | '.' => {
245 if !unitbuf.is_empty() {
246 let num = numbuf.parse::<f64>()?;
247 let scale = unit_to_scale(&unitbuf)?;
248 sum += num * scale;
249 numbuf.clear();
250 unitbuf.clear();
251 }
252 numbuf.push(c);
253 }
254 _ => {
255 unitbuf.push(c);
256 }
257 };
258 }
259 if !unitbuf.is_empty() {
260 let num = numbuf.parse::<f64>()?;
261 let scale = unit_to_scale(&unitbuf)?;
262 sum += num * scale;
263 numbuf.clear();
264 unitbuf.clear();
265 }
266 ensure!(numbuf.is_empty(), "unexpected format: {}", s);
267
268 Ok(sum)
269 }
270
271 fn calc_expected_current(&self) -> ExpectedRateLimit {
273 let now = Instant::now();
274 let elapsed_secs = (now - self.timestamp).as_secs_f64();
275 debug!("{self:?}");
276 debug!("elapsed_secs = {elapsed_secs}");
277
278 let remaining_requests = if elapsed_secs >= self.reset_requests.as_secs_f64() {
279 self.limit_requests
280 } else {
281 let vreq = (self.limit_requests - self.remaining_requests) as f64
282 / self.reset_requests.as_secs_f64();
283 let remaining_requests = self.remaining_requests as f64 + vreq * elapsed_secs;
284
285 (remaining_requests as u32).min(self.limit_requests)
286 };
287
288 let remaining_tokens = if elapsed_secs >= self.reset_tokens.as_secs_f64() {
289 self.limit_tokens
290 } else {
291 let vtok = (self.limit_tokens - self.remaining_tokens) as f64
292 / self.reset_tokens.as_secs_f64();
293 let remaining_tokens = self.remaining_tokens as f64 + vtok * elapsed_secs;
294
295 (remaining_tokens as u32).min(self.limit_tokens)
296 };
297
298 ExpectedRateLimit {
299 limit_requests: self.limit_requests,
300 limit_tokens: self.limit_tokens,
301 remaining_requests,
302 remaining_tokens,
303 }
304 }
305}
306
307#[derive(Clone, Debug, Serialize)]
312#[serde(tag = "type", rename_all = "snake_case")]
313pub enum InputItem {
314 Message {
316 role: Role,
323 content: Vec<InputContent>,
326 },
327 WebSearchCall(WebSearchCall),
329 FunctionCall {
331 call_id: String,
333 name: String,
335 arguments: String,
337 },
338 FunctionCallOutput { call_id: String, output: String },
340}
341
342#[derive(Clone, Serialize)]
343#[serde(tag = "type", rename_all = "snake_case")]
344pub enum InputContent {
345 InputText {
346 text: String,
348 },
349 InputImage {
350 image_url: String,
351 detail: InputImageDetail,
354 },
355 OutputText {
356 text: String,
358 },
359}
360
361impl Debug for InputContent {
362 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
363 match self {
364 InputContent::InputText { text } => write!(f, "InputText({text})"),
365 InputContent::OutputText { text } => write!(f, "OutputText({text})"),
366 InputContent::InputImage { image_url, detail } => write!(
367 f,
368 "InputImage(image_url: {} bytes, {detail:?})",
369 image_url.len()
370 ),
371 }
372 }
373}
374
375#[derive(Clone, Default, Debug, Serialize)]
376#[serde(rename_all = "lowercase")]
377pub enum InputImageDetail {
378 #[default]
379 Auto,
380 High,
381 Low,
382}
383
384#[derive(Default, Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
387#[serde(rename_all = "lowercase")]
388pub enum Role {
389 #[default]
390 Developer,
391 User,
392 Assistant,
393}
394
395#[skip_serializing_none]
396#[derive(Clone, Debug, Serialize, Deserialize)]
397#[serde(tag = "type", rename_all = "snake_case")]
398pub enum Tool {
399 WebSearchPreview {
402 search_context_size: Option<SearchContextSize>,
405 user_location: Option<UserLocation>,
407 },
408 FileSearch {
411 vector_store_ids: Vec<String>,
412 },
413
414 Function(Function),
415}
416
417#[derive(Default, Clone, Debug, Serialize, Deserialize)]
418#[serde(rename_all = "lowercase")]
419pub enum SearchContextSize {
420 Low,
421 #[default]
422 Medium,
423 High,
424}
425
426#[derive(Clone, Debug, Serialize, Deserialize)]
427#[serde(tag = "type", rename_all = "snake_case")]
428pub enum UserLocation {
429 Approximate {
430 #[serde(skip_serializing_if = "Option::is_none")]
432 city: Option<String>,
433 #[serde(skip_serializing_if = "Option::is_none")]
435 country: Option<String>,
436 #[serde(skip_serializing_if = "Option::is_none")]
438 region: Option<String>,
439 #[serde(skip_serializing_if = "Option::is_none")]
441 timezone: Option<String>,
442 },
443}
444
445impl Default for UserLocation {
446 fn default() -> Self {
447 Self::Approximate {
448 city: None,
449 country: Some("JP".to_string()),
450 region: None,
451 timezone: Some("Asia/Tokyo".to_string()),
452 }
453 }
454}
455
456#[skip_serializing_none]
459#[derive(Clone, Debug, Serialize, Deserialize)]
460pub struct Function {
461 pub name: String,
463 pub description: Option<String>,
466 pub parameters: Parameters,
468 pub strict: bool,
470}
471
472impl Default for Function {
473 fn default() -> Self {
477 Self {
478 name: Default::default(),
479 description: Default::default(),
480 parameters: Default::default(),
481 strict: true,
482 }
483 }
484}
485
486#[skip_serializing_none]
492#[derive(Default, Clone, Debug, Serialize, Deserialize)]
493pub struct ParameterElement {
494 #[serde(rename = "type")]
495 pub type_: Vec<ParameterType>,
496 pub description: Option<String>,
497 #[serde(rename = "enum")]
498 pub enum_: Option<Vec<String>>,
499 }
503
504#[derive(Clone, Debug, Serialize, Deserialize)]
505#[serde(rename_all = "lowercase")]
506pub enum ParameterType {
507 Null,
508 Boolean,
509 Integer,
510 Number,
511 String,
512}
513
514#[derive(Clone, Debug, Serialize, Deserialize)]
517pub struct Parameters {
518 #[serde(rename = "type")]
520 pub type_: String,
521 pub properties: HashMap<String, ParameterElement>,
522 pub required: Vec<String>,
523 #[serde(rename = "additionalProperties")]
524 pub additional_properties: bool,
525}
526
527impl Default for Parameters {
528 fn default() -> Self {
529 Self {
530 type_: "object".to_string(),
531 properties: Default::default(),
532 required: Default::default(),
533 additional_properties: false,
534 }
535 }
536}
537
538#[skip_serializing_none]
541#[derive(Default, Clone, Debug, Serialize)]
542pub struct ResponseRequest {
543 model: String,
548
549 instructions: Option<String>,
557
558 input: Vec<InputItem>,
561
562 tools: Option<Vec<Tool>>,
575
576 include: Option<Vec<String>>,
585
586 max_output_tokens: Option<u64>,
589
590 previous_response_id: Option<String>,
594
595 temperature: Option<f32>,
601
602 top_p: Option<f32>,
610
611 user: Option<String>,
614}
615
616#[allow(dead_code)]
617#[derive(Clone, Debug, Deserialize)]
618pub struct ResponseObject {
619 id: String,
620 created_at: u64,
621 error: Option<ErrorObject>,
622 instructions: Option<String>,
623 max_output_tokens: Option<u64>,
624 model: String,
625 output: Vec<OutputElement>,
626 previous_response_id: Option<String>,
627 usage: Usage,
628 user: Option<String>,
629}
630
631impl ResponseObject {
632 pub fn output_text(&self) -> String {
637 let mut buf = String::new();
638 for elem in self.output.iter() {
639 if let OutputElement::Message { content, .. } = elem {
640 for cont in content.iter() {
641 if let OutputContent::OutputText { text } = cont {
642 buf.push_str(text);
643 }
644 }
645 }
646 }
647
648 buf
649 }
650
651 pub fn web_search_iter(&self) -> impl Iterator<Item = &WebSearchCall> {
653 self.output.iter().filter_map(|elem| match elem {
654 OutputElement::WebSearchCall(wsc) => Some(wsc),
655 _ => None,
656 })
657 }
658
659 pub fn func_call_iter(&self) -> impl Iterator<Item = &FunctionCall> {
661 self.output.iter().filter_map(|elem| match elem {
662 OutputElement::FunctionCall(fc) => Some(fc),
663 _ => None,
664 })
665 }
666}
667
668#[allow(dead_code)]
669#[derive(Clone, Debug, Deserialize)]
670#[serde(tag = "type", rename_all = "snake_case")]
671pub enum OutputElement {
672 Message {
674 id: String,
676 role: Role,
678 content: Vec<OutputContent>,
680 },
681 FunctionCall(FunctionCall),
682 WebSearchCall(WebSearchCall),
685}
686
687#[allow(dead_code)]
688#[derive(Clone, Debug, Deserialize)]
689pub struct FunctionCall {
690 pub id: String,
691 pub call_id: String,
692 pub name: String,
693 pub arguments: String,
694 pub status: String,
695}
696
697#[derive(Clone, Debug, Serialize, Deserialize)]
698pub struct WebSearchCall {
699 pub id: String,
701 pub status: String,
703}
704
705#[allow(dead_code)]
706#[derive(Clone, Debug, Deserialize)]
707#[serde(tag = "type", rename_all = "snake_case")]
708pub enum OutputContent {
709 OutputText {
711 text: String,
713 },
716 Refusal {
718 refusal: String,
720 },
721}
722
723#[allow(dead_code)]
724#[derive(Default, Clone, Debug, Deserialize)]
725struct ErrorObject {
726 code: String,
728 message: String,
730}
731
732#[allow(dead_code)]
733#[derive(Default, Clone, Debug, Deserialize)]
734struct Usage {
735 input_tokens: u32,
736 input_tokens_details: InputTokensDetails,
737 output_tokens: u32,
738 output_tokens_details: OutputTokensDetails,
739 total_tokens: u32,
740}
741
742#[allow(dead_code)]
743#[derive(Default, Clone, Debug, Deserialize)]
744struct InputTokensDetails {
745 cached_tokens: u32,
746}
747
748#[allow(dead_code)]
749#[derive(Default, Clone, Debug, Deserialize)]
750struct OutputTokensDetails {
751 reasoning_tokens: u32,
752}
753
754#[skip_serializing_none]
759#[derive(Default, Clone, Debug, Serialize, Deserialize)]
760struct ImageGenRequest {
761 prompt: String,
764 n: Option<u8>,
767 response_format: Option<ResponseFormat>,
771 size: Option<ImageSize>,
774 user: Option<String>,
777}
778
779#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
782#[serde(rename_all = "snake_case")]
783enum ResponseFormat {
784 Url,
785 B64Json,
786}
787
788#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
791#[serde(rename_all = "snake_case")]
792enum ImageSize {
793 #[serde(rename = "256x256")]
794 X256,
795 #[serde(rename = "512x512")]
796 X512,
797 #[serde(rename = "1024x1024")]
798 X1024,
799}
800
801#[derive(Clone, Debug, Serialize, Deserialize)]
804struct ImageGenResponse {
805 created: u64,
806 data: Vec<Image>,
807}
808
809#[derive(Clone, Debug, Serialize, Deserialize)]
812struct Image {
813 b64_json: Option<String>,
814 url: Option<String>,
815}
816
817#[skip_serializing_none]
822#[derive(Default, Clone, Debug, Serialize, Deserialize)]
823struct SpeechRequest {
824 model: SpeechModel,
826 input: String,
829 voice: SpeechVoice,
833 response_format: Option<SpeechFormat>,
837 speed: Option<f32>,
840}
841
842#[derive(Default, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
843#[serde(rename_all = "snake_case")]
844pub enum SpeechModel {
845 #[serde(rename = "tts-1")]
847 #[default]
848 Tts1,
849 #[serde(rename = "tts-1-hd")]
851 Tts1Hd,
852}
853
854pub const SPEECH_INPUT_MAX: usize = 4096;
855
856#[derive(Default, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
857#[serde(rename_all = "snake_case")]
858pub enum SpeechVoice {
859 #[default]
860 Alloy,
861 Echo,
862 Fable,
863 Onyx,
864 Nova,
865 Shimmer,
866}
867
868#[derive(Default, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
869#[serde(rename_all = "snake_case")]
870pub enum SpeechFormat {
871 #[default]
872 Mp3,
873 OpuS,
874 Aac,
875 Flac,
876 Wav,
877 Pcm,
878}
879
880pub const SPEECH_SPEED_MIN: f32 = 0.25;
881pub const SPEECH_SPEED_MAX: f32 = 4.0;
882
883#[derive(Debug, Clone, Serialize, Deserialize)]
885pub struct OpenAiConfig {
886 enabled: bool,
888 api_key: String,
890 pub model: String,
893 pub storage_dir: String,
896}
897
898impl Default for OpenAiConfig {
899 fn default() -> Self {
900 Self {
901 enabled: false,
902 api_key: "".to_string(),
903 model: MODEL_LIST.first().unwrap().name.to_string(),
904 storage_dir: "./aimemory".to_string(),
905 }
906 }
907}
908
909pub struct OpenAi {
911 config: OpenAiConfig,
912 client: reqwest::Client,
913
914 model_name: &'static str,
915 model_info_offline: OfflineModelInfo,
916 model_info_online: Option<CachedModelInfo>,
917
918 rate_limit: Option<RateLimit>,
919}
920
921pub enum OpenAiErrorKind {
925 Fatal,
927 Timeout,
929 RateLimit,
931 QuotaExceeded,
933 HttpError(u16),
935}
936
937impl OpenAi {
938 pub fn new() -> Result<Self> {
940 info!("[openai] initialize");
941
942 let config = config::get(|cfg| cfg.openai.clone());
943
944 info!("[openai] OpenAI model list START");
945 for info in MODEL_LIST.iter() {
946 info!(
947 "[openai] name: \"{}\", context_window: {}",
948 info.name, info.context_window
949 );
950 }
951 info!("[openai] OpenAI model list END");
952
953 let info = get_offline_model_info(&config.model)?;
954 info!(
955 "[openai] selected: model: {}, token_limit: {}",
956 info.name, info.context_window
957 );
958
959 if !config.storage_dir.is_empty() {
960 info!("[openai] mkdir: {}", config.storage_dir);
961 std::fs::create_dir_all(&config.storage_dir)?;
962 }
963
964 let client = reqwest::Client::builder()
965 .connect_timeout(CONN_TIMEOUT)
966 .timeout(TIMEOUT)
967 .build()?;
968
969 Ok(OpenAi {
970 config: config.clone(),
971 client,
972 model_name: info.name,
973 model_info_offline: *info,
974 model_info_online: None,
975 rate_limit: None,
976 })
977 }
978
979 pub async fn model_info(&mut self) -> Result<ModelInfo> {
986 let offline = self.model_info_offline();
987 let online = self.model_info_online().await?;
988
989 Ok(ModelInfo { offline, online })
990 }
991
992 pub fn model_info_offline(&self) -> OfflineModelInfo {
993 self.model_info_offline
994 }
995
996 pub async fn model_info_online(&mut self) -> Result<OnlineModelInfo> {
997 let cur = &self.model_info_online;
998 let update = if let Some(info) = cur {
999 let now = SystemTime::now();
1000 let elapsed = now.duration_since(info.last_update).unwrap_or_default();
1001
1002 elapsed > MODEL_INFO_UPDATE_INTERVAL
1003 } else {
1004 true
1005 };
1006
1007 if update {
1008 info!("[openai] update model info");
1009 let info = self.get_online_model_info().await?;
1010 let info = OnlineModelInfo::from(info);
1011 let newval = CachedModelInfo {
1012 last_update: SystemTime::now(),
1013 info: info.clone(),
1014 };
1015 let _ = self.model_info_online.insert(newval);
1016
1017 Ok(info)
1018 } else {
1019 info!("[openai] skip to update model info");
1020 Ok(cur.as_ref().unwrap().info.clone())
1021 }
1022 }
1023
1024 pub fn get_output_reserved_token(&self) -> usize {
1028 let info = self.model_info_offline();
1029 let v1 = (info.max_output_tokens as f32 * MAX_OUTPUT_TOKENS_FACTOR) as usize;
1030 let v2 = (info.context_window as f32 * OUTPUT_RESERVED_RATIO) as usize;
1031
1032 v1.min(v2)
1033 }
1034
1035 async fn get_online_model_info(&self) -> Result<Model> {
1036 let key = &self.config.api_key;
1037 let model = self.model_name;
1038
1039 info!("[openai] model request");
1040 self.check_enabled()?;
1041
1042 let resp = self
1043 .client
1044 .get(url_model(model))
1045 .header("Authorization", format!("Bearer {key}"))
1046 .send()
1047 .await?;
1048
1049 let json_str = netutil::check_http_resp(resp).await?;
1050
1051 netutil::convert_from_json::<Model>(&json_str)
1052 }
1053
1054 fn check_enabled(&self) -> Result<()> {
1056 if !self.config.enabled {
1057 warn!("[openai] skip because openai feature is disabled");
1058 bail!("openai is disabled");
1059 }
1060
1061 Ok(())
1062 }
1063
1064 pub fn error_kind(err: &anyhow::Error) -> OpenAiErrorKind {
1066 for cause in err.chain() {
1067 if let Some(req_err) = cause.downcast_ref::<reqwest::Error>()
1068 && req_err.is_timeout()
1069 {
1070 return OpenAiErrorKind::Timeout;
1071 }
1072 if let Some(http_err) = cause.downcast_ref::<HttpStatusError>() {
1073 if http_err.status == 429 {
1075 let msg = http_err.body.to_ascii_lowercase();
1076 if msg.contains("rate") && msg.contains("limit") {
1077 return OpenAiErrorKind::RateLimit;
1078 } else if msg.contains("quota") && msg.contains("billing") {
1079 return OpenAiErrorKind::QuotaExceeded;
1080 }
1081 } else {
1082 return OpenAiErrorKind::HttpError(http_err.status);
1083 }
1084 }
1085 }
1086
1087 OpenAiErrorKind::Fatal
1088 }
1089
1090 fn log_header(resp: &reqwest::Response, key: &str) {
1091 if let Some(value) = resp.headers().get(key) {
1092 info!("[openai] {key}: {value:?}");
1093 } else {
1094 info!("[openai] not found: {key}");
1095 }
1096 }
1097
1098 async fn post_json(
1106 &mut self,
1107 url: &str,
1108 body: &(impl Serialize + std::fmt::Debug),
1109 ) -> Result<Response> {
1110 let key = &self.config.api_key;
1111
1112 info!("[openai] post_json: {url}");
1113 info!("[openai] {body:?}");
1114 self.check_enabled()?;
1115
1116 let resp = self
1117 .client
1118 .post(url)
1119 .header("Authorization", format!("Bearer {key}"))
1120 .json(body)
1121 .send()
1122 .await?;
1123 Self::log_header(&resp, "x-request-id");
1128 Self::log_header(&resp, "openai-organization");
1129 Self::log_header(&resp, "openai-processing-ms");
1130 Self::log_header(&resp, "openai-version");
1131 Self::log_header(&resp, "x-should-retry");
1134
1135 match RateLimit::from(&resp) {
1138 Ok(rate_limit) => {
1139 info!("[openai] rate limit: {rate_limit:?}");
1140 self.rate_limit = Some(rate_limit);
1141 }
1142 Err(err) => {
1143 warn!("[openai] could not get rate limit: {err:#}");
1144 }
1145 }
1146
1147 Ok(resp)
1148 }
1149
1150 async fn post_json_text(
1153 &mut self,
1154 url: &str,
1155 body: &(impl Serialize + std::fmt::Debug),
1156 ) -> Result<String> {
1157 let resp = self.post_json(url, body).await?;
1158 let text = netutil::check_http_resp(resp).await?;
1159 info!("{text}");
1160
1161 Ok(text)
1162 }
1163
1164 async fn post_json_bin(
1167 &mut self,
1168 url: &str,
1169 body: &(impl Serialize + std::fmt::Debug),
1170 ) -> Result<Vec<u8>> {
1171 let resp = self.post_json(url, body).await?;
1172 let bin = netutil::check_http_resp_bin(resp).await?;
1173 info!("[openai] binary received: size={}", bin.len());
1174
1175 Ok(bin)
1176 }
1177
1178 pub fn get_expected_rate_limit(&self) -> Option<ExpectedRateLimit> {
1179 self.rate_limit
1180 .as_ref()
1181 .map(|rate_limit| rate_limit.calc_expected_current())
1182 }
1183
1184 pub async fn chat(
1186 &mut self,
1187 instructions: Option<&str>,
1188 input: Vec<InputItem>,
1189 ) -> Result<ResponseObject> {
1190 self.chat_with_tools(instructions, input, &[]).await
1191 }
1192
1193 pub async fn chat_with_tools(
1195 &mut self,
1196 instructions: Option<&str>,
1197 input: Vec<InputItem>,
1198 tools: &[Tool],
1199 ) -> Result<ResponseObject> {
1200 info!("[openai] chat request");
1201
1202 let instructions = instructions.map(|s| s.to_string());
1203
1204 let body = ResponseRequest {
1205 model: self.model_name.to_string(),
1206 instructions,
1207 input,
1208 tools: Some(tools.to_vec()),
1209 ..Default::default()
1210 };
1211
1212 let json_str = self.post_json_text(URL_RESPONSE, &body).await?;
1213 let resp: ResponseObject = netutil::convert_from_json(&json_str)?;
1214
1215 Ok(resp)
1216 }
1217
1218 pub fn to_image_input(bin: &[u8]) -> Result<InputContent> {
1238 const SIZE_LIMIT: u32 = 512;
1239
1240 let mut img: image::DynamicImage =
1241 image::load_from_memory(bin).context("Load image error")?;
1242 if img.width() > SIZE_LIMIT || img.height() > SIZE_LIMIT {
1244 img = img.resize(SIZE_LIMIT, SIZE_LIMIT, image::imageops::FilterType::Nearest);
1245 }
1246
1247 let mut output = Cursor::new(vec![]);
1249 img.write_to(&mut output, image::ImageFormat::Png)
1250 .context("Convert image error")?;
1251 let dst = output.into_inner();
1252
1253 let base64 = general_purpose::STANDARD.encode(&dst);
1255 let image_url = format!("data:image/png;base64,{base64}");
1256 let input = InputContent::InputImage {
1257 image_url,
1258 detail: InputImageDetail::Low,
1259 };
1260
1261 Ok(input)
1262 }
1263
1264 pub async fn generate_image(&mut self, prompt: &str, n: u8) -> Result<Vec<String>> {
1266 info!("[openai] image gen request");
1267
1268 let body = ImageGenRequest {
1269 prompt: prompt.to_string(),
1270 n: Some(n),
1271 size: Some(ImageSize::X256),
1272 ..Default::default()
1273 };
1274
1275 let json_str = self.post_json_text(URL_IMAGE_GEN, &body).await?;
1276 let resp: ImageGenResponse = netutil::convert_from_json(&json_str)?;
1277
1278 let mut result = Vec::new();
1279 for img in resp.data.iter() {
1280 let url = img.url.as_ref().ok_or_else(|| anyhow!("url is required"))?;
1281 result.push(url.to_string());
1282 }
1283 info!("[openai] image gen OK: {result:?}");
1284
1285 Ok(result)
1286 }
1287
1288 pub async fn text_to_speech(
1290 &mut self,
1291 model: SpeechModel,
1292 input: &str,
1293 voice: SpeechVoice,
1294 response_format: Option<SpeechFormat>,
1295 speed: Option<f32>,
1296 ) -> Result<Vec<u8>> {
1297 info!("[openai] create speech request");
1298
1299 ensure!(
1300 input.len() <= SPEECH_INPUT_MAX,
1301 "input length limit is {SPEECH_INPUT_MAX} characters"
1302 );
1303 if let Some(speed) = speed {
1304 ensure!(
1305 (SPEECH_SPEED_MIN..=SPEECH_SPEED_MAX).contains(&speed),
1306 "speed must be {SPEECH_SPEED_MIN} .. {SPEECH_SPEED_MAX}"
1307 );
1308 }
1309
1310 let body = SpeechRequest {
1311 model,
1312 input: input.to_string(),
1313 voice,
1314 response_format,
1315 speed,
1316 };
1317
1318 let bin = self.post_json_bin(URL_AUDIO_SPEECH, &body).await?;
1319
1320 Ok(bin)
1321 }
1322}
1323
1324impl SystemModule for OpenAi {
1325 fn on_start(&mut self, _ctrl: &Control) {
1326 info!("[openai] on_start");
1327 }
1328}
1329
1330#[cfg(test)]
1331mod tests {
1332 use super::*;
1333 use serial_test::serial;
1334 use utils::netutil::HttpStatusError;
1335
1336 #[test]
1337 fn test_parse_resettime() {
1338 const EPS: f64 = 1e-10;
1339
1340 let s = "6m0s";
1341 let v = RateLimit::to_secs_f64(s).unwrap();
1342 assert_eq!(360.0, v);
1343
1344 let s = "30.828s";
1345 let v = RateLimit::to_secs_f64(s).unwrap();
1346 assert!((30.828 - v).abs() < EPS);
1347
1348 let s = "1h2m3s";
1349 let v = RateLimit::to_secs_f64(s).unwrap();
1350 assert_eq!((3600 + 120 + 3) as f64, v);
1351
1352 let s = "120ms";
1353 let v = RateLimit::to_secs_f64(s).unwrap();
1354 assert!((0.120 - v).abs() < EPS);
1355 }
1356
1357 #[tokio::test]
1358 #[serial(openai)]
1359 #[ignore]
1360 async fn simple_assistant() {
1362 let src = std::fs::read_to_string("../config.toml").unwrap();
1363 let _unset = config::set(toml::from_str(&src).unwrap());
1364
1365 let mut ai = OpenAi::new().unwrap();
1366 let inst = concat!(
1367 "あなたの名前は上海人形で、あなたはやっぴーさんの人形です。あなたはやっぴー家の優秀なアシスタントです。",
1368 "やっぴーさんは男性で、ホワイト企業に勤めています。yappyという名前で呼ばれることもあります。"
1369 );
1370 let input = vec![InputItem::Message {
1371 role: Role::User,
1372 content: vec![InputContent::InputText {
1373 text: "こんにちは。あなたの知っている情報を教えてください。".to_string(),
1374 }],
1375 }];
1376 match ai.chat(Some(inst), input).await {
1377 Ok(resp) => {
1378 println!("{resp:?}");
1379 println!("{}", resp.output_text());
1380 }
1381 Err(err) => {
1382 let err = err.downcast_ref::<HttpStatusError>().unwrap();
1384 println!("{err:#?}");
1385 }
1386 };
1387 }
1388
1389 #[tokio::test]
1390 #[serial(openai)]
1391 #[ignore]
1392 async fn web_search() {
1394 let src = std::fs::read_to_string("../config.toml").unwrap();
1395 let _unset = config::set(toml::from_str(&src).unwrap());
1396
1397 let mut ai = OpenAi::new().unwrap();
1398 let input = vec![InputItem::Message {
1399 role: Role::User,
1400 content: vec![InputContent::InputText {
1401 text: "今日の最新ニュースを教えてください。1つだけでいいです。".to_string(),
1402 }],
1403 }];
1404 let tools = [Tool::WebSearchPreview {
1405 search_context_size: Some(SearchContextSize::Low),
1406 user_location: Some(UserLocation::Approximate {
1407 city: None,
1408 country: Some("JP".to_string()),
1409 region: None,
1410 timezone: Some("Asia/Tokyo".to_string()),
1411 }),
1412 }];
1413 println!("{}", serde_json::to_string(&tools).unwrap());
1414 match ai.chat_with_tools(None, input, &tools).await {
1415 Ok(resp) => {
1416 println!("{resp:?}");
1417 println!("{}", resp.output_text());
1418 }
1419 Err(err) => {
1420 let err = err.downcast_ref::<HttpStatusError>().unwrap();
1422 println!("{err:#?}");
1423 }
1424 };
1425 }
1426
1427 #[tokio::test]
1428 #[serial(openai)]
1429 #[ignore]
1430 async fn image_gen() -> Result<()> {
1432 let src = std::fs::read_to_string("../config.toml").unwrap();
1433 let _unset = config::set(toml::from_str(&src).unwrap());
1434
1435 let mut ai = OpenAi::new().unwrap();
1436 let res = ai
1437 .generate_image("Rasberry Pi の上に乗っている管理人形", 1)
1438 .await?;
1439 assert_eq!(1, res.len());
1440
1441 Ok(())
1442 }
1443
1444 #[tokio::test]
1445 #[serial(openai)]
1446 #[ignore]
1447 async fn test_to_sppech() -> Result<()> {
1449 let src = std::fs::read_to_string("../config.toml").unwrap();
1450 let _unset = config::set(toml::from_str(&src).unwrap());
1451
1452 let mut ai = OpenAi::new().unwrap();
1453 let res = ai
1454 .text_to_speech(
1455 SpeechModel::Tts1,
1456 "こんにちは、かんりにんぎょうです。",
1457 SpeechVoice::Nova,
1458 Some(SpeechFormat::Mp3),
1459 Some(1.0),
1460 )
1461 .await?;
1462
1463 assert!(!res.is_empty());
1464 let size = res.len();
1465 const PATH: &str = "speech.mp3";
1466 std::fs::write(PATH, res)?;
1467 println!("Wrote to: {PATH} ({size} bytes)");
1468
1469 Ok(())
1470 }
1471}