Skip to content

File spatial.h

File List > spatial.h

Go to the documentation of this file

#pragma once

#include <chrono>
#include <cstdint>

#include "aproximate_math.hpp"
#include "fpm_adapter.hpp"
#include "logger.h"
#include "utilities.h"
#include "vector.hpp"

using fixed = fixed_16_16;

class PositionVector;
class DistanceVector;
class VelocityVector;

template <typename Derived>
class FixedVector3D: public Vector3D<fixed, Derived> {
public:
    // -- Type Definitions --
    using Vec = Vector3D<fixed, Derived>;
    using typename Vec::NumericType;
    using Vec::rad2DegFactor;
    using Vec::Vec;
    using Vec::X_coord;
    using Vec::Y_coord;
    using Vec::Z_coord;

    // -- Public Methods --

    static fixed integer_sqrt(uint64_t n) {
        if (n < 2)
            return (uint32_t)n;

        uint64_t root = 0;
        uint64_t bit = 1ULL << 62; // The "test bit"

        // Optimization: Skip leading zeros to find the start range
        while (bit > n) {
            bit >>= 2;
        }

        while (bit != 0) {
            if (n >= root + bit) {
                n -= (root + bit);
                // Equivalent to: root += 2*bit; root >>= 1;
                root = (root >> 1) + bit;
            } else {
                root >>= 1;
            }
            bit >>= 2;
        }
        return fixed::from_raw_value(static_cast<int>((uint32_t)root));
    }

    fixed angleTo(const VectorCompatible<fixed> auto& other) const {
        if (!(*this && other)) {
            return NumericType(0);
        }
        // 1. Raw Coordinates (Scale: 1Q)
        int64_t aX = X_coord.raw_value();
        int64_t aY = Y_coord.raw_value();
        int64_t aZ = Z_coord.raw_value();
        int64_t bX = other.X_coord.raw_value();
        int64_t bY = other.Y_coord.raw_value();
        int64_t bZ = other.Z_coord.raw_value();

        // 2. Dot Product (Scale: 2Q)
        int64_t dot_raw = aX * bX + aY * bY + aZ * bZ;

        // 3. Cross Product Components (Scale: 2Q)
        int64_t cX_2Q = aY * bZ - aZ * bY;
        int64_t cY_2Q = aZ * bX - aX * bZ;
        int64_t cZ_2Q = aX * bY - aY * bX;

        // 4. SAFE MAGNITUDE CALCULATION
        // We MUST shift down to 1Q before squaring to avoid int64 overflow.
        // (Squaring a 2Q number results in 4Q, which overflows int64 if value > 0.7)
        int64_t cX_1Q = cX_2Q >> fixed::FixedBits;
        int64_t cY_1Q = cY_2Q >> fixed::FixedBits;
        int64_t cZ_1Q = cZ_2Q >> fixed::FixedBits;

        // Sum of Squares (Scale: 2Q)
        // 1Q * 1Q = 2Q. Summing three 2Q numbers fits safely in int64 (up to ~46,000 vector magnitude).
        uint64_t cross_sq_2Q = (uint64_t)(cX_1Q * cX_1Q + cY_1Q * cY_1Q + cZ_1Q * cZ_1Q);

        // 5. Square Root
        // The input (cross_sq_2Q) is a Q32 value (squared Q16). The output of the
        // sqrt will be a Q16 value, which we store in a fixed type directly.
        fixed cross_magnitude_1Q = integer_sqrt(cross_sq_2Q);

        // 6. Final Angle
        // Ensure dot product is also Q16
        fixed dot_product_1Q = fixed::from_raw_value(static_cast<int>(dot_raw >> fixed::FixedBits));

        return atan2(cross_magnitude_1Q, dot_product_1Q) * rad2DegFactor;
    }

    fixed dot(const VectorCompatible<fixed> auto& other) const {
        if (!(*this && other)) {
            return NumericType(0);
        }

        int64_t aX = X_coord.raw_value();
        int64_t aY = Y_coord.raw_value();
        int64_t aZ = Z_coord.raw_value();
        int64_t bX = other.X_coord.raw_value();
        int64_t bY = other.Y_coord.raw_value();
        int64_t bZ = other.Z_coord.raw_value();
        int64_t dot_product_raw = (int64_t)aX * bX + (int64_t)aY * bY + (int64_t)aZ * bZ;
        return fixed::from_raw_value(static_cast<int>(dot_product_raw >> fixed::FixedBits));
    }

    Derived cross(const VectorCompatible<fixed> auto& other) const {
        int64_t aX = X_coord.raw_value();
        int64_t aY = Y_coord.raw_value();
        int64_t aZ = Z_coord.raw_value();
        int64_t bX = other.X_coord.raw_value();
        int64_t bY = other.Y_coord.raw_value();
        int64_t bZ = other.Z_coord.raw_value();

        int64_t cX_raw = (int64_t)aY * bZ - (int64_t)aZ * bY;
        int64_t cY_raw = (int64_t)aZ * bX - (int64_t)aX * bZ;
        int64_t cZ_raw = (int64_t)aX * bY - (int64_t)aY * bX;

        return Derived(
            fixed::from_raw_value(static_cast<int>(cX_raw >> fixed::FixedBits)),
            fixed::from_raw_value(static_cast<int>(cY_raw >> fixed::FixedBits)),
            fixed::from_raw_value(static_cast<int>(cZ_raw >> fixed::FixedBits))
        );
    }

    fixed magnitude() const {
        if (!(*this)) {
            return NumericType(0);
        }
        int64_t aX = X_coord.raw_value();
        int64_t aY = Y_coord.raw_value();
        int64_t aZ = Z_coord.raw_value();

        int64_t dot_product_raw = aX * aX + aY * aY + aZ * aZ;
        return integer_sqrt(static_cast<uint64_t>(dot_product_raw));
    }

    fixed magnitudeXY() const {
        if (!(*this)) {
            return NumericType(0);
        }
        int64_t aX = X_coord.raw_value();
        int64_t aY = Y_coord.raw_value();

        int64_t dot_product_raw = aX * aX + aY * aY;
        return integer_sqrt(static_cast<uint64_t>(dot_product_raw));
    }

    fixed magnitudeXZ() const {
        if (!(*this)) {
            return NumericType(0);
        }
        int64_t aX = X_coord.raw_value();
        int64_t aZ = Z_coord.raw_value();

        int64_t dot_product_raw = aX * aX + aZ * aZ;
        return integer_sqrt(static_cast<uint64_t>(dot_product_raw));
    }

    fixed magnitudeYZ() const {
        if (!(*this)) {
            return NumericType(0);
        }
        int64_t aY = Y_coord.raw_value();
        int64_t aZ = Z_coord.raw_value();

        int64_t dot_product_raw = aY * aY + aZ * aZ;
        return integer_sqrt(static_cast<uint64_t>(dot_product_raw));
    }

    Derived normalize() const {
        fixed mag = magnitude();
        if (mag != 0) {
            return Derived(X_coord / mag, Y_coord / mag, Z_coord / mag);
        }
        return Derived(X_coord, Y_coord, Z_coord);
    }

    NumericType pitch() const {
        auto mag = magnitudeXY();
        if (!(Z_coord || mag)) {
            return 0;
        }
        return atan2(Z_coord, magnitudeXY()) * rad2DegFactor;
    }

    NumericType yaw() const {
        if (!(X_coord || Y_coord)) {
            return 0;
        }
        return atan2(X_coord, Y_coord) * rad2DegFactor;
    }
};

class DistanceVector: public FixedVector3D<DistanceVector> {
public:
    // -- Type Definitions --
    using Vec = FixedVector3D<DistanceVector>;

    // -- Constructors --
    DistanceVector() = default;
    DistanceVector(const DistanceVector& other) = default;
    constexpr DistanceVector(fixed x, fixed y, fixed z);
    DistanceVector(VelocityVector v, ChronoDuration auto interval);
};

class PositionVector: public FixedVector3D<PositionVector> {
public:
    // -- Type Definitions --
    using Vec = FixedVector3D<PositionVector>;

    // -- Constructors --
    PositionVector() = default;
    PositionVector(const PositionVector& other) = default;
    constexpr PositionVector(fixed x, fixed y, fixed z);
    PositionVector(PositionVector, DistanceVector);
    PositionVector(PositionVector p, VelocityVector v, ChronoDuration auto interval);

    // -- Public Methods --
    fixed Pitch();
    fixed Yaw();
    fixed Distance();

private:
    // -- Private Attributes --
    fixed _distance = 0; 
    fixed _pitch = 0;    
    fixed _yaw = 0;      
};

class VelocityVector: public FixedVector3D<VelocityVector> {
public:
    // -- Type Definitions --
    using Vec = FixedVector3D<VelocityVector>;

    // -- Constructors --
    VelocityVector() = default;
    VelocityVector(const VelocityVector& other) = default;
    constexpr VelocityVector(fixed x, fixed y, fixed z);
    VelocityVector(DistanceVector, TimeInterval interval);
};

// ======================================================================================
// --- Inline-Defined Methods ---
// ======================================================================================

// -- DistanceVector --
constexpr DistanceVector::DistanceVector(fixed x, fixed y, fixed z): Vec(x, y, z) {}

inline DistanceVector::DistanceVector(VelocityVector v, ChronoDuration auto interval) {
    *this = v * interval;
}

// -- PositionVector --
constexpr PositionVector::PositionVector(fixed x, fixed y, fixed z): Vec(x, y, z) {}

inline PositionVector::PositionVector(PositionVector p, DistanceVector d) {
    *this = p + d;
}

inline PositionVector::PositionVector(PositionVector p, VelocityVector v, ChronoDuration auto interval) {
    *this = p + v * interval;
}

inline fixed PositionVector::Pitch() {
    if (!_pitch) {
        _pitch = pitch();
    }
    return _pitch;
}

inline fixed PositionVector::Yaw() {
    if (!_yaw) {
        _yaw = yaw();
    }
    return _yaw;
}

inline fixed PositionVector::Distance() {
    if (!_distance) {
        _distance = magnitudeXY();
    }
    return _distance;
}

// -- VelocityVector --
constexpr VelocityVector::VelocityVector(fixed x, fixed y, fixed z): Vec(x, y, z) {}

inline VelocityVector::VelocityVector(DistanceVector dist, TimeInterval interval) {
    if (interval.count()) {
        X_coord = dist.X_coord / interval.count();
        Y_coord = dist.Y_coord / interval.count();
        Z_coord = dist.Z_coord / interval.count();
    }
}

// ======================================================================================
// --- Operator Overloads ---
// ======================================================================================

constexpr const VelocityVector operator/(const DistanceVector& D, const ChronoDuration auto& interval) {
    auto scale = interval.count();
    return VelocityVector(D.X_coord / scale, D.Y_coord / scale, D.Z_coord / scale);
}

constexpr const DistanceVector operator*(const VelocityVector& V, const ChronoDuration auto& interval) {
    auto scale = static_cast<fixed>(interval.count());
    return DistanceVector(V.X_coord, V.Y_coord, V.Z_coord) * scale;
}

constexpr const PositionVector operator+(const PositionVector& A, const DistanceVector& B) {
    return PositionVector(A.X_coord + B.X_coord, A.Y_coord + B.Y_coord, A.Z_coord + B.Z_coord);
}

constexpr const DistanceVector operator-(const PositionVector& A, const PositionVector& B) {
    return DistanceVector(A.X_coord - B.X_coord, A.Y_coord - B.Y_coord, A.Z_coord - B.Z_coord);
}