#!/usr/bin/env python3
import re
import sys
import os
import pickle

# Global file cache for error reporting: filename -> list of source lines.
file_cache = {}

# ANSI color codes
RED = "\033[91m"
YELLOW = "\033[93m"
CYAN = "\033[96m"
RESET = "\033[0m"

def print_error(message, filename, line, col):
    err_msg = f"{RED}{filename}:{line}:{col}: error: {message}{RESET}"
    if filename in file_cache and 1 <= line <= len(file_cache[filename]):
        src_line = file_cache[filename][line - 1].rstrip("\n")
        err_msg += f"\n{CYAN}{src_line}{RESET}\n{' '*(col-1)}{RED}^{RESET}"
    sys.exit(err_msg)

# ------------------------
# Preprocessor (handles #include and emits #line directives)
# ------------------------
def preprocess(filename, included_files=None):
    if included_files is None:
        included_files = set()
    if filename in included_files:
        return ""
    if not os.path.exists(filename):
        print_error(f"File '{filename}' not found", filename, 0, 0)
    included_files.add(filename)
    with open(filename, "r", encoding="utf-8-sig") as f:
        lines = f.readlines()
    file_cache[filename] = lines
    result = f'#line 1 "{filename}"\n'
    for idx, line in enumerate(lines, 1):
        stripped = line.strip()
        if stripped.startswith("#include"):
            m = re.match(r'#include\s+"([^"]+)"', stripped)
            if m:
                inc_filename = m.group(1)
                base_dir = os.path.dirname(filename)
                full_path = os.path.join(base_dir, inc_filename)
                result += f'#line 1 "{full_path}"\n'
                result += preprocess(full_path, included_files)
                result += f'#line {idx+1} "{filename}"\n'
            else:
                print_error("Malformed #include directive", filename, idx, 1)
        else:
            result += line
    return result
# ------------------------
# Lexer (supports compound assignment operators)
# ------------------------
class Token:
    def __init__(self, typ, value, line, column, filename):
        self.type = typ
        self.value = value
        self.line = line
        self.column = column
        self.filename = filename
    def __repr__(self):
        return f"Token({self.type}, {self.value}, {self.filename}:{self.line}:{self.column})"

class Lexer:
    def __init__(self, text, filename="<stdin>"):
        self.lines = text.splitlines(keepends=True)
        self.default_filename = filename
        self.tokens = []
        self.current_filename = filename
        self.current_line_number = 1
        self.token_specification = [
            # Compound assignments first:
            ('PLUSEQ',   r'\+='),
            ('MINUSEQ',  r'-='),
            ('MULEQ',    r'\*='),
            ('DIVEQ',    r'/='),
            ('FLOAT',    r'\d+\.\d+'),
            ('NUMBER',   r'\d+'),
            ('ID',       r'[A-Za-z_][A-Za-z0-9_]*'),
            ('EQEQ',     r'=='),
            ('ASSIGN',   r'='),
            ('PLUS',     r'\+'),
            ('MINUS',    r'-'),
            ('MUL',      r'\*'),
            ('DIV',      r'/'),
            ('DOT',      r'\.'),
            ('COMMA',    r','),
            ('LPAREN',   r'\('),
            ('RPAREN',   r'\)'),
            ('LBRACE',   r'\{'),
            ('RBRACE',   r'\}'),
            ('SEMICOLON',r';'),
            # Updated SKIP pattern includes a wider range of whitespace characters.
            ('SKIP',     r'[ \t\r\f\v\u00A0\u2000-\u200B\u202F\u205F\u3000]+'),
            ('MISMATCH', r'.'),
        ]
        self.tok_regex = re.compile('|'.join('(?P<%s>%s)' % pair for pair in self.token_specification))
    
    def tokenize(self):
        for line in self.lines:
            if line.startswith("#line"):
                m = re.match(r'#line\s+(\d+)\s+"([^"]+)"', line)
                if m:
                    self.current_line_number = int(m.group(1))
                    self.current_filename = m.group(2)
                    if self.current_filename not in file_cache:
                        try:
                            with open(self.current_filename, "r", encoding="utf-8-sig") as f:
                                file_cache[self.current_filename] = f.readlines()
                        except Exception:
                            file_cache[self.current_filename] = []
                    continue
            pos = 0
            line_len = len(line)
            while pos < line_len:
                mo = self.tok_regex.match(line, pos)
                if not mo:
                    print_error("Lexer error", self.current_filename, self.current_line_number, pos+1)
                kind = mo.lastgroup
                value = mo.group()
                column = pos + 1
                if kind == 'SKIP':
                    pass
                elif kind == 'NUMBER':
                    tok = Token('NUMBER', int(value), self.current_line_number, column, self.current_filename)
                    self.tokens.append(tok)
                elif kind == 'FLOAT':
                    tok = Token('FLOAT', float(value), self.current_line_number, column, self.current_filename)
                    self.tokens.append(tok)
                elif kind == 'ID':
                    tok = Token('ID', value, self.current_line_number, column, self.current_filename)
                    self.tokens.append(tok)
                elif kind in ('EQEQ','ASSIGN','PLUS','MINUS','MUL','DIV','DOT','COMMA','LPAREN','RPAREN','LBRACE','RBRACE','SEMICOLON',
                              'PLUSEQ','MINUSEQ','MULEQ','DIVEQ'):
                    tok = Token(kind, value, self.current_line_number, column, self.current_filename)
                    self.tokens.append(tok)
                elif kind == 'MISMATCH':
                    print_error(f"Unexpected character: {value}", self.current_filename, self.current_line_number, column)
                pos = mo.end()
            self.current_line_number += 1
        return self.tokens


# ------------------------
# AST Node Classes
# ------------------------
class Program:
    def __init__(self, decls):
        self.decls = decls

class StructDecl:
    def __init__(self, name, fields):
        self.name = name
        self.fields = fields  # list of (type, name)

class FunctionDecl:
    def __init__(self, ret_type, name, params, body):
        self.ret_type = ret_type
        self.name = name
        self.params = params  # list of (type, name)
        self.body = body      # CompoundStmt

class VarDecl:
    def __init__(self, var_type, name, init):
        self.var_type = var_type
        self.name = name
        self.init = init  # expression or None

class IfStmt:
    def __init__(self, cond, then_stmt):
        self.cond = cond
        self.then_stmt = then_stmt

class ReturnStmt:
    def __init__(self, expr):
        self.expr = expr

class ExprStmt:
    def __init__(self, expr):
        self.expr = expr

class CompoundStmt:
    def __init__(self, stmts):
        self.stmts = stmts

class BinaryOp:
    def __init__(self, op, left, right):
        self.op = op  # '+', '-', '*', '/', '=='
        self.left = left
        self.right = right

class Assignment:
    def __init__(self, target, op, expr):
        self.target = target  # VarRef
        self.op = op          # '=', '+=', '-=', '*=', '/='
        self.expr = expr

class NumberLiteral:
    def __init__(self, value):
        self.value = value

class FloatLiteral:
    def __init__(self, value):
        self.value = value

class VarRef:
    def __init__(self, name):
        self.name = name

class CallExpr:
    def __init__(self, func_name, args):
        self.func_name = func_name
        self.args = args

class FieldAccess:
    def __init__(self, expr, field):
        self.expr = expr
        self.field = field

class StructInit:
    def __init__(self, type_name, elements):
        self.type_name = type_name
        self.elements = elements

# ------------------------
# Parser with Extended Grammar and Cool Error Reporting
# ------------------------
class Parser:
    def __init__(self, tokens, filename="<stdin>"):
        self.tokens = tokens
        self.pos = 0
        self.filename = filename

    def current(self):
        if self.pos < len(self.tokens):
            return self.tokens[self.pos]
        return None

    def consume(self, typ, value=None):
        token = self.current()
        if token is None:
            self.error(f"Unexpected end of file, expected {typ}", self.filename, -1, -1)
        if token.type != typ or (value is not None and token.value != value):
            expected = f"'{value}'" if value else typ
            self.error(f"expected token {expected} but got '{token.value}'", token.filename, token.line, token.column)
        self.pos += 1
        return token

    def error(self, message, filename, line, column):
        print_error(message, filename, line, column)

    def parse(self):
        decls = []
        while self.current() is not None:
            if self.current().value == 'struct':
                decls.append(self.parse_struct_decl())
            else:
                decls.append(self.parse_function_decl())
        return Program(decls)

    def parse_struct_decl(self):
        self.consume('ID', 'struct')
        name = self.consume('ID').value
        self.consume('LBRACE')
        fields = []
        while self.current().type != 'RBRACE':
            field_type = self.consume('ID').value
            field_names = [self.consume('ID').value]
            while self.current().type == 'COMMA':
                self.consume('COMMA')
                field_names.append(self.consume('ID').value)
            self.consume('SEMICOLON')
            for fname in field_names:
                fields.append((field_type, fname))
        self.consume('RBRACE')
        self.consume('SEMICOLON')
        if self.current() and self.current().value == 'typedef':
            self.consume('ID', 'typedef')
            self.consume('ID')
            self.consume('SEMICOLON')
        return StructDecl(name, fields)

    def parse_function_decl(self):
        ret_type = self.consume('ID').value
        name = self.consume('ID').value
        self.consume('LPAREN')
        params = []
        if self.current().type != 'RPAREN':
            while True:
                p_type = self.consume('ID').value
                p_name = self.consume('ID').value
                params.append((p_type, p_name))
                if self.current().type == 'COMMA':
                    self.consume('COMMA')
                else:
                    break
        self.consume('RPAREN')
        body = self.parse_compound_stmt()
        return FunctionDecl(ret_type, name, params, body)

    def parse_compound_stmt(self):
        self.consume('LBRACE')
        stmts = []
        while self.current().type != 'RBRACE':
            stmts.append(self.parse_statement())
        self.consume('RBRACE')
        return CompoundStmt(stmts)

    def parse_statement(self):
        token = self.current()
        if token.value == 'return':
            return self.parse_return_stmt()
        elif token.value == 'if':
            return self.parse_if_stmt()
        elif token.type == 'ID':
            if token.value in ('int','float','Vec3'):
                return self.parse_var_decl()
            else:
                expr = self.parse_expression()
                self.consume('SEMICOLON')
                return ExprStmt(expr)
        else:
            expr = self.parse_expression()
            self.consume('SEMICOLON')
            return ExprStmt(expr)

    def parse_return_stmt(self):
        self.consume('ID', 'return')
        expr = self.parse_expression()
        self.consume('SEMICOLON')
        return ReturnStmt(expr)

    def parse_if_stmt(self):
        self.consume('ID', 'if')
        self.consume('LPAREN')
        cond = self.parse_expression()
        self.consume('RPAREN')
        stmt = self.parse_statement()
        return IfStmt(cond, stmt)

    def parse_var_decl(self):
        var_type = self.consume('ID').value
        name = self.consume('ID').value
        init = None
        if self.current().type == 'ASSIGN':
            self.consume('ASSIGN')
            if self.current().type == 'LBRACE':
                init = self.parse_struct_initializer(var_type)
            else:
                init = self.parse_expression()
        self.consume('SEMICOLON')
        return VarDecl(var_type, name, init)

    def parse_struct_initializer(self, type_name):
        self.consume('LBRACE')
        elements = [self.parse_expression()]
        while self.current().type == 'COMMA':
            self.consume('COMMA')
            elements.append(self.parse_expression())
        self.consume('RBRACE')
        return StructInit(type_name, elements)

    # Support assignments (including compound assignments)
    def parse_assignment(self):
        left = self.parse_equality()
        if self.current() and self.current().type in ('ASSIGN','PLUSEQ','MINUSEQ','MULEQ','DIVEQ'):
            op_token = self.consume(self.current().type)
            op = op_token.value
            right = self.parse_assignment()
            if not isinstance(left, VarRef):
                self.error("Invalid lvalue in assignment", op_token.filename, op_token.line, op_token.column)
            return Assignment(left, op, right)
        return left

    def parse_expression(self):
        return self.parse_assignment()

    def parse_equality(self):
        node = self.parse_additive()
        while self.current() and self.current().type == 'EQEQ':
            op = self.consume('EQEQ').value
            right = self.parse_additive()
            node = BinaryOp(op, node, right)
        return node

    def parse_additive(self):
        node = self.parse_term()
        while self.current() and self.current().type in ('PLUS','MINUS'):
            op = self.consume(self.current().type).value
            right = self.parse_term()
            node = BinaryOp(op, node, right)
        return node

    def parse_term(self):
        node = self.parse_factor()
        while self.current() and self.current().type in ('MUL','DIV'):
            op = self.consume(self.current().type).value
            right = self.parse_factor()
            node = BinaryOp(op, node, right)
        return node

    def parse_factor(self):
        token = self.current()
        if token.type == 'NUMBER':
            self.consume('NUMBER')
            return NumberLiteral(token.value)
        elif token.type == 'FLOAT':
            self.consume('FLOAT')
            return FloatLiteral(token.value)
        elif token.type == 'LPAREN':
            self.consume('LPAREN')
            expr = self.parse_expression()
            self.consume('RPAREN')
            return expr
        elif token.type == 'ID':
            id_token = self.consume('ID')
            if self.current() and self.current().type == 'LPAREN':
                self.consume('LPAREN')
                args = []
                if self.current().type != 'RPAREN':
                    args.append(self.parse_expression())
                    while self.current().type == 'COMMA':
                        self.consume('COMMA')
                        args.append(self.parse_expression())
                self.consume('RPAREN')
                return CallExpr(id_token.value, args)
            elif self.current() and self.current().type == 'DOT':
                node = VarRef(id_token.value)
                while self.current() and self.current().type == 'DOT':
                    self.consume('DOT')
                    field_name = self.consume('ID').value
                    node = FieldAccess(node, field_name)
                return node
            else:
                return VarRef(id_token.value)
        else:
            self.error("Unexpected token", token.filename, token.line, token.column)

# ------------------------
# Intermediate Code Generation (using integer opcodes)
# ------------------------
PUSH_INT    = 0   # push integer constant
PUSH_FLOAT  = 1   # push float constant
LOAD_LOCAL  = 2   # load local variable (by index)
STORE_LOCAL = 3   # store local variable (by index)
ADD_OP      = 4
SUB_OP      = 5
MUL_OP      = 6
DIV_OP      = 7
EQ_OP       = 8
JMP_IF_FALSE= 9
JMP         = 10
CALL_OP     = 11  # CALL_OP, function_index, num_args
RET_OP      = 12
LOAD_FIELD  = 13  # LOAD_FIELD, field_index
MAKE_STRUCT = 14  # MAKE_STRUCT, num_fields

class CodeGenerator:
    def __init__(self, program):
        self.program = program
        self.func_table = {}   # function name -> function object (dict)
        self.struct_table = {} # struct name -> list of field names
        self.local_vars = {}   # variable name -> local index in current function
        self.next_local = 0

    def generate(self):
        for decl in self.program.decls:
            if isinstance(decl, StructDecl):
                self.struct_table[decl.name] = [fname for (_, fname) in decl.fields]
        for decl in self.program.decls:
            if isinstance(decl, FunctionDecl):
                self.generate_function(decl)
        return self.func_table

    def new_local(self, name):
        idx = self.next_local
        self.local_vars[name] = idx
        self.next_local += 1
        return idx

    def get_local(self, name):
        if name not in self.local_vars:
            raise RuntimeError(f"Variable '{name}' not declared")
        return self.local_vars[name]

    def generate_function(self, func_decl):
        self.local_vars = {}
        self.next_local = 0
        for (ptype, pname) in func_decl.params:
            self.new_local(pname)
        code = []
        self.gen_statement(func_decl.body, code)
        code.append((RET_OP,))
        func_obj = {'code': code, 'n_locals': self.next_local, 'params': [p[1] for p in func_decl.params]}
        self.func_table[func_decl.name] = func_obj

    def gen_statement(self, stmt, code):
        if isinstance(stmt, CompoundStmt):
            for s in stmt.stmts:
                self.gen_statement(s, code)
        elif isinstance(stmt, ReturnStmt):
            self.gen_expr(stmt.expr, code)
            code.append((RET_OP,))
        elif isinstance(stmt, IfStmt):
            self.gen_expr(stmt.cond, code)
            jmp_false_index = len(code)
            code.append((JMP_IF_FALSE, None))
            self.gen_statement(stmt.then_stmt, code)
            offset = len(code) - jmp_false_index - 1
            code[jmp_false_index] = (JMP_IF_FALSE, offset)
        elif isinstance(stmt, VarDecl):
            idx = self.new_local(stmt.name)
            if stmt.init:
                self.gen_expr(stmt.init, code)
            else:
                code.append((PUSH_INT, 0))
            code.append((STORE_LOCAL, idx))
        elif isinstance(stmt, ExprStmt):
            self.gen_expr(stmt.expr, code)
        else:
            raise RuntimeError("Unknown statement type in code generation")

    def gen_expr(self, expr, code):
        if isinstance(expr, NumberLiteral):
            code.append((PUSH_INT, expr.value))
        elif isinstance(expr, FloatLiteral):
            code.append((PUSH_FLOAT, expr.value))
        elif isinstance(expr, VarRef):
            idx = self.get_local(expr.name)
            code.append((LOAD_LOCAL, idx))
        elif isinstance(expr, Assignment):
            var_name = expr.target.name
            local_index = self.get_local(var_name)
            if expr.op == '=':
                self.gen_expr(expr.expr, code)
            elif expr.op == '+=':
                code.append((LOAD_LOCAL, local_index))
                self.gen_expr(expr.expr, code)
                code.append((ADD_OP,))
            elif expr.op == '-=':
                code.append((LOAD_LOCAL, local_index))
                self.gen_expr(expr.expr, code)
                code.append((SUB_OP,))
            elif expr.op == '*=':
                code.append((LOAD_LOCAL, local_index))
                self.gen_expr(expr.expr, code)
                code.append((MUL_OP,))
            elif expr.op == '/=':
                code.append((LOAD_LOCAL, local_index))
                self.gen_expr(expr.expr, code)
                code.append((DIV_OP,))
            else:
                raise RuntimeError("Unsupported assignment operator: " + expr.op)
            code.append((STORE_LOCAL, local_index))
            code.append((LOAD_LOCAL, local_index))
        elif isinstance(expr, BinaryOp):
            self.gen_expr(expr.left, code)
            self.gen_expr(expr.right, code)
            if expr.op == '+':
                code.append((ADD_OP,))
            elif expr.op == '-':
                code.append((SUB_OP,))
            elif expr.op == '*':
                code.append((MUL_OP,))
            elif expr.op == '/':
                code.append((DIV_OP,))
            elif expr.op == '==':
                code.append((EQ_OP,))
            else:
                raise RuntimeError("Unsupported binary operator: " + expr.op)
        elif isinstance(expr, CallExpr):
            for arg in expr.args:
                self.gen_expr(arg, code)
            func_index = None
            for i, (fname, fobj) in enumerate(self.func_table.items()):
                if fname == expr.func_name:
                    func_index = i
                    break
            if func_index is None:
                raise RuntimeError("Function not found: " + expr.func_name)
            code.append((CALL_OP, func_index, len(expr.args)))
        elif isinstance(expr, FieldAccess):
            self.gen_expr(expr.expr, code)
            fields = self.struct_table.get('Vec3')
            if fields is None:
                raise RuntimeError("Struct Vec3 not defined")
            try:
                field_index = fields.index(expr.field)
            except ValueError:
                raise RuntimeError("Field not found: " + expr.field)
            code.append((LOAD_FIELD, field_index))
        elif isinstance(expr, StructInit):
            for element in expr.elements:
                self.gen_expr(element, code)
            code.append((MAKE_STRUCT, len(expr.elements)))
        else:
            raise RuntimeError("Unsupported expression type in code generation")

# ------------------------
# Main: Read source, compile, and save bytecode
# ------------------------
if __name__ == '__main__':
    if len(sys.argv) < 2:
        source_file = "main.c"
    else:
        source_file = sys.argv[1]
    bytecode_file = "program.bc"
    source_code = preprocess(source_file)
    lexer = Lexer(source_code, source_file)
    tokens = lexer.tokenize()
    parser = Parser(tokens, source_file)
    program = parser.parse()
    codegen = CodeGenerator(program)
    func_table = codegen.generate()
    with open(bytecode_file, "wb") as f:
        pickle.dump(func_table, f)
    print(f"{YELLOW}Compilation successful! Bytecode saved to {bytecode_file}.{RESET}")