From aac59c2b3af3a6909767add51c3c2050f3576f30 Mon Sep 17 00:00:00 2001 From: Shirone Date: Wed, 21 Jan 2026 14:57:26 +0100 Subject: [PATCH] feat(ui): enhance WebSocket event handling and polling logic - Introduced a new `useEventRecency` hook to track the recency of WebSocket events, allowing for conditional polling based on event activity. - Updated `AgentInfoPanel` to utilize the new hook, adjusting polling intervals based on WebSocket activity. - Implemented debounced invalidation for auto mode events to optimize query updates during rapid event streams. - Added utility functions for managing event recency checks in various query hooks, improving overall responsiveness and reducing unnecessary polling. - Introduced debounce and throttle utilities for better control over function execution rates. This enhancement improves the application's performance by reducing polling when real-time updates are available, ensuring a more efficient use of resources. --- .../kanban-card/agent-info-panel.tsx | 60 +++- apps/ui/src/hooks/index.ts | 9 + apps/ui/src/hooks/queries/use-features.ts | 14 +- apps/ui/src/hooks/queries/use-spec.ts | 5 +- apps/ui/src/hooks/use-event-recency.ts | 176 ++++++++++ apps/ui/src/hooks/use-query-invalidation.ts | 139 +++++++- libs/utils/src/debounce.ts | 280 +++++++++++++++ libs/utils/src/index.ts | 9 + libs/utils/tests/debounce.test.ts | 330 ++++++++++++++++++ 9 files changed, 1000 insertions(+), 22 deletions(-) create mode 100644 apps/ui/src/hooks/use-event-recency.ts create mode 100644 libs/utils/src/debounce.ts create mode 100644 libs/utils/tests/debounce.test.ts diff --git a/apps/ui/src/components/views/board-view/components/kanban-card/agent-info-panel.tsx b/apps/ui/src/components/views/board-view/components/kanban-card/agent-info-panel.tsx index 9cd9d793..fe77b6e5 100644 --- a/apps/ui/src/components/views/board-view/components/kanban-card/agent-info-panel.tsx +++ b/apps/ui/src/components/views/board-view/components/kanban-card/agent-info-panel.tsx @@ -1,4 +1,4 @@ -import { memo, useEffect, useState, useMemo } from 'react'; +import { memo, useEffect, useState, useMemo, useRef } from 'react'; import { Feature, ThinkingLevel, ParsedTask } from '@/store/app-store'; import type { ReasoningEffort } from '@automaker/types'; import { getProviderFromModel } from '@/lib/utils'; @@ -69,21 +69,70 @@ export const AgentInfoPanel = memo(function AgentInfoPanel({ const [taskStatusMap, setTaskStatusMap] = useState< Map >(new Map()); + // Track last WebSocket event timestamp to know if we're receiving real-time updates + const [lastWsEventTimestamp, setLastWsEventTimestamp] = useState(null); // Determine if we should poll for updates - const shouldPoll = isCurrentAutoTask || feature.status === 'in_progress'; const shouldFetchData = feature.status !== 'backlog'; + // Track whether we're receiving WebSocket events (within threshold) + // Use a state to trigger re-renders when the WebSocket connection becomes stale + const [isReceivingWsEvents, setIsReceivingWsEvents] = useState(false); + const wsEventTimeoutRef = useRef | null>(null); + + // WebSocket activity threshold in ms - if no events within this time, consider WS inactive + const WS_ACTIVITY_THRESHOLD = 10000; + + // Update isReceivingWsEvents when we get new WebSocket events + useEffect(() => { + if (lastWsEventTimestamp !== null) { + // We just received an event, mark as active + setIsReceivingWsEvents(true); + + // Clear any existing timeout + if (wsEventTimeoutRef.current) { + clearTimeout(wsEventTimeoutRef.current); + } + + // Set a timeout to mark as inactive if no new events + wsEventTimeoutRef.current = setTimeout(() => { + setIsReceivingWsEvents(false); + }, WS_ACTIVITY_THRESHOLD); + } + + return () => { + if (wsEventTimeoutRef.current) { + clearTimeout(wsEventTimeoutRef.current); + } + }; + }, [lastWsEventTimestamp]); + + // Polling interval logic: + // - If receiving WebSocket events: use longer interval (10s) as a fallback + // - If not receiving WebSocket events but in_progress: use normal interval (3s) + // - Otherwise: no polling + const pollingInterval = useMemo((): number | false => { + if (!(isCurrentAutoTask || feature.status === 'in_progress')) { + return false; + } + // If receiving WebSocket events, use longer polling interval as fallback + if (isReceivingWsEvents) { + return 10000; + } + // Default polling interval + return 3000; + }, [isCurrentAutoTask, feature.status, isReceivingWsEvents]); + // Fetch fresh feature data for planSpec (store data can be stale for task progress) const { data: freshFeature } = useFeature(projectPath, feature.id, { enabled: shouldFetchData && !contextContent, - pollingInterval: shouldPoll ? 3000 : false, + pollingInterval, }); // Fetch agent output for parsing const { data: agentOutputContent } = useAgentOutput(projectPath, feature.id, { enabled: shouldFetchData && !contextContent, - pollingInterval: shouldPoll ? 3000 : false, + pollingInterval, }); // Parse agent output into agentInfo @@ -174,6 +223,9 @@ export const AgentInfoPanel = memo(function AgentInfoPanel({ // Only handle events for this feature if (!('featureId' in event) || event.featureId !== feature.id) return; + // Update timestamp for any event related to this feature + setLastWsEventTimestamp(Date.now()); + switch (event.type) { case 'auto_mode_task_started': if ('taskId' in event) { diff --git a/apps/ui/src/hooks/index.ts b/apps/ui/src/hooks/index.ts index 8a354b3d..6d7e2bad 100644 --- a/apps/ui/src/hooks/index.ts +++ b/apps/ui/src/hooks/index.ts @@ -1,6 +1,15 @@ export { useAutoMode } from './use-auto-mode'; export { useBoardBackgroundSettings } from './use-board-background-settings'; export { useElectronAgent } from './use-electron-agent'; +export { + useEventRecorder, + useEventRecency, + useEventRecencyStore, + getGlobalEventsRecent, + getEventsRecent, + createSmartPollingInterval, + EVENT_RECENCY_THRESHOLD, +} from './use-event-recency'; export { useGuidedPrompts } from './use-guided-prompts'; export { useKeyboardShortcuts } from './use-keyboard-shortcuts'; export { useMessageQueue } from './use-message-queue'; diff --git a/apps/ui/src/hooks/queries/use-features.ts b/apps/ui/src/hooks/queries/use-features.ts index 78db6101..85eb701c 100644 --- a/apps/ui/src/hooks/queries/use-features.ts +++ b/apps/ui/src/hooks/queries/use-features.ts @@ -10,6 +10,7 @@ import { useQuery } from '@tanstack/react-query'; import { getElectronAPI } from '@/lib/electron'; import { queryKeys } from '@/lib/query-keys'; import { STALE_TIMES } from '@/lib/query-client'; +import { getGlobalEventsRecent } from '@/hooks/use-event-recency'; import type { Feature } from '@/store/app-store'; const FEATURES_REFETCH_ON_FOCUS = false; @@ -79,7 +80,11 @@ export function useFeature( }, enabled: !!projectPath && !!featureId && enabled, staleTime: STALE_TIMES.FEATURES, - refetchInterval: pollingInterval, + // When a polling interval is specified, disable it if WebSocket events are recent + refetchInterval: + pollingInterval === false || pollingInterval === undefined + ? pollingInterval + : () => (getGlobalEventsRecent() ? false : pollingInterval), refetchOnWindowFocus: FEATURES_REFETCH_ON_FOCUS, refetchOnReconnect: FEATURES_REFETCH_ON_RECONNECT, }); @@ -119,11 +124,16 @@ export function useAgentOutput( }, enabled: !!projectPath && !!featureId && enabled, staleTime: STALE_TIMES.AGENT_OUTPUT, - // Use provided polling interval or default behavior + // Use provided polling interval or default smart behavior refetchInterval: pollingInterval !== undefined ? pollingInterval : (query) => { + // Disable polling when WebSocket events are recent (within 5s) + // WebSocket invalidation handles updates in real-time + if (getGlobalEventsRecent()) { + return false; + } // Only poll if we have data and it's not empty (indicating active task) if (query.state.data && query.state.data.length > 0) { return 5000; // 5 seconds diff --git a/apps/ui/src/hooks/queries/use-spec.ts b/apps/ui/src/hooks/queries/use-spec.ts index c81dea34..d2cce124 100644 --- a/apps/ui/src/hooks/queries/use-spec.ts +++ b/apps/ui/src/hooks/queries/use-spec.ts @@ -8,6 +8,7 @@ import { useQuery } from '@tanstack/react-query'; import { getElectronAPI } from '@/lib/electron'; import { queryKeys } from '@/lib/query-keys'; import { STALE_TIMES } from '@/lib/query-client'; +import { getGlobalEventsRecent } from '@/hooks/use-event-recency'; interface SpecFileResult { content: string; @@ -98,6 +99,8 @@ export function useSpecRegenerationStatus(projectPath: string | undefined, enabl }, enabled: !!projectPath && enabled, staleTime: 5000, // Check every 5 seconds when active - refetchInterval: enabled ? 5000 : false, + // Disable polling when WebSocket events are recent (within 5s) + // WebSocket invalidation handles updates in real-time + refetchInterval: enabled ? () => (getGlobalEventsRecent() ? false : 5000) : false, }); } diff --git a/apps/ui/src/hooks/use-event-recency.ts b/apps/ui/src/hooks/use-event-recency.ts new file mode 100644 index 00000000..d3a56139 --- /dev/null +++ b/apps/ui/src/hooks/use-event-recency.ts @@ -0,0 +1,176 @@ +/** + * Event Recency Hook + * + * Tracks the timestamp of the last WebSocket event received. + * Used to conditionally disable polling when events are flowing + * through WebSocket (indicating the connection is healthy). + */ + +import { useEffect, useCallback } from 'react'; +import { create } from 'zustand'; + +/** + * Time threshold (ms) to consider events as "recent" + * If an event was received within this time, WebSocket is considered healthy + * and polling can be safely disabled. + */ +export const EVENT_RECENCY_THRESHOLD = 5000; // 5 seconds + +/** + * Store for tracking event timestamps per query key + * This allows fine-grained control over which queries have received recent events + */ +interface EventRecencyState { + /** Map of query key (stringified) -> last event timestamp */ + eventTimestamps: Record; + /** Global last event timestamp (for any event) */ + lastGlobalEventTimestamp: number; + /** Record an event for a specific query key */ + recordEvent: (queryKey: string) => void; + /** Record a global event (useful for general WebSocket health) */ + recordGlobalEvent: () => void; + /** Check if events are recent for a specific query key */ + areEventsRecent: (queryKey: string) => boolean; + /** Check if any global events are recent */ + areGlobalEventsRecent: () => boolean; +} + +export const useEventRecencyStore = create((set, get) => ({ + eventTimestamps: {}, + lastGlobalEventTimestamp: 0, + + recordEvent: (queryKey: string) => { + const now = Date.now(); + set((state) => ({ + eventTimestamps: { + ...state.eventTimestamps, + [queryKey]: now, + }, + lastGlobalEventTimestamp: now, + })); + }, + + recordGlobalEvent: () => { + set({ lastGlobalEventTimestamp: Date.now() }); + }, + + areEventsRecent: (queryKey: string) => { + const { eventTimestamps } = get(); + const lastEventTime = eventTimestamps[queryKey]; + if (!lastEventTime) return false; + return Date.now() - lastEventTime < EVENT_RECENCY_THRESHOLD; + }, + + areGlobalEventsRecent: () => { + const { lastGlobalEventTimestamp } = get(); + if (!lastGlobalEventTimestamp) return false; + return Date.now() - lastGlobalEventTimestamp < EVENT_RECENCY_THRESHOLD; + }, +})); + +/** + * Hook to record event timestamps when WebSocket events are received. + * Should be called from WebSocket event handlers. + * + * @returns Functions to record events + * + * @example + * ```tsx + * const { recordEvent, recordGlobalEvent } = useEventRecorder(); + * + * // In WebSocket event handler: + * api.autoMode.onEvent((event) => { + * recordGlobalEvent(); + * if (event.featureId) { + * recordEvent(`features:${event.featureId}`); + * } + * }); + * ``` + */ +export function useEventRecorder() { + const recordEvent = useEventRecencyStore((state) => state.recordEvent); + const recordGlobalEvent = useEventRecencyStore((state) => state.recordGlobalEvent); + + return { recordEvent, recordGlobalEvent }; +} + +/** + * Hook to check if WebSocket events are recent, used by queries + * to decide whether to enable/disable polling. + * + * @param queryKey - Optional specific query key to check + * @returns Object with recency check result and timestamp + * + * @example + * ```tsx + * const { areEventsRecent, areGlobalEventsRecent } = useEventRecency(); + * + * // In query options: + * refetchInterval: areGlobalEventsRecent() ? false : 5000, + * ``` + */ +export function useEventRecency(queryKey?: string) { + const areEventsRecent = useEventRecencyStore((state) => state.areEventsRecent); + const areGlobalEventsRecent = useEventRecencyStore((state) => state.areGlobalEventsRecent); + const lastGlobalEventTimestamp = useEventRecencyStore((state) => state.lastGlobalEventTimestamp); + + const checkRecency = useCallback( + (key?: string) => { + if (key) { + return areEventsRecent(key); + } + return areGlobalEventsRecent(); + }, + [areEventsRecent, areGlobalEventsRecent] + ); + + return { + areEventsRecent: queryKey ? () => areEventsRecent(queryKey) : areEventsRecent, + areGlobalEventsRecent, + checkRecency, + lastGlobalEventTimestamp, + }; +} + +/** + * Utility function to create a refetchInterval that respects event recency. + * Returns false (no polling) if events are recent, otherwise returns the interval. + * + * @param defaultInterval - The polling interval to use when events aren't recent + * @returns A function suitable for React Query's refetchInterval option + * + * @example + * ```tsx + * const { data } = useQuery({ + * queryKey: ['features'], + * queryFn: fetchFeatures, + * refetchInterval: createSmartPollingInterval(5000), + * }); + * ``` + */ +export function createSmartPollingInterval(defaultInterval: number) { + return () => { + const { areGlobalEventsRecent } = useEventRecencyStore.getState(); + return areGlobalEventsRecent() ? false : defaultInterval; + }; +} + +/** + * Helper function to get current event recency state (for use outside React) + * Useful in query configurations where hooks can't be used directly. + * + * @returns Whether global events are recent + */ +export function getGlobalEventsRecent(): boolean { + return useEventRecencyStore.getState().areGlobalEventsRecent(); +} + +/** + * Helper function to get event recency for a specific query key (for use outside React) + * + * @param queryKey - The query key to check + * @returns Whether events for that query key are recent + */ +export function getEventsRecent(queryKey: string): boolean { + return useEventRecencyStore.getState().areEventsRecent(queryKey); +} diff --git a/apps/ui/src/hooks/use-query-invalidation.ts b/apps/ui/src/hooks/use-query-invalidation.ts index eb0bfb4d..88625bcb 100644 --- a/apps/ui/src/hooks/use-query-invalidation.ts +++ b/apps/ui/src/hooks/use-query-invalidation.ts @@ -5,12 +5,48 @@ * ensuring the UI stays in sync with server-side changes without manual refetching. */ -import { useEffect } from 'react'; -import { useQueryClient } from '@tanstack/react-query'; +import { useEffect, useRef } from 'react'; +import { useQueryClient, QueryClient } from '@tanstack/react-query'; import { getElectronAPI } from '@/lib/electron'; import { queryKeys } from '@/lib/query-keys'; import type { AutoModeEvent, SpecRegenerationEvent } from '@/types/electron'; import type { IssueValidationEvent } from '@automaker/types'; +import { debounce, DebouncedFunction } from '@automaker/utils'; +import { useEventRecencyStore } from './use-event-recency'; + +/** + * Debounce configuration for auto_mode_progress invalidations + * - wait: 150ms delay to batch rapid consecutive progress events + * - maxWait: 2000ms ensures UI updates at least every 2 seconds during streaming + */ +const PROGRESS_DEBOUNCE_WAIT = 150; +const PROGRESS_DEBOUNCE_MAX_WAIT = 2000; + +/** + * Creates a unique key for per-feature debounce tracking + */ +function getFeatureKey(projectPath: string, featureId: string): string { + return `${projectPath}:${featureId}`; +} + +/** + * Creates a debounced invalidation function for a specific feature's agent output + */ +function createDebouncedInvalidation( + queryClient: QueryClient, + projectPath: string, + featureId: string +): DebouncedFunction<() => void> { + return debounce( + () => { + queryClient.invalidateQueries({ + queryKey: queryKeys.features.agentOutput(projectPath, featureId), + }); + }, + PROGRESS_DEBOUNCE_WAIT, + { maxWait: PROGRESS_DEBOUNCE_MAX_WAIT } + ); +} /** * Invalidate queries based on auto mode events @@ -31,12 +67,54 @@ import type { IssueValidationEvent } from '@automaker/types'; */ export function useAutoModeQueryInvalidation(projectPath: string | undefined) { const queryClient = useQueryClient(); + const recordGlobalEvent = useEventRecencyStore((state) => state.recordGlobalEvent); + + // Store per-feature debounced invalidation functions + // Using a ref to persist across renders without causing re-subscriptions + const debouncedInvalidationsRef = useRef void>>>(new Map()); useEffect(() => { if (!projectPath) return; + // Capture projectPath in a const to satisfy TypeScript's type narrowing + const currentProjectPath = projectPath; + const debouncedInvalidations = debouncedInvalidationsRef.current; + + /** + * Get or create a debounced invalidation function for a specific feature + */ + function getDebouncedInvalidation(featureId: string): DebouncedFunction<() => void> { + const key = getFeatureKey(currentProjectPath, featureId); + let debouncedFn = debouncedInvalidations.get(key); + + if (!debouncedFn) { + debouncedFn = createDebouncedInvalidation(queryClient, currentProjectPath, featureId); + debouncedInvalidations.set(key, debouncedFn); + } + + return debouncedFn; + } + + /** + * Clean up debounced function for a feature (flush pending and remove) + */ + function cleanupFeatureDebounce(featureId: string): void { + const key = getFeatureKey(currentProjectPath, featureId); + const debouncedFn = debouncedInvalidations.get(key); + + if (debouncedFn) { + // Flush any pending invalidation before cleanup + debouncedFn.flush(); + debouncedInvalidations.delete(key); + } + } + const api = getElectronAPI(); const unsubscribe = api.autoMode.onEvent((event: AutoModeEvent) => { + // Record that we received a WebSocket event (for event recency tracking) + // This allows polling to be disabled when WebSocket events are flowing + recordGlobalEvent(); + // Invalidate features when agent completes, errors, or receives plan approval if ( event.type === 'auto_mode_feature_complete' || @@ -47,7 +125,7 @@ export function useAutoModeQueryInvalidation(projectPath: string | undefined) { event.type === 'pipeline_step_complete' ) { queryClient.invalidateQueries({ - queryKey: queryKeys.features.all(projectPath), + queryKey: queryKeys.features.all(currentProjectPath), }); } @@ -72,30 +150,49 @@ export function useAutoModeQueryInvalidation(projectPath: string | undefined) { 'featureId' in event ) { queryClient.invalidateQueries({ - queryKey: queryKeys.features.single(projectPath, event.featureId), + queryKey: queryKeys.features.single(currentProjectPath, event.featureId), }); } - // Invalidate agent output during progress updates + // Invalidate agent output during progress updates (DEBOUNCED) + // Uses per-feature debouncing to batch rapid progress events during streaming if (event.type === 'auto_mode_progress' && 'featureId' in event) { - queryClient.invalidateQueries({ - queryKey: queryKeys.features.agentOutput(projectPath, event.featureId), - }); + const debouncedInvalidation = getDebouncedInvalidation(event.featureId); + debouncedInvalidation(); + } + + // Clean up debounced functions when feature completes or errors + // This ensures we flush any pending invalidations and free memory + if ( + (event.type === 'auto_mode_feature_complete' || event.type === 'auto_mode_error') && + 'featureId' in event && + event.featureId + ) { + cleanupFeatureDebounce(event.featureId); } // Invalidate worktree queries when feature completes (may have created worktree) if (event.type === 'auto_mode_feature_complete' && 'featureId' in event) { queryClient.invalidateQueries({ - queryKey: queryKeys.worktrees.all(projectPath), + queryKey: queryKeys.worktrees.all(currentProjectPath), }); queryClient.invalidateQueries({ - queryKey: queryKeys.worktrees.single(projectPath, event.featureId), + queryKey: queryKeys.worktrees.single(currentProjectPath, event.featureId), }); } }); - return unsubscribe; - }, [projectPath, queryClient]); + // Cleanup on unmount: flush and clear all debounced functions + return () => { + unsubscribe(); + + // Flush all pending invalidations before cleanup + for (const debouncedFn of debouncedInvalidations.values()) { + debouncedFn.flush(); + } + debouncedInvalidations.clear(); + }; + }, [projectPath, queryClient, recordGlobalEvent]); } /** @@ -105,6 +202,7 @@ export function useAutoModeQueryInvalidation(projectPath: string | undefined) { */ export function useSpecRegenerationQueryInvalidation(projectPath: string | undefined) { const queryClient = useQueryClient(); + const recordGlobalEvent = useEventRecencyStore((state) => state.recordGlobalEvent); useEffect(() => { if (!projectPath) return; @@ -114,6 +212,9 @@ export function useSpecRegenerationQueryInvalidation(projectPath: string | undef // Only handle events for the current project if (event.projectPath !== projectPath) return; + // Record that we received a WebSocket event + recordGlobalEvent(); + if (event.type === 'spec_regeneration_complete') { // Invalidate features as new ones may have been generated queryClient.invalidateQueries({ @@ -128,7 +229,7 @@ export function useSpecRegenerationQueryInvalidation(projectPath: string | undef }); return unsubscribe; - }, [projectPath, queryClient]); + }, [projectPath, queryClient, recordGlobalEvent]); } /** @@ -138,6 +239,7 @@ export function useSpecRegenerationQueryInvalidation(projectPath: string | undef */ export function useGitHubValidationQueryInvalidation(projectPath: string | undefined) { const queryClient = useQueryClient(); + const recordGlobalEvent = useEventRecencyStore((state) => state.recordGlobalEvent); useEffect(() => { if (!projectPath) return; @@ -150,6 +252,9 @@ export function useGitHubValidationQueryInvalidation(projectPath: string | undef } const unsubscribe = api.github.onValidationEvent((event: IssueValidationEvent) => { + // Record that we received a WebSocket event + recordGlobalEvent(); + if (event.type === 'validation_complete' || event.type === 'validation_error') { // Invalidate all validations for this project queryClient.invalidateQueries({ @@ -166,7 +271,7 @@ export function useGitHubValidationQueryInvalidation(projectPath: string | undef }); return unsubscribe; - }, [projectPath, queryClient]); + }, [projectPath, queryClient, recordGlobalEvent]); } /** @@ -176,6 +281,7 @@ export function useGitHubValidationQueryInvalidation(projectPath: string | undef */ export function useSessionQueryInvalidation(sessionId: string | undefined) { const queryClient = useQueryClient(); + const recordGlobalEvent = useEventRecencyStore((state) => state.recordGlobalEvent); useEffect(() => { if (!sessionId) return; @@ -185,6 +291,9 @@ export function useSessionQueryInvalidation(sessionId: string | undefined) { // Only handle events for the current session if ('sessionId' in event && event.sessionId !== sessionId) return; + // Record that we received a WebSocket event + recordGlobalEvent(); + // Invalidate session history when a message is complete if (event.type === 'complete' || event.type === 'message') { queryClient.invalidateQueries({ @@ -201,7 +310,7 @@ export function useSessionQueryInvalidation(sessionId: string | undefined) { }); return unsubscribe; - }, [sessionId, queryClient]); + }, [sessionId, queryClient, recordGlobalEvent]); } /** diff --git a/libs/utils/src/debounce.ts b/libs/utils/src/debounce.ts new file mode 100644 index 00000000..211fb8ff --- /dev/null +++ b/libs/utils/src/debounce.ts @@ -0,0 +1,280 @@ +/** + * Debounce and throttle utilities for rate-limiting function calls + */ + +/** + * Options for the debounce function + */ +export interface DebounceOptions { + /** + * If true, call the function immediately on the first invocation (leading edge) + * @default false + */ + leading?: boolean; + + /** + * If true, call the function after the delay on the last invocation (trailing edge) + * @default true + */ + trailing?: boolean; + + /** + * Maximum time to wait before forcing invocation (useful for continuous events) + * If set, the function will be called at most every `maxWait` milliseconds + */ + maxWait?: number; +} + +/** + * The return type of the debounce function with additional control methods + */ +export interface DebouncedFunction unknown> { + /** + * Call the debounced function + */ + (...args: Parameters): void; + + /** + * Cancel any pending invocation + */ + cancel(): void; + + /** + * Immediately invoke any pending function call + */ + flush(): void; + + /** + * Check if there's a pending invocation + */ + pending(): boolean; +} + +/** + * Creates a debounced version of a function that delays invoking the function + * until after `wait` milliseconds have elapsed since the last time the debounced + * function was invoked. + * + * Useful for rate-limiting events like window resize, scroll, or input changes. + * + * @param fn - The function to debounce + * @param wait - The number of milliseconds to delay + * @param options - Optional configuration + * @returns A debounced version of the function with cancel, flush, and pending methods + * + * @example + * // Basic usage - save input after user stops typing for 300ms + * const saveInput = debounce((value: string) => { + * api.save(value); + * }, 300); + * + * input.addEventListener('input', (e) => saveInput(e.target.value)); + * + * @example + * // With leading edge - execute immediately on first call + * const handleClick = debounce(() => { + * submitForm(); + * }, 1000, { leading: true, trailing: false }); + * + * @example + * // With maxWait - ensure function runs at least every 5 seconds during continuous input + * const autoSave = debounce((content: string) => { + * saveToServer(content); + * }, 1000, { maxWait: 5000 }); + */ +export function debounce unknown>( + fn: T, + wait: number, + options: DebounceOptions = {} +): DebouncedFunction { + const { leading = false, trailing = true, maxWait } = options; + + let timeoutId: ReturnType | null = null; + let maxTimeoutId: ReturnType | null = null; + let lastArgs: Parameters | null = null; + let lastCallTime: number | null = null; + let lastInvokeTime = 0; + + // Validate options + if (maxWait !== undefined && maxWait < wait) { + throw new Error('maxWait must be greater than or equal to wait'); + } + + function invokeFunc(): void { + const args = lastArgs; + lastArgs = null; + lastInvokeTime = Date.now(); + + if (args !== null) { + fn(...args); + } + } + + function shouldInvoke(time: number): boolean { + const timeSinceLastCall = lastCallTime === null ? 0 : time - lastCallTime; + const timeSinceLastInvoke = time - lastInvokeTime; + + // First call, or wait time has passed, or maxWait exceeded + return ( + lastCallTime === null || + timeSinceLastCall >= wait || + timeSinceLastCall < 0 || + (maxWait !== undefined && timeSinceLastInvoke >= maxWait) + ); + } + + function timerExpired(): void { + const time = Date.now(); + + if (shouldInvoke(time)) { + trailingEdge(); + return; + } + + // Restart the timer with remaining time + const timeSinceLastCall = lastCallTime === null ? 0 : time - lastCallTime; + const timeSinceLastInvoke = time - lastInvokeTime; + const timeWaiting = wait - timeSinceLastCall; + + const remainingWait = + maxWait !== undefined ? Math.min(timeWaiting, maxWait - timeSinceLastInvoke) : timeWaiting; + + timeoutId = setTimeout(timerExpired, remainingWait); + } + + function trailingEdge(): void { + timeoutId = null; + + if (trailing && lastArgs !== null) { + invokeFunc(); + } + + lastArgs = null; + } + + function leadingEdge(time: number): void { + lastInvokeTime = time; + + // Start timer for trailing edge + timeoutId = setTimeout(timerExpired, wait); + + // Invoke leading edge + if (leading) { + invokeFunc(); + } + } + + function cancel(): void { + if (timeoutId !== null) { + clearTimeout(timeoutId); + timeoutId = null; + } + if (maxTimeoutId !== null) { + clearTimeout(maxTimeoutId); + maxTimeoutId = null; + } + lastArgs = null; + lastCallTime = null; + lastInvokeTime = 0; + } + + function flush(): void { + if (timeoutId !== null) { + invokeFunc(); + cancel(); + } + } + + function pending(): boolean { + return timeoutId !== null; + } + + function debounced(...args: Parameters): void { + const time = Date.now(); + const isInvoking = shouldInvoke(time); + + lastArgs = args; + lastCallTime = time; + + if (isInvoking) { + if (timeoutId === null) { + leadingEdge(time); + return; + } + + // Handle maxWait case + if (maxWait !== undefined) { + timeoutId = setTimeout(timerExpired, wait); + invokeFunc(); + return; + } + } + + if (timeoutId === null) { + timeoutId = setTimeout(timerExpired, wait); + } + } + + debounced.cancel = cancel; + debounced.flush = flush; + debounced.pending = pending; + + return debounced; +} + +/** + * Options for the throttle function + */ +export interface ThrottleOptions { + /** + * If true, call the function on the leading edge + * @default true + */ + leading?: boolean; + + /** + * If true, call the function on the trailing edge + * @default true + */ + trailing?: boolean; +} + +/** + * Creates a throttled version of a function that only invokes the function + * at most once per every `wait` milliseconds. + * + * Useful for rate-limiting events like scroll or mousemove where you want + * regular updates but not on every event. + * + * @param fn - The function to throttle + * @param wait - The number of milliseconds to throttle invocations to + * @param options - Optional configuration + * @returns A throttled version of the function with cancel, flush, and pending methods + * + * @example + * // Throttle scroll handler to run at most every 100ms + * const handleScroll = throttle(() => { + * updateScrollPosition(); + * }, 100); + * + * window.addEventListener('scroll', handleScroll); + * + * @example + * // Throttle with leading edge only (no trailing call) + * const submitOnce = throttle(() => { + * submitForm(); + * }, 1000, { trailing: false }); + */ +export function throttle unknown>( + fn: T, + wait: number, + options: ThrottleOptions = {} +): DebouncedFunction { + const { leading = true, trailing = true } = options; + + return debounce(fn, wait, { + leading, + trailing, + maxWait: wait, + }); +} diff --git a/libs/utils/src/index.ts b/libs/utils/src/index.ts index e5e7ea16..4a2b2dd6 100644 --- a/libs/utils/src/index.ts +++ b/libs/utils/src/index.ts @@ -105,3 +105,12 @@ export { type LearningEntry, type SimpleMemoryFile, } from './memory-loader.js'; + +// Debounce and throttle utilities +export { + debounce, + throttle, + type DebounceOptions, + type ThrottleOptions, + type DebouncedFunction, +} from './debounce.js'; diff --git a/libs/utils/tests/debounce.test.ts b/libs/utils/tests/debounce.test.ts new file mode 100644 index 00000000..bf04f6b0 --- /dev/null +++ b/libs/utils/tests/debounce.test.ts @@ -0,0 +1,330 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import { debounce, throttle } from '../src/debounce.js'; + +describe('debounce', () => { + beforeEach(() => { + vi.useFakeTimers(); + }); + + afterEach(() => { + vi.useRealTimers(); + }); + + it('should delay function execution', () => { + const fn = vi.fn(); + const debounced = debounce(fn, 100); + + debounced(); + expect(fn).not.toHaveBeenCalled(); + + vi.advanceTimersByTime(50); + expect(fn).not.toHaveBeenCalled(); + + vi.advanceTimersByTime(50); + expect(fn).toHaveBeenCalledTimes(1); + }); + + it('should reset timer on subsequent calls', () => { + const fn = vi.fn(); + const debounced = debounce(fn, 100); + + debounced(); + vi.advanceTimersByTime(50); + debounced(); // Reset timer + vi.advanceTimersByTime(50); + expect(fn).not.toHaveBeenCalled(); + + vi.advanceTimersByTime(50); + expect(fn).toHaveBeenCalledTimes(1); + }); + + it('should pass arguments to the function', () => { + const fn = vi.fn(); + const debounced = debounce(fn, 100); + + debounced('arg1', 'arg2'); + vi.advanceTimersByTime(100); + + expect(fn).toHaveBeenCalledWith('arg1', 'arg2'); + }); + + it('should use the latest arguments when called multiple times', () => { + const fn = vi.fn(); + const debounced = debounce(fn, 100); + + debounced('first'); + debounced('second'); + debounced('third'); + vi.advanceTimersByTime(100); + + expect(fn).toHaveBeenCalledTimes(1); + expect(fn).toHaveBeenCalledWith('third'); + }); + + describe('leading option', () => { + it('should call function immediately when leading is true', () => { + const fn = vi.fn(); + const debounced = debounce(fn, 100, { leading: true }); + + debounced(); + expect(fn).toHaveBeenCalledTimes(1); + }); + + it('should not call again until wait time has passed', () => { + const fn = vi.fn(); + const debounced = debounce(fn, 100, { leading: true, trailing: false }); + + debounced(); + debounced(); + debounced(); + expect(fn).toHaveBeenCalledTimes(1); + + vi.advanceTimersByTime(100); + debounced(); + expect(fn).toHaveBeenCalledTimes(2); + }); + + it('should call both leading and trailing when both are true', () => { + const fn = vi.fn(); + const debounced = debounce(fn, 100, { leading: true, trailing: true }); + + debounced('leading'); + expect(fn).toHaveBeenCalledTimes(1); + expect(fn).toHaveBeenLastCalledWith('leading'); + + debounced('trailing'); + vi.advanceTimersByTime(100); + + expect(fn).toHaveBeenCalledTimes(2); + expect(fn).toHaveBeenLastCalledWith('trailing'); + }); + }); + + describe('trailing option', () => { + it('should not call on trailing edge when trailing is false', () => { + const fn = vi.fn(); + const debounced = debounce(fn, 100, { trailing: false }); + + debounced(); + vi.advanceTimersByTime(100); + + expect(fn).not.toHaveBeenCalled(); + }); + }); + + describe('maxWait option', () => { + it('should invoke function after maxWait even with continuous calls', () => { + const fn = vi.fn(); + const debounced = debounce(fn, 100, { maxWait: 200 }); + + // Call continuously every 50ms + debounced(); + vi.advanceTimersByTime(50); + debounced(); + vi.advanceTimersByTime(50); + debounced(); + vi.advanceTimersByTime(50); + debounced(); + vi.advanceTimersByTime(50); + + // After 200ms, maxWait should trigger + expect(fn).toHaveBeenCalledTimes(1); + }); + + it('should throw error if maxWait is less than wait', () => { + const fn = vi.fn(); + expect(() => debounce(fn, 100, { maxWait: 50 })).toThrow( + 'maxWait must be greater than or equal to wait' + ); + }); + }); + + describe('cancel method', () => { + it('should cancel pending invocation', () => { + const fn = vi.fn(); + const debounced = debounce(fn, 100); + + debounced(); + debounced.cancel(); + vi.advanceTimersByTime(100); + + expect(fn).not.toHaveBeenCalled(); + }); + + it('should reset state after cancel', () => { + const fn = vi.fn(); + const debounced = debounce(fn, 100); + + debounced('first'); + debounced.cancel(); + debounced('second'); + vi.advanceTimersByTime(100); + + expect(fn).toHaveBeenCalledTimes(1); + expect(fn).toHaveBeenCalledWith('second'); + }); + }); + + describe('flush method', () => { + it('should immediately invoke pending function', () => { + const fn = vi.fn(); + const debounced = debounce(fn, 100); + + debounced('value'); + expect(fn).not.toHaveBeenCalled(); + + debounced.flush(); + expect(fn).toHaveBeenCalledTimes(1); + expect(fn).toHaveBeenCalledWith('value'); + }); + + it('should not invoke if no pending call', () => { + const fn = vi.fn(); + const debounced = debounce(fn, 100); + + debounced.flush(); + expect(fn).not.toHaveBeenCalled(); + }); + + it('should cancel timer after flush', () => { + const fn = vi.fn(); + const debounced = debounce(fn, 100); + + debounced(); + debounced.flush(); + vi.advanceTimersByTime(100); + + expect(fn).toHaveBeenCalledTimes(1); + }); + }); + + describe('pending method', () => { + it('should return true when there is a pending invocation', () => { + const fn = vi.fn(); + const debounced = debounce(fn, 100); + + expect(debounced.pending()).toBe(false); + debounced(); + expect(debounced.pending()).toBe(true); + }); + + it('should return false after invocation', () => { + const fn = vi.fn(); + const debounced = debounce(fn, 100); + + debounced(); + vi.advanceTimersByTime(100); + expect(debounced.pending()).toBe(false); + }); + + it('should return false after cancel', () => { + const fn = vi.fn(); + const debounced = debounce(fn, 100); + + debounced(); + debounced.cancel(); + expect(debounced.pending()).toBe(false); + }); + }); +}); + +describe('throttle', () => { + beforeEach(() => { + vi.useFakeTimers(); + }); + + afterEach(() => { + vi.useRealTimers(); + }); + + it('should invoke function immediately by default', () => { + const fn = vi.fn(); + const throttled = throttle(fn, 100); + + throttled(); + expect(fn).toHaveBeenCalledTimes(1); + }); + + it('should not invoke again before wait time', () => { + const fn = vi.fn(); + const throttled = throttle(fn, 100); + + throttled(); + throttled(); + throttled(); + expect(fn).toHaveBeenCalledTimes(1); + }); + + it('should invoke on trailing edge with latest args', () => { + const fn = vi.fn(); + const throttled = throttle(fn, 100); + + throttled('first'); + expect(fn).toHaveBeenCalledWith('first'); + + throttled('second'); + throttled('third'); + vi.advanceTimersByTime(100); + + expect(fn).toHaveBeenCalledTimes(2); + expect(fn).toHaveBeenLastCalledWith('third'); + }); + + it('should respect leading option', () => { + const fn = vi.fn(); + const throttled = throttle(fn, 100, { leading: false }); + + throttled(); + expect(fn).not.toHaveBeenCalled(); + + vi.advanceTimersByTime(100); + expect(fn).toHaveBeenCalledTimes(1); + }); + + it('should respect trailing option', () => { + const fn = vi.fn(); + const throttled = throttle(fn, 100, { trailing: false }); + + throttled('first'); + throttled('second'); + vi.advanceTimersByTime(100); + + expect(fn).toHaveBeenCalledTimes(1); + expect(fn).toHaveBeenCalledWith('first'); + }); + + it('should invoke at regular intervals during continuous calls', () => { + const fn = vi.fn(); + const throttled = throttle(fn, 100); + + // Simulate continuous calls every 25ms for 250ms + for (let i = 0; i < 10; i++) { + throttled(i); + vi.advanceTimersByTime(25); + } + + // Should be called at: 0ms (leading), 100ms, 200ms + // Plus one trailing call after the loop + expect(fn.mock.calls.length).toBeGreaterThanOrEqual(3); + }); + + it('should have cancel, flush, and pending methods', () => { + const fn = vi.fn(); + const throttled = throttle(fn, 100); + + expect(typeof throttled.cancel).toBe('function'); + expect(typeof throttled.flush).toBe('function'); + expect(typeof throttled.pending).toBe('function'); + }); + + it('should cancel pending invocation', () => { + const fn = vi.fn(); + const throttled = throttle(fn, 100, { leading: false }); + + throttled(); + throttled.cancel(); + vi.advanceTimersByTime(100); + + expect(fn).not.toHaveBeenCalled(); + }); +});