use of org.nd4j.linalg.dataset.DataSet in project deeplearning4j by deeplearning4j.
the class TestEarlyStoppingSpark method testTimeTermination.
@Test
public void testTimeTermination() {
//test termination after max time
Nd4j.getRandom().setSeed(12345);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).updater(Updater.SGD).learningRate(1e-6).weightInit(WeightInit.XAVIER).list().layer(0, new OutputLayer.Builder().nIn(4).nOut(3).lossFunction(LossFunctions.LossFunction.MCXENT).build()).pretrain(false).backprop(true).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.setListeners(new ScoreIterationListener(1));
JavaRDD<DataSet> irisData = getIris();
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
EarlyStoppingConfiguration<MultiLayerNetwork> esConf = new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>().epochTerminationConditions(new MaxEpochsTerminationCondition(10000)).iterationTerminationConditions(new MaxTimeIterationTerminationCondition(3, TimeUnit.SECONDS), //Initial score is ~2.5
new MaxScoreIterationTerminationCondition(7.5)).scoreCalculator(new SparkDataSetLossCalculator(irisData, true, sc.sc())).modelSaver(saver).build();
IEarlyStoppingTrainer<MultiLayerNetwork> trainer = new SparkEarlyStoppingTrainer(getContext().sc(), new ParameterAveragingTrainingMaster(true, 4, 1, 150 / 15, 1, 0), esConf, net, irisData);
long startTime = System.currentTimeMillis();
EarlyStoppingResult result = trainer.fit();
long endTime = System.currentTimeMillis();
int durationSeconds = (int) (endTime - startTime) / 1000;
assertTrue("durationSeconds = " + durationSeconds, durationSeconds >= 3);
assertTrue("durationSeconds = " + durationSeconds, durationSeconds <= 9);
assertEquals(EarlyStoppingResult.TerminationReason.IterationTerminationCondition, result.getTerminationReason());
String expDetails = new MaxTimeIterationTerminationCondition(3, TimeUnit.SECONDS).toString();
assertEquals(expDetails, result.getTerminationDetails());
}
use of org.nd4j.linalg.dataset.DataSet in project deeplearning4j by deeplearning4j.
the class LoadSerializedDataSetFunction method call.
@Override
public DataSet call(PortableDataStream pds) throws Exception {
try (InputStream is = pds.open()) {
DataSet d = new DataSet();
d.load(is);
return d;
}
}
use of org.nd4j.linalg.dataset.DataSet in project deeplearning4j by deeplearning4j.
the class SparkDl4jMultiLayer method fitLabeledPoint.
/**
* Fit a MultiLayerNetwork using Spark MLLib LabeledPoint instances.
* This will convert the labeled points to the internal DL4J data format and train the model on that
*
* @param rdd the rdd to fitDataSet
* @return the multi layer network that was fitDataSet
*/
public MultiLayerNetwork fitLabeledPoint(JavaRDD<LabeledPoint> rdd) {
int nLayers = network.getLayerWiseConfigurations().getConfs().size();
FeedForwardLayer ffl = (FeedForwardLayer) network.getLayerWiseConfigurations().getConf(nLayers - 1).getLayer();
JavaRDD<DataSet> ds = MLLibUtil.fromLabeledPoint(sc, rdd, ffl.getNOut());
return fit(ds);
}
use of org.nd4j.linalg.dataset.DataSet in project deeplearning4j by deeplearning4j.
the class IEvaluateFlatMapFunctionAdapter method call.
@Override
public Iterable<T> call(Iterator<DataSet> dataSetIterator) throws Exception {
if (!dataSetIterator.hasNext()) {
return Collections.emptyList();
}
MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(json.getValue()));
network.init();
INDArray val = params.value().unsafeDuplication();
if (val.length() != network.numParams(false))
throw new IllegalStateException("Network did not have same number of parameters as the broadcast set parameters");
network.setParameters(val);
List<DataSet> collect = new ArrayList<>();
int totalCount = 0;
while (dataSetIterator.hasNext()) {
collect.clear();
int nExamples = 0;
while (dataSetIterator.hasNext() && nExamples < evalBatchSize) {
DataSet next = dataSetIterator.next();
nExamples += next.numExamples();
collect.add(next);
}
totalCount += nExamples;
DataSet data = DataSet.merge(collect);
INDArray out;
if (data.hasMaskArrays()) {
out = network.output(data.getFeatureMatrix(), false, data.getFeaturesMaskArray(), data.getLabelsMaskArray());
} else {
out = network.output(data.getFeatureMatrix(), false);
}
if (data.getLabels().rank() == 3) {
if (data.getLabelsMaskArray() == null) {
evaluation.evalTimeSeries(data.getLabels(), out);
} else {
evaluation.evalTimeSeries(data.getLabels(), out, data.getLabelsMaskArray());
}
} else {
evaluation.eval(data.getLabels(), out);
}
}
if (log.isDebugEnabled()) {
log.debug("Evaluated {} examples ", totalCount);
}
return Collections.singletonList(evaluation);
}
use of org.nd4j.linalg.dataset.DataSet in project deeplearning4j by deeplearning4j.
the class ScoreExamplesWithKeyFunctionAdapter method call.
@Override
public Iterable<Tuple2<K, Double>> call(Iterator<Tuple2<K, DataSet>> iterator) throws Exception {
if (!iterator.hasNext()) {
return Collections.emptyList();
}
MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(jsonConfig.getValue()));
network.init();
INDArray val = params.value().unsafeDuplication();
if (val.length() != network.numParams(false))
throw new IllegalStateException("Network did not have same number of parameters as the broadcast set parameters");
network.setParameters(val);
List<Tuple2<K, Double>> ret = new ArrayList<>();
List<DataSet> collect = new ArrayList<>(batchSize);
List<K> collectKey = new ArrayList<>(batchSize);
int totalCount = 0;
while (iterator.hasNext()) {
collect.clear();
collectKey.clear();
int nExamples = 0;
while (iterator.hasNext() && nExamples < batchSize) {
Tuple2<K, DataSet> t2 = iterator.next();
DataSet ds = t2._2();
int n = ds.numExamples();
if (n != 1)
throw new IllegalStateException("Cannot score examples with one key per data set if " + "data set contains more than 1 example (numExamples: " + n + ")");
collect.add(ds);
collectKey.add(t2._1());
nExamples += n;
}
totalCount += nExamples;
DataSet data = DataSet.merge(collect);
INDArray scores = network.scoreExamples(data, addRegularization);
double[] doubleScores = scores.data().asDouble();
for (int i = 0; i < doubleScores.length; i++) {
ret.add(new Tuple2<>(collectKey.get(i), doubleScores[i]));
}
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
if (log.isDebugEnabled()) {
log.debug("Scored {} examples ", totalCount);
}
return ret;
}
Aggregations