mirror of
https://github.com/leonvanzyl/autocoder.git
synced 2026-02-01 15:03:36 +00:00
fix: add Windows compatibility to security unit tests
Add cross-platform temporary_home() context manager to handle environment variable differences between Unix and Windows systems. Changes: - Add temporary_home() context manager that handles both HOME (Unix) and USERPROFILE/HOMEDRIVE/HOMEPATH (Windows) environment variables - Update test_org_config_loading() to use temporary_home() - Update test_hierarchy_resolution() to use temporary_home() - Update test_org_blocklist_enforcement() to use temporary_home() - Add missing imports: os, contextmanager Why: The unit tests for org config loading were failing on Windows because they only set the HOME environment variable, but Windows uses USERPROFILE instead. The integration tests already had this fix via a similar context manager. Result: All 148 unit tests now pass on both Windows and Unix systems. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
324
test_security.py
324
test_security.py
@@ -8,8 +8,10 @@ Run with: python test_security.py
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from security import (
|
from security import (
|
||||||
@@ -25,6 +27,48 @@ from security import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def temporary_home(home_path):
|
||||||
|
"""
|
||||||
|
Context manager to temporarily set HOME (and Windows equivalents).
|
||||||
|
|
||||||
|
Saves original environment variables and restores them on exit,
|
||||||
|
even if an exception occurs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
home_path: Path to use as temporary home directory
|
||||||
|
"""
|
||||||
|
# Save original values for Unix and Windows
|
||||||
|
saved_env = {
|
||||||
|
"HOME": os.environ.get("HOME"),
|
||||||
|
"USERPROFILE": os.environ.get("USERPROFILE"),
|
||||||
|
"HOMEDRIVE": os.environ.get("HOMEDRIVE"),
|
||||||
|
"HOMEPATH": os.environ.get("HOMEPATH"),
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Set new home directory for both Unix and Windows
|
||||||
|
os.environ["HOME"] = str(home_path)
|
||||||
|
if sys.platform == "win32":
|
||||||
|
os.environ["USERPROFILE"] = str(home_path)
|
||||||
|
# Note: HOMEDRIVE and HOMEPATH are typically set by Windows
|
||||||
|
# but we update them for consistency
|
||||||
|
drive, path = os.path.splitdrive(str(home_path))
|
||||||
|
if drive:
|
||||||
|
os.environ["HOMEDRIVE"] = drive
|
||||||
|
os.environ["HOMEPATH"] = path
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Restore original values
|
||||||
|
for key, value in saved_env.items():
|
||||||
|
if value is None:
|
||||||
|
os.environ.pop(key, None)
|
||||||
|
else:
|
||||||
|
os.environ[key] = value
|
||||||
|
|
||||||
|
|
||||||
def check_hook(command: str, should_block: bool) -> bool:
|
def check_hook(command: str, should_block: bool) -> bool:
|
||||||
"""Check a single command against the security hook (helper function)."""
|
"""Check a single command against the security hook (helper function)."""
|
||||||
input_data = {"tool_name": "Bash", "tool_input": {"command": command}}
|
input_data = {"tool_name": "Bash", "tool_input": {"command": command}}
|
||||||
@@ -416,20 +460,15 @@ def test_org_config_loading():
|
|||||||
passed = 0
|
passed = 0
|
||||||
failed = 0
|
failed = 0
|
||||||
|
|
||||||
# Save original org config path
|
|
||||||
original_home = Path.home()
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
# Temporarily override home directory for testing
|
# Use temporary_home for cross-platform compatibility
|
||||||
import os
|
with temporary_home(tmpdir):
|
||||||
os.environ["HOME"] = tmpdir
|
org_dir = Path(tmpdir) / ".autocoder"
|
||||||
|
org_dir.mkdir()
|
||||||
|
org_config_path = org_dir / "config.yaml"
|
||||||
|
|
||||||
org_dir = Path(tmpdir) / ".autocoder"
|
# Test 1: Valid org config
|
||||||
org_dir.mkdir()
|
org_config_path.write_text("""version: 1
|
||||||
org_config_path = org_dir / "config.yaml"
|
|
||||||
|
|
||||||
# Test 1: Valid org config
|
|
||||||
org_config_path.write_text("""version: 1
|
|
||||||
allowed_commands:
|
allowed_commands:
|
||||||
- name: jq
|
- name: jq
|
||||||
description: JSON processor
|
description: JSON processor
|
||||||
@@ -437,76 +476,73 @@ blocked_commands:
|
|||||||
- aws
|
- aws
|
||||||
- kubectl
|
- kubectl
|
||||||
""")
|
""")
|
||||||
config = load_org_config()
|
config = load_org_config()
|
||||||
if config and config["version"] == 1:
|
if config and config["version"] == 1:
|
||||||
if len(config["allowed_commands"]) == 1 and len(config["blocked_commands"]) == 2:
|
if len(config["allowed_commands"]) == 1 and len(config["blocked_commands"]) == 2:
|
||||||
print(" PASS: Load valid org config")
|
print(" PASS: Load valid org config")
|
||||||
|
passed += 1
|
||||||
|
else:
|
||||||
|
print(" FAIL: Load valid org config (wrong counts)")
|
||||||
|
failed += 1
|
||||||
|
else:
|
||||||
|
print(" FAIL: Load valid org config")
|
||||||
|
print(f" Got: {config}")
|
||||||
|
failed += 1
|
||||||
|
|
||||||
|
# Test 2: Missing file returns None
|
||||||
|
org_config_path.unlink()
|
||||||
|
config = load_org_config()
|
||||||
|
if config is None:
|
||||||
|
print(" PASS: Missing org config returns None")
|
||||||
passed += 1
|
passed += 1
|
||||||
else:
|
else:
|
||||||
print(" FAIL: Load valid org config (wrong counts)")
|
print(" FAIL: Missing org config returns None")
|
||||||
failed += 1
|
failed += 1
|
||||||
else:
|
|
||||||
print(" FAIL: Load valid org config")
|
|
||||||
print(f" Got: {config}")
|
|
||||||
failed += 1
|
|
||||||
|
|
||||||
# Test 2: Missing file returns None
|
# Test 3: Non-string command name is rejected
|
||||||
org_config_path.unlink()
|
org_config_path.write_text("""version: 1
|
||||||
config = load_org_config()
|
|
||||||
if config is None:
|
|
||||||
print(" PASS: Missing org config returns None")
|
|
||||||
passed += 1
|
|
||||||
else:
|
|
||||||
print(" FAIL: Missing org config returns None")
|
|
||||||
failed += 1
|
|
||||||
|
|
||||||
# Test 3: Non-string command name is rejected
|
|
||||||
org_config_path.write_text("""version: 1
|
|
||||||
allowed_commands:
|
allowed_commands:
|
||||||
- name: 123
|
- name: 123
|
||||||
description: Invalid numeric name
|
description: Invalid numeric name
|
||||||
""")
|
""")
|
||||||
config = load_org_config()
|
config = load_org_config()
|
||||||
if config is None:
|
if config is None:
|
||||||
print(" PASS: Non-string command name rejected")
|
print(" PASS: Non-string command name rejected")
|
||||||
passed += 1
|
passed += 1
|
||||||
else:
|
else:
|
||||||
print(" FAIL: Non-string command name rejected")
|
print(" FAIL: Non-string command name rejected")
|
||||||
print(f" Got: {config}")
|
print(f" Got: {config}")
|
||||||
failed += 1
|
failed += 1
|
||||||
|
|
||||||
# Test 4: Empty command name is rejected
|
# Test 4: Empty command name is rejected
|
||||||
org_config_path.write_text("""version: 1
|
org_config_path.write_text("""version: 1
|
||||||
allowed_commands:
|
allowed_commands:
|
||||||
- name: ""
|
- name: ""
|
||||||
description: Empty name
|
description: Empty name
|
||||||
""")
|
""")
|
||||||
config = load_org_config()
|
config = load_org_config()
|
||||||
if config is None:
|
if config is None:
|
||||||
print(" PASS: Empty command name rejected")
|
print(" PASS: Empty command name rejected")
|
||||||
passed += 1
|
passed += 1
|
||||||
else:
|
else:
|
||||||
print(" FAIL: Empty command name rejected")
|
print(" FAIL: Empty command name rejected")
|
||||||
print(f" Got: {config}")
|
print(f" Got: {config}")
|
||||||
failed += 1
|
failed += 1
|
||||||
|
|
||||||
# Test 5: Whitespace-only command name is rejected
|
# Test 5: Whitespace-only command name is rejected
|
||||||
org_config_path.write_text("""version: 1
|
org_config_path.write_text("""version: 1
|
||||||
allowed_commands:
|
allowed_commands:
|
||||||
- name: " "
|
- name: " "
|
||||||
description: Whitespace name
|
description: Whitespace name
|
||||||
""")
|
""")
|
||||||
config = load_org_config()
|
config = load_org_config()
|
||||||
if config is None:
|
if config is None:
|
||||||
print(" PASS: Whitespace-only command name rejected")
|
print(" PASS: Whitespace-only command name rejected")
|
||||||
passed += 1
|
passed += 1
|
||||||
else:
|
else:
|
||||||
print(" FAIL: Whitespace-only command name rejected")
|
print(" FAIL: Whitespace-only command name rejected")
|
||||||
print(f" Got: {config}")
|
print(f" Got: {config}")
|
||||||
failed += 1
|
failed += 1
|
||||||
|
|
||||||
# Restore HOME
|
|
||||||
os.environ["HOME"] = str(original_home)
|
|
||||||
|
|
||||||
return passed, failed
|
return passed, failed
|
||||||
|
|
||||||
@@ -519,17 +555,14 @@ def test_hierarchy_resolution():
|
|||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmphome:
|
with tempfile.TemporaryDirectory() as tmphome:
|
||||||
with tempfile.TemporaryDirectory() as tmpproject:
|
with tempfile.TemporaryDirectory() as tmpproject:
|
||||||
# Setup fake home directory
|
# Use temporary_home for cross-platform compatibility
|
||||||
import os
|
with temporary_home(tmphome):
|
||||||
original_home = os.environ.get("HOME")
|
org_dir = Path(tmphome) / ".autocoder"
|
||||||
os.environ["HOME"] = tmphome
|
org_dir.mkdir()
|
||||||
|
org_config_path = org_dir / "config.yaml"
|
||||||
|
|
||||||
org_dir = Path(tmphome) / ".autocoder"
|
# Create org config with allowed and blocked commands
|
||||||
org_dir.mkdir()
|
org_config_path.write_text("""version: 1
|
||||||
org_config_path = org_dir / "config.yaml"
|
|
||||||
|
|
||||||
# Create org config with allowed and blocked commands
|
|
||||||
org_config_path.write_text("""version: 1
|
|
||||||
allowed_commands:
|
allowed_commands:
|
||||||
- name: jq
|
- name: jq
|
||||||
description: JSON processor
|
description: JSON processor
|
||||||
@@ -540,66 +573,60 @@ blocked_commands:
|
|||||||
- kubectl
|
- kubectl
|
||||||
""")
|
""")
|
||||||
|
|
||||||
project_dir = Path(tmpproject)
|
project_dir = Path(tmpproject)
|
||||||
project_autocoder = project_dir / ".autocoder"
|
project_autocoder = project_dir / ".autocoder"
|
||||||
project_autocoder.mkdir()
|
project_autocoder.mkdir()
|
||||||
project_config = project_autocoder / "allowed_commands.yaml"
|
project_config = project_autocoder / "allowed_commands.yaml"
|
||||||
|
|
||||||
# Create project config
|
# Create project config
|
||||||
project_config.write_text("""version: 1
|
project_config.write_text("""version: 1
|
||||||
commands:
|
commands:
|
||||||
- name: swift
|
- name: swift
|
||||||
description: Swift compiler
|
description: Swift compiler
|
||||||
""")
|
""")
|
||||||
|
|
||||||
# Test 1: Org allowed commands are included
|
# Test 1: Org allowed commands are included
|
||||||
allowed, blocked = get_effective_commands(project_dir)
|
allowed, blocked = get_effective_commands(project_dir)
|
||||||
if "jq" in allowed and "python3" in allowed:
|
if "jq" in allowed and "python3" in allowed:
|
||||||
print(" PASS: Org allowed commands included")
|
print(" PASS: Org allowed commands included")
|
||||||
passed += 1
|
passed += 1
|
||||||
else:
|
else:
|
||||||
print(" FAIL: Org allowed commands included")
|
print(" FAIL: Org allowed commands included")
|
||||||
print(f" jq in allowed: {'jq' in allowed}")
|
print(f" jq in allowed: {'jq' in allowed}")
|
||||||
print(f" python3 in allowed: {'python3' in allowed}")
|
print(f" python3 in allowed: {'python3' in allowed}")
|
||||||
failed += 1
|
failed += 1
|
||||||
|
|
||||||
# Test 2: Org blocked commands are in blocklist
|
# Test 2: Org blocked commands are in blocklist
|
||||||
if "terraform" in blocked and "kubectl" in blocked:
|
if "terraform" in blocked and "kubectl" in blocked:
|
||||||
print(" PASS: Org blocked commands in blocklist")
|
print(" PASS: Org blocked commands in blocklist")
|
||||||
passed += 1
|
passed += 1
|
||||||
else:
|
else:
|
||||||
print(" FAIL: Org blocked commands in blocklist")
|
print(" FAIL: Org blocked commands in blocklist")
|
||||||
failed += 1
|
failed += 1
|
||||||
|
|
||||||
# Test 3: Project commands are included
|
# Test 3: Project commands are included
|
||||||
if "swift" in allowed:
|
if "swift" in allowed:
|
||||||
print(" PASS: Project commands included")
|
print(" PASS: Project commands included")
|
||||||
passed += 1
|
passed += 1
|
||||||
else:
|
else:
|
||||||
print(" FAIL: Project commands included")
|
print(" FAIL: Project commands included")
|
||||||
failed += 1
|
failed += 1
|
||||||
|
|
||||||
# Test 4: Global commands are included
|
# Test 4: Global commands are included
|
||||||
if "npm" in allowed and "git" in allowed:
|
if "npm" in allowed and "git" in allowed:
|
||||||
print(" PASS: Global commands included")
|
print(" PASS: Global commands included")
|
||||||
passed += 1
|
passed += 1
|
||||||
else:
|
else:
|
||||||
print(" FAIL: Global commands included")
|
print(" FAIL: Global commands included")
|
||||||
failed += 1
|
failed += 1
|
||||||
|
|
||||||
# Test 5: Hardcoded blocklist cannot be overridden
|
# Test 5: Hardcoded blocklist cannot be overridden
|
||||||
if "sudo" in blocked and "shutdown" in blocked:
|
if "sudo" in blocked and "shutdown" in blocked:
|
||||||
print(" PASS: Hardcoded blocklist enforced")
|
print(" PASS: Hardcoded blocklist enforced")
|
||||||
passed += 1
|
passed += 1
|
||||||
else:
|
else:
|
||||||
print(" FAIL: Hardcoded blocklist enforced")
|
print(" FAIL: Hardcoded blocklist enforced")
|
||||||
failed += 1
|
failed += 1
|
||||||
|
|
||||||
# Restore HOME
|
|
||||||
if original_home:
|
|
||||||
os.environ["HOME"] = original_home
|
|
||||||
else:
|
|
||||||
del os.environ["HOME"]
|
|
||||||
|
|
||||||
return passed, failed
|
return passed, failed
|
||||||
|
|
||||||
@@ -612,42 +639,33 @@ def test_org_blocklist_enforcement():
|
|||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmphome:
|
with tempfile.TemporaryDirectory() as tmphome:
|
||||||
with tempfile.TemporaryDirectory() as tmpproject:
|
with tempfile.TemporaryDirectory() as tmpproject:
|
||||||
# Setup fake home directory
|
# Use temporary_home for cross-platform compatibility
|
||||||
import os
|
with temporary_home(tmphome):
|
||||||
original_home = os.environ.get("HOME")
|
org_dir = Path(tmphome) / ".autocoder"
|
||||||
os.environ["HOME"] = tmphome
|
org_dir.mkdir()
|
||||||
|
org_config_path = org_dir / "config.yaml"
|
||||||
|
|
||||||
org_dir = Path(tmphome) / ".autocoder"
|
# Create org config that blocks terraform
|
||||||
org_dir.mkdir()
|
org_config_path.write_text("""version: 1
|
||||||
org_config_path = org_dir / "config.yaml"
|
|
||||||
|
|
||||||
# Create org config that blocks terraform
|
|
||||||
org_config_path.write_text("""version: 1
|
|
||||||
blocked_commands:
|
blocked_commands:
|
||||||
- terraform
|
- terraform
|
||||||
""")
|
""")
|
||||||
|
|
||||||
project_dir = Path(tmpproject)
|
project_dir = Path(tmpproject)
|
||||||
project_autocoder = project_dir / ".autocoder"
|
project_autocoder = project_dir / ".autocoder"
|
||||||
project_autocoder.mkdir()
|
project_autocoder.mkdir()
|
||||||
|
|
||||||
# Try to use terraform (should be blocked)
|
# Try to use terraform (should be blocked)
|
||||||
input_data = {"tool_name": "Bash", "tool_input": {"command": "terraform apply"}}
|
input_data = {"tool_name": "Bash", "tool_input": {"command": "terraform apply"}}
|
||||||
context = {"project_dir": str(project_dir)}
|
context = {"project_dir": str(project_dir)}
|
||||||
result = asyncio.run(bash_security_hook(input_data, context=context))
|
result = asyncio.run(bash_security_hook(input_data, context=context))
|
||||||
|
|
||||||
if result.get("decision") == "block":
|
if result.get("decision") == "block":
|
||||||
print(" PASS: Org blocked command 'terraform' rejected")
|
print(" PASS: Org blocked command 'terraform' rejected")
|
||||||
passed += 1
|
passed += 1
|
||||||
else:
|
else:
|
||||||
print(" FAIL: Org blocked command 'terraform' should be rejected")
|
print(" FAIL: Org blocked command 'terraform' should be rejected")
|
||||||
failed += 1
|
failed += 1
|
||||||
|
|
||||||
# Restore HOME
|
|
||||||
if original_home:
|
|
||||||
os.environ["HOME"] = original_home
|
|
||||||
else:
|
|
||||||
del os.environ["HOME"]
|
|
||||||
|
|
||||||
return passed, failed
|
return passed, failed
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user