Skip to content

Commit

Permalink
Using ArrayLikeXyzNode in SortVectorNode
Browse files Browse the repository at this point in the history
  • Loading branch information
JaroslavTulach committed Aug 30, 2023
1 parent b49cc25 commit 7439403
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 49 deletions.
Original file line number Diff line number Diff line change
@@ -1,17 +1,5 @@
package org.enso.interpreter.node.expression.builtin.ordering;

import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Cached.Shared;
import com.oracle.truffle.api.dsl.GenerateUncached;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.interop.InteropLibrary;
import com.oracle.truffle.api.interop.InvalidArrayIndexException;
import com.oracle.truffle.api.interop.UnsupportedMessageException;
import com.oracle.truffle.api.library.CachedLibrary;
import com.oracle.truffle.api.nodes.Node;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
Expand All @@ -26,7 +14,6 @@
import org.enso.interpreter.dsl.BuiltinMethod;
import org.enso.interpreter.node.callable.dispatch.CallOptimiserNode;
import org.enso.interpreter.node.callable.resolver.MethodResolverNode;
import org.enso.interpreter.node.expression.builtin.interop.syntax.HostValueToEnsoNode;
import org.enso.interpreter.node.expression.builtin.meta.EqualsNode;
import org.enso.interpreter.node.expression.builtin.meta.TypeOfNode;
import org.enso.interpreter.node.expression.builtin.text.AnyToTextNode;
Expand All @@ -36,7 +23,9 @@
import org.enso.interpreter.runtime.callable.function.Function;
import org.enso.interpreter.runtime.data.Type;
import org.enso.interpreter.runtime.data.text.Text;
import org.enso.interpreter.runtime.data.vector.ArrayLikeAtNode;
import org.enso.interpreter.runtime.data.vector.ArrayLikeHelpers;
import org.enso.interpreter.runtime.data.vector.ArrayLikeLengthNode;
import org.enso.interpreter.runtime.error.DataflowError;
import org.enso.interpreter.runtime.error.PanicException;
import org.enso.interpreter.runtime.error.Warning;
Expand All @@ -45,6 +34,18 @@
import org.enso.interpreter.runtime.library.dispatch.TypesLibrary;
import org.enso.interpreter.runtime.state.State;

import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Cached.Shared;
import com.oracle.truffle.api.dsl.GenerateUncached;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.interop.InteropLibrary;
import com.oracle.truffle.api.interop.InvalidArrayIndexException;
import com.oracle.truffle.api.interop.UnsupportedMessageException;
import com.oracle.truffle.api.library.CachedLibrary;
import com.oracle.truffle.api.nodes.Node;

/**
* Sorts a vector with elements that have only Default_Comparator, thus, only elements with a
* builtin type, which is the most common scenario for sorting.
Expand Down Expand Up @@ -97,7 +98,7 @@ public abstract Object execute(
@Specialization(
guards = {
"interop.hasArrayElements(self)",
"areAllDefaultComparators(interop, hostValueToEnsoNode, comparators)",
"areAllDefaultComparators(lengthNode, atNode, comparators)",
"interop.isNull(byFunc)",
"interop.isNull(onFunc)"
}, limit = "3")
Expand All @@ -112,32 +113,23 @@ Object sortPrimitives(
long problemBehavior,
@Shared("lessThanNode") @Cached LessThanNode lessThanNode,
@Shared("equalsNode") @Cached EqualsNode equalsNode,
@Cached HostValueToEnsoNode hostValueToEnsoNode,
@Shared("lengthNode") @Cached ArrayLikeLengthNode lengthNode,
@Shared("atNode") @Cached ArrayLikeAtNode atNode,
@Shared("typeOfNode") @Cached TypeOfNode typeOfNode,
@Shared("anyToTextNode") @Cached AnyToTextNode toTextNode,
@Shared("interop") @CachedLibrary(limit = "10") InteropLibrary interop) {
EnsoContext ctx = EnsoContext.get(this);
Object[] elems;
long longSize = 0L;
try {
long size = interop.getArraySize(self);
assert size < Integer.MAX_VALUE;
elems = new Object[(int) size];
longSize = lengthNode.executeLength(self);
int size = Math.toIntExact(longSize);
elems = new Object[size];
for (int i = 0; i < size; i++) {
if (interop.isArrayElementReadable(self, i)) {
elems[i] = hostValueToEnsoNode.execute(interop.readArrayElement(self, i));
} else {
CompilerDirectives.transferToInterpreter();
throw new PanicException(
ctx.getBuiltins()
.error()
.makeUnsupportedArgumentsError(
new Object[] {self},
"Cannot read array element at index " + i + " of " + self),
this);
}
elems[i] = atNode.executeAt(self, i);
}
} catch (UnsupportedMessageException | InvalidArrayIndexException e) {
throw new IllegalStateException("Should not reach here", e);
} catch (ArithmeticException | InvalidArrayIndexException e) {
throw invalidArrayIndexException(e, longSize);
}
var javaComparator =
createDefaultComparator(
Expand Down Expand Up @@ -189,17 +181,18 @@ Object sortGeneric(
@Shared("lessThanNode") @Cached LessThanNode lessThanNode,
@Shared("equalsNode") @Cached EqualsNode equalsNode,
@Shared("typeOfNode") @Cached TypeOfNode typeOfNode,
@Shared("lengthNode") @Cached ArrayLikeLengthNode lengthNode,
@Shared("atNode") @Cached ArrayLikeAtNode atNode,
@Shared("anyToTextNode") @Cached AnyToTextNode toTextNode,
@Cached MethodResolverNode methodResolverNode,
@Cached(value = "build()", uncached = "build()") HostValueToEnsoNode hostValueToEnsoNode,
@Cached(value = "build()", uncached = "build()") CallOptimiserNode callNode) {
var problemBehavior = ProblemBehavior.fromInt((int) problemBehaviorNum);
// Split into groups
List<Object> elems = readInteropArray(interop, hostValueToEnsoNode, warningsLib, self);
List<Object> elems = readInteropArray(lengthNode, atNode, warningsLib, self);
List<Type> comparators =
readInteropArray(interop, hostValueToEnsoNode, warningsLib, comparatorsArray);
readInteropArray(lengthNode, atNode, warningsLib, comparatorsArray);
List<Function> compareFuncs =
readInteropArray(interop, hostValueToEnsoNode, warningsLib, compareFuncsArray);
readInteropArray(lengthNode, atNode, warningsLib, compareFuncsArray);
List<Group> groups = splitByComparators(elems, comparators, compareFuncs);

// Prepare input for DefaultSortComparator and GenericSortComparator and sort the elements
Expand Down Expand Up @@ -388,23 +381,25 @@ private Object incomparableValuesError(Object left, Object right) {
*/
@SuppressWarnings("unchecked")
private <T> List<T> readInteropArray(
InteropLibrary interop,
HostValueToEnsoNode hostValueToEnsoNode,
ArrayLikeLengthNode lengthNode,
ArrayLikeAtNode atNode,
WarningsLibrary warningsLib,
Object vector) {
var longSize = 0L;
try {
int size = (int) interop.getArraySize(vector);
longSize = lengthNode.executeLength(vector);
int size = Math.toIntExact(longSize);
List<T> res = new ArrayList<>(size);
for (int i = 0; i < size; i++) {
Object elem = hostValueToEnsoNode.execute(interop.readArrayElement(vector, i));
Object elem = atNode.executeAt(vector, i);
if (warningsLib.hasWarnings(elem)) {
elem = warningsLib.removeWarnings(elem);
}
res.add((T) elem);
}
return res;
} catch (UnsupportedMessageException | InvalidArrayIndexException | ClassCastException e) {
throw new IllegalStateException("Should not be reachable", e);
throw invalidArrayIndexException(e, longSize);
}
}

Expand Down Expand Up @@ -454,20 +449,21 @@ private boolean isTrue(Object object) {

/** Returns true iff the given array of comparators is all Default_Comparator */
boolean areAllDefaultComparators(
InteropLibrary interop, HostValueToEnsoNode hostValueToEnsoNode, Object comparators) {
assert interop.hasArrayElements(comparators);
ArrayLikeLengthNode lengthNode, ArrayLikeAtNode atNode, Object comparators
) {
var ctx = EnsoContext.get(this);
var longSize = 0L;
try {
int compSize = (int) interop.getArraySize(comparators);
longSize = lengthNode.executeLength(comparators);
int compSize = (int) longSize;
for (int i = 0; i < compSize; i++) {
assert interop.isArrayElementReadable(comparators, i);
Object comparator = hostValueToEnsoNode.execute(interop.readArrayElement(comparators, i));
Object comparator = atNode.executeAt(comparators, i);
if (!isDefaultComparator(comparator, ctx)) {
return false;
}
}
} catch (UnsupportedMessageException | InvalidArrayIndexException e) {
throw new IllegalStateException("Should not be reachable", e);
} catch (ArithmeticException | InvalidArrayIndexException e) {
throw invalidArrayIndexException(e, longSize);
}
return true;
}
Expand All @@ -480,6 +476,13 @@ private boolean isNan(Object object) {
return object instanceof Double dbl && dbl.isNaN();
}

private PanicException invalidArrayIndexException(Exception e, long size) {
var index = e instanceof InvalidArrayIndexException ex ? ex.getInvalidIndex() : 0L;
var ctx = EnsoContext.get(this);
var payload = ctx.getBuiltins().error().makeInvalidArrayIndex(index, size);
throw new PanicException(payload, this);
}

private enum ProblemBehavior {
IGNORE,
REPORT_WARNING,
Expand Down Expand Up @@ -618,7 +621,7 @@ private int handleIncomparableValues(Object x, Object y) {
} else if (isPrimitiveValue(y)) {
return ascending ? 1 : -1;
} else {
throw new IllegalStateException("Should not be reachable");
throw CompilerDirectives.shouldNotReachHere();
}
} else {
// Values other than primitives are compared just by their type's FQN.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ public static ArrayLikeAtNode create() {
return ArrayLikeAtNodeGen.create();
}

@NeverDefault
public static ArrayLikeAtNode getUncached() {
return ArrayLikeAtNodeGen.getUncached();
}

//
// implementation
//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ public static ArrayLikeLengthNode create() {
return ArrayLikeLengthNodeGen.create();
}

@NeverDefault
public static ArrayLikeLengthNode getUncached() {
return ArrayLikeLengthNodeGen.getUncached();
}

//
// implementation
//
Expand Down

0 comments on commit 7439403

Please sign in to comment.