1use super::SystemModule;
4use crate::sysmod::openai::InputContent;
5use crate::sysmod::openai::InputItem;
6use crate::sysmod::openai::Role;
7use crate::taskserver::Control;
8use crate::{config, taskserver};
9use utils::graphics::FontRenderer;
10use utils::netutil;
11
12use anyhow::Result;
13use base64::{Engine as _, engine::general_purpose};
14use chrono::NaiveTime;
15use log::warn;
16use log::{debug, info};
17use rand::Rng;
18use reqwest::multipart;
19use serde::{Deserialize, Serialize};
20use std::collections::{BTreeMap, BTreeSet, HashMap};
21use std::fs;
22use std::time::{Duration, SystemTime, UNIX_EPOCH};
23
24const LONG_TWEET_FONT_SIZE: u32 = 16;
25const LONG_TWEET_IMAGE_WIDTH: u32 = 640;
26const LONG_TWEET_FGCOLOR: (u8, u8, u8) = (255, 255, 255);
27const LONG_TWEET_BGCOLOR: (u8, u8, u8) = (0, 0, 0);
28
29const TIMEOUT: Duration = Duration::from_secs(20);
30
31pub const TWEET_LEN_MAX: usize = 140;
33pub const LIMIT_PHOTO_COUNT: usize = 4;
34pub const LIMIT_PHOTO_SIZE: usize = 5_000_000;
35
36const URL_USERS_ME: &str = "https://api.twitter.com/2/users/me";
37const URL_USERS_BY: &str = "https://api.twitter.com/2/users/by";
38const LIMIT_USERS_BY: usize = 100;
39
40macro_rules! URL_USERS_TIMELINES_HOME {
41 () => {
42 "https://api.twitter.com/2/users/{}/timelines/reverse_chronological"
43 };
44}
45macro_rules! URL_USERS_TWEET {
46 () => {
47 "https://api.twitter.com/2/users/{}/tweets"
48 };
49}
50
51const URL_TWEETS: &str = "https://api.twitter.com/2/tweets";
52
53const URL_UPLOAD: &str = "https://upload.twitter.com/1.1/media/upload.json";
54
55#[derive(Clone, Debug, Serialize, Deserialize)]
56struct User {
57 id: String,
58 name: String,
59 username: String,
60}
61
62#[derive(Clone, Debug, Serialize, Deserialize)]
63struct UsersMe {
64 data: User,
65}
66
67#[derive(Clone, Debug, Serialize, Deserialize)]
68struct UsersBy {
69 data: Vec<User>,
70}
71
72#[derive(Clone, Debug, Serialize, Deserialize)]
73struct Mention {
74 start: u32,
75 end: u32,
76 username: String,
77}
78
79#[derive(Clone, Debug, Serialize, Deserialize)]
80struct HashTag {
81 start: u32,
82 end: u32,
83 tag: String,
84}
85
86#[derive(Default, Clone, Debug, Serialize, Deserialize)]
87struct Entities {
88 #[serde(default)]
89 mentions: Vec<Mention>,
90 #[serde(default)]
91 hashtags: Vec<HashTag>,
92}
93
94#[derive(Clone, Debug, Serialize, Deserialize)]
95struct Includes {
96 #[serde(default)]
97 users: Vec<User>,
98}
99
100#[derive(Clone, Debug, Serialize, Deserialize)]
101struct Tweet {
102 id: String,
103 text: String,
104 author_id: Option<String>,
105 edit_history_tweet_ids: Vec<String>,
106 #[serde(default)]
108 entities: Entities,
109}
110
111#[derive(Clone, Debug, Serialize, Deserialize)]
112struct Meta {
113 result_count: u64,
115 newest_id: Option<String>,
117 oldest_id: Option<String>,
119}
120
121#[derive(Clone, Debug, Serialize, Deserialize)]
122struct Timeline {
123 data: Vec<Tweet>,
124 includes: Option<Includes>,
126 meta: Meta,
127}
128
129#[derive(Default, Clone, Debug, Serialize, Deserialize)]
130struct TweetParamReply {
131 in_reply_to_tweet_id: String,
132}
133
134#[derive(Default, Clone, Debug, Serialize, Deserialize)]
135struct TweetParamPoll {
136 duration_minutes: u32,
137 options: Vec<String>,
138}
139
140#[derive(Default, Clone, Debug, Serialize, Deserialize)]
141struct Media {
142 #[serde(skip_serializing_if = "Option::is_none")]
143 media_ids: Option<Vec<String>>,
144 #[serde(skip_serializing_if = "Option::is_none")]
145 tagged_user_ids: Option<Vec<String>>,
146}
147
148#[derive(Default, Clone, Debug, Serialize, Deserialize)]
149struct TweetParam {
150 #[serde(skip_serializing_if = "Option::is_none")]
151 poll: Option<TweetParamPoll>,
152 #[serde(skip_serializing_if = "Option::is_none")]
153 reply: Option<TweetParamReply>,
154 #[serde(skip_serializing_if = "Option::is_none")]
156 text: Option<String>,
157 #[serde(skip_serializing_if = "Option::is_none")]
159 media: Option<Media>,
160}
161
162#[derive(Clone, Debug, Serialize, Deserialize)]
163struct TweetResponse {
164 data: TweetResponseData,
165}
166
167#[derive(Clone, Debug, Serialize, Deserialize)]
168struct TweetResponseData {
169 id: String,
170 text: String,
171}
172
173#[derive(Clone, Debug, Serialize, Deserialize)]
174struct UploadResponseData {
175 media_id: u64,
176 size: u64,
177 expires_after_secs: u64,
178}
179
180#[derive(Debug, Clone, Serialize, Deserialize)]
182pub struct TwitterConfig {
183 tlcheck_enabled: bool,
185 debug_exec_once: bool,
187 fake_tweet: bool,
189 consumer_key: String,
191 consumer_secret: String,
193 access_token: String,
195 access_secret: String,
197 ai_hashtag: String,
199 font_file: String,
206 #[serde(default)]
208 tlcheck: TimelineCheck,
209 #[serde(default)]
211 prompt: TwitterPrompt,
212}
213
214impl Default for TwitterConfig {
215 fn default() -> Self {
216 Self {
217 tlcheck_enabled: false,
218 debug_exec_once: false,
219 fake_tweet: true,
220 consumer_key: "".to_string(),
221 consumer_secret: "".to_string(),
222 access_token: "".to_string(),
223 access_secret: "".to_string(),
224 ai_hashtag: "DollsAI".to_string(),
225 font_file: "".to_string(),
226 tlcheck: Default::default(),
227 prompt: Default::default(),
228 }
229 }
230}
231
232#[derive(Debug, Default, Clone, Serialize, Deserialize)]
234pub struct TimelineCheckRule {
235 pub user_names: Vec<String>,
237 pub patterns: Vec<(Vec<String>, Vec<String>)>,
246}
247
248#[derive(Debug, Clone, Serialize, Deserialize)]
250pub struct TimelineCheck {
251 pub rules: Vec<TimelineCheckRule>,
253}
254
255const DEFAULT_TLCHECK_TOML: &str =
257 include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/res/tlcheck.toml"));
258impl Default for TimelineCheck {
259 fn default() -> Self {
260 toml::from_str(DEFAULT_TLCHECK_TOML).unwrap()
261 }
262}
263
264#[derive(Debug, Clone, Serialize, Deserialize)]
266pub struct TwitterPrompt {
267 pub pre: Vec<String>,
268}
269
270const DEFAULT_PROMPT_TOML: &str = include_str!(concat!(
272 env!("CARGO_MANIFEST_DIR"),
273 "/res/openai_twitter.toml"
274));
275impl Default for TwitterPrompt {
276 fn default() -> Self {
277 toml::from_str(DEFAULT_PROMPT_TOML).unwrap()
278 }
279}
280
281pub struct Twitter {
282 config: TwitterConfig,
283
284 wakeup_list: Vec<NaiveTime>,
285
286 font: Option<FontRenderer>,
287
288 tl_check_since_id: Option<String>,
297 my_user_cache: Option<User>,
299 username_user_cache: HashMap<String, User>,
301 id_username_cache: HashMap<String, String>,
303}
304
305struct Reply {
306 to_tw_id: String,
307 to_user_id: String,
308 text: String,
309 post_image_if_long: bool,
310}
311
312impl Twitter {
313 pub fn new(wakeup_list: Vec<NaiveTime>) -> Result<Self> {
314 info!("[twitter] initialize");
315
316 let config = config::get(|cfg| cfg.twitter.clone());
317
318 let font = if !config.font_file.is_empty() {
319 let ttf_bin = fs::read(&config.font_file)?;
320 Some(FontRenderer::new(ttf_bin)?)
321 } else {
322 None
323 };
324
325 Ok(Twitter {
326 config,
327 wakeup_list,
328 font,
329 tl_check_since_id: None,
330 my_user_cache: None,
331 username_user_cache: HashMap::new(),
332 id_username_cache: HashMap::new(),
333 })
334 }
335
336 async fn twitter_task(&mut self, ctrl: &Control) -> Result<()> {
338 let me = self.get_my_id().await?;
340 info!("[tw-check] user_me: {me:?}");
341
342 let since_id = self.get_since_id().await?;
344 info!("[tw-check] since_id: {since_id}");
345
346 info!("[tw-check] get all user info from screen name");
348 let rules = self.config.tlcheck.rules.clone();
350 for rule in rules.iter() {
351 self.resolve_ids(&rule.user_names).await?;
352 }
353 info!(
354 "[tw-check] user id cache size: {}",
355 self.username_user_cache.len()
356 );
357
358 let tl = self.users_timelines_home(&me.id, &since_id).await?;
362 info!("{} tweets fetched", tl.data.len());
363
364 let mut reply_buf = self.create_reply_list(&tl, &me);
366 reply_buf.append(&mut self.create_ai_reply_list(ctrl, &tl, &me).await);
368
369 for Reply {
371 to_tw_id,
372 to_user_id,
373 text,
374 post_image_if_long,
375 } in reply_buf
376 {
377 let cur: u64 = self.tl_check_since_id.as_ref().unwrap().parse().unwrap();
381 let next: u64 = to_tw_id.parse().unwrap();
382 let max = cur.max(next);
383
384 let name = self.get_username_from_id(&to_user_id).unwrap();
385 info!("reply to: {name}");
386
387 if self.font.is_some() && post_image_if_long && text.chars().count() > TWEET_LEN_MAX {
389 let pngbin = self.font.as_ref().unwrap().draw_multiline_text(
390 LONG_TWEET_FGCOLOR,
391 LONG_TWEET_BGCOLOR,
392 &text,
393 LONG_TWEET_FONT_SIZE,
394 LONG_TWEET_IMAGE_WIDTH,
395 );
396 let media_id = self.media_upload(pngbin).await?;
397 self.tweet_custom("", Some(&to_tw_id), &[media_id]).await?;
398 } else {
399 self.tweet_custom(&text, Some(&to_tw_id), &[]).await?;
400 }
401
402 self.tl_check_since_id = Some(max.to_string());
404 }
405
406 Ok(())
421 }
422
423 fn create_reply_list(&self, tl: &Timeline, me: &User) -> Vec<Reply> {
425 let mut reply_buf = Vec::new();
426
427 for rule in self.config.tlcheck.rules.iter() {
428 let tliter = tl
430 .data
431 .iter()
432 .filter(|tw| tw.author_id.is_some())
434 .filter(|tw| *tw.author_id.as_ref().unwrap() != me.id)
436 .filter(|tw| {
438 !tw.entities
439 .hashtags
440 .iter()
441 .any(|v| v.tag == self.config.ai_hashtag)
442 });
443
444 for tw in tliter {
445 let user_match = rule.user_names.iter().any(|user_name| {
447 let user = self.get_user_from_username(user_name);
448 match user {
449 Some(user) => *tw.author_id.as_ref().unwrap() == user.id,
450 None => false,
452 }
453 });
454 if !user_match {
455 continue;
456 }
457 for (pats, msgs) in rule.patterns.iter() {
459 let match_hit = pats.iter().all(|pat| Self::pattern_match(pat, &tw.text));
461 if match_hit {
462 info!("FIND: {tw:?}");
463 let rnd_idx = rand::rng().random_range(0..msgs.len());
465 reply_buf.push(Reply {
466 to_tw_id: tw.id.clone(),
467 to_user_id: tw.author_id.as_ref().unwrap().clone(),
468 text: msgs[rnd_idx].clone(),
469 post_image_if_long: false,
470 });
471 break;
474 }
475 }
476 }
477 }
478
479 reply_buf
480 }
481
482 async fn create_ai_reply_list(&self, ctrl: &Control, tl: &Timeline, me: &User) -> Vec<Reply> {
484 let mut reply_buf = Vec::new();
485
486 let tliter = tl
487 .data
488 .iter()
489 .filter(|tw| tw.author_id.is_some())
491 .filter(|tw| *tw.author_id.as_ref().unwrap() != me.id)
493 .filter(|tw| {
495 tw.entities
496 .mentions
497 .iter()
498 .any(|v| v.username == me.username)
499 })
500 .filter(|tw| {
502 tw.entities
503 .hashtags
504 .iter()
505 .any(|v| v.tag == self.config.ai_hashtag)
506 });
507
508 for tw in tliter {
509 info!("FIND (AI): {tw:?}");
510
511 let user = Self::resolve_user(
512 tw.author_id.as_ref().unwrap(),
513 &tl.includes.as_ref().unwrap().users,
514 );
515 if user.is_none() {
516 warn!("User {} is not found", tw.author_id.as_ref().unwrap());
517 continue;
518 }
519
520 let system_msgs: Vec<_> = self
522 .config
523 .prompt
524 .pre
525 .iter()
526 .map(|text| {
527 let text = text.replace("${user}", &user.unwrap().name);
528 InputItem::Message {
529 role: Role::Developer,
530 content: vec![InputContent::InputText { text }],
531 }
532 })
533 .collect();
534
535 let mut main_msg = String::new();
536 for (ind, ch) in tw.text.chars().enumerate() {
538 let ind = ind as u32;
539 let mut deleted = false;
540 for m in tw.entities.mentions.iter() {
541 if (m.start..m.end).contains(&ind) {
542 deleted = true;
543 break;
544 }
545 }
546 for h in tw.entities.hashtags.iter() {
547 if (h.start..h.end).contains(&ind) {
548 deleted = true;
549 break;
550 }
551 }
552 if !deleted {
553 main_msg.push(ch);
554 }
555 }
556
557 let mut msgs = system_msgs.clone();
559 msgs.push(InputItem::Message {
560 role: Role::User,
561 content: vec![InputContent::InputText { text: main_msg }],
562 });
563
564 {
567 let mut ai = ctrl.sysmods().openai.lock().await;
568 match ai.chat(None, msgs).await {
569 Ok(resp) => reply_buf.push(Reply {
570 to_tw_id: tw.id.clone(),
571 to_user_id: tw.author_id.as_ref().unwrap().clone(),
572 text: resp.output_text(),
573 post_image_if_long: true,
574 }),
575 Err(e) => {
576 warn!("AI chat error: {e}");
577 }
578 }
579 }
580 }
581
582 reply_buf
583 }
584
585 fn resolve_user<'a>(id: &str, users: &'a [User]) -> Option<&'a User> {
586 users.iter().find(|&user| user.id == id)
587 }
588
589 #[allow(clippy::bool_to_int_with_if)]
593 fn pattern_match(pat: &str, text: &str) -> bool {
594 let count = pat.chars().count();
595 if count == 0 {
596 return false;
597 }
598 let match_start = pat.starts_with('^');
599 let match_end = pat.ends_with('$');
600 let begin = pat
601 .char_indices()
602 .nth(if match_start { 1 } else { 0 })
604 .unwrap_or((0, '\0'))
605 .0;
606 let end = pat
607 .char_indices()
608 .nth(if match_end { count - 1 } else { count })
609 .unwrap_or((pat.len(), '\0'))
610 .0;
611 let pat = &pat[begin..end];
612 if pat.is_empty() {
613 return false;
614 }
615
616 if match_start && match_end {
617 text == pat
618 } else if match_start {
619 text.starts_with(pat)
620 } else if match_end {
621 text.ends_with(pat)
622 } else {
623 text.contains(pat)
624 }
625 }
626
627 async fn get_since_id(&mut self) -> Result<String> {
629 let me = self.get_my_id().await?;
630 if self.tl_check_since_id.is_none() {
631 let usertw = self.users_tweets(&me.id).await?;
632 self.tl_check_since_id = Some(usertw.meta.newest_id.unwrap_or_else(|| "1".into()));
634 }
635
636 Ok(self.tl_check_since_id.clone().unwrap())
637 }
638
639 pub async fn tweet(&mut self, text: &str) -> Result<()> {
642 self.tweet_custom(text, None, &[]).await
643 }
644
645 pub async fn tweet_custom(
648 &mut self,
649 text: &str,
650 reply_to: Option<&str>,
651 media_ids: &[u64],
652 ) -> Result<()> {
653 let reply = reply_to.map(|id| TweetParamReply {
654 in_reply_to_tweet_id: id.to_string(),
655 });
656
657 let media_ids = if media_ids.is_empty() {
658 None
659 } else {
660 let media_ids: Vec<_> = media_ids.iter().map(|id| id.to_string()).collect();
661 Some(media_ids)
662 };
663 let media = media_ids.map(|media_ids| Media {
664 media_ids: Some(media_ids),
665 ..Default::default()
666 });
667
668 let param = TweetParam {
669 reply,
670 text: Some(text.to_string()),
671 media,
672 ..Default::default()
673 };
674
675 self.tweet_raw(param).await
676 }
677
678 async fn tweet_raw(&mut self, mut param: TweetParam) -> Result<()> {
680 self.get_since_id().await?;
682
683 if let Some(ref text) = param.text {
685 let len = text.chars().count();
686 if len > TWEET_LEN_MAX {
687 warn!("tweet length > {TWEET_LEN_MAX}: {len}");
688 warn!("before: {text}");
689 let text = Self::truncate_tweet_text(text).to_string();
690 warn!("after : {text}");
691 param.text = Some(text);
692 }
693 }
694
695 if !self.config.fake_tweet {
696 self.tweets_post(param).await?;
698
699 Ok(())
700 } else {
701 info!("fake tweet: {param:?}");
702
703 Ok(())
704 }
705 }
706
707 fn truncate_tweet_text(text: &str) -> &str {
709 let lastc = text.char_indices().nth(TWEET_LEN_MAX);
711
712 match lastc {
713 Some((ind, _)) => &text[0..ind],
715 None => text,
717 }
718 }
719
720 pub async fn media_upload<T: Into<reqwest::Body>>(&self, bin: T) -> Result<u64> {
723 if self.config.fake_tweet {
724 info!("fake upload");
725
726 return Ok(0);
727 }
728
729 info!("upload");
730 let part = multipart::Part::stream(bin);
731 let form = multipart::Form::new().part("media", part);
732
733 let resp = self
734 .http_oauth_post_multipart(URL_UPLOAD, &BTreeMap::new(), form)
735 .await?;
736 let json_str = netutil::check_http_resp(resp).await?;
737 let obj: UploadResponseData = netutil::convert_from_json(&json_str)?;
738 info!("upload OK: media_id={}", obj.media_id);
739
740 Ok(obj.media_id)
741 }
742
743 async fn twitter_task_entry(ctrl: Control) -> Result<()> {
748 let mut twitter = ctrl.sysmods().twitter.lock().await;
749 twitter.twitter_task(&ctrl).await
750 }
751
752 async fn get_my_id(&mut self) -> Result<User> {
755 if let Some(user) = &self.my_user_cache {
756 Ok(user.clone())
757 } else {
758 Ok(self.users_me().await?.data)
759 }
760 }
761
762 fn get_user_from_username(&self, name: &String) -> Option<&User> {
763 self.username_user_cache.get(name)
764 }
765
766 fn get_username_from_id(&self, id: &String) -> Option<&String> {
767 self.id_username_cache.get(id)
768 }
769
770 async fn resolve_ids(&mut self, user_names: &[String]) -> Result<()> {
777 let unknown_users: Vec<_> = user_names
779 .iter()
780 .filter_map(|user| {
781 if !self.username_user_cache.contains_key(user) {
782 Some(user.clone())
783 } else {
784 None
785 }
786 })
787 .collect();
788
789 let mut start = 0_usize;
791 while start < unknown_users.len() {
792 let end = std::cmp::min(unknown_users.len(), start + LIMIT_USERS_BY);
793 let request_users = &unknown_users[start..end];
794 let mut rest: BTreeSet<_> = request_users.iter().collect();
795
796 let result = self.users_by(request_users).await;
801 if let Err(e) = result {
802 if e.is::<serde_json::Error>() {
803 panic!("parse error {e:?}");
804 } else {
805 return Err(e);
806 }
807 }
808
809 for user in result?.data.iter() {
810 info!(
811 "[twitter] resolve username: {} => {}",
812 user.username, user.id
813 );
814 self.username_user_cache
815 .insert(user.username.clone(), user.clone());
816 self.id_username_cache
817 .insert(user.id.clone(), user.username.clone());
818 let removed = rest.remove(&user.username);
819 assert!(removed);
820 }
821 assert!(
822 rest.is_empty(),
823 "cannot resolved (account suspended?): {rest:?}"
824 );
825
826 start += LIMIT_USERS_BY;
827 }
828 assert_eq!(self.username_user_cache.len(), self.id_username_cache.len());
829
830 Ok(())
831 }
832
833 async fn users_me(&self) -> Result<UsersMe> {
834 let resp = self.http_oauth_get(URL_USERS_ME, &KeyValue::new()).await?;
835 let json_str = netutil::check_http_resp(resp).await?;
836 let obj: UsersMe = netutil::convert_from_json(&json_str)?;
837
838 Ok(obj)
839 }
840
841 async fn users_by(&self, users: &[String]) -> Result<UsersBy> {
842 if !(1..LIMIT_USERS_BY).contains(&users.len()) {
843 panic!("{} limit over: {}", URL_USERS_BY, users.len());
844 }
845 let users_str = users.join(",");
846 let resp = self
847 .http_oauth_get(
848 URL_USERS_BY,
849 &BTreeMap::from([("usernames".into(), users_str)]),
850 )
851 .await?;
852 let json_str = netutil::check_http_resp(resp).await?;
853 let obj: UsersBy = netutil::convert_from_json(&json_str)?;
854
855 Ok(obj)
856 }
857
858 async fn users_timelines_home(&self, id: &str, since_id: &str) -> Result<Timeline> {
859 let url = format!(URL_USERS_TIMELINES_HOME!(), id);
860 let param = KeyValue::from([
861 ("since_id".to_string(), since_id.to_string()),
862 ("exclude".to_string(), "retweets".to_string()),
863 ("expansions".to_string(), "author_id".to_string()),
864 ("tweet.fields".to_string(), "entities".to_string()),
865 ]);
866 let resp = self.http_oauth_get(&url, ¶m).await?;
867 let json_str = netutil::check_http_resp(resp).await?;
868 debug!("{json_str}");
869 let obj: Timeline = netutil::convert_from_json(&json_str)?;
870
871 Ok(obj)
872 }
873
874 async fn users_tweets(&self, id: &str) -> Result<Timeline> {
875 let url = format!(URL_USERS_TWEET!(), id);
876 let param = KeyValue::from([
877 ("exclude".into(), "retweets".into()),
879 ("max_results".into(), "100".into()),
881 ]);
882 let resp = self.http_oauth_get(&url, ¶m).await?;
883 let json_str = netutil::check_http_resp(resp).await?;
884 let obj: Timeline = netutil::convert_from_json(&json_str)?;
885
886 Ok(obj)
887 }
888
889 async fn tweets_post(&self, param: TweetParam) -> Result<TweetResponse> {
890 let resp = self
891 .http_oauth_post_json(URL_TWEETS, &KeyValue::new(), ¶m)
892 .await?;
893 let json_str = netutil::check_http_resp(resp).await?;
894 let obj: TweetResponse = netutil::convert_from_json(&json_str)?;
895
896 Ok(obj)
897 }
898
899 async fn http_oauth_get(
900 &self,
901 base_url: &str,
902 query_param: &KeyValue,
903 ) -> Result<reqwest::Response> {
904 let cf = &self.config;
905 let mut oauth_param = create_oauth_field(&cf.consumer_key, &cf.access_token);
906 let signature = create_signature(
907 "GET",
908 base_url,
909 &oauth_param,
910 query_param,
911 &KeyValue::new(),
912 &cf.consumer_secret,
913 &cf.access_secret,
914 );
915 oauth_param.insert("oauth_signature".into(), signature);
916
917 let (oauth_k, oauth_v) = create_http_oauth_header(&oauth_param);
918
919 let client = reqwest::Client::new();
920 let req = client
921 .get(base_url)
922 .timeout(TIMEOUT)
923 .query(&query_param)
924 .header(oauth_k, oauth_v);
925 let res = req.send().await?;
926
927 Ok(res)
928 }
929
930 async fn http_oauth_post_json<T: Serialize>(
931 &self,
932 base_url: &str,
933 query_param: &KeyValue,
934 body_param: &T,
935 ) -> Result<reqwest::Response> {
936 let json_str = serde_json::to_string(body_param).unwrap();
937 debug!("POST: {json_str}");
938
939 let client = reqwest::Client::new();
940 let req = self
941 .http_oauth_post(&client, base_url, query_param)
942 .header("Content-type", "application/json")
943 .body(json_str);
944 let resp = req.send().await?;
945
946 Ok(resp)
947 }
948
949 async fn http_oauth_post_multipart(
950 &self,
951 base_url: &str,
952 query_param: &KeyValue,
953 body: multipart::Form,
954 ) -> Result<reqwest::Response> {
955 let client = reqwest::Client::new();
956 let req = self
957 .http_oauth_post(&client, base_url, query_param)
958 .multipart(body);
959 let resp = req.send().await?;
960
961 Ok(resp)
962 }
963
964 fn http_oauth_post(
965 &self,
966 client: &reqwest::Client,
967 base_url: &str,
968 query_param: &KeyValue,
969 ) -> reqwest::RequestBuilder {
970 let cf = &self.config;
971 let mut oauth_param = create_oauth_field(&cf.consumer_key, &cf.access_token);
972 let signature = create_signature(
973 "POST",
974 base_url,
975 &oauth_param,
976 query_param,
977 &KeyValue::new(),
978 &cf.consumer_secret,
979 &cf.access_secret,
980 );
981 oauth_param.insert("oauth_signature".into(), signature);
982
983 let (oauth_k, oauth_v) = create_http_oauth_header(&oauth_param);
984
985 client
986 .post(base_url)
987 .timeout(TIMEOUT)
988 .query(query_param)
989 .header(oauth_k, oauth_v)
990 }
991}
992
993impl SystemModule for Twitter {
994 fn on_start(&mut self, ctrl: &Control) {
995 info!("[twitter] on_start");
996 if self.config.tlcheck_enabled {
997 if self.config.debug_exec_once {
998 taskserver::spawn_oneshot_task(ctrl, "tw-check", Twitter::twitter_task_entry);
999 } else {
1000 taskserver::spawn_periodic_task(
1001 ctrl,
1002 "tw-check",
1003 &self.wakeup_list,
1004 Twitter::twitter_task_entry,
1005 );
1006 }
1007 }
1008 }
1009}
1010
1011type KeyValue = BTreeMap<String, String>;
1016
1017fn create_oauth_field(consumer_key: &str, access_token: &str) -> KeyValue {
1026 let mut param = KeyValue::new();
1027
1028 param.insert("oauth_consumer_key".into(), consumer_key.into());
1030
1031 let mut rng = rand::rng();
1036 let rnd32: [u8; 32] = rng.random();
1037 let rnd32_str = general_purpose::STANDARD.encode(rnd32);
1038 let mut nonce_str = "".to_string();
1039 for c in rnd32_str.chars() {
1040 if c.is_alphanumeric() {
1041 nonce_str.push(c);
1042 }
1043 }
1044 param.insert("oauth_nonce".into(), nonce_str);
1045
1046 param.insert("oauth_signature_method".to_string(), "HMAC-SHA1".into());
1052 let unix_epoch_sec = SystemTime::now()
1053 .duration_since(UNIX_EPOCH)
1054 .unwrap()
1055 .as_secs();
1056 param.insert("oauth_timestamp".into(), unix_epoch_sec.to_string());
1057 param.insert("oauth_token".into(), access_token.into());
1058 param.insert("oauth_version".into(), "1.0".into());
1059
1060 param
1061}
1062
1063fn create_signature(
1075 http_method: &str,
1076 base_url: &str,
1077 oauth_param: &KeyValue,
1078 query_param: &KeyValue,
1079 body_param: &KeyValue,
1080 consumer_secret: &str,
1081 token_secret: &str,
1082) -> String {
1083 let mut param = KeyValue::new();
1106 let encode_add = |param: &mut KeyValue, src: &KeyValue| {
1107 for (k, v) in src.iter() {
1108 let old = param.insert(netutil::percent_encode(k), netutil::percent_encode(v));
1109 if old.is_some() {
1110 panic!("duplicate key: {k}");
1111 }
1112 }
1113 };
1114 encode_add(&mut param, oauth_param);
1115 encode_add(&mut param, query_param);
1116 encode_add(&mut param, body_param);
1117
1118 let mut parameter_string = "".to_string();
1120 let mut is_first = true;
1121 for (k, v) in param {
1122 if is_first {
1123 is_first = false;
1124 } else {
1125 parameter_string.push('&');
1126 }
1127 parameter_string.push_str(&k);
1128 parameter_string.push('=');
1129 parameter_string.push_str(&v);
1130 }
1131
1132 let mut signature_base_string = "".to_string();
1143 signature_base_string.push_str(&http_method.to_ascii_uppercase());
1144 signature_base_string.push('&');
1145 signature_base_string.push_str(&netutil::percent_encode(base_url));
1146 signature_base_string.push('&');
1147 signature_base_string.push_str(&netutil::percent_encode(¶meter_string));
1148
1149 let mut signing_key = "".to_string();
1152 signing_key.push_str(consumer_secret);
1153 signing_key.push('&');
1154 signing_key.push_str(token_secret);
1155
1156 let result = netutil::hmac_sha1(signing_key.as_bytes(), signature_base_string.as_bytes());
1159
1160 general_purpose::STANDARD.encode(result.into_bytes())
1162}
1163
1164fn create_http_oauth_header(oauth_param: &KeyValue) -> (String, String) {
1168 let mut oauth_value = "OAuth ".to_string();
1169 {
1170 let v: Vec<_> = oauth_param
1171 .iter()
1172 .map(|(k, v)| {
1173 format!(
1174 r#"{}="{}""#,
1175 netutil::percent_encode(k),
1176 netutil::percent_encode(v)
1177 )
1178 })
1179 .collect();
1180 oauth_value.push_str(&v.join(", "));
1181 }
1182
1183 ("Authorization".into(), oauth_value)
1184}
1185
1186#[cfg(test)]
1187mod tests {
1188 use super::*;
1189
1190 #[test]
1191 fn parse_default_toml() {
1192 let obj: TimelineCheck = Default::default();
1194 assert_ne!(obj.rules.len(), 0);
1195
1196 let obj: TwitterPrompt = Default::default();
1197 assert_ne!(obj.pre.len(), 0);
1198 }
1199
1200 #[test]
1201 fn truncate_tweet_text() {
1202 let from1 = "あいうえおかきくけこ0123456789".repeat(7);
1204 let to1 = Twitter::truncate_tweet_text(&from1).to_string();
1205 assert_eq!(from1.chars().count(), TWEET_LEN_MAX);
1206 assert_eq!(from1, to1);
1207
1208 let from2 = format!("{from1}あ");
1209 let to2 = Twitter::truncate_tweet_text(&from2).to_string();
1210 assert_eq!(from2.chars().count(), TWEET_LEN_MAX + 1);
1211 assert_eq!(from1, to2);
1212 }
1213
1214 #[test]
1215 fn tweet_pattern_match() {
1216 assert!(Twitter::pattern_match("あいうえお", "あいうえお"));
1217 assert!(Twitter::pattern_match("^あいうえお", "あいうえお"));
1218 assert!(Twitter::pattern_match("あいうえお$", "あいうえお"));
1219 assert!(Twitter::pattern_match("^あいうえお$", "あいうえお"));
1220
1221 assert!(Twitter::pattern_match("あいう", "あいうえお"));
1222 assert!(Twitter::pattern_match("^あいう", "あいうえお"));
1223 assert!(!Twitter::pattern_match("あいう$", "あいうえお"));
1224 assert!(!Twitter::pattern_match("^あいう$", "あいうえお"));
1225
1226 assert!(Twitter::pattern_match("うえお", "あいうえお"));
1227 assert!(!Twitter::pattern_match("^うえお", "あいうえお"));
1228 assert!(Twitter::pattern_match("うえお$", "あいうえお"));
1229 assert!(!Twitter::pattern_match("^うえお$", "あいうえお"));
1230
1231 assert!(Twitter::pattern_match("いうえ", "あいうえお"));
1232 assert!(!Twitter::pattern_match("^いうえ", "あいうえお"));
1233 assert!(!Twitter::pattern_match("いうえ$", "あいうえお"));
1234 assert!(!Twitter::pattern_match("^いうえ$", "あいうえお"));
1235
1236 assert!(!Twitter::pattern_match("", "あいうえお"));
1237 assert!(!Twitter::pattern_match("^", "あいうえお"));
1238 assert!(!Twitter::pattern_match("$", "あいうえお"));
1239 assert!(!Twitter::pattern_match("^$", "あいうえお"));
1240 }
1241
1242 #[test]
1244 fn twitter_sample_signature() {
1245 let method = "POST";
1246 let url = "https://api.twitter.com/1.1/statuses/update.json";
1247
1248 let mut oauth_param = KeyValue::new();
1251 oauth_param.insert("oauth_consumer_key".into(), "xvz1evFS4wEEPTGEFPHBog".into());
1252 oauth_param.insert(
1253 "oauth_nonce".into(),
1254 "kYjzVBB8Y0ZFabxSWbWovY3uYSQ2pTgmZeNu2VS4cg".into(),
1255 );
1256 oauth_param.insert("oauth_signature_method".into(), "HMAC-SHA1".into());
1257 oauth_param.insert("oauth_timestamp".into(), "1318622958".into());
1258 oauth_param.insert(
1259 "oauth_token".into(),
1260 "370773112-GmHxMAgYyLbNEtIKZeRNFsMKPR9EyMZeS9weJAEb".into(),
1261 );
1262 oauth_param.insert("oauth_version".into(), "1.0".into());
1263
1264 let mut query_param = KeyValue::new();
1265 query_param.insert("include_entities".into(), "true".into());
1266
1267 let mut body_param = KeyValue::new();
1268 body_param.insert(
1269 "status".into(),
1270 "Hello Ladies + Gentlemen, a signed OAuth request!".into(),
1271 );
1272
1273 let consumer_secret = "kAcSOqF21Fu85e7zjz7ZN2U4ZRhfV3WpwPAoE3Z7kBw";
1276 let token_secret = "LswwdoUaIvS8ltyTt5jkRh4J50vUPVVHtR2YPi5kE";
1277
1278 let result = create_signature(
1279 method,
1280 url,
1281 &oauth_param,
1282 &query_param,
1283 &body_param,
1284 consumer_secret,
1285 token_secret,
1286 );
1287
1288 assert_eq!(result, "hCtSmYh+iHYCEqBWrE7C7hYmtUk=");
1289 }
1290}