Commit bdb82411 authored by edy's avatar edy

feat(chat): support cancelling active streams

parent 5abfbb86
Pipeline #18469 failed
......@@ -6,6 +6,7 @@ import {
IPC_CHANNELS,
type AppConfig,
type ChatAttachment,
type ChatCancelStreamResult,
type ChatMessage,
type ChatStreamEvent,
type ConfigSecretId,
......@@ -497,6 +498,15 @@ function delay(ms: number): Promise<void> {
return new Promise((resolve) => setTimeout(resolve, ms));
}
interface ActiveChatStream {
requestId: string;
assistantMessageId: string;
sessionId?: string;
runId?: string;
cancelled: boolean;
markStopped: () => Promise<void>;
}
export function registerDesktopIpc(services: MainServices): RegisteredDesktopIpc {
const {
appVersion,
......@@ -1598,6 +1608,7 @@ export function registerDesktopIpc(services: MainServices): RegisteredDesktopIpc
};
const streamListeners = new Set<(payload: ChatStreamEvent) => void>();
const activeChatStreams = new Map<string, ActiveChatStream>();
const broadcastChatStreamEvent = (payload: ChatStreamEvent, sender?: WebContents) => {
if (sender && !sender.isDestroyed()) {
......@@ -2010,9 +2021,33 @@ export function registerDesktopIpc(services: MainServices): RegisteredDesktopIpc
return transcriptWriteChain;
};
const updateAssistantTranscript = (updater: (current: ChatMessage) => ChatMessage) => {
if (activeChatStream.cancelled) {
return transcriptWriteChain;
}
const nextMessage = updater(assistantTranscript);
return queueAssistantTranscriptWrite(nextMessage);
};
const activeChatStream: ActiveChatStream = {
requestId,
assistantMessageId,
sessionId: executionSessionId,
cancelled: false,
markStopped: async () => {
await queueAssistantTranscriptWrite({
...assistantTranscript,
content: assistantTranscript.content,
streamState: undefined,
statusLabel: "已停止",
statusDetail: undefined
});
}
};
const finishActiveChatStream = () => {
if (activeChatStreams.get(requestId) === activeChatStream) {
activeChatStreams.delete(requestId);
}
};
activeChatStreams.set(requestId, activeChatStream);
const queueProjectContextRefresh = () => {
if (contextRefreshQueued || !shouldScheduleContextRefresh || !refreshProjectId) {
return;
......@@ -2026,6 +2061,9 @@ export function registerDesktopIpc(services: MainServices): RegisteredDesktopIpc
});
};
const queueOrSend = (payload: ChatStreamEvent) => {
if (activeChatStream.cancelled) {
return;
}
if (!ready) {
if (payload.type === "started") {
startedEvent = payload;
......@@ -2038,6 +2076,9 @@ export function registerDesktopIpc(services: MainServices): RegisteredDesktopIpc
};
const flushQueuedEvents = (fallbackStarted?: ChatStreamEvent) => {
setTimeout(() => {
if (activeChatStream.cancelled) {
return;
}
if (startedEvent) {
broadcastChatStreamEvent(startedEvent, sender);
} else if (fallbackStarted) {
......@@ -2068,6 +2109,8 @@ export function registerDesktopIpc(services: MainServices): RegisteredDesktopIpc
}));
runtimeCloudSupervisor.noteMessageReceived(executionSessionId, prompt, undefined);
const runId = randomUUID();
activeChatStream.sessionId = executionSessionId;
activeChatStream.runId = runId;
queueOrSend({
type: "status",
requestId,
......@@ -2087,10 +2130,14 @@ export function registerDesktopIpc(services: MainServices): RegisteredDesktopIpc
void (async () => {
try {
const replyContent = await requestHomeImageChatCompletion(prompt, normalizedAttachments);
if (activeChatStream.cancelled) {
return;
}
const reply = createChatMessage("assistant", replyContent, {
id: assistantMessageId
});
settled = true;
finishActiveChatStream();
await updateAssistantTranscript((current) => ({
...current,
content: reply.content,
......@@ -2110,7 +2157,11 @@ export function registerDesktopIpc(services: MainServices): RegisteredDesktopIpc
executionPolicy: executionPolicy ?? undefined
});
} catch (error) {
if (activeChatStream.cancelled) {
return;
}
settled = true;
finishActiveChatStream();
const message = error instanceof Error ? error.message : String(error);
await updateAssistantTranscript((current) => ({
...current,
......@@ -2229,6 +2280,8 @@ export function registerDesktopIpc(services: MainServices): RegisteredDesktopIpc
}
}, {
onStarted: (runId) => {
activeChatStream.sessionId = executionSessionId;
activeChatStream.runId = runId;
queueOrSend({
type: "started",
requestId,
......@@ -2238,6 +2291,9 @@ export function registerDesktopIpc(services: MainServices): RegisteredDesktopIpc
});
},
onStatus: (stage, label, detail) => {
if (activeChatStream.cancelled) {
return;
}
void updateAssistantTranscript((current) => ({
...current,
streamState: "streaming",
......@@ -2254,6 +2310,11 @@ export function registerDesktopIpc(services: MainServices): RegisteredDesktopIpc
});
},
onDelta: (textDelta, fullText, runId) => {
if (activeChatStream.cancelled) {
return;
}
activeChatStream.sessionId = executionSessionId;
activeChatStream.runId = runId;
void updateAssistantTranscript((current) => ({
...current,
content: fullText && fullText.length >= current.content.length
......@@ -2274,6 +2335,9 @@ export function registerDesktopIpc(services: MainServices): RegisteredDesktopIpc
}
});
if ("handoff" in result) {
if (activeChatStream.cancelled) {
return;
}
executionPolicy = await resolveExecutionPolicy(
preparedExecution.sessionState.projectId,
undefined,
......@@ -2298,6 +2362,8 @@ export function registerDesktopIpc(services: MainServices): RegisteredDesktopIpc
execute: () => gatewayClient.streamPrompt(executionSessionId, result.handoff.content, {
onStarted: ({ sessionId: nextSessionId, runId }) => {
executionSessionId = nextSessionId;
activeChatStream.sessionId = nextSessionId;
activeChatStream.runId = runId;
queueOrSend({
type: "started",
requestId,
......@@ -2307,7 +2373,12 @@ export function registerDesktopIpc(services: MainServices): RegisteredDesktopIpc
});
},
onStatus: ({ sessionId: nextSessionId, runId, stage, label, detail }) => {
if (activeChatStream.cancelled) {
return;
}
executionSessionId = nextSessionId;
activeChatStream.sessionId = nextSessionId;
activeChatStream.runId = runId;
void updateAssistantTranscript((current) => ({
...current,
streamState: "streaming",
......@@ -2325,7 +2396,12 @@ export function registerDesktopIpc(services: MainServices): RegisteredDesktopIpc
});
},
onDelta: ({ sessionId: nextSessionId, runId, textDelta, fullText }) => {
if (activeChatStream.cancelled) {
return;
}
executionSessionId = nextSessionId;
activeChatStream.sessionId = nextSessionId;
activeChatStream.runId = runId;
void updateAssistantTranscript((current) => ({
...current,
content: fullText && fullText.length >= current.content.length
......@@ -2345,8 +2421,14 @@ export function registerDesktopIpc(services: MainServices): RegisteredDesktopIpc
});
},
onCompleted: ({ sessionId: nextSessionId, runId, reply }) => {
if (activeChatStream.cancelled) {
return;
}
executionSessionId = nextSessionId;
activeChatStream.sessionId = nextSessionId;
activeChatStream.runId = runId;
settled = true;
finishActiveChatStream();
void (async () => {
await updateAssistantTranscript((current) => ({
...current,
......@@ -2378,8 +2460,14 @@ export function registerDesktopIpc(services: MainServices): RegisteredDesktopIpc
queueProjectContextRefresh();
},
onError: ({ sessionId: nextSessionId, runId, error }) => {
if (activeChatStream.cancelled) {
return;
}
executionSessionId = nextSessionId;
activeChatStream.sessionId = nextSessionId;
activeChatStream.runId = runId;
settled = true;
finishActiveChatStream();
const errorCategory = typeof (error as Error & { errorCategory?: unknown }).errorCategory === "string"
? String((error as Error & { errorCategory?: unknown }).errorCategory).trim()
: "";
......@@ -2408,7 +2496,13 @@ export function registerDesktopIpc(services: MainServices): RegisteredDesktopIpc
});
return;
}
if (activeChatStream.cancelled) {
return;
}
settled = true;
activeChatStream.sessionId = executionSessionId;
activeChatStream.runId = result.runId;
finishActiveChatStream();
await updateAssistantTranscript((current) => ({
...current,
content: result.reply.content,
......@@ -2436,7 +2530,11 @@ export function registerDesktopIpc(services: MainServices): RegisteredDesktopIpc
executionPolicy: executionPolicy ?? undefined
});
} catch (error) {
if (activeChatStream.cancelled) {
return;
}
settled = true;
finishActiveChatStream();
const message = error instanceof Error ? error.message : String(error);
const errorCategory = error instanceof Error && typeof (error as Error & { errorCategory?: unknown }).errorCategory === "string"
? String((error as Error & { errorCategory?: unknown }).errorCategory).trim()
......@@ -2476,6 +2574,8 @@ export function registerDesktopIpc(services: MainServices): RegisteredDesktopIpc
execute: () => gatewayClient.streamPrompt(executionSessionId, preparedExecution.gatewayPrompt ?? prompt, {
onStarted: ({ sessionId: nextSessionId, runId }) => {
executionSessionId = nextSessionId;
activeChatStream.sessionId = nextSessionId;
activeChatStream.runId = runId;
queueOrSend({
type: "started",
requestId,
......@@ -2485,7 +2585,12 @@ export function registerDesktopIpc(services: MainServices): RegisteredDesktopIpc
});
},
onStatus: ({ sessionId: nextSessionId, runId, stage, label, detail }) => {
if (activeChatStream.cancelled) {
return;
}
executionSessionId = nextSessionId;
activeChatStream.sessionId = nextSessionId;
activeChatStream.runId = runId;
void updateAssistantTranscript((current) => ({
...current,
streamState: "streaming",
......@@ -2503,7 +2608,12 @@ export function registerDesktopIpc(services: MainServices): RegisteredDesktopIpc
});
},
onDelta: ({ sessionId: nextSessionId, runId, textDelta, fullText }) => {
if (activeChatStream.cancelled) {
return;
}
executionSessionId = nextSessionId;
activeChatStream.sessionId = nextSessionId;
activeChatStream.runId = runId;
void updateAssistantTranscript((current) => ({
...current,
content: fullText && fullText.length >= current.content.length
......@@ -2523,8 +2633,14 @@ export function registerDesktopIpc(services: MainServices): RegisteredDesktopIpc
});
},
onCompleted: ({ sessionId: nextSessionId, runId, reply }) => {
if (activeChatStream.cancelled) {
return;
}
executionSessionId = nextSessionId;
activeChatStream.sessionId = nextSessionId;
activeChatStream.runId = runId;
settled = true;
finishActiveChatStream();
void (async () => {
await updateAssistantTranscript((current) => ({
...current,
......@@ -2548,8 +2664,14 @@ export function registerDesktopIpc(services: MainServices): RegisteredDesktopIpc
queueProjectContextRefresh();
},
onError: ({ sessionId: nextSessionId, runId, error }) => {
if (activeChatStream.cancelled) {
return;
}
executionSessionId = nextSessionId;
activeChatStream.sessionId = nextSessionId;
activeChatStream.runId = runId;
settled = true;
finishActiveChatStream();
const errorCategory = typeof (error as Error & { errorCategory?: unknown }).errorCategory === "string"
? String((error as Error & { errorCategory?: unknown }).errorCategory).trim()
: "";
......@@ -2577,6 +2699,8 @@ export function registerDesktopIpc(services: MainServices): RegisteredDesktopIpc
})
});
ready = true;
activeChatStream.sessionId = stream.sessionId;
activeChatStream.runId = stream.runId;
flushQueuedEvents({
type: "started",
requestId,
......@@ -2593,6 +2717,17 @@ export function registerDesktopIpc(services: MainServices): RegisteredDesktopIpc
executionPolicy: executionPolicy ?? undefined
};
} catch (error) {
if (activeChatStream.cancelled) {
return {
requestId,
sessionId: activeChatStream.sessionId ?? executionSessionId,
runId: activeChatStream.runId,
userMessageId,
assistantMessageId,
executionPolicy: executionPolicy ?? undefined
};
}
finishActiveChatStream();
const message = error instanceof Error ? error.message : String(error);
if (!settled) {
runtimeCloudSupervisor.noteError("chat_stream_failed", message, {
......@@ -2604,6 +2739,53 @@ export function registerDesktopIpc(services: MainServices): RegisteredDesktopIpc
throw error;
}
};
const cancelStream = async (requestId: string, runId?: string, sessionId?: string, sender?: WebContents): Promise<ChatCancelStreamResult> => {
const activeStream = activeChatStreams.get(requestId)
?? [...activeChatStreams.values()].find((stream) => Boolean(runId) && stream.runId === runId);
if (!activeStream) {
return {
requestId,
sessionId,
runId,
localCancelled: false,
remoteCancelled: false,
message: "No active stream found."
};
}
activeStream.cancelled = true;
activeChatStreams.delete(activeStream.requestId);
await activeStream.markStopped().catch(() => undefined);
const effectiveRunId = runId ?? activeStream.runId;
const effectiveSessionId = sessionId ?? activeStream.sessionId;
let remoteCancelled = false;
if (effectiveRunId) {
const cancelResult = await gatewayClient.cancelChatRun(effectiveRunId).catch(() => null);
remoteCancelled = Boolean(cancelResult?.remoteCancelled);
}
const result: ChatCancelStreamResult = {
requestId: activeStream.requestId,
sessionId: effectiveSessionId,
runId: effectiveRunId,
localCancelled: true,
remoteCancelled,
message: "已停止"
};
broadcastChatStreamEvent({
type: "cancelled",
requestId: activeStream.requestId,
sessionId: effectiveSessionId,
runId: effectiveRunId,
message: result.message,
remoteCancelled
}, sender);
return result;
};
ipcMain.handle(IPC_CHANNELS.workspaceGetSummary, async () => buildWorkspaceSummary());
ipcMain.handle(IPC_CHANNELS.workspaceWarmup, async () => queueWorkspaceWarmup("workspace-warmup", { action: "init" }));
ipcMain.handle(IPC_CHANNELS.windowMinimize, async (event) => {
......@@ -2744,6 +2926,9 @@ export function registerDesktopIpc(services: MainServices): RegisteredDesktopIpc
ipcMain.handle(IPC_CHANNELS.chatStreamPrompt, async (event, sessionId: string, prompt: string, skillId?: string, attachments?: ChatAttachment[]) => {
return streamPrompt(sessionId, prompt, skillId, attachments, event.sender);
});
ipcMain.handle(IPC_CHANNELS.chatCancelStream, async (event, requestId: string, runId?: string, sessionId?: string) => {
return cancelStream(requestId, runId, sessionId, event.sender);
});
ipcMain.handle(IPC_CHANNELS.diagnosticsOpenControlUi, async () => {
const config = await getEffectiveConfig();
await shell.openExternal(toControlUiUrl(config.gatewayUrl));
......@@ -2863,6 +3048,7 @@ export function registerDesktopIpc(services: MainServices): RegisteredDesktopIpc
readImageAttachmentDataUrl: async (attachment: ChatAttachment) => readImageAttachmentDataUrl(attachment),
sendPrompt: async (sessionId: string, prompt: string, skillId?: string, attachments?: ChatAttachment[]) => sendPrompt(sessionId, prompt, skillId, attachments),
streamPrompt: async (sessionId: string, prompt: string, skillId?: string, attachments?: ChatAttachment[]) => streamPrompt(sessionId, prompt, skillId, attachments),
cancelStream: (requestId: string, runId?: string, sessionId?: string) => cancelStream(requestId, runId, sessionId),
onStreamEvent: (listener) => {
streamListeners.add(listener);
return () => {
......
......@@ -95,6 +95,7 @@ const desktopApi: DesktopApi = {
readImageAttachmentDataUrl: (attachment: ChatAttachment) => ipcRenderer.invoke(IPC_CHANNELS.chatReadImageAttachmentDataUrl, attachment),
sendPrompt: (sessionId: string, prompt: string, skillId?: string, attachments?: ChatAttachment[]) => ipcRenderer.invoke(IPC_CHANNELS.chatSendPrompt, sessionId, prompt, skillId, attachments),
streamPrompt: (sessionId: string, prompt: string, skillId?: string, attachments?: ChatAttachment[]) => ipcRenderer.invoke(IPC_CHANNELS.chatStreamPrompt, sessionId, prompt, skillId, attachments),
cancelStream: (requestId: string, runId?: string, sessionId?: string) => ipcRenderer.invoke(IPC_CHANNELS.chatCancelStream, requestId, runId, sessionId),
onStreamEvent: (listener: ChatStreamListener) => {
const wrapped = (_event: Electron.IpcRendererEvent, payload: Parameters<ChatStreamListener>[0]) => {
listener(payload);
......
import test from "node:test"
import assert from "node:assert/strict"
import { readFileSync } from "node:fs"
const ipcSource = readFileSync(new URL("../src/main/ipc.ts", import.meta.url), "utf8")
const preloadSource = readFileSync(new URL("../src/preload/index.ts", import.meta.url), "utf8")
test("desktop IPC registers chat cancel stream handler", () => {
assert.match(ipcSource, /activeChatStreams/)
assert.match(ipcSource, /cancelStream\s*=\s*async/)
assert.match(ipcSource, /IPC_CHANNELS\.chatCancelStream/)
assert.match(preloadSource, /cancelStream: \(requestId: string, runId\?: string, sessionId\?: string\)/)
})
test("desktop cancel marks assistant stream as stopped and broadcasts cancelled event", () => {
assert.match(ipcSource, /statusLabel:\s*"已停止"/)
assert.match(ipcSource, /type:\s*"cancelled"/)
assert.match(ipcSource, /gatewayClient\.cancelChatRun/)
})
......@@ -31,6 +31,7 @@ import {
NavIcon,
RedBookIcon,
RefreshIcon,
StopIcon,
ThumbIcon,
TrashIcon,
getIntentSuggestionIcon,
......@@ -703,7 +704,8 @@ export default function App() {
sending,
streamSmoke,
activeStreamRef,
submitPrompt
submitPrompt,
cancelActiveStream
} = useChatStreamingController({
desktopApi,
viewMode,
......@@ -784,9 +786,9 @@ export default function App() {
normalizeError: err
});
const sendButtonLabel = sendPhase === "preparing"
? ui.preparing
? "停止生成"
: sendPhase === "streaming" || sendPhase === "finalizing"
? ui.generating
? "停止生成"
: !isBound
? ui.bindFirst
: ui.send;
......@@ -1429,8 +1431,9 @@ export default function App() {
skills: effectiveSkills,
skillMenuOpen,
attachmentIcon: <AttachmentIcon />,
submitIcon: <ArrowUpIcon />,
submitIcon: sendPhase !== "idle" ? <StopIcon /> : <ArrowUpIcon />,
onSubmit: sendPrompt,
onCancel: cancelActiveStream,
onPromptChange: setPrompt,
onTextareaKeyDown: handleComposerKeyDown,
onAttachmentSelection: handleAttachmentSelection,
......
......@@ -328,6 +328,14 @@ export function ArrowUpIcon() {
);
}
export function StopIcon() {
return (
<svg viewBox="0 0 24 24" fill="none" aria-hidden="true" focusable="false">
<rect x="7" y="7" width="10" height="10" rx="1.8" fill="currentColor" />
</svg>
);
}
export function RefreshIcon() {
return (
<svg viewBox="0 0 24 24" fill="none" aria-hidden="true" focusable="false">
......
......@@ -40,6 +40,7 @@ interface ChatComposerProps {
attachmentIcon: ReactNode
submitIcon: ReactNode
onSubmit: () => void | Promise<void>
onCancel: () => void | Promise<void>
onPromptChange: (value: string) => void
onTextareaKeyDown: (event: ReactKeyboardEvent<HTMLTextAreaElement>) => void | Promise<void>
onAttachmentSelection: (event: ChangeEvent<HTMLInputElement>) => void
......@@ -74,6 +75,7 @@ export function ChatComposer({
attachmentIcon,
submitIcon,
onSubmit,
onCancel,
onPromptChange,
onTextareaKeyDown,
onAttachmentSelection,
......@@ -159,13 +161,14 @@ export function ChatComposer({
</button>
</div>
<button
type="submit"
type={sending ? "button" : "submit"}
className={"composer-submit" + (sending ? " is-busy" : "")}
disabled={!canSend}
disabled={sending ? false : !canSend}
onClick={sending ? onCancel : undefined}
aria-label={sendButtonLabel}
title={sendButtonLabel}
>
{sending ? <span className="composer-submit-spinner" aria-hidden="true" /> : submitIcon}
{submitIcon}
<span className="visually-hidden">{sendButtonLabel}</span>
</button>
</div>
......
......@@ -130,6 +130,7 @@ interface ConversationWorkspaceViewProps {
attachmentIcon: ReactNode
submitIcon: ReactNode
onSubmit: () => void | Promise<void>
onCancel: () => void | Promise<void>
onPromptChange: (value: string) => void
onTextareaKeyDown: (event: ReactKeyboardEvent<HTMLTextAreaElement>) => void | Promise<void>
onAttachmentSelection: (event: ChangeEvent<HTMLInputElement>) => void
......@@ -221,6 +222,7 @@ export function ConversationWorkspaceView({
attachmentIcon,
submitIcon,
onSubmit,
onCancel,
onPromptChange,
onTextareaKeyDown,
onAttachmentSelection,
......@@ -387,6 +389,7 @@ export function ConversationWorkspaceView({
attachmentIcon={attachmentIcon}
submitIcon={submitIcon}
onSubmit={onSubmit}
onCancel={onCancel}
onPromptChange={onPromptChange}
onTextareaKeyDown={onTextareaKeyDown}
onAttachmentSelection={onAttachmentSelection}
......
......@@ -13,6 +13,7 @@ interface ActiveStreamState {
assistantMessageId: string
sessionId: string
originSessionId: string
runId?: string
targetText: string
renderedText: string
finalReply?: ChatMessage
......@@ -135,6 +136,10 @@ export function useChatStreamingController(deps: UseChatStreamingControllerDeps)
const messagesBySessionRef = useRef(messagesBySession)
const scrollMessageListToBottomRef = useRef(scrollMessageListToBottom)
const clearComposerAttachmentRef = useRef(clearComposerAttachment)
const stoppedRequestIdsRef = useRef(new Set<string>())
const cancelledSubmissionIdsRef = useRef(new Set<string>())
const pendingSubmissionIdRef = useRef<string | undefined>(undefined)
const pendingSubmissionAssistantMessageIdsRef = useRef(new Map<string, string>())
useEffect(() => {
workspaceRef.current = workspace
......@@ -324,8 +329,62 @@ export function useChatStreamingController(deps: UseChatStreamingControllerDeps)
setErrorText(message)
}, [cancelTypewriter, setErrorText, setMessageTraceExpanded, updateMessageById])
const completeWithFallback = useCallback(async (sessionId: string, promptText: string, skillId: string | undefined, assistantMessageId: string, attachments?: ChatAttachment[]) => {
const cancelActiveStream = useCallback(async () => {
const activeStream = activeStreamRef.current
if (!activeStream) {
const pendingSubmissionId = pendingSubmissionIdRef.current
if (pendingSubmissionId) {
cancelledSubmissionIdsRef.current.add(pendingSubmissionId)
}
const pendingAssistantMessageId = pendingSubmissionId
? pendingSubmissionAssistantMessageIdsRef.current.get(pendingSubmissionId)
: undefined
if (pendingAssistantMessageId) {
updateMessageById(pendingAssistantMessageId, (message) => ({
...message,
streamState: undefined,
statusLabel: "已停止",
statusDetail: undefined
}))
appendTrace(pendingAssistantMessageId, "cancelled", "已停止", undefined, "info")
}
setSendPhase("idle")
return
}
stoppedRequestIdsRef.current.add(activeStream.requestId)
cancelTypewriter()
const stoppedContent = activeStream.renderedText || activeStream.targetText
updateMessageById(activeStream.assistantMessageId, (message) => ({
...message,
content: stoppedContent || message.content,
streamState: undefined,
statusLabel: "已停止",
statusDetail: undefined
}))
appendTrace(activeStream.assistantMessageId, "cancelled", "已停止", undefined, "info")
updateStreamSmoke((current) => current ? {
...current,
phase: "cancelled",
renderedContent: stoppedContent || current.renderedContent,
finalContent: stoppedContent || current.finalContent,
latestStatusLabel: "已停止",
statusLabels: appendSmokeStatusLabel(current.statusLabels, "已停止")
} : current)
activeStreamRef.current = null
setSendPhase("idle")
await desktopApi.chat.cancelStream(activeStream.requestId, activeStream.runId, activeStream.sessionId).catch(() => undefined)
stoppedRequestIdsRef.current.delete(activeStream.requestId)
void syncChatAfterSend()
}, [appendTrace, cancelTypewriter, desktopApi.chat, syncChatAfterSend, updateMessageById])
const completeWithFallback = useCallback(async (sessionId: string, promptText: string, skillId: string | undefined, assistantMessageId: string, attachments?: ChatAttachment[], isCancelled?: () => boolean) => {
const result = await desktopApi.chat.sendPrompt(sessionId, promptText, skillId, attachments)
if (isCancelled?.()) {
setSendPhase("idle")
return
}
cancelTypewriter()
activeStreamRef.current = null
updateMessageById(assistantMessageId, (message) => ({
......@@ -451,6 +510,18 @@ export function useChatStreamingController(deps: UseChatStreamingControllerDeps)
const renderedPrompt = trimmedPrompt || (attachmentsToSend?.length ? buildAttachmentPromptSummary(attachmentsToSend) : "")
const userMessage = buildUserMessage(renderedPrompt, attachmentsToSend)
const assistantMessage = buildAssistantPlaceholder(ui.preparingReply)
const submissionId = createClientMessageId("submission")
cancelledSubmissionIdsRef.current.delete(submissionId)
pendingSubmissionIdRef.current = submissionId
pendingSubmissionAssistantMessageIdsRef.current.set(submissionId, assistantMessage.id)
const isSubmissionCancelled = () => cancelledSubmissionIdsRef.current.has(submissionId)
const clearPendingSubmission = () => {
if (pendingSubmissionIdRef.current === submissionId) {
pendingSubmissionIdRef.current = undefined
}
pendingSubmissionAssistantMessageIdsRef.current.delete(submissionId)
cancelledSubmissionIdsRef.current.delete(submissionId)
}
setSendPhase("preparing")
setErrorText("")
......@@ -475,6 +546,10 @@ export function useChatStreamingController(deps: UseChatStreamingControllerDeps)
try {
const confirmedWorkspace = await ensureChatAvailable(assistantMessage.id)
if (isSubmissionCancelled()) {
clearPendingSubmission()
return
}
const effectiveProjectId = forcedProjectId
?? (viewModeRef.current === "chat"
? sessionScopeProjectIdRef.current
......@@ -496,6 +571,10 @@ export function useChatStreamingController(deps: UseChatStreamingControllerDeps)
sessionId = createdSession.id
upsertSession(createdSession)
}
if (isSubmissionCancelled()) {
clearPendingSubmission()
return
}
if (!optimisticSessionId) {
updateSessionMessages(sessionId, (current) => [...current, userMessage, assistantMessage])
scrollMessageListToBottomRef.current({ force: true, behavior: "smooth" })
......@@ -537,11 +616,27 @@ export function useChatStreamingController(deps: UseChatStreamingControllerDeps)
}
userMessageId = stream.userMessageId ?? userMessageId
assistantMessageId = stream.assistantMessageId ?? assistantMessageId
if (isSubmissionCancelled()) {
stoppedRequestIdsRef.current.add(stream.requestId)
clearPendingSubmission()
updateMessageById(assistantMessageId, (message) => ({
...message,
streamState: undefined,
statusLabel: "已停止",
statusDetail: undefined
}))
await desktopApi.chat.cancelStream(stream.requestId, stream.runId, stream.sessionId).catch(() => undefined)
stoppedRequestIdsRef.current.delete(stream.requestId)
setSendPhase("idle")
return
}
clearPendingSubmission()
activeStreamRef.current = {
requestId: stream.requestId,
assistantMessageId,
sessionId: stream.sessionId,
originSessionId: sessionId,
runId: stream.runId,
targetText: "",
renderedText: ""
}
......@@ -556,13 +651,25 @@ export function useChatStreamingController(deps: UseChatStreamingControllerDeps)
executionPolicyModel: stream.executionPolicy?.modelLabel ?? current.executionPolicyModel
} : current)
} catch {
if (isSubmissionCancelled()) {
clearPendingSubmission()
setSendPhase("idle")
return
}
setSendPhase("finalizing")
appendTrace(assistantMessageId, "fallback", ui.fallbackReply)
updateAssistantStatus(assistantMessageId, ui.generating)
await completeWithFallback(sessionId, trimmedPrompt, skillId, assistantMessageId, attachmentsToSend)
await completeWithFallback(sessionId, trimmedPrompt, skillId, assistantMessageId, attachmentsToSend, isSubmissionCancelled)
clearPendingSubmission()
clearComposerAttachmentRef.current()
}
} catch (error) {
if (isSubmissionCancelled()) {
clearPendingSubmission()
setSendPhase("idle")
return
}
clearPendingSubmission()
setSendPhase("idle")
const message = error instanceof Error ? error.message : String(error)
setMessageTraceExpanded(assistantMessageId, true)
......@@ -618,6 +725,7 @@ export function useChatStreamingController(deps: UseChatStreamingControllerDeps)
ui.preparingReply,
ui.waitingReply,
updateAssistantStatus,
updateMessageById,
updateSessionMessages,
upsertSession,
visibleSessionIdRef
......@@ -625,6 +733,25 @@ export function useChatStreamingController(deps: UseChatStreamingControllerDeps)
useEffect(() => {
const unsubscribe = desktopApi.chat.onStreamEvent((event) => {
if (event.type === "cancelled") {
stoppedRequestIdsRef.current.add(event.requestId)
const activeStream = activeStreamRef.current
if (activeStream && event.requestId === activeStream.requestId) {
void cancelActiveStream()
}
stoppedRequestIdsRef.current.delete(event.requestId)
return
}
if ((event.type === "completed" || event.type === "error") && stoppedRequestIdsRef.current.has(event.requestId)) {
stoppedRequestIdsRef.current.delete(event.requestId)
return
}
if (stoppedRequestIdsRef.current.has(event.requestId)) {
return
}
const activeStream = activeStreamRef.current
if (!activeStream || event.requestId !== activeStream.requestId) {
return
......@@ -632,6 +759,7 @@ export function useChatStreamingController(deps: UseChatStreamingControllerDeps)
if (event.type === "started") {
activeStream.sessionId = event.sessionId
activeStream.runId = event.runId ?? activeStream.runId
setSendPhase("streaming")
appendTrace(activeStream.assistantMessageId, "started", ui.replyStarted)
updateAssistantStatus(activeStream.assistantMessageId, ui.thinking)
......@@ -650,6 +778,7 @@ export function useChatStreamingController(deps: UseChatStreamingControllerDeps)
if (event.type === "status") {
activeStream.sessionId = event.sessionId
activeStream.runId = event.runId ?? activeStream.runId
appendTrace(activeStream.assistantMessageId, event.stage, event.label, event.detail)
updateAssistantStatus(activeStream.assistantMessageId, event.label, event.detail)
updateStreamSmoke((current) => current ? {
......@@ -665,6 +794,7 @@ export function useChatStreamingController(deps: UseChatStreamingControllerDeps)
if (event.type === "delta") {
activeStream.sessionId = event.sessionId
activeStream.runId = event.runId
setSendPhase("streaming")
activeStream.targetText = event.fullText && event.fullText.length >= activeStream.targetText.length
? event.fullText
......@@ -681,7 +811,9 @@ export function useChatStreamingController(deps: UseChatStreamingControllerDeps)
}
if (event.type === "completed") {
stoppedRequestIdsRef.current.delete(event.requestId)
activeStream.sessionId = event.sessionId
activeStream.runId = event.runId
activeStream.finalReply = event.reply
if (event.reply.content.length >= activeStream.targetText.length) {
activeStream.targetText = event.reply.content
......@@ -712,6 +844,7 @@ export function useChatStreamingController(deps: UseChatStreamingControllerDeps)
}
if (event.type === "error") {
stoppedRequestIdsRef.current.delete(event.requestId)
const normalizedMessage = normalizeAssistantErrorMessage(event.message, event.errorCategory)
appendTrace(activeStream.assistantMessageId, "error", ui.replyFailed, normalizedMessage, "error")
updateStreamSmoke((current) => current ? {
......@@ -731,7 +864,7 @@ export function useChatStreamingController(deps: UseChatStreamingControllerDeps)
cancelTypewriter()
activeStreamRef.current = null
}
}, [desktopApi.chat])
}, [cancelActiveStream, desktopApi.chat])
useEffect(() => {
updateStreamSmoke((current) => {
......@@ -766,6 +899,7 @@ export function useChatStreamingController(deps: UseChatStreamingControllerDeps)
streamSmoke,
activeStreamRef,
submitPrompt,
cancelActiveStream,
ensureChatAvailable
}
}
export type SmokeStreamPhase = "idle" | "requested" | "started" | "streaming" | "completed" | "fallback" | "error"
export type SmokeStreamPhase = "idle" | "requested" | "started" | "streaming" | "completed" | "fallback" | "error" | "cancelled"
export interface SmokeStreamSnapshot {
phase: SmokeStreamPhase
......
......@@ -24,6 +24,7 @@ const mockUi = {
waitingReply: "已收到问题,正在组织回答"
} as const
const mockChatStreamListeners = new Set<ChatStreamListener>();
const mockChatStreamTimers = new Map<string, number[]>();
function emitMockChatStreamEvent(event: ChatStreamEvent) {
for (const listener of mockChatStreamListeners) {
......@@ -401,21 +402,28 @@ export const mockDesktopApi = {
const executionPolicy = { source: "client-config" as const, modelId: "qwen3.6-plus", modelLabel: "qwen3.6-plus", routingMode: "platform-managed" as const, skillId, skillName: skillId, message: "mock" };
const replyText = "Mock: " + prompt;
const chunks = replyText.match(/.{1,6}/g) ?? [replyText];
const timers: number[] = [];
const scheduleStreamTimer = (handler: () => void, delay: number) => {
const timer = window.setTimeout(handler, delay);
timers.push(timer);
};
mockChatStreamTimers.set(requestId, timers);
let fullText = "";
window.setTimeout(() => {
scheduleStreamTimer(() => {
emitMockChatStreamEvent({ type: "status", requestId, sessionId, runId, stage: "prepare-request", label: mockUi.preparingReply });
emitMockChatStreamEvent({ type: "started", requestId, sessionId, runId, executionPolicy });
}, 0);
window.setTimeout(() => {
scheduleStreamTimer(() => {
emitMockChatStreamEvent({ type: "status", requestId, sessionId, runId, stage: "await-model", label: mockUi.waitingReply });
}, 30);
chunks.forEach((chunk, index) => {
window.setTimeout(() => {
scheduleStreamTimer(() => {
fullText += chunk;
emitMockChatStreamEvent({ type: "delta", requestId, sessionId, runId, textDelta: chunk, fullText });
}, 90 * (index + 1));
});
window.setTimeout(() => {
scheduleStreamTimer(() => {
mockChatStreamTimers.delete(requestId);
emitMockChatStreamEvent({
type: "completed",
requestId,
......@@ -427,6 +435,28 @@ export const mockDesktopApi = {
}, 90 * (chunks.length + 1));
return { requestId, sessionId, runId, userMessageId, assistantMessageId, executionPolicy };
},
cancelStream: async (requestId: string, runId?: string, sessionId?: string) => {
for (const timer of mockChatStreamTimers.get(requestId) ?? []) {
window.clearTimeout(timer);
}
mockChatStreamTimers.delete(requestId);
emitMockChatStreamEvent({
type: "cancelled",
requestId,
sessionId,
runId,
message: "已停止",
remoteCancelled: false
});
return {
requestId,
sessionId,
runId,
localCancelled: true,
remoteCancelled: false,
message: "已停止"
};
},
onStreamEvent: (listener: ChatStreamListener) => {
mockChatStreamListeners.add(listener);
return () => {
......
import test from "node:test"
import assert from "node:assert/strict"
import { readFileSync } from "node:fs"
const controllerSource = readFileSync(new URL("../src/features/chat/useChatStreamingController.ts", import.meta.url), "utf8")
const composerSource = readFileSync(new URL("../src/features/chat/ChatComposer.tsx", import.meta.url), "utf8")
const mockSource = readFileSync(new URL("../src/lib/mock-desktop-api.ts", import.meta.url), "utf8")
test("chat streaming controller exposes cancelActiveStream and ignores later events", () => {
assert.match(controllerSource, /cancelActiveStream/)
assert.match(controllerSource, /desktopApi\.chat\.cancelStream/)
assert.match(controllerSource, /stoppedRequestIdsRef/)
assert.match(controllerSource, /event\.type === "cancelled"/)
})
test("chat streaming controller scopes preparing cancellation per submission", () => {
assert.match(controllerSource, /cancelledSubmissionIdsRef/)
assert.match(controllerSource, /pendingSubmissionIdRef/)
assert.doesNotMatch(controllerSource, /preStreamCancelRequestedRef/)
})
test("chat streaming controller clears stopped request ids on terminal events", () => {
assert.match(controllerSource, /stoppedRequestIdsRef\.current\.delete\(event\.requestId\)/)
})
test("composer can submit a stop action while sending", () => {
assert.match(composerSource, /onCancel/)
assert.match(composerSource, /type=\{sending \? "button" : "submit"\}/)
assert.match(composerSource, /onClick=\{sending \? onCancel : undefined\}/)
})
test("mock desktop API cancels pending stream timers", () => {
assert.match(mockSource, /mockChatStreamTimers/)
assert.match(mockSource, /cancelStream/)
assert.match(mockSource, /window\.clearTimeout/)
})
......@@ -114,6 +114,12 @@ export interface GatewayPromptStreamStart {
completion: Promise<ChatMessage>;
}
export interface GatewayCancelChatRunResult {
runId: string;
localCancelled: boolean;
remoteCancelled: boolean;
}
export interface GatewayPromptStreamDelta {
sessionId: string;
runId: string;
......@@ -483,6 +489,45 @@ export class GatewayClient {
return { sessionId, runId, completion };
}
async cancelChatRun(runId: string): Promise<GatewayCancelChatRunResult> {
const pending = this.pendingChatRuns.get(runId);
if (pending) {
clearTimeout(pending.timer);
this.pendingChatRuns.delete(runId);
pending.resolve({
id: `${pending.sessionKey}:${runId}:cancelled`,
role: "assistant",
content: pending.accumulatedText,
createdAt: new Date().toISOString()
});
}
const availableMethods = this.statusSnapshot.availableMethods ?? [];
if (!availableMethods.includes("chat.cancel")) {
return {
runId,
localCancelled: Boolean(pending),
remoteCancelled: false
};
}
try {
await this.request("chat.cancel", { runId });
return {
runId,
localCancelled: Boolean(pending),
remoteCancelled: true
};
} catch (error) {
this.appendLog("warn", `Gateway chat.cancel failed for ${runId}: ${error instanceof Error ? error.message : String(error)}`);
return {
runId,
localCancelled: Boolean(pending),
remoteCancelled: false
};
}
}
private async handleEvent(frame: Record<string, unknown>): Promise<void> {
const eventName = String(frame.event ?? "unknown");
......@@ -1187,4 +1232,3 @@ export class GatewayClient {
return message;
}
}
import test from "node:test"
import assert from "node:assert/strict"
import { readFileSync } from "node:fs"
const gatewaySource = readFileSync(new URL("../src/index.ts", import.meta.url), "utf8")
test("gateway client cancels local pending run even when remote cancel is unavailable", () => {
assert.match(gatewaySource, /async cancelChatRun\(runId: string\)/)
assert.match(gatewaySource, /remoteCancelled: false/)
assert.match(gatewaySource, /this\.pendingChatRuns\.delete\(runId\)/)
})
test("gateway client only sends cancel RPC when gateway advertises chat cancel", () => {
assert.match(gatewaySource, /availableMethods.*chat\.cancel/s)
assert.match(gatewaySource, /this\.request\("chat\.cancel"/)
})
......@@ -38,6 +38,7 @@
chatReadImageAttachmentDataUrl: "chat:read-image-attachment-data-url",
chatSendPrompt: "chat:send-prompt",
chatStreamPrompt: "chat:stream-prompt",
chatCancelStream: "chat:cancel-stream",
chatStreamEvent: "chat:stream-event",
diagnosticsOpenControlUi: "diagnostics:open-control-ui",
diagnosticsExportSnapshot: "diagnostics:export-snapshot",
......@@ -509,6 +510,21 @@ export interface ChatStreamPromptResult {
executionPolicy?: ChatExecutionPolicy;
}
export interface ChatCancelStreamRequest {
requestId: string;
runId?: string;
sessionId?: string;
}
export interface ChatCancelStreamResult {
requestId: string;
sessionId?: string;
runId?: string;
localCancelled: boolean;
remoteCancelled: boolean;
message: string;
}
export interface ChatStreamStartedEvent {
type: "started";
requestId: string;
......@@ -554,7 +570,16 @@ export interface ChatStreamErrorEvent {
errorCategory?: string;
}
export type ChatStreamEvent = ChatStreamStartedEvent | ChatStreamStatusEvent | ChatStreamDeltaEvent | ChatStreamCompletedEvent | ChatStreamErrorEvent;
export interface ChatStreamCancelledEvent {
type: "cancelled";
requestId: string;
sessionId?: string;
runId?: string;
message: string;
remoteCancelled: boolean;
}
export type ChatStreamEvent = ChatStreamStartedEvent | ChatStreamStatusEvent | ChatStreamDeltaEvent | ChatStreamCompletedEvent | ChatStreamErrorEvent | ChatStreamCancelledEvent;
export type ChatStreamListener = (event: ChatStreamEvent) => void;
......@@ -965,6 +990,7 @@ export interface DesktopApi {
readImageAttachmentDataUrl(attachment: ChatAttachment): Promise<string | null>;
sendPrompt(sessionId: string, prompt: string, skillId?: string, attachments?: ChatAttachment[]): Promise<PromptResult>;
streamPrompt(sessionId: string, prompt: string, skillId?: string, attachments?: ChatAttachment[]): Promise<ChatStreamPromptResult>;
cancelStream(requestId: string, runId?: string, sessionId?: string): Promise<ChatCancelStreamResult>;
onStreamEvent(listener: ChatStreamListener): () => void;
};
diagnostics: {
......
import test from "node:test"
import assert from "node:assert/strict"
import { readFileSync } from "node:fs"
const sharedTypesSource = readFileSync(new URL("../src/index.ts", import.meta.url), "utf8")
test("shared desktop chat API exposes cancel stream contract", () => {
assert.match(sharedTypesSource, /chatCancelStream:\s*"chat:cancel-stream"/)
assert.match(sharedTypesSource, /export interface ChatCancelStreamRequest/)
assert.match(sharedTypesSource, /export interface ChatCancelStreamResult/)
assert.match(sharedTypesSource, /cancelStream\(requestId: string, runId\?: string, sessionId\?: string\): Promise<ChatCancelStreamResult>/)
})
test("shared stream events include cancelled payload", () => {
assert.match(sharedTypesSource, /export interface ChatStreamCancelledEvent/)
assert.match(sharedTypesSource, /type:\s*"cancelled"/)
assert.match(sharedTypesSource, /ChatStreamEvent = .*ChatStreamCancelledEvent/s)
})
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment