Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 61 additions & 9 deletions backend/cpp/llama-cpp/grpc-server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,12 @@ class BackendServiceImpl final : public backend::Backend::Service {
}

ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_ptr & result) -> bool {
// Check if context is cancelled before processing result
if (context->IsCancelled()) {
ctx_server.cancel_tasks(task_ids);
return false;
}

json res_json = result->to_json();
if (res_json.is_array()) {
for (const auto & res : res_json) {
Expand Down Expand Up @@ -875,13 +881,18 @@ class BackendServiceImpl final : public backend::Backend::Service {
reply.set_message(error_data.value("content", ""));
writer->Write(reply);
return true;
}, [&]() {
// NOTE: we should try to check when the writer is closed here
return false;
}, [&context]() {
// Check if the gRPC context is cancelled
return context->IsCancelled();
});

ctx_server.queue_results.remove_waiting_task_ids(task_ids);

// Check if context was cancelled during processing
if (context->IsCancelled()) {
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
}

return grpc::Status::OK;
}

Expand Down Expand Up @@ -1145,6 +1156,14 @@ class BackendServiceImpl final : public backend::Backend::Service {


std::cout << "[DEBUG] Waiting for results..." << std::endl;

// Check cancellation before waiting for results
if (context->IsCancelled()) {
ctx_server.cancel_tasks(task_ids);
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
}

ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
std::cout << "[DEBUG] Received " << results.size() << " results" << std::endl;
if (results.size() == 1) {
Expand Down Expand Up @@ -1176,13 +1195,20 @@ class BackendServiceImpl final : public backend::Backend::Service {
}, [&](const json & error_data) {
std::cout << "[DEBUG] Error in results: " << error_data.value("content", "") << std::endl;
reply->set_message(error_data.value("content", ""));
}, [&]() {
return false;
}, [&context]() {
// Check if the gRPC context is cancelled
// This is checked every HTTP_POLLING_SECONDS (1 second) during receive_multi_results
return context->IsCancelled();
});

ctx_server.queue_results.remove_waiting_task_ids(task_ids);
std::cout << "[DEBUG] Predict request completed successfully" << std::endl;

// Check if context was cancelled during processing
if (context->IsCancelled()) {
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
}

return grpc::Status::OK;
}

Expand Down Expand Up @@ -1234,6 +1260,13 @@ class BackendServiceImpl final : public backend::Backend::Service {
ctx_server.queue_tasks.post(std::move(tasks));
}

// Check cancellation before waiting for results
if (context->IsCancelled()) {
ctx_server.cancel_tasks(task_ids);
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
}

// get the result
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
for (auto & res : results) {
Expand All @@ -1242,12 +1275,18 @@ class BackendServiceImpl final : public backend::Backend::Service {
}
}, [&](const json & error_data) {
error = true;
}, [&]() {
return false;
}, [&context]() {
// Check if the gRPC context is cancelled
return context->IsCancelled();
});

ctx_server.queue_results.remove_waiting_task_ids(task_ids);

// Check if context was cancelled during processing
if (context->IsCancelled()) {
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
}

if (error) {
return grpc::Status(grpc::StatusCode::INTERNAL, "Error in receiving results");
}
Expand Down Expand Up @@ -1325,6 +1364,13 @@ class BackendServiceImpl final : public backend::Backend::Service {
ctx_server.queue_tasks.post(std::move(tasks));
}

// Check cancellation before waiting for results
if (context->IsCancelled()) {
ctx_server.cancel_tasks(task_ids);
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
}

// Get the results
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
for (auto & res : results) {
Expand All @@ -1333,12 +1379,18 @@ class BackendServiceImpl final : public backend::Backend::Service {
}
}, [&](const json & error_data) {
error = true;
}, [&]() {
return false;
}, [&context]() {
// Check if the gRPC context is cancelled
return context->IsCancelled();
});

ctx_server.queue_results.remove_waiting_task_ids(task_ids);

// Check if context was cancelled during processing
if (context->IsCancelled()) {
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
}

if (error) {
return grpc::Status(grpc::StatusCode::INTERNAL, "Error in receiving results");
}
Expand Down
9 changes: 4 additions & 5 deletions core/config/model_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,19 +93,18 @@ type AgentConfig struct {
EnablePlanReEvaluator bool `yaml:"enable_plan_re_evaluator" json:"enable_plan_re_evaluator"`
}

func (c *MCPConfig) MCPConfigFromYAML() (MCPGenericConfig[MCPRemoteServers], MCPGenericConfig[MCPSTDIOServers]) {
func (c *MCPConfig) MCPConfigFromYAML() (MCPGenericConfig[MCPRemoteServers], MCPGenericConfig[MCPSTDIOServers], error) {
var remote MCPGenericConfig[MCPRemoteServers]
var stdio MCPGenericConfig[MCPSTDIOServers]

if err := yaml.Unmarshal([]byte(c.Servers), &remote); err != nil {
return remote, stdio
return remote, stdio, err
}

if err := yaml.Unmarshal([]byte(c.Stdio), &stdio); err != nil {
return remote, stdio
return remote, stdio, err
}

return remote, stdio
return remote, stdio, nil
}

type MCPGenericConfig[T any] struct {
Expand Down
64 changes: 64 additions & 0 deletions core/http/endpoints/openai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"net"
"time"

"github.com/gofiber/fiber/v2"
Expand All @@ -22,6 +24,59 @@
"github.com/valyala/fasthttp"
)

// NOTE: this is a bad WORKAROUND! We should find a better way to handle this.
// Fasthttp doesn't support context cancellation from the caller
// for non-streaming requests, so we need to monitor the connection directly.
// Monitor connection for client disconnection during non-streaming requests
// We access the connection directly via c.Context().Conn() to monitor it
// during ComputeChoices execution, not after the response is sent
// see: https://github.com/mudler/LocalAI/pull/7187#issuecomment-3506720906
func handleConnectionCancellation(c *fiber.Ctx, cancelFunc func(), requestCtx context.Context) {
var conn net.Conn = c.Context().Conn()
if conn == nil {
return
}

go func() {
defer func() {
// Clear read deadline when goroutine exits
conn.SetReadDeadline(time.Time{})

Check warning

Code scanning / gosec

Errors unhandled Warning

Errors unhandled
}()

buf := make([]byte, 1)
// Use a short read deadline to periodically check if connection is closed
// Without a deadline, Read() would block indefinitely waiting for data
// that will never come (client is waiting for response, not sending more data)
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()

for {
select {
case <-requestCtx.Done():
// Request completed or was cancelled - exit goroutine
return
case <-ticker.C:
// Set a short deadline - if connection is closed, read will fail immediately
// If connection is open but no data, it will timeout and we check again
conn.SetReadDeadline(time.Now().Add(50 * time.Millisecond))

Check warning

Code scanning / gosec

Errors unhandled Warning

Errors unhandled
_, err := conn.Read(buf)
if err != nil {
// Check if it's a timeout (connection still open, just no data)
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
// Timeout is expected - connection is still open, just no data to read
// Continue the loop to check again
continue
}
// Connection closed or other error - cancel the context to stop gRPC call
log.Debug().Msgf("Calling cancellation function")
cancelFunc()
return
}
}
}
}()
}

// ChatEndpoint is the OpenAI Completion API endpoint https://platform.openai.com/docs/api-reference/chat/create
// @Summary Generate a chat completions for a given prompt and model.
// @Param request body schema.OpenAIRequest true "query params"
Expand Down Expand Up @@ -358,6 +413,11 @@
LOOP:
for {
select {
case <-input.Context.Done():
// Context was cancelled (client disconnected or request cancelled)
log.Debug().Msgf("Request context cancelled, stopping stream")
input.Cancel()
break LOOP
case ev := <-responses:
if len(ev.Choices) == 0 {
log.Debug().Msgf("No choices in the response, skipping")
Expand Down Expand Up @@ -511,6 +571,10 @@

}

// NOTE: this is a workaround as fasthttp
// context cancellation does not fire in non-streaming requests
handleConnectionCancellation(c, input.Cancel, input.Context)

result, tokenUsage, err := ComputeChoices(
input,
predInput,
Expand Down
14 changes: 11 additions & 3 deletions core/http/endpoints/openai/mcp.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package openai

import (
"context"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -50,12 +51,15 @@ func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader,
}

// Get MCP config from model config
remote, stdio := config.MCP.MCPConfigFromYAML()
remote, stdio, err := config.MCP.MCPConfigFromYAML()
if err != nil {
return fmt.Errorf("failed to get MCP config: %w", err)
}

// Check if we have tools in cache, or we have to have an initial connection
sessions, err := mcpTools.SessionsFromMCPConfig(config.Name, remote, stdio)
if err != nil {
return err
return fmt.Errorf("failed to get MCP sessions: %w", err)
}

if len(sessions) == 0 {
Expand All @@ -73,6 +77,10 @@ func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader,
if appConfig.ApiKeys != nil {
apiKey = appConfig.ApiKeys[0]
}

ctxWithCancellation, cancel := context.WithCancel(ctx)
defer cancel()
handleConnectionCancellation(c, cancel, ctxWithCancellation)
// TODO: instead of connecting to the API, we should just wire this internally
// and act like completion.go.
// We can do this as cogito expects an interface and we can create one that
Expand All @@ -83,7 +91,7 @@ func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader,
cogito.WithStatusCallback(func(s string) {
log.Debug().Msgf("[model agent] [model: %s] Status: %s", config.Name, s)
}),
cogito.WithContext(ctx),
cogito.WithContext(ctxWithCancellation),
cogito.WithMCPs(sessions...),
cogito.WithIterations(3), // default to 3 iterations
cogito.WithMaxAttempts(3), // default to 3 attempts
Expand Down
10 changes: 10 additions & 0 deletions core/http/middleware/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,17 @@ func (re *RequestExtractor) SetOpenAIRequest(ctx *fiber.Ctx) error {
correlationID := ctx.Get("X-Correlation-ID", uuid.New().String())
ctx.Set("X-Correlation-ID", correlationID)

//c1, cancel := context.WithCancel(re.applicationConfig.Context)
// Use the application context as parent to ensure cancellation on app shutdown
// We'll monitor the Fiber context separately and cancel our context when the request is canceled
c1, cancel := context.WithCancel(re.applicationConfig.Context)
// Monitor the Fiber context and cancel our context when it's canceled
// This ensures we respect request cancellation without causing panics
go func() {
<-ctx.Context().Done()
// Fiber context was canceled (request completed or client disconnected)
cancel()
}()
// Add the correlation ID to the new context
ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID)

Expand Down
Loading
Loading