use of org.apache.ignite.ml.tree.NodeData in project ignite by apache.
the class SparkModelParser method extractNodeDataFromParquetRow.
/**
* Form the node data according data in parquet row.
*
* @param g The given group presenting the node data from Spark DT model.
*/
@NotNull
private static NodeData extractNodeDataFromParquetRow(SimpleGroup g) {
NodeData nodeData = new NodeData();
nodeData.id = g.getInteger(0, 0);
nodeData.prediction = g.getDouble(1, 0);
nodeData.leftChildId = g.getInteger(5, 0);
nodeData.rightChildId = g.getInteger(6, 0);
if (nodeData.leftChildId == -1 && nodeData.rightChildId == -1) {
nodeData.featureIdx = -1;
nodeData.threshold = -1;
nodeData.isLeafNode = true;
} else {
final SimpleGroup splitGrp = (SimpleGroup) g.getGroup(7, 0);
nodeData.featureIdx = splitGrp.getInteger(0, 0);
nodeData.threshold = splitGrp.getGroup(1, 0).getGroup(0, 0).getDouble(0, 0);
}
return nodeData;
}
use of org.apache.ignite.ml.tree.NodeData in project ignite by apache.
the class SparkModelParser method loadDecisionTreeModel.
/**
* Load Decision Tree model.
*
* @param pathToMdl Path to model.
* @param learningEnvironment Learning environment.
*/
private static Model loadDecisionTreeModel(String pathToMdl, LearningEnvironment learningEnvironment) {
try (ParquetFileReader r = ParquetFileReader.open(HadoopInputFile.fromPath(new Path(pathToMdl), new Configuration()))) {
PageReadStore pages;
final MessageType schema = r.getFooter().getFileMetaData().getSchema();
final MessageColumnIO colIO = new ColumnIOFactory().getColumnIO(schema);
final Map<Integer, NodeData> nodes = new TreeMap<>();
while (null != (pages = r.readNextRowGroup())) {
final long rows = pages.getRowCount();
final RecordReader recordReader = colIO.getRecordReader(pages, new GroupRecordConverter(schema));
for (int i = 0; i < rows; i++) {
final SimpleGroup g = (SimpleGroup) recordReader.read();
NodeData nodeData = extractNodeDataFromParquetRow(g);
nodes.put(nodeData.id, nodeData);
}
}
return buildDecisionTreeModel(nodes);
} catch (IOException e) {
String msg = "Error reading parquet file: " + e.getMessage();
learningEnvironment.logger().log(MLLogger.VerboseLevel.HIGH, msg);
e.printStackTrace();
}
return null;
}
use of org.apache.ignite.ml.tree.NodeData in project ignite by apache.
the class SparkModelParser method parseTreesForRandomForestAlgorithm.
/**
* Parse trees from file for common Random Forest ensemble.
*
* @param pathToMdl Path to model.
* @param learningEnvironment Learning environment.
*/
private static List<IgniteModel<Vector, Double>> parseTreesForRandomForestAlgorithm(String pathToMdl, LearningEnvironment learningEnvironment) {
try (ParquetFileReader r = ParquetFileReader.open(HadoopInputFile.fromPath(new Path(pathToMdl), new Configuration()))) {
PageReadStore pages;
final MessageType schema = r.getFooter().getFileMetaData().getSchema();
final MessageColumnIO colIO = new ColumnIOFactory().getColumnIO(schema);
final Map<Integer, TreeMap<Integer, NodeData>> nodesByTreeId = new TreeMap<>();
while (null != (pages = r.readNextRowGroup())) {
final long rows = pages.getRowCount();
final RecordReader recordReader = colIO.getRecordReader(pages, new GroupRecordConverter(schema));
for (int i = 0; i < rows; i++) {
final SimpleGroup g = (SimpleGroup) recordReader.read();
final int treeID = g.getInteger(0, 0);
final SimpleGroup nodeDataGroup = (SimpleGroup) g.getGroup(1, 0);
NodeData nodeData = extractNodeDataFromParquetRow(nodeDataGroup);
if (nodesByTreeId.containsKey(treeID)) {
Map<Integer, NodeData> nodesByNodeId = nodesByTreeId.get(treeID);
nodesByNodeId.put(nodeData.id, nodeData);
} else {
TreeMap<Integer, NodeData> nodesByNodeId = new TreeMap<>();
nodesByNodeId.put(nodeData.id, nodeData);
nodesByTreeId.put(treeID, nodesByNodeId);
}
}
}
List<IgniteModel<Vector, Double>> models = new ArrayList<>();
nodesByTreeId.forEach((key, nodes) -> models.add(buildDecisionTreeModel(nodes)));
return models;
} catch (IOException e) {
String msg = "Error reading parquet file: " + e.getMessage();
learningEnvironment.logger().log(MLLogger.VerboseLevel.HIGH, msg);
e.printStackTrace();
}
return null;
}
use of org.apache.ignite.ml.tree.NodeData in project ignite by apache.
the class SparkModelParser method parseAndBuildGDBModel.
/**
* Parse and build common GDB model with the custom label mapper.
*
* @param pathToMdl Path to model.
* @param pathToMdlMetaData Path to model meta data.
* @param lbMapper Label mapper.
* @param learningEnvironment learningEnvironment
*/
@Nullable
private static Model parseAndBuildGDBModel(String pathToMdl, String pathToMdlMetaData, IgniteFunction<Double, Double> lbMapper, LearningEnvironment learningEnvironment) {
double[] treeWeights = null;
final Map<Integer, Double> treeWeightsByTreeID = new HashMap<>();
try (ParquetFileReader r = ParquetFileReader.open(HadoopInputFile.fromPath(new Path(pathToMdlMetaData), new Configuration()))) {
PageReadStore pagesMetaData;
final MessageType schema = r.getFooter().getFileMetaData().getSchema();
final MessageColumnIO colIO = new ColumnIOFactory().getColumnIO(schema);
while (null != (pagesMetaData = r.readNextRowGroup())) {
final long rows = pagesMetaData.getRowCount();
final RecordReader recordReader = colIO.getRecordReader(pagesMetaData, new GroupRecordConverter(schema));
for (int i = 0; i < rows; i++) {
final SimpleGroup g = (SimpleGroup) recordReader.read();
int treeId = g.getInteger(0, 0);
double treeWeight = g.getDouble(2, 0);
treeWeightsByTreeID.put(treeId, treeWeight);
}
}
} catch (IOException e) {
String msg = "Error reading parquet file with MetaData by the path: " + pathToMdlMetaData;
learningEnvironment.logger().log(MLLogger.VerboseLevel.HIGH, msg);
e.printStackTrace();
}
treeWeights = new double[treeWeightsByTreeID.size()];
for (int i = 0; i < treeWeights.length; i++) treeWeights[i] = treeWeightsByTreeID.get(i);
try (ParquetFileReader r = ParquetFileReader.open(HadoopInputFile.fromPath(new Path(pathToMdl), new Configuration()))) {
PageReadStore pages;
final MessageType schema = r.getFooter().getFileMetaData().getSchema();
final MessageColumnIO colIO = new ColumnIOFactory().getColumnIO(schema);
final Map<Integer, TreeMap<Integer, NodeData>> nodesByTreeId = new TreeMap<>();
while (null != (pages = r.readNextRowGroup())) {
final long rows = pages.getRowCount();
final RecordReader recordReader = colIO.getRecordReader(pages, new GroupRecordConverter(schema));
for (int i = 0; i < rows; i++) {
final SimpleGroup g = (SimpleGroup) recordReader.read();
final int treeID = g.getInteger(0, 0);
final SimpleGroup nodeDataGroup = (SimpleGroup) g.getGroup(1, 0);
NodeData nodeData = extractNodeDataFromParquetRow(nodeDataGroup);
if (nodesByTreeId.containsKey(treeID)) {
Map<Integer, NodeData> nodesByNodeId = nodesByTreeId.get(treeID);
nodesByNodeId.put(nodeData.id, nodeData);
} else {
TreeMap<Integer, NodeData> nodesByNodeId = new TreeMap<>();
nodesByNodeId.put(nodeData.id, nodeData);
nodesByTreeId.put(treeID, nodesByNodeId);
}
}
}
final List<IgniteModel<Vector, Double>> models = new ArrayList<>();
nodesByTreeId.forEach((key, nodes) -> models.add(buildDecisionTreeModel(nodes)));
return new GDBModel(models, new WeightedPredictionsAggregator(treeWeights), lbMapper);
} catch (IOException e) {
String msg = "Error reading parquet file: " + e.getMessage();
learningEnvironment.logger().log(MLLogger.VerboseLevel.HIGH, msg);
e.printStackTrace();
}
return null;
}
Aggregations