test: add safe environment variable handling in integration tests

Changes:
- Add temporary_home() context manager for safe HOME manipulation
- Handle both Unix (HOME) and Windows (USERPROFILE, HOMEDRIVE, HOMEPATH)
- Update test_org_blocklist_enforcement to use context manager
- Update test_org_allowlist_inheritance to use context manager

Benefits:
- Environment variables always restored, even on exceptions
- Prevents test pollution across test runs
- Cross-platform compatibility (Windows + Unix)

All 9 integration tests passing.
This commit is contained in:
Marian Paul
2026-01-22 16:31:50 +01:00
parent 996ac0065c
commit edff398fe6

View File

@@ -19,11 +19,56 @@ import asyncio
import os
import sys
import tempfile
from contextlib import contextmanager
from pathlib import Path
from security import bash_security_hook
@contextmanager
def temporary_home(home_path):
"""
Context manager to temporarily set HOME (and Windows equivalents).
Saves original environment variables and restores them on exit,
even if an exception occurs.
Args:
home_path: Path to use as temporary home directory
"""
# Save original values for Unix and Windows
saved_env = {
"HOME": os.environ.get("HOME"),
"USERPROFILE": os.environ.get("USERPROFILE"),
"HOMEDRIVE": os.environ.get("HOMEDRIVE"),
"HOMEPATH": os.environ.get("HOMEPATH"),
}
try:
# Set new home directory for both Unix and Windows
os.environ["HOME"] = str(home_path)
if sys.platform == "win32":
os.environ["USERPROFILE"] = str(home_path)
# Note: HOMEDRIVE and HOMEPATH are typically set by Windows
# but we update them for consistency
drive, path = os.path.splitdrive(str(home_path))
if drive:
os.environ["HOMEDRIVE"] = drive
os.environ["HOMEPATH"] = path
yield
finally:
# Restore all original values
for key, value in saved_env.items():
if value is None:
# Remove if it didn't exist before
os.environ.pop(key, None)
else:
# Restore original value
os.environ[key] = value
def test_blocked_command_via_hook():
"""Test that hardcoded blocked commands are rejected by the security hook."""
print("\n" + "=" * 70)
@@ -200,52 +245,44 @@ def test_org_blocklist_enforcement():
with tempfile.TemporaryDirectory() as tmphome:
with tempfile.TemporaryDirectory() as tmpproject:
# Setup fake home directory with org config
original_home = os.environ.get("HOME")
os.environ["HOME"] = tmphome
org_dir = Path(tmphome) / ".autocoder"
org_dir.mkdir()
(org_dir / "config.yaml").write_text("""version: 1
# Use context manager to safely set and restore HOME
with temporary_home(tmphome):
org_dir = Path(tmphome) / ".autocoder"
org_dir.mkdir()
(org_dir / "config.yaml").write_text("""version: 1
allowed_commands: []
blocked_commands:
- terraform
- kubectl
""")
project_dir = Path(tmpproject)
autocoder_dir = project_dir / ".autocoder"
autocoder_dir.mkdir()
project_dir = Path(tmpproject)
autocoder_dir = project_dir / ".autocoder"
autocoder_dir.mkdir()
# Try to allow terraform in project config (should fail - org blocked)
(autocoder_dir / "allowed_commands.yaml").write_text("""version: 1
# Try to allow terraform in project config (should fail - org blocked)
(autocoder_dir / "allowed_commands.yaml").write_text("""version: 1
commands:
- name: terraform
description: Infrastructure as code
""")
# Try to run terraform (should be blocked by org config)
input_data = {
"tool_name": "Bash",
"tool_input": {"command": "terraform apply"},
}
context = {"project_dir": str(project_dir)}
# Try to run terraform (should be blocked by org config)
input_data = {
"tool_name": "Bash",
"tool_input": {"command": "terraform apply"},
}
context = {"project_dir": str(project_dir)}
result = asyncio.run(bash_security_hook(input_data, context=context))
result = asyncio.run(bash_security_hook(input_data, context=context))
# Restore HOME
if original_home:
os.environ["HOME"] = original_home
else:
del os.environ["HOME"]
if result.get("decision") == "block":
print("✅ PASS: terraform blocked by org config (cannot override)")
print(f" Reason: {result.get('reason', 'N/A')[:80]}...")
return True
else:
print("❌ FAIL: terraform should have been blocked by org config")
return False
if result.get("decision") == "block":
print("✅ PASS: terraform blocked by org config (cannot override)")
print(f" Reason: {result.get('reason', 'N/A')[:80]}...")
return True
else:
print("❌ FAIL: terraform should have been blocked by org config")
return False
def test_org_allowlist_inheritance():
@@ -256,45 +293,37 @@ def test_org_allowlist_inheritance():
with tempfile.TemporaryDirectory() as tmphome:
with tempfile.TemporaryDirectory() as tmpproject:
# Setup fake home directory with org config
original_home = os.environ.get("HOME")
os.environ["HOME"] = tmphome
org_dir = Path(tmphome) / ".autocoder"
org_dir.mkdir()
(org_dir / "config.yaml").write_text("""version: 1
# Use context manager to safely set and restore HOME
with temporary_home(tmphome):
org_dir = Path(tmphome) / ".autocoder"
org_dir.mkdir()
(org_dir / "config.yaml").write_text("""version: 1
allowed_commands:
- name: jq
description: JSON processor
blocked_commands: []
""")
project_dir = Path(tmpproject)
autocoder_dir = project_dir / ".autocoder"
autocoder_dir.mkdir()
(autocoder_dir / "allowed_commands.yaml").write_text(
"version: 1\ncommands: []"
)
project_dir = Path(tmpproject)
autocoder_dir = project_dir / ".autocoder"
autocoder_dir.mkdir()
(autocoder_dir / "allowed_commands.yaml").write_text(
"version: 1\ncommands: []"
)
# Try to run jq (should be allowed via org config)
input_data = {"tool_name": "Bash", "tool_input": {"command": "jq '.data'"}}
context = {"project_dir": str(project_dir)}
# Try to run jq (should be allowed via org config)
input_data = {"tool_name": "Bash", "tool_input": {"command": "jq '.data'"}}
context = {"project_dir": str(project_dir)}
result = asyncio.run(bash_security_hook(input_data, context=context))
result = asyncio.run(bash_security_hook(input_data, context=context))
# Restore HOME
if original_home:
os.environ["HOME"] = original_home
else:
del os.environ["HOME"]
if result.get("decision") != "block":
print("✅ PASS: jq allowed via org config")
return True
else:
print("❌ FAIL: jq should have been allowed via org config")
print(f" Reason: {result.get('reason', 'N/A')}")
return False
if result.get("decision") != "block":
print("✅ PASS: jq allowed via org config")
return True
else:
print("❌ FAIL: jq should have been allowed via org config")
print(f" Reason: {result.get('reason', 'N/A')}")
return False
def test_invalid_yaml_ignored():