use of org.dkpro.tc.api.type.TextClassificationOutcome in project dkpro-tc by dkpro.
the class TestReaderSingleLabel method getNext.
@Override
public void getNext(CAS aCAS) throws IOException, CollectionException {
super.getNext(aCAS);
JCas jcas;
try {
jcas = aCAS.getJCas();
JCasId id = new JCasId(jcas);
id.setId(jcasId);
id.addToIndexes();
} catch (CASException e) {
throw new CollectionException();
}
TextClassificationOutcome outcome = new TextClassificationOutcome(jcas);
outcome.setOutcome(getTextClassificationOutcome(jcas));
outcome.addToIndexes();
}
use of org.dkpro.tc.api.type.TextClassificationOutcome in project dkpro-tc by dkpro.
the class TestTaskUtils method initJCas.
private JCas initJCas(boolean setUnitIdAsPartOfTheInstanceId) throws Exception {
AnalysisEngine engine = AnalysisEngineFactory.createEngine(NoOpAnnotator.class);
JCas jCas = engine.newJCas();
JCasId id = new JCasId(jCas);
id.setId(4711);
id.addToIndexes();
DocumentMetaData meta = new DocumentMetaData(jCas);
meta.setDocumentTitle("title");
meta.setDocumentId("4711");
meta.addToIndexes();
String[][] tokens = { // sequence 1
{ "a", "DT" }, // sequence 1
{ "car", "NN" }, // sequence 1
{ "drives", "VBZ" }, // sequence 2
{ "the", "DT" }, // sequence 2
{ "hedgehogs", "NN" }, // sequence 2
{ "dies", "VBZ" } };
StringBuilder sb = new StringBuilder();
for (int i = 0; i < tokens.length; i++) {
int start = sb.length();
int end = start + tokens[i][0].length();
TextClassificationTarget unit = new TextClassificationTarget(jCas, start, end);
if (setUnitIdAsPartOfTheInstanceId) {
unit.setSuffix(tokens[i][0]);
}
unit.setId(i);
unit.addToIndexes();
TextClassificationOutcome outcome = new TextClassificationOutcome(jCas, start, end);
outcome.setOutcome(tokens[i][1]);
outcome.addToIndexes();
sb.append(tokens[i][0]);
if (i + 1 < tokens.length) {
sb.append(" ");
}
}
String text = sb.toString();
jCas.setDocumentText(text);
int lenSeq1 = tokens[0][0].length() + 1 + tokens[1][0].length() + 1 + tokens[2][0].length();
TextClassificationSequence seq1 = new TextClassificationSequence(jCas, 0, lenSeq1);
seq1.addToIndexes();
TextClassificationSequence seq2 = new TextClassificationSequence(jCas, lenSeq1 + 1, text.length());
seq2.addToIndexes();
return jCas;
}
use of org.dkpro.tc.api.type.TextClassificationOutcome in project dkpro-tc by dkpro.
the class TcAnnotator method callConversionEngine.
private void callConversionEngine(JCas aJCas) throws AnalysisEngineProcessException {
String name = conversionAnnotator[0];
Object[] parameters = new String[0];
if (conversionAnnotator.length > 1) {
parameters = new String[conversionAnnotator.length - 1];
System.arraycopy(conversionAnnotator, 1, parameters, 0, conversionAnnotator.length - 1);
}
try {
@SuppressWarnings("unchecked") Class<? extends AnalysisComponent> forName = (Class<? extends AnalysisComponent>) Class.forName(name);
AnalysisEngine conversionEngine = AnalysisEngineFactory.createEngine(forName, parameters);
conversionEngine.process(aJCas);
} catch (Exception e) {
throw new AnalysisEngineProcessException(e);
}
for (TextClassificationOutcome o : JCasUtil.select(aJCas, TextClassificationOutcome.class)) {
o.removeFromIndexes();
}
}
use of org.dkpro.tc.api.type.TextClassificationOutcome in project dkpro-tc by dkpro.
the class TcAnnotator method addTCUnitAndOutcomeAnnotation.
private void addTCUnitAndOutcomeAnnotation(JCas aJCas) {
Type type = aJCas.getCas().getTypeSystem().getType(nameUnit);
Collection<AnnotationFS> unitAnnotation = CasUtil.select(aJCas.getCas(), type);
for (AnnotationFS unit : unitAnnotation) {
TextClassificationTarget tcs = new TextClassificationTarget(aJCas, unit.getBegin(), unit.getEnd());
tcs.addToIndexes();
TextClassificationOutcome tco = new TextClassificationOutcome(aJCas, unit.getBegin(), unit.getEnd());
tco.setOutcome(Constants.TC_OUTCOME_DUMMY_VALUE);
tco.addToIndexes();
}
}
use of org.dkpro.tc.api.type.TextClassificationOutcome in project dkpro-tc by dkpro.
the class WekaLoadModelConnector method process.
@Override
public void process(JCas jcas) throws AnalysisEngineProcessException {
Instance instance = null;
try {
InstanceExtractor extractor = new InstanceExtractor(featureMode, featureExtractors, false);
List<Instance> instances = extractor.getInstances(jcas, useSparse);
instance = instances.get(0);
} catch (Exception e1) {
throw new AnalysisEngineProcessException(e1);
}
boolean isMultiLabel = learningMode.equals(Constants.LM_MULTI_LABEL);
boolean isRegression = learningMode.equals(Constants.LM_REGRESSION);
if (!isMultiLabel) {
// single-label
weka.core.Instance wekaInstance = null;
try {
wekaInstance = WekaUtils.tcInstanceToWekaInstance(instance, trainingData, classLabels, isRegression);
} catch (Exception e) {
throw new AnalysisEngineProcessException(e);
}
Object val = null;
try {
if (!isRegression) {
val = classLabels.get((int) cls.classifyInstance(wekaInstance));
} else {
val = cls.classifyInstance(wekaInstance);
}
} catch (Exception e) {
throw new AnalysisEngineProcessException(e);
}
TextClassificationOutcome outcome = getOutcome(jcas);
outcome.setOutcome(val.toString());
} else {
// multi-label
weka.core.Instance mekaInstance = null;
try {
mekaInstance = WekaUtils.tcInstanceToMekaInstance(instance, trainingData, classLabels);
} catch (Exception e) {
throw new AnalysisEngineProcessException(e);
}
double[] vals = null;
try {
vals = cls.distributionForInstance(mekaInstance);
} catch (Exception e) {
throw new AnalysisEngineProcessException(e);
}
List<String> outcomes = new ArrayList<String>();
for (int i = 0; i < vals.length; i++) {
if (vals[i] >= Double.valueOf(bipartitionThreshold)) {
String label = mekaInstance.attribute(i).name().split(WekaDataWriter.CLASS_ATTRIBUTE_PREFIX)[1];
outcomes.add(label);
}
}
// TextClassificationFocus focus = null;
if (FM_DOCUMENT.equals(featureMode) || FM_PAIR.equals(featureMode)) {
Collection<TextClassificationOutcome> oldOutcomes = JCasUtil.select(jcas, TextClassificationOutcome.class);
List<Annotation> annotationsList = new ArrayList<Annotation>();
for (TextClassificationOutcome oldOutcome : oldOutcomes) {
annotationsList.add(oldOutcome);
}
for (Annotation annotation : annotationsList) {
annotation.removeFromIndexes();
}
} else {
TextClassificationOutcome annotation = getOutcome(jcas);
annotation.removeFromIndexes();
// focus = JCasUtil.selectSingle(jcas, TextClassificationFocus.class);
}
if (outcomes.size() > 0) {
TextClassificationOutcome newOutcome = new TextClassificationOutcome(jcas);
newOutcome.setOutcome(outcomes.get(0));
newOutcome.addToIndexes();
}
if (outcomes.size() > 1) {
// add more outcome annotations
try {
for (int i = 1; i < outcomes.size(); i++) {
TextClassificationOutcome newOutcome = new TextClassificationOutcome(jcas);
newOutcome.setOutcome(outcomes.get(i));
newOutcome.addToIndexes();
}
} catch (Exception ex) {
String msg = "Error while trying to retrieve TC focus from CAS. Details: " + ex.getMessage();
Logger.getLogger(getClass()).error(msg, ex);
throw new RuntimeException(msg, ex);
}
}
}
}
Aggregations