use of com.oracle.labs.mlrg.olcut.config.PropertyException in project tribuo by oracle.
the class TensorFlowTrainer method validateGraph.
/**
* Validates that the graph has the appropriate input, output and initialization operations.
* <p>
* Throws IllegalArgumentException or PropertyException if the graph is invalid.
* @param throwPropertyException If true throw PropertyException instead of IllegalArgumentException.
*/
private void validateGraph(boolean throwPropertyException) {
try (Graph graph = new Graph()) {
graph.importGraphDef(graphDef);
for (String inputName : featureConverter.inputNamesSet()) {
if (graph.operation(inputName) == null) {
String msg = "Unable to find an input operation, expected an op with name '" + inputName + "'";
if (throwPropertyException) {
throw new PropertyException("", "featureConverter", msg);
} else {
throw new IllegalArgumentException(msg);
}
}
}
Operation outputOp = graph.operation(outputName);
if (outputOp == null) {
String msg = "Unable to find the output operation, expected an op with name '" + outputName + "'";
if (throwPropertyException) {
throw new PropertyException("", "outputName", msg);
} else {
throw new IllegalArgumentException(msg);
}
}
Shape outputShape = outputOp.output(0).shape();
if (outputShape.numDimensions() != 2) {
String msg = "Expected a 2 dimensional output, found " + Arrays.toString(outputShape.asArray());
if (throwPropertyException) {
throw new PropertyException("", "outputName", msg);
} else {
throw new IllegalArgumentException(msg);
}
}
}
}
use of com.oracle.labs.mlrg.olcut.config.PropertyException in project tribuo by oracle.
the class BERTFeatureExtractor method postConfig.
@Override
public void postConfig() throws PropertyException {
try {
env = OrtEnvironment.getEnvironment();
OrtSession.SessionOptions options = new OrtSession.SessionOptions();
if (useCUDA) {
options.addCUDA();
}
session = env.createSession(modelPath.toString(), options);
// Validate model and extract the embedding dimension
Map<String, NodeInfo> outputs = session.getOutputInfo();
if (outputs.size() != 2) {
throw new PropertyException("", "modelPath", "Invalid model, expected 2 outputs, found " + outputs.size());
} else {
// check that the outputs have the expected names
NodeInfo outputZero = outputs.get(TOKEN_OUTPUT);
if ((outputZero == null) || !(outputZero.getInfo() instanceof TensorInfo)) {
throw new PropertyException("", "modelPath", "Invalid model, expected to find tensor output called '" + TOKEN_OUTPUT + "'");
} else {
TensorInfo outputZeroTensor = (TensorInfo) outputZero.getInfo();
long[] shape = outputZeroTensor.getShape();
if (shape.length != 3) {
throw new PropertyException("", "modelPath", "Invalid model, expected to find " + TOKEN_OUTPUT + " with 3 dimensions, found :" + Arrays.toString(shape));
} else {
// Bert embedding dim is the last dimension.
// The first two are batch and sequence length.
bertDim = (int) shape[2];
}
}
NodeInfo outputOne = outputs.get(CLS_OUTPUT);
if ((outputOne == null) || !(outputOne.getInfo() instanceof TensorInfo)) {
throw new PropertyException("", "modelPath", "Invalid model, expected to find tensor output called '" + CLS_OUTPUT + "'");
} else {
TensorInfo outputOneTensor = (TensorInfo) outputOne.getInfo();
long[] shape = outputOneTensor.getShape();
if (shape.length != 2) {
throw new PropertyException("", "modelPath", "Invalid model, expected to find " + CLS_OUTPUT + " with 2 dimensions, found :" + Arrays.toString(shape));
} else if (shape[1] != bertDim) {
// dimension mismatch between the classification and token outputs, bail out
throw new PropertyException("", "modelPath", "Invalid model, expected to find two outputs with the same embedding dimension, instead found " + bertDim + " and " + shape[1]);
}
}
}
Map<String, NodeInfo> inputs = session.getInputInfo();
if (inputs.size() != 3) {
throw new PropertyException("", "modelPath", "Invalid model, expected 3 inputs, found " + inputs.size());
} else {
if (!inputs.containsKey(ATTENTION_MASK)) {
throw new PropertyException("", "modelPath", "Invalid model, expected to find an input called '" + ATTENTION_MASK + "'");
}
if (!inputs.containsKey(INPUT_IDS)) {
throw new PropertyException("", "modelPath", "Invalid model, expected to find an input called '" + INPUT_IDS + "'");
}
if (!inputs.containsKey(TOKEN_TYPE_IDS)) {
throw new PropertyException("", "modelPath", "Invalid model, expected to find an input called '" + TOKEN_TYPE_IDS + "'");
}
}
featureNames = generateFeatureNames(bertDim);
TokenizerConfig config = loadTokenizer(tokenizerPath);
Wordpiece wordpiece = new Wordpiece(config.tokenIDs.keySet(), config.unknownToken, config.maxInputCharsPerWord);
tokenIDs = config.tokenIDs;
unknownToken = config.unknownToken;
classificationToken = config.classificationToken;
separatorToken = config.separatorToken;
tokenizer = new WordpieceTokenizer(wordpiece, new WordpieceBasicTokenizer(), config.lowercase, config.stripAccents, Collections.emptySet());
} catch (OrtException e) {
throw new PropertyException(e, "", "modelPath", "Failed to load model, ORT threw: ");
} catch (IOException e) {
throw new PropertyException(e, "", "tokenizerPath", "Failed to load tokenizer, Jackson threw: ");
}
}
use of com.oracle.labs.mlrg.olcut.config.PropertyException in project tribuo by oracle.
the class KMeansTrainer method postConfig.
/**
* Used by the OLCUT configuration system, and should not be called by external code.
*/
@Override
public synchronized void postConfig() {
this.rng = new SplittableRandom(seed);
if (this.distanceType != null) {
if (this.distType != null) {
throw new PropertyException("distType", "Both distType and distanceType must not both be set.");
} else {
this.distType = this.distanceType.getDistanceType();
this.distanceType = null;
}
}
}
Aggregations