@@ -31,17 +31,40 @@ limitations under the License.
3131namespace xla {
3232namespace ifrt {
3333
34+ // PjRt-compatible `Executable` interface.
35+ class PjRtCompatibleExecutable
36+ : public llvm::RTTIExtends<PjRtCompatibleExecutable, Executable> {
37+ public:
38+ // APIs that allow direct access to `xla::PjRtExecutable` for PjRt-only
39+ // operations.
40+ virtual xla::PjRtExecutable* pjrt_executable () = 0;
41+ };
42+
43+ // PjRt-compatible `LoadedExecutable` interface.
44+ class PjRtCompatibleLoadedExecutable
45+ : public llvm::RTTIExtends<PjRtCompatibleLoadedExecutable,
46+ LoadedExecutable> {
47+ public:
48+ // APIs that allow direct access to `xla::PjRtLoadedExecutable` for PjRt-only
49+ // operations.
50+ virtual xla::PjRtLoadedExecutable* pjrt_loaded_executable () = 0;
51+ virtual std::shared_ptr<xla::PjRtLoadedExecutable>
52+ shared_ptr_pjrt_loaded_executable () = 0 ;
53+ };
54+
3455// `Executable` implementation that wraps a `xla::PjRtExecutable`.
3556class PjRtExecutable final
36- : public llvm::RTTIExtends<PjRtExecutable, Executable > {
57+ : public llvm::RTTIExtends<PjRtExecutable, PjRtCompatibleExecutable > {
3758 public:
3859 // Creates PjRtExecutable from xla::PjRtExecutable.
3960 static StatusOr<std::unique_ptr<Executable>> Create (
4061 std::unique_ptr<xla::PjRtExecutable> pjrt_executable);
4162 static StatusOr<std::unique_ptr<Executable>> Create (
4263 std::shared_ptr<xla::PjRtExecutable> pjrt_executable);
4364
44- xla::PjRtExecutable* pjrt_executable () const {
65+ // PjRtCompatibleExecutable implementation.
66+
67+ xla::PjRtExecutable* pjrt_executable () override {
4568 DCHECK (this );
4669 return pjrt_executable_.get ();
4770 }
@@ -101,7 +124,8 @@ class PjRtExecutable final
101124
102125// `LoadedExecutable` implementation that wraps a `xla::PjRtLoadedExecutable`.
103126class PjRtLoadedExecutable final
104- : public llvm::RTTIExtends<PjRtLoadedExecutable, LoadedExecutable> {
127+ : public llvm::RTTIExtends<PjRtLoadedExecutable,
128+ PjRtCompatibleLoadedExecutable> {
105129 public:
106130 using LoadedExecutable::ExecuteOptions;
107131 using LoadedExecutable::ExecuteResult;
@@ -125,12 +149,14 @@ class PjRtLoadedExecutable final
125149 PjRtClient* client, const XlaComputation& computation,
126150 CompileOptions options);
127151
128- xla::PjRtLoadedExecutable* pjrt_loaded_executable () const {
152+ // PjRtCompatibleLoadedExecutable implementation.
153+
154+ xla::PjRtLoadedExecutable* pjrt_loaded_executable () override {
129155 DCHECK (this );
130156 return pjrt_loaded_executable_.get ();
131157 }
132158 std::shared_ptr<xla::PjRtLoadedExecutable> shared_ptr_pjrt_loaded_executable ()
133- const {
159+ override {
134160 DCHECK (this );
135161 return pjrt_loaded_executable_;
136162 }
@@ -181,7 +207,7 @@ class PjRtLoadedExecutable final
181207
182208 Client* client () const override {
183209 DCHECK (this );
184- return const_cast <PjRtClient*>( client_) ;
210+ return client_;
185211 }
186212 StatusOr<ExecuteResult> Execute (absl::Span<Array* const > args,
187213 const ExecuteOptions& options,
@@ -190,21 +216,21 @@ class PjRtLoadedExecutable final
190216 Future<Status> Delete () override ;
191217 bool IsDeleted () const override {
192218 DCHECK (this );
193- return pjrt_loaded_executable () ->IsDeleted ();
219+ return pjrt_loaded_executable_ ->IsDeleted ();
194220 }
195221
196222 const DeviceAssignment& device_assignment () const override {
197223 DCHECK (this );
198- return pjrt_loaded_executable () ->device_assignment ();
224+ return pjrt_loaded_executable_ ->device_assignment ();
199225 }
200226 absl::Span<const LoadedExecutable::LogicalDeviceIds>
201227 addressable_device_logical_ids () const override {
202228 DCHECK (this );
203- return pjrt_loaded_executable () ->addressable_device_logical_ids ();
229+ return pjrt_loaded_executable_ ->addressable_device_logical_ids ();
204230 }
205231 absl::Span<Device* const > addressable_devices () const override {
206232 DCHECK (this );
207- return pjrt_loaded_executable () ->addressable_devices ();
233+ return pjrt_loaded_executable_ ->addressable_devices ();
208234 }
209235
210236 static char ID; // NOLINT
0 commit comments