use of org.tensorflow.exceptions.TensorFlowException in project tribuo by oracle.
the class TensorFlowSequenceTrainer method train.
@Override
public SequenceModel<T> train(SequenceDataset<T> examples, Map<String, Provenance> runProvenance) {
// Creates a new RNG, adds one to the invocation count.
SplittableRandom localRNG;
TrainerProvenance provenance;
synchronized (this) {
localRNG = rng.split();
provenance = getProvenance();
trainInvocationCounter++;
}
ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
ImmutableOutputInfo<T> labelMap = examples.getOutputIDInfo();
ArrayList<SequenceExample<T>> batch = new ArrayList<>();
int[] indices = Util.randperm(examples.size(), localRNG);
try (Graph graph = new Graph();
Session session = new Session(graph)) {
//
// Load the graph def into the session.
graph.importGraphDef(graphDef);
//
// Run additional initialization routines, if needed.
preTrainingHook(session, examples);
int interval = 0;
for (int i = 0; i < epochs; i++) {
log.log(Level.INFO, "Starting epoch " + i);
// Shuffle the order in which we'll look at examples
Util.randpermInPlace(indices, localRNG);
for (int j = 0; j < examples.size(); j += minibatchSize) {
batch.clear();
for (int k = j; k < (j + minibatchSize) && k < examples.size(); k++) {
int ix = indices[k];
batch.add(examples.getExample(ix));
}
//
// Transform examples to tensors
TensorMap featureTensors = featureConverter.encode(batch, featureMap);
//
// Add supervision
TensorMap supervisionTensors = outputConverter.encode(batch, labelMap);
//
// Add any additional training hyperparameter values to the feed dict.
TensorMap parameterTensors = getHyperparameterFeed();
//
// Populate the runner.
Session.Runner runner = session.runner();
featureTensors.feedInto(runner);
supervisionTensors.feedInto(runner);
parameterTensors.feedInto(runner);
// Run a training batch.
try (Tensor loss = runner.addTarget(trainOp).fetch(getLossOp).run().get(0)) {
if (interval % loggingInterval == 0) {
float lossVal = ((TFloat32) loss).getFloat(0);
log.info(String.format("loss %-5.6f [epoch %-2d batch %-4d #(%d - %d)/%d]", lossVal, i, interval, j, Math.min(examples.size(), j + minibatchSize), examples.size()));
}
interval++;
}
//
// Cleanup: close the tensors.
featureTensors.close();
supervisionTensors.close();
parameterTensors.close();
}
}
// This call **must** happen before the trainedGraphDef is generated.
TensorFlowUtil.annotateGraph(graph, session);
//
// Generate the trained graph def.
GraphDef trainedGraphDef = graph.toGraphDef();
Map<String, TensorFlowUtil.TensorTuple> tensorMap = TensorFlowUtil.extractMarshalledVariables(graph, session);
ModelProvenance modelProvenance = new ModelProvenance(TensorFlowSequenceModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), provenance, runProvenance);
return new TensorFlowSequenceModel<>("tf-sequence-model", modelProvenance, featureMap, labelMap, trainedGraphDef, featureConverter, outputConverter, predictOp, tensorMap);
} catch (TensorFlowException e) {
log.log(Level.SEVERE, "TensorFlow threw an error", e);
throw new IllegalStateException(e);
}
}
use of org.tensorflow.exceptions.TensorFlowException in project djl by deepjavalibrary.
the class JavacppUtils method loadSavedModelBundle.
@SuppressWarnings({ "unchecked", "try" })
public static SavedModelBundle loadSavedModelBundle(String exportDir, String[] tags, ConfigProto config, RunOptions runOptions) {
try (PointerScope ignored = new PointerScope()) {
TF_Status status = TF_Status.newStatus();
// allocate parameters for TF_LoadSessionFromSavedModel
TF_SessionOptions opts = TF_SessionOptions.newSessionOptions();
if (config != null) {
BytePointer configBytes = new BytePointer(config.toByteArray());
tensorflow.TF_SetConfig(opts, configBytes, configBytes.capacity(), status);
status.throwExceptionIfNotOK();
}
TF_Buffer runOpts = TF_Buffer.newBufferFromString(runOptions);
// load the session
TF_Graph graphHandle = AbstractTF_Graph.newGraph().retainReference();
TF_Buffer metaGraphDef = TF_Buffer.newBuffer();
TF_Session sessionHandle = tensorflow.TF_LoadSessionFromSavedModel(opts, runOpts, new BytePointer(exportDir), new PointerPointer<>(tags), tags.length, graphHandle, metaGraphDef, status);
status.throwExceptionIfNotOK();
// handle the result
try {
return new SavedModelBundle(graphHandle, sessionHandle, MetaGraphDef.parseFrom(metaGraphDef.dataAsByteBuffer()));
} catch (InvalidProtocolBufferException e) {
throw new TensorFlowException("Cannot parse MetaGraphDef protocol buffer", e);
}
}
}
use of org.tensorflow.exceptions.TensorFlowException in project tribuo by oracle.
the class TensorFlowSavedModelExternalModel method createTensorflowModel.
/**
* Creates a TensorflowSavedModelExternalModel by loading in a {@code SavedModelBundle}.
* <p>
* Throws {@link IllegalArgumentException} if the model bundle could not be loaded.
* @param factory The output factory.
* @param featureMapping The feature mapping between Tribuo's names and the TF integer ids.
* @param outputMapping The output mapping between Tribuo's names and the TF integer ids.
* @param outputName The name of the output tensor.
* @param featureConverter The feature transformation function.
* @param outputConverter The output transformation function.
* @param bundleDirectory The path to load the saved model bundle from.
* @param <T> The type of the output.
* @return The TF model wrapped in a Tribuo {@link ExternalModel}.
*/
public static <T extends Output<T>> TensorFlowSavedModelExternalModel<T> createTensorflowModel(OutputFactory<T> factory, Map<String, Integer> featureMapping, Map<T, Integer> outputMapping, String outputName, FeatureConverter featureConverter, OutputConverter<T> outputConverter, String bundleDirectory) {
try {
Path path = Paths.get(bundleDirectory);
URL provenanceLocation = path.toUri().toURL();
ImmutableFeatureMap featureMap = ExternalModel.createFeatureMap(featureMapping.keySet());
ImmutableOutputInfo<T> outputInfo = ExternalModel.createOutputInfo(factory, outputMapping);
OffsetDateTime now = OffsetDateTime.now();
ExternalTrainerProvenance trainerProvenance = new ExternalTrainerProvenance(provenanceLocation);
DatasetProvenance datasetProvenance = new ExternalDatasetProvenance("unknown-external-data", factory, false, featureMapping.size(), outputMapping.size());
ModelProvenance provenance = new ModelProvenance(TensorFlowSavedModelExternalModel.class.getName(), now, datasetProvenance, trainerProvenance);
return new TensorFlowSavedModelExternalModel<>("tf-saved-model-bundle", provenance, featureMap, outputInfo, featureMapping, bundleDirectory, outputName, featureConverter, outputConverter);
} catch (IOException | TensorFlowException e) {
throw new IllegalArgumentException("Unable to load model from path " + bundleDirectory, e);
}
}
use of org.tensorflow.exceptions.TensorFlowException in project tribuo by oracle.
the class TensorFlowTrainer method train.
@Override
public TensorFlowModel<T> train(Dataset<T> examples, Map<String, Provenance> runProvenance, int invocationCount) {
ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
ImmutableOutputInfo<T> outputInfo = examples.getOutputIDInfo();
ArrayList<Example<T>> batch = new ArrayList<>();
Path curCheckpointPath;
synchronized (this) {
if (invocationCount != INCREMENT_INVOCATION_COUNT) {
setInvocationCount(invocationCount);
}
curCheckpointPath = checkpointPath != null ? Paths.get(checkpointPath.toString(), "invocation-" + trainInvocationCounter, "tribuo") : null;
trainInvocationCounter++;
}
ConfigProto.Builder configBuilder = ConfigProto.newBuilder();
if (interOpParallelism > -1) {
configBuilder.setInterOpParallelismThreads(interOpParallelism);
}
if (intraOpParallelism > -1) {
configBuilder.setIntraOpParallelismThreads(intraOpParallelism);
}
ConfigProto config = configBuilder.build();
try (Graph graph = new Graph();
Session session = new Session(graph, config)) {
// Load in the graph definition
graph.importGraphDef(graphDef);
Ops tf = Ops.create(graph).withName("tribuo-internal");
// Lookup output op
Operand<TNumber> intermediateOutputOp = graph.operation(outputName).output(0);
// Validate that the output op is the right shape
Shape outputShape = intermediateOutputOp.shape();
Shape expectedShape = Shape.of(trainBatchSize, outputInfo.size());
if (!outputShape.isCompatibleWith(expectedShape)) {
throw new IllegalArgumentException("Incompatible output shape, expected " + expectedShape.toString() + " found " + outputShape.toString());
}
// Add target placeholder
Placeholder<? extends TNumber> targetPlaceholder = tf.placeholder(TFloat32.class, Placeholder.shape(Shape.of(trainBatchSize, outputInfo.size())));
// Add loss, optimiser and output
Op outputOp = outputConverter.outputTransformFunction().apply(tf, intermediateOutputOp);
Operand<TNumber> lossOp = outputConverter.loss().apply(tf, new Pair<>(targetPlaceholder, intermediateOutputOp));
Op optimiser = optimiserEnum.applyOptimiser(graph, lossOp, gradientParams);
// Initalise all the things
session.initialize();
logger.info("Initialised the model parameters");
int interval = 0;
for (int i = 0; i < epochs; i++) {
logger.log(Level.INFO, "Starting epoch " + i);
for (int j = 0; j < examples.size(); j += trainBatchSize) {
batch.clear();
for (int k = j; k < (j + trainBatchSize) && k < examples.size(); k++) {
batch.add(examples.getExample(k));
}
try (TensorMap input = featureConverter.convert(batch, featureMap);
Tensor target = outputConverter.convertToTensor(batch, outputInfo);
Tensor lossTensor = input.feedInto(session.runner()).feed(targetPlaceholder, target).addTarget(optimiser).fetch(lossOp).run().get(0)) {
if ((loggingInterval != -1) && (interval % loggingInterval == 0)) {
logger.log(Level.INFO, "Training loss at itr " + interval + " = " + ((TFloat32) lossTensor).getFloat());
}
}
interval++;
}
}
// Setup the model serialization infrastructure.
// **Must** happen before the trainedGraphDef is generated.
// We unconditionally annotate the Graph for Tribuo's serialization
TensorFlowUtil.annotateGraph(graph, session);
// If it's a checkpoint we also save it out.
if (modelFormat == TFModelFormat.CHECKPOINT) {
session.save(curCheckpointPath.toString());
}
GraphDef trainedGraphDef = graph.toGraphDef();
ModelProvenance modelProvenance = new ModelProvenance(TensorFlowModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), getProvenance(), runProvenance);
TensorFlowModel<T> tfModel;
switch(modelFormat) {
case TRIBUO_NATIVE:
Map<String, TensorFlowUtil.TensorTuple> tensorMap = TensorFlowUtil.extractMarshalledVariables(graph, session);
tfModel = new TensorFlowNativeModel<>("tf-native-model", modelProvenance, featureMap, outputInfo, trainedGraphDef, tensorMap, testBatchSize, outputOp.op().name(), featureConverter, outputConverter);
break;
case CHECKPOINT:
tfModel = new TensorFlowCheckpointModel<>("tf-checkpoint-model", modelProvenance, featureMap, outputInfo, trainedGraphDef, curCheckpointPath.getParent().toString(), curCheckpointPath.getFileName().toString(), testBatchSize, outputOp.op().name(), featureConverter, outputConverter);
break;
default:
throw new IllegalStateException("Unexpected enum constant " + modelFormat);
}
return tfModel;
} catch (TensorFlowException e) {
logger.log(Level.SEVERE, "TensorFlow threw an error", e);
throw new IllegalStateException(e);
}
}
use of org.tensorflow.exceptions.TensorFlowException in project tribuo by oracle.
the class TensorFlowCheckpointModel method readObject.
private void readObject(java.io.ObjectInputStream in) throws IOException, ClassNotFoundException {
in.defaultReadObject();
byte[] modelBytes = (byte[]) in.readObject();
this.modelGraph = new Graph();
this.modelGraph.importGraphDef(GraphDef.parseFrom(modelBytes));
this.session = new Session(modelGraph);
try {
session.restore(resolvePath());
initialized = true;
} catch (TensorFlowException e) {
logger.log(Level.WARNING, "Failed to initialise model after deserialization, attempted to load from " + checkpointDirectory, e);
}
}
Aggregations