Search in sources :

Example 36 with NeuralNetConfiguration

use of org.deeplearning4j.nn.conf.NeuralNetConfiguration in project deeplearning4j by deeplearning4j.

the class TrainModule method getLayerInfoTable.

private String[][] getLayerInfoTable(int layerIdx, TrainModuleUtils.GraphInfo gi, I18N i18N, boolean noData, StatsStorage ss, String wid) {
    List<String[]> layerInfoRows = new ArrayList<>();
    layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerName"), gi.getVertexNames().get(layerIdx) });
    layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerType"), "" });
    if (!noData) {
        Persistable p = ss.getStaticInfo(currentSessionID, StatsListener.TYPE_ID, wid);
        if (p != null) {
            StatsInitializationReport initReport = (StatsInitializationReport) p;
            String configJson = initReport.getModelConfigJson();
            String modelClass = initReport.getModelClassName();
            //TODO error handling...
            String layerType = "";
            Layer layer = null;
            NeuralNetConfiguration nnc = null;
            if (modelClass.endsWith("MultiLayerNetwork")) {
                MultiLayerConfiguration conf = MultiLayerConfiguration.fromJson(configJson);
                //-1 because of input
                int confIdx = layerIdx - 1;
                if (confIdx >= 0) {
                    nnc = conf.getConf(confIdx);
                    layer = nnc.getLayer();
                } else {
                    //Input layer
                    layerType = "Input";
                }
            } else if (modelClass.endsWith("ComputationGraph")) {
                ComputationGraphConfiguration conf = ComputationGraphConfiguration.fromJson(configJson);
                String vertexName = gi.getVertexNames().get(layerIdx);
                Map<String, GraphVertex> vertices = conf.getVertices();
                if (vertices.containsKey(vertexName) && vertices.get(vertexName) instanceof LayerVertex) {
                    LayerVertex lv = (LayerVertex) vertices.get(vertexName);
                    nnc = lv.getLayerConf();
                    layer = nnc.getLayer();
                } else if (conf.getNetworkInputs().contains(vertexName)) {
                    layerType = "Input";
                } else {
                    GraphVertex gv = conf.getVertices().get(vertexName);
                    if (gv != null) {
                        layerType = gv.getClass().getSimpleName();
                    }
                }
            } else if (modelClass.endsWith("VariationalAutoencoder")) {
                layerType = gi.getVertexTypes().get(layerIdx);
                Map<String, String> map = gi.getVertexInfo().get(layerIdx);
                for (Map.Entry<String, String> entry : map.entrySet()) {
                    layerInfoRows.add(new String[] { entry.getKey(), entry.getValue() });
                }
            }
            if (layer != null) {
                layerType = getLayerType(layer);
            }
            if (layer != null) {
                String activationFn = null;
                if (layer instanceof FeedForwardLayer) {
                    FeedForwardLayer ffl = (FeedForwardLayer) layer;
                    layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerNIn"), String.valueOf(ffl.getNIn()) });
                    layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerSize"), String.valueOf(ffl.getNOut()) });
                    activationFn = layer.getActivationFn().toString();
                }
                int nParams = layer.initializer().numParams(nnc);
                layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerNParams"), String.valueOf(nParams) });
                if (nParams > 0) {
                    WeightInit wi = layer.getWeightInit();
                    String str = wi.toString();
                    if (wi == WeightInit.DISTRIBUTION) {
                        str += layer.getDist();
                    }
                    layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerWeightInit"), str });
                    Updater u = layer.getUpdater();
                    String us = (u == null ? "" : u.toString());
                    layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerUpdater"), us });
                //TODO: Maybe L1/L2, dropout, updater-specific values etc
                }
                if (layer instanceof ConvolutionLayer || layer instanceof SubsamplingLayer) {
                    int[] kernel;
                    int[] stride;
                    int[] padding;
                    if (layer instanceof ConvolutionLayer) {
                        ConvolutionLayer cl = (ConvolutionLayer) layer;
                        kernel = cl.getKernelSize();
                        stride = cl.getStride();
                        padding = cl.getPadding();
                    } else {
                        SubsamplingLayer ssl = (SubsamplingLayer) layer;
                        kernel = ssl.getKernelSize();
                        stride = ssl.getStride();
                        padding = ssl.getPadding();
                        activationFn = null;
                        layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerSubsamplingPoolingType"), ssl.getPoolingType().toString() });
                    }
                    layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerCnnKernel"), Arrays.toString(kernel) });
                    layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerCnnStride"), Arrays.toString(stride) });
                    layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerCnnPadding"), Arrays.toString(padding) });
                }
                if (activationFn != null) {
                    layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerActivationFn"), activationFn });
                }
            }
            layerInfoRows.get(1)[1] = layerType;
        }
    }
    return layerInfoRows.toArray(new String[layerInfoRows.size()][0]);
}
Also used : StatsInitializationReport(org.deeplearning4j.ui.stats.api.StatsInitializationReport) LayerVertex(org.deeplearning4j.nn.conf.graph.LayerVertex) Persistable(org.deeplearning4j.api.storage.Persistable) SubsamplingLayer(org.deeplearning4j.nn.conf.layers.SubsamplingLayer) WeightInit(org.deeplearning4j.nn.weights.WeightInit) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) ConvolutionLayer(org.deeplearning4j.nn.conf.layers.ConvolutionLayer) SubsamplingLayer(org.deeplearning4j.nn.conf.layers.SubsamplingLayer) FeedForwardLayer(org.deeplearning4j.nn.conf.layers.FeedForwardLayer) Layer(org.deeplearning4j.nn.conf.layers.Layer) ConvolutionLayer(org.deeplearning4j.nn.conf.layers.ConvolutionLayer) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) GraphVertex(org.deeplearning4j.nn.conf.graph.GraphVertex) Updater(org.deeplearning4j.nn.conf.Updater) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) FeedForwardLayer(org.deeplearning4j.nn.conf.layers.FeedForwardLayer)

Example 37 with NeuralNetConfiguration

use of org.deeplearning4j.nn.conf.NeuralNetConfiguration in project deeplearning4j by deeplearning4j.

the class TrainModuleUtils method buildGraphInfo.

public static GraphInfo buildGraphInfo(ComputationGraphConfiguration config) {
    List<String> layerNames = new ArrayList<>();
    List<String> layerTypes = new ArrayList<>();
    List<List<Integer>> layerInputs = new ArrayList<>();
    List<Map<String, String>> layerInfo = new ArrayList<>();
    Map<String, GraphVertex> vertices = config.getVertices();
    Map<String, List<String>> vertexInputs = config.getVertexInputs();
    List<String> networkInputs = config.getNetworkInputs();
    List<String> originalVertexName = new ArrayList<>();
    Map<String, Integer> vertexToIndexMap = new HashMap<>();
    int vertexCount = 0;
    for (String s : networkInputs) {
        vertexToIndexMap.put(s, vertexCount++);
        layerNames.add(s);
        originalVertexName.add(s);
        layerTypes.add(s);
        layerInputs.add(Collections.emptyList());
        layerInfo.add(Collections.emptyMap());
    }
    for (String s : vertices.keySet()) {
        vertexToIndexMap.put(s, vertexCount++);
    }
    int layerCount = 0;
    for (Map.Entry<String, GraphVertex> entry : vertices.entrySet()) {
        GraphVertex gv = entry.getValue();
        layerNames.add(entry.getKey());
        List<String> inputsThisVertex = vertexInputs.get(entry.getKey());
        List<Integer> inputIndexes = new ArrayList<>();
        for (String s : inputsThisVertex) {
            inputIndexes.add(vertexToIndexMap.get(s));
        }
        layerInputs.add(inputIndexes);
        if (gv instanceof LayerVertex) {
            NeuralNetConfiguration c = ((LayerVertex) gv).getLayerConf();
            Layer layer = c.getLayer();
            String layerType = layer.getClass().getSimpleName().replaceAll("Layer$", "");
            layerTypes.add(layerType);
            //Extract layer info
            Map<String, String> map = getLayerInfo(c, layer);
            layerInfo.add(map);
        } else {
            String layerType = gv.getClass().getSimpleName();
            layerTypes.add(layerType);
            //TODO
            Map<String, String> thisVertexInfo = Collections.emptyMap();
            layerInfo.add(thisVertexInfo);
        }
        originalVertexName.add(entry.getKey());
    }
    return new GraphInfo(layerNames, layerTypes, layerInputs, layerInfo, originalVertexName);
}
Also used : LayerVertex(org.deeplearning4j.nn.conf.graph.LayerVertex) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) GraphVertex(org.deeplearning4j.nn.conf.graph.GraphVertex)

Example 38 with NeuralNetConfiguration

use of org.deeplearning4j.nn.conf.NeuralNetConfiguration in project deeplearning4j by deeplearning4j.

the class TrainModuleUtils method buildGraphInfo.

public static GraphInfo buildGraphInfo(MultiLayerConfiguration config) {
    List<String> vertexNames = new ArrayList<>();
    List<String> originalVertexName = new ArrayList<>();
    List<String> layerTypes = new ArrayList<>();
    List<List<Integer>> layerInputs = new ArrayList<>();
    List<Map<String, String>> layerInfo = new ArrayList<>();
    vertexNames.add("Input");
    originalVertexName.add(null);
    layerTypes.add("Input");
    layerInputs.add(Collections.emptyList());
    layerInfo.add(Collections.emptyMap());
    List<NeuralNetConfiguration> list = config.getConfs();
    int layerIdx = 1;
    for (NeuralNetConfiguration c : list) {
        Layer layer = c.getLayer();
        String layerName = layer.getLayerName();
        if (layerName == null)
            layerName = "layer" + layerIdx;
        vertexNames.add(layerName);
        originalVertexName.add(String.valueOf(layerIdx - 1));
        String layerType = c.getLayer().getClass().getSimpleName().replaceAll("Layer$", "");
        layerTypes.add(layerType);
        layerInputs.add(Collections.singletonList(layerIdx - 1));
        layerIdx++;
        //Extract layer info
        Map<String, String> map = getLayerInfo(c, layer);
        layerInfo.add(map);
    }
    return new GraphInfo(vertexNames, layerTypes, layerInputs, layerInfo, originalVertexName);
}
Also used : NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration)

Example 39 with NeuralNetConfiguration

use of org.deeplearning4j.nn.conf.NeuralNetConfiguration in project deeplearning4j by deeplearning4j.

the class BaseStatsListener method iterationDone.

@Override
public void iterationDone(Model model, int iteration) {
    StatsUpdateConfiguration config = updateConfig;
    ModelInfo modelInfo = getModelInfo(model);
    boolean backpropParamsOnly = backpropParamsOnly(model);
    long currentTime = getTime();
    if (modelInfo.iterCount == 0) {
        modelInfo.initTime = currentTime;
        doInit(model);
    }
    if (config.collectPerformanceStats()) {
        updateExamplesMinibatchesCounts(model);
    }
    if (config.reportingFrequency() > 1 && (iteration == 0 || iteration % config.reportingFrequency() != 0)) {
        modelInfo.iterCount = iteration;
        return;
    }
    StatsReport report = getNewStatsReport();
    //TODO support NTP time
    report.reportIDs(getSessionID(model), TYPE_ID, workerID, System.currentTimeMillis());
    //--- Performance and System Stats ---
    if (config.collectPerformanceStats()) {
        //Stats to collect: total runtime, total examples, total minibatches, iterations/second, examples/second
        double examplesPerSecond;
        double minibatchesPerSecond;
        if (modelInfo.iterCount == 0) {
            //Not possible to work out perf/second: first iteration...
            examplesPerSecond = 0.0;
            minibatchesPerSecond = 0.0;
        } else {
            long deltaTimeMS = currentTime - modelInfo.lastReportTime;
            examplesPerSecond = 1000.0 * modelInfo.examplesSinceLastReport / deltaTimeMS;
            minibatchesPerSecond = 1000.0 * modelInfo.minibatchesSinceLastReport / deltaTimeMS;
        }
        long totalRuntimeMS = currentTime - modelInfo.initTime;
        report.reportPerformance(totalRuntimeMS, modelInfo.totalExamples, modelInfo.totalMinibatches, examplesPerSecond, minibatchesPerSecond);
        modelInfo.examplesSinceLastReport = 0;
        modelInfo.minibatchesSinceLastReport = 0;
    }
    if (config.collectMemoryStats()) {
        Runtime runtime = Runtime.getRuntime();
        long jvmTotal = runtime.totalMemory();
        long jvmMax = runtime.maxMemory();
        //Off-heap memory
        long offheapTotal = Pointer.totalBytes();
        long offheapMax = Pointer.maxBytes();
        //GPU
        long[] gpuCurrentBytes = null;
        long[] gpuMaxBytes = null;
        NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
        int nDevices = nativeOps.getAvailableDevices();
        if (nDevices > 0) {
            gpuCurrentBytes = new long[nDevices];
            gpuMaxBytes = new long[nDevices];
            for (int i = 0; i < nDevices; i++) {
                try {
                    Pointer p = getDevicePointer(i);
                    if (p == null) {
                        gpuMaxBytes[i] = 0;
                        gpuCurrentBytes[i] = 0;
                    } else {
                        gpuMaxBytes[i] = nativeOps.getDeviceTotalMemory(p);
                        gpuCurrentBytes[i] = gpuMaxBytes[i] - nativeOps.getDeviceFreeMemory(p);
                    }
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
        }
        report.reportMemoryUse(jvmTotal, jvmMax, offheapTotal, offheapMax, gpuCurrentBytes, gpuMaxBytes);
    }
    if (config.collectGarbageCollectionStats()) {
        if (modelInfo.lastReportIteration == -1 || gcBeans == null) {
            //Haven't reported GC stats before...
            gcBeans = ManagementFactory.getGarbageCollectorMXBeans();
            gcStatsAtLastReport = new HashMap<>();
            for (GarbageCollectorMXBean bean : gcBeans) {
                long count = bean.getCollectionCount();
                long timeMs = bean.getCollectionTime();
                gcStatsAtLastReport.put(bean.getName(), new Pair<>(count, timeMs));
            }
        } else {
            for (GarbageCollectorMXBean bean : gcBeans) {
                long count = bean.getCollectionCount();
                long timeMs = bean.getCollectionTime();
                Pair<Long, Long> lastStats = gcStatsAtLastReport.get(bean.getName());
                long deltaGCCount = count - lastStats.getFirst();
                long deltaGCTime = timeMs - lastStats.getSecond();
                lastStats.setFirst(count);
                lastStats.setSecond(timeMs);
                report.reportGarbageCollection(bean.getName(), (int) deltaGCCount, (int) deltaGCTime);
            }
        }
    }
    //--- General ---
    //Always report score
    report.reportScore(model.score());
    if (config.collectLearningRates()) {
        Map<String, Double> lrs = new HashMap<>();
        if (model instanceof MultiLayerNetwork) {
            //Need to append "0_", "1_" etc to param names from layers...
            int layerIdx = 0;
            for (Layer l : ((MultiLayerNetwork) model).getLayers()) {
                NeuralNetConfiguration conf = l.conf();
                Map<String, Double> layerLrs = conf.getLearningRateByParam();
                Set<String> backpropParams = l.paramTable(true).keySet();
                for (Map.Entry<String, Double> entry : layerLrs.entrySet()) {
                    if (!backpropParams.contains(entry.getKey()))
                        //Skip pretrain params
                        continue;
                    lrs.put(layerIdx + "_" + entry.getKey(), entry.getValue());
                }
                layerIdx++;
            }
        } else if (model instanceof ComputationGraph) {
            for (Layer l : ((ComputationGraph) model).getLayers()) {
                //Need to append layer name
                NeuralNetConfiguration conf = l.conf();
                Map<String, Double> layerLrs = conf.getLearningRateByParam();
                String layerName = conf.getLayer().getLayerName();
                Set<String> backpropParams = l.paramTable(true).keySet();
                for (Map.Entry<String, Double> entry : layerLrs.entrySet()) {
                    if (!backpropParams.contains(entry.getKey()))
                        //Skip pretrain params
                        continue;
                    lrs.put(layerName + "_" + entry.getKey(), entry.getValue());
                }
            }
        } else if (model instanceof Layer) {
            Layer l = (Layer) model;
            Map<String, Double> map = l.conf().getLearningRateByParam();
            lrs.putAll(map);
        }
        report.reportLearningRates(lrs);
    }
    if (config.collectHistograms(StatsType.Parameters)) {
        Map<String, Histogram> paramHistograms = getHistograms(model.paramTable(backpropParamsOnly), config.numHistogramBins(StatsType.Parameters));
        report.reportHistograms(StatsType.Parameters, paramHistograms);
    }
    if (config.collectHistograms(StatsType.Gradients)) {
        Map<String, Histogram> gradientHistograms = getHistograms(gradientsPreUpdateMap, config.numHistogramBins(StatsType.Gradients));
        report.reportHistograms(StatsType.Gradients, gradientHistograms);
    }
    if (config.collectHistograms(StatsType.Updates)) {
        Map<String, Histogram> updateHistograms = getHistograms(model.gradient().gradientForVariable(), config.numHistogramBins(StatsType.Updates));
        report.reportHistograms(StatsType.Updates, updateHistograms);
    }
    if (config.collectHistograms(StatsType.Activations)) {
        Map<String, Histogram> activationHistograms = getHistograms(activationsMap, config.numHistogramBins(StatsType.Activations));
        report.reportHistograms(StatsType.Activations, activationHistograms);
    }
    if (config.collectMean(StatsType.Parameters)) {
        Map<String, Double> meanParams = calculateSummaryStats(model.paramTable(backpropParamsOnly), StatType.Mean);
        report.reportMean(StatsType.Parameters, meanParams);
    }
    if (config.collectMean(StatsType.Gradients)) {
        Map<String, Double> meanGradients = calculateSummaryStats(gradientsPreUpdateMap, StatType.Mean);
        report.reportMean(StatsType.Gradients, meanGradients);
    }
    if (config.collectMean(StatsType.Updates)) {
        Map<String, Double> meanUpdates = calculateSummaryStats(model.gradient().gradientForVariable(), StatType.Mean);
        report.reportMean(StatsType.Updates, meanUpdates);
    }
    if (config.collectMean(StatsType.Activations)) {
        Map<String, Double> meanActivations = calculateSummaryStats(activationsMap, StatType.Mean);
        report.reportMean(StatsType.Activations, meanActivations);
    }
    if (config.collectStdev(StatsType.Parameters)) {
        Map<String, Double> stdevParams = calculateSummaryStats(model.paramTable(backpropParamsOnly), StatType.Stdev);
        report.reportStdev(StatsType.Parameters, stdevParams);
    }
    if (config.collectStdev(StatsType.Gradients)) {
        Map<String, Double> stdevGradient = calculateSummaryStats(gradientsPreUpdateMap, StatType.Stdev);
        report.reportStdev(StatsType.Gradients, stdevGradient);
    }
    if (config.collectStdev(StatsType.Updates)) {
        Map<String, Double> stdevUpdates = calculateSummaryStats(model.gradient().gradientForVariable(), StatType.Stdev);
        report.reportStdev(StatsType.Updates, stdevUpdates);
    }
    if (config.collectStdev(StatsType.Activations)) {
        Map<String, Double> stdevActivations = calculateSummaryStats(activationsMap, StatType.Stdev);
        report.reportStdev(StatsType.Activations, stdevActivations);
    }
    if (config.collectMeanMagnitudes(StatsType.Parameters)) {
        Map<String, Double> meanMagParams = calculateSummaryStats(model.paramTable(backpropParamsOnly), StatType.MeanMagnitude);
        report.reportMeanMagnitudes(StatsType.Parameters, meanMagParams);
    }
    if (config.collectMeanMagnitudes(StatsType.Gradients)) {
        Map<String, Double> meanMagGradients = calculateSummaryStats(gradientsPreUpdateMap, StatType.MeanMagnitude);
        report.reportMeanMagnitudes(StatsType.Gradients, meanMagGradients);
    }
    if (config.collectMeanMagnitudes(StatsType.Updates)) {
        Map<String, Double> meanMagUpdates = calculateSummaryStats(model.gradient().gradientForVariable(), StatType.MeanMagnitude);
        report.reportMeanMagnitudes(StatsType.Updates, meanMagUpdates);
    }
    if (config.collectMeanMagnitudes(StatsType.Activations)) {
        Map<String, Double> meanMagActivations = calculateSummaryStats(activationsMap, StatType.MeanMagnitude);
        report.reportMeanMagnitudes(StatsType.Activations, meanMagActivations);
    }
    long endTime = getTime();
    //Amount of time required to alculate all histograms, means etc.
    report.reportStatsCollectionDurationMS((int) (endTime - currentTime));
    modelInfo.lastReportTime = currentTime;
    modelInfo.lastReportIteration = iteration;
    report.reportIterationCount(iteration);
    this.router.putUpdate(report);
    modelInfo.iterCount = iteration;
    activationsMap = null;
}
Also used : DefaultStatsUpdateConfiguration(org.deeplearning4j.ui.stats.impl.DefaultStatsUpdateConfiguration) GarbageCollectorMXBean(java.lang.management.GarbageCollectorMXBean) Pointer(org.bytedeco.javacpp.Pointer) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) NativeOps(org.nd4j.nativeblas.NativeOps) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) Layer(org.deeplearning4j.nn.api.Layer)

Example 40 with NeuralNetConfiguration

use of org.deeplearning4j.nn.conf.NeuralNetConfiguration in project deeplearning4j by deeplearning4j.

the class TestSparkLayer method testIris2.

@Test
public void testIris2() throws Exception {
    NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(10).learningRate(1e-1).layer(new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(4).nOut(3).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()).build();
    System.out.println("Initializing network");
    SparkDl4jLayer master = new SparkDl4jLayer(sc, conf);
    DataSet d = new IrisDataSetIterator(150, 150).next();
    d.normalizeZeroMeanZeroUnitVariance();
    d.shuffle();
    List<DataSet> next = d.asList();
    JavaRDD<DataSet> data = sc.parallelize(next);
    OutputLayer network2 = (OutputLayer) master.fitDataSet(data);
    Evaluation evaluation = new Evaluation();
    evaluation.eval(d.getLabels(), network2.output(d.getFeatureMatrix()));
    System.out.println(evaluation.stats());
}
Also used : OutputLayer(org.deeplearning4j.nn.layers.OutputLayer) Evaluation(org.deeplearning4j.eval.Evaluation) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSet(org.nd4j.linalg.dataset.DataSet) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) Test(org.junit.Test) BaseSparkTest(org.deeplearning4j.spark.BaseSparkTest)

Aggregations

NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)83 INDArray (org.nd4j.linalg.api.ndarray.INDArray)65 Test (org.junit.Test)55 Layer (org.deeplearning4j.nn.api.Layer)29 Gradient (org.deeplearning4j.nn.gradient.Gradient)26 DenseLayer (org.deeplearning4j.nn.conf.layers.DenseLayer)24 Updater (org.deeplearning4j.nn.api.Updater)22 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)21 DefaultGradient (org.deeplearning4j.nn.gradient.DefaultGradient)21 DataSet (org.nd4j.linalg.dataset.DataSet)14 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)11 ScoreIterationListener (org.deeplearning4j.optimize.listeners.ScoreIterationListener)9 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)8 IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)6 UniformDistribution (org.deeplearning4j.nn.conf.distribution.UniformDistribution)6 RnnOutputLayer (org.deeplearning4j.nn.layers.recurrent.RnnOutputLayer)6 MnistDataFetcher (org.deeplearning4j.datasets.fetchers.MnistDataFetcher)4 Evaluation (org.deeplearning4j.eval.Evaluation)4 Model (org.deeplearning4j.nn.api.Model)4 ConvolutionLayer (org.deeplearning4j.nn.conf.layers.ConvolutionLayer)4