diff --git a/DESCRIPTION b/DESCRIPTION index 33a8010..e3a071e 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -14,7 +14,7 @@ Suggests: rmarkdown, tinytest Encoding: UTF-8 -RoxygenNote: 7.2.1 +RoxygenNote: 7.3.1 Imports: Rcpp (>= 1.0.8.3), RcppEigen LinkingTo: Rcpp, RcppEigen VignetteBuilder: knitr diff --git a/inst/include/autodiff/BUILD b/inst/include/autodiff/BUILD index a9bf12d..5c2114a 100644 --- a/inst/include/autodiff/BUILD +++ b/inst/include/autodiff/BUILD @@ -7,7 +7,7 @@ cc_library( cc_library( name = "reverse", - hdrs = glob(["reverse.hpp", "reverse/**/*.hpp"]), + hdrs = glob(["var.hpp", "reverse/**/*.hpp"]), deps = ["@com_github_eigen_eigen//:eigen", ":common"], visibility = ["//visibility:public"], @@ -15,7 +15,7 @@ cc_library( cc_library( name = "forward", - hdrs = glob(["forward.hpp", "forward/**/*.hpp"]), + hdrs = glob(["dual.hpp", "real.hpp", "forward/**/*.hpp"]), deps = [ "@com_github_eigen_eigen//:eigen", ":common", diff --git a/inst/include/autodiff/CMakeLists.txt b/inst/include/autodiff/CMakeLists.txt index e6146dc..eb409ed 100644 --- a/inst/include/autodiff/CMakeLists.txt +++ b/inst/include/autodiff/CMakeLists.txt @@ -14,6 +14,10 @@ target_include_directories(autodiff $ ) +if(CMAKE_CUDA_COMPILER) + target_compile_options(autodiff INTERFACE $<$:--expt-relaxed-constexpr --extended-lambda>) +endif() + # Install autodiff interface library install(TARGETS autodiff EXPORT autodiffTargets) diff --git a/inst/include/autodiff/common/binomialcoefficient.hpp b/inst/include/autodiff/common/binomialcoefficient.hpp index 0cea087..06f43d5 100644 --- a/inst/include/autodiff/common/binomialcoefficient.hpp +++ b/inst/include/autodiff/common/binomialcoefficient.hpp @@ -7,7 +7,7 @@ // // Licensed under the MIT License . // -// Copyright (c) 2018-2022 Allan Leal +// Copyright © 2018–2024 Allan Leal // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal diff --git a/inst/include/autodiff/common/classtraits.hpp b/inst/include/autodiff/common/classtraits.hpp index 154ba54..61c23c8 100644 --- a/inst/include/autodiff/common/classtraits.hpp +++ b/inst/include/autodiff/common/classtraits.hpp @@ -7,7 +7,7 @@ // // Licensed under the MIT License . // -// Copyright (c) 2018-2022 Allan Leal +// Copyright © 2018–2024 Allan Leal // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -95,5 +95,12 @@ CREATE_MEMBER_CHECK(size); template constexpr bool hasSize = has_member_size>::value; +// Create type trait struct `has_operator_bracket`. +template struct has_operator_bracket_impl : std::false_type {}; +template struct has_operator_bracket_impl().operator [](0)) ), T> : std::true_type {}; + +/// Boolean type that is true if type T implements `operator[](int)` method. +template struct has_operator_bracket : has_operator_bracket_impl {}; + } // namespace detail } // namespace autodiff diff --git a/inst/include/autodiff/common/eigen.hpp b/inst/include/autodiff/common/eigen.hpp index c5212f3..d89c06f 100644 --- a/inst/include/autodiff/common/eigen.hpp +++ b/inst/include/autodiff/common/eigen.hpp @@ -7,7 +7,7 @@ // // Licensed under the MIT License . // -// Copyright (c) 2018-2022 Allan Leal +// Copyright © 2018–2024 Allan Leal // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -130,6 +130,15 @@ struct VectorTraits> using ReplaceValueType = VectorReplaceValueType; }; +template +struct VectorTraits> +{ + using ValueType = VectorValueType; + + template + using ReplaceValueType = Eigen::Map, MapOptions, StrideType>; +}; + //===================================================================================================================== // // AUXILIARY TEMPLATE TYPE ALIASES diff --git a/inst/include/autodiff/common/meta.hpp b/inst/include/autodiff/common/meta.hpp index 63ade06..65da700 100644 --- a/inst/include/autodiff/common/meta.hpp +++ b/inst/include/autodiff/common/meta.hpp @@ -7,7 +7,7 @@ // // Licensed under the MIT License . // -// Copyright (c) 2018-2022 Allan Leal +// Copyright © 2018–2024 Allan Leal // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -34,24 +34,39 @@ #include #include +#ifndef AUTODIFF_DEVICE_FUNC +#ifdef AUTODIFF_EIGEN_FOUND + #include + #define AUTODIFF_DEVICE_FUNC EIGEN_DEVICE_FUNC +#else + #define AUTODIFF_DEVICE_FUNC +#endif +#endif + namespace autodiff { namespace detail { template -using EnableIf = typename std::enable_if::type; +using EnableIf = std::enable_if_t; + +template +using Requires = std::enable_if_t; template -using PlainType = typename std::remove_cv::type>::type; +using PlainType = std::remove_cv_t>; template -using ConditionalType = typename std::conditional::type; +using ConditionalType = std::conditional_t; template -using CommonType = typename std::common_type::type; +using CommonType = std::common_type_t; template using ReturnType = std::invoke_result_t; +template +constexpr bool isConst = std::is_const_v>; + template constexpr bool isConvertible = std::is_convertible, U>::value; @@ -62,13 +77,13 @@ template constexpr auto TupleSize = std::tuple_size_v>; template -constexpr auto TupleHead(Tuple&& tuple) +AUTODIFF_DEVICE_FUNC constexpr auto TupleHead(Tuple&& tuple) { return std::get<0>(std::forward(tuple)); } template -constexpr auto TupleTail(Tuple&& tuple) +AUTODIFF_DEVICE_FUNC constexpr auto TupleTail(Tuple&& tuple) { auto g = [&](auto&&, auto&&... args) constexpr { return std::forward_as_tuple(args...); @@ -85,7 +100,7 @@ struct Index }; template -constexpr auto AuxFor(Function&& f) +AUTODIFF_DEVICE_FUNC constexpr auto AuxFor(Function&& f) { if constexpr (i < iend) { f(Index{}); @@ -94,19 +109,19 @@ constexpr auto AuxFor(Function&& f) } template -constexpr auto For(Function&& f) +AUTODIFF_DEVICE_FUNC constexpr auto For(Function&& f) { AuxFor(std::forward(f)); } template -constexpr auto For(Function&& f) +AUTODIFF_DEVICE_FUNC constexpr auto For(Function&& f) { For<0, iend>(std::forward(f)); } template -constexpr auto AuxReverseFor(Function&& f) +AUTODIFF_DEVICE_FUNC constexpr auto AuxReverseFor(Function&& f) { if constexpr (i < iend) { @@ -116,19 +131,19 @@ constexpr auto AuxReverseFor(Function&& f) } template -constexpr auto ReverseFor(Function&& f) +AUTODIFF_DEVICE_FUNC constexpr auto ReverseFor(Function&& f) { AuxReverseFor(std::forward(f)); } template -constexpr auto ReverseFor(Function&& f) +AUTODIFF_DEVICE_FUNC constexpr auto ReverseFor(Function&& f) { ReverseFor<0, iend>(std::forward(f)); } template -constexpr auto ForEach(Tuple&& tuple, Function&& f) +AUTODIFF_DEVICE_FUNC constexpr auto ForEach(Tuple&& tuple, Function&& f) { constexpr auto N = TupleSize; For([&](auto i) constexpr { @@ -144,7 +159,7 @@ constexpr auto ForEach(Tuple&& tuple, Function&& f) } template -constexpr auto ForEach(Tuple1&& tuple1, Tuple2&& tuple2, Function&& f) +AUTODIFF_DEVICE_FUNC constexpr auto ForEach(Tuple1&& tuple1, Tuple2&& tuple2, Function&& f) { constexpr auto N1 = TupleSize; constexpr auto N2 = TupleSize; @@ -155,7 +170,7 @@ constexpr auto ForEach(Tuple1&& tuple1, Tuple2&& tuple2, Function&& f) } template -constexpr auto Sum(Function&& f) +AUTODIFF_DEVICE_FUNC constexpr auto Sum(Function&& f) { using ResultType = std::invoke_result_t>; ResultType res = {}; @@ -166,7 +181,7 @@ constexpr auto Sum(Function&& f) } template -constexpr auto Reduce(Tuple&& tuple, Function&& f) +AUTODIFF_DEVICE_FUNC constexpr auto Reduce(Tuple&& tuple, Function&& f) { using ResultType = std::invoke_result_t(tuple))>; ResultType res = {}; diff --git a/inst/include/autodiff/common/numbertraits.hpp b/inst/include/autodiff/common/numbertraits.hpp index b5f3c06..2126e02 100644 --- a/inst/include/autodiff/common/numbertraits.hpp +++ b/inst/include/autodiff/common/numbertraits.hpp @@ -7,7 +7,7 @@ // // Licensed under the MIT License . // -// Copyright (c) 2018-2022 Allan Leal +// Copyright © 2018–2024 Allan Leal // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal diff --git a/inst/include/autodiff/common/vectortraits.hpp b/inst/include/autodiff/common/vectortraits.hpp index 7b2a9f4..634b576 100644 --- a/inst/include/autodiff/common/vectortraits.hpp +++ b/inst/include/autodiff/common/vectortraits.hpp @@ -7,7 +7,7 @@ // // Licensed under the MIT License . // -// Copyright (c) 2018-2022 Allan Leal +// Copyright © 2018–2024 Allan Leal // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal diff --git a/inst/include/autodiff/forward/dual.hpp b/inst/include/autodiff/forward/dual.hpp index db67334..ba11962 100644 --- a/inst/include/autodiff/forward/dual.hpp +++ b/inst/include/autodiff/forward/dual.hpp @@ -7,7 +7,7 @@ // // Licensed under the MIT License . // -// Copyright (c) 2018-2022 Allan Leal +// Copyright © 2018–2024 Allan Leal // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal diff --git a/inst/include/autodiff/forward/dual/dual.hpp b/inst/include/autodiff/forward/dual/dual.hpp index 385398c..5e904e5 100644 --- a/inst/include/autodiff/forward/dual/dual.hpp +++ b/inst/include/autodiff/forward/dual/dual.hpp @@ -7,7 +7,7 @@ // // Licensed under the MIT License . // -// Copyright (c) 2018-2022 Allan Leal +// Copyright © 2018–2024 Allan Leal // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -479,7 +479,7 @@ template struct AuxDualOpType> { using type = Op; }; template -constexpr auto auxCommonDualType() +AUTODIFF_DEVICE_FUNC constexpr auto auxCommonDualType() { if constexpr (isArithmetic && isArithmetic) return CommonType(); @@ -508,31 +508,21 @@ struct AuxCommonDualType { using type = decltype(auxCommonDualType()); }; template struct Dual { - T val; + T val = {}; - G grad; + G grad = {}; - Dual() : Dual(0.0) {} + AUTODIFF_DEVICE_FUNC constexpr Dual() + {} - template && !isExpr>...> - Dual(U&& v) - : val(std::forward(v)), grad(0) - { - } - - Dual(const NumericType& val) - : val(val), grad(0) - { - } - - template && !isDual>...> - Dual(U&& other) + template || isArithmetic> = true> + AUTODIFF_DEVICE_FUNC Dual(U&& other) { assign(*this, std::forward(other)); } - template && !isDual>...> - Dual& operator=(U&& other) + template || isArithmetic> = true> + AUTODIFF_DEVICE_FUNC Dual& operator=(U&& other) { Dual tmp; assign(tmp, std::forward(other)); @@ -540,8 +530,8 @@ struct Dual return *this; } - template || isExpr>...> - Dual& operator+=(U&& other) + template || isArithmetic> = true> + AUTODIFF_DEVICE_FUNC Dual& operator+=(U&& other) { Dual tmp; assign(tmp, std::forward(other)); @@ -549,8 +539,8 @@ struct Dual return *this; } - template || isExpr>...> - Dual& operator-=(U&& other) + template || isArithmetic> = true> + AUTODIFF_DEVICE_FUNC Dual& operator-=(U&& other) { Dual tmp; assign(tmp, std::forward(other)); @@ -558,8 +548,8 @@ struct Dual return *this; } - template || isExpr>...> - Dual& operator*=(U&& other) + template || isArithmetic> = true> + AUTODIFF_DEVICE_FUNC Dual& operator*=(U&& other) { Dual tmp; assign(tmp, std::forward(other)); @@ -567,8 +557,8 @@ struct Dual return *this; } - template || isExpr>...> - Dual& operator/=(U&& other) + template || isArithmetic> = true> + AUTODIFF_DEVICE_FUNC Dual& operator/=(U&& other) { Dual tmp; assign(tmp, std::forward(other)); @@ -578,15 +568,15 @@ struct Dual /// Convert this Dual number into a value of type @p U. #if defined(AUTODIFF_ENABLE_IMPLICIT_CONVERSION_DUAL) || defined(AUTODIFF_ENABLE_IMPLICIT_CONVERSION) - operator T() const { return val; } + AUTODIFF_DEVICE_FUNC operator T() const { return val; } template - operator U() const { return static_cast(val); } + AUTODIFF_DEVICE_FUNC operator U() const { return static_cast(val); } #else - explicit operator T() const { return val; } + AUTODIFF_DEVICE_FUNC explicit operator T() const { return val; } template - explicit operator U() const { return static_cast(val); } + AUTODIFF_DEVICE_FUNC explicit operator U() const { return static_cast(val); } #endif }; @@ -612,19 +602,19 @@ struct TernaryExpr }; template -auto inner(const UnaryExpr& expr) -> const R +AUTODIFF_DEVICE_FUNC auto inner(const UnaryExpr& expr) -> const R { return expr.r; } template -auto left(const BinaryExpr& expr) -> const L +AUTODIFF_DEVICE_FUNC auto left(const BinaryExpr& expr) -> const L { return expr.l; } template -auto right(const BinaryExpr& expr) -> const R +AUTODIFF_DEVICE_FUNC auto right(const BinaryExpr& expr) -> const R { return expr.r; } @@ -636,7 +626,7 @@ auto right(const BinaryExpr& expr) -> const R //===================================================================================================================== template -auto eval(T&& expr) +AUTODIFF_DEVICE_FUNC constexpr auto eval(T&& expr) { static_assert(isDual || isExpr || isArithmetic); if constexpr (isDual) @@ -647,7 +637,7 @@ auto eval(T&& expr) } template -auto val(T&& expr) +AUTODIFF_DEVICE_FUNC constexpr auto val(T&& expr) { static_assert(isDual || isExpr || isArithmetic); if constexpr (isDual) @@ -664,7 +654,7 @@ auto val(T&& expr) //===================================================================================================================== template -auto derivative(const Dual& dual) +AUTODIFF_DEVICE_FUNC auto derivative(const Dual& dual) { if constexpr (order == 0) return val(dual.val); @@ -681,7 +671,7 @@ auto derivative(const Dual& dual) /// Traverse down along the `val` branch until depth `order` is reached, then return the `grad` node. template -auto& gradnode(Dual& dual) +AUTODIFF_DEVICE_FUNC auto& gradnode(Dual& dual) { constexpr auto N = Order>; static_assert(order <= N); @@ -692,9 +682,9 @@ auto& gradnode(Dual& dual) /// Set the `grad` node of a dual number along the `val` branch at a depth `order`. template -auto seed(Dual& dual, U&& seedval) +AUTODIFF_DEVICE_FUNC auto seed(Dual& dual, U&& seedval) { - gradnode(dual) = static_cast>(seedval); + gradnode(dual) = static_cast(dual))>>(seedval); } //===================================================================================================================== @@ -715,7 +705,7 @@ using PreventExprRef = ConditionalType, T, PlainType>; // NEGATIVE EXPRESSION GENERATOR FUNCTION //----------------------------------------------------------------------------- template -constexpr auto negative(U&& expr) +AUTODIFF_DEVICE_FUNC constexpr auto negative(U&& expr) { static_assert(isExpr || isArithmetic); if constexpr (isNegExpr) @@ -727,7 +717,7 @@ constexpr auto negative(U&& expr) // INVERSE EXPRESSION GENERATOR FUNCTION //----------------------------------------------------------------------------- template -constexpr auto inverse(U&& expr) +AUTODIFF_DEVICE_FUNC constexpr auto inverse(U&& expr) { static_assert(isExpr); if constexpr (isInvExpr) @@ -739,10 +729,10 @@ constexpr auto inverse(U&& expr) // AUXILIARY CONSTEXPR CONSTANTS //----------------------------------------------------------------------------- template -constexpr auto Zero() { return static_cast>(0); } +AUTODIFF_DEVICE_FUNC constexpr auto Zero() { return static_cast>(0); } template -constexpr auto One() { return static_cast>(1); } +AUTODIFF_DEVICE_FUNC constexpr auto One() { return static_cast>(1); } //===================================================================================================================== // @@ -753,8 +743,8 @@ constexpr auto One() { return static_cast>(1); } //----------------------------------------------------------------------------- // POSITIVE OPERATOR: +x //----------------------------------------------------------------------------- -template>...> -constexpr auto operator+(R&& expr) +template> = true> +AUTODIFF_DEVICE_FUNC constexpr auto operator+(R&& expr) { return std::forward(expr); // expression optimization: +(expr) => expr } @@ -765,8 +755,8 @@ constexpr auto operator+(R&& expr) // //===================================================================================================================== -template>...> -constexpr auto operator-(R&& expr) +template> = true> +AUTODIFF_DEVICE_FUNC constexpr auto operator-(R&& expr) { // NEGATIVE EXPRESSION CASE: -(-x) => x when expr is (-x) if constexpr (isNegExpr) @@ -784,8 +774,8 @@ constexpr auto operator-(R&& expr) // //===================================================================================================================== -template>...> -constexpr auto operator+(L&& l, R&& r) +template> = true> +AUTODIFF_DEVICE_FUNC constexpr auto operator+(L&& l, R&& r) { // ADDITION EXPRESSION CASE: (-x) + (-y) => -(x + y) if constexpr (isNegExpr && isNegExpr) @@ -803,8 +793,8 @@ constexpr auto operator+(L&& l, R&& r) // //===================================================================================================================== -template>...> -constexpr auto operator*(L&& l, R&& r) +template> = true> +AUTODIFF_DEVICE_FUNC constexpr auto operator*(L&& l, R&& r) { // MULTIPLICATION EXPRESSION CASE: (-expr) * (-expr) => expr * expr if constexpr (isNegExpr && isNegExpr) @@ -837,8 +827,8 @@ constexpr auto operator*(L&& l, R&& r) //----------------------------------------------------------------------------- // SUBTRACTION OPERATOR: expr - expr, scalar - expr, expr - scalar //----------------------------------------------------------------------------- -template>...> -constexpr auto operator-(L&& l, R&& r) +template> = true> +AUTODIFF_DEVICE_FUNC constexpr auto operator-(L&& l, R&& r) { return std::forward(l) + ( -std::forward(r) ); } @@ -852,8 +842,8 @@ constexpr auto operator-(L&& l, R&& r) //----------------------------------------------------------------------------- // DIVISION OPERATOR: expr / expr //----------------------------------------------------------------------------- -template>...> -constexpr auto operator/(L&& l, R&& r) +template> = true> +AUTODIFF_DEVICE_FUNC constexpr auto operator/(L&& l, R&& r) { if constexpr (isArithmetic) return std::forward(l) * (One() / std::forward(r)); @@ -866,16 +856,16 @@ constexpr auto operator/(L&& l, R&& r) // //===================================================================================================================== -template>...> constexpr auto sin(R&& r) -> SinExpr { return { r }; } -template>...> constexpr auto cos(R&& r) -> CosExpr { return { r }; } -template>...> constexpr auto tan(R&& r) -> TanExpr { return { r }; } -template>...> constexpr auto asin(R&& r) -> ArcSinExpr { return { r }; } -template>...> constexpr auto acos(R&& r) -> ArcCosExpr { return { r }; } -template>...> constexpr auto atan(R&& r) -> ArcTanExpr { return { r }; } -template>...> constexpr auto atan2(L&& l, R&& r) -> ArcTan2Expr { return { l, r }; } -template>...> constexpr auto hypot(L&& l, R&& r) -> Hypot2Expr { return { l, r }; } -template>...> - constexpr auto hypot(L&& l, C&& c, R&& r) -> Hypot3Expr { return { l, c, r }; } +template> = true> AUTODIFF_DEVICE_FUNC constexpr auto sin(R&& r) -> SinExpr { return { r }; } +template> = true> AUTODIFF_DEVICE_FUNC constexpr auto cos(R&& r) -> CosExpr { return { r }; } +template> = true> AUTODIFF_DEVICE_FUNC constexpr auto tan(R&& r) -> TanExpr { return { r }; } +template> = true> AUTODIFF_DEVICE_FUNC constexpr auto asin(R&& r) -> ArcSinExpr { return { r }; } +template> = true> AUTODIFF_DEVICE_FUNC constexpr auto acos(R&& r) -> ArcCosExpr { return { r }; } +template> = true> AUTODIFF_DEVICE_FUNC constexpr auto atan(R&& r) -> ArcTanExpr { return { r }; } +template> = true> AUTODIFF_DEVICE_FUNC constexpr auto atan2(L&& l, R&& r) -> ArcTan2Expr { return { l, r }; } +template> = true> AUTODIFF_DEVICE_FUNC constexpr auto hypot(L&& l, R&& r) -> Hypot2Expr { return { l, r }; } +template> = true> +AUTODIFF_DEVICE_FUNC constexpr auto hypot(L&& l, C&& c, R&& r) -> Hypot3Expr { return { l, c, r }; } //===================================================================================================================== // @@ -883,9 +873,9 @@ template>...> // //===================================================================================================================== -template>...> constexpr auto sinh(R&& r) -> SinhExpr { return { r }; } -template>...> constexpr auto cosh(R&& r) -> CoshExpr { return { r }; } -template>...> constexpr auto tanh(R&& r) -> TanhExpr { return { r }; } +template> = true> AUTODIFF_DEVICE_FUNC constexpr auto sinh(R&& r) -> SinhExpr { return { r }; } +template> = true> AUTODIFF_DEVICE_FUNC constexpr auto cosh(R&& r) -> CoshExpr { return { r }; } +template> = true> AUTODIFF_DEVICE_FUNC constexpr auto tanh(R&& r) -> TanhExpr { return { r }; } //===================================================================================================================== // @@ -893,9 +883,9 @@ template>...> constexpr auto tanh(R&& r) -> TanhE // //===================================================================================================================== -template>...> constexpr auto exp(R&& r) -> ExpExpr { return { r }; } -template>...> constexpr auto log(R&& r) -> LogExpr { return { r }; } -template>...> constexpr auto log10(R&& r) -> Log10Expr { return { r }; } +template> = true> AUTODIFF_DEVICE_FUNC constexpr auto exp(R&& r) -> ExpExpr { return { r }; } +template> = true> AUTODIFF_DEVICE_FUNC constexpr auto log(R&& r) -> LogExpr { return { r }; } +template> = true> AUTODIFF_DEVICE_FUNC constexpr auto log10(R&& r) -> Log10Expr { return { r }; } //===================================================================================================================== // @@ -903,8 +893,8 @@ template>...> constexpr auto log10(R&& r) -> Log1 // //===================================================================================================================== -template>...> constexpr auto pow(L&& l, R&& r) -> PowExpr { return { l, r }; } -template>...> constexpr auto sqrt(R&& r) -> SqrtExpr { return { r }; } +template> = true> AUTODIFF_DEVICE_FUNC constexpr auto pow(L&& l, R&& r) -> PowExpr { return { l, r }; } +template> = true> AUTODIFF_DEVICE_FUNC constexpr auto sqrt(R&& r) -> SqrtExpr { return { r }; } //===================================================================================================================== // @@ -912,23 +902,23 @@ template>...> constexpr auto sqrt(R&& r) -> SqrtE // //===================================================================================================================== -template>...> constexpr auto abs(R&& r) -> AbsExpr { return { r }; } -template>...> constexpr auto abs2(R&& r) { return std::forward(r) * std::forward(r); } -template>...> constexpr auto conj(R&& r) { return std::forward(r); } -template>...> constexpr auto real(R&& r) { return std::forward(r); } -template>...> constexpr auto imag(R&&) { return 0.0; } -template>...> constexpr auto erf(R&& r) -> ErfExpr { return { r }; } +template> = true> AUTODIFF_DEVICE_FUNC constexpr auto abs(R&& r) -> AbsExpr { return { r }; } +template> = true> AUTODIFF_DEVICE_FUNC constexpr auto abs2(R&& r) { return std::forward(r) * std::forward(r); } +template> = true> AUTODIFF_DEVICE_FUNC constexpr auto conj(R&& r) { return std::forward(r); } +template> = true> AUTODIFF_DEVICE_FUNC constexpr auto real(R&& r) { return std::forward(r); } +template> = true> AUTODIFF_DEVICE_FUNC constexpr auto imag(R&&) { return 0.0; } +template> = true> AUTODIFF_DEVICE_FUNC constexpr auto erf(R&& r) -> ErfExpr { return { r }; } -template>...> -constexpr auto min(L&& l, R&& r) +template> = true> +AUTODIFF_DEVICE_FUNC constexpr auto min(L&& l, R&& r) { const auto x = eval(l); const auto y = eval(r); return (x <= y) ? x : y; } -template>...> -constexpr auto max(L&& l, R&& r) +template> = true> +AUTODIFF_DEVICE_FUNC constexpr auto max(L&& l, R&& r) { const auto x = eval(l); const auto y = eval(r); @@ -941,12 +931,12 @@ constexpr auto max(L&& l, R&& r) // //===================================================================================================================== -template>...> bool operator==(L&& l, R&& r) { return val(l) == val(r); } -template>...> bool operator!=(L&& l, R&& r) { return val(l) != val(r); } -template>...> bool operator<=(L&& l, R&& r) { return val(l) <= val(r); } -template>...> bool operator>=(L&& l, R&& r) { return val(l) >= val(r); } -template>...> bool operator<(L&& l, R&& r) { return val(l) < val(r); } -template>...> bool operator>(L&& l, R&& r) { return val(l) > val(r); } +template> = true> AUTODIFF_DEVICE_FUNC bool operator==(L&& l, R&& r) { return val(l) == val(r); } +template> = true> AUTODIFF_DEVICE_FUNC bool operator!=(L&& l, R&& r) { return val(l) != val(r); } +template> = true> AUTODIFF_DEVICE_FUNC bool operator<=(L&& l, R&& r) { return val(l) <= val(r); } +template> = true> AUTODIFF_DEVICE_FUNC bool operator>=(L&& l, R&& r) { return val(l) >= val(r); } +template> = true> AUTODIFF_DEVICE_FUNC bool operator<(L&& l, R&& r) { return val(l) < val(r); } +template> = true> AUTODIFF_DEVICE_FUNC bool operator>(L&& l, R&& r) { return val(l) > val(r); } //===================================================================================================================== // @@ -954,14 +944,14 @@ template>...> bool operator>(L // //===================================================================================================================== template -constexpr void negate(Dual& self) +AUTODIFF_DEVICE_FUNC constexpr void negate(Dual& self) { self.val = -self.val; self.grad = -self.grad; } template -constexpr void scale(Dual& self, const U& scalar) +AUTODIFF_DEVICE_FUNC constexpr void scale(Dual& self, const U& scalar) { self.val *= scalar; self.grad *= scalar; @@ -974,7 +964,7 @@ constexpr void scale(Dual& self, const U& scalar) //===================================================================================================================== template -constexpr void apply(Dual& self); +AUTODIFF_DEVICE_FUNC constexpr void apply(Dual& self); //===================================================================================================================== // @@ -983,14 +973,14 @@ constexpr void apply(Dual& self); //===================================================================================================================== template -constexpr void assign(Dual& self, U&& other) +AUTODIFF_DEVICE_FUNC constexpr void assign(Dual& self, U&& other) { static_assert(isExpr || isArithmetic); // ASSIGN A NUMBER: self = number if constexpr (isArithmetic) { self.val = other; - self.grad = Zero(); + self.grad = Zero(); } // ASSIGN A DUAL NUMBER: self = dual else if constexpr (isDual) { @@ -1040,7 +1030,7 @@ constexpr void assign(Dual& self, U&& other) } template -constexpr void assign(Dual& self, U&& other, Dual& tmp) +AUTODIFF_DEVICE_FUNC constexpr void assign(Dual& self, U&& other, Dual& tmp) { static_assert(isExpr || isArithmetic); @@ -1079,7 +1069,7 @@ constexpr void assign(Dual& self, U&& other, Dual& tmp) //===================================================================================================================== template -constexpr void assignAdd(Dual& self, U&& other) +AUTODIFF_DEVICE_FUNC constexpr void assignAdd(Dual& self, U&& other) { static_assert(isExpr || isArithmetic); @@ -1114,7 +1104,7 @@ constexpr void assignAdd(Dual& self, U&& other) } template -constexpr void assignAdd(Dual& self, U&& other, Dual& tmp) +AUTODIFF_DEVICE_FUNC constexpr void assignAdd(Dual& self, U&& other, Dual& tmp) { static_assert(isExpr || isArithmetic); @@ -1141,7 +1131,7 @@ constexpr void assignAdd(Dual& self, U&& other, Dual& tmp) //===================================================================================================================== template -constexpr void assignSub(Dual& self, U&& other) +AUTODIFF_DEVICE_FUNC constexpr void assignSub(Dual& self, U&& other) { static_assert(isExpr || isArithmetic); @@ -1176,7 +1166,7 @@ constexpr void assignSub(Dual& self, U&& other) } template -constexpr void assignSub(Dual& self, U&& other, Dual& tmp) +AUTODIFF_DEVICE_FUNC constexpr void assignSub(Dual& self, U&& other, Dual& tmp) { static_assert(isExpr || isArithmetic); @@ -1203,7 +1193,7 @@ constexpr void assignSub(Dual& self, U&& other, Dual& tmp) //===================================================================================================================== template -constexpr void assignMul(Dual& self, U&& other) +AUTODIFF_DEVICE_FUNC constexpr void assignMul(Dual& self, U&& other) { static_assert(isExpr || isArithmetic); @@ -1242,7 +1232,7 @@ constexpr void assignMul(Dual& self, U&& other) } template -constexpr void assignMul(Dual& self, U&& other, Dual& tmp) +AUTODIFF_DEVICE_FUNC constexpr void assignMul(Dual& self, U&& other, Dual& tmp) { static_assert(isExpr || isArithmetic); @@ -1270,7 +1260,7 @@ constexpr void assignMul(Dual& self, U&& other, Dual& tmp) //===================================================================================================================== template -constexpr void assignDiv(Dual& self, U&& other) +AUTODIFF_DEVICE_FUNC constexpr void assignDiv(Dual& self, U&& other) { static_assert(isExpr || isArithmetic); @@ -1313,7 +1303,7 @@ constexpr void assignDiv(Dual& self, U&& other) } template -constexpr void assignDiv(Dual& self, U&& other, Dual& tmp) +AUTODIFF_DEVICE_FUNC constexpr void assignDiv(Dual& self, U&& other, Dual& tmp) { static_assert(isExpr || isArithmetic); @@ -1345,7 +1335,7 @@ constexpr void assignDiv(Dual& self, U&& other, Dual& tmp) //===================================================================================================================== template -constexpr void assignPow(Dual& self, U&& other) +AUTODIFF_DEVICE_FUNC constexpr void assignPow(Dual& self, U&& other) { // ASSIGN-POW A NUMBER: self = pow(self, number) if constexpr (isArithmetic) { @@ -1370,7 +1360,7 @@ constexpr void assignPow(Dual& self, U&& other) } template -constexpr void assignPow(Dual& self, U&& other, Dual& tmp) +AUTODIFF_DEVICE_FUNC constexpr void assignPow(Dual& self, U&& other, Dual& tmp) { assign(tmp, other); assignPow(self, tmp); @@ -1383,7 +1373,7 @@ constexpr void assignPow(Dual& self, U&& other, Dual& tmp) //===================================================================================================================== template -constexpr void assignArcTan2(Dual& self, Y&&y, X&&x) +AUTODIFF_DEVICE_FUNC constexpr void assignArcTan2(Dual& self, Y&&y, X&&x) { static_assert(isArithmetic || isExpr); static_assert(isArithmetic || isExpr); @@ -1428,7 +1418,7 @@ constexpr void assignArcTan2(Dual& self, Y&&y, X&&x) //===================================================================================================================== template -constexpr void assignHypot2(Dual& self, X&& x, Y&& y) +AUTODIFF_DEVICE_FUNC constexpr void assignHypot2(Dual& self, X&& x, Y&& y) { static_assert(isArithmetic || isExpr); static_assert(isArithmetic || isExpr); @@ -1466,8 +1456,25 @@ constexpr void assignHypot2(Dual& self, X&& x, Y&& y) } } +template +AUTODIFF_DEVICE_FUNC inline T hypot_device_func(T x, T y, T z) +{ +#ifdef __CUDA_ARCH__ + x = std::abs(x); + y = std::abs(y); + z = std::abs(z); + if(T a = x < y ? y < z ? z : y : x < z ? z + : x) + return a * std::sqrt((x / a) * (x / a) + (y / a) * (y / a) + (z / a) * (z / a)); + else + return {}; +#else + return hypot(x, y, z); +#endif +} + template -constexpr void assignHypot3(Dual& self, X&& x, Y&& y, Z&& z) +AUTODIFF_DEVICE_FUNC constexpr void assignHypot3(Dual& self, X&& x, Y&& y, Z&& z) { static_assert(isArithmetic || isExpr); static_assert(isArithmetic || isExpr); @@ -1475,43 +1482,43 @@ constexpr void assignHypot3(Dual& self, X&& x, Y&& y, Z&& z) // self = hypot(dual, number, number) if constexpr (isDual && isArithmetic && isArithmetic) { - self.val = hypot(x.val, y, z); + self.val = hypot_device_func(x.val, y, z); self.grad = x.val / self.val * x.grad; } // self = hypot(number, dual, number) else if constexpr (isArithmetic && isDual && isArithmetic) { - self.val = hypot(x, y.val, z); + self.val = hypot_device_func(x, y.val, z); self.grad = y.val / self.val * y.grad; } // self = hypot(number, number, dual) else if constexpr (isArithmetic && isArithmetic && isDual) { - self.val = hypot(x, y, z.val); + self.val = hypot_device_func(x, y, z.val); self.grad = z.val / self.val * z.grad; } // self = hypot(dual, dual, number) else if constexpr (isDual && isDual && isArithmetic) { - self.val = hypot(x.val, y.val, z); + self.val = hypot_device_func(x.val, y.val, z); self.grad = (x.grad * x.val + y.grad * y.val ) / self.val; } // self = hypot(number, dual, dual) else if constexpr (isArithmetic && isDual && isDual) { - self.val = hypot(x, y.val, z.val); + self.val = hypot_device_func(x, y.val, z.val); self.grad = (y.grad * y.val + z.grad * z.val) / self.val; } // self = hypot(dual, number, dual) else if constexpr (isDual && isArithmetic && isDual) { - self.val = hypot(x.val, y, z.val); + self.val = hypot_device_func(x.val, y, z.val); self.grad = (x.grad * x.val + z.grad * z.val) / self.val; } // self = hypot(dual, dual, dual) else if constexpr (isDual && isDual && isDual) { - self.val = hypot(x.val, y.val, z.val); + self.val = hypot_device_func(x.val, y.val, z.val); self.grad = (x.grad * x.val + y.grad * y.val + z.grad * z.val) / self.val; } @@ -1543,35 +1550,35 @@ constexpr void assignHypot3(Dual& self, X&& x, Y&& y, Z&& z) // //===================================================================================================================== template -constexpr void apply(Dual& self, NegOp) +AUTODIFF_DEVICE_FUNC constexpr void apply(Dual& self, NegOp) { self.val = -self.val; self.grad = -self.grad; } template -constexpr void apply(Dual& self, InvOp) +AUTODIFF_DEVICE_FUNC constexpr void apply(Dual& self, InvOp) { self.val = One() / self.val; self.grad *= - self.val * self.val; } template -constexpr void apply(Dual& self, SinOp) +AUTODIFF_DEVICE_FUNC constexpr void apply(Dual& self, SinOp) { self.grad *= cos(self.val); self.val = sin(self.val); } template -constexpr void apply(Dual& self, CosOp) +AUTODIFF_DEVICE_FUNC constexpr void apply(Dual& self, CosOp) { self.grad *= -sin(self.val); self.val = cos(self.val); } template -constexpr void apply(Dual& self, TanOp) +AUTODIFF_DEVICE_FUNC constexpr void apply(Dual& self, TanOp) { const T aux = One() / cos(self.val); self.val = tan(self.val); @@ -1579,21 +1586,21 @@ constexpr void apply(Dual& self, TanOp) } template -constexpr void apply(Dual& self, SinhOp) +AUTODIFF_DEVICE_FUNC constexpr void apply(Dual& self, SinhOp) { self.grad *= cosh(self.val); self.val = sinh(self.val); } template -constexpr void apply(Dual& self, CoshOp) +AUTODIFF_DEVICE_FUNC constexpr void apply(Dual& self, CoshOp) { self.grad *= sinh(self.val); self.val = cosh(self.val); } template -constexpr void apply(Dual& self, TanhOp) +AUTODIFF_DEVICE_FUNC constexpr void apply(Dual& self, TanhOp) { const T aux = One() / cosh(self.val); self.val = tanh(self.val); @@ -1601,7 +1608,7 @@ constexpr void apply(Dual& self, TanhOp) } template -constexpr void apply(Dual& self, ArcSinOp) +AUTODIFF_DEVICE_FUNC constexpr void apply(Dual& self, ArcSinOp) { const T aux = One() / sqrt(1.0 - self.val * self.val); self.val = asin(self.val); @@ -1609,7 +1616,7 @@ constexpr void apply(Dual& self, ArcSinOp) } template -constexpr void apply(Dual& self, ArcCosOp) +AUTODIFF_DEVICE_FUNC constexpr void apply(Dual& self, ArcCosOp) { const T aux = -One() / sqrt(1.0 - self.val * self.val); self.val = acos(self.val); @@ -1617,7 +1624,7 @@ constexpr void apply(Dual& self, ArcCosOp) } template -constexpr void apply(Dual& self, ArcTanOp) +AUTODIFF_DEVICE_FUNC constexpr void apply(Dual& self, ArcTanOp) { const T aux = One() / (1.0 + self.val * self.val); self.val = atan(self.val); @@ -1625,14 +1632,14 @@ constexpr void apply(Dual& self, ArcTanOp) } template -constexpr void apply(Dual& self, ExpOp) +AUTODIFF_DEVICE_FUNC constexpr void apply(Dual& self, ExpOp) { self.val = exp(self.val); self.grad *= self.val; } template -constexpr void apply(Dual& self, LogOp) +AUTODIFF_DEVICE_FUNC constexpr void apply(Dual& self, LogOp) { const T aux = One() / self.val; self.val = log(self.val); @@ -1640,7 +1647,7 @@ constexpr void apply(Dual& self, LogOp) } template -constexpr void apply(Dual& self, Log10Op) +AUTODIFF_DEVICE_FUNC constexpr void apply(Dual& self, Log10Op) { constexpr NumericType ln10 = 2.3025850929940456840179914546843; const T aux = One() / (ln10 * self.val); @@ -1649,21 +1656,21 @@ constexpr void apply(Dual& self, Log10Op) } template -constexpr void apply(Dual& self, SqrtOp) +AUTODIFF_DEVICE_FUNC constexpr void apply(Dual& self, SqrtOp) { self.val = sqrt(self.val); self.grad *= 0.5 / self.val; } template -constexpr void apply(Dual& self, AbsOp) +AUTODIFF_DEVICE_FUNC constexpr void apply(Dual& self, AbsOp) { self.grad *= self.val < T(0) ? G(-1) : (self.val > T(0) ? G(1) : G(0)); self.val = abs(self.val); } template -constexpr void apply(Dual& self, ErfOp) +AUTODIFF_DEVICE_FUNC constexpr void apply(Dual& self, ErfOp) { constexpr NumericType sqrt_pi = 1.7724538509055160272981674833411451872554456638435; const T aux = self.val; @@ -1672,7 +1679,7 @@ constexpr void apply(Dual& self, ErfOp) } template -constexpr void apply(Dual& self) +AUTODIFF_DEVICE_FUNC constexpr void apply(Dual& self) { apply(self, Op{}); } @@ -1684,7 +1691,7 @@ std::ostream& operator<<(std::ostream& out, const Dual& x) return out; } -template>...> +template> = true> auto reprAux(const T& x) { std::stringstream ss; ss << x; diff --git a/inst/include/autodiff/forward/dual/eigen.hpp b/inst/include/autodiff/forward/dual/eigen.hpp index a706600..3b8e579 100644 --- a/inst/include/autodiff/forward/dual/eigen.hpp +++ b/inst/include/autodiff/forward/dual/eigen.hpp @@ -7,7 +7,7 @@ // // Licensed under the MIT License . // -// Copyright (c) 2018-2022 Allan Leal +// Copyright © 2018–2024 Allan Leal // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal diff --git a/inst/include/autodiff/forward/real.hpp b/inst/include/autodiff/forward/real.hpp index 22c9dfb..9e24b45 100644 --- a/inst/include/autodiff/forward/real.hpp +++ b/inst/include/autodiff/forward/real.hpp @@ -7,7 +7,7 @@ // // Licensed under the MIT License . // -// Copyright (c) 2018-2022 Allan Leal +// Copyright © 2018–2024 Allan Leal // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal diff --git a/inst/include/autodiff/forward/real/eigen.hpp b/inst/include/autodiff/forward/real/eigen.hpp index e6f5fba..a01b36c 100644 --- a/inst/include/autodiff/forward/real/eigen.hpp +++ b/inst/include/autodiff/forward/real/eigen.hpp @@ -7,7 +7,7 @@ // // Licensed under the MIT License . // -// Copyright (c) 2018-2022 Allan Leal +// Copyright © 2018–2024 Allan Leal // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal diff --git a/inst/include/autodiff/forward/real/real.hpp b/inst/include/autodiff/forward/real/real.hpp index 483ccb2..d283f61 100644 --- a/inst/include/autodiff/forward/real/real.hpp +++ b/inst/include/autodiff/forward/real/real.hpp @@ -7,7 +7,7 @@ // // Licensed under the MIT License . // -// Copyright (c) 2018-2022 Allan Leal +// Copyright © 2018–2024 Allan Leal // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and `associated documentation files (the "Software"), to deal @@ -59,23 +59,23 @@ class Real public: /// Construct a default Real number of order *N* and type *T*. - constexpr Real() + AUTODIFF_DEVICE_FUNC constexpr Real() {} /// Construct a Real number with given data. - constexpr Real(const T& value) + AUTODIFF_DEVICE_FUNC constexpr Real(const T& value) { m_data[0] = value; } /// Construct a Real number with given data. - constexpr Real(const std::array& data) + AUTODIFF_DEVICE_FUNC constexpr Real(const std::array& data) : m_data(data) {} /// Construct a Real number with given data. - template>...> - constexpr explicit Real(const Real& other) + template> = true> + AUTODIFF_DEVICE_FUNC constexpr explicit Real(const Real& other) { static_assert(N <= M); For<0, N + 1>([&](auto i) constexpr { @@ -84,84 +84,84 @@ class Real } /// Return the value of the Real number. - constexpr auto val() -> T& + AUTODIFF_DEVICE_FUNC constexpr auto val() -> T& { return m_data[0]; } /// Return the value of the Real number. - constexpr auto val() const -> const T& + AUTODIFF_DEVICE_FUNC constexpr auto val() const -> const T& { return m_data[0]; } - constexpr auto operator[](size_t i) -> T& + AUTODIFF_DEVICE_FUNC constexpr auto operator[](size_t i) -> T& { return m_data[i]; } - constexpr auto operator[](size_t i) const -> const T& + AUTODIFF_DEVICE_FUNC constexpr auto operator[](size_t i) const -> const T& { return m_data[i]; } - template>...> - constexpr auto operator=(const U& value) -> Real& + template> = true> + AUTODIFF_DEVICE_FUNC constexpr auto operator=(const U& value) -> Real& { m_data[0] = value; For<1, N + 1>([&](auto i) constexpr { m_data[i] = T{}; }); return *this; } - constexpr auto operator=(const std::array& data) + AUTODIFF_DEVICE_FUNC constexpr auto operator=(const std::array& data) { m_data = data; return *this; } - template>...> - constexpr auto operator+=(const U& value) -> Real& + template> = true> + AUTODIFF_DEVICE_FUNC constexpr auto operator+=(const U& value) -> Real& { m_data[0] += static_cast(value); return *this; } - template>...> - constexpr auto operator-=(const U& value) -> Real& + template> = true> + AUTODIFF_DEVICE_FUNC constexpr auto operator-=(const U& value) -> Real& { m_data[0] -= static_cast(value); return *this; } - template>...> - constexpr auto operator*=(const U& value) -> Real& + template> = true> + AUTODIFF_DEVICE_FUNC constexpr auto operator*=(const U& value) -> Real& { For<0, N + 1>([&](auto i) constexpr { m_data[i] *= static_cast(value); }); return *this; } - template>...> - constexpr auto operator/=(const U& value) -> Real& + template> = true> + AUTODIFF_DEVICE_FUNC constexpr auto operator/=(const U& value) -> Real& { For<0, N + 1>([&](auto i) constexpr { m_data[i] /= static_cast(value); }); return *this; } - constexpr auto operator+=(const Real& y) + AUTODIFF_DEVICE_FUNC constexpr auto operator+=(const Real& y) { auto& x = *this; For<0, N + 1>([&](auto i) constexpr { x[i] += y[i]; }); return *this; } - constexpr auto operator-=(const Real& y) + AUTODIFF_DEVICE_FUNC constexpr auto operator-=(const Real& y) { auto& x = *this; For<0, N + 1>([&](auto i) constexpr { x[i] -= y[i]; }); return *this; } - constexpr auto operator*=(const Real& y) + AUTODIFF_DEVICE_FUNC constexpr auto operator*=(const Real& y) { auto& x = *this; ReverseFor([&](auto i) constexpr { @@ -173,7 +173,7 @@ class Real return *this; } - constexpr auto operator/=(const Real& y) + AUTODIFF_DEVICE_FUNC constexpr auto operator/=(const Real& y) { auto& x = *this; For([&](auto i) constexpr { @@ -186,13 +186,16 @@ class Real return *this; } - /// Convert this Real number into a value of type @p U. #if defined(AUTODIFF_ENABLE_IMPLICIT_CONVERSION_REAL) || defined(AUTODIFF_ENABLE_IMPLICIT_CONVERSION) - template>...> - constexpr operator U() const { return static_cast(m_data[0]); } + AUTODIFF_DEVICE_FUNC constexpr operator T() const { return static_cast(m_data[0]); } + + template> = true> + AUTODIFF_DEVICE_FUNC constexpr operator U() const { return static_cast(m_data[0]); } #else - template>...> - constexpr explicit operator U() const { return static_cast(m_data[0]); } + AUTODIFF_DEVICE_FUNC constexpr explicit operator T() const { return static_cast(m_data[0]); } + + template> = true> + AUTODIFF_DEVICE_FUNC constexpr explicit operator U() const { return static_cast(m_data[0]); } #endif }; @@ -250,13 +253,13 @@ constexpr bool areReal = (... && isReal); //===================================================================================================================== template -auto operator+(const Real& x) +AUTODIFF_DEVICE_FUNC auto operator+(const Real& x) { return x; } template -auto operator-(const Real& x) +AUTODIFF_DEVICE_FUNC auto operator-(const Real& x) { Real res; For<0, N + 1>([&](auto i) constexpr { res[i] = -x[i]; }); @@ -270,19 +273,19 @@ auto operator-(const Real& x) //===================================================================================================================== template -auto operator+(Real x, const Real& y) +AUTODIFF_DEVICE_FUNC auto operator+(Real x, const Real& y) { return x += y; } -template>...> -auto operator+(Real x, const U& y) +template> = true> +AUTODIFF_DEVICE_FUNC auto operator+(Real x, const U& y) { return x += y; } -template>...> -auto operator+(const U& x, Real y) +template> = true> +AUTODIFF_DEVICE_FUNC auto operator+(const U& x, Real y) { return y += x; } @@ -293,19 +296,19 @@ auto operator+(const U& x, Real y) // //===================================================================================================================== template -auto operator-(Real x, const Real& y) +AUTODIFF_DEVICE_FUNC auto operator-(Real x, const Real& y) { return x -= y; } -template>...> -auto operator-(Real x, const U& y) +template> = true> +AUTODIFF_DEVICE_FUNC auto operator-(Real x, const U& y) { return x -= y; } -template>...> -auto operator-(const U& x, Real y) +template> = true> +AUTODIFF_DEVICE_FUNC auto operator-(const U& x, Real y) { y -= x; y *= -static_cast(1.0); @@ -319,19 +322,19 @@ auto operator-(const U& x, Real y) //===================================================================================================================== template -auto operator*(Real x, const Real& y) +AUTODIFF_DEVICE_FUNC auto operator*(Real x, const Real& y) { return x *= y; } -template>...> -auto operator*(Real x, const U& y) +template> = true> +AUTODIFF_DEVICE_FUNC auto operator*(Real x, const U& y) { return x *= y; } -template>...> -auto operator*(const U& x, Real y) +template> = true> +AUTODIFF_DEVICE_FUNC auto operator*(const U& x, Real y) { return y *= x; } @@ -343,19 +346,19 @@ auto operator*(const U& x, Real y) //===================================================================================================================== template -auto operator/(Real x, const Real& y) +AUTODIFF_DEVICE_FUNC auto operator/(Real x, const Real& y) { return x /= y; } -template>...> -auto operator/(Real x, const U& y) +template> = true> +AUTODIFF_DEVICE_FUNC auto operator/(Real x, const U& y) { return x /= y; } -template>...> -auto operator/(const U& x, Real y) +template> = true> +AUTODIFF_DEVICE_FUNC auto operator/(const U& x, Real y) { Real z = x; return z /= y; @@ -368,7 +371,7 @@ auto operator/(const U& x, Real y) //===================================================================================================================== template -constexpr auto exp(const Real& x) +AUTODIFF_DEVICE_FUNC constexpr auto exp(const Real& x) { Real expx; expx[0] = exp(x[0]); @@ -382,7 +385,7 @@ constexpr auto exp(const Real& x) } template -constexpr auto log(const Real& x) +AUTODIFF_DEVICE_FUNC constexpr auto log(const Real& x) { assert(x[0] != 0 && "autodiff::log(x) has undefined value and derivatives when x = 0"); Real logx; @@ -398,7 +401,7 @@ constexpr auto log(const Real& x) } template -constexpr auto log10(const Real& x) +AUTODIFF_DEVICE_FUNC constexpr auto log10(const Real& x) { assert(x[0] != 0 && "autodiff::log10(x) has undefined value and derivatives when x = 0"); const auto ln10 = 2.302585092994046; @@ -407,7 +410,7 @@ constexpr auto log10(const Real& x) } template -constexpr auto sqrt(const Real& x) +AUTODIFF_DEVICE_FUNC constexpr auto sqrt(const Real& x) { Real res; res[0] = sqrt(x[0]); @@ -434,7 +437,7 @@ constexpr auto sqrt(const Real& x) } template -constexpr auto cbrt(const Real& x) +AUTODIFF_DEVICE_FUNC constexpr auto cbrt(const Real& x) { Real res; res[0] = cbrt(x[0]); @@ -461,7 +464,7 @@ constexpr auto cbrt(const Real& x) } template -constexpr auto pow(const Real& x, const Real& y) +AUTODIFF_DEVICE_FUNC constexpr auto pow(const Real& x, const Real& y) { Real res; res[0] = pow(x[0], y[0]); @@ -486,8 +489,8 @@ constexpr auto pow(const Real& x, const Real& y) return res; } -template>...> -constexpr auto pow(const Real& x, const U& c) +template> = true> +AUTODIFF_DEVICE_FUNC constexpr auto pow(const Real& x, const U& c) { Real res; res[0] = pow(x[0], static_cast(c)); @@ -506,8 +509,8 @@ constexpr auto pow(const Real& x, const U& c) return res; } -template>...> -constexpr auto pow(const U& c, const Real& y) +template> = true> +AUTODIFF_DEVICE_FUNC constexpr auto pow(const U& c, const Real& y) { Real res; res[0] = pow(static_cast(c), y[0]); @@ -533,7 +536,7 @@ constexpr auto pow(const U& c, const Real& y) //===================================================================================================================== template -auto sincos(const Real& x) -> std::tuple, Real> +AUTODIFF_DEVICE_FUNC auto sincos(const Real& x) -> std::tuple, Real> { Real sinx; Real cosx; @@ -557,19 +560,19 @@ auto sincos(const Real& x) -> std::tuple, Real> } template -auto sin(const Real& x) +AUTODIFF_DEVICE_FUNC auto sin(const Real& x) { return std::get<0>(sincos(x)); } template -auto cos(const Real& x) +AUTODIFF_DEVICE_FUNC auto cos(const Real& x) { return std::get<1>(sincos(x)); } template -auto tan(const Real& x) +AUTODIFF_DEVICE_FUNC auto tan(const Real& x) { Real tanx; tanx[0] = tan(x[0]); @@ -596,7 +599,7 @@ auto tan(const Real& x) } template -constexpr auto asin(const Real& x) +AUTODIFF_DEVICE_FUNC constexpr auto asin(const Real& x) { Real res; res[0] = asin(x[0]); @@ -617,7 +620,7 @@ constexpr auto asin(const Real& x) } template -constexpr auto acos(const Real& x) +AUTODIFF_DEVICE_FUNC constexpr auto acos(const Real& x) { Real res; res[0] = acos(x[0]); @@ -638,7 +641,7 @@ constexpr auto acos(const Real& x) } template -constexpr auto atan(const Real& x) +AUTODIFF_DEVICE_FUNC constexpr auto atan(const Real& x) { Real res; res[0] = atan(x[0]); @@ -657,8 +660,8 @@ constexpr auto atan(const Real& x) return res; } -template>...> -constexpr auto atan2(const U& c, const Real& x) +template> = true> +AUTODIFF_DEVICE_FUNC constexpr auto atan2(const U& c, const Real& x) { // d[atan2(c,x)]/dx = -c / (c^2 + x^2) Real res; @@ -677,8 +680,8 @@ constexpr auto atan2(const U& c, const Real& x) return res; } -template>...> -constexpr auto atan2(const Real& y, const U& c) +template> = true> +AUTODIFF_DEVICE_FUNC constexpr auto atan2(const Real& y, const U& c) { // d[atan2(y,c)]/dy = c / (c^2 + y^2) Real res; @@ -698,7 +701,7 @@ constexpr auto atan2(const Real& y, const U& c) } template -constexpr auto atan2(const Real& y, const Real& x) +AUTODIFF_DEVICE_FUNC constexpr auto atan2(const Real& y, const Real& x) { Real res; res[0] = atan2(y[0], x[0]); @@ -718,7 +721,7 @@ constexpr auto atan2(const Real& y, const Real& x) //===================================================================================================================== template -auto sinhcosh(const Real& x) -> std::tuple, Real> +AUTODIFF_DEVICE_FUNC auto sinhcosh(const Real& x) -> std::tuple, Real> { Real sinhx; Real coshx; @@ -742,20 +745,20 @@ auto sinhcosh(const Real& x) -> std::tuple, Real> } template -auto sinh(const Real& x) +AUTODIFF_DEVICE_FUNC auto sinh(const Real& x) { return std::get<0>(sinhcosh(x)); } template -auto cosh(const Real& x) +AUTODIFF_DEVICE_FUNC auto cosh(const Real& x) { return std::get<1>(sinhcosh(x)); } template -auto tanh(const Real& x) +AUTODIFF_DEVICE_FUNC auto tanh(const Real& x) { Real tanhx; tanhx[0] = tanh(x[0]); @@ -783,7 +786,7 @@ auto tanh(const Real& x) } template -constexpr auto asinh(const Real& x) +AUTODIFF_DEVICE_FUNC constexpr auto asinh(const Real& x) { Real res; res[0] = asinh(x[0]); @@ -799,7 +802,7 @@ constexpr auto asinh(const Real& x) } template -constexpr auto acosh(const Real& x) +AUTODIFF_DEVICE_FUNC constexpr auto acosh(const Real& x) { Real res; res[0] = acosh(x[0]); @@ -816,7 +819,7 @@ constexpr auto acosh(const Real& x) } template -constexpr auto atanh(const Real& x) +AUTODIFF_DEVICE_FUNC constexpr auto atanh(const Real& x) { Real res; res[0] = atanh(x[0]); @@ -839,7 +842,7 @@ constexpr auto atanh(const Real& x) //===================================================================================================================== template -constexpr auto abs(const Real& x) +AUTODIFF_DEVICE_FUNC constexpr auto abs(const Real& x) { Real res; res[0] = std::abs(x[0]); @@ -856,39 +859,39 @@ constexpr auto abs(const Real& x) } template -constexpr auto min(const Real& x, const Real& y) +AUTODIFF_DEVICE_FUNC constexpr auto min(const Real& x, const Real& y) { return (x[0] <= y[0]) ? x : y; } -template>...> -constexpr auto min(const Real& x, const U& y) +template> = true> +AUTODIFF_DEVICE_FUNC constexpr auto min(const Real& x, const U& y) { - return (x[0] <= y) ? x : y; + return (x[0] <= y) ? x : Real{y}; } -template>...> -constexpr auto min(const U& x, const Real& y) +template> = true> +AUTODIFF_DEVICE_FUNC constexpr auto min(const U& x, const Real& y) { - return (x < y[0]) ? x : y; + return (x < y[0]) ? Real{x} : y; } template -constexpr auto max(const Real& x, const Real& y) +AUTODIFF_DEVICE_FUNC constexpr auto max(const Real& x, const Real& y) { return (x[0] >= y[0]) ? x : y; } -template>...> -constexpr auto max(const Real& x, const U& y) +template> = true> +AUTODIFF_DEVICE_FUNC constexpr auto max(const Real& x, const U& y) { - return (x[0] >= y) ? x : y; + return (x[0] >= y) ? x : Real{y}; } -template>...> -constexpr auto max(const U& x, const Real& y) +template> = true> +AUTODIFF_DEVICE_FUNC constexpr auto max(const U& x, const Real& y) { - return (x > y[0]) ? x : y; + return (x > y[0]) ? Real{x} : y; } //===================================================================================================================== @@ -922,7 +925,7 @@ auto repr(const Real& x) //===================================================================================================================== template -bool operator==(const Real& x, const Real& y) +AUTODIFF_DEVICE_FUNC bool operator==(const Real& x, const Real& y) { bool res = true; For<0, N + 1>([&](auto i) constexpr { @@ -931,25 +934,25 @@ bool operator==(const Real& x, const Real& y) return res; } -template bool operator!=(const Real& x, const Real& y) { return !(x == y); } -template bool operator< (const Real& x, const Real& y) { return x[0] < y[0]; } -template bool operator> (const Real& x, const Real& y) { return x[0] > y[0]; } -template bool operator<=(const Real& x, const Real& y) { return x[0] <= y[0]; } -template bool operator>=(const Real& x, const Real& y) { return x[0] >= y[0]; } - -template>...> bool operator==(const Real& x, const U& y) { return x[0] == y; } -template>...> bool operator!=(const Real& x, const U& y) { return x[0] != y; } -template>...> bool operator< (const Real& x, const U& y) { return x[0] < y; } -template>...> bool operator> (const Real& x, const U& y) { return x[0] > y; } -template>...> bool operator<=(const Real& x, const U& y) { return x[0] <= y; } -template>...> bool operator>=(const Real& x, const U& y) { return x[0] >= y; } - -template>...> bool operator==(const U& x, const Real& y) { return x == y[0]; } -template>...> bool operator!=(const U& x, const Real& y) { return x != y[0]; } -template>...> bool operator< (const U& x, const Real& y) { return x < y[0]; } -template>...> bool operator> (const U& x, const Real& y) { return x > y[0]; } -template>...> bool operator<=(const U& x, const Real& y) { return x <= y[0]; } -template>...> bool operator>=(const U& x, const Real& y) { return x >= y[0]; } +template AUTODIFF_DEVICE_FUNC bool operator!=(const Real& x, const Real& y) { return !(x == y); } +template AUTODIFF_DEVICE_FUNC bool operator< (const Real& x, const Real& y) { return x[0] < y[0]; } +template AUTODIFF_DEVICE_FUNC bool operator> (const Real& x, const Real& y) { return x[0] > y[0]; } +template AUTODIFF_DEVICE_FUNC bool operator<=(const Real& x, const Real& y) { return x[0] <= y[0]; } +template AUTODIFF_DEVICE_FUNC bool operator>=(const Real& x, const Real& y) { return x[0] >= y[0]; } + +template> = true> AUTODIFF_DEVICE_FUNC bool operator==(const Real& x, const U& y) { return x[0] == y; } +template> = true> AUTODIFF_DEVICE_FUNC bool operator!=(const Real& x, const U& y) { return x[0] != y; } +template> = true> AUTODIFF_DEVICE_FUNC bool operator< (const Real& x, const U& y) { return x[0] < y; } +template> = true> AUTODIFF_DEVICE_FUNC bool operator> (const Real& x, const U& y) { return x[0] > y; } +template> = true> AUTODIFF_DEVICE_FUNC bool operator<=(const Real& x, const U& y) { return x[0] <= y; } +template> = true> AUTODIFF_DEVICE_FUNC bool operator>=(const Real& x, const U& y) { return x[0] >= y; } + +template> = true> AUTODIFF_DEVICE_FUNC bool operator==(const U& x, const Real& y) { return x == y[0]; } +template> = true> AUTODIFF_DEVICE_FUNC bool operator!=(const U& x, const Real& y) { return x != y[0]; } +template> = true> AUTODIFF_DEVICE_FUNC bool operator< (const U& x, const Real& y) { return x < y[0]; } +template> = true> AUTODIFF_DEVICE_FUNC bool operator> (const U& x, const Real& y) { return x > y[0]; } +template> = true> AUTODIFF_DEVICE_FUNC bool operator<=(const U& x, const Real& y) { return x <= y[0]; } +template> = true> AUTODIFF_DEVICE_FUNC bool operator>=(const U& x, const Real& y) { return x >= y[0]; } //===================================================================================================================== // @@ -958,7 +961,7 @@ template>...> bool op //===================================================================================================================== template -auto seed(Real& real, U&& seedval) +AUTODIFF_DEVICE_FUNC auto seed(Real& real, U&& seedval) { static_assert(order == 1, "Real is optimized for higher-order **directional** derivatives. " @@ -976,14 +979,14 @@ auto seed(Real& real, U&& seedval) /// Return the value of a Real number. template -constexpr auto val(const Real& x) +AUTODIFF_DEVICE_FUNC constexpr auto val(const Real& x) { return x[0]; } /// Return the derivative of a Real number with given order. template -constexpr auto derivative(const Real& x) +AUTODIFF_DEVICE_FUNC constexpr auto derivative(const Real& x) { return x[order]; } diff --git a/inst/include/autodiff/forward/utils/derivative.hpp b/inst/include/autodiff/forward/utils/derivative.hpp index 77c9524..f710cde 100644 --- a/inst/include/autodiff/forward/utils/derivative.hpp +++ b/inst/include/autodiff/forward/utils/derivative.hpp @@ -7,7 +7,7 @@ // // Licensed under the MIT License . // -// Copyright (c) 2018-2022 Allan Leal +// Copyright © 2018–2024 Allan Leal // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -60,21 +60,21 @@ struct Along /// The keyword used to denote the variables *with respect to* the derivative is calculated. template -auto wrt(Args&&... args) +AUTODIFF_DEVICE_FUNC auto wrt(Args&&... args) { return Wrt{ std::forward_as_tuple(std::forward(args)...) }; } /// The keyword used to denote the variables *at* which the derivatives are calculated. template -auto at(Args&&... args) +AUTODIFF_DEVICE_FUNC auto at(Args&&... args) { return At{ std::forward_as_tuple(std::forward(args)...) }; } /// The keyword used to denote the direction vector *along* which the derivatives are calculated. template -auto along(Args&&... args) +AUTODIFF_DEVICE_FUNC auto along(Args&&... args) { return Along{ std::forward_as_tuple(std::forward(args)...) }; } @@ -89,7 +89,7 @@ auto along(Args&&... args) /// y, z, z, z)`. This automatic seeding permits derivatives `fx`, `fxy`, /// `fxyz`, `fxyzz`, and `fxyzzz` to be computed in a more convenient way. template -auto seed(const Wrt& wrt, T&& seedval) +AUTODIFF_DEVICE_FUNC auto seed(const Wrt& wrt, T&& seedval) { constexpr static auto N = Order; constexpr static auto size = 1 + sizeof...(Vars); @@ -103,19 +103,19 @@ auto seed(const Wrt& wrt, T&& seedval) } template -auto seed(const Wrt& wrt) +AUTODIFF_DEVICE_FUNC auto seed(const Wrt& wrt) { seed(wrt, 1.0); } template -auto unseed(const Wrt& wrt) +AUTODIFF_DEVICE_FUNC auto unseed(const Wrt& wrt) { seed(wrt, 0.0); } template -auto seed(const At& at, const Along& along) +AUTODIFF_DEVICE_FUNC auto seed(const At& at, const Along& along) { static_assert(sizeof...(Args) == sizeof...(Vecs)); @@ -131,7 +131,7 @@ auto seed(const At& at, const Along& along) } template -auto unseed(const At& at) +AUTODIFF_DEVICE_FUNC auto unseed(const At& at) { ForEach(at.args, [&](auto& arg) constexpr { if constexpr (isVector) { @@ -142,61 +142,86 @@ auto unseed(const At& at) }); } -template>...> -auto seed(T& x) +template> = true> +AUTODIFF_DEVICE_FUNC auto seed(T& x) { seed(x, 1.0); } -template>...> -auto unseed(T& x) +template> = true> +AUTODIFF_DEVICE_FUNC auto unseed(T& x) { seed(x, 0.0); } +#ifdef __CUDA_ARCH__ +template +AUTODIFF_DEVICE_FUNC constexpr decltype(auto) device_apply_impl(F&& f, Tuple&& t, std::index_sequence) { + return std::forward(f)(std::get(std::forward(t))...); +} + +template +AUTODIFF_DEVICE_FUNC constexpr decltype(auto) device_apply(F&& f, Tuple&& t) { + return device_apply_impl( + std::forward(f), + std::forward(t), + std::make_index_sequence>>{} + ); +} +#endif + template -auto eval(const Fun& f, const At& at, const Wrt& wrt) +AUTODIFF_DEVICE_FUNC auto eval(const Fun& f, const At& at, const Wrt& wrt) { seed(wrt); +#ifdef __CUDA_ARCH__ + auto u = device_apply(f, at.args); +#else auto u = std::apply(f, at.args); +#endif unseed(wrt); return u; } template -auto eval(const Fun& f, const At& at, const Along& along) +AUTODIFF_DEVICE_FUNC auto eval(const Fun& f, const At& at, const Along& along) { seed(at, along); +#ifdef __CUDA_ARCH__ + auto u = device_apply(f, at.args); +#else auto u = std::apply(f, at.args); +#endif unseed(at); return u; } /// Extract the derivative of given order from a vector of dual/real numbers. -template>...> -auto derivative(const Vec& u) +template> = true> +AUTODIFF_DEVICE_FUNC auto derivative(const Vec& u) { size_t len = u.size(); // the length of the vector containing dual/real numbers using NumType = decltype(u[0]); // get the type of the dual/real number using T = NumericType; // get the numeric/floating point type of the dual/real number using Res = VectorReplaceValueType; // get the type of the vector containing numeric values instead of dual/real numbers (e.g., vector becomes vector, VectorXdual becomes VectorXd, etc.) Res res(len); // create an array to store the derivatives stored inside the dual/real number - for(auto i = 0; i < len; ++i) + for(auto i = 0U; i < len; ++i) res[i] = derivative(u[i]); // get the derivative of given order from i-th dual/real number return res; } /// Alias method to `derivative(x)` where `x` is either a dual/real number or vector/array of such numbers. template -auto grad(const T& x) +AUTODIFF_DEVICE_FUNC auto grad(const T& x) { return derivative(x); } /// Unpack the derivatives from the result of an @ref eval call into an array. template -auto derivatives(const Result& result) +AUTODIFF_DEVICE_FUNC auto derivatives(const Result& result) { +#ifndef __CUDA_ARCH__ if constexpr (isVector) // check if the argument is a vector container of dual/real numbers { size_t len = result.size(); // the length of the vector containing dual/real numbers @@ -207,12 +232,13 @@ auto derivatives(const Result& result) std::array values; // create an array to store the derivatives stored inside the dual/real number For([&](auto i) constexpr { values[i].resize(len); - for(auto j = 0; j < len; ++j) + for(auto j = 0U; j < len; ++j) values[i][j] = derivative(result[j]); // get the ith derivative of the jth dual/real number }); return values; } else // result is then just a dual/real number +#endif { using T = NumericType; // get the numeric/floating point type of the dual/real result number constexpr auto N = Order; // the order of the dual/real result number @@ -225,27 +251,27 @@ auto derivatives(const Result& result) } template -auto derivatives(const Fun& f, const Wrt& wrt, const At& at) +AUTODIFF_DEVICE_FUNC auto derivatives(const Fun& f, const Wrt& wrt, const At& at) { return derivatives(eval(f, at, wrt)); } template -auto derivative(const Fun& f, const Wrt& wrt, const At& at, Result& u) +AUTODIFF_DEVICE_FUNC auto derivative(const Fun& f, const Wrt& wrt, const At& at, Result& u) { u = derivatives(f, wrt, at); return derivative(u); } template -auto derivative(const Fun& f, const Wrt& wrt, const At& at) +AUTODIFF_DEVICE_FUNC auto derivative(const Fun& f, const Wrt& wrt, const At& at) { auto u = eval(f, at, wrt); return derivative(u); } template -auto derivatives(const Fun& f, const Along& along, const At& at) +AUTODIFF_DEVICE_FUNC auto derivatives(const Fun& f, const Along& along, const At& at) { return derivatives(eval(f, at, along)); } diff --git a/inst/include/autodiff/forward/utils/gradient.hpp b/inst/include/autodiff/forward/utils/gradient.hpp index 3527077..d060806 100644 --- a/inst/include/autodiff/forward/utils/gradient.hpp +++ b/inst/include/autodiff/forward/utils/gradient.hpp @@ -7,7 +7,7 @@ // // Licensed under the MIT License . // -// Copyright (c) 2018-2022 Allan Leal +// Copyright © 2018–2024 Allan Leal // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -32,6 +32,7 @@ // autodiff includes #include #include +#include #include namespace autodiff { @@ -65,9 +66,16 @@ constexpr auto ForEachWrtVar(const Wrt& wrt, Function&& f) auto i = 0; // the current index of the variable in the wrt list ForEach(wrt.args, [&](auto& item) constexpr { - if constexpr (isVector) { + using T = decltype(item); + static_assert(isVector || Order > 0, "Expecting a wrt list with either vectors or individual autodiff numbers."); + if constexpr (isVector) { for(auto j = 0; j < item.size(); ++j) - f(i++, item(j)); // call given f with current index and variable from item (a vector) + // call given f with current index and variable from item (a vector) + if constexpr (detail::has_operator_bracket()) { + f(i++, item[j]); + } else { + f(i++, item(j)); + } } else f(i++, item); // call given f with current index and variable from item (a number, not a vector) }); @@ -88,6 +96,7 @@ void gradient(const Fun& f, const Wrt& wrt, const At& at, Y& u ForEachWrtVar(wrt, [&](auto&& i, auto&& xi) constexpr { + static_assert(!isConst, "Expecting non-const autodiff numbers in wrt list because these need to be seeded, and thus altered!"); u = eval(f, at, detail::wrt(xi)); // evaluate u with xi seeded so that du/dxi is also computed g[i] = derivative<1>(u); }); @@ -125,6 +134,7 @@ void jacobian(const Fun& f, const Wrt& wrt, const At& at, Y& F size_t m = 0; ForEachWrtVar(wrt, [&](auto&& i, auto&& xi) constexpr { + static_assert(!isConst, "Expecting non-const autodiff numbers in wrt list because these need to be seeded, and thus altered!"); F = eval(f, at, detail::wrt(xi)); // evaluate F with xi seeded so that dF/dxi is also computed if(m == 0) { m = F.size(); J.resize(m, n); }; for(size_t row = 0; row < m; ++row) @@ -174,6 +184,7 @@ void hessian(const Fun& f, const Wrt& wrt, const At& at, U& u, ForEachWrtVar(wrt, [&](auto&& i, auto&& xi) constexpr { ForEachWrtVar(wrt, [&](auto&& j, auto&& xj) constexpr { + static_assert(!isConst && !isConst, "Expecting non-const autodiff numbers in wrt list because these need to be seeded, and thus altered!"); if(j >= i) { // this take advantage of the fact the Hessian matrix is symmetric u = eval(f, at, detail::wrt(xi, xj)); // evaluate u with xi and xj seeded to produce u0, du/dxi, d2u/dxidxj g[i] = derivative<1>(u); // get du/dxi from u diff --git a/inst/include/autodiff/forward/utils/taylorseries.hpp b/inst/include/autodiff/forward/utils/taylorseries.hpp index 3ae31b3..b312f6b 100644 --- a/inst/include/autodiff/forward/utils/taylorseries.hpp +++ b/inst/include/autodiff/forward/utils/taylorseries.hpp @@ -7,7 +7,7 @@ // // Licensed under the MIT License . // -// Copyright (c) 2018-2022 Allan Leal +// Copyright © 2018–2024 Allan Leal // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal diff --git a/inst/include/autodiff/pybind11/eigen.hpp b/inst/include/autodiff/pybind11/eigen.hpp index 4884d39..15fdff1 100644 --- a/inst/include/autodiff/pybind11/eigen.hpp +++ b/inst/include/autodiff/pybind11/eigen.hpp @@ -7,7 +7,7 @@ // // Licensed under the MIT License . // -// Copyright (c) 2018-2022 Allan Leal +// Copyright © 2018–2024 Allan Leal // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal diff --git a/inst/include/autodiff/reverse/var.hpp b/inst/include/autodiff/reverse/var.hpp index a818193..40bd97c 100644 --- a/inst/include/autodiff/reverse/var.hpp +++ b/inst/include/autodiff/reverse/var.hpp @@ -7,7 +7,7 @@ // // Licensed under the MIT License . // -// Copyright (c) 2018-2022 Allan Leal +// Copyright © 2018–2024 Allan Leal // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal diff --git a/inst/include/autodiff/reverse/var/eigen.hpp b/inst/include/autodiff/reverse/var/eigen.hpp index 8c80f28..2543626 100644 --- a/inst/include/autodiff/reverse/var/eigen.hpp +++ b/inst/include/autodiff/reverse/var/eigen.hpp @@ -7,7 +7,7 @@ // // Licensed under the MIT License . // -// Copyright (c) 2018-2022 Allan Leal +// Copyright © 2018–2024 Allan Leal // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -75,10 +75,40 @@ struct ScalarBinaryOpTraits, BinOp> typedef autodiff::Variable ReturnType; }; +template +struct NumTraits> : NumTraits // permits to get the epsilon, dummy_precision, lowest, highest functions +{ + typedef autodiff::Variable Real; + typedef autodiff::Variable NonInteger; + typedef autodiff::Variable Nested; + enum + { + IsComplex = 0, + IsInteger = 0, + IsSigned = 1, + RequireInitialization = 1, + ReadCost = 1, + AddCost = 3, + MulCost = 3 + }; +}; + +template +struct ScalarBinaryOpTraits, T, BinOp> +{ + typedef autodiff::Variable ReturnType; +}; + +template +struct ScalarBinaryOpTraits, BinOp> +{ + typedef autodiff::Variable ReturnType; +}; } // namespace Eigen namespace autodiff { +namespace reverse { namespace detail { template @@ -185,10 +215,12 @@ auto hessian(const Variable& y, Eigen::DenseBase& x) } } // namespace detail + // +} // namespace reverse AUTODIFF_DEFINE_EIGEN_TYPEDEFS_ALL_SIZES(autodiff::var, var) -using detail::gradient; -using detail::hessian; +using reverse::detail::gradient; +using reverse::detail::hessian; } // namespace autodiff diff --git a/inst/include/autodiff/reverse/var/var.hpp b/inst/include/autodiff/reverse/var/var.hpp index 4c140b2..1389112 100644 --- a/inst/include/autodiff/reverse/var/var.hpp +++ b/inst/include/autodiff/reverse/var/var.hpp @@ -7,7 +7,7 @@ // // Licensed under the MIT License . // -// Copyright (c) 2018-2022 Allan Leal +// Copyright © 2018–2024 Allan Leal // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -46,11 +46,13 @@ namespace autodiff {} namespace autodiff { -namespace detail { - -using detail::EnableIf; +// avoid clash with autodiff::detail in autodiff/forward/dual/dual.hpp +namespace reverse { +using detail::Requires; using detail::For; using detail::isArithmetic; +namespace detail { + using std::abs; using std::acos; @@ -135,6 +137,7 @@ struct isVariable { constexpr static bool value = false; }; template struct isVariable> { constexpr static bool value = true; }; + } // namespace traits template @@ -146,78 +149,6 @@ constexpr auto VariableOrder = traits::VariableOrder::value; template constexpr auto isVariable = traits::isVariable::value; -//------------------------------------------------------------------------------ -// CONVENIENT FUNCTIONS (DECLARATION ONLY) -//------------------------------------------------------------------------------ -template -ExprPtr constant(const T& val); - -//------------------------------------------------------------------------------ -// ARITHMETIC OPERATORS (DECLARATION ONLY) -//------------------------------------------------------------------------------ -template ExprPtr operator+(const ExprPtr& r); -template ExprPtr operator-(const ExprPtr& r); - -template ExprPtr operator+(const ExprPtr& l, const ExprPtr& r); -template ExprPtr operator-(const ExprPtr& l, const ExprPtr& r); -template ExprPtr operator*(const ExprPtr& l, const ExprPtr& r); -template ExprPtr operator/(const ExprPtr& l, const ExprPtr& r); - -template>...> ExprPtr operator+(const U& l, const ExprPtr& r); -template>...> ExprPtr operator-(const U& l, const ExprPtr& r); -template>...> ExprPtr operator*(const U& l, const ExprPtr& r); -template>...> ExprPtr operator/(const U& l, const ExprPtr& r); - -template>...> ExprPtr operator+(const ExprPtr& l, const U& r); -template>...> ExprPtr operator-(const ExprPtr& l, const U& r); -template>...> ExprPtr operator*(const ExprPtr& l, const U& r); -template>...> ExprPtr operator/(const ExprPtr& l, const U& r); - -//------------------------------------------------------------------------------ -// TRIGONOMETRIC FUNCTIONS (DECLARATION ONLY) -//------------------------------------------------------------------------------ -template ExprPtr sin(const ExprPtr& x); -template ExprPtr cos(const ExprPtr& x); -template ExprPtr tan(const ExprPtr& x); -template ExprPtr asin(const ExprPtr& x); -template ExprPtr acos(const ExprPtr& x); -template ExprPtr atan(const ExprPtr& x); -template ExprPtr atan2(const ExprPtr& l, const ExprPtr& r); -template>...> ExprPtr atan2(const U& l, const ExprPtr& r); -template>...> ExprPtr atan2(const ExprPtr& l, const U& r); - -//------------------------------------------------------------------------------ -// HYPERBOLIC FUNCTIONS (DECLARATION ONLY) -//------------------------------------------------------------------------------ -template ExprPtr sinh(const ExprPtr& x); -template ExprPtr cosh(const ExprPtr& x); -template ExprPtr tanh(const ExprPtr& x); - -//------------------------------------------------------------------------------ -// EXPONENTIAL AND LOGARITHMIC FUNCTIONS (DECLARATION ONLY) -//------------------------------------------------------------------------------ -template ExprPtr exp(const ExprPtr& x); -template ExprPtr log(const ExprPtr& x); -template ExprPtr log10(const ExprPtr& x); - -//------------------------------------------------------------------------------ -// POWER FUNCTIONS (DECLARATION ONLY) -//------------------------------------------------------------------------------ -template ExprPtr sqrt(const ExprPtr& x); -template ExprPtr pow(const ExprPtr& l, const ExprPtr& r); -template>...> ExprPtr pow(const U& l, const ExprPtr& r); -template>...> ExprPtr pow(const ExprPtr& l, const U& r); - -//------------------------------------------------------------------------------ -// OTHER FUNCTIONS (DECLARATION ONLY) -//------------------------------------------------------------------------------ -template ExprPtr abs(const ExprPtr& x); -template ExprPtr abs2(const ExprPtr& x); -template ExprPtr conj(const ExprPtr& x); -template ExprPtr real(const ExprPtr& x); -template ExprPtr imag(const ExprPtr& x); -template ExprPtr erf(const ExprPtr& x); - /// The abstract type of any node type in the expression tree. template struct Expr @@ -233,6 +164,7 @@ struct Expr /// Bind a value pointer for writing the derivative during propagation virtual void bind_value(T* /* grad */) {} + /// Bind an expression pointer for writing the derivative expression during propagation virtual void bind_expr(ExprPtr* /* gradx */) {} @@ -333,6 +265,8 @@ struct ConstantExpr : Expr void update() override {} }; +template ExprPtr constant(const T& val) { return std::make_shared>(val); } + template struct UnaryExpr : Expr { @@ -427,8 +361,8 @@ struct SubExpr : BinaryExpr void propagatex(const ExprPtr& wprime) override { - l->propagatex(wprime); - r->propagatex(-wprime); + l->propagatex(wprime); // (l - r)'l = l' + r->propagatex(-wprime); // (l - r)'r = -r' } void update() override @@ -449,8 +383,8 @@ struct MulExpr : BinaryExpr void propagate(const T& wprime) override { - l->propagate(wprime * r->val); - r->propagate(wprime * l->val); + l->propagate(wprime * r->val); // (l * r)'l = w' * r + r->propagate(wprime * l->val); // (l * r)'r = l * w' } void propagatex(const ExprPtr& wprime) override @@ -769,7 +703,7 @@ struct ExpExpr : UnaryExpr void propagate(const T& wprime) override { - x->propagate(wprime * val); + x->propagate(wprime * val); // exp(x)' = exp(x) * x' } void propagatex(const ExprPtr& wprime) override @@ -793,7 +727,7 @@ struct LogExpr : UnaryExpr void propagate(const T& wprime) override { - x->propagate(wprime / x->val); + x->propagate(wprime / x->val); // log(x)' = x'/x } void propagatex(const ExprPtr& wprime) override @@ -849,19 +783,23 @@ struct PowExpr : BinaryExpr void propagate(const T& wprime) override { + using U = VariableValueType; + constexpr auto zero = U(0.0); const auto lval = l->val; const auto rval = r->val; const auto aux = wprime * pow(lval, rval - 1); l->propagate(aux * rval); - const auto auxr = lval == 0.0 ? 0.0 : lval * log(lval); // since x*log(x) -> 0 as x -> 0 + const auto auxr = lval == zero ? 0.0 : lval * log(lval); // since x*log(x) -> 0 as x -> 0 r->propagate(aux * auxr); } void propagatex(const ExprPtr& wprime) override { + using U = VariableValueType; + constexpr auto zero = U(0.0); const auto aux = wprime * pow(l, r - 1); l->propagatex(aux * r); - const auto auxr = l == 0.0 ? 0.0*l : l * log(l); // since x*log(x) -> 0 as x -> 0 + const auto auxr = l == zero ? 0.0*l : l * log(l); // since x*log(x) -> 0 as x -> 0 r->propagatex(aux * auxr); } @@ -918,7 +856,7 @@ struct PowConstantRightExpr : BinaryExpr void propagate(const T& wprime) override { - l->propagate(wprime * pow(l->val, r->val - 1) * r->val); + l->propagate(wprime * pow(l->val, r->val - 1) * r->val); // pow(l, r)'l = r * pow(l, r - 1) * l' } void propagatex(const ExprPtr& wprime) override @@ -943,7 +881,7 @@ struct SqrtExpr : UnaryExpr void propagate(const T& wprime) override { - x->propagate(wprime / (2.0 * sqrt(x->val))); + x->propagate(wprime / (2.0 * sqrt(x->val))); // sqrt(x)' = 1/2 * 1/sqrt(x) * x' } void propagatex(const ExprPtr& wprime) override @@ -1000,7 +938,7 @@ struct ErfExpr : UnaryExpr void propagate(const T& wprime) override { - const auto aux = 2.0 / sqrt_pi * exp(-(x->val) * (x->val)); + const auto aux = 2.0 / sqrt_pi * exp(-(x->val) * (x->val)); // erf(x)' = 2/sqrt(pi) * exp(-x * x) * x' x->propagate(wprime * aux); } @@ -1029,14 +967,14 @@ struct Hypot2Expr : BinaryExpr void propagate(const T& wprime) override { - l->propagate(wprime * l->val / val); - r->propagate(wprime * r->val / val); + l->propagate(wprime * l->val / val); // sqrt(l*l + r*r)'l = 1/2 * 1/sqrt(l*l + r*r) * (2*l*l') = (l*l')/sqrt(l*l + r*r) + r->propagate(wprime * r->val / val); // sqrt(l*l + r*r)'r = 1/2 * 1/sqrt(l*l + r*r) * (2*r*r') = (r*r')/sqrt(l*l + r*r) } void propagatex(const ExprPtr& wprime) override { - l->propagatex(wprime * l / val); - r->propagatex(wprime * r / val); + l->propagatex(wprime * l / hypot(l, r)); + r->propagatex(wprime * r / hypot(l, r)); } void update() override @@ -1067,9 +1005,9 @@ struct Hypot3Expr : TernaryExpr void propagatex(const ExprPtr& wprime) override { - l->propagatex(wprime * l / val); - c->propagatex(wprime * c / val); - r->propagatex(wprime * r / val); + l->propagatex(wprime * l / hypot(l, c, r)); + c->propagatex(wprime * c / hypot(l, c, r)); + r->propagatex(wprime * r / hypot(l, c, r)); } void update() override @@ -1159,7 +1097,6 @@ struct ConditionalExpr : Expr //------------------------------------------------------------------------------ // CONVENIENT FUNCTIONS //------------------------------------------------------------------------------ -template ExprPtr constant(const T& val) { return std::make_shared>(val); } //------------------------------------------------------------------------------ // ARITHMETIC OPERATORS @@ -1172,15 +1109,15 @@ template ExprPtr operator-(const ExprPtr& l, const ExprPtr& template ExprPtr operator*(const ExprPtr& l, const ExprPtr& r) { return std::make_shared>(l->val * r->val, l, r); } template ExprPtr operator/(const ExprPtr& l, const ExprPtr& r) { return std::make_shared>(l->val / r->val, l, r); } -template>...> ExprPtr operator+(const U& l, const ExprPtr& r) { return constant(l) + r; } -template>...> ExprPtr operator-(const U& l, const ExprPtr& r) { return constant(l) - r; } -template>...> ExprPtr operator*(const U& l, const ExprPtr& r) { return constant(l) * r; } -template>...> ExprPtr operator/(const U& l, const ExprPtr& r) { return constant(l) / r; } +template> = true> ExprPtr operator+(const U& l, const ExprPtr& r) { return constant(l) + r; } +template> = true> ExprPtr operator-(const U& l, const ExprPtr& r) { return constant(l) - r; } +template> = true> ExprPtr operator*(const U& l, const ExprPtr& r) { return constant(l) * r; } +template> = true> ExprPtr operator/(const U& l, const ExprPtr& r) { return constant(l) / r; } -template>...> ExprPtr operator+(const ExprPtr& l, const U& r) { return l + constant(r); } -template>...> ExprPtr operator-(const ExprPtr& l, const U& r) { return l - constant(r); } -template>...> ExprPtr operator*(const ExprPtr& l, const U& r) { return l * constant(r); } -template>...> ExprPtr operator/(const ExprPtr& l, const U& r) { return l / constant(r); } +template> = true> ExprPtr operator+(const ExprPtr& l, const U& r) { return l + constant(r); } +template> = true> ExprPtr operator-(const ExprPtr& l, const U& r) { return l - constant(r); } +template> = true> ExprPtr operator*(const ExprPtr& l, const U& r) { return l * constant(r); } +template> = true> ExprPtr operator/(const ExprPtr& l, const U& r) { return l / constant(r); } //------------------------------------------------------------------------------ // TRIGONOMETRIC FUNCTIONS @@ -1192,27 +1129,27 @@ template ExprPtr asin(const ExprPtr& x) { return std::make_sha template ExprPtr acos(const ExprPtr& x) { return std::make_shared>(acos(x->val), x); } template ExprPtr atan(const ExprPtr& x) { return std::make_shared>(atan(x->val), x); } template ExprPtr atan2(const ExprPtr& l, const ExprPtr& r) { return std::make_shared>(atan2(l->val, r->val), l, r); } -template>...> ExprPtr atan2(const U& l, const ExprPtr& r) { return std::make_shared>(atan2(l, r->val), constant(l), r); } -template>...> ExprPtr atan2(const ExprPtr& l, const U& r) { return std::make_shared>(atan2(l->val, r), l, constant(r)); } +template> = true> ExprPtr atan2(const U& l, const ExprPtr& r) { return std::make_shared>(atan2(l, r->val), constant(l), r); } +template> = true> ExprPtr atan2(const ExprPtr& l, const U& r) { return std::make_shared>(atan2(l->val, r), l, constant(r)); } //------------------------------------------------------------------------------ // HYPOT2 FUNCTIONS //------------------------------------------------------------------------------ template ExprPtr hypot(const ExprPtr& l, const ExprPtr& r) { return std::make_shared>(hypot(l->val, r->val), l, r); } -template>...> ExprPtr hypot(const U& l, const ExprPtr& r) { return std::make_shared>(hypot(l, r->val), constant(l), r); } -template>...> ExprPtr hypot(const ExprPtr& l, const U& r) { return std::make_shared>(hypot(l->val, r), l, constant(r)); } +template> = true> ExprPtr hypot(const U& l, const ExprPtr& r) { return std::make_shared>(hypot(l, r->val), constant(l), r); } +template> = true> ExprPtr hypot(const ExprPtr& l, const U& r) { return std::make_shared>(hypot(l->val, r), l, constant(r)); } //------------------------------------------------------------------------------ // HYPOT3 FUNCTIONS //------------------------------------------------------------------------------ template ExprPtr hypot(const ExprPtr& l, const ExprPtr& c, const ExprPtr& r) { return std::make_shared>(hypot(l->val,c->val, r->val), l, c, r); } -template>...> ExprPtr hypot(const ExprPtr& l, const ExprPtr& c, const U& r) { return std::make_shared>(hypot(l->val, c->val, r), l, c, constant(r)); } -template>...> ExprPtr hypot(const U& l, const ExprPtr& c, const ExprPtr& r) { return std::make_shared>(hypot(l, c->val, r->val), constant(l), c, r); } -template>...> ExprPtr hypot(const ExprPtr& l,const U& c, const ExprPtr& r) { return std::make_shared>(hypot(l->val, c, r->val), l, constant(c), r); } -template && isArithmetic>...> ExprPtr hypot(const ExprPtr& l, const U& c, const V& r) { return std::make_shared>(hypot(l->val, c, r), l, constant(c), constant(r)); } -template && isArithmetic>...> ExprPtr hypot(const U& l, const ExprPtr& c, const V& r) { return std::make_shared>(hypot(l, c->val, r), constant(l), c, constant(r)); } -template && isArithmetic>...> ExprPtr hypot(const V& l, const U& c, const ExprPtr& r) { return std::make_shared>(hypot(l, c, r->val), constant(l), constant(c), r); } +template> = true> ExprPtr hypot(const ExprPtr& l, const ExprPtr& c, const U& r) { return std::make_shared>(hypot(l->val, c->val, r), l, c, constant(r)); } +template> = true> ExprPtr hypot(const U& l, const ExprPtr& c, const ExprPtr& r) { return std::make_shared>(hypot(l, c->val, r->val), constant(l), c, r); } +template> = true> ExprPtr hypot(const ExprPtr& l,const U& c, const ExprPtr& r) { return std::make_shared>(hypot(l->val, c, r->val), l, constant(c), r); } +template && isArithmetic> = true> ExprPtr hypot(const ExprPtr& l, const U& c, const V& r) { return std::make_shared>(hypot(l->val, c, r), l, constant(c), constant(r)); } +template && isArithmetic> = true> ExprPtr hypot(const U& l, const ExprPtr& c, const V& r) { return std::make_shared>(hypot(l, c->val, r), constant(l), c, constant(r)); } +template && isArithmetic> = true> ExprPtr hypot(const V& l, const U& c, const ExprPtr& r) { return std::make_shared>(hypot(l, c, r->val), constant(l), constant(c), r); } //------------------------------------------------------------------------------ // HYPERBOLIC FUNCTIONS @@ -1233,8 +1170,8 @@ template ExprPtr log10(const ExprPtr& x) { return std::make_sh //------------------------------------------------------------------------------ template ExprPtr sqrt(const ExprPtr& x) { return std::make_shared>(sqrt(x->val), x); } template ExprPtr pow(const ExprPtr& l, const ExprPtr& r) { return std::make_shared>(pow(l->val, r->val), l, r); } -template>...> ExprPtr pow(const U& l, const ExprPtr& r) { return std::make_shared>(pow(l, r->val), constant(l), r); } -template>...> ExprPtr pow(const ExprPtr& l, const U& r) { return std::make_shared>(pow(l->val, r), l, constant(r)); } +template> = true> ExprPtr pow(const U& l, const ExprPtr& r) { return std::make_shared>(pow(l, r->val), constant(l), r); } +template> = true> ExprPtr pow(const ExprPtr& l, const U& r) { return std::make_shared>(pow(l->val, r), l, constant(r)); } //------------------------------------------------------------------------------ // OTHER FUNCTIONS @@ -1260,7 +1197,7 @@ struct Variable Variable(const Variable& other) : Variable(other.expr) {} /// Construct a Variable object with given arithmetic value - template>...> + template> = true> Variable(const U& val) : expr(std::make_shared>(val)) {} /// Construct a Variable object with given expression @@ -1284,11 +1221,8 @@ struct Variable /// Implicitly convert this Variable object into an expression pointer. operator const ExprPtr&() const { return expr; } - /// Explicitly convert this Variable object into its underlying arithmetic type. - explicit operator T() const { return expr->val; } - /// Assign an arithmetic value to this variable. - template>...> + template> = true> auto operator=(const U& val) -> Variable& { *this = Variable(val); return *this; } /// Assign an expression to this variable. @@ -1301,17 +1235,29 @@ struct Variable Variable& operator/=(const ExprPtr& x) { *this = Variable(expr / x); return *this; } // Assignment operators with arithmetic values - template>...> Variable& operator+=(const U& x) { *this = Variable(expr + x); return *this; } - template>...> Variable& operator-=(const U& x) { *this = Variable(expr - x); return *this; } - template>...> Variable& operator*=(const U& x) { *this = Variable(expr * x); return *this; } - template>...> Variable& operator/=(const U& x) { *this = Variable(expr / x); return *this; } + template> = true> Variable& operator+=(const U& x) { *this = Variable(expr + x); return *this; } + template> = true> Variable& operator-=(const U& x) { *this = Variable(expr - x); return *this; } + template> = true> Variable& operator*=(const U& x) { *this = Variable(expr * x); return *this; } + template> = true> Variable& operator/=(const U& x) { *this = Variable(expr / x); return *this; } + +#if defined(AUTODIFF_ENABLE_IMPLICIT_CONVERSION_VAR) || defined(AUTODIFF_ENABLE_IMPLICIT_CONVERSION) + operator T() const { return expr->val; } + + template + operator U() const { return static_cast(expr->val); } +#else + explicit operator T() const { return expr->val; } + + template + explicit operator U() const { return static_cast(expr->val); } +#endif }; //------------------------------------------------------------------------------ // EXPRESSION TRAITS //------------------------------------------------------------------------------ -template>...> T expr_value(const T& t) { return t; } +template> = true> T expr_value(const T& t) { return t; } template T expr_value(const ExprPtr& t) { return t->val; } template T expr_value(const Variable& t) { return t.expr->val; } @@ -1324,7 +1270,7 @@ template static auto is_expr_test(long) -> std::false_type; template struct is_expr : decltype(is_expr_test(0)) {}; template constexpr bool is_expr_v = is_expr::value; -template>...> ExprPtr coerce_expr(const U& u) { return constant(u); } +template> = true> ExprPtr coerce_expr(const U& u) { return constant(u); } template ExprPtr coerce_expr(const ExprPtr& t) { return t; } template ExprPtr coerce_expr(const Variable& t) { return t.expr; } @@ -1342,32 +1288,32 @@ auto comparison_operator(const T& t, const U& u) { return expr_comparison(coerce_expr(t), coerce_expr(u), Comparator {}); } -template>...> +template> = true> auto operator == (const T& t, const U& u) { return comparison_operator>(t, u); } -template>...> +template> = true> auto operator != (const T& t, const U& u) { return comparison_operator>(t, u); } -template>...> +template> = true> auto operator <= (const T& t, const U& u) { return comparison_operator>(t, u); } -template>...> +template> = true> auto operator >= (const T& t, const U& u) { return comparison_operator>(t, u); } -template>...> +template> = true> auto operator < (const T& t, const U& u) { return comparison_operator>(t, u); } -template>...> +template> = true> auto operator > (const T& t, const U& u) { return comparison_operator>(t, u); } //------------------------------------------------------------------------------ // CONDITION AND RELATED FUNCTIONS //------------------------------------------------------------------------------ -template && is_expr_v>...> +template && is_expr_v> = true> auto condition(BooleanExpr&& p, const T& t, const U& u) { using C = expr_common_t; ExprPtr expr = std::make_shared>(std::forward(p), coerce_expr(t), coerce_expr(u)); return expr; } -template>...> auto min(const T& x, const U& y) { return condition(x < y, x, y); } -template>...> auto max(const T& x, const U& y) { return condition(x > y, x, y); } +template> = true> auto min(const T& x, const U& y) { return condition(x < y, x, y); } +template> = true> auto max(const T& x, const U& y) { return condition(x > y, x, y); } template ExprPtr sgn(const ExprPtr& x) { return condition(x < 0, -1.0, condition(x > 0, 1.0, 0.0)); } template ExprPtr sgn(const Variable& x) { return condition(x.expr < 0, -1.0, condition(x.expr > 0, 1.0, 0.0)); } @@ -1392,15 +1338,15 @@ template ExprPtr operator-(const Variable& l, const ExprPtr template ExprPtr operator*(const Variable& l, const ExprPtr& r) { return l.expr * r; } template ExprPtr operator/(const Variable& l, const ExprPtr& r) { return l.expr / r; } -template>...> ExprPtr operator+(const U& l, const Variable& r) { return l + r.expr; } -template>...> ExprPtr operator-(const U& l, const Variable& r) { return l - r.expr; } -template>...> ExprPtr operator*(const U& l, const Variable& r) { return l * r.expr; } -template>...> ExprPtr operator/(const U& l, const Variable& r) { return l / r.expr; } +template> = true> ExprPtr operator+(const U& l, const Variable& r) { return l + r.expr; } +template> = true> ExprPtr operator-(const U& l, const Variable& r) { return l - r.expr; } +template> = true> ExprPtr operator*(const U& l, const Variable& r) { return l * r.expr; } +template> = true> ExprPtr operator/(const U& l, const Variable& r) { return l / r.expr; } -template>...> ExprPtr operator+(const Variable& l, const U& r) { return l.expr + r; } -template>...> ExprPtr operator-(const Variable& l, const U& r) { return l.expr - r; } -template>...> ExprPtr operator*(const Variable& l, const U& r) { return l.expr * r; } -template>...> ExprPtr operator/(const Variable& l, const U& r) { return l.expr / r; } +template> = true> ExprPtr operator+(const Variable& l, const U& r) { return l.expr + r; } +template> = true> ExprPtr operator-(const Variable& l, const U& r) { return l.expr - r; } +template> = true> ExprPtr operator*(const Variable& l, const U& r) { return l.expr * r; } +template> = true> ExprPtr operator/(const Variable& l, const U& r) { return l.expr / r; } //------------------------------------------------------------------------------ // TRIGONOMETRIC FUNCTIONS (DEFINED FOR ARGUMENTS OF TYPE Variable) @@ -1412,26 +1358,26 @@ template ExprPtr asin(const Variable& x) { return asin(x.expr) template ExprPtr acos(const Variable& x) { return acos(x.expr); } template ExprPtr atan(const Variable& x) { return atan(x.expr); } template ExprPtr atan2(const Variable & l, const Variable & r) { return atan2(l.expr, r.expr); } -template>...> ExprPtr atan2(const U& l, const Variable& r) { return atan2(l, r.expr); } -template>...> ExprPtr atan2(const Variable& l, const U& r) { return atan2(l.expr, r); } +template> = true> ExprPtr atan2(const U& l, const Variable& r) { return atan2(l, r.expr); } +template> = true> ExprPtr atan2(const Variable& l, const U& r) { return atan2(l.expr, r); } //------------------------------------------------------------------------------ // HYPOT2 FUNCTIONS (DEFINED FOR ARGUMENTS OF TYPE Variable) //------------------------------------------------------------------------------ template ExprPtr hypot(const Variable& l, const Variable& r) { return hypot(l.expr, r.expr); } -template>...> ExprPtr hypot(const U& l, const Variable& r) { return hypot(l, r.expr); } -template>...> ExprPtr hypot(const Variable& l, const U& r) { return hypot(l.expr, r); } +template> = true> ExprPtr hypot(const U& l, const Variable& r) { return hypot(l, r.expr); } +template> = true> ExprPtr hypot(const Variable& l, const U& r) { return hypot(l.expr, r); } //------------------------------------------------------------------------------ // HYPOT3 FUNCTIONS (DEFINED FOR ARGUMENTS OF TYPE Variable) //------------------------------------------------------------------------------ template ExprPtr hypot(const Variable &l, const Variable &c, const Variable &r) { return hypot(l.expr, c.expr, r.expr); } -template && isArithmetic>...> ExprPtr hypot(const Variable& l, const U& c, const V& r) { return hypot(l.expr, c, r); } -template && isArithmetic>...> ExprPtr hypot(const U& l, const Variable& c, const V& r) { return hypot(l, c.expr, r); } -template && isArithmetic>...> ExprPtr hypot(const U& l, const V& c, const Variable& r) { return hypot(l, c, r.expr); } -template>...> ExprPtr hypot(const Variable &l, const Variable &c, const U& r) { return hypot(l.expr, c.expr, r); } -template>...> ExprPtr hypot(const U &l, const Variable &c, const Variable& r) { return hypot(l, c.expr, r.expr); } -template>...> ExprPtr hypot(const Variable &l, const U &c, const Variable& r) { return hypot(l.expr, c, r.expr); } +template && isArithmetic> = true> ExprPtr hypot(const Variable& l, const U& c, const V& r) { return hypot(l.expr, c, r); } +template && isArithmetic> = true> ExprPtr hypot(const U& l, const Variable& c, const V& r) { return hypot(l, c.expr, r); } +template && isArithmetic> = true> ExprPtr hypot(const U& l, const V& c, const Variable& r) { return hypot(l, c, r.expr); } +template> = true> ExprPtr hypot(const Variable &l, const Variable &c, const U& r) { return hypot(l.expr, c.expr, r); } +template> = true> ExprPtr hypot(const U &l, const Variable &c, const Variable& r) { return hypot(l, c.expr, r.expr); } +template> = true> ExprPtr hypot(const Variable &l, const U &c, const Variable& r) { return hypot(l.expr, c, r.expr); } //------------------------------------------------------------------------------ // HYPERBOLIC FUNCTIONS (DEFINED FOR ARGUMENTS OF TYPE Variable) @@ -1452,8 +1398,8 @@ template ExprPtr log10(const Variable& x) { return log10(x.exp //------------------------------------------------------------------------------ template ExprPtr sqrt(const Variable& x) { return sqrt(x.expr); } template ExprPtr pow(const Variable& l, const Variable& r) { return pow(l.expr, r.expr); } -template>...> ExprPtr pow(const U& l, const Variable& r) { return pow(l, r.expr); } -template>...> ExprPtr pow(const Variable& l, const U& r) { return pow(l.expr, r); } +template> = true> ExprPtr pow(const U& l, const Variable& r) { return pow(l, r.expr); } +template> = true> ExprPtr pow(const Variable& l, const U& r) { return pow(l.expr, r); } //------------------------------------------------------------------------------ // OTHER FUNCTIONS (DEFINED FOR ARGUMENTS OF TYPE Variable) @@ -1465,7 +1411,7 @@ template ExprPtr real(const Variable& x) { return real(x.expr) template ExprPtr imag(const Variable& x) { return imag(x.expr); } template ExprPtr erf(const Variable& x) { return erf(x.expr); } -template>...> +template> = true> auto val(const T& t) { return expr_value(t); } /// Return the derivatives of a variable y with respect to all independent variables. @@ -1580,13 +1526,15 @@ using HigherOrderVariable = typename AuxHigherOrderVariable::type; } // namespace detail -using detail::wrt; -using detail::derivatives; -using detail::Variable; -using detail::val; +} // namespace reverse + +using reverse::detail::wrt; +using reverse::detail::derivatives; +using reverse::detail::Variable; +using reverse::detail::val; using var = Variable; -inline detail::BooleanExpr boolref(const bool& v) { return detail::BooleanExpr([&]() { return v; }); } +inline reverse::detail::BooleanExpr boolref(const bool& v) { return reverse::detail::BooleanExpr([&]() { return v; }); } } // namespace autodiff diff --git a/src/RcppExports.o b/src/RcppExports.o index 62e8592..c2d61bf 100644 Binary files a/src/RcppExports.o and b/src/RcppExports.o differ diff --git a/src/Rcppautodiff.so b/src/Rcppautodiff.so index c9e7f7d..ec2bbe0 100755 Binary files a/src/Rcppautodiff.so and b/src/Rcppautodiff.so differ diff --git a/src/autodiff-single-variable.o b/src/autodiff-single-variable.o index 6b2249c..55d3468 100644 Binary files a/src/autodiff-single-variable.o and b/src/autodiff-single-variable.o differ