clay/generator/parser.py
Harrison Lambeth 38bb241ced small fixes
2025-02-10 23:14:44 -07:00

173 lines
6.4 KiB
Python

from dataclasses import dataclass
from typing import Optional, TypedDict, NotRequired, Union
from pycparser import c_ast, parse_file, preprocess_file
from pathlib import Path
import os
import json
import shutil
import logging
logger = logging.getLogger(__name__)
ExtractedSymbolType = Union[str, "ExtractedFunction"]
class ExtractedStructAttributeUnion(TypedDict):
type: Optional[ExtractedSymbolType]
class ExtractedStructAttribute(TypedDict):
type: NotRequired[ExtractedSymbolType]
union: NotRequired[dict[str, Optional[ExtractedSymbolType]]]
class ExtractedStruct(TypedDict):
attrs: dict[str, ExtractedStructAttribute]
is_union: NotRequired[bool]
ExtractedEnum = dict[str, Optional[str]]
ExtractedFunctionParam = tuple[str, Optional[ExtractedSymbolType]]
class ExtractedFunction(TypedDict):
return_type: Optional["ExtractedSymbolType"]
params: list[ExtractedFunctionParam]
@dataclass
class ExtractedSymbols:
structs: dict[str, ExtractedStruct]
enums: dict[str, ExtractedEnum]
functions: dict[str, ExtractedFunction]
def get_type_names(node: c_ast.Node, prefix: str="") -> Optional[ExtractedSymbolType]:
if isinstance(node, c_ast.TypeDecl) and hasattr(node, 'quals') and node.quals:
prefix = " ".join(node.quals) + " " + prefix
if isinstance(node, c_ast.PtrDecl):
prefix = "*" + prefix
if isinstance(node, c_ast.FuncDecl):
func: ExtractedFunction = {
'return_type': get_type_names(node.type),
'params': [],
}
for param in node.args.params:
if param.name is None:
continue
func['params'].append((param.name, get_type_names(param)))
return func
if hasattr(node, 'names'):
return prefix + node.names[0] # type: ignore
elif hasattr(node, 'type'):
return get_type_names(node.type, prefix) # type: ignore
return None
class Visitor(c_ast.NodeVisitor):
def __init__(self):
self.structs: dict[str, ExtractedStruct] = {}
self.enums: dict[str, ExtractedEnum] = {}
self.functions: dict[str, ExtractedFunction] = {}
def visit_FuncDecl(self, node: c_ast.FuncDecl):
# node.show()
# logger.debug(node)
node_type = node.type
is_pointer = False
if isinstance(node.type, c_ast.PtrDecl):
node_type = node.type.type
is_pointer = True
if hasattr(node_type, "declname"):
return_type = get_type_names(node_type.type)
if return_type is not None and isinstance(return_type, str) and is_pointer:
return_type = "*" + return_type
func: ExtractedFunction = {
'return_type': return_type,
'params': [],
}
for param in node.args.params:
if param.name is None:
continue
func['params'].append((param.name, get_type_names(param)))
self.functions[node_type.declname] = func
self.generic_visit(node)
def visit_Struct(self, node: c_ast.Struct):
# node.show()
if node.name and node.decls:
struct = {}
for decl in node.decls:
struct[decl.name] = {
"type": get_type_names(decl),
}
self.structs[node.name] = {
'attrs': struct,
}
self.generic_visit(node)
def visit_Typedef(self, node: c_ast.Typedef):
# node.show()
if hasattr(node.type, 'type') and hasattr(node.type.type, 'decls') and node.type.type.decls:
struct = {}
for decl in node.type.type.decls:
if hasattr(decl, 'type') and hasattr(decl.type, 'type') and isinstance(decl.type.type, c_ast.Union):
union = {}
for field in decl.type.type.decls:
union[field.name] = get_type_names(field)
struct[decl.name] = {
'union': union
}
else:
struct[decl.name] = {
"type": get_type_names(decl),
}
self.structs[node.name] = {
'attrs': struct,
'is_union': isinstance(node.type.type, c_ast.Union),
}
if hasattr(node.type, 'type') and isinstance(node.type.type, c_ast.Enum):
enum = {}
for enumerator in node.type.type.values.enumerators:
if enumerator.value is None:
enum[enumerator.name] = None
else:
enum[enumerator.name] = enumerator.value.value
self.enums[node.name] = enum
self.generic_visit(node)
def parse_headers(input_files: list[Path], tmp_dir: Path) -> ExtractedSymbols:
cpp_args = ["-nostdinc", "-D__attribute__(x)=", "-E"]
# Make a new clay.h that combines the provided input files, so that we can add bindings for customized structs
with open(tmp_dir / 'merged_clay.h', 'w') as f:
for input_file in input_files:
with open(input_file, 'r') as f2:
for line in f2:
# Ignore includes, as they should be manually included in input_files.
if line.startswith("#include"):
continue
# Ignore the CLAY_IMPLEMENTATION define, because we only want to parse the public api code.
# This is helpful so that the user can provide their implementation code, which will contain any custom extensions
if "#define CLAY_IMPLEMENTATION" in line:
continue
f.write(line)
# Preprocess the file
logger.info("Preprocessing file")
preprocessed = preprocess_file(tmp_dir / 'merged_clay.h', cpp_path="cpp", cpp_args=cpp_args) # type: ignore
with open(tmp_dir / 'clay.preprocessed.h', 'w') as f:
f.write(preprocessed)
# Parse the file
logger.info("Parsing file")
ast = parse_file(tmp_dir / 'clay.preprocessed.h', use_cpp=False) # type: ignore
# Extract symbols
visitor = Visitor()
visitor.visit(ast)
result = ExtractedSymbols(
structs=visitor.structs,
enums=visitor.enums,
functions=visitor.functions
)
return result