import * as Cesium from 'cesium';
import { ExtendedCamera, hasDistortionParameters } from '../store/slices/cameras';
import quartic from 'quartic';

export type PointProjectionResult = { u: number; v: number } | null;

interface Params {
    pointInCameraSpace: Cesium.Cartesian3;
    camera: ExtendedCamera;
}

export function projectPointOntoCamera(point: Cesium.Cartesian3, camera: ExtendedCamera): PointProjectionResult {
    // 4x4 matrix tranforming the point from camera coordinate system to world coordinate systemcoordinate system
    const cameraToWorldTransform = Cesium.Matrix4.fromArray(camera.transform.flat());
    // 4x4 matrix tranforming the point from the world coordinate system to camera coordinate systemcoordinate system
    const worldToCameraTransform = Cesium.Matrix4.inverse(cameraToWorldTransform, new Cesium.Matrix4());
    // point coordinates in camera space
    const pointInCameraSpace = Cesium.Matrix4.multiplyByPoint(worldToCameraTransform, point, new Cesium.Cartesian3());
    if (pointInCameraSpace.z < Number.EPSILON) {
        // the point is behind the camera
        return null;
    }
    const params: Params = { pointInCameraSpace, camera };

    if (camera.type === 'Frame' && hasDistortionParameters(camera)) return project3DPointOntoFrameCameraImage(params);
    if (camera.type === 'Fisheye' && hasDistortionParameters(camera))
        return project3DPointOntoFishEyeCameraImage(params);
    if (camera.type === 'Spherical') return project3DPointOntoSphericalCameraImage(params);
    if (camera.type === 'Cylindrical') return project3DPointOntoCylindricalCameraImage(params);

    return roughlyProject3DPointOntoAnyCameraImage(params);
}

function project3DPointOntoFrameCameraImage({ pointInCameraSpace, camera }: Params): PointProjectionResult {
    //image width & height (in pixels)
    const w = camera.width;
    const h = camera.height;
    //focal length (in pixels)
    const f = camera.f;
    //principal point offset (in pixels)
    const [cx, cy] = camera.pPoint;
    //radial distortion coefficients (dimensionless)
    const [k1, k2, k3, k4] = camera.rDistortion;
    //tangential distortion coefficients (dimensionless)
    const [p1, p2] = camera.tDistortion;
    //affinity and non-orthogonality (skew) coefficients (in pixels)
    const [b1, b2] = camera.aCoefficients;

    const x = pointInCameraSpace.x / pointInCameraSpace.z;
    const y = pointInCameraSpace.y / pointInCameraSpace.z;

    const xpow2 = Math.pow(x, 2);
    const ypow2 = Math.pow(y, 2);
    const r = Math.sqrt(xpow2 + ypow2);
    const rpow2 = Math.pow(r, 2);
    const rpow4 = Math.pow(r, 4);
    const rpow6 = Math.pow(r, 6);
    const rpow8 = Math.pow(r, 8);

    const xi =
        x * (1 + (k1 * rpow2 + k2 * rpow4) + k3 * rpow6 + k4 * rpow8) + (p1 * (rpow2 + 2 * xpow2) + 2 * p2 * x * y);
    const yi =
        y * (1 + (k1 * rpow2 + k2 * rpow4) + k3 * rpow6 + k4 * rpow8) + (p2 * (rpow2 + 2 * ypow2) + 2 * p1 * x * y);

    //raster coordinates of point
    const u = Math.floor(w * 0.5 + cx + xi * f + xi * b1 + yi * b2);
    const v = Math.floor(h * 0.5 + cy + yi * f);

    if (u < 0 || u > w || v < 0 || v > h) return null;
    if (!isPointWithinDistortionMonotonicityRadius(rpow2, k1, k2, k3, k4)) return null;
    return { u, v };
}

const project3DPointOntoFishEyeCameraImage = ({ pointInCameraSpace, camera }: Params): PointProjectionResult => {
    //image width & height (in pixels)
    const w = camera.width;
    const h = camera.height;
    //focal length (in pixels)
    const f = camera.f;
    //principal point offset (in pixels)
    const [cx, cy] = camera.pPoint;
    //radial distortion coefficients (dimensionless)
    const [k1, k2, k3, k4] = camera.rDistortion;
    //tangential distortion coefficients (dimensionless)
    const [p1, p2] = camera.tDistortion;
    //affinity and non-orthogonality (skew) coefficients (in pixels)
    const [b1, b2] = camera.aCoefficients;

    const x0 = pointInCameraSpace.x / pointInCameraSpace.z;
    const y0 = pointInCameraSpace.y / pointInCameraSpace.z;
    const r0 = Math.sqrt(Math.pow(x0, 2) + Math.pow(y0, 2));

    const x = (x0 * Math.atan(r0)) / r0;
    const y = (y0 * Math.atan(r0)) / r0;

    const xpow2 = Math.pow(x, 2);
    const ypow2 = Math.pow(y, 2);
    const r = Math.sqrt(xpow2 + ypow2);
    const rpow2 = Math.pow(r, 2);
    const rpow4 = Math.pow(r, 4);
    const rpow6 = Math.pow(r, 6);
    const rpow8 = Math.pow(r, 8);

    const xi =
        x * (1 + (k1 * rpow2 + k2 * rpow4) + k3 * rpow6 + k4 * rpow8) + (p1 * (rpow2 + 2 * xpow2) + 2 * p2 * x * y);
    const yi =
        y * (1 + (k1 * rpow2 + k2 * rpow4) + k3 * rpow6 + k4 * rpow8) + (p2 * (rpow2 + 2 * ypow2) + 2 * p1 * x * y);

    //raster coordinates of point
    const u = Math.floor(w * 0.5 + cx + xi * f + xi * b1 + yi * b2);
    const v = Math.floor(h * 0.5 + cy + yi * f);

    if (u < 0 || u > w || v < 0 || v > h) return null;
    if (!isPointWithinDistortionMonotonicityRadius(rpow2, k1, k2, k3, k4)) return null;
    return { u, v };
};

function project3DPointOntoSphericalCameraImage({ camera, pointInCameraSpace }: Params): PointProjectionResult {
    //image width & height (in pixels)
    const w = camera.width;
    const h = camera.height;
    //focal length for spherical cameras (in pixels)
    const f = w / (2 * Math.PI);
    //raster coordinates of point
    const u = Math.floor(w * 0.5 + f * Math.atan(pointInCameraSpace.x / pointInCameraSpace.z));
    const v = Math.floor(
        h * 0.5 +
            f *
                Math.atan(
                    pointInCameraSpace.y /
                        Math.sqrt(Math.pow(pointInCameraSpace.x, 2) + Math.pow(pointInCameraSpace.z, 2))
                )
    );

    if (u < 0 || u > w || v < 0 || v > h) return null;
    else return { u, v };
}

function project3DPointOntoCylindricalCameraImage({ pointInCameraSpace, camera }: Params): PointProjectionResult {
    //image width & height (in pixels)
    const w = camera.width;
    const h = camera.height;

    //focal length for cylindrical cameras (in pixels)
    const f = w / (2 * Math.PI);

    //raster coordinates of point
    const u = Math.floor(w * 0.5 + f * Math.atan(pointInCameraSpace.x / pointInCameraSpace.z));
    const v = Math.floor(
        h * 0.5 +
            f *
                (pointInCameraSpace.y /
                    Math.sqrt(Math.pow(pointInCameraSpace.x, 2) + Math.pow(pointInCameraSpace.z, 2)))
    );

    if (u < 0 || u > w || v < 0 || v > h) return null;
    else return { u, v };
}

function roughlyProject3DPointOntoAnyCameraImage({ pointInCameraSpace, camera }: Params): PointProjectionResult {
    //image width & height (in pixels)
    const w = camera.width;
    const h = camera.height;
    //focal length (in pixels)
    const f = camera.f;

    const xi = pointInCameraSpace.x / pointInCameraSpace.z;
    const yi = pointInCameraSpace.y / pointInCameraSpace.z;

    const u = Math.floor(f * xi);
    const v = Math.floor(f * yi);

    if (u < 0 || u > w || v < 0 || v > h) return null;
    else return { u, v };
}

function isPointWithinDistortionMonotonicityRadius(
    rpow2: number,
    k1: number,
    k2: number,
    k3: number,
    k4: number
): boolean {
    let r2max = 1e9;
    const A = 9 * k4;
    const B = 7 * k3;
    const C = 5 * k2;
    const D = 3 * k1;
    const E = 1;
    const roots = quartic([A, B, C, D, E]);
    for (let i = 0; i < roots.length; i++) {
        if (roots[i].re > 0 && roots[i].im === 0) {
            r2max = Math.min(r2max, roots[i].re);
        }
    }
    return rpow2 < r2max;
}
