#include "Engine.h"
#include <iostream>
#include <glm/gtc/type_ptr.hpp>
#define STB_IMAGE_IMPLEMENTATION
#include "stb_image.h"

// Static member definitions.
GLFWwindow* Engine::window = nullptr;
GLuint Engine::framebuffer = 0;
GLuint Engine::colorTexture = 0;
GLuint Engine::depthRenderbuffer = 0;
GLuint Engine::cubeVAO = 0;
GLuint Engine::cubeVBO = 0;
GLuint Engine::shaderProgram = 0;
float Engine::rotationAngle = 0.0f;
int Engine::fbWidth = 640;
int Engine::fbHeight = 400;

GLuint normalMapTexture = 0; // Global texture for the normal map

bool Engine::Init() {
    if (!glfwInit()) {
        std::cout << "Failed to initialize GLFW\n";
        return false;
    }
    glfwWindowHint(GLFW_CONTEXT_VERSION_MAJOR, 3);
    glfwWindowHint(GLFW_CONTEXT_VERSION_MINOR, 3);
    glfwWindowHint(GLFW_OPENGL_PROFILE, GLFW_OPENGL_CORE_PROFILE);

    window = glfwCreateWindow(1280, 800, "Engine Window", nullptr, nullptr);
    if (!window) {
        std::cout << "Failed to create GLFW window\n";
        glfwTerminate();
        return false;
    }
    glfwMakeContextCurrent(window);

    glewExperimental = GL_TRUE;
    if (glewInit() != GLEW_OK) {
        std::cout << "Failed to initialize GLEW\n";
        return false;
    }

    int width, height;
    glfwGetFramebufferSize(window, &width, &height);
    glViewport(0, 0, width, height);

    // Create framebuffer.
    glGenFramebuffers(1, &framebuffer);
    ResizeFramebuffer(fbWidth, fbHeight);

    // Setup cube geometry, shaders, and load normal map.
    if (!SetupScene()) {
        std::cout << "Failed to set up scene\n";
        return false;
    }
    return true;
}

GLFWwindow* Engine::GetWindow() {
    return window;
}

void Engine::ResizeFramebuffer(int width, int height) {
    // Avoid division by zero.
    if (height <= 0)
        height = 1;

    // Define the desired target aspect ratio (e.g., 16:9).
    const float targetAspect = 16.0f / 9.0f;
    float currentAspect = static_cast<float>(width) / static_cast<float>(height);
    
    // Adjust dimensions to maintain the target aspect ratio.
    int newWidth = width;
    int newHeight = height;
    if (currentAspect > targetAspect) {
        // Too wide: adjust width.
        newWidth = static_cast<int>(height * targetAspect);
    } else if (currentAspect < targetAspect) {
        // Too tall: adjust height.
        newHeight = static_cast<int>(width / targetAspect);
    }
    
    fbWidth = newWidth;
    fbHeight = newHeight;
    
    glBindFramebuffer(GL_FRAMEBUFFER, framebuffer);

    // Delete old attachments if they exist.
    if (colorTexture) {
        glDeleteTextures(1, &colorTexture);
    }
    if (depthRenderbuffer) {
        glDeleteRenderbuffers(1, &depthRenderbuffer);
    }

    // Create color texture using GL_RGBA.
    glGenTextures(1, &colorTexture);
    glBindTexture(GL_TEXTURE_2D, colorTexture);
    glTexImage2D(GL_TEXTURE_2D, 0, GL_RGBA, fbWidth, fbHeight, 0, GL_RGBA, GL_UNSIGNED_BYTE, NULL);
    glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_LINEAR);
    glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_LINEAR);
    glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_TEXTURE_2D, colorTexture, 0);

    // Create depth renderbuffer.
    glGenRenderbuffers(1, &depthRenderbuffer);
    glBindRenderbuffer(GL_RENDERBUFFER, depthRenderbuffer);
    glRenderbufferStorage(GL_RENDERBUFFER, GL_DEPTH24_STENCIL8, fbWidth, fbHeight);
    glFramebufferRenderbuffer(GL_FRAMEBUFFER, GL_DEPTH_STENCIL_ATTACHMENT, GL_RENDERBUFFER, depthRenderbuffer);

    GLenum status = glCheckFramebufferStatus(GL_FRAMEBUFFER);
    if (status != GL_FRAMEBUFFER_COMPLETE) {
        std::cout << "Framebuffer is not complete! Status: " << status << std::endl;
    }
    glBindFramebuffer(GL_FRAMEBUFFER, 0);
}

GLuint Engine::CompileShader(const char* vertexSrc, const char* fragmentSrc) {
    GLuint vertexShader = glCreateShader(GL_VERTEX_SHADER);
    glShaderSource(vertexShader, 1, &vertexSrc, nullptr);
    glCompileShader(vertexShader);
    int success;
    glGetShaderiv(vertexShader, GL_COMPILE_STATUS, &success);
    if (!success) {
        char infoLog[512];
        glGetShaderInfoLog(vertexShader, 512, nullptr, infoLog);
        std::cout << "Vertex shader compilation failed: " << infoLog << std::endl;
        return 0;
    }
    GLuint fragmentShader = glCreateShader(GL_FRAGMENT_SHADER);
    glShaderSource(fragmentShader, 1, &fragmentSrc, nullptr);
    glCompileShader(fragmentShader);
    glGetShaderiv(fragmentShader, GL_COMPILE_STATUS, &success);
    if (!success) {
        char infoLog[512];
        glGetShaderInfoLog(fragmentShader, 512, nullptr, infoLog);
        std::cout << "Fragment shader compilation failed: " << infoLog << std::endl;
        return 0;
    }
    GLuint program = glCreateProgram();
    glAttachShader(program, vertexShader);
    glAttachShader(program, fragmentShader);
    glLinkProgram(program);
    glGetProgramiv(program, GL_LINK_STATUS, &success);
    if (!success) {
        char infoLog[512];
        glGetProgramInfoLog(program, 512, nullptr, infoLog);
        std::cout << "Shader program linking failed: " << infoLog << std::endl;
        return 0;
    }
    glDeleteShader(vertexShader);
    glDeleteShader(fragmentShader);
    return program;
}

bool Engine::SetupScene() {
    // Updated cube vertices: each vertex has:
    // position (3), normal (3), texcoord (2), tangent (3) = 11 floats per vertex.
    float vertices[] = {
        // Front face (normal: 0,0,1), tangent: (1,0,0)
        // positions         normals       texcoords   tangent
        -0.5f, -0.5f,  0.5f,  0,0,1,       0,0,       1,0,0,
         0.5f, -0.5f,  0.5f,  0,0,1,       1,0,       1,0,0,
         0.5f,  0.5f,  0.5f,  0,0,1,       1,1,       1,0,0,
         0.5f,  0.5f,  0.5f,  0,0,1,       1,1,       1,0,0,
        -0.5f,  0.5f,  0.5f,  0,0,1,       0,1,       1,0,0,
        -0.5f, -0.5f,  0.5f,  0,0,1,       0,0,       1,0,0,

        // Back face (normal: 0,0,-1), tangent: (-1,0,0)
         0.5f, -0.5f, -0.5f,  0,0,-1,      0,0,      -1,0,0,
        -0.5f, -0.5f, -0.5f,  0,0,-1,      1,0,      -1,0,0,
        -0.5f,  0.5f, -0.5f,  0,0,-1,      1,1,      -1,0,0,
        -0.5f,  0.5f, -0.5f,  0,0,-1,      1,1,      -1,0,0,
         0.5f,  0.5f, -0.5f,  0,0,-1,      0,1,      -1,0,0,
         0.5f, -0.5f, -0.5f,  0,0,-1,      0,0,      -1,0,0,

        // Left face (normal: -1,0,0), tangent: (0,0,-1)
        -0.5f, -0.5f, -0.5f, -1,0,0,       0,0,       0,0,-1,
        -0.5f, -0.5f,  0.5f, -1,0,0,       1,0,       0,0,-1,
        -0.5f,  0.5f,  0.5f, -1,0,0,       1,1,       0,0,-1,
        -0.5f,  0.5f,  0.5f, -1,0,0,       1,1,       0,0,-1,
        -0.5f,  0.5f, -0.5f, -1,0,0,       0,1,       0,0,-1,
        -0.5f, -0.5f, -0.5f, -1,0,0,       0,0,       0,0,-1,

        // Right face (normal: 1,0,0), tangent: (0,0,1)
         0.5f, -0.5f,  0.5f,  1,0,0,       0,0,       0,0,1,
         0.5f, -0.5f, -0.5f,  1,0,0,       1,0,       0,0,1,
         0.5f,  0.5f, -0.5f,  1,0,0,       1,1,       0,0,1,
         0.5f,  0.5f, -0.5f,  1,0,0,       1,1,       0,0,1,
         0.5f,  0.5f,  0.5f,  1,0,0,       0,1,       0,0,1,
         0.5f, -0.5f,  0.5f,  1,0,0,       0,0,       0,0,1,

        // Top face (normal: 0,1,0), tangent: (1,0,0)
        -0.5f,  0.5f,  0.5f,  0,1,0,       0,0,       1,0,0,
         0.5f,  0.5f,  0.5f,  0,1,0,       1,0,       1,0,0,
         0.5f,  0.5f, -0.5f,  0,1,0,       1,1,       1,0,0,
         0.5f,  0.5f, -0.5f,  0,1,0,       1,1,       1,0,0,
        -0.5f,  0.5f, -0.5f,  0,1,0,       0,1,       1,0,0,
        -0.5f,  0.5f,  0.5f,  0,1,0,       0,0,       1,0,0,

        // Bottom face (normal: 0,-1,0), tangent: (1,0,0)
        -0.5f, -0.5f, -0.5f,  0,-1,0,      0,0,       1,0,0,
         0.5f, -0.5f, -0.5f,  0,-1,0,      1,0,       1,0,0,
         0.5f, -0.5f,  0.5f,  0,-1,0,      1,1,       1,0,0,
         0.5f, -0.5f,  0.5f,  0,-1,0,      1,1,       1,0,0,
        -0.5f, -0.5f,  0.5f,  0,-1,0,      0,1,       1,0,0,
        -0.5f, -0.5f, -0.5f,  0,-1,0,      0,0,       1,0,0
    };

    glGenVertexArrays(1, &cubeVAO);
    glGenBuffers(1, &cubeVBO);
    glBindVertexArray(cubeVAO);
    glBindBuffer(GL_ARRAY_BUFFER, cubeVBO);
    glBufferData(GL_ARRAY_BUFFER, sizeof(vertices), vertices, GL_STATIC_DRAW);
    // Position attribute.
    glVertexAttribPointer(0, 3, GL_FLOAT, GL_FALSE, 11 * sizeof(float), (void*)0);
    glEnableVertexAttribArray(0);
    // Normal attribute.
    glVertexAttribPointer(1, 3, GL_FLOAT, GL_FALSE, 11 * sizeof(float), (void*)(3 * sizeof(float)));
    glEnableVertexAttribArray(1);
    // Texcoord attribute.
    glVertexAttribPointer(2, 2, GL_FLOAT, GL_FALSE, 11 * sizeof(float), (void*)(6 * sizeof(float)));
    glEnableVertexAttribArray(2);
    // Tangent attribute.
    glVertexAttribPointer(3, 3, GL_FLOAT, GL_FALSE, 11 * sizeof(float), (void*)(8 * sizeof(float)));
    glEnableVertexAttribArray(3);
    glBindVertexArray(0);

    // Vertex shader source.
    const char* vertexShaderSrc = R"(
        #version 330 core
        layout(location = 0) in vec3 aPos;
        layout(location = 1) in vec3 aNormal;
        layout(location = 2) in vec2 aTexCoords;
        layout(location = 3) in vec3 aTangent;

        uniform mat4 model;
        uniform mat4 view;
        uniform mat4 projection;

        out vec3 FragPos;
        out vec3 Normal;
        out vec2 TexCoords;
        out vec3 Tangent;

        void main() {
            FragPos = vec3(model * vec4(aPos, 1.0));
            Normal = mat3(transpose(inverse(model))) * aNormal;
            TexCoords = aTexCoords;
            Tangent = mat3(model) * aTangent; 
            gl_Position = projection * view * vec4(FragPos, 1.0);
        }
    )";

    // Fragment shader source with normal mapping.
    const char* fragmentShaderSrc = R"(
        #version 330 core
        out vec4 FragColor;

        in vec3 FragPos;
        in vec3 Normal;
        in vec2 TexCoords;
        in vec3 Tangent;

        uniform vec3 lightPos;
        uniform vec3 viewPos;
        uniform sampler2D normalMap;

        void main() {
            // Obtain the normal from the normal map in range [0,1] and remap to [-1,1].
            vec3 normMap = texture(normalMap, TexCoords).rgb;
            normMap = normalize(normMap);
            
            // Calculate TBN matrix. For simplicity, compute bitangent as cross(normal, tangent).
            vec3 N = normalize(Normal);
            vec3 T = normalize(Tangent);
            vec3 B = normalize(cross(N, T));
            mat3 TBN = mat3(T, B, N);
            vec3 perturbedNormal = normalize(TBN * normMap);

            // Ambient.
            float ambientStrength = 0.5;
            vec3 ambient = ambientStrength * vec3(1.0, 0.0, 1.0);

            // Diffuse.
            vec3 lightDir = normalize(lightPos - FragPos);
            float diff = max(dot(perturbedNormal, lightDir), 0.0);
            vec3 diffuse = diff * vec3(1.0, 1.0, 1.0);

            // Specular.
            float specularStrength = 0.2;
            vec3 viewDir = normalize(viewPos - FragPos);
            vec3 reflectDir = reflect(-lightDir, perturbedNormal);
            float spec = pow(max(dot(viewDir, reflectDir), 0.0), 32);
            vec3 specular = specularStrength * spec * vec3(1.0);

            vec3 result = ambient + diffuse + specular;
            FragColor = vec4(result, 1.0);
        }
    )";

    shaderProgram = CompileShader(vertexShaderSrc, fragmentShaderSrc);
    if (shaderProgram == 0) {
        return false;
    }

    // Load normal map from "./normal.png"
    int texWidth, texHeight, texChannels;
    unsigned char* data = stbi_load("./normal.jpg", &texWidth, &texHeight, &texChannels, 0);
    if (!data) {
        std::cout << "Failed to load normal map." << std::endl;
        return false;
    }
    glGenTextures(1, &normalMapTexture);
    glBindTexture(GL_TEXTURE_2D, normalMapTexture);
    GLenum format = (texChannels == 3) ? GL_RGB : GL_RGBA;
    glTexImage2D(GL_TEXTURE_2D, 0, format, texWidth, texHeight, 0, format, GL_UNSIGNED_BYTE, data);
    glGenerateMipmap(GL_TEXTURE_2D);
    // Set texture parameters.
    glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_REPEAT);
    glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_REPEAT);
    glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_LINEAR_MIPMAP_LINEAR);
    glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_LINEAR);
    stbi_image_free(data);

    // Bind the normal map to texture unit 1.
    glUseProgram(shaderProgram);
    glUniform1i(glGetUniformLocation(shaderProgram, "normalMap"), 1);

    return true;
}

ImTextureID Engine::RenderScene(const glm::mat4 &view, const glm::mat4 &projection, const glm::vec3 &viewPos) {
    glBindFramebuffer(GL_FRAMEBUFFER, framebuffer);
    glViewport(0, 0, fbWidth, fbHeight);
    glEnable(GL_DEPTH_TEST);
    glClearColor(0.1f, 0.1f, 0.1f, 1.0f);
    glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT);

    glUseProgram(shaderProgram);
    glUniformMatrix4fv(glGetUniformLocation(shaderProgram, "view"), 1, GL_FALSE, glm::value_ptr(view));
    glUniformMatrix4fv(glGetUniformLocation(shaderProgram, "projection"), 1, GL_FALSE, glm::value_ptr(projection));
    glUniform3f(glGetUniformLocation(shaderProgram, "viewPos"), viewPos.x, viewPos.y, viewPos.z);

    rotationAngle += 0.001f;
    glm::mat4 model = glm::rotate(glm::mat4(1.0f), rotationAngle, glm::vec3(1,1,0));
    glUniformMatrix4fv(glGetUniformLocation(shaderProgram, "model"), 1, GL_FALSE, glm::value_ptr(model));

    glUniform3f(glGetUniformLocation(shaderProgram, "lightPos"), 0.0f, 20.0f, 0.0f);

    // Activate texture unit 1 and bind the normal map.
    glActiveTexture(GL_TEXTURE1);
    glBindTexture(GL_TEXTURE_2D, normalMapTexture);

    glBindVertexArray(cubeVAO);
    glDrawArrays(GL_TRIANGLES, 0, 36);
    glBindVertexArray(0);

    glBindFramebuffer(GL_FRAMEBUFFER, 0);
    return (ImTextureID)(intptr_t)colorTexture;
}

ImTextureID Engine::GetFinalRenderingTexture() {
    return (ImTextureID)(intptr_t)colorTexture;
}

void Engine::Shutdown() {
    glDeleteVertexArrays(1, &cubeVAO);
    glDeleteBuffers(1, &cubeVBO);
    glDeleteProgram(shaderProgram);
    glDeleteFramebuffers(1, &framebuffer);
    glDeleteTextures(1, &colorTexture);
    glDeleteTextures(1, &normalMapTexture);
    glDeleteRenderbuffers(1, &depthRenderbuffer);
    glfwDestroyWindow(window);
    glfwTerminate();
}