import pygame
import sys
import math
import random

# -----------------------------
# World and Pygame initialization
# -----------------------------
pygame.init()
SCREEN_WIDTH, SCREEN_HEIGHT = 800, 600
WORLD_WIDTH, WORLD_HEIGHT = 2000, 2000  # Larger world
BLOCK_SIZE = 40

screen = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))
pygame.display.set_caption("2D Voxel Game with Debug View Frustum & BVH")
clock = pygame.time.Clock()
font = pygame.font.Font(None, 20)

# -----------------------------
# Global Options
# -----------------------------
debug_mode = False         # Toggle debug overlay (F3)
FOV = 90                   # Field of view in degrees for ray casting
num_rays = 360              # Number of rays cast within the FOV
max_distance = 1000        # Maximum ray distance
ray_intersect_count = 0    # Reset each frame

# -----------------------------
# Voxel World Generation (each block gets a random color)
# -----------------------------
def generate_blocks():
    blocks = []
    cols = WORLD_WIDTH // BLOCK_SIZE
    rows = WORLD_HEIGHT // BLOCK_SIZE
    for i in range(cols):
        for j in range(rows):
            # 20% chance to place a block in this grid cell.
            if random.random() < 0.2:
                rect = pygame.Rect(i * BLOCK_SIZE, j * BLOCK_SIZE, BLOCK_SIZE, BLOCK_SIZE)
                color = (random.randint(50, 255), random.randint(50, 255), random.randint(50, 255))
                blocks.append((rect, color))
    return blocks

blocks = generate_blocks()

# -----------------------------
# BVH Data Structures and Build
# -----------------------------
class BVHNode:
    def __init__(self, bbox, left=None, right=None, block=None):
        self.bbox = bbox      # (min_x, min_y, max_x, max_y)
        self.left = left
        self.right = right
        self.block = block    # For leaf nodes, store the block (tuple: (pygame.Rect, color))

def rect_to_bbox(block):
    rect = block[0] if isinstance(block, tuple) else block
    return (rect.left, rect.top, rect.right, rect.bottom)

def union_bbox(b1, b2):
    return (min(b1[0], b2[0]), min(b1[1], b2[1]),
            max(b1[2], b2[2]), max(b1[3], b2[3]))

def build_bvh(block_list):
    if not block_list:
        return None
    if len(block_list) == 1:
        return BVHNode(rect_to_bbox(block_list[0]), block=block_list[0])
    block_list.sort(key=lambda block: block[0].centerx)
    mid = len(block_list) // 2
    left = build_bvh(block_list[:mid])
    right = build_bvh(block_list[mid:])
    if left and right:
        bbox = union_bbox(left.bbox, right.bbox)
    elif left:
        bbox = left.bbox
    else:
        bbox = right.bbox
    return BVHNode(bbox, left, right)

bvh_root = build_bvh(blocks)

# -----------------------------
# BVH Statistics Functions
# -----------------------------
def get_bvh_stats(node):
    if node is None:
        return (0, 0)
    if node.block is not None:
        return (1, 1)
    left_count, left_depth = get_bvh_stats(node.left)
    right_count, right_depth = get_bvh_stats(node.right)
    total = 1 + left_count + right_count
    depth = 1 + max(left_depth, right_depth)
    return (total, depth)

# -----------------------------
# Ray-AABB Intersection (Slab Method)
# -----------------------------
def ray_intersect_aabb(origin, direction, bbox):
    global ray_intersect_count
    ray_intersect_count += 1

    tmin = -math.inf
    tmax = math.inf
    ox, oy = origin
    dx, dy = direction

    # X slab
    if dx != 0:
        tx1 = (bbox[0] - ox) / dx
        tx2 = (bbox[2] - ox) / dx
        tmin = max(tmin, min(tx1, tx2))
        tmax = min(tmax, max(tx1, tx2))
    else:
        if not (bbox[0] <= ox <= bbox[2]):
            return None

    # Y slab
    if dy != 0:
        ty1 = (bbox[1] - oy) / dy
        ty2 = (bbox[3] - oy) / dy
        tmin = max(tmin, min(ty1, ty2))
        tmax = min(tmax, max(ty1, ty2))
    else:
        if not (bbox[1] <= oy <= bbox[3]):
            return None

    if tmax >= tmin and tmax >= 0:
        return tmin if tmin >= 0 else tmax
    return None

# -----------------------------
# BVH Ray Casting
# -----------------------------
def ray_cast_bvh(node, origin, direction):
    if node is None:
        return None, None
    t_bbox = ray_intersect_aabb(origin, direction, node.bbox)
    if t_bbox is None:
        return None, None

    if node.block is not None:
        t_hit = ray_intersect_aabb(origin, direction, rect_to_bbox(node.block))
        if t_hit is not None:
            return t_hit, node.block
        else:
            return None, None

    t_left, block_left = ray_cast_bvh(node.left, origin, direction)
    t_right, block_right = ray_cast_bvh(node.right, origin, direction)

    if t_left is not None and t_right is not None:
        return (t_left, block_left) if t_left < t_right else (t_right, block_right)
    elif t_left is not None:
        return t_left, block_left
    elif t_right is not None:
        return t_right, block_right
    return None, None

# -----------------------------
# Single Ray Casting (No Bouncing)
# -----------------------------
def cast_ray_single(origin, direction):
    t, hit_block = ray_cast_bvh(bvh_root, origin, direction)
    if t is None or t > max_distance:
        return None
    hit_point = (origin[0] + direction[0] * t,
                 origin[1] + direction[1] * t)
    block_color = hit_block[1] if isinstance(hit_block, tuple) else (255, 255, 255)
    return (hit_point, block_color)

# -----------------------------
# Persistent Light Map Setup (World-Sized Surface with Per-Pixel Alpha)
# -----------------------------
persistent_light_map = pygame.Surface((WORLD_WIDTH, WORLD_HEIGHT), pygame.SRCALPHA)
persistent_light_map.fill((0, 0, 0, 255))  # Start fully black (opaque)

def fade_light_map():
    # Multiply each pixel by ~254/255 (about 0.99608) per frame.
    persistent_light_map.fill((254, 254, 254, 254), special_flags=pygame.BLEND_RGBA_MULT)

# -----------------------------
# Update the Light Map from the Current Light Source within a FOV
# -----------------------------
def update_light_map(light_pos, base_angle):
    angle_step = FOV / num_rays
    for i in range(num_rays):
        angle = base_angle - FOV / 2 + i * angle_step
        rad = math.radians(angle)
        direction = (math.cos(rad), math.sin(rad))
        result = cast_ray_single(light_pos, direction)
        if result is not None:
            hit_point, hit_color = result
            persistent_light_map.set_at((int(hit_point[0]), int(hit_point[1])), hit_color)

# -----------------------------
# Debug Draw Functions
# -----------------------------
def draw_bvh(node, surface, cam_offset):
    if node is None:
        return
    x, y, x2, y2 = node.bbox
    rect = pygame.Rect(x - cam_offset[0], y - cam_offset[1], x2 - x, y2 - y)
    color = (0, 255, 0) if node.block is not None else (255, 0, 0)
    pygame.draw.rect(surface, color, rect, 1)
    draw_bvh(node.left, surface, cam_offset)
    draw_bvh(node.right, surface, cam_offset)

def draw_view_frustum(light_pos, base_angle, cam_offset):
    # Compute left and right boundary rays
    left_angle = math.radians(base_angle - FOV / 2)
    right_angle = math.radians(base_angle + FOV / 2)
    end_left = (light_pos[0] + math.cos(left_angle) * max_distance,
                light_pos[1] + math.sin(left_angle) * max_distance)
    end_right = (light_pos[0] + math.cos(right_angle) * max_distance,
                 light_pos[1] + math.sin(right_angle) * max_distance)
    # Convert to screen coords
    lp_screen = (light_pos[0] - cam_offset[0], light_pos[1] - cam_offset[1])
    el_screen = (end_left[0] - cam_offset[0], end_left[1] - cam_offset[1])
    er_screen = (end_right[0] - cam_offset[0], end_right[1] - cam_offset[1])
    # Draw boundary rays
    pygame.draw.line(screen, (0, 255, 255), lp_screen, el_screen, 1)
    pygame.draw.line(screen, (0, 255, 255), lp_screen, er_screen, 1)
    # Optionally draw the frustum polygon
    pygame.draw.polygon(screen, (0, 255, 255, 50), [lp_screen, el_screen, er_screen], 1)

def debug_draw_rays(light_pos, base_angle, cam_offset):
    angle_step = FOV / num_rays
    for i in range(num_rays):
        angle = base_angle - FOV / 2 + i * angle_step
        rad = math.radians(angle)
        direction = (math.cos(rad), math.sin(rad))
        result = cast_ray_single(light_pos, direction)
        if result is not None:
            hit_point, _ = result
            start = (light_pos[0] - cam_offset[0], light_pos[1] - cam_offset[1])
            end = (hit_point[0] - cam_offset[0], hit_point[1] - cam_offset[1])
        else:
            # No hit; extend ray to max_distance.
            end_pt = (light_pos[0] + direction[0] * max_distance,
                      light_pos[1] + direction[1] * max_distance)
            start = (light_pos[0] - cam_offset[0], light_pos[1] - cam_offset[1])
            end = (end_pt[0] - cam_offset[0], end_pt[1] - cam_offset[1])
        pygame.draw.line(screen, (255, 255, 255), start, end, 1)

# -----------------------------
# Advanced Debug Menu (No Background)
# -----------------------------
def draw_debug_menu(surface, fps):
    bvh_node_count, bvh_depth = get_bvh_stats(bvh_root)
    debug_lines = [
        "DEBUG MODE ACTIVE",
        f"FPS: {fps:.1f}",
        f"Blocks: {len(blocks)}",
        f"BVH Nodes: {bvh_node_count}",
        f"BVH Depth: {bvh_depth}",
        f"Player Pos: ({player_rect.centerx}, {player_rect.centery})",
        f"Rays Cast (FOV): {num_rays}",
        f"Ray Intersection Tests: {ray_intersect_count}",
        f"Avg Tests per Ray: {ray_intersect_count / num_rays:.2f}",
        "Toggles: F3 - Debug, R - Regenerate World"
    ]
    y_offset = 10
    for line in debug_lines:
        text_surf = font.render(line, True, (255, 255, 255))
        surface.blit(text_surf, (10, y_offset))
        y_offset += text_surf.get_height() + 2

# -----------------------------
# Player Setup and Collision
# -----------------------------
player_speed = 4
player_rect = pygame.Rect(WORLD_WIDTH // 2, WORLD_HEIGHT // 2, 20, 20)

def move_player(dx, dy):
    player_rect.x += dx
    for block, _ in blocks:
        if player_rect.colliderect(block):
            if dx > 0:
                player_rect.right = block.left
            elif dx < 0:
                player_rect.left = block.right
    player_rect.y += dy
    for block, _ in blocks:
        if player_rect.colliderect(block):
            if dy > 0:
                player_rect.bottom = block.top
            elif dy < 0:
                player_rect.top = block.bottom

# -----------------------------
# Camera Setup
# -----------------------------
def get_camera_offset():
    cam_x = player_rect.centerx - SCREEN_WIDTH // 2
    cam_y = player_rect.centery - SCREEN_HEIGHT // 2
    return cam_x, cam_y

# -----------------------------
# Main Game Loop
# -----------------------------
running = True
while running:
    ray_intersect_count = 0  # Reset each frame
    
    for event in pygame.event.get():
        if event.type == pygame.QUIT:
            running = False
        elif event.type == pygame.KEYDOWN:
            if event.key == pygame.K_F3:
                debug_mode = not debug_mode
            elif event.key == pygame.K_r:
                blocks = generate_blocks()
                bvh_root = build_bvh(blocks)
    
    # Process player movement.
    keys = pygame.key.get_pressed()
    dx = dy = 0
    if keys[pygame.K_LEFT] or keys[pygame.K_a]:
        dx = -player_speed
    if keys[pygame.K_RIGHT] or keys[pygame.K_d]:
        dx = player_speed
    if keys[pygame.K_UP] or keys[pygame.K_w]:
        dy = -player_speed
    if keys[pygame.K_DOWN] or keys[pygame.K_s]:
        dy = player_speed
    move_player(dx, dy)
    
    # Light source is the player's center.
    light_source = player_rect.center
    camera_offset = get_camera_offset()
    
    # Compute the world coordinate of the mouse cursor.
    mouse_screen = pygame.mouse.get_pos()
    mouse_world = (mouse_screen[0] + camera_offset[0], mouse_screen[1] + camera_offset[1])
    delta_x = mouse_world[0] - light_source[0]
    delta_y = mouse_world[1] - light_source[1]
    base_angle = math.degrees(math.atan2(delta_y, delta_x))
    
    # Fade the persistent light map.
    fade_light_map()
    # Update the persistent light map with new collision pixels (only within the FOV).
    update_light_map(light_source, base_angle)
    
    # Clear the screen.
    screen.fill((30, 30, 30))
    # Blit the persistent light map with camera offset.
    screen.blit(persistent_light_map, (-camera_offset[0], -camera_offset[1]))
    
    # Draw the player.
    pygame.draw.rect(screen, (255, 255, 0),
                     (player_rect.x - camera_offset[0], player_rect.y - camera_offset[1],
                      player_rect.width, player_rect.height))
    
    if debug_mode:
        # Draw BVH visualizer.
        draw_bvh(bvh_root, screen, camera_offset)
        # Draw the view frustum.
        draw_view_frustum(light_source, base_angle, camera_offset)
        # Draw each ray cast within the FOV.
        debug_draw_rays(light_source, base_angle, camera_offset)
        # Draw debug menu.
        fps = clock.get_fps()
        draw_debug_menu(screen, fps)
    
    pygame.display.flip()
    clock.tick(60)

pygame.quit()
sys.exit()