/* cspell:ignore gfactor bfactor bfinal MSIZE normpdf Znorm */
import { compileShader, createProgram, glsl } from "../webglHelper.js";

const vertexShaderSource = glsl`#version 300 es
    in vec2 a_position;
    in vec2 a_texCoord;

    out vec2 v_texCoord;

    void main() {
      gl_Position = vec4(a_position, 0.0, 1.0);
      v_texCoord = a_texCoord;
    }
  `;

/*
 * Applies a bilateral filter to the image in unit 0 only on the mask stored in the red channel of
 * unit 1. Space and color sigma values are set by calling setFilterStrength. Writes to
 * outputTexture.
 *
 *
 */
export const buildBilateralFilterStage = (
  gl: WebGL2RenderingContext,
  positionBuffer: WebGLBuffer,
  texCoordBuffer: WebGLBuffer,
  outputTexture: WebGLTexture
) => {
  const fragmentShaderSource = glsl`#version 300 es
  precision highp float;

// Bilateral filter.  Based on https://www.shadertoy.com/view/4dfGDH#
//
//
#define MSIZE 15


float normpdf(in float x, in float sigma)
{
  return 0.39894*exp(-0.5*x*x/(sigma*sigma))/sigma;
}

float normpdf3(in vec3 v, in float sigma)
{
  return 0.39894*exp(-0.5*dot(v,v)/(sigma*sigma))/sigma;
}

uniform sampler2D uSourceSampler;
uniform vec2 u_texelSize;
uniform sampler2D u_mask;
uniform float u_sigma;
uniform float u_bSigma;

in vec2 v_texCoord;
out vec4 fragColor;

void main(void) {
  vec4 c = texture(uSourceSampler, v_texCoord);
  vec4 bc = c;

  //declare stuff
  const int kSize = (MSIZE-1)/2;
  float kernel[MSIZE];
  vec3 bfinal_color = vec3(0.0);

  float bZ = 0.0;

  float a = texture(u_mask, v_texCoord).r;
  if (a == 0.0 || u_sigma == 0.0) {
    fragColor = c;
  } else {
    //create the 1-D kernel
    for (int j = 0; j <= kSize; ++j) {
      kernel[kSize+j] = kernel[kSize-j] = normpdf(float(j), u_sigma);
    }

    vec3 cc;
    float gfactor;
    float bfactor;
    float bZnorm = 1.0/normpdf(0.0, u_bSigma);
    //read out the texels
    for (int i=-kSize; i <= kSize; ++i)
    {
      for (int j=-kSize; j <= kSize; ++j)
      {
        // color at pixel in the neighborhood
        vec2 coord = v_texCoord.xy + vec2(float(i), float(j))*u_texelSize.xy;
        cc = texture(uSourceSampler, coord).rgb;

        // compute both the gaussian smoothed and bilateral
        gfactor = kernel[kSize+j]*kernel[kSize+i];
        bfactor = normpdf3(cc-c.rgb, u_bSigma)*bZnorm*gfactor;
        bZ += bfactor;

        bfinal_color += bfactor*cc;
      }
    }

    bc = vec4(bfinal_color/bZ, 1.0);

    fragColor = mix(c, bc, a);
  }
}
  `;

  const { width: outputWidth, height: outputHeight } = gl.canvas;
  const texelWidth = 1 / outputWidth;
  const texelHeight = 1 / outputHeight;

  const vertexShader = compileShader(gl, gl.VERTEX_SHADER, vertexShaderSource);
  const fragmentShader = compileShader(gl, gl.FRAGMENT_SHADER, fragmentShaderSource);
  const program = createProgram(gl, vertexShader, fragmentShader);
  gl.useProgram(program);
  const inputFrameLocation = gl.getUniformLocation(program, "u_inputFrame");
  const maskLocation = gl.getUniformLocation(program, "u_mask");
  const texelSizeLocation = gl.getUniformLocation(program, "u_texelSize");
  const bSigmaLocation = gl.getUniformLocation(program, "u_bSigma");
  const sigmaLocation = gl.getUniformLocation(program, "u_sigma");
  let sigma = 0;
  let bSigma = 0;

  const uSourceLocation = gl.getUniformLocation(program, "uSourceSampler");
  const texCoordAttributeLocation = gl.getAttribLocation(program, "a_texCoord");
  const positionAttributeLocation = gl.getAttribLocation(program, "a_position");

  const vao = gl.createVertexArray();
  gl.bindVertexArray(vao);

  gl.bindBuffer(gl.ARRAY_BUFFER, positionBuffer);
  gl.vertexAttribPointer(positionAttributeLocation, 2, gl.FLOAT, false, 0, 0);
  gl.enableVertexAttribArray(positionAttributeLocation);

  gl.bindBuffer(gl.ARRAY_BUFFER, texCoordBuffer);
  gl.vertexAttribPointer(texCoordAttributeLocation, 2, gl.FLOAT, false, 0, 0);
  gl.enableVertexAttribArray(texCoordAttributeLocation);

  gl.bindVertexArray(null);

  const frameBuffer = gl.createFramebuffer();
  gl.bindFramebuffer(gl.FRAMEBUFFER, frameBuffer);
  gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, outputTexture, 0);

  gl.useProgram(program);
  gl.uniform1i(inputFrameLocation, 0);
  gl.uniform1i(maskLocation, 1);
  gl.uniform2f(texelSizeLocation, texelWidth, texelHeight);

  const render = () => {
    gl.useProgram(program);
    gl.bindVertexArray(vao);
    gl.bindFramebuffer(gl.FRAMEBUFFER, frameBuffer);
    gl.viewport(0, 0, outputWidth, outputHeight);
    gl.clearColor(0.0, 0.0, 0.0, 0.0);
    gl.clear(gl.COLOR_BUFFER_BIT);
    gl.drawArrays(gl.TRIANGLE_STRIP, 0, 4);
  };

  const updateFilterStrength = (filterStrength: number) => {
    sigma = filterStrength / 2;
    bSigma = 0.125 + (filterStrength / 10) * 0.25;
    gl.useProgram(program);
    gl.uniform1f(sigmaLocation, sigma);
    gl.uniform1f(bSigmaLocation, bSigma);
  };

  const updateSigma = (newSigma: number) => {
    sigma = newSigma;
    gl.useProgram(program);
    gl.uniform1f(sigmaLocation, sigma);
  };

  const updateBSigma = (newBSigma: number) => {
    bSigma = newBSigma;
    gl.useProgram(program);
    gl.uniform1f(bSigmaLocation, bSigma);
  };

  const cleanUp = () => {
    gl.deleteFramebuffer(frameBuffer);
    gl.deleteProgram(program);
    gl.deleteShader(fragmentShader);
  };

  return { render, updateSigma, updateBSigma, updateFilterStrength, cleanUp };
};
