Search in sources :

Example 6 with Artifact

use of ai.djl.repository.Artifact in project djl by deepjavalibrary.

the class GoEmotions method prepare.

/**
 * Prepares the dataset for use with tracked progress. In this method the TSV file will be
 * parsed. All datasets will be preprocessed.
 *
 * @param progress the progress tracker
 * @throws IOException for various exceptions depending on the dataset
 */
@Override
public void prepare(Progress progress) throws IOException, EmbeddingException {
    if (prepared) {
        return;
    }
    Artifact artifact = mrl.getDefaultArtifact();
    mrl.prepare(artifact, progress);
    Path root = mrl.getRepository().getResourceDirectory(artifact);
    Path csvFile;
    switch(usage) {
        case TRAIN:
            csvFile = root.resolve("train.tsv");
            break;
        case TEST:
            csvFile = root.resolve("test.tsv");
            break;
        case VALIDATION:
            csvFile = root.resolve("dev.tsv");
            break;
        default:
            throw new UnsupportedOperationException("Data not available.");
    }
    CSVFormat csvFormat = CSVFormat.TDF.builder().setQuote(null).setHeader(HeaderEnum.class).build();
    URL csvUrl = csvFile.toUri().toURL();
    List<CSVRecord> csvRecords;
    List<String> sourceTextData = new ArrayList<>();
    try (Reader reader = new InputStreamReader(new BufferedInputStream(csvUrl.openStream()), StandardCharsets.UTF_8)) {
        CSVParser csvParser = new CSVParser(reader, csvFormat);
        csvRecords = csvParser.getRecords();
    }
    for (CSVRecord csvRecord : csvRecords) {
        sourceTextData.add(csvRecord.get(0));
        String[] labels = csvRecord.get(1).split(",");
        int[] labelInt = new int[labels.length];
        for (int i = 0; i < labels.length; i++) {
            labelInt[i] = Integer.parseInt(labels[i]);
        }
        targetData.add(labelInt);
    }
    preprocess(sourceTextData, true);
    prepared = true;
}
Also used : Path(java.nio.file.Path) InputStreamReader(java.io.InputStreamReader) ArrayList(java.util.ArrayList) Reader(java.io.Reader) InputStreamReader(java.io.InputStreamReader) Artifact(ai.djl.repository.Artifact) URL(java.net.URL) BufferedInputStream(java.io.BufferedInputStream) CSVParser(org.apache.commons.csv.CSVParser) CSVFormat(org.apache.commons.csv.CSVFormat) CSVRecord(org.apache.commons.csv.CSVRecord)

Example 7 with Artifact

use of ai.djl.repository.Artifact in project djl by deepjavalibrary.

the class StanfordMovieReview method prepare.

/**
 * {@inheritDoc}
 */
@Override
public void prepare(Progress progress) throws IOException, EmbeddingException {
    if (prepared) {
        return;
    }
    Artifact artifact = mrl.getDefaultArtifact();
    mrl.prepare(artifact, progress);
    Path cacheDir = mrl.getRepository().getCacheDirectory();
    URI resourceUri = artifact.getResourceUri();
    Path root = cacheDir.resolve(resourceUri.getPath()).resolve("aclImdb").resolve("aclImdb");
    Path usagePath;
    switch(usage) {
        case TRAIN:
            usagePath = Paths.get("train");
            break;
        case TEST:
            usagePath = Paths.get("test");
            break;
        case VALIDATION:
        default:
            throw new UnsupportedOperationException("Validation data not available.");
    }
    usagePath = root.resolve(usagePath);
    List<String> reviewTexts = new ArrayList<>();
    reviewSentiments = new ArrayList<>();
    reviewImdbScore = new ArrayList<>();
    prepareDataSentiment(usagePath.resolve("pos"), true, reviewTexts);
    prepareDataSentiment(usagePath.resolve("neg"), false, reviewTexts);
    preprocess(reviewTexts, true);
    prepared = true;
}
Also used : Path(java.nio.file.Path) ArrayList(java.util.ArrayList) URI(java.net.URI) Artifact(ai.djl.repository.Artifact)

Example 8 with Artifact

use of ai.djl.repository.Artifact in project djl by deepjavalibrary.

the class StanfordQuestionAnsweringDataset method prepareUsagePath.

private Path prepareUsagePath(Progress progress) throws IOException {
    Artifact artifact = mrl.getDefaultArtifact();
    mrl.prepare(artifact, progress);
    Path root = mrl.getRepository().getResourceDirectory(artifact);
    Path usagePath;
    switch(usage) {
        case TRAIN:
            usagePath = Paths.get("train-v2.0.json");
            break;
        case TEST:
            usagePath = Paths.get("dev-v2.0.json");
            break;
        case VALIDATION:
        default:
            throw new UnsupportedOperationException("Validation data not available.");
    }
    return root.resolve(usagePath);
}
Also used : Path(java.nio.file.Path) Artifact(ai.djl.repository.Artifact)

Example 9 with Artifact

use of ai.djl.repository.Artifact in project djl by deepjavalibrary.

the class UniversalDependenciesEnglishEWT method prepare.

/**
 * Prepares the dataset for use with tracked progress. In this method the TXT file will be
 * parsed. The texts will be added to {@code sourceTextData} and the Universal POS tags will be
 * added to {@code universalPosTags}. Only {@code sourceTextData} will then be preprocessed.
 *
 * @param progress the progress tracker
 * @throws IOException for various exceptions depending on the dataset
 * @throws EmbeddingException if there are exceptions during the embedding process
 */
@Override
public void prepare(Progress progress) throws IOException, EmbeddingException {
    if (prepared) {
        return;
    }
    Artifact artifact = mrl.getDefaultArtifact();
    mrl.prepare(artifact, progress);
    Path root = mrl.getRepository().getResourceDirectory(artifact);
    Path usagePath = null;
    switch(usage) {
        case TRAIN:
            usagePath = Paths.get("en-ud-v2/en-ud-v2/en-ud-tag.v2.train.txt");
            break;
        case TEST:
            usagePath = Paths.get("en-ud-v2/en-ud-v2/en-ud-tag.v2.test.txt");
            break;
        case VALIDATION:
            usagePath = Paths.get("en-ud-v2/en-ud-v2/en-ud-tag.v2.dev.txt");
            break;
        default:
            break;
    }
    usagePath = root.resolve(usagePath);
    StringBuilder sourceTextDatum = new StringBuilder();
    List<String> sourceTextData = new ArrayList<>();
    universalPosTags = new ArrayList<>();
    List<Integer> universalPosTag = new ArrayList<>();
    try (BufferedReader reader = Files.newBufferedReader(usagePath)) {
        String row;
        while ((row = reader.readLine()) != null) {
            if (("").equals(row)) {
                sourceTextData.add(sourceTextDatum.toString());
                universalPosTags.add(universalPosTag);
                sourceTextDatum.delete(0, sourceTextDatum.length());
                universalPosTag = new ArrayList<>();
                continue;
            }
            String[] splits = row.split("\t");
            if (sourceTextDatum.length() != 0) {
                sourceTextDatum.append(' ');
            }
            sourceTextDatum.append(splits[0]);
            universalPosTag.add(UniversalPosTag.valueOf(splits[1]).ordinal());
        }
    }
    preprocess(sourceTextData, true);
    prepared = true;
}
Also used : Path(java.nio.file.Path) ArrayList(java.util.ArrayList) BufferedReader(java.io.BufferedReader) Artifact(ai.djl.repository.Artifact)

Example 10 with Artifact

use of ai.djl.repository.Artifact in project djl by deepjavalibrary.

the class Cifar10 method prepare.

/**
 * {@inheritDoc}
 */
@Override
public void prepare(Progress progress) throws IOException {
    if (prepared) {
        return;
    }
    Artifact artifact = mrl.getDefaultArtifact();
    mrl.prepare(artifact, progress);
    Map<String, Artifact.Item> map = artifact.getFiles();
    Artifact.Item item;
    switch(usage) {
        case TRAIN:
            item = map.get("data_batch.bin");
            break;
        case TEST:
            item = map.get("test_batch.bin");
            break;
        case VALIDATION:
        default:
            throw new UnsupportedOperationException("Validation data not available.");
    }
    NDArray dataAndLabels = readData(item);
    data = new NDArray[] { dataAndLabels.get(":, 1:").reshape(-1, 3, IMAGE_HEIGHT, IMAGE_WIDTH).transpose(0, 2, 3, 1) };
    labels = new NDArray[] { dataAndLabels.get(":,0") };
    // check if data and labels have the same size
    if (data[0].size(0) != labels[0].size(0)) {
        throw new IOException("the size of data " + data[0].size(0) + " didn't match with the size of labels " + labels[0].size(0));
    }
    prepared = true;
}
Also used : NDArray(ai.djl.ndarray.NDArray) IOException(java.io.IOException) Artifact(ai.djl.repository.Artifact)

Aggregations

Artifact (ai.djl.repository.Artifact)40 Path (java.nio.file.Path)20 MRL (ai.djl.repository.MRL)10 ArrayList (java.util.ArrayList)9 Test (org.testng.annotations.Test)9 Repository (ai.djl.repository.Repository)8 Metadata (ai.djl.repository.Metadata)7 IOException (java.io.IOException)5 BufferedReader (java.io.BufferedReader)4 List (java.util.List)4 Application (ai.djl.Application)3 Model (ai.djl.Model)3 Rectangle (ai.djl.modality.cv.output.Rectangle)3 Reader (java.io.Reader)3 ConcurrentHashMap (java.util.concurrent.ConcurrentHashMap)3 Point (ai.djl.modality.cv.output.Point)2 PairList (ai.djl.util.PairList)2 Progress (ai.djl.util.Progress)2 Type (java.lang.reflect.Type)2 Map (java.util.Map)2