Skip to content

Commit 79513d2

Browse files
authored
Fix the regression that made yolo-v3 accuracy drop. (#514)
* Fix the regression that made yolo-v3 accuracy drop. * Improve the accuracy.
1 parent a79576e commit 79513d2

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

examples/tensorflow/object_detection/yolo_v3/yolo_v3.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ model: # mandatory. neural_compres
66

77
quantization: # optional. tuning constraints on model-wise for advance user to reduce tuning space.
88
calibration:
9-
sampling_size: 10 # optional. default value is the size of whole dataset. used to set how many portions of calibration dataset is used. exclusive with iterations field.
9+
sampling_size: 2 # optional. default value is the size of whole dataset. used to set how many portions of calibration dataset is used. exclusive with iterations field.
1010
dataloader: # optional. if not specified, user need construct a q_dataloader in code for neural_compressor.Quantization.
1111
batch_size: 1
1212
dataset:

neural_compressor/adaptor/tf_utils/quantize_graph/quantize_graph_conv.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -264,11 +264,11 @@ def apply_conv_biasadd_addn_relu_fusion(self, match_node_name):
264264
matched_node.node.name)
265265
weight_name = normal_inputs[1]
266266

267-
if not self._find_relu_node(matched_node.node):
268-
return self.apply_conv_biasadd_fusion(match_node_name[:2])
269-
270267
third_node = self.node_name_mapping[match_node_name[2]].node
271268
forth_node = self.node_name_mapping[match_node_name[3]].node
269+
if third_node.op != 'LeakyRelu' and not self._find_relu_node(matched_node.node):
270+
return self.apply_conv_biasadd_fusion(match_node_name[:2])
271+
272272
is_leakyrelu_add_fusion = third_node.op == 'LeakyRelu' and forth_node.op.find('Add') != -1
273273

274274
q_weights_name, q_weights_min_name, q_weights_max_name = \

0 commit comments

Comments
 (0)