import pygame
import sys
import random

# Initialize Pygame and set up the window
pygame.init()
screen = pygame.display.set_mode((800, 600))
pygame.display.set_caption("Simple 2D BVH Visualization")
clock = pygame.time.Clock()
font = pygame.font.Font(None, 24)

# Utility functions to work with bounding boxes.
# We represent a bbox as (min_x, min_y, max_x, max_y).
def rect_to_bbox(rect):
    x, y, w, h = rect
    return (x, y, x + w, y + h)

def union_bbox(bbox1, bbox2):
    x1 = min(bbox1[0], bbox2[0])
    y1 = min(bbox1[1], bbox2[1])
    x2 = max(bbox1[2], bbox2[2])
    y2 = max(bbox1[3], bbox2[3])
    return (x1, y1, x2, y2)

# BVH Node class
class BVHNode:
    def __init__(self, bbox, left=None, right=None, obj=None):
        self.bbox = bbox      # Bounding box: (min_x, min_y, max_x, max_y)
        self.left = left      # Left child (BVHNode or None)
        self.right = right    # Right child (BVHNode or None)
        self.obj = obj        # For leaf nodes, store the actual rectangle

# Recursive BVH build function.
# If there's one object, return a leaf node.
# Otherwise, sort the objects by their center x-coordinate, split in half,
# build left/right children, and compute the union of their bounding boxes.
def build_bvh(objects):
    if not objects:
        return None
    if len(objects) == 1:
        return BVHNode(rect_to_bbox(objects[0]), obj=objects[0])
    
    # Sort objects by center x-coordinate
    objects.sort(key=lambda rect: rect[0] + rect[2] / 2)
    mid = len(objects) // 2
    left = build_bvh(objects[:mid])
    right = build_bvh(objects[mid:])
    node_bbox = union_bbox(left.bbox, right.bbox)
    return BVHNode(node_bbox, left, right)

# Generate a list of random rectangles.
def generate_rectangles(num_rects):
    rects = []
    for _ in range(num_rects):
        x = random.randint(50, 750)
        y = random.randint(50, 550)
        w = random.randint(20, 100)
        h = random.randint(20, 100)
        rects.append((x, y, w, h))
    return rects

num_rects = 20
rectangles = generate_rectangles(num_rects)
bvh_root = build_bvh(rectangles)

# Toggle for showing the BVH overlay
show_bvh = True

# Recursively draw BVH nodes.
def draw_bvh(node, surface):
    if node is None:
        return
    # Convert bbox from (min_x, min_y, max_x, max_y) to pygame.Rect
    x, y, x2, y2 = node.bbox
    rect = pygame.Rect(x, y, x2 - x, y2 - y)
    # Use green for leaf nodes and red for internal nodes.
    color = (0, 255, 0) if node.obj is not None else (255, 0, 0)
    pygame.draw.rect(surface, color, rect, 1)
    draw_bvh(node.left, surface)
    draw_bvh(node.right, surface)

# Main loop.
running = True
while running:
    for event in pygame.event.get():
        if event.type == pygame.QUIT:
            running = False
        elif event.type == pygame.KEYDOWN:
            # Toggle BVH display with D
            if event.key == pygame.K_d:
                show_bvh = not show_bvh
            # Regenerate rectangles and rebuild BVH with R
            elif event.key == pygame.K_r:
                rectangles = generate_rectangles(num_rects)
                bvh_root = build_bvh(rectangles)

    screen.fill((30, 30, 30))

    # Draw each rectangle (the "objects") in white.
    for rect in rectangles:
        pygame.draw.rect(screen, (255, 255, 255), rect, 2)
    
    # Optionally draw the BVH overlay.
    if show_bvh and bvh_root:
        draw_bvh(bvh_root, screen)
    
    # Debug text instructions.
    debug_text = font.render("Press D to toggle BVH overlay, R to regenerate rectangles", True, (255, 255, 255))
    screen.blit(debug_text, (10, 10))
    
    pygame.display.flip()
    clock.tick(60)

pygame.quit()
sys.exit()