Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dual number, add grad method for dual number #19

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions attention/src/attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ mod tests {
embed_dim: 20,
num_head: 4,
seed: 0,
num_blocks: 1,
}
}

Expand Down
76 changes: 67 additions & 9 deletions elements/src/dual_number.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{
fmt::Display,
ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign},
ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign},
};

use interfaces::{
Expand Down Expand Up @@ -41,9 +41,23 @@ impl AddAssign for DualNumber {
}
}

impl Sub for DualNumber {
type Output = Self;
fn sub(self, rhs: Self) -> Self {
Self::new(self.real - rhs.real, self.dual - rhs.dual)
}
}

impl SubAssign for DualNumber {
fn sub_assign(&mut self, rhs: Self) {
self.real -= rhs.real;
self.dual -= rhs.dual;
}
}

impl Mul for DualNumber {
type Output = Self;
// (ax + (ay + xb) i) = (x + iy) * (a + bi)
// (x + ie) * (a + be) = (ax + (ay + xb) e)
fn mul(self, rhs: Self) -> Self::Output {
let real = self.real * rhs.real;
let dual = self.real * rhs.dual + self.dual * rhs.real;
Expand Down Expand Up @@ -99,7 +113,7 @@ impl Ln for DualNumber {
impl Pow for DualNumber {
fn pow(self, exp: Self) -> Self {
let real = self.real.powf(exp.real);
let dual = real * (exp.dual * self.real.ln() + exp.real * self.dual / self.real);
let dual = real * (exp.dual * self.real.ln() + (self.dual * exp.real / self.real));
Self::new(real, dual)
}
}
Expand All @@ -113,13 +127,37 @@ impl Zero for DualNumber {
}
}

// impl Element for DualNumber {}
impl Element for DualNumber {}

// impl RealElement for DualNumber {
// fn neg_inf() -> Self {
// Self::new(-f64::INFINITY, 0.)
// }
// }
fn grad<F: Fn(DualNumber) -> DualNumber>(func: F, value: f64) -> f64 {
func(DualNumber::new(value, 1.)).dual
}

impl DualNumber {
pub fn grad<F: Fn(DualNumber) -> DualNumber>(func: F, value: f64) -> f64 {
func(DualNumber::new(value, 1.)).dual
}
}

pub trait Grad {
fn grad(&self, value: f64) -> f64;
}

// Blanket implementation for Fn(f64) -> f64
impl<F> Grad for F
where
F: Fn(DualNumber) -> DualNumber,
{
fn grad(&self, value: f64) -> f64 {
grad(self, value)
}
}

impl RealElement for DualNumber {
fn neg_inf() -> Self {
Self::new(-f64::INFINITY, 0.)
}
}

impl From<f64> for DualNumber {
fn from(value: f64) -> Self {
Expand All @@ -146,11 +184,31 @@ mod tests {
dual_number.pow(DualNumber::new(3., 0.))
}

// Expression: f(x) = 2x^2 + exp(5x)
// f'(x)= 4x + 5 * exp(5x)
fn test_exp_fn(dual_number: DualNumber) -> DualNumber {
DualNumber::new(2., 0.) * dual_number.pow(DualNumber::new(2.0, 0.))
+ DualNumber::new(f64::exp(1.), 0.).pow(DualNumber::new(5., 0.) * dual_number)
}

fn test_exp_fn_deriv(value: f64) -> f64 {
4. * value + 5. * f64::exp(1.).powf(5. * value)
}

#[test]
fn test_cube() {
let dual_number = DualNumber::new(0.1, 1.);
let result = cube(dual_number);
assert_approx_eq!(f64, result.real, 0.001);
assert_approx_eq!(f64, result.dual, 0.03);
}

#[test]
fn test_grad() {
assert_approx_eq!(f64, DualNumber::grad(cube, 0.1), 0.03);
assert_approx_eq!(f64, grad(cube, 0.1), 0.03);
assert_approx_eq!(f64, cube.grad(0.1), 0.03);
assert_approx_eq!(f64, test_exp_fn.grad(3.0), test_exp_fn_deriv(3.0));
assert_approx_eq!(f64, test_exp_fn.grad(6.0), test_exp_fn_deriv(6.0));
}
}