Search in sources :

Example 16 with StatusLogger

use of edu.cmu.ml.proppr.util.StatusLogger in project ProPPR by TeamCohen.

the class ProverTestTemplate method testProveState.

@Test
public void testProveState() throws LogicProgramException {
    log.info("testProveState");
    FeatureDictWeighter w = new InnerProductWeighter();
    SymbolTable<Feature> featureTab = new SimpleSymbolTable<Feature>();
    int milk = featureTab.getId(new Feature("milk"));
    w.put(featureTab.getSymbol(milk), 2);
    prover.setWeighter(w);
    ProofGraph pg = prover.makeProofGraph(new InferenceExample(Query.parse("isa(elsie,X)"), null, null), apr, featureTab, lpMilk, fMilk);
    //("isa","elsie","X"));
    Map<State, Double> dist = prover.prove(pg, new StatusLogger());
    double query = 0.0;
    double platypus = 0.0;
    double others = 0.0;
    double all = 0.0;
    for (Map.Entry<State, Double> s : dist.entrySet()) {
        Query q = pg.fill(s.getKey());
        String arg2 = q.getRhs()[0].getArg(1).getName();
        if ("platypus".equals(arg2)) {
            platypus = Math.max(platypus, s.getValue());
        } else if ("X1".equals(arg2)) {
            query = Math.max(query, s.getValue());
        } else {
            others = Math.max(others, s.getValue());
        }
        System.out.println(q + "\t" + s.getValue());
        all += s.getValue();
    }
    System.out.println();
    System.out.println("query    weight: " + query);
    System.out.println("platypus weight: " + platypus);
    System.out.println("others   weight: " + others);
    //		assertTrue("query should retain most weight",query > Math.max(platypus,others));
    assertTrue("milk-featured paths should score higher than others", platypus > others);
    assertEquals("Total weight of all states should be around 1.0", 1.0, all, 10 * this.apr.epsilon);
    assertEquals("Known features", 1, prover.weighter.numKnownFeatures);
    assertEquals("Unknown features", 5, prover.weighter.numUnknownFeatures);
}
Also used : StatusLogger(edu.cmu.ml.proppr.util.StatusLogger) Query(edu.cmu.ml.proppr.prove.wam.Query) StateProofGraph(edu.cmu.ml.proppr.prove.wam.StateProofGraph) ProofGraph(edu.cmu.ml.proppr.prove.wam.ProofGraph) Feature(edu.cmu.ml.proppr.prove.wam.Feature) InferenceExample(edu.cmu.ml.proppr.examples.InferenceExample) FeatureDictWeighter(edu.cmu.ml.proppr.prove.FeatureDictWeighter) SimpleSymbolTable(edu.cmu.ml.proppr.util.SimpleSymbolTable) State(edu.cmu.ml.proppr.prove.wam.State) Map(java.util.Map) InnerProductWeighter(edu.cmu.ml.proppr.prove.InnerProductWeighter) Test(org.junit.Test)

Example 17 with StatusLogger

use of edu.cmu.ml.proppr.util.StatusLogger in project ProPPR by TeamCohen.

the class EqualityTest method test.

@Test
public void test() throws LogicProgramException, IOException {
    WamProgram program = WamBaseProgram.load(new File(EQUALITY_PROGRAM));
    Prover prover = new DprProver();
    ProofGraph moral = new StateProofGraph(Query.parse("moral(X)"), new APROptions(), program);
    Collection<Query> bobs = prover.solvedQueries(moral, new StatusLogger()).keySet();
    //		Map<State,Double> ans = prover.prove(moral);
    //		ArrayList<Query> bobs = new ArrayList<Query>();
    //		for (Map.Entry<State,Double> e : ans.entrySet()) {
    //			if (e.getKey().isCompleted()) bobs.add(moral.fill(e.getKey()));
    //		}
    assertEquals(1, bobs.size());
    Query bob = bobs.iterator().next();
    assertEquals(1, bob.getRhs().length);
    assertEquals("Answer should be bob", "bob", bob.getRhs()[0].getArg(0).getName());
}
Also used : StatusLogger(edu.cmu.ml.proppr.util.StatusLogger) Query(edu.cmu.ml.proppr.prove.wam.Query) StateProofGraph(edu.cmu.ml.proppr.prove.wam.StateProofGraph) ProofGraph(edu.cmu.ml.proppr.prove.wam.ProofGraph) DprProver(edu.cmu.ml.proppr.prove.DprProver) Prover(edu.cmu.ml.proppr.prove.Prover) DprProver(edu.cmu.ml.proppr.prove.DprProver) WamProgram(edu.cmu.ml.proppr.prove.wam.WamProgram) StateProofGraph(edu.cmu.ml.proppr.prove.wam.StateProofGraph) APROptions(edu.cmu.ml.proppr.util.APROptions) File(java.io.File) Test(org.junit.Test)

Example 18 with StatusLogger

use of edu.cmu.ml.proppr.util.StatusLogger in project ProPPR by TeamCohen.

the class Trainer method train.

public ParamVector<String, ?> train(SymbolTable<String> masterFeatures, Iterable<String> examples, LearningGraphBuilder builder, ParamVector<String, ?> initialParamVec, int numEpochs) {
    ParamVector<String, ?> paramVec = this.masterLearner.setupParams(initialParamVec);
    if (masterFeatures.size() > 0)
        LearningGraphBuilder.setFeatures(masterFeatures);
    NamedThreadFactory workingThreads = new NamedThreadFactory("work-");
    NamedThreadFactory cleaningThreads = new NamedThreadFactory("cleanup-");
    ThreadPoolExecutor workingPool;
    ExecutorService cleanPool;
    TrainingStatistics total = new TrainingStatistics();
    StoppingCriterion stopper = new StoppingCriterion(numEpochs, this.stoppingPercent, this.stoppingEpoch);
    boolean graphSizesStatusLog = true;
    StatusLogger stattime = new StatusLogger();
    // repeat until ready to stop
    while (!stopper.satisified()) {
        // set up current epoch
        this.epoch++;
        for (SRW learner : this.learners.values()) {
            learner.setEpoch(epoch);
            learner.clearLoss();
        }
        log.info("epoch " + epoch + " ...");
        status.tick();
        // reset counters & file pointers
        this.statistics = new TrainingStatistics();
        workingThreads.reset();
        cleaningThreads.reset();
        workingPool = new ThreadPoolExecutor(this.nthreads, Integer.MAX_VALUE, 10, TimeUnit.SECONDS, new LinkedBlockingQueue<Runnable>(), workingThreads);
        cleanPool = Executors.newSingleThreadExecutor(cleaningThreads);
        // run examples
        int id = 1;
        stattime.start();
        int countdown = -1;
        Trainer notify = null;
        for (String s : examples) {
            if (log.isDebugEnabled())
                log.debug("Queue size " + (workingPool.getTaskCount() - workingPool.getCompletedTaskCount()));
            statistics.updateReadingStatistics(stattime.sinceLast());
            /*
				 * Throttling behavior:
				 * Once the number of unfinished tasks exceeds 1.5x the number of threads,
				 * we add a 'notify' object to the next nthreads training tasks. Then, the
				 * master thread gathers 'notify' signals until the number of unfinished tasks 
				 * is no longer greater than the number of threads. Then we start adding tasks again.
				 * 
				 * This works more or less fine, since the master thread stops pulling examples
				 * from disk when there are then a maximum of 2.5x training examples in the queue (that's
				 * the original 1.5x, which could represent a maximum of 1.5x training examples,
				 * plus the nthreads training tasks with active 'notify' objects. There's an 
				 * additional nthreads parsing tasks in the queue but those don't take up much 
				 * memory so we don't care). This lets us read in a good-sized buffer without
				 * blowing up the heap.
				 * 
				 * Worst-case: None of the backlog is cleared before the master thread enters
				 * the synchronized block. nthreads-1 threads will be training long jobs, and 
				 * the one free thread works through the 0.5x backlog and all nthreads countdown 
				 * examples. The notify() sent by the final countdown example will occur when 
				 * there are nthreads unfinished tasks in the queue, and the master thread will exit
				 * the synchronized block and proceed.
				 * 
				 * Best-case: The backlog is already cleared by the time the master thread enters
				 * the synchronized block. The while() loop immediately exits, and the notify()
				 * signals from the countdown examples have no effect.
				 */
            if (countdown > 0) {
                if (log.isDebugEnabled())
                    log.debug("Countdown " + countdown);
                countdown--;
            } else if (countdown == 0) {
                if (log.isDebugEnabled())
                    log.debug("Countdown " + countdown + "; throttling:");
                countdown--;
                notify = null;
                try {
                    synchronized (this) {
                        if (log.isDebugEnabled())
                            log.debug("Clearing training queue...");
                        while (workingPool.getTaskCount() - workingPool.getCompletedTaskCount() > this.nthreads) this.wait();
                        if (log.isDebugEnabled())
                            log.debug("Queue cleared.");
                    }
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            } else if (workingPool.getTaskCount() - workingPool.getCompletedTaskCount() > 1.5 * this.nthreads) {
                if (log.isDebugEnabled())
                    log.debug("Starting countdown");
                countdown = this.nthreads;
                notify = this;
            }
            Future<PosNegRWExample> parsed = workingPool.submit(new Parse(s, builder, id));
            Future<ExampleStats> trained = workingPool.submit(new Train(parsed, paramVec, id, notify));
            cleanPool.submit(new TraceLosses(trained, id));
            id++;
            stattime.tick();
            if (log.isInfoEnabled() && status.due(1))
                log.info("parsed: " + id + " trained: " + statistics.exampleSetSize);
        }
        cleanEpoch(workingPool, cleanPool, paramVec, stopper, id, total);
        if (graphSizesStatusLog) {
            log.info("Dataset size stats: " + statistics.totalGraphSize + " total nodes / max " + statistics.maxGraphSize + " / avg " + (statistics.totalGraphSize / id));
            graphSizesStatusLog = false;
        }
    }
    log.info("Reading  statistics: min " + total.minReadTime + " / max " + total.maxReadTime + " / total " + total.readTime);
    log.info("Parsing  statistics: min " + total.minParseTime + " / max " + total.maxParseTime + " / total " + total.parseTime);
    log.info("Training statistics: min " + total.minTrainTime + " / max " + total.maxTrainTime + " / total " + total.trainTime);
    return paramVec;
}
Also used : StatusLogger(edu.cmu.ml.proppr.util.StatusLogger) NamedThreadFactory(edu.cmu.ml.proppr.util.multithreading.NamedThreadFactory) StoppingCriterion(edu.cmu.ml.proppr.learn.tools.StoppingCriterion) PosNegRWExample(edu.cmu.ml.proppr.examples.PosNegRWExample) LinkedBlockingQueue(java.util.concurrent.LinkedBlockingQueue) ExecutorService(java.util.concurrent.ExecutorService) SRW(edu.cmu.ml.proppr.learn.SRW) ThreadPoolExecutor(java.util.concurrent.ThreadPoolExecutor)

Example 19 with StatusLogger

use of edu.cmu.ml.proppr.util.StatusLogger in project ProPPR by TeamCohen.

the class SparseMatrixIndex method load.

public void load(File dir, String functor_arg1type_arg2type) throws IOException {
    log.info("Loading matrix " + functor_arg1type_arg2type + " from " + dir.getName() + "...");
    this.name = dir + ":" + functor_arg1type_arg2type;
    StatusLogger status = new StatusLogger(LOGUPDATE_MS);
    /* Read the number of rows, columns, and entries - entry is a triple (i,j,m[i,j]) */
    ParsedFile file = new ParsedFile(new File(dir, functor_arg1type_arg2type + ".rce"));
    {
        Iterator<String> it = file.iterator();
        String line = it.next();
        if (line == null)
            throw new IllegalArgumentException("Bad format for " + functor_arg1type_arg2type + ".rce: line 1 must list #rows");
        this.rows = Integer.parseInt(line.trim());
        line = it.next();
        if (line == null)
            throw new IllegalArgumentException("Bad format for " + functor_arg1type_arg2type + ".rce: line 2 must list #cols");
        this.cols = Integer.parseInt(line.trim());
        line = it.next();
        if (line == null)
            throw new IllegalArgumentException("Bad format for " + functor_arg1type_arg2type + ".rce: line 3 must list #entries");
        this.entries = Integer.parseInt(line.trim());
        file.close();
    }
    /* Data is stored like this: colIndices[] is one long
		 * array, and values is a parallel array.  rowsOffsets is another array so that 
		 * rowOffsets[i] is where the column indices for row i start. Thus
		 *
		 * for (k=rowOffsets[i]; k<rowOffsets[i+1]; k++) {
		 *   j = colIndices[k];
		 *   m_ij = values[k];
		 *   // this would retrieve i,j and the corresponding value in the sparse matrix m[i,j]
		 *   doSomethingWith(i,j,m_ij);
		 * }
		 *
		 */
    ArrayList<Integer> rowsOffsets = new ArrayList<Integer>();
    this.colIndices = new int[entries];
    this.values = new float[entries];
    long start = status.tick();
    file = new ParsedFile(new File(dir, functor_arg1type_arg2type + ".rowOffset"));
    for (String line : file) {
        rowsOffsets.add(Integer.parseInt(line));
        if (log.isInfoEnabled()) {
            if (status.due()) {
                log.info("rowOffset: " + file.getLineNumber() + " lines (" + (file.getLineNumber() / status.since(start)) + " klps)");
            }
        }
    }
    file.close();
    start = status.tick();
    file = new ParsedFile(new File(dir, functor_arg1type_arg2type + ".colIndex"));
    for (String line : file) {
        int ln = file.getLineNumber();
        String[] parts = line.split(WEIGHT_DELIMITER);
        colIndices[ln] = Integer.parseInt(parts[0]);
        values[ln] = (float) (parts.length > 1 ? Float.parseFloat(parts[1]) : 1.0);
        if (colIndices[ln] >= arg2.length) {
            throw new IllegalArgumentException("Malformed sparsegraph! For index " + this.name + ", colIndices[" + ln + "]=" + colIndices[ln] + "; arg2.length is only " + arg2.length);
        }
        if (log.isInfoEnabled()) {
            if (status.due()) {
                log.info("colIndex: " + file.getLineNumber() + " lines (" + (file.getLineNumber() / status.since(start)) + " klps)");
            }
        }
    }
    file.close();
    this.rowOffsets = new int[rowsOffsets.size() + 1];
    for (int i = 0; i < rowsOffsets.size(); i++) {
        rowOffsets[i] = rowsOffsets.get(i);
    }
    rowOffsets[rowsOffsets.size()] = entries;
    long del = status.sinceStart();
    if (del > LOGUPDATE_MS)
        log.info("Finished loading sparse graph matrix " + functor_arg1type_arg2type + " (" + (del / 1000.) + " sec)");
}
Also used : StatusLogger(edu.cmu.ml.proppr.util.StatusLogger) Iterator(java.util.Iterator) ArrayList(java.util.ArrayList) ParsedFile(edu.cmu.ml.proppr.util.ParsedFile) ParsedFile(edu.cmu.ml.proppr.util.ParsedFile) File(java.io.File)

Example 20 with StatusLogger

use of edu.cmu.ml.proppr.util.StatusLogger in project ProPPR by TeamCohen.

the class PathDprProver method main.

public static void main(String[] args) throws LogicProgramException {
    CustomConfiguration c = new CustomConfiguration(args, //input
    Configuration.USE_PARAMS, //output
    0, //constants
    Configuration.USE_WAM | Configuration.USE_SQUASHFUNCTION, //modules
    0) {

        String query;

        @Override
        protected void addCustomOptions(Options options, int[] flags) {
            options.getOption(Configuration.PARAMS_FILE_OPTION).setRequired(false);
            options.addOption(OptionBuilder.withLongOpt("query").withArgName("functor(arg1,Var1)").hasArg().isRequired().withDescription("specify query to print top paths for").create());
        //TODO: add prompt option (for large datasets)
        }

        @Override
        protected void retrieveCustomSettings(CommandLine line, int[] flags, Options options) {
            query = line.getOptionValue("query");
        }

        @Override
        public Object getCustomSetting(String name) {
            return query;
        }
    };
    PathDprProver p = new PathDprProver(c.apr);
    Query query = Query.parse((String) c.getCustomSetting(null));
    StateProofGraph pg = new StateProofGraph(query, c.apr, c.program, c.plugins);
    p.prove(pg, new StatusLogger());
}
Also used : StatusLogger(edu.cmu.ml.proppr.util.StatusLogger) Options(org.apache.commons.cli.Options) APROptions(edu.cmu.ml.proppr.util.APROptions) CommandLine(org.apache.commons.cli.CommandLine) Query(edu.cmu.ml.proppr.prove.wam.Query) CustomConfiguration(edu.cmu.ml.proppr.util.CustomConfiguration) StateProofGraph(edu.cmu.ml.proppr.prove.wam.StateProofGraph)

Aggregations

StatusLogger (edu.cmu.ml.proppr.util.StatusLogger)23 Test (org.junit.Test)13 Query (edu.cmu.ml.proppr.prove.wam.Query)11 StateProofGraph (edu.cmu.ml.proppr.prove.wam.StateProofGraph)11 File (java.io.File)10 WamProgram (edu.cmu.ml.proppr.prove.wam.WamProgram)9 Prover (edu.cmu.ml.proppr.prove.Prover)8 APROptions (edu.cmu.ml.proppr.util.APROptions)8 ProofGraph (edu.cmu.ml.proppr.prove.wam.ProofGraph)6 PosNegRWExample (edu.cmu.ml.proppr.examples.PosNegRWExample)4 DprProver (edu.cmu.ml.proppr.prove.DprProver)4 State (edu.cmu.ml.proppr.prove.wam.State)4 Map (java.util.Map)4 DfsProver (edu.cmu.ml.proppr.prove.DfsProver)3 ConstantArgument (edu.cmu.ml.proppr.prove.wam.ConstantArgument)3 Goal (edu.cmu.ml.proppr.prove.wam.Goal)3 WamPlugin (edu.cmu.ml.proppr.prove.wam.plugins.WamPlugin)3 ArrayList (java.util.ArrayList)3 GrounderTest (edu.cmu.ml.proppr.GrounderTest)2 InferenceExample (edu.cmu.ml.proppr.examples.InferenceExample)2