Search in sources :

Example 1 with WORKSPACE_NAME

use of ml.comet.experiment.impl.ExperimentTestFactory.WORKSPACE_NAME in project comet-java-sdk by comet-ml.

the class ArtifactSupportTest method testLogAndGetArtifact.

@Test
public void testLogAndGetArtifact() {
    try (OnlineExperimentImpl experiment = (OnlineExperimentImpl) createOnlineExperiment()) {
        ArtifactImpl artifact = createArtifact();
        // add remote assets
        // 
        URI firstAssetLink = new URI("s3://bucket/folder/firstAssetFile.extension");
        String firstAssetFileName = "firstAssetFileName";
        artifact.addRemoteAsset(firstAssetLink, firstAssetFileName);
        String secondAssetExpectedFileName = "secondAssetFile.extension";
        URI secondAssetLink = new URI("s3://bucket/folder/" + secondAssetExpectedFileName);
        artifact.addRemoteAsset(secondAssetLink, secondAssetExpectedFileName);
        // add local assets
        // 
        artifact.addAsset(Objects.requireNonNull(TestUtils.getFile(IMAGE_FILE_NAME)), IMAGE_FILE_NAME, false, SOME_METADATA);
        artifact.addAsset(Objects.requireNonNull(TestUtils.getFile(CODE_FILE_NAME)), CODE_FILE_NAME, false);
        byte[] someData = "some data".getBytes(StandardCharsets.UTF_8);
        String someDataName = "someDataName";
        artifact.addAsset(someData, someDataName);
        // add assets folder
        // 
        artifact.addAssetFolder(assetsFolder.toFile(), true, true);
        // the logged artifact validator
        Function4<LoggedArtifact, ArtifactImpl, String, List<String>, Void> loggedArtifactValidator = (actual, original, experimentKey, expectedAliases) -> {
            assertNotNull(actual, "logged artifact expected");
            assertEquals(original.getType(), actual.getArtifactType(), "wrong artifact type");
            assertEquals(new HashSet<>(expectedAliases), actual.getAliases(), "wrong aliases");
            assertEquals(SOME_METADATA, actual.getMetadata(), "wrong metadata");
            assertEquals(new HashSet<>(original.getVersionTags()), actual.getVersionTags(), "wrong version tags");
            assertEquals(WORKSPACE_NAME, actual.getWorkspace(), "wrong workspace");
            assertEquals(experimentKey, actual.getSourceExperimentKey(), "wrong experiment key");
            assertEquals(original.getName(), actual.getName(), "wrong artifact name");
            return null;
        };
        // check artifacts-in-progress counter before
        assertEquals(0, experiment.getArtifactsInProgress().get(), "artifacts-in-progress counter must be zero at start");
        // log artifact and check results
        // 
        CompletableFuture<LoggedArtifact> futureArtifact = experiment.logArtifact(artifact);
        // check artifacts-in-progress counter while in progress
        assertEquals(1, experiment.getArtifactsInProgress().get(), "artifacts-in-progress counter has wrong value while still in progress");
        LoggedArtifact loggedArtifact = futureArtifact.get(60, SECONDS);
        // check artifacts-in-progress counter after
        Awaitility.await("artifacts-in-progress counter must be decreased").pollInterval(10, TimeUnit.MILLISECONDS).atMost(1, TimeUnit.SECONDS).until(() -> experiment.getArtifactsInProgress().get() == 0);
        assertEquals(0, experiment.getArtifactsInProgress().get(), "artifacts-in-progress counter must be zero after log operation completed");
        List<String> expectedAliases = new ArrayList<>(artifact.getAliases());
        loggedArtifactValidator.apply(loggedArtifact, artifact, experiment.getExperimentKey(), expectedAliases);
        // get artifact details from server and check its correctness
        // 
        LoggedArtifact loggedArtifactFromServer = experiment.getArtifact(loggedArtifact.getName(), loggedArtifact.getWorkspace(), loggedArtifact.getVersion());
        // added by the backend automatically
        expectedAliases.add(ALIAS_LATEST);
        loggedArtifactValidator.apply(loggedArtifactFromServer, artifact, experiment.getExperimentKey(), expectedAliases);
        // check that correct assets was logged
        // 
        Collection<LoggedArtifactAsset> loggedAssets = loggedArtifactFromServer.getAssets();
        Collection<ArtifactAsset> assets = artifact.getAssets();
        assertEquals(assets.size(), loggedAssets.size(), "wrong size");
        loggedAssets.forEach(loggedArtifactAsset -> validateArtifactAsset(new ArtifactAssetImpl((LoggedArtifactAssetImpl) loggedArtifactAsset), assets));
    } catch (Throwable t) {
        fail(t);
    }
}
Also used : Arrays(java.util.Arrays) ArtifactAssetNotFoundException(ml.comet.experiment.artifact.ArtifactAssetNotFoundException) LoggedArtifactAsset(ml.comet.experiment.artifact.LoggedArtifactAsset) LoggedArtifact(ml.comet.experiment.artifact.LoggedArtifact) ArtifactException(ml.comet.experiment.artifact.ArtifactException) Assertions.assertFalse(org.junit.jupiter.api.Assertions.assertFalse) Map(java.util.Map) FAILED_TO_DOWNLOAD_ARTIFACT_ASSETS(ml.comet.experiment.impl.resources.LogMessages.FAILED_TO_DOWNLOAD_ARTIFACT_ASSETS) Tag(org.junit.jupiter.api.Tag) URI(java.net.URI) ArtifactAsset(ml.comet.experiment.artifact.ArtifactAsset) Path(java.nio.file.Path) LogMessages.getString(ml.comet.experiment.impl.resources.LogMessages.getString) Collection(java.util.Collection) Set(java.util.Set) UNKNOWN(ml.comet.experiment.impl.asset.AssetType.UNKNOWN) Collectors(java.util.stream.Collectors) StandardCharsets(java.nio.charset.StandardCharsets) Test(org.junit.jupiter.api.Test) Objects(java.util.Objects) IOUtils(org.apache.commons.io.IOUtils) List(java.util.List) Stream(java.util.stream.Stream) Assertions.assertTrue(org.junit.jupiter.api.Assertions.assertTrue) Awaitility(org.awaitility.Awaitility) Assertions.assertThrows(org.junit.jupiter.api.Assertions.assertThrows) Assertions.fail(org.junit.jupiter.api.Assertions.fail) Assertions.assertNotNull(org.junit.jupiter.api.Assertions.assertNotNull) ByteArrayOutputStream(java.io.ByteArrayOutputStream) SOME_METADATA(ml.comet.experiment.impl.ArtifactImplTest.SOME_METADATA) AtomicBoolean(java.util.concurrent.atomic.AtomicBoolean) ASSET(ml.comet.experiment.impl.asset.AssetType.ASSET) HashMap(java.util.HashMap) CompletableFuture(java.util.concurrent.CompletableFuture) ArtifactAssetImpl(ml.comet.experiment.impl.asset.ArtifactAssetImpl) ArrayList(java.util.ArrayList) Function4(io.reactivex.rxjava3.functions.Function4) HashSet(java.util.HashSet) AssetOverwriteStrategy(ml.comet.experiment.artifact.AssetOverwriteStrategy) Artifact(ml.comet.experiment.artifact.Artifact) REMOTE_ASSET_CANNOT_BE_DOWNLOADED(ml.comet.experiment.impl.resources.LogMessages.REMOTE_ASSET_CANNOT_BE_DOWNLOADED) Assertions.assertEquals(org.junit.jupiter.api.Assertions.assertEquals) ExperimentTestFactory.createOnlineExperiment(ml.comet.experiment.impl.ExperimentTestFactory.createOnlineExperiment) DownloadedArtifact(ml.comet.experiment.artifact.DownloadedArtifact) Files(java.nio.file.Files) IOException(java.io.IOException) File(java.io.File) DisplayName(org.junit.jupiter.api.DisplayName) TimeUnit(java.util.concurrent.TimeUnit) Assertions.assertArrayEquals(org.junit.jupiter.api.Assertions.assertArrayEquals) WORKSPACE_NAME(ml.comet.experiment.impl.ExperimentTestFactory.WORKSPACE_NAME) PathUtils(org.apache.commons.io.file.PathUtils) Timeout(org.junit.jupiter.api.Timeout) Collections(java.util.Collections) SECONDS(java.util.concurrent.TimeUnit.SECONDS) InputStream(java.io.InputStream) FAILED_TO_FIND_ASSET_IN_ARTIFACT(ml.comet.experiment.impl.resources.LogMessages.FAILED_TO_FIND_ASSET_IN_ARTIFACT) LoggedArtifact(ml.comet.experiment.artifact.LoggedArtifact) LoggedArtifactAsset(ml.comet.experiment.artifact.LoggedArtifactAsset) ArtifactAsset(ml.comet.experiment.artifact.ArtifactAsset) ArrayList(java.util.ArrayList) LogMessages.getString(ml.comet.experiment.impl.resources.LogMessages.getString) URI(java.net.URI) LoggedArtifactAsset(ml.comet.experiment.artifact.LoggedArtifactAsset) ArtifactAssetImpl(ml.comet.experiment.impl.asset.ArtifactAssetImpl) List(java.util.List) ArrayList(java.util.ArrayList) HashSet(java.util.HashSet) Test(org.junit.jupiter.api.Test)

Aggregations

Function4 (io.reactivex.rxjava3.functions.Function4)1 ByteArrayOutputStream (java.io.ByteArrayOutputStream)1 File (java.io.File)1 IOException (java.io.IOException)1 InputStream (java.io.InputStream)1 URI (java.net.URI)1 StandardCharsets (java.nio.charset.StandardCharsets)1 Files (java.nio.file.Files)1 Path (java.nio.file.Path)1 ArrayList (java.util.ArrayList)1 Arrays (java.util.Arrays)1 Collection (java.util.Collection)1 Collections (java.util.Collections)1 HashMap (java.util.HashMap)1 HashSet (java.util.HashSet)1 List (java.util.List)1 Map (java.util.Map)1 Objects (java.util.Objects)1 Set (java.util.Set)1 CompletableFuture (java.util.concurrent.CompletableFuture)1