Search in sources :

Example 1 with RecordReaderFunction

use of org.deeplearning4j.spark.datavec.RecordReaderFunction in project deeplearning4j by deeplearning4j.

the class SparkDl4jLayer method fit.

/**
     * Fit the layer based on the specified org.deeplearning4j.spark context text file
     * @param path the path to the text file
     * @param labelIndex the index of the label
     * @param recordReader the record reader
     * @return the fit layer
     */
public Layer fit(String path, int labelIndex, RecordReader recordReader) {
    FeedForwardLayer ffLayer = (FeedForwardLayer) conf.getLayer();
    JavaRDD<String> lines = sc.textFile(path);
    // gotta map this to a Matrix/INDArray
    JavaRDD<DataSet> points = lines.map(new RecordReaderFunction(recordReader, labelIndex, ffLayer.getNOut()));
    return fitDataSet(points);
}
Also used : RecordReaderFunction(org.deeplearning4j.spark.datavec.RecordReaderFunction) DataSet(org.nd4j.linalg.dataset.DataSet) FeedForwardLayer(org.deeplearning4j.nn.conf.layers.FeedForwardLayer)

Aggregations

FeedForwardLayer (org.deeplearning4j.nn.conf.layers.FeedForwardLayer)1 RecordReaderFunction (org.deeplearning4j.spark.datavec.RecordReaderFunction)1 DataSet (org.nd4j.linalg.dataset.DataSet)1