-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnanograd.cpp
143 lines (122 loc) · 4 KB
/
nanograd.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#include <iostream>
#include <math.h>
#include <set>
#include <vector>
#include <unordered_map>
using namespace std;
class Value {
private:
float data;
float grad;
set<Value*> prev; // Store pointers to previous nodes
string op;
string label;
// Function to perform topological sort
void build_topo(vector<Value*>& topo, unordered_map<Value*, bool>& visited) {
if (visited[this]) return;
visited[this] = true;
for (auto child : prev) {
child->build_topo(topo, visited);
}
topo.push_back(this);
}
public:
Value(float data, const string& label = "") : data(data), grad(0.0), label(label) {}
// Overloading the + operator
Value operator+(const Value& other) const {
Value out(data + other.data, label + "+" + other.label);
out.prev.insert(const_cast<Value*>(this));
out.prev.insert(const_cast<Value*>(&other));
out.op = "+";
// Correctly assign the _backward function using a lambda
out._backward = [this, &other]() mutable {
const_cast<Value*>(this)->grad += 1.0 * other.grad; // Use const_cast to modify grad
const_cast<Value*>(&other)->grad += 1.0 * this->grad; // Use const_cast to modify grad
};
return out;
}
// Overloading the - operator
Value operator-(const Value& other) const {
Value out(data - other.data, label + "-" + other.label);
out.prev.insert(const_cast<Value*>(this));
out.prev.insert(const_cast<Value*>(&other));
out.op = "-";
return out;
}
// Overloading the * operator
Value operator*(const Value& other) const {
Value out(data * other.data, label + "*" + other.label);
out.prev.insert(const_cast<Value*>(this));
out.prev.insert(const_cast<Value*>(&other));
out.op = "*";
return out;
}
// Overloading the / operator
Value operator/(const Value& other) const {
Value out(data / other.data, label + "/" + other.label);
out.prev.insert(const_cast<Value*>(this));
out.prev.insert(const_cast<Value*>(&other));
out.op = "/";
return out;
}
// Overloading for the Unary operator
Value operator-() const {
Value out(-data, "-" + label);
out.prev.insert(const_cast<Value*>(this));
out.op = "neg";
return out;
}
// Overloading the << operator
friend ostream& operator<<(ostream& out, const Value& obj) {
out << "Value(data = " << obj.data << ", grad = " << obj.grad << ")";
return out;
}
// Power operator
Value pow(float other) const {
Value out(std::pow(data, other), label + "^" + to_string(other)); // Use std::pow
out.prev.insert(const_cast<Value*>(this));
out.op = "^";
return out;
}
// tanh function
Value tanh() const {
float t = (std::exp(2 * data) - 1) / (std::exp(2 * data) + 1); // Use std::exp
Value out(t, "tanh(" + label + ")");
out.prev.insert(const_cast<Value*>(this));
out.op = "tanh";
return out;
}
// exp function
Value exp() const {
Value out(std::exp(data), "exp(" + label + ")"); // Use std::exp
out.prev.insert(const_cast<Value*>(this));
out.op = "exp";
return out;
}
// Backward pass for backpropagation
void backward() {
vector<Value*> topo;
unordered_map<Value*, bool> visited;
build_topo(topo, visited);
this->grad = 1.0;
for (int i = topo.size() - 1; i >= 0; i--) {
topo[i]->_backward();
}
}
private:
// Internal backward function (to be overridden by subclasses)
mutable std::function<void()> _backward = [](){};
};
// Example usage:
int main() {
Value a(2.0, "a");
Value b(3.0, "b");
Value c = a * b;
Value d = c.tanh();
d.backward();
cout << "d: " << d << endl;
cout << "c: " << c << endl;
cout << "a: " << a << endl;
cout << "b: " << b << endl;
return 0;
}