Search in sources :

Example 1 with FixedWeightRules

use of edu.cmu.ml.proppr.learn.tools.FixedWeightRules in project ProPPR by TeamCohen.

the class GradientFinder method main.

public static void main(String[] args) {
    try {
        int inputFiles = Configuration.USE_GROUNDED | Configuration.USE_INIT_PARAMS;
        int outputFiles = Configuration.USE_GRADIENT | Configuration.USE_PARAMS;
        int modules = Configuration.USE_TRAINER | Configuration.USE_SRW | Configuration.USE_SQUASHFUNCTION;
        int constants = Configuration.USE_THREADS | Configuration.USE_EPOCHS | Configuration.USE_FORCE | Configuration.USE_FIXEDWEIGHTS;
        CustomConfiguration c = new CustomConfiguration(args, inputFiles, outputFiles, constants, modules) {

            boolean relax;

            @Override
            protected Option checkOption(Option o) {
                if (PARAMS_FILE_OPTION.equals(o.getLongOpt()) || INIT_PARAMS_FILE_OPTION.equals(o.getLongOpt()))
                    o.setRequired(false);
                return o;
            }

            @Override
            protected void addCustomOptions(Options options, int[] flags) {
                options.addOption(Option.builder().longOpt("relaxFW").desc("Relax fixedWeight rules for gradient computation (used in ProngHorn)").optionalArg(true).build());
            }

            @Override
            protected void retrieveCustomSettings(CommandLine line, int[] flags, Options options) {
                if (groundedFile == null || !groundedFile.exists())
                    usageOptions(options, flags, "Must specify grounded file using --" + Configuration.GROUNDED_FILE_OPTION);
                if (gradientFile == null)
                    usageOptions(options, flags, "Must specify gradient using --" + Configuration.GRADIENT_FILE_OPTION);
                // default to 0 epochs
                if (!options.hasOption("epochs"))
                    this.epochs = 0;
                this.relax = false;
                if (options.hasOption("relaxFW"))
                    this.relax = true;
            }

            @Override
            public Object getCustomSetting(String name) {
                if ("relaxFW".equals(name))
                    return this.relax;
                return null;
            }
        };
        System.out.println(c.toString());
        ParamVector<String, ?> params = null;
        SymbolTable<String> masterFeatures = new SimpleSymbolTable<String>();
        File featureIndex = new File(c.groundedFile.getParent(), c.groundedFile.getName() + Grounder.FEATURE_INDEX_EXTENSION);
        if (featureIndex.exists()) {
            log.info("Reading feature index from " + featureIndex.getName() + "...");
            for (String line : new ParsedFile(featureIndex)) {
                masterFeatures.insert(line.trim());
            }
        }
        if (c.epochs > 0) {
            // train first
            log.info("Training for " + c.epochs + " epochs...");
            params = c.trainer.train(masterFeatures, new ParsedFile(c.groundedFile), new ArrayLearningGraphBuilder(), // create a parameter vector
            c.initParamsFile, c.epochs);
            if (c.paramsFile != null)
                ParamsFile.save(params, c.paramsFile, c);
        } else if (c.initParamsFile != null) {
            params = new SimpleParamVector<String>(Dictionary.load(new ParsedFile(c.initParamsFile)));
        } else if (c.paramsFile != null) {
            params = new SimpleParamVector<String>(Dictionary.load(new ParsedFile(c.paramsFile)));
        } else {
            params = new SimpleParamVector<String>();
        }
        // this lets prongHorn hold external features fixed for training, but still compute their gradient
        if (((Boolean) c.getCustomSetting("relaxFW"))) {
            log.info("Turning off fixedWeight rules");
            c.trainer.setFixedWeightRules(new FixedWeightRules());
        }
        ParamVector<String, ?> batchGradient = c.trainer.findGradient(masterFeatures, new ParsedFile(c.groundedFile), new ArrayLearningGraphBuilder(), params);
        ParamsFile.save(batchGradient, c.gradientFile, c);
    } catch (Throwable t) {
        t.printStackTrace();
        System.exit(-1);
    }
}
Also used : Options(org.apache.commons.cli.Options) CustomConfiguration(edu.cmu.ml.proppr.util.CustomConfiguration) SimpleParamVector(edu.cmu.ml.proppr.util.math.SimpleParamVector) CommandLine(org.apache.commons.cli.CommandLine) SimpleSymbolTable(edu.cmu.ml.proppr.util.SimpleSymbolTable) FixedWeightRules(edu.cmu.ml.proppr.learn.tools.FixedWeightRules) Option(org.apache.commons.cli.Option) ParsedFile(edu.cmu.ml.proppr.util.ParsedFile) File(java.io.File) ParamsFile(edu.cmu.ml.proppr.util.ParamsFile) ParsedFile(edu.cmu.ml.proppr.util.ParsedFile) ArrayLearningGraphBuilder(edu.cmu.ml.proppr.graph.ArrayLearningGraphBuilder)

Example 2 with FixedWeightRules

use of edu.cmu.ml.proppr.learn.tools.FixedWeightRules in project ProPPR by TeamCohen.

the class Configuration method retrieveSettings.

protected void retrieveSettings(CommandLine line, int[] allFlags, Options options) throws IOException {
    int flags;
    if (line.hasOption("help"))
        usageOptions(options, allFlags);
    // input files: must exist already
    flags = inputFiles(allFlags);
    if (isOn(flags, USE_QUERIES) && line.hasOption(QUERIES_FILE_OPTION))
        this.queryFile = getExistingFile(line.getOptionValue(QUERIES_FILE_OPTION));
    if (isOn(flags, USE_GROUNDED) && line.hasOption(GROUNDED_FILE_OPTION))
        this.groundedFile = getExistingFile(line.getOptionValue(GROUNDED_FILE_OPTION));
    if (isOn(flags, USE_ANSWERS) && line.hasOption(SOLUTIONS_FILE_OPTION))
        this.solutionsFile = getExistingFile(line.getOptionValue(SOLUTIONS_FILE_OPTION));
    if (isOn(flags, USE_TEST) && line.hasOption(TEST_FILE_OPTION))
        this.testFile = getExistingFile(line.getOptionValue(TEST_FILE_OPTION));
    if (isOn(flags, USE_TRAIN) && line.hasOption(TRAIN_FILE_OPTION))
        this.queryFile = getExistingFile(line.getOptionValue(TRAIN_FILE_OPTION));
    if (isOn(flags, USE_PARAMS) && line.hasOption(PARAMS_FILE_OPTION))
        this.paramsFile = getExistingFile(line.getOptionValue(PARAMS_FILE_OPTION));
    if (isOn(flags, USE_INIT_PARAMS) && line.hasOption(INIT_PARAMS_FILE_OPTION))
        this.initParamsFile = getExistingFile(line.getOptionValue(INIT_PARAMS_FILE_OPTION));
    if (isOn(flags, USE_GRADIENT) && line.hasOption(GRADIENT_FILE_OPTION))
        this.gradientFile = getExistingFile(line.getOptionValue(GRADIENT_FILE_OPTION));
    // output & intermediate files: may not exist yet
    flags = outputFiles(allFlags);
    if (isOn(flags, USE_QUERIES) && line.hasOption(QUERIES_FILE_OPTION))
        this.queryFile = new File(line.getOptionValue(QUERIES_FILE_OPTION));
    if (isOn(flags, USE_GROUNDED) && line.hasOption(GROUNDED_FILE_OPTION))
        this.groundedFile = new File(line.getOptionValue(GROUNDED_FILE_OPTION));
    if (isOn(flags, USE_ANSWERS) && line.hasOption(SOLUTIONS_FILE_OPTION))
        this.solutionsFile = new File(line.getOptionValue(SOLUTIONS_FILE_OPTION));
    if (isOn(flags, USE_TEST) && line.hasOption(TEST_FILE_OPTION))
        this.testFile = new File(line.getOptionValue(TEST_FILE_OPTION));
    if (isOn(flags, USE_TRAIN) && line.hasOption(TRAIN_FILE_OPTION))
        this.queryFile = new File(line.getOptionValue(TRAIN_FILE_OPTION));
    if (isOn(flags, USE_PARAMS) && line.hasOption(PARAMS_FILE_OPTION))
        this.paramsFile = new File(line.getOptionValue(PARAMS_FILE_OPTION));
    if (isOn(flags, USE_GRADIENT) && line.hasOption(GRADIENT_FILE_OPTION))
        this.gradientFile = new File(line.getOptionValue(GRADIENT_FILE_OPTION));
    // constants
    flags = constants(allFlags);
    if (isOn(flags, USE_WAM)) {
        if (line.hasOption(PROGRAMFILES_CONST_OPTION))
            this.programFiles = line.getOptionValues(PROGRAMFILES_CONST_OPTION);
        if (line.hasOption(TERNARYINDEX_CONST_OPTION))
            this.ternaryIndex = Boolean.parseBoolean(line.getOptionValue(TERNARYINDEX_CONST_OPTION));
        if (line.hasOption(PRUNEDPREDICATE_CONST_OPTION)) {
            this.prunedPredicateRules = new FixedWeightRules(line.getOptionValues(PRUNEDPREDICATE_CONST_OPTION));
        }
    }
    if (anyOn(flags, USE_APR))
        if (line.hasOption(APR_CONST_OPTION))
            this.apr = new APROptions(line.getOptionValues(APR_CONST_OPTION));
    if (isOn(flags, USE_THREADS) && line.hasOption(THREADS_CONST_OPTION))
        this.nthreads = Integer.parseInt(line.getOptionValue(THREADS_CONST_OPTION));
    if (isOn(flags, USE_EPOCHS) && line.hasOption(EPOCHS_CONST_OPTION))
        this.epochs = Integer.parseInt(line.getOptionValue(EPOCHS_CONST_OPTION));
    if (isOn(flags, USE_FORCE) && line.hasOption(FORCE_CONST_OPTION))
        this.force = true;
    if (isOn(flags, USE_ORDER) && line.hasOption(ORDER_CONST_OPTION)) {
        String order = line.getOptionValue(ORDER_CONST_OPTION);
        if (order.equals("same") || order.equals("maintain"))
            this.maintainOrder = true;
        else
            this.maintainOrder = false;
    }
    if (anyOn(flags, USE_DUPCHECK | USE_WAM) && line.hasOption(DUPCHECK_CONST_OPTION))
        this.duplicates = (int) Double.parseDouble(line.getOptionValue(DUPCHECK_CONST_OPTION));
    if (isOn(flags, USE_THROTTLE) && line.hasOption(THROTTLE_CONST_OPTION))
        this.throttle = Integer.parseInt(line.getOptionValue(THROTTLE_CONST_OPTION));
    if (isOn(flags, USE_EMPTYGRAPHS) && line.hasOption(EMPTYGRAPHS_CONST_OPTION))
        this.includeEmptyGraphs = true;
    if (isOn(flags, USE_FIXEDWEIGHTS) && line.hasOption(FIXEDWEIGHTS_CONST_OPTION))
        this.fixedWeightRules = new FixedWeightRules(line.getOptionValues(FIXEDWEIGHTS_CONST_OPTION));
    if (anyOn(flags, USE_SMART_COUNTFEATURES)) {
        if (line.hasOption(COUNTFEATURES_CONST_OPTION))
            this.countFeatures = Boolean.parseBoolean(line.getOptionValue(COUNTFEATURES_CONST_OPTION));
        else if (this.nthreads > 20) {
            log.warn("Large numbers of threads (>20, so " + this.nthreads + " qualifies) can cause a bottleneck in FeatureDictWeighter. If you're " + "seeing lower system loads than expected and you're sure your examples/query/param files are correct, you can reduce contention & increase speed performance by adding " + "'--" + COUNTFEATURES_CONST_OPTION + " false' to your command line.");
        }
    }
    if (this.programFiles != null)
        this.loadProgramFiles(line, allFlags, options);
}
Also used : FixedWeightRules(edu.cmu.ml.proppr.learn.tools.FixedWeightRules)

Aggregations

FixedWeightRules (edu.cmu.ml.proppr.learn.tools.FixedWeightRules)2 ArrayLearningGraphBuilder (edu.cmu.ml.proppr.graph.ArrayLearningGraphBuilder)1 CustomConfiguration (edu.cmu.ml.proppr.util.CustomConfiguration)1 ParamsFile (edu.cmu.ml.proppr.util.ParamsFile)1 ParsedFile (edu.cmu.ml.proppr.util.ParsedFile)1 SimpleSymbolTable (edu.cmu.ml.proppr.util.SimpleSymbolTable)1 SimpleParamVector (edu.cmu.ml.proppr.util.math.SimpleParamVector)1 File (java.io.File)1 CommandLine (org.apache.commons.cli.CommandLine)1 Option (org.apache.commons.cli.Option)1 Options (org.apache.commons.cli.Options)1