Module furiosa.quantizer.frontend.onnx.transformer.fuse_depth_to_space
Expand source code
import abc
import onnx
from furiosa_sdk_quantizer.interfaces.transformer import Transformer
from furiosa_sdk_quantizer.frontend.onnx.transformer import ONNXTransformer
class FuseDepthToSpace(Transformer):
def transform(self, model: onnx.ModelProto) -> onnx.ModelProto:
for transformer in [
Pattern_1,
]:
model = transformer(model).transform()
return model
class Pattern_1(ONNXTransformer, abc.ABC):
"""
transform
prev --> Reshape --> Transpose --> Reshape --> next
to
prev --> DepthToSpace --> next
if Transpose.perm == [0, 1, 4, 2, 5, 3] or == [0, 3, 4, 1, 5, 2]
"""
def pattern_matching(self, base_node):
inputs = base_node.input
pattern_to_match = ['Reshape', 'Transpose', 'Reshape']
matched_nodes = self.pattern_matcher(base_node, pattern_to_match)
if not matched_nodes:
return inputs
if not self.pattern_condition_checker(matched_nodes):
return inputs
top_node, mid_node, _ = matched_nodes
self.transform_to_fuse(matched_nodes,
nodes_to_add=[
self.make_node('DepthToSpace', [top_node.input[0]], [base_node.output[0]],
top_node.name,
**self.get_attrs(top_node, mid_node))
])
return top_node.input
def pattern_condition_checker(self, nodes_to_check):
_, mid_node, _ = nodes_to_check
perm = mid_node.attribute[0].ints
if perm == [0, 1, 4, 2, 5, 3] or perm == [0, 3, 4, 1, 5, 2]:
return True
return False
def get_attrs(self, top_node, mid_node):
permutation = mid_node.attribute[0].ints
reshaped_shape = self.get_value_info_shape(top_node.output[0])
if all(x == y for (x, y) in zip(permutation, [0, 1, 4, 2, 5, 3])):
mode = 'CRD'
blocksize = reshaped_shape[2]
elif all(x == y for (x, y) in zip(permutation, [0, 3, 4, 1, 5, 2])):
mode = 'DCR'
blocksize = reshaped_shape[1]
else:
raise Exception()
return {'blocksize': blocksize, 'mode': mode}
Classes
class FuseDepthToSpace (*args, **kwds)
-
Abstract base class for generic types.
A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::
class Mapping(Generic[KT, VT]): def getitem(self, key: KT) -> VT: … # Etc.
This class can then be used as follows::
def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default
Expand source code
class FuseDepthToSpace(Transformer): def transform(self, model: onnx.ModelProto) -> onnx.ModelProto: for transformer in [ Pattern_1, ]: model = transformer(model).transform() return model
Ancestors
- furiosa_sdk_quantizer.interfaces.transformer.Transformer
- typing.Generic
Methods
def transform(self, model: onnx.onnx_ml_pb2.ModelProto) ‑> onnx.onnx_ml_pb2.ModelProto
-
Expand source code
def transform(self, model: onnx.ModelProto) -> onnx.ModelProto: for transformer in [ Pattern_1, ]: model = transformer(model).transform() return model
class Pattern_1 (model)
-
transform prev –> Reshape –> Transpose –> Reshape –> next to prev –> DepthToSpace –> next
if Transpose.perm == [0, 1, 4, 2, 5, 3] or == [0, 3, 4, 1, 5, 2]
Expand source code
class Pattern_1(ONNXTransformer, abc.ABC): """ transform prev --> Reshape --> Transpose --> Reshape --> next to prev --> DepthToSpace --> next if Transpose.perm == [0, 1, 4, 2, 5, 3] or == [0, 3, 4, 1, 5, 2] """ def pattern_matching(self, base_node): inputs = base_node.input pattern_to_match = ['Reshape', 'Transpose', 'Reshape'] matched_nodes = self.pattern_matcher(base_node, pattern_to_match) if not matched_nodes: return inputs if not self.pattern_condition_checker(matched_nodes): return inputs top_node, mid_node, _ = matched_nodes self.transform_to_fuse(matched_nodes, nodes_to_add=[ self.make_node('DepthToSpace', [top_node.input[0]], [base_node.output[0]], top_node.name, **self.get_attrs(top_node, mid_node)) ]) return top_node.input def pattern_condition_checker(self, nodes_to_check): _, mid_node, _ = nodes_to_check perm = mid_node.attribute[0].ints if perm == [0, 1, 4, 2, 5, 3] or perm == [0, 3, 4, 1, 5, 2]: return True return False def get_attrs(self, top_node, mid_node): permutation = mid_node.attribute[0].ints reshaped_shape = self.get_value_info_shape(top_node.output[0]) if all(x == y for (x, y) in zip(permutation, [0, 1, 4, 2, 5, 3])): mode = 'CRD' blocksize = reshaped_shape[2] elif all(x == y for (x, y) in zip(permutation, [0, 3, 4, 1, 5, 2])): mode = 'DCR' blocksize = reshaped_shape[1] else: raise Exception() return {'blocksize': blocksize, 'mode': mode}
Ancestors
- furiosa_sdk_quantizer.frontend.onnx.transformer.ONNXTransformer
- abc.ABC
Methods
def get_attrs(self, top_node, mid_node)
-
Expand source code
def get_attrs(self, top_node, mid_node): permutation = mid_node.attribute[0].ints reshaped_shape = self.get_value_info_shape(top_node.output[0]) if all(x == y for (x, y) in zip(permutation, [0, 1, 4, 2, 5, 3])): mode = 'CRD' blocksize = reshaped_shape[2] elif all(x == y for (x, y) in zip(permutation, [0, 3, 4, 1, 5, 2])): mode = 'DCR' blocksize = reshaped_shape[1] else: raise Exception() return {'blocksize': blocksize, 'mode': mode}
def pattern_condition_checker(self, nodes_to_check)
-
Expand source code
def pattern_condition_checker(self, nodes_to_check): _, mid_node, _ = nodes_to_check perm = mid_node.attribute[0].ints if perm == [0, 1, 4, 2, 5, 3] or perm == [0, 3, 4, 1, 5, 2]: return True return False
def pattern_matching(self, base_node)
-
Expand source code
def pattern_matching(self, base_node): inputs = base_node.input pattern_to_match = ['Reshape', 'Transpose', 'Reshape'] matched_nodes = self.pattern_matcher(base_node, pattern_to_match) if not matched_nodes: return inputs if not self.pattern_condition_checker(matched_nodes): return inputs top_node, mid_node, _ = matched_nodes self.transform_to_fuse(matched_nodes, nodes_to_add=[ self.make_node('DepthToSpace', [top_node.input[0]], [base_node.output[0]], top_node.name, **self.get_attrs(top_node, mid_node)) ]) return top_node.input