Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
DragonPG2000 authored Sep 1, 2020
1 parent 7755f82 commit abdb614
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 0 deletions.
55 changes: 55 additions & 0 deletions GroupNorm/groupnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import tensorflow as tf
from tensorflow import keras
import numpy as np
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Layer
from tensorflow.keras.backend import image_data_format

class GroupNorm(Layer):
"""
Reimplementation of GroupNorm using the excellent post
https://amaarora.github.io/2020/08/09/groupnorm.html
"""
def __init__(self,groups=32,**kwargs):
"""
Arguments:
groups: The number of groups that the channels are divided into (Default value=32)
eps: The value used in order to prevent zero by division errors
"""
super(GroupNorm,self).__init__(**kwargs)
self.g=groups
self.eps=1e-5
if image_data_format()=='channels_first':
self.axis=1
else:
self.axis=-1
def build(self,input_shape):
"""
Arguments:
input_shape: The shape of the feature maps in the form N*H*W*C
"""
shape=[1,1,1,1]
shape[self.axis]=int(input_shape[self.axis])
self.gamma=self.add_weight('gamma',
shape=shape)
self.beta=self.add_weight('gamma',
shape=shape)
super().build(input_shape)
def call(self,inputs):
"""
Arguments:
inputs: The transformed features from the previous layers
"""
input_shape=K.int_shape(inputs)
n,h,w,c=input_shape
tensor_shape=tf.shape(inputs)
shape=[tensor_shape[i] for i in range(len(input_shape))]
shape[self.axis]=shape[self.axis]//self.g
shape.insert(self.axis,self.g)
shape=tf.stack(shape)
x=tf.reshape(inputs,shape=shape)
mean,variance=tf.nn.moments(x,axes=[1,2,3],keepdims=True)
x_transformed=(x-mean)/tf.sqrt(variance+self.eps)
x_transformed=tf.reshape(x_transformed,shape=tensor_shape)
x_transformed=self.gamma*x_transformed+self.beta
return x_transformed
15 changes: 15 additions & 0 deletions GroupNorm/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import tensorflow as tf
from tensorflow import keras
from groupnorm import GroupNorm
def make_simple_model(input_shape=(28,28,1),norm='group'):
inp=keras.layers.Input(input_shape)
model_gn=keras.layers.Conv2D(128,kernel_size=3,strides=(1,1),padding='same')(inp)
if norm=='group':
model_gn=GroupNorm()(model_gn)
else:
keras.layers.BatchNormalization()(model_gn)
model_gn=keras.layers.GlobalAveragePooling2D()(model_gn)
model_gn=keras.layers.Dense(10,activation='softmax')(model_gn)
model_gn=keras.models.Model(inputs=[inp],outputs=[model_gn])
model_gn.compile(loss='categorical_crossentropy',optimizer=keras.optimizers.Adam(1e-4),metrics=['accuracy'])
return model_gn
16 changes: 16 additions & 0 deletions GroupNorm/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import tensorflow as tf
from tensorflow import keras
import numpy as np
from tensorflow.keras import backend as K
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Layer
from model import make_simple_model
(X_train, Y_train), (X_test, Y_test) = mnist.load_data()
Y_train,Y_test=tf.keras.utils.to_categorical(Y_train,num_classes=10),tf.keras.utils.to_categorical(Y_test,num_classes=10)


models={'group_norm':make_simple_model(norm='group'),'batch_norm':make_simple_model(norm='batch')}

for norm,model in models.items():
print(f'Running with {norm}')
history=model.fit(X_train,Y_train,batch_size=32,epochs=10,validation_data=(X_test,Y_test))

0 comments on commit abdb614

Please sign in to comment.