from dataclasses import dataclass
from typing import Optional, TypedDict, NotRequired, Union
from pycparser import c_ast, parse_file, preprocess_file
from pathlib import Path
import os
import json
import shutil
import logging

logger = logging.getLogger(__name__)

ExtractedSymbolType = Union[str, "ExtractedFunction"]

class ExtractedStructAttributeUnion(TypedDict):
    type: Optional[ExtractedSymbolType]

class ExtractedStructAttribute(TypedDict):
    type: NotRequired[ExtractedSymbolType]
    union: NotRequired[dict[str, Optional[ExtractedSymbolType]]]

class ExtractedStruct(TypedDict):
    attrs: dict[str, ExtractedStructAttribute]
    is_union: NotRequired[bool]

ExtractedEnum = dict[str, Optional[str]]
ExtractedFunctionParam = tuple[str, Optional[ExtractedSymbolType]]

class ExtractedFunction(TypedDict):
    return_type: Optional["ExtractedSymbolType"]
    params: list[ExtractedFunctionParam]

@dataclass
class ExtractedSymbols:
    structs: dict[str, ExtractedStruct]
    enums: dict[str, ExtractedEnum]
    functions: dict[str, ExtractedFunction]

def get_type_names(node: c_ast.Node, prefix: str="") -> Optional[ExtractedSymbolType]:
    if isinstance(node, c_ast.TypeDecl) and hasattr(node, 'quals') and node.quals:
        prefix = " ".join(node.quals) + " " + prefix
    if isinstance(node, c_ast.PtrDecl):
        prefix = "*" + prefix
    if isinstance(node, c_ast.FuncDecl):
        func: ExtractedFunction = {
            'return_type': get_type_names(node.type),
            'params': [],
        }
        for param in node.args.params:
            if param.name is None:
                continue
            func['params'].append((param.name, get_type_names(param)))
        return func

    if hasattr(node, 'names'):
        return prefix + node.names[0] # type: ignore
    elif hasattr(node, 'type'):
        return get_type_names(node.type, prefix) # type: ignore
    return None

class Visitor(c_ast.NodeVisitor):
    def __init__(self):
        self.structs: dict[str, ExtractedStruct] = {}
        self.enums: dict[str, ExtractedEnum] = {}
        self.functions: dict[str, ExtractedFunction] = {}

    def visit_FuncDecl(self, node: c_ast.FuncDecl):
        # node.show()
        # logger.debug(node)
        node_type = node.type
        is_pointer = False
        if isinstance(node.type, c_ast.PtrDecl):
            node_type = node.type.type
            is_pointer = True
        
        if hasattr(node_type, "declname"):
            return_type = get_type_names(node_type.type)
            if return_type is not None and isinstance(return_type, str) and is_pointer:
                return_type = "*" + return_type
            func: ExtractedFunction = {
                'return_type': return_type,
                'params': [],
            }
            for param in node.args.params:
                if param.name is None:
                    continue
                func['params'].append((param.name, get_type_names(param)))
            self.functions[node_type.declname] = func
        self.generic_visit(node)

    def visit_Struct(self, node: c_ast.Struct):
        # node.show()
        if node.name and node.decls:
            struct = {}
            for decl in node.decls:
                struct[decl.name] = {
                    "type": get_type_names(decl),
                }
            self.structs[node.name] = {
                'attrs': struct,
            }
        self.generic_visit(node)

    def visit_Typedef(self, node: c_ast.Typedef):
        # node.show()
        if hasattr(node.type, 'type') and hasattr(node.type.type, 'decls') and node.type.type.decls:
            struct = {}
            for decl in node.type.type.decls:
                if hasattr(decl, 'type') and hasattr(decl.type, 'type') and isinstance(decl.type.type, c_ast.Union):
                    union = {}
                    for field in decl.type.type.decls:
                        union[field.name] = get_type_names(field)
                    struct[decl.name] = {
                        'union': union
                    }
                else:
                    struct[decl.name] = {
                        "type": get_type_names(decl),
                    }

            self.structs[node.name] = {
                'attrs': struct,
                'is_union': isinstance(node.type.type, c_ast.Union),
            }
        if hasattr(node.type, 'type') and isinstance(node.type.type, c_ast.Enum):
            enum = {}
            for enumerator in node.type.type.values.enumerators:
                if enumerator.value is None:
                    enum[enumerator.name] = None
                else:
                    enum[enumerator.name] = enumerator.value.value
            self.enums[node.name] = enum
        self.generic_visit(node)


def parse_headers(input_files: list[Path], tmp_dir: Path) -> ExtractedSymbols:
    cpp_args = ["-nostdinc", "-D__attribute__(x)=", "-E"]

    # Make a new clay.h that combines the provided input files, so that we can add bindings for customized structs
    with open(tmp_dir / 'merged_clay.h', 'w') as f:
        for input_file in input_files:
            with open(input_file, 'r') as f2:
                for line in f2:
                    # Ignore includes, as they should be manually included in input_files.
                    if line.startswith("#include"):
                        continue

                    # Ignore the CLAY_IMPLEMENTATION define, because we only want to parse the public api code.
                    # This is helpful so that the user can provide their implementation code, which will contain any custom extensions
                    if "#define CLAY_IMPLEMENTATION" in line:
                        continue

                    f.write(line)

    # Preprocess the file
    logger.info("Preprocessing file")
    preprocessed = preprocess_file(tmp_dir / 'merged_clay.h', cpp_path="cpp", cpp_args=cpp_args) # type: ignore
    with open(tmp_dir / 'clay.preprocessed.h', 'w') as f:
        f.write(preprocessed)

    # Parse the file
    logger.info("Parsing file")
    ast = parse_file(tmp_dir / 'clay.preprocessed.h', use_cpp=False) # type: ignore

    # Extract symbols
    visitor = Visitor()
    visitor.visit(ast)
    
    result = ExtractedSymbols(
        structs=visitor.structs,
        enums=visitor.enums,
        functions=visitor.functions
    )
    return result