import pygame
import sys
import math
import random

# -----------------------------
# Pygame and world initialization
# -----------------------------
pygame.init()
SCREEN_WIDTH, SCREEN_HEIGHT = 800, 600
BLOCK_SIZE = 40
screen = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))
pygame.display.set_caption("2D Voxel Game with Advanced Debug Visualization")
clock = pygame.time.Clock()
font = pygame.font.Font(None, 20)

# -----------------------------
# Global debug options
# -----------------------------
debug_mode = False         # Toggle full debug overlay (F3)
show_bvh = False           # Toggle drawing BVH overlay (D)
num_rays = 360 * 2             # Rays cast for lighting
max_distance = 1000        # Maximum ray distance if nothing is hit

# Global counter for ray intersection tests
ray_intersect_count = 0

# -----------------------------
# Create a simple voxel world
# -----------------------------
# Build a grid of blocks (some cells are solid, some are empty)
def generate_blocks():
    blocks = []
    cols = SCREEN_WIDTH // BLOCK_SIZE
    rows = SCREEN_HEIGHT // BLOCK_SIZE
    for i in range(cols):
        for j in range(rows):
            # For demo purposes, randomly assign some blocks as solid (20% chance)
            if random.random() < 0.2:
                blocks.append(pygame.Rect(i * BLOCK_SIZE, j * BLOCK_SIZE, BLOCK_SIZE, BLOCK_SIZE))
    return blocks

blocks = generate_blocks()

# -----------------------------
# BVH Data Structures and Build
# -----------------------------
# BVH Node: holds a bounding box (min_x, min_y, max_x, max_y) and either children or a solid block.
class BVHNode:
    def __init__(self, bbox, left=None, right=None, block=None):
        self.bbox = bbox      # Tuple: (min_x, min_y, max_x, max_y)
        self.left = left
        self.right = right
        self.block = block    # For leaf nodes, store the actual block (pygame.Rect)

def rect_to_bbox(rect):
    return (rect.left, rect.top, rect.right, rect.bottom)

def union_bbox(b1, b2):
    x1 = min(b1[0], b2[0])
    y1 = min(b1[1], b2[1])
    x2 = max(b1[2], b2[2])
    y2 = max(b1[3], b2[3])
    return (x1, y1, x2, y2)

def build_bvh(block_list):
    if not block_list:
        return None
    if len(block_list) == 1:
        return BVHNode(rect_to_bbox(block_list[0]), block=block_list[0])
    # Sort blocks by center x-coordinate (could also choose y or alternate)
    block_list.sort(key=lambda rect: rect.centerx)
    mid = len(block_list) // 2
    left = build_bvh(block_list[:mid])
    right = build_bvh(block_list[mid:])
    if left and right:
        bbox = union_bbox(left.bbox, right.bbox)
    elif left:
        bbox = left.bbox
    else:
        bbox = right.bbox
    return BVHNode(bbox, left, right)

bvh_root = build_bvh(blocks)

# -----------------------------
# BVH Statistics Functions
# -----------------------------
def get_bvh_stats(node):
    """Return (node_count, depth) for the given BVH."""
    if node is None:
        return (0, 0)
    if node.block is not None:
        return (1, 1)
    left_count, left_depth = get_bvh_stats(node.left)
    right_count, right_depth = get_bvh_stats(node.right)
    total = 1 + left_count + right_count
    depth = 1 + max(left_depth, right_depth)
    return (total, depth)

# -----------------------------
# Ray-AABB Intersection (Slab Method)
# -----------------------------
def ray_intersect_aabb(origin, direction, bbox):
    # Increase global counter each time this function is called.
    global ray_intersect_count
    ray_intersect_count += 1

    # bbox: (min_x, min_y, max_x, max_y)
    tmin = -math.inf
    tmax = math.inf
    ox, oy = origin
    dx, dy = direction

    # X slab
    if dx != 0:
        tx1 = (bbox[0] - ox) / dx
        tx2 = (bbox[2] - ox) / dx
        tmin = max(tmin, min(tx1, tx2))
        tmax = min(tmax, max(tx1, tx2))
    else:
        if not (bbox[0] <= ox <= bbox[2]):
            return None

    # Y slab
    if dy != 0:
        ty1 = (bbox[1] - oy) / dy
        ty2 = (bbox[3] - oy) / dy
        tmin = max(tmin, min(ty1, ty2))
        tmax = min(tmax, max(ty1, ty2))
    else:
        if not (bbox[1] <= oy <= bbox[3]):
            return None

    if tmax >= tmin and tmax >= 0:
        return tmin if tmin >= 0 else tmax
    return None

# -----------------------------
# BVH Ray Casting
# -----------------------------
def ray_cast_bvh(node, origin, direction):
    if node is None:
        return None, None
    # Test if ray intersects node's bounding box.
    t_bbox = ray_intersect_aabb(origin, direction, node.bbox)
    if t_bbox is None:
        return None, None

    # If leaf node, test collision with the block.
    if node.block is not None:
        t_hit = ray_intersect_aabb(origin, direction, rect_to_bbox(node.block))
        if t_hit is not None:
            return t_hit, node.block
        else:
            return None, None

    # Otherwise, recursively test both children.
    t_left, block_left = ray_cast_bvh(node.left, origin, direction)
    t_right, block_right = ray_cast_bvh(node.right, origin, direction)

    if t_left is not None and t_right is not None:
        if t_left < t_right:
            return t_left, block_left
        else:
            return t_right, block_right
    elif t_left is not None:
        return t_left, block_left
    elif t_right is not None:
        return t_right, block_right
    return None, None

# -----------------------------
# Ray-Based Lighting Function
# -----------------------------
# Cast many rays from the light source (controlled by the mouse) and collect hit points.
light_source = (SCREEN_WIDTH // 2, SCREEN_HEIGHT // 2)

def cast_light_rays(light_pos):
    points = []
    # Cast one ray per degree
    for angle in range(0, 360, num_rays // num_rays):
        rad = math.radians(angle)
        direction = (math.cos(rad), math.sin(rad))
        t, _ = ray_cast_bvh(bvh_root, light_pos, direction)
        if t is None:
            t = max_distance
        hit_x = light_pos[0] + direction[0] * t
        hit_y = light_pos[1] + direction[1] * t
        points.append((hit_x, hit_y))
    return points

# -----------------------------
# Advanced Debug Menu Drawing
# -----------------------------
def draw_debug_menu(surface, fps):
    # Gather BVH statistics.
    bvh_node_count, bvh_depth = get_bvh_stats(bvh_root)
    # Prepare debug information lines.
    debug_lines = [
        "DEBUG MODE ACTIVE",
        f"FPS: {fps:.1f}",
        f"Blocks: {len(blocks)}",
        f"BVH Nodes: {bvh_node_count}",
        f"BVH Depth: {bvh_depth}",
        f"Light Source: {light_source}",
        f"Rays Cast: {num_rays}",
        f"Ray Intersection Tests (this frame): {ray_intersect_count}",
        f"Avg Tests per Ray: {ray_intersect_count / num_rays:.2f}",
        "Toggles: F3 - Debug, D - BVH overlay, R - Regenerate World"
    ]
    # Create a semi-transparent panel background.
    panel_width = 300
    panel_height = (len(debug_lines) * 22) + 10
    panel = pygame.Surface((panel_width, panel_height))
    panel.set_alpha(180)
    panel.fill((0, 0, 0))
    # Render text lines onto the panel.
    y_offset = 5
    for line in debug_lines:
        text_surf = font.render(line, True, (255, 255, 255))
        panel.blit(text_surf, (5, y_offset))
        y_offset += 22
    # Blit the debug panel onto the main surface.
    surface.blit(panel, (10, 10))

# -----------------------------
# BVH Visualization (recursive drawing)
# -----------------------------
def draw_bvh(node, surface):
    if node is None:
        return
    x, y, x2, y2 = node.bbox
    rect = pygame.Rect(x, y, x2 - x, y2 - y)
    # Draw leaf nodes in green; internal nodes in red.
    color = (0, 255, 0) if node.block is not None else (255, 0, 0)
    pygame.draw.rect(surface, color, rect, 1)
    draw_bvh(node.left, surface)
    draw_bvh(node.right, surface)

# -----------------------------
# Main Game Loop
# -----------------------------
running = True
while running:
    # Reset ray intersection counter each frame.
    ray_intersect_count = 0

    for event in pygame.event.get():
        if event.type == pygame.QUIT:
            running = False
        elif event.type == pygame.KEYDOWN:
            # Toggle full debug mode with F3.
            if event.key == pygame.K_F3:
                debug_mode = not debug_mode
            # Toggle BVH overlay with D.
            elif event.key == pygame.K_d:
                show_bvh = not show_bvh
            # Regenerate the world and rebuild the BVH with R.
            elif event.key == pygame.K_r:
                blocks = generate_blocks()
                bvh_root = build_bvh(blocks)
        # Update light source with mouse motion.
        elif event.type == pygame.MOUSEMOTION:
            light_source = event.pos

    # Clear screen.
    screen.fill((30, 30, 30))

    # Draw voxel blocks (solid cells).
    for block in blocks:
        pygame.draw.rect(screen, (100, 100, 100), block)

    # Optionally draw the BVH overlay.
    if show_bvh and bvh_root:
        draw_bvh(bvh_root, screen)

    # Cast light rays from the light source.
    light_polygon = cast_light_rays(light_source)
    # Create a separate surface for lighting effects.
    light_surface = pygame.Surface((SCREEN_WIDTH, SCREEN_HEIGHT))
    light_surface.fill((0, 0, 0))
    pygame.draw.polygon(light_surface, (255, 255, 200), light_polygon)
    light_surface.set_alpha(180)
    screen.blit(light_surface, (0, 0), special_flags=pygame.BLEND_ADD)

    # Draw the light source.
    pygame.draw.circle(screen, (255, 255, 0), light_source, 5)

    # Draw advanced debug overlay if debug mode is enabled.
    if debug_mode:
        fps = clock.get_fps()
        draw_debug_menu(screen, fps)

    pygame.display.flip()
    clock.tick(60)

pygame.quit()
sys.exit()