-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.cpp
More file actions
43 lines (35 loc) · 1.42 KB
/
Copy pathmain.cpp
File metadata and controls
43 lines (35 loc) · 1.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
#include <iostream>
#include "utils.h"
#include "NNet/NNet.h"
int main() {
const int n = 4, iterations = 100;
std::vector<matrix<double> > batchX;
batchX.push_back(std::vector<std::vector<double> >{{0, 0}});
batchX.push_back(std::vector<std::vector<double> >{{0, 1}});
batchX.push_back(std::vector<std::vector<double> >{{1, 0}});
batchX.push_back(std::vector<std::vector<double> >{{1, 1}});
std::vector<matrix<double> > batchY;
batchY.push_back(std::vector<std::vector<double> >{{1, 0}});
batchY.push_back(std::vector<std::vector<double> >{{0, 1}});
batchY.push_back(std::vector<std::vector<double> >{{0, 1}});
batchY.push_back(std::vector<std::vector<double> >{{1, 0}});
net::NNet nn({net::Linear(2, 2), net::Linear(2, 2)});
for (int iteration = 0; iteration < iterations; ++iteration) {
for (int i = 0; i < n; ++i) {
nn.propagate_front(batchX[i]);
net::Loss loss(nn, batchY[i], net::MSE, net::MSE_derivative);
// std::cout << loss.sum() << std::endl;
nn.propagate_back(loss.get_loss_derivative());
}
// std::cout << '\n';
}
nn.propagate_front(batchX[0]);
std::cout << nn.get_output();
nn.propagate_front(batchX[1]);
std::cout << nn.get_output();
nn.propagate_front(batchX[2]);
std::cout << nn.get_output();
nn.propagate_front(batchX[3]);
std::cout << nn.get_output();
return 0;
}