-
Notifications
You must be signed in to change notification settings - Fork 9
Add BatchNormalization operator #186
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
Swopper050
wants to merge
19
commits into
AdvancedClimateSystems:develop
Choose a base branch
from
Swopper050:5-batch-normalization
base: develop
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from 17 commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
e613122
WIP on RNN
Swopper050 6992040
WIP on RNN
Swopper050 b938a25
Working RNN version
Swopper050 70ac771
Added tests for RNN
Swopper050 da24e6a
Merged develop and elaborated docstrings
Swopper050 cb74f72
Working version of LSTM operator
Swopper050 dd95f78
Reusable attrs and tests for LSTM
Swopper050 07612a7
Refactored recurrent operators to share code
Swopper050 18cd49e
WIP on batch norm
Swopper050 6643100
Merged develop
Swopper050 efb4ca4
Merged develop and implement test mode for batch normalization
Swopper050 e79e9c8
Added BatchNormalization implementation
Swopper050 5465c2c
Comment and remove print
Swopper050 18ca63b
Fix lint
Swopper050 faf43df
Ignore newly added tests
Swopper050 b3d0268
Remove print statement
Swopper050 b7a3044
Merged develop
Swopper050 fe10c12
Error earlier
Swopper050 b3521bf
Fixed tests
Swopper050 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,230 @@ | ||
| package opset13 | ||
|
|
||
| import ( | ||
| "github.com/advancedclimatesystems/gonnx/onnx" | ||
| "github.com/advancedclimatesystems/gonnx/ops" | ||
| "gorgonia.org/tensor" | ||
| ) | ||
|
|
||
| const ( | ||
| MinBatchNormalizationInputs = 5 | ||
| MaxBatchNormalizationInputs = 5 | ||
| BatchNormalizationDefaultEpsilon = 1e-5 | ||
| BatchNormalizationDefaultMomentum = 0.9 | ||
| ) | ||
|
|
||
| // BatchNormalization represents the ONNX batchNormalization operator. | ||
| type BatchNormalization struct { | ||
| epsilon float32 | ||
| momentum float32 | ||
| testMode bool | ||
| } | ||
|
|
||
| // newBatchNormalization creates a new batchNormalization operator. | ||
| func newBatchNormalization() ops.Operator { | ||
| return &BatchNormalization{ | ||
| epsilon: BatchNormalizationDefaultEpsilon, | ||
| momentum: BatchNormalizationDefaultMomentum, | ||
| } | ||
| } | ||
|
|
||
| // Init initializes the batchNormalization operator. | ||
| func (b *BatchNormalization) Init(n *onnx.NodeProto) error { | ||
| hasMomentum := false | ||
|
|
||
| for _, attr := range n.GetAttribute() { | ||
| switch attr.GetName() { | ||
| case "epsilon": | ||
| b.epsilon = attr.GetF() | ||
| case "momentum": | ||
| hasMomentum = true | ||
| b.momentum = attr.GetF() | ||
| default: | ||
| return ops.ErrInvalidAttribute(attr.GetName(), b) | ||
| } | ||
| } | ||
|
|
||
| if !hasMomentum { | ||
| b.testMode = true | ||
| } | ||
|
|
||
| return nil | ||
| } | ||
|
|
||
| // Apply applies the batchNormalization operator. | ||
| func (b *BatchNormalization) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { | ||
| X := inputs[0] | ||
| scale := inputs[1] | ||
| B := inputs[2] | ||
| mean := inputs[3] | ||
| variance := inputs[4] | ||
|
|
||
| // We only support test mode, as this is by far the most common for inference models. | ||
| if !b.testMode { | ||
| return nil, ops.ErrUnsupportedAttribute("momentum", b) | ||
| } | ||
|
|
||
| out, err := b.testModeCalculation(X, scale, B, mean, variance) | ||
| if err != nil { | ||
| return nil, err | ||
| } | ||
|
|
||
| return []tensor.Tensor{out}, nil | ||
| } | ||
|
|
||
| // ValidateInputs validates the inputs that will be given to Apply for this operator. | ||
| func (b *BatchNormalization) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { | ||
| return ops.ValidateInputs(b, inputs) | ||
| } | ||
|
|
||
| // GetMinInputs returns the minimum number of input tensors this operator expects. | ||
| func (b *BatchNormalization) GetMinInputs() int { | ||
| return MinBatchNormalizationInputs | ||
| } | ||
|
|
||
| // GetMaxInputs returns the maximum number of input tensors this operator expects. | ||
| func (b *BatchNormalization) GetMaxInputs() int { | ||
| return MaxBatchNormalizationInputs | ||
| } | ||
|
|
||
| // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes | ||
| // for the corresponding input tensor. | ||
| func (b *BatchNormalization) GetInputTypeConstraints() [][]tensor.Dtype { | ||
| return [][]tensor.Dtype{ | ||
| {tensor.Float32, tensor.Float64}, | ||
| {tensor.Float32, tensor.Float64}, | ||
| {tensor.Float32, tensor.Float64}, | ||
| {tensor.Float32, tensor.Float64}, | ||
| {tensor.Float32, tensor.Float64}, | ||
| } | ||
| } | ||
|
|
||
| // String implements the stringer interface, and can be used to format errors or messages. | ||
| func (b *BatchNormalization) String() string { | ||
| return "batchNormalization operator" | ||
| } | ||
|
|
||
| func (b *BatchNormalization) reshapeTensors(X, scale, bias, mean, variance tensor.Tensor) (newScale, newBias, newMean, newVariance tensor.Tensor, err error) { | ||
| nNonSpatialDims := 2 | ||
|
|
||
| nSpatialDims := len(X.Shape()) - nNonSpatialDims | ||
| if nSpatialDims <= 0 { | ||
| return scale, bias, mean, variance, nil | ||
| } | ||
|
|
||
| // The new shape for the `scale`, `bias`, `mean` and `variance` tensors should | ||
| // be (C, 1, 1, ...), such that they can be broadcasted to match the shape of `X`. | ||
| newShape := make([]int, 1+nSpatialDims) | ||
|
|
||
| // Here we set the channel dimension. The channel dimension is the same | ||
| // for all `X`, `scale`, `bias`, `mean` and `variance` tensors. | ||
| newShape[0] = scale.Shape()[0] | ||
|
|
||
| // Set all the remaining dimensions to 1 to allow for broadcasting. | ||
| for i := 1; i < len(newShape); i++ { | ||
| newShape[i] = 1 | ||
| } | ||
|
|
||
| // Now we create new tensors for all the input tensors (except `X`) and reshape | ||
| // them. | ||
| newScale, ok := scale.Clone().(tensor.Tensor) | ||
| if !ok { | ||
| return nil, nil, nil, nil, ops.ErrTypeAssert("tensor.Tensor", scale.Clone()) | ||
| } | ||
|
|
||
| newBias, ok = bias.Clone().(tensor.Tensor) | ||
| if !ok { | ||
| return nil, nil, nil, nil, ops.ErrTypeAssert("tensor.Tensor", bias.Clone()) | ||
| } | ||
|
|
||
| newMean, ok = mean.Clone().(tensor.Tensor) | ||
| if !ok { | ||
| return nil, nil, nil, nil, ops.ErrTypeAssert("tensor.Tensor", mean.Clone()) | ||
| } | ||
|
|
||
| newVariance, ok = variance.Clone().(tensor.Tensor) | ||
| if !ok { | ||
| return nil, nil, nil, nil, ops.ErrTypeAssert("tensor.Tensor", variance.Clone()) | ||
| } | ||
|
|
||
| err = newScale.Reshape(newShape...) | ||
| if err != nil { | ||
| return nil, nil, nil, nil, err | ||
| } | ||
|
|
||
| err = newBias.Reshape(newShape...) | ||
| if err != nil { | ||
| return nil, nil, nil, nil, err | ||
| } | ||
|
|
||
| err = newMean.Reshape(newShape...) | ||
| if err != nil { | ||
| return nil, nil, nil, nil, err | ||
| } | ||
|
|
||
| err = newVariance.Reshape(newShape...) | ||
| if err != nil { | ||
| return nil, nil, nil, nil, err | ||
| } | ||
|
|
||
| return | ||
| } | ||
|
|
||
| func (b *BatchNormalization) testModeCalculation(X, scale, bias, mean, variance tensor.Tensor) (tensor.Tensor, error) { | ||
| newScale, newBias, newMean, newVariance, err := b.reshapeTensors(X, scale, bias, mean, variance) | ||
| if err != nil { | ||
| return nil, err | ||
| } | ||
|
|
||
| numerator, err := ops.ApplyBinaryOperation( | ||
| X, | ||
| newMean, | ||
| ops.Sub, | ||
| ops.UnidirectionalBroadcasting, | ||
| ) | ||
| if err != nil { | ||
| return nil, err | ||
| } | ||
|
|
||
| numerator, err = ops.ApplyBinaryOperation( | ||
| numerator[0], | ||
| newScale, | ||
| ops.Mul, | ||
| ops.UnidirectionalBroadcasting, | ||
| ) | ||
| if err != nil { | ||
| return nil, err | ||
| } | ||
|
|
||
| denominator, err := tensor.Add(newVariance, b.epsilon) | ||
| if err != nil { | ||
| return nil, err | ||
| } | ||
|
|
||
| denominator, err = tensor.Sqrt(denominator) | ||
| if err != nil { | ||
| return nil, err | ||
| } | ||
|
|
||
| outputs, err := ops.ApplyBinaryOperation( | ||
| numerator[0], | ||
| denominator, | ||
| ops.Div, | ||
| ops.UnidirectionalBroadcasting, | ||
| ) | ||
| if err != nil { | ||
| return nil, err | ||
| } | ||
|
|
||
| outputs, err = ops.ApplyBinaryOperation( | ||
| outputs[0], | ||
| newBias, | ||
| ops.Add, | ||
| ops.UnidirectionalBroadcasting, | ||
| ) | ||
| if err != nil { | ||
| return nil, err | ||
| } | ||
|
|
||
| return outputs[0], nil | ||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,146 @@ | ||
| package opset13 | ||
|
|
||
| import ( | ||
| "testing" | ||
|
|
||
| "github.com/advancedclimatesystems/gonnx/onnx" | ||
| "github.com/advancedclimatesystems/gonnx/ops" | ||
| "github.com/stretchr/testify/assert" | ||
| "gorgonia.org/tensor" | ||
| ) | ||
|
|
||
| func TestBatchNormalizationInit(t *testing.T) { | ||
| b := &BatchNormalization{} | ||
|
|
||
| err := b.Init( | ||
| &onnx.NodeProto{ | ||
| Attribute: []*onnx.AttributeProto{ | ||
| {Name: "epsilon", F: 0.001}, | ||
| }, | ||
| }, | ||
| ) | ||
| assert.Nil(t, err) | ||
|
|
||
| assert.Equal(t, float32(0.001), b.epsilon) | ||
| assert.True(t, b.testMode) | ||
| } | ||
|
|
||
| func TestBatchNormalizationInitTrainingMode(t *testing.T) { | ||
| b := &BatchNormalization{} | ||
|
|
||
| err := b.Init( | ||
| &onnx.NodeProto{ | ||
| Attribute: []*onnx.AttributeProto{ | ||
| {Name: "epsilon", F: 0.001}, | ||
| {Name: "momentum", F: 0.99}, | ||
| }, | ||
| }, | ||
| ) | ||
| assert.Nil(t, err) | ||
|
|
||
| assert.Equal(t, float32(0.001), b.epsilon) | ||
| assert.Equal(t, float32(0.99), b.momentum) | ||
| assert.False(t, b.testMode) | ||
| } | ||
|
|
||
| func TestBatchNormalization(t *testing.T) { | ||
| tests := []struct { | ||
| batchNormalization *BatchNormalization | ||
| backings [][]float32 | ||
| shapes [][]int | ||
| expected []float32 | ||
| }{ | ||
| { | ||
| &BatchNormalization{ | ||
| epsilon: 1e5, | ||
| momentum: 0.9, | ||
| testMode: true, | ||
| }, | ||
| [][]float32{ | ||
| {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, | ||
| {0.2, 0.3, 0.4}, | ||
| {0.1, -0.1, 0.2}, | ||
| {4, 8, 12}, | ||
| {1, 2, 3}, | ||
| }, | ||
| [][]int{ | ||
| {2, 3, 2, 2}, | ||
| {3}, | ||
| {3}, | ||
| {3}, | ||
| {3}, | ||
| }, | ||
| []float32{0.097470194, 0.098102644, 0.098735094, 0.09936755, -0.103794694, -0.10284603, -0.10189735, -0.10094868, 0.19494043, 0.19620533, 0.19747022, 0.19873512, 0.10505962, 0.10569207, 0.10632452, 0.10695698, -0.09241061, -0.091461934, -0.09051326, -0.08956459, 0.21011914, 0.21138403, 0.21264893, 0.21391381}, | ||
| }, | ||
| } | ||
|
|
||
| for _, test := range tests { | ||
| inputs := []tensor.Tensor{ | ||
| ops.TensorWithBackingFixture(test.backings[0], test.shapes[0]...), | ||
| ops.TensorWithBackingFixture(test.backings[1], test.shapes[1]...), | ||
| ops.TensorWithBackingFixture(test.backings[2], test.shapes[2]...), | ||
| ops.TensorWithBackingFixture(test.backings[3], test.shapes[3]...), | ||
| ops.TensorWithBackingFixture(test.backings[4], test.shapes[4]...), | ||
| } | ||
|
|
||
| res, err := test.batchNormalization.Apply(inputs) | ||
| assert.Nil(t, err) | ||
|
|
||
| assert.Equal(t, test.expected, res[0].Data()) | ||
| } | ||
| } | ||
|
|
||
| func TestInputValidationBatchNormalization(t *testing.T) { | ||
| tests := []struct { | ||
| inputs []tensor.Tensor | ||
| err error | ||
| }{ | ||
| { | ||
| []tensor.Tensor{ | ||
| ops.TensorWithBackingFixture([]float32{1, 2}, 2), | ||
| ops.TensorWithBackingFixture([]float32{3, 4}, 2), | ||
| ops.TensorWithBackingFixture([]float32{3, 4}, 2), | ||
| ops.TensorWithBackingFixture([]float32{3, 4}, 2), | ||
| ops.TensorWithBackingFixture([]float32{3, 4}, 2), | ||
| }, | ||
| nil, | ||
| }, | ||
| { | ||
| []tensor.Tensor{ | ||
| ops.TensorWithBackingFixture([]float64{1, 2}, 2), | ||
| ops.TensorWithBackingFixture([]float64{3, 4}, 2), | ||
| ops.TensorWithBackingFixture([]float64{3, 4}, 2), | ||
| ops.TensorWithBackingFixture([]float64{3, 4}, 2), | ||
| ops.TensorWithBackingFixture([]float64{3, 4}, 2), | ||
| }, | ||
| nil, | ||
| }, | ||
| { | ||
| []tensor.Tensor{ | ||
| ops.TensorWithBackingFixture([]int{1, 2}, 2), | ||
| }, | ||
| ops.ErrInvalidInputCount(1, &BatchNormalization{}), | ||
| }, | ||
| { | ||
| []tensor.Tensor{ | ||
| ops.TensorWithBackingFixture([]float32{1, 2}, 2), | ||
| ops.TensorWithBackingFixture([]int{3, 4}, 2), | ||
| ops.TensorWithBackingFixture([]float32{1, 2}, 2), | ||
| ops.TensorWithBackingFixture([]float32{1, 2}, 2), | ||
| ops.TensorWithBackingFixture([]float32{1, 2}, 2), | ||
| }, | ||
| ops.ErrInvalidInputType(1, "int", &BatchNormalization{}), | ||
| }, | ||
| } | ||
|
|
||
| for _, test := range tests { | ||
| batchNormalization := &BatchNormalization{} | ||
| validated, err := batchNormalization.ValidateInputs(test.inputs) | ||
|
|
||
| assert.Equal(t, test.err, err) | ||
|
|
||
| if test.err == nil { | ||
| assert.Equal(t, test.inputs, validated) | ||
| } | ||
| } | ||
| } |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.