Module furiosa.quantizer.furiosa_sdk_quantizer.frontend.onnx.transformer.fuse_gelu
Expand source code
import onnx
from onnx.helper import ModelProto
from furiosa_sdk_quantizer.interfaces.transformer import Transformer
from furiosa_sdk_quantizer.frontend.onnx.transformer import utils
from furiosa_sdk_quantizer.frontend.onnx.utils.check_model import check_model
from onnxruntime_tools.transformers.onnx_model import OnnxModel
from onnxruntime_tools.transformers.fusion_gelu import FusionGelu
class BertOnnxModel(OnnxModel):
def __init__(self, model):
super().__init__(model)
def fuse_gelu(self):
fusion = FusionGelu(self)
fusion.apply()
class FuseGELU(Transformer):
"""
from:
Input --> Div --> Erf --> Add --> M
------------------> Mul --> ul--> Output
to:
GELU
"""
def transform(self, model: onnx.ModelProto) -> onnx.ModelProto:
orig_model = ModelProto()
orig_model.CopyFrom(model)
optimizer = BertOnnxModel(model)
optimizer.fuse_gelu()
model = optimizer.model
gelu_by_input_name = {node.input[0]: node for node in model.graph.node if node.op_type == 'Gelu'}
value_info = {vi.name: vi for vi in
list(model.graph.value_info) + list(model.graph.input) + list(model.graph.output)}
# nodes are not topologically sorted as a result of onnxruntime_tools optimization
sorted_nodes = []
visited = 0
for node in orig_model.graph.node:
if node in model.graph.node:
sorted_nodes.append(node)
if node.output[0] in gelu_by_input_name.keys():
sorted_nodes.append(gelu_by_input_name[node.output[0]])
visited += 1
if not visited:
sorted_nodes = model.graph.node
model = utils.rebuild_model(model, sorted_nodes)
check_model(model)
return model
Classes
class BertOnnxModel (model)
-
Expand source code
class BertOnnxModel(OnnxModel): def __init__(self, model): super().__init__(model) def fuse_gelu(self): fusion = FusionGelu(self) fusion.apply()
Ancestors
- onnxruntime_tools.transformers.onnx_model.OnnxModel
Methods
def fuse_gelu(self)
-
Expand source code
def fuse_gelu(self): fusion = FusionGelu(self) fusion.apply()
class FuseGELU (*args, **kwds)
-
from: Input –> Div –> Erf –> Add –> M ------------------> Mul –> ul–> Output to: GELU
Expand source code
class FuseGELU(Transformer): """ from: Input --> Div --> Erf --> Add --> M ------------------> Mul --> ul--> Output to: GELU """ def transform(self, model: onnx.ModelProto) -> onnx.ModelProto: orig_model = ModelProto() orig_model.CopyFrom(model) optimizer = BertOnnxModel(model) optimizer.fuse_gelu() model = optimizer.model gelu_by_input_name = {node.input[0]: node for node in model.graph.node if node.op_type == 'Gelu'} value_info = {vi.name: vi for vi in list(model.graph.value_info) + list(model.graph.input) + list(model.graph.output)} # nodes are not topologically sorted as a result of onnxruntime_tools optimization sorted_nodes = [] visited = 0 for node in orig_model.graph.node: if node in model.graph.node: sorted_nodes.append(node) if node.output[0] in gelu_by_input_name.keys(): sorted_nodes.append(gelu_by_input_name[node.output[0]]) visited += 1 if not visited: sorted_nodes = model.graph.node model = utils.rebuild_model(model, sorted_nodes) check_model(model) return model
Ancestors
- furiosa_sdk_quantizer.interfaces.transformer.Transformer
- typing.Generic
Methods
def transform(self, model: onnx.onnx_ml_pb2.ModelProto) ‑> onnx.onnx_ml_pb2.ModelProto
-
Expand source code
def transform(self, model: onnx.ModelProto) -> onnx.ModelProto: orig_model = ModelProto() orig_model.CopyFrom(model) optimizer = BertOnnxModel(model) optimizer.fuse_gelu() model = optimizer.model gelu_by_input_name = {node.input[0]: node for node in model.graph.node if node.op_type == 'Gelu'} value_info = {vi.name: vi for vi in list(model.graph.value_info) + list(model.graph.input) + list(model.graph.output)} # nodes are not topologically sorted as a result of onnxruntime_tools optimization sorted_nodes = [] visited = 0 for node in orig_model.graph.node: if node in model.graph.node: sorted_nodes.append(node) if node.output[0] in gelu_by_input_name.keys(): sorted_nodes.append(gelu_by_input_name[node.output[0]]) visited += 1 if not visited: sorted_nodes = model.graph.node model = utils.rebuild_model(model, sorted_nodes) check_model(model) return model