diff --git a/samples/cpp/live-audio-transcription/README.md b/samples/cpp/live-audio-transcription/README.md index a9fca9774..3e8b8e8d6 100644 --- a/samples/cpp/live-audio-transcription/README.md +++ b/samples/cpp/live-audio-transcription/README.md @@ -26,3 +26,7 @@ g++ -std=c++20 main.cpp -lfoundry_local -o live-audio-transcription-example # Synthetic 440Hz sine wave (no microphone needed) ./live-audio-transcription-example --synth ``` + +Press `Ctrl+C` to request a graceful stop. The sample passes that signal to +execution-provider and model downloads so long-running downloads can be +cancelled before transcription starts. diff --git a/samples/cpp/live-audio-transcription/main.cpp b/samples/cpp/live-audio-transcription/main.cpp index 1a3341e4c..5c94d6180 100644 --- a/samples/cpp/live-audio-transcription/main.cpp +++ b/samples/cpp/live-audio-transcription/main.cpp @@ -122,7 +122,8 @@ int main(int argc, char* argv[]) { foundry_local::Manager::Create(config); auto& manager = foundry_local::Manager::Instance(); - manager.EnsureEpsDownloaded(); + auto isCancellationRequested = [] { return !g_running.load(); }; + manager.DownloadAndRegisterEps(nullptr, isCancellationRequested); auto& catalog = manager.GetCatalog(); auto* model = catalog.GetModel("nemotron-speech-streaming-en-0.6b"); @@ -131,9 +132,12 @@ int main(int argc, char* argv[]) { } std::cout << "Downloading model (if needed)..." << std::endl; - model->Download([](float pct) { - std::cout << "\rDownloading: " << pct << "% " << std::flush; - }); + model->Download( + [](float pct) { + std::cout << "\rDownloading: " << pct << "% " << std::flush; + return true; + }, + isCancellationRequested); std::cout << std::endl; std::cout << "Loading model..." << std::endl; model->Load(); diff --git a/sdk/cpp/include/foundry_local_manager.h b/sdk/cpp/include/foundry_local_manager.h index 51af7f161..ac04a335a 100644 --- a/sdk/cpp/include/foundry_local_manager.h +++ b/sdk/cpp/include/foundry_local_manager.h @@ -83,15 +83,21 @@ namespace foundry_local { /// Download and register all available execution providers. /// @param progressCallback Optional callback invoked with (ep_name, percent) during download. + /// @param isCancellationRequested Optional callback checked on each progress update. Return true to cancel. /// @return Result describing which EPs were registered or failed. - EpDownloadResult DownloadAndRegisterEps(EpProgressCallback progressCallback = nullptr) const; + EpDownloadResult DownloadAndRegisterEps( + EpProgressCallback progressCallback = nullptr, + CancellationCallback isCancellationRequested = nullptr) const; /// Download and register specific execution providers by name. /// @param names EP names to download (as returned by DiscoverEps). /// @param progressCallback Optional callback invoked with (ep_name, percent) during download. + /// @param isCancellationRequested Optional callback checked on each progress update. Return true to cancel. /// @return Result describing which EPs were registered or failed. - EpDownloadResult DownloadAndRegisterEps(const std::vector& names, - EpProgressCallback progressCallback = nullptr) const; + EpDownloadResult DownloadAndRegisterEps( + const std::vector& names, + EpProgressCallback progressCallback = nullptr, + CancellationCallback isCancellationRequested = nullptr) const; private: explicit Manager(Configuration configuration, ILogger* logger); diff --git a/sdk/cpp/include/model.h b/sdk/cpp/include/model.h index b52fae76c..052cf45e9 100644 --- a/sdk/cpp/include/model.h +++ b/sdk/cpp/include/model.h @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -33,6 +34,7 @@ namespace foundry_local { #endif using DownloadProgressCallback = std::function; + using CancellationCallback = std::function; class IModel { public: @@ -43,7 +45,11 @@ namespace foundry_local { virtual bool IsLoaded() const = 0; virtual bool IsCached() const = 0; virtual const std::filesystem::path& GetPath() const = 0; - virtual void Download(DownloadProgressCallback onProgress = nullptr) = 0; + + /// Download the model, with an optional cancellation callback checked on each progress update. + /// Return true from isCancellationRequested to cancel the in-progress download. + virtual void Download(DownloadProgressCallback onProgress = nullptr, + CancellationCallback isCancellationRequested = nullptr) = 0; virtual void Load() = 0; virtual void Unload() = 0; virtual void RemoveFromCache() = 0; @@ -123,7 +129,8 @@ namespace foundry_local { const ModelInfo& GetInfo() const; const std::filesystem::path& GetPath() const override; - void Download(DownloadProgressCallback onProgress = nullptr) override; + void Download(DownloadProgressCallback onProgress = nullptr, + CancellationCallback isCancellationRequested = nullptr) override; void Load() override; bool IsLoaded() const override; @@ -158,8 +165,9 @@ namespace foundry_local { bool IsLoaded() const override { return SelectedVariant().IsLoaded(); } bool IsCached() const override { return SelectedVariant().IsCached(); } const std::filesystem::path& GetPath() const override { return SelectedVariant().GetPath(); } - void Download(DownloadProgressCallback onProgress = nullptr) override { - SelectedVariant().Download(std::move(onProgress)); + void Download(DownloadProgressCallback onProgress = nullptr, + CancellationCallback isCancellationRequested = nullptr) override { + SelectedVariant().Download(std::move(onProgress), std::move(isCancellationRequested)); } void Load() override { SelectedVariant().Load(); } void Unload() override { SelectedVariant().Unload(); } diff --git a/sdk/cpp/include/openai/audio_client.h b/sdk/cpp/include/openai/audio_client.h index c58fad1c3..5de7bd265 100644 --- a/sdk/cpp/include/openai/audio_client.h +++ b/sdk/cpp/include/openai/audio_client.h @@ -33,9 +33,13 @@ namespace foundry_local { const std::string& GetModelId() const noexcept { return modelId_; } AudioCreateTranscriptionResponse TranscribeAudio(const std::filesystem::path& audioFilePath) const; + AudioCreateTranscriptionResponse TranscribeAudio(const std::filesystem::path& audioFilePath, + std::function isCancellationRequested) const; using StreamCallback = std::function; void TranscribeAudioStreaming(const std::filesystem::path& audioFilePath, const StreamCallback& onChunk) const; + void TranscribeAudioStreaming(const std::filesystem::path& audioFilePath, const StreamCallback& onChunk, + std::function isCancellationRequested) const; /// Create a new live audio transcription session for streaming PCM audio. std::unique_ptr CreateLiveTranscriptionSession() const; diff --git a/sdk/cpp/include/openai/chat_client.h b/sdk/cpp/include/openai/chat_client.h index 8a114e6ab..5302eb4ac 100644 --- a/sdk/cpp/include/openai/chat_client.h +++ b/sdk/cpp/include/openai/chat_client.h @@ -89,17 +89,30 @@ namespace foundry_local { ChatCompletionCreateResponse CompleteChat(gsl::span messages, const ChatSettings& settings) const; + ChatCompletionCreateResponse CompleteChat(gsl::span messages, + const ChatSettings& settings, + std::function isCancellationRequested) const; ChatCompletionCreateResponse CompleteChat(gsl::span messages, gsl::span tools, const ChatSettings& settings) const; + ChatCompletionCreateResponse CompleteChat(gsl::span messages, + gsl::span tools, + const ChatSettings& settings, + std::function isCancellationRequested) const; using StreamCallback = std::function; void CompleteChatStreaming(gsl::span messages, const ChatSettings& settings, const StreamCallback& onChunk) const; + void CompleteChatStreaming(gsl::span messages, const ChatSettings& settings, + const StreamCallback& onChunk, + std::function isCancellationRequested) const; void CompleteChatStreaming(gsl::span messages, gsl::span tools, const ChatSettings& settings, const StreamCallback& onChunk) const; + void CompleteChatStreaming(gsl::span messages, gsl::span tools, + const ChatSettings& settings, const StreamCallback& onChunk, + std::function isCancellationRequested) const; private: OpenAIChatClient(gsl::not_null core, std::string_view modelId, diff --git a/sdk/cpp/sample/main.cpp b/sdk/cpp/sample/main.cpp index 7c377da99..b82047800 100644 --- a/sdk/cpp/sample/main.cpp +++ b/sdk/cpp/sample/main.cpp @@ -4,6 +4,8 @@ #include "foundry_local.h" #include +#include +#include #include #include #include @@ -14,6 +16,18 @@ using namespace foundry_local; +namespace { +std::atomic g_cancelRequested{false}; + +void SignalHandler(int /*signum*/) { + g_cancelRequested.store(true); +} + +bool IsCancellationRequested() { + return g_cancelRequested.load(); +} +} // namespace + // --------------------------------------------------------------------------- // Logger // --------------------------------------------------------------------------- @@ -118,7 +132,8 @@ void ChatNonStreaming(Manager& manager, const std::string& alias) { PreferCpuVariant(*concreteModel); } - model->Download([](float pct) { printf("\rDownloading: %5.1f%%", pct); fflush(stdout); return true; }); + model->Download([](float pct) { printf("\rDownloading: %5.1f%%", pct); fflush(stdout); return true; }, + IsCancellationRequested); std::cout << "\n"; model->Load(); @@ -211,7 +226,8 @@ void TranscribeAudio(Manager& manager, const std::string& alias, const std::stri PreferCpuVariant(*concreteModel); } - model->Download([](float pct) { printf("\rDownloading: %5.1f%%", pct); fflush(stdout); return true; }); + model->Download([](float pct) { printf("\rDownloading: %5.1f%%", pct); fflush(stdout); return true; }, + IsCancellationRequested); std::cout << "\n"; model->Load(); @@ -263,7 +279,8 @@ void ChatWithToolCalling(Manager& manager, const std::string& alias) { PreferCpuVariant(*concreteModel); } - model->Download([](float pct) { printf("\rDownloading: %5.1f%%", pct); fflush(stdout); return true; }); + model->Download([](float pct) { printf("\rDownloading: %5.1f%%", pct); fflush(stdout); return true; }, + IsCancellationRequested); std::cout << "\n"; model->Load(); @@ -376,6 +393,8 @@ int main(int argc, char* argv[]) { const std::string audioPath = (argc > 3) ? argv[3] : ""; try { + std::signal(SIGINT, SignalHandler); + StdLogger logger; Manager::Create({"SampleApp"}, &logger); auto& manager = Manager::Instance(); @@ -399,7 +418,7 @@ int main(int argc, char* argv[]) { } printf("\r %-30s %5.1f%%", epName.c_str(), percent); fflush(stdout); - }); + }, IsCancellationRequested); if (!currentEp.empty()) std::cout << "\n"; } else { std::cout << "\nNo execution providers to download.\n"; diff --git a/sdk/cpp/src/audio_client.cpp b/sdk/cpp/src/audio_client.cpp index 1656834aa..90c245a16 100644 --- a/sdk/cpp/src/audio_client.cpp +++ b/sdk/cpp/src/audio_client.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -26,13 +27,19 @@ namespace foundry_local { AudioCreateTranscriptionResponse OpenAIAudioClient::TranscribeAudio( const std::filesystem::path& audioFilePath) const { + return TranscribeAudio(audioFilePath, nullptr); + } + + AudioCreateTranscriptionResponse OpenAIAudioClient::TranscribeAudio( + const std::filesystem::path& audioFilePath, std::function isCancellationRequested) const { nlohmann::json openAiReq = {{"Model", modelId_}, {"FileName", audioFilePath.string()}}; CoreInteropRequest req("audio_transcribe"); req.AddParam("OpenAICreateRequest", openAiReq.dump()); std::string json = req.ToJson(); - auto coreResponse = core_->call(req.Command(), *logger_, &json); + auto coreResponse = core_->call(req.Command(), *logger_, &json, nullptr, nullptr, + std::move(isCancellationRequested)); if (coreResponse.HasError()) { throw Exception("Audio transcription failed: " + coreResponse.error, *logger_); } @@ -45,6 +52,12 @@ namespace foundry_local { void OpenAIAudioClient::TranscribeAudioStreaming(const std::filesystem::path& audioFilePath, const StreamCallback& onChunk) const { + TranscribeAudioStreaming(audioFilePath, onChunk, nullptr); + } + + void OpenAIAudioClient::TranscribeAudioStreaming(const std::filesystem::path& audioFilePath, + const StreamCallback& onChunk, + std::function isCancellationRequested) const { nlohmann::json openAiReq = {{"Model", modelId_}, {"FileName", audioFilePath.string()}}; CoreInteropRequest req("audio_transcribe"); req.AddParam("OpenAICreateRequest", openAiReq.dump()); @@ -58,7 +71,8 @@ namespace foundry_local { chunk.text = text; onChunk(chunk); }, - "Streaming audio transcription failed: "); + "Streaming audio transcription failed: ", + std::move(isCancellationRequested)); } OpenAIAudioClient::OpenAIAudioClient(const IModel& model) diff --git a/sdk/cpp/src/chat_client.cpp b/sdk/cpp/src/chat_client.cpp index 5c19a0ba5..a05418d02 100644 --- a/sdk/cpp/src/chat_client.cpp +++ b/sdk/cpp/src/chat_client.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -97,16 +98,30 @@ namespace foundry_local { return CompleteChat(messages, {}, settings); } + ChatCompletionCreateResponse OpenAIChatClient::CompleteChat(gsl::span messages, + const ChatSettings& settings, + std::function isCancellationRequested) const { + return CompleteChat(messages, {}, settings, std::move(isCancellationRequested)); + } + ChatCompletionCreateResponse OpenAIChatClient::CompleteChat(gsl::span messages, gsl::span tools, const ChatSettings& settings) const { + return CompleteChat(messages, tools, settings, nullptr); + } + + ChatCompletionCreateResponse OpenAIChatClient::CompleteChat(gsl::span messages, + gsl::span tools, + const ChatSettings& settings, + std::function isCancellationRequested) const { std::string openAiReqJson = BuildChatRequestJson(messages, tools, settings, /*stream=*/false); CoreInteropRequest req("chat_completions"); req.AddParam("OpenAICreateRequest", openAiReqJson); std::string json = req.ToJson(); - auto response = core_->call(req.Command(), *logger_, &json); + auto response = core_->call(req.Command(), *logger_, &json, nullptr, nullptr, + std::move(isCancellationRequested)); if (response.HasError()) { throw Exception("Chat completion failed: " + response.error, *logger_); } @@ -119,9 +134,22 @@ namespace foundry_local { CompleteChatStreaming(messages, {}, settings, onChunk); } + void OpenAIChatClient::CompleteChatStreaming(gsl::span messages, const ChatSettings& settings, + const StreamCallback& onChunk, + std::function isCancellationRequested) const { + CompleteChatStreaming(messages, {}, settings, onChunk, std::move(isCancellationRequested)); + } + void OpenAIChatClient::CompleteChatStreaming(gsl::span messages, gsl::span tools, const ChatSettings& settings, const StreamCallback& onChunk) const { + CompleteChatStreaming(messages, tools, settings, onChunk, nullptr); + } + + void OpenAIChatClient::CompleteChatStreaming(gsl::span messages, + gsl::span tools, const ChatSettings& settings, + const StreamCallback& onChunk, + std::function isCancellationRequested) const { std::string openAiReqJson = BuildChatRequestJson(messages, tools, settings, /*stream=*/true); CoreInteropRequest req("chat_completions"); @@ -134,7 +162,8 @@ namespace foundry_local { auto parsed = nlohmann::json::parse(chunk).get(); onChunk(parsed); }, - "Streaming chat completion failed: "); + "Streaming chat completion failed: ", + std::move(isCancellationRequested)); } OpenAIChatClient::OpenAIChatClient(const IModel& model) diff --git a/sdk/cpp/src/core.h b/sdk/cpp/src/core.h index eb598373d..c2f95df86 100644 --- a/sdk/cpp/src/core.h +++ b/sdk/cpp/src/core.h @@ -10,6 +10,10 @@ #include #include #include +#include +#include +#include +#include #ifdef _WIN32 #include @@ -136,6 +140,87 @@ namespace foundry_local { #endif } + inline bool IsUserCancellationError(const std::string& error) { + auto end = error.find_last_not_of(" \t\r\n."); + auto start = error.find_first_not_of(" \t\r\n"); + if (start == std::string::npos || end == std::string::npos || end < start) { + return false; + } + return error.substr(start, end - start + 1) == "Operation was cancelled by user"; + } + + class CancellationContextGuard { + public: + CancellationContextGuard(create_cancellation_context_fn createFn, + cancel_cancellation_context_fn cancelFn, + release_cancellation_context_fn releaseFn, + std::function isCancellationRequested) + : cancelFn_(cancelFn), releaseFn_(releaseFn), + isCancellationRequested_(std::move(isCancellationRequested)) { + if (!createFn || !cancelFn_ || !releaseFn_ || !isCancellationRequested_) { + return; + } + + id_ = createFn(); + if (id_ == 0) { + return; + } + + if (isCancellationRequested_()) { + cancelFn_(id_); + cancellationObserved_.store(true, std::memory_order_relaxed); + return; + } + + watcher_ = std::thread([this]() { + while (!stopWatcher_.load(std::memory_order_relaxed)) { + bool shouldCancel = false; + try { + shouldCancel = isCancellationRequested_ && isCancellationRequested_(); + } + catch (...) { + shouldCancel = true; + } + if (shouldCancel) { + cancellationObserved_.store(true, std::memory_order_relaxed); + cancelFn_(id_); + return; + } + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + }); + } + + CancellationContextGuard(const CancellationContextGuard&) = delete; + CancellationContextGuard& operator=(const CancellationContextGuard&) = delete; + + ~CancellationContextGuard() { + stopWatcher_.store(true, std::memory_order_relaxed); + if (watcher_.joinable()) { + watcher_.join(); + } + + if (id_ != 0 && releaseFn_) { + releaseFn_(id_); + } + } + + bool IsAvailable() const noexcept { return id_ != 0; } + int64_t Id() const noexcept { return id_; } + bool CancellationObserved() const noexcept { + return cancellationObserved_.load(std::memory_order_relaxed); + } + + private: + int64_t id_ = 0; + cancel_cancellation_context_fn cancelFn_{}; + release_cancellation_context_fn releaseFn_{}; + std::function isCancellationRequested_; + std::atomic_bool stopWatcher_{false}; + std::atomic_bool cancellationObserved_{false}; + std::thread watcher_; + }; + } // namespace struct Core : Internal::IFoundryLocalCore { @@ -159,16 +244,26 @@ namespace foundry_local { void unload() override { module_.reset(); execCmd_ = nullptr; + execCancellableCmd_ = nullptr; execCbCmd_ = nullptr; + execCbCancellableCmd_ = nullptr; execBinaryCmd_ = nullptr; + execBinaryCancellableCmd_ = nullptr; + createCancellationContextCmd_ = nullptr; + cancelCancellationContextCmd_ = nullptr; + releaseCancellationContextCmd_ = nullptr; freeResCmd_ = nullptr; } CoreResponse call(std::string_view command, ILogger& logger, const std::string* dataArgument = nullptr, - NativeCallbackFn callback = nullptr, void* data = nullptr) const override { + NativeCallbackFn callback = nullptr, void* data = nullptr, + std::function isCancellationRequested = nullptr) const override { if (!static_cast(module_) || !execCmd_ || !execCbCmd_ || !freeResCmd_) { throw Exception("Core is not loaded. Cannot call command: " + std::string(command), logger); } + if (isCancellationRequested && isCancellationRequested()) { + throw Exception("Operation cancelled", logger); + } RequestBuffer request{}; request.Command = command.empty() ? nullptr : command.data(); @@ -186,16 +281,38 @@ namespace foundry_local { }; std::unique_ptr responseGuard(&response, safeDeleter); + const bool canUseCancellableCommand = + createCancellationContextCmd_ && cancelCancellationContextCmd_ && releaseCancellationContextCmd_ && + ((callback != nullptr && execCbCancellableCmd_) || (callback == nullptr && execCancellableCmd_)); + CancellationContextGuard cancellationContext( + canUseCancellableCommand ? createCancellationContextCmd_ : nullptr, + canUseCancellableCommand ? cancelCancellationContextCmd_ : nullptr, + canUseCancellableCommand ? releaseCancellationContextCmd_ : nullptr, + std::move(isCancellationRequested)); + if (callback != nullptr) { - execCbCmd_(&request, &response, reinterpret_cast(callback), data); + if (cancellationContext.IsAvailable() && execCbCancellableCmd_) { + execCbCancellableCmd_(&request, &response, callback, data, cancellationContext.Id()); + } + else { + execCbCmd_(&request, &response, callback, data); + } } else { - execCmd_(&request, &response); + if (cancellationContext.IsAvailable() && execCancellableCmd_) { + execCancellableCmd_(&request, &response, cancellationContext.Id()); + } + else { + execCmd_(&request, &response); + } } CoreResponse result; if (response.Error && response.ErrorLength > 0) { result.error.assign(static_cast(response.Error), response.ErrorLength); + if (cancellationContext.CancellationObserved() && IsUserCancellationError(result.error)) { + result.error = "Operation cancelled"; + } return result; } @@ -256,8 +373,14 @@ namespace foundry_local { private: SharedLibHandle module_; execute_command_fn execCmd_{}; + execute_command_cancellable_fn execCancellableCmd_{}; execute_command_with_callback_fn execCbCmd_{}; + execute_command_with_callback_cancellable_fn execCbCancellableCmd_{}; execute_command_with_binary_fn execBinaryCmd_{}; + execute_command_with_binary_cancellable_fn execBinaryCancellableCmd_{}; + create_cancellation_context_fn createCancellationContextCmd_{}; + cancel_cancellation_context_fn cancelCancellationContextCmd_{}; + release_cancellation_context_fn releaseCancellationContextCmd_{}; free_response_fn freeResCmd_{}; void LoadFromPath(const std::filesystem::path& path) { @@ -271,10 +394,22 @@ namespace foundry_local { } execCmd_ = reinterpret_cast(RequireProc(m.handle, "execute_command")); + execCancellableCmd_ = reinterpret_cast( + OptionalProc(m.handle, "execute_command_cancellable")); execCbCmd_ = reinterpret_cast( RequireProc(m.handle, "execute_command_with_callback")); + execCbCancellableCmd_ = reinterpret_cast( + OptionalProc(m.handle, "execute_command_with_callback_cancellable")); execBinaryCmd_ = reinterpret_cast( OptionalProc(m.handle, "execute_command_with_binary")); + execBinaryCancellableCmd_ = reinterpret_cast( + OptionalProc(m.handle, "execute_command_with_binary_cancellable")); + createCancellationContextCmd_ = reinterpret_cast( + OptionalProc(m.handle, "create_cancellation_context")); + cancelCancellationContextCmd_ = reinterpret_cast( + OptionalProc(m.handle, "cancel_cancellation_context")); + releaseCancellationContextCmd_ = reinterpret_cast( + OptionalProc(m.handle, "release_cancellation_context")); freeResCmd_ = reinterpret_cast(RequireProc(m.handle, "free_response")); module_ = std::move(m); diff --git a/sdk/cpp/src/core_helpers.h b/sdk/cpp/src/core_helpers.h index c46f294a2..55809550b 100644 --- a/sdk/cpp/src/core_helpers.h +++ b/sdk/cpp/src/core_helpers.h @@ -6,12 +6,15 @@ #pragma once +#include #include #include #include #include #include +#include #include +#include #include @@ -47,38 +50,82 @@ namespace foundry_local::detail { return core->call(command, logger, &payload, callback, userData); } + inline bool TryParseFloatToken(std::string_view token, float& value) { + if (token.empty()) { + return false; + } + + const auto* begin = token.data(); + const auto* end = begin + token.size(); + const auto result = std::from_chars(begin, end, value); + return result.ec == std::errc{} && result.ptr == end; + } + + inline bool TryParseDoubleToken(std::string_view token, double& value) { + if (token.empty()) { + return false; + } + + const auto* begin = token.data(); + const auto* end = begin + token.size(); + const auto result = std::from_chars(begin, end, value); + return result.ec == std::errc{} && result.ptr == end; + } + // Serialize + call with a streaming chunk handler. // Wraps the caller-supplied onChunk with the native callback boilerplate - // (null/length checks, exception capture, rethrow after the call). + // (null/length checks, exception capture, cancellation, rethrow after the call). // The errorContext string is used to prefix any core-layer error message. inline CoreResponse CallWithStreamingCallback(Internal::IFoundryLocalCore* core, std::string_view command, - const std::string& payload, ILogger& logger, - const std::function& onChunk, - std::string_view errorContext) { + const std::string* payload, ILogger& logger, + const std::function& onChunk, + std::string_view errorContext, + CancellationCallback isCancellationRequested = nullptr) { struct State { - const std::function* cb; + const std::function* cb; + CancellationCallback isCancellationRequested; + bool cancellationObserved = false; std::exception_ptr exception; - } state{&onChunk, nullptr}; + } state{&onChunk, std::move(isCancellationRequested), false, nullptr}; - auto nativeCallback = [](void* data, int32_t len, void* user) -> int { - if (!data || len <= 0) + auto nativeCallback = [](const void* data, int32_t len, void* user) -> int32_t { + auto* st = static_cast(user); + if (!st) { return 0; + } - auto* st = static_cast(user); - if (st->exception) + if (st->exception || st->cancellationObserved) { + return 1; + } + + if (!data || len <= 0) return 0; try { + if (st->isCancellationRequested && st->isCancellationRequested()) { + st->cancellationObserved = true; + return 1; + } + std::string chunk(static_cast(data), static_cast(len)); - (*(st->cb))(chunk); + if (!(*(st->cb))(chunk)) { + st->cancellationObserved = true; + return 1; + } } catch (...) { st->exception = std::current_exception(); + return 1; } + return 0; }; - auto response = core->call(command, logger, &payload, +nativeCallback, &state); + auto response = core->call(command, logger, payload, +nativeCallback, &state, state.isCancellationRequested); + if (state.cancellationObserved) { + throw Exception("Operation cancelled", logger); + } + if (response.HasError()) { throw Exception(std::string(errorContext) + response.error, logger); } @@ -90,6 +137,38 @@ namespace foundry_local::detail { return response; } + inline CoreResponse CallWithStreamingCallback(Internal::IFoundryLocalCore* core, std::string_view command, + const std::string* payload, ILogger& logger, + const std::function& onChunk, + std::string_view errorContext, + CancellationCallback isCancellationRequested = nullptr) { + const std::function continuingOnChunk = + [&onChunk](const std::string& chunk) { + onChunk(chunk); + return true; + }; + return CallWithStreamingCallback(core, command, payload, logger, continuingOnChunk, errorContext, + std::move(isCancellationRequested)); + } + + inline CoreResponse CallWithStreamingCallback(Internal::IFoundryLocalCore* core, std::string_view command, + const std::string& payload, ILogger& logger, + const std::function& onChunk, + std::string_view errorContext, + CancellationCallback isCancellationRequested = nullptr) { + return CallWithStreamingCallback(core, command, &payload, logger, onChunk, errorContext, + std::move(isCancellationRequested)); + } + + inline CoreResponse CallWithStreamingCallback(Internal::IFoundryLocalCore* core, std::string_view command, + const std::string& payload, ILogger& logger, + const std::function& onChunk, + std::string_view errorContext, + CancellationCallback isCancellationRequested = nullptr) { + return CallWithStreamingCallback(core, command, &payload, logger, onChunk, errorContext, + std::move(isCancellationRequested)); + } + // Overload: allow Params object directly inline CoreResponse CallWithParams(Internal::IFoundryLocalCore* core, std::string_view command, const nlohmann::json& params, ILogger& logger) { diff --git a/sdk/cpp/src/flcore_native.h b/sdk/cpp/src/flcore_native.h index 2ea792b9e..f678b3c7c 100644 --- a/sdk/cpp/src/flcore_native.h +++ b/sdk/cpp/src/flcore_native.h @@ -5,10 +5,12 @@ #include #include -#ifdef _WIN32 - #define FL_CDECL __cdecl -#else - #define FL_CDECL +#ifndef FL_CDECL + #ifdef _WIN32 + #define FL_CDECL __cdecl + #else + #define FL_CDECL + #endif #endif extern "C" @@ -29,8 +31,9 @@ extern "C" int32_t ErrorLength; }; - // Callback signature: int(*)(void* data, int length, void* userData) — returns 0 to continue, 1 to cancel - using UserCallbackFn = int(__cdecl*)(void*, int32_t, void*); + // Callback signature: int32_t(*)(const void* data, int length, void* userData) + // Return 0 to continue, 1 to cancel. + using UserCallbackFn = int32_t(FL_CDECL*)(const void*, int32_t, void*); struct StreamingRequestBuffer { const void* Command; @@ -43,9 +46,19 @@ extern "C" // Exported function pointer types using execute_command_fn = void(FL_CDECL*)(RequestBuffer*, ResponseBuffer*); - using execute_command_with_callback_fn = void(FL_CDECL*)(RequestBuffer*, ResponseBuffer*, void* /*callback*/, - void* /*userData*/); + using execute_command_cancellable_fn = void(FL_CDECL*)(RequestBuffer*, ResponseBuffer*, int64_t); + using execute_command_with_callback_fn = void(FL_CDECL*)(RequestBuffer*, ResponseBuffer*, + UserCallbackFn /*callback*/, + void* /*userData*/); + using execute_command_with_callback_cancellable_fn = void(FL_CDECL*)(RequestBuffer*, ResponseBuffer*, + UserCallbackFn /*callback*/, + void* /*userData*/, + int64_t); using execute_command_with_binary_fn = void(FL_CDECL*)(StreamingRequestBuffer*, ResponseBuffer*); + using execute_command_with_binary_cancellable_fn = void(FL_CDECL*)(StreamingRequestBuffer*, ResponseBuffer*, int64_t); + using create_cancellation_context_fn = int64_t(FL_CDECL*)(); + using cancel_cancellation_context_fn = int32_t(FL_CDECL*)(int64_t); + using release_cancellation_context_fn = int32_t(FL_CDECL*)(int64_t); using free_response_fn = void(FL_CDECL*)(ResponseBuffer*); static_assert(std::is_standard_layout::value, "RequestBuffer must be standard layout"); diff --git a/sdk/cpp/src/foundry_local_internal_core.h b/sdk/cpp/src/foundry_local_internal_core.h index 368096dec..5ef987a1f 100644 --- a/sdk/cpp/src/foundry_local_internal_core.h +++ b/sdk/cpp/src/foundry_local_internal_core.h @@ -4,15 +4,25 @@ #pragma once #include +#include #include #include #include "logger.h" +#ifndef FL_CDECL + #ifdef _WIN32 + #define FL_CDECL __cdecl + #else + #define FL_CDECL + #endif +#endif + namespace foundry_local { /// Native callback signature used by the core DLL interop. /// Parameters: (data, dataLength, userData). - using NativeCallbackFn = int (*)(void*, int32_t, void*); + /// Return 0 to continue, 1 to cancel the native operation. + using NativeCallbackFn = int32_t(FL_CDECL*)(const void*, int32_t, void*); /// Value returned by IFoundryLocalCore::call(). /// On success, `data` contains the response payload and `error` is empty. @@ -29,8 +39,9 @@ namespace foundry_local { virtual ~IFoundryLocalCore() = default; virtual CoreResponse call(std::string_view command, ILogger& logger, - const std::string* dataArgument = nullptr, NativeCallbackFn callback = nullptr, - void* data = nullptr) const = 0; + const std::string* dataArgument = nullptr, NativeCallbackFn callback = nullptr, + void* data = nullptr, + std::function isCancellationRequested = nullptr) const = 0; virtual CoreResponse callWithBinary(std::string_view command, ILogger& logger, const std::string* dataArgument, @@ -40,4 +51,4 @@ namespace foundry_local { }; } // namespace Internal -} // namespace foundry_local \ No newline at end of file +} // namespace foundry_local diff --git a/sdk/cpp/src/foundry_local_manager.cpp b/sdk/cpp/src/foundry_local_manager.cpp index 2c1e6177c..2d634f4ca 100644 --- a/sdk/cpp/src/foundry_local_manager.cpp +++ b/sdk/cpp/src/foundry_local_manager.cpp @@ -6,7 +6,7 @@ #include #include #include -#include +#include #include #include @@ -15,6 +15,7 @@ #include "foundry_local_internal_core.h" #include "foundry_local_exception.h" #include "core_interop_request.h" +#include "core_helpers.h" #include "core.h" #include "logger.h" @@ -163,39 +164,16 @@ void Manager::Cleanup() noexcept { return result; } - namespace { - struct EpCallbackContext { - EpProgressCallback* callback; - }; - - int EpProgressNativeCallback(void* data, int32_t dataLength, void* userData) { - auto* ctx = static_cast(userData); - if (!ctx || !ctx->callback || !*ctx->callback) return 0; - if (!data || dataLength <= 0) return 0; - - std::string progressStr(static_cast(data), static_cast(dataLength)); - auto sepIndex = progressStr.find('|'); - if (sepIndex != std::string::npos) { - std::string name = progressStr.substr(0, sepIndex); - // Parse percent using locale-independent std::from_chars - const auto* begin = progressStr.data() + sepIndex + 1; - const auto* end = progressStr.data() + progressStr.size(); - double percent = 0.0; - auto [ptr, ec] = std::from_chars(begin, end, percent); - if (ec == std::errc{}) { - (*ctx->callback)(name, percent); - } - } - return 0; - } + EpDownloadResult Manager::DownloadAndRegisterEps( + EpProgressCallback progressCallback, + CancellationCallback isCancellationRequested) const { + return DownloadAndRegisterEps({}, std::move(progressCallback), std::move(isCancellationRequested)); } - EpDownloadResult Manager::DownloadAndRegisterEps(EpProgressCallback progressCallback) const { - return DownloadAndRegisterEps({}, std::move(progressCallback)); - } - - EpDownloadResult Manager::DownloadAndRegisterEps(const std::vector& names, - EpProgressCallback progressCallback) const { + EpDownloadResult Manager::DownloadAndRegisterEps( + const std::vector& names, + EpProgressCallback progressCallback, + CancellationCallback isCancellationRequested) const { std::string requestData; std::string* requestDataPtr = nullptr; @@ -212,16 +190,32 @@ void Manager::Cleanup() noexcept { } CoreResponse response; - if (progressCallback) { - EpCallbackContext ctx{&progressCallback}; - response = core_->call("download_and_register_eps", *logger_, - requestDataPtr, EpProgressNativeCallback, &ctx); + if (progressCallback || isCancellationRequested) { + auto onChunk = [&progressCallback](const std::string& chunk) { + if (!progressCallback) { + return; + } + + const auto sep = chunk.find('|'); + if (sep == std::string::npos) { + return; + } + + double percent = 0.0; + if (detail::TryParseDoubleToken(std::string_view(chunk).substr(sep + 1), percent)) { + progressCallback(chunk.substr(0, sep), percent); + } + }; + + response = detail::CallWithStreamingCallback(core_.get(), "download_and_register_eps", + requestDataPtr, *logger_, onChunk, + "Error downloading execution providers: ", + std::move(isCancellationRequested)); } else { response = core_->call("download_and_register_eps", *logger_, requestDataPtr); - } - - if (response.HasError()) { - throw Exception(std::string("Error downloading execution providers: ") + response.error, *logger_); + if (response.HasError()) { + throw Exception(std::string("Error downloading execution providers: ") + response.error, *logger_); + } } EpDownloadResult result; diff --git a/sdk/cpp/src/model.cpp b/sdk/cpp/src/model.cpp index e09f55414..9cc7f3672 100644 --- a/sdk/cpp/src/model.cpp +++ b/sdk/cpp/src/model.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include @@ -67,38 +68,32 @@ namespace foundry_local { return false; } - void ModelVariant::Download(DownloadProgressCallback onProgress) { + void ModelVariant::Download(DownloadProgressCallback onProgress, + CancellationCallback isCancellationRequested) { if (IsCached()) { logger_->Log(LogLevel::Information, "Model '" + info_.name + "' is already cached, skipping download."); return; } - if (onProgress) { - struct ProgressState { - DownloadProgressCallback* cb; - ILogger* logger; - } state{&onProgress, logger_}; - - auto nativeCallback = [](void* data, int32_t len, void* user) -> int { - if (!data || len <= 0) - return 0; - auto* st = static_cast(user); - std::string perc(static_cast(data), static_cast(len)); - try { - float value = std::stof(perc); - (*(st->cb))(value); + if (onProgress || isCancellationRequested) { + std::function onChunk = [&onProgress](const std::string& chunk) { + if (!onProgress) { + return true; } - catch (...) { - st->logger->Log(LogLevel::Warning, "Failed to parse download progress: " + perc); + + float value = 0.0f; + if (TryParseFloatToken(chunk, value)) { + if (!onProgress(value)) { + return false; + } } - return 0; + return true; }; - auto response = CallWithJsonAndCallback(core_, "download_model", MakeModelParams(info_.name), *logger_, - +nativeCallback, &state); - if (response.HasError()) { - throw Exception("Error downloading model [" + info_.name + "]: " + response.error, *logger_); - } + const std::string payload = MakeModelParams(info_.name).dump(); + CallWithStreamingCallback(core_, "download_model", payload, *logger_, onChunk, + "Error downloading model [" + info_.name + "]: ", + std::move(isCancellationRequested)); } else { auto response = CallWithJson(core_, "download_model", MakeModelParams(info_.name), *logger_); diff --git a/sdk/cpp/test/ep_test.cpp b/sdk/cpp/test/ep_test.cpp index 7649b1efd..78c9ecaf6 100644 --- a/sdk/cpp/test/ep_test.cpp +++ b/sdk/cpp/test/ep_test.cpp @@ -72,7 +72,7 @@ static EpDownloadResult TestDownloadAndRegisterEps( struct EpCallbackContext { EpProgressCallback* callback; }; - auto nativeCb = [](void* data, int32_t dataLength, void* userData) -> int { + auto nativeCb = [](const void* data, int32_t dataLength, void* userData) -> int32_t { auto* ctx = static_cast(userData); if (!ctx || !ctx->callback || !*ctx->callback) return 0; if (!data || dataLength <= 0) return 0; @@ -249,9 +249,9 @@ TEST_F(DownloadAndRegisterEpsTest, CallbackInvokedWithProgressData) { [](std::string_view, const std::string*, NativeCallbackFn callback, void* userData) -> std::string { if (callback) { std::string p1 = "WebGpuExecutionProvider|25.0"; - callback(const_cast(p1.data()), static_cast(p1.size()), userData); + callback(p1.data(), static_cast(p1.size()), userData); std::string p2 = "WebGpuExecutionProvider|100.0"; - callback(const_cast(p2.data()), static_cast(p2.size()), userData); + callback(p2.data(), static_cast(p2.size()), userData); } return R"({"Success": true, "Status": "OK", "RegisteredEps": ["WebGpuExecutionProvider"], "FailedEps": []})"; }); diff --git a/sdk/cpp/test/mock_core.h b/sdk/cpp/test/mock_core.h index e7b5f84ca..3861941db 100644 --- a/sdk/cpp/test/mock_core.h +++ b/sdk/cpp/test/mock_core.h @@ -56,7 +56,9 @@ namespace foundry_local::Testing { // IFoundryLocalCore implementation CoreResponse call(std::string_view command, ILogger& /*logger*/, const std::string* dataArgument = nullptr, - NativeCallbackFn callback = nullptr, void* data = nullptr) const override { + NativeCallbackFn callback = nullptr, void* data = nullptr, + std::function isCancellationRequested = nullptr) const override { + (void)isCancellationRequested; std::string cmd(command); const_cast(this)->callCounts_[cmd]++; @@ -126,7 +128,9 @@ namespace foundry_local::Testing { } CoreResponse call(std::string_view command, ILogger& /*logger*/, const std::string* /*dataArgument*/ = nullptr, - NativeCallbackFn /*callback*/ = nullptr, void* /*data*/ = nullptr) const override { + NativeCallbackFn /*callback*/ = nullptr, void* /*data*/ = nullptr, + std::function isCancellationRequested = nullptr) const override { + (void)isCancellationRequested; CoreResponse resp; diff --git a/sdk/cpp/test/model_variant_test.cpp b/sdk/cpp/test/model_variant_test.cpp index c631f8ff3..82bdea5c0 100644 --- a/sdk/cpp/test/model_variant_test.cpp +++ b/sdk/cpp/test/model_variant_test.cpp @@ -9,6 +9,7 @@ #include "foundry_local_exception.h" #include +#include using namespace foundry_local; using namespace foundry_local::Testing; @@ -136,7 +137,7 @@ TEST_F(ModelVariantTest, Download_WithCallback_ReturnsZeroToContinue) { [](std::string_view, const std::string*, NativeCallbackFn callback, void* userData) -> std::string { if (callback && userData) { std::string progress = "50"; - int result = callback(progress.data(), static_cast(progress.size()), userData); + const int32_t result = callback(progress.data(), static_cast(progress.size()), userData); EXPECT_EQ(0, result) << "Callback should return 0 (continue), not " << result; } return ""; @@ -146,6 +147,84 @@ TEST_F(ModelVariantTest, Download_WithCallback_ReturnsZeroToContinue) { variant.Download([&](float) { return true; }); } +TEST_F(ModelVariantTest, Download_ParsesNumericProgressChunk) { + core_.OnCall("get_cached_models", R"([])"); + core_.OnCall("download_model", + [](std::string_view, const std::string*, NativeCallbackFn callback, void* userData) -> std::string { + if (callback && userData) { + std::string progress = "12.5"; + callback(progress.data(), static_cast(progress.size()), userData); + } + return ""; + }); + + auto variant = MakeVariant("test-model"); + std::vector progressValues; + variant.Download([&](float pct) { + progressValues.push_back(pct); + return true; + }); + + ASSERT_EQ(1u, progressValues.size()); + EXPECT_NEAR(12.5f, progressValues[0], 0.01f); +} + +TEST_F(ModelVariantTest, Download_WithCancellationRequestsNativeCancel) { + core_.OnCall("get_cached_models", R"([])"); + bool nativeCallbackCancelled = false; + core_.OnCall("download_model", + [&](std::string_view, const std::string*, NativeCallbackFn callback, void* userData) -> std::string { + if (callback && userData) { + std::string progress = "50"; + nativeCallbackCancelled = + callback(progress.data(), static_cast(progress.size()), userData) == 1; + } + return ""; + }); + + auto variant = MakeVariant("test-model"); + EXPECT_THROW(variant.Download(nullptr, [] { return true; }), Exception); + EXPECT_TRUE(nativeCallbackCancelled); +} + +TEST_F(ModelVariantTest, Download_ProgressCallbackFalseRequestsNativeCancel) { + core_.OnCall("get_cached_models", R"([])"); + bool nativeCallbackCancelled = false; + core_.OnCall("download_model", + [&](std::string_view, const std::string*, NativeCallbackFn callback, void* userData) -> std::string { + if (callback && userData) { + std::string progress = "50"; + nativeCallbackCancelled = + callback(progress.data(), static_cast(progress.size()), userData) == 1; + } + return ""; + }); + + auto variant = MakeVariant("test-model"); + EXPECT_THROW(variant.Download([](float) { return false; }), Exception); + EXPECT_TRUE(nativeCallbackCancelled); +} + +TEST_F(ModelVariantTest, Download_CancellationAfterFinalCallbackDoesNotCancelSuccessfulDownload) { + core_.OnCall("get_cached_models", R"([])"); + core_.OnCall("download_model", + [](std::string_view, const std::string*, NativeCallbackFn callback, void* userData) -> std::string { + if (callback && userData) { + std::string progress = "100"; + callback(progress.data(), static_cast(progress.size()), userData); + } + return ""; + }); + + auto variant = MakeVariant("test-model"); + bool cancel = false; + EXPECT_NO_THROW(variant.Download([&](float) { + cancel = true; + return true; + }, [&] { return cancel; })); + EXPECT_TRUE(cancel); +} + TEST_F(ModelVariantTest, RemoveFromCache_CallsCore) { core_.OnCall("remove_cached_model", ""); auto variant = MakeVariant("test-model"); diff --git a/sdk/cs/README.md b/sdk/cs/README.md index 276ffb716..9493eea0b 100644 --- a/sdk/cs/README.md +++ b/sdk/cs/README.md @@ -99,6 +99,18 @@ await mgr.DownloadAndRegisterEpsAsync((epName, percent) => Console.WriteLine(); ``` +#### Cancelling model and EP downloads + +Pass a `CancellationToken` to either download API. Cancellation is observed on the next progress update. + +```csharp +// mgr and model already initialized +using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(5)); + +await mgr.DownloadAndRegisterEpsAsync(ct: cts.Token); +await model.DownloadAsync(ct: cts.Token); +``` + Catalog access no longer blocks on EP downloads. Call `DownloadAndRegisterEpsAsync` explicitly when you need hardware-accelerated execution providers. ## Quick Start diff --git a/sdk/cs/src/Detail/CoreInterop.Modern.cs b/sdk/cs/src/Detail/CoreInterop.Modern.cs index 1774c0d3d..252d04608 100644 --- a/sdk/cs/src/Detail/CoreInterop.Modern.cs +++ b/sdk/cs/src/Detail/CoreInterop.Modern.cs @@ -22,6 +22,12 @@ internal partial class CoreInterop [UnmanagedCallConv(CallConvs = new[] { typeof(System.Runtime.CompilerServices.CallConvCdecl) })] private static unsafe partial void CoreExecuteCommand(RequestBuffer* request, ResponseBuffer* response); + [LibraryImport(LibraryName, EntryPoint = "execute_command_cancellable")] + [UnmanagedCallConv(CallConvs = new[] { typeof(System.Runtime.CompilerServices.CallConvCdecl) })] + private static unsafe partial void CoreExecuteCommandCancellable(RequestBuffer* request, + ResponseBuffer* response, + long cancellationContextId); + [LibraryImport(LibraryName, EntryPoint = "execute_command_with_callback")] [UnmanagedCallConv(CallConvs = new[] { typeof(System.Runtime.CompilerServices.CallConvCdecl) })] private static unsafe partial void CoreExecuteCommandWithCallback(RequestBuffer* nativeRequest, @@ -29,11 +35,37 @@ private static unsafe partial void CoreExecuteCommandWithCallback(RequestBuffer* nint callbackPtr, nint userData); + [LibraryImport(LibraryName, EntryPoint = "execute_command_with_callback_cancellable")] + [UnmanagedCallConv(CallConvs = new[] { typeof(System.Runtime.CompilerServices.CallConvCdecl) })] + private static unsafe partial void CoreExecuteCommandWithCallbackCancellable(RequestBuffer* nativeRequest, + ResponseBuffer* nativeResponse, + nint callbackPtr, + nint userData, + long cancellationContextId); + [LibraryImport(LibraryName, EntryPoint = "execute_command_with_binary")] [UnmanagedCallConv(CallConvs = new[] { typeof(System.Runtime.CompilerServices.CallConvCdecl) })] private static unsafe partial void CoreExecuteCommandWithBinary(StreamingRequestBuffer* nativeRequest, ResponseBuffer* nativeResponse); + [LibraryImport(LibraryName, EntryPoint = "execute_command_with_binary_cancellable")] + [UnmanagedCallConv(CallConvs = new[] { typeof(System.Runtime.CompilerServices.CallConvCdecl) })] + private static unsafe partial void CoreExecuteCommandWithBinaryCancellable(StreamingRequestBuffer* nativeRequest, + ResponseBuffer* nativeResponse, + long cancellationContextId); + + [LibraryImport(LibraryName, EntryPoint = "create_cancellation_context")] + [UnmanagedCallConv(CallConvs = new[] { typeof(System.Runtime.CompilerServices.CallConvCdecl) })] + private static partial long CoreCreateCancellationContext(); + + [LibraryImport(LibraryName, EntryPoint = "cancel_cancellation_context")] + [UnmanagedCallConv(CallConvs = new[] { typeof(System.Runtime.CompilerServices.CallConvCdecl) })] + private static partial int CoreCancelCancellationContext(long cancellationContextId); + + [LibraryImport(LibraryName, EntryPoint = "release_cancellation_context")] + [UnmanagedCallConv(CallConvs = new[] { typeof(System.Runtime.CompilerServices.CallConvCdecl) })] + private static partial int CoreReleaseCancellationContext(long cancellationContextId); + [LibraryImport(LibraryName, EntryPoint = "audio_stream_start")] [UnmanagedCallConv(CallConvs = new[] { typeof(System.Runtime.CompilerServices.CallConvCdecl) })] private static unsafe partial void CoreAudioStreamStart(RequestBuffer* request, ResponseBuffer* response); diff --git a/sdk/cs/src/Detail/CoreInterop.NetStandard.cs b/sdk/cs/src/Detail/CoreInterop.NetStandard.cs index b96a258b1..a21514855 100644 --- a/sdk/cs/src/Detail/CoreInterop.NetStandard.cs +++ b/sdk/cs/src/Detail/CoreInterop.NetStandard.cs @@ -23,16 +23,42 @@ internal partial class CoreInterop [DllImport(LibraryName, EntryPoint = "execute_command", CallingConvention = CallingConvention.Cdecl)] private static unsafe extern void CoreExecuteCommand(RequestBuffer* request, ResponseBuffer* response); + [DllImport(LibraryName, EntryPoint = "execute_command_cancellable", CallingConvention = CallingConvention.Cdecl)] + private static unsafe extern void CoreExecuteCommandCancellable(RequestBuffer* request, + ResponseBuffer* response, + long cancellationContextId); + [DllImport(LibraryName, EntryPoint = "execute_command_with_callback", CallingConvention = CallingConvention.Cdecl)] private static unsafe extern void CoreExecuteCommandWithCallback(RequestBuffer* nativeRequest, ResponseBuffer* nativeResponse, nint callbackPtr, nint userData); + [DllImport(LibraryName, EntryPoint = "execute_command_with_callback_cancellable", CallingConvention = CallingConvention.Cdecl)] + private static unsafe extern void CoreExecuteCommandWithCallbackCancellable(RequestBuffer* nativeRequest, + ResponseBuffer* nativeResponse, + nint callbackPtr, + nint userData, + long cancellationContextId); + [DllImport(LibraryName, EntryPoint = "execute_command_with_binary", CallingConvention = CallingConvention.Cdecl)] private static unsafe extern void CoreExecuteCommandWithBinary(StreamingRequestBuffer* nativeRequest, ResponseBuffer* nativeResponse); + [DllImport(LibraryName, EntryPoint = "execute_command_with_binary_cancellable", CallingConvention = CallingConvention.Cdecl)] + private static unsafe extern void CoreExecuteCommandWithBinaryCancellable(StreamingRequestBuffer* nativeRequest, + ResponseBuffer* nativeResponse, + long cancellationContextId); + + [DllImport(LibraryName, EntryPoint = "create_cancellation_context", CallingConvention = CallingConvention.Cdecl)] + private static extern long CoreCreateCancellationContext(); + + [DllImport(LibraryName, EntryPoint = "cancel_cancellation_context", CallingConvention = CallingConvention.Cdecl)] + private static extern int CoreCancelCancellationContext(long cancellationContextId); + + [DllImport(LibraryName, EntryPoint = "release_cancellation_context", CallingConvention = CallingConvention.Cdecl)] + private static extern int CoreReleaseCancellationContext(long cancellationContextId); + [DllImport(LibraryName, EntryPoint = "audio_stream_start", CallingConvention = CallingConvention.Cdecl)] private static unsafe extern void CoreAudioStreamStart(RequestBuffer* request, ResponseBuffer* response); diff --git a/sdk/cs/src/Detail/CoreInterop.cs b/sdk/cs/src/Detail/CoreInterop.cs index 7239a48e4..9a64fff37 100644 --- a/sdk/cs/src/Detail/CoreInterop.cs +++ b/sdk/cs/src/Detail/CoreInterop.cs @@ -33,6 +33,8 @@ internal partial class CoreInterop : ICoreInterop private static IntPtr genaiLibHandle = IntPtr.Zero; private static IntPtr ortLibHandle = IntPtr.Zero; private static readonly NativeCallbackFn handleCallbackDelegate = HandleCallback; + private static int cancellableCommandsUnavailable; + private const string UserCancellationError = "Operation was cancelled by user"; [UnmanagedFunctionPointer(CallingConvention.Cdecl)] private unsafe delegate void ExecuteCommandDelegate(RequestBuffer* req, ResponseBuffer* resp); @@ -47,6 +49,69 @@ public CallbackHelper(CallbackFn callback) } } + private sealed class CancellationContext : IDisposable + { + private readonly CancellationTokenRegistration _registration; + private long _id; + + private CancellationContext(long id, CancellationTokenRegistration registration) + { + _id = id; + _registration = registration; + } + + public long Id => _id; + + public static CancellationContext? Create(CancellationToken cancellationToken) + { + if (!cancellationToken.CanBeCanceled || !CancellableCommandsAvailable) + { + return null; + } + + long id; + try + { + id = CoreCreateCancellationContext(); + } + catch (EntryPointNotFoundException) + { + MarkCancellableCommandsUnavailable(); + return null; + } + + if (id == 0) + { + MarkCancellableCommandsUnavailable(); + return null; + } + + try + { + var registration = cancellationToken.Register( + static state => CancelCancellationContext((long)state!), + id); + return new CancellationContext(id, registration); + } + catch + { + ReleaseCancellationContext(id); + throw; + } + } + + public void Dispose() + { + _registration.Dispose(); + + var id = Interlocked.Exchange(ref _id, 0); + if (id != 0) + { + ReleaseCancellationContext(id); + } + } + } + static CoreInterop() { InitializeNativeLibraryResolver(); @@ -110,6 +175,119 @@ private static void LoadOrtDllsIfInSameDir(string path) Debug.WriteLine($"Loaded GenAI: {loadedGenAI} handle={genaiLibHandle}"); } + private static bool CancellableCommandsAvailable => + Volatile.Read(ref cancellableCommandsUnavailable) == 0; + + private static void MarkCancellableCommandsUnavailable() + { + Volatile.Write(ref cancellableCommandsUnavailable, 1); + } + + private static void CancelCancellationContext(long cancellationContextId) + { + try + { + _ = CoreCancelCancellationContext(cancellationContextId); + } + catch (EntryPointNotFoundException) + { + MarkCancellableCommandsUnavailable(); + } + } + + private static void ReleaseCancellationContext(long cancellationContextId) + { + try + { + _ = CoreReleaseCancellationContext(cancellationContextId); + } + catch (EntryPointNotFoundException) + { + MarkCancellableCommandsUnavailable(); + } + } + + private static unsafe bool TryExecuteCommandCancellable(RequestBuffer* request, + ResponseBuffer* response, + long cancellationContextId) + { + if (!CancellableCommandsAvailable) + { + return false; + } + + try + { + CoreExecuteCommandCancellable(request, response, cancellationContextId); + return true; + } + catch (EntryPointNotFoundException) + { + MarkCancellableCommandsUnavailable(); + return false; + } + } + + private static unsafe bool TryExecuteCommandWithCallbackCancellable(RequestBuffer* request, + ResponseBuffer* response, + nint callbackPtr, + nint userData, + long cancellationContextId) + { + if (!CancellableCommandsAvailable) + { + return false; + } + + try + { + CoreExecuteCommandWithCallbackCancellable(request, response, callbackPtr, userData, cancellationContextId); + return true; + } + catch (EntryPointNotFoundException) + { + MarkCancellableCommandsUnavailable(); + return false; + } + } + + private static unsafe bool TryExecuteCommandWithBinaryCancellable(StreamingRequestBuffer* request, + ResponseBuffer* response, + long cancellationContextId) + { + if (!CancellableCommandsAvailable) + { + return false; + } + + try + { + CoreExecuteCommandWithBinaryCancellable(request, response, cancellationContextId); + return true; + } + catch (EntryPointNotFoundException) + { + MarkCancellableCommandsUnavailable(); + return false; + } + } + + private static void ThrowIfCancellationResponse(Response result, CancellationToken cancellationToken) + { + if (cancellationToken.IsCancellationRequested && + result.Error != null && + IsUserCancellationError(result.Error)) + { + throw new OperationCanceledException(cancellationToken); + } + } + + private static bool IsUserCancellationError(string error) + { + var normalized = error.Trim().TrimEnd('.'); + return string.Equals(normalized, UserCancellationError, StringComparison.OrdinalIgnoreCase); + } + private static int HandleCallback(nint data, int length, nint callbackHelper) { var callbackData = string.Empty; @@ -150,10 +328,13 @@ private static int HandleCallback(nint data, int length, nint callbackHelper) } public Response ExecuteCommandImpl(string commandName, string? commandInput, - CallbackFn? callback = null) + CallbackFn? callback = null, + CancellationToken cancellationToken = default) { try { + cancellationToken.ThrowIfCancellationRequested(); + byte[] commandBytes = System.Text.Encoding.UTF8.GetBytes(commandName); IntPtr commandPtr = Marshal.AllocHGlobal(commandBytes.Length); Marshal.Copy(commandBytes, 0, commandPtr, commandBytes.Length); @@ -177,6 +358,8 @@ public Response ExecuteCommandImpl(string commandName, string? commandInput, }; ResponseBuffer response = default; + Exception? callbackException = null; + using var cancellationContext = CancellationContext.Create(cancellationToken); if (callback != null) { @@ -190,24 +373,48 @@ public Response ExecuteCommandImpl(string commandName, string? commandInput, var helperHandle = GCHandle.Alloc(helper); var helperPtr = GCHandle.ToIntPtr(helperHandle); - unsafe + try { - CoreExecuteCommandWithCallback(&request, &response, funcPtr, helperPtr); + unsafe + { + if (cancellationContext != null) + { + if (!TryExecuteCommandWithCallbackCancellable( + &request, &response, funcPtr, helperPtr, cancellationContext.Id)) + { + cancellationToken.ThrowIfCancellationRequested(); + CoreExecuteCommandWithCallback(&request, &response, funcPtr, helperPtr); + } + } + else + { + CoreExecuteCommandWithCallback(&request, &response, funcPtr, helperPtr); + } + } } - - helperHandle.Free(); - - if (helper.Exception != null) + finally { - throw new FoundryLocalException("Exception in callback handler. See InnerException for details", - helper.Exception); + helperHandle.Free(); } + + callbackException = helper.Exception; } else { unsafe { - CoreExecuteCommand(&request, &response); + if (cancellationContext != null) + { + if (!TryExecuteCommandCancellable(&request, &response, cancellationContext.Id)) + { + cancellationToken.ThrowIfCancellationRequested(); + CoreExecuteCommand(&request, &response); + } + } + else + { + CoreExecuteCommand(&request, &response); + } } } @@ -239,6 +446,19 @@ public Response ExecuteCommandImpl(string commandName, string? commandInput, Marshal.FreeHGlobal(inputPtr!.Value); } + if (callbackException != null) + { + if (callbackException is OperationCanceledException canceledException) + { + throw canceledException; + } + + throw new FoundryLocalException("Exception in callback handler. See InnerException for details", + callbackException); + } + + ThrowIfCancellationResponse(result, cancellationToken); + return result; } catch (Exception ex) when (ex is not OperationCanceledException) @@ -265,7 +485,8 @@ public Task ExecuteCommandAsync(string commandName, CoreInteropRequest CancellationToken? cancellationToken = null) { var ct = cancellationToken ?? CancellationToken.None; - return Task.Run(() => ExecuteCommand(commandName, commandInput), ct); + var commandInputJson = commandInput?.ToJson(); + return Task.Run(() => ExecuteCommandImpl(commandName, commandInputJson, cancellationToken: ct), ct); } public Task ExecuteCommandWithCallbackAsync(string commandName, CoreInteropRequest? commandInput, @@ -273,7 +494,8 @@ public Task ExecuteCommandWithCallbackAsync(string commandName, CoreIn CancellationToken? cancellationToken = null) { var ct = cancellationToken ?? CancellationToken.None; - return Task.Run(() => ExecuteCommandWithCallback(commandName, commandInput, callback), ct); + var commandInputJson = commandInput?.ToJson(); + return Task.Run(() => ExecuteCommandImpl(commandName, commandInputJson, callback, ct), ct); } /// diff --git a/sdk/cs/src/Detail/ModelVariant.cs b/sdk/cs/src/Detail/ModelVariant.cs index 250c601a2..442817228 100644 --- a/sdk/cs/src/Detail/ModelVariant.cs +++ b/sdk/cs/src/Detail/ModelVariant.cs @@ -6,6 +6,8 @@ namespace Microsoft.AI.Foundry.Local; +using System.Globalization; + using Microsoft.AI.Foundry.Local.Detail; using Microsoft.Extensions.Logging; @@ -63,8 +65,8 @@ public async Task DownloadAsync(Action? downloadProgress = null, CancellationToken? ct = null) { await Utils.CallWithExceptionHandling(() => DownloadImplAsync(downloadProgress, ct), - $"Error downloading model {Id}", _logger) - .ConfigureAwait(false); + $"Error downloading model {Id}", _logger) + .ConfigureAwait(false); } public async Task LoadAsync(CancellationToken? ct = null) @@ -144,16 +146,26 @@ private async Task DownloadImplAsync(Action? downloadProgress = null, }; ICoreInterop.Response? response; + var useCallbackPath = downloadProgress != null || (ct?.CanBeCanceled ?? false); - if (downloadProgress == null) - { - response = await _coreInterop.ExecuteCommandAsync("download_model", request, ct).ConfigureAwait(false); - } - else + if (useCallbackPath) { var callback = new ICoreInterop.CallbackFn(progressString => { - if (float.TryParse(progressString, out var progress)) + if (ct is CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + } + + if (downloadProgress == null) + { + return; + } + + if (float.TryParse(progressString, + NumberStyles.Float, + CultureInfo.InvariantCulture, + out var progress)) { downloadProgress(progress); } @@ -162,6 +174,10 @@ private async Task DownloadImplAsync(Action? downloadProgress = null, response = await _coreInterop.ExecuteCommandWithCallbackAsync("download_model", request, callback, ct).ConfigureAwait(false); } + else + { + response = await _coreInterop.ExecuteCommandAsync("download_model", request, ct).ConfigureAwait(false); + } if (response.Error != null) { diff --git a/sdk/cs/src/FoundryLocalManager.cs b/sdk/cs/src/FoundryLocalManager.cs index b014850f6..855aed4a2 100644 --- a/sdk/cs/src/FoundryLocalManager.cs +++ b/sdk/cs/src/FoundryLocalManager.cs @@ -6,6 +6,7 @@ namespace Microsoft.AI.Foundry.Local; using System; +using System.Globalization; using System.Text.Json; using System.Threading.Tasks; @@ -373,20 +374,27 @@ private async Task DownloadAndRegisterEpsImplAsync(IEnumerable ICoreInterop.Response result; - if (progressCallback != null) + var useCallbackPath = progressCallback != null || (ct?.CanBeCanceled ?? false); + + if (useCallbackPath) { var callback = new ICoreInterop.CallbackFn(progressString => { + if (ct is CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + } + var sepIndex = progressString.IndexOf('|'); if (sepIndex >= 0) { var name = progressString[..sepIndex]; if (double.TryParse(progressString[(sepIndex + 1)..], - System.Globalization.NumberStyles.Float, - System.Globalization.CultureInfo.InvariantCulture, + NumberStyles.Float, + CultureInfo.InvariantCulture, out var percent)) { - progressCallback(string.IsNullOrEmpty(name) ? "" : name, percent); + progressCallback?.Invoke(string.IsNullOrEmpty(name) ? "" : name, percent); } } }); diff --git a/sdk/cs/test/FoundryLocal.Tests/DownloadCancellationTests.cs b/sdk/cs/test/FoundryLocal.Tests/DownloadCancellationTests.cs new file mode 100644 index 000000000..04738b3b2 --- /dev/null +++ b/sdk/cs/test/FoundryLocal.Tests/DownloadCancellationTests.cs @@ -0,0 +1,121 @@ +// -------------------------------------------------------------------------------------------------------------------- +// +// Copyright (c) Microsoft. All rights reserved. +// +// -------------------------------------------------------------------------------------------------------------------- + +namespace Microsoft.AI.Foundry.Local.Tests; + +using Microsoft.AI.Foundry.Local.Detail; + +using Microsoft.Extensions.Logging; + +using Moq; + +internal sealed class DownloadCancellationTests +{ + [Test] + public async Task ModelVariantDownload_WithCancellationToken_UsesCallbackPathAndPropagatesCancellation() + { + var modelInfo = new ModelInfo + { + Id = "test-model-cpu:1", + Name = "test-model-cpu", + Alias = "test-model", + Version = 1, + ProviderType = "AzureFoundry", + Uri = "azureml://registries/azureml/models/test-model-cpu/versions/1", + ModelType = "ONNX", + }; + + var modelLoadManager = new Mock(MockBehavior.Strict); + var coreInterop = new Mock(MockBehavior.Strict); + var logger = new Mock(); + using var cts = new CancellationTokenSource(); + + coreInterop.Setup(x => x.ExecuteCommandWithCallbackAsync( + It.Is(s => s == "download_model"), + It.Is(r => r != null && + r.Params != null && + r.Params.ContainsKey("Model") && + r.Params["Model"] == modelInfo.Id), + It.IsAny(), + It.IsAny())) + .Returns((string commandName, + CoreInteropRequest? request, + ICoreInterop.CallbackFn callback, + CancellationToken? cancellationToken) => + { + callback("10"); + cts.Cancel(); + callback("20"); + return Task.FromResult(new ICoreInterop.Response()); + }); + + IModel model = new ModelVariant(modelInfo, modelLoadManager.Object, coreInterop.Object, logger.Object); + + OperationCanceledException? caught = null; + try + { + await model.DownloadAsync(ct: cts.Token); + } + catch (OperationCanceledException ex) + { + caught = ex; + } + + await Assert.That(caught).IsNotNull(); + coreInterop.Verify(x => x.ExecuteCommandWithCallbackAsync( + "download_model", + It.IsAny(), + It.IsAny(), + It.IsAny()), + Times.Once); + coreInterop.Verify(x => x.ExecuteCommandAsync( + It.IsAny(), + It.IsAny(), + It.IsAny()), + Times.Never); + } + + [Test] + public async Task ModelVariantDownload_WithProgressChunk_ParsesInvariantFloat() + { + var modelInfo = new ModelInfo + { + Id = "test-model-cpu:1", + Name = "test-model-cpu", + Alias = "test-model", + Version = 1, + ProviderType = "AzureFoundry", + Uri = "azureml://registries/azureml/models/test-model-cpu/versions/1", + ModelType = "ONNX", + }; + + var modelLoadManager = new Mock(MockBehavior.Strict); + var coreInterop = new Mock(MockBehavior.Strict); + var logger = new Mock(); + + coreInterop.Setup(x => x.ExecuteCommandWithCallbackAsync( + It.Is(s => s == "download_model"), + It.IsAny(), + It.IsAny(), + It.IsAny())) + .Returns((string commandName, + CoreInteropRequest? request, + ICoreInterop.CallbackFn callback, + CancellationToken? cancellationToken) => + { + callback("12.5"); + return Task.FromResult(new ICoreInterop.Response()); + }); + + var model = new ModelVariant(modelInfo, modelLoadManager.Object, coreInterop.Object, logger.Object); + var progressValues = new List(); + + await model.DownloadAsync(progressValues.Add); + + await Assert.That(progressValues.Count).IsEqualTo(1); + await Assert.That(progressValues[0]).IsEqualTo(12.5f); + } +} diff --git a/sdk/js/README.md b/sdk/js/README.md index 26471cc8c..fad973353 100644 --- a/sdk/js/README.md +++ b/sdk/js/README.md @@ -77,6 +77,19 @@ await manager.downloadAndRegisterEps((epName, percent) => { process.stdout.write('\n'); ``` +#### Cancelling model and EP downloads + +Use an `AbortController` with either `downloadAndRegisterEps()` or `model.download()`. Aborting the signal rejects the in-progress download promise. + +```typescript +// manager and model already initialized +const controller = new AbortController(); +setTimeout(() => controller.abort(), 5000); + +await manager.downloadAndRegisterEps(controller.signal); +await model.download(controller.signal); +``` + Catalog access does not block on EP downloads. Call `downloadAndRegisterEps()` when you need hardware-accelerated execution providers. ## Quick Start @@ -336,4 +349,4 @@ See `test/README.md` for details on prerequisites and setup. npm run example ``` -This runs the chat completion example in `examples/chat-completion.ts`. \ No newline at end of file +This runs the chat completion example in `examples/chat-completion.ts`. diff --git a/sdk/js/native/foundry_local_napi.c b/sdk/js/native/foundry_local_napi.c index d84b3f67b..6e4033ab7 100644 --- a/sdk/js/native/foundry_local_napi.c +++ b/sdk/js/native/foundry_local_napi.c @@ -79,6 +79,12 @@ typedef void (*ExecuteCommandFn)( ResponseBuffer* response ); +typedef void (*ExecuteCommandCancellableFn)( + const RequestBuffer* request, + ResponseBuffer* response, + int64_t cancellationContextId +); + typedef void (*ExecuteCommandWithCallbackFn)( const RequestBuffer* request, ResponseBuffer* response, @@ -86,11 +92,29 @@ typedef void (*ExecuteCommandWithCallbackFn)( void* userData ); +typedef void (*ExecuteCommandWithCallbackCancellableFn)( + const RequestBuffer* request, + ResponseBuffer* response, + CallbackFn callback, + void* userData, + int64_t cancellationContextId +); + typedef void (*ExecuteCommandWithBinaryFn)( const StreamingRequestBuffer* request, ResponseBuffer* response ); +typedef void (*ExecuteCommandWithBinaryCancellableFn)( + const StreamingRequestBuffer* request, + ResponseBuffer* response, + int64_t cancellationContextId +); + +typedef int64_t (*CreateCancellationContextFn)(void); +typedef int32_t (*CancelCancellationContextFn)(int64_t cancellationContextId); +typedef int32_t (*ReleaseCancellationContextFn)(int64_t cancellationContextId); + /* ── Module state ─────────────────────────────────────────────────────── */ static lib_handle_t g_core_lib = NULL; @@ -100,6 +124,12 @@ static size_t g_dep_lib_count = 0; static ExecuteCommandFn g_execute_command = NULL; static ExecuteCommandWithCallbackFn g_execute_command_with_callback = NULL; static ExecuteCommandWithBinaryFn g_execute_command_with_binary = NULL; +static ExecuteCommandCancellableFn g_execute_command_cancellable = NULL; +static ExecuteCommandWithCallbackCancellableFn g_execute_command_with_callback_cancellable = NULL; +static ExecuteCommandWithBinaryCancellableFn g_execute_command_with_binary_cancellable = NULL; +static CreateCancellationContextFn g_create_cancellation_context = NULL; +static CancelCancellationContextFn g_cancel_cancellation_context = NULL; +static ReleaseCancellationContextFn g_release_cancellation_context = NULL; /* ── Platform-specific memory deallocation ────────────────────────────── */ @@ -235,6 +265,21 @@ static void cleanup_loaded_libs(void) { g_execute_command = NULL; g_execute_command_with_callback = NULL; g_execute_command_with_binary = NULL; + g_execute_command_cancellable = NULL; + g_execute_command_with_callback_cancellable = NULL; + g_execute_command_with_binary_cancellable = NULL; + g_create_cancellation_context = NULL; + g_cancel_cancellation_context = NULL; + g_release_cancellation_context = NULL; +} + +static int cancellable_commands_available(void) { + return g_create_cancellation_context && + g_cancel_cancellation_context && + g_release_cancellation_context && + g_execute_command_cancellable && + g_execute_command_with_callback_cancellable && + g_execute_command_with_binary_cancellable; } /* ── Helper: extract response and free native buffers ─────────────────── */ @@ -397,6 +442,91 @@ static napi_value napi_load_library(napi_env env, napi_callback_info info) { return NULL; } + g_create_cancellation_context = (CreateCancellationContextFn)LIB_SYM( + g_core_lib, "create_cancellation_context"); + g_cancel_cancellation_context = (CancelCancellationContextFn)LIB_SYM( + g_core_lib, "cancel_cancellation_context"); + g_release_cancellation_context = (ReleaseCancellationContextFn)LIB_SYM( + g_core_lib, "release_cancellation_context"); + g_execute_command_cancellable = (ExecuteCommandCancellableFn)LIB_SYM( + g_core_lib, "execute_command_cancellable"); + g_execute_command_with_callback_cancellable = (ExecuteCommandWithCallbackCancellableFn)LIB_SYM( + g_core_lib, "execute_command_with_callback_cancellable"); + g_execute_command_with_binary_cancellable = (ExecuteCommandWithBinaryCancellableFn)LIB_SYM( + g_core_lib, "execute_command_with_binary_cancellable"); + + napi_value undefined; + NAPI_CALL(env, napi_get_undefined(env, &undefined)); + return undefined; +} + +/* ── Cancellable command context helpers ──────────────────────────────── */ + +static napi_value napi_has_cancellable_commands(napi_env env, + napi_callback_info info) { + (void)info; + napi_value result; + NAPI_CALL(env, napi_get_boolean(env, cancellable_commands_available(), &result)); + return result; +} + +static napi_value napi_create_cancellation_context(napi_env env, + napi_callback_info info) { + (void)info; + if (!cancellable_commands_available()) { + napi_throw_error(env, NULL, "Cancellable commands are not supported by this native library"); + return NULL; + } + + int64_t id = g_create_cancellation_context(); + napi_value result; + NAPI_CALL(env, napi_create_int64(env, id, &result)); + return result; +} + +static napi_value napi_cancel_cancellation_context(napi_env env, + napi_callback_info info) { + if (!g_cancel_cancellation_context) { + napi_throw_error(env, NULL, "Cancellable commands are not supported by this native library"); + return NULL; + } + + size_t argc = 1; + napi_value argv[1]; + NAPI_CALL(env, napi_get_cb_info(env, info, &argc, argv, NULL, NULL)); + if (argc < 1) { + napi_throw_error(env, NULL, "cancelCancellationContext requires 1 argument (contextId)"); + return NULL; + } + + int64_t id = 0; + NAPI_CALL(env, napi_get_value_int64(env, argv[0], &id)); + (void)g_cancel_cancellation_context(id); + + napi_value undefined; + NAPI_CALL(env, napi_get_undefined(env, &undefined)); + return undefined; +} + +static napi_value napi_release_cancellation_context(napi_env env, + napi_callback_info info) { + if (!g_release_cancellation_context) { + napi_throw_error(env, NULL, "Cancellable commands are not supported by this native library"); + return NULL; + } + + size_t argc = 1; + napi_value argv[1]; + NAPI_CALL(env, napi_get_cb_info(env, info, &argc, argv, NULL, NULL)); + if (argc < 1) { + napi_throw_error(env, NULL, "releaseCancellationContext requires 1 argument (contextId)"); + return NULL; + } + + int64_t id = 0; + NAPI_CALL(env, napi_get_value_int64(env, argv[0], &id)); + (void)g_release_cancellation_context(id); + napi_value undefined; NAPI_CALL(env, napi_get_undefined(env, &undefined)); return undefined; @@ -554,6 +684,7 @@ typedef struct { size_t command_length; char* data; size_t data_length; + int64_t cancellation_context_id; ResponseBuffer response; napi_deferred deferred; napi_async_work work; @@ -575,7 +706,12 @@ static void async_execute(napi_env env, void* data) { work_data->response.Error = NULL; work_data->response.ErrorLength = 0; - g_execute_command(&req, &work_data->response); + if (work_data->cancellation_context_id != 0 && g_execute_command_cancellable) { + g_execute_command_cancellable(&req, &work_data->response, + work_data->cancellation_context_id); + } else { + g_execute_command(&req, &work_data->response); + } } /* Runs on the JS main thread after async_execute completes */ @@ -628,21 +764,21 @@ static void async_complete(napi_env env, napi_status status, void* data) { free(work_data); } -/* executeCommandAsync(command, dataJson) → Promise */ +/* executeCommandAsync(command, dataJson, cancellationContextId?) → Promise */ static napi_value napi_execute_command_async(napi_env env, - napi_callback_info info) { + napi_callback_info info) { if (!g_execute_command) { napi_throw_error(env, NULL, "Native library not loaded. Call loadLibrary() first."); return NULL; } - size_t argc = 2; - napi_value argv[2]; + size_t argc = 3; + napi_value argv[3]; NAPI_CALL(env, napi_get_cb_info(env, info, &argc, argv, NULL, NULL)); if (argc < 2) { napi_throw_error(env, NULL, - "executeCommandAsync requires 2 arguments (command, dataJson)"); + "executeCommandAsync requires at least 2 arguments (command, dataJson)"); return NULL; } @@ -679,6 +815,21 @@ static napi_value napi_execute_command_async(napi_env env, work_data->data = data_str; work_data->data_length = data_len; + if (argc >= 3) { + napi_valuetype vt; + NAPI_CALL(env, napi_typeof(env, argv[2], &vt)); + if (vt != napi_undefined && vt != napi_null) { + NAPI_CALL(env, napi_get_value_int64(env, argv[2], &work_data->cancellation_context_id)); + if (work_data->cancellation_context_id != 0 && !g_execute_command_cancellable) { + free(cmd); + free(data_str); + free(work_data); + napi_throw_error(env, NULL, "execute_command_cancellable is not supported by this native library"); + return NULL; + } + } + } + /* Create promise */ napi_value promise; napi_status st = napi_create_promise(env, &work_data->deferred, &promise); @@ -734,6 +885,7 @@ struct StreamingWorkData { size_t command_length; char* data; size_t data_length; + int64_t cancellation_context_id; /* Threadsafe function for streaming callback */ napi_threadsafe_function tsfn; @@ -836,9 +988,16 @@ static void streaming_execute(napi_env env, void* data) { work_data->response.Error = NULL; work_data->response.ErrorLength = 0; - g_execute_command_with_callback( - &req, &work_data->response, - streaming_native_callback, work_data); + if (work_data->cancellation_context_id != 0 && g_execute_command_with_callback_cancellable) { + g_execute_command_with_callback_cancellable( + &req, &work_data->response, + streaming_native_callback, work_data, + work_data->cancellation_context_id); + } else { + g_execute_command_with_callback( + &req, &work_data->response, + streaming_native_callback, work_data); + } } /* Runs on the JS main thread after streaming_execute completes */ @@ -971,7 +1130,7 @@ static bool streaming_setup(napi_env env, napi_value js_callback, return true; } -/* ── executeCommandStreaming(command, dataJson, callback) → Promise ───── */ +/* ── executeCommandStreaming(command, dataJson, callback, cancellationContextId?) → Promise ───── */ static napi_value napi_execute_command_streaming(napi_env env, napi_callback_info info) { @@ -980,13 +1139,13 @@ static napi_value napi_execute_command_streaming(napi_env env, return NULL; } - size_t argc = 3; - napi_value argv[3]; + size_t argc = 4; + napi_value argv[4]; NAPI_CALL(env, napi_get_cb_info(env, info, &argc, argv, NULL, NULL)); if (argc < 3) { napi_throw_error(env, NULL, - "executeCommandStreaming requires 3 arguments (command, dataJson, callback)"); + "executeCommandStreaming requires at least 3 arguments (command, dataJson, callback)"); return NULL; } @@ -1031,6 +1190,20 @@ static napi_value napi_execute_command_streaming(napi_env env, work_data->data = data_str; work_data->data_length = data_len; + if (argc >= 4) { + napi_valuetype vt; + NAPI_CALL(env, napi_typeof(env, argv[3], &vt)); + if (vt != napi_undefined && vt != napi_null) { + NAPI_CALL(env, napi_get_value_int64(env, argv[3], &work_data->cancellation_context_id)); + if (work_data->cancellation_context_id != 0 && !g_execute_command_with_callback_cancellable) { + streaming_cleanup(work_data, false); + napi_throw_error(env, NULL, + "execute_command_with_callback_cancellable is not supported by this native library"); + return NULL; + } + } + } + /* Setup phase: use manual status checks instead of NAPI_CALL so we can clean up work_data on failure. Once async work is queued successfully, streaming_complete owns all cleanup. */ @@ -1048,6 +1221,14 @@ static napi_value init(napi_env env, napi_value exports) { napi_property_descriptor props[] = { { "loadLibrary", NULL, napi_load_library, NULL, NULL, NULL, napi_default, NULL }, + { "hasCancellableCommands", NULL, napi_has_cancellable_commands, + NULL, NULL, NULL, napi_default, NULL }, + { "createCancellationContext", NULL, napi_create_cancellation_context, + NULL, NULL, NULL, napi_default, NULL }, + { "cancelCancellationContext", NULL, napi_cancel_cancellation_context, + NULL, NULL, NULL, napi_default, NULL }, + { "releaseCancellationContext", NULL, napi_release_cancellation_context, + NULL, NULL, NULL, napi_default, NULL }, { "executeCommand", NULL, napi_execute_command, NULL, NULL, NULL, napi_default, NULL }, { "executeCommandAsync", NULL, napi_execute_command_async, NULL, diff --git a/sdk/js/src/detail/coreInterop.ts b/sdk/js/src/detail/coreInterop.ts index 72013815c..fea872895 100644 --- a/sdk/js/src/detail/coreInterop.ts +++ b/sdk/js/src/detail/coreInterop.ts @@ -12,10 +12,30 @@ const require = createRequire(import.meta.url); interface NativeAddon { loadLibrary(corePath: string, depPaths?: string[]): void; + hasCancellableCommands?: () => boolean; + createCancellationContext?: () => number; + cancelCancellationContext?: (contextId: number) => void; + releaseCancellationContext?: (contextId: number) => void; executeCommand(command: string, dataJson: string): string; - executeCommandAsync(command: string, dataJson: string): Promise; + executeCommandAsync(command: string, dataJson: string, cancellationContextId?: number): Promise; executeCommandWithBinary(command: string, dataJson: string, binaryBuffer: Buffer): string; - executeCommandStreaming(command: string, dataJson: string, callback: (chunk: string) => void): Promise; + executeCommandStreaming( + command: string, + dataJson: string, + callback: (chunk: string) => void, + cancellationContextId?: number + ): Promise; +} + +function createAbortError(): Error { + const error = new Error('Operation cancelled'); + error.name = 'AbortError'; + return error; +} + +function isUserCancellationError(error: unknown): boolean { + const message = error instanceof Error ? error.message : String(error); + return message.includes('Operation was cancelled by user'); } function loadAddon(): NativeAddon { @@ -120,9 +140,45 @@ export class CoreInterop { * Asynchronously execute a native command without blocking the event loop. * Runs the native call on a libuv worker thread. */ - public executeCommandAsync(command: string, params?: any): Promise { + public async executeCommandAsync(command: string, params?: any, signal?: AbortSignal): Promise { + if (signal?.aborted) { + throw createAbortError(); + } + const dataStr = params ? JSON.stringify(params) : ''; - return this.addon.executeCommandAsync(command, dataStr); + const cancellationContextId = this.createCancellationContext(signal); + let abortListener: (() => void) | undefined; + let aborted = false; + + if (signal && cancellationContextId !== undefined) { + abortListener = () => { + aborted = true; + this.addon.cancelCancellationContext?.(cancellationContextId); + }; + signal.addEventListener('abort', abortListener, { once: true }); + } + + try { + const result = await this.addon.executeCommandAsync(command, dataStr, cancellationContextId); + if (aborted) { + throw createAbortError(); + } + + return result; + } catch (error) { + if (signal?.aborted && isUserCancellationError(error)) { + throw createAbortError(); + } + + throw error; + } finally { + if (signal && abortListener) { + signal.removeEventListener('abort', abortListener); + } + if (cancellationContextId !== undefined) { + this.addon.releaseCancellationContext?.(cancellationContextId); + } + } } /** @@ -136,9 +192,72 @@ export class CoreInterop { return this.addon.executeCommandWithBinary(command, dataStr, binBuf); } - public executeCommandStreaming(command: string, params: any, callback: (chunk: string) => void): Promise { + public async executeCommandStreaming( + command: string, + params: any, + callback: (chunk: string) => void, + signal?: AbortSignal + ): Promise { + if (signal?.aborted) { + throw createAbortError(); + } + const dataStr = params ? JSON.stringify(params) : ''; - return this.addon.executeCommandStreaming(command, dataStr, callback); + let cancelled = false; + const cancellationContextId = this.createCancellationContext(signal); + let abortListener: (() => void) | undefined; + + if (signal && cancellationContextId !== undefined) { + abortListener = () => { + cancelled = true; + this.addon.cancelCancellationContext?.(cancellationContextId); + }; + signal.addEventListener('abort', abortListener, { once: true }); + } + + const wrappedCallback = (chunk: string) => { + if (signal?.aborted) { + cancelled = true; + throw createAbortError(); + } + + callback(chunk); + }; + + try { + const result = await this.addon.executeCommandStreaming( + command, + dataStr, + wrappedCallback, + cancellationContextId + ); + if (cancelled) { + throw createAbortError(); + } + + return result; + } catch (error) { + if (cancelled || (signal?.aborted && isUserCancellationError(error))) { + throw createAbortError(); + } + + throw error; + } finally { + if (signal && abortListener) { + signal.removeEventListener('abort', abortListener); + } + if (cancellationContextId !== undefined) { + this.addon.releaseCancellationContext?.(cancellationContextId); + } + } + } + + private createCancellationContext(signal?: AbortSignal): number | undefined { + if (!signal || !this.addon.hasCancellableCommands?.()) { + return undefined; + } + + return this.addon.createCancellationContext?.(); } } diff --git a/sdk/js/src/detail/model.ts b/sdk/js/src/detail/model.ts index ffd962db5..e70c0703c 100644 --- a/sdk/js/src/detail/model.ts +++ b/sdk/js/src/detail/model.ts @@ -125,10 +125,14 @@ export class Model implements IModel { /** * Downloads the currently selected variant. - * @param progressCallback - Optional callback to report download progress. + * @param progressCallbackOrSignal - Optional progress callback or AbortSignal. + * @param signal - Optional AbortSignal when a progress callback is provided. */ - public download(progressCallback?: (progress: number) => void): Promise { - return this.selectedVariant.download(progressCallback); + public download( + progressCallbackOrSignal?: ((progress: number) => void) | AbortSignal, + signal?: AbortSignal + ): Promise { + return this.selectedVariant.download(progressCallbackOrSignal, signal); } /** @@ -202,4 +206,4 @@ export class Model implements IModel { public createResponsesClient(baseUrl: string): ResponsesClient { return this.selectedVariant.createResponsesClient(baseUrl); } -} \ No newline at end of file +} diff --git a/sdk/js/src/detail/modelVariant.ts b/sdk/js/src/detail/modelVariant.ts index af150bb81..7f78353ac 100644 --- a/sdk/js/src/detail/modelVariant.ts +++ b/sdk/js/src/detail/modelVariant.ts @@ -107,19 +107,38 @@ export class ModelVariant implements IModel { /** * Downloads the model variant. - * @param progressCallback - Optional callback to report download progress (0-100). - */ - public async download(progressCallback?: (progress: number) => void): Promise { + * @param progressCallbackOrSignal - Optional progress callback (0-100) or AbortSignal. + * @param signal - Optional AbortSignal when a progress callback is provided. + */ + public async download( + progressCallbackOrSignal?: ((progress: number) => void) | AbortSignal, + signal?: AbortSignal + ): Promise { + const progressCallback = typeof progressCallbackOrSignal === 'function' + ? progressCallbackOrSignal + : undefined; + const abortSignal = typeof progressCallbackOrSignal === 'function' + ? signal + : progressCallbackOrSignal ?? signal; const request = { Params: { Model: this._modelInfo.id } }; - if (!progressCallback) { + if (!progressCallback && !abortSignal) { await this.coreInterop.executeCommandAsync("download_model", request); } else { + // Use the streaming path when progress or cancellation is needed. + // Provide a no-op callback when only cancellation is requested so + // the native callback mechanism is engaged. + const cb = progressCallback ?? (() => {}); await this.coreInterop.executeCommandStreaming("download_model", request, (chunk: string) => { - const progress = parseFloat(chunk); - if (!isNaN(progress)) { - progressCallback(progress); + const progressChunk = chunk.trim(); + if (progressChunk.length === 0) { + return; + } + + const progress = Number(progressChunk); + if (!Number.isNaN(progress)) { + cb(progress); } - }); + }, abortSignal); } } diff --git a/sdk/js/src/foundryLocalManager.ts b/sdk/js/src/foundryLocalManager.ts index f3224e656..dc8c7d712 100644 --- a/sdk/js/src/foundryLocalManager.ts +++ b/sdk/js/src/foundryLocalManager.ts @@ -5,6 +5,13 @@ import { Catalog } from './catalog.js'; import { ResponsesClient } from './openai/responsesClient.js'; import { EpInfo, EpDownloadResult } from './types.js'; +function isAbortSignal(value: unknown): value is AbortSignal { + return typeof value === 'object' + && value !== null + && 'aborted' in value + && typeof (value as AbortSignal).aborted === 'boolean'; +} + /** * The main entry point for the Foundry Local SDK. * Manages the initialization of the core system and provides access to the Catalog and ModelLoadManager. @@ -178,18 +185,38 @@ export class FoundryLocalManager { * @returns A promise that resolves with an EpDownloadResult describing the outcome. */ public downloadAndRegisterEps(): Promise; + /** + * Downloads and registers execution providers. + * @param signal - Optional AbortSignal used to cancel an in-progress download. + * @returns A promise that resolves with an EpDownloadResult describing the outcome. + */ + public downloadAndRegisterEps(signal: AbortSignal): Promise; /** * Downloads and registers execution providers. * @param names - Array of EP names to download. * @returns A promise that resolves with an EpDownloadResult describing the outcome. */ public downloadAndRegisterEps(names: string[]): Promise; + /** + * Downloads and registers execution providers. + * @param names - Array of EP names to download. + * @param signal - Optional AbortSignal used to cancel an in-progress download. + * @returns A promise that resolves with an EpDownloadResult describing the outcome. + */ + public downloadAndRegisterEps(names: string[], signal: AbortSignal): Promise; /** * Downloads and registers execution providers, reporting progress. * @param progressCallback - Callback invoked with (epName, percent) as each EP downloads. Percent is 0-100. * @returns A promise that resolves with an EpDownloadResult describing the outcome. */ public downloadAndRegisterEps(progressCallback: (epName: string, percent: number) => void): Promise; + /** + * Downloads and registers execution providers, reporting progress. + * @param progressCallback - Callback invoked with (epName, percent) as each EP downloads. Percent is 0-100. + * @param signal - Optional AbortSignal used to cancel an in-progress download. + * @returns A promise that resolves with an EpDownloadResult describing the outcome. + */ + public downloadAndRegisterEps(progressCallback: (epName: string, percent: number) => void, signal: AbortSignal): Promise; /** * Downloads and registers execution providers, reporting progress. * @param names - Array of EP names to download. @@ -197,16 +224,45 @@ export class FoundryLocalManager { * @returns A promise that resolves with an EpDownloadResult describing the outcome. */ public downloadAndRegisterEps(names: string[], progressCallback: (epName: string, percent: number) => void): Promise; + /** + * Downloads and registers execution providers, reporting progress. + * @param names - Array of EP names to download. + * @param progressCallback - Callback invoked with (epName, percent) as each EP downloads. Percent is 0-100. + * @param signal - Optional AbortSignal used to cancel an in-progress download. + * @returns A promise that resolves with an EpDownloadResult describing the outcome. + */ + public downloadAndRegisterEps(names: string[], progressCallback: (epName: string, percent: number) => void, signal: AbortSignal): Promise; + /** + * Downloads and registers execution providers, preserving compatibility with callers that pass undefined for names. + * @param names - Undefined to download all EPs. + * @param signal - Optional AbortSignal used to cancel an in-progress download. + * @returns A promise that resolves with an EpDownloadResult describing the outcome. + */ + public downloadAndRegisterEps(names: undefined, signal: AbortSignal): Promise; + /** + * Downloads and registers execution providers, preserving compatibility with callers that pass undefined for names. + * @param names - Undefined to download all EPs. + * @param progressCallback - Callback invoked with (epName, percent) as each EP downloads. Percent is 0-100. + * @param signal - Optional AbortSignal used to cancel an in-progress download. + * @returns A promise that resolves with an EpDownloadResult describing the outcome. + */ + public downloadAndRegisterEps(names: undefined, progressCallback: (epName: string, percent: number) => void, signal?: AbortSignal): Promise; public async downloadAndRegisterEps( - namesOrCallback?: string[] | ((epName: string, percent: number) => void), - progressCallback?: (epName: string, percent: number) => void + namesOrCallbackOrSignal?: string[] | ((epName: string, percent: number) => void) | AbortSignal, + progressCallbackOrSignal?: ((epName: string, percent: number) => void) | AbortSignal, + maybeSignal?: AbortSignal ): Promise { - let names: string[] | undefined; - if (typeof namesOrCallback === 'function') { - progressCallback = namesOrCallback; - } else { - names = namesOrCallback; - } + const names = Array.isArray(namesOrCallbackOrSignal) ? namesOrCallbackOrSignal : undefined; + const progressCallback = typeof namesOrCallbackOrSignal === 'function' + ? namesOrCallbackOrSignal + : typeof progressCallbackOrSignal === 'function' + ? progressCallbackOrSignal + : undefined; + const signal = isAbortSignal(namesOrCallbackOrSignal) + ? namesOrCallbackOrSignal + : isAbortSignal(progressCallbackOrSignal) + ? progressCallbackOrSignal + : maybeSignal; const params: { Params?: { Names: string } } = {}; if (names && names.length > 0) { @@ -221,11 +277,17 @@ export class FoundryLocalManager { }; let response: string; + const commandParams = Object.keys(params).length > 0 ? params : undefined; - if (progressCallback) { + if (!progressCallback && !signal) { + response = await this.coreInterop.executeCommandAsync( + "download_and_register_eps", + commandParams + ); + } else if (progressCallback) { response = await this.coreInterop.executeCommandStreaming( "download_and_register_eps", - Object.keys(params).length > 0 ? params : undefined, + commandParams, (chunk: string) => { const sepIndex = chunk.indexOf('|'); if (sepIndex >= 0) { @@ -235,13 +297,15 @@ export class FoundryLocalManager { progressCallback(epName || '', percent); } } - } + }, + signal ); } else { response = await this.coreInterop.executeCommandStreaming( "download_and_register_eps", - Object.keys(params).length > 0 ? params : undefined, - () => {} // no-op callback + commandParams, + () => {}, // no-op callback + signal ); } diff --git a/sdk/js/src/imodel.ts b/sdk/js/src/imodel.ts index 7a8a79e35..72fdc4d8b 100644 --- a/sdk/js/src/imodel.ts +++ b/sdk/js/src/imodel.ts @@ -17,7 +17,13 @@ export interface IModel { get capabilities(): string | null; get supportsToolCalling(): boolean | null; - download(progressCallback?: (progress: number) => void): Promise; + /** + * Download the model to local cache if not already present. + * @param progressCallbackOrSignal - Optional callback for download progress (0-100), or AbortSignal. + * @param signal - Optional AbortSignal when a progress callback is provided. + */ + download(progressCallbackOrSignal?: ((progress: number) => void) | AbortSignal, + signal?: AbortSignal): Promise; get path(): string; load(): Promise; removeFromCache(): void; diff --git a/sdk/js/src/openai/audioClient.ts b/sdk/js/src/openai/audioClient.ts index 5e3d4f326..ba9830bbb 100644 --- a/sdk/js/src/openai/audioClient.ts +++ b/sdk/js/src/openai/audioClient.ts @@ -34,6 +34,10 @@ export class AudioClientSettings { } } +export interface AudioRequestOptions { + signal?: AbortSignal; +} + /** * Client for performing audio operations (transcription, translation) with a loaded model. * Follows the OpenAI Audio API structure. @@ -81,7 +85,7 @@ export class AudioClient { * @returns The transcription result. * @throws Error - If audioFilePath is invalid or transcription fails. */ - public async transcribe(audioFilePath: string): Promise { + public async transcribe(audioFilePath: string, options?: AudioRequestOptions): Promise { this.validateAudioFilePath(audioFilePath); const request = { Model: this.modelId, @@ -90,9 +94,17 @@ export class AudioClient { }; try { - const response = this.coreInterop.executeCommand("audio_transcribe", { Params: { OpenAICreateRequest: JSON.stringify(request) } }); + const response = await this.coreInterop.executeCommandAsync( + "audio_transcribe", + { Params: { OpenAICreateRequest: JSON.stringify(request) } }, + options?.signal + ); return JSON.parse(response); } catch (error) { + if (error instanceof Error && error.name === 'AbortError') { + throw error; + } + throw new Error(`Audio transcription failed for model '${this.modelId}': ${error instanceof Error ? error.message : String(error)}`, { cause: error }); } } @@ -110,7 +122,7 @@ export class AudioClient { * } * ``` */ - public transcribeStreaming(audioFilePath: string): AsyncIterable { + public transcribeStreaming(audioFilePath: string, options?: AudioRequestOptions): AsyncIterable { this.validateAudioFilePath(audioFilePath); const request = { @@ -141,6 +153,13 @@ export class AudioClient { let error: Error | null = null; let resolve: (() => void) | null = null; let nextInFlight = false; + const abortController = new AbortController(); + const abortFromExternalSignal = () => abortController.abort(); + if (options?.signal?.aborted) { + abortController.abort(); + } else { + options?.signal?.addEventListener('abort', abortFromExternalSignal, { once: true }); + } const streamingPromise = coreInterop.executeCommandStreaming( "audio_transcribe", @@ -166,10 +185,12 @@ export class AudioClient { resolve = null; r(); } - } + }, + abortController.signal // When the native stream completes, mark done and wake up any // pending next() call so it can see that iteration has ended. ).then(() => { + options?.signal?.removeEventListener('abort', abortFromExternalSignal); done = true; if (resolve) { const r = resolve; @@ -177,6 +198,7 @@ export class AudioClient { r(); // resolve the pending next() promise } }).catch((err) => { + options?.signal?.removeEventListener('abort', abortFromExternalSignal); if (!error) { const underlyingError = err instanceof Error ? err : new Error(String(err)); error = new Error( @@ -228,12 +250,8 @@ export class AudioClient { } }, async return(): Promise> { - // Mark cancelled so the callback stops buffering. - // Note: the underlying native stream cannot be cancelled - // (CoreInterop.executeCommandStreaming has no abort support), - // so the koffi callback may still fire but will no-op due - // to the cancelled guard above. cancelled = true; + abortController.abort(); chunks.length = 0; head = 0; if (resolve) { diff --git a/sdk/js/src/openai/chatClient.ts b/sdk/js/src/openai/chatClient.ts index e61efcfa5..3ef92808a 100644 --- a/sdk/js/src/openai/chatClient.ts +++ b/sdk/js/src/openai/chatClient.ts @@ -105,6 +105,10 @@ export class ChatClientSettings { } } +export interface ChatRequestOptions { + signal?: AbortSignal; +} + /** * Client for performing chat completions with a loaded model. * Follows the OpenAI Chat Completion API structure. @@ -186,9 +190,16 @@ export class ChatClient { * @returns The chat completion response object. * @throws Error - If messages or tools are invalid or completion fails. */ - public async completeChat(messages: any[]): Promise; - public async completeChat(messages: any[], tools: any[]): Promise; - public async completeChat(messages: any[], tools?: any[]): Promise { + public async completeChat(messages: any[], options?: ChatRequestOptions): Promise; + public async completeChat(messages: any[], tools: any[], options?: ChatRequestOptions): Promise; + public async completeChat( + messages: any[], + toolsOrOptions?: any[] | ChatRequestOptions, + options?: ChatRequestOptions + ): Promise { + const tools = Array.isArray(toolsOrOptions) ? toolsOrOptions : undefined; + const signal = Array.isArray(toolsOrOptions) ? options?.signal : toolsOrOptions?.signal; + this.validateMessages(messages); this.validateTools(tools); @@ -201,11 +212,15 @@ export class ChatClient { }; try { - const response = this.coreInterop.executeCommand('chat_completions', { + const response = await this.coreInterop.executeCommandAsync('chat_completions', { Params: { OpenAICreateRequest: JSON.stringify(request) } - }); + }, signal); return JSON.parse(response); } catch (error) { + if (error instanceof Error && error.name === 'AbortError') { + throw error; + } + throw new Error( `Chat completion failed for model '${this.modelId}': ${error instanceof Error ? error.message : String(error)}`, { cause: error } @@ -235,9 +250,16 @@ export class ChatClient { * } * ``` */ - public completeStreamingChat(messages: any[]): AsyncIterable; - public completeStreamingChat(messages: any[], tools: any[]): AsyncIterable; - public completeStreamingChat(messages: any[], tools?: any[]): AsyncIterable { + public completeStreamingChat(messages: any[], options?: ChatRequestOptions): AsyncIterable; + public completeStreamingChat(messages: any[], tools: any[], options?: ChatRequestOptions): AsyncIterable; + public completeStreamingChat( + messages: any[], + toolsOrOptions?: any[] | ChatRequestOptions, + options?: ChatRequestOptions + ): AsyncIterable { + const tools = Array.isArray(toolsOrOptions) ? toolsOrOptions : undefined; + const signal = Array.isArray(toolsOrOptions) ? options?.signal : toolsOrOptions?.signal; + this.validateMessages(messages); this.validateTools(tools); @@ -271,6 +293,13 @@ export class ChatClient { let error: Error | null = null; let resolve: (() => void) | null = null; let nextInFlight = false; + const abortController = new AbortController(); + const abortFromExternalSignal = () => abortController.abort(); + if (signal?.aborted) { + abortController.abort(); + } else { + signal?.addEventListener('abort', abortFromExternalSignal, { once: true }); + } const streamingPromise = coreInterop.executeCommandStreaming( 'chat_completions', @@ -296,10 +325,12 @@ export class ChatClient { resolve = null; r(); } - } + }, + abortController.signal // When the native stream completes, mark done and wake up any // pending next() call so it can see that iteration has ended. ).then(() => { + signal?.removeEventListener('abort', abortFromExternalSignal); done = true; if (resolve) { const r = resolve; @@ -307,6 +338,7 @@ export class ChatClient { r(); // resolve the pending next() promise } }).catch((err) => { + signal?.removeEventListener('abort', abortFromExternalSignal); if (!error) { const underlyingError = err instanceof Error ? err : new Error(String(err)); error = new Error( @@ -358,12 +390,8 @@ export class ChatClient { } }, async return(): Promise> { - // Mark cancelled so the callback stops buffering. - // Note: the underlying native stream cannot be cancelled - // (CoreInterop.executeCommandStreaming has no abort support), - // so the koffi callback may still fire but will no-op due - // to the cancelled guard above. cancelled = true; + abortController.abort(); chunks.length = 0; head = 0; if (resolve) { diff --git a/sdk/js/src/openai/embeddingClient.ts b/sdk/js/src/openai/embeddingClient.ts index ab415e0fe..2c26721f6 100644 --- a/sdk/js/src/openai/embeddingClient.ts +++ b/sdk/js/src/openai/embeddingClient.ts @@ -1,5 +1,9 @@ import { CoreInterop } from '../detail/coreInterop.js'; +export interface EmbeddingRequestOptions { + signal?: AbortSignal; +} + /** * Client for generating text embeddings with a loaded model. * Follows the OpenAI Embeddings API structure. @@ -45,18 +49,22 @@ export class EmbeddingClient { * Sends an embedding request and parses the response. * @internal */ - private executeRequest(input: string | string[]): any { + private async executeRequest(input: string | string[], options?: EmbeddingRequestOptions): Promise { const request = { model: this.modelId, input, }; try { - const response = this.coreInterop.executeCommand('embeddings', { + const response = await this.coreInterop.executeCommandAsync('embeddings', { Params: { OpenAICreateRequest: JSON.stringify(request) } - }); + }, options?.signal); return JSON.parse(response); } catch (error: any) { + if (error instanceof Error && error.name === 'AbortError') { + throw error; + } + throw new Error( `Embedding generation failed for model '${this.modelId}': ${error instanceof Error ? error.message : String(error)}`, { cause: error } @@ -69,9 +77,9 @@ export class EmbeddingClient { * @param input - The text to generate embeddings for. * @returns The embedding response containing the embedding vector. */ - public async generateEmbedding(input: string): Promise { + public async generateEmbedding(input: string, options?: EmbeddingRequestOptions): Promise { this.validateInput(input); - return this.executeRequest(input); + return this.executeRequest(input, options); } /** @@ -79,8 +87,8 @@ export class EmbeddingClient { * @param inputs - The texts to generate embeddings for. * @returns The embedding response containing one embedding vector per input. */ - public async generateEmbeddings(inputs: string[]): Promise { + public async generateEmbeddings(inputs: string[], options?: EmbeddingRequestOptions): Promise { this.validateInputs(inputs); - return this.executeRequest(inputs); + return this.executeRequest(inputs, options); } } diff --git a/sdk/js/test/detail/coreInterop.test.ts b/sdk/js/test/detail/coreInterop.test.ts new file mode 100644 index 000000000..2adae14a1 --- /dev/null +++ b/sdk/js/test/detail/coreInterop.test.ts @@ -0,0 +1,198 @@ +import { describe, it } from 'mocha'; +import { expect } from 'chai'; +import { CoreInterop } from '../../src/detail/coreInterop.js'; + +describe('CoreInterop Tests', () => { + it('executeCommandAsync should reject without calling native interop when signal is already aborted', async function() { + const controller = new AbortController(); + controller.abort(); + const interop = Object.create(CoreInterop.prototype) as any; + interop.addon = { + executeCommandAsync: () => { + throw new Error('native interop should not be called for an already aborted signal'); + } + }; + + let caught: unknown; + try { + await CoreInterop.prototype.executeCommandAsync.call( + interop, + 'chat_completions', + undefined, + controller.signal + ); + } catch (error) { + caught = error; + } + + expect(caught).to.be.instanceOf(Error); + expect((caught as Error).name).to.equal('AbortError'); + }); + + it('executeCommandAsync should cancel and release a native cancellation context', async function() { + const controller = new AbortController(); + const calls: string[] = []; + const interop = Object.create(CoreInterop.prototype) as any; + interop.addon = { + hasCancellableCommands: () => true, + createCancellationContext: () => { + calls.push('create'); + return 42; + }, + cancelCancellationContext: (id: number) => { + calls.push(`cancel:${id}`); + }, + releaseCancellationContext: (id: number) => { + calls.push(`release:${id}`); + }, + executeCommandAsync: async (_command: string, _dataJson: string, contextId?: number) => { + calls.push(`execute:${contextId}`); + controller.abort(); + throw new Error("Command 'chat_completions' failed: Operation was cancelled by user"); + } + }; + + let caught: unknown; + try { + await CoreInterop.prototype.executeCommandAsync.call( + interop, + 'chat_completions', + undefined, + controller.signal + ); + } catch (error) { + caught = error; + } + + expect(calls).to.deep.equal(['create', 'execute:42', 'cancel:42', 'release:42']); + expect(caught).to.be.instanceOf(Error); + expect((caught as Error).name).to.equal('AbortError'); + }); + + it('executeCommandStreaming should reject without calling native interop when signal is already aborted', async function() { + const controller = new AbortController(); + controller.abort(); + const interop = Object.create(CoreInterop.prototype) as any; + interop.addon = { + executeCommandStreaming: () => { + throw new Error('native interop should not be called for an already aborted signal'); + } + }; + + let caught: unknown; + try { + await CoreInterop.prototype.executeCommandStreaming.call( + interop, + 'download_model', + undefined, + () => {}, + controller.signal + ); + } catch (error) { + caught = error; + } + + expect(caught).to.be.instanceOf(Error); + expect((caught as Error).name).to.equal('AbortError'); + }); + + it('executeCommandStreaming should reject when signal is aborted before the next callback', async function() { + const controller = new AbortController(); + const chunks: string[] = []; + const interop = Object.create(CoreInterop.prototype) as any; + interop.addon = { + executeCommandStreaming: async (_command: string, _dataJson: string, callback: (chunk: string) => void) => { + callback('50'); + callback('60'); + return 'ok'; + } + }; + + let caught: unknown; + try { + await CoreInterop.prototype.executeCommandStreaming.call( + interop, + 'download_model', + undefined, + (chunk: string) => { + chunks.push(chunk); + controller.abort(); + }, + controller.signal + ); + } catch (error) { + caught = error; + } + + expect(chunks).to.deep.equal(['50']); + expect(caught).to.be.instanceOf(Error); + expect((caught as Error).name).to.equal('AbortError'); + }); + + it('executeCommandStreaming should not reject when signal aborts after the final observed callback', async function() { + const controller = new AbortController(); + const interop = Object.create(CoreInterop.prototype) as any; + interop.addon = { + executeCommandStreaming: async (_command: string, _dataJson: string, callback: (chunk: string) => void) => { + callback('100'); + return 'ok'; + } + }; + + const result = await CoreInterop.prototype.executeCommandStreaming.call( + interop, + 'download_model', + undefined, + () => controller.abort(), + controller.signal + ); + + expect(result).to.equal('ok'); + }); + + it('executeCommandStreaming should cancel and release a native cancellation context', async function() { + const controller = new AbortController(); + const calls: string[] = []; + const interop = Object.create(CoreInterop.prototype) as any; + interop.addon = { + hasCancellableCommands: () => true, + createCancellationContext: () => { + calls.push('create'); + return 7; + }, + cancelCancellationContext: (id: number) => { + calls.push(`cancel:${id}`); + }, + releaseCancellationContext: (id: number) => { + calls.push(`release:${id}`); + }, + executeCommandStreaming: async ( + _command: string, + _dataJson: string, + _callback: (chunk: string) => void, + contextId?: number + ) => { + calls.push(`execute:${contextId}`); + controller.abort(); + throw new Error("Command 'chat_completions' failed: Operation was cancelled by user"); + } + }; + + let caught: unknown; + try { + await CoreInterop.prototype.executeCommandStreaming.call( + interop, + 'chat_completions', + undefined, + () => {}, + controller.signal + ); + } catch (error) { + caught = error; + } + + expect(calls).to.deep.equal(['create', 'execute:7', 'cancel:7', 'release:7']); + expect(caught).to.be.instanceOf(Error); + expect((caught as Error).name).to.equal('AbortError'); + }); +}); diff --git a/sdk/js/test/foundryLocalManager.test.ts b/sdk/js/test/foundryLocalManager.test.ts index 48adcff40..8f7e9b6e4 100644 --- a/sdk/js/test/foundryLocalManager.test.ts +++ b/sdk/js/test/foundryLocalManager.test.ts @@ -1,6 +1,7 @@ import { describe, it } from 'mocha'; import { expect } from 'chai'; import { getTestManager } from './testUtils.js'; +import { FoundryLocalManager } from '../src/foundryLocalManager.js'; describe('Foundry Local Manager Tests', () => { it('should initialize successfully', function() { @@ -18,64 +19,159 @@ describe('Foundry Local Manager Tests', () => { }); it('downloadAndRegisterEps should call command without params when names are omitted', async function() { - const manager = getTestManager() as any; const calls: unknown[][] = []; - const originalExecuteCommandStreaming = manager.coreInterop.executeCommandStreaming; - - manager.coreInterop.executeCommandStreaming = (...args: unknown[]) => { - calls.push(args); - return Promise.resolve(JSON.stringify({ - Success: true, - Status: 'All providers registered', - RegisteredEps: ['CUDAExecutionProvider'], - FailedEps: [] - })); + const manager = Object.create(FoundryLocalManager.prototype) as any; + manager.coreInterop = { + executeCommandAsync: (...args: unknown[]) => { + calls.push(args); + return Promise.resolve(JSON.stringify({ + Success: true, + Status: 'All providers registered', + RegisteredEps: ['CUDAExecutionProvider'], + FailedEps: [] + })); + }, + executeCommandStreaming: () => { + throw new Error('download should not use streaming interop without progress or cancellation'); + } + }; + manager._catalog = { + invalidateCache: () => {} }; - try { - const result = await manager.downloadAndRegisterEps(); - expect(calls.length).to.equal(1); - expect(calls[0][0]).to.equal('download_and_register_eps'); - expect(calls[0][1]).to.be.undefined; - expect(result).to.deep.equal({ - success: true, - status: 'All providers registered', - registeredEps: ['CUDAExecutionProvider'], - failedEps: [] - }); - } finally { - manager.coreInterop.executeCommandStreaming = originalExecuteCommandStreaming; - } + const result = await manager.downloadAndRegisterEps(); + expect(calls.length).to.equal(1); + expect(calls[0][0]).to.equal('download_and_register_eps'); + expect(calls[0][1]).to.be.undefined; + expect(result).to.deep.equal({ + success: true, + status: 'All providers registered', + registeredEps: ['CUDAExecutionProvider'], + failedEps: [] + }); }); it('downloadAndRegisterEps should send Names param when subset is provided', async function() { - const manager = getTestManager() as any; const calls: unknown[][] = []; - const originalExecuteCommandStreaming = manager.coreInterop.executeCommandStreaming; + const manager = Object.create(FoundryLocalManager.prototype) as any; + manager.coreInterop = { + executeCommandAsync: (...args: unknown[]) => { + calls.push(args); + return Promise.resolve(JSON.stringify({ + Success: false, + Status: 'Some providers failed', + RegisteredEps: ['CUDAExecutionProvider'], + FailedEps: ['OpenVINOExecutionProvider'] + })); + }, + executeCommandStreaming: () => { + throw new Error('download should not use streaming interop without progress or cancellation'); + } + }; + manager._catalog = { + invalidateCache: () => {} + }; + + const result = await manager.downloadAndRegisterEps(['CUDAExecutionProvider', 'OpenVINOExecutionProvider']); + expect(calls.length).to.equal(1); + expect(calls[0][0]).to.equal('download_and_register_eps'); + expect(calls[0][1]).to.deep.equal({ Params: { Names: 'CUDAExecutionProvider,OpenVINOExecutionProvider' } }); + expect(result).to.deep.equal({ + success: false, + status: 'Some providers failed', + registeredEps: ['CUDAExecutionProvider'], + failedEps: ['OpenVINOExecutionProvider'] + }); + }); + + it('downloadAndRegisterEps should pass AbortSignal through to streaming interop', async function() { + const calls: unknown[][] = []; + const controller = new AbortController(); + const manager = Object.create(FoundryLocalManager.prototype) as any; + manager.coreInterop = { + executeCommandStreaming: (...args: unknown[]) => { + calls.push(args); + return Promise.resolve(JSON.stringify({ + Success: true, + Status: 'All providers registered', + RegisteredEps: ['CUDAExecutionProvider'], + FailedEps: [] + })); + } + }; + manager._catalog = { + invalidateCache: () => {} + }; + + await FoundryLocalManager.prototype.downloadAndRegisterEps.call( + manager, + ['CUDAExecutionProvider'], + controller.signal + ); + expect(calls.length).to.equal(1); + expect(calls[0][0]).to.equal('download_and_register_eps'); + expect(calls[0][3]).to.equal(controller.signal); + }); + + it('downloadAndRegisterEps should honor progress callback when names are explicitly undefined', async function() { + const calls: unknown[][] = []; + const progress: Array<[string, number]> = []; + const manager = Object.create(FoundryLocalManager.prototype) as any; + manager.coreInterop = { + executeCommandStreaming: (...args: unknown[]) => { + calls.push(args); + const callback = args[2] as (chunk: string) => void; + callback('CUDAExecutionProvider|42.5'); + return Promise.resolve(JSON.stringify({ + Success: true, + Status: 'All providers registered', + RegisteredEps: ['CUDAExecutionProvider'], + FailedEps: [] + })); + } + }; + manager._catalog = { + invalidateCache: () => {} + }; + + await FoundryLocalManager.prototype.downloadAndRegisterEps.call( + manager, + undefined, + (epName: string, percent: number) => progress.push([epName, percent]) + ); - manager.coreInterop.executeCommandStreaming = (...args: unknown[]) => { - calls.push(args); - return Promise.resolve(JSON.stringify({ - Success: false, - Status: 'Some providers failed', - RegisteredEps: ['CUDAExecutionProvider'], - FailedEps: ['OpenVINOExecutionProvider'] - })); + expect(calls.length).to.equal(1); + expect(calls[0][1]).to.be.undefined; + expect(progress).to.deep.equal([['CUDAExecutionProvider', 42.5]]); + }); + + it('downloadAndRegisterEps should pass AbortSignal when names are explicitly undefined', async function() { + const calls: unknown[][] = []; + const controller = new AbortController(); + const manager = Object.create(FoundryLocalManager.prototype) as any; + manager.coreInterop = { + executeCommandStreaming: (...args: unknown[]) => { + calls.push(args); + return Promise.resolve(JSON.stringify({ + Success: true, + Status: 'All providers registered', + RegisteredEps: ['CUDAExecutionProvider'], + FailedEps: [] + })); + } }; + manager._catalog = { + invalidateCache: () => {} + }; + + await FoundryLocalManager.prototype.downloadAndRegisterEps.call( + manager, + undefined, + controller.signal + ); - try { - const result = await manager.downloadAndRegisterEps(['CUDAExecutionProvider', 'OpenVINOExecutionProvider']); - expect(calls.length).to.equal(1); - expect(calls[0][0]).to.equal('download_and_register_eps'); - expect(calls[0][1]).to.deep.equal({ Params: { Names: 'CUDAExecutionProvider,OpenVINOExecutionProvider' } }); - expect(result).to.deep.equal({ - success: false, - status: 'Some providers failed', - registeredEps: ['CUDAExecutionProvider'], - failedEps: ['OpenVINOExecutionProvider'] - }); - } finally { - manager.coreInterop.executeCommandStreaming = originalExecuteCommandStreaming; - } + expect(calls.length).to.equal(1); + expect(calls[0][1]).to.be.undefined; + expect(calls[0][3]).to.equal(controller.signal); }); }); diff --git a/sdk/js/test/model.test.ts b/sdk/js/test/model.test.ts index 4048d9a11..7203e5668 100644 --- a/sdk/js/test/model.test.ts +++ b/sdk/js/test/model.test.ts @@ -1,6 +1,9 @@ import { describe, it } from 'mocha'; import { expect } from 'chai'; import { getTestManager, TEST_MODEL_ALIAS } from './testUtils.js'; +import { Model } from '../src/detail/model.js'; +import { ModelVariant } from '../src/detail/modelVariant.js'; +import { DeviceType, type ModelInfo } from '../src/types.js'; describe('Model Tests', () => { it('should verify cached models from test-data-shared', async function() { @@ -58,4 +61,114 @@ describe('Model Tests', () => { await model.unload(); expect(await model.isLoaded()).to.be.false; }); -}); \ No newline at end of file + + it('download should use streaming interop when only an AbortSignal is provided', async function() { + const calls: unknown[][] = []; + const controller = new AbortController(); + const fakeCoreInterop = { + executeCommandAsync: () => { + throw new Error('download should not use executeCommandAsync when a signal is provided'); + }, + executeCommandStreaming: (...args: unknown[]) => { + calls.push(args); + return Promise.resolve(''); + } + }; + const modelInfo: ModelInfo = { + id: 'test-model-cpu:1', + name: 'test-model-cpu', + version: 1, + alias: TEST_MODEL_ALIAS, + providerType: 'AzureFoundry', + uri: 'azureml://registries/azureml/models/test-model-cpu/versions/1', + modelType: 'ONNX', + cached: false, + createdAtUnix: 0, + runtime: { + deviceType: DeviceType.CPU, + executionProvider: 'CPUExecutionProvider' + } + }; + const variant = new ModelVariant(modelInfo, fakeCoreInterop as any, {} as any); + const model = new Model(variant); + + await model.download(controller.signal); + expect(calls.length).to.equal(1); + expect(calls[0][0]).to.equal('download_model'); + expect(calls[0][3]).to.equal(controller.signal); + }); + + it('download should preserve undefined progress callback with AbortSignal overload', async function() { + const calls: unknown[][] = []; + const controller = new AbortController(); + const fakeCoreInterop = { + executeCommandAsync: () => { + throw new Error('download should not use executeCommandAsync when a signal is provided'); + }, + executeCommandStreaming: (...args: unknown[]) => { + calls.push(args); + return Promise.resolve(''); + } + }; + const modelInfo: ModelInfo = { + id: 'test-model-cpu:1', + name: 'test-model-cpu', + version: 1, + alias: TEST_MODEL_ALIAS, + providerType: 'AzureFoundry', + uri: 'azureml://registries/azureml/models/test-model-cpu/versions/1', + modelType: 'ONNX', + cached: false, + createdAtUnix: 0, + runtime: { + deviceType: DeviceType.CPU, + executionProvider: 'CPUExecutionProvider' + } + }; + const variant = new ModelVariant(modelInfo, fakeCoreInterop as any, {} as any); + const model = new Model(variant); + + await model.download(undefined, controller.signal); + expect(calls.length).to.equal(1); + expect(calls[0][0]).to.equal('download_model'); + expect(calls[0][3]).to.equal(controller.signal); + }); + + it('download should parse a numeric progress chunk', async function() { + const progress: number[] = []; + const fakeCoreInterop = { + executeCommandAsync: () => { + throw new Error('download should use streaming interop when progress is provided'); + }, + executeCommandStreaming: async ( + _command: string, + _request: unknown, + callback: (chunk: string) => void + ) => { + callback('12.5'); + return ''; + } + }; + const modelInfo: ModelInfo = { + id: 'test-model-cpu:1', + name: 'test-model-cpu', + version: 1, + alias: TEST_MODEL_ALIAS, + providerType: 'AzureFoundry', + uri: 'azureml://registries/azureml/models/test-model-cpu/versions/1', + modelType: 'ONNX', + cached: false, + createdAtUnix: 0, + runtime: { + deviceType: DeviceType.CPU, + executionProvider: 'CPUExecutionProvider' + } + }; + const variant = new ModelVariant(modelInfo, fakeCoreInterop as any, {} as any); + const model = new Model(variant); + + await model.download(progress.push.bind(progress)); + + expect(progress).to.deep.equal([12.5]); + }); +}); diff --git a/sdk/python/README.md b/sdk/python/README.md index 2a121411e..55a6f8d17 100644 --- a/sdk/python/README.md +++ b/sdk/python/README.md @@ -108,6 +108,21 @@ manager.download_and_register_eps(progress_callback=on_progress) print() ``` +### Cancelling model and EP downloads + +Pass a `threading.Event` as `cancel_event` to either download API. Set the event from another thread or handler to cancel the in-progress download. + +```python +import threading + +# manager and model already initialized +cancel_event = threading.Event() +threading.Timer(5.0, cancel_event.set).start() + +manager.download_and_register_eps(cancel_event=cancel_event) +model.download(cancel_event=cancel_event) +``` + Catalog access does not block on EP downloads. Call `download_and_register_eps()` when you need hardware-accelerated execution providers. ## Quick Start @@ -328,4 +343,4 @@ See [test/README.md](test/README.md) for detailed test setup and structure. ```bash python examples/chat_completion.py -``` \ No newline at end of file +``` diff --git a/sdk/python/src/detail/core_interop.py b/sdk/python/src/detail/core_interop.py index f93b79f03..bb15ef344 100644 --- a/sdk/python/src/detail/core_interop.py +++ b/sdk/python/src/detail/core_interop.py @@ -10,6 +10,7 @@ import logging import os import sys +import threading from dataclasses import dataclass from pathlib import Path @@ -84,6 +85,56 @@ class Response: error: Optional[str] = None +class CancelledException(Exception): + """Raised internally when a download or streaming operation is cancelled.""" + + +def _is_user_cancellation_error(error: Optional[str]) -> bool: + return error is not None and error.strip().rstrip(".").lower() == "operation was cancelled by user" + + +class CancellationContext: + """Context manager for a Core cancellation context tied to a threading.Event.""" + + _POLL_SECONDS = 0.01 + + def __init__(self, lib, cancel_event: Optional['threading.Event']): + self._lib = lib + self._cancel_event = cancel_event + self._stop_event = threading.Event() + self._watcher: Optional[threading.Thread] = None + self.context_id: Optional[int] = None + + def __enter__(self): + if self._cancel_event is None or not CoreInterop._cancellable_commands_available: + return self + + self.context_id = self._lib.create_cancellation_context() + + if self._cancel_event.is_set(): + self._lib.cancel_cancellation_context(self.context_id) + return self + + def _watch() -> None: + while not self._stop_event.is_set(): + if self._cancel_event.wait(self._POLL_SECONDS): + self._lib.cancel_cancellation_context(self.context_id) + return + + self._watcher = threading.Thread(target=_watch, daemon=True) + self._watcher.start() + return self + + def __exit__(self, exc_type, exc, tb): + self._stop_event.set() + if self._watcher is not None: + self._watcher.join() + + if self.context_id is not None: + self._lib.release_cancellation_context(self.context_id) + self.context_id = None + + class CallbackHelper: """Internal helper class to convert the callback from ctypes to a str and call the python callback.""" @staticmethod @@ -92,18 +143,27 @@ def callback(data_ptr, length, self_ptr): try: self = ctypes.cast(self_ptr, ctypes.POINTER(ctypes.py_object)).contents.value + # Check for cancellation before processing the callback data. + if self._cancel_event is not None and self._cancel_event.is_set(): + raise CancelledException("Operation cancelled") + # convert to a string and pass to the python callback data_bytes = ctypes.string_at(data_ptr, length) data_str = data_bytes.decode('utf-8') self._py_callback(data_str) return 0 # continue + except CancelledException as e: + if self is not None and self.exception is None: + self.exception = e + return 1 # cancel except Exception as e: if self is not None and self.exception is None: self.exception = e # keep the first only as they are likely all the same return 1 # cancel on error - def __init__(self, py_callback: Callable[[str], None]): + def __init__(self, py_callback: Callable[[str], None], cancel_event: Optional['threading.Event'] = None): self._py_callback = py_callback + self._cancel_event = cancel_event self.exception = None @@ -118,6 +178,7 @@ class CoreInterop: _flcore_library = None _genai_library = None _ort_library = None + _cancellable_commands_available = False instance = None @@ -200,6 +261,36 @@ def _initialize_native_libraries() -> 'NativeBinaryPaths': logger.debug("execute_command_with_binary not exported by Core — " "live audio streaming will not be available until Core is updated") + try: + lib.create_cancellation_context.argtypes = [] + lib.create_cancellation_context.restype = ctypes.c_int64 + lib.cancel_cancellation_context.argtypes = [ctypes.c_int64] + lib.cancel_cancellation_context.restype = ctypes.c_int + lib.release_cancellation_context.argtypes = [ctypes.c_int64] + lib.release_cancellation_context.restype = ctypes.c_int + lib.execute_command_cancellable.argtypes = [ctypes.POINTER(RequestBuffer), + ctypes.POINTER(ResponseBuffer), + ctypes.c_int64] + lib.execute_command_cancellable.restype = None + lib.execute_command_with_callback_cancellable.argtypes = [ + ctypes.POINTER(RequestBuffer), + ctypes.POINTER(ResponseBuffer), + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_int64, + ] + lib.execute_command_with_callback_cancellable.restype = None + lib.execute_command_with_binary_cancellable.argtypes = [ + ctypes.POINTER(StreamingRequestBuffer), + ctypes.POINTER(ResponseBuffer), + ctypes.c_int64, + ] + lib.execute_command_with_binary_cancellable.restype = None + CoreInterop._cancellable_commands_available = True + except AttributeError: + CoreInterop._cancellable_commands_available = False + logger.debug("Cancellable command symbols not exported by Core — falling back to legacy command paths") + return paths @staticmethod @@ -252,46 +343,78 @@ def __init__(self, config: Configuration): logger.info("Foundry.Local.Core initialized successfully: %s", response.data) def _execute_command(self, command: str, interop_request: InteropRequest = None, - callback: CoreInterop.CALLBACK_TYPE = None): + callback: CoreInterop.CALLBACK_TYPE = None, + cancel_event: Optional[threading.Event] = None): + if cancel_event is not None and cancel_event.is_set(): + raise FoundryLocalException("Operation cancelled") + cmd_ptr, cmd_len, cmd_buf = CoreInterop._to_c_buffer(command) data_ptr, data_len, data_buf = CoreInterop._to_c_buffer(interop_request.to_json() if interop_request else None) req = RequestBuffer(Command=cmd_ptr, CommandLength=cmd_len, Data=data_ptr, DataLength=data_len) resp = ResponseBuffer() lib = CoreInterop._flcore_library - - if (callback is not None): - # If a callback is provided, use the execute_command_with_callback method - # We need a helper to do the initial conversion from ctypes to Python and pass it through to the - # provided callback function - callback_helper = CallbackHelper(callback) - callback_py_obj = ctypes.py_object(callback_helper) - callback_helper_ptr = ctypes.cast(ctypes.pointer(callback_py_obj), ctypes.c_void_p) - callback_fn = CoreInterop.CALLBACK_TYPE(CallbackHelper.callback) - - lib.execute_command_with_callback(ctypes.byref(req), ctypes.byref(resp), callback_fn, callback_helper_ptr) - - if callback_helper.exception is not None: - raise callback_helper.exception - else: - lib.execute_command(ctypes.byref(req), ctypes.byref(resp)) + callback_exception = None + + with CancellationContext(lib, cancel_event) as cancellation_context: + if (callback is not None): + # If a callback is provided, use the execute_command_with_callback method + # We need a helper to do the initial conversion from ctypes to Python and pass it through to the + # provided callback function + callback_helper = CallbackHelper(callback, cancel_event) + callback_py_obj = ctypes.py_object(callback_helper) + callback_helper_ptr = ctypes.cast(ctypes.pointer(callback_py_obj), ctypes.c_void_p) + callback_fn = CoreInterop.CALLBACK_TYPE(CallbackHelper.callback) + + if cancellation_context.context_id is not None: + lib.execute_command_with_callback_cancellable( + ctypes.byref(req), + ctypes.byref(resp), + callback_fn, + callback_helper_ptr, + cancellation_context.context_id, + ) + else: + lib.execute_command_with_callback( + ctypes.byref(req), ctypes.byref(resp), callback_fn, callback_helper_ptr + ) + callback_exception = callback_helper.exception + else: + if cancellation_context.context_id is not None: + lib.execute_command_cancellable( + ctypes.byref(req), ctypes.byref(resp), cancellation_context.context_id + ) + else: + lib.execute_command(ctypes.byref(req), ctypes.byref(resp)) req = None # Free Python reference to request - response_str = ctypes.string_at(resp.Data, resp.DataLength).decode("utf-8") if resp.Data else None - error_str = ctypes.string_at(resp.Error, resp.ErrorLength).decode("utf-8") if resp.Error else None - - # C# owns the memory in the response so we need to free it explicitly - lib.free_response(resp) + try: + response_str = ctypes.string_at(resp.Data, resp.DataLength).decode("utf-8") if resp.Data else None + error_str = ctypes.string_at(resp.Error, resp.ErrorLength).decode("utf-8") if resp.Error else None + finally: + # C# owns the memory in the response so we need to free it explicitly. + # Do this before surfacing callback exceptions so cancellation does not leak native buffers. + lib.free_response(resp) + + if callback_exception is not None: + if isinstance(callback_exception, CancelledException): + raise FoundryLocalException("Operation cancelled") + raise callback_exception + + if cancel_event is not None and cancel_event.is_set() and _is_user_cancellation_error(error_str): + raise FoundryLocalException("Operation cancelled") return Response(data=response_str, error=error_str) - def execute_command(self, command_name: str, command_input: Optional[InteropRequest] = None) -> Response: + def execute_command(self, command_name: str, command_input: Optional[InteropRequest] = None, + cancel_event: Optional[threading.Event] = None) -> Response: """Execute a command synchronously. Args: command_name: The native command name (e.g. ``"get_model_list"``). command_input: Optional request parameters. + cancel_event: Optional ``threading.Event`` that signals cancellation. Returns: A ``Response`` with ``data`` on success or ``error`` on failure. @@ -299,32 +422,43 @@ def execute_command(self, command_name: str, command_input: Optional[InteropRequ logger.debug("Executing command: %s Input: %s", command_name, command_input.params if command_input else None) - response = self._execute_command(command_name, command_input) + response = self._execute_command(command_name, command_input, cancel_event=cancel_event) return response def execute_command_with_callback(self, command_name: str, command_input: Optional[InteropRequest], - callback: Callable[[str], None]) -> Response: + callback: Callable[[str], None], + cancel_event: Optional[threading.Event] = None) -> Response: """Execute a command with a streaming callback. The ``callback`` receives incremental string data from the native layer (e.g. streaming chat tokens or download progress). + If ``cancel_event`` is provided and is set, the native call will be + cancelled at the next callback invocation and a ``FoundryLocalException`` + with message ``"Operation cancelled"`` will be raised. + Args: command_name: The native command name. command_input: Optional request parameters. callback: Called with each incremental string response. + cancel_event: Optional ``threading.Event`` that signals cancellation + when set. Returns: A ``Response`` with ``data`` on success or ``error`` on failure. + + Raises: + FoundryLocalException: If the operation is cancelled or fails. """ logger.debug("Executing command with callback: %s Input: %s", command_name, command_input.params if command_input else None) - response = self._execute_command(command_name, command_input, callback) + response = self._execute_command(command_name, command_input, callback, cancel_event) return response def execute_command_with_binary(self, command_name: str, command_input: Optional[InteropRequest], - binary_data: bytes) -> Response: + binary_data: bytes, + cancel_event: Optional[threading.Event] = None) -> Response: """Execute a command with both JSON parameters and a raw binary payload. Used for operations like pushing PCM audio data alongside JSON metadata. @@ -333,6 +467,7 @@ def execute_command_with_binary(self, command_name: str, command_name: The native command name (e.g. ``"audio_stream_push"``). command_input: Optional request parameters (serialized as JSON). binary_data: Raw binary payload (e.g. PCM audio bytes). + cancel_event: Optional ``threading.Event`` that signals cancellation. Returns: A ``Response`` with ``data`` on success or ``error`` on failure. @@ -340,6 +475,9 @@ def execute_command_with_binary(self, command_name: str, logger.debug("Executing command with binary: %s Input: %s BinaryLen: %d", command_name, command_input.params if command_input else None, len(binary_data)) + if cancel_event is not None and cancel_event.is_set(): + raise FoundryLocalException("Operation cancelled") + cmd_ptr, cmd_len, cmd_buf = CoreInterop._to_c_buffer(command_name) data_ptr, data_len, data_buf = CoreInterop._to_c_buffer( command_input.to_json() if command_input else None @@ -357,7 +495,13 @@ def execute_command_with_binary(self, command_name: str, resp = ResponseBuffer() lib = CoreInterop._flcore_library - lib.execute_command_with_binary(ctypes.byref(req), ctypes.byref(resp)) + with CancellationContext(lib, cancel_event) as cancellation_context: + if cancellation_context.context_id is not None: + lib.execute_command_with_binary_cancellable( + ctypes.byref(req), ctypes.byref(resp), cancellation_context.context_id + ) + else: + lib.execute_command_with_binary(ctypes.byref(req), ctypes.byref(resp)) req = None # Free Python reference to request @@ -366,6 +510,9 @@ def execute_command_with_binary(self, command_name: str, lib.free_response(resp) + if cancel_event is not None and cancel_event.is_set() and _is_user_cancellation_error(error_str): + raise FoundryLocalException("Operation cancelled") + return Response(data=response_str, error=error_str) # --- Audio streaming session support --- diff --git a/sdk/python/src/detail/model.py b/sdk/python/src/detail/model.py index 6d60b7a2f..a71b1dba5 100644 --- a/sdk/python/src/detail/model.py +++ b/sdk/python/src/detail/model.py @@ -5,6 +5,7 @@ from __future__ import annotations import logging +from threading import Event from typing import Callable, List, Optional from ..imodel import IModel @@ -115,9 +116,10 @@ def is_loaded(self) -> bool: """Is the currently selected variant loaded in memory?""" return self._selected_variant.is_loaded - def download(self, progress_callback: Optional[Callable[[float], None]] = None) -> None: + def download(self, progress_callback: Optional[Callable[[float], None]] = None, + cancel_event: Optional[Event] = None) -> None: """Download the currently selected variant.""" - self._selected_variant.download(progress_callback) + self._selected_variant.download(progress_callback, cancel_event) def get_path(self) -> str: """Get the path to the currently selected variant.""" diff --git a/sdk/python/src/detail/model_variant.py b/sdk/python/src/detail/model_variant.py index 76efb05cd..a563baabd 100644 --- a/sdk/python/src/detail/model_variant.py +++ b/sdk/python/src/detail/model_variant.py @@ -5,6 +5,7 @@ from __future__ import annotations import logging +from threading import Event from typing import Callable, List, Optional from ..imodel import IModel @@ -112,20 +113,40 @@ def is_loaded(self) -> bool: loaded_model_ids = self._model_load_manager.list_loaded() return self.id in loaded_model_ids - def download(self, progress_callback: Callable[[float], None] = None): + def download(self, progress_callback: Callable[[float], None] = None, + cancel_event: Optional[Event] = None): """Download this variant to the local cache. Args: progress_callback: Optional callback receiving download progress as a percentage (0.0 to 100.0). + cancel_event: Optional ``threading.Event``. When set, the download will be + cancelled at the next progress update and ``FoundryLocalException`` is raised. """ + self._download_impl(progress_callback, cancel_event) + + def _download_impl(self, progress_callback: Callable[[float], None] = None, + cancel_event: Optional[Event] = None) -> None: request = InteropRequest(params={"Model": self.id}) - if progress_callback is None: + if progress_callback is None and cancel_event is None: response = self._core_interop.execute_command("download_model", request) else: + # Use the callback path when either progress or cancellation is needed. + # Ignore invalid progress chunks so cancellation-only downloads + # still tolerate any non-progress output from the native layer. + def _on_chunk(chunk: str) -> None: + if progress_callback is None: + return + + try: + progress_callback(float(chunk)) + except ValueError: + pass + response = self._core_interop.execute_command_with_callback( "download_model", request, - lambda pct_str: progress_callback(float(pct_str)) + _on_chunk, + cancel_event, ) logger.info("Download response: %s", response) diff --git a/sdk/python/src/foundry_local_manager.py b/sdk/python/src/foundry_local_manager.py index a649f8e56..e47569ecc 100644 --- a/sdk/python/src/foundry_local_manager.py +++ b/sdk/python/src/foundry_local_manager.py @@ -101,6 +101,7 @@ def download_and_register_eps( self, names: Optional[list[str]] = None, progress_callback: Optional[Callable[[str, float], None]] = None, + cancel_event: Optional[threading.Event] = None, ) -> EpDownloadResult: """Download and register execution providers. @@ -109,6 +110,8 @@ def download_and_register_eps( all discoverable EPs are downloaded. progress_callback: Optional callback ``(ep_name: str, percent: float) -> None`` invoked as each EP downloads. ``percent`` is 0-100. + cancel_event: Optional ``threading.Event`` that signals cancellation + when set. The download will be cancelled at the next progress update. Returns: ``EpDownloadResult`` describing operation status and per-EP outcomes. @@ -120,19 +123,20 @@ def download_and_register_eps( if names is not None and len(names) > 0: request = InteropRequest(params={"Names": ",".join(names)}) - if progress_callback is not None: + if progress_callback is not None or cancel_event is not None: def _on_chunk(chunk: str) -> None: - sep = chunk.find("|") - if sep >= 0: - ep_name = chunk[:sep] or "" - try: - percent = float(chunk[sep + 1:]) - progress_callback(ep_name, percent) - except ValueError: - pass + if progress_callback is not None: + sep = chunk.find("|") + if sep >= 0: + ep_name = chunk[:sep] or "" + try: + percent = float(chunk[sep + 1:]) + progress_callback(ep_name, percent) + except ValueError: + pass response = self._core_interop.execute_command_with_callback( - "download_and_register_eps", request, _on_chunk + "download_and_register_eps", request, _on_chunk, cancel_event ) else: response = self._core_interop.execute_command("download_and_register_eps", request) diff --git a/sdk/python/src/imodel.py b/sdk/python/src/imodel.py index f723e514a..fc63f3747 100644 --- a/sdk/python/src/imodel.py +++ b/sdk/python/src/imodel.py @@ -5,6 +5,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +from threading import Event from typing import Callable, List, Optional from .openai.chat_client import ChatClient @@ -76,10 +77,13 @@ def supports_tool_calling(self) -> Optional[bool]: pass @abstractmethod - def download(self, progress_callback: Callable[[float], None] = None) -> None: + def download(self, progress_callback: Callable[[float], None] = None, + cancel_event: Optional[Event] = None) -> None: """ Download the model to local cache if not already present. :param progress_callback: Optional callback function for download progress as a percentage (0.0 to 100.0). + :param cancel_event: Optional ``threading.Event``. When set, the download will be + cancelled at the next progress update and ``FoundryLocalException`` is raised. """ pass diff --git a/sdk/python/src/openai/audio_client.py b/sdk/python/src/openai/audio_client.py index a203a21a3..29afea8a6 100644 --- a/sdk/python/src/openai/audio_client.py +++ b/sdk/python/src/openai/audio_client.py @@ -109,11 +109,16 @@ def _create_request_json(self, audio_file_path: str) -> str: return json.dumps(request) - def transcribe(self, audio_file_path: str) -> AudioTranscriptionResponse: + def transcribe( + self, + audio_file_path: str, + cancel_event: Optional[threading.Event] = None, + ) -> AudioTranscriptionResponse: """Transcribe an audio file (non-streaming). Args: audio_file_path: Path to the audio file to transcribe. + cancel_event: Optional ``threading.Event`` that signals cancellation. Returns: An ``AudioTranscriptionResponse`` containing the transcribed text. @@ -127,7 +132,7 @@ def transcribe(self, audio_file_path: str) -> AudioTranscriptionResponse: request_json = self._create_request_json(audio_file_path) request = InteropRequest(params={"OpenAICreateRequest": request_json}) - response = self._core_interop.execute_command("audio_transcribe", request) + response = self._core_interop.execute_command("audio_transcribe", request, cancel_event) if response.error is not None: raise FoundryLocalException( f"Audio transcription failed for model '{self.model_id}': {response.error}" @@ -136,7 +141,11 @@ def transcribe(self, audio_file_path: str) -> AudioTranscriptionResponse: data = json.loads(response.data) return AudioTranscriptionResponse(text=data.get("text", "")) - def _stream_chunks(self, request_json: str) -> Generator[AudioTranscriptionResponse, None, None]: + def _stream_chunks( + self, + request_json: str, + cancel_event: Optional[threading.Event] = None, + ) -> Generator[AudioTranscriptionResponse, None, None]: """Background-thread generator that yields parsed chunks from the native streaming call.""" _SENTINEL = object() chunk_queue: queue.Queue = queue.Queue() @@ -152,6 +161,7 @@ def _run() -> None: "audio_transcribe", InteropRequest(params={"OpenAICreateRequest": request_json}), _on_chunk, + cancel_event, ) if resp.error is not None: errors.append( @@ -173,6 +183,7 @@ def _run() -> None: def transcribe_streaming( self, audio_file_path: str, + cancel_event: Optional[threading.Event] = None, ) -> Generator[AudioTranscriptionResponse, None, None]: """Transcribe an audio file with streaming chunks. @@ -183,6 +194,7 @@ def transcribe_streaming( Args: audio_file_path: Path to the audio file to transcribe. + cancel_event: Optional ``threading.Event`` that signals cancellation. Returns: A generator of ``AudioTranscriptionResponse`` objects. @@ -194,4 +206,4 @@ def transcribe_streaming( self._validate_audio_file_path(audio_file_path) request_json = self._create_request_json(audio_file_path) - return self._stream_chunks(request_json) \ No newline at end of file + return self._stream_chunks(request_json, cancel_event) diff --git a/sdk/python/src/openai/chat_client.py b/sdk/python/src/openai/chat_client.py index 0b0d58bcd..fa5660c98 100644 --- a/sdk/python/src/openai/chat_client.py +++ b/sdk/python/src/openai/chat_client.py @@ -192,12 +192,18 @@ def _create_request( return json.dumps(chat_request) - def complete_chat(self, messages: List[ChatCompletionMessageParam], tools: Optional[List[Dict[str, Any]]] = None): + def complete_chat( + self, + messages: List[ChatCompletionMessageParam], + tools: Optional[List[Dict[str, Any]]] = None, + cancel_event: Optional[threading.Event] = None, + ): """Perform a non-streaming chat completion. Args: messages: Conversation history as a list of OpenAI message dicts. tools: Optional list of tool definitions for function calling. + cancel_event: Optional ``threading.Event`` that signals cancellation. Returns: A ``ChatCompletion`` response. @@ -212,7 +218,7 @@ def complete_chat(self, messages: List[ChatCompletionMessageParam], tools: Optio # Send the request to the chat API request = InteropRequest(params={"OpenAICreateRequest": chat_request_json}) - response = self._core_interop.execute_command("chat_completions", request) + response = self._core_interop.execute_command("chat_completions", request, cancel_event) if response.error is not None: raise FoundryLocalException(f"Error during chat completion: {response.error}") @@ -220,7 +226,11 @@ def complete_chat(self, messages: List[ChatCompletionMessageParam], tools: Optio return completion - def _stream_chunks(self, chat_request_json: str) -> Generator[ChatCompletionChunk, None, None]: + def _stream_chunks( + self, + chat_request_json: str, + cancel_event: Optional[threading.Event] = None, + ) -> Generator[ChatCompletionChunk, None, None]: """Background-thread generator that yields parsed chunks from the native streaming call.""" _SENTINEL = object() chunk_queue: queue.Queue = queue.Queue() @@ -246,6 +256,7 @@ def _run() -> None: "chat_completions", InteropRequest(params={"OpenAICreateRequest": chat_request_json}), _on_chunk, + cancel_event, ) if resp.error is not None: errors.append(FoundryLocalException(f"Error during streaming chat completion: {resp.error}")) @@ -264,6 +275,7 @@ def complete_streaming_chat( self, messages: List[ChatCompletionMessageParam], tools: Optional[List[Dict[str, Any]]] = None, + cancel_event: Optional[threading.Event] = None, ) -> Generator[ChatCompletionChunk, None, None]: """Perform a streaming chat completion, yielding chunks as they arrive. @@ -276,6 +288,7 @@ def complete_streaming_chat( Args: messages: Conversation history as a list of OpenAI message dicts. tools: Optional list of tool definitions for function calling. + cancel_event: Optional ``threading.Event`` that signals cancellation. Returns: A generator of ``ChatCompletionChunk`` objects. @@ -287,4 +300,4 @@ def complete_streaming_chat( self._validate_messages(messages) self._validate_tools(tools) chat_request_json = self._create_request(messages, streaming=True, tools=tools) - return self._stream_chunks(chat_request_json) + return self._stream_chunks(chat_request_json, cancel_event) diff --git a/sdk/python/src/openai/embedding_client.py b/sdk/python/src/openai/embedding_client.py index 89a3b8e55..43f82afd2 100644 --- a/sdk/python/src/openai/embedding_client.py +++ b/sdk/python/src/openai/embedding_client.py @@ -7,6 +7,7 @@ import json import logging +import threading from typing import List, Union from ..detail.core_interop import CoreInterop, InteropRequest @@ -46,12 +47,16 @@ def _create_request_json(self, input_value: Union[str, List[str]]) -> str: return json.dumps(embedding_request) - def _execute_embedding_request(self, input_value: Union[str, List[str]]) -> CreateEmbeddingResponse: + def _execute_embedding_request( + self, + input_value: Union[str, List[str]], + cancel_event: threading.Event | None = None, + ) -> CreateEmbeddingResponse: """Send an embedding request and parse the response.""" request_json = self._create_request_json(input_value) request = InteropRequest(params={"OpenAICreateRequest": request_json}) - response = self._core_interop.execute_command("embeddings", request) + response = self._core_interop.execute_command("embeddings", request, cancel_event) if response.error is not None: raise FoundryLocalException( f"Embedding generation failed for model '{self.model_id}': {response.error}" @@ -69,11 +74,16 @@ def _execute_embedding_request(self, input_value: Union[str, List[str]]) -> Crea return CreateEmbeddingResponse.model_validate(data) - def generate_embedding(self, input_text: str) -> CreateEmbeddingResponse: + def generate_embedding( + self, + input_text: str, + cancel_event: threading.Event | None = None, + ) -> CreateEmbeddingResponse: """Generate embeddings for a single input text. Args: input_text: The text to generate embeddings for. + cancel_event: Optional ``threading.Event`` that signals cancellation. Returns: A ``CreateEmbeddingResponse`` containing the embedding vector. @@ -83,13 +93,18 @@ def generate_embedding(self, input_text: str) -> CreateEmbeddingResponse: FoundryLocalException: If the underlying native embeddings command fails. """ self._validate_input(input_text) - return self._execute_embedding_request(input_text) + return self._execute_embedding_request(input_text, cancel_event) - def generate_embeddings(self, inputs: List[str]) -> CreateEmbeddingResponse: + def generate_embeddings( + self, + inputs: List[str], + cancel_event: threading.Event | None = None, + ) -> CreateEmbeddingResponse: """Generate embeddings for multiple input texts in a single request. Args: inputs: The texts to generate embeddings for. + cancel_event: Optional ``threading.Event`` that signals cancellation. Returns: A ``CreateEmbeddingResponse`` containing one embedding vector per input. @@ -104,4 +119,4 @@ def generate_embeddings(self, inputs: List[str]) -> CreateEmbeddingResponse: for text in inputs: self._validate_input(text) - return self._execute_embedding_request(inputs) + return self._execute_embedding_request(inputs, cancel_event) diff --git a/sdk/python/test/test_core_interop.py b/sdk/python/test/test_core_interop.py new file mode 100644 index 000000000..1ddddd610 --- /dev/null +++ b/sdk/python/test/test_core_interop.py @@ -0,0 +1,138 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Tests for CoreInterop callback helpers.""" + +from __future__ import annotations + +import ctypes +import threading +import time + +import pytest + +from foundry_local_sdk.detail.core_interop import ( + CallbackHelper, + CancelledException, + CoreInterop, + InteropRequest, + ResponseBuffer, +) +from foundry_local_sdk.exception import FoundryLocalException + + +class FakeCoreLibrary: + def __init__(self): + self.calls: list[str] = [] + self.buffers = [] + + def create_cancellation_context(self): + self.calls.append("create") + return 123 + + def cancel_cancellation_context(self, context_id): + self.calls.append(f"cancel:{context_id}") + return 0 + + def release_cancellation_context(self, context_id): + self.calls.append(f"release:{context_id}") + return 0 + + def execute_command(self, _req, resp): + self.calls.append("execute") + self._set_response(resp, data=b"ok") + + def execute_command_cancellable(self, _req, resp, context_id): + self.calls.append(f"execute_cancellable:{context_id}") + self.cancel_event.set() + time.sleep(0.05) + self._set_response(resp, error=b"Operation was cancelled by user") + + def free_response(self, _resp): + self.calls.append("free") + + def _set_response(self, resp_arg, data: bytes | None = None, error: bytes | None = None): + resp = ctypes.cast(resp_arg, ctypes.POINTER(ResponseBuffer)).contents + if data is not None: + data_buffer = ctypes.create_string_buffer(data) + self.buffers.append(data_buffer) + resp.Data = ctypes.cast(data_buffer, ctypes.c_void_p).value + resp.DataLength = len(data) + if error is not None: + error_buffer = ctypes.create_string_buffer(error) + self.buffers.append(error_buffer) + resp.Error = ctypes.cast(error_buffer, ctypes.c_void_p).value + resp.ErrorLength = len(error) + + +class TestCoreInterop: + def test_callback_helper_returns_cancel_when_cancel_event_is_set(self): + """Callback helper should return 1 without invoking Python callback when cancelled.""" + cancel_event = threading.Event() + cancel_event.set() + called = False + + def _callback(_chunk: str) -> None: + nonlocal called + called = True + + helper = CallbackHelper(_callback, cancel_event) + helper_ref = ctypes.py_object(helper) + helper_ptr = ctypes.cast(ctypes.pointer(helper_ref), ctypes.c_void_p) + data = ctypes.create_string_buffer(b"50") + + result = CallbackHelper.callback(data, 2, helper_ptr) + + assert result == 1 + assert called is False + assert isinstance(helper.exception, CancelledException) + + def test_execute_command_uses_cancellable_context_when_cancel_event_fires(self): + fake_lib = FakeCoreLibrary() + fake_lib.cancel_event = threading.Event() + original_lib = CoreInterop._flcore_library + original_available = CoreInterop._cancellable_commands_available + CoreInterop._flcore_library = fake_lib + CoreInterop._cancellable_commands_available = True + interop = object.__new__(CoreInterop) + + try: + with pytest.raises(FoundryLocalException, match="Operation cancelled"): + interop.execute_command( + "chat_completions", + InteropRequest(params={"OpenAICreateRequest": "{}"}), + fake_lib.cancel_event, + ) + finally: + CoreInterop._flcore_library = original_lib + CoreInterop._cancellable_commands_available = original_available + + assert fake_lib.calls == [ + "create", + "execute_cancellable:123", + "cancel:123", + "release:123", + "free", + ] + + def test_execute_command_falls_back_when_cancellable_context_is_unavailable(self): + fake_lib = FakeCoreLibrary() + original_lib = CoreInterop._flcore_library + original_available = CoreInterop._cancellable_commands_available + CoreInterop._flcore_library = fake_lib + CoreInterop._cancellable_commands_available = False + interop = object.__new__(CoreInterop) + + try: + response = interop.execute_command( + "get_model_list", + InteropRequest(params={}), + threading.Event(), + ) + finally: + CoreInterop._flcore_library = original_lib + CoreInterop._cancellable_commands_available = original_available + + assert response.data == "ok" + assert fake_lib.calls == ["execute", "free"] diff --git a/sdk/python/test/test_foundry_local_manager.py b/sdk/python/test/test_foundry_local_manager.py index 315288912..3abb37f64 100644 --- a/sdk/python/test/test_foundry_local_manager.py +++ b/sdk/python/test/test_foundry_local_manager.py @@ -6,6 +6,10 @@ from __future__ import annotations +import threading + +from foundry_local_sdk.foundry_local_manager import FoundryLocalManager + class _Response: def __init__(self, data=None, error=None): @@ -22,6 +26,12 @@ def execute_command(self, command_name, command_input=None): self.calls.append((command_name, command_input)) return self._responses[command_name] + def execute_command_with_callback( + self, command_name, command_input=None, callback=None, cancel_event=None + ): + self.calls.append((command_name, command_input, callback, cancel_event)) + return self._responses[command_name] + class TestFoundryLocalManager: """Foundry Local Manager Tests.""" @@ -81,3 +91,36 @@ def test_download_and_register_eps_returns_result(self, manager): assert result.status == "ok" assert result.registered_eps == ["CUDAExecutionProvider"] assert result.failed_eps == [] + + def test_download_and_register_eps_uses_callback_path_when_cancel_event_is_provided(self): + fake_core = _FakeCoreInterop( + { + "download_and_register_eps": _Response( + data=( + '{"Success":true,"Status":"ok",' + '"RegisteredEps":["CUDAExecutionProvider"],"FailedEps":[]}' + ), + error=None, + ) + } + ) + manager = FoundryLocalManager.__new__(FoundryLocalManager) + manager._core_interop = fake_core + manager.catalog = type( + "_FakeCatalog", + (), + {"_invalidate_cache": staticmethod(lambda: None)}, + )() + cancel_event = threading.Event() + + result = manager.download_and_register_eps( + ["CUDAExecutionProvider"], cancel_event=cancel_event + ) + + assert result.success is True + assert len(fake_core.calls) == 1 + command_name, command_input, callback, seen_cancel_event = fake_core.calls[0] + assert command_name == "download_and_register_eps" + assert command_input.params == {"Names": "CUDAExecutionProvider"} + assert callable(callback) + assert seen_cancel_event is cancel_event diff --git a/sdk/python/test/test_model.py b/sdk/python/test/test_model.py index e2ea15090..3d83a44ec 100644 --- a/sdk/python/test/test_model.py +++ b/sdk/python/test/test_model.py @@ -6,6 +6,12 @@ from __future__ import annotations +import threading + +from types import SimpleNamespace + +from foundry_local_sdk.detail.model_variant import ModelVariant + from .conftest import TEST_MODEL_ALIAS, AUDIO_MODEL_ALIAS @@ -86,3 +92,75 @@ def test_should_expose_supports_tool_calling(self, catalog): assert model is not None stc = model.supports_tool_calling assert stc is None or isinstance(stc, bool) + + def test_download_should_use_callback_path_when_cancel_event_is_provided(self): + """Model download should route through callback interop when cancellation is enabled.""" + + class _Response: + def __init__(self, data=None, error=None): + self.data = data + self.error = error + + class _FakeCoreInterop: + def __init__(self): + self.calls = [] + + def execute_command(self, command_name, command_input=None): + raise AssertionError( + "download should not use execute_command when cancel_event is provided" + ) + + def execute_command_with_callback( + self, command_name, command_input=None, callback=None, cancel_event=None + ): + self.calls.append((command_name, command_input, callback, cancel_event)) + return _Response(data="", error=None) + + fake_core = _FakeCoreInterop() + cancel_event = threading.Event() + variant = ModelVariant.__new__(ModelVariant) + variant._model_info = SimpleNamespace(id="test-model-cpu:1", alias=TEST_MODEL_ALIAS) + variant._id = "test-model-cpu:1" + variant._alias = TEST_MODEL_ALIAS + variant._core_interop = fake_core + variant._model_load_manager = None + + variant.download(cancel_event=cancel_event) + + assert len(fake_core.calls) == 1 + command_name, command_input, callback, seen_cancel_event = fake_core.calls[0] + assert command_name == "download_model" + assert command_input.params == {"Model": "test-model-cpu:1"} + assert callable(callback) + assert seen_cancel_event is cancel_event + callback("50") + + def test_download_should_parse_numeric_progress_chunk(self): + """Model download progress parsing should parse the numeric native chunk.""" + + class _Response: + def __init__(self, data=None, error=None): + self.data = data + self.error = error + + class _FakeCoreInterop: + def execute_command(self, command_name, command_input=None): + raise AssertionError("download should use callback interop when progress is provided") + + def execute_command_with_callback( + self, command_name, command_input=None, callback=None, cancel_event=None + ): + callback("12.5") + return _Response(data="", error=None) + + progress = [] + variant = ModelVariant.__new__(ModelVariant) + variant._model_info = SimpleNamespace(id="test-model-cpu:1", alias=TEST_MODEL_ALIAS) + variant._id = "test-model-cpu:1" + variant._alias = TEST_MODEL_ALIAS + variant._core_interop = _FakeCoreInterop() + variant._model_load_manager = None + + variant.download(progress_callback=progress.append) + + assert progress == [12.5] diff --git a/sdk/rust/README.md b/sdk/rust/README.md index ce97a7dd0..d017ce5e2 100644 --- a/sdk/rust/README.md +++ b/sdk/rust/README.md @@ -107,6 +107,28 @@ manager.download_and_register_eps_with_progress(None, move |ep_name: &str, perce println!(); ``` +#### Cancelling model and EP downloads + +Use a shared `Arc` with the cancellable download APIs. Set the flag from another task or signal handler to stop the in-progress download. + +```rust +use std::sync::{ + Arc, + atomic::AtomicBool, +}; + +// manager and model already initialized +let cancel_flag = Arc::new(AtomicBool::new(false)); +// call cancel_flag.store(true, ...) from another task or signal handler to cancel + +manager + .download_and_register_eps_cancellable(None, Arc::clone(&cancel_flag)) + .await?; +model + .download_cancellable(None::, Arc::clone(&cancel_flag)) + .await?; +``` + Catalog access does not block on EP downloads. Call `download_and_register_eps` when you need hardware-accelerated execution providers. ## Quick Start diff --git a/sdk/rust/src/detail/core_interop.rs b/sdk/rust/src/detail/core_interop.rs index 0d17fe62d..1f81b56c9 100644 --- a/sdk/rust/src/detail/core_interop.rs +++ b/sdk/rust/src/detail/core_interop.rs @@ -9,14 +9,31 @@ use std::ffi::CString; use std::path::{Path, PathBuf}; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; +use std::thread; +use std::time::Duration; use libloading::{Library, Symbol}; use serde_json::Value; +use tokio_util::sync::CancellationToken; use crate::configuration::Configuration; use crate::error::{FoundryLocalError, Result}; +fn checked_i32_length(name: &str, len: usize) -> Result { + i32::try_from(len).map_err(|_| FoundryLocalError::CommandExecution { + reason: format!("{name} length {len} exceeds i32::MAX"), + }) +} + +fn is_user_cancellation_error(error: &str) -> bool { + error + .trim() + .trim_end_matches('.') + .eq_ignore_ascii_case("Operation was cancelled by user") +} + // ── FFI types ──────────────────────────────────────────────────────────────── /// Request buffer passed to the native library. @@ -64,6 +81,10 @@ struct StreamingRequestBuffer { /// Signature for `execute_command`. type ExecuteCommandFn = unsafe extern "C" fn(*const RequestBuffer, *mut ResponseBuffer); +/// Signature for `execute_command_cancellable`. +type ExecuteCommandCancellableFn = + unsafe extern "C" fn(*const RequestBuffer, *mut ResponseBuffer, i64); + /// Signature for the streaming callback invoked by the native library. /// Returns 0 to continue, 1 to cancel. type CallbackFn = unsafe extern "C" fn(*const u8, i32, *mut std::ffi::c_void) -> i32; @@ -76,10 +97,27 @@ type ExecuteCommandWithCallbackFn = unsafe extern "C" fn( *mut std::ffi::c_void, ); +/// Signature for `execute_command_with_callback_cancellable`. +type ExecuteCommandWithCallbackCancellableFn = unsafe extern "C" fn( + *const RequestBuffer, + *mut ResponseBuffer, + CallbackFn, + *mut std::ffi::c_void, + i64, +); + /// Signature for `execute_command_with_binary`. type ExecuteCommandWithBinaryFn = unsafe extern "C" fn(*const StreamingRequestBuffer, *mut ResponseBuffer); +/// Signature for `execute_command_with_binary_cancellable`. +type ExecuteCommandWithBinaryCancellableFn = + unsafe extern "C" fn(*const StreamingRequestBuffer, *mut ResponseBuffer, i64); + +type CreateCancellationContextFn = unsafe extern "C" fn() -> i64; +type CancelCancellationContextFn = unsafe extern "C" fn(i64) -> i32; +type ReleaseCancellationContextFn = unsafe extern "C" fn(i64) -> i32; + // ── Library name helpers ───────────────────────────────────────────────────── #[cfg(target_os = "windows")] @@ -143,6 +181,8 @@ unsafe fn free_native_buffer(ptr: *mut u8) { struct StreamingCallbackState<'a> { callback: &'a mut dyn FnMut(&str), buf: Vec, + cancel_flag: Option>, + cancelled_observed: bool, } impl<'a> StreamingCallbackState<'a> { @@ -150,9 +190,37 @@ impl<'a> StreamingCallbackState<'a> { Self { callback, buf: Vec::new(), + cancel_flag: None, + cancelled_observed: false, } } + fn with_cancel(callback: &'a mut dyn FnMut(&str), cancel_flag: Arc) -> Self { + Self { + callback, + buf: Vec::new(), + cancel_flag: Some(cancel_flag), + cancelled_observed: false, + } + } + + /// Records and returns `true` only when this callback invocation observes a cancellation request. + fn mark_cancelled_if_requested(&mut self) -> bool { + let cancelled = self + .cancel_flag + .as_ref() + .is_some_and(|f| f.load(Ordering::Relaxed)); + if cancelled { + self.cancelled_observed = true; + } + + cancelled + } + + fn cancellation_observed(&self) -> bool { + self.cancelled_observed + } + /// Append raw bytes, decode as much valid UTF-8 as possible, and forward /// complete text to the callback. Any trailing incomplete multi-byte /// sequence is kept in the buffer for the next call. Invalid byte @@ -193,9 +261,13 @@ impl<'a> StreamingCallbackState<'a> { } } - /// Flush any remaining bytes as lossy UTF-8 (called once after the native - /// call completes). + /// Flush any remaining bytes as lossy UTF-8 after a completed native call. fn flush(&mut self) { + if self.cancelled_observed { + self.buf.clear(); + return; + } + if !self.buf.is_empty() { let text = String::from_utf8_lossy(&self.buf).into_owned(); (self.callback)(&text); @@ -225,15 +297,93 @@ unsafe extern "C" fn streaming_trampoline( // by the caller of `execute_command_with_callback` for the duration of // the native call. let state = &mut *(user_data as *mut StreamingCallbackState<'_>); + + // Check for cancellation before processing the chunk. + if state.mark_cancelled_if_requested() { + return 1; // cancel + } + // SAFETY: `data` is valid for `length` bytes as guaranteed by the native // core's callback contract. let slice = std::slice::from_raw_parts(data, length as usize); state.push(slice); + 0 // continue })); - if result.is_err() { - 1 - } else { - 0 + result.unwrap_or(1) +} + +struct CancellationContextGuard { + id: i64, + release_fn: ReleaseCancellationContextFn, + stop_watcher: Arc, + watcher: Option>, +} + +impl CancellationContextGuard { + fn new( + id: i64, + cancel_fn: CancelCancellationContextFn, + release_fn: ReleaseCancellationContextFn, + token: CancellationToken, + ) -> Self { + Self::new_with_watcher(id, cancel_fn, release_fn, move || token.is_cancelled()) + } + + fn new_for_flag( + id: i64, + cancel_fn: CancelCancellationContextFn, + release_fn: ReleaseCancellationContextFn, + cancel_flag: Arc, + ) -> Self { + Self::new_with_watcher(id, cancel_fn, release_fn, move || { + cancel_flag.load(Ordering::Relaxed) + }) + } + + fn new_with_watcher( + id: i64, + cancel_fn: CancelCancellationContextFn, + release_fn: ReleaseCancellationContextFn, + is_cancelled: F, + ) -> Self + where + F: Fn() -> bool + Send + 'static, + { + let stop_watcher = Arc::new(AtomicBool::new(false)); + let watcher_stop = Arc::clone(&stop_watcher); + let watcher = thread::spawn(move || { + while !watcher_stop.load(Ordering::Relaxed) { + if is_cancelled() { + // SAFETY: `cancel_fn` was loaded from the native core with the expected C ABI. + unsafe { + cancel_fn(id); + } + return; + } + thread::sleep(Duration::from_millis(10)); + } + }); + + Self { + id, + release_fn, + stop_watcher, + watcher: Some(watcher), + } + } +} + +impl Drop for CancellationContextGuard { + fn drop(&mut self) { + self.stop_watcher.store(true, Ordering::Relaxed); + if let Some(watcher) = self.watcher.take() { + let _ = watcher.join(); + } + + // SAFETY: `release_fn` was loaded from the native core with the expected C ABI. + unsafe { + (self.release_fn)(self.id); + } } } @@ -248,14 +398,30 @@ pub(crate) struct CoreInterop { #[cfg(target_os = "windows")] _dependency_libs: Vec, execute_command: unsafe extern "C" fn(*const RequestBuffer, *mut ResponseBuffer), + execute_command_cancellable: + Option, execute_command_with_callback: unsafe extern "C" fn( *const RequestBuffer, *mut ResponseBuffer, CallbackFn, *mut std::ffi::c_void, ), + execute_command_with_callback_cancellable: Option< + unsafe extern "C" fn( + *const RequestBuffer, + *mut ResponseBuffer, + CallbackFn, + *mut std::ffi::c_void, + i64, + ), + >, execute_command_with_binary: Option, + execute_command_with_binary_cancellable: + Option, + create_cancellation_context: Option i64>, + cancel_cancellation_context: Option i32>, + release_cancellation_context: Option i32>, } impl std::fmt::Debug for CoreInterop { @@ -316,6 +482,13 @@ impl CoreInterop { *sym }; + let execute_command_cancellable: Option = unsafe { + library + .get::(b"execute_command_cancellable\0") + .ok() + .map(|sym| *sym) + }; + // SAFETY: Same as above — symbol must match `ExecuteCommandWithCallbackFn`. let execute_command_with_callback: ExecuteCommandWithCallbackFn = unsafe { let sym: Symbol = library @@ -326,6 +499,17 @@ impl CoreInterop { *sym }; + let execute_command_with_callback_cancellable: Option< + ExecuteCommandWithCallbackCancellableFn, + > = unsafe { + library + .get::( + b"execute_command_with_callback_cancellable\0", + ) + .ok() + .map(|sym| *sym) + }; + // SAFETY: Same as above — symbol must match `ExecuteCommandWithBinaryFn`. // Optional: older native cores may not export this symbol (used for audio streaming). let execute_command_with_binary: Option = unsafe { @@ -335,22 +519,167 @@ impl CoreInterop { .map(|sym| *sym) }; + let execute_command_with_binary_cancellable: Option = unsafe { + library + .get::( + b"execute_command_with_binary_cancellable\0", + ) + .ok() + .map(|sym| *sym) + }; + + let create_cancellation_context: Option = unsafe { + library + .get::(b"create_cancellation_context\0") + .ok() + .map(|sym| *sym) + }; + let cancel_cancellation_context: Option = unsafe { + library + .get::(b"cancel_cancellation_context\0") + .ok() + .map(|sym| *sym) + }; + let release_cancellation_context: Option = unsafe { + library + .get::(b"release_cancellation_context\0") + .ok() + .map(|sym| *sym) + }; + Ok(Self { _library: library, #[cfg(target_os = "windows")] _dependency_libs, execute_command, + execute_command_cancellable, execute_command_with_callback, + execute_command_with_callback_cancellable, execute_command_with_binary, + execute_command_with_binary_cancellable, + create_cancellation_context, + cancel_cancellation_context, + release_cancellation_context, }) } + fn cancellation_context_available(&self) -> bool { + self.create_cancellation_context.is_some() + && self.cancel_cancellation_context.is_some() + && self.release_cancellation_context.is_some() + } + + fn create_cancellation_context( + &self, + cancellation_token: Option<&CancellationToken>, + can_use_cancellable_command: bool, + ) -> Result> { + let Some(token) = cancellation_token else { + return Ok(None); + }; + + if token.is_cancelled() { + return Err(FoundryLocalError::CommandExecution { + reason: "Operation cancelled".into(), + }); + } + + if !can_use_cancellable_command || !self.cancellation_context_available() { + return Ok(None); + } + + let create_fn = self.create_cancellation_context.unwrap(); + let cancel_fn = self.cancel_cancellation_context.unwrap(); + let release_fn = self.release_cancellation_context.unwrap(); + + // SAFETY: Function pointers were loaded from the native core with the expected C ABI. + let id = unsafe { create_fn() }; + if id == 0 { + return Err(FoundryLocalError::CommandExecution { + reason: "Failed to create native cancellation context".into(), + }); + } + + Ok(Some(CancellationContextGuard::new( + id, + cancel_fn, + release_fn, + token.clone(), + ))) + } + + fn create_cancellation_context_for_flag( + &self, + cancel_flag: Option<&Arc>, + can_use_cancellable_command: bool, + ) -> Result> { + let Some(flag) = cancel_flag else { + return Ok(None); + }; + + if flag.load(Ordering::Relaxed) { + return Err(FoundryLocalError::CommandExecution { + reason: "Operation cancelled".into(), + }); + } + + if !can_use_cancellable_command || !self.cancellation_context_available() { + return Ok(None); + } + + let create_fn = self.create_cancellation_context.unwrap(); + let cancel_fn = self.cancel_cancellation_context.unwrap(); + let release_fn = self.release_cancellation_context.unwrap(); + + // SAFETY: Function pointers were loaded from the native core with the expected C ABI. + let id = unsafe { create_fn() }; + if id == 0 { + return Err(FoundryLocalError::CommandExecution { + reason: "Failed to create native cancellation context".into(), + }); + } + + Ok(Some(CancellationContextGuard::new_for_flag( + id, + cancel_fn, + release_fn, + Arc::clone(flag), + ))) + } + /// Execute a synchronous command against the native core. /// /// `command` is the operation name (e.g. `"initialize"`, `"load_model"`). /// `params` is an optional JSON value that will be serialised and sent as /// the data payload. pub fn execute_command(&self, command: &str, params: Option<&Value>) -> Result { + self.execute_command_impl(command, params, None) + } + + pub fn execute_command_cancellable( + &self, + command: &str, + params: Option<&Value>, + cancellation_token: CancellationToken, + ) -> Result { + self.execute_command_impl(command, params, Some(cancellation_token)) + } + + fn execute_command_impl( + &self, + command: &str, + params: Option<&Value>, + cancellation_token: Option, + ) -> Result { + if cancellation_token + .as_ref() + .is_some_and(|t| t.is_cancelled()) + { + return Err(FoundryLocalError::CommandExecution { + reason: "Operation cancelled".into(), + }); + } + let cmd = CString::new(command).map_err(|e| FoundryLocalError::CommandExecution { reason: format!("Invalid command string: {e}"), })?; @@ -366,21 +695,40 @@ impl CoreInterop { let request = RequestBuffer { command: cmd.as_ptr(), - command_length: cmd.as_bytes().len() as i32, + command_length: checked_i32_length("command", cmd.as_bytes().len())?, data: data_cstr.as_ptr(), - data_length: data_cstr.as_bytes().len() as i32, + data_length: checked_i32_length("data", data_cstr.as_bytes().len())?, }; let mut response = ResponseBuffer::new(); + let cancellation_context = self.create_cancellation_context( + cancellation_token.as_ref(), + self.execute_command_cancellable.is_some(), + )?; // SAFETY: `request` fields point into `cmd` and `data_cstr` which are // alive for the duration of this call. The native function writes into // `response` using its documented C ABI. unsafe { - (self.execute_command)(&request, &mut response); + if let (Some(context), Some(execute_cancellable)) = ( + cancellation_context.as_ref(), + self.execute_command_cancellable, + ) { + execute_cancellable(&request, &mut response, context.id); + } else { + if cancellation_token + .as_ref() + .is_some_and(|t| t.is_cancelled()) + { + return Err(FoundryLocalError::CommandExecution { + reason: "Operation cancelled".into(), + }); + } + (self.execute_command)(&request, &mut response); + } } - Self::process_response(response) + Self::process_response_with_cancellation(response, cancellation_token.as_ref()) } /// Execute a command with an additional binary payload. @@ -392,6 +740,16 @@ impl CoreInterop { command: &str, params: Option<&Value>, binary_data: &[u8], + ) -> Result { + self.execute_command_with_binary_impl(command, params, binary_data, None) + } + + fn execute_command_with_binary_impl( + &self, + command: &str, + params: Option<&Value>, + binary_data: &[u8], + cancellation_token: Option, ) -> Result { let native_fn = self.execute_command_with_binary.ok_or_else(|| { FoundryLocalError::CommandExecution { @@ -416,26 +774,37 @@ impl CoreInterop { let request = StreamingRequestBuffer { command: cmd.as_ptr(), - command_length: cmd.as_bytes().len() as i32, + command_length: checked_i32_length("command", cmd.as_bytes().len())?, data: data_cstr.as_ptr(), - data_length: data_cstr.as_bytes().len() as i32, + data_length: checked_i32_length("data", data_cstr.as_bytes().len())?, binary_data: if binary_data.is_empty() { std::ptr::null() } else { binary_data.as_ptr() }, - binary_data_length: binary_data.len() as i32, + binary_data_length: checked_i32_length("binary data", binary_data.len())?, }; let mut response = ResponseBuffer::new(); + let cancellation_context = self.create_cancellation_context( + cancellation_token.as_ref(), + self.execute_command_with_binary_cancellable.is_some(), + )?; // SAFETY: `request` fields point into `cmd`, `data_cstr`, and // `binary_data` which are all alive for the duration of this call. unsafe { - (native_fn)(&request, &mut response); + if let (Some(context), Some(native_cancellable_fn)) = ( + cancellation_context.as_ref(), + self.execute_command_with_binary_cancellable, + ) { + native_cancellable_fn(&request, &mut response, context.id); + } else { + native_fn(&request, &mut response); + } } - Self::process_response(response) + Self::process_response_with_cancellation(response, cancellation_token.as_ref()) } /// Execute a command that streams results back via `callback`. @@ -452,6 +821,32 @@ impl CoreInterop { where F: FnMut(&str), { + self.execute_command_streaming_impl(command, params, &mut callback, None) + } + + /// Like [`Self::execute_command_streaming`], but accepts a cancellation + /// flag. When `cancel_flag` is set to `true`, the native call will be + /// cancelled at the next callback invocation and an error is returned. + pub fn execute_command_streaming_cancellable( + &self, + command: &str, + params: Option<&Value>, + mut callback: F, + cancel_flag: Arc, + ) -> Result + where + F: FnMut(&str), + { + self.execute_command_streaming_impl(command, params, &mut callback, Some(cancel_flag)) + } + + fn execute_command_streaming_impl( + &self, + command: &str, + params: Option<&Value>, + callback: &mut dyn FnMut(&str), + cancel_flag: Option>, + ) -> Result { let cmd = CString::new(command).map_err(|e| FoundryLocalError::CommandExecution { reason: format!("Invalid command string: {e}"), })?; @@ -467,17 +862,23 @@ impl CoreInterop { let request = RequestBuffer { command: cmd.as_ptr(), - command_length: cmd.as_bytes().len() as i32, + command_length: checked_i32_length("command", cmd.as_bytes().len())?, data: data_cstr.as_ptr(), - data_length: data_cstr.as_bytes().len() as i32, + data_length: checked_i32_length("data", data_cstr.as_bytes().len())?, }; let mut response = ResponseBuffer::new(); + let cancellation_context = self.create_cancellation_context_for_flag( + cancel_flag.as_ref(), + self.execute_command_with_callback_cancellable.is_some(), + )?; // Wrap the closure in a StreamingCallbackState that handles partial // UTF-8 sequences split across native callbacks. - let mut cb = |chunk: &str| callback(chunk); - let mut state = StreamingCallbackState::new(&mut cb); + let mut state = match cancel_flag { + Some(ref flag) => StreamingCallbackState::with_cancel(callback, Arc::clone(flag)), + None => StreamingCallbackState::new(callback), + }; let user_data = &mut state as *mut StreamingCallbackState<'_> as *mut std::ffi::c_void; // SAFETY: `request` fields point into `cmd` and `data_cstr` which are @@ -486,18 +887,46 @@ impl CoreInterop { // `streaming_trampoline` will only cast `user_data` back to // `StreamingCallbackState`. unsafe { - (self.execute_command_with_callback)( - &request, - &mut response, - streaming_trampoline, - user_data, - ); + if let (Some(context), Some(execute_cancellable)) = ( + cancellation_context.as_ref(), + self.execute_command_with_callback_cancellable, + ) { + execute_cancellable( + &request, + &mut response, + streaming_trampoline, + user_data, + context.id, + ); + } else { + (self.execute_command_with_callback)( + &request, + &mut response, + streaming_trampoline, + user_data, + ); + } } - // Flush any trailing partial UTF-8 bytes. + let cancelled = state.cancellation_observed(); + + // Flush any trailing partial UTF-8 bytes unless cancellation was observed. state.flush(); - Self::process_response(response) + if cancelled { + // Free native response memory before returning the error. + Self::process_response(response).ok(); + return Err(FoundryLocalError::CommandExecution { + reason: "Operation cancelled".to_string(), + }); + } + + Self::process_response_with_cancellation_requested( + response, + cancel_flag + .as_ref() + .is_some_and(|flag| flag.load(Ordering::Relaxed)), + ) } /// Async version of [`Self::execute_command`]. @@ -517,6 +946,22 @@ impl CoreInterop { })? } + pub async fn execute_command_async_cancellable( + self: &Arc, + command: String, + params: Option, + cancellation_token: CancellationToken, + ) -> Result { + let this = Arc::clone(self); + tokio::task::spawn_blocking(move || { + this.execute_command_cancellable(&command, params.as_ref(), cancellation_token) + }) + .await + .map_err(|e| FoundryLocalError::CommandExecution { + reason: format!("task join error: {e}"), + })? + } + /// Async version of [`Self::execute_command_streaming`]. /// /// The `callback` is invoked on the blocking thread – it must be @@ -540,6 +985,36 @@ impl CoreInterop { })? } + /// Async version of [`Self::execute_command_streaming_cancellable`]. + /// + /// Accepts a shared cancellation flag (`Arc`). When the flag + /// is set to `true`, the native call will be cancelled at the next + /// callback invocation and an error is returned. + pub async fn execute_command_streaming_cancellable_async( + self: &Arc, + command: String, + params: Option, + callback: F, + cancel_flag: Arc, + ) -> Result + where + F: FnMut(&str) + Send + 'static, + { + let this = Arc::clone(self); + tokio::task::spawn_blocking(move || { + this.execute_command_streaming_cancellable( + &command, + params.as_ref(), + callback, + cancel_flag, + ) + }) + .await + .map_err(|e| FoundryLocalError::CommandExecution { + reason: format!("task join error: {e}"), + })? + } + /// Async streaming variant that bridges the FFI callback into a /// [`tokio::sync::mpsc`] channel. /// @@ -601,6 +1076,23 @@ impl CoreInterop { /// /// Takes the buffer by value so it can only be consumed once. fn process_response(response: ResponseBuffer) -> Result { + Self::process_response_with_cancellation(response, None) + } + + fn process_response_with_cancellation( + response: ResponseBuffer, + cancellation_token: Option<&CancellationToken>, + ) -> Result { + Self::process_response_with_cancellation_requested( + response, + cancellation_token.is_some_and(|token| token.is_cancelled()), + ) + } + + fn process_response_with_cancellation_requested( + response: ResponseBuffer, + cancellation_requested: bool, + ) -> Result { // SAFETY: response fields are either null or valid native-allocated // pointers filled by the preceding FFI call. let error_str = unsafe { Self::read_native_buffer(response.error, response.error_length) }; @@ -617,6 +1109,12 @@ impl CoreInterop { // Return error or data. if let Some(err) = error_str { + if cancellation_requested && is_user_cancellation_error(&err) { + return Err(FoundryLocalError::CommandExecution { + reason: "Operation cancelled".into(), + }); + } + Err(FoundryLocalError::CommandExecution { reason: err }) } else { Ok(data_str.unwrap_or_default()) @@ -702,3 +1200,68 @@ impl CoreInterop { Ok(libs) } } + +#[cfg(test)] +mod tests { + use super::{checked_i32_length, StreamingCallbackState}; + use crate::error::FoundryLocalError; + use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }; + + #[test] + fn cancellation_request_after_callback_is_not_observed_until_next_callback() { + let cancel_flag = Arc::new(AtomicBool::new(false)); + let mut callback = |_chunk: &str| {}; + let mut state = + StreamingCallbackState::with_cancel(&mut callback, Arc::clone(&cancel_flag)); + + state.push(b"100"); + cancel_flag.store(true, Ordering::Relaxed); + + assert!(!state.cancellation_observed()); + } + + #[test] + fn cancellation_is_recorded_when_callback_observes_cancel_flag() { + let cancel_flag = Arc::new(AtomicBool::new(true)); + let mut callback = |_chunk: &str| {}; + let mut state = StreamingCallbackState::with_cancel(&mut callback, cancel_flag); + + assert!(state.mark_cancelled_if_requested()); + assert!(state.cancellation_observed()); + } + + #[test] + fn flush_drops_buffer_after_cancellation_without_callback() { + let cancel_flag = Arc::new(AtomicBool::new(true)); + let mut chunks = Vec::new(); + + { + let mut callback = |chunk: &str| chunks.push(chunk.to_owned()); + let mut state = StreamingCallbackState::with_cancel(&mut callback, cancel_flag); + + state.push(&[0xE2]); + assert!(state.mark_cancelled_if_requested()); + state.flush(); + } + + assert!(chunks.is_empty()); + } + + #[test] + fn checked_i32_length_rejects_too_large_values() { + assert_eq!( + checked_i32_length("data", i32::MAX as usize).unwrap(), + i32::MAX + ); + + match checked_i32_length("data", i32::MAX as usize + 1).unwrap_err() { + FoundryLocalError::CommandExecution { reason } => { + assert!(reason.contains("exceeds i32::MAX")); + } + err => panic!("unexpected error: {err:?}"), + } + } +} diff --git a/sdk/rust/src/detail/model.rs b/sdk/rust/src/detail/model.rs index 08288aee8..5921fbcde 100644 --- a/sdk/rust/src/detail/model.rs +++ b/sdk/rust/src/detail/model.rs @@ -6,7 +6,7 @@ use std::fmt; use std::path::PathBuf; -use std::sync::atomic::{AtomicUsize, Ordering::Relaxed}; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering::Relaxed}; use std::sync::Arc; use super::core_interop::CoreInterop; @@ -213,6 +213,23 @@ impl Model { self.selected_variant().download(progress).await } + /// Like [`Self::download`], but accepts a shared cancellation flag + /// (`Arc`). When the flag is set to `true`, the download + /// will be cancelled at the next progress callback and an error is + /// returned. + pub async fn download_cancellable( + &self, + progress: Option, + cancel_flag: Arc, + ) -> Result<()> + where + F: FnMut(f64) + Send + 'static, + { + self.selected_variant() + .download_cancellable(progress, cancel_flag) + .await + } + /// Return the local file-system path of the (selected) variant. pub async fn path(&self) -> Result { self.selected_variant().path().await diff --git a/sdk/rust/src/detail/model_variant.rs b/sdk/rust/src/detail/model_variant.rs index 1f8ce7d5b..e8d437131 100644 --- a/sdk/rust/src/detail/model_variant.rs +++ b/sdk/rust/src/detail/model_variant.rs @@ -5,6 +5,7 @@ use std::fmt; use std::path::PathBuf; +use std::sync::atomic::AtomicBool; use std::sync::Arc; use serde_json::json; @@ -88,26 +89,61 @@ impl ModelVariant { } pub(crate) async fn download(&self, progress: Option) -> Result<()> + where + F: FnMut(f64) + Send + 'static, + { + self.download_impl(progress, None).await + } + + /// Like [`Self::download`], but accepts a shared cancellation flag. + /// When `cancel_flag` is set to `true`, the download will be cancelled at + /// the next progress callback. + pub(crate) async fn download_cancellable( + &self, + progress: Option, + cancel_flag: Arc, + ) -> Result<()> + where + F: FnMut(f64) + Send + 'static, + { + self.download_impl(progress, Some(cancel_flag)).await + } + + async fn download_impl( + &self, + progress: Option, + cancel_flag: Option>, + ) -> Result<()> where F: FnMut(f64) + Send + 'static, { let params = json!({ "Params": { "Model": self.info.id } }); - match progress { - Some(mut cb) => { - let wrapper = move |chunk: &str| { - for token in chunk.split_whitespace() { - if let Ok(pct) = token.parse::() { - cb(pct); - } + if progress.is_none() && cancel_flag.is_none() { + self.core + .execute_command_async("download_model".into(), Some(params)) + .await?; + } else { + let mut progress = progress; + let wrapper = move |chunk: &str| { + if let Some(cb) = progress.as_mut() { + if let Ok(pct) = chunk.trim().parse::() { + cb(pct); } - }; + } + }; + + if let Some(flag) = cancel_flag { self.core - .execute_command_streaming_async("download_model".into(), Some(params), wrapper) + .execute_command_streaming_cancellable_async( + "download_model".into(), + Some(params), + wrapper, + flag, + ) .await?; - } - None => { + } else { self.core - .execute_command_async("download_model".into(), Some(params)) + .execute_command_streaming_async("download_model".into(), Some(params), wrapper) .await?; } } diff --git a/sdk/rust/src/foundry_local_manager.rs b/sdk/rust/src/foundry_local_manager.rs index 0c22ef154..a14b42b75 100644 --- a/sdk/rust/src/foundry_local_manager.rs +++ b/sdk/rust/src/foundry_local_manager.rs @@ -4,6 +4,7 @@ //! library, provides access to the model [`Catalog`], and can start / stop //! the local web service. +use std::sync::atomic::AtomicBool; use std::sync::{Arc, Mutex, OnceLock}; use serde_json::json; @@ -150,7 +151,19 @@ impl FoundryLocalManager { &self, names: Option<&[&str]>, ) -> Result { - self.download_and_register_eps_impl(names, None::) + self.download_and_register_eps_impl(names, None::, None) + .await + } + + /// Like [`Self::download_and_register_eps`], but accepts a shared + /// cancellation flag (`Arc`). When the flag is set to `true`, + /// the download will be cancelled at the next progress callback. + pub async fn download_and_register_eps_cancellable( + &self, + names: Option<&[&str]>, + cancel_flag: Arc, + ) -> Result { + self.download_and_register_eps_impl(names, None::, Some(cancel_flag)) .await } @@ -169,7 +182,23 @@ impl FoundryLocalManager { where F: FnMut(&str, f64) + Send + 'static, { - self.download_and_register_eps_impl(names, Some(progress_callback)) + self.download_and_register_eps_impl(names, Some(progress_callback), None) + .await + } + + /// Like [`Self::download_and_register_eps_with_progress`], but accepts a + /// shared cancellation flag (`Arc`). When the flag is set to + /// `true`, the download will be cancelled at the next progress callback. + pub async fn download_and_register_eps_with_progress_cancellable( + &self, + names: Option<&[&str]>, + progress_callback: F, + cancel_flag: Arc, + ) -> Result + where + F: FnMut(&str, f64) + Send + 'static, + { + self.download_and_register_eps_impl(names, Some(progress_callback), Some(cancel_flag)) .await } @@ -177,6 +206,7 @@ impl FoundryLocalManager { &self, names: Option<&[&str]>, progress_callback: Option, + cancel_flag: Option>, ) -> Result where F: FnMut(&str, f64) + Send + 'static, @@ -186,8 +216,28 @@ impl FoundryLocalManager { _ => None, }; - let raw = match progress_callback { - Some(cb) => { + let raw = match (progress_callback, cancel_flag) { + (Some(cb), Some(flag)) => { + let mut callback = cb; + let wrapper = move |chunk: &str| { + if let Some(sep) = chunk.find('|') { + let name = &chunk[..sep]; + if let Ok(percent) = chunk[sep + 1..].parse::() { + callback(if name.is_empty() { "" } else { name }, percent); + } + } + }; + + self.core + .execute_command_streaming_cancellable_async( + "download_and_register_eps".into(), + params, + wrapper, + flag, + ) + .await? + } + (Some(cb), None) => { let mut callback = cb; let wrapper = move |chunk: &str| { if let Some(sep) = chunk.find('|') { @@ -206,7 +256,17 @@ impl FoundryLocalManager { ) .await? } - None => { + (None, Some(flag)) => { + self.core + .execute_command_streaming_cancellable_async( + "download_and_register_eps".into(), + params, + |_chunk: &str| {}, + flag, + ) + .await? + } + (None, None) => { self.core .execute_command_async("download_and_register_eps".into(), params) .await? diff --git a/sdk/rust/src/openai/audio_client.rs b/sdk/rust/src/openai/audio_client.rs index c48e24282..b078e63e6 100644 --- a/sdk/rust/src/openai/audio_client.rs +++ b/sdk/rust/src/openai/audio_client.rs @@ -4,6 +4,7 @@ use std::path::Path; use std::sync::Arc; use serde_json::{json, Value}; +use tokio_util::sync::CancellationToken; use crate::detail::core_interop::CoreInterop; use crate::error::{FoundryLocalError, Result}; @@ -141,6 +142,24 @@ impl AudioClient { pub async fn transcribe( &self, audio_file_path: impl AsRef, + ) -> Result { + self.transcribe_impl(audio_file_path, None).await + } + + /// Transcribe an audio file with native command cancellation. + pub async fn transcribe_with_cancellation( + &self, + audio_file_path: impl AsRef, + cancellation_token: CancellationToken, + ) -> Result { + self.transcribe_impl(audio_file_path, Some(cancellation_token)) + .await + } + + async fn transcribe_impl( + &self, + audio_file_path: impl AsRef, + cancellation_token: Option, ) -> Result { let path_str = audio_file_path @@ -158,10 +177,22 @@ impl AudioClient { } }); - let raw = self - .core - .execute_command_async("audio_transcribe".into(), Some(params)) - .await?; + let raw = match cancellation_token { + Some(token) => { + self.core + .execute_command_async_cancellable( + "audio_transcribe".into(), + Some(params), + token, + ) + .await? + } + None => { + self.core + .execute_command_async("audio_transcribe".into(), Some(params)) + .await? + } + }; let parsed: AudioTranscriptionResponse = serde_json::from_str(&raw)?; Ok(parsed) } diff --git a/sdk/rust/src/openai/chat_client.rs b/sdk/rust/src/openai/chat_client.rs index 6597de826..16c49b8ab 100644 --- a/sdk/rust/src/openai/chat_client.rs +++ b/sdk/rust/src/openai/chat_client.rs @@ -8,6 +8,7 @@ use async_openai::types::chat::{ CreateChatCompletionStreamResponse, }; use serde_json::{json, Value}; +use tokio_util::sync::CancellationToken; use crate::detail::core_interop::CoreInterop; use crate::error::{FoundryLocalError, Result}; @@ -205,6 +206,26 @@ impl ChatClient { &self, messages: &[ChatCompletionRequestMessage], tools: Option<&[ChatCompletionTools]>, + ) -> Result { + self.complete_chat_impl(messages, tools, None).await + } + + /// Perform a non-streaming chat completion with native command cancellation. + pub async fn complete_chat_with_cancellation( + &self, + messages: &[ChatCompletionRequestMessage], + tools: Option<&[ChatCompletionTools]>, + cancellation_token: CancellationToken, + ) -> Result { + self.complete_chat_impl(messages, tools, Some(cancellation_token)) + .await + } + + async fn complete_chat_impl( + &self, + messages: &[ChatCompletionRequestMessage], + tools: Option<&[ChatCompletionTools]>, + cancellation_token: Option, ) -> Result { if messages.is_empty() { return Err(FoundryLocalError::Validation { @@ -219,10 +240,22 @@ impl ChatClient { } }); - let raw = self - .core - .execute_command_async("chat_completions".into(), Some(params)) - .await?; + let raw = match cancellation_token { + Some(token) => { + self.core + .execute_command_async_cancellable( + "chat_completions".into(), + Some(params), + token, + ) + .await? + } + None => { + self.core + .execute_command_async("chat_completions".into(), Some(params)) + .await? + } + }; let parsed: CreateChatCompletionResponse = serde_json::from_str(&raw)?; Ok(parsed) } diff --git a/sdk/rust/src/openai/embedding_client.rs b/sdk/rust/src/openai/embedding_client.rs index 5de080a0c..ff5fa8236 100644 --- a/sdk/rust/src/openai/embedding_client.rs +++ b/sdk/rust/src/openai/embedding_client.rs @@ -4,6 +4,7 @@ use std::sync::Arc; use async_openai::types::embeddings::CreateEmbeddingResponse; use serde_json::{json, Value}; +use tokio_util::sync::CancellationToken; use crate::detail::core_interop::CoreInterop; use crate::error::{FoundryLocalError, Result}; @@ -26,7 +27,19 @@ impl EmbeddingClient { pub async fn generate_embedding(&self, input: &str) -> Result { Self::validate_input(input)?; let request = self.build_request(json!(input)); - self.execute_request(request).await + self.execute_request(request, None).await + } + + /// Generate embeddings for a single input text with native command cancellation. + pub async fn generate_embedding_with_cancellation( + &self, + input: &str, + cancellation_token: CancellationToken, + ) -> Result { + Self::validate_input(input)?; + let request = self.build_request(json!(input)); + self.execute_request(request, Some(cancellation_token)) + .await } /// Generate embeddings for multiple input texts in a single request. @@ -40,20 +53,51 @@ impl EmbeddingClient { Self::validate_input(input)?; } let request = self.build_request(json!(inputs)); - self.execute_request(request).await + self.execute_request(request, None).await } - async fn execute_request(&self, request: Value) -> Result { + /// Generate embeddings for multiple input texts with native command cancellation. + pub async fn generate_embeddings_with_cancellation( + &self, + inputs: &[&str], + cancellation_token: CancellationToken, + ) -> Result { + if inputs.is_empty() { + return Err(FoundryLocalError::Validation { + reason: "inputs must be a non-empty array".into(), + }); + } + for input in inputs { + Self::validate_input(input)?; + } + let request = self.build_request(json!(inputs)); + self.execute_request(request, Some(cancellation_token)) + .await + } + + async fn execute_request( + &self, + request: Value, + cancellation_token: Option, + ) -> Result { let params = json!({ "Params": { "OpenAICreateRequest": serde_json::to_string(&request)? } }); - let raw = self - .core - .execute_command_async("embeddings".into(), Some(params)) - .await?; + let raw = match cancellation_token { + Some(token) => { + self.core + .execute_command_async_cancellable("embeddings".into(), Some(params), token) + .await? + } + None => { + self.core + .execute_command_async("embeddings".into(), Some(params)) + .await? + } + }; // Patch the response to add fields required by async_openai types // that the server doesn't return (object on each item, usage)