diff --git a/morph_net/framework/batch_norm_source_op_handler_test.py b/morph_net/framework/batch_norm_source_op_handler_test.py index 369ba0f..7b7920a 100644 --- a/morph_net/framework/batch_norm_source_op_handler_test.py +++ b/morph_net/framework/batch_norm_source_op_handler_test.py @@ -41,7 +41,7 @@ def setUp(self): # Declare OpSlice and OpGroup for ops that are created in the test network. self.batch_norm_op = g.get_operation_by_name( - 'conv1/BatchNorm/FusedBatchNorm') + 'conv1/BatchNorm/FusedBatchNormV3') self.batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 5)) self.batch_norm_op_group = orm.OpGroup(self.batch_norm_op_slice) diff --git a/morph_net/framework/concat_op_handler_test.py b/morph_net/framework/concat_op_handler_test.py index 0533536..db806d3 100644 --- a/morph_net/framework/concat_op_handler_test.py +++ b/morph_net/framework/concat_op_handler_test.py @@ -61,7 +61,7 @@ def setUp(self): self.axis_op = g.get_operation_by_name('concat/axis') - self.batch_norm_op = g.get_operation_by_name('BatchNorm/FusedBatchNorm') + self.batch_norm_op = g.get_operation_by_name('BatchNorm/FusedBatchNormV3') self.batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 18)) self.batch_norm_op_group = orm.OpGroup( self.batch_norm_op_slice, @@ -808,7 +808,7 @@ def setUp(self): self.relu3_op_group = orm.OpGroup( self.relu3_op_slice, omit_source_op_slices=[self.relu3_op_slice]) - self.batch_norm_op = g.get_operation_by_name('BatchNorm/FusedBatchNorm') + self.batch_norm_op = g.get_operation_by_name('BatchNorm/FusedBatchNormV3') self.batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 6)) self.batch_norm_op_group = orm.OpGroup( self.batch_norm_op_slice, diff --git a/morph_net/framework/grouping_op_handler_test.py b/morph_net/framework/grouping_op_handler_test.py index 468c857..31301bb 100644 --- a/morph_net/framework/grouping_op_handler_test.py +++ b/morph_net/framework/grouping_op_handler_test.py @@ -39,7 +39,7 @@ def setUp(self): # Declare OpSlice and OpGroup for ops of interest. self.batch_norm_op = g.get_operation_by_name( - 'conv1/BatchNorm/FusedBatchNorm') + 'conv1/BatchNorm/FusedBatchNormV3') self.batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 5)) self.batch_norm_op_group = orm.OpGroup(self.batch_norm_op_slice) diff --git a/morph_net/framework/leaf_op_handler_test.py b/morph_net/framework/leaf_op_handler_test.py index e90aa47..43816b3 100644 --- a/morph_net/framework/leaf_op_handler_test.py +++ b/morph_net/framework/leaf_op_handler_test.py @@ -39,7 +39,7 @@ def setUp(self): # Declare OpSlice and OpGroup for ops of interest. self.batch_norm_op = g.get_operation_by_name( - 'conv1/BatchNorm/FusedBatchNorm') + 'conv1/BatchNorm/FusedBatchNormV3') self.batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 5)) self.batch_norm_op_group = orm.OpGroup(self.batch_norm_op_slice) diff --git a/morph_net/framework/op_handler_util_test.py b/morph_net/framework/op_handler_util_test.py index eca9739..c85cee7 100644 --- a/morph_net/framework/op_handler_util_test.py +++ b/morph_net/framework/op_handler_util_test.py @@ -46,7 +46,7 @@ def setUp(self): # Declare OpSlice and OpGroup for ops in the first test network. self.batch_norm_op = g.get_operation_by_name( - 'conv1/BatchNorm/FusedBatchNorm') + 'conv1/BatchNorm/FusedBatchNormV3') self.batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, None) self.batch_norm_op_group = orm.OpGroup(self.batch_norm_op_slice) @@ -86,7 +86,7 @@ def setUp(self): self.relu4_op_slice, omit_source_op_slices=[self.relu4_op_slice]) self.unfused_batch_norm_op = g.get_operation_by_name( - 'BatchNorm/FusedBatchNorm') + 'BatchNorm/FusedBatchNormV3') self.unfused_batch_norm_op_slice = orm.OpSlice( self.unfused_batch_norm_op, orm.Slice(0, 18)) @@ -676,7 +676,7 @@ def testOpAssumptions(self): g = tf.get_default_graph() - # Verify that FusedBatchNorm has gamma as inputs[1]. + # Verify that FusedBatchNormV3 has gamma as inputs[1]. self.assertEqual('conv1/BatchNorm/gamma/read:0', self.batch_norm_op.inputs[1].name) diff --git a/morph_net/framework/op_regularizer_manager_test.py b/morph_net/framework/op_regularizer_manager_test.py index bc8816d..1b0f2f6 100644 --- a/morph_net/framework/op_regularizer_manager_test.py +++ b/morph_net/framework/op_regularizer_manager_test.py @@ -41,13 +41,14 @@ def setUp(self): self._default_op_handler_dict = collections.defaultdict( grouping_op_handler.GroupingOpHandler) self._default_op_handler_dict.update({ - 'FusedBatchNorm': IndexBatchNormSourceOpHandler(), + 'FusedBatchNormV3': + IndexBatchNormSourceOpHandler(), 'Conv2D': - output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(), + output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(), 'ConcatV2': - concat_op_handler.ConcatOpHandler(), + concat_op_handler.ConcatOpHandler(), 'DepthwiseConv2dNative': - depthwise_convolution_op_handler.DepthwiseConvolutionOpHandler(), + depthwise_convolution_op_handler.DepthwiseConvolutionOpHandler(), }) def _batch_norm_scope(self): @@ -86,7 +87,8 @@ def testSimpleOpGetRegularizer(self, use_batch_norm, use_partitioner, scope): # Instantiate OpRegularizerManager. op_handler_dict = self._default_op_handler_dict - op_handler_dict['FusedBatchNorm'] = StubBatchNormSourceOpHandler(model_stub) + op_handler_dict['FusedBatchNormV3'] = StubBatchNormSourceOpHandler( + model_stub) if not use_batch_norm: op_handler_dict['Conv2D'] = StubConv2DSourceOpHandler(model_stub) op_reg_manager = orm.OpRegularizerManager([final_op], op_handler_dict) @@ -112,7 +114,8 @@ def testConcatOpGetRegularizer(self, use_batch_norm, use_partitioner): # Instantiate OpRegularizerManager. op_handler_dict = self._default_op_handler_dict - op_handler_dict['FusedBatchNorm'] = StubBatchNormSourceOpHandler(model_stub) + op_handler_dict['FusedBatchNormV3'] = StubBatchNormSourceOpHandler( + model_stub) if not use_batch_norm: op_handler_dict['Conv2D'] = StubConv2DSourceOpHandler(model_stub) op_reg_manager = orm.OpRegularizerManager([final_op], op_handler_dict) @@ -139,7 +142,8 @@ def testGroupConcatOpGetRegularizerValues(self, op_name, short_name): # Instantiate OpRegularizerManager. op_handler_dict = self._default_op_handler_dict - op_handler_dict['FusedBatchNorm'] = StubBatchNormSourceOpHandler(model_stub) + op_handler_dict['FusedBatchNormV3'] = StubBatchNormSourceOpHandler( + model_stub) op_reg_manager = orm.OpRegularizerManager([final_op], op_handler_dict) @@ -158,7 +162,8 @@ def testGroupConcatOpGetRegularizerObjects(self): # Instantiate OpRegularizerManager. op_handler_dict = self._default_op_handler_dict - op_handler_dict['FusedBatchNorm'] = StubBatchNormSourceOpHandler(model_stub) + op_handler_dict['FusedBatchNormV3'] = StubBatchNormSourceOpHandler( + model_stub) op_reg_manager = orm.OpRegularizerManager([final_op], op_handler_dict) self.assertEqual( @@ -1688,14 +1693,15 @@ def testDfsForSourceOps(self): # Verify source ops were found. expected_queue = collections.deque([ - _get_op('conv3/BatchNorm/FusedBatchNorm'), - _get_op('conv2/BatchNorm/FusedBatchNorm'), - _get_op('conv1/BatchNorm/FusedBatchNorm')]) + _get_op('conv3/BatchNorm/FusedBatchNormV3'), + _get_op('conv2/BatchNorm/FusedBatchNormV3'), + _get_op('conv1/BatchNorm/FusedBatchNormV3') + ]) self.assertEqual(expected_queue, manager._op_deque) # Verify extra branch was not included. self.assertNotIn( - _get_op('conv4/BatchNorm/FusedBatchNorm'), manager._op_deque) + _get_op('conv4/BatchNorm/FusedBatchNormV3'), manager._op_deque) def testOpGroup_NewSourceGroup(self): inputs = tf.zeros([2, 4, 4, 3]) diff --git a/morph_net/framework/output_non_passthrough_op_handler_test.py b/morph_net/framework/output_non_passthrough_op_handler_test.py index 579b71a..3495904 100644 --- a/morph_net/framework/output_non_passthrough_op_handler_test.py +++ b/morph_net/framework/output_non_passthrough_op_handler_test.py @@ -60,7 +60,7 @@ def setUp(self): self.relu2_op_slice, omit_source_op_slices=[self.relu2_op_slice]) self.batch_norm_op = g.get_operation_by_name( - 'conv2/BatchNorm/FusedBatchNorm') + 'conv2/BatchNorm/FusedBatchNormV3') self.batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 6)) self.batch_norm_op_group = orm.OpGroup(self.batch_norm_op_slice) diff --git a/morph_net/network_regularizers/activation_regularizer.py b/morph_net/network_regularizers/activation_regularizer.py index 117e241..482f708 100644 --- a/morph_net/network_regularizers/activation_regularizer.py +++ b/morph_net/network_regularizers/activation_regularizer.py @@ -61,6 +61,7 @@ def __init__( op_handler_dict.update({ 'FusedBatchNorm': source_op_handler, 'FusedBatchNormV2': source_op_handler, + 'FusedBatchNormV3': source_op_handler, }) self._manager = orm.OpRegularizerManager( diff --git a/morph_net/network_regularizers/cost_calculator.py b/morph_net/network_regularizers/cost_calculator.py index 6744f01..c63d5ad 100644 --- a/morph_net/network_regularizers/cost_calculator.py +++ b/morph_net/network_regularizers/cost_calculator.py @@ -11,8 +11,9 @@ CONV3D_OPS = ('Conv3D',) CONV_OPS = CONV2D_OPS + CONV3D_OPS FLOP_OPS = CONV_OPS + ('MatMul',) -SUPPORTED_OPS = FLOP_OPS + ( - 'Add', 'AddN', 'ConcatV2', 'FusedBatchNorm', 'Mul', 'Relu', 'Relu6', 'Sum') +SUPPORTED_OPS = FLOP_OPS + ('Add', 'AddN', 'ConcatV2', 'FusedBatchNorm', + 'FusedBatchNormV2', 'FusedBatchNormV3', 'Mul', + 'Relu', 'Relu6', 'Sum') class CostCalculator(object): diff --git a/morph_net/network_regularizers/cost_calculator_test.py b/morph_net/network_regularizers/cost_calculator_test.py index 4137647..af89360 100644 --- a/morph_net/network_regularizers/cost_calculator_test.py +++ b/morph_net/network_regularizers/cost_calculator_test.py @@ -61,12 +61,12 @@ def testImageIsNotZerothOutputOfOp(self): op_handler_dict = collections.defaultdict( grouping_op_handler.GroupingOpHandler) op_handler_dict.update({ - 'FusedBatchNorm': - batch_norm_source_op_handler.BatchNormSourceOpHandler(0.1), + 'FusedBatchNormV3': + batch_norm_source_op_handler.BatchNormSourceOpHandler(0.1), 'Conv2D': - output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(), + output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(), 'ConcatV2': - concat_op_handler.ConcatOpHandler(), + concat_op_handler.ConcatOpHandler(), }) # Create OpRegularizerManager and NetworkRegularizer for test. diff --git a/morph_net/network_regularizers/flop_regularizer.py b/morph_net/network_regularizers/flop_regularizer.py index 9ba4da9..e0596db 100644 --- a/morph_net/network_regularizers/flop_regularizer.py +++ b/morph_net/network_regularizers/flop_regularizer.py @@ -61,6 +61,7 @@ def __init__( op_handler_dict.update({ 'FusedBatchNorm': source_op_handler, 'FusedBatchNormV2': source_op_handler, + 'FusedBatchNormV3': source_op_handler, }) self._manager = orm.OpRegularizerManager( diff --git a/morph_net/network_regularizers/flop_regularizer_test.py b/morph_net/network_regularizers/flop_regularizer_test.py index 647ac74..f4fde9e 100644 --- a/morph_net/network_regularizers/flop_regularizer_test.py +++ b/morph_net/network_regularizers/flop_regularizer_test.py @@ -129,8 +129,8 @@ def testInputBoundaryNone(self): self.BuildWithBatchNorm(fused=True) self.AddRegularizer(input_boundary=None) self.assertCountEqual(self.GetSourceOps(), [ - 'conv1/BatchNorm/FusedBatchNorm', 'conv2/BatchNorm/FusedBatchNorm', - 'conv3/BatchNorm/FusedBatchNorm', 'conv4/BatchNorm/FusedBatchNorm' + 'conv1/BatchNorm/FusedBatchNormV3', 'conv2/BatchNorm/FusedBatchNormV3', + 'conv3/BatchNorm/FusedBatchNormV3', 'conv4/BatchNorm/FusedBatchNormV3' ]) def testInputBoundaryConv3(self): @@ -138,8 +138,8 @@ def testInputBoundaryConv3(self): self.BuildWithBatchNorm(fused=True) self.AddRegularizer(input_boundary=[self.conv3.op]) self.assertCountEqual(self.GetSourceOps(), [ - 'conv1/BatchNorm/FusedBatchNorm', 'conv2/BatchNorm/FusedBatchNorm', - 'conv4/BatchNorm/FusedBatchNorm' + 'conv1/BatchNorm/FusedBatchNormV3', 'conv2/BatchNorm/FusedBatchNormV3', + 'conv4/BatchNorm/FusedBatchNormV3' ]) def testInputBoundaryConv3And4(self): @@ -152,9 +152,9 @@ def testInputBoundaryConcat(self): # Block concat, can only see conv3 and conv4. self.BuildWithBatchNorm(fused=True) self.AddRegularizer(input_boundary=[self.concat.op]) - self.assertCountEqual( - self.GetSourceOps(), - ['conv3/BatchNorm/FusedBatchNorm', 'conv4/BatchNorm/FusedBatchNorm']) + self.assertCountEqual(self.GetSourceOps(), [ + 'conv3/BatchNorm/FusedBatchNormV3', 'conv4/BatchNorm/FusedBatchNormV3' + ]) def testLossDecorated(self): self.BuildWithBatchNorm(True) diff --git a/morph_net/network_regularizers/latency_regularizer.py b/morph_net/network_regularizers/latency_regularizer.py index 928045c..8b09f06 100644 --- a/morph_net/network_regularizers/latency_regularizer.py +++ b/morph_net/network_regularizers/latency_regularizer.py @@ -69,6 +69,7 @@ def __init__( op_handler_dict.update({ 'FusedBatchNorm': source_op_handler, 'FusedBatchNormV2': source_op_handler, + 'FusedBatchNormV3': source_op_handler, }) self._manager = orm.OpRegularizerManager( diff --git a/morph_net/network_regularizers/model_size_regularizer.py b/morph_net/network_regularizers/model_size_regularizer.py index 15f233b..073eac7 100644 --- a/morph_net/network_regularizers/model_size_regularizer.py +++ b/morph_net/network_regularizers/model_size_regularizer.py @@ -63,6 +63,7 @@ def __init__( op_handler_dict.update({ 'FusedBatchNorm': source_op_handler, 'FusedBatchNormV2': source_op_handler, + 'FusedBatchNormV3': source_op_handler, }) self._manager = orm.OpRegularizerManager(