import pygame
import sys
import math
import random

# -------------------------------
# Simulation Parameters
# -------------------------------
WIDTH, HEIGHT = 800, 600
grass_count = 3000  # Thousands of blades for a lush field

# Initial wind parameters
wind_amplitude = 20    # How far things sway (pixels)
wind_speed = 2.0       # Oscillation frequency

# Debug mode flag (toggle with F3)
debug_mode = True

# -------------------------------
# Pygame Setup
# -------------------------------
pygame.init()
screen = pygame.display.set_mode((WIDTH, HEIGHT))
pygame.display.set_caption("Optimized Grass, Tree & Collision Simulation")
clock = pygame.time.Clock()
font = pygame.font.SysFont("Arial", 18)

# -------------------------------
# Collision System
# -------------------------------
class GrassCollider:
    """
    Base collider class. Entities that affect grass or leaves should extend this class.
    Collision is defined as a circle. The collider's movement is independent.
    """
    def __init__(self, pos, radius, repulsion_strength=5):
        self.pos = list(pos)  # stored as [x, y] so it can be updated independently
        self.radius = radius
        self.repulsion_strength = repulsion_strength

    def collide_point(self, x, y):
        dx = x - self.pos[0]
        dy = y - self.pos[1]
        return (dx * dx + dy * dy) < (self.radius * self.radius)

# A sample moving entity: a bouncing ball that acts as a collider.
class Ball(GrassCollider):
    def __init__(self, pos, radius, repulsion_strength=30):
        super().__init__(pos, radius, repulsion_strength)
        self.velocity = [3, 2]

    def update(self):
        # Update position and bounce off window edges.
        self.pos[0] += self.velocity[0]
        self.pos[1] += self.velocity[1]
        if self.pos[0] - self.radius < 0 or self.pos[0] + self.radius > WIDTH:
            self.velocity[0] = -self.velocity[0]
        if self.pos[1] - self.radius < 0 or self.pos[1] + self.radius > HEIGHT:
            self.velocity[1] = -self.velocity[1]

    def draw(self, surface):
        pygame.draw.circle(surface, (255, 0, 0), (int(self.pos[0]), int(self.pos[1])), self.radius)

# -------------------------------
# Grass Blade Class
# -------------------------------
class GrassBlade:
    def __init__(self, x, base_y, length, phase, tilt):
        self.x = x
        self.base_y = base_y
        self.length = length
        self.phase = phase      # For varied motion per blade
        self.tilt = tilt        # Base tilt for a natural look

    def draw(self, surface, t, wind_amp, wind_spd, colliders):
        segments = 10  # Number of segments along the blade
        points = []
        sin_arg = t * wind_spd  # Precompute common sine argument for efficiency
        for i in range(segments + 1):
            progress = i / segments
            # Base (static) position along a slightly tilted vertical line.
            base_x = self.x + progress * self.length * math.sin(self.tilt)
            base_y = self.base_y - progress * self.length * math.cos(self.tilt)
            # Wind displacement increases with height.
            displacement = wind_amp * progress * math.sin(sin_arg + self.phase)
            x_pos = base_x + displacement
            y_pos = base_y
            # For each collider, push the grass blade away if it's too close.
            for collider in colliders:
                if collider.collide_point(x_pos, y_pos):
                    dx = x_pos - collider.pos[0]
                    dy = y_pos - collider.pos[1]
                    dist = math.hypot(dx, dy)
                    if dist > 0:
                        repulsion = collider.repulsion_strength / dist
                        x_pos += repulsion * (dx / dist)
                        y_pos += repulsion * (dy / dist)
            points.append((x_pos, y_pos))
        pygame.draw.aalines(surface, (34, 139, 34), False, points)

# Initialize a list of grass blades.
grass_blades = []
def init_grass():
    for _ in range(grass_count):
        x = random.uniform(0, WIDTH)
        # Grass base near the bottom.
        base_y = HEIGHT - random.uniform(0, 20)
        length = random.uniform(50, 100)
        phase = random.uniform(0, 2 * math.pi)
        tilt = random.uniform(-0.1, 0.1)
        grass_blades.append(GrassBlade(x, base_y, length, phase, tilt))

init_grass()

def draw_grass(surface, t, wind_amp, wind_spd, colliders):
    for blade in grass_blades:
        blade.draw(surface, t, wind_amp, wind_spd, colliders)

# -------------------------------
# Tree and Leaves Drawing
# -------------------------------
# Global leaf counter for debug display.
leaf_count = 0

def draw_leaf(surface, x, y, colliders):
    global leaf_count
    # Adjust leaf position if a collider is nearby.
    for collider in colliders:
        if collider.collide_point(x, y):
            dx = x - collider.pos[0]
            dy = y - collider.pos[1]
            dist = math.hypot(dx, dy)
            if dist > 0:
                repulsion = collider.repulsion_strength / dist
                x += repulsion * (dx / dist)
                y += repulsion * (dy / dist)
    pygame.draw.circle(surface, (50, 205, 50), (int(x), int(y)), 5)
    leaf_count += 1

def draw_tree(surface, x, y, length, angle, level, t, wind_amp, wind_spd, colliders):
    if level == 0 or length < 2:
        return

    # Wind-induced sway for the branch.
    sway = math.radians(wind_amp * 0.3 * math.sin(t * wind_spd + level))
    new_angle = angle + sway

    # Calculate branch endpoint.
    end_x = x + length * math.sin(new_angle)
    end_y = y - length * math.cos(new_angle)
    
    # Apply collision repulsion at the branch endpoint.
    for collider in colliders:
        if collider.collide_point(end_x, end_y):
            dx = end_x - collider.pos[0]
            dy = end_y - collider.pos[1]
            dist = math.hypot(dx, dy)
            if dist > 0:
                repulsion = collider.repulsion_strength / dist
                end_x += repulsion * (dx / dist)
                end_y += repulsion * (dy / dist)
    
    # Draw the branch.
    thickness = max(1, level)
    branch_color = (139, 69, 19) if level > 2 else (34, 139, 34)
    pygame.draw.line(surface, branch_color, (x, y), (end_x, end_y), thickness)
    
    # Draw a leaf at lower recursion levels.
    if level <= 2:
        draw_leaf(surface, end_x, end_y, colliders)

    # Recursively draw left and right branches.
    new_length = length * 0.7
    branch_angle = math.radians(30)
    draw_tree(surface, end_x, end_y, new_length, new_angle - branch_angle, level - 1, t, wind_amp, wind_spd, colliders)
    draw_tree(surface, end_x, end_y, new_length, new_angle + branch_angle, level - 1, t, wind_amp, wind_spd, colliders)

# -------------------------------
# Create Colliders
# -------------------------------
colliders = []
# Add a bouncing ball as a collider.
ball = Ball([400, 300], 20, repulsion_strength=30)
colliders.append(ball)

# -------------------------------
# Main Loop
# -------------------------------
running = True
start_time = pygame.time.get_ticks() / 1000.0

while running:
    t = pygame.time.get_ticks() / 1000.0 - start_time

    for event in pygame.event.get():
        if event.type == pygame.QUIT:
            running = False
        elif event.type == pygame.KEYDOWN:
            # Adjust wind parameters.
            if event.key == pygame.K_UP:
                wind_amplitude += 5
            elif event.key == pygame.K_DOWN:
                wind_amplitude = max(0, wind_amplitude - 5)
            elif event.key == pygame.K_RIGHT:
                wind_speed += 0.5
            elif event.key == pygame.K_LEFT:
                wind_speed = max(0.5, wind_speed - 0.5)
            elif event.key == pygame.K_F3:
                debug_mode = not debug_mode  # Toggle debug overlay

    # Update colliders (the collider itself is not pushed by grass/leaves).
    ball.update()

    # Draw background.
    screen.fill((135, 206, 235))
    ground_rect = pygame.Rect(0, int(HEIGHT * 0.8), WIDTH, int(HEIGHT * 0.2))
    pygame.draw.rect(screen, (85, 107, 47), ground_rect)

    # Draw grass blades.
    draw_grass(screen, t, wind_amplitude, wind_speed, colliders)

    # Reset leaf counter and draw the fractal tree.
    leaf_count = 0
    tree_base_x = int(WIDTH * 0.2)
    tree_base_y = int(HEIGHT * 0.8)
    tree_length = 100
    tree_angle = 0  # vertical
    tree_levels = 7
    draw_tree(screen, tree_base_x, tree_base_y, tree_length, tree_angle, tree_levels, t, wind_amplitude, wind_speed, colliders)

    # Draw the collider (ball).
    ball.draw(screen)

    # Debug Menu overlay.
    if debug_mode:
        debug_texts = [
            f"FPS: {clock.get_fps():.1f}",
            f"Grass Blades: {len(grass_blades)}",
            f"Leaves: {leaf_count}",
            f"Colliders: {len(colliders)}",
            f"Wind Amplitude: {wind_amplitude}",
            f"Wind Speed: {wind_speed}"
        ]
        for i, text in enumerate(debug_texts):
            text_surface = font.render(text, True, (0, 0, 0))
            screen.blit(text_surface, (10, 10 + i * 20))

    pygame.display.flip()
    clock.tick(60)

pygame.quit()
sys.exit()