Skip to content

Commit d86b69c

Browse files
hyeontaekcopybara-github
authored andcommitted
[IFRT] Add "PjRt-compatible" interfaces for PjRt escape hatches
This change helps migrate various PjRt implementations that do not directly use/wrap around a low-level PjRt runtime. Implementing the interface will preserve the compatibility between the JAX Python binding and such PjRt-compatible IFRT implementations. PiperOrigin-RevId: 494186705
1 parent eaa170c commit d86b69c

File tree

12 files changed

+107
-52
lines changed

12 files changed

+107
-52
lines changed

xla/python/ifrt/client.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ class Client : public llvm::RTTIExtends<Client, llvm::RTTIRoot> {
9999

100100
// TODO(hyeontaek): Potentially remove this method to encourage supporting
101101
// only ahead-of-time compilation.
102-
virtual Compiler* GetDefaultCompiler() const = 0;
102+
virtual Compiler* GetDefaultCompiler() = 0;
103103

104104
static char ID; // NOLINT
105105
};

xla/python/pjrt_ifrt/pjrt_array.cc

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,8 @@ StatusOr<DType> ToDType(xla::PrimitiveType primitive_type) {
9191
StatusOr<std::unique_ptr<Array>> PjRtArray::Create(
9292
Client* client, DType dtype, Shape shape,
9393
std::shared_ptr<const Sharding> sharding, PjRtBuffers pjrt_buffers) {
94-
if (!llvm::isa_and_nonnull<PjRtClient>(client)) {
95-
return InvalidArgument("PjRtClient expected");
94+
if (!llvm::isa_and_nonnull<PjRtCompatibleClient>(client)) {
95+
return InvalidArgument("PjRtCompatibleClient expected");
9696
}
9797
if (pjrt_buffers.empty()) {
9898
return InvalidArgument("pjrt_buffers must be non-empty");
@@ -101,22 +101,22 @@ StatusOr<std::unique_ptr<Array>> PjRtArray::Create(
101101
return InvalidArgument("device and buffer counts mismatch: %d vs. %d",
102102
sharding->devices().size(), pjrt_buffers.size());
103103
}
104-
return std::unique_ptr<Array>(
105-
new PjRtArray(static_cast<PjRtClient*>(client), dtype, std::move(shape),
106-
std::move(sharding), std::move(pjrt_buffers)));
104+
return std::unique_ptr<Array>(new PjRtArray(
105+
static_cast<PjRtCompatibleClient*>(client), dtype, std::move(shape),
106+
std::move(sharding), std::move(pjrt_buffers)));
107107
}
108108

109109
StatusOr<std::unique_ptr<Array>> PjRtArray::Create(
110110
Client* client, std::shared_ptr<PjRtBuffer> pjrt_buffer) {
111-
if (!llvm::isa_and_nonnull<PjRtClient>(client)) {
112-
return InvalidArgument("PjRtClient expected");
111+
if (!llvm::isa_and_nonnull<PjRtCompatibleClient>(client)) {
112+
return InvalidArgument("PjRtCompatibleClient expected");
113113
}
114114
TF_ASSIGN_OR_RETURN(auto dtype,
115115
ToDType(pjrt_buffer->on_device_shape().element_type()));
116116
Shape shape(pjrt_buffer->on_device_shape().dimensions());
117117
auto sharding = SingleDeviceSharding::Create(pjrt_buffer->device());
118118
return std::unique_ptr<Array>(new PjRtArray(
119-
static_cast<PjRtClient*>(client), dtype, std::move(shape),
119+
static_cast<PjRtCompatibleClient*>(client), dtype, std::move(shape),
120120
std::move(sharding), PjRtBuffers({std::move(pjrt_buffer)})));
121121
}
122122

@@ -149,7 +149,7 @@ StatusOr<std::unique_ptr<Array>> PjRtArray::Create(Client* client, Shape shape,
149149
std::move(pjrt_buffers));
150150
}
151151

152-
PjRtArray::PjRtArray(PjRtClient* client, DType dtype, Shape shape,
152+
PjRtArray::PjRtArray(PjRtCompatibleClient* client, DType dtype, Shape shape,
153153
std::shared_ptr<const Sharding> sharding,
154154
PjRtBuffers pjrt_buffers)
155155
: client_(client),

xla/python/pjrt_ifrt/pjrt_array.h

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,19 @@ StatusOr<xla::PrimitiveType> ToPrimitiveType(DType dtype);
3636
// Converts `xla::PrimitiveType` into IFRT `DType`.
3737
StatusOr<DType> ToDType(xla::PrimitiveType primitive_type);
3838

39+
// PjRt-compatible `Array` interface that wraps a list of `xla::PjRtBuffer`s.
40+
class PjRtCompatibleArray
41+
: public llvm::RTTIExtends<PjRtCompatibleArray, Array> {
42+
public:
43+
// APIs that allow direct access to `PjRtBuffer` for PjRt-only operations.
44+
virtual absl::Span<const std::shared_ptr<PjRtBuffer>> pjrt_buffers() = 0;
45+
virtual StatusOr<absl::Span<std::shared_ptr<PjRtBuffer>>>
46+
mutable_pjrt_buffers() = 0;
47+
};
48+
3949
// `Array` implementation that wraps a list of `xla::PjRtBuffer`s.
40-
class PjRtArray final : public llvm::RTTIExtends<PjRtArray, Array> {
50+
class PjRtArray final
51+
: public llvm::RTTIExtends<PjRtArray, PjRtCompatibleArray> {
4152
public:
4253
static constexpr int kPjRtBufferInlineSize = 1;
4354
using PjRtBuffers =
@@ -60,26 +71,25 @@ class PjRtArray final : public llvm::RTTIExtends<PjRtArray, Array> {
6071
static StatusOr<std::unique_ptr<Array>> Create(Client* client, Shape shape,
6172
PjRtBuffers pjrt_buffers);
6273

63-
absl::Span<const std::shared_ptr<PjRtBuffer>> pjrt_buffers() const {
74+
// PjRtCompatibleArray implementation.
75+
76+
absl::Span<const std::shared_ptr<PjRtBuffer>> pjrt_buffers() override {
6477
DCHECK(this);
6578
return pjrt_buffers_;
6679
}
67-
absl::Span<std::shared_ptr<PjRtBuffer>> pjrt_buffers() {
80+
StatusOr<absl::Span<std::shared_ptr<PjRtBuffer>>> mutable_pjrt_buffers()
81+
override {
6882
DCHECK(this);
6983
return absl::MakeSpan(pjrt_buffers_);
7084
}
71-
PjRtBuffer* pjrt_buffer(int device_id) const {
72-
DCHECK(this);
73-
return pjrt_buffers_[device_id].get();
74-
}
7585

7686
// Array implementation.
7787

7888
~PjRtArray() override = default;
7989

8090
Client* client() const override {
8191
DCHECK(this);
82-
return const_cast<PjRtClient*>(client_);
92+
return client_;
8393
}
8494

8595
DType dtype() const override {
@@ -121,10 +131,10 @@ class PjRtArray final : public llvm::RTTIExtends<PjRtArray, Array> {
121131
static char ID; // NOLINT
122132

123133
private:
124-
PjRtArray(PjRtClient* client, DType dtype, Shape shape,
134+
PjRtArray(PjRtCompatibleClient* client, DType dtype, Shape shape,
125135
std::shared_ptr<const Sharding> sharding, PjRtBuffers pjrt_buffers);
126136

127-
PjRtClient* client_;
137+
PjRtCompatibleClient* client_;
128138
DType dtype_;
129139
Shape shape_;
130140
std::shared_ptr<const Sharding> sharding_;

xla/python/pjrt_ifrt/pjrt_client.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ PjRtClient::AssembleArrayFromSingleDeviceArrays(
8585
return InvalidArgument("Only PjRtArray is supported: arrays[%d]=%s", i,
8686
arrays[i]->DebugString());
8787
}
88-
const auto* array = static_cast<const PjRtArray*>(arrays[i]);
88+
auto* array = static_cast<PjRtArray*>(arrays[i]);
8989
if (array->dtype() != dtype) {
9090
return InvalidArgument(
9191
"Every input must have the same dtype: %s (shard 0) vs. %s (shard "

xla/python/pjrt_ifrt/pjrt_client.h

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,29 @@ limitations under the License.
3131
namespace xla {
3232
namespace ifrt {
3333

34+
// PjRt-compatible `Client` interface.
35+
class PjRtCompatibleClient
36+
: public llvm::RTTIExtends<PjRtCompatibleClient, Client> {
37+
public:
38+
// APIs that allow direct access to `xla::PjRtClient` for PjRt-only
39+
// operations.
40+
virtual xla::PjRtClient* pjrt_client() = 0;
41+
virtual std::shared_ptr<xla::PjRtClient> shared_ptr_pjrt_client() = 0;
42+
};
43+
3444
// `Client` implementation that wraps `xla::PjRtClient`.
35-
class PjRtClient final : public llvm::RTTIExtends<PjRtClient, Client> {
45+
class PjRtClient final
46+
: public llvm::RTTIExtends<PjRtClient, PjRtCompatibleClient> {
3647
public:
3748
static std::unique_ptr<ifrt::Client> Create(
3849
std::shared_ptr<xla::PjRtClient> pjrt_client);
3950
static std::unique_ptr<ifrt::Client> Create(
4051
std::unique_ptr<xla::PjRtClient> pjrt_client);
4152

42-
xla::PjRtClient* pjrt_client() const { return pjrt_client_.get(); }
43-
std::shared_ptr<xla::PjRtClient> shared_ptr_pjrt_client() const {
53+
// PjRtCompatibleClient implementation.
54+
55+
xla::PjRtClient* pjrt_client() override { return pjrt_client_.get(); }
56+
std::shared_ptr<xla::PjRtClient> shared_ptr_pjrt_client() override {
4457
return pjrt_client_;
4558
}
4659

@@ -114,9 +127,9 @@ class PjRtClient final : public llvm::RTTIExtends<PjRtClient, Client> {
114127
return pjrt_client_->CreateHostToDeviceChannelHandle();
115128
}
116129

117-
Compiler* GetDefaultCompiler() const override {
130+
Compiler* GetDefaultCompiler() override {
118131
DCHECK(this);
119-
return const_cast<PjRtCompiler*>(&default_compiler_);
132+
return &default_compiler_;
120133
}
121134

122135
static char ID; // NOLINT

xla/python/pjrt_ifrt/pjrt_executable.h

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,40 @@ limitations under the License.
3131
namespace xla {
3232
namespace ifrt {
3333

34+
// PjRt-compatible `Executable` interface.
35+
class PjRtCompatibleExecutable
36+
: public llvm::RTTIExtends<PjRtCompatibleExecutable, Executable> {
37+
public:
38+
// APIs that allow direct access to `xla::PjRtExecutable` for PjRt-only
39+
// operations.
40+
virtual xla::PjRtExecutable* pjrt_executable() = 0;
41+
};
42+
43+
// PjRt-compatible `LoadedExecutable` interface.
44+
class PjRtCompatibleLoadedExecutable
45+
: public llvm::RTTIExtends<PjRtCompatibleLoadedExecutable,
46+
LoadedExecutable> {
47+
public:
48+
// APIs that allow direct access to `xla::PjRtLoadedExecutable` for PjRt-only
49+
// operations.
50+
virtual xla::PjRtLoadedExecutable* pjrt_loaded_executable() = 0;
51+
virtual std::shared_ptr<xla::PjRtLoadedExecutable>
52+
shared_ptr_pjrt_loaded_executable() = 0;
53+
};
54+
3455
// `Executable` implementation that wraps a `xla::PjRtExecutable`.
3556
class PjRtExecutable final
36-
: public llvm::RTTIExtends<PjRtExecutable, Executable> {
57+
: public llvm::RTTIExtends<PjRtExecutable, PjRtCompatibleExecutable> {
3758
public:
3859
// Creates PjRtExecutable from xla::PjRtExecutable.
3960
static StatusOr<std::unique_ptr<Executable>> Create(
4061
std::unique_ptr<xla::PjRtExecutable> pjrt_executable);
4162
static StatusOr<std::unique_ptr<Executable>> Create(
4263
std::shared_ptr<xla::PjRtExecutable> pjrt_executable);
4364

44-
xla::PjRtExecutable* pjrt_executable() const {
65+
// PjRtCompatibleExecutable implementation.
66+
67+
xla::PjRtExecutable* pjrt_executable() override {
4568
DCHECK(this);
4669
return pjrt_executable_.get();
4770
}
@@ -101,7 +124,8 @@ class PjRtExecutable final
101124

102125
// `LoadedExecutable` implementation that wraps a `xla::PjRtLoadedExecutable`.
103126
class PjRtLoadedExecutable final
104-
: public llvm::RTTIExtends<PjRtLoadedExecutable, LoadedExecutable> {
127+
: public llvm::RTTIExtends<PjRtLoadedExecutable,
128+
PjRtCompatibleLoadedExecutable> {
105129
public:
106130
using LoadedExecutable::ExecuteOptions;
107131
using LoadedExecutable::ExecuteResult;
@@ -125,12 +149,14 @@ class PjRtLoadedExecutable final
125149
PjRtClient* client, const XlaComputation& computation,
126150
CompileOptions options);
127151

128-
xla::PjRtLoadedExecutable* pjrt_loaded_executable() const {
152+
// PjRtCompatibleLoadedExecutable implementation.
153+
154+
xla::PjRtLoadedExecutable* pjrt_loaded_executable() override {
129155
DCHECK(this);
130156
return pjrt_loaded_executable_.get();
131157
}
132158
std::shared_ptr<xla::PjRtLoadedExecutable> shared_ptr_pjrt_loaded_executable()
133-
const {
159+
override {
134160
DCHECK(this);
135161
return pjrt_loaded_executable_;
136162
}
@@ -181,7 +207,7 @@ class PjRtLoadedExecutable final
181207

182208
Client* client() const override {
183209
DCHECK(this);
184-
return const_cast<PjRtClient*>(client_);
210+
return client_;
185211
}
186212
StatusOr<ExecuteResult> Execute(absl::Span<Array* const> args,
187213
const ExecuteOptions& options,
@@ -190,21 +216,21 @@ class PjRtLoadedExecutable final
190216
Future<Status> Delete() override;
191217
bool IsDeleted() const override {
192218
DCHECK(this);
193-
return pjrt_loaded_executable()->IsDeleted();
219+
return pjrt_loaded_executable_->IsDeleted();
194220
}
195221

196222
const DeviceAssignment& device_assignment() const override {
197223
DCHECK(this);
198-
return pjrt_loaded_executable()->device_assignment();
224+
return pjrt_loaded_executable_->device_assignment();
199225
}
200226
absl::Span<const LoadedExecutable::LogicalDeviceIds>
201227
addressable_device_logical_ids() const override {
202228
DCHECK(this);
203-
return pjrt_loaded_executable()->addressable_device_logical_ids();
229+
return pjrt_loaded_executable_->addressable_device_logical_ids();
204230
}
205231
absl::Span<Device* const> addressable_devices() const override {
206232
DCHECK(this);
207-
return pjrt_loaded_executable()->addressable_devices();
233+
return pjrt_loaded_executable_->addressable_devices();
208234
}
209235

210236
static char ID; // NOLINT

xla/python/py_array.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,8 @@ class PyArray : public pybind11::object {
181181
if (ifrt_array_ptr == nullptr) {
182182
return {};
183183
}
184-
auto* arr = llvm::dyn_cast_or_null<ifrt::PjRtArray>(ifrt_array_ptr);
184+
auto* arr =
185+
llvm::dyn_cast_or_null<ifrt::PjRtCompatibleArray>(ifrt_array_ptr);
185186
if (arr == nullptr) {
186187
throw XlaRuntimeError(
187188
"This operation is implemented for a PjRt-compatible backend only.");

xla/python/py_buffer.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ class PyBuffer {
8787
// Short-term escape hatch to get PjRtBuffer from PyBuffer.
8888
// TODO(hyeontaek): Migrate all users of this method to be agnostic of PjRt.
8989
PjRtBuffer* pjrt_buffer() const {
90-
auto* arr = llvm::dyn_cast_or_null<ifrt::PjRtArray>(ifrt_array_.get());
90+
auto* arr =
91+
llvm::dyn_cast_or_null<ifrt::PjRtCompatibleArray>(ifrt_array_.get());
9192
if (arr == nullptr) {
9293
throw XlaRuntimeError(
9394
"This operation is implemented for a PjRt-compatible backend only.");
@@ -98,7 +99,8 @@ class PyBuffer {
9899
// Short-term escape hatch to get PjRtBuffer from PyBuffer.
99100
// TODO(hyeontaek): Migrate all users of this method to be agnostic of PjRt.
100101
std::shared_ptr<PjRtBuffer> shared_ptr_pjrt_buffer() const {
101-
auto* arr = llvm::dyn_cast_or_null<ifrt::PjRtArray>(ifrt_array_.get());
102+
auto* arr =
103+
llvm::dyn_cast_or_null<ifrt::PjRtCompatibleArray>(ifrt_array_.get());
102104
if (arr == nullptr) {
103105
throw XlaRuntimeError(
104106
"This operation is implemented for a PjRt-compatible backend only.");
@@ -406,7 +408,8 @@ class PyShardedBuffer {
406408
#ifdef JAX_ENABLE_IFRT
407409
PyBuffer::object GetPyBuffer(int device_id) const {
408410
// TODO(hyeontaek): Remove this method. This method will not scale well.
409-
auto* arr = llvm::dyn_cast_or_null<ifrt::PjRtArray>(ifrt_array_.get());
411+
auto* arr =
412+
llvm::dyn_cast_or_null<ifrt::PjRtCompatibleArray>(ifrt_array_.get());
410413
if (arr == nullptr) {
411414
throw XlaRuntimeError(
412415
"This operation is implemented for a PjRt-compatible backend only.");
@@ -446,7 +449,8 @@ class PyShardedBuffer {
446449
// Short-term escape hatch to get PjRtBuffer from PyShardedBuffer.
447450
// TODO(hyeontaek): Migrate all users of this method to be agnostic of PjRt.
448451
PjRtBuffer* pjrt_buffer(int device_id) const {
449-
auto* arr = llvm::dyn_cast_or_null<ifrt::PjRtArray>(ifrt_array_.get());
452+
auto* arr =
453+
llvm::dyn_cast_or_null<ifrt::PjRtCompatibleArray>(ifrt_array_.get());
450454
if (arr == nullptr) {
451455
throw XlaRuntimeError(
452456
"This operation is implemented for a PjRt-compatible backend only.");

xla/python/py_client.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -206,15 +206,15 @@ Status PyClient::Defragment() {
206206
if (array->ifrt_array == nullptr) {
207207
continue;
208208
}
209-
auto* arr =
210-
llvm::dyn_cast_or_null<ifrt::PjRtArray>(array->ifrt_array.get());
209+
auto* arr = llvm::dyn_cast_or_null<ifrt::PjRtCompatibleArray>(
210+
array->ifrt_array.get());
211211
if (arr == nullptr) {
212212
throw XlaRuntimeError(
213213
"This operation is implemented for a PjRt-compatible backend "
214214
"only.");
215215
}
216-
absl::Span<std::shared_ptr<PjRtBuffer>> pjrt_buffers =
217-
arr->pjrt_buffers();
216+
TF_ASSIGN_OR_RETURN(absl::Span<std::shared_ptr<PjRtBuffer>> pjrt_buffers,
217+
arr->mutable_pjrt_buffers());
218218
#else
219219
absl::Span<std::shared_ptr<PjRtBuffer>> pjrt_buffers =
220220
absl::MakeSpan(array->pjrt_buffers);
@@ -669,8 +669,8 @@ StatusOr<py::bytes> PyClient::HeapProfile() {
669669
if (array->ifrt_array == nullptr) {
670670
continue;
671671
}
672-
auto* arr =
673-
llvm::dyn_cast_or_null<ifrt::PjRtArray>(array->ifrt_array.get());
672+
auto* arr = llvm::dyn_cast_or_null<ifrt::PjRtCompatibleArray>(
673+
array->ifrt_array.get());
674674
// TODO(hyeontaek): Support non-PjRt Arrays.
675675
if (arr == nullptr) {
676676
throw XlaRuntimeError(

xla/python/py_client.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ class PyClient : public std::enable_shared_from_this<PyClient> {
112112
// TODO(hyeontaek): Migrate all users of this method to be agnostic of PjRt.
113113
xla::PjRtClient* pjrt_client() const {
114114
auto* pjrt_client =
115-
llvm::dyn_cast_or_null<ifrt::PjRtClient>(ifrt_client_.get());
115+
llvm::dyn_cast_or_null<ifrt::PjRtCompatibleClient>(ifrt_client_.get());
116116
if (pjrt_client == nullptr) {
117117
throw XlaRuntimeError(
118118
"This operation is implemented for a PjRt-compatible backend only.");
@@ -121,7 +121,7 @@ class PyClient : public std::enable_shared_from_this<PyClient> {
121121
}
122122
std::shared_ptr<PjRtClient> shared_ptr_pjrt_client() {
123123
auto* pjrt_client =
124-
llvm::dyn_cast_or_null<ifrt::PjRtClient>(ifrt_client_.get());
124+
llvm::dyn_cast_or_null<ifrt::PjRtCompatibleClient>(ifrt_client_.get());
125125
if (pjrt_client == nullptr) {
126126
throw XlaRuntimeError(
127127
"This operation is implemented for a PjRt-compatible backend only.");

0 commit comments

Comments
 (0)