Search in sources :

Example 6 with Model

use of org.deeplearning4j.nn.api.Model in project deeplearning4j by deeplearning4j.

the class TestOptimizers method testRosenbrockFnMultipleStepsHelper.

private static void testRosenbrockFnMultipleStepsHelper(OptimizationAlgorithm oa, int nOptIter, int maxNumLineSearchIter) {
    double[] scores = new double[nOptIter + 1];
    for (int i = 0; i <= nOptIter; i++) {
        NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().maxNumLineSearchIterations(maxNumLineSearchIter).iterations(i).stepFunction(new org.deeplearning4j.nn.conf.stepfunctions.NegativeDefaultStepFunction()).learningRate(1e-1).layer(new RBM.Builder().nIn(1).nOut(1).updater(Updater.SGD).build()).build();
        //Normally done by ParamInitializers, but obviously that isn't done here
        conf.addVariable("W");
        Model m = new RosenbrockFunctionModel(100, conf);
        if (i == 0) {
            m.computeGradientAndScore();
            //Before optimization
            scores[0] = m.score();
        } else {
            ConvexOptimizer opt = getOptimizer(oa, conf, m);
            opt.optimize();
            m.computeGradientAndScore();
            scores[i] = m.score();
            assertTrue("NaN or infinite score: " + scores[i], !Double.isNaN(scores[i]) && !Double.isInfinite(scores[i]));
        }
    }
    if (PRINT_OPT_RESULTS) {
        System.out.println("Rosenbrock: Multiple optimization iterations ( " + nOptIter + " opt. iter.) score vs iteration, maxNumLineSearchIter= " + maxNumLineSearchIter + ": " + oa);
        System.out.println(Arrays.toString(scores));
    }
    for (int i = 1; i < scores.length; i++) {
        if (i == 1) {
            //Require at least one step of improvement
            assertTrue(scores[i] < scores[i - 1]);
        } else {
            assertTrue(scores[i] <= scores[i - 1]);
        }
    }
}
Also used : Model(org.deeplearning4j.nn.api.Model) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) NegativeDefaultStepFunction(org.deeplearning4j.optimize.stepfunctions.NegativeDefaultStepFunction) ConvexOptimizer(org.deeplearning4j.optimize.api.ConvexOptimizer)

Example 7 with Model

use of org.deeplearning4j.nn.api.Model in project deeplearning4j by deeplearning4j.

the class ModelGuesserTest method testModelGuess.

@Test
public void testModelGuess() throws Exception {
    ClassPathResource sequenceResource = new ClassPathResource("modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_model.h5");
    assertTrue(sequenceResource.exists());
    File f = getTempFile(sequenceResource);
    Model guess1 = ModelGuesser.loadModelGuess(f.getAbsolutePath());
    assumeNotNull(guess1);
    ClassPathResource sequenceResource2 = new ClassPathResource("modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_model.h5");
    assertTrue(sequenceResource2.exists());
    File f2 = getTempFile(sequenceResource);
    Model guess2 = ModelGuesser.loadModelGuess(f2.getAbsolutePath());
    assumeNotNull(guess2);
}
Also used : Model(org.deeplearning4j.nn.api.Model) File(java.io.File) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) Test(org.junit.Test)

Example 8 with Model

use of org.deeplearning4j.nn.api.Model in project deeplearning4j by deeplearning4j.

the class ParameterServerParallelWrapper method init.

private void init(Object iterator) {
    if (numEpochs < 1)
        throw new IllegalStateException("numEpochs must be >= 1");
    //TODO: make this efficient
    if (iterator instanceof DataSetIterator) {
        DataSetIterator dataSetIterator = (DataSetIterator) iterator;
        numUpdatesPerEpoch = numUpdatesPerEpoch(dataSetIterator);
    } else if (iterator instanceof MultiDataSetIterator) {
        MultiDataSetIterator iterator1 = (MultiDataSetIterator) iterator;
        numUpdatesPerEpoch = numUpdatesPerEpoch(iterator1);
    } else
        throw new IllegalArgumentException("Illegal type of object passed in for initialization. Must be of type DataSetIterator or MultiDataSetIterator");
    mediaDriverContext = new MediaDriver.Context();
    mediaDriver = MediaDriver.launchEmbedded(mediaDriverContext);
    parameterServerNode = new ParameterServerNode(mediaDriver, statusServerPort, numWorkers);
    running = new AtomicBoolean(true);
    if (parameterServerArgs == null)
        parameterServerArgs = new String[] { "-m", "true", "-s", "1," + String.valueOf(model.numParams()), "-p", "40323", "-h", "localhost", "-id", "11", "-md", mediaDriver.aeronDirectoryName(), "-sh", "localhost", "-sp", String.valueOf(statusServerPort), "-u", String.valueOf(numUpdatesPerEpoch) };
    if (numWorkers == 0)
        numWorkers = Runtime.getRuntime().availableProcessors();
    linkedBlockingQueue = new LinkedBlockingQueue<>(numWorkers);
    //pass through args for the parameter server subscriber
    parameterServerNode.runMain(parameterServerArgs);
    while (!parameterServerNode.subscriberLaunched()) {
        try {
            Thread.sleep(10000);
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }
    }
    try {
        Thread.sleep(10000);
    } catch (InterruptedException e) {
        Thread.currentThread().interrupt();
    }
    log.info("Parameter server started");
    parameterServerClient = new Trainer[numWorkers];
    executorService = Executors.newFixedThreadPool(numWorkers);
    for (int i = 0; i < numWorkers; i++) {
        Model model = null;
        if (this.model instanceof ComputationGraph) {
            ComputationGraph computationGraph = (ComputationGraph) this.model;
            model = computationGraph.clone();
        } else if (this.model instanceof MultiLayerNetwork) {
            MultiLayerNetwork multiLayerNetwork = (MultiLayerNetwork) this.model;
            model = multiLayerNetwork.clone();
        }
        parameterServerClient[i] = new Trainer(ParameterServerClient.builder().aeron(parameterServerNode.getAeron()).ndarrayRetrieveUrl(parameterServerNode.getSubscriber()[i].getResponder().connectionUrl()).ndarraySendUrl(parameterServerNode.getSubscriber()[i].getSubscriber().connectionUrl()).subscriberHost("localhost").masterStatusHost("localhost").masterStatusPort(statusServerPort).subscriberPort(40625 + i).subscriberStream(12 + i).build(), running, linkedBlockingQueue, model);
        final int j = i;
        executorService.submit(() -> parameterServerClient[j].start());
    }
    init = true;
    log.info("Initialized wrapper");
}
Also used : ParameterServerNode(org.nd4j.parameterserver.node.ParameterServerNode) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) AsyncMultiDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator) AtomicBoolean(java.util.concurrent.atomic.AtomicBoolean) MediaDriver(io.aeron.driver.MediaDriver) Model(org.deeplearning4j.nn.api.Model) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) AsyncDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncDataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) AsyncMultiDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator)

Aggregations

Model (org.deeplearning4j.nn.api.Model)8 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)5 ConvexOptimizer (org.deeplearning4j.optimize.api.ConvexOptimizer)4 File (java.io.File)2 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)2 Test (org.junit.Test)2 DefaultRandom (org.nd4j.linalg.api.rng.DefaultRandom)2 Random (org.nd4j.linalg.api.rng.Random)2 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)2 MultiDataSetIterator (org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator)2 JCommander (com.beust.jcommander.JCommander)1 ParameterException (com.beust.jcommander.ParameterException)1 MediaDriver (io.aeron.driver.MediaDriver)1 AtomicBoolean (java.util.concurrent.atomic.AtomicBoolean)1 StatsStorageRouter (org.deeplearning4j.api.storage.StatsStorageRouter)1 RemoteUIStatsStorageRouter (org.deeplearning4j.api.storage.impl.RemoteUIStatsStorageRouter)1 AsyncDataSetIterator (org.deeplearning4j.datasets.iterator.AsyncDataSetIterator)1 AsyncMultiDataSetIterator (org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator)1 Evaluation (org.deeplearning4j.eval.Evaluation)1 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)1