Skip to content

Commit

Permalink
add Pow
Browse files Browse the repository at this point in the history
  • Loading branch information
arseniybelkov committed Jul 22, 2024
1 parent 5e7dd3d commit 571eda4
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 19 deletions.
29 changes: 28 additions & 1 deletion src/differentiable.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use std::ops::{Add, Div, Mul, Neg, Sub};

pub trait Differentiable:
Add<Output = Self>
// TODO: should Differentiable require Float
Float
+ Add<Output = Self>
+ Sub<Output = Self>
+ Mul<Output = Self>
+ Div<Output = Self>
Expand Down Expand Up @@ -33,3 +35,28 @@ impl Differentiable for f64 {
1f64
}
}

pub trait Float: Copy {
fn pow(self, n: Self) -> Self;
fn log(self) -> Self;
}

impl Float for f32 {
fn pow(self, n: Self) -> Self {
self.powf(n)
}

fn log(self) -> Self {
self.ln()
}
}

impl Float for f64 {
fn pow(self, n: Self) -> Self {
self.powf(n)
}

fn log(self) -> Self {
self.ln()
}
}
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
mod differentiable;
mod value;

pub use crate::differentiable::Differentiable;
pub use crate::differentiable::{Differentiable, Float};
pub use crate::value::Value;
45 changes: 36 additions & 9 deletions src/value.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,32 @@
use crate::differentiable::Differentiable;
use crate::Differentiable;
use std::cell::Cell;
use std::ops::{Add, Div, Mul, Neg, Sub};

#[derive(Clone)]
enum Operation<'a, T: Differentiable + Copy> {
enum Operation<'a, T: Copy> {
Add(&'a Value<'a, T>, &'a Value<'a, T>),
Sub(&'a Value<'a, T>, &'a Value<'a, T>),
Mul(&'a Value<'a, T>, &'a Value<'a, T>),
Div(&'a Value<'a, T>, &'a Value<'a, T>),
Pow(&'a Value<'a, T>, &'a Value<'a, T>),
Neg(&'a Value<'a, T>),
}

#[derive(Clone)]
pub struct Value<'a, T: Differentiable + Copy> {
pub struct Value<'a, T: Copy> {
data: Cell<T>,
grad: Cell<T>,
operation: Option<Operation<'a, T>>,
}

impl<'a, T: Differentiable> Value<'a, T> {
pub fn pow(&'a self, n: &'a Value<'a, T>) -> Value<'a, T> {
let mut value = Self::new(self.data.get().pow(n.data.get()));
value.operation = Some(Operation::Pow(self, n));
value
}
}

impl<'a, T: Differentiable + Copy> Value<'a, T> {
pub fn new(data: T) -> Self {
Self {
Expand All @@ -27,14 +36,20 @@ impl<'a, T: Differentiable + Copy> Value<'a, T> {
}
}

pub fn data(&self) -> T {
self.data.get()
}

pub fn grad(&self) -> T {
self.grad.get()
}

pub fn zero_grad(&self) {
self.grad.set(T::zero_grad());
}
}

impl<'a, T: Differentiable> Value<'a, T> {
pub fn backward(&self) {
// dy / dy = 1
self._backward(T::eye_grad());
Expand Down Expand Up @@ -68,6 +83,18 @@ impl<'a, T: Differentiable + Copy> Value<'a, T> {
v.grad.set(v.grad() - T::eye_grad() * grad);
v._backward(v.grad());
}
Operation::Pow(v1, v2) => {
v1.grad.set(
v1.grad()
+ v2.data.get()
* v1.data.get().pow(v2.data.get() - T::eye_grad())
* grad,
);
v2.grad.set(
v2.grad() + v1.data.get().pow(v2.data.get()) * v1.data.get().log() * grad,
);
pair_backward(v1, v2)
}
},
None => {
// end of graph
Expand All @@ -76,7 +103,7 @@ impl<'a, T: Differentiable + Copy> Value<'a, T> {
}
}

fn pair_backward<'a, T: Differentiable + Copy>(v1: &'a Value<T>, v2: &'a Value<T>) {
fn pair_backward<'a, T: Differentiable>(v1: &'a Value<T>, v2: &'a Value<T>) {
v1._backward(v1.grad());
if !std::ptr::eq(v1, v2) {
v2._backward(v2.grad());
Expand All @@ -85,7 +112,7 @@ fn pair_backward<'a, T: Differentiable + Copy>(v1: &'a Value<T>, v2: &'a Value<T

impl<'a, T> Add<&'a Value<'a, T>> for &'a Value<'a, T>
where
T: Add<T, Output = T> + Differentiable + Copy,
T: Differentiable,
{
type Output = Value<'a, T>;
fn add(self, rhs: &'a Value<'a, T>) -> Self::Output {
Expand All @@ -97,7 +124,7 @@ where

impl<'a, T> Sub<&'a Value<'a, T>> for &'a Value<'a, T>
where
T: Sub<T, Output = T> + Differentiable + Copy,
T: Differentiable,
{
type Output = Value<'a, T>;
fn sub(self, rhs: &'a Value<'a, T>) -> Self::Output {
Expand All @@ -109,7 +136,7 @@ where

impl<'a, T> Mul<&'a Value<'a, T>> for &'a Value<'a, T>
where
T: Mul<T, Output = T> + Differentiable + Copy,
T: Differentiable,
{
type Output = Value<'a, T>;
fn mul(self, rhs: &'a Value<'a, T>) -> Self::Output {
Expand All @@ -121,7 +148,7 @@ where

impl<'a, T> Div<&'a Value<'a, T>> for &'a Value<'a, T>
where
T: Div<T, Output = T> + Differentiable + Copy,
T: Differentiable,
{
type Output = Value<'a, T>;
fn div(self, rhs: &'a Value<'a, T>) -> Self::Output {
Expand All @@ -133,7 +160,7 @@ where

impl<'a, T> Neg for &'a Value<'a, T>
where
T: Neg<Output = T> + Differentiable + Copy,
T: Differentiable,
{
type Output = Value<'a, T>;
fn neg(self) -> Self::Output {
Expand Down
47 changes: 39 additions & 8 deletions tests/test_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,22 @@ fn test_simple() {

#[test]
fn test_deep() {
let x = Value::new(5f32);
let y = Value::new(4f32);
let z = Value::new(8f32);
let x = Value::new(5f64);
let y = Value::new(4f64);
let z = Value::new(8f64);

let a = &y + &x;
let b = &z * &a;
let c = &b * &b;
let r = &c - &z;
let d = &c - &z;
let e = &z * &d;
let f = &e + &e;
let result = &f / &x;

r.backward();
assert_eq!(x.grad(), 1152f32);
assert_eq!(y.grad(), 1152f32);
assert_eq!(z.grad(), 1295f32);
result.backward();
assert!(f64::abs(x.grad() - 373.76) < 10e-8f64);
assert!(f64::abs(y.grad() - 3686.4) < 10e-8f64);
assert!(f64::abs(z.grad() - 6214.4) < 10e-8f64);
}

#[cfg(test)]
Expand Down Expand Up @@ -150,6 +153,34 @@ mod operations {
assert_eq!(x.grad(), 0f32);
}

#[test]
fn test_pow() {
let x = Value::new(1f32);
let y = Value::new(2f32);

let result = x.pow(&y);
result.backward();
assert_eq!(x.grad(), 2f32);
assert_eq!(y.grad(), 0f32);

let x = Value::new(2f64);
let y = Value::new(3f64);
let z = Value::new(0.1f64);

let a = x.pow(&y);
let b = &a + &z;
let result = b.pow(&z);

result.backward();
assert!(f64::abs(x.grad() - 0.1826f64) < 1e-3f64);
assert!(f64::abs(y.grad() - 0.0844f64) < 1e-3f64);
assert!(f64::abs(z.grad() - 2.5938f64) < 1e-3f64);

let x = Value::new(3f64);
x.pow(&x).backward();
assert!(f64::abs(x.grad() - 56.6625f64) < 1e-3f64);
}

#[test]
fn test_neg() {
let x = Value::new(123f32);
Expand Down

0 comments on commit 571eda4

Please sign in to comment.