Skip to content

Commit

Permalink
YQL-17542 move TaskRunner dependent Execute to TDqSyncComputeActorBase (
Browse files Browse the repository at this point in the history
  • Loading branch information
zverevgeny authored Feb 8, 2024
1 parent 4b4be95 commit 112425d
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 52 deletions.
54 changes: 26 additions & 28 deletions ydb/library/yql/dq/actors/compute/dq_compute_actor_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1458,33 +1458,32 @@ class TDqComputeActorBase : public NActors::TActorBootstrapped<TDerived>
this->RegisterWithSameMailbox(source.Actor);
}
for (auto& [inputIndex, transform] : InputTransformsMap) {
Y_ABORT_UNLESS(TaskRunner);
transform.ProgramBuilder.ConstructInPlace(typeEnv, *FunctionRegistry);
Y_ABORT_UNLESS(AsyncIoFactory);
const auto& inputDesc = Task.GetInputs(inputIndex);
CA_LOG_D("Create transform for input " << inputIndex << " " << inputDesc.ShortDebugString());
try {
std::tie(transform.AsyncInput, transform.Actor) = AsyncIoFactory->CreateDqInputTransform(
IDqAsyncIoFactory::TInputTransformArguments {
.InputDesc = inputDesc,
.InputIndex = inputIndex,
.StatsLevel = collectStatsLevel,
.TxId = TxId,
.TaskId = Task.GetId(),
.TransformInput = transform.InputBuffer,
.SecureParams = secureParams,
.TaskParams = taskParams,
.ComputeActorId = this->SelfId(),
.TypeEnv = typeEnv,
.HolderFactory = holderFactory,
.ProgramBuilder = *transform.ProgramBuilder,
.Alloc = Alloc,
.TraceId = ComputeActorSpan.GetTraceId()
});
} catch (const std::exception& ex) {
throw yexception() << "Failed to create input transform " << inputDesc.GetTransform().GetType() << ": " << ex.what();
}
this->RegisterWithSameMailbox(transform.Actor);
transform.ProgramBuilder.ConstructInPlace(typeEnv, *FunctionRegistry);
Y_ABORT_UNLESS(AsyncIoFactory);
const auto& inputDesc = Task.GetInputs(inputIndex);
CA_LOG_D("Create transform for input " << inputIndex << " " << inputDesc.ShortDebugString());
try {
std::tie(transform.AsyncInput, transform.Actor) = AsyncIoFactory->CreateDqInputTransform(
IDqAsyncIoFactory::TInputTransformArguments {
.InputDesc = inputDesc,
.InputIndex = inputIndex,
.StatsLevel = collectStatsLevel,
.TxId = TxId,
.TaskId = Task.GetId(),
.TransformInput = transform.InputBuffer,
.SecureParams = secureParams,
.TaskParams = taskParams,
.ComputeActorId = this->SelfId(),
.TypeEnv = typeEnv,
.HolderFactory = holderFactory,
.ProgramBuilder = *transform.ProgramBuilder,
.Alloc = Alloc,
.TraceId = ComputeActorSpan.GetTraceId()
});
} catch (const std::exception& ex) {
throw yexception() << "Failed to create input transform " << inputDesc.GetTransform().GetType() << ": " << ex.what();
}
this->RegisterWithSameMailbox(transform.Actor);
}
for (auto& [outputIndex, transform] : OutputTransformsMap) {
transform.ProgramBuilder.ConstructInPlace(typeEnv, *FunctionRegistry);
Expand Down Expand Up @@ -2031,7 +2030,6 @@ class TDqComputeActorBase : public NActors::TActorBootstrapped<TDerived>
const IDqAsyncIoFactory::TPtr AsyncIoFactory;
const NKikimr::NMiniKQL::IFunctionRegistry* FunctionRegistry = nullptr;
const NDqProto::ECheckpointingMode CheckpointingMode;
TIntrusivePtr<IDqTaskRunner> TaskRunner;
TDqComputeActorChannels* Channels = nullptr;
TDqComputeActorCheckpoints* Checkpoints = nullptr;
THashMap<ui64, TInputChannelInfo> InputChannelsMap; // Channel id -> Channel info
Expand Down
50 changes: 26 additions & 24 deletions ydb/library/yql/dq/actors/compute/dq_sync_compute_actor_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class TDqSyncComputeActorBase: public TDqComputeActorBase<TDerived, TComputeActo
auto sourcesState = static_cast<TDerived*>(this)->GetSourcesState();

TBase::PollAsyncInput();
ERunStatus status = this->TaskRunner->Run();
ERunStatus status = TaskRunner->Run();

CA_LOG_T("Resume execution, run status: " << status);

Expand All @@ -65,13 +65,13 @@ class TDqSyncComputeActorBase: public TDqComputeActorBase<TDerived, TComputeActo
}

void DoTerminateImpl() override {
this->TaskRunner.Reset();
TaskRunner.Reset();
}

void InvalidateMeminfo() override {
if (this->TaskRunner) {
this->TaskRunner->GetAllocator().InvalidateMemInfo();
this->TaskRunner->GetAllocator().DisableStrictAllocationCheck();
if (TaskRunner) {
TaskRunner->GetAllocator().InvalidateMemInfo();
TaskRunner->GetAllocator().DisableStrictAllocationCheck();
}
}

Expand All @@ -81,7 +81,7 @@ class TDqSyncComputeActorBase: public TDqComputeActorBase<TDerived, TComputeActo
mkqlProgramState.SetRuntimeVersion(NDqProto::RUNTIME_VERSION_YQL_1_0);
NDqProto::TStateData::TData& data = *mkqlProgramState.MutableData()->MutableStateData();
data.SetVersion(TDqComputeActorCheckpoints::ComputeActorCurrentStateVersion);
data.SetBlob(this->TaskRunner->Save());
data.SetBlob(TaskRunner->Save());

for (auto& [inputIndex, source] : this->SourcesMap) {
YQL_ENSURE(source.AsyncInput, "Source[" << inputIndex << "] is not created");
Expand All @@ -94,19 +94,19 @@ class TDqSyncComputeActorBase: public TDqComputeActorBase<TDerived, TComputeActo
void DoLoadRunnerState(TString&& blob) override {
TMaybe<TString> error = Nothing();
try {
this->TaskRunner->Load(blob);
TaskRunner->Load(blob);
} catch (const std::exception& e) {
error = e.what();
}
this->Checkpoints->AfterStateLoading(error);
}

void SetTaskRunner(const TIntrusivePtr<IDqTaskRunner>& taskRunner) {
this->TaskRunner = taskRunner;
TaskRunner = taskRunner;
}

void PrepareTaskRunner(const IDqTaskRunnerExecutionContext& execCtx) {
YQL_ENSURE(this->TaskRunner);
YQL_ENSURE(TaskRunner);

auto guard = TBase::BindAllocator();
auto* alloc = guard.GetMutex();
Expand All @@ -118,49 +118,49 @@ class TDqSyncComputeActorBase: public TDqComputeActorBase<TDerived, TComputeActo
limits.ChannelBufferSize = this->MemoryLimits.ChannelBufferSize;
limits.OutputChunkMaxSize = GetDqExecutionSettings().FlowControl.MaxOutputChunkSize;

this->TaskRunner->Prepare(this->Task, limits, execCtx);
TaskRunner->Prepare(this->Task, limits, execCtx);

for (auto& [channelId, channel] : this->InputChannelsMap) {
channel.Channel = this->TaskRunner->GetInputChannel(channelId);
channel.Channel = TaskRunner->GetInputChannel(channelId);
}

for (auto& [inputIndex, source] : this->SourcesMap) {
source.Buffer = this->TaskRunner->GetSource(inputIndex);
source.Buffer = TaskRunner->GetSource(inputIndex);
Y_ABORT_UNLESS(source.Buffer);
}

for (auto& [inputIndex, transform] : this->InputTransformsMap) {
std::tie(transform.InputBuffer, transform.Buffer) = this->TaskRunner->GetInputTransform(inputIndex);
std::tie(transform.InputBuffer, transform.Buffer) = TaskRunner->GetInputTransform(inputIndex);
}

for (auto& [channelId, channel] : this->OutputChannelsMap) {
channel.Channel = this->TaskRunner->GetOutputChannel(channelId);
channel.Channel = TaskRunner->GetOutputChannel(channelId);
}

for (auto& [outputIndex, transform] : this->OutputTransformsMap) {
std::tie(transform.Buffer, transform.OutputBuffer) = this->TaskRunner->GetOutputTransform(outputIndex);
std::tie(transform.Buffer, transform.OutputBuffer) = TaskRunner->GetOutputTransform(outputIndex);
}

for (auto& [outputIndex, sink] : this->SinksMap) {
sink.Buffer = this->TaskRunner->GetSink(outputIndex);
sink.Buffer = TaskRunner->GetSink(outputIndex);
}

TBase::FillIoMaps(
this->TaskRunner->GetHolderFactory(),
this->TaskRunner->GetTypeEnv(),
this->TaskRunner->GetSecureParams(),
this->TaskRunner->GetTaskParams(),
this->TaskRunner->GetReadRanges(),
this->TaskRunner->GetRandomProvider()
TaskRunner->GetHolderFactory(),
TaskRunner->GetTypeEnv(),
TaskRunner->GetSecureParams(),
TaskRunner->GetTaskParams(),
TaskRunner->GetReadRanges(),
TaskRunner->GetRandomProvider()
);
}

const NYql::NDq::TTaskRunnerStatsBase* GetTaskRunnerStats() override {
return this->TaskRunner ? this->TaskRunner->GetStats() : nullptr;
return TaskRunner ? TaskRunner->GetStats() : nullptr;
}

const NYql::NDq::TDqMeteringStats* GetMeteringStats() override {
return this->TaskRunner ? this->TaskRunner->GetMeteringStats() : nullptr;
return TaskRunner ? TaskRunner->GetMeteringStats() : nullptr;
}

protected:
Expand All @@ -171,6 +171,8 @@ class TDqSyncComputeActorBase: public TDqComputeActorBase<TDerived, TComputeActo
void PollSources(void* /* state */) {
}

TIntrusivePtr<IDqTaskRunner> TaskRunner;

};

} //namespace NYql::NDq
Expand Down

0 comments on commit 112425d

Please sign in to comment.