Search in sources :

Example 1 with TensorFlowException

use of org.tensorflow.exceptions.TensorFlowException in project tribuo by oracle.

the class TensorFlowSequenceTrainer method train.

@Override
public SequenceModel<T> train(SequenceDataset<T> examples, Map<String, Provenance> runProvenance) {
    // Creates a new RNG, adds one to the invocation count.
    SplittableRandom localRNG;
    TrainerProvenance provenance;
    synchronized (this) {
        localRNG = rng.split();
        provenance = getProvenance();
        trainInvocationCounter++;
    }
    ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
    ImmutableOutputInfo<T> labelMap = examples.getOutputIDInfo();
    ArrayList<SequenceExample<T>> batch = new ArrayList<>();
    int[] indices = Util.randperm(examples.size(), localRNG);
    try (Graph graph = new Graph();
        Session session = new Session(graph)) {
        // 
        // Load the graph def into the session.
        graph.importGraphDef(graphDef);
        // 
        // Run additional initialization routines, if needed.
        preTrainingHook(session, examples);
        int interval = 0;
        for (int i = 0; i < epochs; i++) {
            log.log(Level.INFO, "Starting epoch " + i);
            // Shuffle the order in which we'll look at examples
            Util.randpermInPlace(indices, localRNG);
            for (int j = 0; j < examples.size(); j += minibatchSize) {
                batch.clear();
                for (int k = j; k < (j + minibatchSize) && k < examples.size(); k++) {
                    int ix = indices[k];
                    batch.add(examples.getExample(ix));
                }
                // 
                // Transform examples to tensors
                TensorMap featureTensors = featureConverter.encode(batch, featureMap);
                // 
                // Add supervision
                TensorMap supervisionTensors = outputConverter.encode(batch, labelMap);
                // 
                // Add any additional training hyperparameter values to the feed dict.
                TensorMap parameterTensors = getHyperparameterFeed();
                // 
                // Populate the runner.
                Session.Runner runner = session.runner();
                featureTensors.feedInto(runner);
                supervisionTensors.feedInto(runner);
                parameterTensors.feedInto(runner);
                // Run a training batch.
                try (Tensor loss = runner.addTarget(trainOp).fetch(getLossOp).run().get(0)) {
                    if (interval % loggingInterval == 0) {
                        float lossVal = ((TFloat32) loss).getFloat(0);
                        log.info(String.format("loss %-5.6f [epoch %-2d batch %-4d #(%d - %d)/%d]", lossVal, i, interval, j, Math.min(examples.size(), j + minibatchSize), examples.size()));
                    }
                    interval++;
                }
                // 
                // Cleanup: close the tensors.
                featureTensors.close();
                supervisionTensors.close();
                parameterTensors.close();
            }
        }
        // This call **must** happen before the trainedGraphDef is generated.
        TensorFlowUtil.annotateGraph(graph, session);
        // 
        // Generate the trained graph def.
        GraphDef trainedGraphDef = graph.toGraphDef();
        Map<String, TensorFlowUtil.TensorTuple> tensorMap = TensorFlowUtil.extractMarshalledVariables(graph, session);
        ModelProvenance modelProvenance = new ModelProvenance(TensorFlowSequenceModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), provenance, runProvenance);
        return new TensorFlowSequenceModel<>("tf-sequence-model", modelProvenance, featureMap, labelMap, trainedGraphDef, featureConverter, outputConverter, predictOp, tensorMap);
    } catch (TensorFlowException e) {
        log.log(Level.SEVERE, "TensorFlow threw an error", e);
        throw new IllegalStateException(e);
    }
}
Also used : ArrayList(java.util.ArrayList) SequenceExample(org.tribuo.sequence.SequenceExample) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) SkeletalTrainerProvenance(org.tribuo.provenance.SkeletalTrainerProvenance) TrainerProvenance(org.tribuo.provenance.TrainerProvenance) Tensor(org.tensorflow.Tensor) TensorMap(org.tribuo.interop.tensorflow.TensorMap) ModelProvenance(org.tribuo.provenance.ModelProvenance) TFloat32(org.tensorflow.types.TFloat32) GraphDef(org.tensorflow.proto.framework.GraphDef) Graph(org.tensorflow.Graph) SplittableRandom(java.util.SplittableRandom) TensorFlowException(org.tensorflow.exceptions.TensorFlowException) Session(org.tensorflow.Session)

Example 2 with TensorFlowException

use of org.tensorflow.exceptions.TensorFlowException in project djl by deepjavalibrary.

the class JavacppUtils method loadSavedModelBundle.

@SuppressWarnings({ "unchecked", "try" })
public static SavedModelBundle loadSavedModelBundle(String exportDir, String[] tags, ConfigProto config, RunOptions runOptions) {
    try (PointerScope ignored = new PointerScope()) {
        TF_Status status = TF_Status.newStatus();
        // allocate parameters for TF_LoadSessionFromSavedModel
        TF_SessionOptions opts = TF_SessionOptions.newSessionOptions();
        if (config != null) {
            BytePointer configBytes = new BytePointer(config.toByteArray());
            tensorflow.TF_SetConfig(opts, configBytes, configBytes.capacity(), status);
            status.throwExceptionIfNotOK();
        }
        TF_Buffer runOpts = TF_Buffer.newBufferFromString(runOptions);
        // load the session
        TF_Graph graphHandle = AbstractTF_Graph.newGraph().retainReference();
        TF_Buffer metaGraphDef = TF_Buffer.newBuffer();
        TF_Session sessionHandle = tensorflow.TF_LoadSessionFromSavedModel(opts, runOpts, new BytePointer(exportDir), new PointerPointer<>(tags), tags.length, graphHandle, metaGraphDef, status);
        status.throwExceptionIfNotOK();
        // handle the result
        try {
            return new SavedModelBundle(graphHandle, sessionHandle, MetaGraphDef.parseFrom(metaGraphDef.dataAsByteBuffer()));
        } catch (InvalidProtocolBufferException e) {
            throw new TensorFlowException("Cannot parse MetaGraphDef protocol buffer", e);
        }
    }
}
Also used : TF_Graph(org.tensorflow.internal.c_api.TF_Graph) AbstractTF_Graph(org.tensorflow.internal.c_api.AbstractTF_Graph) TF_Session(org.tensorflow.internal.c_api.TF_Session) TF_Status(org.tensorflow.internal.c_api.TF_Status) BytePointer(org.bytedeco.javacpp.BytePointer) InvalidProtocolBufferException(com.google.protobuf.InvalidProtocolBufferException) TF_Buffer(org.tensorflow.internal.c_api.TF_Buffer) SavedModelBundle(ai.djl.tensorflow.engine.SavedModelBundle) PointerScope(org.bytedeco.javacpp.PointerScope) TensorFlowException(org.tensorflow.exceptions.TensorFlowException) TF_SessionOptions(org.tensorflow.internal.c_api.TF_SessionOptions)

Example 3 with TensorFlowException

use of org.tensorflow.exceptions.TensorFlowException in project tribuo by oracle.

the class TensorFlowSavedModelExternalModel method createTensorflowModel.

/**
 * Creates a TensorflowSavedModelExternalModel by loading in a {@code SavedModelBundle}.
 * <p>
 * Throws {@link IllegalArgumentException} if the model bundle could not be loaded.
 * @param factory The output factory.
 * @param featureMapping The feature mapping between Tribuo's names and the TF integer ids.
 * @param outputMapping The output mapping between Tribuo's names and the TF integer ids.
 * @param outputName The name of the output tensor.
 * @param featureConverter The feature transformation function.
 * @param outputConverter The output transformation function.
 * @param bundleDirectory The path to load the saved model bundle from.
 * @param <T> The type of the output.
 * @return The TF model wrapped in a Tribuo {@link ExternalModel}.
 */
public static <T extends Output<T>> TensorFlowSavedModelExternalModel<T> createTensorflowModel(OutputFactory<T> factory, Map<String, Integer> featureMapping, Map<T, Integer> outputMapping, String outputName, FeatureConverter featureConverter, OutputConverter<T> outputConverter, String bundleDirectory) {
    try {
        Path path = Paths.get(bundleDirectory);
        URL provenanceLocation = path.toUri().toURL();
        ImmutableFeatureMap featureMap = ExternalModel.createFeatureMap(featureMapping.keySet());
        ImmutableOutputInfo<T> outputInfo = ExternalModel.createOutputInfo(factory, outputMapping);
        OffsetDateTime now = OffsetDateTime.now();
        ExternalTrainerProvenance trainerProvenance = new ExternalTrainerProvenance(provenanceLocation);
        DatasetProvenance datasetProvenance = new ExternalDatasetProvenance("unknown-external-data", factory, false, featureMapping.size(), outputMapping.size());
        ModelProvenance provenance = new ModelProvenance(TensorFlowSavedModelExternalModel.class.getName(), now, datasetProvenance, trainerProvenance);
        return new TensorFlowSavedModelExternalModel<>("tf-saved-model-bundle", provenance, featureMap, outputInfo, featureMapping, bundleDirectory, outputName, featureConverter, outputConverter);
    } catch (IOException | TensorFlowException e) {
        throw new IllegalArgumentException("Unable to load model from path " + bundleDirectory, e);
    }
}
Also used : Path(java.nio.file.Path) ExternalDatasetProvenance(org.tribuo.interop.ExternalDatasetProvenance) ExternalTrainerProvenance(org.tribuo.interop.ExternalTrainerProvenance) ModelProvenance(org.tribuo.provenance.ModelProvenance) DatasetProvenance(org.tribuo.provenance.DatasetProvenance) ExternalDatasetProvenance(org.tribuo.interop.ExternalDatasetProvenance) IOException(java.io.IOException) URL(java.net.URL) OffsetDateTime(java.time.OffsetDateTime) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) TensorFlowException(org.tensorflow.exceptions.TensorFlowException)

Example 4 with TensorFlowException

use of org.tensorflow.exceptions.TensorFlowException in project tribuo by oracle.

the class TensorFlowTrainer method train.

@Override
public TensorFlowModel<T> train(Dataset<T> examples, Map<String, Provenance> runProvenance, int invocationCount) {
    ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
    ImmutableOutputInfo<T> outputInfo = examples.getOutputIDInfo();
    ArrayList<Example<T>> batch = new ArrayList<>();
    Path curCheckpointPath;
    synchronized (this) {
        if (invocationCount != INCREMENT_INVOCATION_COUNT) {
            setInvocationCount(invocationCount);
        }
        curCheckpointPath = checkpointPath != null ? Paths.get(checkpointPath.toString(), "invocation-" + trainInvocationCounter, "tribuo") : null;
        trainInvocationCounter++;
    }
    ConfigProto.Builder configBuilder = ConfigProto.newBuilder();
    if (interOpParallelism > -1) {
        configBuilder.setInterOpParallelismThreads(interOpParallelism);
    }
    if (intraOpParallelism > -1) {
        configBuilder.setIntraOpParallelismThreads(intraOpParallelism);
    }
    ConfigProto config = configBuilder.build();
    try (Graph graph = new Graph();
        Session session = new Session(graph, config)) {
        // Load in the graph definition
        graph.importGraphDef(graphDef);
        Ops tf = Ops.create(graph).withName("tribuo-internal");
        // Lookup output op
        Operand<TNumber> intermediateOutputOp = graph.operation(outputName).output(0);
        // Validate that the output op is the right shape
        Shape outputShape = intermediateOutputOp.shape();
        Shape expectedShape = Shape.of(trainBatchSize, outputInfo.size());
        if (!outputShape.isCompatibleWith(expectedShape)) {
            throw new IllegalArgumentException("Incompatible output shape, expected " + expectedShape.toString() + " found " + outputShape.toString());
        }
        // Add target placeholder
        Placeholder<? extends TNumber> targetPlaceholder = tf.placeholder(TFloat32.class, Placeholder.shape(Shape.of(trainBatchSize, outputInfo.size())));
        // Add loss, optimiser and output
        Op outputOp = outputConverter.outputTransformFunction().apply(tf, intermediateOutputOp);
        Operand<TNumber> lossOp = outputConverter.loss().apply(tf, new Pair<>(targetPlaceholder, intermediateOutputOp));
        Op optimiser = optimiserEnum.applyOptimiser(graph, lossOp, gradientParams);
        // Initalise all the things
        session.initialize();
        logger.info("Initialised the model parameters");
        int interval = 0;
        for (int i = 0; i < epochs; i++) {
            logger.log(Level.INFO, "Starting epoch " + i);
            for (int j = 0; j < examples.size(); j += trainBatchSize) {
                batch.clear();
                for (int k = j; k < (j + trainBatchSize) && k < examples.size(); k++) {
                    batch.add(examples.getExample(k));
                }
                try (TensorMap input = featureConverter.convert(batch, featureMap);
                    Tensor target = outputConverter.convertToTensor(batch, outputInfo);
                    Tensor lossTensor = input.feedInto(session.runner()).feed(targetPlaceholder, target).addTarget(optimiser).fetch(lossOp).run().get(0)) {
                    if ((loggingInterval != -1) && (interval % loggingInterval == 0)) {
                        logger.log(Level.INFO, "Training loss at itr " + interval + " = " + ((TFloat32) lossTensor).getFloat());
                    }
                }
                interval++;
            }
        }
        // Setup the model serialization infrastructure.
        // **Must** happen before the trainedGraphDef is generated.
        // We unconditionally annotate the Graph for Tribuo's serialization
        TensorFlowUtil.annotateGraph(graph, session);
        // If it's a checkpoint we also save it out.
        if (modelFormat == TFModelFormat.CHECKPOINT) {
            session.save(curCheckpointPath.toString());
        }
        GraphDef trainedGraphDef = graph.toGraphDef();
        ModelProvenance modelProvenance = new ModelProvenance(TensorFlowModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), getProvenance(), runProvenance);
        TensorFlowModel<T> tfModel;
        switch(modelFormat) {
            case TRIBUO_NATIVE:
                Map<String, TensorFlowUtil.TensorTuple> tensorMap = TensorFlowUtil.extractMarshalledVariables(graph, session);
                tfModel = new TensorFlowNativeModel<>("tf-native-model", modelProvenance, featureMap, outputInfo, trainedGraphDef, tensorMap, testBatchSize, outputOp.op().name(), featureConverter, outputConverter);
                break;
            case CHECKPOINT:
                tfModel = new TensorFlowCheckpointModel<>("tf-checkpoint-model", modelProvenance, featureMap, outputInfo, trainedGraphDef, curCheckpointPath.getParent().toString(), curCheckpointPath.getFileName().toString(), testBatchSize, outputOp.op().name(), featureConverter, outputConverter);
                break;
            default:
                throw new IllegalStateException("Unexpected enum constant " + modelFormat);
        }
        return tfModel;
    } catch (TensorFlowException e) {
        logger.log(Level.SEVERE, "TensorFlow threw an error", e);
        throw new IllegalStateException(e);
    }
}
Also used : Op(org.tensorflow.op.Op) Shape(org.tensorflow.ndarray.Shape) ConfigProto(org.tensorflow.proto.framework.ConfigProto) ArrayList(java.util.ArrayList) Ops(org.tensorflow.op.Ops) Example(org.tribuo.Example) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) Path(java.nio.file.Path) Tensor(org.tensorflow.Tensor) ModelProvenance(org.tribuo.provenance.ModelProvenance) TFloat32(org.tensorflow.types.TFloat32) GraphDef(org.tensorflow.proto.framework.GraphDef) Graph(org.tensorflow.Graph) TNumber(org.tensorflow.types.family.TNumber) TensorFlowException(org.tensorflow.exceptions.TensorFlowException) Session(org.tensorflow.Session)

Example 5 with TensorFlowException

use of org.tensorflow.exceptions.TensorFlowException in project tribuo by oracle.

the class TensorFlowCheckpointModel method readObject.

private void readObject(java.io.ObjectInputStream in) throws IOException, ClassNotFoundException {
    in.defaultReadObject();
    byte[] modelBytes = (byte[]) in.readObject();
    this.modelGraph = new Graph();
    this.modelGraph.importGraphDef(GraphDef.parseFrom(modelBytes));
    this.session = new Session(modelGraph);
    try {
        session.restore(resolvePath());
        initialized = true;
    } catch (TensorFlowException e) {
        logger.log(Level.WARNING, "Failed to initialise model after deserialization, attempted to load from " + checkpointDirectory, e);
    }
}
Also used : Graph(org.tensorflow.Graph) TensorFlowException(org.tensorflow.exceptions.TensorFlowException) Session(org.tensorflow.Session)

Aggregations

TensorFlowException (org.tensorflow.exceptions.TensorFlowException)5 Graph (org.tensorflow.Graph)3 Session (org.tensorflow.Session)3 ImmutableFeatureMap (org.tribuo.ImmutableFeatureMap)3 ModelProvenance (org.tribuo.provenance.ModelProvenance)3 Path (java.nio.file.Path)2 ArrayList (java.util.ArrayList)2 Tensor (org.tensorflow.Tensor)2 GraphDef (org.tensorflow.proto.framework.GraphDef)2 TFloat32 (org.tensorflow.types.TFloat32)2 SavedModelBundle (ai.djl.tensorflow.engine.SavedModelBundle)1 InvalidProtocolBufferException (com.google.protobuf.InvalidProtocolBufferException)1 IOException (java.io.IOException)1 URL (java.net.URL)1 OffsetDateTime (java.time.OffsetDateTime)1 SplittableRandom (java.util.SplittableRandom)1 BytePointer (org.bytedeco.javacpp.BytePointer)1 PointerScope (org.bytedeco.javacpp.PointerScope)1 AbstractTF_Graph (org.tensorflow.internal.c_api.AbstractTF_Graph)1 TF_Buffer (org.tensorflow.internal.c_api.TF_Buffer)1