use of edu.cmu.ml.proppr.util.StatusLogger in project ProPPR by TeamCohen.
the class ProverTestTemplate method testProveState.
@Test
public void testProveState() throws LogicProgramException {
log.info("testProveState");
FeatureDictWeighter w = new InnerProductWeighter();
SymbolTable<Feature> featureTab = new SimpleSymbolTable<Feature>();
int milk = featureTab.getId(new Feature("milk"));
w.put(featureTab.getSymbol(milk), 2);
prover.setWeighter(w);
ProofGraph pg = prover.makeProofGraph(new InferenceExample(Query.parse("isa(elsie,X)"), null, null), apr, featureTab, lpMilk, fMilk);
//("isa","elsie","X"));
Map<State, Double> dist = prover.prove(pg, new StatusLogger());
double query = 0.0;
double platypus = 0.0;
double others = 0.0;
double all = 0.0;
for (Map.Entry<State, Double> s : dist.entrySet()) {
Query q = pg.fill(s.getKey());
String arg2 = q.getRhs()[0].getArg(1).getName();
if ("platypus".equals(arg2)) {
platypus = Math.max(platypus, s.getValue());
} else if ("X1".equals(arg2)) {
query = Math.max(query, s.getValue());
} else {
others = Math.max(others, s.getValue());
}
System.out.println(q + "\t" + s.getValue());
all += s.getValue();
}
System.out.println();
System.out.println("query weight: " + query);
System.out.println("platypus weight: " + platypus);
System.out.println("others weight: " + others);
// assertTrue("query should retain most weight",query > Math.max(platypus,others));
assertTrue("milk-featured paths should score higher than others", platypus > others);
assertEquals("Total weight of all states should be around 1.0", 1.0, all, 10 * this.apr.epsilon);
assertEquals("Known features", 1, prover.weighter.numKnownFeatures);
assertEquals("Unknown features", 5, prover.weighter.numUnknownFeatures);
}
use of edu.cmu.ml.proppr.util.StatusLogger in project ProPPR by TeamCohen.
the class EqualityTest method test.
@Test
public void test() throws LogicProgramException, IOException {
WamProgram program = WamBaseProgram.load(new File(EQUALITY_PROGRAM));
Prover prover = new DprProver();
ProofGraph moral = new StateProofGraph(Query.parse("moral(X)"), new APROptions(), program);
Collection<Query> bobs = prover.solvedQueries(moral, new StatusLogger()).keySet();
// Map<State,Double> ans = prover.prove(moral);
// ArrayList<Query> bobs = new ArrayList<Query>();
// for (Map.Entry<State,Double> e : ans.entrySet()) {
// if (e.getKey().isCompleted()) bobs.add(moral.fill(e.getKey()));
// }
assertEquals(1, bobs.size());
Query bob = bobs.iterator().next();
assertEquals(1, bob.getRhs().length);
assertEquals("Answer should be bob", "bob", bob.getRhs()[0].getArg(0).getName());
}
use of edu.cmu.ml.proppr.util.StatusLogger in project ProPPR by TeamCohen.
the class Trainer method train.
public ParamVector<String, ?> train(SymbolTable<String> masterFeatures, Iterable<String> examples, LearningGraphBuilder builder, ParamVector<String, ?> initialParamVec, int numEpochs) {
ParamVector<String, ?> paramVec = this.masterLearner.setupParams(initialParamVec);
if (masterFeatures.size() > 0)
LearningGraphBuilder.setFeatures(masterFeatures);
NamedThreadFactory workingThreads = new NamedThreadFactory("work-");
NamedThreadFactory cleaningThreads = new NamedThreadFactory("cleanup-");
ThreadPoolExecutor workingPool;
ExecutorService cleanPool;
TrainingStatistics total = new TrainingStatistics();
StoppingCriterion stopper = new StoppingCriterion(numEpochs, this.stoppingPercent, this.stoppingEpoch);
boolean graphSizesStatusLog = true;
StatusLogger stattime = new StatusLogger();
// repeat until ready to stop
while (!stopper.satisified()) {
// set up current epoch
this.epoch++;
for (SRW learner : this.learners.values()) {
learner.setEpoch(epoch);
learner.clearLoss();
}
log.info("epoch " + epoch + " ...");
status.tick();
// reset counters & file pointers
this.statistics = new TrainingStatistics();
workingThreads.reset();
cleaningThreads.reset();
workingPool = new ThreadPoolExecutor(this.nthreads, Integer.MAX_VALUE, 10, TimeUnit.SECONDS, new LinkedBlockingQueue<Runnable>(), workingThreads);
cleanPool = Executors.newSingleThreadExecutor(cleaningThreads);
// run examples
int id = 1;
stattime.start();
int countdown = -1;
Trainer notify = null;
for (String s : examples) {
if (log.isDebugEnabled())
log.debug("Queue size " + (workingPool.getTaskCount() - workingPool.getCompletedTaskCount()));
statistics.updateReadingStatistics(stattime.sinceLast());
/*
* Throttling behavior:
* Once the number of unfinished tasks exceeds 1.5x the number of threads,
* we add a 'notify' object to the next nthreads training tasks. Then, the
* master thread gathers 'notify' signals until the number of unfinished tasks
* is no longer greater than the number of threads. Then we start adding tasks again.
*
* This works more or less fine, since the master thread stops pulling examples
* from disk when there are then a maximum of 2.5x training examples in the queue (that's
* the original 1.5x, which could represent a maximum of 1.5x training examples,
* plus the nthreads training tasks with active 'notify' objects. There's an
* additional nthreads parsing tasks in the queue but those don't take up much
* memory so we don't care). This lets us read in a good-sized buffer without
* blowing up the heap.
*
* Worst-case: None of the backlog is cleared before the master thread enters
* the synchronized block. nthreads-1 threads will be training long jobs, and
* the one free thread works through the 0.5x backlog and all nthreads countdown
* examples. The notify() sent by the final countdown example will occur when
* there are nthreads unfinished tasks in the queue, and the master thread will exit
* the synchronized block and proceed.
*
* Best-case: The backlog is already cleared by the time the master thread enters
* the synchronized block. The while() loop immediately exits, and the notify()
* signals from the countdown examples have no effect.
*/
if (countdown > 0) {
if (log.isDebugEnabled())
log.debug("Countdown " + countdown);
countdown--;
} else if (countdown == 0) {
if (log.isDebugEnabled())
log.debug("Countdown " + countdown + "; throttling:");
countdown--;
notify = null;
try {
synchronized (this) {
if (log.isDebugEnabled())
log.debug("Clearing training queue...");
while (workingPool.getTaskCount() - workingPool.getCompletedTaskCount() > this.nthreads) this.wait();
if (log.isDebugEnabled())
log.debug("Queue cleared.");
}
} catch (InterruptedException e) {
e.printStackTrace();
}
} else if (workingPool.getTaskCount() - workingPool.getCompletedTaskCount() > 1.5 * this.nthreads) {
if (log.isDebugEnabled())
log.debug("Starting countdown");
countdown = this.nthreads;
notify = this;
}
Future<PosNegRWExample> parsed = workingPool.submit(new Parse(s, builder, id));
Future<ExampleStats> trained = workingPool.submit(new Train(parsed, paramVec, id, notify));
cleanPool.submit(new TraceLosses(trained, id));
id++;
stattime.tick();
if (log.isInfoEnabled() && status.due(1))
log.info("parsed: " + id + " trained: " + statistics.exampleSetSize);
}
cleanEpoch(workingPool, cleanPool, paramVec, stopper, id, total);
if (graphSizesStatusLog) {
log.info("Dataset size stats: " + statistics.totalGraphSize + " total nodes / max " + statistics.maxGraphSize + " / avg " + (statistics.totalGraphSize / id));
graphSizesStatusLog = false;
}
}
log.info("Reading statistics: min " + total.minReadTime + " / max " + total.maxReadTime + " / total " + total.readTime);
log.info("Parsing statistics: min " + total.minParseTime + " / max " + total.maxParseTime + " / total " + total.parseTime);
log.info("Training statistics: min " + total.minTrainTime + " / max " + total.maxTrainTime + " / total " + total.trainTime);
return paramVec;
}
use of edu.cmu.ml.proppr.util.StatusLogger in project ProPPR by TeamCohen.
the class SparseMatrixIndex method load.
public void load(File dir, String functor_arg1type_arg2type) throws IOException {
log.info("Loading matrix " + functor_arg1type_arg2type + " from " + dir.getName() + "...");
this.name = dir + ":" + functor_arg1type_arg2type;
StatusLogger status = new StatusLogger(LOGUPDATE_MS);
/* Read the number of rows, columns, and entries - entry is a triple (i,j,m[i,j]) */
ParsedFile file = new ParsedFile(new File(dir, functor_arg1type_arg2type + ".rce"));
{
Iterator<String> it = file.iterator();
String line = it.next();
if (line == null)
throw new IllegalArgumentException("Bad format for " + functor_arg1type_arg2type + ".rce: line 1 must list #rows");
this.rows = Integer.parseInt(line.trim());
line = it.next();
if (line == null)
throw new IllegalArgumentException("Bad format for " + functor_arg1type_arg2type + ".rce: line 2 must list #cols");
this.cols = Integer.parseInt(line.trim());
line = it.next();
if (line == null)
throw new IllegalArgumentException("Bad format for " + functor_arg1type_arg2type + ".rce: line 3 must list #entries");
this.entries = Integer.parseInt(line.trim());
file.close();
}
/* Data is stored like this: colIndices[] is one long
* array, and values is a parallel array. rowsOffsets is another array so that
* rowOffsets[i] is where the column indices for row i start. Thus
*
* for (k=rowOffsets[i]; k<rowOffsets[i+1]; k++) {
* j = colIndices[k];
* m_ij = values[k];
* // this would retrieve i,j and the corresponding value in the sparse matrix m[i,j]
* doSomethingWith(i,j,m_ij);
* }
*
*/
ArrayList<Integer> rowsOffsets = new ArrayList<Integer>();
this.colIndices = new int[entries];
this.values = new float[entries];
long start = status.tick();
file = new ParsedFile(new File(dir, functor_arg1type_arg2type + ".rowOffset"));
for (String line : file) {
rowsOffsets.add(Integer.parseInt(line));
if (log.isInfoEnabled()) {
if (status.due()) {
log.info("rowOffset: " + file.getLineNumber() + " lines (" + (file.getLineNumber() / status.since(start)) + " klps)");
}
}
}
file.close();
start = status.tick();
file = new ParsedFile(new File(dir, functor_arg1type_arg2type + ".colIndex"));
for (String line : file) {
int ln = file.getLineNumber();
String[] parts = line.split(WEIGHT_DELIMITER);
colIndices[ln] = Integer.parseInt(parts[0]);
values[ln] = (float) (parts.length > 1 ? Float.parseFloat(parts[1]) : 1.0);
if (colIndices[ln] >= arg2.length) {
throw new IllegalArgumentException("Malformed sparsegraph! For index " + this.name + ", colIndices[" + ln + "]=" + colIndices[ln] + "; arg2.length is only " + arg2.length);
}
if (log.isInfoEnabled()) {
if (status.due()) {
log.info("colIndex: " + file.getLineNumber() + " lines (" + (file.getLineNumber() / status.since(start)) + " klps)");
}
}
}
file.close();
this.rowOffsets = new int[rowsOffsets.size() + 1];
for (int i = 0; i < rowsOffsets.size(); i++) {
rowOffsets[i] = rowsOffsets.get(i);
}
rowOffsets[rowsOffsets.size()] = entries;
long del = status.sinceStart();
if (del > LOGUPDATE_MS)
log.info("Finished loading sparse graph matrix " + functor_arg1type_arg2type + " (" + (del / 1000.) + " sec)");
}
use of edu.cmu.ml.proppr.util.StatusLogger in project ProPPR by TeamCohen.
the class PathDprProver method main.
public static void main(String[] args) throws LogicProgramException {
CustomConfiguration c = new CustomConfiguration(args, //input
Configuration.USE_PARAMS, //output
0, //constants
Configuration.USE_WAM | Configuration.USE_SQUASHFUNCTION, //modules
0) {
String query;
@Override
protected void addCustomOptions(Options options, int[] flags) {
options.getOption(Configuration.PARAMS_FILE_OPTION).setRequired(false);
options.addOption(OptionBuilder.withLongOpt("query").withArgName("functor(arg1,Var1)").hasArg().isRequired().withDescription("specify query to print top paths for").create());
//TODO: add prompt option (for large datasets)
}
@Override
protected void retrieveCustomSettings(CommandLine line, int[] flags, Options options) {
query = line.getOptionValue("query");
}
@Override
public Object getCustomSetting(String name) {
return query;
}
};
PathDprProver p = new PathDprProver(c.apr);
Query query = Query.parse((String) c.getCustomSetting(null));
StateProofGraph pg = new StateProofGraph(query, c.apr, c.program, c.plugins);
p.prove(pg, new StatusLogger());
}
Aggregations