Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reshape with explicit memory layout and transpose with explicit permutation #474

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/arraymancer/tensor/private/p_shapeshifting.nim
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,16 @@ proc reshapeImpl*(t: AnyTensor, new_shape: varargs[int]|MetadataArray, result: v
else:
reshape_with_copy(t, new_shape, result)

proc reshapeImplWithContig*(t : AnyTensor, new_shape: varargs[int]|MetadataArray, result: var AnyTensor, layout: OrderType) {.noSideEffect.}=
when compileOption("boundChecks"):
when new_shape is MetadataArray:
check_reshape(t, new_shape)
else:
check_reshape(t, new_shape.toMetadataArray)

reshapeImpl(t.asContiguous(layout, force=true), new_shape, result)


proc broadcastImpl*(t: var AnyTensor, shape: varargs[int]|MetadataArray) {.noSideEffect.}=
when compileOption("boundChecks"):
assert t.rank == shape.len
Expand Down
59 changes: 59 additions & 0 deletions src/arraymancer/tensor/shapeshifting.nim
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,39 @@ proc transpose*(t: Tensor): Tensor {.noInit,noSideEffect,inline.} =
result.offset = t.offset
result.storage = t.storage

proc transpose*(t : Tensor, axes: seq[int]) : Tensor {.noInit,inline.} =
## Transpose a Tensor using a passed permutation of axes.
##
## Data is not copied or modified, only metadata is modified.
assert axes.len == t.rank

let
n = axes.len

var
perm = newSeqWith(t.rank, 0)
mrep = newSeqWith(t.rank, -1)
new_shape = t.shape
new_strides = t.strides

for i in 0 ..< n:
var axis = axes[i]
if axis < 0:
axis += t.rank
assert axis >= 0 and axis < t.rank, "Out of bounds axis for the Tensor"
assert mrep[axis] == -1, "Axes can not be repeated"
mrep[axis] = i
perm[i] = axis

for i in 0 ..< n:
new_shape[i] = t.shape[perm[i]]
new_strides[i] = t.strides[perm[i]]

result.shape = new_shape
result.strides = new_strides
result.offset = t.offset
result.storage = t.storage

proc asContiguous*[T](t: Tensor[T], layout: OrderType = rowMajor, force: bool = false): Tensor[T] {.noInit.} =
## Transform a tensor with general striding to a Tensor with contiguous layout.
##
Expand Down Expand Up @@ -69,6 +102,19 @@ proc reshape*(t: Tensor, new_shape: varargs[int]): Tensor {.noInit.} =
## - a tensor with the same data but reshaped.
reshapeImpl(t, new_shape, result)

proc reshape*(t: Tensor, new_shape: varargs[int], layout: OrderType): Tensor {.noInit.} =
## Reshape a tensor. If possible no data copy is done and the returned tensor
## shares data with the input. If input is not contiguous, this is not possible
## and a copy will be made.
##
## Input:
## - a tensor
## - a new shape. Number of elements must be the same
## - a memory layout to use when reshaping the data
## Returns:
## - a tensor with the same data but reshaped.
reshapeImplWithContig(t, new_shape, result, layout)

proc reshape*(t: Tensor, new_shape: MetadataArray): Tensor {.noInit.} =
## Reshape a tensor. If possible no data copy is done and the returned tensor
## shares data with the input. If input is not contiguous, this is not possible
Expand All @@ -81,6 +127,19 @@ proc reshape*(t: Tensor, new_shape: MetadataArray): Tensor {.noInit.} =
## - a tensor with the same data but reshaped.
reshapeImpl(t, new_shape, result)

proc reshape*(t: Tensor, new_shape: MetadataArray, layout: OrderType): Tensor {.noInit.} =
## Reshape a tensor. If possible no data copy is done and the returned tensor
## shares data with the input. If input is not contiguous, this is not possible
## and a copy will be made.
##
## Input:
## - a tensor
## - a new shape. Number of elements must be the same
## - a memory layout to use when reshaping the data
## Returns:
## - a tensor with the same data but reshaped.
reshapeImplWithContig(t, new_shape, result, layout)

proc broadcast*[T](t: Tensor[T], shape: varargs[int]): Tensor[T] {.noInit,noSideEffect.}=
## Explicitly broadcast a tensor to the specified shape.
##
Expand Down
27 changes: 27 additions & 0 deletions tests/tensor/test_shapeshifting.nim
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,33 @@ testSuite "Shapeshifting":
check: a == [[1,2],
[3,4]].toTensor()

test "Reshape with explicit order":
let a = toSeq(1..12).toTensor().reshape(3, 2, 2).asContiguous(rowMajor, force = true)
let b = toSeq(1..12).toTensor().reshape(3, 2, 2).asContiguous(colMajor, force = true)
check: a == b
# Default behavior is respecting memory layouts when reshaping
check: a.reshape(6, 2) != b.reshape(6, 2)

# Explicit ordering will reshape using the same memory layout
check: a.reshape(6, 2, colMajor) == b.reshape(6, 2)
check: a.reshape(6, 2) == b.reshape(6, 2, rowMajor)

test "Transpose with explicit permutation":
let a = toSeq(1..6).toTensor().reshape(1, 2, 3)
let b = a.transpose(@[0, 2, 1])
let c = a.transpose(@[2, 0, 1])
# Check different permutations other than a full transpose

let expected_b = @[1, 4, 2, 5, 3, 6].toTensor().reshape(1, 3, 2)
check: b == expected_b
check: b.shape == [1, 3, 2]
check: b.strides == [6, 1, 3]

let expected_c = @[1, 4, 2, 5, 3, 6].toTensor().reshape(3, 1, 2)
check: c == expected_c
check: c.shape == [3, 1, 2]
check: c.strides == [1, 6, 3]

test "Unsafe reshape":
block:
let a = toSeq(1..4).toTensor()
Expand Down