Skip to content

File aproximate_math.hpp

File List > aproximate_math.hpp

Go to the documentation of this file

#pragma once

#include <cmath> // For std::signbit
#include <cstdint>
#include <functional>
#include <iostream>
#include <vector>

#include "fpm_adapter.hpp"

namespace Approximate {
    template <typename T>
    struct ApproximateResult {
        bool converged = false; 
        T    result;            
    };

    template <typename T>
    constexpr ApproximateResult<T>
    small_root(const std::function<T(const T&)> func, const T error = T(0.001), const uint8_t rounds = 16) {
        try {
            T closestNonZero;
            if constexpr (std::numeric_limits<T>::is_integer) {
                closestNonZero = 1;
            } else {
                closestNonZero = std::numeric_limits<T>::epsilon();
            }

            T leftInput = T(0);
            T rightInput = closestNonZero;
            T midInput;

            T leftValue = func(leftInput);
            T rightValue = func(rightInput);
            T midValue;

            uint8_t round = 0;

            auto sign_bit = [](auto val) {
                if constexpr (fpm::is_fixed<T>::value) {
                    return fpm::signbit(val);
                } else {
                    return std::signbit(val);
                }
            };

            // Find an interval containing the first root by expanding the search window.
            while ((sign_bit(leftValue) == sign_bit(rightValue)) && (round < rounds)) {
                leftInput = rightInput;
                leftValue = rightValue;
                rightInput += 4*(closestNonZero+round);
                rightValue = func(rightInput);
                round++; //-- TODO: evaluate the impact of this check.
            }

            round = 0; // Reset round counter for the refinement loop
            do {
                // Secant method - zero of secant of function at best guess
                T deltaInput = rightInput - leftInput;
                T deltaValue; // Declare the variable first

                if constexpr (fpm::is_fixed<T>::value) {
                    // Cast raw values to int64_t for safe subtraction, preventing 32-bit overflow.
                    int64_t raw_right = rightValue.raw_value();
                    int64_t raw_left = leftValue.raw_value();
                    int64_t raw_delta = raw_right - raw_left;

                    // Retrieve the underlying raw type (likely int32_t) for clamping.
                    const int64_t min_raw = std::numeric_limits<T>::min();
                    const int64_t max_raw = std::numeric_limits<T>::max();

                    // Clamp the 64-bit result back into the safe 32-bit range of T.
                    if (raw_delta > max_raw) {
                        raw_delta = max_raw;
                    } else if (raw_delta < min_raw) {
                        raw_delta = min_raw;
                    }

                    // Assign the safely clamped value back to the fixed-point type T.
                    deltaValue = T::from_raw_value(static_cast<T::base_type>(raw_delta));
                } else {
                    // For all other types (floating point, safe integers) use standard subtraction.
                    deltaValue = rightValue - leftValue;
                }

                // Check for potential overflow before multiplying.
                bool willOverflow = false;
                if (rightValue != 0 && deltaInput != 0) {
                    if (std::abs(deltaInput) > std::numeric_limits<T>::max() / std::abs(rightValue)) {
                        willOverflow = true;
                    }
                }

                bool secantFailure = false;
                if (!willOverflow && deltaValue != 0) {
                    // Safe to use standard logic
                    midInput = rightInput - (rightValue * deltaInput) / deltaValue;
                    // If secant intersects outside our boundary, pick next best guess to be midpoint between edges
                    // instead.
                    if (midInput <= leftInput || midInput >= rightInput) {
                        secantFailure = true;
                    }
                }
                if (willOverflow || secantFailure || deltaValue == 0) {
                    // Fallback to Bisection if overflow is imminent, secant didn't work, or division by zero
                    midInput = leftInput + (deltaInput / 2);
                }

                midValue = func(midInput);

                // Narrow the search interval based on the sign of the function value.
                if (sign_bit(leftValue) == sign_bit(midValue)) {
                    leftInput = midInput;
                    leftValue = midValue;
                } else {
                    rightInput = midInput;
                    rightValue = midValue;
                }

                // This should check if proportional error is less than the threshold
                if ((rightInput - leftInput) / rightInput <= error) {
                    return ApproximateResult<T>(true, midInput);
                }
            } while (round++ < rounds);
            return ApproximateResult<T>(false, midInput);
        } catch (std::runtime_error& e) {
            return ApproximateResult<T>(false, 0);
        }
    }

    template <typename T>
    ApproximateResult<std::vector<T>> n_roots(
        const std::function<T(const T&)> func,
        const uint8_t                    n_roots,
        const T                          error = T(0.001),
        const uint8_t                    rounds = 16
    ) {
        std::vector<T> roots;
        T              last_root = 0;

        for (int i = 0; i < n_roots; i++) {
            // Start searching slightly after the last root to avoid finding it again.
            const T search_start = last_root + error;

            std::function<T(const T&)> shifted_func = [&](const T& x) { return func(x + search_start); };
            auto                       result = small_root(shifted_func, error, rounds);
            if (result.converged) {
                // The new root is relative to the search start.
                last_root = search_start + result.result;
                roots.push_back(last_root);
            } else {
                // If we can't find another root, we're done.
                return ApproximateResult<std::vector<T>>{false, roots};
            }
        }

        return ApproximateResult<std::vector<T>>{true, roots};
    }
} // namespace Approximate