Search in sources :

Example 16 with PropertyException

use of com.oracle.labs.mlrg.olcut.config.PropertyException in project tribuo by oracle.

the class TensorFlowTrainer method validateGraph.

/**
 * Validates that the graph has the appropriate input, output and initialization operations.
 * <p>
 * Throws IllegalArgumentException or PropertyException if the graph is invalid.
 * @param throwPropertyException If true throw PropertyException instead of IllegalArgumentException.
 */
private void validateGraph(boolean throwPropertyException) {
    try (Graph graph = new Graph()) {
        graph.importGraphDef(graphDef);
        for (String inputName : featureConverter.inputNamesSet()) {
            if (graph.operation(inputName) == null) {
                String msg = "Unable to find an input operation, expected an op with name '" + inputName + "'";
                if (throwPropertyException) {
                    throw new PropertyException("", "featureConverter", msg);
                } else {
                    throw new IllegalArgumentException(msg);
                }
            }
        }
        Operation outputOp = graph.operation(outputName);
        if (outputOp == null) {
            String msg = "Unable to find the output operation, expected an op with name '" + outputName + "'";
            if (throwPropertyException) {
                throw new PropertyException("", "outputName", msg);
            } else {
                throw new IllegalArgumentException(msg);
            }
        }
        Shape outputShape = outputOp.output(0).shape();
        if (outputShape.numDimensions() != 2) {
            String msg = "Expected a 2 dimensional output, found " + Arrays.toString(outputShape.asArray());
            if (throwPropertyException) {
                throw new PropertyException("", "outputName", msg);
            } else {
                throw new IllegalArgumentException(msg);
            }
        }
    }
}
Also used : Graph(org.tensorflow.Graph) Shape(org.tensorflow.ndarray.Shape) PropertyException(com.oracle.labs.mlrg.olcut.config.PropertyException) Operation(org.tensorflow.Operation)

Example 17 with PropertyException

use of com.oracle.labs.mlrg.olcut.config.PropertyException in project tribuo by oracle.

the class BERTFeatureExtractor method postConfig.

@Override
public void postConfig() throws PropertyException {
    try {
        env = OrtEnvironment.getEnvironment();
        OrtSession.SessionOptions options = new OrtSession.SessionOptions();
        if (useCUDA) {
            options.addCUDA();
        }
        session = env.createSession(modelPath.toString(), options);
        // Validate model and extract the embedding dimension
        Map<String, NodeInfo> outputs = session.getOutputInfo();
        if (outputs.size() != 2) {
            throw new PropertyException("", "modelPath", "Invalid model, expected 2 outputs, found " + outputs.size());
        } else {
            // check that the outputs have the expected names
            NodeInfo outputZero = outputs.get(TOKEN_OUTPUT);
            if ((outputZero == null) || !(outputZero.getInfo() instanceof TensorInfo)) {
                throw new PropertyException("", "modelPath", "Invalid model, expected to find tensor output called '" + TOKEN_OUTPUT + "'");
            } else {
                TensorInfo outputZeroTensor = (TensorInfo) outputZero.getInfo();
                long[] shape = outputZeroTensor.getShape();
                if (shape.length != 3) {
                    throw new PropertyException("", "modelPath", "Invalid model, expected to find " + TOKEN_OUTPUT + " with 3 dimensions, found :" + Arrays.toString(shape));
                } else {
                    // Bert embedding dim is the last dimension.
                    // The first two are batch and sequence length.
                    bertDim = (int) shape[2];
                }
            }
            NodeInfo outputOne = outputs.get(CLS_OUTPUT);
            if ((outputOne == null) || !(outputOne.getInfo() instanceof TensorInfo)) {
                throw new PropertyException("", "modelPath", "Invalid model, expected to find tensor output called '" + CLS_OUTPUT + "'");
            } else {
                TensorInfo outputOneTensor = (TensorInfo) outputOne.getInfo();
                long[] shape = outputOneTensor.getShape();
                if (shape.length != 2) {
                    throw new PropertyException("", "modelPath", "Invalid model, expected to find " + CLS_OUTPUT + " with 2 dimensions, found :" + Arrays.toString(shape));
                } else if (shape[1] != bertDim) {
                    // dimension mismatch between the classification and token outputs, bail out
                    throw new PropertyException("", "modelPath", "Invalid model, expected to find two outputs with the same embedding dimension, instead found " + bertDim + " and " + shape[1]);
                }
            }
        }
        Map<String, NodeInfo> inputs = session.getInputInfo();
        if (inputs.size() != 3) {
            throw new PropertyException("", "modelPath", "Invalid model, expected 3 inputs, found " + inputs.size());
        } else {
            if (!inputs.containsKey(ATTENTION_MASK)) {
                throw new PropertyException("", "modelPath", "Invalid model, expected to find an input called '" + ATTENTION_MASK + "'");
            }
            if (!inputs.containsKey(INPUT_IDS)) {
                throw new PropertyException("", "modelPath", "Invalid model, expected to find an input called '" + INPUT_IDS + "'");
            }
            if (!inputs.containsKey(TOKEN_TYPE_IDS)) {
                throw new PropertyException("", "modelPath", "Invalid model, expected to find an input called '" + TOKEN_TYPE_IDS + "'");
            }
        }
        featureNames = generateFeatureNames(bertDim);
        TokenizerConfig config = loadTokenizer(tokenizerPath);
        Wordpiece wordpiece = new Wordpiece(config.tokenIDs.keySet(), config.unknownToken, config.maxInputCharsPerWord);
        tokenIDs = config.tokenIDs;
        unknownToken = config.unknownToken;
        classificationToken = config.classificationToken;
        separatorToken = config.separatorToken;
        tokenizer = new WordpieceTokenizer(wordpiece, new WordpieceBasicTokenizer(), config.lowercase, config.stripAccents, Collections.emptySet());
    } catch (OrtException e) {
        throw new PropertyException(e, "", "modelPath", "Failed to load model, ORT threw: ");
    } catch (IOException e) {
        throw new PropertyException(e, "", "tokenizerPath", "Failed to load tokenizer, Jackson threw: ");
    }
}
Also used : WordpieceTokenizer(org.tribuo.util.tokens.impl.wordpiece.WordpieceTokenizer) PropertyException(com.oracle.labs.mlrg.olcut.config.PropertyException) WordpieceBasicTokenizer(org.tribuo.util.tokens.impl.wordpiece.WordpieceBasicTokenizer) IOException(java.io.IOException) OrtException(ai.onnxruntime.OrtException) Wordpiece(org.tribuo.util.tokens.impl.wordpiece.Wordpiece) OrtSession(ai.onnxruntime.OrtSession) NodeInfo(ai.onnxruntime.NodeInfo) TensorInfo(ai.onnxruntime.TensorInfo)

Example 18 with PropertyException

use of com.oracle.labs.mlrg.olcut.config.PropertyException in project tribuo by oracle.

the class KMeansTrainer method postConfig.

/**
 * Used by the OLCUT configuration system, and should not be called by external code.
 */
@Override
public synchronized void postConfig() {
    this.rng = new SplittableRandom(seed);
    if (this.distanceType != null) {
        if (this.distType != null) {
            throw new PropertyException("distType", "Both distType and distanceType must not both be set.");
        } else {
            this.distType = this.distanceType.getDistanceType();
            this.distanceType = null;
        }
    }
}
Also used : PropertyException(com.oracle.labs.mlrg.olcut.config.PropertyException) SplittableRandom(java.util.SplittableRandom)

Aggregations

PropertyException (com.oracle.labs.mlrg.olcut.config.PropertyException)18 Random (java.util.Random)6 ArrayList (java.util.ArrayList)5 Example (org.tribuo.Example)5 ArrayExample (org.tribuo.impl.ArrayExample)5 Locale (java.util.Locale)2 NeighboursBruteForceFactory (org.tribuo.math.neighbour.bruteforce.NeighboursBruteForceFactory)2 Regressor (org.tribuo.regression.Regressor)2 NodeInfo (ai.onnxruntime.NodeInfo)1 OrtException (ai.onnxruntime.OrtException)1 OrtSession (ai.onnxruntime.OrtSession)1 TensorInfo (ai.onnxruntime.TensorInfo)1 ConfigurationManager (com.oracle.labs.mlrg.olcut.config.ConfigurationManager)1 FileNotFoundException (java.io.FileNotFoundException)1 IOException (java.io.IOException)1 MalformedURLException (java.net.MalformedURLException)1 MessageDigest (java.security.MessageDigest)1 HashSet (java.util.HashSet)1 SplittableRandom (java.util.SplittableRandom)1 MultivariateNormalDistribution (org.apache.commons.math3.distribution.MultivariateNormalDistribution)1