Search in sources :

Example 6 with LossMCXENT

use of org.nd4j.linalg.lossfunctions.impl.LossMCXENT in project deeplearning4j by deeplearning4j.

the class RegressionTest071 method regressionTestCGLSTM1.

@Test
public void regressionTestCGLSTM1() throws Exception {
    File f = new ClassPathResource("regression_testing/071/071_ModelSerializer_Regression_CG_LSTM_1.zip").getTempFileFromArchive();
    ComputationGraph net = ModelSerializer.restoreComputationGraph(f, true);
    ComputationGraphConfiguration conf = net.getConfiguration();
    assertEquals(3, conf.getVertices().size());
    assertTrue(conf.isBackprop());
    assertFalse(conf.isPretrain());
    GravesLSTM l0 = (GravesLSTM) ((LayerVertex) conf.getVertices().get("0")).getLayerConf().getLayer();
    assertEquals("tanh", l0.getActivationFn().toString());
    assertEquals(3, l0.getNIn());
    assertEquals(4, l0.getNOut());
    assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, l0.getGradientNormalization());
    assertEquals(1.5, l0.getGradientNormalizationThreshold(), 1e-5);
    GravesBidirectionalLSTM l1 = (GravesBidirectionalLSTM) ((LayerVertex) conf.getVertices().get("1")).getLayerConf().getLayer();
    assertEquals("softsign", l1.getActivationFn().toString());
    assertEquals(4, l1.getNIn());
    assertEquals(4, l1.getNOut());
    assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, l1.getGradientNormalization());
    assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5);
    RnnOutputLayer l2 = (RnnOutputLayer) ((LayerVertex) conf.getVertices().get("2")).getLayerConf().getLayer();
    assertEquals(4, l2.getNIn());
    assertEquals(5, l2.getNOut());
    assertEquals("softmax", l2.getActivationFn().toString());
    assertTrue(l2.getLossFn() instanceof LossMCXENT);
}
Also used : LayerVertex(org.deeplearning4j.nn.conf.graph.LayerVertex) LossMCXENT(org.nd4j.linalg.lossfunctions.impl.LossMCXENT) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) File(java.io.File) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) Test(org.junit.Test)

Example 7 with LossMCXENT

use of org.nd4j.linalg.lossfunctions.impl.LossMCXENT in project deeplearning4j by deeplearning4j.

the class RegressionTest071 method regressionTestLSTM1.

@Test
public void regressionTestLSTM1() throws Exception {
    File f = new ClassPathResource("regression_testing/071/071_ModelSerializer_Regression_LSTM_1.zip").getTempFileFromArchive();
    MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true);
    MultiLayerConfiguration conf = net.getLayerWiseConfigurations();
    assertEquals(3, conf.getConfs().size());
    assertTrue(conf.isBackprop());
    assertFalse(conf.isPretrain());
    GravesLSTM l0 = (GravesLSTM) conf.getConf(0).getLayer();
    assertEquals("tanh", l0.getActivationFn().toString());
    assertEquals(3, l0.getNIn());
    assertEquals(4, l0.getNOut());
    assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, l0.getGradientNormalization());
    assertEquals(1.5, l0.getGradientNormalizationThreshold(), 1e-5);
    GravesBidirectionalLSTM l1 = (GravesBidirectionalLSTM) conf.getConf(1).getLayer();
    assertEquals("softsign", l1.getActivationFn().toString());
    assertEquals(4, l1.getNIn());
    assertEquals(4, l1.getNOut());
    assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, l1.getGradientNormalization());
    assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5);
    RnnOutputLayer l2 = (RnnOutputLayer) conf.getConf(2).getLayer();
    assertEquals(4, l2.getNIn());
    assertEquals(5, l2.getNOut());
    assertEquals("softmax", l2.getActivationFn().toString());
    assertTrue(l2.getLossFn() instanceof LossMCXENT);
}
Also used : LossMCXENT(org.nd4j.linalg.lossfunctions.impl.LossMCXENT) File(java.io.File) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) Test(org.junit.Test)

Example 8 with LossMCXENT

use of org.nd4j.linalg.lossfunctions.impl.LossMCXENT in project deeplearning4j by deeplearning4j.

the class RegressionTest060 method regressionTestMLP1.

@Test
public void regressionTestMLP1() throws Exception {
    File f = new ClassPathResource("regression_testing/060/060_ModelSerializer_Regression_MLP_1.zip").getTempFileFromArchive();
    MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true);
    MultiLayerConfiguration conf = net.getLayerWiseConfigurations();
    assertEquals(2, conf.getConfs().size());
    assertTrue(conf.isBackprop());
    assertFalse(conf.isPretrain());
    DenseLayer l0 = (DenseLayer) conf.getConf(0).getLayer();
    assertEquals("relu", l0.getActivationFn().toString());
    assertEquals(3, l0.getNIn());
    assertEquals(4, l0.getNOut());
    assertEquals(WeightInit.XAVIER, l0.getWeightInit());
    assertEquals(Updater.NESTEROVS, l0.getUpdater());
    assertEquals(0.9, l0.getMomentum(), 1e-6);
    assertEquals(0.15, l0.getLearningRate(), 1e-6);
    OutputLayer l1 = (OutputLayer) conf.getConf(1).getLayer();
    assertEquals("softmax", l1.getActivationFn().toString());
    assertEquals(LossFunctions.LossFunction.MCXENT, l1.getLossFunction());
    assertTrue(l1.getLossFn() instanceof LossMCXENT);
    assertEquals(4, l1.getNIn());
    assertEquals(5, l1.getNOut());
    assertEquals(WeightInit.XAVIER, l1.getWeightInit());
    assertEquals(Updater.NESTEROVS, l1.getUpdater());
    assertEquals(0.9, l1.getMomentum(), 1e-6);
    assertEquals(0.15, l1.getLearningRate(), 1e-6);
    int numParams = net.numParams();
    assertEquals(Nd4j.linspace(1, numParams, numParams), net.params());
    int updaterSize = net.getUpdater().stateSizeForLayer(net);
    assertEquals(Nd4j.linspace(1, updaterSize, updaterSize), net.getUpdater().getStateViewArray());
}
Also used : LossMCXENT(org.nd4j.linalg.lossfunctions.impl.LossMCXENT) File(java.io.File) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) Test(org.junit.Test)

Aggregations

LossMCXENT (org.nd4j.linalg.lossfunctions.impl.LossMCXENT)8 File (java.io.File)7 Test (org.junit.Test)7 ClassPathResource (org.nd4j.linalg.io.ClassPathResource)7 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)5 LayerVertex (org.deeplearning4j.nn.conf.graph.LayerVertex)2 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)2 IOException (java.io.IOException)1 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)1 IActivation (org.nd4j.linalg.activations.IActivation)1 LossFunctions (org.nd4j.linalg.lossfunctions.LossFunctions)1 LossBinaryXENT (org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT)1 LossMSE (org.nd4j.linalg.lossfunctions.impl.LossMSE)1 LossNegativeLogLikelihood (org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood)1 JsonNode (org.nd4j.shade.jackson.databind.JsonNode)1 ObjectMapper (org.nd4j.shade.jackson.databind.ObjectMapper)1 ArrayNode (org.nd4j.shade.jackson.databind.node.ArrayNode)1