use of ai.djl.basicmodelzoo.nlp.SimpleTextEncoder in project djl by deepjavalibrary.
the class TrainSeq2Seq method getSeq2SeqModel.
private static Block getSeq2SeqModel(TrainableTextEmbedding sourceEmbedding, TrainableTextEmbedding targetEmbedding, long vocabSize) {
SimpleTextEncoder simpleTextEncoder = new SimpleTextEncoder(sourceEmbedding, new LSTM.Builder().setStateSize(32).setNumLayers(2).optDropRate(0).optBatchFirst(true).optReturnState(true).build());
SimpleTextDecoder simpleTextDecoder = new SimpleTextDecoder(targetEmbedding, new LSTM.Builder().setStateSize(32).setNumLayers(2).optDropRate(0).optBatchFirst(true).optReturnState(false).build(), vocabSize);
return new EncoderDecoder(simpleTextEncoder, simpleTextDecoder);
}
use of ai.djl.basicmodelzoo.nlp.SimpleTextEncoder in project djl by deepjavalibrary.
the class SimpleTextEncoderTest method testEncoder.
@Test
public void testEncoder() {
TrainableTextEmbedding trainableTextEmbedding = new TrainableTextEmbedding(TrainableWordEmbedding.builder().setEmbeddingSize(8).setVocabulary(new DefaultVocabulary(Arrays.asList("1 2 3 4 5 6 7 8 9 10".split(" ")))).build());
SimpleTextEncoder encoder = new SimpleTextEncoder(trainableTextEmbedding, LSTM.builder().setNumLayers(2).setStateSize(16).optBatchFirst(true).optReturnState(true).build());
try (NDManager manager = NDManager.newBaseManager(TestUtils.getDevices()[0])) {
encoder.initialize(manager, DataType.FLOAT32, new Shape(4, 7));
NDList output = encoder.forward(new ParameterStore(manager, false), new NDList(manager.zeros(new Shape(4, 7), DataType.INT64)), false);
Assert.assertEquals(output.head().getShape(), new Shape(4, 7, 16));
Assert.assertEquals(output.size(), 3);
Assert.assertEquals(output.get(1).getShape(), new Shape(2, 4, 16));
Assert.assertEquals(output.get(2).getShape(), new Shape(2, 4, 16));
}
}
Aggregations