Skip to content

Commit

Permalink
Updated backend autodiff
Browse files Browse the repository at this point in the history
Updated backend autodiff
  • Loading branch information
sn248 committed Jun 13, 2024
1 parent 4d66792 commit 0fb0353
Show file tree
Hide file tree
Showing 25 changed files with 541 additions and 479 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Suggests:
rmarkdown,
tinytest
Encoding: UTF-8
RoxygenNote: 7.2.1
RoxygenNote: 7.3.1
Imports: Rcpp (>= 1.0.8.3), RcppEigen
LinkingTo: Rcpp, RcppEigen
VignetteBuilder: knitr
4 changes: 2 additions & 2 deletions inst/include/autodiff/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@ cc_library(

cc_library(
name = "reverse",
hdrs = glob(["reverse.hpp", "reverse/**/*.hpp"]),
hdrs = glob(["var.hpp", "reverse/**/*.hpp"]),
deps = ["@com_github_eigen_eigen//:eigen",
":common"],
visibility = ["//visibility:public"],
)

cc_library(
name = "forward",
hdrs = glob(["forward.hpp", "forward/**/*.hpp"]),
hdrs = glob(["dual.hpp", "real.hpp", "forward/**/*.hpp"]),
deps = [
"@com_github_eigen_eigen//:eigen",
":common",
Expand Down
4 changes: 4 additions & 0 deletions inst/include/autodiff/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ target_include_directories(autodiff
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}>
)

if(CMAKE_CUDA_COMPILER)
target_compile_options(autodiff INTERFACE $<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr --extended-lambda>)
endif()

# Install autodiff interface library
install(TARGETS autodiff EXPORT autodiffTargets)

Expand Down
2 changes: 1 addition & 1 deletion inst/include/autodiff/common/binomialcoefficient.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
//
// Licensed under the MIT License <http://opensource.org/licenses/MIT>.
//
// Copyright (c) 2018-2022 Allan Leal
// Copyright © 2018–2024 Allan Leal
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
Expand Down
9 changes: 8 additions & 1 deletion inst/include/autodiff/common/classtraits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
//
// Licensed under the MIT License <http://opensource.org/licenses/MIT>.
//
// Copyright (c) 2018-2022 Allan Leal
// Copyright © 2018–2024 Allan Leal
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -95,5 +95,12 @@ CREATE_MEMBER_CHECK(size);
template<typename T>
constexpr bool hasSize = has_member_size<PlainType<T>>::value;

// Create type trait struct `has_operator_bracket`.
template<class, typename T> struct has_operator_bracket_impl : std::false_type {};
template<typename T> struct has_operator_bracket_impl<decltype( void(std::declval<T>().operator [](0)) ), T> : std::true_type {};

/// Boolean type that is true if type T implements `operator[](int)` method.
template<typename T> struct has_operator_bracket : has_operator_bracket_impl<void, T> {};

} // namespace detail
} // namespace autodiff
11 changes: 10 additions & 1 deletion inst/include/autodiff/common/eigen.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
//
// Licensed under the MIT License <http://opensource.org/licenses/MIT>.
//
// Copyright (c) 2018-2022 Allan Leal
// Copyright © 2018–2024 Allan Leal
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -130,6 +130,15 @@ struct VectorTraits<Eigen::Ref<MatrixType>>
using ReplaceValueType = VectorReplaceValueType<MatrixType, NewValueType>;
};

template<typename VectorType, int MapOptions, typename StrideType>
struct VectorTraits<Eigen::Map<VectorType, MapOptions, StrideType>>
{
using ValueType = VectorValueType<VectorType>;

template<typename NewValueType>
using ReplaceValueType = Eigen::Map<VectorReplaceValueType<VectorType, NewValueType>, MapOptions, StrideType>;
};

//=====================================================================================================================
//
// AUXILIARY TEMPLATE TYPE ALIASES
Expand Down
49 changes: 32 additions & 17 deletions inst/include/autodiff/common/meta.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
//
// Licensed under the MIT License <http://opensource.org/licenses/MIT>.
//
// Copyright (c) 2018-2022 Allan Leal
// Copyright © 2018–2024 Allan Leal
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
Expand All @@ -34,24 +34,39 @@
#include <tuple>
#include <type_traits>

#ifndef AUTODIFF_DEVICE_FUNC
#ifdef AUTODIFF_EIGEN_FOUND
#include <Eigen/src/Core/util/Macros.h>
#define AUTODIFF_DEVICE_FUNC EIGEN_DEVICE_FUNC
#else
#define AUTODIFF_DEVICE_FUNC
#endif
#endif

namespace autodiff {
namespace detail {

template<bool value>
using EnableIf = typename std::enable_if<value>::type;
using EnableIf = std::enable_if_t<value>;

template<bool value>
using Requires = std::enable_if_t<value, bool>;

template<typename T>
using PlainType = typename std::remove_cv<typename std::remove_reference<T>::type>::type;
using PlainType = std::remove_cv_t<std::remove_reference_t<T>>;

template<bool Cond, typename WhenTrue, typename WhenFalse>
using ConditionalType = typename std::conditional<Cond, WhenTrue, WhenFalse>::type;
using ConditionalType = std::conditional_t<Cond, WhenTrue, WhenFalse>;

template<typename A, typename B>
using CommonType = typename std::common_type<A, B>::type;
using CommonType = std::common_type_t<A, B>;

template<typename Fun, typename... Args>
using ReturnType = std::invoke_result_t<Fun, Args...>;

template<typename T>
constexpr bool isConst = std::is_const_v<std::remove_reference_t<T>>;

template<typename T, typename U>
constexpr bool isConvertible = std::is_convertible<PlainType<T>, U>::value;

Expand All @@ -62,13 +77,13 @@ template<typename Tuple>
constexpr auto TupleSize = std::tuple_size_v<std::decay_t<Tuple>>;

template<typename Tuple>
constexpr auto TupleHead(Tuple&& tuple)
AUTODIFF_DEVICE_FUNC constexpr auto TupleHead(Tuple&& tuple)
{
return std::get<0>(std::forward<Tuple>(tuple));
}

template<typename Tuple>
constexpr auto TupleTail(Tuple&& tuple)
AUTODIFF_DEVICE_FUNC constexpr auto TupleTail(Tuple&& tuple)
{
auto g = [&](auto&&, auto&&... args) constexpr {
return std::forward_as_tuple(args...);
Expand All @@ -85,7 +100,7 @@ struct Index
};

template<size_t i, size_t ibegin, size_t iend, typename Function>
constexpr auto AuxFor(Function&& f)
AUTODIFF_DEVICE_FUNC constexpr auto AuxFor(Function&& f)
{
if constexpr (i < iend) {
f(Index<i>{});
Expand All @@ -94,19 +109,19 @@ constexpr auto AuxFor(Function&& f)
}

template<size_t ibegin, size_t iend, typename Function>
constexpr auto For(Function&& f)
AUTODIFF_DEVICE_FUNC constexpr auto For(Function&& f)
{
AuxFor<ibegin, ibegin, iend>(std::forward<Function>(f));
}

template<size_t iend, typename Function>
constexpr auto For(Function&& f)
AUTODIFF_DEVICE_FUNC constexpr auto For(Function&& f)
{
For<0, iend>(std::forward<Function>(f));
}

template<size_t i, size_t ibegin, size_t iend, typename Function>
constexpr auto AuxReverseFor(Function&& f)
AUTODIFF_DEVICE_FUNC constexpr auto AuxReverseFor(Function&& f)
{
if constexpr (i < iend)
{
Expand All @@ -116,19 +131,19 @@ constexpr auto AuxReverseFor(Function&& f)
}

template<size_t ibegin, size_t iend, typename Function>
constexpr auto ReverseFor(Function&& f)
AUTODIFF_DEVICE_FUNC constexpr auto ReverseFor(Function&& f)
{
AuxReverseFor<ibegin, ibegin, iend>(std::forward<Function>(f));
}

template<size_t iend, typename Function>
constexpr auto ReverseFor(Function&& f)
AUTODIFF_DEVICE_FUNC constexpr auto ReverseFor(Function&& f)
{
ReverseFor<0, iend>(std::forward<Function>(f));
}

template<typename Tuple, typename Function>
constexpr auto ForEach(Tuple&& tuple, Function&& f)
AUTODIFF_DEVICE_FUNC constexpr auto ForEach(Tuple&& tuple, Function&& f)
{
constexpr auto N = TupleSize<Tuple>;
For<N>([&](auto i) constexpr {
Expand All @@ -144,7 +159,7 @@ constexpr auto ForEach(Tuple&& tuple, Function&& f)
}

template<typename Tuple1, typename Tuple2, typename Function>
constexpr auto ForEach(Tuple1&& tuple1, Tuple2&& tuple2, Function&& f)
AUTODIFF_DEVICE_FUNC constexpr auto ForEach(Tuple1&& tuple1, Tuple2&& tuple2, Function&& f)
{
constexpr auto N1 = TupleSize<Tuple1>;
constexpr auto N2 = TupleSize<Tuple2>;
Expand All @@ -155,7 +170,7 @@ constexpr auto ForEach(Tuple1&& tuple1, Tuple2&& tuple2, Function&& f)
}

template<size_t ibegin, size_t iend, typename Function>
constexpr auto Sum(Function&& f)
AUTODIFF_DEVICE_FUNC constexpr auto Sum(Function&& f)
{
using ResultType = std::invoke_result_t<Function, Index<ibegin>>;
ResultType res = {};
Expand All @@ -166,7 +181,7 @@ constexpr auto Sum(Function&& f)
}

template<typename Tuple, typename Function>
constexpr auto Reduce(Tuple&& tuple, Function&& f)
AUTODIFF_DEVICE_FUNC constexpr auto Reduce(Tuple&& tuple, Function&& f)
{
using ResultType = std::invoke_result_t<Function, decltype(std::get<0>(tuple))>;
ResultType res = {};
Expand Down
2 changes: 1 addition & 1 deletion inst/include/autodiff/common/numbertraits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
//
// Licensed under the MIT License <http://opensource.org/licenses/MIT>.
//
// Copyright (c) 2018-2022 Allan Leal
// Copyright © 2018–2024 Allan Leal
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
Expand Down
2 changes: 1 addition & 1 deletion inst/include/autodiff/common/vectortraits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
//
// Licensed under the MIT License <http://opensource.org/licenses/MIT>.
//
// Copyright (c) 2018-2022 Allan Leal
// Copyright © 2018–2024 Allan Leal
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
Expand Down
2 changes: 1 addition & 1 deletion inst/include/autodiff/forward/dual.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
//
// Licensed under the MIT License <http://opensource.org/licenses/MIT>.
//
// Copyright (c) 2018-2022 Allan Leal
// Copyright © 2018–2024 Allan Leal
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
Expand Down
Loading

0 comments on commit 0fb0353

Please sign in to comment.