Skip to content

Commit

Permalink
Update resource placement and transfer for barrier operations (#19725)
Browse files Browse the repository at this point in the history
Barriers indicate within device blocking. The results of a barrier
should not transfer to another location, otherwise there would be a
transfer and not a barrier.

---------

Co-authored-by: Ben Vanik <[email protected]>
  • Loading branch information
rsuderman and benvanik authored Jan 17, 2025
1 parent 75c9e86 commit f31cc72
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,12 @@ class ValueResourceUsage : public AbstractResourceUsage<DFX::ValueElement> {
DFX::Resolution::REQUIRED);
getState() ^= targetUsage.getState();
})
.Case([&](IREE::Stream::AsyncBarrierOp op) {
auto &tiedUsage = solver.getElementFor<ValueResourceUsage>(
*this, Position::forValue(op.getOperand(0)),
DFX::Resolution::REQUIRED);
getState() ^= tiedUsage.getState();
})
.Case([&](IREE::Stream::AsyncTransferOp op) {
removeAssumedBits(NOT_TRANSFER_WRITE);
auto &sourceUsage = solver.getElementFor<ValueResourceUsage>(
Expand Down Expand Up @@ -716,6 +722,12 @@ class ValueResourceUsage : public AbstractResourceUsage<DFX::ValueElement> {
getState() ^= resultUsage.getState();
}
})
.Case([&](IREE::Stream::AsyncBarrierOp op) {
auto &resultUsage = solver.getElementFor<ValueResourceUsage>(
*this, Position::forValue(op.getResult()),
DFX::Resolution::OPTIONAL);
getState() ^= resultUsage.getState();
})
.Case([&](IREE::Stream::AsyncTransferOp op) {
removeAssumedBits(NOT_TRANSFER_READ);
auto &resultUsage = solver.getElementFor<ValueResourceUsage>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,26 @@ util.func public @tensorSplat(%value: i8, %dim0: index) -> tensor<?x128xi8> {

util.global private @device : !hal.device

// CHECK-LABEL: @tensorBarrierDispatch
// CHECK-SAME: (%[[INPUT:.+]]: !stream.resource<*>, %[[DIM0:.+]]: index, %[[DIM1:.+]]: index)
util.func public @tensorBarrierDispatch(%input: tensor<?x128xi8>, %dim0: index) -> tensor<?x128xi8> {
%c0 = arith.constant 0 : index
%barrier = flow.tensor.barrier %input : tensor<?x128xi8>{%dim0} on #hal.device.affinity<@device>
%0 = flow.dispatch @ex::@entry[%c0](%barrier) : (tensor<?x128xi8>{%dim0}) -> tensor<?x128xi8>{%dim0}

// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[BARRIER:.+]] = stream.async.barrier %[[INPUT]] : !stream.resource<*>{%[[DIM0]]} -> !stream.resource<*>
// CHECK: %[[C0_2:.+]] = arith.constant 0 : index
// CHECK: %[[SIZE:.+]] = stream.tensor.sizeof on(#hal.device.affinity<@device>) tensor<?x128xi8>{%arg2} : index
// CHECK: %[[DISP:.+]] = stream.async.dispatch on(#hal.device.affinity<@device>) @ex::@entry[%[[C0]]](%[[BARRIER]][%[[C0_2]] to %[[DIM0]] for %[[DIM0]]])
// CHECK: util.return %[[DISP]], %[[SIZE]]
util.return %0 : tensor<?x128xi8>
}

// -----

util.global private @device : !hal.device

// CHECK-LABEL: @tensorTransfer
// CHECK-SAME: (%[[INPUT:.+]]: !stream.resource<*>, %[[INPUT_SIZE:.+]]: index, %[[DIM0:.+]]: index)
util.func public @tensorTransfer(%input: tensor<?x128xi8>, %dim0: index) -> tensor<?x128xi8> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ ConvertedTensor transferTensorOperands(
Value resource = convertedOperand[0];
Value resourceSize = convertedOperand[1];
auto affinityAttr = affinityAnalysis->lookupResourceAffinity(originalOperand);
if (affinityAttr != requiredAffinityAttr) {
bool isBarrier = resource.getDefiningOp() &&
isa<IREE::Stream::AsyncBarrierOp>(resource.getDefiningOp());
if (affinityAttr != requiredAffinityAttr && !isBarrier) {
resource = builder.create<IREE::Stream::AsyncTransferOp>(
loc, resource.getType(), resource, resourceSize, resourceSize,
affinityAttr, requiredAffinityAttr);
Expand Down

0 comments on commit f31cc72

Please sign in to comment.