Search in sources :

Example 91 with ComputationGraph

use of org.deeplearning4j.nn.graph.ComputationGraph in project deeplearning4j by deeplearning4j.

the class DL4jServeRouteBuilder method configure.

/**
     * <b>Called on initialization to build the routes using the fluent builder syntax.</b>
     * <p/>
     * This is a central method for RouteBuilder implementations to implement
     * the routes using the Java fluent builder syntax.
     *
     * @throws Exception can be thrown during configuration
     */
@Override
public void configure() throws Exception {
    if (groupId == null)
        groupId = "dl4j-serving";
    if (zooKeeperHost == null)
        zooKeeperHost = "localhost";
    String kafkaUri = String.format("kafka:%s?topic=%s&groupId=%s", kafkaBroker, consumingTopic, groupId);
    if (beforeProcessor == null) {
        beforeProcessor = new Processor() {

            @Override
            public void process(Exchange exchange) throws Exception {
            }
        };
    }
    from(kafkaUri).process(beforeProcessor).process(new Processor() {

        @Override
        public void process(Exchange exchange) throws Exception {
            INDArray predict;
            if (exchange.getIn().getBody() instanceof byte[]) {
                byte[] o = (byte[]) exchange.getIn().getBody();
                byte[] arr = Base64.decodeBase64(new String(o));
                ByteArrayInputStream bis = new ByteArrayInputStream(arr);
                DataInputStream dis = new DataInputStream(bis);
                predict = Nd4j.read(dis);
            } else
                predict = (INDArray) exchange.getIn().getBody();
            if (computationGraph) {
                ComputationGraph graph = ModelSerializer.restoreComputationGraph(modelUri);
                INDArray[] output = graph.output(predict);
                exchange.getOut().setBody(output);
                exchange.getIn().setBody(output);
            } else {
                MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(modelUri);
                INDArray output = network.output(predict);
                exchange.getOut().setBody(output);
                exchange.getIn().setBody(output);
            }
        }
    }).process(finalProcessor).to(outputUri);
}
Also used : Exchange(org.apache.camel.Exchange) Processor(org.apache.camel.Processor) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ByteArrayInputStream(java.io.ByteArrayInputStream) DataInputStream(java.io.DataInputStream) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork)

Example 92 with ComputationGraph

use of org.deeplearning4j.nn.graph.ComputationGraph in project deeplearning4j by deeplearning4j.

the class BaseOptimizer method incrementIterationCount.

public static void incrementIterationCount(Model model, int incrementBy) {
    if (model instanceof MultiLayerNetwork) {
        MultiLayerConfiguration conf = ((MultiLayerNetwork) model).getLayerWiseConfigurations();
        conf.setIterationCount(conf.getIterationCount() + incrementBy);
    } else if (model instanceof ComputationGraph) {
        ComputationGraphConfiguration conf = ((ComputationGraph) model).getConfiguration();
        conf.setIterationCount(conf.getIterationCount() + incrementBy);
    } else {
        model.conf().setIterationCount(model.conf().getIterationCount() + incrementBy);
    }
}
Also used : MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork)

Example 93 with ComputationGraph

use of org.deeplearning4j.nn.graph.ComputationGraph in project deeplearning4j by deeplearning4j.

the class BaseOptimizer method updateGradientAccordingToParams.

@Override
public void updateGradientAccordingToParams(Gradient gradient, Model model, int batchSize) {
    if (model instanceof ComputationGraph) {
        ComputationGraph graph = (ComputationGraph) model;
        if (computationGraphUpdater == null) {
            computationGraphUpdater = new ComputationGraphUpdater(graph);
        }
        computationGraphUpdater.update(graph, gradient, getIterationCount(model), batchSize);
    } else {
        if (updater == null)
            updater = UpdaterCreator.getUpdater(model);
        Layer layer = (Layer) model;
        updater.update(layer, gradient, getIterationCount(model), batchSize);
    }
}
Also used : ComputationGraphUpdater(org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) Layer(org.deeplearning4j.nn.api.Layer)

Example 94 with ComputationGraph

use of org.deeplearning4j.nn.graph.ComputationGraph in project deeplearning4j by deeplearning4j.

the class ParameterServerParallelWrapper method init.

private void init(Object iterator) {
    if (numEpochs < 1)
        throw new IllegalStateException("numEpochs must be >= 1");
    //TODO: make this efficient
    if (iterator instanceof DataSetIterator) {
        DataSetIterator dataSetIterator = (DataSetIterator) iterator;
        numUpdatesPerEpoch = numUpdatesPerEpoch(dataSetIterator);
    } else if (iterator instanceof MultiDataSetIterator) {
        MultiDataSetIterator iterator1 = (MultiDataSetIterator) iterator;
        numUpdatesPerEpoch = numUpdatesPerEpoch(iterator1);
    } else
        throw new IllegalArgumentException("Illegal type of object passed in for initialization. Must be of type DataSetIterator or MultiDataSetIterator");
    mediaDriverContext = new MediaDriver.Context();
    mediaDriver = MediaDriver.launchEmbedded(mediaDriverContext);
    parameterServerNode = new ParameterServerNode(mediaDriver, statusServerPort, numWorkers);
    running = new AtomicBoolean(true);
    if (parameterServerArgs == null)
        parameterServerArgs = new String[] { "-m", "true", "-s", "1," + String.valueOf(model.numParams()), "-p", "40323", "-h", "localhost", "-id", "11", "-md", mediaDriver.aeronDirectoryName(), "-sh", "localhost", "-sp", String.valueOf(statusServerPort), "-u", String.valueOf(numUpdatesPerEpoch) };
    if (numWorkers == 0)
        numWorkers = Runtime.getRuntime().availableProcessors();
    linkedBlockingQueue = new LinkedBlockingQueue<>(numWorkers);
    //pass through args for the parameter server subscriber
    parameterServerNode.runMain(parameterServerArgs);
    while (!parameterServerNode.subscriberLaunched()) {
        try {
            Thread.sleep(10000);
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }
    }
    try {
        Thread.sleep(10000);
    } catch (InterruptedException e) {
        Thread.currentThread().interrupt();
    }
    log.info("Parameter server started");
    parameterServerClient = new Trainer[numWorkers];
    executorService = Executors.newFixedThreadPool(numWorkers);
    for (int i = 0; i < numWorkers; i++) {
        Model model = null;
        if (this.model instanceof ComputationGraph) {
            ComputationGraph computationGraph = (ComputationGraph) this.model;
            model = computationGraph.clone();
        } else if (this.model instanceof MultiLayerNetwork) {
            MultiLayerNetwork multiLayerNetwork = (MultiLayerNetwork) this.model;
            model = multiLayerNetwork.clone();
        }
        parameterServerClient[i] = new Trainer(ParameterServerClient.builder().aeron(parameterServerNode.getAeron()).ndarrayRetrieveUrl(parameterServerNode.getSubscriber()[i].getResponder().connectionUrl()).ndarraySendUrl(parameterServerNode.getSubscriber()[i].getSubscriber().connectionUrl()).subscriberHost("localhost").masterStatusHost("localhost").masterStatusPort(statusServerPort).subscriberPort(40625 + i).subscriberStream(12 + i).build(), running, linkedBlockingQueue, model);
        final int j = i;
        executorService.submit(() -> parameterServerClient[j].start());
    }
    init = true;
    log.info("Initialized wrapper");
}
Also used : ParameterServerNode(org.nd4j.parameterserver.node.ParameterServerNode) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) AsyncMultiDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator) AtomicBoolean(java.util.concurrent.atomic.AtomicBoolean) MediaDriver(io.aeron.driver.MediaDriver) Model(org.deeplearning4j.nn.api.Model) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) AsyncDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncDataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) AsyncMultiDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator)

Example 95 with ComputationGraph

use of org.deeplearning4j.nn.graph.ComputationGraph in project deeplearning4j by deeplearning4j.

the class CGVaeReconstructionProbWithKeyFunction method getVaeLayer.

@Override
public VariationalAutoencoder getVaeLayer() {
    ComputationGraph network = new ComputationGraph(ComputationGraphConfiguration.fromJson((String) jsonConfig.getValue()));
    network.init();
    INDArray val = ((INDArray) params.value()).unsafeDuplication();
    if (val.length() != network.numParams(false))
        throw new IllegalStateException("Network did not have same number of parameters as the broadcasted set parameters");
    network.setParams(val);
    Layer l = network.getLayer(0);
    if (!(l instanceof VariationalAutoencoder)) {
        throw new RuntimeException("Cannot use CGVaeReconstructionProbWithKeyFunction on network that doesn't have a VAE " + "layer as layer 0. Layer type: " + l.getClass());
    }
    return (VariationalAutoencoder) l;
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) VariationalAutoencoder(org.deeplearning4j.nn.layers.variational.VariationalAutoencoder) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) Layer(org.deeplearning4j.nn.api.Layer)

Aggregations

ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)109 Test (org.junit.Test)73 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)63 INDArray (org.nd4j.linalg.api.ndarray.INDArray)62 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)36 DataSet (org.nd4j.linalg.dataset.DataSet)25 NormalDistribution (org.deeplearning4j.nn.conf.distribution.NormalDistribution)22 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)21 DenseLayer (org.deeplearning4j.nn.conf.layers.DenseLayer)19 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)19 ScoreIterationListener (org.deeplearning4j.optimize.listeners.ScoreIterationListener)17 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)17 IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)14 Layer (org.deeplearning4j.nn.api.Layer)14 Random (java.util.Random)11 InMemoryModelSaver (org.deeplearning4j.earlystopping.saver.InMemoryModelSaver)10 MaxEpochsTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition)10 TrainingMaster (org.deeplearning4j.spark.api.TrainingMaster)10 MaxTimeIterationTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition)9 GridExecutioner (org.nd4j.linalg.api.ops.executioner.GridExecutioner)9