-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrbf.go
100 lines (75 loc) · 2.3 KB
/
rbf.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
package gorbi
import (
"fmt"
"gonum.org/v1/gonum/floats"
"gonum.org/v1/gonum/mat"
"math"
)
// Radial basis functions based on the euclidean distance
func multiquadric(epsilon, r float64) float64 {
return math.Sqrt(math.Pow(1.0/epsilon*r, 2.0) + 1)
}
// Radial basis interpolator
type RBF struct {
xi [][]float64
vi []float64
n int
epsilon float64
function func(epsilon, r float64) float64
nodes *mat.Dense
}
// Constructor for the radial basis interpolator.
func NewRBF(args [][]float64, values []float64) (RBF,error ){
// Find the number of points
nPts := len(values)
// Find the size of the hypercube containing all points, and set epsilon as the average length of the sides
hypercubeDim := HypercubeDims(args)
epsilon := math.Pow(floats.Prod(hypercubeDim)/float64(nPts), 1./float64(len(hypercubeDim)))
// Set the radial basis function
// TODO: Add more basis functions and a nice API for changing basis functions
function := multiquadric
// Calculate the euclidean distance between all points
r := Cdist(args, args)
// Evaluate the radial basis function for all points and assemble into A
A := []float64{}
for _, ri := range r {
for _, r := range ri {
A = append(A, function(epsilon, r))
}
}
// Assemble the coordinates and values into matrices and solve for the node values
diMat := mat.NewDense(nPts, 1, values)
AMat := mat.NewDense(nPts, nPts, A)
nodes := mat.NewDense(nPts, 1, nil)
err := nodes.Solve(AMat, diMat)
if err != nil{
fmt.Println(err)
return RBF{},err
}
return RBF{xi: args,
vi: values,
n: nPts,
epsilon: epsilon,
function: function,
nodes: nodes,
},nil
}
// Get the interpolated value at the given coordinate
func (rbf *RBF) At(xs [][]float64) []float64 {
nPts := len(xs)
// Determine the distance between the current points and the points of the interpolated field
r := Cdist(xs, rbf.xi)
// Evaluate the basis functions for the radial distances
A := []float64{}
for _, ri := range r {
for _, r := range ri {
A = append(A, multiquadric(rbf.epsilon, r))
}
}
// Assemble into matrices and take the dot product of the values of the radial basis functions
// and the node values
AMat := mat.NewDense(nPts, rbf.n, A)
vals := mat.NewDense(nPts, 1, nil)
vals.Mul(AMat, rbf.nodes)
return vals.RawMatrix().Data
}