mirror of
https://github.com/leonvanzyl/autocoder.git
synced 2026-01-30 06:12:06 +00:00
feat: add per-project bash command allowlist system
Implement hierarchical command security with project and org-level configs:
WHAT'S NEW:
- Project-level YAML config (.autocoder/allowed_commands.yaml)
- Organization-level config (~/.autocoder/config.yaml)
- Pattern matching (exact, wildcards, local scripts)
- Hardcoded blocklist (sudo, dd, shutdown - never allowed)
- Org blocklist (terraform, kubectl - configurable)
- Helpful error messages with config hints
- Comprehensive documentation and examples
ARCHITECTURE:
- Hierarchical resolution: Hardcoded → Org Block → Org Allow → Global → Project
- YAML validation with 50 command limit per project
- Pattern matching: exact ("swift"), wildcards ("swift*"), scripts ("./build.sh")
- Secure by default: all examples commented out
TESTING:
- 136 unit tests (pattern matching, YAML, hierarchy, validation)
- 9 integration tests (real security hook flows)
- All tests passing, 100% backward compatible
DOCUMENTATION:
- examples/README.md - comprehensive guide with use cases
- examples/project_allowed_commands.yaml - template (all commented)
- examples/org_config.yaml - org config template (all commented)
- PHASE3_SPEC.md - mid-session approval spec (future enhancement)
- Updated CLAUDE.md with security model documentation
USE CASES:
- iOS projects: Add Swift toolchain (xcodebuild, swift*, etc.)
- Rust projects: Add cargo, rustc, clippy
- Enterprise: Block aws, kubectl, terraform org-wide
- Custom scripts: Allow ./scripts/build.sh
PHASES:
✅ Phase 1: Project YAML + blocklist (implemented)
✅ Phase 2: Org config + hierarchy (implemented)
📋 Phase 3: Mid-session approval (spec ready, not implemented)
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
481
test_security.py
481
test_security.py
@@ -9,12 +9,19 @@ Run with: python test_security.py
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from security import (
|
||||
bash_security_hook,
|
||||
extract_commands,
|
||||
get_effective_commands,
|
||||
load_org_config,
|
||||
load_project_commands,
|
||||
matches_pattern,
|
||||
validate_chmod_command,
|
||||
validate_init_script,
|
||||
validate_project_command,
|
||||
)
|
||||
|
||||
|
||||
@@ -151,6 +158,440 @@ def test_validate_init_script():
|
||||
return passed, failed
|
||||
|
||||
|
||||
def test_pattern_matching():
|
||||
"""Test command pattern matching."""
|
||||
print("\nTesting pattern matching:\n")
|
||||
passed = 0
|
||||
failed = 0
|
||||
|
||||
# Test cases: (command, pattern, should_match, description)
|
||||
test_cases = [
|
||||
# Exact matches
|
||||
("swift", "swift", True, "exact match"),
|
||||
("npm", "npm", True, "exact npm"),
|
||||
("xcodebuild", "xcodebuild", True, "exact xcodebuild"),
|
||||
|
||||
# Prefix wildcards
|
||||
("swiftc", "swift*", True, "swiftc matches swift*"),
|
||||
("swiftlint", "swift*", True, "swiftlint matches swift*"),
|
||||
("swiftformat", "swift*", True, "swiftformat matches swift*"),
|
||||
("swift", "swift*", True, "swift matches swift*"),
|
||||
("npm", "swift*", False, "npm doesn't match swift*"),
|
||||
|
||||
# Local script paths
|
||||
("build.sh", "./scripts/build.sh", True, "script name matches path"),
|
||||
("./scripts/build.sh", "./scripts/build.sh", True, "exact script path"),
|
||||
("scripts/build.sh", "./scripts/build.sh", True, "relative script path"),
|
||||
("/abs/path/scripts/build.sh", "./scripts/build.sh", True, "absolute path matches"),
|
||||
("test.sh", "./scripts/build.sh", False, "different script name"),
|
||||
|
||||
# Non-matches
|
||||
("go", "swift*", False, "go doesn't match swift*"),
|
||||
("rustc", "swift*", False, "rustc doesn't match swift*"),
|
||||
]
|
||||
|
||||
for command, pattern, should_match, description in test_cases:
|
||||
result = matches_pattern(command, pattern)
|
||||
if result == should_match:
|
||||
print(f" PASS: {command!r} vs {pattern!r} ({description})")
|
||||
passed += 1
|
||||
else:
|
||||
expected = "match" if should_match else "no match"
|
||||
actual = "match" if result else "no match"
|
||||
print(f" FAIL: {command!r} vs {pattern!r} ({description})")
|
||||
print(f" Expected: {expected}, Got: {actual}")
|
||||
failed += 1
|
||||
|
||||
return passed, failed
|
||||
|
||||
|
||||
def test_yaml_loading():
|
||||
"""Test YAML config loading and validation."""
|
||||
print("\nTesting YAML loading:\n")
|
||||
passed = 0
|
||||
failed = 0
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
project_dir = Path(tmpdir)
|
||||
autocoder_dir = project_dir / ".autocoder"
|
||||
autocoder_dir.mkdir()
|
||||
|
||||
# Test 1: Valid YAML
|
||||
config_path = autocoder_dir / "allowed_commands.yaml"
|
||||
config_path.write_text("""version: 1
|
||||
commands:
|
||||
- name: swift
|
||||
description: Swift compiler
|
||||
- name: xcodebuild
|
||||
description: Xcode build
|
||||
- name: swift*
|
||||
description: All Swift tools
|
||||
""")
|
||||
config = load_project_commands(project_dir)
|
||||
if config and config["version"] == 1 and len(config["commands"]) == 3:
|
||||
print(" PASS: Load valid YAML")
|
||||
passed += 1
|
||||
else:
|
||||
print(" FAIL: Load valid YAML")
|
||||
print(f" Got: {config}")
|
||||
failed += 1
|
||||
|
||||
# Test 2: Missing file returns None
|
||||
(project_dir / ".autocoder" / "allowed_commands.yaml").unlink()
|
||||
config = load_project_commands(project_dir)
|
||||
if config is None:
|
||||
print(" PASS: Missing file returns None")
|
||||
passed += 1
|
||||
else:
|
||||
print(" FAIL: Missing file returns None")
|
||||
print(f" Got: {config}")
|
||||
failed += 1
|
||||
|
||||
# Test 3: Invalid YAML returns None
|
||||
config_path.write_text("invalid: yaml: content:")
|
||||
config = load_project_commands(project_dir)
|
||||
if config is None:
|
||||
print(" PASS: Invalid YAML returns None")
|
||||
passed += 1
|
||||
else:
|
||||
print(" FAIL: Invalid YAML returns None")
|
||||
print(f" Got: {config}")
|
||||
failed += 1
|
||||
|
||||
# Test 4: Over limit (50 commands)
|
||||
commands = [f" - name: cmd{i}\n description: Command {i}" for i in range(51)]
|
||||
config_path.write_text("version: 1\ncommands:\n" + "\n".join(commands))
|
||||
config = load_project_commands(project_dir)
|
||||
if config is None:
|
||||
print(" PASS: Over limit rejected")
|
||||
passed += 1
|
||||
else:
|
||||
print(" FAIL: Over limit rejected")
|
||||
print(f" Got: {config}")
|
||||
failed += 1
|
||||
|
||||
return passed, failed
|
||||
|
||||
|
||||
def test_command_validation():
|
||||
"""Test project command validation."""
|
||||
print("\nTesting command validation:\n")
|
||||
passed = 0
|
||||
failed = 0
|
||||
|
||||
# Test cases: (cmd_config, should_be_valid, description)
|
||||
test_cases = [
|
||||
# Valid commands
|
||||
({"name": "swift", "description": "Swift compiler"}, True, "valid command"),
|
||||
({"name": "swift"}, True, "command without description"),
|
||||
({"name": "swift*", "description": "All Swift tools"}, True, "pattern command"),
|
||||
({"name": "./scripts/build.sh", "description": "Build script"}, True, "local script"),
|
||||
|
||||
# Invalid commands
|
||||
({}, False, "missing name"),
|
||||
({"description": "No name"}, False, "missing name field"),
|
||||
({"name": ""}, False, "empty name"),
|
||||
({"name": 123}, False, "non-string name"),
|
||||
|
||||
# Blocklisted commands
|
||||
({"name": "sudo"}, False, "blocklisted sudo"),
|
||||
({"name": "shutdown"}, False, "blocklisted shutdown"),
|
||||
({"name": "dd"}, False, "blocklisted dd"),
|
||||
]
|
||||
|
||||
for cmd_config, should_be_valid, description in test_cases:
|
||||
valid, error = validate_project_command(cmd_config)
|
||||
if valid == should_be_valid:
|
||||
print(f" PASS: {description}")
|
||||
passed += 1
|
||||
else:
|
||||
expected = "valid" if should_be_valid else "invalid"
|
||||
actual = "valid" if valid else "invalid"
|
||||
print(f" FAIL: {description}")
|
||||
print(f" Expected: {expected}, Got: {actual}")
|
||||
if error:
|
||||
print(f" Error: {error}")
|
||||
failed += 1
|
||||
|
||||
return passed, failed
|
||||
|
||||
|
||||
def test_blocklist_enforcement():
|
||||
"""Test blocklist enforcement in security hook."""
|
||||
print("\nTesting blocklist enforcement:\n")
|
||||
passed = 0
|
||||
failed = 0
|
||||
|
||||
# All blocklisted commands should be rejected
|
||||
for cmd in ["sudo apt install", "shutdown now", "dd if=/dev/zero", "aws s3 ls"]:
|
||||
input_data = {"tool_name": "Bash", "tool_input": {"command": cmd}}
|
||||
result = asyncio.run(bash_security_hook(input_data))
|
||||
if result.get("decision") == "block":
|
||||
print(f" PASS: Blocked {cmd.split()[0]}")
|
||||
passed += 1
|
||||
else:
|
||||
print(f" FAIL: Should block {cmd.split()[0]}")
|
||||
failed += 1
|
||||
|
||||
return passed, failed
|
||||
|
||||
|
||||
def test_project_commands():
|
||||
"""Test project-specific commands in security hook."""
|
||||
print("\nTesting project-specific commands:\n")
|
||||
passed = 0
|
||||
failed = 0
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
project_dir = Path(tmpdir)
|
||||
autocoder_dir = project_dir / ".autocoder"
|
||||
autocoder_dir.mkdir()
|
||||
|
||||
# Create a config with Swift commands
|
||||
config_path = autocoder_dir / "allowed_commands.yaml"
|
||||
config_path.write_text("""version: 1
|
||||
commands:
|
||||
- name: swift
|
||||
description: Swift compiler
|
||||
- name: xcodebuild
|
||||
description: Xcode build
|
||||
- name: swift*
|
||||
description: All Swift tools
|
||||
""")
|
||||
|
||||
# Test 1: Project command should be allowed
|
||||
input_data = {"tool_name": "Bash", "tool_input": {"command": "swift --version"}}
|
||||
context = {"project_dir": str(project_dir)}
|
||||
result = asyncio.run(bash_security_hook(input_data, context=context))
|
||||
if result.get("decision") != "block":
|
||||
print(" PASS: Project command 'swift' allowed")
|
||||
passed += 1
|
||||
else:
|
||||
print(" FAIL: Project command 'swift' should be allowed")
|
||||
print(f" Reason: {result.get('reason')}")
|
||||
failed += 1
|
||||
|
||||
# Test 2: Pattern match should work
|
||||
input_data = {"tool_name": "Bash", "tool_input": {"command": "swiftlint"}}
|
||||
result = asyncio.run(bash_security_hook(input_data, context=context))
|
||||
if result.get("decision") != "block":
|
||||
print(" PASS: Pattern 'swift*' matches 'swiftlint'")
|
||||
passed += 1
|
||||
else:
|
||||
print(" FAIL: Pattern 'swift*' should match 'swiftlint'")
|
||||
print(f" Reason: {result.get('reason')}")
|
||||
failed += 1
|
||||
|
||||
# Test 3: Non-allowed command should be blocked
|
||||
input_data = {"tool_name": "Bash", "tool_input": {"command": "rustc"}}
|
||||
result = asyncio.run(bash_security_hook(input_data, context=context))
|
||||
if result.get("decision") == "block":
|
||||
print(" PASS: Non-allowed command 'rustc' blocked")
|
||||
passed += 1
|
||||
else:
|
||||
print(" FAIL: Non-allowed command 'rustc' should be blocked")
|
||||
failed += 1
|
||||
|
||||
return passed, failed
|
||||
|
||||
|
||||
def test_org_config_loading():
|
||||
"""Test organization-level config loading."""
|
||||
print("\nTesting org config loading:\n")
|
||||
passed = 0
|
||||
failed = 0
|
||||
|
||||
# Save original org config path
|
||||
original_home = Path.home()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Temporarily override home directory for testing
|
||||
import os
|
||||
os.environ["HOME"] = tmpdir
|
||||
|
||||
org_dir = Path(tmpdir) / ".autocoder"
|
||||
org_dir.mkdir()
|
||||
org_config_path = org_dir / "config.yaml"
|
||||
|
||||
# Test 1: Valid org config
|
||||
org_config_path.write_text("""version: 1
|
||||
allowed_commands:
|
||||
- name: jq
|
||||
description: JSON processor
|
||||
blocked_commands:
|
||||
- aws
|
||||
- kubectl
|
||||
""")
|
||||
config = load_org_config()
|
||||
if config and config["version"] == 1:
|
||||
if len(config["allowed_commands"]) == 1 and len(config["blocked_commands"]) == 2:
|
||||
print(" PASS: Load valid org config")
|
||||
passed += 1
|
||||
else:
|
||||
print(" FAIL: Load valid org config (wrong counts)")
|
||||
failed += 1
|
||||
else:
|
||||
print(" FAIL: Load valid org config")
|
||||
print(f" Got: {config}")
|
||||
failed += 1
|
||||
|
||||
# Test 2: Missing file returns None
|
||||
org_config_path.unlink()
|
||||
config = load_org_config()
|
||||
if config is None:
|
||||
print(" PASS: Missing org config returns None")
|
||||
passed += 1
|
||||
else:
|
||||
print(" FAIL: Missing org config returns None")
|
||||
failed += 1
|
||||
|
||||
# Restore HOME
|
||||
os.environ["HOME"] = str(original_home)
|
||||
|
||||
return passed, failed
|
||||
|
||||
|
||||
def test_hierarchy_resolution():
|
||||
"""Test command hierarchy resolution."""
|
||||
print("\nTesting hierarchy resolution:\n")
|
||||
passed = 0
|
||||
failed = 0
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmphome:
|
||||
with tempfile.TemporaryDirectory() as tmpproject:
|
||||
# Setup fake home directory
|
||||
import os
|
||||
original_home = os.environ.get("HOME")
|
||||
os.environ["HOME"] = tmphome
|
||||
|
||||
org_dir = Path(tmphome) / ".autocoder"
|
||||
org_dir.mkdir()
|
||||
org_config_path = org_dir / "config.yaml"
|
||||
|
||||
# Create org config with allowed and blocked commands
|
||||
org_config_path.write_text("""version: 1
|
||||
allowed_commands:
|
||||
- name: jq
|
||||
description: JSON processor
|
||||
- name: python3
|
||||
description: Python interpreter
|
||||
blocked_commands:
|
||||
- terraform
|
||||
- kubectl
|
||||
""")
|
||||
|
||||
project_dir = Path(tmpproject)
|
||||
project_autocoder = project_dir / ".autocoder"
|
||||
project_autocoder.mkdir()
|
||||
project_config = project_autocoder / "allowed_commands.yaml"
|
||||
|
||||
# Create project config
|
||||
project_config.write_text("""version: 1
|
||||
commands:
|
||||
- name: swift
|
||||
description: Swift compiler
|
||||
""")
|
||||
|
||||
# Test 1: Org allowed commands are included
|
||||
allowed, blocked = get_effective_commands(project_dir)
|
||||
if "jq" in allowed and "python3" in allowed:
|
||||
print(" PASS: Org allowed commands included")
|
||||
passed += 1
|
||||
else:
|
||||
print(" FAIL: Org allowed commands included")
|
||||
print(f" jq in allowed: {'jq' in allowed}")
|
||||
print(f" python3 in allowed: {'python3' in allowed}")
|
||||
failed += 1
|
||||
|
||||
# Test 2: Org blocked commands are in blocklist
|
||||
if "terraform" in blocked and "kubectl" in blocked:
|
||||
print(" PASS: Org blocked commands in blocklist")
|
||||
passed += 1
|
||||
else:
|
||||
print(" FAIL: Org blocked commands in blocklist")
|
||||
failed += 1
|
||||
|
||||
# Test 3: Project commands are included
|
||||
if "swift" in allowed:
|
||||
print(" PASS: Project commands included")
|
||||
passed += 1
|
||||
else:
|
||||
print(" FAIL: Project commands included")
|
||||
failed += 1
|
||||
|
||||
# Test 4: Global commands are included
|
||||
if "npm" in allowed and "git" in allowed:
|
||||
print(" PASS: Global commands included")
|
||||
passed += 1
|
||||
else:
|
||||
print(" FAIL: Global commands included")
|
||||
failed += 1
|
||||
|
||||
# Test 5: Hardcoded blocklist cannot be overridden
|
||||
if "sudo" in blocked and "shutdown" in blocked:
|
||||
print(" PASS: Hardcoded blocklist enforced")
|
||||
passed += 1
|
||||
else:
|
||||
print(" FAIL: Hardcoded blocklist enforced")
|
||||
failed += 1
|
||||
|
||||
# Restore HOME
|
||||
if original_home:
|
||||
os.environ["HOME"] = original_home
|
||||
else:
|
||||
del os.environ["HOME"]
|
||||
|
||||
return passed, failed
|
||||
|
||||
|
||||
def test_org_blocklist_enforcement():
|
||||
"""Test that org-level blocked commands cannot be used."""
|
||||
print("\nTesting org blocklist enforcement:\n")
|
||||
passed = 0
|
||||
failed = 0
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmphome:
|
||||
with tempfile.TemporaryDirectory() as tmpproject:
|
||||
# Setup fake home directory
|
||||
import os
|
||||
original_home = os.environ.get("HOME")
|
||||
os.environ["HOME"] = tmphome
|
||||
|
||||
org_dir = Path(tmphome) / ".autocoder"
|
||||
org_dir.mkdir()
|
||||
org_config_path = org_dir / "config.yaml"
|
||||
|
||||
# Create org config that blocks terraform
|
||||
org_config_path.write_text("""version: 1
|
||||
blocked_commands:
|
||||
- terraform
|
||||
""")
|
||||
|
||||
project_dir = Path(tmpproject)
|
||||
project_autocoder = project_dir / ".autocoder"
|
||||
project_autocoder.mkdir()
|
||||
|
||||
# Try to use terraform (should be blocked)
|
||||
input_data = {"tool_name": "Bash", "tool_input": {"command": "terraform apply"}}
|
||||
context = {"project_dir": str(project_dir)}
|
||||
result = asyncio.run(bash_security_hook(input_data, context=context))
|
||||
|
||||
if result.get("decision") == "block":
|
||||
print(" PASS: Org blocked command 'terraform' rejected")
|
||||
passed += 1
|
||||
else:
|
||||
print(" FAIL: Org blocked command 'terraform' should be rejected")
|
||||
failed += 1
|
||||
|
||||
# Restore HOME
|
||||
if original_home:
|
||||
os.environ["HOME"] = original_home
|
||||
else:
|
||||
del os.environ["HOME"]
|
||||
|
||||
return passed, failed
|
||||
|
||||
|
||||
def main():
|
||||
print("=" * 70)
|
||||
print(" SECURITY HOOK TESTS")
|
||||
@@ -174,6 +615,46 @@ def main():
|
||||
passed += init_passed
|
||||
failed += init_failed
|
||||
|
||||
# Test pattern matching (Phase 1)
|
||||
pattern_passed, pattern_failed = test_pattern_matching()
|
||||
passed += pattern_passed
|
||||
failed += pattern_failed
|
||||
|
||||
# Test YAML loading (Phase 1)
|
||||
yaml_passed, yaml_failed = test_yaml_loading()
|
||||
passed += yaml_passed
|
||||
failed += yaml_failed
|
||||
|
||||
# Test command validation (Phase 1)
|
||||
validation_passed, validation_failed = test_command_validation()
|
||||
passed += validation_passed
|
||||
failed += validation_failed
|
||||
|
||||
# Test blocklist enforcement (Phase 1)
|
||||
blocklist_passed, blocklist_failed = test_blocklist_enforcement()
|
||||
passed += blocklist_passed
|
||||
failed += blocklist_failed
|
||||
|
||||
# Test project commands (Phase 1)
|
||||
project_passed, project_failed = test_project_commands()
|
||||
passed += project_passed
|
||||
failed += project_failed
|
||||
|
||||
# Test org config loading (Phase 2)
|
||||
org_loading_passed, org_loading_failed = test_org_config_loading()
|
||||
passed += org_loading_passed
|
||||
failed += org_loading_failed
|
||||
|
||||
# Test hierarchy resolution (Phase 2)
|
||||
hierarchy_passed, hierarchy_failed = test_hierarchy_resolution()
|
||||
passed += hierarchy_passed
|
||||
failed += hierarchy_failed
|
||||
|
||||
# Test org blocklist enforcement (Phase 2)
|
||||
org_block_passed, org_block_failed = test_org_blocklist_enforcement()
|
||||
passed += org_block_passed
|
||||
failed += org_block_failed
|
||||
|
||||
# Commands that SHOULD be blocked
|
||||
print("\nCommands that should be BLOCKED:\n")
|
||||
dangerous = [
|
||||
|
||||
Reference in New Issue
Block a user