mirror of
https://github.com/nicbarker/clay.git
synced 2025-04-19 04:38:01 +00:00
Add proper support for function arguments
This commit is contained in:
parent
7c65f31f46
commit
4f4605eff9
@ -1,6 +1,7 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from parser import ExtractedSymbolType
|
||||||
from generators.base_generator import BaseGenerator
|
from generators.base_generator import BaseGenerator
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -61,23 +62,12 @@ TYPE_MAPPING = {
|
|||||||
'int32_t': 'c.int32_t',
|
'int32_t': 'c.int32_t',
|
||||||
'uintptr_t': 'rawptr',
|
'uintptr_t': 'rawptr',
|
||||||
'void': 'void',
|
'void': 'void',
|
||||||
|
|
||||||
'*Clay_RectangleElementConfig': '^RectangleElementConfig',
|
|
||||||
'*Clay_TextElementConfig': '^TextElementConfig',
|
|
||||||
'*Clay_ImageElementConfig': '^ImageElementConfig',
|
|
||||||
'*Clay_FloatingElementConfig': '^FloatingElementConfig',
|
|
||||||
'*Clay_CustomElementConfig': '^CustomElementConfig',
|
|
||||||
'*Clay_ScrollElementConfig': '^ScrollElementConfig',
|
|
||||||
'*Clay_BorderElementConfig': '^BorderElementConfig',
|
|
||||||
}
|
}
|
||||||
STRUCT_TYPE_OVERRIDES = {
|
STRUCT_TYPE_OVERRIDES = {
|
||||||
'Clay_Arena': {
|
'Clay_Arena': {
|
||||||
'nextAllocation': 'uintptr',
|
'nextAllocation': 'uintptr',
|
||||||
'capacity': 'uintptr',
|
'capacity': 'uintptr',
|
||||||
},
|
},
|
||||||
'Clay_ErrorHandler': {
|
|
||||||
'errorHandlerFunction': 'proc "c" (errorData: ErrorData)',
|
|
||||||
},
|
|
||||||
'Clay_SizingAxis': {
|
'Clay_SizingAxis': {
|
||||||
'size': 'SizingConstraints',
|
'size': 'SizingConstraints',
|
||||||
},
|
},
|
||||||
@ -108,7 +98,6 @@ FUNCTION_TYPE_OVERRIDES = {
|
|||||||
'offset': '[^]u8',
|
'offset': '[^]u8',
|
||||||
},
|
},
|
||||||
'Clay_SetMeasureTextFunction': {
|
'Clay_SetMeasureTextFunction': {
|
||||||
'measureTextFunction': 'proc "c" (text: ^StringSlice, config: ^TextElementConfig, userData: uintptr) -> Dimensions',
|
|
||||||
'userData': 'uintptr',
|
'userData': 'uintptr',
|
||||||
},
|
},
|
||||||
'Clay_RenderCommandArray_Get': {
|
'Clay_RenderCommandArray_Get': {
|
||||||
@ -155,7 +144,20 @@ class OdinGenerator(BaseGenerator):
|
|||||||
return base_name
|
return base_name
|
||||||
raise ValueError(f'Unknown symbol: {symbol}')
|
raise ValueError(f'Unknown symbol: {symbol}')
|
||||||
|
|
||||||
def resolve_binding_type(self, symbol: str, member: str | None, member_type: str | None, type_overrides: dict[str, dict[str, str]]) -> str | None:
|
def format_type(self, type: ExtractedSymbolType) -> str:
|
||||||
|
if isinstance(type, str):
|
||||||
|
return type
|
||||||
|
|
||||||
|
parameter_strs = []
|
||||||
|
for param_name, param_type in type['params']:
|
||||||
|
parameter_strs.append(f"{param_name}: {self.format_type(param_type or 'unknown')}")
|
||||||
|
return_type_str = ''
|
||||||
|
if type['return_type'] is not None and type['return_type'] != 'void':
|
||||||
|
return_type_str = ' -> ' + self.format_type(type['return_type'])
|
||||||
|
return f"proc \"c\" ({', '.join(parameter_strs)}){return_type_str}"
|
||||||
|
|
||||||
|
def resolve_binding_type(self, symbol: str, member: str | None, member_type: ExtractedSymbolType | None, type_overrides: dict[str, dict[str, str]]) -> ExtractedSymbolType | None:
|
||||||
|
if isinstance(member_type, str):
|
||||||
if member_type in SYMBOL_COMPLETE_OVERRIDES:
|
if member_type in SYMBOL_COMPLETE_OVERRIDES:
|
||||||
return SYMBOL_COMPLETE_OVERRIDES[member_type]
|
return SYMBOL_COMPLETE_OVERRIDES[member_type]
|
||||||
if symbol in type_overrides and member in type_overrides[symbol]:
|
if symbol in type_overrides and member in type_overrides[symbol]:
|
||||||
@ -169,6 +171,22 @@ class OdinGenerator(BaseGenerator):
|
|||||||
if result:
|
if result:
|
||||||
return f"^{result}"
|
return f"^{result}"
|
||||||
return None
|
return None
|
||||||
|
if member_type is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
resolved_parameters = []
|
||||||
|
for param_name, param_type in member_type['params']:
|
||||||
|
resolved_param = self.resolve_binding_type(symbol, param_name, param_type, type_overrides)
|
||||||
|
if resolved_param is None:
|
||||||
|
return None
|
||||||
|
resolved_parameters.append((param_name, resolved_param))
|
||||||
|
resolved_return_type = self.resolve_binding_type(symbol, None, member_type['return_type'], type_overrides)
|
||||||
|
if resolved_return_type is None:
|
||||||
|
return None
|
||||||
|
return {
|
||||||
|
"params": resolved_parameters,
|
||||||
|
"return_type": resolved_return_type,
|
||||||
|
}
|
||||||
|
|
||||||
def generate_structs(self) -> None:
|
def generate_structs(self) -> None:
|
||||||
for struct, struct_data in sorted(self.extracted_symbols.structs.items(), key=lambda x: x[0]):
|
for struct, struct_data in sorted(self.extracted_symbols.structs.items(), key=lambda x: x[0]):
|
||||||
@ -184,12 +202,15 @@ class OdinGenerator(BaseGenerator):
|
|||||||
|
|
||||||
if struct in STRUCT_OVERRIDE_AS_FIXED_ARRAY:
|
if struct in STRUCT_OVERRIDE_AS_FIXED_ARRAY:
|
||||||
array_size = len(members)
|
array_size = len(members)
|
||||||
array_type = list(members.values())[0]['type']
|
first_elem = list(members.values())[0]
|
||||||
|
array_type = None
|
||||||
|
if 'type' in first_elem:
|
||||||
|
array_type = first_elem['type']
|
||||||
|
|
||||||
if array_type in TYPE_MAPPING:
|
if array_type in TYPE_MAPPING:
|
||||||
array_binding_type = TYPE_MAPPING[array_type]
|
array_binding_type = TYPE_MAPPING[array_type]
|
||||||
elif array_type and self.has_symbol(array_type):
|
elif array_type and self.has_symbol(self.format_type(array_type)):
|
||||||
array_binding_type = self.get_symbol_name(array_type)
|
array_binding_type = self.get_symbol_name(self.format_type(array_type))
|
||||||
else:
|
else:
|
||||||
self._write('struct', f"// {struct} ({array_type}) - has no mapping")
|
self._write('struct', f"// {struct} ({array_type}) - has no mapping")
|
||||||
continue
|
continue
|
||||||
@ -221,6 +242,7 @@ class OdinGenerator(BaseGenerator):
|
|||||||
if member_binding_type is None:
|
if member_binding_type is None:
|
||||||
self._write('struct', f" // {binding_member_name} ({member_type}) - has no mapping")
|
self._write('struct', f" // {binding_member_name} ({member_type}) - has no mapping")
|
||||||
continue
|
continue
|
||||||
|
member_binding_type = self.format_type(member_binding_type)
|
||||||
self._write('struct', f" {binding_member_name}: {member_binding_type}, // {member} ({member_type})")
|
self._write('struct', f" {binding_member_name}: {member_binding_type}, // {member} ({member_type})")
|
||||||
self._write('struct', "}")
|
self._write('struct', "}")
|
||||||
self._write('struct', '')
|
self._write('struct', '')
|
||||||
@ -272,6 +294,7 @@ class OdinGenerator(BaseGenerator):
|
|||||||
if binding_return_type is None:
|
if binding_return_type is None:
|
||||||
self._write(write_to, f" // {function} ({return_type}) - has no mapping")
|
self._write(write_to, f" // {function} ({return_type}) - has no mapping")
|
||||||
continue
|
continue
|
||||||
|
binding_return_type = self.format_type(binding_return_type)
|
||||||
|
|
||||||
skip = False
|
skip = False
|
||||||
binding_params = []
|
binding_params = []
|
||||||
@ -282,6 +305,8 @@ class OdinGenerator(BaseGenerator):
|
|||||||
binding_param_type = self.resolve_binding_type(function, param_name, param_type, FUNCTION_TYPE_OVERRIDES)
|
binding_param_type = self.resolve_binding_type(function, param_name, param_type, FUNCTION_TYPE_OVERRIDES)
|
||||||
if binding_param_type is None:
|
if binding_param_type is None:
|
||||||
skip = True
|
skip = True
|
||||||
|
else:
|
||||||
|
binding_param_type = self.format_type(binding_param_type)
|
||||||
binding_params.append(f"{binding_param_name}: {binding_param_type}")
|
binding_params.append(f"{binding_param_name}: {binding_param_type}")
|
||||||
if skip:
|
if skip:
|
||||||
self._write(write_to, f" // {function} - has no mapping")
|
self._write(write_to, f" // {function} - has no mapping")
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, TypedDict, NotRequired
|
from typing import Optional, TypedDict, NotRequired, Union
|
||||||
from pycparser import c_ast, parse_file, preprocess_file
|
from pycparser import c_ast, parse_file, preprocess_file
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import os
|
import os
|
||||||
@ -9,22 +9,24 @@ import logging
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
ExtractedSymbolType = Union[str, "ExtractedFunction"]
|
||||||
|
|
||||||
class ExtractedStructAttributeUnion(TypedDict):
|
class ExtractedStructAttributeUnion(TypedDict):
|
||||||
type: Optional[str]
|
type: Optional[ExtractedSymbolType]
|
||||||
|
|
||||||
class ExtractedStructAttribute(TypedDict):
|
class ExtractedStructAttribute(TypedDict):
|
||||||
type: Optional[str]
|
type: NotRequired[ExtractedSymbolType]
|
||||||
union: Optional[dict[str, Optional[str]]]
|
union: NotRequired[dict[str, Optional[ExtractedSymbolType]]]
|
||||||
|
|
||||||
class ExtractedStruct(TypedDict):
|
class ExtractedStruct(TypedDict):
|
||||||
attrs: dict[str, ExtractedStructAttribute]
|
attrs: dict[str, ExtractedStructAttribute]
|
||||||
is_union: NotRequired[bool]
|
is_union: NotRequired[bool]
|
||||||
|
|
||||||
ExtractedEnum = dict[str, Optional[str]]
|
ExtractedEnum = dict[str, Optional[str]]
|
||||||
ExtractedFunctionParam = tuple[str, Optional[str]]
|
ExtractedFunctionParam = tuple[str, Optional[ExtractedSymbolType]]
|
||||||
|
|
||||||
class ExtractedFunction(TypedDict):
|
class ExtractedFunction(TypedDict):
|
||||||
return_type: Optional[str]
|
return_type: Optional["ExtractedSymbolType"]
|
||||||
params: list[ExtractedFunctionParam]
|
params: list[ExtractedFunctionParam]
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -33,11 +35,21 @@ class ExtractedSymbols:
|
|||||||
enums: dict[str, ExtractedEnum]
|
enums: dict[str, ExtractedEnum]
|
||||||
functions: dict[str, ExtractedFunction]
|
functions: dict[str, ExtractedFunction]
|
||||||
|
|
||||||
def get_type_names(node: c_ast.Node, prefix: str="") -> Optional[str]:
|
def get_type_names(node: c_ast.Node, prefix: str="") -> Optional[ExtractedSymbolType]:
|
||||||
if isinstance(node, c_ast.TypeDecl) and hasattr(node, 'quals') and node.quals:
|
if isinstance(node, c_ast.TypeDecl) and hasattr(node, 'quals') and node.quals:
|
||||||
prefix = " ".join(node.quals) + " " + prefix
|
prefix = " ".join(node.quals) + " " + prefix
|
||||||
if isinstance(node, c_ast.PtrDecl):
|
if isinstance(node, c_ast.PtrDecl):
|
||||||
prefix = "*" + prefix
|
prefix = "*" + prefix
|
||||||
|
if isinstance(node, c_ast.FuncDecl):
|
||||||
|
func: ExtractedFunction = {
|
||||||
|
'return_type': get_type_names(node.type),
|
||||||
|
'params': [],
|
||||||
|
}
|
||||||
|
for param in node.args.params:
|
||||||
|
if param.name is None:
|
||||||
|
continue
|
||||||
|
func['params'].append((param.name, get_type_names(param)))
|
||||||
|
return func
|
||||||
|
|
||||||
if hasattr(node, 'names'):
|
if hasattr(node, 'names'):
|
||||||
return prefix + node.names[0] # type: ignore
|
return prefix + node.names[0] # type: ignore
|
||||||
@ -62,7 +74,7 @@ class Visitor(c_ast.NodeVisitor):
|
|||||||
|
|
||||||
if hasattr(node_type, "declname"):
|
if hasattr(node_type, "declname"):
|
||||||
return_type = get_type_names(node_type.type)
|
return_type = get_type_names(node_type.type)
|
||||||
if return_type is not None and is_pointer:
|
if return_type is not None and isinstance(return_type, str) and is_pointer:
|
||||||
return_type = "*" + return_type
|
return_type = "*" + return_type
|
||||||
func: ExtractedFunction = {
|
func: ExtractedFunction = {
|
||||||
'return_type': return_type,
|
'return_type': return_type,
|
||||||
@ -91,6 +103,8 @@ class Visitor(c_ast.NodeVisitor):
|
|||||||
def visit_Typedef(self, node: c_ast.Typedef):
|
def visit_Typedef(self, node: c_ast.Typedef):
|
||||||
# node.show()
|
# node.show()
|
||||||
if hasattr(node.type, 'type') and hasattr(node.type.type, 'decls') and node.type.type.decls:
|
if hasattr(node.type, 'type') and hasattr(node.type.type, 'decls') and node.type.type.decls:
|
||||||
|
if node.name == "Clay_ErrorHandler":
|
||||||
|
logger.debug(node)
|
||||||
struct = {}
|
struct = {}
|
||||||
for decl in node.type.type.decls:
|
for decl in node.type.type.decls:
|
||||||
if hasattr(decl, 'type') and hasattr(decl.type, 'type') and isinstance(decl.type.type, c_ast.Union):
|
if hasattr(decl, 'type') and hasattr(decl.type, 'type') and isinstance(decl.type.type, c_ast.Union):
|
||||||
|
Loading…
Reference in New Issue
Block a user