Module furiosa.quantizer.frontend.onnx.transformer.deprecated.fuse_softmax
Expand source code
import onnx
import numpy as np
from onnx.helper import make_node, make_tensor_value_info
from furiosa_sdk_quantizer.interfaces.transformer import Transformer
from furiosa_sdk_quantizer.frontend.onnx.transformer import utils
from furiosa_sdk_quantizer.frontend.onnx.utils.check_model import check_model
class FuseSoftmax(Transformer):
"""
https://github.com/onnx/onnx/blob/master/docs/Operators.md#softmax
from: Exp -> ReduceSum -> Div
to: Transpose -> Softmax -> Transpose
Assume NCHW Input
"""
def transform(self, model: onnx.ModelProto) -> onnx.ModelProto:
nodes_by_output_name = {node.output[0]: node for node in model.graph.node}
initializer = {init.name: init for init in model.graph.initializer}
value_info = {vi.name: vi for vi in
list(model.graph.value_info) + list(model.graph.input) + list(model.graph.output)}
post_fix = '_transposed'
optimized_nodes = []
removed_nodes = []
for node in model.graph.node:
if node.op_type != 'Div':
optimized_nodes.append(node)
continue
# Div has no specific order of input according to spec.
# Therefore, we need to find the input index of Exp and ReduceSum.
def _is_input_op_type(node_input, op_type):
if node_input in initializer.keys():
return False
return nodes_by_output_name[node_input].op_type == op_type
idx_exp = list(filter(lambda enum: _is_input_op_type(enum[1], 'Exp'), enumerate(node.input)))
idx_rsum = list(filter(lambda enum: _is_input_op_type(enum[1], 'ReduceSum'), enumerate(node.input)))
# Expect one of the inputs is Exp and the other is ReduceSum
if len(idx_exp) != 1 and len(idx_rsum) != 1:
optimized_nodes.append(node)
continue
idx_exp = idx_exp[0][0]
idx_rsum = idx_rsum[0][0]
exp_node = nodes_by_output_name[node.input[idx_exp]]
rsum_node = nodes_by_output_name[node.input[idx_rsum]]
removed_nodes.extend([node, exp_node, rsum_node])
# assert dim(input_shape) == 4
exp_shape = [dim.dim_value for dim in value_info[exp_node.output[0]].type.tensor_type.shape.dim]
length = len(exp_shape)
axis = rsum_node.attribute[0].ints
# assert ReduceSum takes only 1 axis
assert len(axis) == 1
axis = axis[0]
if axis == -1:
axis = length - 1
# make permutation according to axis given
perm = list(range(0, length))
perm[axis], perm[-1] = perm[-1], perm[axis]
new_vi = []
if axis != length - 1:
trans_node_1 = make_node('Transpose', inputs=[exp_node.input[0]],
outputs=[exp_node.output[0] + post_fix], perm=perm)
softmax_node = make_node('Softmax', inputs=[exp_node.output[0] + post_fix],
outputs=[exp_node.output[0] + '_softmax'], axis=length - 1)
trans_node_2 = make_node('Transpose', inputs=[exp_node.output[0] + '_softmax'],
outputs=[node.output[0]], perm=perm)
optimized_nodes.extend([trans_node_1, softmax_node, trans_node_2])
perm1_shape = np.array(exp_shape)[perm].tolist()
new_vi.append(make_tensor_value_info(name=softmax_node.output[0],
elem_type=onnx.TensorProto.FLOAT,
shape=perm1_shape))
new_vi.append(make_tensor_value_info(name=trans_node_1.output[0],
elem_type=onnx.TensorProto.FLOAT,
shape=perm1_shape))
else:
softmax_node = make_node('Softmax', inputs=[exp_node.input[0]],
outputs=[node.output[0]], axis=length - 1)
optimized_nodes.extend([softmax_node])
model.graph.value_info.extend(new_vi)
# remove duplicate node(s) in optimized nodes
seen = []
for op_node in optimized_nodes:
if op_node in seen:
continue
seen.append(op_node)
optimized_nodes = seen
new_nodes = list(filter(lambda node: node not in removed_nodes, optimized_nodes))
model = utils.rebuild_model(model, new_nodes)
check_model(model)
return model
Classes
class FuseSoftmax (*args, **kwds)
-
https://github.com/onnx/onnx/blob/master/docs/Operators.md#softmax from: Exp -> ReduceSum -> Div to: Transpose -> Softmax -> Transpose
Assume NCHW Input
Expand source code
class FuseSoftmax(Transformer): """ https://github.com/onnx/onnx/blob/master/docs/Operators.md#softmax from: Exp -> ReduceSum -> Div to: Transpose -> Softmax -> Transpose Assume NCHW Input """ def transform(self, model: onnx.ModelProto) -> onnx.ModelProto: nodes_by_output_name = {node.output[0]: node for node in model.graph.node} initializer = {init.name: init for init in model.graph.initializer} value_info = {vi.name: vi for vi in list(model.graph.value_info) + list(model.graph.input) + list(model.graph.output)} post_fix = '_transposed' optimized_nodes = [] removed_nodes = [] for node in model.graph.node: if node.op_type != 'Div': optimized_nodes.append(node) continue # Div has no specific order of input according to spec. # Therefore, we need to find the input index of Exp and ReduceSum. def _is_input_op_type(node_input, op_type): if node_input in initializer.keys(): return False return nodes_by_output_name[node_input].op_type == op_type idx_exp = list(filter(lambda enum: _is_input_op_type(enum[1], 'Exp'), enumerate(node.input))) idx_rsum = list(filter(lambda enum: _is_input_op_type(enum[1], 'ReduceSum'), enumerate(node.input))) # Expect one of the inputs is Exp and the other is ReduceSum if len(idx_exp) != 1 and len(idx_rsum) != 1: optimized_nodes.append(node) continue idx_exp = idx_exp[0][0] idx_rsum = idx_rsum[0][0] exp_node = nodes_by_output_name[node.input[idx_exp]] rsum_node = nodes_by_output_name[node.input[idx_rsum]] removed_nodes.extend([node, exp_node, rsum_node]) # assert dim(input_shape) == 4 exp_shape = [dim.dim_value for dim in value_info[exp_node.output[0]].type.tensor_type.shape.dim] length = len(exp_shape) axis = rsum_node.attribute[0].ints # assert ReduceSum takes only 1 axis assert len(axis) == 1 axis = axis[0] if axis == -1: axis = length - 1 # make permutation according to axis given perm = list(range(0, length)) perm[axis], perm[-1] = perm[-1], perm[axis] new_vi = [] if axis != length - 1: trans_node_1 = make_node('Transpose', inputs=[exp_node.input[0]], outputs=[exp_node.output[0] + post_fix], perm=perm) softmax_node = make_node('Softmax', inputs=[exp_node.output[0] + post_fix], outputs=[exp_node.output[0] + '_softmax'], axis=length - 1) trans_node_2 = make_node('Transpose', inputs=[exp_node.output[0] + '_softmax'], outputs=[node.output[0]], perm=perm) optimized_nodes.extend([trans_node_1, softmax_node, trans_node_2]) perm1_shape = np.array(exp_shape)[perm].tolist() new_vi.append(make_tensor_value_info(name=softmax_node.output[0], elem_type=onnx.TensorProto.FLOAT, shape=perm1_shape)) new_vi.append(make_tensor_value_info(name=trans_node_1.output[0], elem_type=onnx.TensorProto.FLOAT, shape=perm1_shape)) else: softmax_node = make_node('Softmax', inputs=[exp_node.input[0]], outputs=[node.output[0]], axis=length - 1) optimized_nodes.extend([softmax_node]) model.graph.value_info.extend(new_vi) # remove duplicate node(s) in optimized nodes seen = [] for op_node in optimized_nodes: if op_node in seen: continue seen.append(op_node) optimized_nodes = seen new_nodes = list(filter(lambda node: node not in removed_nodes, optimized_nodes)) model = utils.rebuild_model(model, new_nodes) check_model(model) return model
Ancestors
- furiosa_sdk_quantizer.interfaces.transformer.Transformer
- typing.Generic
Methods
def transform(self, model: onnx.onnx_ml_pb2.ModelProto) ‑> onnx.onnx_ml_pb2.ModelProto
-
Expand source code
def transform(self, model: onnx.ModelProto) -> onnx.ModelProto: nodes_by_output_name = {node.output[0]: node for node in model.graph.node} initializer = {init.name: init for init in model.graph.initializer} value_info = {vi.name: vi for vi in list(model.graph.value_info) + list(model.graph.input) + list(model.graph.output)} post_fix = '_transposed' optimized_nodes = [] removed_nodes = [] for node in model.graph.node: if node.op_type != 'Div': optimized_nodes.append(node) continue # Div has no specific order of input according to spec. # Therefore, we need to find the input index of Exp and ReduceSum. def _is_input_op_type(node_input, op_type): if node_input in initializer.keys(): return False return nodes_by_output_name[node_input].op_type == op_type idx_exp = list(filter(lambda enum: _is_input_op_type(enum[1], 'Exp'), enumerate(node.input))) idx_rsum = list(filter(lambda enum: _is_input_op_type(enum[1], 'ReduceSum'), enumerate(node.input))) # Expect one of the inputs is Exp and the other is ReduceSum if len(idx_exp) != 1 and len(idx_rsum) != 1: optimized_nodes.append(node) continue idx_exp = idx_exp[0][0] idx_rsum = idx_rsum[0][0] exp_node = nodes_by_output_name[node.input[idx_exp]] rsum_node = nodes_by_output_name[node.input[idx_rsum]] removed_nodes.extend([node, exp_node, rsum_node]) # assert dim(input_shape) == 4 exp_shape = [dim.dim_value for dim in value_info[exp_node.output[0]].type.tensor_type.shape.dim] length = len(exp_shape) axis = rsum_node.attribute[0].ints # assert ReduceSum takes only 1 axis assert len(axis) == 1 axis = axis[0] if axis == -1: axis = length - 1 # make permutation according to axis given perm = list(range(0, length)) perm[axis], perm[-1] = perm[-1], perm[axis] new_vi = [] if axis != length - 1: trans_node_1 = make_node('Transpose', inputs=[exp_node.input[0]], outputs=[exp_node.output[0] + post_fix], perm=perm) softmax_node = make_node('Softmax', inputs=[exp_node.output[0] + post_fix], outputs=[exp_node.output[0] + '_softmax'], axis=length - 1) trans_node_2 = make_node('Transpose', inputs=[exp_node.output[0] + '_softmax'], outputs=[node.output[0]], perm=perm) optimized_nodes.extend([trans_node_1, softmax_node, trans_node_2]) perm1_shape = np.array(exp_shape)[perm].tolist() new_vi.append(make_tensor_value_info(name=softmax_node.output[0], elem_type=onnx.TensorProto.FLOAT, shape=perm1_shape)) new_vi.append(make_tensor_value_info(name=trans_node_1.output[0], elem_type=onnx.TensorProto.FLOAT, shape=perm1_shape)) else: softmax_node = make_node('Softmax', inputs=[exp_node.input[0]], outputs=[node.output[0]], axis=length - 1) optimized_nodes.extend([softmax_node]) model.graph.value_info.extend(new_vi) # remove duplicate node(s) in optimized nodes seen = [] for op_node in optimized_nodes: if op_node in seen: continue seen.append(op_node) optimized_nodes = seen new_nodes = list(filter(lambda node: node not in removed_nodes, optimized_nodes)) model = utils.rebuild_model(model, new_nodes) check_model(model) return model