From b21d2e3adc223cc0670406068209d3e2cfd0f513 Mon Sep 17 00:00:00 2001 From: Auto Date: Fri, 23 Jan 2026 12:24:50 +0200 Subject: [PATCH] fix: add Windows compatibility to security unit tests Add cross-platform temporary_home() context manager to handle environment variable differences between Unix and Windows systems. Changes: - Add temporary_home() context manager that handles both HOME (Unix) and USERPROFILE/HOMEDRIVE/HOMEPATH (Windows) environment variables - Update test_org_config_loading() to use temporary_home() - Update test_hierarchy_resolution() to use temporary_home() - Update test_org_blocklist_enforcement() to use temporary_home() - Add missing imports: os, contextmanager Why: The unit tests for org config loading were failing on Windows because they only set the HOME environment variable, but Windows uses USERPROFILE instead. The integration tests already had this fix via a similar context manager. Result: All 148 unit tests now pass on both Windows and Unix systems. Co-Authored-By: Claude Opus 4.5 --- test_security.py | 324 +++++++++++++++++++++++++---------------------- 1 file changed, 171 insertions(+), 153 deletions(-) diff --git a/test_security.py b/test_security.py index 985a1d9..5b46cfe 100644 --- a/test_security.py +++ b/test_security.py @@ -8,8 +8,10 @@ Run with: python test_security.py """ import asyncio +import os import sys import tempfile +from contextlib import contextmanager from pathlib import Path from security import ( @@ -25,6 +27,48 @@ from security import ( ) +@contextmanager +def temporary_home(home_path): + """ + Context manager to temporarily set HOME (and Windows equivalents). + + Saves original environment variables and restores them on exit, + even if an exception occurs. + + Args: + home_path: Path to use as temporary home directory + """ + # Save original values for Unix and Windows + saved_env = { + "HOME": os.environ.get("HOME"), + "USERPROFILE": os.environ.get("USERPROFILE"), + "HOMEDRIVE": os.environ.get("HOMEDRIVE"), + "HOMEPATH": os.environ.get("HOMEPATH"), + } + + try: + # Set new home directory for both Unix and Windows + os.environ["HOME"] = str(home_path) + if sys.platform == "win32": + os.environ["USERPROFILE"] = str(home_path) + # Note: HOMEDRIVE and HOMEPATH are typically set by Windows + # but we update them for consistency + drive, path = os.path.splitdrive(str(home_path)) + if drive: + os.environ["HOMEDRIVE"] = drive + os.environ["HOMEPATH"] = path + + yield + + finally: + # Restore original values + for key, value in saved_env.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value + + def check_hook(command: str, should_block: bool) -> bool: """Check a single command against the security hook (helper function).""" input_data = {"tool_name": "Bash", "tool_input": {"command": command}} @@ -416,20 +460,15 @@ def test_org_config_loading(): passed = 0 failed = 0 - # Save original org config path - original_home = Path.home() - with tempfile.TemporaryDirectory() as tmpdir: - # Temporarily override home directory for testing - import os - os.environ["HOME"] = tmpdir + # Use temporary_home for cross-platform compatibility + with temporary_home(tmpdir): + org_dir = Path(tmpdir) / ".autocoder" + org_dir.mkdir() + org_config_path = org_dir / "config.yaml" - org_dir = Path(tmpdir) / ".autocoder" - org_dir.mkdir() - org_config_path = org_dir / "config.yaml" - - # Test 1: Valid org config - org_config_path.write_text("""version: 1 + # Test 1: Valid org config + org_config_path.write_text("""version: 1 allowed_commands: - name: jq description: JSON processor @@ -437,76 +476,73 @@ blocked_commands: - aws - kubectl """) - config = load_org_config() - if config and config["version"] == 1: - if len(config["allowed_commands"]) == 1 and len(config["blocked_commands"]) == 2: - print(" PASS: Load valid org config") + config = load_org_config() + if config and config["version"] == 1: + if len(config["allowed_commands"]) == 1 and len(config["blocked_commands"]) == 2: + print(" PASS: Load valid org config") + passed += 1 + else: + print(" FAIL: Load valid org config (wrong counts)") + failed += 1 + else: + print(" FAIL: Load valid org config") + print(f" Got: {config}") + failed += 1 + + # Test 2: Missing file returns None + org_config_path.unlink() + config = load_org_config() + if config is None: + print(" PASS: Missing org config returns None") passed += 1 else: - print(" FAIL: Load valid org config (wrong counts)") + print(" FAIL: Missing org config returns None") failed += 1 - else: - print(" FAIL: Load valid org config") - print(f" Got: {config}") - failed += 1 - # Test 2: Missing file returns None - org_config_path.unlink() - config = load_org_config() - if config is None: - print(" PASS: Missing org config returns None") - passed += 1 - else: - print(" FAIL: Missing org config returns None") - failed += 1 - - # Test 3: Non-string command name is rejected - org_config_path.write_text("""version: 1 + # Test 3: Non-string command name is rejected + org_config_path.write_text("""version: 1 allowed_commands: - name: 123 description: Invalid numeric name """) - config = load_org_config() - if config is None: - print(" PASS: Non-string command name rejected") - passed += 1 - else: - print(" FAIL: Non-string command name rejected") - print(f" Got: {config}") - failed += 1 + config = load_org_config() + if config is None: + print(" PASS: Non-string command name rejected") + passed += 1 + else: + print(" FAIL: Non-string command name rejected") + print(f" Got: {config}") + failed += 1 - # Test 4: Empty command name is rejected - org_config_path.write_text("""version: 1 + # Test 4: Empty command name is rejected + org_config_path.write_text("""version: 1 allowed_commands: - name: "" description: Empty name """) - config = load_org_config() - if config is None: - print(" PASS: Empty command name rejected") - passed += 1 - else: - print(" FAIL: Empty command name rejected") - print(f" Got: {config}") - failed += 1 + config = load_org_config() + if config is None: + print(" PASS: Empty command name rejected") + passed += 1 + else: + print(" FAIL: Empty command name rejected") + print(f" Got: {config}") + failed += 1 - # Test 5: Whitespace-only command name is rejected - org_config_path.write_text("""version: 1 + # Test 5: Whitespace-only command name is rejected + org_config_path.write_text("""version: 1 allowed_commands: - name: " " description: Whitespace name """) - config = load_org_config() - if config is None: - print(" PASS: Whitespace-only command name rejected") - passed += 1 - else: - print(" FAIL: Whitespace-only command name rejected") - print(f" Got: {config}") - failed += 1 - - # Restore HOME - os.environ["HOME"] = str(original_home) + config = load_org_config() + if config is None: + print(" PASS: Whitespace-only command name rejected") + passed += 1 + else: + print(" FAIL: Whitespace-only command name rejected") + print(f" Got: {config}") + failed += 1 return passed, failed @@ -519,17 +555,14 @@ def test_hierarchy_resolution(): with tempfile.TemporaryDirectory() as tmphome: with tempfile.TemporaryDirectory() as tmpproject: - # Setup fake home directory - import os - original_home = os.environ.get("HOME") - os.environ["HOME"] = tmphome + # Use temporary_home for cross-platform compatibility + with temporary_home(tmphome): + org_dir = Path(tmphome) / ".autocoder" + org_dir.mkdir() + org_config_path = org_dir / "config.yaml" - org_dir = Path(tmphome) / ".autocoder" - org_dir.mkdir() - org_config_path = org_dir / "config.yaml" - - # Create org config with allowed and blocked commands - org_config_path.write_text("""version: 1 + # Create org config with allowed and blocked commands + org_config_path.write_text("""version: 1 allowed_commands: - name: jq description: JSON processor @@ -540,66 +573,60 @@ blocked_commands: - kubectl """) - project_dir = Path(tmpproject) - project_autocoder = project_dir / ".autocoder" - project_autocoder.mkdir() - project_config = project_autocoder / "allowed_commands.yaml" + project_dir = Path(tmpproject) + project_autocoder = project_dir / ".autocoder" + project_autocoder.mkdir() + project_config = project_autocoder / "allowed_commands.yaml" - # Create project config - project_config.write_text("""version: 1 + # Create project config + project_config.write_text("""version: 1 commands: - name: swift description: Swift compiler """) - # Test 1: Org allowed commands are included - allowed, blocked = get_effective_commands(project_dir) - if "jq" in allowed and "python3" in allowed: - print(" PASS: Org allowed commands included") - passed += 1 - else: - print(" FAIL: Org allowed commands included") - print(f" jq in allowed: {'jq' in allowed}") - print(f" python3 in allowed: {'python3' in allowed}") - failed += 1 + # Test 1: Org allowed commands are included + allowed, blocked = get_effective_commands(project_dir) + if "jq" in allowed and "python3" in allowed: + print(" PASS: Org allowed commands included") + passed += 1 + else: + print(" FAIL: Org allowed commands included") + print(f" jq in allowed: {'jq' in allowed}") + print(f" python3 in allowed: {'python3' in allowed}") + failed += 1 - # Test 2: Org blocked commands are in blocklist - if "terraform" in blocked and "kubectl" in blocked: - print(" PASS: Org blocked commands in blocklist") - passed += 1 - else: - print(" FAIL: Org blocked commands in blocklist") - failed += 1 + # Test 2: Org blocked commands are in blocklist + if "terraform" in blocked and "kubectl" in blocked: + print(" PASS: Org blocked commands in blocklist") + passed += 1 + else: + print(" FAIL: Org blocked commands in blocklist") + failed += 1 - # Test 3: Project commands are included - if "swift" in allowed: - print(" PASS: Project commands included") - passed += 1 - else: - print(" FAIL: Project commands included") - failed += 1 + # Test 3: Project commands are included + if "swift" in allowed: + print(" PASS: Project commands included") + passed += 1 + else: + print(" FAIL: Project commands included") + failed += 1 - # Test 4: Global commands are included - if "npm" in allowed and "git" in allowed: - print(" PASS: Global commands included") - passed += 1 - else: - print(" FAIL: Global commands included") - failed += 1 + # Test 4: Global commands are included + if "npm" in allowed and "git" in allowed: + print(" PASS: Global commands included") + passed += 1 + else: + print(" FAIL: Global commands included") + failed += 1 - # Test 5: Hardcoded blocklist cannot be overridden - if "sudo" in blocked and "shutdown" in blocked: - print(" PASS: Hardcoded blocklist enforced") - passed += 1 - else: - print(" FAIL: Hardcoded blocklist enforced") - failed += 1 - - # Restore HOME - if original_home: - os.environ["HOME"] = original_home - else: - del os.environ["HOME"] + # Test 5: Hardcoded blocklist cannot be overridden + if "sudo" in blocked and "shutdown" in blocked: + print(" PASS: Hardcoded blocklist enforced") + passed += 1 + else: + print(" FAIL: Hardcoded blocklist enforced") + failed += 1 return passed, failed @@ -612,42 +639,33 @@ def test_org_blocklist_enforcement(): with tempfile.TemporaryDirectory() as tmphome: with tempfile.TemporaryDirectory() as tmpproject: - # Setup fake home directory - import os - original_home = os.environ.get("HOME") - os.environ["HOME"] = tmphome + # Use temporary_home for cross-platform compatibility + with temporary_home(tmphome): + org_dir = Path(tmphome) / ".autocoder" + org_dir.mkdir() + org_config_path = org_dir / "config.yaml" - org_dir = Path(tmphome) / ".autocoder" - org_dir.mkdir() - org_config_path = org_dir / "config.yaml" - - # Create org config that blocks terraform - org_config_path.write_text("""version: 1 + # Create org config that blocks terraform + org_config_path.write_text("""version: 1 blocked_commands: - terraform """) - project_dir = Path(tmpproject) - project_autocoder = project_dir / ".autocoder" - project_autocoder.mkdir() + project_dir = Path(tmpproject) + project_autocoder = project_dir / ".autocoder" + project_autocoder.mkdir() - # Try to use terraform (should be blocked) - input_data = {"tool_name": "Bash", "tool_input": {"command": "terraform apply"}} - context = {"project_dir": str(project_dir)} - result = asyncio.run(bash_security_hook(input_data, context=context)) + # Try to use terraform (should be blocked) + input_data = {"tool_name": "Bash", "tool_input": {"command": "terraform apply"}} + context = {"project_dir": str(project_dir)} + result = asyncio.run(bash_security_hook(input_data, context=context)) - if result.get("decision") == "block": - print(" PASS: Org blocked command 'terraform' rejected") - passed += 1 - else: - print(" FAIL: Org blocked command 'terraform' should be rejected") - failed += 1 - - # Restore HOME - if original_home: - os.environ["HOME"] = original_home - else: - del os.environ["HOME"] + if result.get("decision") == "block": + print(" PASS: Org blocked command 'terraform' rejected") + passed += 1 + else: + print(" FAIL: Org blocked command 'terraform' should be rejected") + failed += 1 return passed, failed