File aproximate_math.hpp
File List > aproximate_math.hpp
Go to the documentation of this file
#pragma once
#include <cmath> // For std::signbit
#include <cstdint>
#include <functional>
#include <iostream>
#include <vector>
#include "fpm_adapter.hpp"
namespace Approximate {
template <typename T>
struct ApproximateResult {
bool converged = false;
T result;
};
template <typename T>
constexpr ApproximateResult<T>
small_root(const std::function<T(const T&)> func, const T error = T(0.001), const uint8_t rounds = 16) {
try {
T closestNonZero;
if constexpr (std::numeric_limits<T>::is_integer) {
closestNonZero = 1;
} else {
closestNonZero = std::numeric_limits<T>::epsilon();
}
T leftInput = T(0);
T rightInput = closestNonZero;
T midInput;
T leftValue = func(leftInput);
T rightValue = func(rightInput);
T midValue;
uint8_t round = 0;
auto sign_bit = [](auto val) {
if constexpr (fpm::is_fixed<T>::value) {
return fpm::signbit(val);
} else {
return std::signbit(val);
}
};
// Find an interval containing the first root by expanding the search window.
while ((sign_bit(leftValue) == sign_bit(rightValue)) && (round < rounds)) {
leftInput = rightInput;
leftValue = rightValue;
rightInput += 4*(closestNonZero+round);
rightValue = func(rightInput);
round++; //-- TODO: evaluate the impact of this check.
}
round = 0; // Reset round counter for the refinement loop
do {
// Secant method - zero of secant of function at best guess
T deltaInput = rightInput - leftInput;
T deltaValue; // Declare the variable first
if constexpr (fpm::is_fixed<T>::value) {
// Cast raw values to int64_t for safe subtraction, preventing 32-bit overflow.
int64_t raw_right = rightValue.raw_value();
int64_t raw_left = leftValue.raw_value();
int64_t raw_delta = raw_right - raw_left;
// Retrieve the underlying raw type (likely int32_t) for clamping.
const int64_t min_raw = std::numeric_limits<T>::min();
const int64_t max_raw = std::numeric_limits<T>::max();
// Clamp the 64-bit result back into the safe 32-bit range of T.
if (raw_delta > max_raw) {
raw_delta = max_raw;
} else if (raw_delta < min_raw) {
raw_delta = min_raw;
}
// Assign the safely clamped value back to the fixed-point type T.
deltaValue = T::from_raw_value(static_cast<T::base_type>(raw_delta));
} else {
// For all other types (floating point, safe integers) use standard subtraction.
deltaValue = rightValue - leftValue;
}
// Check for potential overflow before multiplying.
bool willOverflow = false;
if (rightValue != 0 && deltaInput != 0) {
if (std::abs(deltaInput) > std::numeric_limits<T>::max() / std::abs(rightValue)) {
willOverflow = true;
}
}
bool secantFailure = false;
if (!willOverflow && deltaValue != 0) {
// Safe to use standard logic
midInput = rightInput - (rightValue * deltaInput) / deltaValue;
// If secant intersects outside our boundary, pick next best guess to be midpoint between edges
// instead.
if (midInput <= leftInput || midInput >= rightInput) {
secantFailure = true;
}
}
if (willOverflow || secantFailure || deltaValue == 0) {
// Fallback to Bisection if overflow is imminent, secant didn't work, or division by zero
midInput = leftInput + (deltaInput / 2);
}
midValue = func(midInput);
// Narrow the search interval based on the sign of the function value.
if (sign_bit(leftValue) == sign_bit(midValue)) {
leftInput = midInput;
leftValue = midValue;
} else {
rightInput = midInput;
rightValue = midValue;
}
// This should check if proportional error is less than the threshold
if ((rightInput - leftInput) / rightInput <= error) {
return ApproximateResult<T>(true, midInput);
}
} while (round++ < rounds);
return ApproximateResult<T>(false, midInput);
} catch (std::runtime_error& e) {
return ApproximateResult<T>(false, 0);
}
}
template <typename T>
ApproximateResult<std::vector<T>> n_roots(
const std::function<T(const T&)> func,
const uint8_t n_roots,
const T error = T(0.001),
const uint8_t rounds = 16
) {
std::vector<T> roots;
T last_root = 0;
for (int i = 0; i < n_roots; i++) {
// Start searching slightly after the last root to avoid finding it again.
const T search_start = last_root + error;
std::function<T(const T&)> shifted_func = [&](const T& x) { return func(x + search_start); };
auto result = small_root(shifted_func, error, rounds);
if (result.converged) {
// The new root is relative to the search start.
last_root = search_start + result.result;
roots.push_back(last_root);
} else {
// If we can't find another root, we're done.
return ApproximateResult<std::vector<T>>{false, roots};
}
}
return ApproximateResult<std::vector<T>>{true, roots};
}
} // namespace Approximate