Skip to content

Commit 22a6979

Browse files
authored
Fix for deadlock in python callback (#3073)
Fix update from release branch
1 parent ccc7079 commit 22a6979

File tree

7 files changed

+279
-36
lines changed

7 files changed

+279
-36
lines changed

.github/workflows/linux.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -522,8 +522,8 @@ jobs:
522522
run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).continuous_batching.test }}
523523
timeout: 360
524524
- name: 'LLM & VLM'
525-
cmd: 'python -m pytest -v ./tests/python_tests/test_llm_pipeline.py tests/python_tests/test_llm_pipeline_static.py ./tests/python_tests/test_vlm_pipeline.py tests/python_tests/test_structured_output.py --override-ini cache_dir=/mount/caches/pytest/'
526-
run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).visual_language.test || fromJSON(needs.smart_ci.outputs.affected_components).LLM.test }}
525+
cmd: 'python -m pytest -v ./tests/python_tests/test_llm_pipeline.py tests/python_tests/test_llm_pipeline_static.py ./tests/python_tests/test_vlm_pipeline.py tests/python_tests/test_structured_output.py tests/python_tests/test_image_generation.py --override-ini cache_dir=/mount/caches/pytest/'
526+
run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).visual_language.test || fromJSON(needs.smart_ci.outputs.affected_components).LLM.test || fromJSON(needs.smart_ci.outputs.affected_components).Image_generation.test }}
527527
timeout: 180
528528
- name: 'GGUF Reader tests'
529529
cmd: 'python -m pytest -v ./tests/python_tests/test_gguf_reader.py'

.github/workflows/mac.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -447,8 +447,8 @@ jobs:
447447
# timeout: 240
448448
# Only supported on X64 or ARM with SVE support
449449
# - name: 'LLM & VLM'
450-
# cmd: 'tests/python_tests/test_llm_pipeline.py tests/python_tests/test_llm_pipeline_static.py tests/python_tests/test_vlm_pipeline.py tests/python_tests/test_structured_output.py'
451-
# run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).visual_language.test || fromJSON(needs.smart_ci.outputs.affected_components).LLM.test }}
450+
# cmd: 'tests/python_tests/test_llm_pipeline.py tests/python_tests/test_llm_pipeline_static.py tests/python_tests/test_vlm_pipeline.py tests/python_tests/test_structured_output.py tests/python_tests/test_image_generation.py'
451+
# run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).visual_language.test || fromJSON(needs.smart_ci.outputs.affected_components).LLM.test || fromJSON(needs.smart_ci.outputs.affected_components).Image_generation.test }}
452452
# timeout: 180
453453
- name: 'GGUF Reader tests'
454454
cmd: 'python -m pytest -v ./tests/python_tests/test_gguf_reader.py'

.github/workflows/windows.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -611,8 +611,8 @@ jobs:
611611
run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).continuous_batching.test }}
612612
timeout: 360
613613
- name: 'LLM & VLM'
614-
cmd: 'python -m pytest -s -v tests/python_tests/test_llm_pipeline.py tests/python_tests/test_llm_pipeline_static.py tests/python_tests/test_vlm_pipeline.py tests/python_tests/test_structured_output.py --override-ini cache_dir=/mount/caches/pytest/'
615-
run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).visual_language.test || fromJSON(needs.smart_ci.outputs.affected_components).LLM.test }}
614+
cmd: 'python -m pytest -s -v tests/python_tests/test_llm_pipeline.py tests/python_tests/test_llm_pipeline_static.py tests/python_tests/test_vlm_pipeline.py tests/python_tests/test_structured_output.py tests/python_tests/test_image_generation.py --override-ini cache_dir=/mount/caches/pytest/'
615+
run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).visual_language.test || fromJSON(needs.smart_ci.outputs.affected_components).LLM.test || fromJSON(needs.smart_ci.outputs.affected_components).Image_generation.test }}
616616
timeout: 180
617617
- name: 'GGUF Reader tests'
618618
cmd: 'python -m pytest -s -v tests/python_tests/test_gguf_reader.py'

src/python/py_image_generation_pipelines.cpp

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -180,10 +180,12 @@ class TorchGenerator : public ov::genai::CppStdGenerator {
180180
}
181181

182182
float next() override {
183+
py::gil_scoped_acquire acquire;
183184
return m_torch.attr("randn")(1, "generator"_a=m_torch_generator, "dtype"_a=m_float32).attr("item")().cast<float>();
184185
}
185186

186187
ov::Tensor randn_tensor(const ov::Shape& shape) override {
188+
py::gil_scoped_acquire acquire;
187189
py::object torch_tensor = m_torch.attr("randn")(to_py_list(shape), "generator"_a=m_torch_generator, "dtype"_a=m_float32);
188190
py::object numpy_tensor = torch_tensor.attr("numpy")();
189191
py::array numpy_array = py::cast<py::array>(numpy_tensor);
@@ -201,6 +203,32 @@ class TorchGenerator : public ov::genai::CppStdGenerator {
201203
TorchTensorAllocator(size_t total_size, void * mutable_data, py::object torch_tensor) :
202204
m_total_size(total_size), m_mutable_data(mutable_data), m_torch_tensor(torch_tensor) { }
203205

206+
~TorchTensorAllocator() {
207+
if (m_torch_tensor && Py_IsInitialized()) {
208+
py::gil_scoped_acquire acquire;
209+
m_torch_tensor = py::object();
210+
}
211+
}
212+
213+
TorchTensorAllocator(const TorchTensorAllocator& other)
214+
: m_total_size(other.m_total_size), m_mutable_data(other.m_mutable_data) {
215+
py::gil_scoped_acquire acquire;
216+
m_torch_tensor = other.m_torch_tensor;
217+
}
218+
219+
TorchTensorAllocator& operator=(const TorchTensorAllocator& other) {
220+
if (this != &other) {
221+
m_total_size = other.m_total_size;
222+
m_mutable_data = other.m_mutable_data;
223+
py::gil_scoped_acquire acquire;
224+
m_torch_tensor = other.m_torch_tensor;
225+
}
226+
return *this;
227+
}
228+
229+
TorchTensorAllocator(TorchTensorAllocator&&) = default;
230+
TorchTensorAllocator& operator=(TorchTensorAllocator&&) = default;
231+
204232
void* allocate(size_t bytes, size_t) const {
205233
if (m_total_size == bytes) {
206234
return m_mutable_data;
@@ -221,6 +249,7 @@ class TorchGenerator : public ov::genai::CppStdGenerator {
221249
}
222250

223251
void seed(size_t new_seed) override {
252+
py::gil_scoped_acquire acquire;
224253
create_torch_generator(new_seed);
225254
}
226255
};
@@ -448,12 +477,7 @@ void init_image_generation_pipelines(py::module_& m) {
448477
) -> py::typing::Union<ov::Tensor> {
449478
ov::AnyMap params = pyutils::kwargs_to_any_map(kwargs);
450479
ov::Tensor res;
451-
if (params_have_torch_generator(params)) {
452-
// TorchGenerator stores python object which causes segfault after gil_scoped_release
453-
// so if it was passed, we don't release GIL
454-
res = pipe.generate(prompt, params);
455-
}
456-
else {
480+
{
457481
py::gil_scoped_release rel;
458482
res = pipe.generate(prompt, params);
459483
}
@@ -565,12 +589,7 @@ void init_image_generation_pipelines(py::module_& m) {
565589
) -> py::typing::Union<ov::Tensor> {
566590
ov::AnyMap params = pyutils::kwargs_to_any_map(kwargs);
567591
ov::Tensor res;
568-
if (params_have_torch_generator(params)) {
569-
// TorchGenerator stores python object which causes segfault after gil_scoped_release
570-
// so if it was passed, we don't release GIL
571-
res = pipe.generate(prompt, image, params);
572-
}
573-
else {
592+
{
574593
py::gil_scoped_release rel;
575594
res = pipe.generate(prompt, image, params);
576595
}
@@ -676,12 +695,7 @@ void init_image_generation_pipelines(py::module_& m) {
676695
) -> py::typing::Union<ov::Tensor> {
677696
ov::AnyMap params = pyutils::kwargs_to_any_map(kwargs);
678697
ov::Tensor res;
679-
if (params_have_torch_generator(params)) {
680-
// TorchGenerator stores python object which causes segfault after gil_scoped_release
681-
// so if it was passed, we don't release GIL
682-
res = pipe.generate(prompt, image, mask_image, params);
683-
}
684-
else {
698+
{
685699
py::gil_scoped_release rel;
686700
res = pipe.generate(prompt, image, mask_image, params);
687701
}

src/python/py_utils.cpp

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -374,10 +374,22 @@ ov::Any py_object_to_any(const py::object& py_obj, std::string property_name) {
374374
return py::cast<std::shared_ptr<ov::genai::Generator>>(py_obj);
375375
} else if (py::isinstance<py::function>(py_obj) && property_name == "callback") {
376376
auto py_callback = py::cast<py::function>(py_obj);
377+
auto shared_callback = std::shared_ptr<py::function>(
378+
new py::function(py_callback),
379+
[](py::function* f) {
380+
if (Py_IsInitialized()) {
381+
py::gil_scoped_acquire acquire;
382+
delete f;
383+
} else {
384+
delete f;
385+
}
386+
}
387+
);
388+
377389
return std::function<bool(size_t, size_t, ov::Tensor&)>(
378-
[py_callback](size_t step, size_t num_steps, ov::Tensor& latent) -> bool {
390+
[shared_callback](size_t step, size_t num_steps, ov::Tensor& latent) -> bool {
379391
py::gil_scoped_acquire acquire;
380-
return py_callback(step, num_steps, latent).cast<bool>();
392+
return (*shared_callback)(step, num_steps, latent).cast<bool>();
381393
}
382394
);
383395
} else if ((py::isinstance<py::function>(py_obj) || py::isinstance<ov::genai::StreamerBase>(py_obj) || py::isinstance<std::monostate>(py_obj)) && property_name == "streamer") {
@@ -443,21 +455,40 @@ ov::genai::StreamerVariant pystreamer_to_streamer(const PyBindStreamerVariant& p
443455

444456
std::visit(overloaded {
445457
[&streamer](const std::function<std::optional<uint16_t>(py::str)>& py_callback){
446-
// Wrap python streamer with manual utf-8 decoding. Do not rely
447-
// on pybind automatic decoding since it raises exceptions on incomplete strings.
448-
auto callback_wrapped = [py_callback](std::string subword) -> ov::genai::StreamingStatus {
458+
auto shared_callback = std::shared_ptr<std::function<std::optional<uint16_t>(py::str)>>(
459+
new std::function<std::optional<uint16_t>(py::str)>(py_callback),
460+
[](std::function<std::optional<uint16_t>(py::str)>* f) {
461+
if (Py_IsInitialized()) {
462+
py::gil_scoped_acquire acquire;
463+
delete f;
464+
} else {
465+
delete f;
466+
}
467+
}
468+
);
469+
470+
auto callback_wrapped = [shared_callback](std::string subword) -> ov::genai::StreamingStatus {
449471
py::gil_scoped_acquire acquire;
450-
auto py_str = PyUnicode_DecodeUTF8(subword.data(), subword.length(), "replace");
451-
std::optional<uint16_t> callback_output = py_callback(py::reinterpret_borrow<py::str>(py_str));
472+
PyObject* py_str = PyUnicode_DecodeUTF8(subword.data(), subword.length(), "replace");
473+
if (!py_str) {
474+
PyErr_WriteUnraisable(nullptr);
475+
return StreamingStatus::RUNNING;
476+
}
477+
auto py_str_obj = py::reinterpret_steal<py::str>(py_str);
478+
std::optional<uint16_t> callback_output;
479+
try {
480+
callback_output = (*shared_callback)(py_str_obj);
481+
} catch (const py::error_already_set&) {
482+
return StreamingStatus::RUNNING;
483+
}
452484
if (callback_output.has_value()) {
453-
if (*callback_output == (uint16_t)StreamingStatus::RUNNING)
485+
if (*callback_output == static_cast<uint16_t>(StreamingStatus::RUNNING))
454486
return StreamingStatus::RUNNING;
455-
else if (*callback_output == (uint16_t)StreamingStatus::CANCEL)
487+
else if (*callback_output == static_cast<uint16_t>(StreamingStatus::CANCEL))
456488
return StreamingStatus::CANCEL;
457489
return StreamingStatus::STOP;
458-
} else {
459-
return StreamingStatus::RUNNING;
460490
}
491+
return StreamingStatus::RUNNING;
461492
};
462493
streamer = callback_wrapped;
463494
},

src/python/py_utils.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ namespace ov::genai::pybind::utils {
1919
// When StreamerVariant is used utf-8 decoding is done by pybind and can lead to exception on incomplete texts.
2020
// Therefore strings decoding should be handled with PyUnicode_DecodeUTF8(..., "replace") to not throw errors.
2121
using PyBindStreamerVariant = std::variant<
22-
std::function<std::optional<uint16_t>(std::string)>,
22+
std::function<std::optional<uint16_t>(py::str)>,
2323
std::shared_ptr<StreamerBase>,
2424
std::monostate>;
2525

0 commit comments

Comments
 (0)