import * as THREE from "three";
import { linterp, rand, remap, saturate, spectrum_offset } from "./shaderFragments";

export default class PostProcess {
    constructor(renderer, positionRT) {
        this.material = new THREE.ShaderMaterial({
            uniforms: {
                uTexture: { type: "t", value: null },
                uPosition: { type: "t", value: positionRT.texture },
                uWindLevel: { value: 0 },
                uTime: { value: 0 },
                uSphereActiveTimer: { value: 0 },
                uAspect: { value: window.innerWidth / window.innerHeight },
                uBottomDistortion: { value: 0.5 },
                uDispersion: { value: 0 },
            },
            
            vertexShader: `
                varying vec2 vUv;

                void main() {
                    vUv = uv;
                    gl_Position = vec4(position.xy, 0.0, 1.0);    
                }
            `,

            fragmentShader: `
                uniform sampler2D uTexture;
                uniform sampler2D uPosition;
                uniform float uWindLevel;
                uniform float uTime;
                uniform float uSphereActiveTimer;
                uniform float uAspect;
                uniform bool  uBottomDistortion;
                uniform float uDispersion;

                varying vec2 vUv;

                ${rand}
                ${saturate}
                ${remap}
                ${linterp}
                ${spectrum_offset}


                /* discontinuous pseudorandom uniformly distributed in [-0.5, +0.5]^3 */
                vec3 random3(vec3 c) {
                    float j = 4096.0*sin(dot(c,vec3(17.0, 59.4, 15.0)));
                    vec3 r;
                    r.z = fract(512.0*j);
                    j *= .125;
                    r.x = fract(512.0*j);
                    j *= .125;
                    r.y = fract(512.0*j);
                    return r-0.5;
                }
                
                const float F3 =  0.3333333;
                const float G3 =  0.1666667;
                
                /* 3d simplex noise */
                float simplex3d(vec3 p) {
                     vec3 s = floor(p + dot(p, vec3(F3)));
                     vec3 x = p - s + dot(s, vec3(G3));
                     vec3 e = step(vec3(0.0), x - x.yzx);
                     vec3 i1 = e*(1.0 - e.zxy);
                     vec3 i2 = 1.0 - e.zxy*(1.0 - e);
                     vec3 x1 = x - i1 + G3;
                     vec3 x2 = x - i2 + 2.0*G3;
                     vec3 x3 = x - 1.0 + 3.0*G3;
                     vec4 w, d;
                     w.x = dot(x, x);
                     w.y = dot(x1, x1);
                     w.z = dot(x2, x2);
                     w.w = dot(x3, x3);
                     w = max(0.6 - w, 0.0);
                     d.x = dot(random3(s), x);
                     d.y = dot(random3(s + i1), x1);
                     d.z = dot(random3(s + i2), x2);
                     d.w = dot(random3(s + 1.0), x3);
                     w *= w;
                     w *= w;
                     d *= w;
                     return dot(d, vec4(52.0));
                }
                const mat3 rot1 = mat3(-0.37, 0.36, 0.85,-0.14,-0.93, 0.34,0.92, 0.01,0.4);
                const mat3 rot2 = mat3(-0.55,-0.39, 0.74, 0.33,-0.91,-0.24,0.77, 0.12,0.63);
                const mat3 rot3 = mat3(-0.71, 0.52,-0.47,-0.08,-0.72,-0.68,-0.7,-0.45,0.56);
                float simplex3d_fractal(vec3 m) {
                    return   0.5333333*simplex3d(m*rot1)
                            +0.2666667*simplex3d(2.0*m*rot2)
                            +0.1333333*simplex3d(4.0*m*rot3)
                            +0.0666667*simplex3d(8.0*m);
                }

                // vec4 block:  .xy bottomcoords - .zw topcoords
                float isInBlock(vec2 uv, vec4 block) {
                    vec2 a = sign(uv - block.xy);
                    vec2 b = sign(block.zw - uv);
                    return min(sign(a.x + a.y + b.x + b.y - 3.), 0.);
                }

                vec2 moveDiff(vec2 uv, vec4 swapA, vec4 swapB) {
                    vec2 diff = swapB.xy - swapA.xy;
                    return diff * isInBlock(uv, swapA);
                }

                vec2 randSwap(
                    vec2 uv, 
                    vec2 gridSize, /* in uv space */
                    vec2 subGridSize, /* e.g. vec2(3, 3) for a 3x3 grid */ 
                    float time,
                    inout float dispersion
                ) {
                    vec2 gridBottom = uv - mod(uv, gridSize);
                    vec2 gridCenter = gridBottom + gridSize * 0.5;

                    if(uBottomDistortion) {
                        uv.y += srand(gridBottom) * 0.025;
                    }

                    float subGridCellsCount = subGridSize.x * subGridSize.y;

                    float gridRand1 = rand(gridCenter + vec2(time));
                    float gridRand2 = rand(gridBottom + vec2(time));

                    dispersion += srand(gridBottom + vec2(time)) * (gridSize.x + gridSize.y);

                    float randSubGridIdx1 = floor( gridRand1 * subGridCellsCount  );
                    float randSubGridIdx2 = floor( gridRand2 * subGridCellsCount  );

                    vec2 subCellSize = gridSize / subGridSize;
                    
                    vec2 scell1Bottom = gridBottom + vec2(
                        mod(randSubGridIdx1, subGridSize.x) * subCellSize.x,
                        floor(randSubGridIdx1 / subGridSize.x) * subCellSize.y
                    );
                    vec2 scell2Bottom = gridBottom + vec2(
                        mod(randSubGridIdx2, subGridSize.x) * subCellSize.x,
                        floor(randSubGridIdx2 / subGridSize.x) * subCellSize.y
                    );
                    
                    vec4 swapA = vec4(scell1Bottom, scell1Bottom + subCellSize);
                    vec4 swapB = vec4(scell2Bottom, scell2Bottom + subCellSize);

                    vec2 newUv = uv;
                    // if we're in swapA, move to swapB
                    newUv += moveDiff(uv, swapA, swapB);
                    // if we're in swapB, move to swapA
                    newUv += moveDiff(uv, swapB, swapA);
                    return newUv;
                }

                void main() {
                    float rgbOffs = 0.0;
                    vec2 windOffs = vec2(0.0);

                    vec3 position = texture2D(uPosition, vUv).xyz;

                    if(uWindLevel > 0.0) {
                        float noisex = simplex3d(vec3(uTime * 1.5, 0.0, 0.0));
                        float noisey = simplex3d(vec3(uTime * 1.5 + 57.897, 0.0, 0.0));

                        windOffs = vec2(noisex, noisey) * uWindLevel * 0.0035;

                        rgbOffs += uWindLevel * 0.006 * length(vUv - vec2(0.5));
                    } 

                    // vec3 c1 = texture2D(uTexture, vUv + windOffs + vec2(-rgbOffs, 0.0)).xyz;
                    // vec3 c2 = texture2D(uTexture, vUv + windOffs + vec2(0.0, 0.0)).xyz;
                    // vec3 c3 = texture2D(uTexture, vUv + windOffs + vec2(+rgbOffs, 0.0)).xyz;

                    // vec3 finalColor = vec3(c1.x, c2.y, c3.z);
                    // gl_FragColor = vec4(finalColor, 1.0); 


                    vec2 nuvs = vUv + windOffs;

                    float sphereActiveTime = uSphereActiveTimer;
                    float strength = 0.0;
                    float targetDist = length(vec3(position.x, 0.0, position.z) - vec3(0.71, 0.0, -0.05));

                    sphereActiveTime -= 0.3;
                    if(targetDist > 0.5 && targetDist < 0.65 && position.y < -0.4575) {
                        strength += 0.15 * clamp(sphereActiveTime, 0.0, 1.0);
                    }

                    sphereActiveTime -= 0.5;
                    if(targetDist > 0.75 && targetDist < 0.8 && position.y < -0.40) {
                        strength += 0.15 * clamp(sphereActiveTime, 0.0, 1.0);
                    }
                    
                    float modT = mod(uTime, 5.0);
                    sphereActiveTime -= 0.2;
                    if(targetDist > modT && targetDist < (0.6 + modT)) {
                        float mid = 0.3 + modT;
                        float t = 1.0 - abs(targetDist - mid) / 0.3;

                        if(targetDist < 0.9) t = 0.0;
                        if(targetDist > 4.0) t *= 1.0 - (targetDist - 4.0);

                        strength += t * 0.15 * clamp(sphereActiveTime, 0.0, 1.0);
                    }

                    {
                        float modT = mod(uTime * 2.0 + 1.0, 10.0);
                        if(targetDist > modT && targetDist < (0.3 + modT)) {
                            float mid = 0.15 + modT;
                            float t = 1.0 - abs(targetDist - mid) / 0.15;
    
                            if(targetDist < 0.9) t = 0.0;
                            if(targetDist > 4.0) t *= 1.0 - (targetDist - 4.0);
    
                            strength += t * 0.15 * clamp(sphereActiveTime, 0.0, 1.0);
                        }
                    }
                    



                    float time = uTime;
                    float dispersion = 0.0;
                    float nullop = 0.0; // we'll use this value instead of dispersion just because the compiler complaints if I don't do it

                    vec2 uv = randSwap(nuvs, vec2(0.40 + nuvs.x * 0.0000025, 0.20),  vec2(3.0, 3.0), time - mod(time, 0.15), dispersion);
                    uv      = randSwap(uv,  vec2(0.02 + uv.x  * 0.0000025, 0.015), vec2(3.0, 2.0), time - mod(time, 0.05), nullop);
                    uv      = randSwap(uv,  vec2(0.06 + uv.x  * 0.0000025, 0.12),  vec2(2.0, 3.0), time - mod(time, 0.02), nullop);
                    uv      = randSwap(uv,  vec2(0.35 + uv.x  * 0.0000025, 0.35),  vec2(2.0, 2.0), time - mod(time, 0.07), dispersion);

                    vec2 dist = nuvs - uv;
                    uv += dist * (1.0 - strength);
                    dispersion = sat(dispersion) * strength * 15.0;


                    float direction = rand(vec2(dispersion)) > 0.5 ? -1.0 : 1.0;
                    const int steps = 10;
                    vec3 sum = vec3(0.0);
                    vec3 cumw = vec3(0.0);

                    // rand pixel offset to ease dispersion a bit
                    uv.x += srand(uv) * 0.01 * dispersion;

                    for(int i = 0; i < steps; i++) {
                        float t = float(i) / float(steps);
                        vec2 dispUv = uv + vec2(dispersion * 0.1 * direction * t, 0.0);
                        vec3 spectr = spectrum_offset(t);
                        cumw += spectr;
                        sum += texture2D(uTexture, dispUv).rgb * spectr;
                    }

                    gl_FragColor = vec4(sum / cumw, 1.0);
                }
            `,

            depthTest:  false,
            depthWrite: false,
        });

        this.mesh = new THREE.Mesh(new THREE.PlaneBufferGeometry(2,2), this.material);
        this.camera = new THREE.PerspectiveCamera( 45, 1 /* remember that the camera is worthless here */, 1, 1000 );
        this.renderer = renderer;

        this.scene = new THREE.Scene();
        this.scene.add(this.mesh);
    }

    compute({ windLevel, time, sphereActiveTimer }, textureFrom, renderTargetDest) {
        this.renderer.setRenderTarget(renderTargetDest);

        this.material.uniforms.uTexture.value = textureFrom;
        this.material.uniforms.uWindLevel.value = windLevel;
        this.material.uniforms.uTime.value = time;
        this.material.uniforms.uSphereActiveTimer.value = sphereActiveTimer;
        this.renderer.render(this.scene, this.camera);

        this.renderer.setRenderTarget(null);
    }
}