Module furiosa.quantizer.frontend.onnx.transformer.fuse_redundant_reshape_pattern
Expand source code
import abc
import onnx
from furiosa_sdk_quantizer.interfaces.transformer import Transformer
from furiosa_sdk_quantizer.frontend.onnx.transformer import ONNXTransformer
class FuseRedundantReshapePattern(Transformer):
def transform(self, model: onnx.ModelProto) -> onnx.ModelProto:
for transformer in [
Pattern_2,
Pattern_1,
Pattern_3,
]:
model = transformer(model).transform()
return model
class Pattern_1(ONNXTransformer, abc.ABC):
"""
transform
prev --> Reshape --> Reshape --> next
to
prev --> Reshape --> next
if prev.output[0].shape != next.input[0].shape
"""
pattern_to_match = ['Reshape', 'Reshape']
postfix = '_reshape_fused'
def pattern_matching(self, base_node):
inputs = base_node.input
matched_nodes = self.pattern_matcher(base_node, self.pattern_to_match)
if not matched_nodes:
return inputs
if not self.pattern_condition_checker(matched_nodes):
return inputs
top_node = matched_nodes[0]
self.transform_to_fuse(matched_nodes,
nodes_to_add=[self.make_new_node(matched_nodes)],
inits_to_add=[self.make_new_init(matched_nodes)],
vis_to_add=[self.make_new_vi(matched_nodes)] if self.make_new_vi(
matched_nodes) else None)
return top_node.input
def pattern_condition_checker(self, nodes_to_check):
top_node = nodes_to_check[0]
base_node = nodes_to_check[-1]
if self.is_same_shape(top_node.input[0], base_node.output[0]):
return False
return True
def make_new_node(self, matched_nodes):
top_node = matched_nodes[0]
base_node = matched_nodes[-1]
return self.make_node('Reshape', [top_node.input[0], top_node.input[1] + self.postfix],
[base_node.output[0]],
name=top_node.name)
def make_new_init(self, matched_nodes):
top_node = matched_nodes[0]
base_node = matched_nodes[-1]
return self.make_int64_initializer(top_node.input[1] + self.postfix, base_node.output[0])
def make_new_vi(self, matched_nodes):
return None
class Pattern_2(Pattern_1, abc.ABC):
"""
transform
prev --> Reshape --> Reshape --> Reshape --> next
to
prev --> Reshape --> next
if prev.output[0].shape != next.input[0].shape
"""
pattern_to_match = ['Reshape', 'Reshape', 'Reshape']
class Pattern_3(Pattern_1, abc.ABC):
"""
transform
prev --> Flatten/Squeeze --> Unsqueeze --> next
to
prev --> Reshape --> next
if prev.output[0].shape != next.input[0].shape
"""
pattern_to_match = ['Flatten/Squeeze', 'Unsqueeze']
def make_new_node(self, matched_nodes):
top_node = matched_nodes[0]
base_node = matched_nodes[-1]
return self.make_node('Reshape', [top_node.input[0], top_node.input[0] + self.postfix],
[base_node.output[0]], top_node.name)
def make_new_init(self, matched_nodes):
top_node = matched_nodes[0]
base_node = matched_nodes[-1]
return self.make_int64_initializer(top_node.input[0] + self.postfix, base_node.output[0])
def make_new_vi(self, matched_nodes):
top_node = matched_nodes[0]
return self.copy_value_info(top_node.input[0])
Classes
class FuseRedundantReshapePattern (*args, **kwds)
-
Abstract base class for generic types.
A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::
class Mapping(Generic[KT, VT]): def getitem(self, key: KT) -> VT: … # Etc.
This class can then be used as follows::
def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default
Expand source code
class FuseRedundantReshapePattern(Transformer): def transform(self, model: onnx.ModelProto) -> onnx.ModelProto: for transformer in [ Pattern_2, Pattern_1, Pattern_3, ]: model = transformer(model).transform() return model
Ancestors
- furiosa_sdk_quantizer.interfaces.transformer.Transformer
- typing.Generic
Methods
def transform(self, model: onnx.onnx_ml_pb2.ModelProto) ‑> onnx.onnx_ml_pb2.ModelProto
-
Expand source code
def transform(self, model: onnx.ModelProto) -> onnx.ModelProto: for transformer in [ Pattern_2, Pattern_1, Pattern_3, ]: model = transformer(model).transform() return model
class Pattern_1 (model)
-
transform prev –> Reshape –> Reshape –> next to prev –> Reshape –> next
if prev.output[0].shape != next.input[0].shape
Expand source code
class Pattern_1(ONNXTransformer, abc.ABC): """ transform prev --> Reshape --> Reshape --> next to prev --> Reshape --> next if prev.output[0].shape != next.input[0].shape """ pattern_to_match = ['Reshape', 'Reshape'] postfix = '_reshape_fused' def pattern_matching(self, base_node): inputs = base_node.input matched_nodes = self.pattern_matcher(base_node, self.pattern_to_match) if not matched_nodes: return inputs if not self.pattern_condition_checker(matched_nodes): return inputs top_node = matched_nodes[0] self.transform_to_fuse(matched_nodes, nodes_to_add=[self.make_new_node(matched_nodes)], inits_to_add=[self.make_new_init(matched_nodes)], vis_to_add=[self.make_new_vi(matched_nodes)] if self.make_new_vi( matched_nodes) else None) return top_node.input def pattern_condition_checker(self, nodes_to_check): top_node = nodes_to_check[0] base_node = nodes_to_check[-1] if self.is_same_shape(top_node.input[0], base_node.output[0]): return False return True def make_new_node(self, matched_nodes): top_node = matched_nodes[0] base_node = matched_nodes[-1] return self.make_node('Reshape', [top_node.input[0], top_node.input[1] + self.postfix], [base_node.output[0]], name=top_node.name) def make_new_init(self, matched_nodes): top_node = matched_nodes[0] base_node = matched_nodes[-1] return self.make_int64_initializer(top_node.input[1] + self.postfix, base_node.output[0]) def make_new_vi(self, matched_nodes): return None
Ancestors
- furiosa_sdk_quantizer.frontend.onnx.transformer.ONNXTransformer
- abc.ABC
Subclasses
Class variables
var pattern_to_match
var postfix
Methods
def make_new_init(self, matched_nodes)
-
Expand source code
def make_new_init(self, matched_nodes): top_node = matched_nodes[0] base_node = matched_nodes[-1] return self.make_int64_initializer(top_node.input[1] + self.postfix, base_node.output[0])
def make_new_node(self, matched_nodes)
-
Expand source code
def make_new_node(self, matched_nodes): top_node = matched_nodes[0] base_node = matched_nodes[-1] return self.make_node('Reshape', [top_node.input[0], top_node.input[1] + self.postfix], [base_node.output[0]], name=top_node.name)
def make_new_vi(self, matched_nodes)
-
Expand source code
def make_new_vi(self, matched_nodes): return None
def pattern_condition_checker(self, nodes_to_check)
-
Expand source code
def pattern_condition_checker(self, nodes_to_check): top_node = nodes_to_check[0] base_node = nodes_to_check[-1] if self.is_same_shape(top_node.input[0], base_node.output[0]): return False return True
def pattern_matching(self, base_node)
-
Expand source code
def pattern_matching(self, base_node): inputs = base_node.input matched_nodes = self.pattern_matcher(base_node, self.pattern_to_match) if not matched_nodes: return inputs if not self.pattern_condition_checker(matched_nodes): return inputs top_node = matched_nodes[0] self.transform_to_fuse(matched_nodes, nodes_to_add=[self.make_new_node(matched_nodes)], inits_to_add=[self.make_new_init(matched_nodes)], vis_to_add=[self.make_new_vi(matched_nodes)] if self.make_new_vi( matched_nodes) else None) return top_node.input
class Pattern_2 (model)
-
transform prev –> Reshape –> Reshape –> Reshape –> next to prev –> Reshape –> next
if prev.output[0].shape != next.input[0].shape
Expand source code
class Pattern_2(Pattern_1, abc.ABC): """ transform prev --> Reshape --> Reshape --> Reshape --> next to prev --> Reshape --> next if prev.output[0].shape != next.input[0].shape """ pattern_to_match = ['Reshape', 'Reshape', 'Reshape']
Ancestors
- Pattern_1
- furiosa_sdk_quantizer.frontend.onnx.transformer.ONNXTransformer
- abc.ABC
Class variables
var pattern_to_match
class Pattern_3 (model)
-
transform prev –> Flatten/Squeeze –> Unsqueeze –> next to prev –> Reshape –> next if prev.output[0].shape != next.input[0].shape
Expand source code
class Pattern_3(Pattern_1, abc.ABC): """ transform prev --> Flatten/Squeeze --> Unsqueeze --> next to prev --> Reshape --> next if prev.output[0].shape != next.input[0].shape """ pattern_to_match = ['Flatten/Squeeze', 'Unsqueeze'] def make_new_node(self, matched_nodes): top_node = matched_nodes[0] base_node = matched_nodes[-1] return self.make_node('Reshape', [top_node.input[0], top_node.input[0] + self.postfix], [base_node.output[0]], top_node.name) def make_new_init(self, matched_nodes): top_node = matched_nodes[0] base_node = matched_nodes[-1] return self.make_int64_initializer(top_node.input[0] + self.postfix, base_node.output[0]) def make_new_vi(self, matched_nodes): top_node = matched_nodes[0] return self.copy_value_info(top_node.input[0])
Ancestors
- Pattern_1
- furiosa_sdk_quantizer.frontend.onnx.transformer.ONNXTransformer
- abc.ABC
Class variables
var pattern_to_match
Methods
def make_new_init(self, matched_nodes)
-
Expand source code
def make_new_init(self, matched_nodes): top_node = matched_nodes[0] base_node = matched_nodes[-1] return self.make_int64_initializer(top_node.input[0] + self.postfix, base_node.output[0])
def make_new_node(self, matched_nodes)
-
Expand source code
def make_new_node(self, matched_nodes): top_node = matched_nodes[0] base_node = matched_nodes[-1] return self.make_node('Reshape', [top_node.input[0], top_node.input[0] + self.postfix], [base_node.output[0]], top_node.name)
def make_new_vi(self, matched_nodes)
-
Expand source code
def make_new_vi(self, matched_nodes): top_node = matched_nodes[0] return self.copy_value_info(top_node.input[0])