#include "GreedyMesher.h"
#include <algorithm>

// Helper to index into voxels with bounds checking
inline int voxelAt(const std::vector<std::vector<std::vector<int>>>& v, int x, int y, int z) {
    if (x<0||y<0||z<0) return 0;
    if (x>= (int)v.size() || y>= (int)v[0].size() || z>= (int)v[0][0].size()) return 0;
    return v[x][y][z];
}

std::vector<Quad> GreedyMesher::mesh(const std::vector<std::vector<std::vector<int>>>& voxels) {
    std::vector<Quad> quads;
    int sizeX = voxels.size();
    int sizeY = voxels[0].size();
    int sizeZ = voxels[0][0].size();
    
    // directions: {dx,dy,dz}, and their “du” and “dv” axes
    static const int dirs[6][3] = {
        {-1,0,0},{1,0,0},{0,-1,0},{0,1,0},{0,0,-1},{0,0,1}
    };
    static const int axisU[6] = {2,2,0,0,0,0}; // u-axis index per face
    static const int axisV[6] = {1,1,1,1,2,2}; // v-axis index per face
    
    // For each face direction
    for (int d = 0; d < 6; d++) {
        int dx = dirs[d][0], dy = dirs[d][1], dz = dirs[d][2];
        int u = axisU[d], v = axisV[d];
        
        // dims: dimensions along u,v, and w (the face-normal axis)
        int dimU = (u==0?sizeX:(u==1?sizeY:sizeZ));
        int dimV = (v==0?sizeX:(v==1?sizeY:sizeZ));
        int dimW = (dx!=0?sizeX:(dy!=0?sizeY:sizeZ));
        
        // Allocate mask[dimU][dimV]
        std::vector<int> mask(dimU * dimV, 0);
        
        // Sweep along the w axis
        for (int w = 0; w <= dimW; w++) {
            // build mask
            for (int i = 0; i < dimU; i++) {
                for (int j = 0; j < dimV; j++) {
                    int x,y,z;
                    // map (i,j,w) to (x,y,z)
                    int coord[3];
                    coord[u] = i;
                    coord[v] = j;
                    coord[ (dx!=0?0:(dy!=0?1:2)) ] = w;
                    x = coord[0]; y = coord[1]; z = coord[2];
                    int a = voxelAt(voxels, x, y, z);
                    coord[ (dx!=0?0:(dy!=0?1:2)) ] = w-1;
                    int b = voxelAt(voxels, coord[0], coord[1], coord[2]);
                    // if face between a and b should be drawn
                    if ((a!=0) != (b!=0)) {
                        // a is solid, b is empty → draw face on b side
                        mask[i + j*dimU] = (a!=0 ? a : -b);
                    } else {
                        mask[i + j*dimU] = 0;
                    }
                }
            }
            // Greedy merge mask
            for (int j = 0; j < dimV; j++) {
                for (int i = 0; i < dimU; ) {
                    int c = mask[i + j*dimU];
                    if (c != 0) {
                        // determine width
                        int wdt = 1;
                        while (i+wdt < dimU && mask[i+wdt + j*dimU] == c) wdt++;
                        // determine height
                        int hgt = 1;
                        bool done = false;
                        while (j+hgt < dimV && !done) {
                            for (int k = 0; k < wdt; k++) {
                                if (mask[i+k + (j+hgt)*dimU] != c) {
                                    done = true;
                                    break;
                                }
                            }
                            if (!done) hgt++;
                        }
                        // emit quad
                        Quad q;
                        // base position
                        float pos[3] = {0,0,0};
                        pos[0] = pos[1] = pos[2] = 0;
                        // set w axis coordinate
                        pos[(dx!=0?0:(dy!=0?1:2))] = w;
                        // but if this is a “back” face (a was empty, b solid), shift pos back
                        if (c < 0) pos[(dx!=0?0:(dy!=0?1:2))] = w-1;
                        // then set u,v
                        pos[u] = i;
                        pos[v] = j;
                        q.x = pos[0];
                        q.y = pos[1];
                        q.z = pos[2];
                        // set du and dv vectors
                        q.du[0] = q.du[1] = q.du[2] = 0;
                        q.dv[0] = q.dv[1] = q.dv[2] = 0;
                        q.du[u] = wdt;
                        q.dv[v] = hgt;
                        // normal
                        q.normal = d;
                        // textureID
                        q.textureID = abs(c);
                        quads.push_back(q);
                        // zero out mask
                        for (int jj = 0; jj < hgt; jj++)
                            for (int ii = 0; ii < wdt; ii++)
                                mask[i+ii + (j+jj)*dimU] = 0;
                        // advance
                        i += wdt;
                    } else {
                        i++;
                    }
                }
            }
        }
    }
    return quads;
}