Module furiosa.quantizer.frontend.onnx.transformer.experimental.reify_conv_for_bert
Expand source code
import onnx
from onnx import numpy_helper
from onnx.helper import make_node, make_tensor_value_info, make_model
from furiosa_sdk_quantizer.interfaces.transformer import Transformer
from furiosa_sdk_quantizer.frontend.onnx.transformer import utils
from furiosa_sdk_quantizer.frontend.onnx.transformer.polish_model import PolishModel
from furiosa_sdk_quantizer.frontend.onnx.utils.check_model import check_model
class ReifyConvForBert(Transformer):
"""
from: MatMul + Add
to: Conv
Assume NCHW Input
"""
def __init__(self):
self.nodes_by_output_name = None
self.initializers = None
self.outputs_by_name = None
def transform(self, model: onnx.ModelProto) -> onnx.ModelProto:
model = PolishModel().transform(model)
self.nodes_by_output_name = {node.output[0]: node for node in model.graph.node}
self.initializers = {init.name: init for init in model.graph.initializer}
self.outputs_by_name = {oup.name: oup for oup in model.graph.output}
model = self.transform_matmul_add(model) # transform matmul + add --> conv
check_model(model)
return PolishModel().transform(model)
def transform_matmul_add(self, model):
optimized_nodes = []
removed_nodes = []
# Handle Case1: MatMul + Add
for node in model.graph.node:
if node.op_type != 'Add':
optimized_nodes.append(node)
continue
# Add has no specific order of input according to spec.
# Therefore, we need to find the input index of MatMul
def _is_input_op_type(node_input, op_type):
if node_input not in self.initializers.keys():
return self.nodes_by_output_name[node_input].op_type == op_type
try:
idx_matmul = list(
filter(lambda enum: _is_input_op_type(enum[1], 'MatMul'), enumerate(node.input)))
except KeyError:
optimized_nodes.append(node)
continue
# Expect one of the inputs is Exp and the other is ReduceSum
if len(idx_matmul) != 1:
optimized_nodes.append(node)
continue
idx_matmul = idx_matmul[0][0]
matmul_node = self.nodes_by_output_name[node.input[idx_matmul]]
def _get_initializer_idx(node_input):
if node_input in self.initializers.keys():
return True
idx_matmul_init = list(filter(lambda enum: _get_initializer_idx(enum[1]), \
enumerate(matmul_node.input)))[0][0]
matmul_init = self.initializers[matmul_node.input[idx_matmul_init]]
matmul_weight = numpy_helper.to_array(matmul_init)
c, n = matmul_weight.shape
conv_weight = matmul_weight.transpose().reshape(n, c, 1, 1)
idx_add_init = \
list(filter(lambda enum: _get_initializer_idx(enum[1]), enumerate(matmul_node.input)))[0][0]
add_init = self.initializers[node.input[idx_add_init]]
bias = numpy_helper.to_array(add_init)
conv_weight_init = numpy_helper.from_array(conv_weight,
name=matmul_node.input[idx_matmul_init] + '_reified')
bias_init = numpy_helper.from_array(bias, name=node.input[idx_add_init] + '_reified')
model.graph.initializer.extend([conv_weight_init, bias_init])
removed_nodes.extend([node, matmul_node])
unsqueeze_node = make_node('Unsqueeze',
inputs=[matmul_node.input[0]],
outputs=[matmul_node.output[0] + '_expanded'],
**{
'axes': [1, ]
})
transpose_node = make_node('Transpose',
inputs=[unsqueeze_node.output[0]],
outputs=[matmul_node.output[0] + '_transposed'],
**{
'perm': [0, 3, 1, 2]
})
conv_node = make_node('Conv',
inputs=[transpose_node.output[0],
conv_weight_init.name,
bias_init.name],
outputs=[matmul_node.output[0] + '_conv_output'],
**{
'dilations': [1, 1],
'group': 1,
'kernel_shape': [1, 1],
'pads': [0, 0, 0, 0],
'strides': [1, 1]
})
squeeze_node = make_node('Squeeze',
inputs=[conv_node.output[0]],
outputs=[matmul_node.output[0] + '_squeezed'],
**{'axes': [2]})
transpose_node_1 = make_node('Transpose',
inputs=[squeeze_node.output[0]],
outputs=[node.output[0]],
**{
'perm': [0, 2, 1]
})
optimized_nodes.extend([
unsqueeze_node, transpose_node, conv_node, squeeze_node, transpose_node_1])
graph_input = model.graph.input[0]
if conv_node.input[0] == graph_input.name:
batch_size = graph_input.type.tensor_type.shape.dim[0].dim_value
new_vi = make_tensor_value_info(name=graph_input.name,
elem_type=graph_input.type.tensor_type.elem_type,
shape=(batch_size, c, 1, 1))
model.graph.input.remove(graph_input)
model.graph.input.insert(0, new_vi)
new_nodes = list(filter(lambda node: node not in removed_nodes, optimized_nodes))
model.graph.ClearField('node')
model.graph.node.extend(new_nodes)
model = make_model(model.graph)
model = utils.eliminate_unused_protos(model)
return model
Classes
class ReifyConvForBert
-
from: MatMul + Add to: Conv
Assume NCHW Input
Expand source code
class ReifyConvForBert(Transformer): """ from: MatMul + Add to: Conv Assume NCHW Input """ def __init__(self): self.nodes_by_output_name = None self.initializers = None self.outputs_by_name = None def transform(self, model: onnx.ModelProto) -> onnx.ModelProto: model = PolishModel().transform(model) self.nodes_by_output_name = {node.output[0]: node for node in model.graph.node} self.initializers = {init.name: init for init in model.graph.initializer} self.outputs_by_name = {oup.name: oup for oup in model.graph.output} model = self.transform_matmul_add(model) # transform matmul + add --> conv check_model(model) return PolishModel().transform(model) def transform_matmul_add(self, model): optimized_nodes = [] removed_nodes = [] # Handle Case1: MatMul + Add for node in model.graph.node: if node.op_type != 'Add': optimized_nodes.append(node) continue # Add has no specific order of input according to spec. # Therefore, we need to find the input index of MatMul def _is_input_op_type(node_input, op_type): if node_input not in self.initializers.keys(): return self.nodes_by_output_name[node_input].op_type == op_type try: idx_matmul = list( filter(lambda enum: _is_input_op_type(enum[1], 'MatMul'), enumerate(node.input))) except KeyError: optimized_nodes.append(node) continue # Expect one of the inputs is Exp and the other is ReduceSum if len(idx_matmul) != 1: optimized_nodes.append(node) continue idx_matmul = idx_matmul[0][0] matmul_node = self.nodes_by_output_name[node.input[idx_matmul]] def _get_initializer_idx(node_input): if node_input in self.initializers.keys(): return True idx_matmul_init = list(filter(lambda enum: _get_initializer_idx(enum[1]), \ enumerate(matmul_node.input)))[0][0] matmul_init = self.initializers[matmul_node.input[idx_matmul_init]] matmul_weight = numpy_helper.to_array(matmul_init) c, n = matmul_weight.shape conv_weight = matmul_weight.transpose().reshape(n, c, 1, 1) idx_add_init = \ list(filter(lambda enum: _get_initializer_idx(enum[1]), enumerate(matmul_node.input)))[0][0] add_init = self.initializers[node.input[idx_add_init]] bias = numpy_helper.to_array(add_init) conv_weight_init = numpy_helper.from_array(conv_weight, name=matmul_node.input[idx_matmul_init] + '_reified') bias_init = numpy_helper.from_array(bias, name=node.input[idx_add_init] + '_reified') model.graph.initializer.extend([conv_weight_init, bias_init]) removed_nodes.extend([node, matmul_node]) unsqueeze_node = make_node('Unsqueeze', inputs=[matmul_node.input[0]], outputs=[matmul_node.output[0] + '_expanded'], **{ 'axes': [1, ] }) transpose_node = make_node('Transpose', inputs=[unsqueeze_node.output[0]], outputs=[matmul_node.output[0] + '_transposed'], **{ 'perm': [0, 3, 1, 2] }) conv_node = make_node('Conv', inputs=[transpose_node.output[0], conv_weight_init.name, bias_init.name], outputs=[matmul_node.output[0] + '_conv_output'], **{ 'dilations': [1, 1], 'group': 1, 'kernel_shape': [1, 1], 'pads': [0, 0, 0, 0], 'strides': [1, 1] }) squeeze_node = make_node('Squeeze', inputs=[conv_node.output[0]], outputs=[matmul_node.output[0] + '_squeezed'], **{'axes': [2]}) transpose_node_1 = make_node('Transpose', inputs=[squeeze_node.output[0]], outputs=[node.output[0]], **{ 'perm': [0, 2, 1] }) optimized_nodes.extend([ unsqueeze_node, transpose_node, conv_node, squeeze_node, transpose_node_1]) graph_input = model.graph.input[0] if conv_node.input[0] == graph_input.name: batch_size = graph_input.type.tensor_type.shape.dim[0].dim_value new_vi = make_tensor_value_info(name=graph_input.name, elem_type=graph_input.type.tensor_type.elem_type, shape=(batch_size, c, 1, 1)) model.graph.input.remove(graph_input) model.graph.input.insert(0, new_vi) new_nodes = list(filter(lambda node: node not in removed_nodes, optimized_nodes)) model.graph.ClearField('node') model.graph.node.extend(new_nodes) model = make_model(model.graph) model = utils.eliminate_unused_protos(model) return model
Ancestors
- furiosa_sdk_quantizer.interfaces.transformer.Transformer
- typing.Generic
Methods
def transform(self, model: onnx.onnx_ml_pb2.ModelProto) ‑> onnx.onnx_ml_pb2.ModelProto
-
Expand source code
def transform(self, model: onnx.ModelProto) -> onnx.ModelProto: model = PolishModel().transform(model) self.nodes_by_output_name = {node.output[0]: node for node in model.graph.node} self.initializers = {init.name: init for init in model.graph.initializer} self.outputs_by_name = {oup.name: oup for oup in model.graph.output} model = self.transform_matmul_add(model) # transform matmul + add --> conv check_model(model) return PolishModel().transform(model)
def transform_matmul_add(self, model)
-
Expand source code
def transform_matmul_add(self, model): optimized_nodes = [] removed_nodes = [] # Handle Case1: MatMul + Add for node in model.graph.node: if node.op_type != 'Add': optimized_nodes.append(node) continue # Add has no specific order of input according to spec. # Therefore, we need to find the input index of MatMul def _is_input_op_type(node_input, op_type): if node_input not in self.initializers.keys(): return self.nodes_by_output_name[node_input].op_type == op_type try: idx_matmul = list( filter(lambda enum: _is_input_op_type(enum[1], 'MatMul'), enumerate(node.input))) except KeyError: optimized_nodes.append(node) continue # Expect one of the inputs is Exp and the other is ReduceSum if len(idx_matmul) != 1: optimized_nodes.append(node) continue idx_matmul = idx_matmul[0][0] matmul_node = self.nodes_by_output_name[node.input[idx_matmul]] def _get_initializer_idx(node_input): if node_input in self.initializers.keys(): return True idx_matmul_init = list(filter(lambda enum: _get_initializer_idx(enum[1]), \ enumerate(matmul_node.input)))[0][0] matmul_init = self.initializers[matmul_node.input[idx_matmul_init]] matmul_weight = numpy_helper.to_array(matmul_init) c, n = matmul_weight.shape conv_weight = matmul_weight.transpose().reshape(n, c, 1, 1) idx_add_init = \ list(filter(lambda enum: _get_initializer_idx(enum[1]), enumerate(matmul_node.input)))[0][0] add_init = self.initializers[node.input[idx_add_init]] bias = numpy_helper.to_array(add_init) conv_weight_init = numpy_helper.from_array(conv_weight, name=matmul_node.input[idx_matmul_init] + '_reified') bias_init = numpy_helper.from_array(bias, name=node.input[idx_add_init] + '_reified') model.graph.initializer.extend([conv_weight_init, bias_init]) removed_nodes.extend([node, matmul_node]) unsqueeze_node = make_node('Unsqueeze', inputs=[matmul_node.input[0]], outputs=[matmul_node.output[0] + '_expanded'], **{ 'axes': [1, ] }) transpose_node = make_node('Transpose', inputs=[unsqueeze_node.output[0]], outputs=[matmul_node.output[0] + '_transposed'], **{ 'perm': [0, 3, 1, 2] }) conv_node = make_node('Conv', inputs=[transpose_node.output[0], conv_weight_init.name, bias_init.name], outputs=[matmul_node.output[0] + '_conv_output'], **{ 'dilations': [1, 1], 'group': 1, 'kernel_shape': [1, 1], 'pads': [0, 0, 0, 0], 'strides': [1, 1] }) squeeze_node = make_node('Squeeze', inputs=[conv_node.output[0]], outputs=[matmul_node.output[0] + '_squeezed'], **{'axes': [2]}) transpose_node_1 = make_node('Transpose', inputs=[squeeze_node.output[0]], outputs=[node.output[0]], **{ 'perm': [0, 2, 1] }) optimized_nodes.extend([ unsqueeze_node, transpose_node, conv_node, squeeze_node, transpose_node_1]) graph_input = model.graph.input[0] if conv_node.input[0] == graph_input.name: batch_size = graph_input.type.tensor_type.shape.dim[0].dim_value new_vi = make_tensor_value_info(name=graph_input.name, elem_type=graph_input.type.tensor_type.elem_type, shape=(batch_size, c, 1, 1)) model.graph.input.remove(graph_input) model.graph.input.insert(0, new_vi) new_nodes = list(filter(lambda node: node not in removed_nodes, optimized_nodes)) model.graph.ClearField('node') model.graph.node.extend(new_nodes) model = make_model(model.graph) model = utils.eliminate_unused_protos(model) return model