Module furiosa.quantizer.furiosa_sdk_quantizer.frontend.onnx.transformer.fuse_gelu

Expand source code
import onnx
from onnx.helper import ModelProto

from furiosa_sdk_quantizer.interfaces.transformer import Transformer
from furiosa_sdk_quantizer.frontend.onnx.transformer import utils
from furiosa_sdk_quantizer.frontend.onnx.utils.check_model import check_model

from onnxruntime_tools.transformers.onnx_model import OnnxModel
from onnxruntime_tools.transformers.fusion_gelu import FusionGelu


class BertOnnxModel(OnnxModel):
    def __init__(self, model):
        super().__init__(model)

    def fuse_gelu(self):
        fusion = FusionGelu(self)
        fusion.apply()


class FuseGELU(Transformer):
    """
    from:
        Input --> Div --> Erf --> Add --> M
              ------------------> Mul --> ul--> Output
    to:
        GELU
    """

    def transform(self, model: onnx.ModelProto) -> onnx.ModelProto:
        orig_model = ModelProto()
        orig_model.CopyFrom(model)

        optimizer = BertOnnxModel(model)
        optimizer.fuse_gelu()

        model = optimizer.model
        gelu_by_input_name = {node.input[0]: node for node in model.graph.node if node.op_type == 'Gelu'}

        value_info = {vi.name: vi for vi in
                      list(model.graph.value_info) + list(model.graph.input) + list(model.graph.output)}

        # nodes are not topologically sorted as a result of onnxruntime_tools optimization
        sorted_nodes = []
        visited = 0
        for node in orig_model.graph.node:
            if node in model.graph.node:
                sorted_nodes.append(node)
                if node.output[0] in gelu_by_input_name.keys():
                    sorted_nodes.append(gelu_by_input_name[node.output[0]])
                visited += 1

        if not visited:
            sorted_nodes = model.graph.node

        model = utils.rebuild_model(model, sorted_nodes)
        check_model(model)

        return model

Classes

class BertOnnxModel (model)
Expand source code
class BertOnnxModel(OnnxModel):
    def __init__(self, model):
        super().__init__(model)

    def fuse_gelu(self):
        fusion = FusionGelu(self)
        fusion.apply()

Ancestors

  • onnxruntime_tools.transformers.onnx_model.OnnxModel

Methods

def fuse_gelu(self)
Expand source code
def fuse_gelu(self):
    fusion = FusionGelu(self)
    fusion.apply()
class FuseGELU (*args, **kwds)

from: Input –> Div –> Erf –> Add –> M ------------------> Mul –> ul–> Output to: GELU

Expand source code
class FuseGELU(Transformer):
    """
    from:
        Input --> Div --> Erf --> Add --> M
              ------------------> Mul --> ul--> Output
    to:
        GELU
    """

    def transform(self, model: onnx.ModelProto) -> onnx.ModelProto:
        orig_model = ModelProto()
        orig_model.CopyFrom(model)

        optimizer = BertOnnxModel(model)
        optimizer.fuse_gelu()

        model = optimizer.model
        gelu_by_input_name = {node.input[0]: node for node in model.graph.node if node.op_type == 'Gelu'}

        value_info = {vi.name: vi for vi in
                      list(model.graph.value_info) + list(model.graph.input) + list(model.graph.output)}

        # nodes are not topologically sorted as a result of onnxruntime_tools optimization
        sorted_nodes = []
        visited = 0
        for node in orig_model.graph.node:
            if node in model.graph.node:
                sorted_nodes.append(node)
                if node.output[0] in gelu_by_input_name.keys():
                    sorted_nodes.append(gelu_by_input_name[node.output[0]])
                visited += 1

        if not visited:
            sorted_nodes = model.graph.node

        model = utils.rebuild_model(model, sorted_nodes)
        check_model(model)

        return model

Ancestors

  • furiosa_sdk_quantizer.interfaces.transformer.Transformer
  • typing.Generic

Methods

def transform(self, model: onnx.onnx_ml_pb2.ModelProto) ‑> onnx.onnx_ml_pb2.ModelProto
Expand source code
def transform(self, model: onnx.ModelProto) -> onnx.ModelProto:
    orig_model = ModelProto()
    orig_model.CopyFrom(model)

    optimizer = BertOnnxModel(model)
    optimizer.fuse_gelu()

    model = optimizer.model
    gelu_by_input_name = {node.input[0]: node for node in model.graph.node if node.op_type == 'Gelu'}

    value_info = {vi.name: vi for vi in
                  list(model.graph.value_info) + list(model.graph.input) + list(model.graph.output)}

    # nodes are not topologically sorted as a result of onnxruntime_tools optimization
    sorted_nodes = []
    visited = 0
    for node in orig_model.graph.node:
        if node in model.graph.node:
            sorted_nodes.append(node)
            if node.output[0] in gelu_by_input_name.keys():
                sorted_nodes.append(gelu_by_input_name[node.output[0]])
            visited += 1

    if not visited:
        sorted_nodes = model.graph.node

    model = utils.rebuild_model(model, sorted_nodes)
    check_model(model)

    return model