Module furiosa.quantizer.frontend.onnx.transformer.fuse_redundant_reshape_pattern

Expand source code
import abc

import onnx

from furiosa_sdk_quantizer.interfaces.transformer import Transformer
from furiosa_sdk_quantizer.frontend.onnx.transformer import ONNXTransformer


class FuseRedundantReshapePattern(Transformer):
    def transform(self, model: onnx.ModelProto) -> onnx.ModelProto:
        for transformer in [
            Pattern_2,
            Pattern_1,
            Pattern_3,
        ]:
            model = transformer(model).transform()

        return model


class Pattern_1(ONNXTransformer, abc.ABC):
    """
        transform
            prev --> Reshape --> Reshape --> next
        to
            prev --> Reshape --> next

        if prev.output[0].shape != next.input[0].shape
    """
    pattern_to_match = ['Reshape', 'Reshape']
    postfix = '_reshape_fused'

    def pattern_matching(self, base_node):
        inputs = base_node.input

        matched_nodes = self.pattern_matcher(base_node, self.pattern_to_match)
        if not matched_nodes:
            return inputs

        if not self.pattern_condition_checker(matched_nodes):
            return inputs

        top_node = matched_nodes[0]

        self.transform_to_fuse(matched_nodes,
                               nodes_to_add=[self.make_new_node(matched_nodes)],
                               inits_to_add=[self.make_new_init(matched_nodes)],
                               vis_to_add=[self.make_new_vi(matched_nodes)] if self.make_new_vi(
                                   matched_nodes) else None)
        return top_node.input

    def pattern_condition_checker(self, nodes_to_check):
        top_node = nodes_to_check[0]
        base_node = nodes_to_check[-1]
        if self.is_same_shape(top_node.input[0], base_node.output[0]):
            return False
        return True

    def make_new_node(self, matched_nodes):
        top_node = matched_nodes[0]
        base_node = matched_nodes[-1]
        return self.make_node('Reshape', [top_node.input[0], top_node.input[1] + self.postfix],
                              [base_node.output[0]],
                              name=top_node.name)

    def make_new_init(self, matched_nodes):
        top_node = matched_nodes[0]
        base_node = matched_nodes[-1]
        return self.make_int64_initializer(top_node.input[1] + self.postfix, base_node.output[0])

    def make_new_vi(self, matched_nodes):
        return None


class Pattern_2(Pattern_1, abc.ABC):
    """
        transform
            prev --> Reshape --> Reshape --> Reshape --> next
        to
            prev --> Reshape --> next

        if prev.output[0].shape != next.input[0].shape
    """
    pattern_to_match = ['Reshape', 'Reshape', 'Reshape']


class Pattern_3(Pattern_1, abc.ABC):
    """
        transform
            prev --> Flatten/Squeeze --> Unsqueeze --> next
        to
            prev --> Reshape --> next
        if prev.output[0].shape != next.input[0].shape
    """
    pattern_to_match = ['Flatten/Squeeze', 'Unsqueeze']

    def make_new_node(self, matched_nodes):
        top_node = matched_nodes[0]
        base_node = matched_nodes[-1]
        return self.make_node('Reshape', [top_node.input[0], top_node.input[0] + self.postfix],
                              [base_node.output[0]], top_node.name)

    def make_new_init(self, matched_nodes):
        top_node = matched_nodes[0]
        base_node = matched_nodes[-1]
        return self.make_int64_initializer(top_node.input[0] + self.postfix, base_node.output[0])

    def make_new_vi(self, matched_nodes):
        top_node = matched_nodes[0]
        return self.copy_value_info(top_node.input[0])

Classes

class FuseRedundantReshapePattern (*args, **kwds)

Abstract base class for generic types.

A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::

class Mapping(Generic[KT, VT]): def getitem(self, key: KT) -> VT: … # Etc.

This class can then be used as follows::

def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default

Expand source code
class FuseRedundantReshapePattern(Transformer):
    def transform(self, model: onnx.ModelProto) -> onnx.ModelProto:
        for transformer in [
            Pattern_2,
            Pattern_1,
            Pattern_3,
        ]:
            model = transformer(model).transform()

        return model

Ancestors

  • furiosa_sdk_quantizer.interfaces.transformer.Transformer
  • typing.Generic

Methods

def transform(self, model: onnx.onnx_ml_pb2.ModelProto) ‑> onnx.onnx_ml_pb2.ModelProto
Expand source code
def transform(self, model: onnx.ModelProto) -> onnx.ModelProto:
    for transformer in [
        Pattern_2,
        Pattern_1,
        Pattern_3,
    ]:
        model = transformer(model).transform()

    return model
class Pattern_1 (model)

transform prev –> Reshape –> Reshape –> next to prev –> Reshape –> next

if prev.output[0].shape != next.input[0].shape

Expand source code
class Pattern_1(ONNXTransformer, abc.ABC):
    """
        transform
            prev --> Reshape --> Reshape --> next
        to
            prev --> Reshape --> next

        if prev.output[0].shape != next.input[0].shape
    """
    pattern_to_match = ['Reshape', 'Reshape']
    postfix = '_reshape_fused'

    def pattern_matching(self, base_node):
        inputs = base_node.input

        matched_nodes = self.pattern_matcher(base_node, self.pattern_to_match)
        if not matched_nodes:
            return inputs

        if not self.pattern_condition_checker(matched_nodes):
            return inputs

        top_node = matched_nodes[0]

        self.transform_to_fuse(matched_nodes,
                               nodes_to_add=[self.make_new_node(matched_nodes)],
                               inits_to_add=[self.make_new_init(matched_nodes)],
                               vis_to_add=[self.make_new_vi(matched_nodes)] if self.make_new_vi(
                                   matched_nodes) else None)
        return top_node.input

    def pattern_condition_checker(self, nodes_to_check):
        top_node = nodes_to_check[0]
        base_node = nodes_to_check[-1]
        if self.is_same_shape(top_node.input[0], base_node.output[0]):
            return False
        return True

    def make_new_node(self, matched_nodes):
        top_node = matched_nodes[0]
        base_node = matched_nodes[-1]
        return self.make_node('Reshape', [top_node.input[0], top_node.input[1] + self.postfix],
                              [base_node.output[0]],
                              name=top_node.name)

    def make_new_init(self, matched_nodes):
        top_node = matched_nodes[0]
        base_node = matched_nodes[-1]
        return self.make_int64_initializer(top_node.input[1] + self.postfix, base_node.output[0])

    def make_new_vi(self, matched_nodes):
        return None

Ancestors

  • furiosa_sdk_quantizer.frontend.onnx.transformer.ONNXTransformer
  • abc.ABC

Subclasses

Class variables

var pattern_to_match
var postfix

Methods

def make_new_init(self, matched_nodes)
Expand source code
def make_new_init(self, matched_nodes):
    top_node = matched_nodes[0]
    base_node = matched_nodes[-1]
    return self.make_int64_initializer(top_node.input[1] + self.postfix, base_node.output[0])
def make_new_node(self, matched_nodes)
Expand source code
def make_new_node(self, matched_nodes):
    top_node = matched_nodes[0]
    base_node = matched_nodes[-1]
    return self.make_node('Reshape', [top_node.input[0], top_node.input[1] + self.postfix],
                          [base_node.output[0]],
                          name=top_node.name)
def make_new_vi(self, matched_nodes)
Expand source code
def make_new_vi(self, matched_nodes):
    return None
def pattern_condition_checker(self, nodes_to_check)
Expand source code
def pattern_condition_checker(self, nodes_to_check):
    top_node = nodes_to_check[0]
    base_node = nodes_to_check[-1]
    if self.is_same_shape(top_node.input[0], base_node.output[0]):
        return False
    return True
def pattern_matching(self, base_node)
Expand source code
def pattern_matching(self, base_node):
    inputs = base_node.input

    matched_nodes = self.pattern_matcher(base_node, self.pattern_to_match)
    if not matched_nodes:
        return inputs

    if not self.pattern_condition_checker(matched_nodes):
        return inputs

    top_node = matched_nodes[0]

    self.transform_to_fuse(matched_nodes,
                           nodes_to_add=[self.make_new_node(matched_nodes)],
                           inits_to_add=[self.make_new_init(matched_nodes)],
                           vis_to_add=[self.make_new_vi(matched_nodes)] if self.make_new_vi(
                               matched_nodes) else None)
    return top_node.input
class Pattern_2 (model)

transform prev –> Reshape –> Reshape –> Reshape –> next to prev –> Reshape –> next

if prev.output[0].shape != next.input[0].shape

Expand source code
class Pattern_2(Pattern_1, abc.ABC):
    """
        transform
            prev --> Reshape --> Reshape --> Reshape --> next
        to
            prev --> Reshape --> next

        if prev.output[0].shape != next.input[0].shape
    """
    pattern_to_match = ['Reshape', 'Reshape', 'Reshape']

Ancestors

  • Pattern_1
  • furiosa_sdk_quantizer.frontend.onnx.transformer.ONNXTransformer
  • abc.ABC

Class variables

var pattern_to_match
class Pattern_3 (model)

transform prev –> Flatten/Squeeze –> Unsqueeze –> next to prev –> Reshape –> next if prev.output[0].shape != next.input[0].shape

Expand source code
class Pattern_3(Pattern_1, abc.ABC):
    """
        transform
            prev --> Flatten/Squeeze --> Unsqueeze --> next
        to
            prev --> Reshape --> next
        if prev.output[0].shape != next.input[0].shape
    """
    pattern_to_match = ['Flatten/Squeeze', 'Unsqueeze']

    def make_new_node(self, matched_nodes):
        top_node = matched_nodes[0]
        base_node = matched_nodes[-1]
        return self.make_node('Reshape', [top_node.input[0], top_node.input[0] + self.postfix],
                              [base_node.output[0]], top_node.name)

    def make_new_init(self, matched_nodes):
        top_node = matched_nodes[0]
        base_node = matched_nodes[-1]
        return self.make_int64_initializer(top_node.input[0] + self.postfix, base_node.output[0])

    def make_new_vi(self, matched_nodes):
        top_node = matched_nodes[0]
        return self.copy_value_info(top_node.input[0])

Ancestors

  • Pattern_1
  • furiosa_sdk_quantizer.frontend.onnx.transformer.ONNXTransformer
  • abc.ABC

Class variables

var pattern_to_match

Methods

def make_new_init(self, matched_nodes)
Expand source code
def make_new_init(self, matched_nodes):
    top_node = matched_nodes[0]
    base_node = matched_nodes[-1]
    return self.make_int64_initializer(top_node.input[0] + self.postfix, base_node.output[0])
def make_new_node(self, matched_nodes)
Expand source code
def make_new_node(self, matched_nodes):
    top_node = matched_nodes[0]
    base_node = matched_nodes[-1]
    return self.make_node('Reshape', [top_node.input[0], top_node.input[0] + self.postfix],
                          [base_node.output[0]], top_node.name)
def make_new_vi(self, matched_nodes)
Expand source code
def make_new_vi(self, matched_nodes):
    top_node = matched_nodes[0]
    return self.copy_value_info(top_node.input[0])