diff --git a/security.py b/security.py index c15f3dc..195aa92 100644 --- a/security.py +++ b/security.py @@ -219,23 +219,37 @@ def extract_commands(command_string: str) -> list[str]: return commands -def validate_pkill_command(command_string: str) -> tuple[bool, str]: +# Default pkill process names (hardcoded baseline, always available) +DEFAULT_PKILL_PROCESSES = { + "node", + "npm", + "npx", + "vite", + "next", +} + + +def validate_pkill_command( + command_string: str, + extra_processes: Optional[set[str]] = None +) -> tuple[bool, str]: """ Validate pkill commands - only allow killing dev-related processes. Uses shlex to parse the command, avoiding regex bypass vulnerabilities. + Args: + command_string: The pkill command to validate + extra_processes: Optional set of additional process names to allow + (from org/project config pkill_processes) + Returns: Tuple of (is_allowed, reason_if_blocked) """ - # Allowed process names for pkill - allowed_process_names = { - "node", - "npm", - "npx", - "vite", - "next", - } + # Merge default processes with any extra configured processes + allowed_process_names = DEFAULT_PKILL_PROCESSES.copy() + if extra_processes: + allowed_process_names |= extra_processes try: tokens = shlex.split(command_string) @@ -264,7 +278,7 @@ def validate_pkill_command(command_string: str) -> tuple[bool, str]: if target in allowed_process_names: return True, "" - return False, f"pkill only allowed for dev processes: {allowed_process_names}" + return False, f"pkill only allowed for processes: {sorted(allowed_process_names)}" def validate_chmod_command(command_string: str) -> tuple[bool, str]: @@ -455,6 +469,15 @@ def load_org_config() -> Optional[dict]: if not isinstance(cmd, str): return None + # Validate pkill_processes if present + if "pkill_processes" in config: + processes = config["pkill_processes"] + if not isinstance(processes, list): + return None + for proc in processes: + if not isinstance(proc, str) or proc.strip() == "": + return None + return config except (yaml.YAMLError, IOError, OSError): @@ -508,6 +531,15 @@ def load_project_commands(project_dir: Path) -> Optional[dict]: if not isinstance(cmd["name"], str): return None + # Validate pkill_processes if present + if "pkill_processes" in config: + processes = config["pkill_processes"] + if not isinstance(processes, list): + return None + for proc in processes: + if not isinstance(proc, str) or proc.strip() == "": + return None + return config except (yaml.YAMLError, IOError, OSError): @@ -628,6 +660,42 @@ def get_project_allowed_commands(project_dir: Optional[Path]) -> set[str]: return allowed +def get_effective_pkill_processes(project_dir: Optional[Path]) -> set[str]: + """ + Get effective pkill process names after hierarchy resolution. + + Merges processes from: + 1. DEFAULT_PKILL_PROCESSES (hardcoded baseline) + 2. Org config pkill_processes + 3. Project config pkill_processes + + Args: + project_dir: Path to the project directory, or None + + Returns: + Set of allowed process names for pkill + """ + # Start with default processes + processes = DEFAULT_PKILL_PROCESSES.copy() + + # Add org-level pkill_processes + org_config = load_org_config() + if org_config: + org_processes = org_config.get("pkill_processes", []) + if isinstance(org_processes, list): + processes |= {p for p in org_processes if isinstance(p, str) and p.strip()} + + # Add project-level pkill_processes + if project_dir: + project_config = load_project_commands(project_dir) + if project_config: + proj_processes = project_config.get("pkill_processes", []) + if isinstance(proj_processes, list): + processes |= {p for p in proj_processes if isinstance(p, str) and p.strip()} + + return processes + + def is_command_allowed(command: str, allowed_commands: set[str]) -> bool: """ Check if a command is allowed (supports patterns). @@ -692,6 +760,9 @@ async def bash_security_hook(input_data, tool_use_id=None, context=None): # Get effective commands using hierarchy resolution allowed_commands, blocked_commands = get_effective_commands(project_dir) + # Get effective pkill processes (includes org/project config) + pkill_processes = get_effective_pkill_processes(project_dir) + # Split into segments for per-command validation segments = split_command_segments(command) @@ -725,7 +796,9 @@ async def bash_security_hook(input_data, tool_use_id=None, context=None): cmd_segment = command # Fallback to full command if cmd == "pkill": - allowed, reason = validate_pkill_command(cmd_segment) + # Pass configured extra processes (beyond defaults) + extra_procs = pkill_processes - DEFAULT_PKILL_PROCESSES + allowed, reason = validate_pkill_command(cmd_segment, extra_procs if extra_procs else None) if not allowed: return {"decision": "block", "reason": reason} elif cmd == "chmod": diff --git a/test_security.py b/test_security.py index 5b46cfe..5bd1867 100644 --- a/test_security.py +++ b/test_security.py @@ -15,14 +15,17 @@ from contextlib import contextmanager from pathlib import Path from security import ( + DEFAULT_PKILL_PROCESSES, bash_security_hook, extract_commands, get_effective_commands, + get_effective_pkill_processes, load_org_config, load_project_commands, matches_pattern, validate_chmod_command, validate_init_script, + validate_pkill_command, validate_project_command, ) @@ -670,6 +673,145 @@ blocked_commands: return passed, failed +def test_pkill_extensibility(): + """Test that pkill processes can be extended via config.""" + print("\nTesting pkill process extensibility:\n") + passed = 0 + failed = 0 + + # Test 1: Default processes work without config + allowed, reason = validate_pkill_command("pkill node") + if allowed: + print(" PASS: Default process 'node' allowed") + passed += 1 + else: + print(f" FAIL: Default process 'node' should be allowed: {reason}") + failed += 1 + + # Test 2: Non-default process blocked without config + allowed, reason = validate_pkill_command("pkill python") + if not allowed: + print(" PASS: Non-default process 'python' blocked without config") + passed += 1 + else: + print(" FAIL: Non-default process 'python' should be blocked without config") + failed += 1 + + # Test 3: Extra processes allowed when passed + allowed, reason = validate_pkill_command("pkill python", extra_processes={"python"}) + if allowed: + print(" PASS: Extra process 'python' allowed when configured") + passed += 1 + else: + print(f" FAIL: Extra process 'python' should be allowed when configured: {reason}") + failed += 1 + + # Test 4: Default processes still work with extra processes + allowed, reason = validate_pkill_command("pkill npm", extra_processes={"python"}) + if allowed: + print(" PASS: Default process 'npm' still works with extra processes") + passed += 1 + else: + print(f" FAIL: Default process should still work: {reason}") + failed += 1 + + # Test 5: Test get_effective_pkill_processes with org config + with tempfile.TemporaryDirectory() as tmphome: + with tempfile.TemporaryDirectory() as tmpproject: + with temporary_home(tmphome): + org_dir = Path(tmphome) / ".autocoder" + org_dir.mkdir() + org_config_path = org_dir / "config.yaml" + + # Create org config with extra pkill processes + org_config_path.write_text("""version: 1 +pkill_processes: + - python + - uvicorn +""") + + project_dir = Path(tmpproject) + processes = get_effective_pkill_processes(project_dir) + + # Should include defaults + org processes + if "node" in processes and "python" in processes and "uvicorn" in processes: + print(" PASS: Org pkill_processes merged with defaults") + passed += 1 + else: + print(f" FAIL: Expected node, python, uvicorn in {processes}") + failed += 1 + + # Test 6: Test get_effective_pkill_processes with project config + with tempfile.TemporaryDirectory() as tmphome: + with tempfile.TemporaryDirectory() as tmpproject: + with temporary_home(tmphome): + project_dir = Path(tmpproject) + project_autocoder = project_dir / ".autocoder" + project_autocoder.mkdir() + project_config = project_autocoder / "allowed_commands.yaml" + + # Create project config with extra pkill processes + project_config.write_text("""version: 1 +commands: [] +pkill_processes: + - gunicorn + - flask +""") + + processes = get_effective_pkill_processes(project_dir) + + # Should include defaults + project processes + if "node" in processes and "gunicorn" in processes and "flask" in processes: + print(" PASS: Project pkill_processes merged with defaults") + passed += 1 + else: + print(f" FAIL: Expected node, gunicorn, flask in {processes}") + failed += 1 + + # Test 7: Integration test - pkill python blocked by default + with tempfile.TemporaryDirectory() as tmphome: + with tempfile.TemporaryDirectory() as tmpproject: + with temporary_home(tmphome): + project_dir = Path(tmpproject) + input_data = {"tool_name": "Bash", "tool_input": {"command": "pkill python"}} + context = {"project_dir": str(project_dir)} + result = asyncio.run(bash_security_hook(input_data, context=context)) + + if result.get("decision") == "block": + print(" PASS: pkill python blocked without config") + passed += 1 + else: + print(" FAIL: pkill python should be blocked without config") + failed += 1 + + # Test 8: Integration test - pkill python allowed with org config + with tempfile.TemporaryDirectory() as tmphome: + with tempfile.TemporaryDirectory() as tmpproject: + with temporary_home(tmphome): + org_dir = Path(tmphome) / ".autocoder" + org_dir.mkdir() + org_config_path = org_dir / "config.yaml" + + org_config_path.write_text("""version: 1 +pkill_processes: + - python +""") + + project_dir = Path(tmpproject) + input_data = {"tool_name": "Bash", "tool_input": {"command": "pkill python"}} + context = {"project_dir": str(project_dir)} + result = asyncio.run(bash_security_hook(input_data, context=context)) + + if result.get("decision") != "block": + print(" PASS: pkill python allowed with org config") + passed += 1 + else: + print(f" FAIL: pkill python should be allowed with org config: {result}") + failed += 1 + + return passed, failed + + def main(): print("=" * 70) print(" SECURITY HOOK TESTS") @@ -733,6 +875,11 @@ def main(): passed += org_block_passed failed += org_block_failed + # Test pkill process extensibility + pkill_passed, pkill_failed = test_pkill_extensibility() + passed += pkill_passed + failed += pkill_failed + # Commands that SHOULD be blocked print("\nCommands that should be BLOCKED:\n") dangerous = [