use of org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator in project deeplearning4j by deeplearning4j.
the class ParameterServerParallelWrapper method fit.
public void fit(MultiDataSetIterator multiDataSetIterator) {
if (!init)
init(multiDataSetIterator);
MultiDataSetIterator iterator = null;
if (preFetchSize > 0 && multiDataSetIterator.asyncSupported()) {
iterator = new AsyncMultiDataSetIterator(multiDataSetIterator, preFetchSize);
} else
iterator = multiDataSetIterator;
while (iterator.hasNext()) {
org.nd4j.linalg.dataset.api.MultiDataSet next = iterator.next();
addObject(next);
}
}
use of org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator in project deeplearning4j by deeplearning4j.
the class ParameterServerParallelWrapper method init.
private void init(Object iterator) {
if (numEpochs < 1)
throw new IllegalStateException("numEpochs must be >= 1");
//TODO: make this efficient
if (iterator instanceof DataSetIterator) {
DataSetIterator dataSetIterator = (DataSetIterator) iterator;
numUpdatesPerEpoch = numUpdatesPerEpoch(dataSetIterator);
} else if (iterator instanceof MultiDataSetIterator) {
MultiDataSetIterator iterator1 = (MultiDataSetIterator) iterator;
numUpdatesPerEpoch = numUpdatesPerEpoch(iterator1);
} else
throw new IllegalArgumentException("Illegal type of object passed in for initialization. Must be of type DataSetIterator or MultiDataSetIterator");
mediaDriverContext = new MediaDriver.Context();
mediaDriver = MediaDriver.launchEmbedded(mediaDriverContext);
parameterServerNode = new ParameterServerNode(mediaDriver, statusServerPort, numWorkers);
running = new AtomicBoolean(true);
if (parameterServerArgs == null)
parameterServerArgs = new String[] { "-m", "true", "-s", "1," + String.valueOf(model.numParams()), "-p", "40323", "-h", "localhost", "-id", "11", "-md", mediaDriver.aeronDirectoryName(), "-sh", "localhost", "-sp", String.valueOf(statusServerPort), "-u", String.valueOf(numUpdatesPerEpoch) };
if (numWorkers == 0)
numWorkers = Runtime.getRuntime().availableProcessors();
linkedBlockingQueue = new LinkedBlockingQueue<>(numWorkers);
//pass through args for the parameter server subscriber
parameterServerNode.runMain(parameterServerArgs);
while (!parameterServerNode.subscriberLaunched()) {
try {
Thread.sleep(10000);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}
try {
Thread.sleep(10000);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
log.info("Parameter server started");
parameterServerClient = new Trainer[numWorkers];
executorService = Executors.newFixedThreadPool(numWorkers);
for (int i = 0; i < numWorkers; i++) {
Model model = null;
if (this.model instanceof ComputationGraph) {
ComputationGraph computationGraph = (ComputationGraph) this.model;
model = computationGraph.clone();
} else if (this.model instanceof MultiLayerNetwork) {
MultiLayerNetwork multiLayerNetwork = (MultiLayerNetwork) this.model;
model = multiLayerNetwork.clone();
}
parameterServerClient[i] = new Trainer(ParameterServerClient.builder().aeron(parameterServerNode.getAeron()).ndarrayRetrieveUrl(parameterServerNode.getSubscriber()[i].getResponder().connectionUrl()).ndarraySendUrl(parameterServerNode.getSubscriber()[i].getSubscriber().connectionUrl()).subscriberHost("localhost").masterStatusHost("localhost").masterStatusPort(statusServerPort).subscriberPort(40625 + i).subscriberStream(12 + i).build(), running, linkedBlockingQueue, model);
final int j = i;
executorService.submit(() -> parameterServerClient[j].start());
}
init = true;
log.info("Initialized wrapper");
}
Aggregations