Search in sources :

Example 1 with NodeInfo

use of ai.onnxruntime.NodeInfo in project javacpp-presets by bytedeco.

the class ScoreMNIST method main.

public static void main(String[] args) throws OrtException, IOException {
    if (args.length < 2 || args.length > 3) {
        System.out.println("Usage: ScoreMNIST <model-path> <test-data> <optional:scikit-learn-flag>");
        System.out.println("The test data input should be a libsvm format version of MNIST.");
        return;
    }
    try (OrtEnvironment env = OrtEnvironment.getEnvironment();
        OrtSession.SessionOptions opts = new SessionOptions()) {
        opts.setOptimizationLevel(OptLevel.BASIC_OPT);
        logger.info("Loading model from " + args[0]);
        try (OrtSession session = env.createSession(args[0], opts)) {
            logger.info("Inputs:");
            for (NodeInfo i : session.getInputInfo().values()) {
                logger.info(i.toString());
            }
            logger.info("Outputs:");
            for (NodeInfo i : session.getOutputInfo().values()) {
                logger.info(i.toString());
            }
            SparseData data = load(args[1]);
            float[][][][] testData = new float[1][1][28][28];
            float[][] testDataSKL = new float[1][780];
            int correctCount = 0;
            int[][] confusionMatrix = new int[10][10];
            String inputName = session.getInputNames().iterator().next();
            for (int i = 0; i < data.labels.length; i++) {
                if (args.length == 3) {
                    writeDataSKL(testDataSKL, data.indices.get(i), data.values.get(i));
                } else {
                    writeData(testData, data.indices.get(i), data.values.get(i));
                }
                try (OnnxTensor test = OnnxTensor.createTensor(env, args.length == 3 ? testDataSKL : testData);
                    Result output = session.run(Collections.singletonMap(inputName, test))) {
                    int predLabel;
                    if (args.length == 3) {
                        long[] labels = (long[]) output.get(0).getValue();
                        predLabel = (int) labels[0];
                    } else {
                        float[][] outputProbs = (float[][]) output.get(0).getValue();
                        predLabel = pred(outputProbs[0]);
                    }
                    if (predLabel == data.labels[i]) {
                        correctCount++;
                    }
                    confusionMatrix[data.labels[i]][predLabel]++;
                    if (i % 2000 == 0) {
                        logger.log(Level.INFO, "Cur accuracy = " + ((float) correctCount) / (i + 1));
                        logger.log(Level.INFO, "Output type = " + output.get(0).toString());
                        if (args.length == 3) {
                            logger.log(Level.INFO, "Output type = " + output.get(1).toString());
                            logger.log(Level.INFO, "Output value = " + output.get(1).getValue().toString());
                        }
                    }
                }
            }
            logger.info("Final accuracy = " + ((float) correctCount) / data.labels.length);
            StringBuilder sb = new StringBuilder();
            sb.append("Label");
            for (int i = 0; i < confusionMatrix.length; i++) {
                sb.append(String.format("%1$5s", "" + i));
            }
            sb.append("\n");
            for (int i = 0; i < confusionMatrix.length; i++) {
                sb.append(String.format("%1$5s", "" + i));
                for (int j = 0; j < confusionMatrix[i].length; j++) {
                    sb.append(String.format("%1$5s", "" + confusionMatrix[i][j]));
                }
                sb.append("\n");
            }
            System.out.println(sb.toString());
        }
    }
    logger.info("Done!");
}
Also used : SessionOptions(ai.onnxruntime.OrtSession.SessionOptions) Result(ai.onnxruntime.OrtSession.Result) OrtEnvironment(ai.onnxruntime.OrtEnvironment) OrtSession(ai.onnxruntime.OrtSession) NodeInfo(ai.onnxruntime.NodeInfo) SessionOptions(ai.onnxruntime.OrtSession.SessionOptions) OnnxTensor(ai.onnxruntime.OnnxTensor)

Example 2 with NodeInfo

use of ai.onnxruntime.NodeInfo in project tribuo by oracle.

the class BERTFeatureExtractor method postConfig.

@Override
public void postConfig() throws PropertyException {
    try {
        env = OrtEnvironment.getEnvironment();
        OrtSession.SessionOptions options = new OrtSession.SessionOptions();
        if (useCUDA) {
            options.addCUDA();
        }
        session = env.createSession(modelPath.toString(), options);
        // Validate model and extract the embedding dimension
        Map<String, NodeInfo> outputs = session.getOutputInfo();
        if (outputs.size() != 2) {
            throw new PropertyException("", "modelPath", "Invalid model, expected 2 outputs, found " + outputs.size());
        } else {
            // check that the outputs have the expected names
            NodeInfo outputZero = outputs.get(TOKEN_OUTPUT);
            if ((outputZero == null) || !(outputZero.getInfo() instanceof TensorInfo)) {
                throw new PropertyException("", "modelPath", "Invalid model, expected to find tensor output called '" + TOKEN_OUTPUT + "'");
            } else {
                TensorInfo outputZeroTensor = (TensorInfo) outputZero.getInfo();
                long[] shape = outputZeroTensor.getShape();
                if (shape.length != 3) {
                    throw new PropertyException("", "modelPath", "Invalid model, expected to find " + TOKEN_OUTPUT + " with 3 dimensions, found :" + Arrays.toString(shape));
                } else {
                    // Bert embedding dim is the last dimension.
                    // The first two are batch and sequence length.
                    bertDim = (int) shape[2];
                }
            }
            NodeInfo outputOne = outputs.get(CLS_OUTPUT);
            if ((outputOne == null) || !(outputOne.getInfo() instanceof TensorInfo)) {
                throw new PropertyException("", "modelPath", "Invalid model, expected to find tensor output called '" + CLS_OUTPUT + "'");
            } else {
                TensorInfo outputOneTensor = (TensorInfo) outputOne.getInfo();
                long[] shape = outputOneTensor.getShape();
                if (shape.length != 2) {
                    throw new PropertyException("", "modelPath", "Invalid model, expected to find " + CLS_OUTPUT + " with 2 dimensions, found :" + Arrays.toString(shape));
                } else if (shape[1] != bertDim) {
                    // dimension mismatch between the classification and token outputs, bail out
                    throw new PropertyException("", "modelPath", "Invalid model, expected to find two outputs with the same embedding dimension, instead found " + bertDim + " and " + shape[1]);
                }
            }
        }
        Map<String, NodeInfo> inputs = session.getInputInfo();
        if (inputs.size() != 3) {
            throw new PropertyException("", "modelPath", "Invalid model, expected 3 inputs, found " + inputs.size());
        } else {
            if (!inputs.containsKey(ATTENTION_MASK)) {
                throw new PropertyException("", "modelPath", "Invalid model, expected to find an input called '" + ATTENTION_MASK + "'");
            }
            if (!inputs.containsKey(INPUT_IDS)) {
                throw new PropertyException("", "modelPath", "Invalid model, expected to find an input called '" + INPUT_IDS + "'");
            }
            if (!inputs.containsKey(TOKEN_TYPE_IDS)) {
                throw new PropertyException("", "modelPath", "Invalid model, expected to find an input called '" + TOKEN_TYPE_IDS + "'");
            }
        }
        featureNames = generateFeatureNames(bertDim);
        TokenizerConfig config = loadTokenizer(tokenizerPath);
        Wordpiece wordpiece = new Wordpiece(config.tokenIDs.keySet(), config.unknownToken, config.maxInputCharsPerWord);
        tokenIDs = config.tokenIDs;
        unknownToken = config.unknownToken;
        classificationToken = config.classificationToken;
        separatorToken = config.separatorToken;
        tokenizer = new WordpieceTokenizer(wordpiece, new WordpieceBasicTokenizer(), config.lowercase, config.stripAccents, Collections.emptySet());
    } catch (OrtException e) {
        throw new PropertyException(e, "", "modelPath", "Failed to load model, ORT threw: ");
    } catch (IOException e) {
        throw new PropertyException(e, "", "tokenizerPath", "Failed to load tokenizer, Jackson threw: ");
    }
}
Also used : WordpieceTokenizer(org.tribuo.util.tokens.impl.wordpiece.WordpieceTokenizer) PropertyException(com.oracle.labs.mlrg.olcut.config.PropertyException) WordpieceBasicTokenizer(org.tribuo.util.tokens.impl.wordpiece.WordpieceBasicTokenizer) IOException(java.io.IOException) OrtException(ai.onnxruntime.OrtException) Wordpiece(org.tribuo.util.tokens.impl.wordpiece.Wordpiece) OrtSession(ai.onnxruntime.OrtSession) NodeInfo(ai.onnxruntime.NodeInfo) TensorInfo(ai.onnxruntime.TensorInfo)

Aggregations

NodeInfo (ai.onnxruntime.NodeInfo)2 OrtSession (ai.onnxruntime.OrtSession)2 OnnxTensor (ai.onnxruntime.OnnxTensor)1 OrtEnvironment (ai.onnxruntime.OrtEnvironment)1 OrtException (ai.onnxruntime.OrtException)1 Result (ai.onnxruntime.OrtSession.Result)1 SessionOptions (ai.onnxruntime.OrtSession.SessionOptions)1 TensorInfo (ai.onnxruntime.TensorInfo)1 PropertyException (com.oracle.labs.mlrg.olcut.config.PropertyException)1 IOException (java.io.IOException)1 Wordpiece (org.tribuo.util.tokens.impl.wordpiece.Wordpiece)1 WordpieceBasicTokenizer (org.tribuo.util.tokens.impl.wordpiece.WordpieceBasicTokenizer)1 WordpieceTokenizer (org.tribuo.util.tokens.impl.wordpiece.WordpieceTokenizer)1