From 4f4605eff9c553ed51f29f8611908e88c10ad10e Mon Sep 17 00:00:00 2001 From: Harrison Lambeth Date: Sun, 26 Jan 2025 14:43:51 -0700 Subject: [PATCH] Add proper support for function arguments --- generator/generators/odin_generator.py | 83 +++++++++++++++++--------- generator/parser.py | 30 +++++++--- 2 files changed, 76 insertions(+), 37 deletions(-) diff --git a/generator/generators/odin_generator.py b/generator/generators/odin_generator.py index 12abe13..ef9ad6d 100644 --- a/generator/generators/odin_generator.py +++ b/generator/generators/odin_generator.py @@ -1,6 +1,7 @@ from pathlib import Path import logging +from parser import ExtractedSymbolType from generators.base_generator import BaseGenerator logger = logging.getLogger(__name__) @@ -61,23 +62,12 @@ TYPE_MAPPING = { 'int32_t': 'c.int32_t', 'uintptr_t': 'rawptr', 'void': 'void', - - '*Clay_RectangleElementConfig': '^RectangleElementConfig', - '*Clay_TextElementConfig': '^TextElementConfig', - '*Clay_ImageElementConfig': '^ImageElementConfig', - '*Clay_FloatingElementConfig': '^FloatingElementConfig', - '*Clay_CustomElementConfig': '^CustomElementConfig', - '*Clay_ScrollElementConfig': '^ScrollElementConfig', - '*Clay_BorderElementConfig': '^BorderElementConfig', } STRUCT_TYPE_OVERRIDES = { 'Clay_Arena': { 'nextAllocation': 'uintptr', 'capacity': 'uintptr', }, - 'Clay_ErrorHandler': { - 'errorHandlerFunction': 'proc "c" (errorData: ErrorData)', - }, 'Clay_SizingAxis': { 'size': 'SizingConstraints', }, @@ -108,7 +98,6 @@ FUNCTION_TYPE_OVERRIDES = { 'offset': '[^]u8', }, 'Clay_SetMeasureTextFunction': { - 'measureTextFunction': 'proc "c" (text: ^StringSlice, config: ^TextElementConfig, userData: uintptr) -> Dimensions', 'userData': 'uintptr', }, 'Clay_RenderCommandArray_Get': { @@ -155,20 +144,49 @@ class OdinGenerator(BaseGenerator): return base_name raise ValueError(f'Unknown symbol: {symbol}') - def resolve_binding_type(self, symbol: str, member: str | None, member_type: str | None, type_overrides: dict[str, dict[str, str]]) -> str | None: - if member_type in SYMBOL_COMPLETE_OVERRIDES: - return SYMBOL_COMPLETE_OVERRIDES[member_type] - if symbol in type_overrides and member in type_overrides[symbol]: - return type_overrides[symbol][member] - if member_type in TYPE_MAPPING: - return TYPE_MAPPING[member_type] - if member_type and self.has_symbol(member_type): - return self.get_symbol_name(member_type) - if member_type and member_type.startswith('*'): - result = self.resolve_binding_type(symbol, member, member_type[1:], type_overrides) - if result: - return f"^{result}" - return None + def format_type(self, type: ExtractedSymbolType) -> str: + if isinstance(type, str): + return type + + parameter_strs = [] + for param_name, param_type in type['params']: + parameter_strs.append(f"{param_name}: {self.format_type(param_type or 'unknown')}") + return_type_str = '' + if type['return_type'] is not None and type['return_type'] != 'void': + return_type_str = ' -> ' + self.format_type(type['return_type']) + return f"proc \"c\" ({', '.join(parameter_strs)}){return_type_str}" + + def resolve_binding_type(self, symbol: str, member: str | None, member_type: ExtractedSymbolType | None, type_overrides: dict[str, dict[str, str]]) -> ExtractedSymbolType | None: + if isinstance(member_type, str): + if member_type in SYMBOL_COMPLETE_OVERRIDES: + return SYMBOL_COMPLETE_OVERRIDES[member_type] + if symbol in type_overrides and member in type_overrides[symbol]: + return type_overrides[symbol][member] + if member_type in TYPE_MAPPING: + return TYPE_MAPPING[member_type] + if member_type and self.has_symbol(member_type): + return self.get_symbol_name(member_type) + if member_type and member_type.startswith('*'): + result = self.resolve_binding_type(symbol, member, member_type[1:], type_overrides) + if result: + return f"^{result}" + return None + if member_type is None: + return None + + resolved_parameters = [] + for param_name, param_type in member_type['params']: + resolved_param = self.resolve_binding_type(symbol, param_name, param_type, type_overrides) + if resolved_param is None: + return None + resolved_parameters.append((param_name, resolved_param)) + resolved_return_type = self.resolve_binding_type(symbol, None, member_type['return_type'], type_overrides) + if resolved_return_type is None: + return None + return { + "params": resolved_parameters, + "return_type": resolved_return_type, + } def generate_structs(self) -> None: for struct, struct_data in sorted(self.extracted_symbols.structs.items(), key=lambda x: x[0]): @@ -184,12 +202,15 @@ class OdinGenerator(BaseGenerator): if struct in STRUCT_OVERRIDE_AS_FIXED_ARRAY: array_size = len(members) - array_type = list(members.values())[0]['type'] + first_elem = list(members.values())[0] + array_type = None + if 'type' in first_elem: + array_type = first_elem['type'] if array_type in TYPE_MAPPING: array_binding_type = TYPE_MAPPING[array_type] - elif array_type and self.has_symbol(array_type): - array_binding_type = self.get_symbol_name(array_type) + elif array_type and self.has_symbol(self.format_type(array_type)): + array_binding_type = self.get_symbol_name(self.format_type(array_type)) else: self._write('struct', f"// {struct} ({array_type}) - has no mapping") continue @@ -221,6 +242,7 @@ class OdinGenerator(BaseGenerator): if member_binding_type is None: self._write('struct', f" // {binding_member_name} ({member_type}) - has no mapping") continue + member_binding_type = self.format_type(member_binding_type) self._write('struct', f" {binding_member_name}: {member_binding_type}, // {member} ({member_type})") self._write('struct', "}") self._write('struct', '') @@ -272,6 +294,7 @@ class OdinGenerator(BaseGenerator): if binding_return_type is None: self._write(write_to, f" // {function} ({return_type}) - has no mapping") continue + binding_return_type = self.format_type(binding_return_type) skip = False binding_params = [] @@ -282,6 +305,8 @@ class OdinGenerator(BaseGenerator): binding_param_type = self.resolve_binding_type(function, param_name, param_type, FUNCTION_TYPE_OVERRIDES) if binding_param_type is None: skip = True + else: + binding_param_type = self.format_type(binding_param_type) binding_params.append(f"{binding_param_name}: {binding_param_type}") if skip: self._write(write_to, f" // {function} - has no mapping") diff --git a/generator/parser.py b/generator/parser.py index 3b5e218..8346b2d 100644 --- a/generator/parser.py +++ b/generator/parser.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Optional, TypedDict, NotRequired +from typing import Optional, TypedDict, NotRequired, Union from pycparser import c_ast, parse_file, preprocess_file from pathlib import Path import os @@ -9,22 +9,24 @@ import logging logger = logging.getLogger(__name__) +ExtractedSymbolType = Union[str, "ExtractedFunction"] + class ExtractedStructAttributeUnion(TypedDict): - type: Optional[str] + type: Optional[ExtractedSymbolType] class ExtractedStructAttribute(TypedDict): - type: Optional[str] - union: Optional[dict[str, Optional[str]]] + type: NotRequired[ExtractedSymbolType] + union: NotRequired[dict[str, Optional[ExtractedSymbolType]]] class ExtractedStruct(TypedDict): attrs: dict[str, ExtractedStructAttribute] is_union: NotRequired[bool] ExtractedEnum = dict[str, Optional[str]] -ExtractedFunctionParam = tuple[str, Optional[str]] +ExtractedFunctionParam = tuple[str, Optional[ExtractedSymbolType]] class ExtractedFunction(TypedDict): - return_type: Optional[str] + return_type: Optional["ExtractedSymbolType"] params: list[ExtractedFunctionParam] @dataclass @@ -33,11 +35,21 @@ class ExtractedSymbols: enums: dict[str, ExtractedEnum] functions: dict[str, ExtractedFunction] -def get_type_names(node: c_ast.Node, prefix: str="") -> Optional[str]: +def get_type_names(node: c_ast.Node, prefix: str="") -> Optional[ExtractedSymbolType]: if isinstance(node, c_ast.TypeDecl) and hasattr(node, 'quals') and node.quals: prefix = " ".join(node.quals) + " " + prefix if isinstance(node, c_ast.PtrDecl): prefix = "*" + prefix + if isinstance(node, c_ast.FuncDecl): + func: ExtractedFunction = { + 'return_type': get_type_names(node.type), + 'params': [], + } + for param in node.args.params: + if param.name is None: + continue + func['params'].append((param.name, get_type_names(param))) + return func if hasattr(node, 'names'): return prefix + node.names[0] # type: ignore @@ -62,7 +74,7 @@ class Visitor(c_ast.NodeVisitor): if hasattr(node_type, "declname"): return_type = get_type_names(node_type.type) - if return_type is not None and is_pointer: + if return_type is not None and isinstance(return_type, str) and is_pointer: return_type = "*" + return_type func: ExtractedFunction = { 'return_type': return_type, @@ -91,6 +103,8 @@ class Visitor(c_ast.NodeVisitor): def visit_Typedef(self, node: c_ast.Typedef): # node.show() if hasattr(node.type, 'type') and hasattr(node.type.type, 'decls') and node.type.type.decls: + if node.name == "Clay_ErrorHandler": + logger.debug(node) struct = {} for decl in node.type.type.decls: if hasattr(decl, 'type') and hasattr(decl.type, 'type') and isinstance(decl.type.type, c_ast.Union):