Skip to content

Commit

Permalink
refactor(internal/onnx/ir): rename internal/pb-onnx to internal/onnx/…
Browse files Browse the repository at this point in the history
…ir (#160)
  • Loading branch information
Romain Lespinasse authored and owulveryck committed Oct 18, 2019
1 parent cfe97d3 commit 91bc4b0
Show file tree
Hide file tree
Showing 721 changed files with 2,139 additions and 2,139 deletions.
34 changes: 17 additions & 17 deletions attributes.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package onnx

import (
pb "github.com/owulveryck/onnx-go/internal/pb-onnx"
"github.com/owulveryck/onnx-go/internal/onnx/ir"
)

func toOperationAttributes(attrs []*pb.AttributeProto) (map[string]interface{}, error) {
func toOperationAttributes(attrs []*ir.AttributeProto) (map[string]interface{}, error) {
output := make(map[string]interface{}, len(attrs))
for _, attr := range attrs {
o, err := toOperationAttribute(attr)
Expand All @@ -16,46 +16,46 @@ func toOperationAttributes(attrs []*pb.AttributeProto) (map[string]interface{},
return output, nil
}

func toOperationAttribute(attr *pb.AttributeProto) (interface{}, error) {
func toOperationAttribute(attr *ir.AttributeProto) (interface{}, error) {
switch attr.GetType() {
case pb.AttributeProto_UNDEFINED:
case ir.AttributeProto_UNDEFINED:
return struct{}{}, nil
case pb.AttributeProto_FLOAT:
case ir.AttributeProto_FLOAT:
return attr.GetF(), nil
case pb.AttributeProto_INT:
case ir.AttributeProto_INT:
return attr.GetI(), nil
case pb.AttributeProto_STRING:
case ir.AttributeProto_STRING:
return string(attr.GetS()), nil
case pb.AttributeProto_TENSOR:
case ir.AttributeProto_TENSOR:
return attr.GetT().Tensor()
case pb.AttributeProto_GRAPH:
case ir.AttributeProto_GRAPH:
return nil, &ErrNotImplemented{
AttributeName: attr.GetName(),
AttributeValue: attr,
Message: "pb.AttributeProto_GRAPH not handled yet",
Message: "ir.AttributeProto_GRAPH not handled yet",
}
case pb.AttributeProto_FLOATS:
case ir.AttributeProto_FLOATS:
return attr.GetFloats(), nil
case pb.AttributeProto_INTS:
case ir.AttributeProto_INTS:
return attr.GetInts(), nil
case pb.AttributeProto_STRINGS:
case ir.AttributeProto_STRINGS:
s := attr.GetStrings()
strings := make([]string, len(s))
for i := 0; i < len(s); i++ {
strings[i] = string(s[i])
}
return strings, nil
case pb.AttributeProto_TENSORS:
case ir.AttributeProto_TENSORS:
return nil, &ErrNotImplemented{
AttributeName: attr.GetName(),
AttributeValue: attr,
Message: "pb.AttributeProto_TENSORS not handled yet",
Message: "ir.AttributeProto_TENSORS not handled yet",
}
case pb.AttributeProto_GRAPHS:
case ir.AttributeProto_GRAPHS:
return nil, &ErrNotImplemented{
AttributeName: attr.GetName(),
AttributeValue: attr,
Message: "pb.AttributeProto_GRAPHS not handled yet",
Message: "ir.AttributeProto_GRAPHS not handled yet",
}
default:
return nil, &ErrNotImplemented{
Expand Down
36 changes: 18 additions & 18 deletions attributes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package onnx
import (
"testing"

pb "github.com/owulveryck/onnx-go/internal/pb-onnx"
"github.com/owulveryck/onnx-go/internal/onnx/ir"
"github.com/stretchr/testify/assert"
)

Expand All @@ -21,36 +21,36 @@ import (
AttributeProto_GRAPHS AttributeProto_AttributeType = 10
*/

func GetTestPBAttributeProto() []*pb.AttributeProto {
return []*pb.AttributeProto{
func GetTestPBAttributeProto() []*ir.AttributeProto {
return []*ir.AttributeProto{
{
Name: "floats",
Type: pb.AttributeProto_FLOATS,
Type: ir.AttributeProto_FLOATS,
Floats: []float32{1, 2},
},
{
Name: "float",
Type: pb.AttributeProto_FLOAT,
Type: ir.AttributeProto_FLOAT,
F: 1,
},
{
Name: "int",
Type: pb.AttributeProto_INT,
Type: ir.AttributeProto_INT,
I: 1,
},
{
Name: "ints",
Type: pb.AttributeProto_INTS,
Type: ir.AttributeProto_INTS,
Ints: []int64{1, 2},
},
{
Name: "string",
Type: pb.AttributeProto_STRING,
Type: ir.AttributeProto_STRING,
S: []byte("a"),
},
{
Name: "strings",
Type: pb.AttributeProto_STRINGS,
Type: ir.AttributeProto_STRINGS,
Strings: [][]byte{[]byte("a"), []byte("b")},
},
}
Expand Down Expand Up @@ -126,38 +126,38 @@ func TestToOperationAttributes_Float(t *testing.T) {
}

func TestToOperationAttributes_NotImplemented(t *testing.T) {
_, err := toOperationAttributes([]*pb.AttributeProto{
_, err := toOperationAttributes([]*ir.AttributeProto{
{
Type: pb.AttributeProto_GRAPH,
Type: ir.AttributeProto_GRAPH,
},
})
_, ok := err.(*ErrNotImplemented)
assert.True(t, ok)
_, err = toOperationAttributes([]*pb.AttributeProto{
_, err = toOperationAttributes([]*ir.AttributeProto{
{
Type: pb.AttributeProto_TENSORS,
Type: ir.AttributeProto_TENSORS,
},
})
_, ok = err.(*ErrNotImplemented)
assert.True(t, ok)
_, err = toOperationAttributes([]*pb.AttributeProto{
_, err = toOperationAttributes([]*ir.AttributeProto{
{
Type: pb.AttributeProto_GRAPHS,
Type: ir.AttributeProto_GRAPHS,
},
})
_, ok = err.(*ErrNotImplemented)
assert.True(t, ok)
_, err = toOperationAttributes([]*pb.AttributeProto{
_, err = toOperationAttributes([]*ir.AttributeProto{
{
Type: pb.AttributeProto_AttributeType(-1),
Type: ir.AttributeProto_AttributeType(-1),
},
})
_, ok = err.(*ErrNotImplemented)
assert.True(t, ok)
}

func TestToOperationAttributes_Undefined(t *testing.T) {
attrs, err := toOperationAttributes([]*pb.AttributeProto{
attrs, err := toOperationAttributes([]*ir.AttributeProto{
nil,
})
assert.NoError(t, err)
Expand Down
2 changes: 1 addition & 1 deletion backend/simple/graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ type Node struct {
description string
value tensor.Tensor
opType string
//attributes []*pb.Attribute
//attributes []*ir.Attribute
}

// ID to fulfil the graph.Node interface
Expand Down
28 changes: 14 additions & 14 deletions backend/testbackend/onnx/gen_cmd/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
"text/template"

"github.com/davecgh/go-spew/spew"
pb "github.com/owulveryck/onnx-go/internal/pb-onnx"
"github.com/owulveryck/onnx-go/internal/onnx/ir"
)

var (
Expand Down Expand Up @@ -85,7 +85,7 @@ func processFile(file os.FileInfo) (string, string, error) {
return "", "", err
}
tv.ModelB = fmt.Sprintf("%#v", b)
model := new(pb.ModelProto)
model := new(ir.ModelProto)
err = model.XXX_Unmarshal(b)
if err != nil {
return "", "", err
Expand Down Expand Up @@ -160,7 +160,7 @@ func processFile(file os.FileInfo) (string, string, error) {
return mv.NodeProto[0].OpType, tv.TestName, nil
}

func processModelGraphInput(model *pb.ModelProto, mv *modelValue) {
func processModelGraphInput(model *ir.ModelProto, mv *modelValue) {
mv.Input = make([]valueInfoProto, len(model.Graph.Input))
for i := range model.Graph.Input {
mv.Input[i] = valueInfoProto{
Expand All @@ -169,12 +169,12 @@ func processModelGraphInput(model *pb.ModelProto, mv *modelValue) {
Dims: make([]string, len(model.Graph.Input[i].Type.GetTensorType().Shape.Dim)),
}
for j, v := range model.Graph.Input[i].Type.GetTensorType().Shape.Dim {
mv.Input[i].Dims[j] = fmt.Sprintf("%v", v.GetValue().(*pb.TensorShapeProto_Dimension_DimValue).DimValue)
mv.Input[i].Dims[j] = fmt.Sprintf("%v", v.GetValue().(*ir.TensorShapeProto_Dimension_DimValue).DimValue)
}
}
}

func processModelGraphOutput(model *pb.ModelProto, mv *modelValue) {
func processModelGraphOutput(model *ir.ModelProto, mv *modelValue) {
mv.Output = make([]valueInfoProto, len(model.Graph.Output))
for i := range model.Graph.Output {
mv.Output[i] = valueInfoProto{
Expand All @@ -183,12 +183,12 @@ func processModelGraphOutput(model *pb.ModelProto, mv *modelValue) {
Dims: make([]string, len(model.Graph.Output[i].Type.GetTensorType().Shape.Dim)),
}
for j, v := range model.Graph.Output[i].Type.GetTensorType().Shape.Dim {
mv.Output[i].Dims[j] = fmt.Sprintf("%v", v.GetValue().(*pb.TensorShapeProto_Dimension_DimValue).DimValue)
mv.Output[i].Dims[j] = fmt.Sprintf("%v", v.GetValue().(*ir.TensorShapeProto_Dimension_DimValue).DimValue)
}
}
}

func processModelGraphValueInfo(model *pb.ModelProto, mv *modelValue) {
func processModelGraphValueInfo(model *ir.ModelProto, mv *modelValue) {
mv.ValueInfo = make([]valueInfoProto, len(model.Graph.ValueInfo))
for i := range model.Graph.ValueInfo {
mv.ValueInfo[i] = valueInfoProto{
Expand All @@ -197,22 +197,22 @@ func processModelGraphValueInfo(model *pb.ModelProto, mv *modelValue) {
Dims: make([]string, len(model.Graph.ValueInfo[i].Type.GetTensorType().Shape.Dim)),
}
for j, v := range model.Graph.ValueInfo[i].Type.GetTensorType().Shape.Dim {
mv.ValueInfo[i].Dims[j] = fmt.Sprintf("%v", v.GetValue().(*pb.TensorShapeProto_Dimension_DimValue).DimValue)
mv.ValueInfo[i].Dims[j] = fmt.Sprintf("%v", v.GetValue().(*ir.TensorShapeProto_Dimension_DimValue).DimValue)

}
}
}

func processModelGraphNodeInput(filename string, node *pb.NodeProto, tv *testValue) error {
func processModelGraphNodeInput(filename string, node *ir.NodeProto, tv *testValue) error {
tv.Input = make([]iO, len(node.GetInput()))
for i := range node.GetInput() {
// Open the tensorproto sample file
filepath := fmt.Sprintf("%v%v/test_data_set_0/input_%v.pb", *testdir, filename, i)
filepath := fmt.Sprintf("%v%v/test_data_set_0/input_%v.ir", *testdir, filename, i)
b, err := ioutil.ReadFile(filepath)
if err != nil {
return err
}
sampleTestData := new(pb.TensorProto)
sampleTestData := new(ir.TensorProto)
err = sampleTestData.XXX_Unmarshal(b)
if err != nil {
return err
Expand All @@ -239,16 +239,16 @@ func processModelGraphNodeInput(filename string, node *pb.NodeProto, tv *testVal
return nil
}

func processModelGraphNodeOutput(filename string, node *pb.NodeProto, tv *testValue) error {
func processModelGraphNodeOutput(filename string, node *ir.NodeProto, tv *testValue) error {
tv.ExpectedOutput = make([]iO, len(node.GetOutput()))
for i := range node.Output {
// Open the tensorproto sample file
filepath := fmt.Sprintf("%v%v/test_data_set_0/output_%v.pb", *testdir, filename, i)
filepath := fmt.Sprintf("%v%v/test_data_set_0/output_%v.ir", *testdir, filename, i)
b, err := ioutil.ReadFile(filepath)
if err != nil {
return err
}
sampleTestData := new(pb.TensorProto)
sampleTestData := new(ir.TensorProto)
err = sampleTestData.XXX_Unmarshal(b)
if err != nil {
return err
Expand Down
Loading

0 comments on commit 91bc4b0

Please sign in to comment.