mirror of
https://github.com/nicbarker/clay.git
synced 2025-04-15 10:48:04 +00:00
173 lines
6.4 KiB
Python
173 lines
6.4 KiB
Python
from dataclasses import dataclass
|
|
from typing import Optional, TypedDict, NotRequired, Union
|
|
from pycparser import c_ast, parse_file, preprocess_file
|
|
from pathlib import Path
|
|
import os
|
|
import json
|
|
import shutil
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
ExtractedSymbolType = Union[str, "ExtractedFunction"]
|
|
|
|
class ExtractedStructAttributeUnion(TypedDict):
|
|
type: Optional[ExtractedSymbolType]
|
|
|
|
class ExtractedStructAttribute(TypedDict):
|
|
type: NotRequired[ExtractedSymbolType]
|
|
union: NotRequired[dict[str, Optional[ExtractedSymbolType]]]
|
|
|
|
class ExtractedStruct(TypedDict):
|
|
attrs: dict[str, ExtractedStructAttribute]
|
|
is_union: NotRequired[bool]
|
|
|
|
ExtractedEnum = dict[str, Optional[str]]
|
|
ExtractedFunctionParam = tuple[str, Optional[ExtractedSymbolType]]
|
|
|
|
class ExtractedFunction(TypedDict):
|
|
return_type: Optional["ExtractedSymbolType"]
|
|
params: list[ExtractedFunctionParam]
|
|
|
|
@dataclass
|
|
class ExtractedSymbols:
|
|
structs: dict[str, ExtractedStruct]
|
|
enums: dict[str, ExtractedEnum]
|
|
functions: dict[str, ExtractedFunction]
|
|
|
|
def get_type_names(node: c_ast.Node, prefix: str="") -> Optional[ExtractedSymbolType]:
|
|
if isinstance(node, c_ast.TypeDecl) and hasattr(node, 'quals') and node.quals:
|
|
prefix = " ".join(node.quals) + " " + prefix
|
|
if isinstance(node, c_ast.PtrDecl):
|
|
prefix = "*" + prefix
|
|
if isinstance(node, c_ast.FuncDecl):
|
|
func: ExtractedFunction = {
|
|
'return_type': get_type_names(node.type),
|
|
'params': [],
|
|
}
|
|
for param in node.args.params:
|
|
if param.name is None:
|
|
continue
|
|
func['params'].append((param.name, get_type_names(param)))
|
|
return func
|
|
|
|
if hasattr(node, 'names'):
|
|
return prefix + node.names[0] # type: ignore
|
|
elif hasattr(node, 'type'):
|
|
return get_type_names(node.type, prefix) # type: ignore
|
|
return None
|
|
|
|
class Visitor(c_ast.NodeVisitor):
|
|
def __init__(self):
|
|
self.structs: dict[str, ExtractedStruct] = {}
|
|
self.enums: dict[str, ExtractedEnum] = {}
|
|
self.functions: dict[str, ExtractedFunction] = {}
|
|
|
|
def visit_FuncDecl(self, node: c_ast.FuncDecl):
|
|
# node.show()
|
|
# logger.debug(node)
|
|
node_type = node.type
|
|
is_pointer = False
|
|
if isinstance(node.type, c_ast.PtrDecl):
|
|
node_type = node.type.type
|
|
is_pointer = True
|
|
|
|
if hasattr(node_type, "declname"):
|
|
return_type = get_type_names(node_type.type)
|
|
if return_type is not None and isinstance(return_type, str) and is_pointer:
|
|
return_type = "*" + return_type
|
|
func: ExtractedFunction = {
|
|
'return_type': return_type,
|
|
'params': [],
|
|
}
|
|
for param in node.args.params:
|
|
if param.name is None:
|
|
continue
|
|
func['params'].append((param.name, get_type_names(param)))
|
|
self.functions[node_type.declname] = func
|
|
self.generic_visit(node)
|
|
|
|
def visit_Struct(self, node: c_ast.Struct):
|
|
# node.show()
|
|
if node.name and node.decls:
|
|
struct = {}
|
|
for decl in node.decls:
|
|
struct[decl.name] = {
|
|
"type": get_type_names(decl),
|
|
}
|
|
self.structs[node.name] = {
|
|
'attrs': struct,
|
|
}
|
|
self.generic_visit(node)
|
|
|
|
def visit_Typedef(self, node: c_ast.Typedef):
|
|
# node.show()
|
|
if hasattr(node.type, 'type') and hasattr(node.type.type, 'decls') and node.type.type.decls:
|
|
struct = {}
|
|
for decl in node.type.type.decls:
|
|
if hasattr(decl, 'type') and hasattr(decl.type, 'type') and isinstance(decl.type.type, c_ast.Union):
|
|
union = {}
|
|
for field in decl.type.type.decls:
|
|
union[field.name] = get_type_names(field)
|
|
struct[decl.name] = {
|
|
'union': union
|
|
}
|
|
else:
|
|
struct[decl.name] = {
|
|
"type": get_type_names(decl),
|
|
}
|
|
|
|
self.structs[node.name] = {
|
|
'attrs': struct,
|
|
'is_union': isinstance(node.type.type, c_ast.Union),
|
|
}
|
|
if hasattr(node.type, 'type') and isinstance(node.type.type, c_ast.Enum):
|
|
enum = {}
|
|
for enumerator in node.type.type.values.enumerators:
|
|
if enumerator.value is None:
|
|
enum[enumerator.name] = None
|
|
else:
|
|
enum[enumerator.name] = enumerator.value.value
|
|
self.enums[node.name] = enum
|
|
self.generic_visit(node)
|
|
|
|
|
|
def parse_headers(input_files: list[Path], tmp_dir: Path) -> ExtractedSymbols:
|
|
cpp_args = ["-nostdinc", "-D__attribute__(x)=", "-E"]
|
|
|
|
# Make a new clay.h that combines the provided input files, so that we can add bindings for customized structs
|
|
with open(tmp_dir / 'merged_clay.h', 'w') as f:
|
|
for input_file in input_files:
|
|
with open(input_file, 'r') as f2:
|
|
for line in f2:
|
|
# Ignore includes, as they should be manually included in input_files.
|
|
if line.startswith("#include"):
|
|
continue
|
|
|
|
# Ignore the CLAY_IMPLEMENTATION define, because we only want to parse the public api code.
|
|
# This is helpful so that the user can provide their implementation code, which will contain any custom extensions
|
|
if "#define CLAY_IMPLEMENTATION" in line:
|
|
continue
|
|
|
|
f.write(line)
|
|
|
|
# Preprocess the file
|
|
logger.info("Preprocessing file")
|
|
preprocessed = preprocess_file(tmp_dir / 'merged_clay.h', cpp_path="cpp", cpp_args=cpp_args) # type: ignore
|
|
with open(tmp_dir / 'clay.preprocessed.h', 'w') as f:
|
|
f.write(preprocessed)
|
|
|
|
# Parse the file
|
|
logger.info("Parsing file")
|
|
ast = parse_file(tmp_dir / 'clay.preprocessed.h', use_cpp=False) # type: ignore
|
|
|
|
# Extract symbols
|
|
visitor = Visitor()
|
|
visitor.visit(ast)
|
|
|
|
result = ExtractedSymbols(
|
|
structs=visitor.structs,
|
|
enums=visitor.enums,
|
|
functions=visitor.functions
|
|
)
|
|
return result |