use of org.vitrivr.cineast.core.data.tag.WeightedTag in project cineast by vitrivr.
the class SegmentTags method getSimilar.
@Override
public List<ScoreElement> getSimilar(String segmentId, ReadableQueryConfig qc) {
List<Map<String, PrimitiveTypeProvider>> rows = this.selector.getRows("id", new StringTypeProvider(segmentId));
if (rows.isEmpty()) {
return Collections.emptyList();
}
ArrayList<WeightedTag> wtags = new ArrayList<>(rows.size());
for (Map<String, PrimitiveTypeProvider> row : rows) {
wtags.add(new IncompleteTag(row.get("tagid").getString(), "", "", row.get("score").getFloat()));
}
return getSimilar(wtags, qc);
}
use of org.vitrivr.cineast.core.data.tag.WeightedTag in project cineast by vitrivr.
the class SegmentTags method getSimilar.
@Override
public List<ScoreElement> getSimilar(SegmentContainer sc, ReadableQueryConfig qc) {
List<Tag> tags = sc.getTags();
if (tags.isEmpty()) {
return Collections.emptyList();
}
ArrayList<WeightedTag> wtags = new ArrayList<>(tags.size());
for (Tag t : tags) {
if (t instanceof WeightedTag) {
wtags.add((WeightedTag) t);
} else {
wtags.add(new IncompleteTag(t));
}
}
return getSimilar(wtags, qc);
}
use of org.vitrivr.cineast.core.data.tag.WeightedTag in project cineast by vitrivr.
the class SegmentTags method getSimilar.
private List<ScoreElement> getSimilar(Iterable<WeightedTag> tags, ReadableQueryConfig qc) {
ArrayList<String> tagids = new ArrayList<>();
TObjectFloatHashMap<String> tagWeights = new TObjectFloatHashMap<>();
float weightSum = 0f;
/* Sum weights for normalization at a later point*/
for (WeightedTag wt : tags) {
tagids.add(wt.getId());
tagWeights.put(wt.getId(), wt.getWeight());
if (wt.getWeight() > 1) {
LOGGER.error("Weight is > 1 -- this makes little sense.");
}
weightSum += Math.min(1, wt.getWeight());
}
if (tagids.isEmpty() || weightSum <= 0f) {
return Collections.emptyList();
}
/* Retrieve all elements matching the provided ids */
List<Map<String, PrimitiveTypeProvider>> rows = this.selector.getRows("tagid", tagids.stream().map(StringTypeProvider::new).collect(Collectors.toList()));
Map<String, TObjectFloatHashMap<String>> maxScoreByTag = new HashMap<>();
/* Prepare the set of relevant ids (if this entity is used for filtering at a later stage) */
Set<String> relevant = null;
if (qc != null && qc.hasRelevantSegmentIds()) {
relevant = qc.getRelevantSegmentIds();
}
/* Iterate over all matches */
for (Map<String, PrimitiveTypeProvider> row : rows) {
String segmentId = row.get("id").getString();
/* Skip segments which are not desired by the query-config */
if (relevant != null && !relevant.contains(segmentId)) {
continue;
}
String tagid = row.get("tagid").getString();
float score = row.get("score").getFloat() * (tagWeights.containsKey(tagid) ? tagWeights.get(tagid) : 0f);
if (score > 1) {
LOGGER.warn("Score is larger than 1 - this makes little sense");
score = 1f;
}
/* Update maximum score by tag*/
maxScoreByTag.putIfAbsent(segmentId, new TObjectFloatHashMap<>());
float prev = maxScoreByTag.get(segmentId).get(tagid);
if (prev == Constants.DEFAULT_FLOAT_NO_ENTRY_VALUE) {
maxScoreByTag.get(segmentId).put(tagid, score);
} else {
maxScoreByTag.get(segmentId).put(tagid, Math.max(score, prev));
}
}
ArrayList<ScoreElement> _return = new ArrayList<>();
final float normalizer = weightSum;
/* per segment, the max score for all tags is summed and divided by the normalizer */
maxScoreByTag.forEach((segmentId, tagScores) -> _return.add(new SegmentScoreElement(segmentId, MathHelper.sum(tagScores.values()) / normalizer)));
return _return;
}
Aggregations