/* cspell:ignore multiclass */
import { ImageSegmenter, ImageSegmenterResult, MPMask } from "@mediapipe/tasks-vision";
import loader from "@mediapipe/tasks-vision/simd-loader.js";
import model from "@mediapipe/tasks-vision/simd-model.wasm";
import { Dimensions } from "../../../../shared/geometry/Dimensions.js";
import { logger } from "../../../../shared/infra/logger.js";
import selfie_multiclass from "../../../assets/facemesh/selfie_multiclass_256x256.tflite";
import selfie_landscape from "../../../assets/facemesh/selfie_segmenter_landscape.tflite";
import { BackgroundType, IVideoPipeline } from "../interfaces/IVideoPipeline.js";
import { PipelineSetupError, WebGLContextLostError } from "../interfaces/VideoPipelineError.js";
import { PipelineStatsRawDelta } from "../pipelineStats.js";
import { buildBilateralFilterStage } from "../stages/bilateralFilterStage.js";
import { buildBlurTextureStage } from "../stages/blurTextureStage.js";
import { buildRenderMaskStage } from "../stages/renderMaskStage.js";
import { buildRenderTextureStage } from "../stages/renderTextureStage.js";
import { compileShader, createTexture, glsl } from "../webglHelper.js";
import { FaceMeshWorkerProxy } from "./faceMeshWorker/FaceMeshWorkerProxy.js";
import { IFaceMeshWorkerProxy } from "./faceMeshWorker/IFaceMeshWorkerProxy.js";
import { buildAddForeheadStage } from "./stages/addForeheadStage.js";
import { buildBackgroundBlurStageMP } from "./stages/backgroundBlurStageMP.js";
import { buildBackgroundImageStageMP } from "./stages/backgroundImageStageMP.js";
import { buildJointBilateralFilterMPStage } from "./stages/jointBilateralFilterMPStage.js";
import { buildRenderFacialFeaturesStage } from "./stages/renderFacialFeaturesStage.js";

interface ExtendedVideoFrame extends VideoFrame {
  height?: number;
  width?: number;
}

const visionWasmFileset = {
  wasmLoaderPath: loader,
  wasmBinaryPath: model,
};

export const mediaPipePipelineVersion = 1;

export const setupMediaPipePipeline = async (
  sourcePlayback: Dimensions,
  gl: WebGL2RenderingContext,
  goodSegmentation: boolean,
  faceMesh: boolean
): Promise<IVideoPipeline> => {
  logger.info(
    `setting up MediaPipePipeline: goodSegmentation: ${goodSegmentation}, faceMesh: ${faceMesh}`
  );

  let faceMeshWorkerProxy: FaceMeshWorkerProxy | undefined;

  if (faceMesh) {
    faceMeshWorkerProxy = new FaceMeshWorkerProxy();
    await faceMeshWorkerProxy.waitForWorker();
  }

  const imageSegmenter = await ImageSegmenter.createFromOptions(visionWasmFileset, {
    baseOptions: {
      modelAssetPath: goodSegmentation ? selfie_multiclass : selfie_landscape,
      delegate: "GPU",
    },
    runningMode: "VIDEO",
    outputCategoryMask: false,
    outputConfidenceMasks: true,
    canvas: gl.canvas,
  });

  const pipeline: IVideoPipeline = buildMediaPipePipeline(
    sourcePlayback,
    gl,
    faceMeshWorkerProxy,
    imageSegmenter,
    goodSegmentation,
    faceMesh
  );

  return pipeline;
};

const buildMediaPipePipeline = (
  sourceDimensions: Dimensions,
  gl: WebGL2RenderingContext,
  faceMeshWorkerProxy: IFaceMeshWorkerProxy | undefined,
  imageSegmenter: ImageSegmenter,
  goodSegmentation: boolean,
  faceMesh: boolean
): IVideoPipeline => {
  const vertexShaderSource = glsl`#version 300 es

    in vec2 a_position;
    in vec2 a_texCoord;

    out vec2 v_texCoord;

    void main() {
      gl_Position = vec4(a_position, 0.0, 1.0);
      v_texCoord = a_texCoord;
    }
  `;

  const { width: frameWidth, height: frameHeight } = sourceDimensions;
  let blurCount = 0;
  let filterStrength = 0;

  const vertexShader = compileShader(gl, gl.VERTEX_SHADER, vertexShaderSource);

  const vertexArray = gl.createVertexArray();
  gl.bindVertexArray(vertexArray);

  const positionBuffer = gl.createBuffer();
  if (!positionBuffer) {
    throw new Error("unable to create positionBuffer");
  }
  gl.bindBuffer(gl.ARRAY_BUFFER, positionBuffer);
  gl.bufferData(
    gl.ARRAY_BUFFER,
    new Float32Array([-1.0, -1.0, 1.0, -1.0, -1.0, 1.0, 1.0, 1.0]),
    gl.STATIC_DRAW
  );

  let backgroundType: BackgroundType = "blurred";
  let debugStage: number | undefined;

  const texCoordBuffer = gl.createBuffer();
  if (!texCoordBuffer) {
    throw new Error("unable to create texCoordBuffer");
  }
  gl.bindBuffer(gl.ARRAY_BUFFER, texCoordBuffer);
  gl.bufferData(
    gl.ARRAY_BUFFER,
    new Float32Array([0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0]),
    gl.STATIC_DRAW
  );

  gl.bindVertexArray(null);

  const inputFrameTexture = gl.createTexture();
  if (!inputFrameTexture) {
    throw new Error("error creating input frame texture");
  }
  gl.bindTexture(gl.TEXTURE_2D, inputFrameTexture);
  gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
  gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
  gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
  gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.NEAREST);

  const featureMaskTexture = createTexture(gl, gl.R8, frameWidth, frameHeight);
  const combinedMaskTexture = createTexture(gl, gl.R8, frameWidth, frameHeight);
  const blurredFeatureMaskTexture = createTexture(gl, gl.R8, frameWidth, frameHeight);
  const inputWithTouchupTexture = createTexture(gl, gl.RGBA8, frameWidth, frameHeight);
  const selfieMaskTexture = createTexture(gl, gl.R8, frameWidth, frameHeight);
  const sharpenedSelfieMaskTexture = createTexture(gl, gl.R8, frameWidth, frameHeight);
  const staticBackgroundTexture = createTexture(gl, gl.RGBA8, frameWidth, frameHeight);
  const landmarkBuffer = gl.createBuffer();
  if (!landmarkBuffer) {
    throw new PipelineSetupError("unable to create landmarkBuffer");
  }

  const renderFacialFeaturesStage = buildRenderFacialFeaturesStage(
    gl,
    landmarkBuffer,
    featureMaskTexture
  );
  const addForeheadStage = buildAddForeheadStage(
    gl,
    vertexShader,
    positionBuffer,
    texCoordBuffer,
    combinedMaskTexture
  );
  const renderMaskStage = buildRenderMaskStage(gl, positionBuffer, texCoordBuffer, "r");
  const renderRawMaskStage = buildRenderMaskStage(gl, positionBuffer, texCoordBuffer, "r", true);
  const blurMaskStage = buildBlurTextureStage(
    gl,
    positionBuffer,
    texCoordBuffer,
    0,
    blurredFeatureMaskTexture
  );
  const bilateralFilterStage = buildBilateralFilterStage(
    gl,
    positionBuffer,
    texCoordBuffer,
    inputWithTouchupTexture
  );
  const renderTextureStage = buildRenderTextureStage(gl, positionBuffer, texCoordBuffer);
  const backgroundBlurStage = buildBackgroundBlurStageMP(
    gl,
    vertexShader,
    positionBuffer,
    texCoordBuffer
  );
  const jointBilateralFilterStage = buildJointBilateralFilterMPStage(
    gl,
    vertexShader,
    positionBuffer,
    texCoordBuffer,
    { width: 256, height: 256 },
    sharpenedSelfieMaskTexture,
    goodSegmentation
  );
  const backgroundImageStage = buildBackgroundImageStageMP(gl, positionBuffer, texCoordBuffer);

  let selfieMaskCountdown = 0;
  let selfieMaskSkipFrames = 1;
  let lastSelfieTime = 0;
  const maxSelfieTimeSkipped = 50;

  let faceMeshFrameCountdown = 0;
  const faceMeshSkipFrames = 4;
  let lastFaceMeshTime = 0;
  const maxFaceMeshTimeSkipped = 300;

  let backgroundTexture: WebGLTexture | null | undefined;
  let faceSkinTexture: WebGLTexture | null | undefined;

  const render = async (src: ExtendedVideoFrame) => {
    // If we run the MediaPipe multiclass segmentation model with a lost context, the context never
    // gets restored. Just don't do anything if the context is lost. We'll eventually be destroyed
    // and recreated.
    if (gl.isContextLost()) {
      throw new WebGLContextLostError();
    }

    const includeTouchup = faceMesh && filterStrength > 0;
    const maskBackground =
      (backgroundType === "blurred" && blurCount > 0) || backgroundType === "static_image";

    if ("codedWidth" in src) {
      // VideoFrame
      src.width = src.codedWidth;
      src.height = src.codedHeight;
    }

    // Using the timestamp from the VideoFrame for the inference doesn't seem to work, but it also
    // doesn't seem to be a problem to use the current time
    const timestamp = performance.now();

    if (includeTouchup && faceMeshWorkerProxy) {
      // Send one of every touchupFrameInterval frames to the FaceMesh worker for inference. (Then
      // we move on and separately use the latest available result when we need it, which is
      // probably not for the frame we just sent.)
      if (faceMeshFrameCountdown > 0 && timestamp - lastFaceMeshTime < maxFaceMeshTimeSkipped) {
        faceMeshFrameCountdown--;
      } else {
        lastFaceMeshTime = timestamp;
        faceMeshFrameCountdown = faceMeshSkipFrames;
        faceMeshWorkerProxy.infer(src.clone());
      }
    }

    let selfieMaskPromise: Promise<[MPMask, MPMask?]> | undefined;

    if (selfieMaskCountdown > 0 && timestamp - lastSelfieTime < maxSelfieTimeSkipped) {
      selfieMaskCountdown--;
    } else {
      lastSelfieTime = timestamp;
      selfieMaskCountdown = selfieMaskSkipFrames;
      selfieMaskPromise = new Promise<[MPMask, MPMask?]>((resolve, reject) => {
        if ("codedWidth" in src) {
          // VideoFrame
          src.width = src.codedWidth;
          src.height = src.codedHeight;
        }
        performance.mark("VideoPipelineSegment-start");
        imageSegmenter.segmentForVideo(
          src as unknown as ImageData,
          timestamp,
          (res: ImageSegmenterResult) => {
            const backgroundTex = res.confidenceMasks?.[0];
            const faceSkinTex = res.confidenceMasks?.[3];

            performance.mark("VideoPipelineSegment-end");
            performance.measure(
              "VideoPipelineSegment",
              "VideoPipelineSegment-start",
              "VideoPipelineSegment-end"
            );
            if (!backgroundTex) {
              reject("background texture not found");
              return;
            }
            if (goodSegmentation && !faceSkinTex) {
              reject("face skin tex not found");
            }
            resolve([backgroundTex, faceSkinTex]);
          }
        );
      });
    }

    let landmarkVertices: Float32Array | undefined;
    if (includeTouchup && faceMeshWorkerProxy) {
      landmarkVertices = faceMeshWorkerProxy.getLatestResults();

      if (landmarkVertices) {
        gl.bindBuffer(gl.ARRAY_BUFFER, landmarkBuffer);
        gl.bufferData(gl.ARRAY_BUFFER, landmarkVertices, gl.STREAM_DRAW);

        renderFacialFeaturesStage.render();
      }
    }

    gl.activeTexture(gl.TEXTURE0);
    gl.bindTexture(gl.TEXTURE_2D, inputFrameTexture);
    gl.texImage2D(
      gl.TEXTURE_2D,
      0,
      gl.RGBA,
      gl.RGBA,
      gl.UNSIGNED_BYTE,
      src as TexImageSourceWebCodecs
    );

    if (selfieMaskPromise) {
      const [backgroundMask, faceSkinMask] = await selfieMaskPromise;
      backgroundTexture = backgroundMask.getAsWebGLTexture();
      if (includeTouchup && faceSkinMask) {
        faceSkinTexture = faceSkinMask.getAsWebGLTexture();
      }
      backgroundMask.close();
      faceSkinMask?.close();
    }

    const finalTouchupMask = goodSegmentation ? combinedMaskTexture : blurredFeatureMaskTexture;
    if (includeTouchup) {
      if (landmarkVertices) {
        gl.activeTexture(gl.TEXTURE0);
        gl.bindTexture(gl.TEXTURE_2D, featureMaskTexture);
        if (debugStage === 3) {
          renderMaskStage.render(0);
          return;
        }

        blurMaskStage.render();

        gl.activeTexture(gl.TEXTURE0);
        gl.bindTexture(gl.TEXTURE_2D, blurredFeatureMaskTexture);

        if (goodSegmentation && faceSkinTexture) {
          addForeheadStage.render(landmarkVertices);
          gl.activeTexture(gl.TEXTURE1);
          gl.bindTexture(gl.TEXTURE_2D, faceSkinTexture);
          if (debugStage === 4) {
            renderMaskStage.render(1);
            return;
          }
        }

        if (debugStage === 5) {
          gl.activeTexture(gl.TEXTURE0);
          gl.bindTexture(gl.TEXTURE_2D, finalTouchupMask);
          renderMaskStage.render(0);
          return;
        }
      }

      gl.activeTexture(gl.TEXTURE0);
      gl.bindTexture(gl.TEXTURE_2D, inputFrameTexture);
      gl.activeTexture(gl.TEXTURE1);
      gl.bindTexture(gl.TEXTURE_2D, finalTouchupMask);
      bilateralFilterStage.render();
    }
    if (maskBackground) {
      if (!backgroundTexture) {
        throw new PipelineSetupError("selfieMaskTexture undefined");
      }
      gl.activeTexture(gl.TEXTURE0);
      gl.bindTexture(gl.TEXTURE_2D, inputFrameTexture);
      gl.activeTexture(gl.TEXTURE1);
      gl.bindTexture(gl.TEXTURE_2D, backgroundTexture);
      if (debugStage === 1) {
        renderRawMaskStage.render(1);
        return;
      }

      jointBilateralFilterStage.render();

      gl.activeTexture(gl.TEXTURE0);
      gl.bindTexture(gl.TEXTURE_2D, includeTouchup ? inputWithTouchupTexture : inputFrameTexture);
      gl.activeTexture(gl.TEXTURE1);
      gl.bindTexture(gl.TEXTURE_2D, sharpenedSelfieMaskTexture);
      if (debugStage === 2) {
        renderMaskStage.render(1);
        return;
      }

      if (backgroundType === "blurred") {
        backgroundBlurStage.render();
      } else if (backgroundType === "static_image") {
        backgroundImageStage.render();
      }
    } else if (includeTouchup) {
      gl.activeTexture(gl.TEXTURE0);
      gl.bindTexture(gl.TEXTURE_2D, inputWithTouchupTexture);
      renderTextureStage.render(0);
    }
  };

  const setStaticBackground = (staticBg: ImageBitmap) => {
    backgroundImageStage.setBackgroundImage(staticBg);
  };

  const setBackgroundType = (bgType: BackgroundType) => {
    backgroundType = bgType;
  };

  const updateBlurCount = (newBlurCount: number) => {
    blurCount = newBlurCount;
    backgroundBlurStage.updateBlurCount(newBlurCount);
  };

  const setForegroundOverlays = (stages: ImageBitmap[]) => {
    backgroundImageStage.setForegroundOverlays(stages);
  };

  const updateTouchUpStrength = (newFilterStrength: number) => {
    filterStrength = newFilterStrength;
    bilateralFilterStage.updateFilterStrength(newFilterStrength);
  };

  const setDebugStage = (stage?: number) => {
    debugStage = stage;
  };

  let selfieMaskSkipChanges = 0;
  const statsFeedback = (delta: PipelineStatsRawDelta) => {
    if (backgroundType === "blurred") {
      if (selfieMaskSkipFrames === 0) {
        selfieMaskSkipFrames = 1;
        selfieMaskSkipChanges = 0;
        logger.info("mediaPipePipeline: skipping frames because not using image background");
      }
    } else if (backgroundType !== "input") {
      // Skipping frames of the segmentation can look bad with virtual backgrounds (not really an
      // issue with blur), so stop skipping frames if we can when in virtual background
      if (selfieMaskSkipChanges > 10) {
        if (selfieMaskSkipFrames > 0) {
          // If we've toggled too many times, we're thrashing, so stop
          logger.info("mediaPipePipeline: permanently downgrading since we're thrashing");
          selfieMaskSkipFrames = 0;
        }
      } else {
        if (
          delta.segmentationFrames >= 15 &&
          delta.segmentationTime / delta.segmentationFrames < 10 &&
          selfieMaskSkipFrames > 0
        ) {
          logger.info("mediaPipePipeline: upgrading to not skip any frames");
          logger.info(backgroundType);
          selfieMaskSkipFrames = 0;
          selfieMaskSkipChanges++;
        } else if (
          selfieMaskSkipFrames === 0 &&
          (delta.segmentationFrames === 0 || delta.segmentationTime / delta.segmentationFrames > 15)
        ) {
          logger.info("mediaPipePipeline: downgrading to skip frames");
          selfieMaskSkipFrames = 1;
          selfieMaskSkipChanges++;
        }
      }
    }
  };

  const cleanUp = (includeWebGl = true) => {
    if (faceMeshWorkerProxy) {
      faceMeshWorkerProxy.terminate();
      faceMeshWorkerProxy = undefined;
    }
    imageSegmenter.close();

    if (includeWebGl) {
      backgroundBlurStage.cleanUp();
      backgroundImageStage.cleanUp();
      jointBilateralFilterStage.cleanUp();
      renderFacialFeaturesStage.cleanUp();
      bilateralFilterStage.cleanUp();

      gl.deleteBuffer(texCoordBuffer);
      gl.deleteBuffer(positionBuffer);
      gl.deleteVertexArray(vertexArray);
      gl.deleteShader(vertexShader);
    }
  };

  return {
    renderInternal: render,
    setStaticBackground,
    setBackgroundType,
    updateBlurCount,
    updateTouchUpStrength,
    setForegroundOverlays,
    statsFeedback,
    cleanUp,
    setDebugStage,
  };
};
