use core::{cell::Cell, fmt, ops::Deref};

use alloc::vec::Vec;

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Span {
    pub start: usize,
    pub end: usize,
}

impl Span {
    pub fn new(start: usize, end: usize) -> Self {
        Self { start, end }
    }

    pub fn slice<'a>(&self, source: &'a SourceCode) -> &'a str {
        &source.code[self.start..self.end]
    }
}

/// Source code string with a verified size limit.
/// An exact size limit is not enforced by this type - it only ensures the string isn't longer than
/// intended, to not stall the parser for an unexpected amount of time.
#[derive(Debug, PartialEq, Eq)]
#[repr(transparent)]
pub struct SourceCode {
    code: str,
}

impl SourceCode {
    pub fn limited_len(code: &str, max_len: usize) -> Option<&Self> {
        if code.len() <= max_len {
            Some(Self::unlimited_len(code))
        } else {
            None
        }
    }

    pub fn unlimited_len(code: &str) -> &Self {
        // SAFETY: SourceCode is a transparent wrapper around str, so converting between them is safe.
        unsafe { core::mem::transmute(code) }
    }
}

impl Deref for SourceCode {
    type Target = str;

    fn deref(&self) -> &Self::Target {
        &self.code
    }
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct NodeId(usize);

impl NodeId {
    pub const NIL: NodeId = NodeId(0);
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NodeKind {
    Nil,
    Eof,

    // Atoms
    Ident,
    Number,

    List(NodeId, NodeId),
    Toplevel(NodeId),

    Error(&'static str),
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Node {
    pub span: Span,
    pub kind: NodeKind,
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Ast {
    pub nodes: Vec<Node>,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AstWriteMode {
    Compact,
    Spans,
}

impl Ast {
    pub fn new(capacity: usize) -> Self {
        assert!(capacity >= 1, "there must be space for at least a nil node");

        let mut ast = Self {
            nodes: Vec::with_capacity(capacity),
        };

        ast.alloc(Node {
            span: Span::new(0, 0),
            kind: NodeKind::Nil,
        })
        .unwrap();

        ast
    }

    pub fn alloc(&mut self, node: Node) -> Result<NodeId, NodeAllocError> {
        if self.nodes.len() >= self.nodes.capacity() {
            return Err(NodeAllocError);
        }

        let index = self.nodes.len();
        self.nodes.push(node);
        Ok(NodeId(index))
    }

    pub fn get(&self, node_id: NodeId) -> &Node {
        &self.nodes[node_id.0]
    }

    pub fn get_mut(&mut self, node_id: NodeId) -> &mut Node {
        &mut self.nodes[node_id.0]
    }

    pub fn write(
        &self,
        source: &SourceCode,
        node_id: NodeId,
        w: &mut dyn fmt::Write,
        mode: AstWriteMode,
    ) -> fmt::Result {
        #[allow(clippy::too_many_arguments)]
        fn write_list(
            ast: &Ast,
            source: &SourceCode,
            w: &mut dyn fmt::Write,
            mode: AstWriteMode,
            mut head: NodeId,
            mut tail: NodeId,
            sep_element: &str,
            sep_tail: &str,
        ) -> fmt::Result {
            loop {
                write_rec(ast, source, w, mode, head)?;
                match ast.get(tail).kind {
                    NodeKind::Nil => break,
                    NodeKind::List(head2, tail2) => {
                        w.write_str(sep_element)?;
                        (head, tail) = (head2, tail2);
                    }
                    _ => {
                        w.write_str(sep_tail)?;
                        write_rec(ast, source, w, mode, tail)?;
                        break;
                    }
                }
            }
            Ok(())
        }

        // NOTE: Separated out to a separate function in case we ever want to introduce auto-indentation.
        fn write_rec(
            ast: &Ast,
            source: &SourceCode,
            w: &mut dyn fmt::Write,
            mode: AstWriteMode,
            node_id: NodeId,
        ) -> fmt::Result {
            let node = ast.get(node_id);
            match &node.kind {
                NodeKind::Nil => write!(w, "()")?,
                NodeKind::Eof => write!(w, "<eof>")?,
                NodeKind::Ident | NodeKind::Number => write!(w, "{}", node.span.slice(source))?,

                NodeKind::List(head, tail) => {
                    w.write_char('(')?;
                    write_list(ast, source, w, mode, *head, *tail, " ", " . ")?;
                    w.write_char(')')?;
                }

                NodeKind::Toplevel(list) => {
                    let NodeKind::List(head, tail) = ast.get(*list).kind else {
                        unreachable!("child of Toplevel must be a List");
                    };

                    write_list(ast, source, w, mode, head, tail, "\n", " . ")?;
                }

                NodeKind::Error(message) => write!(w, "#error({message})")?,
            }

            if mode == AstWriteMode::Spans {
                write!(w, "@{}..{}", node.span.start, node.span.end)?;
            }

            Ok(())
        }

        write_rec(self, source, w, mode, node_id)?;

        Ok(())
    }
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct NodeAllocError;

pub struct Parser<'a> {
    pub ast: Ast,
    input: &'a SourceCode,
    position: usize,
    fuel: Cell<usize>,
    alloc_error: NodeId,
}

impl<'a> Parser<'a> {
    const FUEL: usize = 256;

    pub fn new(mut ast: Ast, input: &'a SourceCode) -> Self {
        let alloc_error = ast
            .alloc(Node {
                span: Span::new(0, 0),
                kind: NodeKind::Error("program is too big"),
            })
            .expect("there is not enough space in the arena for an error node");

        Self {
            ast,
            input,
            position: 0,
            fuel: Cell::new(Self::FUEL),
            alloc_error,
        }
    }

    #[track_caller]
    pub fn current(&self) -> char {
        assert_ne!(self.fuel.get(), 0, "parser is stuck");
        self.fuel.set(self.fuel.get() - 1);

        self.input[self.position..].chars().next().unwrap_or('\0')
    }

    pub fn advance(&mut self) {
        self.position += self.current().len_utf8();
        self.fuel.set(Self::FUEL);
    }

    pub fn alloc(&mut self, expr: Node) -> NodeId {
        self.ast.alloc(expr).unwrap_or(self.alloc_error)
    }
}

pub fn skip_whitespace_and_comments(p: &mut Parser<'_>) {
    loop {
        match p.current() {
            ' ' | '\t' | '\n' => {
                p.advance();
                continue;
            }
            ';' => {
                while p.current() != '\n' && p.current() != '\0' {
                    p.advance();
                }
            }
            _ => break,
        }
    }
}

fn is_decimal_digit(c: char) -> bool {
    c.is_ascii_digit()
}

pub fn parse_number(p: &mut Parser<'_>) -> NodeKind {
    while is_decimal_digit(p.current()) {
        p.advance();
    }
    if p.current() == '.' {
        p.advance();
        if !is_decimal_digit(p.current()) {
            return NodeKind::Error("missing digits after decimal point '.' in number literal");
        }
        while is_decimal_digit(p.current()) {
            p.advance();
        }
    }

    NodeKind::Number
}

fn is_ident(c: char) -> bool {
    // The identifier character set is quite limited to help with easy expansion in the future.
    // Rationale:
    // - alphabet and digits are pretty obvious
    // - '-' and '_' can be used for identifier separators, whichever you prefer.
    // - '+', '-', '*', '/', '^' are for arithmetic.
    // - '=', '!', '<', '>' are fore comparison.
    // - '\' is for builtin string constants, such as \n.
    // For other operators, it's generally clearer to use words (such as `and` and `or`.)
    matches!(c, 'a'..='z' | 'A'..='Z' | '0'..='9' | '-' | '_' | '+' | '*' | '/' | '\\' | '^' | '!' | '=' | '<' | '>')
}

pub fn parse_ident(p: &mut Parser<'_>) -> NodeKind {
    while is_ident(p.current()) {
        p.advance();
    }

    NodeKind::Ident
}

struct List {
    head: NodeId,
    tail: NodeId,
}

impl List {
    fn new() -> Self {
        Self {
            head: NodeId::NIL,
            tail: NodeId::NIL,
        }
    }

    fn append(&mut self, p: &mut Parser<'_>, node: NodeId) {
        let node_span = p.ast.get(node).span;

        let new_tail = p.alloc(Node {
            span: node_span,
            kind: NodeKind::List(node, NodeId::NIL),
        });
        if self.head == NodeId::NIL {
            self.head = new_tail;
            self.tail = new_tail;
        } else {
            let old_tail = p.ast.get_mut(self.tail);
            let NodeKind::List(expr_before, _) = old_tail.kind else {
                return;
            };
            *old_tail = Node {
                span: Span::new(old_tail.span.start, node_span.end),
                kind: NodeKind::List(expr_before, new_tail),
            };
            self.tail = new_tail;
        }
    }
}

pub fn parse_list(p: &mut Parser<'_>) -> NodeId {
    // This could've been a lot simpler if Rust supported tail recursion.

    let start = p.position;

    p.advance(); // skip past opening parenthesis
    skip_whitespace_and_comments(p);

    let mut list = List::new();

    while p.current() != ')' {
        if p.current() == '\0' {
            return p.alloc(Node {
                span: Span::new(start, p.position),
                kind: NodeKind::Error("missing ')' to close '('"),
            });
        }

        let expr = parse_expr(p);
        skip_whitespace_and_comments(p);

        list.append(p, expr);
    }
    p.advance(); // skip past closing parenthesis

    // If we didn't have any elements, we must not modify the initial Nil with ID 0.
    if list.head == NodeId::NIL {
        list.head = p.alloc(Node {
            span: Span::new(0, 0),
            kind: NodeKind::Nil,
        });
    }

    let end = p.position;
    p.ast.get_mut(list.head).span = Span::new(start, end);

    list.head
}

pub fn parse_expr(p: &mut Parser<'_>) -> NodeId {
    let start = p.position;
    let kind = match p.current() {
        '\0' => NodeKind::Eof,
        c if is_decimal_digit(c) => parse_number(p),
        // NOTE: Because of the `match` order, this prevents identifiers from starting with a digit.
        c if is_ident(c) => parse_ident(p),
        '(' => return parse_list(p),
        _ => {
            p.advance();
            NodeKind::Error("unexpected character")
        }
    };
    let end = p.position;

    p.alloc(Node {
        span: Span::new(start, end),
        kind,
    })
}

pub fn parse_toplevel(p: &mut Parser<'_>) -> NodeId {
    let start = p.position;

    let mut nodes = List::new();

    skip_whitespace_and_comments(p);
    while p.current() != '\0' {
        let expr = parse_expr(p);
        skip_whitespace_and_comments(p);

        nodes.append(p, expr);
    }

    let end = p.position;

    p.alloc(Node {
        span: Span::new(start, end),
        kind: NodeKind::Toplevel(nodes.head),
    })
}

#[cfg(test)]
mod tests {
    use core::error::Error;

    use alloc::{boxed::Box, string::String};

    use super::*;

    #[track_caller]
    fn parse(
        f: fn(&mut Parser<'_>) -> NodeId,
        source: &str,
        expected: &str,
    ) -> Result<(), Box<dyn Error>> {
        let ast = Ast::new(16);
        let code = SourceCode::unlimited_len(source);
        let mut p = Parser::new(ast, code);
        let node = f(&mut p);
        let ast = p.ast;

        let mut s = String::new();
        ast.write(code, node, &mut s, AstWriteMode::Spans)?;

        assert_eq!(s, expected);

        Ok(())
    }

    #[test]
    fn parse_number() -> Result<(), Box<dyn Error>> {
        parse(parse_expr, "123", "123@0..3")?;
        parse(parse_expr, "123.456", "123.456@0..7")?;
        Ok(())
    }

    #[test]
    fn parse_ident() -> Result<(), Box<dyn Error>> {
        parse(parse_expr, "abc", "abc@0..3")?;
        parse(parse_expr, "abcABC_01234", "abcABC_01234@0..12")?;
        parse(parse_expr, "+-*/\\^!=<>", "+-*/\\^!=<>@0..10")?;
        Ok(())
    }

    #[test]
    fn parse_list() -> Result<(), Box<dyn Error>> {
        parse(parse_expr, "()", "()@0..2")?;
        parse(parse_expr, "(a a)", "(a@1..2 a@3..4)@0..5")?;
        parse(parse_expr, "(a a a)", "(a@1..2 a@3..4 a@5..6)@0..7")?;
        parse(parse_expr, "(() ())", "(()@1..3 ()@4..6)@0..7")?;
        parse(
            parse_expr,
            "(nestedy (nest OwO))",
            "(nestedy@1..8 (nest@10..14 OwO@15..18)@9..19)@0..20",
        )?;
        Ok(())
    }

    #[test]
    fn oom() -> Result<(), Box<dyn Error>> {
        parse(parse_expr, "(a a a a a a a a)", "(a@1..2 a@3..4 a@5..6 a@7..8 a@9..10 a@11..12 a@13..14 . #error(program is too big)@0..0)@0..17")?;
        parse(parse_expr, "(a a a a a a a a a)", "(a@1..2 a@3..4 a@5..6 a@7..8 a@9..10 a@11..12 a@13..14 . #error(program is too big)@0..0)@0..19")?;
        parse(parse_expr, "(a a a a a a a a a a)", "(a@1..2 a@3..4 a@5..6 a@7..8 a@9..10 a@11..12 a@13..14 . #error(program is too big)@0..0)@0..21")?;
        parse(parse_expr, "(a a a a a a a a a a a)", "(a@1..2 a@3..4 a@5..6 a@7..8 a@9..10 a@11..12 a@13..14 . #error(program is too big)@0..0)@0..23")?;
        Ok(())
    }

    #[test]
    fn toplevel() -> Result<(), Box<dyn Error>> {
        parse(
            parse_toplevel,
            r#"
                (hello world)
                (abc)
            "#,
            "(hello@18..23 world@24..29)@17..30\n(abc@48..51)@47..52@0..65",
        )?;
        Ok(())
    }
}