use of org.deeplearning4j.nn.api.Model in project deeplearning4j by deeplearning4j.
the class TestOptimizers method testRosenbrockFnMultipleStepsHelper.
private static void testRosenbrockFnMultipleStepsHelper(OptimizationAlgorithm oa, int nOptIter, int maxNumLineSearchIter) {
double[] scores = new double[nOptIter + 1];
for (int i = 0; i <= nOptIter; i++) {
NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().maxNumLineSearchIterations(maxNumLineSearchIter).iterations(i).stepFunction(new org.deeplearning4j.nn.conf.stepfunctions.NegativeDefaultStepFunction()).learningRate(1e-1).layer(new RBM.Builder().nIn(1).nOut(1).updater(Updater.SGD).build()).build();
//Normally done by ParamInitializers, but obviously that isn't done here
conf.addVariable("W");
Model m = new RosenbrockFunctionModel(100, conf);
if (i == 0) {
m.computeGradientAndScore();
//Before optimization
scores[0] = m.score();
} else {
ConvexOptimizer opt = getOptimizer(oa, conf, m);
opt.optimize();
m.computeGradientAndScore();
scores[i] = m.score();
assertTrue("NaN or infinite score: " + scores[i], !Double.isNaN(scores[i]) && !Double.isInfinite(scores[i]));
}
}
if (PRINT_OPT_RESULTS) {
System.out.println("Rosenbrock: Multiple optimization iterations ( " + nOptIter + " opt. iter.) score vs iteration, maxNumLineSearchIter= " + maxNumLineSearchIter + ": " + oa);
System.out.println(Arrays.toString(scores));
}
for (int i = 1; i < scores.length; i++) {
if (i == 1) {
//Require at least one step of improvement
assertTrue(scores[i] < scores[i - 1]);
} else {
assertTrue(scores[i] <= scores[i - 1]);
}
}
}
use of org.deeplearning4j.nn.api.Model in project deeplearning4j by deeplearning4j.
the class ModelGuesserTest method testModelGuess.
@Test
public void testModelGuess() throws Exception {
ClassPathResource sequenceResource = new ClassPathResource("modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_model.h5");
assertTrue(sequenceResource.exists());
File f = getTempFile(sequenceResource);
Model guess1 = ModelGuesser.loadModelGuess(f.getAbsolutePath());
assumeNotNull(guess1);
ClassPathResource sequenceResource2 = new ClassPathResource("modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_model.h5");
assertTrue(sequenceResource2.exists());
File f2 = getTempFile(sequenceResource);
Model guess2 = ModelGuesser.loadModelGuess(f2.getAbsolutePath());
assumeNotNull(guess2);
}
use of org.deeplearning4j.nn.api.Model in project deeplearning4j by deeplearning4j.
the class ParameterServerParallelWrapper method init.
private void init(Object iterator) {
if (numEpochs < 1)
throw new IllegalStateException("numEpochs must be >= 1");
//TODO: make this efficient
if (iterator instanceof DataSetIterator) {
DataSetIterator dataSetIterator = (DataSetIterator) iterator;
numUpdatesPerEpoch = numUpdatesPerEpoch(dataSetIterator);
} else if (iterator instanceof MultiDataSetIterator) {
MultiDataSetIterator iterator1 = (MultiDataSetIterator) iterator;
numUpdatesPerEpoch = numUpdatesPerEpoch(iterator1);
} else
throw new IllegalArgumentException("Illegal type of object passed in for initialization. Must be of type DataSetIterator or MultiDataSetIterator");
mediaDriverContext = new MediaDriver.Context();
mediaDriver = MediaDriver.launchEmbedded(mediaDriverContext);
parameterServerNode = new ParameterServerNode(mediaDriver, statusServerPort, numWorkers);
running = new AtomicBoolean(true);
if (parameterServerArgs == null)
parameterServerArgs = new String[] { "-m", "true", "-s", "1," + String.valueOf(model.numParams()), "-p", "40323", "-h", "localhost", "-id", "11", "-md", mediaDriver.aeronDirectoryName(), "-sh", "localhost", "-sp", String.valueOf(statusServerPort), "-u", String.valueOf(numUpdatesPerEpoch) };
if (numWorkers == 0)
numWorkers = Runtime.getRuntime().availableProcessors();
linkedBlockingQueue = new LinkedBlockingQueue<>(numWorkers);
//pass through args for the parameter server subscriber
parameterServerNode.runMain(parameterServerArgs);
while (!parameterServerNode.subscriberLaunched()) {
try {
Thread.sleep(10000);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}
try {
Thread.sleep(10000);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
log.info("Parameter server started");
parameterServerClient = new Trainer[numWorkers];
executorService = Executors.newFixedThreadPool(numWorkers);
for (int i = 0; i < numWorkers; i++) {
Model model = null;
if (this.model instanceof ComputationGraph) {
ComputationGraph computationGraph = (ComputationGraph) this.model;
model = computationGraph.clone();
} else if (this.model instanceof MultiLayerNetwork) {
MultiLayerNetwork multiLayerNetwork = (MultiLayerNetwork) this.model;
model = multiLayerNetwork.clone();
}
parameterServerClient[i] = new Trainer(ParameterServerClient.builder().aeron(parameterServerNode.getAeron()).ndarrayRetrieveUrl(parameterServerNode.getSubscriber()[i].getResponder().connectionUrl()).ndarraySendUrl(parameterServerNode.getSubscriber()[i].getSubscriber().connectionUrl()).subscriberHost("localhost").masterStatusHost("localhost").masterStatusPort(statusServerPort).subscriberPort(40625 + i).subscriberStream(12 + i).build(), running, linkedBlockingQueue, model);
final int j = i;
executorService.submit(() -> parameterServerClient[j].start());
}
init = true;
log.info("Initialized wrapper");
}
Aggregations