fix: add Windows compatibility to security unit tests

Add cross-platform temporary_home() context manager to handle
environment variable differences between Unix and Windows systems.

Changes:
- Add temporary_home() context manager that handles both HOME (Unix)
  and USERPROFILE/HOMEDRIVE/HOMEPATH (Windows) environment variables
- Update test_org_config_loading() to use temporary_home()
- Update test_hierarchy_resolution() to use temporary_home()
- Update test_org_blocklist_enforcement() to use temporary_home()
- Add missing imports: os, contextmanager

Why: The unit tests for org config loading were failing on Windows
because they only set the HOME environment variable, but Windows
uses USERPROFILE instead. The integration tests already had this
fix via a similar context manager.

Result: All 148 unit tests now pass on both Windows and Unix systems.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Auto
2026-01-23 12:24:50 +02:00
parent 1fe47736cc
commit b21d2e3adc

View File

@@ -8,8 +8,10 @@ Run with: python test_security.py
""" """
import asyncio import asyncio
import os
import sys import sys
import tempfile import tempfile
from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from security import ( from security import (
@@ -25,6 +27,48 @@ from security import (
) )
@contextmanager
def temporary_home(home_path):
"""
Context manager to temporarily set HOME (and Windows equivalents).
Saves original environment variables and restores them on exit,
even if an exception occurs.
Args:
home_path: Path to use as temporary home directory
"""
# Save original values for Unix and Windows
saved_env = {
"HOME": os.environ.get("HOME"),
"USERPROFILE": os.environ.get("USERPROFILE"),
"HOMEDRIVE": os.environ.get("HOMEDRIVE"),
"HOMEPATH": os.environ.get("HOMEPATH"),
}
try:
# Set new home directory for both Unix and Windows
os.environ["HOME"] = str(home_path)
if sys.platform == "win32":
os.environ["USERPROFILE"] = str(home_path)
# Note: HOMEDRIVE and HOMEPATH are typically set by Windows
# but we update them for consistency
drive, path = os.path.splitdrive(str(home_path))
if drive:
os.environ["HOMEDRIVE"] = drive
os.environ["HOMEPATH"] = path
yield
finally:
# Restore original values
for key, value in saved_env.items():
if value is None:
os.environ.pop(key, None)
else:
os.environ[key] = value
def check_hook(command: str, should_block: bool) -> bool: def check_hook(command: str, should_block: bool) -> bool:
"""Check a single command against the security hook (helper function).""" """Check a single command against the security hook (helper function)."""
input_data = {"tool_name": "Bash", "tool_input": {"command": command}} input_data = {"tool_name": "Bash", "tool_input": {"command": command}}
@@ -416,20 +460,15 @@ def test_org_config_loading():
passed = 0 passed = 0
failed = 0 failed = 0
# Save original org config path
original_home = Path.home()
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
# Temporarily override home directory for testing # Use temporary_home for cross-platform compatibility
import os with temporary_home(tmpdir):
os.environ["HOME"] = tmpdir org_dir = Path(tmpdir) / ".autocoder"
org_dir.mkdir()
org_config_path = org_dir / "config.yaml"
org_dir = Path(tmpdir) / ".autocoder" # Test 1: Valid org config
org_dir.mkdir() org_config_path.write_text("""version: 1
org_config_path = org_dir / "config.yaml"
# Test 1: Valid org config
org_config_path.write_text("""version: 1
allowed_commands: allowed_commands:
- name: jq - name: jq
description: JSON processor description: JSON processor
@@ -437,76 +476,73 @@ blocked_commands:
- aws - aws
- kubectl - kubectl
""") """)
config = load_org_config() config = load_org_config()
if config and config["version"] == 1: if config and config["version"] == 1:
if len(config["allowed_commands"]) == 1 and len(config["blocked_commands"]) == 2: if len(config["allowed_commands"]) == 1 and len(config["blocked_commands"]) == 2:
print(" PASS: Load valid org config") print(" PASS: Load valid org config")
passed += 1
else:
print(" FAIL: Load valid org config (wrong counts)")
failed += 1
else:
print(" FAIL: Load valid org config")
print(f" Got: {config}")
failed += 1
# Test 2: Missing file returns None
org_config_path.unlink()
config = load_org_config()
if config is None:
print(" PASS: Missing org config returns None")
passed += 1 passed += 1
else: else:
print(" FAIL: Load valid org config (wrong counts)") print(" FAIL: Missing org config returns None")
failed += 1 failed += 1
else:
print(" FAIL: Load valid org config")
print(f" Got: {config}")
failed += 1
# Test 2: Missing file returns None # Test 3: Non-string command name is rejected
org_config_path.unlink() org_config_path.write_text("""version: 1
config = load_org_config()
if config is None:
print(" PASS: Missing org config returns None")
passed += 1
else:
print(" FAIL: Missing org config returns None")
failed += 1
# Test 3: Non-string command name is rejected
org_config_path.write_text("""version: 1
allowed_commands: allowed_commands:
- name: 123 - name: 123
description: Invalid numeric name description: Invalid numeric name
""") """)
config = load_org_config() config = load_org_config()
if config is None: if config is None:
print(" PASS: Non-string command name rejected") print(" PASS: Non-string command name rejected")
passed += 1 passed += 1
else: else:
print(" FAIL: Non-string command name rejected") print(" FAIL: Non-string command name rejected")
print(f" Got: {config}") print(f" Got: {config}")
failed += 1 failed += 1
# Test 4: Empty command name is rejected # Test 4: Empty command name is rejected
org_config_path.write_text("""version: 1 org_config_path.write_text("""version: 1
allowed_commands: allowed_commands:
- name: "" - name: ""
description: Empty name description: Empty name
""") """)
config = load_org_config() config = load_org_config()
if config is None: if config is None:
print(" PASS: Empty command name rejected") print(" PASS: Empty command name rejected")
passed += 1 passed += 1
else: else:
print(" FAIL: Empty command name rejected") print(" FAIL: Empty command name rejected")
print(f" Got: {config}") print(f" Got: {config}")
failed += 1 failed += 1
# Test 5: Whitespace-only command name is rejected # Test 5: Whitespace-only command name is rejected
org_config_path.write_text("""version: 1 org_config_path.write_text("""version: 1
allowed_commands: allowed_commands:
- name: " " - name: " "
description: Whitespace name description: Whitespace name
""") """)
config = load_org_config() config = load_org_config()
if config is None: if config is None:
print(" PASS: Whitespace-only command name rejected") print(" PASS: Whitespace-only command name rejected")
passed += 1 passed += 1
else: else:
print(" FAIL: Whitespace-only command name rejected") print(" FAIL: Whitespace-only command name rejected")
print(f" Got: {config}") print(f" Got: {config}")
failed += 1 failed += 1
# Restore HOME
os.environ["HOME"] = str(original_home)
return passed, failed return passed, failed
@@ -519,17 +555,14 @@ def test_hierarchy_resolution():
with tempfile.TemporaryDirectory() as tmphome: with tempfile.TemporaryDirectory() as tmphome:
with tempfile.TemporaryDirectory() as tmpproject: with tempfile.TemporaryDirectory() as tmpproject:
# Setup fake home directory # Use temporary_home for cross-platform compatibility
import os with temporary_home(tmphome):
original_home = os.environ.get("HOME") org_dir = Path(tmphome) / ".autocoder"
os.environ["HOME"] = tmphome org_dir.mkdir()
org_config_path = org_dir / "config.yaml"
org_dir = Path(tmphome) / ".autocoder" # Create org config with allowed and blocked commands
org_dir.mkdir() org_config_path.write_text("""version: 1
org_config_path = org_dir / "config.yaml"
# Create org config with allowed and blocked commands
org_config_path.write_text("""version: 1
allowed_commands: allowed_commands:
- name: jq - name: jq
description: JSON processor description: JSON processor
@@ -540,66 +573,60 @@ blocked_commands:
- kubectl - kubectl
""") """)
project_dir = Path(tmpproject) project_dir = Path(tmpproject)
project_autocoder = project_dir / ".autocoder" project_autocoder = project_dir / ".autocoder"
project_autocoder.mkdir() project_autocoder.mkdir()
project_config = project_autocoder / "allowed_commands.yaml" project_config = project_autocoder / "allowed_commands.yaml"
# Create project config # Create project config
project_config.write_text("""version: 1 project_config.write_text("""version: 1
commands: commands:
- name: swift - name: swift
description: Swift compiler description: Swift compiler
""") """)
# Test 1: Org allowed commands are included # Test 1: Org allowed commands are included
allowed, blocked = get_effective_commands(project_dir) allowed, blocked = get_effective_commands(project_dir)
if "jq" in allowed and "python3" in allowed: if "jq" in allowed and "python3" in allowed:
print(" PASS: Org allowed commands included") print(" PASS: Org allowed commands included")
passed += 1 passed += 1
else: else:
print(" FAIL: Org allowed commands included") print(" FAIL: Org allowed commands included")
print(f" jq in allowed: {'jq' in allowed}") print(f" jq in allowed: {'jq' in allowed}")
print(f" python3 in allowed: {'python3' in allowed}") print(f" python3 in allowed: {'python3' in allowed}")
failed += 1 failed += 1
# Test 2: Org blocked commands are in blocklist # Test 2: Org blocked commands are in blocklist
if "terraform" in blocked and "kubectl" in blocked: if "terraform" in blocked and "kubectl" in blocked:
print(" PASS: Org blocked commands in blocklist") print(" PASS: Org blocked commands in blocklist")
passed += 1 passed += 1
else: else:
print(" FAIL: Org blocked commands in blocklist") print(" FAIL: Org blocked commands in blocklist")
failed += 1 failed += 1
# Test 3: Project commands are included # Test 3: Project commands are included
if "swift" in allowed: if "swift" in allowed:
print(" PASS: Project commands included") print(" PASS: Project commands included")
passed += 1 passed += 1
else: else:
print(" FAIL: Project commands included") print(" FAIL: Project commands included")
failed += 1 failed += 1
# Test 4: Global commands are included # Test 4: Global commands are included
if "npm" in allowed and "git" in allowed: if "npm" in allowed and "git" in allowed:
print(" PASS: Global commands included") print(" PASS: Global commands included")
passed += 1 passed += 1
else: else:
print(" FAIL: Global commands included") print(" FAIL: Global commands included")
failed += 1 failed += 1
# Test 5: Hardcoded blocklist cannot be overridden # Test 5: Hardcoded blocklist cannot be overridden
if "sudo" in blocked and "shutdown" in blocked: if "sudo" in blocked and "shutdown" in blocked:
print(" PASS: Hardcoded blocklist enforced") print(" PASS: Hardcoded blocklist enforced")
passed += 1 passed += 1
else: else:
print(" FAIL: Hardcoded blocklist enforced") print(" FAIL: Hardcoded blocklist enforced")
failed += 1 failed += 1
# Restore HOME
if original_home:
os.environ["HOME"] = original_home
else:
del os.environ["HOME"]
return passed, failed return passed, failed
@@ -612,42 +639,33 @@ def test_org_blocklist_enforcement():
with tempfile.TemporaryDirectory() as tmphome: with tempfile.TemporaryDirectory() as tmphome:
with tempfile.TemporaryDirectory() as tmpproject: with tempfile.TemporaryDirectory() as tmpproject:
# Setup fake home directory # Use temporary_home for cross-platform compatibility
import os with temporary_home(tmphome):
original_home = os.environ.get("HOME") org_dir = Path(tmphome) / ".autocoder"
os.environ["HOME"] = tmphome org_dir.mkdir()
org_config_path = org_dir / "config.yaml"
org_dir = Path(tmphome) / ".autocoder" # Create org config that blocks terraform
org_dir.mkdir() org_config_path.write_text("""version: 1
org_config_path = org_dir / "config.yaml"
# Create org config that blocks terraform
org_config_path.write_text("""version: 1
blocked_commands: blocked_commands:
- terraform - terraform
""") """)
project_dir = Path(tmpproject) project_dir = Path(tmpproject)
project_autocoder = project_dir / ".autocoder" project_autocoder = project_dir / ".autocoder"
project_autocoder.mkdir() project_autocoder.mkdir()
# Try to use terraform (should be blocked) # Try to use terraform (should be blocked)
input_data = {"tool_name": "Bash", "tool_input": {"command": "terraform apply"}} input_data = {"tool_name": "Bash", "tool_input": {"command": "terraform apply"}}
context = {"project_dir": str(project_dir)} context = {"project_dir": str(project_dir)}
result = asyncio.run(bash_security_hook(input_data, context=context)) result = asyncio.run(bash_security_hook(input_data, context=context))
if result.get("decision") == "block": if result.get("decision") == "block":
print(" PASS: Org blocked command 'terraform' rejected") print(" PASS: Org blocked command 'terraform' rejected")
passed += 1 passed += 1
else: else:
print(" FAIL: Org blocked command 'terraform' should be rejected") print(" FAIL: Org blocked command 'terraform' should be rejected")
failed += 1 failed += 1
# Restore HOME
if original_home:
os.environ["HOME"] = original_home
else:
del os.environ["HOME"]
return passed, failed return passed, failed