use of edu.cmu.ml.proppr.learn.tools.Sigmoid in project ProPPR by TeamCohen.
the class ModuleConfiguration method retrieveSettings.
@Override
protected void retrieveSettings(CommandLine line, int[] allFlags, Options options) throws IOException {
super.retrieveSettings(line, allFlags, options);
int flags;
// modules
flags = modules(allFlags);
if (isOn(flags, USE_PROVER)) {
if (!line.hasOption(PROVER_MODULE_OPTION)) {
// default:
this.prover = new DprProver(apr);
} else {
String[] values = line.getOptionValue(PROVER_MODULE_OPTION).split(":");
boolean proverSupportsPruning = false;
switch(PROVERS.valueOf(values[0])) {
case ippr:
this.prover = new IdPprProver(apr);
break;
case ppr:
this.prover = new PprProver(apr);
break;
case dpr:
this.prover = new DprProver(apr);
break;
case idpr:
this.prover = new IdDprProver(apr);
break;
case p_idpr:
if (prunedPredicateRules == null)
log.warn("option --" + PRUNEDPREDICATE_CONST_OPTION + " not set");
this.prover = new PruningIdDprProver(apr, prunedPredicateRules);
proverSupportsPruning = true;
break;
case qpr:
this.prover = new PriorityQueueProver(apr);
break;
case pdpr:
this.prover = new PathDprProver(apr);
break;
case dfs:
this.prover = new DfsProver(apr);
break;
case tr:
this.prover = new TracingDfsProver(apr);
if (this.nthreads > 1)
usageOptions(options, allFlags, "Tracing prover is not multithreaded. Remove --threads option or use --threads 1.");
break;
default:
usageOptions(options, allFlags, "No prover definition for '" + values[0] + "'");
}
if (prunedPredicateRules != null && !proverSupportsPruning)
log.warn("option --" + PRUNEDPREDICATE_CONST_OPTION + " is ignored by this prover");
if (values.length > 1) {
for (int i = 1; i < values.length; i++) {
this.prover.configure(values[i]);
}
}
}
}
if (anyOn(flags, USE_SQUASHFUNCTION | USE_PROVER | USE_SRW)) {
if (!line.hasOption(SQUASHFUNCTION_MODULE_OPTION)) {
// default:
this.squashingFunction = SRW.DEFAULT_SQUASHING_FUNCTION();
} else {
switch(SQUASHFUNCTIONS.valueOf(line.getOptionValue(SQUASHFUNCTION_MODULE_OPTION))) {
case linear:
squashingFunction = new Linear();
break;
case sigmoid:
squashingFunction = new Sigmoid();
break;
case tanh:
squashingFunction = new Tanh();
break;
case tanh1:
squashingFunction = new Tanh1();
break;
case ReLU:
squashingFunction = new ReLU();
break;
case LReLU:
squashingFunction = new LReLU();
break;
case exp:
squashingFunction = new Exp();
break;
case clipExp:
squashingFunction = new ClippedExp();
break;
default:
this.usageOptions(options, allFlags, "Unrecognized squashing function " + line.getOptionValue(SQUASHFUNCTION_MODULE_OPTION));
}
}
}
if (isOn(flags, Configuration.USE_GROUNDER)) {
if (!line.hasOption(GROUNDER_MODULE_OPTION)) {
this.grounder = new Grounder(nthreads, Multithreading.DEFAULT_THROTTLE, apr, prover, program, plugins);
} else {
String[] values = line.getOptionValues(GROUNDER_MODULE_OPTION);
int threads = nthreads;
if (values.length > 1)
threads = Integer.parseInt(values[1]);
int throttle = Multithreading.DEFAULT_THROTTLE;
if (values.length > 2)
throttle = Integer.parseInt(values[2]);
this.grounder = new Grounder(threads, throttle, apr, prover, program, plugins);
}
this.grounder.includeUnlabeledGraphs(includeEmptyGraphs);
}
if (isOn(flags, USE_TRAIN)) {
this.setupSRW(line, flags, options);
seed(line);
if (isOn(flags, USE_TRAINER)) {
// set default stopping criteria
double percent = StoppingCriterion.DEFAULT_MAX_PCT_IMPROVEMENT;
int stableEpochs = StoppingCriterion.DEFAULT_MIN_STABLE_EPOCHS;
TRAINERS type = TRAINERS.cached;
if (line.hasOption(TRAINER_MODULE_OPTION))
type = TRAINERS.valueOf(line.getOptionValues(TRAINER_MODULE_OPTION)[0]);
switch(type) {
case streaming:
this.trainer = new Trainer(this.srw, this.nthreads, this.throttle);
break;
//fallthrough
case caching:
case cached:
boolean shuff = CachingTrainer.DEFAULT_SHUFFLE;
if (line.hasOption(TRAINER_MODULE_OPTION)) {
for (String val : line.getOptionValues(TRAINER_MODULE_OPTION)) {
if (val.startsWith("shuff"))
shuff = Boolean.parseBoolean(val.substring(val.indexOf("=") + 1));
}
}
this.trainer = new CachingTrainer(this.srw, this.nthreads, this.throttle, shuff);
break;
case adagrad:
this.usageOptions(options, allFlags, "Trainer 'adagrad' no longer necessary. Use '--srw adagrad' for adagrad descent method.");
default:
this.usageOptions(options, allFlags, "Unrecognized trainer " + line.getOptionValue(TRAINER_MODULE_OPTION));
}
if (this.srw instanceof AdaGradSRW)
// override default
stableEpochs = 2;
// now get stopping criteria from command line
if (line.hasOption(TRAINER_MODULE_OPTION)) {
for (String val : line.getOptionValues(TRAINER_MODULE_OPTION)) {
if (val.startsWith("pct"))
percent = Double.parseDouble(val.substring(val.indexOf("=") + 1));
else if (val.startsWith("stableEpochs"))
stableEpochs = Integer.parseInt(val.substring(val.indexOf("=") + 1));
}
}
this.trainer.setStoppingCriteria(stableEpochs, percent);
}
}
if (isOn(flags, USE_SRW) && this.srw == null)
this.setupSRW(line, flags, options);
}
Aggregations