Module furiosa.quantizer.frontend.onnx.transformer.convert_2d_sum_to_add

Expand source code
import onnx

from onnx.helper import make_node, make_tensor, TensorProto

from furiosa_sdk_quantizer.interfaces.transformer import Transformer
from furiosa_sdk_quantizer.frontend.onnx.transformer import utils
from furiosa_sdk_quantizer.frontend.onnx.utils.check_model import check_model


class Convert2dSumToAdd(Transformer):
    def transform(self, model: onnx.ModelProto) -> onnx.ModelProto:
        optimized_nodes = []
        for node in model.graph.node:
            if node.op_type != 'Sum':
                optimized_nodes.append(node)
                continue

            if len(node.input) != 2:
                optimized_nodes.append(node)
                continue

            new_node = make_node('Add', inputs=[node.input[0], node.input[1]], outputs=[node.output[0]])

            optimized_nodes.append(new_node)

        # remove duplicate node(s) in optimized nodes
        seen = []
        for op_node in optimized_nodes:
            if op_node in seen:
                continue
            seen.append(op_node)
        optimized_nodes = seen

        model = utils.rebuild_model(model, optimized_nodes)
        check_model(model)

        return model

Classes

class Convert2dSumToAdd (*args, **kwds)

Abstract base class for generic types.

A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::

class Mapping(Generic[KT, VT]): def getitem(self, key: KT) -> VT: … # Etc.

This class can then be used as follows::

def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default

Expand source code
class Convert2dSumToAdd(Transformer):
    def transform(self, model: onnx.ModelProto) -> onnx.ModelProto:
        optimized_nodes = []
        for node in model.graph.node:
            if node.op_type != 'Sum':
                optimized_nodes.append(node)
                continue

            if len(node.input) != 2:
                optimized_nodes.append(node)
                continue

            new_node = make_node('Add', inputs=[node.input[0], node.input[1]], outputs=[node.output[0]])

            optimized_nodes.append(new_node)

        # remove duplicate node(s) in optimized nodes
        seen = []
        for op_node in optimized_nodes:
            if op_node in seen:
                continue
            seen.append(op_node)
        optimized_nodes = seen

        model = utils.rebuild_model(model, optimized_nodes)
        check_model(model)

        return model

Ancestors

  • furiosa_sdk_quantizer.interfaces.transformer.Transformer
  • typing.Generic

Methods

def transform(self, model: onnx.onnx_ml_pb2.ModelProto) ‑> onnx.onnx_ml_pb2.ModelProto
Expand source code
def transform(self, model: onnx.ModelProto) -> onnx.ModelProto:
    optimized_nodes = []
    for node in model.graph.node:
        if node.op_type != 'Sum':
            optimized_nodes.append(node)
            continue

        if len(node.input) != 2:
            optimized_nodes.append(node)
            continue

        new_node = make_node('Add', inputs=[node.input[0], node.input[1]], outputs=[node.output[0]])

        optimized_nodes.append(new_node)

    # remove duplicate node(s) in optimized nodes
    seen = []
    for op_node in optimized_nodes:
        if op_node in seen:
            continue
        seen.append(op_node)
    optimized_nodes = seen

    model = utils.rebuild_model(model, optimized_nodes)
    check_model(model)

    return model