mirror of
https://github.com/nicbarker/clay.git
synced 2025-04-15 10:48:04 +00:00
Add proper support for function arguments
This commit is contained in:
parent
7c65f31f46
commit
4f4605eff9
@ -1,6 +1,7 @@
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
from parser import ExtractedSymbolType
|
||||
from generators.base_generator import BaseGenerator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -61,23 +62,12 @@ TYPE_MAPPING = {
|
||||
'int32_t': 'c.int32_t',
|
||||
'uintptr_t': 'rawptr',
|
||||
'void': 'void',
|
||||
|
||||
'*Clay_RectangleElementConfig': '^RectangleElementConfig',
|
||||
'*Clay_TextElementConfig': '^TextElementConfig',
|
||||
'*Clay_ImageElementConfig': '^ImageElementConfig',
|
||||
'*Clay_FloatingElementConfig': '^FloatingElementConfig',
|
||||
'*Clay_CustomElementConfig': '^CustomElementConfig',
|
||||
'*Clay_ScrollElementConfig': '^ScrollElementConfig',
|
||||
'*Clay_BorderElementConfig': '^BorderElementConfig',
|
||||
}
|
||||
STRUCT_TYPE_OVERRIDES = {
|
||||
'Clay_Arena': {
|
||||
'nextAllocation': 'uintptr',
|
||||
'capacity': 'uintptr',
|
||||
},
|
||||
'Clay_ErrorHandler': {
|
||||
'errorHandlerFunction': 'proc "c" (errorData: ErrorData)',
|
||||
},
|
||||
'Clay_SizingAxis': {
|
||||
'size': 'SizingConstraints',
|
||||
},
|
||||
@ -108,7 +98,6 @@ FUNCTION_TYPE_OVERRIDES = {
|
||||
'offset': '[^]u8',
|
||||
},
|
||||
'Clay_SetMeasureTextFunction': {
|
||||
'measureTextFunction': 'proc "c" (text: ^StringSlice, config: ^TextElementConfig, userData: uintptr) -> Dimensions',
|
||||
'userData': 'uintptr',
|
||||
},
|
||||
'Clay_RenderCommandArray_Get': {
|
||||
@ -155,20 +144,49 @@ class OdinGenerator(BaseGenerator):
|
||||
return base_name
|
||||
raise ValueError(f'Unknown symbol: {symbol}')
|
||||
|
||||
def resolve_binding_type(self, symbol: str, member: str | None, member_type: str | None, type_overrides: dict[str, dict[str, str]]) -> str | None:
|
||||
if member_type in SYMBOL_COMPLETE_OVERRIDES:
|
||||
return SYMBOL_COMPLETE_OVERRIDES[member_type]
|
||||
if symbol in type_overrides and member in type_overrides[symbol]:
|
||||
return type_overrides[symbol][member]
|
||||
if member_type in TYPE_MAPPING:
|
||||
return TYPE_MAPPING[member_type]
|
||||
if member_type and self.has_symbol(member_type):
|
||||
return self.get_symbol_name(member_type)
|
||||
if member_type and member_type.startswith('*'):
|
||||
result = self.resolve_binding_type(symbol, member, member_type[1:], type_overrides)
|
||||
if result:
|
||||
return f"^{result}"
|
||||
return None
|
||||
def format_type(self, type: ExtractedSymbolType) -> str:
|
||||
if isinstance(type, str):
|
||||
return type
|
||||
|
||||
parameter_strs = []
|
||||
for param_name, param_type in type['params']:
|
||||
parameter_strs.append(f"{param_name}: {self.format_type(param_type or 'unknown')}")
|
||||
return_type_str = ''
|
||||
if type['return_type'] is not None and type['return_type'] != 'void':
|
||||
return_type_str = ' -> ' + self.format_type(type['return_type'])
|
||||
return f"proc \"c\" ({', '.join(parameter_strs)}){return_type_str}"
|
||||
|
||||
def resolve_binding_type(self, symbol: str, member: str | None, member_type: ExtractedSymbolType | None, type_overrides: dict[str, dict[str, str]]) -> ExtractedSymbolType | None:
|
||||
if isinstance(member_type, str):
|
||||
if member_type in SYMBOL_COMPLETE_OVERRIDES:
|
||||
return SYMBOL_COMPLETE_OVERRIDES[member_type]
|
||||
if symbol in type_overrides and member in type_overrides[symbol]:
|
||||
return type_overrides[symbol][member]
|
||||
if member_type in TYPE_MAPPING:
|
||||
return TYPE_MAPPING[member_type]
|
||||
if member_type and self.has_symbol(member_type):
|
||||
return self.get_symbol_name(member_type)
|
||||
if member_type and member_type.startswith('*'):
|
||||
result = self.resolve_binding_type(symbol, member, member_type[1:], type_overrides)
|
||||
if result:
|
||||
return f"^{result}"
|
||||
return None
|
||||
if member_type is None:
|
||||
return None
|
||||
|
||||
resolved_parameters = []
|
||||
for param_name, param_type in member_type['params']:
|
||||
resolved_param = self.resolve_binding_type(symbol, param_name, param_type, type_overrides)
|
||||
if resolved_param is None:
|
||||
return None
|
||||
resolved_parameters.append((param_name, resolved_param))
|
||||
resolved_return_type = self.resolve_binding_type(symbol, None, member_type['return_type'], type_overrides)
|
||||
if resolved_return_type is None:
|
||||
return None
|
||||
return {
|
||||
"params": resolved_parameters,
|
||||
"return_type": resolved_return_type,
|
||||
}
|
||||
|
||||
def generate_structs(self) -> None:
|
||||
for struct, struct_data in sorted(self.extracted_symbols.structs.items(), key=lambda x: x[0]):
|
||||
@ -184,12 +202,15 @@ class OdinGenerator(BaseGenerator):
|
||||
|
||||
if struct in STRUCT_OVERRIDE_AS_FIXED_ARRAY:
|
||||
array_size = len(members)
|
||||
array_type = list(members.values())[0]['type']
|
||||
first_elem = list(members.values())[0]
|
||||
array_type = None
|
||||
if 'type' in first_elem:
|
||||
array_type = first_elem['type']
|
||||
|
||||
if array_type in TYPE_MAPPING:
|
||||
array_binding_type = TYPE_MAPPING[array_type]
|
||||
elif array_type and self.has_symbol(array_type):
|
||||
array_binding_type = self.get_symbol_name(array_type)
|
||||
elif array_type and self.has_symbol(self.format_type(array_type)):
|
||||
array_binding_type = self.get_symbol_name(self.format_type(array_type))
|
||||
else:
|
||||
self._write('struct', f"// {struct} ({array_type}) - has no mapping")
|
||||
continue
|
||||
@ -221,6 +242,7 @@ class OdinGenerator(BaseGenerator):
|
||||
if member_binding_type is None:
|
||||
self._write('struct', f" // {binding_member_name} ({member_type}) - has no mapping")
|
||||
continue
|
||||
member_binding_type = self.format_type(member_binding_type)
|
||||
self._write('struct', f" {binding_member_name}: {member_binding_type}, // {member} ({member_type})")
|
||||
self._write('struct', "}")
|
||||
self._write('struct', '')
|
||||
@ -272,6 +294,7 @@ class OdinGenerator(BaseGenerator):
|
||||
if binding_return_type is None:
|
||||
self._write(write_to, f" // {function} ({return_type}) - has no mapping")
|
||||
continue
|
||||
binding_return_type = self.format_type(binding_return_type)
|
||||
|
||||
skip = False
|
||||
binding_params = []
|
||||
@ -282,6 +305,8 @@ class OdinGenerator(BaseGenerator):
|
||||
binding_param_type = self.resolve_binding_type(function, param_name, param_type, FUNCTION_TYPE_OVERRIDES)
|
||||
if binding_param_type is None:
|
||||
skip = True
|
||||
else:
|
||||
binding_param_type = self.format_type(binding_param_type)
|
||||
binding_params.append(f"{binding_param_name}: {binding_param_type}")
|
||||
if skip:
|
||||
self._write(write_to, f" // {function} - has no mapping")
|
||||
|
@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, TypedDict, NotRequired
|
||||
from typing import Optional, TypedDict, NotRequired, Union
|
||||
from pycparser import c_ast, parse_file, preprocess_file
|
||||
from pathlib import Path
|
||||
import os
|
||||
@ -9,22 +9,24 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ExtractedSymbolType = Union[str, "ExtractedFunction"]
|
||||
|
||||
class ExtractedStructAttributeUnion(TypedDict):
|
||||
type: Optional[str]
|
||||
type: Optional[ExtractedSymbolType]
|
||||
|
||||
class ExtractedStructAttribute(TypedDict):
|
||||
type: Optional[str]
|
||||
union: Optional[dict[str, Optional[str]]]
|
||||
type: NotRequired[ExtractedSymbolType]
|
||||
union: NotRequired[dict[str, Optional[ExtractedSymbolType]]]
|
||||
|
||||
class ExtractedStruct(TypedDict):
|
||||
attrs: dict[str, ExtractedStructAttribute]
|
||||
is_union: NotRequired[bool]
|
||||
|
||||
ExtractedEnum = dict[str, Optional[str]]
|
||||
ExtractedFunctionParam = tuple[str, Optional[str]]
|
||||
ExtractedFunctionParam = tuple[str, Optional[ExtractedSymbolType]]
|
||||
|
||||
class ExtractedFunction(TypedDict):
|
||||
return_type: Optional[str]
|
||||
return_type: Optional["ExtractedSymbolType"]
|
||||
params: list[ExtractedFunctionParam]
|
||||
|
||||
@dataclass
|
||||
@ -33,11 +35,21 @@ class ExtractedSymbols:
|
||||
enums: dict[str, ExtractedEnum]
|
||||
functions: dict[str, ExtractedFunction]
|
||||
|
||||
def get_type_names(node: c_ast.Node, prefix: str="") -> Optional[str]:
|
||||
def get_type_names(node: c_ast.Node, prefix: str="") -> Optional[ExtractedSymbolType]:
|
||||
if isinstance(node, c_ast.TypeDecl) and hasattr(node, 'quals') and node.quals:
|
||||
prefix = " ".join(node.quals) + " " + prefix
|
||||
if isinstance(node, c_ast.PtrDecl):
|
||||
prefix = "*" + prefix
|
||||
if isinstance(node, c_ast.FuncDecl):
|
||||
func: ExtractedFunction = {
|
||||
'return_type': get_type_names(node.type),
|
||||
'params': [],
|
||||
}
|
||||
for param in node.args.params:
|
||||
if param.name is None:
|
||||
continue
|
||||
func['params'].append((param.name, get_type_names(param)))
|
||||
return func
|
||||
|
||||
if hasattr(node, 'names'):
|
||||
return prefix + node.names[0] # type: ignore
|
||||
@ -62,7 +74,7 @@ class Visitor(c_ast.NodeVisitor):
|
||||
|
||||
if hasattr(node_type, "declname"):
|
||||
return_type = get_type_names(node_type.type)
|
||||
if return_type is not None and is_pointer:
|
||||
if return_type is not None and isinstance(return_type, str) and is_pointer:
|
||||
return_type = "*" + return_type
|
||||
func: ExtractedFunction = {
|
||||
'return_type': return_type,
|
||||
@ -91,6 +103,8 @@ class Visitor(c_ast.NodeVisitor):
|
||||
def visit_Typedef(self, node: c_ast.Typedef):
|
||||
# node.show()
|
||||
if hasattr(node.type, 'type') and hasattr(node.type.type, 'decls') and node.type.type.decls:
|
||||
if node.name == "Clay_ErrorHandler":
|
||||
logger.debug(node)
|
||||
struct = {}
|
||||
for decl in node.type.type.decls:
|
||||
if hasattr(decl, 'type') and hasattr(decl.type, 'type') and isinstance(decl.type.type, c_ast.Union):
|
||||
|
Loading…
Reference in New Issue
Block a user