use of hex.genmodel.algos.word2vec.WordEmbeddingModel in project h2o-3 by h2oai.
the class EasyPredictModelWrapper method predictWord2Vec.
/**
* Lookup word embeddings for a given word (or set of words).
* @param data RawData structure, every key with a String value will be translated to an embedding
* @return The prediction
* @throws PredictException if model is not a WordEmbedding model
*/
public Word2VecPrediction predictWord2Vec(RowData data) throws PredictException {
validateModelCategory(ModelCategory.WordEmbedding);
if (!(m instanceof WordEmbeddingModel))
throw new PredictException("Model is not of the expected type, class = " + m.getClass().getSimpleName());
final WordEmbeddingModel weModel = (WordEmbeddingModel) m;
final int vecSize = weModel.getVecSize();
HashMap<String, float[]> embeddings = new HashMap<>(data.size());
for (String wordKey : data.keySet()) {
Object value = data.get(wordKey);
if (value instanceof String) {
String word = (String) value;
embeddings.put(wordKey, weModel.transform0(word, new float[vecSize]));
}
}
Word2VecPrediction p = new Word2VecPrediction();
p.wordEmbeddings = embeddings;
return p;
}
Aggregations