Skip to content

Commit

Permalink
some utility methods added for shape handling
Browse files Browse the repository at this point in the history
  • Loading branch information
kaifox committed Feb 9, 2017
1 parent 65c0c71 commit 9bac4ca
Show file tree
Hide file tree
Showing 9 changed files with 246 additions and 54 deletions.
27 changes: 20 additions & 7 deletions src/java/org/tensorics/core/lang/Tensorics.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@

package org.tensorics.core.lang;

import static org.tensorics.core.tensor.operations.PositionFunctions.forSupplier;

import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
Expand All @@ -46,9 +44,9 @@
import org.tensorics.core.tensor.lang.OngoingTensorManipulation;
import org.tensorics.core.tensor.lang.QuantityTensors;
import org.tensorics.core.tensor.lang.TensorStructurals;
import org.tensorics.core.tensor.operations.FunctionTensorCreationOperation;
import org.tensorics.core.tensor.operations.SingleValueTensorCreationOperation;
import org.tensorics.core.tensor.operations.TensorInternals;
import org.tensorics.core.tensor.stream.TensorStreams;
import org.tensorics.core.tensorbacked.OngoingTensorbackedCompletion;
import org.tensorics.core.tensorbacked.Tensorbacked;
import org.tensorics.core.tensorbacked.TensorbackedBuilder;
import org.tensorics.core.tensorbacked.Tensorbackeds;
Expand Down Expand Up @@ -303,15 +301,15 @@ public static <S> OngoingFlattening<S> flatten(Tensorbacked<S> tensorbacked) {
}

public static <S> Tensor<S> sameValues(Shape shape, S value) {
return new SingleValueTensorCreationOperation<S>(shape, value).perform();
return TensorInternals.sameValues(shape, value);
}

public static <S> Tensor<S> createFrom(Shape shape, Supplier<S> supplier) {
return new FunctionTensorCreationOperation<>(shape, forSupplier(supplier)).perform();
return TensorInternals.createFrom(shape, supplier);
}

public static <S> Tensor<S> createFrom(Shape shape, Function<Position, S> function) {
return new FunctionTensorCreationOperation<>(shape, function).perform();
return TensorInternals.createFrom(shape, function);
}

public static <S> OngoingCompletion<S> complete(Tensor<S> tensor) {
Expand Down Expand Up @@ -407,4 +405,19 @@ public static <S> Stream<Map.Entry<Position, S>> stream(Tensor<S> tensor) {
public static <S> Stream<Map.Entry<Position, S>> stream(Tensorbacked<S> tensorBacked) {
return TensorStreams.tensorEntryStream(tensorBacked.tensor());
}

/**
* @see Tensorbackeds#shapesOf(Tensorbacked)
*/
public static <TB extends Tensorbacked<?>> Iterable<Shape> shapesOf(Iterable<TB> tensorbackeds) {
return Tensorbackeds.shapesOf(tensorbackeds);
}

/**
* @see Tensorbackeds#complete(Tensorbacked)
*/
public static <S, TB extends Tensorbacked<S>> OngoingTensorbackedCompletion<TB, S> complete(TB tensorbacked) {
return Tensorbackeds.complete(tensorbacked);
}

}
97 changes: 86 additions & 11 deletions src/java/org/tensorics/core/tensor/Shapes.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,14 @@
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.collect.Collections2.transform;
import static com.google.common.collect.Sets.union;
import static java.util.Objects.requireNonNull;

import java.util.NoSuchElementException;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.function.Function;

import com.google.common.collect.Iterables;
import com.google.common.collect.Sets;

/**
Expand All @@ -48,20 +51,59 @@ private Shapes() {

/**
* Creates a shape, containing all the positions that are contained in both given shapes. This only makes sense, if
* the dimensions of the two sets are the same. If they are not, then an {@link IllegalArgumentException} is thrown.
* the dimensions of the two shapes are the same. If they are not, then an {@link IllegalArgumentException} is
* thrown.
*
* @param left the first shape from which to take positions
* @param right the second shape from which to take positions
* @return a shape containing all positions, which are contained in both given shapes
*/
public static Shape intersection(Shape left, Shape right) {
checkLeftRightNotNull(left, right);
if (!left.hasSameDimensionsAs(right)) {
throw new IllegalArgumentException("The two shapes do not have the same dimension, "
+ "therefore the intersection of coordinates cannot be determined. Left dimensions: "
+ left.dimensionSet() + "; Right dimensions: " + right.dimensionSet());
}
return Shape.of(Sets.intersection(left.positionSet(), right.positionSet()));
return combineLeftRightBy(left, right, Sets::intersection);
}

/**
* Creates a shape, containing all the positions that are either contained in the left or the right shape.This only
* makes sense, if the dimensions of the two shapes are the same. If they are not, then an
* {@link IllegalArgumentException} is thrown.
*
* @param left the first shape from which to take positions
* @param right the second shape from which to take positions
* @return a shape containing all positions, which are contained in at least one of the two shapes
*/
public static Shape union(Shape left, Shape right) {
return combineLeftRightBy(left, right, Sets::union);
}

/**
* Creates a shape, containing all the positions that are contained at least in one of the given shapes. This only
* makes sense, if the dimensions of the shapes are the same. If they are not, then an
* {@link IllegalArgumentException} is thrown. Further, it is required that at least one element is contained in the
* iterable.
*
* @param shapes the shapes for which the union shall be found
* @return a shape which represents the union of all the shapes
* @throws IllegalArgumentException if the shapes are not of the same dimension
* @throws NoSuchElementException in case the iterable is empty
* @throws NullPointerException if the given iterable is {@code null}
*/
public static final Shape union(Iterable<Shape> shapes) {
return combineBy(shapes, Shapes::union);
}

/**
* Creates a shape, containing the positions which are contained in each of the given shapes. This only makes sense,
* if the dimensions of the shapes are the same. If they are not, then an {@link IllegalArgumentException} is
* thrown. Further, it is required that at least one element is contained in the iterable.
*
* @param shapes the shapes for which the intersection shall be found
* @return a shape which represents the intersection of all the shapes
* @throws IllegalArgumentException if the shapes are not of the same dimension
* @throws NoSuchElementException in case the iterable is empty
* @throws NullPointerException if the given iterable is {@code null}
*/
public static final Shape intersection(Iterable<Shape> shapes) {
return combineBy(shapes, Shapes::intersection);
}

/**
Expand Down Expand Up @@ -89,7 +131,8 @@ public static Set<Class<?>> dimensionalIntersection(Shape left, Shape right) {
public static Shape dimensionStripped(Shape shape, Set<? extends Class<?>> dimensionsToStrip) {
checkNotNull(shape, "shape must not be null");
checkNotNull(dimensionsToStrip, "dimensions must not be null");
return Shape.of(Positions.unique(transform(shape.positionSet(), toGuavaFunction(Positions.stripping(dimensionsToStrip)))));
return Shape.of(Positions
.unique(transform(shape.positionSet(), toGuavaFunction(Positions.stripping(dimensionsToStrip)))));
}

/**
Expand All @@ -107,7 +150,7 @@ public static Shape dimensionStripped(Shape shape, Set<? extends Class<?>> dimen
public static Shape outerProduct(Shape left, Shape right) {
checkArgument(dimensionalIntersection(left, right).isEmpty(), "The two shapes have "
+ "overlapping dimensions. The outer product is not foreseen to be used in this situation.");
Shape.Builder builder = Shape.builder(union(left.dimensionSet(), right.dimensionSet()));
Shape.Builder builder = Shape.builder(Sets.union(left.dimensionSet(), right.dimensionSet()));
for (Position leftPosition : left.positionSet()) {
for (Position rightPosition : right.positionSet()) {
builder.add(Positions.union(leftPosition, rightPosition));
Expand Down Expand Up @@ -137,4 +180,36 @@ public R apply(T input) {
}
};
}

private static void checkLeftRightSameDimensions(Shape left, Shape right) {
if (!left.hasSameDimensionsAs(right)) {
throw new IllegalArgumentException("The two shapes do not have the same dimension, "
+ "therefore combining of positions does not make sense. Left dimensions: " + left.dimensionSet()
+ "; Right dimensions: " + right.dimensionSet());
}
}

private static Shape combineLeftRightBy(Shape left, Shape right,
BiFunction<Set<Position>, Set<Position>, Set<Position>> combiner) {
checkLeftRightNotNull(left, right);
checkLeftRightSameDimensions(left, right);
return Shape.of(combiner.apply(left.positionSet(), right.positionSet()));
}

private static Shape combineBy(Iterable<Shape> shapes, BiFunction<Shape, Shape, Shape> combiner) {
requireNonNull(shapes, "shapes must not be null");
if (Iterables.isEmpty(shapes)) {
throw new NoSuchElementException("At least one shape is required.");
}
Shape resultingShape = null;
for (Shape shape : shapes) {
if (shape == null) {
resultingShape = shape;
} else {
resultingShape = combiner.apply(resultingShape, shape);
}
}
return resultingShape;
}

}
30 changes: 5 additions & 25 deletions src/java/org/tensorics/core/tensor/lang/OngoingCompletion.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,9 @@

package org.tensorics.core.tensor.lang;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;

import java.util.Map.Entry;
import java.util.Set;

import org.tensorics.core.tensor.ImmutableTensor;
import org.tensorics.core.tensor.ImmutableTensor.Builder;
import org.tensorics.core.tensor.Position;
import org.tensorics.core.tensor.Shape;
import org.tensorics.core.tensor.Tensor;
import org.tensorics.core.tensor.operations.TensorInternals;

import com.google.common.base.Preconditions;

Expand All @@ -44,24 +37,11 @@ public class OngoingCompletion<S> {
}

public Tensor<S> with(Tensor<S> second) {
checkNotNull(second, "second tensor must not be null");
checkArgument(second.shape().dimensionSet().equals(dimensions()),
"Tensors do not have the same dimensions! Completion not supported!");
Builder<S> builder = ImmutableTensor.builder(dimensions());
builder.context(tensor.context());
for (Entry<Position, S> entry: second.asMap().entrySet()) {
Position position = entry.getKey();
if (tensor.shape().contains(position)) {
builder.putAt(tensor.get(position), position);
} else {
builder.put(entry);
}
}
return builder.build();
return TensorStructurals.completeWith(tensor, second);
}

private Set<Class<?>> dimensions() {
return tensor.shape().dimensionSet();
public Tensor<S> with(Shape shape, S value) {
return with(TensorInternals.sameValues(shape, value));
}

}
23 changes: 23 additions & 0 deletions src/java/org/tensorics/core/tensor/lang/TensorStructurals.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@

package org.tensorics.core.tensor.lang;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;

import java.util.HashSet;
import java.util.Map.Entry;
import java.util.Set;
Expand All @@ -30,6 +33,8 @@
import org.tensorics.core.tensor.ImmutableTensor;
import org.tensorics.core.tensor.ImmutableTensor.Builder;
import org.tensorics.core.tensor.Position;
import org.tensorics.core.tensor.Shape;
import org.tensorics.core.tensor.Shapes;
import org.tensorics.core.tensor.Tensor;

import com.google.common.collect.Iterables;
Expand Down Expand Up @@ -163,4 +168,22 @@ public static final <S> OngoingTensorFiltering<S> filter(Tensor<S> tensor) {
return new OngoingTensorFiltering<>(tensor);
}

public static final <S> Tensor<S> completeWith(Tensor<S> tensor, Tensor<S> second) {
checkNotNull(second, "second tensor must not be null");
checkArgument(second.shape().dimensionSet().equals(tensor.shape().dimensionSet()),
"Tensors do not have the same dimensions! Completion not supported!");
Builder<S> builder = ImmutableTensor.builder(tensor.shape().dimensionSet());
builder.context(tensor.context());

Shape shape = Shapes.union(tensor.shape(), second.shape());
for (Position position : shape.positionSet()) {
if (tensor.shape().contains(position)) {
builder.putAt(tensor.get(position), position);
} else {
builder.putAt(second.get(position), position);
}
}
return builder.build();
}

}
17 changes: 17 additions & 0 deletions src/java/org/tensorics/core/tensor/operations/TensorInternals.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,15 @@

package org.tensorics.core.tensor.operations;

import static org.tensorics.core.tensor.operations.PositionFunctions.forSupplier;

import java.util.Map.Entry;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Supplier;

import org.tensorics.core.tensor.Position;
import org.tensorics.core.tensor.Shape;
import org.tensorics.core.tensor.Tensor;

/**
Expand All @@ -47,4 +52,16 @@ public static <T> Set<Entry<Position, T>> entrySetOf(Tensor<T> tensor) {
return tensor.asMap().entrySet();
}

public static <S> Tensor<S> sameValues(Shape shape, S value) {
return new SingleValueTensorCreationOperation<S>(shape, value).perform();
}

public static <S> Tensor<S> createFrom(Shape shape, Supplier<S> supplier) {
return new FunctionTensorCreationOperation<>(shape, forSupplier(supplier)).perform();
}

public static <S> Tensor<S> createFrom(Shape shape, Function<Position, S> function) {
return new FunctionTensorCreationOperation<>(shape, function).perform();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// @formatter:off
/*******************************************************************************
*
* This file is part of tensorics.
*
* Copyright (c) 2008-2011, CERN. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
******************************************************************************/
// @formatter:on

package org.tensorics.core.tensorbacked;

import static com.google.common.base.Preconditions.checkNotNull;
import static org.tensorics.core.tensor.lang.TensorStructurals.completeWith;
import static org.tensorics.core.tensorbacked.TensorbackedInternals.createBackedByTensor;

import org.tensorics.core.tensor.Shape;
import org.tensorics.core.tensor.Tensor;
import org.tensorics.core.tensor.operations.TensorInternals;

public class OngoingTensorbackedCompletion<TB extends Tensorbacked<S>, S> {

private final TB tensorbacked;

OngoingTensorbackedCompletion(TB tensorbacked) {
this.tensorbacked = checkNotNull(tensorbacked, "tensorbacked must not be null");
}

public TB with(Tensor<S> second) {
Tensor<S> tensor = completeWith(tensorbacked.tensor(), second);
return createBackedByTensor(TensorbackedInternals.classOf(tensorbacked), tensor);
}

public TB with(TB second) {
return with(second.tensor());
}

public TB with(Shape shape, S value) {
return with(TensorInternals.sameValues(shape, value));
}

}
Loading

0 comments on commit 9bac4ca

Please sign in to comment.