#include <iostream>
#include <vector>
#include <cstdint>
#include <cstring>
#include <string>
#include <iomanip>
#include <sstream>
#include <map>
#include "imgui.h"
#include "imgui_impl_glfw.h"
#include "imgui_impl_opengl3.h"
#include <GLFW/glfw3.h>
#include "imgui_memory_editor.h" // Add this header for the Memory Editor

#define SCREEN_WIDTH 128
#define SCREEN_HEIGHT 64
#define SCREEN_PIXELS (SCREEN_WIDTH * SCREEN_HEIGHT)
#define SCREEN_MEM_START 0x1000
#define SCREEN_MEM_SIZE (SCREEN_PIXELS * 3)  // RGB888

// ========== CPU ==========
class CPU {
public:
    uint8_t A = 0;
    uint8_t CMPF = 0;
    uint16_t PC = 0;
    bool running = true;
    std::vector<uint8_t> memory;

    CPU() : memory(65536, 0) {}

    void loadProgram(const std::vector<uint8_t>& program) {
        std::fill(memory.begin(), memory.end(), 0);
        for (size_t i = 0; i < program.size(); i++)
            memory[i] = program[i];
        PC = 0;
        A = 0;
        CMPF = 0;
        running = true;
    }

    void step() {
        uint8_t opcode = memory[PC++];
        switch (opcode) {
            case 0x01: A = memory[PC++]; break;                  // LDA #val
            case 0x02: A += memory[PC++]; break;                 // ADD #val
            case 0x03: memory[memory[PC++]] = A; break;          // STA addr
            case 0x04: A = memory[memory[PC++]]; break;          // LDM addr
            case 0x05: {
                uint8_t addr = memory[PC++];
                PC = addr;
                break;
            }
            case 0x06: {
                uint8_t addr = memory[PC++];
                if (A == 0) PC = addr;
                break;
            }
            case 0x07: running = false; break;                   // HLT
            case 0x08: A -= memory[PC++]; break;                 // SUB #val
            case 0x09: {
                uint8_t addr = memory[PC++];
                if (A != 0) PC = addr;
                break;
            }
            case 0x0A: CMPF = A == memory[PC++]; break;          // CMP #val
            case 0x0B: break;                                    // NOP

            // Flow control
            case 0x0C: { // BEQ
                uint8_t addr = memory[PC++];
                if (A == 0) PC = addr;
                break;
            }
            case 0x0D: { // BNE
                uint8_t addr = memory[PC++];
                if (A != 0) PC = addr;
                break;
            }
            case 0x0E: { // BMI (A & 0x80 != 0)
                uint8_t addr = memory[PC++];
                if (A & 0x80) PC = addr;
                break;
            }
            case 0x0F: { // BPL (A & 0x80 == 0)
                uint8_t addr = memory[PC++];
                if (!(A & 0x80)) PC = addr;
                break;
            }
            case 0x10: { // BRA (Unconditional branch)
                uint8_t addr = memory[PC++];
                PC = addr;
                break;
            }
            case 0x11: { // CMP #val
                uint8_t value = memory[PC++];
                CMPF = A == value;
                break;
            }
            case 0x12: { // BCC (Branch if carry clear)
                uint8_t addr = memory[PC++];
                if (!(A & 0x01)) PC = addr; // Using the least significant bit of A for carry flag
                break;
            }
            case 0x13: { // BCS (Branch if carry set)
                uint8_t addr = memory[PC++];
                if (A & 0x01) PC = addr;
                break;
            }

            default: running = false; break;
        }
    }

    void run(int steps = 1) {
        while (running && steps-- > 0) step();
    }

    std::string getCurrentInstruction() {
        uint8_t opcode = memory[PC];
        char buffer[64];

        switch (opcode) {
            case 0x01: sprintf(buffer, "LDA 0x%02X", memory[PC + 1]); break;
            case 0x02: sprintf(buffer, "ADD 0x%02X", memory[PC + 1]); break;
            case 0x03: sprintf(buffer, "STA 0x%02X", memory[PC + 1]); break;
            case 0x04: sprintf(buffer, "LDM 0x%02X", memory[PC + 1]); break;
            case 0x05: sprintf(buffer, "JMP 0x%02X", memory[PC + 1]); break;
            case 0x06: sprintf(buffer, "JEZ 0x%02X", memory[PC + 1]); break;
            case 0x07: sprintf(buffer, "HLT"); break;
            case 0x08: sprintf(buffer, "SUB 0x%02X", memory[PC + 1]); break;
            case 0x09: sprintf(buffer, "JNZ 0x%02X", memory[PC + 1]); break;
            case 0x0A: sprintf(buffer, "CMP 0x%02X", memory[PC + 1]); break;
            case 0x0B: sprintf(buffer, "NOP"); break;
            case 0x0C: sprintf(buffer, "BEQ 0x%02X", memory[PC + 1]); break;
            case 0x0D: sprintf(buffer, "BNE 0x%02X", memory[PC + 1]); break;
            case 0x0E: sprintf(buffer, "BMI 0x%02X", memory[PC + 1]); break;
            case 0x0F: sprintf(buffer, "BPL 0x%02X", memory[PC + 1]); break;
            case 0x10: sprintf(buffer, "BRA 0x%02X", memory[PC + 1]); break;
            case 0x11: sprintf(buffer, "CMP 0x%02X", memory[PC + 1]); break;
            case 0x12: sprintf(buffer, "BCC 0x%02X", memory[PC + 1]); break;
            case 0x13: sprintf(buffer, "BCS 0x%02X", memory[PC + 1]); break;
            default: sprintf(buffer, "??? 0x%02X", opcode); break;
        }

        return std::string(buffer);
    }
};

// ========== OpenGL Pixel Renderer ==========
void renderScreenToBackground(uint8_t* memory) {
    ImGuiIO& io = ImGui::GetIO();
    ImDrawList* bg = ImGui::GetBackgroundDrawList();
    float pixelSize = io.DisplaySize.x / SCREEN_WIDTH;

    for (int y = 0; y < SCREEN_HEIGHT; ++y) {
        for (int x = 0; x < SCREEN_WIDTH; ++x) {
            int index = SCREEN_MEM_START + (y * SCREEN_WIDTH + x) * 3;
            uint8_t r = memory[index + 0];
            uint8_t g = memory[index + 1];
            uint8_t b = memory[index + 2];
            ImVec2 tl(x * pixelSize, y * pixelSize);
            ImVec2 br(tl.x + pixelSize, tl.y + pixelSize);
            bg->AddRectFilled(tl, br, IM_COL32(r, g, b, 255));
        }
    }
}

// ========== Main ==========
int main() {
    if (!glfwInit()) return -1;
    GLFWwindow* window = glfwCreateWindow(1024, 640, "Rainbow CPU Emulator", NULL, NULL);
    glfwMakeContextCurrent(window);
    glfwSwapInterval(1);

    IMGUI_CHECKVERSION();
    ImGui::CreateContext();
    ImGui_ImplGlfw_InitForOpenGL(window, true);
    ImGui_ImplOpenGL3_Init("#version 130");
    ImGui::StyleColorsDark();

    CPU cpu;
    std::vector<uint8_t> rainbow = {
        0x01, 0xFF,       // LDA #0xFF
        0x03, 0x00,       // STA $00 (R)
        0x01, 0x00,
        0x03, 0x01,       // G
        0x01, 0x00,
        0x03, 0x02,       // B
        0x01, 0x00,
        0x03, 0x10,       // pixelIndex = 0

        // loop:
        0x04, 0x00, 0x03, 0x20, // load R -> $20
        0x04, 0x01, 0x03, 0x21, // G -> $21
        0x04, 0x02, 0x03, 0x22, // B -> $22

        // Write to screen: mem[0x1000 + pixelIndex * 3]
        0x04, 0x10,             // LDM pixelIndex
        0x02, 0x00,             // ADD #0 (A = index)
        0x03, 0x11,             // Store to $11 (low index)

        // simulate writing to screen: here you can inject into actual screen mem if you simulate 16-bit

        0x04, 0x00, 0x03, 0x00, // Rotate R -> temp
        0x04, 0x01, 0x03, 0x00,
        0x04, 0x02, 0x03, 0x01,
        0x04, 0x00, 0x03, 0x02,

        0x02, 0x01,             // ADD #1 (pixelIndex++)
        0x03, 0x10,             // STA pixelIndex

        0x05, 0x10              // JMP to loop start
    };

    cpu.loadProgram(rainbow);

    static MemoryEditor mem_edit_1;
//   static char data[0x10000];
//   size_t data_size = 0x10000;

    // --- Main loop ---
    while (!glfwWindowShouldClose(window)) {
        glfwPollEvents();
        ImGui_ImplOpenGL3_NewFrame();
        ImGui_ImplGlfw_NewFrame();
        ImGui::NewFrame();

        renderScreenToBackground(cpu.memory.data());

        ImGui::Begin("CPU Emulator");

        if (ImGui::Button("Step")) cpu.step();
        ImGui::SameLine();
        if (ImGui::Button("Run")) cpu.run(10000);
        ImGui::SameLine();
        if (ImGui::Button("Reset")) cpu.loadProgram(rainbow);

        ImGui::Text("A: 0x%02X | PC: 0x%04X", cpu.A, cpu.PC);
        ImGui::Text("Instruction: %s", cpu.getCurrentInstruction().c_str());

        ImGui::End();


        mem_edit_1.DrawWindow("Memory Editor", &cpu.memory, 65536);


        ImGui::Begin("Memory Inspector");

        static int selected = 0;
        ImGui::SliderInt("Address", &selected, 0, (int)cpu.memory.size() - 1);
        ImGui::Text("0x%04X: 0x%02X", selected, cpu.memory[selected]);
        if (ImGui::Button("Zero")) cpu.memory[selected] = 0;

        ImGui::Separator();
        ImGui::Text("Full Memory Dump:");

        ImGui::BeginChild("memory_dump", ImVec2(500, 300), true);
        for (int row = 0; row < 16; ++row) {
            for (int col = 0; col < 16; ++col) {
                int addr = row * 16 + col;
                ImGui::SameLine();
                ImGui::Text("0x%02X", cpu.memory[addr]);
            }
        }
        ImGui::EndChild();

        ImGui::End();

        ImGui::Render();
        int display_w, display_h;
        glfwGetFramebufferSize(window, &display_w, &display_h);
        glViewport(0, 0, display_w, display_h);
        glClearColor(0, 0, 0, 1);
        glClear(GL_COLOR_BUFFER_BIT);
        ImGui_ImplOpenGL3_RenderDrawData(ImGui::GetDrawData());
        glfwSwapBuffers(window);
    }

    ImGui_ImplOpenGL3_Shutdown();
    ImGui_ImplGlfw_Shutdown();
    ImGui::DestroyContext();
    glfwDestroyWindow(window);
    glfwTerminate();
    return 0;
}