diff --git a/compialer.py b/compialer.py index e69de29..0a36e8c 100644 --- a/compialer.py +++ b/compialer.py @@ -0,0 +1,620 @@ +#!/usr/bin/env python3 +import re +import sys +import os +import pickle + +# Global file cache for error reporting: filename -> list of source lines. +file_cache = {} + +# ANSI color codes +RED = "\033[91m" +YELLOW = "\033[93m" +CYAN = "\033[96m" +RESET = "\033[0m" + +def print_error(message, filename, line, col): + err_msg = f"{RED}{filename}:{line}:{col}: error: {message}{RESET}" + if filename in file_cache and 1 <= line <= len(file_cache[filename]): + src_line = file_cache[filename][line - 1].rstrip("\n") + err_msg += f"\n{CYAN}{src_line}{RESET}\n{' '*(col-1)}{RED}^{RESET}" + sys.exit(err_msg) + +# ------------------------ +# Preprocessor (handles #include and emits #line directives) +# ------------------------ +def preprocess(filename, included_files=None): + if included_files is None: + included_files = set() + if filename in included_files: + return "" + if not os.path.exists(filename): + print_error(f"File '{filename}' not found", filename, 0, 0) + included_files.add(filename) + with open(filename, "r", encoding="utf-8-sig") as f: + lines = f.readlines() + file_cache[filename] = lines + result = f'#line 1 "{filename}"\n' + for idx, line in enumerate(lines, 1): + stripped = line.strip() + if stripped.startswith("#include"): + m = re.match(r'#include\s+"([^"]+)"', stripped) + if m: + inc_filename = m.group(1) + base_dir = os.path.dirname(filename) + full_path = os.path.join(base_dir, inc_filename) + result += f'#line 1 "{full_path}"\n' + result += preprocess(full_path, included_files) + result += f'#line {idx+1} "{filename}"\n' + else: + print_error("Malformed #include directive", filename, idx, 1) + else: + result += line + return result +# ------------------------ +# Lexer (supports compound assignment operators) +# ------------------------ +class Token: + def __init__(self, typ, value, line, column, filename): + self.type = typ + self.value = value + self.line = line + self.column = column + self.filename = filename + def __repr__(self): + return f"Token({self.type}, {self.value}, {self.filename}:{self.line}:{self.column})" + +class Lexer: + def __init__(self, text, filename=""): + self.lines = text.splitlines(keepends=True) + self.default_filename = filename + self.tokens = [] + self.current_filename = filename + self.current_line_number = 1 + self.token_specification = [ + # Compound assignments first: + ('PLUSEQ', r'\+='), + ('MINUSEQ', r'-='), + ('MULEQ', r'\*='), + ('DIVEQ', r'/='), + ('FLOAT', r'\d+\.\d+'), + ('NUMBER', r'\d+'), + ('ID', r'[A-Za-z_][A-Za-z0-9_]*'), + ('EQEQ', r'=='), + ('ASSIGN', r'='), + ('PLUS', r'\+'), + ('MINUS', r'-'), + ('MUL', r'\*'), + ('DIV', r'/'), + ('DOT', r'\.'), + ('COMMA', r','), + ('LPAREN', r'\('), + ('RPAREN', r'\)'), + ('LBRACE', r'\{'), + ('RBRACE', r'\}'), + ('SEMICOLON',r';'), + # Updated SKIP pattern includes a wider range of whitespace characters. + ('SKIP', r'[ \t\r\f\v\u00A0\u2000-\u200B\u202F\u205F\u3000]+'), + ('MISMATCH', r'.'), + ] + self.tok_regex = re.compile('|'.join('(?P<%s>%s)' % pair for pair in self.token_specification)) + + def tokenize(self): + for line in self.lines: + if line.startswith("#line"): + m = re.match(r'#line\s+(\d+)\s+"([^"]+)"', line) + if m: + self.current_line_number = int(m.group(1)) + self.current_filename = m.group(2) + if self.current_filename not in file_cache: + try: + with open(self.current_filename, "r", encoding="utf-8-sig") as f: + file_cache[self.current_filename] = f.readlines() + except Exception: + file_cache[self.current_filename] = [] + continue + pos = 0 + line_len = len(line) + while pos < line_len: + mo = self.tok_regex.match(line, pos) + if not mo: + print_error("Lexer error", self.current_filename, self.current_line_number, pos+1) + kind = mo.lastgroup + value = mo.group() + column = pos + 1 + if kind == 'SKIP': + pass + elif kind == 'NUMBER': + tok = Token('NUMBER', int(value), self.current_line_number, column, self.current_filename) + self.tokens.append(tok) + elif kind == 'FLOAT': + tok = Token('FLOAT', float(value), self.current_line_number, column, self.current_filename) + self.tokens.append(tok) + elif kind == 'ID': + tok = Token('ID', value, self.current_line_number, column, self.current_filename) + self.tokens.append(tok) + elif kind in ('EQEQ','ASSIGN','PLUS','MINUS','MUL','DIV','DOT','COMMA','LPAREN','RPAREN','LBRACE','RBRACE','SEMICOLON', + 'PLUSEQ','MINUSEQ','MULEQ','DIVEQ'): + tok = Token(kind, value, self.current_line_number, column, self.current_filename) + self.tokens.append(tok) + elif kind == 'MISMATCH': + print_error(f"Unexpected character: {value}", self.current_filename, self.current_line_number, column) + pos = mo.end() + self.current_line_number += 1 + return self.tokens + + +# ------------------------ +# AST Node Classes +# ------------------------ +class Program: + def __init__(self, decls): + self.decls = decls + +class StructDecl: + def __init__(self, name, fields): + self.name = name + self.fields = fields # list of (type, name) + +class FunctionDecl: + def __init__(self, ret_type, name, params, body): + self.ret_type = ret_type + self.name = name + self.params = params # list of (type, name) + self.body = body # CompoundStmt + +class VarDecl: + def __init__(self, var_type, name, init): + self.var_type = var_type + self.name = name + self.init = init # expression or None + +class IfStmt: + def __init__(self, cond, then_stmt): + self.cond = cond + self.then_stmt = then_stmt + +class ReturnStmt: + def __init__(self, expr): + self.expr = expr + +class ExprStmt: + def __init__(self, expr): + self.expr = expr + +class CompoundStmt: + def __init__(self, stmts): + self.stmts = stmts + +class BinaryOp: + def __init__(self, op, left, right): + self.op = op # '+', '-', '*', '/', '==' + self.left = left + self.right = right + +class Assignment: + def __init__(self, target, op, expr): + self.target = target # VarRef + self.op = op # '=', '+=', '-=', '*=', '/=' + self.expr = expr + +class NumberLiteral: + def __init__(self, value): + self.value = value + +class FloatLiteral: + def __init__(self, value): + self.value = value + +class VarRef: + def __init__(self, name): + self.name = name + +class CallExpr: + def __init__(self, func_name, args): + self.func_name = func_name + self.args = args + +class FieldAccess: + def __init__(self, expr, field): + self.expr = expr + self.field = field + +class StructInit: + def __init__(self, type_name, elements): + self.type_name = type_name + self.elements = elements + +# ------------------------ +# Parser with Extended Grammar and Cool Error Reporting +# ------------------------ +class Parser: + def __init__(self, tokens, filename=""): + self.tokens = tokens + self.pos = 0 + self.filename = filename + + def current(self): + if self.pos < len(self.tokens): + return self.tokens[self.pos] + return None + + def consume(self, typ, value=None): + token = self.current() + if token is None: + self.error(f"Unexpected end of file, expected {typ}", self.filename, -1, -1) + if token.type != typ or (value is not None and token.value != value): + expected = f"'{value}'" if value else typ + self.error(f"expected token {expected} but got '{token.value}'", token.filename, token.line, token.column) + self.pos += 1 + return token + + def error(self, message, filename, line, column): + print_error(message, filename, line, column) + + def parse(self): + decls = [] + while self.current() is not None: + if self.current().value == 'struct': + decls.append(self.parse_struct_decl()) + else: + decls.append(self.parse_function_decl()) + return Program(decls) + + def parse_struct_decl(self): + self.consume('ID', 'struct') + name = self.consume('ID').value + self.consume('LBRACE') + fields = [] + while self.current().type != 'RBRACE': + field_type = self.consume('ID').value + field_names = [self.consume('ID').value] + while self.current().type == 'COMMA': + self.consume('COMMA') + field_names.append(self.consume('ID').value) + self.consume('SEMICOLON') + for fname in field_names: + fields.append((field_type, fname)) + self.consume('RBRACE') + self.consume('SEMICOLON') + if self.current() and self.current().value == 'typedef': + self.consume('ID', 'typedef') + self.consume('ID') + self.consume('SEMICOLON') + return StructDecl(name, fields) + + def parse_function_decl(self): + ret_type = self.consume('ID').value + name = self.consume('ID').value + self.consume('LPAREN') + params = [] + if self.current().type != 'RPAREN': + while True: + p_type = self.consume('ID').value + p_name = self.consume('ID').value + params.append((p_type, p_name)) + if self.current().type == 'COMMA': + self.consume('COMMA') + else: + break + self.consume('RPAREN') + body = self.parse_compound_stmt() + return FunctionDecl(ret_type, name, params, body) + + def parse_compound_stmt(self): + self.consume('LBRACE') + stmts = [] + while self.current().type != 'RBRACE': + stmts.append(self.parse_statement()) + self.consume('RBRACE') + return CompoundStmt(stmts) + + def parse_statement(self): + token = self.current() + if token.value == 'return': + return self.parse_return_stmt() + elif token.value == 'if': + return self.parse_if_stmt() + elif token.type == 'ID': + if token.value in ('int','float','Vec3'): + return self.parse_var_decl() + else: + expr = self.parse_expression() + self.consume('SEMICOLON') + return ExprStmt(expr) + else: + expr = self.parse_expression() + self.consume('SEMICOLON') + return ExprStmt(expr) + + def parse_return_stmt(self): + self.consume('ID', 'return') + expr = self.parse_expression() + self.consume('SEMICOLON') + return ReturnStmt(expr) + + def parse_if_stmt(self): + self.consume('ID', 'if') + self.consume('LPAREN') + cond = self.parse_expression() + self.consume('RPAREN') + stmt = self.parse_statement() + return IfStmt(cond, stmt) + + def parse_var_decl(self): + var_type = self.consume('ID').value + name = self.consume('ID').value + init = None + if self.current().type == 'ASSIGN': + self.consume('ASSIGN') + if self.current().type == 'LBRACE': + init = self.parse_struct_initializer(var_type) + else: + init = self.parse_expression() + self.consume('SEMICOLON') + return VarDecl(var_type, name, init) + + def parse_struct_initializer(self, type_name): + self.consume('LBRACE') + elements = [self.parse_expression()] + while self.current().type == 'COMMA': + self.consume('COMMA') + elements.append(self.parse_expression()) + self.consume('RBRACE') + return StructInit(type_name, elements) + + # Support assignments (including compound assignments) + def parse_assignment(self): + left = self.parse_equality() + if self.current() and self.current().type in ('ASSIGN','PLUSEQ','MINUSEQ','MULEQ','DIVEQ'): + op_token = self.consume(self.current().type) + op = op_token.value + right = self.parse_assignment() + if not isinstance(left, VarRef): + self.error("Invalid lvalue in assignment", op_token.filename, op_token.line, op_token.column) + return Assignment(left, op, right) + return left + + def parse_expression(self): + return self.parse_assignment() + + def parse_equality(self): + node = self.parse_additive() + while self.current() and self.current().type == 'EQEQ': + op = self.consume('EQEQ').value + right = self.parse_additive() + node = BinaryOp(op, node, right) + return node + + def parse_additive(self): + node = self.parse_term() + while self.current() and self.current().type in ('PLUS','MINUS'): + op = self.consume(self.current().type).value + right = self.parse_term() + node = BinaryOp(op, node, right) + return node + + def parse_term(self): + node = self.parse_factor() + while self.current() and self.current().type in ('MUL','DIV'): + op = self.consume(self.current().type).value + right = self.parse_factor() + node = BinaryOp(op, node, right) + return node + + def parse_factor(self): + token = self.current() + if token.type == 'NUMBER': + self.consume('NUMBER') + return NumberLiteral(token.value) + elif token.type == 'FLOAT': + self.consume('FLOAT') + return FloatLiteral(token.value) + elif token.type == 'LPAREN': + self.consume('LPAREN') + expr = self.parse_expression() + self.consume('RPAREN') + return expr + elif token.type == 'ID': + id_token = self.consume('ID') + if self.current() and self.current().type == 'LPAREN': + self.consume('LPAREN') + args = [] + if self.current().type != 'RPAREN': + args.append(self.parse_expression()) + while self.current().type == 'COMMA': + self.consume('COMMA') + args.append(self.parse_expression()) + self.consume('RPAREN') + return CallExpr(id_token.value, args) + elif self.current() and self.current().type == 'DOT': + node = VarRef(id_token.value) + while self.current() and self.current().type == 'DOT': + self.consume('DOT') + field_name = self.consume('ID').value + node = FieldAccess(node, field_name) + return node + else: + return VarRef(id_token.value) + else: + self.error("Unexpected token", token.filename, token.line, token.column) + +# ------------------------ +# Intermediate Code Generation (using integer opcodes) +# ------------------------ +PUSH_INT = 0 # push integer constant +PUSH_FLOAT = 1 # push float constant +LOAD_LOCAL = 2 # load local variable (by index) +STORE_LOCAL = 3 # store local variable (by index) +ADD_OP = 4 +SUB_OP = 5 +MUL_OP = 6 +DIV_OP = 7 +EQ_OP = 8 +JMP_IF_FALSE= 9 +JMP = 10 +CALL_OP = 11 # CALL_OP, function_index, num_args +RET_OP = 12 +LOAD_FIELD = 13 # LOAD_FIELD, field_index +MAKE_STRUCT = 14 # MAKE_STRUCT, num_fields + +class CodeGenerator: + def __init__(self, program): + self.program = program + self.func_table = {} # function name -> function object (dict) + self.struct_table = {} # struct name -> list of field names + self.local_vars = {} # variable name -> local index in current function + self.next_local = 0 + + def generate(self): + for decl in self.program.decls: + if isinstance(decl, StructDecl): + self.struct_table[decl.name] = [fname for (_, fname) in decl.fields] + for decl in self.program.decls: + if isinstance(decl, FunctionDecl): + self.generate_function(decl) + return self.func_table + + def new_local(self, name): + idx = self.next_local + self.local_vars[name] = idx + self.next_local += 1 + return idx + + def get_local(self, name): + if name not in self.local_vars: + raise RuntimeError(f"Variable '{name}' not declared") + return self.local_vars[name] + + def generate_function(self, func_decl): + self.local_vars = {} + self.next_local = 0 + for (ptype, pname) in func_decl.params: + self.new_local(pname) + code = [] + self.gen_statement(func_decl.body, code) + code.append((RET_OP,)) + func_obj = {'code': code, 'n_locals': self.next_local, 'params': [p[1] for p in func_decl.params]} + self.func_table[func_decl.name] = func_obj + + def gen_statement(self, stmt, code): + if isinstance(stmt, CompoundStmt): + for s in stmt.stmts: + self.gen_statement(s, code) + elif isinstance(stmt, ReturnStmt): + self.gen_expr(stmt.expr, code) + code.append((RET_OP,)) + elif isinstance(stmt, IfStmt): + self.gen_expr(stmt.cond, code) + jmp_false_index = len(code) + code.append((JMP_IF_FALSE, None)) + self.gen_statement(stmt.then_stmt, code) + offset = len(code) - jmp_false_index - 1 + code[jmp_false_index] = (JMP_IF_FALSE, offset) + elif isinstance(stmt, VarDecl): + idx = self.new_local(stmt.name) + if stmt.init: + self.gen_expr(stmt.init, code) + else: + code.append((PUSH_INT, 0)) + code.append((STORE_LOCAL, idx)) + elif isinstance(stmt, ExprStmt): + self.gen_expr(stmt.expr, code) + else: + raise RuntimeError("Unknown statement type in code generation") + + def gen_expr(self, expr, code): + if isinstance(expr, NumberLiteral): + code.append((PUSH_INT, expr.value)) + elif isinstance(expr, FloatLiteral): + code.append((PUSH_FLOAT, expr.value)) + elif isinstance(expr, VarRef): + idx = self.get_local(expr.name) + code.append((LOAD_LOCAL, idx)) + elif isinstance(expr, Assignment): + var_name = expr.target.name + local_index = self.get_local(var_name) + if expr.op == '=': + self.gen_expr(expr.expr, code) + elif expr.op == '+=': + code.append((LOAD_LOCAL, local_index)) + self.gen_expr(expr.expr, code) + code.append((ADD_OP,)) + elif expr.op == '-=': + code.append((LOAD_LOCAL, local_index)) + self.gen_expr(expr.expr, code) + code.append((SUB_OP,)) + elif expr.op == '*=': + code.append((LOAD_LOCAL, local_index)) + self.gen_expr(expr.expr, code) + code.append((MUL_OP,)) + elif expr.op == '/=': + code.append((LOAD_LOCAL, local_index)) + self.gen_expr(expr.expr, code) + code.append((DIV_OP,)) + else: + raise RuntimeError("Unsupported assignment operator: " + expr.op) + code.append((STORE_LOCAL, local_index)) + code.append((LOAD_LOCAL, local_index)) + elif isinstance(expr, BinaryOp): + self.gen_expr(expr.left, code) + self.gen_expr(expr.right, code) + if expr.op == '+': + code.append((ADD_OP,)) + elif expr.op == '-': + code.append((SUB_OP,)) + elif expr.op == '*': + code.append((MUL_OP,)) + elif expr.op == '/': + code.append((DIV_OP,)) + elif expr.op == '==': + code.append((EQ_OP,)) + else: + raise RuntimeError("Unsupported binary operator: " + expr.op) + elif isinstance(expr, CallExpr): + for arg in expr.args: + self.gen_expr(arg, code) + func_index = None + for i, (fname, fobj) in enumerate(self.func_table.items()): + if fname == expr.func_name: + func_index = i + break + if func_index is None: + raise RuntimeError("Function not found: " + expr.func_name) + code.append((CALL_OP, func_index, len(expr.args))) + elif isinstance(expr, FieldAccess): + self.gen_expr(expr.expr, code) + fields = self.struct_table.get('Vec3') + if fields is None: + raise RuntimeError("Struct Vec3 not defined") + try: + field_index = fields.index(expr.field) + except ValueError: + raise RuntimeError("Field not found: " + expr.field) + code.append((LOAD_FIELD, field_index)) + elif isinstance(expr, StructInit): + for element in expr.elements: + self.gen_expr(element, code) + code.append((MAKE_STRUCT, len(expr.elements))) + else: + raise RuntimeError("Unsupported expression type in code generation") + +# ------------------------ +# Main: Read source, compile, and save bytecode +# ------------------------ +if __name__ == '__main__': + if len(sys.argv) < 2: + source_file = "main.c" + else: + source_file = sys.argv[1] + bytecode_file = "program.bc" + source_code = preprocess(source_file) + lexer = Lexer(source_code, source_file) + tokens = lexer.tokenize() + parser = Parser(tokens, source_file) + program = parser.parse() + codegen = CodeGenerator(program) + func_table = codegen.generate() + with open(bytecode_file, "wb") as f: + pickle.dump(func_table, f) + print(f"{YELLOW}Compilation successful! Bytecode saved to {bytecode_file}.{RESET}") diff --git a/main.c b/main.c new file mode 100644 index 0000000..46550ab --- /dev/null +++ b/main.c @@ -0,0 +1,9 @@ +#include "std.h" + + +int main(int argc, char arg) +{ + Vec3 v = {0, 1, 0}; + int result = absVelocity(v); + return result; +} \ No newline at end of file diff --git a/runtime.py b/runtime.py index e69de29..4da7815 100644 --- a/runtime.py +++ b/runtime.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python3 +import sys +import os +import pickle + +# These opcode definitions must match those in compiler.py. +PUSH_INT = 0 +PUSH_FLOAT = 1 +LOAD_LOCAL = 2 +STORE_LOCAL = 3 +ADD_OP = 4 +SUB_OP = 5 +MUL_OP = 6 +DIV_OP = 7 +EQ_OP = 8 +JMP_IF_FALSE= 9 +JMP = 10 +CALL_OP = 11 +RET_OP = 12 +LOAD_FIELD = 13 +MAKE_STRUCT = 14 + +class Runtime: + def __init__(self, func_table): + # Preserve insertion order by converting to a list. + self.func_table = list(func_table.values()) + + def run_function(self, func_index, args): + func = self.func_table[func_index] + locals_ = [0] * func['n_locals'] + for i, arg in enumerate(args): + locals_[i] = arg + stack = [] + ip = 0 + code = func['code'] + while ip < len(code): + instr = code[ip] + op = instr[0] + if op == PUSH_INT: + stack.append(instr[1]) + elif op == PUSH_FLOAT: + stack.append(instr[1]) + elif op == LOAD_LOCAL: + stack.append(locals_[instr[1]]) + elif op == STORE_LOCAL: + locals_[instr[1]] = stack.pop() + elif op == ADD_OP: + b = stack.pop(); a = stack.pop() + stack.append(a + b) + elif op == SUB_OP: + b = stack.pop(); a = stack.pop() + stack.append(a - b) + elif op == MUL_OP: + b = stack.pop(); a = stack.pop() + stack.append(a * b) + elif op == DIV_OP: + b = stack.pop(); a = stack.pop() + stack.append(a // b) + elif op == EQ_OP: + b = stack.pop(); a = stack.pop() + stack.append(1 if a == b else 0) + elif op == JMP_IF_FALSE: + offset = instr[1] + cond = stack.pop() + if cond == 0: + ip += offset + continue + elif op == JMP: + ip += instr[1] + continue + elif op == CALL_OP: + callee_index = instr[1] + num_args = instr[2] + call_args = [stack.pop() for _ in range(num_args)][::-1] + ret_val = self.run_function(callee_index, call_args) + stack.append(ret_val) + elif op == RET_OP: + return stack.pop() if stack else 0 + elif op == LOAD_FIELD: + struct_val = stack.pop() + field_index = instr[1] + stack.append(struct_val[field_index]) + elif op == MAKE_STRUCT: + n = instr[1] + fields = [stack.pop() for _ in range(n)][::-1] + stack.append(tuple(fields)) + else: + raise RuntimeError("Unknown opcode: " + str(op)) + ip += 1 + return 0 + +if __name__ == '__main__': + bytecode_file = "program.bc" + if not os.path.exists(bytecode_file): + sys.exit("Error: Bytecode file not found. Please compile the source first.") + with open(bytecode_file, "rb") as f: + func_table = pickle.load(f) + main_index = None + for idx, (fname, fobj) in enumerate(func_table.items()): + if fname == 'main': + main_index = idx + break + if main_index is None: + sys.exit("Error: main function not found in bytecode.") + runtime = Runtime(func_table) + result = runtime.run_function(main_index, [0, 0]) + print("Program finished with result:", result) diff --git a/std.h b/std.h new file mode 100644 index 0000000..fa1d142 --- /dev/null +++ b/std.h @@ -0,0 +1,18 @@ +struct Vec3 +{ + float x; + float y; + float z; +}; +typedef Vec3; + +float absVelocity(Vec3 vector) +{ + float final = 0; + + final = vector.x + vector.x; + final = vector.y + vector.y; + final = vector.z + vector.z; + + return final; +} \ No newline at end of file