diff --git a/xla/tests/BUILD b/xla/tests/BUILD index 4fb88fa0de2f8..25045815a5a03 100644 --- a/xla/tests/BUILD +++ b/xla/tests/BUILD @@ -273,6 +273,17 @@ cc_library( ], ) +cc_library( + name = "hlo_pjrt_interpreter_reference_mixin", + testonly = True, + hdrs = ["hlo_pjrt_interpreter_reference_mixin.h"], + deps = [ + ":hlo_runner_agnostic_reference_mixin", + "//xla/pjrt/interpreter:interpreter_client", + "//xla/service:hlo_runner_pjrt", + ], +) + cc_library( name = "hlo_pjrt_test_base", testonly = True, diff --git a/xla/tests/hlo_pjrt_interpreter_reference_mixin.h b/xla/tests/hlo_pjrt_interpreter_reference_mixin.h new file mode 100644 index 0000000000000..cdbd2f3575cfb --- /dev/null +++ b/xla/tests/hlo_pjrt_interpreter_reference_mixin.h @@ -0,0 +1,50 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_TESTS_HLO_PJRT_INTERPRETER_REFERENCE_MIXIN_H_ +#define XLA_TESTS_HLO_PJRT_INTERPRETER_REFERENCE_MIXIN_H_ + +#include + +#include "xla/pjrt/interpreter/interpreter_client.h" +#include "xla/service/hlo_runner_pjrt.h" +#include "xla/tests/hlo_runner_agnostic_reference_mixin.h" + +namespace xla { + +// A wrapper mixin around HloRunnerAgnosticReferenceMixin which provides a +// default reference backend via HloRunnerPjRt using the PjRt InterpreterClient. +// +// The mixin requires that that the test class is a subclass of +// HloRunnerAgnosticTestBase. +template +class HloPjRtInterpreterReferenceMixin + : public HloRunnerAgnosticReferenceMixin { + protected: + template + explicit HloPjRtInterpreterReferenceMixin(BaseArgs&&... base_args) + : HloRunnerAgnosticReferenceMixin( + std::make_unique( + std::make_unique(), + InterpreterClient::DeviceShapeRepresentation, + InterpreterClient::ShapeSizeBytes, + /*use_parameter_layout_on_device=*/true), + std::forward(base_args)...) {} + ~HloPjRtInterpreterReferenceMixin() override = default; +}; + +} // namespace xla + +#endif // XLA_TESTS_HLO_PJRT_INTERPRETER_REFERENCE_MIXIN_H_