Module furiosa.quantizer.frontend.onnx.transformer.fuse_layer_normalization

Expand source code
import onnx
from onnx.helper import make_model, ModelProto

from furiosa_sdk_quantizer.interfaces.transformer import Transformer
from furiosa_sdk_quantizer.frontend.onnx.transformer import utils
from furiosa_sdk_quantizer.frontend.onnx.utils.check_model import check_model

from onnxruntime_tools.transformers.onnx_model import OnnxModel
from onnxruntime_tools.transformers.fusion_layernorm import FusionLayerNormalization


class BertOnnxModel(OnnxModel):
    def __init__(self, model):
        super().__init__(model)

    def fuse_layer_normalization(self):
        fusion = FusionLayerNormalization(self)
        fusion.apply()


class FuseLayerNormalization(Transformer):
    """
    from:
        Input --> ReduceMean --> S --> Pow --> ReduceMean --> Add --> Sqrt --> D
              -----------------> ub -----------------------------------------> iv --> Mul --> Add Output
    to:
        LayerNormalization
    """

    def transform(self, model: onnx.ModelProto) -> onnx.ModelProto:
        orig_model = ModelProto()
        orig_model.CopyFrom(model)

        optimizer = BertOnnxModel(model)
        optimizer.fuse_layer_normalization()

        model = optimizer.model
        layer_norm_by_input_name = {node.input[0]: node for node in model.graph.node
                                    if node.op_type == 'LayerNormalization'}

        # nodes are not topologically sorted as a result of onnxruntime_tools optimization
        sorted_nodes = []
        visited = 0
        for node in orig_model.graph.node:
            if node in model.graph.node:
                sorted_nodes.append(node)
                if node.output[0] in layer_norm_by_input_name.keys():
                    sorted_nodes.append(layer_norm_by_input_name[node.output[0]])
                visited += 1

        if not visited:
            sorted_nodes = model.graph.node

        model = utils.rebuild_model(model, sorted_nodes)
        check_model(model)

        return model

Classes

class BertOnnxModel (model)
Expand source code
class BertOnnxModel(OnnxModel):
    def __init__(self, model):
        super().__init__(model)

    def fuse_layer_normalization(self):
        fusion = FusionLayerNormalization(self)
        fusion.apply()

Ancestors

  • onnxruntime_tools.transformers.onnx_model.OnnxModel

Methods

def fuse_layer_normalization(self)
Expand source code
def fuse_layer_normalization(self):
    fusion = FusionLayerNormalization(self)
    fusion.apply()
class FuseLayerNormalization (*args, **kwds)

from: Input –> ReduceMean –> S –> Pow –> ReduceMean –> Add –> Sqrt –> D -----------------> ub -----------------------------------------> iv –> Mul –> Add Output to: LayerNormalization

Expand source code
class FuseLayerNormalization(Transformer):
    """
    from:
        Input --> ReduceMean --> S --> Pow --> ReduceMean --> Add --> Sqrt --> D
              -----------------> ub -----------------------------------------> iv --> Mul --> Add Output
    to:
        LayerNormalization
    """

    def transform(self, model: onnx.ModelProto) -> onnx.ModelProto:
        orig_model = ModelProto()
        orig_model.CopyFrom(model)

        optimizer = BertOnnxModel(model)
        optimizer.fuse_layer_normalization()

        model = optimizer.model
        layer_norm_by_input_name = {node.input[0]: node for node in model.graph.node
                                    if node.op_type == 'LayerNormalization'}

        # nodes are not topologically sorted as a result of onnxruntime_tools optimization
        sorted_nodes = []
        visited = 0
        for node in orig_model.graph.node:
            if node in model.graph.node:
                sorted_nodes.append(node)
                if node.output[0] in layer_norm_by_input_name.keys():
                    sorted_nodes.append(layer_norm_by_input_name[node.output[0]])
                visited += 1

        if not visited:
            sorted_nodes = model.graph.node

        model = utils.rebuild_model(model, sorted_nodes)
        check_model(model)

        return model

Ancestors

  • furiosa_sdk_quantizer.interfaces.transformer.Transformer
  • typing.Generic

Methods

def transform(self, model: onnx.onnx_ml_pb2.ModelProto) ‑> onnx.onnx_ml_pb2.ModelProto
Expand source code
def transform(self, model: onnx.ModelProto) -> onnx.ModelProto:
    orig_model = ModelProto()
    orig_model.CopyFrom(model)

    optimizer = BertOnnxModel(model)
    optimizer.fuse_layer_normalization()

    model = optimizer.model
    layer_norm_by_input_name = {node.input[0]: node for node in model.graph.node
                                if node.op_type == 'LayerNormalization'}

    # nodes are not topologically sorted as a result of onnxruntime_tools optimization
    sorted_nodes = []
    visited = 0
    for node in orig_model.graph.node:
        if node in model.graph.node:
            sorted_nodes.append(node)
            if node.output[0] in layer_norm_by_input_name.keys():
                sorted_nodes.append(layer_norm_by_input_name[node.output[0]])
            visited += 1

    if not visited:
        sorted_nodes = model.graph.node

    model = utils.rebuild_model(model, sorted_nodes)
    check_model(model)

    return model