use of org.encog.neural.networks.BasicNetwork in project shifu by ShifuML.
the class NNTrainerTest method computeScore.
private void computeScore(File modelFile, MLDataPair dataPair0, MLDataPair dataPair1, MLDataPair dataPair2, MLDataPair dataPair3) {
BasicNetwork model = (BasicNetwork) EncogDirectoryPersistence.loadObject(modelFile);
System.out.println((int) (model.compute(dataPair0.getInput()).getData(0) * 1000));
System.out.println((int) (model.compute(dataPair1.getInput()).getData(0) * 1000));
System.out.println((int) (model.compute(dataPair2.getInput()).getData(0) * 1000));
System.out.println((int) (model.compute(dataPair3.getInput()).getData(0) * 1000));
}
use of org.encog.neural.networks.BasicNetwork in project shifu by ShifuML.
the class DTrainTest method setup.
@BeforeTest
public void setup() {
network = new BasicNetwork();
network.addLayer(new BasicLayer(DTrainTest.INPUT_COUNT));
network.addLayer(new BasicLayer(DTrainTest.HIDDEN_COUNT));
network.addLayer(new BasicLayer(DTrainTest.OUTPUT_COUNT));
network.getStructure().finalizeStructure();
network.reset();
weights = network.getFlat().getWeights();
training = RandomTrainingFactory.generate(1000, 10000, INPUT_COUNT, OUTPUT_COUNT, -1, 1);
}
use of org.encog.neural.networks.BasicNetwork in project shifu by ShifuML.
the class NNModelSpecTest method testModelStructureCompare.
@Test
public void testModelStructureCompare() {
BasicML basicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model0.nn")));
BasicNetwork basicNetwork = (BasicNetwork) basicML;
FlatNetwork flatNetwork = basicNetwork.getFlat();
BasicML extendedBasicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model1.nn")));
BasicNetwork extendedBasicNetwork = (BasicNetwork) extendedBasicML;
FlatNetwork extendedFlatNetwork = extendedBasicNetwork.getFlat();
Assert.assertEquals(new NNStructureComparator().compare(flatNetwork, extendedFlatNetwork), -1);
Assert.assertEquals(new NNStructureComparator().compare(flatNetwork, flatNetwork), 0);
Assert.assertEquals(new NNStructureComparator().compare(extendedFlatNetwork, flatNetwork), 1);
BasicML diffBasicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model2.nn")));
BasicNetwork diffBasicNetwork = (BasicNetwork) diffBasicML;
FlatNetwork diffFlatNetwork = diffBasicNetwork.getFlat();
Assert.assertEquals(new NNStructureComparator().compare(flatNetwork, diffFlatNetwork), -1);
Assert.assertEquals(new NNStructureComparator().compare(diffFlatNetwork, flatNetwork), -1);
Assert.assertEquals(new NNStructureComparator().compare(extendedFlatNetwork, diffFlatNetwork), 1);
Assert.assertEquals(new NNStructureComparator().compare(diffFlatNetwork, extendedFlatNetwork), -1);
BasicML deepBasicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model3.nn")));
BasicNetwork deppBasicNetwork = (BasicNetwork) deepBasicML;
FlatNetwork deepFlatNetwork = deppBasicNetwork.getFlat();
Assert.assertEquals(new NNStructureComparator().compare(deepFlatNetwork, flatNetwork), 1);
Assert.assertEquals(new NNStructureComparator().compare(flatNetwork, deepFlatNetwork), -1);
}
use of org.encog.neural.networks.BasicNetwork in project shifu by ShifuML.
the class NNModelSpecTest method testFitExistingModelIn.
@Test
public void testFitExistingModelIn() {
BasicML basicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model0.nn")));
BasicNetwork basicNetwork = (BasicNetwork) basicML;
FlatNetwork flatNetwork = basicNetwork.getFlat();
NNMaster master = new NNMaster();
Set<Integer> fixedWeightIndexSet = master.fitExistingModelIn(flatNetwork, flatNetwork, Arrays.asList(new Integer[] { 6 }));
List<Integer> indexList = new ArrayList<Integer>(fixedWeightIndexSet);
Collections.sort(indexList);
Assert.assertEquals(indexList.size(), 31);
fixedWeightIndexSet = master.fitExistingModelIn(flatNetwork, flatNetwork, Arrays.asList(new Integer[] { 1 }));
indexList = new ArrayList<Integer>(fixedWeightIndexSet);
Collections.sort(indexList);
Assert.assertEquals(indexList.size(), 930);
BasicML extendedBasicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model1.nn")));
BasicNetwork extendedBasicNetwork = (BasicNetwork) extendedBasicML;
FlatNetwork extendedFlatNetwork = extendedBasicNetwork.getFlat();
fixedWeightIndexSet = master.fitExistingModelIn(flatNetwork, extendedFlatNetwork, Arrays.asList(new Integer[] { 1 }));
indexList = new ArrayList<Integer>(fixedWeightIndexSet);
Collections.sort(indexList);
Assert.assertEquals(indexList.size(), 930);
fixedWeightIndexSet = master.fitExistingModelIn(flatNetwork, extendedFlatNetwork, Arrays.asList(new Integer[] { 1 }), false);
indexList = new ArrayList<Integer>(fixedWeightIndexSet);
Collections.sort(indexList);
Assert.assertEquals(indexList.size(), 900);
}
Aggregations