Search in sources :

Example 1 with ToDoubleFunction

use of java.util.function.ToDoubleFunction in project exotic by forax.

the class ConstantMemoizer method doubleMemoizer.

/**
 * Return a function that returns a constant value (for the Virtual Machine) for each key taken as
 * argument. The value corresponding to a key is calculated by calling the {@code function} once
 * by key and then cached in a code similar to a cascade of {@code if equals else}.
 *
 * <p>To find if a key was previously seen or not, {@link Object#equals(Object)} will be called to
 * compare the actual key with possibly all the keys already seen, so if there are a lot of
 * different keys, the performance in the worst case is like a linear search i.e. O(number of seen
 * keys).
 *
 * @param <K> type of the keys.
 * @param function a function that takes a non null key as argument and return a non null value.
 * @param keyClass the class of the key, if it's a primitive type, the key value will be boxed
 *     before calling the {@code function}.
 * @return a function the function getting the value for a specific key.
 * @throws NullPointerException if the {@code function}, the {@code keyClass} is null, or if the
 *     function key.
 * @throws ClassCastException if the function key types doesn't match the {@code keyClass}.
 */
public static <K> ToDoubleFunction<K> doubleMemoizer(ToDoubleFunction<? super K> function, Class<K> keyClass) {
    Objects.requireNonNull(function);
    Objects.requireNonNull(keyClass);
    MethodHandle mh = new InliningCacheCallSite<>(methodType(double.class, keyClass), function::applyAsDouble).dynamicInvoker().asType(// erase
    methodType(double.class, Object.class));
    return key -> {
        Objects.requireNonNull(key);
        try {
            return (double) mh.invokeExact(key);
        } catch (Throwable e) {
            throw Thrower.rethrow(e);
        }
    };
}
Also used : MethodHandles.guardWithTest(java.lang.invoke.MethodHandles.guardWithTest) MethodType.methodType(java.lang.invoke.MethodType.methodType) MethodHandles.constant(java.lang.invoke.MethodHandles.constant) MethodHandles.dropArguments(java.lang.invoke.MethodHandles.dropArguments) MethodHandle(java.lang.invoke.MethodHandle) ToIntFunction(java.util.function.ToIntFunction) MethodHandles.lookup(java.lang.invoke.MethodHandles.lookup) Function(java.util.function.Function) Objects(java.util.Objects) Lookup(java.lang.invoke.MethodHandles.Lookup) MethodHandles.exactInvoker(java.lang.invoke.MethodHandles.exactInvoker) MethodType(java.lang.invoke.MethodType) ToDoubleFunction(java.util.function.ToDoubleFunction) MethodHandles.foldArguments(java.lang.invoke.MethodHandles.foldArguments) MutableCallSite(java.lang.invoke.MutableCallSite) ToLongFunction(java.util.function.ToLongFunction) MethodHandle(java.lang.invoke.MethodHandle)

Example 2 with ToDoubleFunction

use of java.util.function.ToDoubleFunction in project exotic by forax.

the class StableField method doubleGetter.

/**
 * Create a getter on a field of type {@code double} of a class with a stable semantics.
 *
 * <p>If the field is not initialized or initialized with its default value, the default value
 * will be returned when calling the getter. If the field is initialized with another value than
 * the default value, the getter will return the first value of the field observed by the getter,
 * any subsequent calls to the getter will return this same value.
 *
 * <p>If the getter has observed a value different from the default value, any subsequent calls to
 * the getter need to pass the same object as argument of the getter.
 *
 * <p>This call is equivalent to a call to {@link #getter(Lookup, Class, String, Class)} with
 * {@code double.class} as last argument that returns a getter that doesn't box the return value.
 *
 * @param <T> the type of the object containing the field.
 * @param lookup a lookup object that can access to the field.
 * @param declaringClass the class that declares the field.
 * @param name the name of the field.
 * @return a function that takes an object of the {@code declaring class} and returns the value of
 *     the field.
 * @throws NullPointerException if either the lookup, the declaring class or the name is null.
 * @throws NoSuchFieldError if the field doesn't exist.
 * @throws IllegalAccessError if the field is not accessible from the lookup.
 * @throws IllegalStateException if the argument of the getter is not constant.
 */
public static <T> ToDoubleFunction<T> doubleGetter(Lookup lookup, Class<T> declaringClass, String name) {
    Objects.requireNonNull(lookup);
    Objects.requireNonNull(declaringClass);
    Objects.requireNonNull(name);
    MethodHandle getter = createGetter(lookup, declaringClass, name, double.class);
    MethodHandle mh = new StableFieldCS(getter, double.class).dynamicInvoker();
    return object -> {
        try {
            return (double) mh.invokeExact(object);
        } catch (Throwable t) {
            throw Thrower.rethrow(t);
        }
    };
}
Also used : MethodType.methodType(java.lang.invoke.MethodType.methodType) MethodHandles.constant(java.lang.invoke.MethodHandles.constant) MethodHandles.dropArguments(java.lang.invoke.MethodHandles.dropArguments) MethodHandle(java.lang.invoke.MethodHandle) MethodHandles(java.lang.invoke.MethodHandles) ToIntFunction(java.util.function.ToIntFunction) Function(java.util.function.Function) Objects(java.util.Objects) Lookup(java.lang.invoke.MethodHandles.Lookup) MethodHandles.exactInvoker(java.lang.invoke.MethodHandles.exactInvoker) ToDoubleFunction(java.util.function.ToDoubleFunction) MethodHandles.foldArguments(java.lang.invoke.MethodHandles.foldArguments) MutableCallSite(java.lang.invoke.MutableCallSite) ToLongFunction(java.util.function.ToLongFunction) MethodHandle(java.lang.invoke.MethodHandle)

Example 3 with ToDoubleFunction

use of java.util.function.ToDoubleFunction in project MindsEye by SimiaCryptus.

the class AvgMetaLayer method eval.

@Nonnull
@Override
public Result eval(final Result... inObj) {
    final Result input = inObj[0];
    input.addRef();
    TensorList inputData = input.getData();
    final int itemCnt = inputData.length();
    @Nullable Tensor thisResult;
    boolean passback;
    if (null == lastResult || inputData.length() > minBatchCount) {
        @Nonnull final ToDoubleFunction<Coordinate> f = (c) -> IntStream.range(0, itemCnt).mapToDouble(dataIndex -> {
            Tensor tensor = inputData.get(dataIndex);
            double v = tensor.get(c);
            tensor.freeRef();
            return v;
        }).sum() / itemCnt;
        Tensor tensor = inputData.get(0);
        thisResult = tensor.mapCoords(f);
        tensor.freeRef();
        passback = true;
        if (null != lastResult)
            lastResult.freeRef();
        lastResult = thisResult;
        lastResult.addRef();
    } else {
        passback = false;
        thisResult = lastResult;
        thisResult.freeRef();
    }
    return new Result(TensorArray.create(thisResult), (@Nonnull final DeltaSet<Layer> buffer, @Nonnull final TensorList data) -> {
        if (passback && input.isAlive()) {
            @Nullable final Tensor delta = data.get(0);
            @Nonnull final Tensor[] feedback = new Tensor[itemCnt];
            Arrays.parallelSetAll(feedback, i -> new Tensor(delta.getDimensions()));
            thisResult.coordStream(true).forEach((inputCoord) -> {
                for (int inputItem = 0; inputItem < itemCnt; inputItem++) {
                    feedback[inputItem].add(inputCoord, delta.get(inputCoord) / itemCnt);
                }
            });
            delta.freeRef();
            @Nonnull TensorArray tensorArray = TensorArray.wrap(feedback);
            input.accumulate(buffer, tensorArray);
        }
    }) {

        @Override
        public boolean isAlive() {
            return input.isAlive();
        }

        @Override
        protected void _free() {
            thisResult.freeRef();
            input.freeRef();
        }
    };
}
Also used : IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) Coordinate(com.simiacryptus.mindseye.lang.Coordinate) Arrays(java.util.Arrays) Logger(org.slf4j.Logger) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) Result(com.simiacryptus.mindseye.lang.Result) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) ToDoubleFunction(java.util.function.ToDoubleFunction) TensorList(com.simiacryptus.mindseye.lang.TensorList) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) Result(com.simiacryptus.mindseye.lang.Result) Coordinate(com.simiacryptus.mindseye.lang.Coordinate) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull)

Example 4 with ToDoubleFunction

use of java.util.function.ToDoubleFunction in project MindsEye by SimiaCryptus.

the class BiasMetaLayer method eval.

@Nullable
@Override
public Result eval(@Nonnull final Result... inObj) {
    final int itemCnt = inObj[0].getData().length();
    Tensor tensor1 = inObj[1].getData().get(0);
    final Tensor[] tensors = IntStream.range(0, itemCnt).parallel().mapToObj(dataIndex -> {
        Tensor tensor = inObj[0].getData().get(dataIndex);
        Tensor mapIndex = tensor.mapIndex((v, c) -> {
            return v + tensor1.get(c);
        });
        tensor.freeRef();
        return mapIndex;
    }).toArray(i -> new Tensor[i]);
    tensor1.freeRef();
    Tensor tensor0 = tensors[0];
    tensor0.addRef();
    Arrays.stream(inObj).forEach(nnResult -> nnResult.addRef());
    return new Result(TensorArray.wrap(tensors), (@Nonnull final DeltaSet<Layer> buffer, @Nonnull final TensorList data) -> {
        if (inObj[0].isAlive()) {
            data.addRef();
            inObj[0].accumulate(buffer, data);
        }
        if (inObj[1].isAlive()) {
            @Nonnull final ToDoubleFunction<Coordinate> f = (c) -> {
                return IntStream.range(0, itemCnt).mapToDouble(i -> {
                    Tensor tensor = data.get(i);
                    double v = tensor.get(c);
                    tensor.freeRef();
                    return v;
                }).sum();
            };
            @Nullable final Tensor passback = tensor0.mapCoords(f);
            @Nonnull TensorArray tensorArray = TensorArray.wrap(IntStream.range(0, inObj[1].getData().length()).mapToObj(i -> {
                if (i == 0)
                    return passback;
                else {
                    @Nullable Tensor map = passback.map(v -> 0);
                    passback.freeRef();
                    return map;
                }
            }).toArray(i -> new Tensor[i]));
            inObj[1].accumulate(buffer, tensorArray);
        }
    }) {

        @Override
        protected void _free() {
            tensor0.freeRef();
            Arrays.stream(inObj).forEach(nnResult -> nnResult.freeRef());
        }

        @Override
        public boolean isAlive() {
            return inObj[0].isAlive() || inObj[1].isAlive();
        }
    };
}
Also used : IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) Coordinate(com.simiacryptus.mindseye.lang.Coordinate) Arrays(java.util.Arrays) Logger(org.slf4j.Logger) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) Result(com.simiacryptus.mindseye.lang.Result) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) ToDoubleFunction(java.util.function.ToDoubleFunction) TensorList(com.simiacryptus.mindseye.lang.TensorList) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) Result(com.simiacryptus.mindseye.lang.Result) Coordinate(com.simiacryptus.mindseye.lang.Coordinate) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) Nullable(javax.annotation.Nullable) Nullable(javax.annotation.Nullable)

Example 5 with ToDoubleFunction

use of java.util.function.ToDoubleFunction in project MindsEye by SimiaCryptus.

the class SumMetaLayer method eval.

@Nullable
@Override
public Result eval(@Nonnull final Result... inObj) {
    final Result input = inObj[0];
    Arrays.stream(inObj).forEach(nnResult -> nnResult.addRef());
    final int itemCnt = input.getData().length();
    if (null == lastResult || minBatches < itemCnt) {
        @Nonnull final ToDoubleFunction<Coordinate> f = (c) -> IntStream.range(0, itemCnt).mapToDouble(dataIndex -> input.getData().get(dataIndex).get(c)).sum();
        lastResult = input.getData().get(0).mapCoords(f);
    }
    return new Result(TensorArray.wrap(lastResult), (@Nonnull final DeltaSet<Layer> buffer, @Nonnull final TensorList data) -> {
        if (input.isAlive()) {
            @Nullable final Tensor delta = data.get(0);
            @Nonnull final Tensor[] feedback = new Tensor[itemCnt];
            Arrays.parallelSetAll(feedback, i -> new Tensor(delta.getDimensions()));
            @Nonnull final ToDoubleFunction<Coordinate> f = (inputCoord) -> {
                for (int inputItem = 0; inputItem < itemCnt; inputItem++) {
                    feedback[inputItem].add(inputCoord, delta.get(inputCoord));
                }
                return 0;
            };
            delta.mapCoords(f);
            @Nonnull TensorArray tensorArray = TensorArray.wrap(feedback);
            input.accumulate(buffer, tensorArray);
        }
    }) {

        @Override
        protected void _free() {
            Arrays.stream(inObj).forEach(nnResult -> nnResult.freeRef());
        }

        @Override
        public boolean isAlive() {
            return input.isAlive();
        }
    };
}
Also used : IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) Coordinate(com.simiacryptus.mindseye.lang.Coordinate) Arrays(java.util.Arrays) Logger(org.slf4j.Logger) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) Result(com.simiacryptus.mindseye.lang.Result) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) ToDoubleFunction(java.util.function.ToDoubleFunction) TensorList(com.simiacryptus.mindseye.lang.TensorList) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) Coordinate(com.simiacryptus.mindseye.lang.Coordinate) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) Nullable(javax.annotation.Nullable) Result(com.simiacryptus.mindseye.lang.Result) Nullable(javax.annotation.Nullable)

Aggregations

ToDoubleFunction (java.util.function.ToDoubleFunction)9 List (java.util.List)4 Logger (org.slf4j.Logger)4 LoggerFactory (org.slf4j.LoggerFactory)4 JsonObject (com.google.gson.JsonObject)3 Coordinate (com.simiacryptus.mindseye.lang.Coordinate)3 DataSerializer (com.simiacryptus.mindseye.lang.DataSerializer)3 DeltaSet (com.simiacryptus.mindseye.lang.DeltaSet)3 Layer (com.simiacryptus.mindseye.lang.Layer)3 LayerBase (com.simiacryptus.mindseye.lang.LayerBase)3 Result (com.simiacryptus.mindseye.lang.Result)3 Tensor (com.simiacryptus.mindseye.lang.Tensor)3 TensorArray (com.simiacryptus.mindseye.lang.TensorArray)3 TensorList (com.simiacryptus.mindseye.lang.TensorList)3 Arrays (java.util.Arrays)3 Map (java.util.Map)3 Function (java.util.function.Function)3 IntStream (java.util.stream.IntStream)3 Nonnull (javax.annotation.Nonnull)3 Nullable (javax.annotation.Nullable)3