-
Notifications
You must be signed in to change notification settings - Fork 36
/
UnaryExpr.hpp
123 lines (104 loc) · 3.93 KB
/
UnaryExpr.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
/*******************************************************************************
Unary expressions.
This file is part of XAD, a comprehensive C++ library for
automatic differentiation.
Copyright (C) 2010-2024 Xcelerit Computing Ltd.
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published
by the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
******************************************************************************/
#pragma once
#include <XAD/Expression.hpp>
#include <XAD/Macros.hpp>
#include <XAD/Traits.hpp>
namespace xad
{
namespace detail
{
template <bool>
struct UnaryDerivativeImpl
{
template <class Op, class Scalar>
XAD_INLINE static Scalar derivative(const Op& op, const Scalar& a, const Scalar&)
{
return op.derivative(a);
}
};
template <>
struct UnaryDerivativeImpl<true>
{
template <class Op, class Scalar>
XAD_INLINE static Scalar derivative(const Op& op, const Scalar& a, const Scalar& v)
{
return op.derivative(a, v);
}
};
} // namespace detail
template <class, class>
struct Expression;
/// Base class of all unary expressions
template <class Scalar, class Op, class Expr>
struct UnaryExpr : Expression<Scalar, UnaryExpr<Scalar, Op, Expr> >
{
typedef detail::UnaryDerivativeImpl<OperatorTraits<Op>::useResultBasedDerivatives == 1>
der_impl;
XAD_INLINE explicit UnaryExpr(const Expr& a, Op op = Op()) : a_(a), op_(op), v_(op_(a_.value()))
{
}
XAD_INLINE Scalar value() const { return v_; }
template <class Tape>
XAD_INLINE void calc_derivatives(Tape& s, const Scalar& mul) const
{
using xad::value;
a_.calc_derivatives(s, mul * der_impl::template derivative(op_, value(a_), v_));
}
template <class Tape>
XAD_INLINE void calc_derivatives(Tape& s) const
{
using xad::value;
a_.calc_derivatives(s, der_impl::template derivative(op_, value(a_), v_));
}
template <typename Slot>
XAD_INLINE void calc_derivatives(Slot* slot, Scalar* muls, int& n, const Scalar& mul) const
{
using xad::value;
a_.calc_derivatives(slot, muls, n, mul * der_impl::template derivative(op_, value(a_), v_));
}
template <typename It1, typename It2>
XAD_INLINE void calc_derivatives(It1& sit, It2& mit, const Scalar& mul) const
{
using xad::value;
a_.calc_derivatives(sit, mit, mul * der_impl::template derivative(op_, value(a_), v_));
}
XAD_INLINE bool shouldRecord() const { return a_.shouldRecord(); }
XAD_INLINE Scalar derivative() const
{
using xad::derivative;
using xad::value;
return der_impl::template derivative(op_, value(a_), v_) * derivative(a_);
}
private:
Expr a_;
Op op_;
Scalar v_;
};
template <class Scalar, class Op, class Expr>
struct ExprTraits<UnaryExpr<Scalar, Op, Expr> >
{
static const bool isExpr = true;
static const int numVariables = ExprTraits<Expr>::numVariables;
static const bool isForward = ExprTraits<typename ExprTraits<Expr>::value_type>::isForward;
static const bool isReverse = ExprTraits<typename ExprTraits<Expr>::value_type>::isReverse;
static const bool isLiteral = false;
static const Direction direction = ExprTraits<typename ExprTraits<Expr>::value_type>::direction;
typedef typename ExprTraits<Scalar>::nested_type nested_type;
typedef typename ExprTraits<Expr>::value_type value_type;
};
} // namespace xad