Module furiosa.quantizer.frontend.onnx.transformer.experimental.fuse_div_for_bert
Expand source code
import onnx
from onnx import numpy_helper
from onnx.helper import make_model
from furiosa_sdk_quantizer.interfaces.transformer import Transformer
from furiosa_sdk_quantizer.frontend.onnx.transformer import utils
from furiosa_sdk_quantizer.frontend.onnx.transformer.polish_model import PolishModel
from furiosa_sdk_quantizer.frontend.onnx.utils.check_model import check_model
class FuseDivForBert(Transformer):
"""
Only works for some BERT Models
"""
def __init__(self):
self.nodes_by_output_name = None
self.initializers = None
self.outputs_by_name = None
def transform(self, model: onnx.ModelProto) -> onnx.ModelProto:
model = PolishModel().transform(model)
self.nodes_by_output_name = {node.output[0]: node for node in model.graph.node}
self.initializers = {init.name: init for init in model.graph.initializer}
self.outputs_by_name = {oup.name: oup for oup in model.graph.output}
model = self.transform_matmul_add(model) # transform matmul + add --> conv
check_model(model)
return PolishModel().transform(model)
def transform_matmul_add(self, model):
optimized_nodes = []
removed_nodes = []
# Handle Case1: MatMul + Add
for node in model.graph.node:
if node.op_type != 'Div':
optimized_nodes.append(node)
continue
prev_node_1 = self.nodes_by_output_name[node.input[0]]
if prev_node_1.op_type != 'MatMul':
optimized_nodes.append(node)
continue
prev_node_2 = self.nodes_by_output_name[prev_node_1.input[0]]
prev_node_3 = self.nodes_by_output_name[prev_node_2.input[0]]
prev_node_4 = self.nodes_by_output_name[prev_node_3.input[0]]
assert prev_node_4.op_type == 'Add'
scalar = numpy_helper.to_array(self.initializers[node.input[1]])
arr = numpy_helper.to_array(self.initializers[prev_node_4.input[1]])
model.graph.initializer.append(
numpy_helper.from_array(arr / scalar, name=prev_node_4.input[1] + '_div_fused'))
prev_node_4.input[1] = prev_node_4.input[1] + '_div_fused'
prev_node_1.output[0] = node.output[0]
removed_nodes.append(node)
new_nodes = list(filter(lambda node: node not in removed_nodes, optimized_nodes))
model.graph.ClearField('node')
model.graph.node.extend(new_nodes)
model = make_model(model.graph)
model = utils.eliminate_unused_protos(model)
return model
Classes
class FuseDivForBert
-
Only works for some BERT Models
Expand source code
class FuseDivForBert(Transformer): """ Only works for some BERT Models """ def __init__(self): self.nodes_by_output_name = None self.initializers = None self.outputs_by_name = None def transform(self, model: onnx.ModelProto) -> onnx.ModelProto: model = PolishModel().transform(model) self.nodes_by_output_name = {node.output[0]: node for node in model.graph.node} self.initializers = {init.name: init for init in model.graph.initializer} self.outputs_by_name = {oup.name: oup for oup in model.graph.output} model = self.transform_matmul_add(model) # transform matmul + add --> conv check_model(model) return PolishModel().transform(model) def transform_matmul_add(self, model): optimized_nodes = [] removed_nodes = [] # Handle Case1: MatMul + Add for node in model.graph.node: if node.op_type != 'Div': optimized_nodes.append(node) continue prev_node_1 = self.nodes_by_output_name[node.input[0]] if prev_node_1.op_type != 'MatMul': optimized_nodes.append(node) continue prev_node_2 = self.nodes_by_output_name[prev_node_1.input[0]] prev_node_3 = self.nodes_by_output_name[prev_node_2.input[0]] prev_node_4 = self.nodes_by_output_name[prev_node_3.input[0]] assert prev_node_4.op_type == 'Add' scalar = numpy_helper.to_array(self.initializers[node.input[1]]) arr = numpy_helper.to_array(self.initializers[prev_node_4.input[1]]) model.graph.initializer.append( numpy_helper.from_array(arr / scalar, name=prev_node_4.input[1] + '_div_fused')) prev_node_4.input[1] = prev_node_4.input[1] + '_div_fused' prev_node_1.output[0] = node.output[0] removed_nodes.append(node) new_nodes = list(filter(lambda node: node not in removed_nodes, optimized_nodes)) model.graph.ClearField('node') model.graph.node.extend(new_nodes) model = make_model(model.graph) model = utils.eliminate_unused_protos(model) return model
Ancestors
- furiosa_sdk_quantizer.interfaces.transformer.Transformer
- typing.Generic
Methods
def transform(self, model: onnx.onnx_ml_pb2.ModelProto) ‑> onnx.onnx_ml_pb2.ModelProto
-
Expand source code
def transform(self, model: onnx.ModelProto) -> onnx.ModelProto: model = PolishModel().transform(model) self.nodes_by_output_name = {node.output[0]: node for node in model.graph.node} self.initializers = {init.name: init for init in model.graph.initializer} self.outputs_by_name = {oup.name: oup for oup in model.graph.output} model = self.transform_matmul_add(model) # transform matmul + add --> conv check_model(model) return PolishModel().transform(model)
def transform_matmul_add(self, model)
-
Expand source code
def transform_matmul_add(self, model): optimized_nodes = [] removed_nodes = [] # Handle Case1: MatMul + Add for node in model.graph.node: if node.op_type != 'Div': optimized_nodes.append(node) continue prev_node_1 = self.nodes_by_output_name[node.input[0]] if prev_node_1.op_type != 'MatMul': optimized_nodes.append(node) continue prev_node_2 = self.nodes_by_output_name[prev_node_1.input[0]] prev_node_3 = self.nodes_by_output_name[prev_node_2.input[0]] prev_node_4 = self.nodes_by_output_name[prev_node_3.input[0]] assert prev_node_4.op_type == 'Add' scalar = numpy_helper.to_array(self.initializers[node.input[1]]) arr = numpy_helper.to_array(self.initializers[prev_node_4.input[1]]) model.graph.initializer.append( numpy_helper.from_array(arr / scalar, name=prev_node_4.input[1] + '_div_fused')) prev_node_4.input[1] = prev_node_4.input[1] + '_div_fused' prev_node_1.output[0] = node.output[0] removed_nodes.append(node) new_nodes = list(filter(lambda node: node not in removed_nodes, optimized_nodes)) model.graph.ClearField('node') model.graph.node.extend(new_nodes) model = make_model(model.graph) model = utils.eliminate_unused_protos(model) return model