Module furiosa.quantizer.frontend.onnx.transformer
Expand source code
from typing import List, Set, Optional
from collections import OrderedDict
import onnx
import numpy as np
from onnx import numpy_helper
from onnx.helper import make_node, make_tensor, make_tensor_value_info
from furiosa_sdk_quantizer.frontend.onnx.transformer import utils
from furiosa_sdk_quantizer.frontend.onnx.utils.check_model import check_model
class ONNXTransformer:
def __init__(self, model):
self.model = model
self.producer_map = {node_output: node for node in model.graph.node for node_output in node.output}
self.optimizer_map = OrderedDict({node.name: node for node in model.graph.node})
self.initializer_map = {init.name: init for init in model.graph.initializer}
self.node_input_map = {node.name: node.input for node in model.graph.node}
self.value_info_map = {vi.name: vi for vi in model.graph.value_info}
self.graph_input_map = {inp.name: inp for inp in model.graph.input}
self.graph_output_map = {out.name: out for out in model.graph.output}
def transform(self):
outputs = list(self.graph_output_map.keys())
# To prevent traversing cyclic connections
visited: Set[str] = set()
visited_node: List[onnx.NodeProto] = list()
while len(outputs) > 0:
output = outputs.pop(0)
if output not in self.producer_map:
continue
node = self.producer_map[output]
# prevent duplicate specs from being created from nodes that have multiple outputs like Split.
if node in visited_node:
continue
inputs = self.pattern_matching(node)
# Put predecessor of node to new outputs
outputs += list(filter(lambda input: input not in visited, inputs))
visited.update(inputs)
visited_node.append(node)
return self.build_optimized_model(self.model)
def update_graph_fields(self, model):
for field in ['initializer', 'input', 'output', 'value_info']:
model.graph.ClearField(field)
getattr(model.graph, field).extend(self.get_map_values(field))
return model
def build_optimized_model(self, model):
model = self.update_graph_fields(model)
new_nodes = []
for member in self.get_map_values('node'):
if isinstance(member, onnx.NodeProto):
new_nodes.append(member)
elif isinstance(member, list):
new_nodes.extend(member)
else:
raise Exception(member)
model = utils.rebuild_model(model, new_nodes)
check_model(model)
return model
def make_node(self, op_type, inputs, outputs, name=None, **attrs):
inputs = [x for x in inputs if x is not None]
return make_node(op_type, inputs, outputs, name, **attrs)
def make_tensor_value_info(self, name, elem_type, shape):
return make_tensor_value_info(name, elem_type, shape)
def make_initializer_from_array(self, array: np.array, name: Optional[str] = None) -> onnx.TensorProto:
return numpy_helper.from_array(array, name)
def make_int64_initializer(self, name, target_name):
return make_tensor(name, onnx.TensorProto.INT64,
(len(self.get_value_info_shape(target_name)),),
self.get_value_info_shape(target_name))
def copy_value_info(self, name):
if name in self.graph_input_map:
return self.graph_input_map[name]
elif name in self.value_info_map:
return self.value_info_map[name]
else:
raise Exception('%s not found.' % name)
def get_value_info_shape(self, value_info_name: str) -> List[int]:
def _get_shape(name, vi_map):
return [dim.dim_value for dim in vi_map[name].type.tensor_type.shape.dim]
if value_info_name in self.value_info_map:
return _get_shape(value_info_name, self.value_info_map)
elif value_info_name in self.graph_output_map:
return _get_shape(value_info_name, self.graph_output_map)
elif value_info_name in self.graph_input_map:
return _get_shape(value_info_name, self.graph_input_map)
else:
raise Exception('%s not found.' % value_info_name)
def get_map_values(self, field):
if any(field == word for word in ['input', 'output']):
field_map = 'graph_' + field + '_map'
elif field == 'node':
field_map = 'optimizer_map'
else:
field_map = field + '_map'
return self.make_field_unique(getattr(self, field_map).values())
def get_initializer_array(self, node_input):
if node_input not in self.initializer_map:
return None
return numpy_helper.to_array(self.initializer_map[node_input])
def get_init_node_input(self, node):
init_node_input = None
for node_input in node.input:
if node_input not in self.initializer_map:
continue
init_node_input = node_input
return init_node_input
def get_data_node_input(self, node):
data_node_input = None
for node_input in node.input:
if node_input in self.initializer_map:
continue
data_node_input = node_input
return data_node_input
def make_field_unique(self, values):
seen = []
for v in values:
if v not in seen:
seen.append(v)
return seen
def find_next_node(self, node: onnx.NodeProto) -> List[onnx.NodeProto]:
next_nodes = []
for v in self.optimizer_map.values():
if not v:
continue
if isinstance(v, list):
v = v[0]
if not any(output == v_input for output in node.output for v_input in v.input):
continue
next_nodes.append(v)
return next_nodes
def find_prev_node(self, node_input: str) -> onnx.NodeProto:
if node_input not in self.producer_map:
return None
return self.producer_map[node_input]
def is_op_type(self, op_type: str, target_op_types: List[str]):
if any(op_type == target for target in target_op_types):
return True
return False
def is_same_shape(self, input_1, input_2):
if self.get_value_info_shape(input_1) != self.get_value_info_shape(input_2):
return False
return True
def traverse_prev_node(self, producer_map_key: str, target_op_types: List[str]):
prev_node = self.find_prev_node(producer_map_key)
if not prev_node:
return None
if not self.is_op_type(prev_node.op_type, target_op_types):
return False
return prev_node
def update_single_optimizer_map(self, node: onnx.NodeProto, dest_name):
self.optimizer_map[dest_name] = node
def update_multiple_optimizer_map(self, nodes: List[onnx.NodeProto], dest_name):
self.optimizer_map[dest_name] = nodes
def update_single_value_info_map(self, value_info: onnx.ValueInfoProto):
if value_info.name in self.graph_input_map:
self.graph_input_map[value_info.name] = value_info
else:
self.value_info_map[value_info.name] = value_info
def update_multiple_value_info_map(self, value_infos: List[onnx.ValueInfoProto]):
for vi in value_infos:
self.update_single_value_info_map(vi)
def update_single_initializer_map(self, initializer: onnx.TensorProto):
self.initializer_map[initializer.name] = initializer
self.graph_input_map[initializer.name] = make_tensor_value_info(initializer.name, initializer.data_type,
numpy_helper.to_array(initializer).shape)
def update_multiple_initializer_map(self, initializers: List[onnx.TensorProto]):
for init in initializers:
if not init:
continue
self.update_single_initializer_map(init)
def pop_single_optimizer_map(self, node: onnx.NodeProto):
self.optimizer_map[node.name] = []
def pop_multiple_optimizer_map(self, nodes: List[onnx.NodeProto]):
for node in nodes:
self.pop_single_optimizer_map(node)
def pop_single_value_info_map(self, vi: onnx.NodeProto):
self.value_info_map.pop(vi.name)
if vi.name in self.graph_output_map:
self.value_info_map.pop(vi.name)
if vi.name in self.graph_input_map:
self.value_info_map.pop(vi.name)
def pop_multiple_value_info_map(self, vis: List[onnx.ValueInfoProto]):
for vi in vis:
self.pop_single_value_info_map(vi)
def pop_single_initializer_map(self, init: onnx.TensorProto):
self.initializer_map.pop(init.name)
self.graph_input_map.pop(init.name)
def pop_multiple_initializer_map(self, nodes: List[onnx.TensorProto]):
for node in nodes:
self.pop_single_initializer_map(node)
def bridge_disconnected_nodes(self, node_0: onnx.NodeProto, next_nodes: List[onnx.NodeProto], new_input):
"""
For a graph changed, for example,
before) prev --> node_1 --> node_0 --> next
after) prev --> node_1 --> ( ) -/-> next
This function bridges node_1 and next as follows:
prev --> node_1 --> next
by assigning next.input[y] = node_1.output[x]
"""
for next_node in next_nodes:
for idx, next_node_input in enumerate(next_node.input):
for node_output in node_0.output:
if node_output != next_node_input:
continue
next_node.input[idx] = new_input
self.update_single_optimizer_map(next_node, next_node.name)
for node_output in node_0.output:
for idx, output in enumerate(self.model.graph.output):
if node_output != output.name:
continue
self.graph_output_map[node_output] = self.copy_value_info(new_input)
def transform_to_eliminate(self, nodes_to_remove: List[onnx.NodeProto], new_input):
self.pop_multiple_optimizer_map(nodes_to_remove)
self.bridge_disconnected_nodes(nodes_to_remove[-1], self.find_next_node(nodes_to_remove[-1]), new_input)
def transform_to_convert(self, nodes_to_remove: List[onnx.NodeProto],
nodes_to_add: Optional[List[onnx.NodeProto]] = None,
inits_to_add: Optional[List[onnx.TensorProto]] = None,
vis_to_add: Optional[List[onnx.ValueInfoProto]] = None):
self.transform_to_fuse(nodes_to_remove,
nodes_to_add,
inits_to_add,
vis_to_add)
def transform_to_fuse(self, nodes_to_remove: List[onnx.NodeProto],
nodes_to_add: Optional[List[onnx.NodeProto]] = None,
inits_to_add: Optional[List[onnx.TensorProto]] = None,
vis_to_add: Optional[List[onnx.ValueInfoProto]] = None):
self.pop_multiple_optimizer_map(nodes_to_remove)
if nodes_to_add:
self.update_multiple_optimizer_map(nodes_to_add, nodes_to_remove[0].name)
if inits_to_add:
self.update_multiple_initializer_map(inits_to_add)
if vis_to_add:
self.update_multiple_value_info_map(vis_to_add)
def pattern_matching(self, node):
raise NotImplementedError
def pattern_matcher(self, node, pattern_to_match: List[str]):
decoded_pattern = [p.split('/') for p in pattern_to_match]
decoded_pattern.reverse()
op_type_0 = decoded_pattern.pop(0)
if not self.is_op_type(node.op_type, op_type_0):
return None
matched_nodes = [node]
while decoded_pattern:
op_type_1 = decoded_pattern.pop(0)
node_1 = None
# TODO impl "multi-path search"
for node_input in node.input:
node_1 = self.traverse_prev_node(node_input, op_type_1)
if node_1:
break
if not node_1:
return None
node = node_1
matched_nodes.append(node)
matched_nodes.reverse()
return matched_nodes
def pattern_condition_checker(self, nodes_to_check):
raise NotImplementedError
Sub-modules
furiosa.quantizer.frontend.onnx.transformer.convert_2d_sum_to_add
furiosa.quantizer.frontend.onnx.transformer.convert_clip_attr_to_input
furiosa.quantizer.frontend.onnx.transformer.convert_conv1d_to_conv2d
furiosa.quantizer.frontend.onnx.transformer.deprecated
furiosa.quantizer.frontend.onnx.transformer.eliminate_argmax_output
furiosa.quantizer.frontend.onnx.transformer.eliminate_redundant_reshape_pattern
furiosa.quantizer.frontend.onnx.transformer.experimental
furiosa.quantizer.frontend.onnx.transformer.extract_constant_to_initializer
furiosa.quantizer.frontend.onnx.transformer.fuse_bn_into_conv
furiosa.quantizer.frontend.onnx.transformer.fuse_conv
furiosa.quantizer.frontend.onnx.transformer.fuse_depth_to_space
furiosa.quantizer.frontend.onnx.transformer.fuse_gelu
furiosa.quantizer.frontend.onnx.transformer.fuse_layer_normalization
furiosa.quantizer.frontend.onnx.transformer.fuse_lp_normalization
furiosa.quantizer.frontend.onnx.transformer.fuse_pad
furiosa.quantizer.frontend.onnx.transformer.fuse_redundant_reshape_pattern
furiosa.quantizer.frontend.onnx.transformer.polish_model
furiosa.quantizer.frontend.onnx.transformer.utils
Classes
class ONNXTransformer (model)
-
Expand source code
class ONNXTransformer: def __init__(self, model): self.model = model self.producer_map = {node_output: node for node in model.graph.node for node_output in node.output} self.optimizer_map = OrderedDict({node.name: node for node in model.graph.node}) self.initializer_map = {init.name: init for init in model.graph.initializer} self.node_input_map = {node.name: node.input for node in model.graph.node} self.value_info_map = {vi.name: vi for vi in model.graph.value_info} self.graph_input_map = {inp.name: inp for inp in model.graph.input} self.graph_output_map = {out.name: out for out in model.graph.output} def transform(self): outputs = list(self.graph_output_map.keys()) # To prevent traversing cyclic connections visited: Set[str] = set() visited_node: List[onnx.NodeProto] = list() while len(outputs) > 0: output = outputs.pop(0) if output not in self.producer_map: continue node = self.producer_map[output] # prevent duplicate specs from being created from nodes that have multiple outputs like Split. if node in visited_node: continue inputs = self.pattern_matching(node) # Put predecessor of node to new outputs outputs += list(filter(lambda input: input not in visited, inputs)) visited.update(inputs) visited_node.append(node) return self.build_optimized_model(self.model) def update_graph_fields(self, model): for field in ['initializer', 'input', 'output', 'value_info']: model.graph.ClearField(field) getattr(model.graph, field).extend(self.get_map_values(field)) return model def build_optimized_model(self, model): model = self.update_graph_fields(model) new_nodes = [] for member in self.get_map_values('node'): if isinstance(member, onnx.NodeProto): new_nodes.append(member) elif isinstance(member, list): new_nodes.extend(member) else: raise Exception(member) model = utils.rebuild_model(model, new_nodes) check_model(model) return model def make_node(self, op_type, inputs, outputs, name=None, **attrs): inputs = [x for x in inputs if x is not None] return make_node(op_type, inputs, outputs, name, **attrs) def make_tensor_value_info(self, name, elem_type, shape): return make_tensor_value_info(name, elem_type, shape) def make_initializer_from_array(self, array: np.array, name: Optional[str] = None) -> onnx.TensorProto: return numpy_helper.from_array(array, name) def make_int64_initializer(self, name, target_name): return make_tensor(name, onnx.TensorProto.INT64, (len(self.get_value_info_shape(target_name)),), self.get_value_info_shape(target_name)) def copy_value_info(self, name): if name in self.graph_input_map: return self.graph_input_map[name] elif name in self.value_info_map: return self.value_info_map[name] else: raise Exception('%s not found.' % name) def get_value_info_shape(self, value_info_name: str) -> List[int]: def _get_shape(name, vi_map): return [dim.dim_value for dim in vi_map[name].type.tensor_type.shape.dim] if value_info_name in self.value_info_map: return _get_shape(value_info_name, self.value_info_map) elif value_info_name in self.graph_output_map: return _get_shape(value_info_name, self.graph_output_map) elif value_info_name in self.graph_input_map: return _get_shape(value_info_name, self.graph_input_map) else: raise Exception('%s not found.' % value_info_name) def get_map_values(self, field): if any(field == word for word in ['input', 'output']): field_map = 'graph_' + field + '_map' elif field == 'node': field_map = 'optimizer_map' else: field_map = field + '_map' return self.make_field_unique(getattr(self, field_map).values()) def get_initializer_array(self, node_input): if node_input not in self.initializer_map: return None return numpy_helper.to_array(self.initializer_map[node_input]) def get_init_node_input(self, node): init_node_input = None for node_input in node.input: if node_input not in self.initializer_map: continue init_node_input = node_input return init_node_input def get_data_node_input(self, node): data_node_input = None for node_input in node.input: if node_input in self.initializer_map: continue data_node_input = node_input return data_node_input def make_field_unique(self, values): seen = [] for v in values: if v not in seen: seen.append(v) return seen def find_next_node(self, node: onnx.NodeProto) -> List[onnx.NodeProto]: next_nodes = [] for v in self.optimizer_map.values(): if not v: continue if isinstance(v, list): v = v[0] if not any(output == v_input for output in node.output for v_input in v.input): continue next_nodes.append(v) return next_nodes def find_prev_node(self, node_input: str) -> onnx.NodeProto: if node_input not in self.producer_map: return None return self.producer_map[node_input] def is_op_type(self, op_type: str, target_op_types: List[str]): if any(op_type == target for target in target_op_types): return True return False def is_same_shape(self, input_1, input_2): if self.get_value_info_shape(input_1) != self.get_value_info_shape(input_2): return False return True def traverse_prev_node(self, producer_map_key: str, target_op_types: List[str]): prev_node = self.find_prev_node(producer_map_key) if not prev_node: return None if not self.is_op_type(prev_node.op_type, target_op_types): return False return prev_node def update_single_optimizer_map(self, node: onnx.NodeProto, dest_name): self.optimizer_map[dest_name] = node def update_multiple_optimizer_map(self, nodes: List[onnx.NodeProto], dest_name): self.optimizer_map[dest_name] = nodes def update_single_value_info_map(self, value_info: onnx.ValueInfoProto): if value_info.name in self.graph_input_map: self.graph_input_map[value_info.name] = value_info else: self.value_info_map[value_info.name] = value_info def update_multiple_value_info_map(self, value_infos: List[onnx.ValueInfoProto]): for vi in value_infos: self.update_single_value_info_map(vi) def update_single_initializer_map(self, initializer: onnx.TensorProto): self.initializer_map[initializer.name] = initializer self.graph_input_map[initializer.name] = make_tensor_value_info(initializer.name, initializer.data_type, numpy_helper.to_array(initializer).shape) def update_multiple_initializer_map(self, initializers: List[onnx.TensorProto]): for init in initializers: if not init: continue self.update_single_initializer_map(init) def pop_single_optimizer_map(self, node: onnx.NodeProto): self.optimizer_map[node.name] = [] def pop_multiple_optimizer_map(self, nodes: List[onnx.NodeProto]): for node in nodes: self.pop_single_optimizer_map(node) def pop_single_value_info_map(self, vi: onnx.NodeProto): self.value_info_map.pop(vi.name) if vi.name in self.graph_output_map: self.value_info_map.pop(vi.name) if vi.name in self.graph_input_map: self.value_info_map.pop(vi.name) def pop_multiple_value_info_map(self, vis: List[onnx.ValueInfoProto]): for vi in vis: self.pop_single_value_info_map(vi) def pop_single_initializer_map(self, init: onnx.TensorProto): self.initializer_map.pop(init.name) self.graph_input_map.pop(init.name) def pop_multiple_initializer_map(self, nodes: List[onnx.TensorProto]): for node in nodes: self.pop_single_initializer_map(node) def bridge_disconnected_nodes(self, node_0: onnx.NodeProto, next_nodes: List[onnx.NodeProto], new_input): """ For a graph changed, for example, before) prev --> node_1 --> node_0 --> next after) prev --> node_1 --> ( ) -/-> next This function bridges node_1 and next as follows: prev --> node_1 --> next by assigning next.input[y] = node_1.output[x] """ for next_node in next_nodes: for idx, next_node_input in enumerate(next_node.input): for node_output in node_0.output: if node_output != next_node_input: continue next_node.input[idx] = new_input self.update_single_optimizer_map(next_node, next_node.name) for node_output in node_0.output: for idx, output in enumerate(self.model.graph.output): if node_output != output.name: continue self.graph_output_map[node_output] = self.copy_value_info(new_input) def transform_to_eliminate(self, nodes_to_remove: List[onnx.NodeProto], new_input): self.pop_multiple_optimizer_map(nodes_to_remove) self.bridge_disconnected_nodes(nodes_to_remove[-1], self.find_next_node(nodes_to_remove[-1]), new_input) def transform_to_convert(self, nodes_to_remove: List[onnx.NodeProto], nodes_to_add: Optional[List[onnx.NodeProto]] = None, inits_to_add: Optional[List[onnx.TensorProto]] = None, vis_to_add: Optional[List[onnx.ValueInfoProto]] = None): self.transform_to_fuse(nodes_to_remove, nodes_to_add, inits_to_add, vis_to_add) def transform_to_fuse(self, nodes_to_remove: List[onnx.NodeProto], nodes_to_add: Optional[List[onnx.NodeProto]] = None, inits_to_add: Optional[List[onnx.TensorProto]] = None, vis_to_add: Optional[List[onnx.ValueInfoProto]] = None): self.pop_multiple_optimizer_map(nodes_to_remove) if nodes_to_add: self.update_multiple_optimizer_map(nodes_to_add, nodes_to_remove[0].name) if inits_to_add: self.update_multiple_initializer_map(inits_to_add) if vis_to_add: self.update_multiple_value_info_map(vis_to_add) def pattern_matching(self, node): raise NotImplementedError def pattern_matcher(self, node, pattern_to_match: List[str]): decoded_pattern = [p.split('/') for p in pattern_to_match] decoded_pattern.reverse() op_type_0 = decoded_pattern.pop(0) if not self.is_op_type(node.op_type, op_type_0): return None matched_nodes = [node] while decoded_pattern: op_type_1 = decoded_pattern.pop(0) node_1 = None # TODO impl "multi-path search" for node_input in node.input: node_1 = self.traverse_prev_node(node_input, op_type_1) if node_1: break if not node_1: return None node = node_1 matched_nodes.append(node) matched_nodes.reverse() return matched_nodes def pattern_condition_checker(self, nodes_to_check): raise NotImplementedError
Methods
def bridge_disconnected_nodes(self, node_0: onnx.onnx_ml_pb2.NodeProto, next_nodes: List[onnx.onnx_ml_pb2.NodeProto], new_input)
-
For a graph changed, for example, before) prev –> node_1 –> node_0 –> next after) prev –> node_1 –> ( ) -/-> next
This function bridges node_1 and next as follows: prev –> node_1 –> next by assigning next.input[y] = node_1.output[x]
Expand source code
def bridge_disconnected_nodes(self, node_0: onnx.NodeProto, next_nodes: List[onnx.NodeProto], new_input): """ For a graph changed, for example, before) prev --> node_1 --> node_0 --> next after) prev --> node_1 --> ( ) -/-> next This function bridges node_1 and next as follows: prev --> node_1 --> next by assigning next.input[y] = node_1.output[x] """ for next_node in next_nodes: for idx, next_node_input in enumerate(next_node.input): for node_output in node_0.output: if node_output != next_node_input: continue next_node.input[idx] = new_input self.update_single_optimizer_map(next_node, next_node.name) for node_output in node_0.output: for idx, output in enumerate(self.model.graph.output): if node_output != output.name: continue self.graph_output_map[node_output] = self.copy_value_info(new_input)
def build_optimized_model(self, model)
-
Expand source code
def build_optimized_model(self, model): model = self.update_graph_fields(model) new_nodes = [] for member in self.get_map_values('node'): if isinstance(member, onnx.NodeProto): new_nodes.append(member) elif isinstance(member, list): new_nodes.extend(member) else: raise Exception(member) model = utils.rebuild_model(model, new_nodes) check_model(model) return model
def copy_value_info(self, name)
-
Expand source code
def copy_value_info(self, name): if name in self.graph_input_map: return self.graph_input_map[name] elif name in self.value_info_map: return self.value_info_map[name] else: raise Exception('%s not found.' % name)
def find_next_node(self, node: onnx.onnx_ml_pb2.NodeProto) ‑> List[onnx.onnx_ml_pb2.NodeProto]
-
Expand source code
def find_next_node(self, node: onnx.NodeProto) -> List[onnx.NodeProto]: next_nodes = [] for v in self.optimizer_map.values(): if not v: continue if isinstance(v, list): v = v[0] if not any(output == v_input for output in node.output for v_input in v.input): continue next_nodes.append(v) return next_nodes
def find_prev_node(self, node_input: str) ‑> onnx.onnx_ml_pb2.NodeProto
-
Expand source code
def find_prev_node(self, node_input: str) -> onnx.NodeProto: if node_input not in self.producer_map: return None return self.producer_map[node_input]
def get_data_node_input(self, node)
-
Expand source code
def get_data_node_input(self, node): data_node_input = None for node_input in node.input: if node_input in self.initializer_map: continue data_node_input = node_input return data_node_input
def get_init_node_input(self, node)
-
Expand source code
def get_init_node_input(self, node): init_node_input = None for node_input in node.input: if node_input not in self.initializer_map: continue init_node_input = node_input return init_node_input
def get_initializer_array(self, node_input)
-
Expand source code
def get_initializer_array(self, node_input): if node_input not in self.initializer_map: return None return numpy_helper.to_array(self.initializer_map[node_input])
def get_map_values(self, field)
-
Expand source code
def get_map_values(self, field): if any(field == word for word in ['input', 'output']): field_map = 'graph_' + field + '_map' elif field == 'node': field_map = 'optimizer_map' else: field_map = field + '_map' return self.make_field_unique(getattr(self, field_map).values())
def get_value_info_shape(self, value_info_name: str) ‑> List[int]
-
Expand source code
def get_value_info_shape(self, value_info_name: str) -> List[int]: def _get_shape(name, vi_map): return [dim.dim_value for dim in vi_map[name].type.tensor_type.shape.dim] if value_info_name in self.value_info_map: return _get_shape(value_info_name, self.value_info_map) elif value_info_name in self.graph_output_map: return _get_shape(value_info_name, self.graph_output_map) elif value_info_name in self.graph_input_map: return _get_shape(value_info_name, self.graph_input_map) else: raise Exception('%s not found.' % value_info_name)
def is_op_type(self, op_type: str, target_op_types: List[str])
-
Expand source code
def is_op_type(self, op_type: str, target_op_types: List[str]): if any(op_type == target for target in target_op_types): return True return False
def is_same_shape(self, input_1, input_2)
-
Expand source code
def is_same_shape(self, input_1, input_2): if self.get_value_info_shape(input_1) != self.get_value_info_shape(input_2): return False return True
def make_field_unique(self, values)
-
Expand source code
def make_field_unique(self, values): seen = [] for v in values: if v not in seen: seen.append(v) return seen
def make_initializer_from_array(self, array:
, name: Union[str, NoneType] = None) ‑> onnx.onnx_ml_pb2.TensorProto -
Expand source code
def make_initializer_from_array(self, array: np.array, name: Optional[str] = None) -> onnx.TensorProto: return numpy_helper.from_array(array, name)
def make_int64_initializer(self, name, target_name)
-
Expand source code
def make_int64_initializer(self, name, target_name): return make_tensor(name, onnx.TensorProto.INT64, (len(self.get_value_info_shape(target_name)),), self.get_value_info_shape(target_name))
def make_node(self, op_type, inputs, outputs, name=None, **attrs)
-
Expand source code
def make_node(self, op_type, inputs, outputs, name=None, **attrs): inputs = [x for x in inputs if x is not None] return make_node(op_type, inputs, outputs, name, **attrs)
def make_tensor_value_info(self, name, elem_type, shape)
-
Expand source code
def make_tensor_value_info(self, name, elem_type, shape): return make_tensor_value_info(name, elem_type, shape)
def pattern_condition_checker(self, nodes_to_check)
-
Expand source code
def pattern_condition_checker(self, nodes_to_check): raise NotImplementedError
def pattern_matcher(self, node, pattern_to_match: List[str])
-
Expand source code
def pattern_matcher(self, node, pattern_to_match: List[str]): decoded_pattern = [p.split('/') for p in pattern_to_match] decoded_pattern.reverse() op_type_0 = decoded_pattern.pop(0) if not self.is_op_type(node.op_type, op_type_0): return None matched_nodes = [node] while decoded_pattern: op_type_1 = decoded_pattern.pop(0) node_1 = None # TODO impl "multi-path search" for node_input in node.input: node_1 = self.traverse_prev_node(node_input, op_type_1) if node_1: break if not node_1: return None node = node_1 matched_nodes.append(node) matched_nodes.reverse() return matched_nodes
def pattern_matching(self, node)
-
Expand source code
def pattern_matching(self, node): raise NotImplementedError
def pop_multiple_initializer_map(self, nodes: List[onnx.onnx_ml_pb2.TensorProto])
-
Expand source code
def pop_multiple_initializer_map(self, nodes: List[onnx.TensorProto]): for node in nodes: self.pop_single_initializer_map(node)
def pop_multiple_optimizer_map(self, nodes: List[onnx.onnx_ml_pb2.NodeProto])
-
Expand source code
def pop_multiple_optimizer_map(self, nodes: List[onnx.NodeProto]): for node in nodes: self.pop_single_optimizer_map(node)
def pop_multiple_value_info_map(self, vis: List[onnx.onnx_ml_pb2.ValueInfoProto])
-
Expand source code
def pop_multiple_value_info_map(self, vis: List[onnx.ValueInfoProto]): for vi in vis: self.pop_single_value_info_map(vi)
def pop_single_initializer_map(self, init: onnx.onnx_ml_pb2.TensorProto)
-
Expand source code
def pop_single_initializer_map(self, init: onnx.TensorProto): self.initializer_map.pop(init.name) self.graph_input_map.pop(init.name)
def pop_single_optimizer_map(self, node: onnx.onnx_ml_pb2.NodeProto)
-
Expand source code
def pop_single_optimizer_map(self, node: onnx.NodeProto): self.optimizer_map[node.name] = []
def pop_single_value_info_map(self, vi: onnx.onnx_ml_pb2.NodeProto)
-
Expand source code
def pop_single_value_info_map(self, vi: onnx.NodeProto): self.value_info_map.pop(vi.name) if vi.name in self.graph_output_map: self.value_info_map.pop(vi.name) if vi.name in self.graph_input_map: self.value_info_map.pop(vi.name)
def transform(self)
-
Expand source code
def transform(self): outputs = list(self.graph_output_map.keys()) # To prevent traversing cyclic connections visited: Set[str] = set() visited_node: List[onnx.NodeProto] = list() while len(outputs) > 0: output = outputs.pop(0) if output not in self.producer_map: continue node = self.producer_map[output] # prevent duplicate specs from being created from nodes that have multiple outputs like Split. if node in visited_node: continue inputs = self.pattern_matching(node) # Put predecessor of node to new outputs outputs += list(filter(lambda input: input not in visited, inputs)) visited.update(inputs) visited_node.append(node) return self.build_optimized_model(self.model)
def transform_to_convert(self, nodes_to_remove: List[onnx.onnx_ml_pb2.NodeProto], nodes_to_add: Union[List[onnx.onnx_ml_pb2.NodeProto], NoneType] = None, inits_to_add: Union[List[onnx.onnx_ml_pb2.TensorProto], NoneType] = None, vis_to_add: Union[List[onnx.onnx_ml_pb2.ValueInfoProto], NoneType] = None)
-
Expand source code
def transform_to_convert(self, nodes_to_remove: List[onnx.NodeProto], nodes_to_add: Optional[List[onnx.NodeProto]] = None, inits_to_add: Optional[List[onnx.TensorProto]] = None, vis_to_add: Optional[List[onnx.ValueInfoProto]] = None): self.transform_to_fuse(nodes_to_remove, nodes_to_add, inits_to_add, vis_to_add)
def transform_to_eliminate(self, nodes_to_remove: List[onnx.onnx_ml_pb2.NodeProto], new_input)
-
Expand source code
def transform_to_eliminate(self, nodes_to_remove: List[onnx.NodeProto], new_input): self.pop_multiple_optimizer_map(nodes_to_remove) self.bridge_disconnected_nodes(nodes_to_remove[-1], self.find_next_node(nodes_to_remove[-1]), new_input)
def transform_to_fuse(self, nodes_to_remove: List[onnx.onnx_ml_pb2.NodeProto], nodes_to_add: Union[List[onnx.onnx_ml_pb2.NodeProto], NoneType] = None, inits_to_add: Union[List[onnx.onnx_ml_pb2.TensorProto], NoneType] = None, vis_to_add: Union[List[onnx.onnx_ml_pb2.ValueInfoProto], NoneType] = None)
-
Expand source code
def transform_to_fuse(self, nodes_to_remove: List[onnx.NodeProto], nodes_to_add: Optional[List[onnx.NodeProto]] = None, inits_to_add: Optional[List[onnx.TensorProto]] = None, vis_to_add: Optional[List[onnx.ValueInfoProto]] = None): self.pop_multiple_optimizer_map(nodes_to_remove) if nodes_to_add: self.update_multiple_optimizer_map(nodes_to_add, nodes_to_remove[0].name) if inits_to_add: self.update_multiple_initializer_map(inits_to_add) if vis_to_add: self.update_multiple_value_info_map(vis_to_add)
def traverse_prev_node(self, producer_map_key: str, target_op_types: List[str])
-
Expand source code
def traverse_prev_node(self, producer_map_key: str, target_op_types: List[str]): prev_node = self.find_prev_node(producer_map_key) if not prev_node: return None if not self.is_op_type(prev_node.op_type, target_op_types): return False return prev_node
def update_graph_fields(self, model)
-
Expand source code
def update_graph_fields(self, model): for field in ['initializer', 'input', 'output', 'value_info']: model.graph.ClearField(field) getattr(model.graph, field).extend(self.get_map_values(field)) return model
def update_multiple_initializer_map(self, initializers: List[onnx.onnx_ml_pb2.TensorProto])
-
Expand source code
def update_multiple_initializer_map(self, initializers: List[onnx.TensorProto]): for init in initializers: if not init: continue self.update_single_initializer_map(init)
def update_multiple_optimizer_map(self, nodes: List[onnx.onnx_ml_pb2.NodeProto], dest_name)
-
Expand source code
def update_multiple_optimizer_map(self, nodes: List[onnx.NodeProto], dest_name): self.optimizer_map[dest_name] = nodes
def update_multiple_value_info_map(self, value_infos: List[onnx.onnx_ml_pb2.ValueInfoProto])
-
Expand source code
def update_multiple_value_info_map(self, value_infos: List[onnx.ValueInfoProto]): for vi in value_infos: self.update_single_value_info_map(vi)
def update_single_initializer_map(self, initializer: onnx.onnx_ml_pb2.TensorProto)
-
Expand source code
def update_single_initializer_map(self, initializer: onnx.TensorProto): self.initializer_map[initializer.name] = initializer self.graph_input_map[initializer.name] = make_tensor_value_info(initializer.name, initializer.data_type, numpy_helper.to_array(initializer).shape)
def update_single_optimizer_map(self, node: onnx.onnx_ml_pb2.NodeProto, dest_name)
-
Expand source code
def update_single_optimizer_map(self, node: onnx.NodeProto, dest_name): self.optimizer_map[dest_name] = node
def update_single_value_info_map(self, value_info: onnx.onnx_ml_pb2.ValueInfoProto)
-
Expand source code
def update_single_value_info_map(self, value_info: onnx.ValueInfoProto): if value_info.name in self.graph_input_map: self.graph_input_map[value_info.name] = value_info else: self.value_info_map[value_info.name] = value_info