Module furiosa.quantizer.furiosa_sdk_quantizer.frontend.onnx.transformer.deprecated.eliminate_identity
Expand source code
import onnx
from onnx.helper import make_model
from furiosa_sdk_quantizer.interfaces.transformer import Transformer
from furiosa_sdk_quantizer.frontend.onnx.transformer import utils
from furiosa_sdk_quantizer.frontend.onnx.utils.check_model import check_model
class EliminateIdentity(Transformer):
"""
from: ArgMax -> graph output
to: graph output
"""
def transform(self, model: onnx.ModelProto) -> onnx.ModelProto:
nodes_by_output_name = {node.output[0]: node for node in model.graph.node}
outputs_by_output_name = {output.name: output for output in model.graph.output}
optimized_nodes = []
removed_nodes = []
# handle case where Identity occurs in the middle of graph
for node in model.graph.node:
if node.op_type == 'Constant':
continue
# TODO need to ease assumption that node has only one input if necessary
try:
prev_node = nodes_by_output_name[node.input[0]]
except KeyError:
continue
if prev_node.op_type != 'Identity':
continue
node.input[0] = prev_node.input[0]
removed_nodes.append(prev_node)
# handle case where Identity occurs at the end of graph
for node in model.graph.node:
if node.op_type != 'Identity':
optimized_nodes.append(node)
continue
# Identity must be a graph output
try:
output_node = outputs_by_output_name[node.output[0]]
model.graph.output.remove(output_node)
except KeyError:
continue
removed_nodes.append(node)
# Graph must have at least one graph output
prev_node = nodes_by_output_name[node.input[0]]
new_output_node = outputs_by_output_name[node.output[0]]
new_output_node.name = prev_node.output[0]
model.graph.output.append(new_output_node)
# remove duplicate node(s) in optimized nodes
seen = []
for op_node in optimized_nodes:
if op_node in seen:
continue
seen.append(op_node)
optimized_nodes = seen
new_nodes = list(filter(lambda node: node not in removed_nodes, optimized_nodes))
model = utils.rebuild_model(model, new_nodes)
check_model(model)
return model
Classes
class EliminateIdentity (*args, **kwds)
-
from: ArgMax -> graph output to: graph output
Expand source code
class EliminateIdentity(Transformer): """ from: ArgMax -> graph output to: graph output """ def transform(self, model: onnx.ModelProto) -> onnx.ModelProto: nodes_by_output_name = {node.output[0]: node for node in model.graph.node} outputs_by_output_name = {output.name: output for output in model.graph.output} optimized_nodes = [] removed_nodes = [] # handle case where Identity occurs in the middle of graph for node in model.graph.node: if node.op_type == 'Constant': continue # TODO need to ease assumption that node has only one input if necessary try: prev_node = nodes_by_output_name[node.input[0]] except KeyError: continue if prev_node.op_type != 'Identity': continue node.input[0] = prev_node.input[0] removed_nodes.append(prev_node) # handle case where Identity occurs at the end of graph for node in model.graph.node: if node.op_type != 'Identity': optimized_nodes.append(node) continue # Identity must be a graph output try: output_node = outputs_by_output_name[node.output[0]] model.graph.output.remove(output_node) except KeyError: continue removed_nodes.append(node) # Graph must have at least one graph output prev_node = nodes_by_output_name[node.input[0]] new_output_node = outputs_by_output_name[node.output[0]] new_output_node.name = prev_node.output[0] model.graph.output.append(new_output_node) # remove duplicate node(s) in optimized nodes seen = [] for op_node in optimized_nodes: if op_node in seen: continue seen.append(op_node) optimized_nodes = seen new_nodes = list(filter(lambda node: node not in removed_nodes, optimized_nodes)) model = utils.rebuild_model(model, new_nodes) check_model(model) return model
Ancestors
- furiosa_sdk_quantizer.interfaces.transformer.Transformer
- typing.Generic
Methods
def transform(self, model: onnx.onnx_ml_pb2.ModelProto) ‑> onnx.onnx_ml_pb2.ModelProto
-
Expand source code
def transform(self, model: onnx.ModelProto) -> onnx.ModelProto: nodes_by_output_name = {node.output[0]: node for node in model.graph.node} outputs_by_output_name = {output.name: output for output in model.graph.output} optimized_nodes = [] removed_nodes = [] # handle case where Identity occurs in the middle of graph for node in model.graph.node: if node.op_type == 'Constant': continue # TODO need to ease assumption that node has only one input if necessary try: prev_node = nodes_by_output_name[node.input[0]] except KeyError: continue if prev_node.op_type != 'Identity': continue node.input[0] = prev_node.input[0] removed_nodes.append(prev_node) # handle case where Identity occurs at the end of graph for node in model.graph.node: if node.op_type != 'Identity': optimized_nodes.append(node) continue # Identity must be a graph output try: output_node = outputs_by_output_name[node.output[0]] model.graph.output.remove(output_node) except KeyError: continue removed_nodes.append(node) # Graph must have at least one graph output prev_node = nodes_by_output_name[node.input[0]] new_output_node = outputs_by_output_name[node.output[0]] new_output_node.name = prev_node.output[0] model.graph.output.append(new_output_node) # remove duplicate node(s) in optimized nodes seen = [] for op_node in optimized_nodes: if op_node in seen: continue seen.append(op_node) optimized_nodes = seen new_nodes = list(filter(lambda node: node not in removed_nodes, optimized_nodes)) model = utils.rebuild_model(model, new_nodes) check_model(model) return model