diff --git a/security.py b/security.py index c15f3dc..44507a4 100644 --- a/security.py +++ b/security.py @@ -7,12 +7,17 @@ Uses an allowlist approach - only explicitly permitted commands can run. """ import os +import re import shlex from pathlib import Path from typing import Optional import yaml +# Regex pattern for valid pkill process names (no regex metacharacters allowed) +# Matches alphanumeric names with dots, underscores, and hyphens +VALID_PROCESS_NAME_PATTERN = re.compile(r"^[A-Za-z0-9._-]+$") + # Allowed commands for development tasks # Minimal set needed for the autonomous coding demo ALLOWED_COMMANDS = { @@ -219,23 +224,37 @@ def extract_commands(command_string: str) -> list[str]: return commands -def validate_pkill_command(command_string: str) -> tuple[bool, str]: +# Default pkill process names (hardcoded baseline, always available) +DEFAULT_PKILL_PROCESSES = { + "node", + "npm", + "npx", + "vite", + "next", +} + + +def validate_pkill_command( + command_string: str, + extra_processes: Optional[set[str]] = None +) -> tuple[bool, str]: """ Validate pkill commands - only allow killing dev-related processes. Uses shlex to parse the command, avoiding regex bypass vulnerabilities. + Args: + command_string: The pkill command to validate + extra_processes: Optional set of additional process names to allow + (from org/project config pkill_processes) + Returns: Tuple of (is_allowed, reason_if_blocked) """ - # Allowed process names for pkill - allowed_process_names = { - "node", - "npm", - "npx", - "vite", - "next", - } + # Merge default processes with any extra configured processes + allowed_process_names = DEFAULT_PKILL_PROCESSES.copy() + if extra_processes: + allowed_process_names |= extra_processes try: tokens = shlex.split(command_string) @@ -254,17 +273,19 @@ def validate_pkill_command(command_string: str) -> tuple[bool, str]: if not args: return False, "pkill requires a process name" - # The target is typically the last non-flag argument - target = args[-1] + # Validate every non-flag argument (pkill accepts multiple patterns on BSD) + # This defensively ensures no disallowed process can be targeted + targets = [] + for arg in args: + # For -f flag (full command line match), take the first word as process name + # e.g., "pkill -f 'node server.js'" -> target is "node server.js", process is "node" + t = arg.split()[0] if " " in arg else arg + targets.append(t) - # For -f flag (full command line match), extract the first word as process name - # e.g., "pkill -f 'node server.js'" -> target is "node server.js", process is "node" - if " " in target: - target = target.split()[0] - - if target in allowed_process_names: + disallowed = [t for t in targets if t not in allowed_process_names] + if not disallowed: return True, "" - return False, f"pkill only allowed for dev processes: {allowed_process_names}" + return False, f"pkill only allowed for processes: {sorted(allowed_process_names)}" def validate_chmod_command(command_string: str) -> tuple[bool, str]: @@ -455,6 +476,23 @@ def load_org_config() -> Optional[dict]: if not isinstance(cmd, str): return None + # Validate pkill_processes if present + if "pkill_processes" in config: + processes = config["pkill_processes"] + if not isinstance(processes, list): + return None + # Normalize and validate each process name against safe pattern + normalized = [] + for proc in processes: + if not isinstance(proc, str): + return None + proc = proc.strip() + # Block empty strings and regex metacharacters + if not proc or not VALID_PROCESS_NAME_PATTERN.fullmatch(proc): + return None + normalized.append(proc) + config["pkill_processes"] = normalized + return config except (yaml.YAMLError, IOError, OSError): @@ -508,6 +546,23 @@ def load_project_commands(project_dir: Path) -> Optional[dict]: if not isinstance(cmd["name"], str): return None + # Validate pkill_processes if present + if "pkill_processes" in config: + processes = config["pkill_processes"] + if not isinstance(processes, list): + return None + # Normalize and validate each process name against safe pattern + normalized = [] + for proc in processes: + if not isinstance(proc, str): + return None + proc = proc.strip() + # Block empty strings and regex metacharacters + if not proc or not VALID_PROCESS_NAME_PATTERN.fullmatch(proc): + return None + normalized.append(proc) + config["pkill_processes"] = normalized + return config except (yaml.YAMLError, IOError, OSError): @@ -628,6 +683,42 @@ def get_project_allowed_commands(project_dir: Optional[Path]) -> set[str]: return allowed +def get_effective_pkill_processes(project_dir: Optional[Path]) -> set[str]: + """ + Get effective pkill process names after hierarchy resolution. + + Merges processes from: + 1. DEFAULT_PKILL_PROCESSES (hardcoded baseline) + 2. Org config pkill_processes + 3. Project config pkill_processes + + Args: + project_dir: Path to the project directory, or None + + Returns: + Set of allowed process names for pkill + """ + # Start with default processes + processes = DEFAULT_PKILL_PROCESSES.copy() + + # Add org-level pkill_processes + org_config = load_org_config() + if org_config: + org_processes = org_config.get("pkill_processes", []) + if isinstance(org_processes, list): + processes |= {p for p in org_processes if isinstance(p, str) and p.strip()} + + # Add project-level pkill_processes + if project_dir: + project_config = load_project_commands(project_dir) + if project_config: + proj_processes = project_config.get("pkill_processes", []) + if isinstance(proj_processes, list): + processes |= {p for p in proj_processes if isinstance(p, str) and p.strip()} + + return processes + + def is_command_allowed(command: str, allowed_commands: set[str]) -> bool: """ Check if a command is allowed (supports patterns). @@ -692,6 +783,9 @@ async def bash_security_hook(input_data, tool_use_id=None, context=None): # Get effective commands using hierarchy resolution allowed_commands, blocked_commands = get_effective_commands(project_dir) + # Get effective pkill processes (includes org/project config) + pkill_processes = get_effective_pkill_processes(project_dir) + # Split into segments for per-command validation segments = split_command_segments(command) @@ -725,7 +819,9 @@ async def bash_security_hook(input_data, tool_use_id=None, context=None): cmd_segment = command # Fallback to full command if cmd == "pkill": - allowed, reason = validate_pkill_command(cmd_segment) + # Pass configured extra processes (beyond defaults) + extra_procs = pkill_processes - DEFAULT_PKILL_PROCESSES + allowed, reason = validate_pkill_command(cmd_segment, extra_procs if extra_procs else None) if not allowed: return {"decision": "block", "reason": reason} elif cmd == "chmod": diff --git a/test_security.py b/test_security.py index 5b46cfe..5068e1e 100644 --- a/test_security.py +++ b/test_security.py @@ -15,14 +15,17 @@ from contextlib import contextmanager from pathlib import Path from security import ( + DEFAULT_PKILL_PROCESSES, bash_security_hook, extract_commands, get_effective_commands, + get_effective_pkill_processes, load_org_config, load_project_commands, matches_pattern, validate_chmod_command, validate_init_script, + validate_pkill_command, validate_project_command, ) @@ -670,6 +673,240 @@ blocked_commands: return passed, failed +def test_pkill_extensibility(): + """Test that pkill processes can be extended via config.""" + print("\nTesting pkill process extensibility:\n") + passed = 0 + failed = 0 + + # Test 1: Default processes work without config + allowed, reason = validate_pkill_command("pkill node") + if allowed: + print(" PASS: Default process 'node' allowed") + passed += 1 + else: + print(f" FAIL: Default process 'node' should be allowed: {reason}") + failed += 1 + + # Test 2: Non-default process blocked without config + allowed, reason = validate_pkill_command("pkill python") + if not allowed: + print(" PASS: Non-default process 'python' blocked without config") + passed += 1 + else: + print(" FAIL: Non-default process 'python' should be blocked without config") + failed += 1 + + # Test 3: Extra processes allowed when passed + allowed, reason = validate_pkill_command("pkill python", extra_processes={"python"}) + if allowed: + print(" PASS: Extra process 'python' allowed when configured") + passed += 1 + else: + print(f" FAIL: Extra process 'python' should be allowed when configured: {reason}") + failed += 1 + + # Test 4: Default processes still work with extra processes + allowed, reason = validate_pkill_command("pkill npm", extra_processes={"python"}) + if allowed: + print(" PASS: Default process 'npm' still works with extra processes") + passed += 1 + else: + print(f" FAIL: Default process should still work: {reason}") + failed += 1 + + # Test 5: Test get_effective_pkill_processes with org config + with tempfile.TemporaryDirectory() as tmphome: + with tempfile.TemporaryDirectory() as tmpproject: + with temporary_home(tmphome): + org_dir = Path(tmphome) / ".autocoder" + org_dir.mkdir() + org_config_path = org_dir / "config.yaml" + + # Create org config with extra pkill processes + org_config_path.write_text("""version: 1 +pkill_processes: + - python + - uvicorn +""") + + project_dir = Path(tmpproject) + processes = get_effective_pkill_processes(project_dir) + + # Should include defaults + org processes + if "node" in processes and "python" in processes and "uvicorn" in processes: + print(" PASS: Org pkill_processes merged with defaults") + passed += 1 + else: + print(f" FAIL: Expected node, python, uvicorn in {processes}") + failed += 1 + + # Test 6: Test get_effective_pkill_processes with project config + with tempfile.TemporaryDirectory() as tmphome: + with tempfile.TemporaryDirectory() as tmpproject: + with temporary_home(tmphome): + project_dir = Path(tmpproject) + project_autocoder = project_dir / ".autocoder" + project_autocoder.mkdir() + project_config = project_autocoder / "allowed_commands.yaml" + + # Create project config with extra pkill processes + project_config.write_text("""version: 1 +commands: [] +pkill_processes: + - gunicorn + - flask +""") + + processes = get_effective_pkill_processes(project_dir) + + # Should include defaults + project processes + if "node" in processes and "gunicorn" in processes and "flask" in processes: + print(" PASS: Project pkill_processes merged with defaults") + passed += 1 + else: + print(f" FAIL: Expected node, gunicorn, flask in {processes}") + failed += 1 + + # Test 7: Integration test - pkill python blocked by default + with tempfile.TemporaryDirectory() as tmphome: + with tempfile.TemporaryDirectory() as tmpproject: + with temporary_home(tmphome): + project_dir = Path(tmpproject) + input_data = {"tool_name": "Bash", "tool_input": {"command": "pkill python"}} + context = {"project_dir": str(project_dir)} + result = asyncio.run(bash_security_hook(input_data, context=context)) + + if result.get("decision") == "block": + print(" PASS: pkill python blocked without config") + passed += 1 + else: + print(" FAIL: pkill python should be blocked without config") + failed += 1 + + # Test 8: Integration test - pkill python allowed with org config + with tempfile.TemporaryDirectory() as tmphome: + with tempfile.TemporaryDirectory() as tmpproject: + with temporary_home(tmphome): + org_dir = Path(tmphome) / ".autocoder" + org_dir.mkdir() + org_config_path = org_dir / "config.yaml" + + org_config_path.write_text("""version: 1 +pkill_processes: + - python +""") + + project_dir = Path(tmpproject) + input_data = {"tool_name": "Bash", "tool_input": {"command": "pkill python"}} + context = {"project_dir": str(project_dir)} + result = asyncio.run(bash_security_hook(input_data, context=context)) + + if result.get("decision") != "block": + print(" PASS: pkill python allowed with org config") + passed += 1 + else: + print(f" FAIL: pkill python should be allowed with org config: {result}") + failed += 1 + + # Test 9: Regex metacharacters should be rejected in pkill_processes + with tempfile.TemporaryDirectory() as tmphome: + with tempfile.TemporaryDirectory() as tmpproject: + with temporary_home(tmphome): + org_dir = Path(tmphome) / ".autocoder" + org_dir.mkdir() + org_config_path = org_dir / "config.yaml" + + # Try to register a regex pattern (should be rejected) + org_config_path.write_text("""version: 1 +pkill_processes: + - ".*" +""") + + config = load_org_config() + if config is None: + print(" PASS: Regex pattern '.*' rejected in pkill_processes") + passed += 1 + else: + print(" FAIL: Regex pattern '.*' should be rejected") + failed += 1 + + # Test 10: Valid process names with dots/underscores/hyphens should be accepted + with tempfile.TemporaryDirectory() as tmphome: + with tempfile.TemporaryDirectory() as tmpproject: + with temporary_home(tmphome): + org_dir = Path(tmphome) / ".autocoder" + org_dir.mkdir() + org_config_path = org_dir / "config.yaml" + + # Valid names with special chars + org_config_path.write_text("""version: 1 +pkill_processes: + - my-app + - app_server + - node.js +""") + + config = load_org_config() + if config is not None and config.get("pkill_processes") == ["my-app", "app_server", "node.js"]: + print(" PASS: Valid process names with dots/underscores/hyphens accepted") + passed += 1 + else: + print(f" FAIL: Valid process names should be accepted: {config}") + failed += 1 + + # Test 11: Names with spaces should be rejected + with tempfile.TemporaryDirectory() as tmphome: + with tempfile.TemporaryDirectory() as tmpproject: + with temporary_home(tmphome): + org_dir = Path(tmphome) / ".autocoder" + org_dir.mkdir() + org_config_path = org_dir / "config.yaml" + + org_config_path.write_text("""version: 1 +pkill_processes: + - "my app" +""") + + config = load_org_config() + if config is None: + print(" PASS: Process name with space rejected") + passed += 1 + else: + print(" FAIL: Process name with space should be rejected") + failed += 1 + + # Test 12: Multiple patterns - all must be allowed (BSD behavior) + # On BSD, "pkill node sshd" would kill both, so we must validate all patterns + allowed, reason = validate_pkill_command("pkill node npm") + if allowed: + print(" PASS: Multiple allowed patterns accepted") + passed += 1 + else: + print(f" FAIL: Multiple allowed patterns should be accepted: {reason}") + failed += 1 + + # Test 13: Multiple patterns - block if any is disallowed + allowed, reason = validate_pkill_command("pkill node sshd") + if not allowed: + print(" PASS: Multiple patterns blocked when one is disallowed") + passed += 1 + else: + print(" FAIL: Should block when any pattern is disallowed") + failed += 1 + + # Test 14: Multiple patterns - only first allowed, second disallowed + allowed, reason = validate_pkill_command("pkill npm python") + if not allowed: + print(" PASS: Multiple patterns blocked (first allowed, second not)") + passed += 1 + else: + print(" FAIL: Should block when second pattern is disallowed") + failed += 1 + + return passed, failed + + def main(): print("=" * 70) print(" SECURITY HOOK TESTS") @@ -733,6 +970,11 @@ def main(): passed += org_block_passed failed += org_block_failed + # Test pkill process extensibility + pkill_passed, pkill_failed = test_pkill_extensibility() + passed += pkill_passed + failed += pkill_failed + # Commands that SHOULD be blocked print("\nCommands that should be BLOCKED:\n") dangerous = [