forked from sql-machine-learning/elasticdl
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add SGD Kernel Go Package (sql-machine-learning#1637)
* add sgd kernel using eigen * test style * move kernel to go dir * add kernel test * refine test * fix eigen3 header file * trigger go unittest * fix ci * fix ci * fix ci * follow comments * trigger go test * update * follow comments * update * update * fix code style * fix dockerfile
- Loading branch information
Showing
13 changed files
with
147 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
--- | ||
Language: Cpp | ||
BasedOnStyle: Google | ||
IndentWidth: 2 | ||
TabWidth: 2 | ||
ContinuationIndentWidth: 4 | ||
AccessModifierOffset: -1 # The private/protected/public has no indent in class | ||
Standard: Cpp11 | ||
AllowAllParametersOfDeclarationOnNextLine: true | ||
BinPackParameters: false | ||
BinPackArguments: false | ||
--- | ||
Language: Proto | ||
BasedOnStyle: Google | ||
IndentWidth: 2 | ||
TabWidth: 2 | ||
ContinuationIndentWidth: 4 | ||
--- |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,8 @@ __pycache__/ | |
|
||
# C extensions | ||
*.so | ||
*.o | ||
*.a | ||
|
||
# Distribution / packaging | ||
.Python | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ addons: | |
- docker-ce | ||
- python3-pip | ||
- python3-setuptools | ||
- clang-format | ||
|
||
install: | ||
- docker version | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
#include "kernel_api.h" | ||
|
||
#include <eigen3/Eigen/Dense> | ||
|
||
void SGD(float* grad, float* param, double lr, long long size) { | ||
Eigen::Map<Eigen::Array<float, 1, Eigen::Dynamic>> eg{ | ||
grad, static_cast<Eigen::Index>(size)}; | ||
|
||
Eigen::Map<Eigen::Array<float, 1, Eigen::Dynamic>> ep{ | ||
param, static_cast<Eigen::Index>(size)}; | ||
|
||
ep -= lr * eg; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
#ifndef ELASTICDL_PKG_KERNEL_CAPI_KERNEL_API_H_ | ||
#define ELASTICDL_PKG_KERNEL_CAPI_KERNEL_API_H_ | ||
|
||
#ifdef __cplusplus | ||
extern "C" { | ||
#endif | ||
|
||
void SGD(float* grad, float* param, double lr, long long size); | ||
|
||
#ifdef __cplusplus | ||
} | ||
#endif | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
package kernel | ||
|
||
// #cgo LDFLAGS: -L./capi -lkernel_api | ||
// #include "capi/kernel_api.h" | ||
import "C" | ||
import "unsafe" | ||
|
||
// SGD kernel | ||
func SGD(grad []float32, param []float32, lr float64, size int64) { | ||
gradPtr := (*C.float)(unsafe.Pointer(&grad[0])) | ||
paramPtr := (*C.float)(unsafe.Pointer(¶m[0])) | ||
C.SGD(gradPtr, paramPtr, C.double(lr), C.longlong(size)) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
package kernel | ||
|
||
import "testing" | ||
import "math/rand" | ||
import "github.com/stretchr/testify/assert" | ||
|
||
func TestSGD(t *testing.T) { | ||
const size int = 10 | ||
a := make([]float32, size) | ||
b := make([]float32, size) | ||
var lr float32 = 0.1 | ||
|
||
for i := 0; i < size; i++ { | ||
a[i] = rand.Float32() | ||
b[i] = rand.Float32() | ||
} | ||
|
||
expected := make([]float32, size) | ||
|
||
for i := 0; i < size; i++ { | ||
expected[i] = b[i] - lr*a[i] | ||
} | ||
|
||
SGD(a, b, float64(lr), int64(size)) | ||
|
||
assert.Equal(t, b, expected) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
clang-format $@ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
#!/bin/bash | ||
|
||
TOTAL_ERRORS=0 | ||
if [[ ! $TRAVIS_BRANCH ]]; then | ||
# install cpplint on local machine. | ||
if [[ ! $(which cpplint) ]]; then | ||
pip install cpplint | ||
fi | ||
# diff files on local machine. | ||
files=$(git diff --cached --name-status | awk 'Extra open brace or missing close brace2}') | ||
else | ||
# diff files between PR and latest commit on Travis CI. | ||
branch_ref=$(git rev-parse "$TRAVIS_BRANCH") | ||
head_ref=$(git rev-parse HEAD) | ||
files=$(git diff --name-status $branch_ref $head_ref | awk 'Extra open brace or missing close brace2}') | ||
fi | ||
# The trick to remove deleted files: https://stackoverflow.com/a/2413151 | ||
for file in $files; do | ||
if [[ $file =~ ^(patches/.*) ]]; then | ||
continue; | ||
else | ||
cpplint --filter=-readability/fn_size $file; | ||
TOTAL_ERRORS=$(expr $TOTAL_ERRORS + $?); | ||
fi | ||
done | ||
|
||
exit $TOTAL_ERRORS |