use of org.deeplearning4j.nn.graph.ComputationGraph in project deeplearning4j by deeplearning4j.
the class DL4jServeRouteBuilder method configure.
/**
* <b>Called on initialization to build the routes using the fluent builder syntax.</b>
* <p/>
* This is a central method for RouteBuilder implementations to implement
* the routes using the Java fluent builder syntax.
*
* @throws Exception can be thrown during configuration
*/
@Override
public void configure() throws Exception {
if (groupId == null)
groupId = "dl4j-serving";
if (zooKeeperHost == null)
zooKeeperHost = "localhost";
String kafkaUri = String.format("kafka:%s?topic=%s&groupId=%s", kafkaBroker, consumingTopic, groupId);
if (beforeProcessor == null) {
beforeProcessor = new Processor() {
@Override
public void process(Exchange exchange) throws Exception {
}
};
}
from(kafkaUri).process(beforeProcessor).process(new Processor() {
@Override
public void process(Exchange exchange) throws Exception {
INDArray predict;
if (exchange.getIn().getBody() instanceof byte[]) {
byte[] o = (byte[]) exchange.getIn().getBody();
byte[] arr = Base64.decodeBase64(new String(o));
ByteArrayInputStream bis = new ByteArrayInputStream(arr);
DataInputStream dis = new DataInputStream(bis);
predict = Nd4j.read(dis);
} else
predict = (INDArray) exchange.getIn().getBody();
if (computationGraph) {
ComputationGraph graph = ModelSerializer.restoreComputationGraph(modelUri);
INDArray[] output = graph.output(predict);
exchange.getOut().setBody(output);
exchange.getIn().setBody(output);
} else {
MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(modelUri);
INDArray output = network.output(predict);
exchange.getOut().setBody(output);
exchange.getIn().setBody(output);
}
}
}).process(finalProcessor).to(outputUri);
}
use of org.deeplearning4j.nn.graph.ComputationGraph in project deeplearning4j by deeplearning4j.
the class BaseOptimizer method incrementIterationCount.
public static void incrementIterationCount(Model model, int incrementBy) {
if (model instanceof MultiLayerNetwork) {
MultiLayerConfiguration conf = ((MultiLayerNetwork) model).getLayerWiseConfigurations();
conf.setIterationCount(conf.getIterationCount() + incrementBy);
} else if (model instanceof ComputationGraph) {
ComputationGraphConfiguration conf = ((ComputationGraph) model).getConfiguration();
conf.setIterationCount(conf.getIterationCount() + incrementBy);
} else {
model.conf().setIterationCount(model.conf().getIterationCount() + incrementBy);
}
}
use of org.deeplearning4j.nn.graph.ComputationGraph in project deeplearning4j by deeplearning4j.
the class BaseOptimizer method updateGradientAccordingToParams.
@Override
public void updateGradientAccordingToParams(Gradient gradient, Model model, int batchSize) {
if (model instanceof ComputationGraph) {
ComputationGraph graph = (ComputationGraph) model;
if (computationGraphUpdater == null) {
computationGraphUpdater = new ComputationGraphUpdater(graph);
}
computationGraphUpdater.update(graph, gradient, getIterationCount(model), batchSize);
} else {
if (updater == null)
updater = UpdaterCreator.getUpdater(model);
Layer layer = (Layer) model;
updater.update(layer, gradient, getIterationCount(model), batchSize);
}
}
use of org.deeplearning4j.nn.graph.ComputationGraph in project deeplearning4j by deeplearning4j.
the class ParameterServerParallelWrapper method init.
private void init(Object iterator) {
if (numEpochs < 1)
throw new IllegalStateException("numEpochs must be >= 1");
//TODO: make this efficient
if (iterator instanceof DataSetIterator) {
DataSetIterator dataSetIterator = (DataSetIterator) iterator;
numUpdatesPerEpoch = numUpdatesPerEpoch(dataSetIterator);
} else if (iterator instanceof MultiDataSetIterator) {
MultiDataSetIterator iterator1 = (MultiDataSetIterator) iterator;
numUpdatesPerEpoch = numUpdatesPerEpoch(iterator1);
} else
throw new IllegalArgumentException("Illegal type of object passed in for initialization. Must be of type DataSetIterator or MultiDataSetIterator");
mediaDriverContext = new MediaDriver.Context();
mediaDriver = MediaDriver.launchEmbedded(mediaDriverContext);
parameterServerNode = new ParameterServerNode(mediaDriver, statusServerPort, numWorkers);
running = new AtomicBoolean(true);
if (parameterServerArgs == null)
parameterServerArgs = new String[] { "-m", "true", "-s", "1," + String.valueOf(model.numParams()), "-p", "40323", "-h", "localhost", "-id", "11", "-md", mediaDriver.aeronDirectoryName(), "-sh", "localhost", "-sp", String.valueOf(statusServerPort), "-u", String.valueOf(numUpdatesPerEpoch) };
if (numWorkers == 0)
numWorkers = Runtime.getRuntime().availableProcessors();
linkedBlockingQueue = new LinkedBlockingQueue<>(numWorkers);
//pass through args for the parameter server subscriber
parameterServerNode.runMain(parameterServerArgs);
while (!parameterServerNode.subscriberLaunched()) {
try {
Thread.sleep(10000);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}
try {
Thread.sleep(10000);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
log.info("Parameter server started");
parameterServerClient = new Trainer[numWorkers];
executorService = Executors.newFixedThreadPool(numWorkers);
for (int i = 0; i < numWorkers; i++) {
Model model = null;
if (this.model instanceof ComputationGraph) {
ComputationGraph computationGraph = (ComputationGraph) this.model;
model = computationGraph.clone();
} else if (this.model instanceof MultiLayerNetwork) {
MultiLayerNetwork multiLayerNetwork = (MultiLayerNetwork) this.model;
model = multiLayerNetwork.clone();
}
parameterServerClient[i] = new Trainer(ParameterServerClient.builder().aeron(parameterServerNode.getAeron()).ndarrayRetrieveUrl(parameterServerNode.getSubscriber()[i].getResponder().connectionUrl()).ndarraySendUrl(parameterServerNode.getSubscriber()[i].getSubscriber().connectionUrl()).subscriberHost("localhost").masterStatusHost("localhost").masterStatusPort(statusServerPort).subscriberPort(40625 + i).subscriberStream(12 + i).build(), running, linkedBlockingQueue, model);
final int j = i;
executorService.submit(() -> parameterServerClient[j].start());
}
init = true;
log.info("Initialized wrapper");
}
use of org.deeplearning4j.nn.graph.ComputationGraph in project deeplearning4j by deeplearning4j.
the class CGVaeReconstructionProbWithKeyFunction method getVaeLayer.
@Override
public VariationalAutoencoder getVaeLayer() {
ComputationGraph network = new ComputationGraph(ComputationGraphConfiguration.fromJson((String) jsonConfig.getValue()));
network.init();
INDArray val = ((INDArray) params.value()).unsafeDuplication();
if (val.length() != network.numParams(false))
throw new IllegalStateException("Network did not have same number of parameters as the broadcasted set parameters");
network.setParams(val);
Layer l = network.getLayer(0);
if (!(l instanceof VariationalAutoencoder)) {
throw new RuntimeException("Cannot use CGVaeReconstructionProbWithKeyFunction on network that doesn't have a VAE " + "layer as layer 0. Layer type: " + l.getClass());
}
return (VariationalAutoencoder) l;
}
Aggregations