Module furiosa.quantizer.frontend.onnx.transformer.fuse_layer_normalization
Expand source code
import onnx
from onnx.helper import make_model, ModelProto
from furiosa_sdk_quantizer.interfaces.transformer import Transformer
from furiosa_sdk_quantizer.frontend.onnx.transformer import utils
from furiosa_sdk_quantizer.frontend.onnx.utils.check_model import check_model
from onnxruntime_tools.transformers.onnx_model import OnnxModel
from onnxruntime_tools.transformers.fusion_layernorm import FusionLayerNormalization
class BertOnnxModel(OnnxModel):
def __init__(self, model):
super().__init__(model)
def fuse_layer_normalization(self):
fusion = FusionLayerNormalization(self)
fusion.apply()
class FuseLayerNormalization(Transformer):
"""
from:
Input --> ReduceMean --> S --> Pow --> ReduceMean --> Add --> Sqrt --> D
-----------------> ub -----------------------------------------> iv --> Mul --> Add Output
to:
LayerNormalization
"""
def transform(self, model: onnx.ModelProto) -> onnx.ModelProto:
orig_model = ModelProto()
orig_model.CopyFrom(model)
optimizer = BertOnnxModel(model)
optimizer.fuse_layer_normalization()
model = optimizer.model
layer_norm_by_input_name = {node.input[0]: node for node in model.graph.node
if node.op_type == 'LayerNormalization'}
# nodes are not topologically sorted as a result of onnxruntime_tools optimization
sorted_nodes = []
visited = 0
for node in orig_model.graph.node:
if node in model.graph.node:
sorted_nodes.append(node)
if node.output[0] in layer_norm_by_input_name.keys():
sorted_nodes.append(layer_norm_by_input_name[node.output[0]])
visited += 1
if not visited:
sorted_nodes = model.graph.node
model = utils.rebuild_model(model, sorted_nodes)
check_model(model)
return model
Classes
class BertOnnxModel (model)
-
Expand source code
class BertOnnxModel(OnnxModel): def __init__(self, model): super().__init__(model) def fuse_layer_normalization(self): fusion = FusionLayerNormalization(self) fusion.apply()
Ancestors
- onnxruntime_tools.transformers.onnx_model.OnnxModel
Methods
def fuse_layer_normalization(self)
-
Expand source code
def fuse_layer_normalization(self): fusion = FusionLayerNormalization(self) fusion.apply()
class FuseLayerNormalization (*args, **kwds)
-
from: Input –> ReduceMean –> S –> Pow –> ReduceMean –> Add –> Sqrt –> D -----------------> ub -----------------------------------------> iv –> Mul –> Add Output to: LayerNormalization
Expand source code
class FuseLayerNormalization(Transformer): """ from: Input --> ReduceMean --> S --> Pow --> ReduceMean --> Add --> Sqrt --> D -----------------> ub -----------------------------------------> iv --> Mul --> Add Output to: LayerNormalization """ def transform(self, model: onnx.ModelProto) -> onnx.ModelProto: orig_model = ModelProto() orig_model.CopyFrom(model) optimizer = BertOnnxModel(model) optimizer.fuse_layer_normalization() model = optimizer.model layer_norm_by_input_name = {node.input[0]: node for node in model.graph.node if node.op_type == 'LayerNormalization'} # nodes are not topologically sorted as a result of onnxruntime_tools optimization sorted_nodes = [] visited = 0 for node in orig_model.graph.node: if node in model.graph.node: sorted_nodes.append(node) if node.output[0] in layer_norm_by_input_name.keys(): sorted_nodes.append(layer_norm_by_input_name[node.output[0]]) visited += 1 if not visited: sorted_nodes = model.graph.node model = utils.rebuild_model(model, sorted_nodes) check_model(model) return model
Ancestors
- furiosa_sdk_quantizer.interfaces.transformer.Transformer
- typing.Generic
Methods
def transform(self, model: onnx.onnx_ml_pb2.ModelProto) ‑> onnx.onnx_ml_pb2.ModelProto
-
Expand source code
def transform(self, model: onnx.ModelProto) -> onnx.ModelProto: orig_model = ModelProto() orig_model.CopyFrom(model) optimizer = BertOnnxModel(model) optimizer.fuse_layer_normalization() model = optimizer.model layer_norm_by_input_name = {node.input[0]: node for node in model.graph.node if node.op_type == 'LayerNormalization'} # nodes are not topologically sorted as a result of onnxruntime_tools optimization sorted_nodes = [] visited = 0 for node in orig_model.graph.node: if node in model.graph.node: sorted_nodes.append(node) if node.output[0] in layer_norm_by_input_name.keys(): sorted_nodes.append(layer_norm_by_input_name[node.output[0]]) visited += 1 if not visited: sorted_nodes = model.graph.node model = utils.rebuild_model(model, sorted_nodes) check_model(model) return model