import { ShaderMaterial, Vector3, Vector2, DoubleSide } from "three";

export function WaterMaterial() {
  let program = new ShaderMaterial({
    vertexShader: `
      varying vec3 vFragPos;
      varying vec3 vNormal;
      varying mat4 vProjViewMatrix;
      varying mat4 vViewMatrix;

      void main() {
        gl_Position = projectionMatrix * modelViewMatrix * vec4(position, 1.0); 
      
        vFragPos = (modelMatrix * vec4(position, 1.0)).xyz;
        vNormal = mat3(transpose(inverse(modelMatrix))) * normal;
        vProjViewMatrix = projectionMatrix * viewMatrix;
        vViewMatrix = viewMatrix;
      }
    `,
  
    fragmentShader: `
    varying vec3 vFragPos;
    varying vec3 vNormal;
    varying mat4 vProjViewMatrix;
    varying mat4 vViewMatrix;

    uniform sampler2D uMountainsDepth;
    uniform sampler2D uWaterDepth;
    uniform sampler2D uColor;
    uniform sampler2D uBlurredReflectionDistance;
    uniform sampler2D uFluidTexture;

    uniform vec2 uScreenSize;
    uniform vec2 uPixelSize;
    uniform vec3 uCameraPos;

    void main() {
      vec3 norm = normalize(vNormal);

      vec2 uv = gl_FragCoord.xy / uScreenSize.xy;
      
      
      // fluid sim - these calculations need to match inside fluidsim.js as well
      // get plane position
      vec3 sphereCenter = vec3(0.71, -0.13, -0.05);
      vec2 distVec = (vFragPos - sphereCenter).xz * 0.75;
      vec2 fluidUv = clamp(distVec, -1.0, 1.0) * 0.5 + 0.5;
      
      vec3 fluidOffset = vec3(0.0);
      vec3 fluidColor  = vec3(0.2, 0.1, 0.06);
      float fluidAlpha = 0.0;
      vec3 f0 = texture2D(uFluidTexture, fluidUv).xyz;
      vec3 fdx = texture2D(uFluidTexture, fluidUv + vec2(uPixelSize.x, 0.0)).xyz;
      vec3 fdy = texture2D(uFluidTexture, fluidUv + vec2(0.0, uPixelSize.y)).xyz;
      fluidOffset.x = f0.x - fdx.x;
      fluidOffset.y = f0.y - fdy.x;
      fluidAlpha = clamp(f0.x, 0.0, 1.0);

      if(distVec.x < -1.0 || distVec.x > 1.0 || distVec.y < -1.0 || distVec.y > 1.0) {
        fluidOffset = vec3(0.0);
        fluidColor  = vec3(0.0);
        fluidAlpha  = 0.0;
      }
      
      uv += fluidOffset.xy * 0.055;
      
      float mountainsDepth = texture2D(uMountainsDepth, uv).x;
      float waterDepth = texture2D(uWaterDepth, uv).x;

      float diff = max(-(waterDepth - mountainsDepth), 0.0) * 13.0;
      diff = min(diff, 2.5);
      vec3 color = vec3(0.09, 0.027, 0.0);
      float alpha = 1.0 - exp(-diff);

      vec3 reflectionColor = texture2D(uColor, uv).xyz;
      reflectionColor *= 1.0;
      reflectionColor *= vec3(1.0, 0.8, 0.4);


      vec3 colorPlusReflections = color + reflectionColor;
      // vec3 colorWithFluid = mix(colorPlusReflections, fluidColor, fluidAlpha);
      // vec3 colorWithFluid = pow(colorPlusReflections, vec3(1.0 + fluidAlpha * 3.0));
      
      // vec3 colorWithFluid = colorPlusReflections * (1.0 + fluidAlpha * 2.0);
      // colorWithFluid = mix(colorWithFluid, fluidColor, fluidAlpha * 0.9);

      vec3 colorWithFluid = colorPlusReflections * (1.0 + fluidAlpha * 2.0);
      if(fluidAlpha > 0.75) {
        colorWithFluid = mix(colorWithFluid, fluidColor, (fluidAlpha - 0.75) * 3.5);
      }

      // vec3 colorWithFluid = colorPlusReflections + (colorPlusReflections - vec3(0.15)) * (fluidAlpha * 2.0);
      // if(fluidAlpha > 0.75) {
      //   colorWithFluid = mix(colorWithFluid, fluidColor, (fluidAlpha - 0.75) * 2.75);
      // }



      gl_FragColor = vec4(colorWithFluid, alpha);

      // gl_FragColor = texture2D(uPosition, uv, 5.0);
      // gl_FragColor = vec4(norm, 1.0);
      // gl_FragColor = vec4(vFragPos, 1.0);
      // gl_FragColor = vec4(waterDepth < mountainsDepth ? 1.0 : 0.0, 0.0, 0.0, 1.0);
      // gl_FragColor = vec4(vec3(mountainsDepth) * 0.1  , 1.0);
      // gl_FragColor = vec4(vec3(diff), 1.0);
    }
    `,
  
    uniforms: {
      uMountainsDepth: { type: "t", value: null },
      uWaterDepth: { type: "t", value: null },
      uColor: { type: "t", value: null },
      uScreenSize: { value: new Vector2(window.innerWidth, window.innerHeight) },
      uPixelSize: { value: new Vector2(1 / window.innerWidth, 1 / window.innerHeight) },
      uCameraPos: { value: new Vector3(0,0,0) },
      uBlurredReflectionDistance: { type: "t", value: null },
      uFluidTexture: { type: "t", value: null },
    },
 
    side: DoubleSide,     
    transparent: true,
  });

  return program;
}















/*
  unused SSR code





  // // SSR
  // float depth = length(vFragPos - uCameraPos);
  // vec3 viewDir = normalize(vFragPos - uCameraPos);
  // if(dot(viewDir, norm) > 0.0) norm = -norm; 
  // vec3 mult = vec3(1.0);
  // vec3 reflDir = reflect(viewDir, norm);
  // reflDir = normalize(reflDir);  
  // vec3 rd = reflDir;
  // vec3 ro = vFragPos + reflDir * max(0.01, 0.01 * depth);  
  // vec3 p2;
  // vec3 lastP;
  // vec3 debugColor = vec3(1.0);
  // bool intersected = intersect(ro, rd, p2, lastP, debugColor);
  // if(intersected) {
  //   // intersection validated
  //   vec4 projP2 = vProjViewMatrix * vec4(p2, 1.0);
  //   vec2 p2Uv = (projP2 / projP2.w).xy * 0.5 + 0.5;
  //   vec3 color = vec3(1.0); // texture2D(uColor, p2Uv).xyz;
  //   mult *= debugColor;
  // } else {
  //   // intersection is invalid
  //   mult *= 0.0;
  // }  
  // gl_FragColor.xyz += mult;





float rand(float co) { return fract(sin(co*(91.3458)) * 47453.5453); }
    float rand(vec2 co)  { return fract(sin(dot(co.xy ,vec2(12.9898,78.233))) * 43758.5453); }
    float rand(vec3 co)  { return rand(co.xy+rand(co.z)); }

    float depthBufferAtP(vec3 p) {
      vec4 projP = vProjViewMatrix * vec4(p, 1.0);
      vec2 pNdc = (projP / projP.w).xy;
      vec2 pUv  = pNdc * 0.5 + 0.5;
      float depthAtPointP = texture2D(uPosition, pUv).w;
      if(depthAtPointP == 0.0) depthAtPointP = 9999999.0; 
      return depthAtPointP;
    }

    bool intersect(
      vec3 ro, vec3 rd, 
      out vec3 intersectionP,
      out vec3 lastP,
      out vec3 color) 
    {
      // bool jitter = true;
      // float startingStep = 0.05;
      // float stepMult = 1.15;
      // const int steps = 40;
      // const int binarySteps = 5;
      // float maxIntersectionDepthDistance = 1.5;

      bool jitter = true;
      float startingStep = 0.01;
      float stepMult = 1.055;
      const int steps = 60;
      const int binarySteps = 7;
      float maxIntersectionDepthDistance = 0.05;

      // vec2 halfPixelOffs = vec2(1.0 / uScreenSize.x, 1.0 / uScreenSize.y) * 0.5;
      float step = startingStep;
      vec3 p = ro;
      bool intersected = false;
      bool possibleIntersection = false;
      float lastRecordedDepthBuffThatIntersected;
      vec3 p1, p2;
      vec3 initialP = p;
      int stepsTaken = 0;
      for(int i = 0; i < steps; i++) {
        stepsTaken = i;
        // at the end of the loop, we'll advance p by jittB to keep the jittered sampling in the proper "cell" 
        // float jittA = 0.5 + rand(p) * 0.5;
        float jittA = fract(rand(p)); // + uRandoms.x);
        if(!jitter) jittA = 1.0;
        float jittB = 1.0 - jittA;
        p += rd * step * jittA;
        vec4 projP = vProjViewMatrix * vec4(p, 1.0);
        vec2 pNdc = (projP / projP.w).xy;
        vec2 pUv  = (pNdc * 0.5 + 0.5);
        float depthAtPosBuff = texture2D(uPosition, pUv).w;
        if(depthAtPosBuff == 0.0) {
          depthAtPosBuff = 9999999.0;
        } 
        // out of screen bounds condition
        if(pUv.x < 0.0 || pUv.x > 1.0 || pUv.y < 0.0 || pUv.y > 1.0) {
          break;
        }
        float depthAtPointP = -(vViewMatrix * vec4(p, 1.0)).z;
        if(depthAtPointP > depthAtPosBuff) {
          // intersection found!
          p1 = initialP;
          p2 = p;
          lastRecordedDepthBuffThatIntersected = depthAtPosBuff;
          possibleIntersection = true;
          break;
        }
        // initialP needs to be the last jittered sample, and can't just be the "p" value at the start
        // of the loop iteration, otherwise you run the risk of having both p1 and p2 at the same side of the depth buffer
        // and (apparently) for the binary search to work properly you need to have p1 and p2 on different sides of the depth buffer
        // p1 at the side of the depth buffer plane that it's closer to the camera, and p2 at the other side
        initialP = p;
        p += rd * step * jittB;
        step *= stepMult; // this multiplication obviously need to appear AFTER we add jittB
      }
      // stranamente mi trovo a dover spostare la binary search fuori dal primo loop, altrimenti
      // per qualche motivo esoterico la gpu inizia a prendere fuoco

      // ******** binary search start *********
      for(int j = 0; j < binarySteps; j++) {
        vec3 mid = (p1 + p2) * 0.5;
        float depthAtMid = - (vViewMatrix * vec4(mid, 1.0)).z;
        float depthAtPosBuff = depthBufferAtP(mid);
        if(depthAtMid > depthAtPosBuff) {
          p2 = (p1 + p2) * 0.5;
          // we need to save this value inside this if-statement otherwise if it was outside and above this
          // if statement, it would be possible that it's value would be very large (e.g. if p1 intersected the "background"
          // since in that case positionBufferAtP() returns viewDir * 99999999.0)
          // and if that value is very large, it can create artifacts when evaluating this condition:
          // ---> if(abs(distanceFromCameraAtP2 - lastRecordedDepthBuffThatIntersected) < maxIntersectionDepthDistance) 
          // to be honest though, these artifacts only appear for largerish values of maxIntersectionDepthDistance
          lastRecordedDepthBuffThatIntersected = depthAtPosBuff;
        } else {
          p1 = (p1 + p2) * 0.5;
        }
      }
      // ******** binary search end   *********

      intersectionP = p2;
      lastP = p;
      
      // use p2 as the intersection point
      float depthAtP2 = - (vViewMatrix * vec4(p2, 1.0)).z;
      if( possibleIntersection && // without using possibleIntersection apparently it's possible that lastRecordedDepthBuffThatIntersected
                                  // ends up being valid thanks to the binary search, and that causes all sorts of troubles
          abs(depthAtP2 - lastRecordedDepthBuffThatIntersected) < maxIntersectionDepthDistance
      ) {
        // intersection validated
        intersected = true;
      }

      return intersected;
    }
*/