use of edu.cmu.ml.proppr.graph.ArrayLearningGraphBuilder in project ProPPR by TeamCohen.
the class RedBlueGraph method setup.
@Before
public void setup() {
// if (!Logger.getRootLogger().getAllAppenders().hasMoreElements()) {
// BasicConfigurator.configure(); Logger.getRootLogger().setLevel(Level.WARN);
// }
LearningGraphBuilder lgb = new ArrayLearningGraphBuilder();
brGraph = (LearningGraph) lgb.create(new SimpleSymbolTable<String>());
lgb.index(1);
lgb.setGraphSize(brGraph, magicNumber * 2, -1, -1);
// brSRWs = new ArrayList<SRW>();
// Collections.addAll(brSRWs, new L2SqLossSRW(), new L2SqLossSRW(), new L2SqLossSRW());
{
int u = nodes.getId("b0"), v = nodes.getId("r0");
HashMap<String, Double> ff = new HashMap<String, Double>();
ff.put("fromb", 1.0);
ff.put("tor", 1.0);
lgb.addOutlink(brGraph, u, makeOutlink(lgb, ff, v));
ff = new HashMap<String, Double>();
ff.put("fromr", 1.0);
ff.put("tob", 1.0);
lgb.addOutlink(brGraph, v, makeOutlink(lgb, ff, u));
}
addColor(lgb, brGraph, magicNumber, "r");
addColor(lgb, brGraph, magicNumber, "b");
// save sets of red and blue nodes
reds = new TreeSet<String>();
blues = new TreeSet<String>();
for (int ui = 1; ui < (2 * magicNumber + 1); ui++) {
String u = nodes.getSymbol(ui);
if (u.startsWith("b"))
blues.add(u);
else
reds.add(u);
}
moreSetup(lgb);
lgb.freeze(brGraph);
// System.err.println("\n"+brGraphs.get(0).dump("r0"));
}
use of edu.cmu.ml.proppr.graph.ArrayLearningGraphBuilder in project ProPPR by TeamCohen.
the class Diagnostic method main.
public static void main(String[] args) {
StatusLogger status = new StatusLogger();
try {
int inputFiles = Configuration.USE_TRAIN;
int outputFiles = 0;
int constants = Configuration.USE_THREADS | Configuration.USE_THROTTLE;
int modules = Configuration.USE_SRW;
ModuleConfiguration c = new ModuleConfiguration(args, inputFiles, outputFiles, constants, modules);
log.info(c.toString());
String groundedFile = c.queryFile.getPath();
log.info("Parsing " + groundedFile + "...");
long start = System.currentTimeMillis();
final ArrayLearningGraphBuilder b = new ArrayLearningGraphBuilder();
final SRW srw = c.srw;
final ParamVector<String, ?> params = srw.setupParams(new SimpleParamVector<String>(new ConcurrentHashMap<String, Double>(16, (float) 0.75, 24)));
srw.setEpoch(1);
srw.clearLoss();
srw.fixedWeightRules().addExact("id(restart)");
srw.fixedWeightRules().addExact("id(trueLoop)");
srw.fixedWeightRules().addExact("id(trueLoopRestart)");
/* DiagSrwES: */
ArrayList<Future<PosNegRWExample>> parsed = new ArrayList<Future<PosNegRWExample>>();
final ExecutorService trainerPool = Executors.newFixedThreadPool(c.nthreads > 1 ? c.nthreads / 2 : 1);
final ExecutorService parserPool = Executors.newFixedThreadPool(c.nthreads > 1 ? c.nthreads / 2 : 1);
int i = 1;
for (String s : new ParsedFile(groundedFile)) {
final int id = i++;
final String in = s;
parsed.add(parserPool.submit(new Callable<PosNegRWExample>() {
@Override
public PosNegRWExample call() throws Exception {
try {
//log.debug("Job start "+id);
//PosNegRWExample ret = parser.parse(in, b.copy());
log.debug("Parsing start " + id);
PosNegRWExample ret = new RWExampleParser().parse(in, b.copy(), srw);
log.debug("Parsing done " + id);
//log.debug("Job done "+id);
return ret;
} catch (IllegalArgumentException e) {
System.err.println("Problem with #" + id);
e.printStackTrace();
}
return null;
}
}));
}
parserPool.shutdown();
i = 1;
for (Future<PosNegRWExample> future : parsed) {
final int id = i++;
final Future<PosNegRWExample> in = future;
trainerPool.submit(new Runnable() {
@Override
public void run() {
try {
PosNegRWExample ex = in.get();
log.debug("Training start " + id);
srw.trainOnExample(params, ex, status);
log.debug("Training done " + id);
} catch (InterruptedException e) {
e.printStackTrace();
} catch (ExecutionException e) {
e.printStackTrace();
}
}
});
}
trainerPool.shutdown();
try {
parserPool.awaitTermination(7, TimeUnit.DAYS);
trainerPool.awaitTermination(7, TimeUnit.DAYS);
} catch (InterruptedException e) {
log.error("Interrupted?", e);
}
/* /SrwES */
/* SrwTtwop:
final ExecutorService parserPool = Executors.newFixedThreadPool(c.nthreads>1?c.nthreads/2:1);
Multithreading<String,PosNegRWExample> m = new Multithreading<String,PosNegRWExample>(log);
m.executeJob(c.nthreads/2, new ParsedFile(groundedFile),
new Transformer<String,PosNegRWExample>() {
@Override
public Callable<PosNegRWExample> transformer(final String in, final int id) {
return new Callable<PosNegRWExample>() {
@Override
public PosNegRWExample call() throws Exception {
try {
//log.debug("Job start "+id);
//PosNegRWExample ret = parser.parse(in, b.copy());
log.debug("Parsing start "+id);
PosNegRWExample ret = new GroundedExampleParser().parse(in, b.copy());
log.debug("Parsing done "+id);
//log.debug("Job done "+id);
return ret;
} catch (IllegalArgumentException e) {
System.err.println("Problem with #"+id);
e.printStackTrace();
}
return null;
}};
}}, new Cleanup<PosNegRWExample>() {
@Override
public Runnable cleanup(final Future<PosNegRWExample> in, final int id) {
return new Runnable(){
@Override
public void run() {
try {
final PosNegRWExample ex = in.get();
log.debug("Cleanup start "+id);
trainerPool.submit(new Runnable() {
@Override
public void run(){
log.debug("Training start "+id);
srw.trainOnExample(params,ex);
log.debug("Training done "+id);
}
});
} catch (InterruptedException e) {
e.printStackTrace();
} catch (ExecutionException e) {
e.printStackTrace();
}
log.debug("Cleanup done "+id);
}};
}}, c.throttle);
trainerPool.shutdown();
try {
trainerPool.awaitTermination(7, TimeUnit.DAYS);
} catch (InterruptedException e) {
log.error("Interrupted?",e);
}
/SrwTtwop */
/* all diag tasks except SrwO:
Multithreading<String,PosNegRWExample> m = new Multithreading<String,PosNegRWExample>(log);
m.executeJob(c.nthreads, new ParsedFile(groundedFile),
new Transformer<String,PosNegRWExample>() {
@Override
public Callable<PosNegRWExample> transformer(final String in, final int id) {
return new Callable<PosNegRWExample>() {
@Override
public PosNegRWExample call() throws Exception {
try {
//log.debug("Job start "+id);
//PosNegRWExample ret = parser.parse(in, b.copy());
log.debug("Parsing start "+id);
PosNegRWExample ret = new GroundedExampleParser().parse(in, b.copy());
log.debug("Parsing done "+id);
log.debug("Training start "+id);
srw.trainOnExample(params,ret);
log.debug("Training done "+id);
//log.debug("Job done "+id);
return ret;
} catch (IllegalArgumentException e) {
System.err.println("Problem with #"+id);
e.printStackTrace();
}
return null;
}};
}}, new Cleanup<PosNegRWExample>() {
@Override
public Runnable cleanup(final Future<PosNegRWExample> in, final int id) {
return new Runnable(){
//ArrayList<PosNegRWExample> done = new ArrayList<PosNegRWExample>();
@Override
public void run() {
try {
//done.add(in.get());
in.get();
} catch (InterruptedException e) {
e.printStackTrace();
} catch (ExecutionException e) {
e.printStackTrace();
}
log.debug("Cleanup start "+id);
log.debug("Cleanup done "+id);
}};
}}, c.throttle);
*/
/* SrwO:
Multithreading<PosNegRWExample,Integer> m = new Multithreading<PosNegRWExample,Integer>(log);
m.executeJob(c.nthreads, new PosNegRWExampleStreamer(new ParsedFile(groundedFile),new ArrayLearningGraphBuilder()),
new Transformer<PosNegRWExample,Integer>() {
@Override
public Callable<Integer> transformer(final PosNegRWExample in, final int id) {
return new Callable<Integer>() {
@Override
public Integer call() throws Exception {
try {
//log.debug("Job start "+id);
//PosNegRWExample ret = parser.parse(in, b.copy());
log.debug("Training start "+id);
srw.trainOnExample(params,in);
log.debug("Training done "+id);
//log.debug("Job done "+id);
} catch (IllegalArgumentException e) {
System.err.println("Problem with #"+id);
e.printStackTrace();
}
return in.length();
}};
}}, new Cleanup<Integer>() {
@Override
public Runnable cleanup(final Future<Integer> in, final int id) {
return new Runnable(){
//ArrayList<PosNegRWExample> done = new ArrayList<PosNegRWExample>();
@Override
public void run() {
try {
//done.add(in.get());
in.get();
} catch (InterruptedException e) {
e.printStackTrace();
} catch (ExecutionException e) {
e.printStackTrace();
}
log.debug("Cleanup start "+id);
log.debug("Cleanup done "+id);
}};
}}, c.throttle);
*/
srw.cleanupParams(params, params);
log.info("Finished diagnostic in " + (System.currentTimeMillis() - start) + " ms");
} catch (Throwable t) {
t.printStackTrace();
System.exit(-1);
}
}
use of edu.cmu.ml.proppr.graph.ArrayLearningGraphBuilder in project ProPPR by TeamCohen.
the class GradientFinder method main.
public static void main(String[] args) {
try {
int inputFiles = Configuration.USE_GROUNDED | Configuration.USE_INIT_PARAMS;
int outputFiles = Configuration.USE_GRADIENT | Configuration.USE_PARAMS;
int modules = Configuration.USE_TRAINER | Configuration.USE_SRW | Configuration.USE_SQUASHFUNCTION;
int constants = Configuration.USE_THREADS | Configuration.USE_EPOCHS | Configuration.USE_FORCE | Configuration.USE_FIXEDWEIGHTS;
CustomConfiguration c = new CustomConfiguration(args, inputFiles, outputFiles, constants, modules) {
boolean relax;
@Override
protected Option checkOption(Option o) {
if (PARAMS_FILE_OPTION.equals(o.getLongOpt()) || INIT_PARAMS_FILE_OPTION.equals(o.getLongOpt()))
o.setRequired(false);
return o;
}
@Override
protected void addCustomOptions(Options options, int[] flags) {
options.addOption(Option.builder().longOpt("relaxFW").desc("Relax fixedWeight rules for gradient computation (used in ProngHorn)").optionalArg(true).build());
}
@Override
protected void retrieveCustomSettings(CommandLine line, int[] flags, Options options) {
if (groundedFile == null || !groundedFile.exists())
usageOptions(options, flags, "Must specify grounded file using --" + Configuration.GROUNDED_FILE_OPTION);
if (gradientFile == null)
usageOptions(options, flags, "Must specify gradient using --" + Configuration.GRADIENT_FILE_OPTION);
// default to 0 epochs
if (!options.hasOption("epochs"))
this.epochs = 0;
this.relax = false;
if (options.hasOption("relaxFW"))
this.relax = true;
}
@Override
public Object getCustomSetting(String name) {
if ("relaxFW".equals(name))
return this.relax;
return null;
}
};
System.out.println(c.toString());
ParamVector<String, ?> params = null;
SymbolTable<String> masterFeatures = new SimpleSymbolTable<String>();
File featureIndex = new File(c.groundedFile.getParent(), c.groundedFile.getName() + Grounder.FEATURE_INDEX_EXTENSION);
if (featureIndex.exists()) {
log.info("Reading feature index from " + featureIndex.getName() + "...");
for (String line : new ParsedFile(featureIndex)) {
masterFeatures.insert(line.trim());
}
}
if (c.epochs > 0) {
// train first
log.info("Training for " + c.epochs + " epochs...");
params = c.trainer.train(masterFeatures, new ParsedFile(c.groundedFile), new ArrayLearningGraphBuilder(), // create a parameter vector
c.initParamsFile, c.epochs);
if (c.paramsFile != null)
ParamsFile.save(params, c.paramsFile, c);
} else if (c.initParamsFile != null) {
params = new SimpleParamVector<String>(Dictionary.load(new ParsedFile(c.initParamsFile)));
} else if (c.paramsFile != null) {
params = new SimpleParamVector<String>(Dictionary.load(new ParsedFile(c.paramsFile)));
} else {
params = new SimpleParamVector<String>();
}
// this lets prongHorn hold external features fixed for training, but still compute their gradient
if (((Boolean) c.getCustomSetting("relaxFW"))) {
log.info("Turning off fixedWeight rules");
c.trainer.setFixedWeightRules(new FixedWeightRules());
}
ParamVector<String, ?> batchGradient = c.trainer.findGradient(masterFeatures, new ParsedFile(c.groundedFile), new ArrayLearningGraphBuilder(), params);
ParamsFile.save(batchGradient, c.gradientFile, c);
} catch (Throwable t) {
t.printStackTrace();
System.exit(-1);
}
}
use of edu.cmu.ml.proppr.graph.ArrayLearningGraphBuilder in project ProPPR by TeamCohen.
the class SRWRestartTest method moreOutlinks.
@Override
public void moreOutlinks(LearningGraphBuilder lgb, LearningGraph graph, int u) {
ArrayLearningGraphBuilder b = ((ArrayLearningGraphBuilder) lgb);
// first see if we already have a reset link (covered by makeOutlink())
int r0 = nodes.getId("r0");
if (b.outlinks[u] != null) {
for (RWOutlink o : b.outlinks[u]) {
if (o.nodeid == r0) {
return;
}
}
}
// no reset link exists, so add one
lgb.addOutlink(graph, u, new RWOutlink(new int[] { lgb.getFeatureLibrary().getId("id(restart)") }, new double[] { this.srw.getSquashingFunction().defaultValue() }, r0));
}
use of edu.cmu.ml.proppr.graph.ArrayLearningGraphBuilder in project ProPPR by TeamCohen.
the class Trainer method main.
public static void main(String[] args) {
try {
int inputFiles = Configuration.USE_TRAIN | Configuration.USE_INIT_PARAMS;
int outputFiles = Configuration.USE_PARAMS;
int constants = Configuration.USE_EPOCHS | Configuration.USE_FORCE | Configuration.USE_THREADS | Configuration.USE_FIXEDWEIGHTS;
int modules = Configuration.USE_TRAINER | Configuration.USE_SRW | Configuration.USE_SQUASHFUNCTION;
ModuleConfiguration c = new ModuleConfiguration(args, inputFiles, outputFiles, constants, modules);
log.info(c.toString());
String groundedFile = c.queryFile.getPath();
if (!c.queryFile.getName().endsWith(Grounder.GROUNDED_SUFFIX)) {
throw new IllegalStateException("Run Grounder on " + c.queryFile.getName() + " first. Ground+Train in one go is not supported yet.");
}
SymbolTable<String> masterFeatures = new SimpleSymbolTable<String>();
File featureIndex = new File(groundedFile + Grounder.FEATURE_INDEX_EXTENSION);
if (featureIndex.exists()) {
log.info("Reading feature index from " + featureIndex.getName() + "...");
for (String line : new ParsedFile(featureIndex)) {
masterFeatures.insert(line.trim());
}
}
log.info("Training model parameters on " + groundedFile + "...");
long start = System.currentTimeMillis();
ParamVector<String, ?> params = c.trainer.train(masterFeatures, new ParsedFile(groundedFile), new ArrayLearningGraphBuilder(), c.initParamsFile, c.epochs);
System.out.println("Training time: " + (System.currentTimeMillis() - start));
if (c.paramsFile != null) {
log.info("Saving parameters to " + c.paramsFile + "...");
ParamsFile.save(params, c.paramsFile, c);
}
} catch (Throwable t) {
t.printStackTrace();
System.exit(-1);
}
}
Aggregations