1use super::SystemModule;
4use super::openai::{InputContent, ParameterType};
5use super::openai::{
6 ParameterElement,
7 chat_history::ChatHistory,
8 function::{self, BasicContext, FuncArgs, FunctionTable},
9};
10use crate::config;
11use crate::sysmod::openai::{Function, Parameters, function::FUNCTION_TOKEN};
12use crate::taskserver::{self, Control};
13
14use anyhow::{Result, anyhow, bail, ensure};
15use log::info;
16use reqwest::{Client, StatusCode};
17use serde::{Deserialize, Serialize};
18use std::vec;
19use std::{
20 collections::HashMap,
21 fmt::Debug,
22 sync::Arc,
23 time::{Duration, Instant},
24};
25
26const TIMEOUT: Duration = Duration::from_secs(30);
28const MSG_SPLIT_LEN: usize = 5000 - 128;
31
32#[derive(Default, Debug, Clone, Serialize, Deserialize)]
34pub struct LineConfig {
35 enabled: bool,
37 token: String,
39 pub channel_secret: String,
41 pub id_name_map: HashMap<String, String>,
43 #[serde(default)]
45 pub prompt: LinePrompt,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct LinePrompt {
50 pub instructions: Vec<String>,
52 pub each: Vec<String>,
54 pub history_timeout_min: u32,
56 pub timeout_msg: String,
58 pub ratelimit_msg: String,
60 pub quota_msg: String,
62 pub error_msg: String,
64}
65
66const DEFAULT_TOML: &str =
68 include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/res/openai_line.toml"));
69impl Default for LinePrompt {
70 fn default() -> Self {
71 toml::from_str(DEFAULT_TOML).unwrap()
72 }
73}
74
75pub struct FunctionContext {
76 pub reply_to: String,
78}
79
80pub struct Line {
84 pub config: LineConfig,
86 client: reqwest::Client,
88
89 pub image_buffer: HashMap<String, Vec<InputContent>>,
90 pub chat_history: Option<ChatHistory>,
92 pub history_timeout: Option<Instant>,
94 pub func_table: Option<FunctionTable<FunctionContext>>,
96}
97
98impl Line {
99 pub fn new() -> Result<Self> {
101 info!("[line] initialize");
102 let config = config::get(|cfg| cfg.line.clone());
103 let client = Client::builder().timeout(TIMEOUT).build()?;
104
105 Ok(Self {
106 config,
107 client,
108 image_buffer: Default::default(),
109 chat_history: None,
110 history_timeout: None,
111 func_table: None,
112 })
113 }
114
115 async fn init_openai(&mut self, ctrl: &Control) {
116 if self.chat_history.is_some() && self.func_table.is_some() {
117 return;
118 }
119
120 let (model_info, reserved) = {
123 let openai = ctrl.sysmods().openai.lock().await;
124
125 (
126 openai.model_info_offline(),
127 openai.get_output_reserved_token(),
128 )
129 };
130
131 let mut chat_history = ChatHistory::new(model_info.name);
132 assert!(chat_history.get_total_limit() == model_info.context_window);
133 let pre_token: usize = self
134 .config
135 .prompt
136 .instructions
137 .iter()
138 .map(|text| chat_history.token_count(text))
139 .sum();
140 let reserved = FUNCTION_TOKEN + pre_token + reserved;
141 chat_history.reserve_tokens(reserved);
142 info!("[line] OpenAI token limit");
143 info!("[line] {:6} total", model_info.context_window);
144 info!("[line] {reserved:6} reserved");
145 info!("[line] {:6} chat history", chat_history.usage().1);
146
147 let mut func_table = FunctionTable::new(Arc::clone(ctrl), Some("line"));
148 func_table.register_basic_functions();
149 register_draw_picture(&mut func_table);
150
151 let _ = self.chat_history.insert(chat_history);
152 let _ = self.func_table.insert(func_table);
153 }
154
155 pub async fn chat_history(&mut self, ctrl: &Control) -> &ChatHistory {
156 self.init_openai(ctrl).await;
157 self.chat_history.as_ref().unwrap()
158 }
159
160 pub async fn chat_history_mut(&mut self, ctrl: &Control) -> &mut ChatHistory {
161 self.init_openai(ctrl).await;
162 self.chat_history.as_mut().unwrap()
163 }
164
165 pub async fn func_table(&mut self, ctrl: &Control) -> &FunctionTable<FunctionContext> {
166 self.init_openai(ctrl).await;
167 self.func_table.as_ref().unwrap()
168 }
169
170 }
177
178impl SystemModule for Line {
179 fn on_start(&mut self, _ctrl: &Control) {
180 info!("[line] on_start");
181 }
182}
183
184#[derive(Debug, Serialize, Deserialize)]
187#[serde(tag = "type")]
188struct ErrorResp {
189 message: String,
190 details: Option<Vec<Detail>>,
191}
192
193#[derive(Debug, Serialize, Deserialize)]
194struct Detail {
195 message: Option<String>,
196 property: Option<String>,
197}
198
199#[derive(Debug, Serialize, Deserialize)]
200#[serde(rename_all = "camelCase")]
201pub struct ProfileResp {
202 pub display_name: String,
203 pub user_id: String,
204 pub language: Option<String>,
205 pub picture_url: Option<String>,
206 pub status_message: Option<String>,
207}
208
209#[derive(Debug, Serialize, Deserialize)]
210#[serde(rename_all = "camelCase")]
211#[serde_with::skip_serializing_none]
212struct ReplyReq {
213 reply_token: String,
214 messages: Vec<Message>,
216 notification_disabled: Option<bool>,
217}
218
219#[derive(Debug, Serialize, Deserialize)]
220#[serde(rename_all = "camelCase")]
221pub struct ReplyResp {
222 sent_messages: Vec<SentMessage>,
223}
224
225#[derive(Debug, Serialize, Deserialize)]
226#[serde(rename_all = "camelCase")]
227#[serde_with::skip_serializing_none]
228struct PushReq {
229 to: String,
230 messages: Vec<Message>,
232 notification_disabled: Option<bool>,
233 custom_aggregation_units: Option<Vec<String>>,
234}
235
236#[derive(Debug, Serialize, Deserialize)]
237#[serde(rename_all = "camelCase")]
238pub struct PushResp {
239 sent_messages: Vec<SentMessage>,
240}
241
242#[derive(Debug, Serialize, Deserialize)]
243#[serde(rename_all = "camelCase")]
244struct SentMessage {
245 id: String,
246 quote_token: Option<String>,
247}
248
249#[derive(Debug, Serialize, Deserialize)]
250#[serde(rename_all = "camelCase")]
251#[serde(tag = "type")]
252#[serde_with::skip_serializing_none]
253enum Message {
254 #[serde(rename_all = "camelCase")]
255 Text { text: String },
256 #[serde(rename_all = "camelCase")]
257 Image {
258 original_content_url: String,
263 preview_image_url: String,
268 },
269}
270
271fn url_profile(user_id: &str) -> String {
272 format!("https://api.line.me/v2/bot/profile/{user_id}")
273}
274
275fn url_group_profile(group_id: &str, user_id: &str) -> String {
276 format!("https://api.line.me/v2/bot/group/{group_id}/member/{user_id}")
277}
278
279fn url_content(message_id: &str) -> String {
280 format!("https://api-data.line.me/v2/bot/message/{message_id}/content")
281}
282
283const URL_REPLY: &str = "https://api.line.me/v2/bot/message/reply";
284const URL_PUSH: &str = "https://api.line.me/v2/bot/message/push";
285
286impl Line {
287 pub async fn get_profile(&self, user_id: &str) -> Result<ProfileResp> {
288 self.get_auth_json(&url_profile(user_id)).await
289 }
290
291 pub async fn get_group_profile(&self, group_id: &str, user_id: &str) -> Result<ProfileResp> {
292 self.get_auth_json(&url_group_profile(group_id, user_id))
293 .await
294 }
295
296 pub async fn get_content(&self, message_id: &str) -> Result<Vec<u8>> {
313 let bin = loop {
314 let (status, bin) = self.get_auth_bin(&url_content(message_id)).await?;
315 match status {
316 StatusCode::OK => {
317 info!("OK ({} bytes)", bin.len());
318 break bin;
319 }
320 StatusCode::ACCEPTED => {
321 info!("202: Not ready yet");
323 tokio::time::sleep(Duration::from_secs(5)).await;
324 }
325 _ => {
326 bail!("Invalid status: {status}");
327 }
328 }
329 };
330
331 Ok(bin)
332 }
333
334 pub fn postpone_timeout(&mut self) {
335 let now = Instant::now();
336 let timeout = now + Duration::from_secs(self.config.prompt.history_timeout_min as u64 * 60);
337 self.history_timeout = Some(timeout);
338 }
339
340 pub async fn check_history_timeout(&mut self, ctrl: &Control) {
342 let now = Instant::now();
343
344 if let Some(timeout) = self.history_timeout
345 && now > timeout
346 {
347 self.image_buffer.clear();
348 self.chat_history_mut(ctrl).await.clear();
349 self.history_timeout = None;
350 }
351 }
352
353 pub async fn reply(&self, reply_token: &str, text: &str) -> Result<ReplyResp> {
356 let texts = [text];
357
358 self.reply_multi(reply_token, &texts).await
359 }
360
361 pub async fn reply_multi(&self, reply_token: &str, texts: &[&str]) -> Result<ReplyResp> {
369 let mut messages = Vec::new();
370 for text in texts {
371 ensure!(!text.is_empty(), "text must not be empty");
372 let splitted = split_message(text);
373 messages.extend(splitted.iter().map(|&chunk| Message::Text {
374 text: chunk.to_string(),
375 }));
376 }
377 ensure!(messages.len() <= 5, "text too long: {}", texts.len());
378
379 let req = ReplyReq {
380 reply_token: reply_token.to_string(),
381 messages,
382 notification_disabled: None,
383 };
384 let resp = self.post_auth_json(URL_REPLY, &req).await?;
385 info!("{resp:?}");
386
387 Ok(resp)
388 }
389
390 #[allow(unused)]
394 pub async fn push_message(&self, to: &str, text: &str) -> Result<ReplyResp> {
395 ensure!(!text.is_empty(), "text must not be empty");
396
397 let messages: Vec<_> = split_message(text)
398 .iter()
399 .map(|&chunk| Message::Text {
400 text: chunk.to_string(),
401 })
402 .collect();
403 ensure!(messages.len() <= 5, "text too long: {}", text.len());
404
405 let req = PushReq {
406 to: to.to_string(),
407 messages,
408 notification_disabled: None,
409 custom_aggregation_units: None,
410 };
411 let resp = self.post_auth_json(URL_PUSH, &req).await?;
412 info!("{resp:?}");
413
414 Ok(resp)
415 }
416
417 pub async fn push_image_message(&self, to: &str, url: &str) -> Result<ReplyResp> {
419 let messages = vec![Message::Image {
420 original_content_url: url.to_string(),
421 preview_image_url: url.to_string(),
422 }];
423
424 let req = PushReq {
425 to: to.to_string(),
426 messages,
427 notification_disabled: None,
428 custom_aggregation_units: None,
429 };
430 let resp = self.post_auth_json(URL_PUSH, &req).await?;
431 info!("{resp:?}");
432
433 Ok(resp)
434 }
435
436 async fn check_resp_json<'a, T>(resp: reqwest::Response) -> Result<T>
446 where
447 T: for<'de> Deserialize<'de>,
448 {
449 let status = resp.status();
451 let body = resp.text().await?;
452
453 if status.is_success() {
454 Ok(serde_json::from_reader::<_, T>(body.as_bytes())?)
455 } else {
456 match serde_json::from_str::<ErrorResp>(&body) {
457 Ok(obj) => bail!("{status}: {:?}", obj),
458 Err(json_err) => bail!("{status} - {json_err}: {body}"),
459 }
460 }
461 }
462
463 async fn get_auth_json<'a, T>(&self, url: &str) -> Result<T>
464 where
465 T: for<'de> Deserialize<'de>,
466 {
467 info!("[line] GET {url}");
468 let token = &self.config.token;
469 let resp = self
470 .client
471 .get(url)
472 .header("Authorization", format!("Bearer {token}"))
473 .send()
474 .await?;
475
476 Self::check_resp_json(resp).await
477 }
478
479 async fn post_auth_json<T, R>(&self, url: &str, body: &T) -> Result<R>
480 where
481 T: Serialize + Debug,
482 R: for<'de> Deserialize<'de>,
483 {
484 info!("[line] POST {url} {body:?}");
485 let token = &self.config.token;
486 let resp = self
487 .client
488 .post(url)
489 .header("Authorization", format!("Bearer {token}"))
490 .json(body)
491 .send()
492 .await?;
493
494 Self::check_resp_json(resp).await
495 }
496
497 async fn get_auth_bin(&self, url: &str) -> Result<(StatusCode, Vec<u8>)> {
498 info!("[line] GET {url}");
499 let token = &self.config.token;
500
501 let resp = self
502 .client
503 .get(url)
504 .header("Authorization", format!("Bearer {token}"))
505 .send()
506 .await?;
507
508 let status = resp.status();
509 if status.is_success() {
510 let body = resp.bytes().await?.to_vec();
511
512 Ok((status, body))
513 } else {
514 let body = resp.text().await?;
515
516 match serde_json::from_str::<ErrorResp>(&body) {
517 Ok(obj) => bail!("{status}: {:?}", obj),
518 Err(json_err) => bail!("{status} - {json_err}: {body}"),
519 }
520 }
521 }
522}
523
524fn split_message(text: &str) -> Vec<&str> {
525 let mut result = Vec::new();
527 let mut s = 0;
529 let mut len = 0;
531 for (ind, c) in text.char_indices() {
532 let clen = c.len_utf16();
534 if len + clen > MSG_SPLIT_LEN {
536 result.push(&text[s..ind]);
537 s = ind;
538 len = 0;
539 }
540 len += clen;
541 }
542 if len > 0 {
543 result.push(&text[s..]);
544 }
545
546 result
547}
548
549fn register_draw_picture(func_table: &mut FunctionTable<FunctionContext>) {
550 let mut properties = HashMap::new();
551 properties.insert(
552 "keywords".to_string(),
553 ParameterElement {
554 type_: vec![ParameterType::String],
555 description: Some("Keywords for drawing. They must be in English.".to_string()),
556 ..Default::default()
557 },
558 );
559 func_table.register_function(
560 Function {
561 name: "draw".to_string(),
562 description: Some("Draw a picture".to_string()),
563 parameters: Parameters {
564 properties,
565 required: vec!["keywords".to_string()],
566 ..Default::default()
567 },
568 ..Default::default()
569 },
570 move |bctx, ctx, args| Box::pin(draw_picture(bctx, ctx, args)),
571 );
572}
573
574async fn draw_picture(
575 bctx: Arc<BasicContext>,
576 ctx: FunctionContext,
577 args: &FuncArgs,
578) -> Result<String> {
579 let keywords = function::get_arg_str(args, "keywords")?.to_string();
580
581 let ctrl = bctx.ctrl.clone();
582 taskserver::spawn_oneshot_fn(&ctrl, "line_draw_picture", async move {
583 let url = {
584 let mut ai = bctx.ctrl.sysmods().openai.lock().await;
585
586 ai.generate_image(&keywords, 1)
587 .await?
588 .pop()
589 .ok_or_else(|| anyhow!("parse error"))?
590 };
591 {
592 let line = bctx.ctrl.sysmods().line.lock().await;
593 line.push_image_message(&ctx.reply_to, &url).await?;
594 }
595 Ok(())
596 });
597
598 Ok("Accepted. The result will be automatially posted later. Assistant should not draw for now.".to_string())
599}
600
601#[cfg(test)]
602mod tests {
603 use super::*;
604
605 #[test]
606 fn split_long_message() {
607 let mut src = String::new();
608 assert!(split_message(&src).is_empty());
609
610 for i in 0..MSG_SPLIT_LEN {
611 let a = 'A' as u32;
612 src.push(char::from_u32(a + (i as u32 % 26)).unwrap());
613 }
614 let res = split_message(&src);
615 assert_eq!(1, res.len());
616 assert_eq!(src, res[0]);
617
618 src.push('0');
619 let res = split_message(&src);
620 assert_eq!(2, res.len());
621 assert_eq!(&src[..MSG_SPLIT_LEN], res[0]);
622 assert_eq!(&src[MSG_SPLIT_LEN..], res[1]);
623
624 src.pop();
625 src.pop();
626 src.push('😀');
627 let res = split_message(&src);
628 assert_eq!(2, res.len());
629 assert_eq!(&src[..MSG_SPLIT_LEN - 1], res[0]);
630 assert_eq!(&src[MSG_SPLIT_LEN - 1..], res[1]);
631 }
632}