Search in sources :

Example 1 with SimpleSymbolTable

use of edu.cmu.ml.proppr.util.SimpleSymbolTable in project ProPPR by TeamCohen.

the class GradientFinder method main.

public static void main(String[] args) {
    try {
        int inputFiles = Configuration.USE_GROUNDED | Configuration.USE_INIT_PARAMS;
        int outputFiles = Configuration.USE_GRADIENT | Configuration.USE_PARAMS;
        int modules = Configuration.USE_TRAINER | Configuration.USE_SRW | Configuration.USE_SQUASHFUNCTION;
        int constants = Configuration.USE_THREADS | Configuration.USE_EPOCHS | Configuration.USE_FORCE | Configuration.USE_FIXEDWEIGHTS;
        CustomConfiguration c = new CustomConfiguration(args, inputFiles, outputFiles, constants, modules) {

            boolean relax;

            @Override
            protected Option checkOption(Option o) {
                if (PARAMS_FILE_OPTION.equals(o.getLongOpt()) || INIT_PARAMS_FILE_OPTION.equals(o.getLongOpt()))
                    o.setRequired(false);
                return o;
            }

            @Override
            protected void addCustomOptions(Options options, int[] flags) {
                options.addOption(Option.builder().longOpt("relaxFW").desc("Relax fixedWeight rules for gradient computation (used in ProngHorn)").optionalArg(true).build());
            }

            @Override
            protected void retrieveCustomSettings(CommandLine line, int[] flags, Options options) {
                if (groundedFile == null || !groundedFile.exists())
                    usageOptions(options, flags, "Must specify grounded file using --" + Configuration.GROUNDED_FILE_OPTION);
                if (gradientFile == null)
                    usageOptions(options, flags, "Must specify gradient using --" + Configuration.GRADIENT_FILE_OPTION);
                // default to 0 epochs
                if (!options.hasOption("epochs"))
                    this.epochs = 0;
                this.relax = false;
                if (options.hasOption("relaxFW"))
                    this.relax = true;
            }

            @Override
            public Object getCustomSetting(String name) {
                if ("relaxFW".equals(name))
                    return this.relax;
                return null;
            }
        };
        System.out.println(c.toString());
        ParamVector<String, ?> params = null;
        SymbolTable<String> masterFeatures = new SimpleSymbolTable<String>();
        File featureIndex = new File(c.groundedFile.getParent(), c.groundedFile.getName() + Grounder.FEATURE_INDEX_EXTENSION);
        if (featureIndex.exists()) {
            log.info("Reading feature index from " + featureIndex.getName() + "...");
            for (String line : new ParsedFile(featureIndex)) {
                masterFeatures.insert(line.trim());
            }
        }
        if (c.epochs > 0) {
            // train first
            log.info("Training for " + c.epochs + " epochs...");
            params = c.trainer.train(masterFeatures, new ParsedFile(c.groundedFile), new ArrayLearningGraphBuilder(), // create a parameter vector
            c.initParamsFile, c.epochs);
            if (c.paramsFile != null)
                ParamsFile.save(params, c.paramsFile, c);
        } else if (c.initParamsFile != null) {
            params = new SimpleParamVector<String>(Dictionary.load(new ParsedFile(c.initParamsFile)));
        } else if (c.paramsFile != null) {
            params = new SimpleParamVector<String>(Dictionary.load(new ParsedFile(c.paramsFile)));
        } else {
            params = new SimpleParamVector<String>();
        }
        // this lets prongHorn hold external features fixed for training, but still compute their gradient
        if (((Boolean) c.getCustomSetting("relaxFW"))) {
            log.info("Turning off fixedWeight rules");
            c.trainer.setFixedWeightRules(new FixedWeightRules());
        }
        ParamVector<String, ?> batchGradient = c.trainer.findGradient(masterFeatures, new ParsedFile(c.groundedFile), new ArrayLearningGraphBuilder(), params);
        ParamsFile.save(batchGradient, c.gradientFile, c);
    } catch (Throwable t) {
        t.printStackTrace();
        System.exit(-1);
    }
}
Also used : Options(org.apache.commons.cli.Options) CustomConfiguration(edu.cmu.ml.proppr.util.CustomConfiguration) SimpleParamVector(edu.cmu.ml.proppr.util.math.SimpleParamVector) CommandLine(org.apache.commons.cli.CommandLine) SimpleSymbolTable(edu.cmu.ml.proppr.util.SimpleSymbolTable) FixedWeightRules(edu.cmu.ml.proppr.learn.tools.FixedWeightRules) Option(org.apache.commons.cli.Option) ParsedFile(edu.cmu.ml.proppr.util.ParsedFile) File(java.io.File) ParamsFile(edu.cmu.ml.proppr.util.ParamsFile) ParsedFile(edu.cmu.ml.proppr.util.ParsedFile) ArrayLearningGraphBuilder(edu.cmu.ml.proppr.graph.ArrayLearningGraphBuilder)

Example 2 with SimpleSymbolTable

use of edu.cmu.ml.proppr.util.SimpleSymbolTable in project ProPPR by TeamCohen.

the class ProverTestTemplate method testProveState.

@Test
public void testProveState() throws LogicProgramException {
    log.info("testProveState");
    FeatureDictWeighter w = new InnerProductWeighter();
    SymbolTable<Feature> featureTab = new SimpleSymbolTable<Feature>();
    int milk = featureTab.getId(new Feature("milk"));
    w.put(featureTab.getSymbol(milk), 2);
    prover.setWeighter(w);
    ProofGraph pg = prover.makeProofGraph(new InferenceExample(Query.parse("isa(elsie,X)"), null, null), apr, featureTab, lpMilk, fMilk);
    //("isa","elsie","X"));
    Map<State, Double> dist = prover.prove(pg, new StatusLogger());
    double query = 0.0;
    double platypus = 0.0;
    double others = 0.0;
    double all = 0.0;
    for (Map.Entry<State, Double> s : dist.entrySet()) {
        Query q = pg.fill(s.getKey());
        String arg2 = q.getRhs()[0].getArg(1).getName();
        if ("platypus".equals(arg2)) {
            platypus = Math.max(platypus, s.getValue());
        } else if ("X1".equals(arg2)) {
            query = Math.max(query, s.getValue());
        } else {
            others = Math.max(others, s.getValue());
        }
        System.out.println(q + "\t" + s.getValue());
        all += s.getValue();
    }
    System.out.println();
    System.out.println("query    weight: " + query);
    System.out.println("platypus weight: " + platypus);
    System.out.println("others   weight: " + others);
    //		assertTrue("query should retain most weight",query > Math.max(platypus,others));
    assertTrue("milk-featured paths should score higher than others", platypus > others);
    assertEquals("Total weight of all states should be around 1.0", 1.0, all, 10 * this.apr.epsilon);
    assertEquals("Known features", 1, prover.weighter.numKnownFeatures);
    assertEquals("Unknown features", 5, prover.weighter.numUnknownFeatures);
}
Also used : StatusLogger(edu.cmu.ml.proppr.util.StatusLogger) Query(edu.cmu.ml.proppr.prove.wam.Query) StateProofGraph(edu.cmu.ml.proppr.prove.wam.StateProofGraph) ProofGraph(edu.cmu.ml.proppr.prove.wam.ProofGraph) Feature(edu.cmu.ml.proppr.prove.wam.Feature) InferenceExample(edu.cmu.ml.proppr.examples.InferenceExample) FeatureDictWeighter(edu.cmu.ml.proppr.prove.FeatureDictWeighter) SimpleSymbolTable(edu.cmu.ml.proppr.util.SimpleSymbolTable) State(edu.cmu.ml.proppr.prove.wam.State) Map(java.util.Map) InnerProductWeighter(edu.cmu.ml.proppr.prove.InnerProductWeighter) Test(org.junit.Test)

Example 3 with SimpleSymbolTable

use of edu.cmu.ml.proppr.util.SimpleSymbolTable in project ProPPR by TeamCohen.

the class Trainer method main.

public static void main(String[] args) {
    try {
        int inputFiles = Configuration.USE_TRAIN | Configuration.USE_INIT_PARAMS;
        int outputFiles = Configuration.USE_PARAMS;
        int constants = Configuration.USE_EPOCHS | Configuration.USE_FORCE | Configuration.USE_THREADS | Configuration.USE_FIXEDWEIGHTS;
        int modules = Configuration.USE_TRAINER | Configuration.USE_SRW | Configuration.USE_SQUASHFUNCTION;
        ModuleConfiguration c = new ModuleConfiguration(args, inputFiles, outputFiles, constants, modules);
        log.info(c.toString());
        String groundedFile = c.queryFile.getPath();
        if (!c.queryFile.getName().endsWith(Grounder.GROUNDED_SUFFIX)) {
            throw new IllegalStateException("Run Grounder on " + c.queryFile.getName() + " first. Ground+Train in one go is not supported yet.");
        }
        SymbolTable<String> masterFeatures = new SimpleSymbolTable<String>();
        File featureIndex = new File(groundedFile + Grounder.FEATURE_INDEX_EXTENSION);
        if (featureIndex.exists()) {
            log.info("Reading feature index from " + featureIndex.getName() + "...");
            for (String line : new ParsedFile(featureIndex)) {
                masterFeatures.insert(line.trim());
            }
        }
        log.info("Training model parameters on " + groundedFile + "...");
        long start = System.currentTimeMillis();
        ParamVector<String, ?> params = c.trainer.train(masterFeatures, new ParsedFile(groundedFile), new ArrayLearningGraphBuilder(), c.initParamsFile, c.epochs);
        System.out.println("Training time: " + (System.currentTimeMillis() - start));
        if (c.paramsFile != null) {
            log.info("Saving parameters to " + c.paramsFile + "...");
            ParamsFile.save(params, c.paramsFile, c);
        }
    } catch (Throwable t) {
        t.printStackTrace();
        System.exit(-1);
    }
}
Also used : ModuleConfiguration(edu.cmu.ml.proppr.util.ModuleConfiguration) SimpleSymbolTable(edu.cmu.ml.proppr.util.SimpleSymbolTable) ParamsFile(edu.cmu.ml.proppr.util.ParamsFile) ParsedFile(edu.cmu.ml.proppr.util.ParsedFile) File(java.io.File) ParsedFile(edu.cmu.ml.proppr.util.ParsedFile) ArrayLearningGraphBuilder(edu.cmu.ml.proppr.graph.ArrayLearningGraphBuilder)

Aggregations

SimpleSymbolTable (edu.cmu.ml.proppr.util.SimpleSymbolTable)3 ArrayLearningGraphBuilder (edu.cmu.ml.proppr.graph.ArrayLearningGraphBuilder)2 ParamsFile (edu.cmu.ml.proppr.util.ParamsFile)2 ParsedFile (edu.cmu.ml.proppr.util.ParsedFile)2 File (java.io.File)2 InferenceExample (edu.cmu.ml.proppr.examples.InferenceExample)1 FixedWeightRules (edu.cmu.ml.proppr.learn.tools.FixedWeightRules)1 FeatureDictWeighter (edu.cmu.ml.proppr.prove.FeatureDictWeighter)1 InnerProductWeighter (edu.cmu.ml.proppr.prove.InnerProductWeighter)1 Feature (edu.cmu.ml.proppr.prove.wam.Feature)1 ProofGraph (edu.cmu.ml.proppr.prove.wam.ProofGraph)1 Query (edu.cmu.ml.proppr.prove.wam.Query)1 State (edu.cmu.ml.proppr.prove.wam.State)1 StateProofGraph (edu.cmu.ml.proppr.prove.wam.StateProofGraph)1 CustomConfiguration (edu.cmu.ml.proppr.util.CustomConfiguration)1 ModuleConfiguration (edu.cmu.ml.proppr.util.ModuleConfiguration)1 StatusLogger (edu.cmu.ml.proppr.util.StatusLogger)1 SimpleParamVector (edu.cmu.ml.proppr.util.math.SimpleParamVector)1 Map (java.util.Map)1 CommandLine (org.apache.commons.cli.CommandLine)1