Module furiosa.quantizer.furiosa_sdk_quantizer.frontend.onnx.transformer.eliminate_argmax_output
Expand source code
import onnx
import abc
from furiosa_sdk_quantizer.frontend.onnx.transformer import ONNXTransformer
from furiosa_sdk_quantizer.interfaces.transformer import Transformer
class EliminateArgmaxOutput(Transformer):
def transform(self, model: onnx.ModelProto) -> onnx.ModelProto:
for transformer in [
Pattern_1,
]:
model = transformer(model).transform()
return model
class Pattern_1(ONNXTransformer, abc.ABC):
"""
transform
prev --> Argmax --> next
to
prev --> ( ) --> next
if next is one of graph outputs
"""
def pattern_matching(self, base_node):
inputs = base_node.input
pattern_to_match = ['ArgMax']
matched_nodes = self.pattern_matcher(base_node, pattern_to_match)
if not matched_nodes:
return inputs
if not self.pattern_condition_checker(matched_nodes):
return inputs
self.transform_to_eliminate(matched_nodes, base_node.input[0])
return inputs
def pattern_condition_checker(self, nodes_to_check):
node = nodes_to_check[0]
if any(node.output[0] == output for output in self.graph_output_map):
return True
return False
Classes
class EliminateArgmaxOutput (*args, **kwds)
-
Abstract base class for generic types.
A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::
class Mapping(Generic[KT, VT]): def getitem(self, key: KT) -> VT: … # Etc.
This class can then be used as follows::
def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default
Expand source code
class EliminateArgmaxOutput(Transformer): def transform(self, model: onnx.ModelProto) -> onnx.ModelProto: for transformer in [ Pattern_1, ]: model = transformer(model).transform() return model
Ancestors
- furiosa_sdk_quantizer.interfaces.transformer.Transformer
- typing.Generic
Methods
def transform(self, model: onnx.onnx_ml_pb2.ModelProto) ‑> onnx.onnx_ml_pb2.ModelProto
-
Expand source code
def transform(self, model: onnx.ModelProto) -> onnx.ModelProto: for transformer in [ Pattern_1, ]: model = transformer(model).transform() return model
class Pattern_1 (model)
-
transform prev –> Argmax –> next to prev –> ( ) –> next if next is one of graph outputs
Expand source code
class Pattern_1(ONNXTransformer, abc.ABC): """ transform prev --> Argmax --> next to prev --> ( ) --> next if next is one of graph outputs """ def pattern_matching(self, base_node): inputs = base_node.input pattern_to_match = ['ArgMax'] matched_nodes = self.pattern_matcher(base_node, pattern_to_match) if not matched_nodes: return inputs if not self.pattern_condition_checker(matched_nodes): return inputs self.transform_to_eliminate(matched_nodes, base_node.input[0]) return inputs def pattern_condition_checker(self, nodes_to_check): node = nodes_to_check[0] if any(node.output[0] == output for output in self.graph_output_map): return True return False
Ancestors
- furiosa_sdk_quantizer.frontend.onnx.transformer.ONNXTransformer
- abc.ABC
Methods
def pattern_condition_checker(self, nodes_to_check)
-
Expand source code
def pattern_condition_checker(self, nodes_to_check): node = nodes_to_check[0] if any(node.output[0] == output for output in self.graph_output_map): return True return False
def pattern_matching(self, base_node)
-
Expand source code
def pattern_matching(self, base_node): inputs = base_node.input pattern_to_match = ['ArgMax'] matched_nodes = self.pattern_matcher(base_node, pattern_to_match) if not matched_nodes: return inputs if not self.pattern_condition_checker(matched_nodes): return inputs self.transform_to_eliminate(matched_nodes, base_node.input[0]) return inputs