Search in sources :

Example 1 with NNStructureComparator

use of ml.shifu.shifu.core.dtrain.nn.NNStructureComparator in project shifu by ShifuML.

the class NNModelSpecTest method testModelStructureCompare.

@Test
public void testModelStructureCompare() {
    BasicML basicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model0.nn")));
    BasicNetwork basicNetwork = (BasicNetwork) basicML;
    FlatNetwork flatNetwork = basicNetwork.getFlat();
    BasicML extendedBasicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model1.nn")));
    BasicNetwork extendedBasicNetwork = (BasicNetwork) extendedBasicML;
    FlatNetwork extendedFlatNetwork = extendedBasicNetwork.getFlat();
    Assert.assertEquals(new NNStructureComparator().compare(flatNetwork, extendedFlatNetwork), -1);
    Assert.assertEquals(new NNStructureComparator().compare(flatNetwork, flatNetwork), 0);
    Assert.assertEquals(new NNStructureComparator().compare(extendedFlatNetwork, flatNetwork), 1);
    BasicML diffBasicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model2.nn")));
    BasicNetwork diffBasicNetwork = (BasicNetwork) diffBasicML;
    FlatNetwork diffFlatNetwork = diffBasicNetwork.getFlat();
    Assert.assertEquals(new NNStructureComparator().compare(flatNetwork, diffFlatNetwork), -1);
    Assert.assertEquals(new NNStructureComparator().compare(diffFlatNetwork, flatNetwork), -1);
    Assert.assertEquals(new NNStructureComparator().compare(extendedFlatNetwork, diffFlatNetwork), 1);
    Assert.assertEquals(new NNStructureComparator().compare(diffFlatNetwork, extendedFlatNetwork), -1);
    BasicML deepBasicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model3.nn")));
    BasicNetwork deppBasicNetwork = (BasicNetwork) deepBasicML;
    FlatNetwork deepFlatNetwork = deppBasicNetwork.getFlat();
    Assert.assertEquals(new NNStructureComparator().compare(deepFlatNetwork, flatNetwork), 1);
    Assert.assertEquals(new NNStructureComparator().compare(flatNetwork, deepFlatNetwork), -1);
}
Also used : FlatNetwork(org.encog.neural.flat.FlatNetwork) BasicNetwork(org.encog.neural.networks.BasicNetwork) NNStructureComparator(ml.shifu.shifu.core.dtrain.nn.NNStructureComparator) BasicML(org.encog.ml.BasicML) File(java.io.File) Test(org.testng.annotations.Test)

Aggregations

File (java.io.File)1 NNStructureComparator (ml.shifu.shifu.core.dtrain.nn.NNStructureComparator)1 BasicML (org.encog.ml.BasicML)1 FlatNetwork (org.encog.neural.flat.FlatNetwork)1 BasicNetwork (org.encog.neural.networks.BasicNetwork)1 Test (org.testng.annotations.Test)1