Search in sources :

Example 1 with WordEmbeddingModel

use of hex.genmodel.algos.word2vec.WordEmbeddingModel in project h2o-3 by h2oai.

the class EasyPredictModelWrapper method predictWord2Vec.

/**
   * Lookup word embeddings for a given word (or set of words).
   * @param data RawData structure, every key with a String value will be translated to an embedding
   * @return The prediction
   * @throws PredictException if model is not a WordEmbedding model
   */
public Word2VecPrediction predictWord2Vec(RowData data) throws PredictException {
    validateModelCategory(ModelCategory.WordEmbedding);
    if (!(m instanceof WordEmbeddingModel))
        throw new PredictException("Model is not of the expected type, class = " + m.getClass().getSimpleName());
    final WordEmbeddingModel weModel = (WordEmbeddingModel) m;
    final int vecSize = weModel.getVecSize();
    HashMap<String, float[]> embeddings = new HashMap<>(data.size());
    for (String wordKey : data.keySet()) {
        Object value = data.get(wordKey);
        if (value instanceof String) {
            String word = (String) value;
            embeddings.put(wordKey, weModel.transform0(word, new float[vecSize]));
        }
    }
    Word2VecPrediction p = new Word2VecPrediction();
    p.wordEmbeddings = embeddings;
    return p;
}
Also used : ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) HashMap(java.util.HashMap) WordEmbeddingModel(hex.genmodel.algos.word2vec.WordEmbeddingModel) PredictException(hex.genmodel.easy.exception.PredictException)

Aggregations

WordEmbeddingModel (hex.genmodel.algos.word2vec.WordEmbeddingModel)1 PredictException (hex.genmodel.easy.exception.PredictException)1 HashMap (java.util.HashMap)1 ConcurrentHashMap (java.util.concurrent.ConcurrentHashMap)1