Module furiosa.quantizer.furiosa_sdk_quantizer.frontend.onnx.utils.check_model

Expand source code
import onnx
from onnx import checker
import onnxruntime as ort


def check_model(model: onnx.ModelProto, check_runnable: bool = True) -> None:
    """
    Check if model's well-defined and executable on onnxruntime
    """
    # TODO After collecting possible errors,
    #  pass through only if all error messages are "No opset import for domain 'com.microsoft'".
    #  The code below is only to see the first error encountered.
    acceptable_error_msg = ["No opset import for domain 'com.microsoft'",
                            'No Op registered for LayerNormalization with domain_version of 12']
    try:
        checker.check_model(model)
    except checker.ValidationError as e:
        if str(e).split("==>")[0].rstrip() in acceptable_error_msg:
            pass
        else:
            checker.check_model(model)

    if check_runnable:
        ort.set_default_logger_severity(3)
        ort.InferenceSession(model.SerializeToString())

Functions

def check_model(model: onnx.onnx_ml_pb2.ModelProto, check_runnable: bool = True) ‑> NoneType

Check if model's well-defined and executable on onnxruntime

Expand source code
def check_model(model: onnx.ModelProto, check_runnable: bool = True) -> None:
    """
    Check if model's well-defined and executable on onnxruntime
    """
    # TODO After collecting possible errors,
    #  pass through only if all error messages are "No opset import for domain 'com.microsoft'".
    #  The code below is only to see the first error encountered.
    acceptable_error_msg = ["No opset import for domain 'com.microsoft'",
                            'No Op registered for LayerNormalization with domain_version of 12']
    try:
        checker.check_model(model)
    except checker.ValidationError as e:
        if str(e).split("==>")[0].rstrip() in acceptable_error_msg:
            pass
        else:
            checker.check_model(model)

    if check_runnable:
        ort.set_default_logger_severity(3)
        ort.InferenceSession(model.SerializeToString())