Module furiosa.quantizer.frontend.onnx.transformer.convert_2d_sum_to_add
Expand source code
import onnx
from onnx.helper import make_node, make_tensor, TensorProto
from furiosa_sdk_quantizer.interfaces.transformer import Transformer
from furiosa_sdk_quantizer.frontend.onnx.transformer import utils
from furiosa_sdk_quantizer.frontend.onnx.utils.check_model import check_model
class Convert2dSumToAdd(Transformer):
def transform(self, model: onnx.ModelProto) -> onnx.ModelProto:
optimized_nodes = []
for node in model.graph.node:
if node.op_type != 'Sum':
optimized_nodes.append(node)
continue
if len(node.input) != 2:
optimized_nodes.append(node)
continue
new_node = make_node('Add', inputs=[node.input[0], node.input[1]], outputs=[node.output[0]])
optimized_nodes.append(new_node)
# remove duplicate node(s) in optimized nodes
seen = []
for op_node in optimized_nodes:
if op_node in seen:
continue
seen.append(op_node)
optimized_nodes = seen
model = utils.rebuild_model(model, optimized_nodes)
check_model(model)
return model
Classes
class Convert2dSumToAdd (*args, **kwds)
-
Abstract base class for generic types.
A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::
class Mapping(Generic[KT, VT]): def getitem(self, key: KT) -> VT: … # Etc.
This class can then be used as follows::
def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default
Expand source code
class Convert2dSumToAdd(Transformer): def transform(self, model: onnx.ModelProto) -> onnx.ModelProto: optimized_nodes = [] for node in model.graph.node: if node.op_type != 'Sum': optimized_nodes.append(node) continue if len(node.input) != 2: optimized_nodes.append(node) continue new_node = make_node('Add', inputs=[node.input[0], node.input[1]], outputs=[node.output[0]]) optimized_nodes.append(new_node) # remove duplicate node(s) in optimized nodes seen = [] for op_node in optimized_nodes: if op_node in seen: continue seen.append(op_node) optimized_nodes = seen model = utils.rebuild_model(model, optimized_nodes) check_model(model) return model
Ancestors
- furiosa_sdk_quantizer.interfaces.transformer.Transformer
- typing.Generic
Methods
def transform(self, model: onnx.onnx_ml_pb2.ModelProto) ‑> onnx.onnx_ml_pb2.ModelProto
-
Expand source code
def transform(self, model: onnx.ModelProto) -> onnx.ModelProto: optimized_nodes = [] for node in model.graph.node: if node.op_type != 'Sum': optimized_nodes.append(node) continue if len(node.input) != 2: optimized_nodes.append(node) continue new_node = make_node('Add', inputs=[node.input[0], node.input[1]], outputs=[node.output[0]]) optimized_nodes.append(new_node) # remove duplicate node(s) in optimized nodes seen = [] for op_node in optimized_nodes: if op_node in seen: continue seen.append(op_node) optimized_nodes = seen model = utils.rebuild_model(model, optimized_nodes) check_model(model) return model