fix: add Windows compatibility to security unit tests

Add cross-platform temporary_home() context manager to handle
environment variable differences between Unix and Windows systems.

Changes:
- Add temporary_home() context manager that handles both HOME (Unix)
  and USERPROFILE/HOMEDRIVE/HOMEPATH (Windows) environment variables
- Update test_org_config_loading() to use temporary_home()
- Update test_hierarchy_resolution() to use temporary_home()
- Update test_org_blocklist_enforcement() to use temporary_home()
- Add missing imports: os, contextmanager

Why: The unit tests for org config loading were failing on Windows
because they only set the HOME environment variable, but Windows
uses USERPROFILE instead. The integration tests already had this
fix via a similar context manager.

Result: All 148 unit tests now pass on both Windows and Unix systems.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Auto
2026-01-23 12:24:50 +02:00
parent 1fe47736cc
commit b21d2e3adc

View File

@@ -8,8 +8,10 @@ Run with: python test_security.py
"""
import asyncio
import os
import sys
import tempfile
from contextlib import contextmanager
from pathlib import Path
from security import (
@@ -25,6 +27,48 @@ from security import (
)
@contextmanager
def temporary_home(home_path):
"""
Context manager to temporarily set HOME (and Windows equivalents).
Saves original environment variables and restores them on exit,
even if an exception occurs.
Args:
home_path: Path to use as temporary home directory
"""
# Save original values for Unix and Windows
saved_env = {
"HOME": os.environ.get("HOME"),
"USERPROFILE": os.environ.get("USERPROFILE"),
"HOMEDRIVE": os.environ.get("HOMEDRIVE"),
"HOMEPATH": os.environ.get("HOMEPATH"),
}
try:
# Set new home directory for both Unix and Windows
os.environ["HOME"] = str(home_path)
if sys.platform == "win32":
os.environ["USERPROFILE"] = str(home_path)
# Note: HOMEDRIVE and HOMEPATH are typically set by Windows
# but we update them for consistency
drive, path = os.path.splitdrive(str(home_path))
if drive:
os.environ["HOMEDRIVE"] = drive
os.environ["HOMEPATH"] = path
yield
finally:
# Restore original values
for key, value in saved_env.items():
if value is None:
os.environ.pop(key, None)
else:
os.environ[key] = value
def check_hook(command: str, should_block: bool) -> bool:
"""Check a single command against the security hook (helper function)."""
input_data = {"tool_name": "Bash", "tool_input": {"command": command}}
@@ -416,20 +460,15 @@ def test_org_config_loading():
passed = 0
failed = 0
# Save original org config path
original_home = Path.home()
with tempfile.TemporaryDirectory() as tmpdir:
# Temporarily override home directory for testing
import os
os.environ["HOME"] = tmpdir
# Use temporary_home for cross-platform compatibility
with temporary_home(tmpdir):
org_dir = Path(tmpdir) / ".autocoder"
org_dir.mkdir()
org_config_path = org_dir / "config.yaml"
org_dir = Path(tmpdir) / ".autocoder"
org_dir.mkdir()
org_config_path = org_dir / "config.yaml"
# Test 1: Valid org config
org_config_path.write_text("""version: 1
# Test 1: Valid org config
org_config_path.write_text("""version: 1
allowed_commands:
- name: jq
description: JSON processor
@@ -437,76 +476,73 @@ blocked_commands:
- aws
- kubectl
""")
config = load_org_config()
if config and config["version"] == 1:
if len(config["allowed_commands"]) == 1 and len(config["blocked_commands"]) == 2:
print(" PASS: Load valid org config")
config = load_org_config()
if config and config["version"] == 1:
if len(config["allowed_commands"]) == 1 and len(config["blocked_commands"]) == 2:
print(" PASS: Load valid org config")
passed += 1
else:
print(" FAIL: Load valid org config (wrong counts)")
failed += 1
else:
print(" FAIL: Load valid org config")
print(f" Got: {config}")
failed += 1
# Test 2: Missing file returns None
org_config_path.unlink()
config = load_org_config()
if config is None:
print(" PASS: Missing org config returns None")
passed += 1
else:
print(" FAIL: Load valid org config (wrong counts)")
print(" FAIL: Missing org config returns None")
failed += 1
else:
print(" FAIL: Load valid org config")
print(f" Got: {config}")
failed += 1
# Test 2: Missing file returns None
org_config_path.unlink()
config = load_org_config()
if config is None:
print(" PASS: Missing org config returns None")
passed += 1
else:
print(" FAIL: Missing org config returns None")
failed += 1
# Test 3: Non-string command name is rejected
org_config_path.write_text("""version: 1
# Test 3: Non-string command name is rejected
org_config_path.write_text("""version: 1
allowed_commands:
- name: 123
description: Invalid numeric name
""")
config = load_org_config()
if config is None:
print(" PASS: Non-string command name rejected")
passed += 1
else:
print(" FAIL: Non-string command name rejected")
print(f" Got: {config}")
failed += 1
config = load_org_config()
if config is None:
print(" PASS: Non-string command name rejected")
passed += 1
else:
print(" FAIL: Non-string command name rejected")
print(f" Got: {config}")
failed += 1
# Test 4: Empty command name is rejected
org_config_path.write_text("""version: 1
# Test 4: Empty command name is rejected
org_config_path.write_text("""version: 1
allowed_commands:
- name: ""
description: Empty name
""")
config = load_org_config()
if config is None:
print(" PASS: Empty command name rejected")
passed += 1
else:
print(" FAIL: Empty command name rejected")
print(f" Got: {config}")
failed += 1
config = load_org_config()
if config is None:
print(" PASS: Empty command name rejected")
passed += 1
else:
print(" FAIL: Empty command name rejected")
print(f" Got: {config}")
failed += 1
# Test 5: Whitespace-only command name is rejected
org_config_path.write_text("""version: 1
# Test 5: Whitespace-only command name is rejected
org_config_path.write_text("""version: 1
allowed_commands:
- name: " "
description: Whitespace name
""")
config = load_org_config()
if config is None:
print(" PASS: Whitespace-only command name rejected")
passed += 1
else:
print(" FAIL: Whitespace-only command name rejected")
print(f" Got: {config}")
failed += 1
# Restore HOME
os.environ["HOME"] = str(original_home)
config = load_org_config()
if config is None:
print(" PASS: Whitespace-only command name rejected")
passed += 1
else:
print(" FAIL: Whitespace-only command name rejected")
print(f" Got: {config}")
failed += 1
return passed, failed
@@ -519,17 +555,14 @@ def test_hierarchy_resolution():
with tempfile.TemporaryDirectory() as tmphome:
with tempfile.TemporaryDirectory() as tmpproject:
# Setup fake home directory
import os
original_home = os.environ.get("HOME")
os.environ["HOME"] = tmphome
# Use temporary_home for cross-platform compatibility
with temporary_home(tmphome):
org_dir = Path(tmphome) / ".autocoder"
org_dir.mkdir()
org_config_path = org_dir / "config.yaml"
org_dir = Path(tmphome) / ".autocoder"
org_dir.mkdir()
org_config_path = org_dir / "config.yaml"
# Create org config with allowed and blocked commands
org_config_path.write_text("""version: 1
# Create org config with allowed and blocked commands
org_config_path.write_text("""version: 1
allowed_commands:
- name: jq
description: JSON processor
@@ -540,66 +573,60 @@ blocked_commands:
- kubectl
""")
project_dir = Path(tmpproject)
project_autocoder = project_dir / ".autocoder"
project_autocoder.mkdir()
project_config = project_autocoder / "allowed_commands.yaml"
project_dir = Path(tmpproject)
project_autocoder = project_dir / ".autocoder"
project_autocoder.mkdir()
project_config = project_autocoder / "allowed_commands.yaml"
# Create project config
project_config.write_text("""version: 1
# Create project config
project_config.write_text("""version: 1
commands:
- name: swift
description: Swift compiler
""")
# Test 1: Org allowed commands are included
allowed, blocked = get_effective_commands(project_dir)
if "jq" in allowed and "python3" in allowed:
print(" PASS: Org allowed commands included")
passed += 1
else:
print(" FAIL: Org allowed commands included")
print(f" jq in allowed: {'jq' in allowed}")
print(f" python3 in allowed: {'python3' in allowed}")
failed += 1
# Test 1: Org allowed commands are included
allowed, blocked = get_effective_commands(project_dir)
if "jq" in allowed and "python3" in allowed:
print(" PASS: Org allowed commands included")
passed += 1
else:
print(" FAIL: Org allowed commands included")
print(f" jq in allowed: {'jq' in allowed}")
print(f" python3 in allowed: {'python3' in allowed}")
failed += 1
# Test 2: Org blocked commands are in blocklist
if "terraform" in blocked and "kubectl" in blocked:
print(" PASS: Org blocked commands in blocklist")
passed += 1
else:
print(" FAIL: Org blocked commands in blocklist")
failed += 1
# Test 2: Org blocked commands are in blocklist
if "terraform" in blocked and "kubectl" in blocked:
print(" PASS: Org blocked commands in blocklist")
passed += 1
else:
print(" FAIL: Org blocked commands in blocklist")
failed += 1
# Test 3: Project commands are included
if "swift" in allowed:
print(" PASS: Project commands included")
passed += 1
else:
print(" FAIL: Project commands included")
failed += 1
# Test 3: Project commands are included
if "swift" in allowed:
print(" PASS: Project commands included")
passed += 1
else:
print(" FAIL: Project commands included")
failed += 1
# Test 4: Global commands are included
if "npm" in allowed and "git" in allowed:
print(" PASS: Global commands included")
passed += 1
else:
print(" FAIL: Global commands included")
failed += 1
# Test 4: Global commands are included
if "npm" in allowed and "git" in allowed:
print(" PASS: Global commands included")
passed += 1
else:
print(" FAIL: Global commands included")
failed += 1
# Test 5: Hardcoded blocklist cannot be overridden
if "sudo" in blocked and "shutdown" in blocked:
print(" PASS: Hardcoded blocklist enforced")
passed += 1
else:
print(" FAIL: Hardcoded blocklist enforced")
failed += 1
# Restore HOME
if original_home:
os.environ["HOME"] = original_home
else:
del os.environ["HOME"]
# Test 5: Hardcoded blocklist cannot be overridden
if "sudo" in blocked and "shutdown" in blocked:
print(" PASS: Hardcoded blocklist enforced")
passed += 1
else:
print(" FAIL: Hardcoded blocklist enforced")
failed += 1
return passed, failed
@@ -612,42 +639,33 @@ def test_org_blocklist_enforcement():
with tempfile.TemporaryDirectory() as tmphome:
with tempfile.TemporaryDirectory() as tmpproject:
# Setup fake home directory
import os
original_home = os.environ.get("HOME")
os.environ["HOME"] = tmphome
# Use temporary_home for cross-platform compatibility
with temporary_home(tmphome):
org_dir = Path(tmphome) / ".autocoder"
org_dir.mkdir()
org_config_path = org_dir / "config.yaml"
org_dir = Path(tmphome) / ".autocoder"
org_dir.mkdir()
org_config_path = org_dir / "config.yaml"
# Create org config that blocks terraform
org_config_path.write_text("""version: 1
# Create org config that blocks terraform
org_config_path.write_text("""version: 1
blocked_commands:
- terraform
""")
project_dir = Path(tmpproject)
project_autocoder = project_dir / ".autocoder"
project_autocoder.mkdir()
project_dir = Path(tmpproject)
project_autocoder = project_dir / ".autocoder"
project_autocoder.mkdir()
# Try to use terraform (should be blocked)
input_data = {"tool_name": "Bash", "tool_input": {"command": "terraform apply"}}
context = {"project_dir": str(project_dir)}
result = asyncio.run(bash_security_hook(input_data, context=context))
# Try to use terraform (should be blocked)
input_data = {"tool_name": "Bash", "tool_input": {"command": "terraform apply"}}
context = {"project_dir": str(project_dir)}
result = asyncio.run(bash_security_hook(input_data, context=context))
if result.get("decision") == "block":
print(" PASS: Org blocked command 'terraform' rejected")
passed += 1
else:
print(" FAIL: Org blocked command 'terraform' should be rejected")
failed += 1
# Restore HOME
if original_home:
os.environ["HOME"] = original_home
else:
del os.environ["HOME"]
if result.get("decision") == "block":
print(" PASS: Org blocked command 'terraform' rejected")
passed += 1
else:
print(" FAIL: Org blocked command 'terraform' should be rejected")
failed += 1
return passed, failed