Module furiosa.quantizer.frontend.onnx.transformer.fuse_bn_into_conv
Expand source code
import abc
import warnings
import onnx
import numpy as np
from furiosa_sdk_quantizer.interfaces.transformer import Transformer
from furiosa_sdk_quantizer.frontend.onnx.transformer import ONNXTransformer
default_conv_attrs = {
'dilations': [1, 1],
'group': 1,
'kernel_shape': [1, 1],
'pads': [0, 0, 0, 0],
'strides': [1, 1]
}
class FuseBnIntoConv(Transformer):
def transform(self, model: onnx.ModelProto) -> onnx.ModelProto:
for transformer in [
Pattern_1,
Pattern_2,
]:
model = transformer(model).transform()
return model
class Pattern_1(ONNXTransformer, abc.ABC):
"""
transform
prev --> Conv --> BatchNormalization --> next
to
prev --> Conv --> next
"""
pattern_to_match = ['Conv', 'BatchNormalization']
def pattern_matching(self, base_node):
inputs = base_node.input
matched_nodes = self.pattern_matcher(base_node, self.pattern_to_match)
if not matched_nodes:
return inputs
if not self.pattern_condition_checker(matched_nodes):
return inputs
top_node = matched_nodes[0]
self.transform_to_fuse(matched_nodes,
nodes_to_add=[*self.make_new_node(matched_nodes)],
inits_to_add=[*self.make_new_init(matched_nodes)],
vis_to_add=[*self.make_new_vi(matched_nodes)] if self.make_new_vi(
matched_nodes) else None
)
return top_node.input
def pattern_condition_checker(self, nodes_to_check):
return True
def make_new_node(self, matched_nodes):
top_node, base_node = matched_nodes
input_names = [node_input if node_input not in self.initializer_map else node_input + '_bn_fused'
for node_input in top_node.input]
return self.make_node('Conv', [*input_names], [base_node.output[0]], top_node.name,
**default_conv_attrs)
def make_new_init(self, matched_nodes):
top_node, base_node = matched_nodes
bn_params = self.get_bn_params(base_node)
multiplier, shifter = self.get_multiplier_and_shifter(*bn_params)
inits_to_add = []
for node_input in top_node.input:
if node_input not in self.initializer_map:
continue
weight = self.get_initializer_array(node_input)
fused_weight = self.fuse_bn_params(weight, multiplier, shifter)
inits_to_add.append(self.make_initializer_from_array(fused_weight, node_input + '_bn_fused'))
return inits_to_add
def make_new_vi(self, matched_nodes):
return None
def get_bn_params(self, node):
scale = self.get_initializer_array(node.input[1])
if all(v == 0. for v in scale):
warnings.warn(f'BatchNormalization.scale is a zero tensor: {node.input[1]}')
B = self.get_initializer_array(node.input[2])
mean = self.get_initializer_array(node.input[3])
var = self.get_initializer_array(node.input[4])
from furiosa_sdk_quantizer.frontend.onnx.quantizer.utils import attribute_to_kwargs
attrs = attribute_to_kwargs(node.attribute)
eps = attrs.get('epsilon', 1e-05)
return scale, B, mean, var, eps
@staticmethod
def get_multiplier_and_shifter(scale, B, mean, var, eps):
multiplier = scale * 1 / np.sqrt(var + eps)
shifter = - mean * scale + B
return multiplier, shifter
@staticmethod
def fuse_bn_params(weight, multiplier, shifter):
if weight.ndim == 4:
fused_weight = weight * multiplier.reshape(-1, 1, 1, 1)
return fused_weight
elif weight.ndim == 1:
fused_bias = weight * multiplier + shifter
return fused_bias
else:
raise Exception('Unknown weight ndim: %s' % weight.dim)
class Pattern_2(Pattern_1, abc.ABC):
"""
transform
prev --> BatchNormalization --> next
to
prev --> Mul --> Add --> next
if prev.op_type != Conv
"""
pattern_to_match = ['BatchNormalization']
def pattern_condition_checker(self, nodes_to_check):
node = nodes_to_check[0]
if self.is_op_type(node.op_type, ['Conv']):
return False
return True
def make_new_node(self, matched_nodes):
node = matched_nodes[0]
return [
self.make_node('Mul', [node.input[0], node.input[0] + '_bn_multiplier'],
[node.output[0] + '_bn_multiplied'], node.name),
self.make_node('Add',
[node.output[0] + '_bn_multiplied', node.input[0] + '_bn_shifter'],
[node.output[0]], node.name)
]
def make_new_init(self, matched_nodes):
node = matched_nodes[0]
bn_params = self.get_bn_params(node)
multiplier, shifter = self.get_multiplier_and_shifter(*bn_params)
num_features = self.get_value_info_shape(node.output[0])[0]
return [
self.make_initializer_from_array(multiplier.reshape(num_features, -1, 1, 1),
name=node.input[0] + '_bn_multiplier'),
self.make_initializer_from_array(shifter.reshape(num_features, -1, 1, 1),
name=node.input[0] + '_bn_shifter')
]
def make_new_vi(self, matched_nodes):
node = matched_nodes[0]
return [self.make_tensor_value_info(node.output[0] + '_bn_multiplied',
onnx.TensorProto.FLOAT,
shape=self.get_value_info_shape(node.output[0]))]
Classes
class FuseBnIntoConv (*args, **kwds)
-
Abstract base class for generic types.
A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::
class Mapping(Generic[KT, VT]): def getitem(self, key: KT) -> VT: … # Etc.
This class can then be used as follows::
def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default
Expand source code
class FuseBnIntoConv(Transformer): def transform(self, model: onnx.ModelProto) -> onnx.ModelProto: for transformer in [ Pattern_1, Pattern_2, ]: model = transformer(model).transform() return model
Ancestors
- furiosa_sdk_quantizer.interfaces.transformer.Transformer
- typing.Generic
Methods
def transform(self, model: onnx.onnx_ml_pb2.ModelProto) ‑> onnx.onnx_ml_pb2.ModelProto
-
Expand source code
def transform(self, model: onnx.ModelProto) -> onnx.ModelProto: for transformer in [ Pattern_1, Pattern_2, ]: model = transformer(model).transform() return model
class Pattern_1 (model)
-
transform prev –> Conv –> BatchNormalization –> next to prev –> Conv –> next
Expand source code
class Pattern_1(ONNXTransformer, abc.ABC): """ transform prev --> Conv --> BatchNormalization --> next to prev --> Conv --> next """ pattern_to_match = ['Conv', 'BatchNormalization'] def pattern_matching(self, base_node): inputs = base_node.input matched_nodes = self.pattern_matcher(base_node, self.pattern_to_match) if not matched_nodes: return inputs if not self.pattern_condition_checker(matched_nodes): return inputs top_node = matched_nodes[0] self.transform_to_fuse(matched_nodes, nodes_to_add=[*self.make_new_node(matched_nodes)], inits_to_add=[*self.make_new_init(matched_nodes)], vis_to_add=[*self.make_new_vi(matched_nodes)] if self.make_new_vi( matched_nodes) else None ) return top_node.input def pattern_condition_checker(self, nodes_to_check): return True def make_new_node(self, matched_nodes): top_node, base_node = matched_nodes input_names = [node_input if node_input not in self.initializer_map else node_input + '_bn_fused' for node_input in top_node.input] return self.make_node('Conv', [*input_names], [base_node.output[0]], top_node.name, **default_conv_attrs) def make_new_init(self, matched_nodes): top_node, base_node = matched_nodes bn_params = self.get_bn_params(base_node) multiplier, shifter = self.get_multiplier_and_shifter(*bn_params) inits_to_add = [] for node_input in top_node.input: if node_input not in self.initializer_map: continue weight = self.get_initializer_array(node_input) fused_weight = self.fuse_bn_params(weight, multiplier, shifter) inits_to_add.append(self.make_initializer_from_array(fused_weight, node_input + '_bn_fused')) return inits_to_add def make_new_vi(self, matched_nodes): return None def get_bn_params(self, node): scale = self.get_initializer_array(node.input[1]) if all(v == 0. for v in scale): warnings.warn(f'BatchNormalization.scale is a zero tensor: {node.input[1]}') B = self.get_initializer_array(node.input[2]) mean = self.get_initializer_array(node.input[3]) var = self.get_initializer_array(node.input[4]) from furiosa_sdk_quantizer.frontend.onnx.quantizer.utils import attribute_to_kwargs attrs = attribute_to_kwargs(node.attribute) eps = attrs.get('epsilon', 1e-05) return scale, B, mean, var, eps @staticmethod def get_multiplier_and_shifter(scale, B, mean, var, eps): multiplier = scale * 1 / np.sqrt(var + eps) shifter = - mean * scale + B return multiplier, shifter @staticmethod def fuse_bn_params(weight, multiplier, shifter): if weight.ndim == 4: fused_weight = weight * multiplier.reshape(-1, 1, 1, 1) return fused_weight elif weight.ndim == 1: fused_bias = weight * multiplier + shifter return fused_bias else: raise Exception('Unknown weight ndim: %s' % weight.dim)
Ancestors
- furiosa_sdk_quantizer.frontend.onnx.transformer.ONNXTransformer
- abc.ABC
Subclasses
Class variables
var pattern_to_match
Static methods
def fuse_bn_params(weight, multiplier, shifter)
-
Expand source code
@staticmethod def fuse_bn_params(weight, multiplier, shifter): if weight.ndim == 4: fused_weight = weight * multiplier.reshape(-1, 1, 1, 1) return fused_weight elif weight.ndim == 1: fused_bias = weight * multiplier + shifter return fused_bias else: raise Exception('Unknown weight ndim: %s' % weight.dim)
def get_multiplier_and_shifter(scale, B, mean, var, eps)
-
Expand source code
@staticmethod def get_multiplier_and_shifter(scale, B, mean, var, eps): multiplier = scale * 1 / np.sqrt(var + eps) shifter = - mean * scale + B return multiplier, shifter
Methods
def get_bn_params(self, node)
-
Expand source code
def get_bn_params(self, node): scale = self.get_initializer_array(node.input[1]) if all(v == 0. for v in scale): warnings.warn(f'BatchNormalization.scale is a zero tensor: {node.input[1]}') B = self.get_initializer_array(node.input[2]) mean = self.get_initializer_array(node.input[3]) var = self.get_initializer_array(node.input[4]) from furiosa_sdk_quantizer.frontend.onnx.quantizer.utils import attribute_to_kwargs attrs = attribute_to_kwargs(node.attribute) eps = attrs.get('epsilon', 1e-05) return scale, B, mean, var, eps
def make_new_init(self, matched_nodes)
-
Expand source code
def make_new_init(self, matched_nodes): top_node, base_node = matched_nodes bn_params = self.get_bn_params(base_node) multiplier, shifter = self.get_multiplier_and_shifter(*bn_params) inits_to_add = [] for node_input in top_node.input: if node_input not in self.initializer_map: continue weight = self.get_initializer_array(node_input) fused_weight = self.fuse_bn_params(weight, multiplier, shifter) inits_to_add.append(self.make_initializer_from_array(fused_weight, node_input + '_bn_fused')) return inits_to_add
def make_new_node(self, matched_nodes)
-
Expand source code
def make_new_node(self, matched_nodes): top_node, base_node = matched_nodes input_names = [node_input if node_input not in self.initializer_map else node_input + '_bn_fused' for node_input in top_node.input] return self.make_node('Conv', [*input_names], [base_node.output[0]], top_node.name, **default_conv_attrs)
def make_new_vi(self, matched_nodes)
-
Expand source code
def make_new_vi(self, matched_nodes): return None
def pattern_condition_checker(self, nodes_to_check)
-
Expand source code
def pattern_condition_checker(self, nodes_to_check): return True
def pattern_matching(self, base_node)
-
Expand source code
def pattern_matching(self, base_node): inputs = base_node.input matched_nodes = self.pattern_matcher(base_node, self.pattern_to_match) if not matched_nodes: return inputs if not self.pattern_condition_checker(matched_nodes): return inputs top_node = matched_nodes[0] self.transform_to_fuse(matched_nodes, nodes_to_add=[*self.make_new_node(matched_nodes)], inits_to_add=[*self.make_new_init(matched_nodes)], vis_to_add=[*self.make_new_vi(matched_nodes)] if self.make_new_vi( matched_nodes) else None ) return top_node.input
class Pattern_2 (model)
-
transform prev –> BatchNormalization –> next to prev –> Mul –> Add –> next
if prev.op_type != Conv
Expand source code
class Pattern_2(Pattern_1, abc.ABC): """ transform prev --> BatchNormalization --> next to prev --> Mul --> Add --> next if prev.op_type != Conv """ pattern_to_match = ['BatchNormalization'] def pattern_condition_checker(self, nodes_to_check): node = nodes_to_check[0] if self.is_op_type(node.op_type, ['Conv']): return False return True def make_new_node(self, matched_nodes): node = matched_nodes[0] return [ self.make_node('Mul', [node.input[0], node.input[0] + '_bn_multiplier'], [node.output[0] + '_bn_multiplied'], node.name), self.make_node('Add', [node.output[0] + '_bn_multiplied', node.input[0] + '_bn_shifter'], [node.output[0]], node.name) ] def make_new_init(self, matched_nodes): node = matched_nodes[0] bn_params = self.get_bn_params(node) multiplier, shifter = self.get_multiplier_and_shifter(*bn_params) num_features = self.get_value_info_shape(node.output[0])[0] return [ self.make_initializer_from_array(multiplier.reshape(num_features, -1, 1, 1), name=node.input[0] + '_bn_multiplier'), self.make_initializer_from_array(shifter.reshape(num_features, -1, 1, 1), name=node.input[0] + '_bn_shifter') ] def make_new_vi(self, matched_nodes): node = matched_nodes[0] return [self.make_tensor_value_info(node.output[0] + '_bn_multiplied', onnx.TensorProto.FLOAT, shape=self.get_value_info_shape(node.output[0]))]
Ancestors
- Pattern_1
- furiosa_sdk_quantizer.frontend.onnx.transformer.ONNXTransformer
- abc.ABC
Class variables
var pattern_to_match
Methods
def make_new_init(self, matched_nodes)
-
Expand source code
def make_new_init(self, matched_nodes): node = matched_nodes[0] bn_params = self.get_bn_params(node) multiplier, shifter = self.get_multiplier_and_shifter(*bn_params) num_features = self.get_value_info_shape(node.output[0])[0] return [ self.make_initializer_from_array(multiplier.reshape(num_features, -1, 1, 1), name=node.input[0] + '_bn_multiplier'), self.make_initializer_from_array(shifter.reshape(num_features, -1, 1, 1), name=node.input[0] + '_bn_shifter') ]
def make_new_node(self, matched_nodes)
-
Expand source code
def make_new_node(self, matched_nodes): node = matched_nodes[0] return [ self.make_node('Mul', [node.input[0], node.input[0] + '_bn_multiplier'], [node.output[0] + '_bn_multiplied'], node.name), self.make_node('Add', [node.output[0] + '_bn_multiplied', node.input[0] + '_bn_shifter'], [node.output[0]], node.name) ]
def make_new_vi(self, matched_nodes)
-
Expand source code
def make_new_vi(self, matched_nodes): node = matched_nodes[0] return [self.make_tensor_value_info(node.output[0] + '_bn_multiplied', onnx.TensorProto.FLOAT, shape=self.get_value_info_shape(node.output[0]))]
def pattern_condition_checker(self, nodes_to_check)
-
Expand source code
def pattern_condition_checker(self, nodes_to_check): node = nodes_to_check[0] if self.is_op_type(node.op_type, ['Conv']): return False return True