use of org.nd4j.linalg.dataset.MultiDataSet in project deeplearning4j by deeplearning4j.
the class ComputationGraphTestRNN method checkMaskArrayClearance.
@Test
public void checkMaskArrayClearance() {
for (boolean tbptt : new boolean[] { true, false }) {
//Simple "does it throw an exception" type test...
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().iterations(1).seed(12345).graphBuilder().addInputs("in").addLayer("out", new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY).nIn(1).nOut(1).build(), "in").setOutputs("out").backpropType(tbptt ? BackpropType.TruncatedBPTT : BackpropType.Standard).tBPTTForwardLength(8).tBPTTBackwardLength(8).build();
ComputationGraph net = new ComputationGraph(conf);
net.init();
MultiDataSet data = new MultiDataSet(new INDArray[] { Nd4j.linspace(1, 10, 10).reshape(1, 1, 10) }, new INDArray[] { Nd4j.linspace(2, 20, 10).reshape(1, 1, 10) }, new INDArray[] { Nd4j.ones(10) }, new INDArray[] { Nd4j.ones(10) });
net.fit(data);
assertNull(net.getInputMaskArrays());
assertNull(net.getLabelMaskArrays());
for (Layer l : net.getLayers()) {
assertNull(l.getMaskArray());
}
DataSet ds = new DataSet(data.getFeatures(0), data.getLabels(0), data.getFeaturesMaskArray(0), data.getLabelsMaskArray(0));
net.fit(ds);
assertNull(net.getInputMaskArrays());
assertNull(net.getLabelMaskArrays());
for (Layer l : net.getLayers()) {
assertNull(l.getMaskArray());
}
net.fit(data.getFeatures(), data.getLabels(), data.getFeaturesMaskArrays(), data.getLabelsMaskArrays());
assertNull(net.getInputMaskArrays());
assertNull(net.getLabelMaskArrays());
for (Layer l : net.getLayers()) {
assertNull(l.getMaskArray());
}
MultiDataSetIterator iter = new IteratorMultiDataSetIterator(Collections.singletonList((org.nd4j.linalg.dataset.api.MultiDataSet) data).iterator(), 1);
net.fit(iter);
assertNull(net.getInputMaskArrays());
assertNull(net.getLabelMaskArrays());
for (Layer l : net.getLayers()) {
assertNull(l.getMaskArray());
}
DataSetIterator iter2 = new IteratorDataSetIterator(Collections.singletonList(ds).iterator(), 1);
net.fit(iter2);
assertNull(net.getInputMaskArrays());
assertNull(net.getLabelMaskArrays());
for (Layer l : net.getLayers()) {
assertNull(l.getMaskArray());
}
}
}
use of org.nd4j.linalg.dataset.MultiDataSet in project deeplearning4j by deeplearning4j.
the class TransferLearningHelperTest method testFitUnFrozen.
@Test
public void testFitUnFrozen() {
NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().learningRate(0.9).seed(124).activation(Activation.IDENTITY).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.SGD);
ComputationGraphConfiguration conf = overallConf.graphBuilder().addInputs("inCentre", "inRight").addLayer("denseCentre0", new DenseLayer.Builder().nIn(10).nOut(9).build(), "inCentre").addLayer("denseCentre1", new DenseLayer.Builder().nIn(9).nOut(8).build(), "denseCentre0").addLayer("denseCentre2", new DenseLayer.Builder().nIn(8).nOut(7).build(), "denseCentre1").addLayer("denseCentre3", new DenseLayer.Builder().nIn(7).nOut(7).build(), "denseCentre2").addLayer("outCentre", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(7).nOut(4).build(), "denseCentre3").addVertex("subsetLeft", new SubsetVertex(0, 3), "denseCentre1").addLayer("denseLeft0", new DenseLayer.Builder().nIn(4).nOut(5).build(), "subsetLeft").addLayer("outLeft", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(5).nOut(6).build(), "denseLeft0").addLayer("denseRight", new DenseLayer.Builder().nIn(7).nOut(7).build(), "denseCentre2").addLayer("denseRight0", new DenseLayer.Builder().nIn(2).nOut(3).build(), "inRight").addVertex("mergeRight", new MergeVertex(), "denseRight", "denseRight0").addLayer("denseRight1", new DenseLayer.Builder().nIn(10).nOut(5).build(), "mergeRight").addLayer("outRight", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(5).nOut(5).build(), "denseRight1").setOutputs("outLeft", "outCentre", "outRight").build();
ComputationGraph modelToTune = new ComputationGraph(conf);
modelToTune.init();
INDArray inRight = Nd4j.rand(10, 2);
INDArray inCentre = Nd4j.rand(10, 10);
INDArray outLeft = Nd4j.rand(10, 6);
INDArray outRight = Nd4j.rand(10, 5);
INDArray outCentre = Nd4j.rand(10, 4);
MultiDataSet origData = new MultiDataSet(new INDArray[] { inCentre, inRight }, new INDArray[] { outLeft, outCentre, outRight });
ComputationGraph modelIdentical = modelToTune.clone();
modelIdentical.getVertex("denseCentre0").setLayerAsFrozen();
modelIdentical.getVertex("denseCentre1").setLayerAsFrozen();
modelIdentical.getVertex("denseCentre2").setLayerAsFrozen();
TransferLearningHelper helper = new TransferLearningHelper(modelToTune, "denseCentre2");
MultiDataSet featurizedDataSet = helper.featurize(origData);
assertEquals(modelIdentical.getLayer("denseRight0").params(), modelToTune.getLayer("denseRight0").params());
modelIdentical.fit(origData);
helper.fitFeaturized(featurizedDataSet);
assertEquals(modelIdentical.getLayer("denseCentre0").params(), modelToTune.getLayer("denseCentre0").params());
assertEquals(modelIdentical.getLayer("denseCentre1").params(), modelToTune.getLayer("denseCentre1").params());
assertEquals(modelIdentical.getLayer("denseCentre2").params(), modelToTune.getLayer("denseCentre2").params());
assertEquals(modelIdentical.getLayer("denseCentre3").params(), modelToTune.getLayer("denseCentre3").params());
assertEquals(modelIdentical.getLayer("outCentre").params(), modelToTune.getLayer("outCentre").params());
assertEquals(modelIdentical.getLayer("denseRight").conf().toJson(), modelToTune.getLayer("denseRight").conf().toJson());
assertEquals(modelIdentical.getLayer("denseRight").params(), modelToTune.getLayer("denseRight").params());
assertEquals(modelIdentical.getLayer("denseRight0").conf().toJson(), modelToTune.getLayer("denseRight0").conf().toJson());
//assertEquals(modelIdentical.getLayer("denseRight0").params(),modelToTune.getLayer("denseRight0").params());
assertEquals(modelIdentical.getLayer("denseRight1").params(), modelToTune.getLayer("denseRight1").params());
assertEquals(modelIdentical.getLayer("outRight").params(), modelToTune.getLayer("outRight").params());
assertEquals(modelIdentical.getLayer("denseLeft0").params(), modelToTune.getLayer("denseLeft0").params());
assertEquals(modelIdentical.getLayer("outLeft").params(), modelToTune.getLayer("outLeft").params());
log.info(modelIdentical.summary());
log.info(helper.unfrozenGraph().summary());
}
use of org.nd4j.linalg.dataset.MultiDataSet in project deeplearning4j by deeplearning4j.
the class TransferLearningComplex method testSimplerMergeBackProp.
@Test
public void testSimplerMergeBackProp() {
NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().learningRate(0.9).activation(Activation.IDENTITY).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.SGD);
/*
inCentre inRight
| |
denseCentre0 denseRight0
| |
|------ mergeRight ------|
|
outRight
*/
ComputationGraphConfiguration conf = overallConf.graphBuilder().addInputs("inCentre", "inRight").addLayer("denseCentre0", new DenseLayer.Builder().nIn(2).nOut(2).build(), "inCentre").addLayer("denseRight0", new DenseLayer.Builder().nIn(2).nOut(2).build(), "inRight").addVertex("mergeRight", new MergeVertex(), "denseCentre0", "denseRight0").addLayer("outRight", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(4).nOut(2).build(), "mergeRight").setOutputs("outRight").build();
ComputationGraph modelToTune = new ComputationGraph(conf);
modelToTune.init();
MultiDataSet randData = new MultiDataSet(new INDArray[] { Nd4j.rand(2, 2), Nd4j.rand(2, 2) }, new INDArray[] { Nd4j.rand(2, 2) });
INDArray denseCentre0 = modelToTune.feedForward(randData.getFeatures(), false).get("denseCentre0");
MultiDataSet otherRandData = new MultiDataSet(new INDArray[] { denseCentre0, randData.getFeatures(1) }, randData.getLabels());
ComputationGraphConfiguration otherConf = overallConf.graphBuilder().addInputs("denseCentre0", "inRight").addLayer("denseRight0", new DenseLayer.Builder().nIn(2).nOut(2).build(), "inRight").addVertex("mergeRight", new MergeVertex(), "denseCentre0", "denseRight0").addLayer("outRight", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(4).nOut(2).build(), "mergeRight").setOutputs("outRight").build();
ComputationGraph modelOther = new ComputationGraph(otherConf);
modelOther.init();
modelOther.getLayer("denseRight0").setParams(modelToTune.getLayer("denseRight0").params());
modelOther.getLayer("outRight").setParams(modelToTune.getLayer("outRight").params());
modelToTune.getVertex("denseCentre0").setLayerAsFrozen();
ComputationGraph modelNow = new TransferLearning.GraphBuilder(modelToTune).setFeatureExtractor("denseCentre0").build();
int n = 0;
while (n < 5) {
if (n == 0) {
//confirm activations out of the merge are equivalent
assertEquals(modelToTune.feedForward(randData.getFeatures(), false).get("mergeRight"), modelOther.feedForward(otherRandData.getFeatures(), false).get("mergeRight"));
assertEquals(modelNow.feedForward(randData.getFeatures(), false).get("mergeRight"), modelOther.feedForward(otherRandData.getFeatures(), false).get("mergeRight"));
}
//confirm activations out of frozen vertex is the same as the input to the other model
modelOther.fit(otherRandData);
modelToTune.fit(randData);
modelNow.fit(randData);
assertEquals(otherRandData.getFeatures(0), modelNow.feedForward(randData.getFeatures(), false).get("denseCentre0"));
assertEquals(otherRandData.getFeatures(0), modelToTune.feedForward(randData.getFeatures(), false).get("denseCentre0"));
assertEquals(modelOther.getLayer("denseRight0").params(), modelNow.getLayer("denseRight0").params());
assertEquals(modelOther.getLayer("denseRight0").params(), modelToTune.getLayer("denseRight0").params());
assertEquals(modelOther.getLayer("outRight").params(), modelNow.getLayer("outRight").params());
assertEquals(modelOther.getLayer("outRight").params(), modelToTune.getLayer("outRight").params());
n++;
}
}
use of org.nd4j.linalg.dataset.MultiDataSet in project deeplearning4j by deeplearning4j.
the class TransferLearningComplex method testLessSimpleMergeBackProp.
@Test
public void testLessSimpleMergeBackProp() {
NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().learningRate(0.9).activation(Activation.IDENTITY).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.SGD);
/*
inCentre inRight
| |
denseCentre0 denseRight0
| |
|------ mergeRight ------|
| |
outCentre outRight
*/
ComputationGraphConfiguration conf = overallConf.graphBuilder().addInputs("inCentre", "inRight").addLayer("denseCentre0", new DenseLayer.Builder().nIn(2).nOut(2).build(), "inCentre").addLayer("outCentre", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(2).nOut(2).build(), "denseCentre0").addLayer("denseRight0", new DenseLayer.Builder().nIn(3).nOut(2).build(), "inRight").addVertex("mergeRight", new MergeVertex(), "denseCentre0", "denseRight0").addLayer("outRight", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(4).nOut(2).build(), "mergeRight").setOutputs("outRight").setOutputs("outCentre").build();
ComputationGraph modelToTune = new ComputationGraph(conf);
modelToTune.init();
modelToTune.getVertex("denseCentre0").setLayerAsFrozen();
MultiDataSet randData = new MultiDataSet(new INDArray[] { Nd4j.rand(2, 2), Nd4j.rand(2, 3) }, new INDArray[] { Nd4j.rand(2, 2), Nd4j.rand(2, 2) });
INDArray denseCentre0 = modelToTune.feedForward(randData.getFeatures(), false).get("denseCentre0");
MultiDataSet otherRandData = new MultiDataSet(new INDArray[] { denseCentre0, randData.getFeatures(1) }, randData.getLabels());
ComputationGraph modelNow = new TransferLearning.GraphBuilder(modelToTune).setFeatureExtractor("denseCentre0").build();
assertTrue(modelNow.getLayer("denseCentre0") instanceof FrozenLayer);
int n = 0;
while (n < 5) {
if (n == 0) {
//confirm activations out of the merge are equivalent
assertEquals(modelToTune.feedForward(randData.getFeatures(), false).get("mergeRight"), modelNow.feedForward(otherRandData.getFeatures(), false).get("mergeRight"));
}
//confirm activations out of frozen vertex is the same as the input to the other model
modelToTune.fit(randData);
modelNow.fit(randData);
assertEquals(otherRandData.getFeatures(0), modelNow.feedForward(randData.getFeatures(), false).get("denseCentre0"));
assertEquals(otherRandData.getFeatures(0), modelToTune.feedForward(randData.getFeatures(), false).get("denseCentre0"));
assertEquals(modelToTune.getLayer("denseRight0").params(), modelNow.getLayer("denseRight0").params());
assertEquals(modelToTune.getLayer("outRight").params(), modelNow.getLayer("outRight").params());
assertEquals(modelToTune.getLayer("outCentre").params(), modelNow.getLayer("outCentre").params());
n++;
}
}
use of org.nd4j.linalg.dataset.MultiDataSet in project deeplearning4j by deeplearning4j.
the class TransferLearningComplex method testAddOutput.
@Test
public void testAddOutput() {
NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().learningRate(0.9).activation(Activation.IDENTITY).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.SGD);
ComputationGraphConfiguration conf = overallConf.graphBuilder().addInputs("inCentre", "inRight").addLayer("denseCentre0", new DenseLayer.Builder().nIn(2).nOut(2).build(), "inCentre").addLayer("denseRight0", new DenseLayer.Builder().nIn(2).nOut(2).build(), "inRight").addVertex("mergeRight", new MergeVertex(), "denseCentre0", "denseRight0").addLayer("outRight", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(4).nOut(2).build(), "mergeRight").setOutputs("outRight").build();
ComputationGraph modelToTune = new ComputationGraph(conf);
modelToTune.init();
ComputationGraph modelNow = new TransferLearning.GraphBuilder(modelToTune).addLayer("outCentre", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(2).nOut(3).build(), "denseCentre0").setOutputs("outCentre").build();
assertEquals(2, modelNow.getNumOutputArrays());
MultiDataSet rand = new MultiDataSet(new INDArray[] { Nd4j.rand(2, 2), Nd4j.rand(2, 2) }, new INDArray[] { Nd4j.rand(2, 2), Nd4j.rand(2, 3) });
modelNow.fit(rand);
log.info(modelNow.summary());
}
Aggregations