Search in sources :

Example 1 with NNMaster

use of ml.shifu.shifu.core.dtrain.nn.NNMaster in project shifu by ShifuML.

the class NNModelSpecTest method testModelFitIn.

@Test
public void testModelFitIn() {
    PersistorRegistry.getInstance().add(new PersistBasicFloatNetwork());
    BasicML basicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model5.nn")));
    BasicNetwork basicNetwork = (BasicNetwork) basicML;
    FlatNetwork flatNetwork = basicNetwork.getFlat();
    BasicML extendedBasicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model6.nn")));
    BasicNetwork extendedBasicNetwork = (BasicNetwork) extendedBasicML;
    FlatNetwork extendedFlatNetwork = extendedBasicNetwork.getFlat();
    NNMaster master = new NNMaster();
    Set<Integer> fixedWeightIndexSet = master.fitExistingModelIn(flatNetwork, extendedFlatNetwork, Arrays.asList(new Integer[] { 1, 2, 3 }));
    Assert.assertEquals(fixedWeightIndexSet.size(), 931);
    fixedWeightIndexSet = master.fitExistingModelIn(flatNetwork, extendedFlatNetwork, Arrays.asList(new Integer[] { 1, 2, 3 }), false);
    Assert.assertEquals(fixedWeightIndexSet.size(), 910);
}
Also used : FlatNetwork(org.encog.neural.flat.FlatNetwork) NNMaster(ml.shifu.shifu.core.dtrain.nn.NNMaster) BasicNetwork(org.encog.neural.networks.BasicNetwork) BasicML(org.encog.ml.BasicML) File(java.io.File) PersistBasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork) Test(org.testng.annotations.Test)

Example 2 with NNMaster

use of ml.shifu.shifu.core.dtrain.nn.NNMaster in project shifu by ShifuML.

the class NNModelSpecTest method testFitExistingModelIn.

@Test
public void testFitExistingModelIn() {
    BasicML basicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model0.nn")));
    BasicNetwork basicNetwork = (BasicNetwork) basicML;
    FlatNetwork flatNetwork = basicNetwork.getFlat();
    NNMaster master = new NNMaster();
    Set<Integer> fixedWeightIndexSet = master.fitExistingModelIn(flatNetwork, flatNetwork, Arrays.asList(new Integer[] { 6 }));
    List<Integer> indexList = new ArrayList<Integer>(fixedWeightIndexSet);
    Collections.sort(indexList);
    Assert.assertEquals(indexList.size(), 31);
    fixedWeightIndexSet = master.fitExistingModelIn(flatNetwork, flatNetwork, Arrays.asList(new Integer[] { 1 }));
    indexList = new ArrayList<Integer>(fixedWeightIndexSet);
    Collections.sort(indexList);
    Assert.assertEquals(indexList.size(), 930);
    BasicML extendedBasicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model1.nn")));
    BasicNetwork extendedBasicNetwork = (BasicNetwork) extendedBasicML;
    FlatNetwork extendedFlatNetwork = extendedBasicNetwork.getFlat();
    fixedWeightIndexSet = master.fitExistingModelIn(flatNetwork, extendedFlatNetwork, Arrays.asList(new Integer[] { 1 }));
    indexList = new ArrayList<Integer>(fixedWeightIndexSet);
    Collections.sort(indexList);
    Assert.assertEquals(indexList.size(), 930);
    fixedWeightIndexSet = master.fitExistingModelIn(flatNetwork, extendedFlatNetwork, Arrays.asList(new Integer[] { 1 }), false);
    indexList = new ArrayList<Integer>(fixedWeightIndexSet);
    Collections.sort(indexList);
    Assert.assertEquals(indexList.size(), 900);
}
Also used : FlatNetwork(org.encog.neural.flat.FlatNetwork) NNMaster(ml.shifu.shifu.core.dtrain.nn.NNMaster) BasicNetwork(org.encog.neural.networks.BasicNetwork) ArrayList(java.util.ArrayList) BasicML(org.encog.ml.BasicML) File(java.io.File) Test(org.testng.annotations.Test)

Aggregations

File (java.io.File)2 NNMaster (ml.shifu.shifu.core.dtrain.nn.NNMaster)2 BasicML (org.encog.ml.BasicML)2 FlatNetwork (org.encog.neural.flat.FlatNetwork)2 BasicNetwork (org.encog.neural.networks.BasicNetwork)2 Test (org.testng.annotations.Test)2 ArrayList (java.util.ArrayList)1 PersistBasicFloatNetwork (ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork)1