use of ml.shifu.shifu.core.dtrain.nn.NNStructureComparator in project shifu by ShifuML.
the class NNModelSpecTest method testModelStructureCompare.
@Test
public void testModelStructureCompare() {
BasicML basicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model0.nn")));
BasicNetwork basicNetwork = (BasicNetwork) basicML;
FlatNetwork flatNetwork = basicNetwork.getFlat();
BasicML extendedBasicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model1.nn")));
BasicNetwork extendedBasicNetwork = (BasicNetwork) extendedBasicML;
FlatNetwork extendedFlatNetwork = extendedBasicNetwork.getFlat();
Assert.assertEquals(new NNStructureComparator().compare(flatNetwork, extendedFlatNetwork), -1);
Assert.assertEquals(new NNStructureComparator().compare(flatNetwork, flatNetwork), 0);
Assert.assertEquals(new NNStructureComparator().compare(extendedFlatNetwork, flatNetwork), 1);
BasicML diffBasicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model2.nn")));
BasicNetwork diffBasicNetwork = (BasicNetwork) diffBasicML;
FlatNetwork diffFlatNetwork = diffBasicNetwork.getFlat();
Assert.assertEquals(new NNStructureComparator().compare(flatNetwork, diffFlatNetwork), -1);
Assert.assertEquals(new NNStructureComparator().compare(diffFlatNetwork, flatNetwork), -1);
Assert.assertEquals(new NNStructureComparator().compare(extendedFlatNetwork, diffFlatNetwork), 1);
Assert.assertEquals(new NNStructureComparator().compare(diffFlatNetwork, extendedFlatNetwork), -1);
BasicML deepBasicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model3.nn")));
BasicNetwork deppBasicNetwork = (BasicNetwork) deepBasicML;
FlatNetwork deepFlatNetwork = deppBasicNetwork.getFlat();
Assert.assertEquals(new NNStructureComparator().compare(deepFlatNetwork, flatNetwork), 1);
Assert.assertEquals(new NNStructureComparator().compare(flatNetwork, deepFlatNetwork), -1);
}
Aggregations