diff --git a/README.md b/README.md index 9b8e0b0..57f4116 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,8 @@ # micrograd.rs +> _Quit your shitty ass job and go learn some skills._ +> +> _-_ George Hotz + Insanely lightweight Rust implementation of Andrej Karpathy's [micrograd](https://github.com/karpathy/micrograd). ## Example ```rust diff --git a/src/differentiable.rs b/src/differentiable.rs index 75dadb9..654a4de 100644 --- a/src/differentiable.rs +++ b/src/differentiable.rs @@ -1,7 +1,9 @@ use std::ops::{Add, Div, Mul, Neg, Sub}; pub trait Differentiable: - Add + // TODO: should Differentiable require Float + Float + + Add + Sub + Mul + Div @@ -33,3 +35,28 @@ impl Differentiable for f64 { 1f64 } } + +pub trait Float: Copy { + fn pow(self, n: Self) -> Self; + fn log(self) -> Self; +} + +impl Float for f32 { + fn pow(self, n: Self) -> Self { + self.powf(n) + } + + fn log(self) -> Self { + self.ln() + } +} + +impl Float for f64 { + fn pow(self, n: Self) -> Self { + self.powf(n) + } + + fn log(self) -> Self { + self.ln() + } +} diff --git a/src/lib.rs b/src/lib.rs index 462b35b..2ed203c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,5 @@ mod differentiable; mod value; -pub use crate::differentiable::Differentiable; +pub use crate::differentiable::{Differentiable, Float}; pub use crate::value::Value; diff --git a/src/value.rs b/src/value.rs index d7774fc..006d9b9 100644 --- a/src/value.rs +++ b/src/value.rs @@ -1,23 +1,32 @@ -use crate::differentiable::Differentiable; +use crate::Differentiable; use std::cell::Cell; use std::ops::{Add, Div, Mul, Neg, Sub}; #[derive(Clone)] -enum Operation<'a, T: Differentiable + Copy> { +enum Operation<'a, T: Copy> { Add(&'a Value<'a, T>, &'a Value<'a, T>), Sub(&'a Value<'a, T>, &'a Value<'a, T>), Mul(&'a Value<'a, T>, &'a Value<'a, T>), Div(&'a Value<'a, T>, &'a Value<'a, T>), + Pow(&'a Value<'a, T>, &'a Value<'a, T>), Neg(&'a Value<'a, T>), } #[derive(Clone)] -pub struct Value<'a, T: Differentiable + Copy> { +pub struct Value<'a, T: Copy> { data: Cell, grad: Cell, operation: Option>, } +impl<'a, T: Differentiable> Value<'a, T> { + pub fn pow(&'a self, n: &'a Value<'a, T>) -> Value<'a, T> { + let mut value = Self::new(self.data.get().pow(n.data.get())); + value.operation = Some(Operation::Pow(self, n)); + value + } +} + impl<'a, T: Differentiable + Copy> Value<'a, T> { pub fn new(data: T) -> Self { Self { @@ -27,6 +36,10 @@ impl<'a, T: Differentiable + Copy> Value<'a, T> { } } + pub fn data(&self) -> T { + self.data.get() + } + pub fn grad(&self) -> T { self.grad.get() } @@ -34,7 +47,9 @@ impl<'a, T: Differentiable + Copy> Value<'a, T> { pub fn zero_grad(&self) { self.grad.set(T::zero_grad()); } +} +impl<'a, T: Differentiable> Value<'a, T> { pub fn backward(&self) { // dy / dy = 1 self._backward(T::eye_grad()); @@ -68,6 +83,18 @@ impl<'a, T: Differentiable + Copy> Value<'a, T> { v.grad.set(v.grad() - T::eye_grad() * grad); v._backward(v.grad()); } + Operation::Pow(v1, v2) => { + v1.grad.set( + v1.grad() + + v2.data.get() + * v1.data.get().pow(v2.data.get() - T::eye_grad()) + * grad, + ); + v2.grad.set( + v2.grad() + v1.data.get().pow(v2.data.get()) * v1.data.get().log() * grad, + ); + pair_backward(v1, v2) + } }, None => { // end of graph @@ -76,7 +103,7 @@ impl<'a, T: Differentiable + Copy> Value<'a, T> { } } -fn pair_backward<'a, T: Differentiable + Copy>(v1: &'a Value, v2: &'a Value) { +fn pair_backward<'a, T: Differentiable>(v1: &'a Value, v2: &'a Value) { v1._backward(v1.grad()); if !std::ptr::eq(v1, v2) { v2._backward(v2.grad()); @@ -85,7 +112,7 @@ fn pair_backward<'a, T: Differentiable + Copy>(v1: &'a Value, v2: &'a Value Add<&'a Value<'a, T>> for &'a Value<'a, T> where - T: Add + Differentiable + Copy, + T: Differentiable, { type Output = Value<'a, T>; fn add(self, rhs: &'a Value<'a, T>) -> Self::Output { @@ -97,7 +124,7 @@ where impl<'a, T> Sub<&'a Value<'a, T>> for &'a Value<'a, T> where - T: Sub + Differentiable + Copy, + T: Differentiable, { type Output = Value<'a, T>; fn sub(self, rhs: &'a Value<'a, T>) -> Self::Output { @@ -109,7 +136,7 @@ where impl<'a, T> Mul<&'a Value<'a, T>> for &'a Value<'a, T> where - T: Mul + Differentiable + Copy, + T: Differentiable, { type Output = Value<'a, T>; fn mul(self, rhs: &'a Value<'a, T>) -> Self::Output { @@ -121,7 +148,7 @@ where impl<'a, T> Div<&'a Value<'a, T>> for &'a Value<'a, T> where - T: Div + Differentiable + Copy, + T: Differentiable, { type Output = Value<'a, T>; fn div(self, rhs: &'a Value<'a, T>) -> Self::Output { @@ -133,7 +160,7 @@ where impl<'a, T> Neg for &'a Value<'a, T> where - T: Neg + Differentiable + Copy, + T: Differentiable, { type Output = Value<'a, T>; fn neg(self) -> Self::Output { diff --git a/tests/test_value.rs b/tests/test_value.rs index cb37cea..5b99a6e 100644 --- a/tests/test_value.rs +++ b/tests/test_value.rs @@ -24,19 +24,22 @@ fn test_simple() { #[test] fn test_deep() { - let x = Value::new(5f32); - let y = Value::new(4f32); - let z = Value::new(8f32); + let x = Value::new(5f64); + let y = Value::new(4f64); + let z = Value::new(8f64); let a = &y + &x; let b = &z * &a; let c = &b * &b; - let r = &c - &z; + let d = &c - &z; + let e = &z * &d; + let f = &e + &e; + let result = &f / &x; - r.backward(); - assert_eq!(x.grad(), 1152f32); - assert_eq!(y.grad(), 1152f32); - assert_eq!(z.grad(), 1295f32); + result.backward(); + assert!(f64::abs(x.grad() - 373.76) < 10e-8f64); + assert!(f64::abs(y.grad() - 3686.4) < 10e-8f64); + assert!(f64::abs(z.grad() - 6214.4) < 10e-8f64); } #[cfg(test)] @@ -150,6 +153,34 @@ mod operations { assert_eq!(x.grad(), 0f32); } + #[test] + fn test_pow() { + let x = Value::new(1f32); + let y = Value::new(2f32); + + let result = x.pow(&y); + result.backward(); + assert_eq!(x.grad(), 2f32); + assert_eq!(y.grad(), 0f32); + + let x = Value::new(2f64); + let y = Value::new(3f64); + let z = Value::new(0.1f64); + + let a = x.pow(&y); + let b = &a + &z; + let result = b.pow(&z); + + result.backward(); + assert!(f64::abs(x.grad() - 0.1826f64) < 1e-3f64); + assert!(f64::abs(y.grad() - 0.0844f64) < 1e-3f64); + assert!(f64::abs(z.grad() - 2.5938f64) < 1e-3f64); + + let x = Value::new(3f64); + x.pow(&x).backward(); + assert!(f64::abs(x.grad() - 56.6625f64) < 1e-3f64); + } + #[test] fn test_neg() { let x = Value::new(123f32);