use of org.nd4j.linalg.io.ClassPathResource in project deeplearning4j by deeplearning4j.
the class KerasModelConfigurationTest method importMnistCnnTensorFlowConfigurationTest.
@Test
public void importMnistCnnTensorFlowConfigurationTest() throws Exception {
ClassPathResource configResource = new ClassPathResource("modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_config.json", KerasModelConfigurationTest.class.getClassLoader());
MultiLayerConfiguration config = new KerasModel.ModelBuilder().modelJsonInputStream(configResource.getInputStream()).enforceTrainingConfig(true).buildSequential().getMultiLayerConfiguration();
MultiLayerNetwork model = new MultiLayerNetwork(config);
model.init();
}
use of org.nd4j.linalg.io.ClassPathResource in project deeplearning4j by deeplearning4j.
the class MiniBatchTests method testMiniBatches.
@Test
public void testMiniBatches() throws Exception {
log.info("Setting up Spark Context...");
JavaRDD<String> lines = sc.textFile(new ClassPathResource("svmLight/iris_svmLight_0.txt").getTempFileFromArchive().toURI().toString()).cache();
long count = lines.count();
assertEquals(300, count);
// gotta map this to a Matrix/INDArray
JavaRDD<DataSet> points = lines.map(new RecordReaderFunction(new SVMLightRecordReader(), 4, 3)).cache();
count = points.count();
assertEquals(300, count);
JavaRDD<DataSet> miniBatches = new RDDMiniBatches(10, points).miniBatchesJava();
count = miniBatches.count();
assertEquals(30, count);
lines.unpersist();
points.unpersist();
miniBatches.map(new DataSetAssertionFunction());
}
use of org.nd4j.linalg.io.ClassPathResource in project deeplearning4j by deeplearning4j.
the class TestSparkMultiLayerParameterAveraging method testFromSvmLight.
@Test
public void testFromSvmLight() throws Exception {
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), new ClassPathResource("svmLight/iris_svmLight_0.txt").getTempFileFromArchive().getAbsolutePath()).toJavaRDD().map(new Function<LabeledPoint, LabeledPoint>() {
@Override
public LabeledPoint call(LabeledPoint v1) throws Exception {
return new LabeledPoint(v1.label(), Vectors.dense(v1.features().toArray()));
}
});
DataSet d = new IrisDataSetIterator(150, 150).next();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(123).optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).iterations(100).miniBatch(true).maxNumLineSearchIterations(10).list().layer(0, new RBM.Builder(RBM.HiddenUnit.RECTIFIED, RBM.VisibleUnit.GAUSSIAN).nIn(4).nOut(100).weightInit(WeightInit.XAVIER).activation(Activation.RELU).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(100).nOut(3).activation(Activation.SOFTMAX).weightInit(WeightInit.XAVIER).build()).backprop(false).build();
MultiLayerNetwork network = new MultiLayerNetwork(conf);
network.init();
System.out.println("Initializing network");
SparkDl4jMultiLayer master = new SparkDl4jMultiLayer(sc, getBasicConf(), new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 5, 1, 0));
MultiLayerNetwork network2 = master.fitLabeledPoint(data);
Evaluation evaluation = new Evaluation();
evaluation.eval(d.getLabels(), network2.output(d.getFeatureMatrix()));
System.out.println(evaluation.stats());
}
use of org.nd4j.linalg.io.ClassPathResource in project nd4j by deeplearning4j.
the class ContextHolder method getInstance.
// holder for memory strategies override
/**
* Singleton pattern
* @return the instance for the context holder.
*/
public static synchronized ContextHolder getInstance() {
if (INSTANCE == null) {
Properties props = new Properties();
try {
props.load(new ClassPathResource("/cudafunctions.properties", ContextHolder.class.getClassLoader()).getInputStream());
} catch (IOException e) {
throw new RuntimeException(e);
}
INSTANCE = new ContextHolder();
INSTANCE.configure();
// set the properties to be accessible globally
for (String pair : props.stringPropertyNames()) System.getProperties().put(pair, props.getProperty(pair));
}
return INSTANCE;
}
use of org.nd4j.linalg.io.ClassPathResource in project nd4j by deeplearning4j.
the class OnnxImportTest method testOnnxImportEmbedding.
@Test
public void testOnnxImportEmbedding() throws Exception {
/**
*/
val importGraph = OnnxGraphMapper.getInstance().importGraph(new ClassPathResource("onnx_graphs/embedding_only.onnx").getInputStream());
val embeddingMatrix = importGraph.getVariable("2");
assertArrayEquals(new int[] { 100, 300 }, embeddingMatrix.getShape());
/* val onlyOp = importGraph.getFunctionForVertexId(importGraph.getVariable("3").getVertexId());
assertNotNull(onlyOp);
assertTrue(onlyOp instanceof Gather);
*/
}
Aggregations