Module furiosa.quantizer.frontend.onnx.transformer.convert_conv1d_to_conv2d

Expand source code
import abc

import onnx
import numpy as np

from furiosa_sdk_quantizer.interfaces.transformer import Transformer
from furiosa_sdk_quantizer.frontend.onnx.transformer import ONNXTransformer


class ConvertConv1dToConv2d(Transformer):
    def transform(self, model: onnx.ModelProto) -> onnx.ModelProto:
        for transformer in [
            Pattern_1,
        ]:
            model = transformer(model).transform()

        return model


class Pattern_1(ONNXTransformer, abc.ABC):
    """
        transform
            prev --> Reshape --> Conv --> Reshape --> next
        to
            prev --> Reshape --> Conv --> Reshape --> next

        if Conv.input[0].ndim == 3, i.e., if Conv1d
    """
    pattern_to_match = ['Reshape', 'Conv', 'Reshape']

    def pattern_matching(self, base_node):
        inputs = base_node.input

        matched_nodes = self.pattern_matcher(base_node, self.pattern_to_match)
        if not matched_nodes:
            return inputs

        if not self.pattern_condition_checker(matched_nodes):
            return inputs

        top_node, mid_node, base_node = matched_nodes
        new_mid_input_shape = [*self.get_value_info_shape(mid_node.input[0]), 1]
        new_top_reshape_shape = [*self.get_initializer_array(top_node.input[1]), 1]
        new_mid_output_shape = [*self.get_value_info_shape(mid_node.output[0]), 1]
        new_mid_weight_shape = [*self.get_value_info_shape(mid_node.input[1]), 1]

        self.transform_to_convert(matched_nodes,
                                  nodes_to_add=[
                                      self.make_node('Reshape', [top_node.input[0], top_node.input[1] + '_converted'],
                                                     [top_node.output[0]],
                                                     top_node.name),
                                      self.make_node('Conv', [mid_node.input[0], mid_node.input[1] + '_converted',
                                                              mid_node.input[2] if len(mid_node.input) == 3 else None],
                                                     [mid_node.output[0]],
                                                     mid_node.name,
                                                     **self.get_attrs(mid_node)),
                                      base_node
                                  ],
                                  inits_to_add=[
                                      self.make_initializer_from_array(
                                          np.array(new_top_reshape_shape), name=top_node.input[1] + '_converted'),
                                      self.make_initializer_from_array(
                                          self.get_initializer_array(mid_node.input[1]).reshape(new_mid_weight_shape),
                                          name=mid_node.input[1] + '_converted'),
                                      self.initializer_map[mid_node.input[0]] if len(mid_node.input) == 3 else None
                                  ],
                                  vis_to_add=[
                                      self.make_tensor_value_info(mid_node.input[0], onnx.TensorProto.FLOAT,
                                                                  new_mid_input_shape),
                                      self.make_tensor_value_info(mid_node.output[0], onnx.TensorProto.FLOAT,
                                                                  new_mid_output_shape)
                                  ]
                                  )

        return top_node.input

    def pattern_condition_checker(self, nodes_to_check):
        _, mid_node, _ = nodes_to_check
        if len(self.get_value_info_shape(mid_node.input[0])) == 3:
            return True
        return False

    def get_attrs(self, mid_node):
        from furiosa_sdk_quantizer.frontend.onnx.quantizer.utils import attribute_to_kwargs
        attrs = attribute_to_kwargs(mid_node.attribute)
        dilations = attrs.get('dilations', [1])
        group = attrs.get('group', 1)
        kernel_shape = attrs['kernel_shape']
        pads = attrs.get('pads', [0, 0])
        strides = attrs.get('strides', [1])

        return {'dilations': [*dilations, 1],
                'group': group,
                'kernel_shape': [*kernel_shape, 1],
                'pads': [pads[0], 0, pads[1], 0],
                'strides': [strides[0], 1]}

Classes

class ConvertConv1dToConv2d (*args, **kwds)

Abstract base class for generic types.

A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::

class Mapping(Generic[KT, VT]): def getitem(self, key: KT) -> VT: … # Etc.

This class can then be used as follows::

def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default

Expand source code
class ConvertConv1dToConv2d(Transformer):
    def transform(self, model: onnx.ModelProto) -> onnx.ModelProto:
        for transformer in [
            Pattern_1,
        ]:
            model = transformer(model).transform()

        return model

Ancestors

  • furiosa_sdk_quantizer.interfaces.transformer.Transformer
  • typing.Generic

Methods

def transform(self, model: onnx.onnx_ml_pb2.ModelProto) ‑> onnx.onnx_ml_pb2.ModelProto
Expand source code
def transform(self, model: onnx.ModelProto) -> onnx.ModelProto:
    for transformer in [
        Pattern_1,
    ]:
        model = transformer(model).transform()

    return model
class Pattern_1 (model)

transform prev –> Reshape –> Conv –> Reshape –> next to prev –> Reshape –> Conv –> Reshape –> next

if Conv.input[0].ndim == 3, i.e., if Conv1d

Expand source code
class Pattern_1(ONNXTransformer, abc.ABC):
    """
        transform
            prev --> Reshape --> Conv --> Reshape --> next
        to
            prev --> Reshape --> Conv --> Reshape --> next

        if Conv.input[0].ndim == 3, i.e., if Conv1d
    """
    pattern_to_match = ['Reshape', 'Conv', 'Reshape']

    def pattern_matching(self, base_node):
        inputs = base_node.input

        matched_nodes = self.pattern_matcher(base_node, self.pattern_to_match)
        if not matched_nodes:
            return inputs

        if not self.pattern_condition_checker(matched_nodes):
            return inputs

        top_node, mid_node, base_node = matched_nodes
        new_mid_input_shape = [*self.get_value_info_shape(mid_node.input[0]), 1]
        new_top_reshape_shape = [*self.get_initializer_array(top_node.input[1]), 1]
        new_mid_output_shape = [*self.get_value_info_shape(mid_node.output[0]), 1]
        new_mid_weight_shape = [*self.get_value_info_shape(mid_node.input[1]), 1]

        self.transform_to_convert(matched_nodes,
                                  nodes_to_add=[
                                      self.make_node('Reshape', [top_node.input[0], top_node.input[1] + '_converted'],
                                                     [top_node.output[0]],
                                                     top_node.name),
                                      self.make_node('Conv', [mid_node.input[0], mid_node.input[1] + '_converted',
                                                              mid_node.input[2] if len(mid_node.input) == 3 else None],
                                                     [mid_node.output[0]],
                                                     mid_node.name,
                                                     **self.get_attrs(mid_node)),
                                      base_node
                                  ],
                                  inits_to_add=[
                                      self.make_initializer_from_array(
                                          np.array(new_top_reshape_shape), name=top_node.input[1] + '_converted'),
                                      self.make_initializer_from_array(
                                          self.get_initializer_array(mid_node.input[1]).reshape(new_mid_weight_shape),
                                          name=mid_node.input[1] + '_converted'),
                                      self.initializer_map[mid_node.input[0]] if len(mid_node.input) == 3 else None
                                  ],
                                  vis_to_add=[
                                      self.make_tensor_value_info(mid_node.input[0], onnx.TensorProto.FLOAT,
                                                                  new_mid_input_shape),
                                      self.make_tensor_value_info(mid_node.output[0], onnx.TensorProto.FLOAT,
                                                                  new_mid_output_shape)
                                  ]
                                  )

        return top_node.input

    def pattern_condition_checker(self, nodes_to_check):
        _, mid_node, _ = nodes_to_check
        if len(self.get_value_info_shape(mid_node.input[0])) == 3:
            return True
        return False

    def get_attrs(self, mid_node):
        from furiosa_sdk_quantizer.frontend.onnx.quantizer.utils import attribute_to_kwargs
        attrs = attribute_to_kwargs(mid_node.attribute)
        dilations = attrs.get('dilations', [1])
        group = attrs.get('group', 1)
        kernel_shape = attrs['kernel_shape']
        pads = attrs.get('pads', [0, 0])
        strides = attrs.get('strides', [1])

        return {'dilations': [*dilations, 1],
                'group': group,
                'kernel_shape': [*kernel_shape, 1],
                'pads': [pads[0], 0, pads[1], 0],
                'strides': [strides[0], 1]}

Ancestors

  • furiosa_sdk_quantizer.frontend.onnx.transformer.ONNXTransformer
  • abc.ABC

Class variables

var pattern_to_match

Methods

def get_attrs(self, mid_node)
Expand source code
def get_attrs(self, mid_node):
    from furiosa_sdk_quantizer.frontend.onnx.quantizer.utils import attribute_to_kwargs
    attrs = attribute_to_kwargs(mid_node.attribute)
    dilations = attrs.get('dilations', [1])
    group = attrs.get('group', 1)
    kernel_shape = attrs['kernel_shape']
    pads = attrs.get('pads', [0, 0])
    strides = attrs.get('strides', [1])

    return {'dilations': [*dilations, 1],
            'group': group,
            'kernel_shape': [*kernel_shape, 1],
            'pads': [pads[0], 0, pads[1], 0],
            'strides': [strides[0], 1]}
def pattern_condition_checker(self, nodes_to_check)
Expand source code
def pattern_condition_checker(self, nodes_to_check):
    _, mid_node, _ = nodes_to_check
    if len(self.get_value_info_shape(mid_node.input[0])) == 3:
        return True
    return False
def pattern_matching(self, base_node)
Expand source code
def pattern_matching(self, base_node):
    inputs = base_node.input

    matched_nodes = self.pattern_matcher(base_node, self.pattern_to_match)
    if not matched_nodes:
        return inputs

    if not self.pattern_condition_checker(matched_nodes):
        return inputs

    top_node, mid_node, base_node = matched_nodes
    new_mid_input_shape = [*self.get_value_info_shape(mid_node.input[0]), 1]
    new_top_reshape_shape = [*self.get_initializer_array(top_node.input[1]), 1]
    new_mid_output_shape = [*self.get_value_info_shape(mid_node.output[0]), 1]
    new_mid_weight_shape = [*self.get_value_info_shape(mid_node.input[1]), 1]

    self.transform_to_convert(matched_nodes,
                              nodes_to_add=[
                                  self.make_node('Reshape', [top_node.input[0], top_node.input[1] + '_converted'],
                                                 [top_node.output[0]],
                                                 top_node.name),
                                  self.make_node('Conv', [mid_node.input[0], mid_node.input[1] + '_converted',
                                                          mid_node.input[2] if len(mid_node.input) == 3 else None],
                                                 [mid_node.output[0]],
                                                 mid_node.name,
                                                 **self.get_attrs(mid_node)),
                                  base_node
                              ],
                              inits_to_add=[
                                  self.make_initializer_from_array(
                                      np.array(new_top_reshape_shape), name=top_node.input[1] + '_converted'),
                                  self.make_initializer_from_array(
                                      self.get_initializer_array(mid_node.input[1]).reshape(new_mid_weight_shape),
                                      name=mid_node.input[1] + '_converted'),
                                  self.initializer_map[mid_node.input[0]] if len(mid_node.input) == 3 else None
                              ],
                              vis_to_add=[
                                  self.make_tensor_value_info(mid_node.input[0], onnx.TensorProto.FLOAT,
                                                              new_mid_input_shape),
                                  self.make_tensor_value_info(mid_node.output[0], onnx.TensorProto.FLOAT,
                                                              new_mid_output_shape)
                              ]
                              )

    return top_node.input