Skip to content

Commit

Permalink
feat: standard math functions have requirements and also support auto…
Browse files Browse the repository at this point in the history
…diff::real
  • Loading branch information
Franz R. Sattler committed Jan 15, 2025
1 parent 18d8d09 commit f8c2d44
Showing 1 changed file with 46 additions and 9 deletions.
55 changes: 46 additions & 9 deletions DiFfRG/include/DiFfRG/common/math.hh
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <DiFfRG/common/complex_math.hh>
#include <DiFfRG/common/cuda_prefix.hh>
#include <DiFfRG/common/utils.hh>
#include <type_traits>

namespace DiFfRG
{
Expand All @@ -25,6 +26,10 @@ namespace DiFfRG
* @return x^n
*/
template <int n, typename NumberType>
requires requires(NumberType x) {
x *x;
NumberType(1.) / x;
}
constexpr __forceinline__ __host__ __device__ NumberType powr(const NumberType x)
{
if constexpr (n == 0)
Expand Down Expand Up @@ -82,42 +87,74 @@ namespace DiFfRG
* @brief A compile-time evaluatable theta function
*/
template <typename NumberType>
constexpr __forceinline__ __host__ __device__ double heaviside_theta(const NumberType x)
requires requires(NumberType x) { x >= 0; }
constexpr __forceinline__ __host__ __device__ auto heaviside_theta(const NumberType x)
{
return x > 0 ? 1. : 0.;
if constexpr (std::is_same_v<NumberType, autodiff::real>)
return x >= 0. ? 1. : 0.;
else
return x >= static_cast<NumberType>(0) ? static_cast<NumberType>(1) : static_cast<NumberType>(0);
}

/**
* @brief A compile-time evaluatable sign function
*/
template <typename NumberType> constexpr __forceinline__ __host__ __device__ double sign(const NumberType x)
template <typename NumberType>
requires requires(NumberType x) { x >= 0; }
constexpr __forceinline__ __host__ __device__ auto sign(const NumberType x)
{
return x >= 0. ? 1. : -1.;
if constexpr (std::is_same_v<NumberType, autodiff::real>)
return x >= 0. ? 1. : -1.;
else
return x >= static_cast<NumberType>(0) ? static_cast<NumberType>(1) : static_cast<NumberType>(-1);
}

/**
* @brief Function to evaluate whether two floats are equal to numerical precision.
* Tests for both relative and absolute equality.
*
* @tparam T Type of the float
* @param eps_ Precision with which to compare a and b
* @return bool
*/
template <typename T1, typename T2>
bool __forceinline__ __host__ __device__
is_close(T1 a, T2 b, decltype(std::numeric_limits<T1>::epsilon()) eps_ = std::numeric_limits<T1>::epsilon())
template <typename T1, typename T2, typename T3>
requires std::is_floating_point<T3>::value
bool __forceinline__ __host__ __device__ is_close(T1 a, T2 b, T3 eps_)
{
if constexpr (std::is_same_v<T1, autodiff::real> || std::is_same_v<T2, autodiff::real>)
return is_close((double)a, (double)b, (double)eps_);
T1 diff = std::fabs(a - b);
if (diff <= eps_) return true;
if (diff <= std::fmax(std::fabs(a), std::fabs(b)) * eps_) return true;
return false;
}

/**
* @brief Function to evaluate whether two floats are equal to numerical precision.
* Tests for both relative and absolute equality.
*
* @return bool
*/
template <typename T1, typename T2> bool __forceinline__ __host__ __device__ is_close(T1 a, T2 b)
{
if constexpr (std::is_same_v<T1, autodiff::real> || std::is_same_v<T2, autodiff::real>) {
constexpr auto eps_ = std::numeric_limits<double>::epsilon() * 10.;
return is_close((double)a, (double)b, eps_);
} else {
constexpr auto eps_ = std::max(std::numeric_limits<T1>::epsilon(), std::numeric_limits<T2>::epsilon());
T1 diff = std::fabs(a - b);
if (diff <= eps_) return true;
if (diff <= std::fmax(std::fabs(a), std::fabs(b)) * eps_) return true;
}
return false;
}

/**
* @brief A dot product which takes the dot product between a1 and a2, assuming each has n entries which can be
* accessed via the [] operator.
*/
template <uint n, typename NT, typename A1, typename A2> NT dot(const A1 &a1, const A2 &a2)
template <uint n, typename NT, typename A1, typename A2>
requires requires(A1 a1, A2 a2) { a1[0] * a2[0]; }
NT dot(const A1 &a1, const A2 &a2)
{
NT ret = a1[0] * a2[0];
for (uint i = 1; i < n; ++i)
Expand Down

0 comments on commit f8c2d44

Please sign in to comment.