Search in sources :

Example 1 with SimpleTextEncoder

use of ai.djl.basicmodelzoo.nlp.SimpleTextEncoder in project djl by deepjavalibrary.

the class TrainSeq2Seq method getSeq2SeqModel.

private static Block getSeq2SeqModel(TrainableTextEmbedding sourceEmbedding, TrainableTextEmbedding targetEmbedding, long vocabSize) {
    SimpleTextEncoder simpleTextEncoder = new SimpleTextEncoder(sourceEmbedding, new LSTM.Builder().setStateSize(32).setNumLayers(2).optDropRate(0).optBatchFirst(true).optReturnState(true).build());
    SimpleTextDecoder simpleTextDecoder = new SimpleTextDecoder(targetEmbedding, new LSTM.Builder().setStateSize(32).setNumLayers(2).optDropRate(0).optBatchFirst(true).optReturnState(false).build(), vocabSize);
    return new EncoderDecoder(simpleTextEncoder, simpleTextDecoder);
}
Also used : LSTM(ai.djl.nn.recurrent.LSTM) SimpleTextDecoder(ai.djl.basicmodelzoo.nlp.SimpleTextDecoder) EncoderDecoder(ai.djl.modality.nlp.EncoderDecoder) SimpleTextEncoder(ai.djl.basicmodelzoo.nlp.SimpleTextEncoder)

Example 2 with SimpleTextEncoder

use of ai.djl.basicmodelzoo.nlp.SimpleTextEncoder in project djl by deepjavalibrary.

the class SimpleTextEncoderTest method testEncoder.

@Test
public void testEncoder() {
    TrainableTextEmbedding trainableTextEmbedding = new TrainableTextEmbedding(TrainableWordEmbedding.builder().setEmbeddingSize(8).setVocabulary(new DefaultVocabulary(Arrays.asList("1 2 3 4 5 6 7 8 9 10".split(" ")))).build());
    SimpleTextEncoder encoder = new SimpleTextEncoder(trainableTextEmbedding, LSTM.builder().setNumLayers(2).setStateSize(16).optBatchFirst(true).optReturnState(true).build());
    try (NDManager manager = NDManager.newBaseManager(TestUtils.getDevices()[0])) {
        encoder.initialize(manager, DataType.FLOAT32, new Shape(4, 7));
        NDList output = encoder.forward(new ParameterStore(manager, false), new NDList(manager.zeros(new Shape(4, 7), DataType.INT64)), false);
        Assert.assertEquals(output.head().getShape(), new Shape(4, 7, 16));
        Assert.assertEquals(output.size(), 3);
        Assert.assertEquals(output.get(1).getShape(), new Shape(2, 4, 16));
        Assert.assertEquals(output.get(2).getShape(), new Shape(2, 4, 16));
    }
}
Also used : Shape(ai.djl.ndarray.types.Shape) ParameterStore(ai.djl.training.ParameterStore) NDList(ai.djl.ndarray.NDList) DefaultVocabulary(ai.djl.modality.nlp.DefaultVocabulary) NDManager(ai.djl.ndarray.NDManager) TrainableTextEmbedding(ai.djl.modality.nlp.embedding.TrainableTextEmbedding) SimpleTextEncoder(ai.djl.basicmodelzoo.nlp.SimpleTextEncoder) Test(org.testng.annotations.Test)

Aggregations

SimpleTextEncoder (ai.djl.basicmodelzoo.nlp.SimpleTextEncoder)2 SimpleTextDecoder (ai.djl.basicmodelzoo.nlp.SimpleTextDecoder)1 DefaultVocabulary (ai.djl.modality.nlp.DefaultVocabulary)1 EncoderDecoder (ai.djl.modality.nlp.EncoderDecoder)1 TrainableTextEmbedding (ai.djl.modality.nlp.embedding.TrainableTextEmbedding)1 NDList (ai.djl.ndarray.NDList)1 NDManager (ai.djl.ndarray.NDManager)1 Shape (ai.djl.ndarray.types.Shape)1 LSTM (ai.djl.nn.recurrent.LSTM)1 ParameterStore (ai.djl.training.ParameterStore)1 Test (org.testng.annotations.Test)1