Search in sources :

Example 76 with ClassPathResource

use of org.nd4j.linalg.io.ClassPathResource in project deeplearning4j by deeplearning4j.

the class RegressionTest060 method regressionTestMLP1.

@Test
public void regressionTestMLP1() throws Exception {
    File f = new ClassPathResource("regression_testing/060/060_ModelSerializer_Regression_MLP_1.zip").getTempFileFromArchive();
    MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true);
    MultiLayerConfiguration conf = net.getLayerWiseConfigurations();
    assertEquals(2, conf.getConfs().size());
    assertTrue(conf.isBackprop());
    assertFalse(conf.isPretrain());
    DenseLayer l0 = (DenseLayer) conf.getConf(0).getLayer();
    assertEquals("relu", l0.getActivationFn().toString());
    assertEquals(3, l0.getNIn());
    assertEquals(4, l0.getNOut());
    assertEquals(WeightInit.XAVIER, l0.getWeightInit());
    assertEquals(Updater.NESTEROVS, l0.getUpdater());
    assertEquals(0.9, l0.getMomentum(), 1e-6);
    assertEquals(0.15, l0.getLearningRate(), 1e-6);
    OutputLayer l1 = (OutputLayer) conf.getConf(1).getLayer();
    assertEquals("softmax", l1.getActivationFn().toString());
    assertEquals(LossFunctions.LossFunction.MCXENT, l1.getLossFunction());
    assertTrue(l1.getLossFn() instanceof LossMCXENT);
    assertEquals(4, l1.getNIn());
    assertEquals(5, l1.getNOut());
    assertEquals(WeightInit.XAVIER, l1.getWeightInit());
    assertEquals(Updater.NESTEROVS, l1.getUpdater());
    assertEquals(0.9, l1.getMomentum(), 1e-6);
    assertEquals(0.15, l1.getLearningRate(), 1e-6);
    int numParams = net.numParams();
    assertEquals(Nd4j.linspace(1, numParams, numParams), net.params());
    int updaterSize = net.getUpdater().stateSizeForLayer(net);
    assertEquals(Nd4j.linspace(1, updaterSize, updaterSize), net.getUpdater().getStateViewArray());
}
Also used : LossMCXENT(org.nd4j.linalg.lossfunctions.impl.LossMCXENT) File(java.io.File) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) Test(org.junit.Test)

Example 77 with ClassPathResource

use of org.nd4j.linalg.io.ClassPathResource in project deeplearning4j by deeplearning4j.

the class KerasModelConfigurationTest method importKerasConvnetTheanoConfigTest.

@Test
public void importKerasConvnetTheanoConfigTest() throws Exception {
    ClassPathResource configResource = new ClassPathResource("modelimport/keras/configs/cnn_th_config.json", KerasModelConfigurationTest.class.getClassLoader());
    MultiLayerConfiguration config = new KerasModel.ModelBuilder().modelJsonInputStream(configResource.getInputStream()).enforceTrainingConfig(true).buildSequential().getMultiLayerConfiguration();
    MultiLayerNetwork model = new MultiLayerNetwork(config);
    model.init();
}
Also used : MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) Test(org.junit.Test)

Example 78 with ClassPathResource

use of org.nd4j.linalg.io.ClassPathResource in project deeplearning4j by deeplearning4j.

the class KerasModelConfigurationTest method importKerasMlpSequentialConfigTest.

@Test
public void importKerasMlpSequentialConfigTest() throws Exception {
    ClassPathResource configResource = new ClassPathResource("modelimport/keras/configs/mlp_config.json", KerasModelConfigurationTest.class.getClassLoader());
    MultiLayerConfiguration config = new KerasModel.ModelBuilder().modelJsonInputStream(configResource.getInputStream()).enforceTrainingConfig(true).buildSequential().getMultiLayerConfiguration();
    MultiLayerNetwork model = new MultiLayerNetwork(config);
    model.init();
}
Also used : MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) Test(org.junit.Test)

Example 79 with ClassPathResource

use of org.nd4j.linalg.io.ClassPathResource in project deeplearning4j by deeplearning4j.

the class KerasModelConfigurationTest method importKerasMlpModelMultilossConfigTest.

@Test
public void importKerasMlpModelMultilossConfigTest() throws Exception {
    ClassPathResource configResource = new ClassPathResource("modelimport/keras/configs/mlp_fapi_multiloss_config.json", KerasModelConfigurationTest.class.getClassLoader());
    ComputationGraphConfiguration config = new KerasModel.ModelBuilder().modelJsonInputStream(configResource.getInputStream()).enforceTrainingConfig(true).buildModel().getComputationGraphConfiguration();
    ComputationGraph model = new ComputationGraph(config);
    model.init();
}
Also used : ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) Test(org.junit.Test)

Example 80 with ClassPathResource

use of org.nd4j.linalg.io.ClassPathResource in project deeplearning4j by deeplearning4j.

the class KerasModelConfigurationTest method importKerasMlpModelConfigTest.

@Test
public void importKerasMlpModelConfigTest() throws Exception {
    ClassPathResource configResource = new ClassPathResource("modelimport/keras/configs/mlp_fapi_config.json", KerasModelConfigurationTest.class.getClassLoader());
    ComputationGraphConfiguration config = new KerasModel.ModelBuilder().modelJsonInputStream(configResource.getInputStream()).enforceTrainingConfig(true).buildModel().getComputationGraphConfiguration();
    ComputationGraph model = new ComputationGraph(config);
    model.init();
}
Also used : ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) Test(org.junit.Test)

Aggregations

ClassPathResource (org.nd4j.linalg.io.ClassPathResource)112 Test (org.junit.Test)100 lombok.val (lombok.val)31 INDArray (org.nd4j.linalg.api.ndarray.INDArray)26 SequenceRecordReader (org.datavec.api.records.reader.SequenceRecordReader)23 CSVSequenceRecordReader (org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader)23 DataSet (org.nd4j.linalg.dataset.DataSet)23 File (java.io.File)22 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)20 FileSplit (org.datavec.api.split.FileSplit)18 CollectionSequenceRecordReader (org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader)14 Ignore (org.junit.Ignore)14 CSVRecordReader (org.datavec.api.records.reader.impl.csv.CSVRecordReader)13 RecordReader (org.datavec.api.records.reader.RecordReader)12 NumberedFileInputSplit (org.datavec.api.split.NumberedFileInputSplit)12 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)12 MultiDataSet (org.nd4j.linalg.dataset.api.MultiDataSet)11 MultiDataSetIterator (org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator)8 RecordMetaData (org.datavec.api.records.metadata.RecordMetaData)7 ImageRecordReader (org.datavec.image.recordreader.ImageRecordReader)7