use of edu.illinois.cs.cogcomp.sl.util.WeightVector in project cogcomp-nlp by CogComp.
the class SenseManager method getScores.
/**
* Scores instance for the different labels allowed for it
*/
public double[] getScores(SenseInstance x, boolean rescoreInvalidLabels) {
int numLabels = this.getNumLabels();
double[] scores = new double[numLabels];
WeightVector w;
try {
w = this.getModelInfo().getWeights();
assert w != null;
} catch (Exception e) {
log.error("Unable to load weight vector, exception:\n{}", e);
throw new RuntimeException(e);
}
for (int label = 0; label < numLabels; label++) {
if (!this.isValidLabel(x, label) && rescoreInvalidLabels) {
scores[label] = -50;
} else {
SenseStructure y = new SenseStructure(x, label, this);
scores[label] = w.dotProduct(y.getFeatureVector());
}
}
scores = MathUtilities.softmax(scores);
return scores;
}
use of edu.illinois.cs.cogcomp.sl.util.WeightVector in project cogcomp-nlp by CogComp.
the class VerbSenseClassifierMain method train.
@CommandDescription(description = "Trains the verb-sense model.", usage = "train")
public static void train() throws Exception {
SenseManager manager = getManager(true);
int numThreads = Runtime.getRuntime().availableProcessors();
ModelInfo modelInfo = manager.getModelInfo();
String featureSet = "" + modelInfo.featureManifest.getIncludedFeatures().hashCode();
String cacheFile = VerbSenseConfigurator.getPrunedFeatureCacheFile(featureSet, rm);
AbstractInferenceSolver[] inference = new AbstractInferenceSolver[numThreads];
// TODO Can I replace this with ILPInference?
for (int i = 0; i < inference.length; i++) inference[i] = new MulticlassInference(manager);
double c;
FeatureVectorCacheFile cache;
cache = new FeatureVectorCacheFile(cacheFile, manager);
StructuredProblem cvProblem = cache.getStructuredProblem(20000);
cache.close();
LearnerParameters params = JLISLearner.crossvalStructSVMSense(cvProblem, inference, 4);
c = params.getcStruct();
log.info("c = {} after cv", c);
cache = new FeatureVectorCacheFile(cacheFile, manager);
StructuredProblem problem = cache.getStructuredProblem();
cache.close();
WeightVector w = JLISLearner.trainStructSVM(inference, problem, c);
JLISLearner.saveWeightVector(w, manager.getModelFileName());
}
use of edu.illinois.cs.cogcomp.sl.util.WeightVector in project cogcomp-nlp by CogComp.
the class CrossValidationHelper method tryParamSerial.
private PerformanceMeasure tryParamSerial(DatasetType train, LearnerParameters param) throws Exception {
List<PerformanceMeasure> perf = new ArrayList<>();
for (int foldId = 0; foldId < nFolds; foldId++) {
Pair<DatasetType, DatasetType> foldData = foldSplitter.getFoldData(train, foldId);
log.info("Starting fold {} for params {}", foldId, param);
WeightVector w = trainer.train(foldData.getFirst(), param, inference);
log.info("Finished training fold {} for params {}", foldId, param);
PerformanceMeasure p = tester.evaluate(foldData.getSecond(), w, inference[0]);
log.info("Performance for fold {}, params {} =" + p.summarize(), foldId, param);
perf.add(p);
}
return averager.average(perf);
}
use of edu.illinois.cs.cogcomp.sl.util.WeightVector in project cogcomp-nlp by CogComp.
the class WeightVectorUtils method load.
public static WeightVector load(String fileName) {
try {
GZIPInputStream zipin = new GZIPInputStream(new FileInputStream(fileName));
BufferedReader reader = new BufferedReader(new InputStreamReader(zipin));
String line;
line = reader.readLine().trim();
if (!line.equals("WeightVector")) {
reader.close();
throw new IOException("Invalid model file.");
}
line = reader.readLine().trim();
int size = Integer.parseInt(line);
WeightVector w = new WeightVector(size);
while ((line = reader.readLine()) != null) {
line = line.trim();
String[] parts = line.split(":");
int index = Integer.parseInt(parts[0]);
float value = Float.parseFloat(parts[1]);
w.setElement(index, value);
}
zipin.close();
return w;
} catch (Exception e) {
log.error("Error loading model file {}", fileName);
System.exit(-1);
}
return null;
}
use of edu.illinois.cs.cogcomp.sl.util.WeightVector in project cogcomp-nlp by CogComp.
the class WeightVectorUtils method loadWeightVectorFromClassPath.
public static WeightVector loadWeightVectorFromClassPath(String fileName) {
try {
Class<WeightVectorUtils> clazz = WeightVectorUtils.class;
List<URL> list = IOUtils.lsResources(clazz, fileName);
if (list.size() == 0) {
log.error("File {} not found on the classpath", fileName);
throw new Exception("File not found on classpath");
}
InputStream stream = list.get(0).openStream();
GZIPInputStream zipin = new GZIPInputStream(stream);
BufferedReader reader = new BufferedReader(new InputStreamReader(zipin));
String line;
line = reader.readLine().trim();
if (!line.equals("WeightVector")) {
reader.close();
throw new IOException("Invalid model file.");
}
line = reader.readLine().trim();
int size = Integer.parseInt(line);
WeightVector w = new WeightVector(size);
while ((line = reader.readLine()) != null) {
line = line.trim();
String[] parts = line.split(":");
int index = Integer.parseInt(parts[0]);
float value = Float.parseFloat(parts[1]);
w.setElement(index, value);
}
zipin.close();
return w;
} catch (Exception e) {
log.error("Error loading model file {}", fileName);
System.exit(-1);
}
return null;
}
Aggregations