use of ml.shifu.shifu.core.dtrain.nn.NNMaster in project shifu by ShifuML.
the class NNModelSpecTest method testModelFitIn.
@Test
public void testModelFitIn() {
PersistorRegistry.getInstance().add(new PersistBasicFloatNetwork());
BasicML basicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model5.nn")));
BasicNetwork basicNetwork = (BasicNetwork) basicML;
FlatNetwork flatNetwork = basicNetwork.getFlat();
BasicML extendedBasicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model6.nn")));
BasicNetwork extendedBasicNetwork = (BasicNetwork) extendedBasicML;
FlatNetwork extendedFlatNetwork = extendedBasicNetwork.getFlat();
NNMaster master = new NNMaster();
Set<Integer> fixedWeightIndexSet = master.fitExistingModelIn(flatNetwork, extendedFlatNetwork, Arrays.asList(new Integer[] { 1, 2, 3 }));
Assert.assertEquals(fixedWeightIndexSet.size(), 931);
fixedWeightIndexSet = master.fitExistingModelIn(flatNetwork, extendedFlatNetwork, Arrays.asList(new Integer[] { 1, 2, 3 }), false);
Assert.assertEquals(fixedWeightIndexSet.size(), 910);
}
use of ml.shifu.shifu.core.dtrain.nn.NNMaster in project shifu by ShifuML.
the class NNModelSpecTest method testFitExistingModelIn.
@Test
public void testFitExistingModelIn() {
BasicML basicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model0.nn")));
BasicNetwork basicNetwork = (BasicNetwork) basicML;
FlatNetwork flatNetwork = basicNetwork.getFlat();
NNMaster master = new NNMaster();
Set<Integer> fixedWeightIndexSet = master.fitExistingModelIn(flatNetwork, flatNetwork, Arrays.asList(new Integer[] { 6 }));
List<Integer> indexList = new ArrayList<Integer>(fixedWeightIndexSet);
Collections.sort(indexList);
Assert.assertEquals(indexList.size(), 31);
fixedWeightIndexSet = master.fitExistingModelIn(flatNetwork, flatNetwork, Arrays.asList(new Integer[] { 1 }));
indexList = new ArrayList<Integer>(fixedWeightIndexSet);
Collections.sort(indexList);
Assert.assertEquals(indexList.size(), 930);
BasicML extendedBasicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model1.nn")));
BasicNetwork extendedBasicNetwork = (BasicNetwork) extendedBasicML;
FlatNetwork extendedFlatNetwork = extendedBasicNetwork.getFlat();
fixedWeightIndexSet = master.fitExistingModelIn(flatNetwork, extendedFlatNetwork, Arrays.asList(new Integer[] { 1 }));
indexList = new ArrayList<Integer>(fixedWeightIndexSet);
Collections.sort(indexList);
Assert.assertEquals(indexList.size(), 930);
fixedWeightIndexSet = master.fitExistingModelIn(flatNetwork, extendedFlatNetwork, Arrays.asList(new Integer[] { 1 }), false);
indexList = new ArrayList<Integer>(fixedWeightIndexSet);
Collections.sort(indexList);
Assert.assertEquals(indexList.size(), 900);
}
Aggregations