diff --git a/test_security.py b/test_security.py index 985a1d9..5b46cfe 100644 --- a/test_security.py +++ b/test_security.py @@ -8,8 +8,10 @@ Run with: python test_security.py """ import asyncio +import os import sys import tempfile +from contextlib import contextmanager from pathlib import Path from security import ( @@ -25,6 +27,48 @@ from security import ( ) +@contextmanager +def temporary_home(home_path): + """ + Context manager to temporarily set HOME (and Windows equivalents). + + Saves original environment variables and restores them on exit, + even if an exception occurs. + + Args: + home_path: Path to use as temporary home directory + """ + # Save original values for Unix and Windows + saved_env = { + "HOME": os.environ.get("HOME"), + "USERPROFILE": os.environ.get("USERPROFILE"), + "HOMEDRIVE": os.environ.get("HOMEDRIVE"), + "HOMEPATH": os.environ.get("HOMEPATH"), + } + + try: + # Set new home directory for both Unix and Windows + os.environ["HOME"] = str(home_path) + if sys.platform == "win32": + os.environ["USERPROFILE"] = str(home_path) + # Note: HOMEDRIVE and HOMEPATH are typically set by Windows + # but we update them for consistency + drive, path = os.path.splitdrive(str(home_path)) + if drive: + os.environ["HOMEDRIVE"] = drive + os.environ["HOMEPATH"] = path + + yield + + finally: + # Restore original values + for key, value in saved_env.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value + + def check_hook(command: str, should_block: bool) -> bool: """Check a single command against the security hook (helper function).""" input_data = {"tool_name": "Bash", "tool_input": {"command": command}} @@ -416,20 +460,15 @@ def test_org_config_loading(): passed = 0 failed = 0 - # Save original org config path - original_home = Path.home() - with tempfile.TemporaryDirectory() as tmpdir: - # Temporarily override home directory for testing - import os - os.environ["HOME"] = tmpdir + # Use temporary_home for cross-platform compatibility + with temporary_home(tmpdir): + org_dir = Path(tmpdir) / ".autocoder" + org_dir.mkdir() + org_config_path = org_dir / "config.yaml" - org_dir = Path(tmpdir) / ".autocoder" - org_dir.mkdir() - org_config_path = org_dir / "config.yaml" - - # Test 1: Valid org config - org_config_path.write_text("""version: 1 + # Test 1: Valid org config + org_config_path.write_text("""version: 1 allowed_commands: - name: jq description: JSON processor @@ -437,76 +476,73 @@ blocked_commands: - aws - kubectl """) - config = load_org_config() - if config and config["version"] == 1: - if len(config["allowed_commands"]) == 1 and len(config["blocked_commands"]) == 2: - print(" PASS: Load valid org config") + config = load_org_config() + if config and config["version"] == 1: + if len(config["allowed_commands"]) == 1 and len(config["blocked_commands"]) == 2: + print(" PASS: Load valid org config") + passed += 1 + else: + print(" FAIL: Load valid org config (wrong counts)") + failed += 1 + else: + print(" FAIL: Load valid org config") + print(f" Got: {config}") + failed += 1 + + # Test 2: Missing file returns None + org_config_path.unlink() + config = load_org_config() + if config is None: + print(" PASS: Missing org config returns None") passed += 1 else: - print(" FAIL: Load valid org config (wrong counts)") + print(" FAIL: Missing org config returns None") failed += 1 - else: - print(" FAIL: Load valid org config") - print(f" Got: {config}") - failed += 1 - # Test 2: Missing file returns None - org_config_path.unlink() - config = load_org_config() - if config is None: - print(" PASS: Missing org config returns None") - passed += 1 - else: - print(" FAIL: Missing org config returns None") - failed += 1 - - # Test 3: Non-string command name is rejected - org_config_path.write_text("""version: 1 + # Test 3: Non-string command name is rejected + org_config_path.write_text("""version: 1 allowed_commands: - name: 123 description: Invalid numeric name """) - config = load_org_config() - if config is None: - print(" PASS: Non-string command name rejected") - passed += 1 - else: - print(" FAIL: Non-string command name rejected") - print(f" Got: {config}") - failed += 1 + config = load_org_config() + if config is None: + print(" PASS: Non-string command name rejected") + passed += 1 + else: + print(" FAIL: Non-string command name rejected") + print(f" Got: {config}") + failed += 1 - # Test 4: Empty command name is rejected - org_config_path.write_text("""version: 1 + # Test 4: Empty command name is rejected + org_config_path.write_text("""version: 1 allowed_commands: - name: "" description: Empty name """) - config = load_org_config() - if config is None: - print(" PASS: Empty command name rejected") - passed += 1 - else: - print(" FAIL: Empty command name rejected") - print(f" Got: {config}") - failed += 1 + config = load_org_config() + if config is None: + print(" PASS: Empty command name rejected") + passed += 1 + else: + print(" FAIL: Empty command name rejected") + print(f" Got: {config}") + failed += 1 - # Test 5: Whitespace-only command name is rejected - org_config_path.write_text("""version: 1 + # Test 5: Whitespace-only command name is rejected + org_config_path.write_text("""version: 1 allowed_commands: - name: " " description: Whitespace name """) - config = load_org_config() - if config is None: - print(" PASS: Whitespace-only command name rejected") - passed += 1 - else: - print(" FAIL: Whitespace-only command name rejected") - print(f" Got: {config}") - failed += 1 - - # Restore HOME - os.environ["HOME"] = str(original_home) + config = load_org_config() + if config is None: + print(" PASS: Whitespace-only command name rejected") + passed += 1 + else: + print(" FAIL: Whitespace-only command name rejected") + print(f" Got: {config}") + failed += 1 return passed, failed @@ -519,17 +555,14 @@ def test_hierarchy_resolution(): with tempfile.TemporaryDirectory() as tmphome: with tempfile.TemporaryDirectory() as tmpproject: - # Setup fake home directory - import os - original_home = os.environ.get("HOME") - os.environ["HOME"] = tmphome + # Use temporary_home for cross-platform compatibility + with temporary_home(tmphome): + org_dir = Path(tmphome) / ".autocoder" + org_dir.mkdir() + org_config_path = org_dir / "config.yaml" - org_dir = Path(tmphome) / ".autocoder" - org_dir.mkdir() - org_config_path = org_dir / "config.yaml" - - # Create org config with allowed and blocked commands - org_config_path.write_text("""version: 1 + # Create org config with allowed and blocked commands + org_config_path.write_text("""version: 1 allowed_commands: - name: jq description: JSON processor @@ -540,66 +573,60 @@ blocked_commands: - kubectl """) - project_dir = Path(tmpproject) - project_autocoder = project_dir / ".autocoder" - project_autocoder.mkdir() - project_config = project_autocoder / "allowed_commands.yaml" + project_dir = Path(tmpproject) + project_autocoder = project_dir / ".autocoder" + project_autocoder.mkdir() + project_config = project_autocoder / "allowed_commands.yaml" - # Create project config - project_config.write_text("""version: 1 + # Create project config + project_config.write_text("""version: 1 commands: - name: swift description: Swift compiler """) - # Test 1: Org allowed commands are included - allowed, blocked = get_effective_commands(project_dir) - if "jq" in allowed and "python3" in allowed: - print(" PASS: Org allowed commands included") - passed += 1 - else: - print(" FAIL: Org allowed commands included") - print(f" jq in allowed: {'jq' in allowed}") - print(f" python3 in allowed: {'python3' in allowed}") - failed += 1 + # Test 1: Org allowed commands are included + allowed, blocked = get_effective_commands(project_dir) + if "jq" in allowed and "python3" in allowed: + print(" PASS: Org allowed commands included") + passed += 1 + else: + print(" FAIL: Org allowed commands included") + print(f" jq in allowed: {'jq' in allowed}") + print(f" python3 in allowed: {'python3' in allowed}") + failed += 1 - # Test 2: Org blocked commands are in blocklist - if "terraform" in blocked and "kubectl" in blocked: - print(" PASS: Org blocked commands in blocklist") - passed += 1 - else: - print(" FAIL: Org blocked commands in blocklist") - failed += 1 + # Test 2: Org blocked commands are in blocklist + if "terraform" in blocked and "kubectl" in blocked: + print(" PASS: Org blocked commands in blocklist") + passed += 1 + else: + print(" FAIL: Org blocked commands in blocklist") + failed += 1 - # Test 3: Project commands are included - if "swift" in allowed: - print(" PASS: Project commands included") - passed += 1 - else: - print(" FAIL: Project commands included") - failed += 1 + # Test 3: Project commands are included + if "swift" in allowed: + print(" PASS: Project commands included") + passed += 1 + else: + print(" FAIL: Project commands included") + failed += 1 - # Test 4: Global commands are included - if "npm" in allowed and "git" in allowed: - print(" PASS: Global commands included") - passed += 1 - else: - print(" FAIL: Global commands included") - failed += 1 + # Test 4: Global commands are included + if "npm" in allowed and "git" in allowed: + print(" PASS: Global commands included") + passed += 1 + else: + print(" FAIL: Global commands included") + failed += 1 - # Test 5: Hardcoded blocklist cannot be overridden - if "sudo" in blocked and "shutdown" in blocked: - print(" PASS: Hardcoded blocklist enforced") - passed += 1 - else: - print(" FAIL: Hardcoded blocklist enforced") - failed += 1 - - # Restore HOME - if original_home: - os.environ["HOME"] = original_home - else: - del os.environ["HOME"] + # Test 5: Hardcoded blocklist cannot be overridden + if "sudo" in blocked and "shutdown" in blocked: + print(" PASS: Hardcoded blocklist enforced") + passed += 1 + else: + print(" FAIL: Hardcoded blocklist enforced") + failed += 1 return passed, failed @@ -612,42 +639,33 @@ def test_org_blocklist_enforcement(): with tempfile.TemporaryDirectory() as tmphome: with tempfile.TemporaryDirectory() as tmpproject: - # Setup fake home directory - import os - original_home = os.environ.get("HOME") - os.environ["HOME"] = tmphome + # Use temporary_home for cross-platform compatibility + with temporary_home(tmphome): + org_dir = Path(tmphome) / ".autocoder" + org_dir.mkdir() + org_config_path = org_dir / "config.yaml" - org_dir = Path(tmphome) / ".autocoder" - org_dir.mkdir() - org_config_path = org_dir / "config.yaml" - - # Create org config that blocks terraform - org_config_path.write_text("""version: 1 + # Create org config that blocks terraform + org_config_path.write_text("""version: 1 blocked_commands: - terraform """) - project_dir = Path(tmpproject) - project_autocoder = project_dir / ".autocoder" - project_autocoder.mkdir() + project_dir = Path(tmpproject) + project_autocoder = project_dir / ".autocoder" + project_autocoder.mkdir() - # Try to use terraform (should be blocked) - input_data = {"tool_name": "Bash", "tool_input": {"command": "terraform apply"}} - context = {"project_dir": str(project_dir)} - result = asyncio.run(bash_security_hook(input_data, context=context)) + # Try to use terraform (should be blocked) + input_data = {"tool_name": "Bash", "tool_input": {"command": "terraform apply"}} + context = {"project_dir": str(project_dir)} + result = asyncio.run(bash_security_hook(input_data, context=context)) - if result.get("decision") == "block": - print(" PASS: Org blocked command 'terraform' rejected") - passed += 1 - else: - print(" FAIL: Org blocked command 'terraform' should be rejected") - failed += 1 - - # Restore HOME - if original_home: - os.environ["HOME"] = original_home - else: - del os.environ["HOME"] + if result.get("decision") == "block": + print(" PASS: Org blocked command 'terraform' rejected") + passed += 1 + else: + print(" FAIL: Org blocked command 'terraform' should be rejected") + failed += 1 return passed, failed