1use std::{collections::VecDeque, str::Chars};
4
5use anyhow::{Result, anyhow, bail, ensure};
6
7#[derive(Debug, Clone, Copy, PartialEq)]
8pub struct Position {
9 line: u32,
10 column: u32,
11}
12
13impl std::fmt::Display for Position {
14 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
15 write!(f, "{}:{}", self.line, self.column)?;
16 Ok(())
17 }
18}
19
20#[derive(Debug, Clone, Copy, PartialEq)]
21pub enum Token {
22 LParen,
23 RParen,
24 Add,
25 Sub,
26 Mul,
27 Div,
28 Rem,
29 Integer(i64),
30 Eof,
31}
32
33pub type TokenWithPos = (Token, Position);
34
35struct Lexer<'a> {
36 iter: Chars<'a>,
37 buf: Option<char>,
38 pos: Position,
39 newline: bool,
40}
41
42impl Lexer<'_> {
43 fn new(src: &str) -> Lexer<'_> {
44 Lexer {
45 iter: src.chars(),
46 buf: None,
47 pos: Position { line: 1, column: 0 },
48 newline: false,
49 }
50 }
51
52 fn peekc(&mut self) -> Option<char> {
54 if self.buf.is_none() {
55 let c = self.iter.next();
56 self.buf = c;
57 }
58
59 self.buf
60 }
61
62 fn getc(&mut self) -> Option<char> {
64 let c = if self.buf.is_some() {
65 let c = self.buf;
66 self.buf = None;
67 c
68 } else {
69 self.iter.next()
70 };
71
72 if self.newline {
73 self.pos.column = 0;
74 self.pos.line += 1;
75 self.newline = false;
76 }
77
78 if let Some(c) = c {
79 if c == '\n' {
80 self.newline = true;
81 }
82 self.pos.column += 1;
83 }
84
85 c
86 }
87
88 fn next_token(&mut self) -> Result<Option<TokenWithPos>> {
90 loop {
92 let c = self.peekc();
93 if let Some(c) = c {
94 if !c.is_ascii_whitespace() {
95 break;
96 }
97 } else {
98 return Ok(None);
100 }
101 self.getc();
103 }
104
105 let c = self.getc().unwrap();
106 let pos = self.pos;
107
108 let matched = match c {
109 '(' => Some((Token::LParen, pos)),
110 ')' => Some((Token::RParen, pos)),
111
112 '+' => Some((Token::Add, pos)),
113 '-' => Some((Token::Sub, pos)),
114 '*' => Some((Token::Mul, pos)),
115 '/' => Some((Token::Div, pos)),
116 '%' => Some((Token::Rem, pos)),
117
118 _ => None,
119 };
120 if matched.is_some() {
121 return Ok(matched);
122 }
123
124 let range = '0'..='9';
126 if range.contains(&c) {
127 let mut str = String::from(c);
128 loop {
129 if let Some(c) = self.peekc()
130 && range.contains(&c)
131 {
132 str.push(self.getc().unwrap());
133 continue;
134 }
135 break;
136 }
137 if let Ok(n) = str.parse::<i64>() {
138 return Ok(Some((Token::Integer(n), pos)));
139 } else {
140 bail!("{}:{} Invalid number", pos.line, pos.column);
141 };
142 }
143
144 bail!("{}:{} Invalid character", pos.line, pos.column);
145 }
146}
147
148pub fn lexical_analyze(src: &str) -> Result<Vec<TokenWithPos>> {
149 let src = src.to_owned();
150 let mut lexer = Lexer::new(&src);
151 let mut result = Vec::new();
152
153 while let Some(tok) = lexer.next_token()? {
154 result.push(tok);
155 }
156
157 Ok(result)
158}
159
160pub enum Ast {
163 Operation(Operation),
164 Literal(RuntimeValue),
165}
166
167pub enum Operation {
168 Minus(Box<Ast>),
169 Add(Box<Ast>, Box<Ast>),
170 Sub(Box<Ast>, Box<Ast>),
171 Mul(Box<Ast>, Box<Ast>),
172 Div(Box<Ast>, Box<Ast>),
173 Rem(Box<Ast>, Box<Ast>),
174}
175
176#[derive(Debug, Clone, PartialEq)]
177pub enum RuntimeValue {
178 Integer(i64),
179}
180
181impl std::fmt::Display for RuntimeValue {
182 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
183 match *self {
184 RuntimeValue::Integer(n) => f.write_str(&n.to_string())?,
185 }
186 Ok(())
187 }
188}
189
190struct Parser {
191 src: VecDeque<TokenWithPos>,
192}
193
194impl Parser {
195 fn new(src: Vec<TokenWithPos>) -> Self {
196 let mut src = VecDeque::from(src);
197 src.push_back((Token::Eof, Position { line: 0, column: 0 }));
198 Self { src }
199 }
200
201 fn peek(&self) -> &TokenWithPos {
202 self.src.front().unwrap()
203 }
204
205 fn get(&mut self) -> TokenWithPos {
206 self.src.pop_front().unwrap()
207 }
208
209 fn parse_formula(&mut self) -> Result<Ast> {
211 let root = self.parse_term();
212 let (next, pos) = self.get();
213 ensure!(matches!(next, Token::Eof), "Invalid token {}", pos);
214
215 root
216 }
217
218 fn parse_expr(&mut self) -> Result<Ast> {
220 self.parse_term()
221 }
222
223 fn parse_term(&mut self) -> Result<Ast> {
225 let mut lh = self.parse_factor()?;
226 while let (Token::Add | Token::Sub, _) = self.peek() {
227 let plh = Box::new(lh);
228 let (op, _) = self.get();
229 let prh = Box::new(self.parse_factor()?);
230 lh = match op {
231 Token::Add => Ast::Operation(Operation::Add(plh, prh)),
232 Token::Sub => Ast::Operation(Operation::Sub(plh, prh)),
233 _ => panic!("logic error"),
234 };
235 }
236
237 Ok(lh)
238 }
239
240 fn parse_factor(&mut self) -> Result<Ast> {
242 let mut lh = self.parse_unary()?;
243 while let (Token::Mul | Token::Div | Token::Rem, _) = self.peek() {
244 let plh = Box::new(lh);
245 let (op, _) = self.get();
246 let prh = Box::new(self.parse_unary()?);
247 lh = match op {
248 Token::Mul => Ast::Operation(Operation::Mul(plh, prh)),
249 Token::Div => Ast::Operation(Operation::Div(plh, prh)),
250 Token::Rem => Ast::Operation(Operation::Rem(plh, prh)),
251 _ => panic!("logic error"),
252 };
253 }
254
255 Ok(lh)
256 }
257
258 fn parse_unary(&mut self) -> Result<Ast> {
260 let mut minus: bool = false;
261 while let (Token::Add | Token::Sub, _) = self.peek() {
262 let (op, _) = self.get();
263 match op {
264 Token::Add => {}
265 Token::Sub => {
266 minus = !minus;
267 }
268 _ => panic!("logic error"),
269 };
270 }
271
272 let operand = self.parse_primary()?;
273 if !minus {
274 Ok(operand)
275 } else {
276 Ok(Ast::Operation(Operation::Minus(Box::new(operand))))
277 }
278 }
279
280 fn parse_primary(&mut self) -> Result<Ast> {
282 let (tok, _) = self.peek();
283 match *tok {
284 Token::LParen => {
285 assert!(matches!(self.get(), (Token::LParen, _)),);
286 let ast = self.parse_expr()?;
287 ensure!(matches!(self.get(), (Token::RParen, _)), "RPAREN required");
288
289 Ok(ast)
290 }
291 Token::Integer(n) => {
292 assert!(matches!(self.get(), (Token::Integer(_), _)),);
293
294 Ok(Ast::Literal(RuntimeValue::Integer(n)))
295 }
296 _ => {
297 bail!("Parse error");
298 }
299 }
300 }
301}
302
303pub fn parse_formula(src: Vec<TokenWithPos>) -> Result<Ast> {
304 let mut parser = Parser::new(src);
305 let root = parser.parse_formula()?;
306
307 Ok(root)
308}
309
310impl RuntimeValue {
313 fn minus(self) -> Result<Self> {
314 match self {
315 Self::Integer(n) => n
316 .checked_neg()
317 .ok_or_else(|| anyhow!("overflow"))
318 .map(Self::Integer),
319 }
320 }
321
322 fn add(self, rh: Self) -> Result<Self> {
323 match self {
324 Self::Integer(a) => match rh {
325 Self::Integer(b) => a
326 .checked_add(b)
327 .ok_or_else(|| anyhow!("overflow"))
328 .map(Self::Integer),
329 },
330 }
331 }
332
333 fn sub(self, rh: Self) -> Result<Self> {
334 match self {
335 Self::Integer(a) => match rh {
336 Self::Integer(b) => a
337 .checked_sub(b)
338 .ok_or_else(|| anyhow!("overflow"))
339 .map(Self::Integer),
340 },
341 }
342 }
343
344 fn mul(self, rh: Self) -> Result<Self> {
345 match self {
346 Self::Integer(a) => match rh {
347 Self::Integer(b) => a
348 .checked_mul(b)
349 .ok_or_else(|| anyhow!("overflow"))
350 .map(Self::Integer),
351 },
352 }
353 }
354
355 fn div(self, rh: Self) -> Result<Self> {
356 match self {
357 Self::Integer(a) => match rh {
358 Self::Integer(b) => a
359 .checked_div_euclid(b)
360 .ok_or_else(|| {
361 anyhow!(if b == 0 {
362 "division by zero"
363 } else {
364 "overflow"
365 })
366 })
367 .map(Self::Integer),
368 },
369 }
370 }
371
372 fn rem(self, rh: Self) -> Result<Self> {
373 match self {
374 Self::Integer(a) => match rh {
375 Self::Integer(b) => a
376 .checked_rem_euclid(b)
377 .ok_or_else(|| anyhow!("overflow"))
378 .map(Self::Integer),
379 },
380 }
381 }
382}
383
384fn evaluate_operation(op: Operation) -> Result<RuntimeValue> {
385 match op {
386 Operation::Minus(operand) => {
387 let v = evaluate(*operand)?;
388 v.minus()
389 }
390 Operation::Add(lh, rh) => evaluate(*lh)?.add(evaluate(*rh)?),
391 Operation::Sub(lh, rh) => evaluate(*lh)?.sub(evaluate(*rh)?),
392 Operation::Mul(lh, rh) => evaluate(*lh)?.mul(evaluate(*rh)?),
393 Operation::Div(lh, rh) => evaluate(*lh)?.div(evaluate(*rh)?),
394 Operation::Rem(lh, rh) => evaluate(*lh)?.rem(evaluate(*rh)?),
395 }
396}
397
398pub fn evaluate(ast: Ast) -> Result<RuntimeValue> {
399 match ast {
400 Ast::Literal(v) => Ok(v),
401 Ast::Operation(op) => evaluate_operation(op),
402 }
403}
404
405#[cfg(test)]
408mod tests {
409 use super::*;
410
411 #[test]
412 fn lex() {
413 let src = "
414(
415123 + 456
416)";
417 let toks = lexical_analyze(src).unwrap();
418 assert_eq!(5, toks.len());
419
420 assert_eq!((Token::LParen, Position { line: 2, column: 1 }), toks[0]);
421 assert_eq!(
422 (Token::Integer(123), Position { line: 3, column: 1 }),
423 toks[1]
424 );
425 assert_eq!((Token::Add, Position { line: 3, column: 5 }), toks[2]);
426 assert_eq!(
427 (Token::Integer(456), Position { line: 3, column: 7 }),
428 toks[3]
429 );
430 assert_eq!((Token::RParen, Position { line: 4, column: 1 }), toks[4]);
431 }
432
433 #[test]
434 fn parse_eval() {
435 let src = "
436(((1 + 2) * 3) - --(1 + 2 * 3))
437";
438 let toks = lexical_analyze(src).unwrap();
439 let root = parse_formula(toks).unwrap();
440 let v = evaluate(root).unwrap();
441 assert_eq!(RuntimeValue::Integer(2), v);
442 }
443}