Add proper support for function arguments

This commit is contained in:
Harrison Lambeth 2025-01-26 14:43:51 -07:00
parent 7c65f31f46
commit 4f4605eff9
2 changed files with 76 additions and 37 deletions

View File

@ -1,6 +1,7 @@
from pathlib import Path from pathlib import Path
import logging import logging
from parser import ExtractedSymbolType
from generators.base_generator import BaseGenerator from generators.base_generator import BaseGenerator
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -61,23 +62,12 @@ TYPE_MAPPING = {
'int32_t': 'c.int32_t', 'int32_t': 'c.int32_t',
'uintptr_t': 'rawptr', 'uintptr_t': 'rawptr',
'void': 'void', 'void': 'void',
'*Clay_RectangleElementConfig': '^RectangleElementConfig',
'*Clay_TextElementConfig': '^TextElementConfig',
'*Clay_ImageElementConfig': '^ImageElementConfig',
'*Clay_FloatingElementConfig': '^FloatingElementConfig',
'*Clay_CustomElementConfig': '^CustomElementConfig',
'*Clay_ScrollElementConfig': '^ScrollElementConfig',
'*Clay_BorderElementConfig': '^BorderElementConfig',
} }
STRUCT_TYPE_OVERRIDES = { STRUCT_TYPE_OVERRIDES = {
'Clay_Arena': { 'Clay_Arena': {
'nextAllocation': 'uintptr', 'nextAllocation': 'uintptr',
'capacity': 'uintptr', 'capacity': 'uintptr',
}, },
'Clay_ErrorHandler': {
'errorHandlerFunction': 'proc "c" (errorData: ErrorData)',
},
'Clay_SizingAxis': { 'Clay_SizingAxis': {
'size': 'SizingConstraints', 'size': 'SizingConstraints',
}, },
@ -108,7 +98,6 @@ FUNCTION_TYPE_OVERRIDES = {
'offset': '[^]u8', 'offset': '[^]u8',
}, },
'Clay_SetMeasureTextFunction': { 'Clay_SetMeasureTextFunction': {
'measureTextFunction': 'proc "c" (text: ^StringSlice, config: ^TextElementConfig, userData: uintptr) -> Dimensions',
'userData': 'uintptr', 'userData': 'uintptr',
}, },
'Clay_RenderCommandArray_Get': { 'Clay_RenderCommandArray_Get': {
@ -155,7 +144,20 @@ class OdinGenerator(BaseGenerator):
return base_name return base_name
raise ValueError(f'Unknown symbol: {symbol}') raise ValueError(f'Unknown symbol: {symbol}')
def resolve_binding_type(self, symbol: str, member: str | None, member_type: str | None, type_overrides: dict[str, dict[str, str]]) -> str | None: def format_type(self, type: ExtractedSymbolType) -> str:
if isinstance(type, str):
return type
parameter_strs = []
for param_name, param_type in type['params']:
parameter_strs.append(f"{param_name}: {self.format_type(param_type or 'unknown')}")
return_type_str = ''
if type['return_type'] is not None and type['return_type'] != 'void':
return_type_str = ' -> ' + self.format_type(type['return_type'])
return f"proc \"c\" ({', '.join(parameter_strs)}){return_type_str}"
def resolve_binding_type(self, symbol: str, member: str | None, member_type: ExtractedSymbolType | None, type_overrides: dict[str, dict[str, str]]) -> ExtractedSymbolType | None:
if isinstance(member_type, str):
if member_type in SYMBOL_COMPLETE_OVERRIDES: if member_type in SYMBOL_COMPLETE_OVERRIDES:
return SYMBOL_COMPLETE_OVERRIDES[member_type] return SYMBOL_COMPLETE_OVERRIDES[member_type]
if symbol in type_overrides and member in type_overrides[symbol]: if symbol in type_overrides and member in type_overrides[symbol]:
@ -169,6 +171,22 @@ class OdinGenerator(BaseGenerator):
if result: if result:
return f"^{result}" return f"^{result}"
return None return None
if member_type is None:
return None
resolved_parameters = []
for param_name, param_type in member_type['params']:
resolved_param = self.resolve_binding_type(symbol, param_name, param_type, type_overrides)
if resolved_param is None:
return None
resolved_parameters.append((param_name, resolved_param))
resolved_return_type = self.resolve_binding_type(symbol, None, member_type['return_type'], type_overrides)
if resolved_return_type is None:
return None
return {
"params": resolved_parameters,
"return_type": resolved_return_type,
}
def generate_structs(self) -> None: def generate_structs(self) -> None:
for struct, struct_data in sorted(self.extracted_symbols.structs.items(), key=lambda x: x[0]): for struct, struct_data in sorted(self.extracted_symbols.structs.items(), key=lambda x: x[0]):
@ -184,12 +202,15 @@ class OdinGenerator(BaseGenerator):
if struct in STRUCT_OVERRIDE_AS_FIXED_ARRAY: if struct in STRUCT_OVERRIDE_AS_FIXED_ARRAY:
array_size = len(members) array_size = len(members)
array_type = list(members.values())[0]['type'] first_elem = list(members.values())[0]
array_type = None
if 'type' in first_elem:
array_type = first_elem['type']
if array_type in TYPE_MAPPING: if array_type in TYPE_MAPPING:
array_binding_type = TYPE_MAPPING[array_type] array_binding_type = TYPE_MAPPING[array_type]
elif array_type and self.has_symbol(array_type): elif array_type and self.has_symbol(self.format_type(array_type)):
array_binding_type = self.get_symbol_name(array_type) array_binding_type = self.get_symbol_name(self.format_type(array_type))
else: else:
self._write('struct', f"// {struct} ({array_type}) - has no mapping") self._write('struct', f"// {struct} ({array_type}) - has no mapping")
continue continue
@ -221,6 +242,7 @@ class OdinGenerator(BaseGenerator):
if member_binding_type is None: if member_binding_type is None:
self._write('struct', f" // {binding_member_name} ({member_type}) - has no mapping") self._write('struct', f" // {binding_member_name} ({member_type}) - has no mapping")
continue continue
member_binding_type = self.format_type(member_binding_type)
self._write('struct', f" {binding_member_name}: {member_binding_type}, // {member} ({member_type})") self._write('struct', f" {binding_member_name}: {member_binding_type}, // {member} ({member_type})")
self._write('struct', "}") self._write('struct', "}")
self._write('struct', '') self._write('struct', '')
@ -272,6 +294,7 @@ class OdinGenerator(BaseGenerator):
if binding_return_type is None: if binding_return_type is None:
self._write(write_to, f" // {function} ({return_type}) - has no mapping") self._write(write_to, f" // {function} ({return_type}) - has no mapping")
continue continue
binding_return_type = self.format_type(binding_return_type)
skip = False skip = False
binding_params = [] binding_params = []
@ -282,6 +305,8 @@ class OdinGenerator(BaseGenerator):
binding_param_type = self.resolve_binding_type(function, param_name, param_type, FUNCTION_TYPE_OVERRIDES) binding_param_type = self.resolve_binding_type(function, param_name, param_type, FUNCTION_TYPE_OVERRIDES)
if binding_param_type is None: if binding_param_type is None:
skip = True skip = True
else:
binding_param_type = self.format_type(binding_param_type)
binding_params.append(f"{binding_param_name}: {binding_param_type}") binding_params.append(f"{binding_param_name}: {binding_param_type}")
if skip: if skip:
self._write(write_to, f" // {function} - has no mapping") self._write(write_to, f" // {function} - has no mapping")

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, TypedDict, NotRequired from typing import Optional, TypedDict, NotRequired, Union
from pycparser import c_ast, parse_file, preprocess_file from pycparser import c_ast, parse_file, preprocess_file
from pathlib import Path from pathlib import Path
import os import os
@ -9,22 +9,24 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ExtractedSymbolType = Union[str, "ExtractedFunction"]
class ExtractedStructAttributeUnion(TypedDict): class ExtractedStructAttributeUnion(TypedDict):
type: Optional[str] type: Optional[ExtractedSymbolType]
class ExtractedStructAttribute(TypedDict): class ExtractedStructAttribute(TypedDict):
type: Optional[str] type: NotRequired[ExtractedSymbolType]
union: Optional[dict[str, Optional[str]]] union: NotRequired[dict[str, Optional[ExtractedSymbolType]]]
class ExtractedStruct(TypedDict): class ExtractedStruct(TypedDict):
attrs: dict[str, ExtractedStructAttribute] attrs: dict[str, ExtractedStructAttribute]
is_union: NotRequired[bool] is_union: NotRequired[bool]
ExtractedEnum = dict[str, Optional[str]] ExtractedEnum = dict[str, Optional[str]]
ExtractedFunctionParam = tuple[str, Optional[str]] ExtractedFunctionParam = tuple[str, Optional[ExtractedSymbolType]]
class ExtractedFunction(TypedDict): class ExtractedFunction(TypedDict):
return_type: Optional[str] return_type: Optional["ExtractedSymbolType"]
params: list[ExtractedFunctionParam] params: list[ExtractedFunctionParam]
@dataclass @dataclass
@ -33,11 +35,21 @@ class ExtractedSymbols:
enums: dict[str, ExtractedEnum] enums: dict[str, ExtractedEnum]
functions: dict[str, ExtractedFunction] functions: dict[str, ExtractedFunction]
def get_type_names(node: c_ast.Node, prefix: str="") -> Optional[str]: def get_type_names(node: c_ast.Node, prefix: str="") -> Optional[ExtractedSymbolType]:
if isinstance(node, c_ast.TypeDecl) and hasattr(node, 'quals') and node.quals: if isinstance(node, c_ast.TypeDecl) and hasattr(node, 'quals') and node.quals:
prefix = " ".join(node.quals) + " " + prefix prefix = " ".join(node.quals) + " " + prefix
if isinstance(node, c_ast.PtrDecl): if isinstance(node, c_ast.PtrDecl):
prefix = "*" + prefix prefix = "*" + prefix
if isinstance(node, c_ast.FuncDecl):
func: ExtractedFunction = {
'return_type': get_type_names(node.type),
'params': [],
}
for param in node.args.params:
if param.name is None:
continue
func['params'].append((param.name, get_type_names(param)))
return func
if hasattr(node, 'names'): if hasattr(node, 'names'):
return prefix + node.names[0] # type: ignore return prefix + node.names[0] # type: ignore
@ -62,7 +74,7 @@ class Visitor(c_ast.NodeVisitor):
if hasattr(node_type, "declname"): if hasattr(node_type, "declname"):
return_type = get_type_names(node_type.type) return_type = get_type_names(node_type.type)
if return_type is not None and is_pointer: if return_type is not None and isinstance(return_type, str) and is_pointer:
return_type = "*" + return_type return_type = "*" + return_type
func: ExtractedFunction = { func: ExtractedFunction = {
'return_type': return_type, 'return_type': return_type,
@ -91,6 +103,8 @@ class Visitor(c_ast.NodeVisitor):
def visit_Typedef(self, node: c_ast.Typedef): def visit_Typedef(self, node: c_ast.Typedef):
# node.show() # node.show()
if hasattr(node.type, 'type') and hasattr(node.type.type, 'decls') and node.type.type.decls: if hasattr(node.type, 'type') and hasattr(node.type.type, 'decls') and node.type.type.decls:
if node.name == "Clay_ErrorHandler":
logger.debug(node)
struct = {} struct = {}
for decl in node.type.type.decls: for decl in node.type.type.decls:
if hasattr(decl, 'type') and hasattr(decl.type, 'type') and isinstance(decl.type.type, c_ast.Union): if hasattr(decl, 'type') and hasattr(decl.type, 'type') and isinstance(decl.type.type, c_ast.Union):