import { Dimensions } from "../../../../shared/geometry/Dimensions.js";
import { TFLite } from "../../../assets/tflite/tflite-simd.js";
import { blankImageBitmap } from "../../../helpers/imagesWorker.js";
import { BackgroundType, IVideoPipeline } from "../interfaces/IVideoPipeline.js";
import { WebGLFactory } from "../interfaces/IVideoPipelineContainer.js";
import { WebGLContextLostError } from "../interfaces/VideoPipelineError.js";
import { buildRenderMaskStage } from "../stages/renderMaskStage.js";
import { compileShader, createTexture, glsl } from "../webglHelper.js";
import chooseTFLite from "./chooseTFLite.js";
import { SegmentationConfig, inputResolutions } from "./segmentationHelper.js";
import { buildBackgroundBlurStage } from "./stages/backgroundBlurStage.js";
import { buildBackgroundImageStage } from "./stages/backgroundImageStage.js";
import { buildJointBilateralFilterStage } from "./stages/jointBilateralFilterStage.js";
import { buildLoadSegmentationStage } from "./stages/loadSegmentationStage.js";
import { buildResizingStage } from "./stages/resizingStage.js";
import { buildSoftmaxStage } from "./stages/softmaxStage.js";

const segmentationConfig: SegmentationConfig = {
  model: "meet",
  backend: "wasmSimd", // need to check is simd supported
  inputResolution: "256x144",
  pipeline: "webgl2",
};

export const tfLitePipelineVersion = 1;

export const setupTFLitePipeline: WebGLFactory = async (
  sourcePlayback: Dimensions,
  gl: WebGL2RenderingContext
): Promise<IVideoPipeline> => {
  const tfResult = await chooseTFLite(segmentationConfig);

  if (!tfResult) {
    throw new Error("Unable to setup Virtual Background pipeline");
  }

  const pipeline = buildTFLitePipeline(sourcePlayback, segmentationConfig, gl, tfResult?.tflite);

  return pipeline;
};

export const preloadTFLiteModel = async () => {
  await chooseTFLite(segmentationConfig);
};

const buildTFLitePipeline = (
  sourceDimensions: Dimensions,
  segmentationConfig: SegmentationConfig,
  gl: WebGL2RenderingContext,
  tflite: TFLite
): IVideoPipeline => {
  const vertexShaderSource = glsl`#version 300 es

    in vec2 a_position;
    in vec2 a_texCoord;

    out vec2 v_texCoord;

    void main() {
      gl_Position = vec4(a_position, 0.0, 1.0);
      v_texCoord = a_texCoord;
    }
  `;

  const { width: frameWidth, height: frameHeight } = sourceDimensions;
  const [segmentationWidth, segmentationHeight] =
    inputResolutions[segmentationConfig.inputResolution];

  const vertexShader = compileShader(gl, gl.VERTEX_SHADER, vertexShaderSource);

  const vertexArray = gl.createVertexArray();
  gl.bindVertexArray(vertexArray);

  const positionBuffer = gl.createBuffer();
  if (!positionBuffer) {
    throw new Error("unable to create positionBuffer");
  }
  gl.bindBuffer(gl.ARRAY_BUFFER, positionBuffer);
  gl.bufferData(
    gl.ARRAY_BUFFER,
    new Float32Array([-1.0, -1.0, 1.0, -1.0, -1.0, 1.0, 1.0, 1.0]),
    gl.STATIC_DRAW
  );

  const texCoordBuffer = gl.createBuffer();
  if (!texCoordBuffer) {
    throw new Error("unable to create positionBuffer");
  }
  gl.bindBuffer(gl.ARRAY_BUFFER, texCoordBuffer);
  gl.bufferData(
    gl.ARRAY_BUFFER,
    new Float32Array([0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0]),
    gl.STATIC_DRAW
  );

  // We don't use texStorage2D here because texImage2D seems faster
  // to upload video texture than texSubImage2D even though the latter
  // is supposed to be the recommended way:
  // https://developer.mozilla.org/en-US/docs/Web/API/WebGL_API/WebGL_best_practices#use_texstorage_to_create_textures
  const inputFrameTexture = gl.createTexture();
  gl.bindTexture(gl.TEXTURE_2D, inputFrameTexture);
  gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
  gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
  gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
  gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.NEAREST);

  const segmentationTexture = createTexture(gl, gl.RGBA8, segmentationWidth, segmentationHeight);
  if (!segmentationTexture) {
    throw new Error("unable to create segmentationTexture");
  }
  const personMaskTexture = createTexture(gl, gl.RGBA8, frameWidth, frameHeight);
  if (!personMaskTexture) {
    throw new Error("unable to create personMaskTexture");
  }

  const resizingStage = buildResizingStage(
    gl,
    vertexShader,
    positionBuffer,
    texCoordBuffer,
    segmentationConfig,
    tflite
  );
  const loadSegmentationStage =
    segmentationConfig.model === "meet"
      ? buildSoftmaxStage(
          gl,
          vertexShader,
          positionBuffer,
          texCoordBuffer,
          segmentationConfig,
          tflite,
          segmentationTexture
        )
      : buildLoadSegmentationStage(
          gl,
          vertexShader,
          positionBuffer,
          texCoordBuffer,
          segmentationConfig,
          tflite,
          segmentationTexture
        );
  const jointBilateralFilterStage = buildJointBilateralFilterStage(
    gl,
    vertexShader,
    positionBuffer,
    texCoordBuffer,
    segmentationTexture,
    segmentationConfig,
    personMaskTexture
  );

  const renderMaskStage = buildRenderMaskStage(gl, positionBuffer, texCoordBuffer, "a");

  const backgroundBlurStage = buildBackgroundBlurStage(
    gl,
    vertexShader,
    positionBuffer,
    texCoordBuffer,
    personMaskTexture
  );
  const backgroundImageStage = buildBackgroundImageStage(
    gl,
    positionBuffer,
    texCoordBuffer,
    personMaskTexture
  );

  let backgroundType: BackgroundType = "blurred";
  const blankBackground = blankImageBitmap(400, 300);
  let staticBackground = blankBackground;
  let debugStage: number | undefined;

  const render = async (source: TexImageSource) => {
    if (gl.isContextLost()) {
      throw new WebGLContextLostError();
    }

    gl.clearColor(0, 0, 0, 0);
    gl.clear(gl.COLOR_BUFFER_BIT);

    gl.activeTexture(gl.TEXTURE0);
    gl.bindTexture(gl.TEXTURE_2D, inputFrameTexture);

    // texImage2D seems faster than texSubImage2D to upload
    // video texture
    gl.texImage2D(
      gl.TEXTURE_2D,
      0,
      gl.RGBA,
      gl.RGBA,
      gl.UNSIGNED_BYTE,
      source as TexImageSourceWebCodecs
    );

    gl.activeTexture(gl.TEXTURE10);
    gl.bindTexture(gl.TEXTURE_2D, personMaskTexture);

    gl.bindVertexArray(vertexArray);

    await resizingStage.render();

    performance.mark("VideoPipelineSegment-start");
    // eslint-disable-next-line no-underscore-dangle
    tflite._runInference();

    performance.mark("VideoPipelineSegment-end");
    performance.measure(
      "VideoPipelineSegment",
      "VideoPipelineSegment-start",
      "VideoPipelineSegment-end"
    );

    loadSegmentationStage.render();
    if (debugStage === 1) {
      gl.activeTexture(gl.TEXTURE10);
      gl.bindTexture(gl.TEXTURE_2D, segmentationTexture);
      renderMaskStage.render(10);
      return;
    }

    jointBilateralFilterStage.render();

    if (debugStage === 2) {
      gl.activeTexture(gl.TEXTURE10);
      gl.bindTexture(gl.TEXTURE_2D, personMaskTexture);
      renderMaskStage.render(10);
      return;
    }

    if (backgroundType === "blurred") {
      backgroundBlurStage.render();
    } else {
      backgroundImageStage.render(
        backgroundType === "static_image" ? staticBackground : blankBackground
      );
    }
  };

  const setStaticBackground = (staticBg: ImageBitmap) => {
    staticBackground = staticBg;
  };

  const setBackgroundType = (bgType: BackgroundType) => {
    backgroundType = bgType;
  };

  const updateBlurCount = (blurCount: number) => {
    backgroundBlurStage.updateBlurCount(blurCount);
  };

  const setForegroundOverlays = (stages: ImageBitmap[]) => {
    backgroundImageStage.setForegroundOverlays(stages);
  };

  const updateTouchUpStrength = (strength: number) => {};

  const setDebugStage = (stage?: number) => {
    debugStage = stage;
  };

  const cleanUp = () => {
    backgroundBlurStage.cleanUp();
    backgroundImageStage.cleanUp();
    jointBilateralFilterStage.cleanUp();
    loadSegmentationStage.cleanUp();
    resizingStage.cleanUp();

    gl.deleteTexture(personMaskTexture);
    gl.deleteTexture(segmentationTexture);
    gl.deleteTexture(inputFrameTexture);
    gl.deleteBuffer(texCoordBuffer);
    gl.deleteBuffer(positionBuffer);
    gl.deleteVertexArray(vertexArray);
    gl.deleteShader(vertexShader);
  };

  return {
    renderInternal: render,
    setStaticBackground,
    setBackgroundType,
    updateBlurCount,
    setForegroundOverlays,
    cleanUp,
    updateTouchUpStrength,
    setDebugStage,
  };
};
