-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.go
73 lines (61 loc) · 2.14 KB
/
train.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
package main
import (
"flag"
"fmt"
"github.com/ishandhanani/micrograd-go/engine"
"github.com/ishandhanani/micrograd-go/nn"
)
func main() {
verbose := flag.Bool("verbose", false, "Enable verbose output")
flag.Parse()
x := [][]*engine.Value{
{engine.NewValue(2.0, "x1", []*engine.Value{}), engine.NewValue(3.0, "x2", []*engine.Value{}), engine.NewValue(-1.0, "x3", []*engine.Value{})},
{engine.NewValue(3.0, "x1", []*engine.Value{}), engine.NewValue(-1.0, "x2", []*engine.Value{}), engine.NewValue(0.5, "x3", []*engine.Value{})},
{engine.NewValue(0.5, "x1", []*engine.Value{}), engine.NewValue(1.0, "x2", []*engine.Value{}), engine.NewValue(1.0, "x3", []*engine.Value{})},
{engine.NewValue(1.0, "x1", []*engine.Value{}), engine.NewValue(1.0, "x2", []*engine.Value{}), engine.NewValue(-1.0, "x3", []*engine.Value{})},
}
y_obs := []*engine.Value{
engine.NewValue(1.0, "y1", []*engine.Value{}),
engine.NewValue(-1.0, "y2", []*engine.Value{}),
engine.NewValue(-1.0, "y3", []*engine.Value{}),
engine.NewValue(1.0, "y4", []*engine.Value{}),
}
mlp := nn.NewMLP(3, []int{4, 4, 1})
learningRate := 0.01
for i := 0; i < 500; i++ {
// Forward pass
y_pred := make([]*engine.Value, len(x))
for j, x_i := range x {
y_pred[j] = mlp.Forward(x_i)[0]
}
// Compute loss
loss := nn.MSE(y_pred, y_obs)
// Backward pass
loss.BackwardPass()
fmt.Printf("Iteration %d: Loss: %.6f\n", i, loss.Data)
if *verbose {
fmt.Println("Predictions vs Observations:")
for j := range y_pred {
fmt.Printf(" Pred: %.6f, Obs: %.6f\n", y_pred[j].Data, y_obs[j].Data)
}
fmt.Println("Sample gradients and updates:")
for j, p := range mlp.Parameters() {
if j < 5 { // Print first 5 parameters
update := -learningRate * p.Grad
fmt.Printf(" Param: %.6f, Grad: %.6f, Update: %.6f\n", p.Data, p.Grad, update)
}
}
fmt.Println()
}
// Optimizer
for _, p := range mlp.Parameters() {
p.Data = p.Data - (learningRate * p.Grad)
p.Grad = 0.0
}
}
fmt.Println("\nFinal Predictions:")
for j, x_i := range x {
pred := mlp.Forward(x_i)[0].Data
fmt.Printf("Input %d: Pred: %.6f, Obs: %.6f\n", j, pred, y_obs[j].Data)
}
}