Module furiosa.quantizer.furiosa_sdk_quantizer.frontend.onnx.transformer.fuse_conv
Expand source code
import abc
import onnx
from furiosa_sdk_quantizer.interfaces.transformer import Transformer
from furiosa_sdk_quantizer.frontend.onnx.transformer import ONNXTransformer
class FuseConv(Transformer):
def transform(self, model: onnx.ModelProto) -> onnx.ModelProto:
for transformer in [
Pattern_1,
Pattern_2,
]:
model = transformer(model).transform()
return model
class Pattern_1(ONNXTransformer, abc.ABC):
"""
transform
prev --> MatMul --> Add --> next
to
prev --> Unsqueeze --> Conv --> Squeeze --> next
if 1. MatMul.ndim == 2
2. MatMul must have at most one initializer
3. Add must have at most one initializer
"""
pattern_to_match = ['MatMul', 'Add']
def pattern_matching(self, base_node):
inputs = base_node.input
matched_nodes = self.pattern_matcher(base_node, self.pattern_to_match)
if not matched_nodes:
return inputs
if not self.pattern_condition_checker(matched_nodes):
return inputs
top_node = matched_nodes[0]
self.transform_to_fuse(matched_nodes,
nodes_to_add=[
*self.make_nodes(**self.get_new_node_args(matched_nodes))
],
inits_to_add=[
*self.make_initializers(**self.get_new_init_args(matched_nodes))
],
vis_to_add=[
*self.make_value_infos(**self.get_new_vi_args(matched_nodes))
])
return top_node.input
def pattern_condition_checker(self, nodes_to_check):
top_node, base_node = nodes_to_check
if not self.check_condition_1(top_node.output[0]):
return False
if not self.check_condition_2(top_node):
return False
if not self.check_condition_2(base_node):
return False
def check_condition_1(self, tensor_name):
if len(self.get_value_info_shape(tensor_name)) == 2:
return True
return False
def check_condition_2(self, node):
num_init = 0
for node_input in node.input:
if node_input in self.initializer_map:
num_init += 1
if num_init == 1:
return True
return False
def get_new_vi_args(self, matched_nodes):
top_node = matched_nodes[0]
base_node = matched_nodes[-1]
fnode_input = self.get_data_node_input(top_node)
fnode_output = base_node.output[0]
return {'node_input': fnode_input, 'node_output': fnode_output}
def get_new_init_args(self, matched_nodes):
top_node = matched_nodes[0]
base_node = matched_nodes[-1]
fw_input = self.get_init_node_input(top_node)
fb_input = self.get_init_node_input(base_node)
return {'w_input': fw_input, 'b_input': fb_input}
def get_new_node_args(self, matched_nodes):
args = dict()
args.update(self.get_new_vi_args(matched_nodes))
args.update(self.get_new_init_args(matched_nodes))
return args
def make_nodes(self, node_input, node_output, w_input, b_input, **kwargs):
unsqueeze_node = self.make_node('Unsqueeze', inputs=[node_input],
outputs=[node_input + '_unsqueezed'], name=node_input + '_1',
**{'axes': [2, 3]})
conv_node = self.make_node('Conv',
inputs=[unsqueeze_node.output[0], w_input + '_fused',
b_input + '_fused'],
outputs=[node_input + '_fused'], name=node_input + '_2',
**{
'dilations': [1, 1],
'group': 1,
'kernel_shape': [1, 1],
'pads': [0, 0, 0, 0],
'strides': [1, 1]
})
squeeze_node = self.make_node('Squeeze', inputs=[conv_node.output[0]],
outputs=[node_output], name=node_input + '_3',
**{'axes': [2, 3]})
return unsqueeze_node, conv_node, squeeze_node
def make_initializers(self, w_input, b_input=None, **kwargs):
new_inits = []
w_arr = self.get_initializer_array(w_input)
new_w_arr = self.weight_transformation(w_arr, **kwargs)
new_w_init = self.make_initializer_from_array(new_w_arr, w_input + '_fused')
new_inits.append(new_w_init)
if b_input:
b_arr = self.get_initializer_array(b_input)
new_b_init = self.make_initializer_from_array(b_arr, b_input + '_fused')
new_inits.append(new_b_init)
return new_inits
def weight_transformation(self, w_arr, **kwargs):
c, n = w_arr.shape
new_w_arr = w_arr.transpose().reshape(n, c, 1, 1)
return new_w_arr
def make_value_infos(self, node_input, node_output):
conv_input_vi = self.make_tensor_value_info(node_input + '_unsqueezed',
onnx.TensorProto.FLOAT,
self.get_value_info_shape(node_input) + [1, 1])
conv_output_vi = self.make_tensor_value_info(node_input + '_fused',
onnx.TensorProto.FLOAT,
self.get_value_info_shape(node_output) + [1, 1])
return conv_input_vi, conv_output_vi
class Pattern_2(Pattern_1, abc.ABC):
"""
transform
prev --> Gemm --> next
to
prev --> Unsqueeze --> Conv --> Squeeze --> next
if 1. one of Gemm.A and Gemm.B must have initializer
2. Gemm.C must have initializer if defined
"""
pattern_to_match = ['Gemm']
def pattern_condition_checker(self, nodes_to_check):
node = nodes_to_check[0]
if not self.check_condition_3(node):
return False
if not self.check_condition_4(node):
return False
return True
def check_condition_3(self, node):
num_init = 0
for idx, node_input in enumerate(node.input):
if idx == 2:
break
if node_input in self.initializer_map:
num_init += 1
if num_init == 1:
return True
return False
def check_condition_4(self, node):
if len(node.input) == 3:
if node.input[2] not in self.initializer_map:
return False
return True
def get_new_init_args(self, matched_nodes):
node = matched_nodes[0]
fw_input = node.input[1]
fb_input = None
if len(node.input) == 3:
fb_input = node.input[2]
args = {'w_input': fw_input, 'b_input': fb_input}
args.update(self.get_attrs(node))
return args
def get_new_vi_args(self, matched_nodes):
node = matched_nodes[0]
fnode_input = node.input[0]
fnode_output = node.output[0]
return {'node_input': fnode_input, 'node_output': fnode_output}
def weight_transformation(self, w_arr, **kwargs):
transB = kwargs['transB']
if transB == 0:
w_arr = w_arr.transpose()
n, c = w_arr.shape
new_arr = w_arr.reshape(n, c, 1, 1)
return new_arr
def get_attrs(self, node):
from furiosa_sdk_quantizer.frontend.onnx.quantizer.utils import attribute_to_kwargs
attrs = attribute_to_kwargs(node.attribute)
alpha = attrs['alpha']
beta = attrs['beta']
assert alpha == beta == 1.0, "Assume alpha = beta = 1.0"
transB = attrs['transB']
return {'transB': transB}
Classes
class FuseConv (*args, **kwds)
-
Abstract base class for generic types.
A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::
class Mapping(Generic[KT, VT]): def getitem(self, key: KT) -> VT: … # Etc.
This class can then be used as follows::
def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default
Expand source code
class FuseConv(Transformer): def transform(self, model: onnx.ModelProto) -> onnx.ModelProto: for transformer in [ Pattern_1, Pattern_2, ]: model = transformer(model).transform() return model
Ancestors
- furiosa_sdk_quantizer.interfaces.transformer.Transformer
- typing.Generic
Methods
def transform(self, model: onnx.onnx_ml_pb2.ModelProto) ‑> onnx.onnx_ml_pb2.ModelProto
-
Expand source code
def transform(self, model: onnx.ModelProto) -> onnx.ModelProto: for transformer in [ Pattern_1, Pattern_2, ]: model = transformer(model).transform() return model
class Pattern_1 (model)
-
transform prev –> MatMul –> Add –> next to prev –> Unsqueeze –> Conv –> Squeeze –> next
if 1. MatMul.ndim == 2 2. MatMul must have at most one initializer 3. Add must have at most one initializer
Expand source code
class Pattern_1(ONNXTransformer, abc.ABC): """ transform prev --> MatMul --> Add --> next to prev --> Unsqueeze --> Conv --> Squeeze --> next if 1. MatMul.ndim == 2 2. MatMul must have at most one initializer 3. Add must have at most one initializer """ pattern_to_match = ['MatMul', 'Add'] def pattern_matching(self, base_node): inputs = base_node.input matched_nodes = self.pattern_matcher(base_node, self.pattern_to_match) if not matched_nodes: return inputs if not self.pattern_condition_checker(matched_nodes): return inputs top_node = matched_nodes[0] self.transform_to_fuse(matched_nodes, nodes_to_add=[ *self.make_nodes(**self.get_new_node_args(matched_nodes)) ], inits_to_add=[ *self.make_initializers(**self.get_new_init_args(matched_nodes)) ], vis_to_add=[ *self.make_value_infos(**self.get_new_vi_args(matched_nodes)) ]) return top_node.input def pattern_condition_checker(self, nodes_to_check): top_node, base_node = nodes_to_check if not self.check_condition_1(top_node.output[0]): return False if not self.check_condition_2(top_node): return False if not self.check_condition_2(base_node): return False def check_condition_1(self, tensor_name): if len(self.get_value_info_shape(tensor_name)) == 2: return True return False def check_condition_2(self, node): num_init = 0 for node_input in node.input: if node_input in self.initializer_map: num_init += 1 if num_init == 1: return True return False def get_new_vi_args(self, matched_nodes): top_node = matched_nodes[0] base_node = matched_nodes[-1] fnode_input = self.get_data_node_input(top_node) fnode_output = base_node.output[0] return {'node_input': fnode_input, 'node_output': fnode_output} def get_new_init_args(self, matched_nodes): top_node = matched_nodes[0] base_node = matched_nodes[-1] fw_input = self.get_init_node_input(top_node) fb_input = self.get_init_node_input(base_node) return {'w_input': fw_input, 'b_input': fb_input} def get_new_node_args(self, matched_nodes): args = dict() args.update(self.get_new_vi_args(matched_nodes)) args.update(self.get_new_init_args(matched_nodes)) return args def make_nodes(self, node_input, node_output, w_input, b_input, **kwargs): unsqueeze_node = self.make_node('Unsqueeze', inputs=[node_input], outputs=[node_input + '_unsqueezed'], name=node_input + '_1', **{'axes': [2, 3]}) conv_node = self.make_node('Conv', inputs=[unsqueeze_node.output[0], w_input + '_fused', b_input + '_fused'], outputs=[node_input + '_fused'], name=node_input + '_2', **{ 'dilations': [1, 1], 'group': 1, 'kernel_shape': [1, 1], 'pads': [0, 0, 0, 0], 'strides': [1, 1] }) squeeze_node = self.make_node('Squeeze', inputs=[conv_node.output[0]], outputs=[node_output], name=node_input + '_3', **{'axes': [2, 3]}) return unsqueeze_node, conv_node, squeeze_node def make_initializers(self, w_input, b_input=None, **kwargs): new_inits = [] w_arr = self.get_initializer_array(w_input) new_w_arr = self.weight_transformation(w_arr, **kwargs) new_w_init = self.make_initializer_from_array(new_w_arr, w_input + '_fused') new_inits.append(new_w_init) if b_input: b_arr = self.get_initializer_array(b_input) new_b_init = self.make_initializer_from_array(b_arr, b_input + '_fused') new_inits.append(new_b_init) return new_inits def weight_transformation(self, w_arr, **kwargs): c, n = w_arr.shape new_w_arr = w_arr.transpose().reshape(n, c, 1, 1) return new_w_arr def make_value_infos(self, node_input, node_output): conv_input_vi = self.make_tensor_value_info(node_input + '_unsqueezed', onnx.TensorProto.FLOAT, self.get_value_info_shape(node_input) + [1, 1]) conv_output_vi = self.make_tensor_value_info(node_input + '_fused', onnx.TensorProto.FLOAT, self.get_value_info_shape(node_output) + [1, 1]) return conv_input_vi, conv_output_vi
Ancestors
- furiosa_sdk_quantizer.frontend.onnx.transformer.ONNXTransformer
- abc.ABC
Subclasses
Class variables
var pattern_to_match
Methods
def check_condition_1(self, tensor_name)
-
Expand source code
def check_condition_1(self, tensor_name): if len(self.get_value_info_shape(tensor_name)) == 2: return True return False
def check_condition_2(self, node)
-
Expand source code
def check_condition_2(self, node): num_init = 0 for node_input in node.input: if node_input in self.initializer_map: num_init += 1 if num_init == 1: return True return False
def get_new_init_args(self, matched_nodes)
-
Expand source code
def get_new_init_args(self, matched_nodes): top_node = matched_nodes[0] base_node = matched_nodes[-1] fw_input = self.get_init_node_input(top_node) fb_input = self.get_init_node_input(base_node) return {'w_input': fw_input, 'b_input': fb_input}
def get_new_node_args(self, matched_nodes)
-
Expand source code
def get_new_node_args(self, matched_nodes): args = dict() args.update(self.get_new_vi_args(matched_nodes)) args.update(self.get_new_init_args(matched_nodes)) return args
def get_new_vi_args(self, matched_nodes)
-
Expand source code
def get_new_vi_args(self, matched_nodes): top_node = matched_nodes[0] base_node = matched_nodes[-1] fnode_input = self.get_data_node_input(top_node) fnode_output = base_node.output[0] return {'node_input': fnode_input, 'node_output': fnode_output}
def make_initializers(self, w_input, b_input=None, **kwargs)
-
Expand source code
def make_initializers(self, w_input, b_input=None, **kwargs): new_inits = [] w_arr = self.get_initializer_array(w_input) new_w_arr = self.weight_transformation(w_arr, **kwargs) new_w_init = self.make_initializer_from_array(new_w_arr, w_input + '_fused') new_inits.append(new_w_init) if b_input: b_arr = self.get_initializer_array(b_input) new_b_init = self.make_initializer_from_array(b_arr, b_input + '_fused') new_inits.append(new_b_init) return new_inits
def make_nodes(self, node_input, node_output, w_input, b_input, **kwargs)
-
Expand source code
def make_nodes(self, node_input, node_output, w_input, b_input, **kwargs): unsqueeze_node = self.make_node('Unsqueeze', inputs=[node_input], outputs=[node_input + '_unsqueezed'], name=node_input + '_1', **{'axes': [2, 3]}) conv_node = self.make_node('Conv', inputs=[unsqueeze_node.output[0], w_input + '_fused', b_input + '_fused'], outputs=[node_input + '_fused'], name=node_input + '_2', **{ 'dilations': [1, 1], 'group': 1, 'kernel_shape': [1, 1], 'pads': [0, 0, 0, 0], 'strides': [1, 1] }) squeeze_node = self.make_node('Squeeze', inputs=[conv_node.output[0]], outputs=[node_output], name=node_input + '_3', **{'axes': [2, 3]}) return unsqueeze_node, conv_node, squeeze_node
def make_value_infos(self, node_input, node_output)
-
Expand source code
def make_value_infos(self, node_input, node_output): conv_input_vi = self.make_tensor_value_info(node_input + '_unsqueezed', onnx.TensorProto.FLOAT, self.get_value_info_shape(node_input) + [1, 1]) conv_output_vi = self.make_tensor_value_info(node_input + '_fused', onnx.TensorProto.FLOAT, self.get_value_info_shape(node_output) + [1, 1]) return conv_input_vi, conv_output_vi
def pattern_condition_checker(self, nodes_to_check)
-
Expand source code
def pattern_condition_checker(self, nodes_to_check): top_node, base_node = nodes_to_check if not self.check_condition_1(top_node.output[0]): return False if not self.check_condition_2(top_node): return False if not self.check_condition_2(base_node): return False
def pattern_matching(self, base_node)
-
Expand source code
def pattern_matching(self, base_node): inputs = base_node.input matched_nodes = self.pattern_matcher(base_node, self.pattern_to_match) if not matched_nodes: return inputs if not self.pattern_condition_checker(matched_nodes): return inputs top_node = matched_nodes[0] self.transform_to_fuse(matched_nodes, nodes_to_add=[ *self.make_nodes(**self.get_new_node_args(matched_nodes)) ], inits_to_add=[ *self.make_initializers(**self.get_new_init_args(matched_nodes)) ], vis_to_add=[ *self.make_value_infos(**self.get_new_vi_args(matched_nodes)) ]) return top_node.input
def weight_transformation(self, w_arr, **kwargs)
-
Expand source code
def weight_transformation(self, w_arr, **kwargs): c, n = w_arr.shape new_w_arr = w_arr.transpose().reshape(n, c, 1, 1) return new_w_arr
class Pattern_2 (model)
-
transform prev –> Gemm –> next to prev –> Unsqueeze –> Conv –> Squeeze –> next
if 1. one of Gemm.A and Gemm.B must have initializer 2. Gemm.C must have initializer if defined
Expand source code
class Pattern_2(Pattern_1, abc.ABC): """ transform prev --> Gemm --> next to prev --> Unsqueeze --> Conv --> Squeeze --> next if 1. one of Gemm.A and Gemm.B must have initializer 2. Gemm.C must have initializer if defined """ pattern_to_match = ['Gemm'] def pattern_condition_checker(self, nodes_to_check): node = nodes_to_check[0] if not self.check_condition_3(node): return False if not self.check_condition_4(node): return False return True def check_condition_3(self, node): num_init = 0 for idx, node_input in enumerate(node.input): if idx == 2: break if node_input in self.initializer_map: num_init += 1 if num_init == 1: return True return False def check_condition_4(self, node): if len(node.input) == 3: if node.input[2] not in self.initializer_map: return False return True def get_new_init_args(self, matched_nodes): node = matched_nodes[0] fw_input = node.input[1] fb_input = None if len(node.input) == 3: fb_input = node.input[2] args = {'w_input': fw_input, 'b_input': fb_input} args.update(self.get_attrs(node)) return args def get_new_vi_args(self, matched_nodes): node = matched_nodes[0] fnode_input = node.input[0] fnode_output = node.output[0] return {'node_input': fnode_input, 'node_output': fnode_output} def weight_transformation(self, w_arr, **kwargs): transB = kwargs['transB'] if transB == 0: w_arr = w_arr.transpose() n, c = w_arr.shape new_arr = w_arr.reshape(n, c, 1, 1) return new_arr def get_attrs(self, node): from furiosa_sdk_quantizer.frontend.onnx.quantizer.utils import attribute_to_kwargs attrs = attribute_to_kwargs(node.attribute) alpha = attrs['alpha'] beta = attrs['beta'] assert alpha == beta == 1.0, "Assume alpha = beta = 1.0" transB = attrs['transB'] return {'transB': transB}
Ancestors
- Pattern_1
- furiosa_sdk_quantizer.frontend.onnx.transformer.ONNXTransformer
- abc.ABC
Class variables
var pattern_to_match
Methods
def check_condition_3(self, node)
-
Expand source code
def check_condition_3(self, node): num_init = 0 for idx, node_input in enumerate(node.input): if idx == 2: break if node_input in self.initializer_map: num_init += 1 if num_init == 1: return True return False
def check_condition_4(self, node)
-
Expand source code
def check_condition_4(self, node): if len(node.input) == 3: if node.input[2] not in self.initializer_map: return False return True
def get_attrs(self, node)
-
Expand source code
def get_attrs(self, node): from furiosa_sdk_quantizer.frontend.onnx.quantizer.utils import attribute_to_kwargs attrs = attribute_to_kwargs(node.attribute) alpha = attrs['alpha'] beta = attrs['beta'] assert alpha == beta == 1.0, "Assume alpha = beta = 1.0" transB = attrs['transB'] return {'transB': transB}
def get_new_init_args(self, matched_nodes)
-
Expand source code
def get_new_init_args(self, matched_nodes): node = matched_nodes[0] fw_input = node.input[1] fb_input = None if len(node.input) == 3: fb_input = node.input[2] args = {'w_input': fw_input, 'b_input': fb_input} args.update(self.get_attrs(node)) return args
def get_new_vi_args(self, matched_nodes)
-
Expand source code
def get_new_vi_args(self, matched_nodes): node = matched_nodes[0] fnode_input = node.input[0] fnode_output = node.output[0] return {'node_input': fnode_input, 'node_output': fnode_output}
def pattern_condition_checker(self, nodes_to_check)
-
Expand source code
def pattern_condition_checker(self, nodes_to_check): node = nodes_to_check[0] if not self.check_condition_3(node): return False if not self.check_condition_4(node): return False return True
def weight_transformation(self, w_arr, **kwargs)
-
Expand source code
def weight_transformation(self, w_arr, **kwargs): transB = kwargs['transB'] if transB == 0: w_arr = w_arr.transpose() n, c = w_arr.shape new_arr = w_arr.reshape(n, c, 1, 1) return new_arr