import pygame
import math
import sys

# ---------- Configuration ----------
WIDTH, HEIGHT = 800, 600
VOXEL_SIZE = 10

# ---------- Global Counters ----------
# These counters are used to sum the number of AABB intersection tests
# over all rays (reset each frame).
total_bvh_checks = 0
total_grid_checks = 0

# ---------- Data Structures ----------

class AABB:
    """Axis-Aligned Bounding Box for a voxel or a BVH node."""
    def __init__(self, x, y, w, h):
        self.x = x  # top-left x
        self.y = y  # top-left y
        self.w = w
        self.h = h

    def union(self, other):
        """Returns the smallest AABB that contains both self and other."""
        x_min = min(self.x, other.x)
        y_min = min(self.y, other.y)
        x_max = max(self.x + self.w, other.x + other.w)
        y_max = max(self.y + self.h, other.y + other.h)
        return AABB(x_min, y_min, x_max - x_min, y_max - y_min)

    def intersect_ray(self, origin, direction):
        """
        Ray-AABB intersection using the slab method.
        Returns distance t if hit, or None if no intersection.
        """
        invDx = 1.0 / (direction[0] if direction[0] != 0 else 1e-6)
        invDy = 1.0 / (direction[1] if direction[1] != 0 else 1e-6)
        
        t1 = (self.x - origin[0]) * invDx
        t2 = ((self.x + self.w) - origin[0]) * invDx
        t3 = (self.y - origin[1]) * invDy
        t4 = ((self.y + self.h) - origin[1]) * invDy

        tmin = max(min(t1, t2), min(t3, t4))
        tmax = min(max(t1, t2), max(t3, t4))

        if tmax < 0 or tmin > tmax:
            return None  # No intersection
        return tmin

class BVHNode:
    """A node in the BVH. If leaf, 'voxel' is set; otherwise, children are in 'left' and 'right'."""
    def __init__(self, aabb, voxel=None, left=None, right=None):
        self.aabb = aabb
        self.voxel = voxel  # (i, j) tuple if leaf
        self.left = left
        self.right = right
        # Use red color for all nodes
        self.color = (255, 0, 0)

def build_bvh(voxels):
    """
    Build a BVH recursively from a list of voxels.
    Each voxel is a tuple: (i, j, AABB)
    """
    if not voxels:
        return None
    if len(voxels) == 1:
        i, j, aabb = voxels[0]
        return BVHNode(aabb, voxel=(i, j))

    # Compute combined bounding box.
    combined = voxels[0][2]
    for _, _, aabb in voxels[1:]:
        combined = combined.union(aabb)

    # Choose the axis to split: 0 = x, 1 = y.
    axis = 0
    if combined.w < combined.h:
        axis = 1

    # Sort voxels by the center coordinate on the chosen axis.
    voxels.sort(key=lambda item: (item[2].x + item[2].w / 2) if axis == 0 else (item[2].y + item[2].h / 2))
    mid = len(voxels) // 2

    left_node = build_bvh(voxels[:mid])
    right_node = build_bvh(voxels[mid:])
    return BVHNode(combined, left=left_node, right=right_node)

def bvh_raycast(node, ray_origin, ray_dir, counters):
    """
    Traverse the BVH recursively and perform a raycast.
    Only at leaf nodes will the actual voxel "collision" be returned.
    The 'counters' dictionary tracks the number of intersection tests.
    Returns a tuple (hit_voxel, t_hit) if a hit is found, else None.
    """
    if node is None:
        return None

    counters['bvh'] += 1  # Count this node's AABB check
    t = node.aabb.intersect_ray(ray_origin, ray_dir)
    if t is None:
        return None

    # If leaf, return the voxel (no additional voxel-specific intersection is done)
    if node.voxel is not None:
        return (node.voxel, t)

    hit_left = bvh_raycast(node.left, ray_origin, ray_dir, counters)
    hit_right = bvh_raycast(node.right, ray_origin, ray_dir, counters)

    if hit_left and hit_right:
        return hit_left if hit_left[1] < hit_right[1] else hit_right
    elif hit_left:
        return hit_left
    elif hit_right:
        return hit_right
    else:
        return None

def grid_raycast(voxels, ray_origin, ray_dir, counters):
    """
    Perform a brute-force raycast over all voxels in the grid.
    For each voxel a ray-AABB test is done.
    Returns a tuple (hit_voxel, t_hit) if a voxel is hit, else None.
    """
    hit = None
    for (i, j) in voxels.keys():
        # Create the voxel's AABB.
        aabb = AABB(i * VOXEL_SIZE, j * VOXEL_SIZE, VOXEL_SIZE, VOXEL_SIZE)
        counters['grid'] += 1
        t = aabb.intersect_ray(ray_origin, ray_dir)
        if t is not None:
            if hit is None or t < hit[1]:
                hit = ((i, j), t)
    return hit

def draw_bvh(node, surface):
    """Recursively draw the BVH bounding boxes in red."""
    if node is None:
        return
    rect = pygame.Rect(node.aabb.x, node.aabb.y, node.aabb.w, node.aabb.h)
    pygame.draw.rect(surface, node.color, rect, 1)
    draw_bvh(node.left, surface)
    draw_bvh(node.right, surface)

# ---------- Main Program ----------

pygame.init()
screen = pygame.display.set_mode((WIDTH, HEIGHT))
pygame.display.set_caption("2D Voxel Editor with BVH & Multi-Ray Debug")
clock = pygame.time.Clock()
font = pygame.font.SysFont('Arial', 16)

# Dictionary to store voxels: key is (i, j) grid coordinate.
voxels = {}

def add_voxel(i, j):
    voxels[(i, j)] = True

def remove_voxel(i, j):
    if (i, j) in voxels:
        del voxels[(i, j)]

# Control flag for debug mode (toggled with F3).
debug_mode = True

running = True
mouse_down = False
mouse_button = None

while running:
    # Reset the total counters each frame.
    total_bvh_checks = 0
    total_grid_checks = 0

    for event in pygame.event.get():
        if event.type == pygame.QUIT:
            running = False

        elif event.type == pygame.KEYDOWN:
            if event.key == pygame.K_F3:
                debug_mode = not debug_mode

        elif event.type == pygame.MOUSEBUTTONDOWN:
            mouse_down = True
            mouse_button = event.button  # 1=left, 3=right

        elif event.type == pygame.MOUSEBUTTONUP:
            mouse_down = False
            mouse_button = None

    # Process mouse dragging to add or remove voxels.
    if mouse_down:
        mx, my = pygame.mouse.get_pos()
        grid_x = mx // VOXEL_SIZE
        grid_y = my // VOXEL_SIZE
        if mouse_button == 1:
            add_voxel(grid_x, grid_y)
        elif mouse_button == 3:
            remove_voxel(grid_x, grid_y)

    # Build a list of voxel AABBs for the BVH.
    voxel_list = []
    for (i, j) in voxels.keys():
        aabb = AABB(i * VOXEL_SIZE, j * VOXEL_SIZE, VOXEL_SIZE, VOXEL_SIZE)
        voxel_list.append((i, j, aabb))
    bvh_root = build_bvh(voxel_list)

    # Set up multi-ray casting.
    center = (WIDTH / 2, HEIGHT / 2)
    mouse_pos = pygame.mouse.get_pos()

    # Compute base direction from center to mouse pointer.
    base_angle = math.atan2(mouse_pos[1] - center[1], mouse_pos[0] - center[0])
    num_rays = 25
    fov_degrees = 25  # total field-of-view in degrees.
    half_fov = math.radians(fov_degrees / 2)
    angle_step = math.radians(fov_degrees) / (num_rays - 1)

    # Prepare lists to store results for each ray.
    ray_lines = []        # (start, end) of each ray.
    bvh_ray_hits = []     # hit voxel from BVH per ray.
    grid_ray_hits = []    # hit voxel from grid per ray.
    ray_counters = []     # each element is a tuple (bvh_checks, grid_checks) for that ray.

    for i in range(num_rays):
        # Calculate the angle for this ray.
        angle = base_angle - half_fov + i * angle_step
        ray_dir = (math.cos(angle), math.sin(angle))
        # Define a distant end point (for drawing the ray).
        ray_length = max(WIDTH, HEIGHT)
        ray_end = (center[0] + ray_dir[0] * ray_length,
                   center[1] + ray_dir[1] * ray_length)

        # Prepare separate counters for this ray.
        counters = {'bvh': 0, 'grid': 0}

        # Perform BVH raycast.
        bvh_result = bvh_raycast(bvh_root, center, ray_dir, counters)
        # Perform brute-force grid raycast.
        grid_result = grid_raycast(voxels, center, ray_dir, counters)

        ray_lines.append((center, ray_end))
        bvh_ray_hits.append(bvh_result)
        grid_ray_hits.append(grid_result)
        ray_counters.append((counters['bvh'], counters['grid']))

        total_bvh_checks += counters['bvh']
        total_grid_checks += counters['grid']

    # ---------- Rendering ----------
    screen.fill((30, 30, 30))
    
    # Draw voxels as filled white squares.
    for (i, j) in voxels.keys():
        rect = pygame.Rect(i * VOXEL_SIZE, j * VOXEL_SIZE, VOXEL_SIZE, VOXEL_SIZE)
        pygame.draw.rect(screen, (255, 255, 255), rect)

    if debug_mode:
        # Draw the BVH bounding boxes in red.
        if bvh_root:
            draw_bvh(bvh_root, screen)

        # Draw each ray (green lines).
        for ray_line in ray_lines:
            pygame.draw.line(screen, (0, 255, 0), ray_line[0], ray_line[1], 1)

        # For each ray, if it hit a voxel via the BVH, highlight it with a blue border.
        for hit in bvh_ray_hits:
            if hit:
                hit_voxel, _ = hit
                i, j = hit_voxel
                hit_rect = pygame.Rect(i * VOXEL_SIZE, j * VOXEL_SIZE, VOXEL_SIZE, VOXEL_SIZE)
                pygame.draw.rect(screen, (0, 0, 255), hit_rect, 3)
        
        # Display debug menu with ray results and counters.
        debug_text = [
            "DEBUG MODE (F3 toggles)",
            "Left-drag: add voxel, Right-drag: remove voxel",
            "25 Rays from center in 25° FOV (green)",
            f"Total BVH Intersection Checks: {total_bvh_checks}",
            f"Total Grid Intersection Checks: {total_grid_checks}"
        ]
        # Optionally add per-ray info (commented out for clarity)
        for idx, (bvh_count, grid_count) in enumerate(ray_counters):
            debug_text.append(f"Ray {idx}: BVH={bvh_count} Grid={grid_count}")

        for idx, line in enumerate(debug_text):
            text_surf = font.render(line, True, (255, 255, 0))
            screen.blit(text_surf, (5, 5 + idx * 18))
    
    pygame.display.flip()
    clock.tick()

pygame.quit()
sys.exit()