Interprited_C/python-c-compialer/compialer.py
2025-03-30 20:23:07 -05:00

621 lines
22 KiB
Python

#!/usr/bin/env python3
import re
import sys
import os
import pickle
# Global file cache for error reporting: filename -> list of source lines.
file_cache = {}
# ANSI color codes
RED = "\033[91m"
YELLOW = "\033[93m"
CYAN = "\033[96m"
RESET = "\033[0m"
def print_error(message, filename, line, col):
err_msg = f"{RED}{filename}:{line}:{col}: error: {message}{RESET}"
if filename in file_cache and 1 <= line <= len(file_cache[filename]):
src_line = file_cache[filename][line - 1].rstrip("\n")
err_msg += f"\n{CYAN}{src_line}{RESET}\n{' '*(col-1)}{RED}^{RESET}"
sys.exit(err_msg)
# ------------------------
# Preprocessor (handles #include and emits #line directives)
# ------------------------
def preprocess(filename, included_files=None):
if included_files is None:
included_files = set()
if filename in included_files:
return ""
if not os.path.exists(filename):
print_error(f"File '{filename}' not found", filename, 0, 0)
included_files.add(filename)
with open(filename, "r", encoding="utf-8-sig") as f:
lines = f.readlines()
file_cache[filename] = lines
result = f'#line 1 "{filename}"\n'
for idx, line in enumerate(lines, 1):
stripped = line.strip()
if stripped.startswith("#include"):
m = re.match(r'#include\s+"([^"]+)"', stripped)
if m:
inc_filename = m.group(1)
base_dir = os.path.dirname(filename)
full_path = os.path.join(base_dir, inc_filename)
result += f'#line 1 "{full_path}"\n'
result += preprocess(full_path, included_files)
result += f'#line {idx+1} "{filename}"\n'
else:
print_error("Malformed #include directive", filename, idx, 1)
else:
result += line
return result
# ------------------------
# Lexer (supports compound assignment operators)
# ------------------------
class Token:
def __init__(self, typ, value, line, column, filename):
self.type = typ
self.value = value
self.line = line
self.column = column
self.filename = filename
def __repr__(self):
return f"Token({self.type}, {self.value}, {self.filename}:{self.line}:{self.column})"
class Lexer:
def __init__(self, text, filename="<stdin>"):
self.lines = text.splitlines(keepends=True)
self.default_filename = filename
self.tokens = []
self.current_filename = filename
self.current_line_number = 1
self.token_specification = [
# Compound assignments first:
('PLUSEQ', r'\+='),
('MINUSEQ', r'-='),
('MULEQ', r'\*='),
('DIVEQ', r'/='),
('FLOAT', r'\d+\.\d+'),
('NUMBER', r'\d+'),
('ID', r'[A-Za-z_][A-Za-z0-9_]*'),
('EQEQ', r'=='),
('ASSIGN', r'='),
('PLUS', r'\+'),
('MINUS', r'-'),
('MUL', r'\*'),
('DIV', r'/'),
('DOT', r'\.'),
('COMMA', r','),
('LPAREN', r'\('),
('RPAREN', r'\)'),
('LBRACE', r'\{'),
('RBRACE', r'\}'),
('SEMICOLON',r';'),
# Updated SKIP pattern includes a wider range of whitespace characters.
('SKIP', r'[ \t\r\f\v\u00A0\u2000-\u200B\u202F\u205F\u3000]+'),
('MISMATCH', r'.'),
]
self.tok_regex = re.compile('|'.join('(?P<%s>%s)' % pair for pair in self.token_specification))
def tokenize(self):
for line in self.lines:
if line.startswith("#line"):
m = re.match(r'#line\s+(\d+)\s+"([^"]+)"', line)
if m:
self.current_line_number = int(m.group(1))
self.current_filename = m.group(2)
if self.current_filename not in file_cache:
try:
with open(self.current_filename, "r", encoding="utf-8-sig") as f:
file_cache[self.current_filename] = f.readlines()
except Exception:
file_cache[self.current_filename] = []
continue
pos = 0
line_len = len(line)
while pos < line_len:
mo = self.tok_regex.match(line, pos)
if not mo:
print_error("Lexer error", self.current_filename, self.current_line_number, pos+1)
kind = mo.lastgroup
value = mo.group()
column = pos + 1
if kind == 'SKIP':
pass
elif kind == 'NUMBER':
tok = Token('NUMBER', int(value), self.current_line_number, column, self.current_filename)
self.tokens.append(tok)
elif kind == 'FLOAT':
tok = Token('FLOAT', float(value), self.current_line_number, column, self.current_filename)
self.tokens.append(tok)
elif kind == 'ID':
tok = Token('ID', value, self.current_line_number, column, self.current_filename)
self.tokens.append(tok)
elif kind in ('EQEQ','ASSIGN','PLUS','MINUS','MUL','DIV','DOT','COMMA','LPAREN','RPAREN','LBRACE','RBRACE','SEMICOLON',
'PLUSEQ','MINUSEQ','MULEQ','DIVEQ'):
tok = Token(kind, value, self.current_line_number, column, self.current_filename)
self.tokens.append(tok)
elif kind == 'MISMATCH':
print_error(f"Unexpected character: {value}", self.current_filename, self.current_line_number, column)
pos = mo.end()
self.current_line_number += 1
return self.tokens
# ------------------------
# AST Node Classes
# ------------------------
class Program:
def __init__(self, decls):
self.decls = decls
class StructDecl:
def __init__(self, name, fields):
self.name = name
self.fields = fields # list of (type, name)
class FunctionDecl:
def __init__(self, ret_type, name, params, body):
self.ret_type = ret_type
self.name = name
self.params = params # list of (type, name)
self.body = body # CompoundStmt
class VarDecl:
def __init__(self, var_type, name, init):
self.var_type = var_type
self.name = name
self.init = init # expression or None
class IfStmt:
def __init__(self, cond, then_stmt):
self.cond = cond
self.then_stmt = then_stmt
class ReturnStmt:
def __init__(self, expr):
self.expr = expr
class ExprStmt:
def __init__(self, expr):
self.expr = expr
class CompoundStmt:
def __init__(self, stmts):
self.stmts = stmts
class BinaryOp:
def __init__(self, op, left, right):
self.op = op # '+', '-', '*', '/', '=='
self.left = left
self.right = right
class Assignment:
def __init__(self, target, op, expr):
self.target = target # VarRef
self.op = op # '=', '+=', '-=', '*=', '/='
self.expr = expr
class NumberLiteral:
def __init__(self, value):
self.value = value
class FloatLiteral:
def __init__(self, value):
self.value = value
class VarRef:
def __init__(self, name):
self.name = name
class CallExpr:
def __init__(self, func_name, args):
self.func_name = func_name
self.args = args
class FieldAccess:
def __init__(self, expr, field):
self.expr = expr
self.field = field
class StructInit:
def __init__(self, type_name, elements):
self.type_name = type_name
self.elements = elements
# ------------------------
# Parser with Extended Grammar and Cool Error Reporting
# ------------------------
class Parser:
def __init__(self, tokens, filename="<stdin>"):
self.tokens = tokens
self.pos = 0
self.filename = filename
def current(self):
if self.pos < len(self.tokens):
return self.tokens[self.pos]
return None
def consume(self, typ, value=None):
token = self.current()
if token is None:
self.error(f"Unexpected end of file, expected {typ}", self.filename, -1, -1)
if token.type != typ or (value is not None and token.value != value):
expected = f"'{value}'" if value else typ
self.error(f"expected token {expected} but got '{token.value}'", token.filename, token.line, token.column)
self.pos += 1
return token
def error(self, message, filename, line, column):
print_error(message, filename, line, column)
def parse(self):
decls = []
while self.current() is not None:
if self.current().value == 'struct':
decls.append(self.parse_struct_decl())
else:
decls.append(self.parse_function_decl())
return Program(decls)
def parse_struct_decl(self):
self.consume('ID', 'struct')
name = self.consume('ID').value
self.consume('LBRACE')
fields = []
while self.current().type != 'RBRACE':
field_type = self.consume('ID').value
field_names = [self.consume('ID').value]
while self.current().type == 'COMMA':
self.consume('COMMA')
field_names.append(self.consume('ID').value)
self.consume('SEMICOLON')
for fname in field_names:
fields.append((field_type, fname))
self.consume('RBRACE')
self.consume('SEMICOLON')
if self.current() and self.current().value == 'typedef':
self.consume('ID', 'typedef')
self.consume('ID')
self.consume('SEMICOLON')
return StructDecl(name, fields)
def parse_function_decl(self):
ret_type = self.consume('ID').value
name = self.consume('ID').value
self.consume('LPAREN')
params = []
if self.current().type != 'RPAREN':
while True:
p_type = self.consume('ID').value
p_name = self.consume('ID').value
params.append((p_type, p_name))
if self.current().type == 'COMMA':
self.consume('COMMA')
else:
break
self.consume('RPAREN')
body = self.parse_compound_stmt()
return FunctionDecl(ret_type, name, params, body)
def parse_compound_stmt(self):
self.consume('LBRACE')
stmts = []
while self.current().type != 'RBRACE':
stmts.append(self.parse_statement())
self.consume('RBRACE')
return CompoundStmt(stmts)
def parse_statement(self):
token = self.current()
if token.value == 'return':
return self.parse_return_stmt()
elif token.value == 'if':
return self.parse_if_stmt()
elif token.type == 'ID':
if token.value in ('int','float','Vec3'):
return self.parse_var_decl()
else:
expr = self.parse_expression()
self.consume('SEMICOLON')
return ExprStmt(expr)
else:
expr = self.parse_expression()
self.consume('SEMICOLON')
return ExprStmt(expr)
def parse_return_stmt(self):
self.consume('ID', 'return')
expr = self.parse_expression()
self.consume('SEMICOLON')
return ReturnStmt(expr)
def parse_if_stmt(self):
self.consume('ID', 'if')
self.consume('LPAREN')
cond = self.parse_expression()
self.consume('RPAREN')
stmt = self.parse_statement()
return IfStmt(cond, stmt)
def parse_var_decl(self):
var_type = self.consume('ID').value
name = self.consume('ID').value
init = None
if self.current().type == 'ASSIGN':
self.consume('ASSIGN')
if self.current().type == 'LBRACE':
init = self.parse_struct_initializer(var_type)
else:
init = self.parse_expression()
self.consume('SEMICOLON')
return VarDecl(var_type, name, init)
def parse_struct_initializer(self, type_name):
self.consume('LBRACE')
elements = [self.parse_expression()]
while self.current().type == 'COMMA':
self.consume('COMMA')
elements.append(self.parse_expression())
self.consume('RBRACE')
return StructInit(type_name, elements)
# Support assignments (including compound assignments)
def parse_assignment(self):
left = self.parse_equality()
if self.current() and self.current().type in ('ASSIGN','PLUSEQ','MINUSEQ','MULEQ','DIVEQ'):
op_token = self.consume(self.current().type)
op = op_token.value
right = self.parse_assignment()
if not isinstance(left, VarRef):
self.error("Invalid lvalue in assignment", op_token.filename, op_token.line, op_token.column)
return Assignment(left, op, right)
return left
def parse_expression(self):
return self.parse_assignment()
def parse_equality(self):
node = self.parse_additive()
while self.current() and self.current().type == 'EQEQ':
op = self.consume('EQEQ').value
right = self.parse_additive()
node = BinaryOp(op, node, right)
return node
def parse_additive(self):
node = self.parse_term()
while self.current() and self.current().type in ('PLUS','MINUS'):
op = self.consume(self.current().type).value
right = self.parse_term()
node = BinaryOp(op, node, right)
return node
def parse_term(self):
node = self.parse_factor()
while self.current() and self.current().type in ('MUL','DIV'):
op = self.consume(self.current().type).value
right = self.parse_factor()
node = BinaryOp(op, node, right)
return node
def parse_factor(self):
token = self.current()
if token.type == 'NUMBER':
self.consume('NUMBER')
return NumberLiteral(token.value)
elif token.type == 'FLOAT':
self.consume('FLOAT')
return FloatLiteral(token.value)
elif token.type == 'LPAREN':
self.consume('LPAREN')
expr = self.parse_expression()
self.consume('RPAREN')
return expr
elif token.type == 'ID':
id_token = self.consume('ID')
if self.current() and self.current().type == 'LPAREN':
self.consume('LPAREN')
args = []
if self.current().type != 'RPAREN':
args.append(self.parse_expression())
while self.current().type == 'COMMA':
self.consume('COMMA')
args.append(self.parse_expression())
self.consume('RPAREN')
return CallExpr(id_token.value, args)
elif self.current() and self.current().type == 'DOT':
node = VarRef(id_token.value)
while self.current() and self.current().type == 'DOT':
self.consume('DOT')
field_name = self.consume('ID').value
node = FieldAccess(node, field_name)
return node
else:
return VarRef(id_token.value)
else:
self.error("Unexpected token", token.filename, token.line, token.column)
# ------------------------
# Intermediate Code Generation (using integer opcodes)
# ------------------------
PUSH_INT = 0 # push integer constant
PUSH_FLOAT = 1 # push float constant
LOAD_LOCAL = 2 # load local variable (by index)
STORE_LOCAL = 3 # store local variable (by index)
ADD_OP = 4
SUB_OP = 5
MUL_OP = 6
DIV_OP = 7
EQ_OP = 8
JMP_IF_FALSE= 9
JMP = 10
CALL_OP = 11 # CALL_OP, function_index, num_args
RET_OP = 12
LOAD_FIELD = 13 # LOAD_FIELD, field_index
MAKE_STRUCT = 14 # MAKE_STRUCT, num_fields
class CodeGenerator:
def __init__(self, program):
self.program = program
self.func_table = {} # function name -> function object (dict)
self.struct_table = {} # struct name -> list of field names
self.local_vars = {} # variable name -> local index in current function
self.next_local = 0
def generate(self):
for decl in self.program.decls:
if isinstance(decl, StructDecl):
self.struct_table[decl.name] = [fname for (_, fname) in decl.fields]
for decl in self.program.decls:
if isinstance(decl, FunctionDecl):
self.generate_function(decl)
return self.func_table
def new_local(self, name):
idx = self.next_local
self.local_vars[name] = idx
self.next_local += 1
return idx
def get_local(self, name):
if name not in self.local_vars:
raise RuntimeError(f"Variable '{name}' not declared")
return self.local_vars[name]
def generate_function(self, func_decl):
self.local_vars = {}
self.next_local = 0
for (ptype, pname) in func_decl.params:
self.new_local(pname)
code = []
self.gen_statement(func_decl.body, code)
code.append((RET_OP,))
func_obj = {'code': code, 'n_locals': self.next_local, 'params': [p[1] for p in func_decl.params]}
self.func_table[func_decl.name] = func_obj
def gen_statement(self, stmt, code):
if isinstance(stmt, CompoundStmt):
for s in stmt.stmts:
self.gen_statement(s, code)
elif isinstance(stmt, ReturnStmt):
self.gen_expr(stmt.expr, code)
code.append((RET_OP,))
elif isinstance(stmt, IfStmt):
self.gen_expr(stmt.cond, code)
jmp_false_index = len(code)
code.append((JMP_IF_FALSE, None))
self.gen_statement(stmt.then_stmt, code)
offset = len(code) - jmp_false_index - 1
code[jmp_false_index] = (JMP_IF_FALSE, offset)
elif isinstance(stmt, VarDecl):
idx = self.new_local(stmt.name)
if stmt.init:
self.gen_expr(stmt.init, code)
else:
code.append((PUSH_INT, 0))
code.append((STORE_LOCAL, idx))
elif isinstance(stmt, ExprStmt):
self.gen_expr(stmt.expr, code)
else:
raise RuntimeError("Unknown statement type in code generation")
def gen_expr(self, expr, code):
if isinstance(expr, NumberLiteral):
code.append((PUSH_INT, expr.value))
elif isinstance(expr, FloatLiteral):
code.append((PUSH_FLOAT, expr.value))
elif isinstance(expr, VarRef):
idx = self.get_local(expr.name)
code.append((LOAD_LOCAL, idx))
elif isinstance(expr, Assignment):
var_name = expr.target.name
local_index = self.get_local(var_name)
if expr.op == '=':
self.gen_expr(expr.expr, code)
elif expr.op == '+=':
code.append((LOAD_LOCAL, local_index))
self.gen_expr(expr.expr, code)
code.append((ADD_OP,))
elif expr.op == '-=':
code.append((LOAD_LOCAL, local_index))
self.gen_expr(expr.expr, code)
code.append((SUB_OP,))
elif expr.op == '*=':
code.append((LOAD_LOCAL, local_index))
self.gen_expr(expr.expr, code)
code.append((MUL_OP,))
elif expr.op == '/=':
code.append((LOAD_LOCAL, local_index))
self.gen_expr(expr.expr, code)
code.append((DIV_OP,))
else:
raise RuntimeError("Unsupported assignment operator: " + expr.op)
code.append((STORE_LOCAL, local_index))
code.append((LOAD_LOCAL, local_index))
elif isinstance(expr, BinaryOp):
self.gen_expr(expr.left, code)
self.gen_expr(expr.right, code)
if expr.op == '+':
code.append((ADD_OP,))
elif expr.op == '-':
code.append((SUB_OP,))
elif expr.op == '*':
code.append((MUL_OP,))
elif expr.op == '/':
code.append((DIV_OP,))
elif expr.op == '==':
code.append((EQ_OP,))
else:
raise RuntimeError("Unsupported binary operator: " + expr.op)
elif isinstance(expr, CallExpr):
for arg in expr.args:
self.gen_expr(arg, code)
func_index = None
for i, (fname, fobj) in enumerate(self.func_table.items()):
if fname == expr.func_name:
func_index = i
break
if func_index is None:
raise RuntimeError("Function not found: " + expr.func_name)
code.append((CALL_OP, func_index, len(expr.args)))
elif isinstance(expr, FieldAccess):
self.gen_expr(expr.expr, code)
fields = self.struct_table.get('Vec3')
if fields is None:
raise RuntimeError("Struct Vec3 not defined")
try:
field_index = fields.index(expr.field)
except ValueError:
raise RuntimeError("Field not found: " + expr.field)
code.append((LOAD_FIELD, field_index))
elif isinstance(expr, StructInit):
for element in expr.elements:
self.gen_expr(element, code)
code.append((MAKE_STRUCT, len(expr.elements)))
else:
raise RuntimeError("Unsupported expression type in code generation")
# ------------------------
# Main: Read source, compile, and save bytecode
# ------------------------
if __name__ == '__main__':
if len(sys.argv) < 2:
source_file = "main.c"
else:
source_file = sys.argv[1]
bytecode_file = "program.bc"
source_code = preprocess(source_file)
lexer = Lexer(source_code, source_file)
tokens = lexer.tokenize()
parser = Parser(tokens, source_file)
program = parser.parse()
codegen = CodeGenerator(program)
func_table = codegen.generate()
with open(bytecode_file, "wb") as f:
pickle.dump(func_table, f)
print(f"{YELLOW}Compilation successful! Bytecode saved to {bytecode_file}.{RESET}")