use of org.encog.ml.BasicML in project shifu by ShifuML.
the class NNModelSpecTest method testModelTraverse.
// @Test
public void testModelTraverse() {
BasicML basicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model0.nn")));
BasicNetwork basicNetwork = (BasicNetwork) basicML;
FlatNetwork flatNetwork = basicNetwork.getFlat();
BasicML extendedBasicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model1.nn")));
BasicNetwork extendedBasicNetwork = (BasicNetwork) extendedBasicML;
FlatNetwork extendedFlatNetwork = extendedBasicNetwork.getFlat();
for (int layer = flatNetwork.getLayerIndex().length - 1; layer > 0; layer--) {
int layerOutputCnt = flatNetwork.getLayerFeedCounts()[layer - 1];
int layerInputCnt = flatNetwork.getLayerCounts()[layer];
System.out.println("Weight index for layer " + (flatNetwork.getLayerIndex().length - layer));
int extendedLayerInputCnt = extendedFlatNetwork.getLayerCounts()[layer];
int indexPos = flatNetwork.getWeightIndex()[layer - 1];
int extendedIndexPos = extendedFlatNetwork.getWeightIndex()[layer - 1];
for (int i = 0; i < layerOutputCnt; i++) {
for (int j = 0; j < layerInputCnt; j++) {
int weightIndex = indexPos + (i * layerInputCnt) + j;
int extendedWeightIndex = extendedIndexPos + (i * extendedLayerInputCnt) + j;
if (j == layerInputCnt - 1) {
// move bias to end
extendedWeightIndex = extendedIndexPos + (i * extendedLayerInputCnt) + (extendedLayerInputCnt - 1);
}
System.out.println(weightIndex + " --> " + extendedWeightIndex);
}
}
}
}
use of org.encog.ml.BasicML in project shifu by ShifuML.
the class NNModelSpecTest method testModelFitIn.
@Test
public void testModelFitIn() {
PersistorRegistry.getInstance().add(new PersistBasicFloatNetwork());
BasicML basicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model5.nn")));
BasicNetwork basicNetwork = (BasicNetwork) basicML;
FlatNetwork flatNetwork = basicNetwork.getFlat();
BasicML extendedBasicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model6.nn")));
BasicNetwork extendedBasicNetwork = (BasicNetwork) extendedBasicML;
FlatNetwork extendedFlatNetwork = extendedBasicNetwork.getFlat();
NNMaster master = new NNMaster();
Set<Integer> fixedWeightIndexSet = master.fitExistingModelIn(flatNetwork, extendedFlatNetwork, Arrays.asList(new Integer[] { 1, 2, 3 }));
Assert.assertEquals(fixedWeightIndexSet.size(), 931);
fixedWeightIndexSet = master.fitExistingModelIn(flatNetwork, extendedFlatNetwork, Arrays.asList(new Integer[] { 1, 2, 3 }), false);
Assert.assertEquals(fixedWeightIndexSet.size(), 910);
}
use of org.encog.ml.BasicML in project shifu by ShifuML.
the class ShifuCLI method analysisModelFi.
public static int analysisModelFi(String modelPath) {
File modelFile = new File(modelPath);
if (!modelFile.exists() || !(modelPath.toUpperCase().endsWith("." + CommonConstants.GBT_ALG_NAME) || modelPath.toUpperCase().endsWith("." + CommonConstants.RF_ALG_NAME))) {
log.error("The model {} doesn't exist or it isn't GBT/RF model.", modelPath);
return 1;
}
FileInputStream inputStream = null;
String fiFileName = modelFile.getName() + ".fi";
try {
inputStream = new FileInputStream(modelFile);
BasicML basicML = TreeModel.loadFromStream(inputStream);
Map<Integer, MutablePair<String, Double>> featureImportances = CommonUtils.computeTreeModelFeatureImportance(Arrays.asList(new BasicML[] { basicML }));
CommonUtils.writeFeatureImportance(fiFileName, featureImportances);
} catch (IOException e) {
log.error("Fail to analysis model FI for {}", modelPath);
return 1;
} finally {
IOUtils.closeQuietly(inputStream);
}
return 0;
}
use of org.encog.ml.BasicML in project shifu by ShifuML.
the class PostTrainMapper method setup.
@SuppressWarnings({ "rawtypes", "unchecked" })
@Override
protected void setup(Context context) throws IOException, InterruptedException {
loadConfigFiles(context);
loadTagWeightNum();
this.dataPurifier = new DataPurifier(this.modelConfig, false);
this.outputKey = new IntWritable();
this.outputValue = new Text();
this.tags = new HashSet<String>(modelConfig.getFlattenTags());
SourceType sourceType = this.modelConfig.getDataSet().getSource();
List<BasicML> models = ModelSpecLoaderUtils.loadBasicModels(modelConfig, null, sourceType);
this.headers = CommonUtils.getFinalHeaders(modelConfig);
this.modelRunner = new ModelRunner(modelConfig, columnConfigList, this.headers, modelConfig.getDataSetDelimiter(), models);
this.mos = new MultipleOutputs<NullWritable, Text>((TaskInputOutputContext) context);
this.initFeatureStats();
}
use of org.encog.ml.BasicML in project shifu by ShifuML.
the class ExportModelProcessor method run.
/*
* (non-Javadoc)
*
* @see ml.shifu.shifu.core.processor.Processor#run()
*/
@Override
public int run() throws Exception {
setUp(ModelStep.EXPORT);
int status = 0;
File pmmls = new File("pmmls");
FileUtils.forceMkdir(pmmls);
if (StringUtils.isBlank(type)) {
type = PMML;
}
String modelsPath = pathFinder.getModelsPath(SourceType.LOCAL);
if (type.equalsIgnoreCase(ONE_BAGGING_MODEL)) {
if (!"nn".equalsIgnoreCase(modelConfig.getAlgorithm()) && !CommonUtils.isTreeModel(modelConfig.getAlgorithm())) {
log.warn("Currently one bagging model is only supported in NN/GBT/RF algorithm.");
} else {
List<BasicML> models = ModelSpecLoaderUtils.loadBasicModels(modelsPath, ALGORITHM.valueOf(modelConfig.getAlgorithm().toUpperCase()));
if (models.size() < 1) {
log.warn("No model is found in {}.", modelsPath);
} else {
log.info("Convert nn models into one binary bagging model.");
Configuration conf = new Configuration();
Path output = new Path(pathFinder.getBaggingModelPath(SourceType.LOCAL), "model.b" + modelConfig.getAlgorithm());
if ("nn".equalsIgnoreCase(modelConfig.getAlgorithm())) {
BinaryNNSerializer.save(modelConfig, columnConfigList, models, FileSystem.getLocal(conf), output);
} else if (CommonUtils.isTreeModel(modelConfig.getAlgorithm())) {
List<List<TreeNode>> baggingTrees = new ArrayList<List<TreeNode>>();
for (int i = 0; i < models.size(); i++) {
TreeModel tm = (TreeModel) models.get(i);
// TreeModel only has one TreeNode instance although it is list inside
baggingTrees.add(tm.getIndependentTreeModel().getTrees().get(0));
}
int[] inputOutputIndex = DTrainUtils.getNumericAndCategoricalInputAndOutputCounts(this.columnConfigList);
// numerical + categorical = # of all input
int inputCount = inputOutputIndex[0] + inputOutputIndex[1];
BinaryDTSerializer.save(modelConfig, columnConfigList, baggingTrees, modelConfig.getParams().get("Loss").toString(), inputCount, FileSystem.getLocal(conf), output);
}
log.info("Please find one unified bagging model in local {}.", output);
}
}
} else if (type.equalsIgnoreCase(PMML)) {
// typical pmml generation
List<BasicML> models = ModelSpecLoaderUtils.loadBasicModels(modelsPath, ALGORITHM.valueOf(modelConfig.getAlgorithm().toUpperCase()));
PMMLTranslator translator = PMMLConstructorFactory.produce(modelConfig, columnConfigList, isConcise(), false);
for (int index = 0; index < models.size(); index++) {
String path = "pmmls" + File.separator + modelConfig.getModelSetName() + Integer.toString(index) + ".pmml";
log.info("\t Start to generate " + path);
PMML pmml = translator.build(Arrays.asList(new BasicML[] { models.get(index) }));
PMMLUtils.savePMML(pmml, path);
}
} else if (type.equalsIgnoreCase(ONE_BAGGING_PMML_MODEL)) {
// one unified bagging pmml generation
log.info("Convert models into one bagging pmml model {} format", type);
if (!"nn".equalsIgnoreCase(modelConfig.getAlgorithm())) {
log.warn("Currently one bagging pmml model is only supported in NN algorithm.");
} else {
List<BasicML> models = ModelSpecLoaderUtils.loadBasicModels(modelsPath, ALGORITHM.valueOf(modelConfig.getAlgorithm().toUpperCase()));
PMMLTranslator translator = PMMLConstructorFactory.produce(modelConfig, columnConfigList, isConcise(), true);
String path = "pmmls" + File.separator + modelConfig.getModelSetName() + ".pmml";
log.info("\t Start to generate one unified model to: " + path);
PMML pmml = translator.build(models);
PMMLUtils.savePMML(pmml, path);
}
} else if (type.equalsIgnoreCase(COLUMN_STATS)) {
saveColumnStatus();
} else if (type.equalsIgnoreCase(WOE_MAPPING)) {
List<ColumnConfig> exportCatColumns = new ArrayList<ColumnConfig>();
List<String> catVariables = getRequestVars();
for (ColumnConfig columnConfig : this.columnConfigList) {
if (CollectionUtils.isEmpty(catVariables) || isRequestColumn(catVariables, columnConfig)) {
exportCatColumns.add(columnConfig);
}
}
if (CollectionUtils.isNotEmpty(exportCatColumns)) {
List<String> woeMappings = new ArrayList<String>();
for (ColumnConfig columnConfig : exportCatColumns) {
String woeMapText = rebinAndExportWoeMapping(columnConfig);
woeMappings.add(woeMapText);
}
FileUtils.write(new File("woemapping.txt"), StringUtils.join(woeMappings, ",\n"));
}
} else if (type.equalsIgnoreCase(WOE)) {
List<String> woeInfos = new ArrayList<String>();
for (ColumnConfig columnConfig : this.columnConfigList) {
if (columnConfig.getBinLength() > 1 && ((columnConfig.isCategorical() && CollectionUtils.isNotEmpty(columnConfig.getBinCategory())) || (columnConfig.isNumerical() && CollectionUtils.isNotEmpty(columnConfig.getBinBoundary()) && columnConfig.getBinBoundary().size() > 1))) {
List<String> varWoeInfos = generateWoeInfos(columnConfig);
if (CollectionUtils.isNotEmpty(varWoeInfos)) {
woeInfos.addAll(varWoeInfos);
woeInfos.add("");
}
}
FileUtils.writeLines(new File("varwoe_info.txt"), woeInfos);
}
} else if (type.equalsIgnoreCase(CORRELATION)) {
// export correlation into mapping list
if (!ShifuFileUtils.isFileExists(pathFinder.getLocalCorrelationCsvPath(), SourceType.LOCAL)) {
log.warn("The correlation file doesn't exist. Please make sure you have ran `shifu stats -c`.");
return 2;
}
return exportVariableCorr();
} else {
log.error("Unsupported output format - {}", type);
status = -1;
}
clearUp(ModelStep.EXPORT);
log.info("Done.");
return status;
}
Aggregations