// Segmentation.js
// This processor refines the raw segmentation mask from the model by
// applying soft thresholding, temporal smoothing, and a lightweight blur filter
// to produce smoother edges similar to production applications like Google Meet.

class SegmentationProcessor {
    constructor() {
        // Increased resolution to improve edge quality while still balancing performance.
        this.PROCESSING_WIDTH = 512;   // Increased width from 256 to 512
        this.PROCESSING_HEIGHT = 288;  // Increased height from 144 to 288

        // Parameters for soft thresholding (tweak these for your desired edge softness)
        this.threshold = 0.2; // Confidence cutoff for segmentation
        this.softness = 0.07;  // Range around the threshold for gradual blending

        // For temporal smoothing: we save the last processed mask
        this.lastMask = null;

        // Off-screen canvas to process the segmentation mask
        this.processingCanvas = document.createElement('canvas');
        this.processingCanvas.width = this.PROCESSING_WIDTH;
        this.processingCanvas.height = this.PROCESSING_HEIGHT;
        this.processingCtx = this.processingCanvas.getContext('2d', {
            willReadFrequently: true
        });
    }

    processSegmentation(results, ctx, video, width, height) {
        if (!results?.categoryMask || !ctx) return;

        // Save the state of the final canvas context
        ctx.save();
        ctx.clearRect(0, 0, width, height);

        // First, draw the video frame as the base.
        if (video.readyState >= 2) {
            ctx.drawImage(video, 0, 0, width, height);
        }

        // Create an ImageData object for our processing canvas
        const imageData = this.processingCtx.createImageData(
            this.PROCESSING_WIDTH,
            this.PROCESSING_HEIGHT
        );

        // Get the segmentation mask values (assumed to be a Float32Array)
        const maskArray = results.categoryMask.getAsFloat32Array();

        // Determine scaling factors (since the raw mask size may differ)
        const scaleX = this.PROCESSING_WIDTH / results.categoryMask.width;
        const scaleY = this.PROCESSING_HEIGHT / results.categoryMask.height;
        const data = imageData.data;

        // Loop through every pixel in our processing resolution
        for (let y = 0; y < this.PROCESSING_HEIGHT; y++) {
            for (let x = 0; x < this.PROCESSING_WIDTH; x++) {
                // Map the processing coordinate back to the raw mask coordinate
                const maskX = Math.floor(x / scaleX);
                const maskY = Math.floor(y / scaleY);
                const maskIndex = maskY * results.categoryMask.width + maskX;
                const category = maskArray[maskIndex];
                const pixelIndex = (y * this.PROCESSING_WIDTH + x) * 4;

                // Soft thresholding: instead of a binary cutoff, we use a smooth ramp.
                let alpha;
                if (category < this.threshold - this.softness) {
                    alpha = 0;
                } else if (category > this.threshold + this.softness) {
                    alpha = 255;
                } else {
                    // Linearly interpolate alpha between 0 and 255 in the transition band.
                    alpha =
                        255 *
                        ((category - (this.threshold - this.softness)) / (2 * this.softness));
                }

                // Set the color to white with the computed alpha value.
                data[pixelIndex] = 255;
                data[pixelIndex + 1] = 255;
                data[pixelIndex + 2] = 255;
                data[pixelIndex + 3] = alpha;
            }
        }

        // Write the processed ImageData to our off-screen canvas.
        this.processingCtx.putImageData(imageData, 0, 0);

        // --- Temporal Smoothing ---
        // Blend in the previous mask to reduce jitter. This simple blend can be
        // replaced by a more sophisticated per-pixel EMA if needed.
        if (this.lastMask) {
            this.processingCtx.globalAlpha = 0.35;
            this.processingCtx.drawImage(this.lastMask, 0, 0);
            this.processingCtx.globalAlpha = 1.0;
        }

        // Store the current mask for use in the next frame.
        if (!this.lastMask) {
            this.lastMask = document.createElement('canvas');
            this.lastMask.width = this.PROCESSING_WIDTH;
            this.lastMask.height = this.PROCESSING_HEIGHT;
        }
        const lastMaskCtx = this.lastMask.getContext('2d');
        if (lastMaskCtx) {
            lastMaskCtx.clearRect(0, 0, this.PROCESSING_WIDTH, this.PROCESSING_HEIGHT);
            lastMaskCtx.drawImage(this.processingCanvas, 0, 0);
        }

        // --- Post-Processing: Blur Filter ---
        // Create a temporary canvas to apply a light blur which helps to smooth out the pixelation.
        const blurredCanvas = document.createElement('canvas');
        blurredCanvas.width = this.PROCESSING_WIDTH;
        blurredCanvas.height = this.PROCESSING_HEIGHT;
        const blurredCtx = blurredCanvas.getContext('2d');

        // Apply a small blur filter (adjust the blur radius as needed)
        blurredCtx.filter = 'blur(2px)';
        blurredCtx.drawImage(this.processingCanvas, 0, 0);
        blurredCtx.filter = 'none';

        // --- Final Composition ---
        // Use the refined segmentation mask as a destination mask on the video frame.
        ctx.globalCompositeOperation = 'destination-out';
        ctx.drawImage(
            blurredCanvas,
            0,
            0,
            this.PROCESSING_WIDTH,
            this.PROCESSING_HEIGHT,
            0,
            0,
            width,
            height
        );

        // Restore the canvas context to its original state.
        ctx.restore();
    }
}

export default SegmentationProcessor;
