Files
autocoder/test_security.py
Marian Paul 996ac0065c fix: improve path matching and org config validation
Changes:
- Support path patterns without ./ prefix (e.g., 'scripts/test.sh')
- Reject non-string or empty command names in org config
- Add 8 new test cases (5 for path patterns, 3 for validation)

Details:
- matches_pattern() now treats any pattern with '/' as a path pattern
- load_org_config() validates that cmd['name'] is a non-empty string
- All 148 unit tests + 9 integration tests passing

Security hardening: Prevents invalid command names from reaching
pattern matching logic, reducing attack surface.
2026-01-22 15:35:00 +01:00

837 lines
27 KiB
Python

#!/usr/bin/env python3
"""
Security Hook Tests
===================
Tests for the bash command security validation logic.
Run with: python test_security.py
"""
import asyncio
import sys
import tempfile
from pathlib import Path
from security import (
bash_security_hook,
extract_commands,
get_effective_commands,
load_org_config,
load_project_commands,
matches_pattern,
validate_chmod_command,
validate_init_script,
validate_project_command,
)
def check_hook(command: str, should_block: bool) -> bool:
"""Check a single command against the security hook (helper function)."""
input_data = {"tool_name": "Bash", "tool_input": {"command": command}}
result = asyncio.run(bash_security_hook(input_data))
was_blocked = result.get("decision") == "block"
if was_blocked == should_block:
status = "PASS"
else:
status = "FAIL"
expected = "blocked" if should_block else "allowed"
actual = "blocked" if was_blocked else "allowed"
reason = result.get("reason", "")
print(f" {status}: {command!r}")
print(f" Expected: {expected}, Got: {actual}")
if reason:
print(f" Reason: {reason}")
return False
print(f" {status}: {command!r}")
return True
def test_extract_commands():
"""Test the command extraction logic."""
print("\nTesting command extraction:\n")
passed = 0
failed = 0
test_cases = [
("ls -la", ["ls"]),
("npm install && npm run build", ["npm", "npm"]),
("cat file.txt | grep pattern", ["cat", "grep"]),
("/usr/bin/node script.js", ["node"]),
("VAR=value ls", ["ls"]),
("git status || git init", ["git", "git"]),
]
for cmd, expected in test_cases:
result = extract_commands(cmd)
if result == expected:
print(f" PASS: {cmd!r} -> {result}")
passed += 1
else:
print(f" FAIL: {cmd!r}")
print(f" Expected: {expected}, Got: {result}")
failed += 1
return passed, failed
def test_validate_chmod():
"""Test chmod command validation."""
print("\nTesting chmod validation:\n")
passed = 0
failed = 0
# Test cases: (command, should_be_allowed, description)
test_cases = [
# Allowed cases
("chmod +x init.sh", True, "basic +x"),
("chmod +x script.sh", True, "+x on any script"),
("chmod u+x init.sh", True, "user +x"),
("chmod a+x init.sh", True, "all +x"),
("chmod ug+x init.sh", True, "user+group +x"),
("chmod +x file1.sh file2.sh", True, "multiple files"),
# Blocked cases
("chmod 777 init.sh", False, "numeric mode"),
("chmod 755 init.sh", False, "numeric mode 755"),
("chmod +w init.sh", False, "write permission"),
("chmod +r init.sh", False, "read permission"),
("chmod -x init.sh", False, "remove execute"),
("chmod -R +x dir/", False, "recursive flag"),
("chmod --recursive +x dir/", False, "long recursive flag"),
("chmod +x", False, "missing file"),
]
for cmd, should_allow, description in test_cases:
allowed, reason = validate_chmod_command(cmd)
if allowed == should_allow:
print(f" PASS: {cmd!r} ({description})")
passed += 1
else:
expected = "allowed" if should_allow else "blocked"
actual = "allowed" if allowed else "blocked"
print(f" FAIL: {cmd!r} ({description})")
print(f" Expected: {expected}, Got: {actual}")
if reason:
print(f" Reason: {reason}")
failed += 1
return passed, failed
def test_validate_init_script():
"""Test init.sh script execution validation."""
print("\nTesting init.sh validation:\n")
passed = 0
failed = 0
# Test cases: (command, should_be_allowed, description)
test_cases = [
# Allowed cases
("./init.sh", True, "basic ./init.sh"),
("./init.sh arg1 arg2", True, "with arguments"),
("/path/to/init.sh", True, "absolute path"),
("../dir/init.sh", True, "relative path with init.sh"),
# Blocked cases
("./setup.sh", False, "different script name"),
("./init.py", False, "python script"),
("bash init.sh", False, "bash invocation"),
("sh init.sh", False, "sh invocation"),
("./malicious.sh", False, "malicious script"),
("./init.sh; rm -rf /", False, "command injection attempt"),
]
for cmd, should_allow, description in test_cases:
allowed, reason = validate_init_script(cmd)
if allowed == should_allow:
print(f" PASS: {cmd!r} ({description})")
passed += 1
else:
expected = "allowed" if should_allow else "blocked"
actual = "allowed" if allowed else "blocked"
print(f" FAIL: {cmd!r} ({description})")
print(f" Expected: {expected}, Got: {actual}")
if reason:
print(f" Reason: {reason}")
failed += 1
return passed, failed
def test_pattern_matching():
"""Test command pattern matching."""
print("\nTesting pattern matching:\n")
passed = 0
failed = 0
# Test cases: (command, pattern, should_match, description)
test_cases = [
# Exact matches
("swift", "swift", True, "exact match"),
("npm", "npm", True, "exact npm"),
("xcodebuild", "xcodebuild", True, "exact xcodebuild"),
# Prefix wildcards
("swiftc", "swift*", True, "swiftc matches swift*"),
("swiftlint", "swift*", True, "swiftlint matches swift*"),
("swiftformat", "swift*", True, "swiftformat matches swift*"),
("swift", "swift*", True, "swift matches swift*"),
("npm", "swift*", False, "npm doesn't match swift*"),
# Bare wildcard (security: should NOT match anything)
("npm", "*", False, "bare wildcard doesn't match npm"),
("sudo", "*", False, "bare wildcard doesn't match sudo"),
("anything", "*", False, "bare wildcard doesn't match anything"),
# Local script paths (with ./ prefix)
("build.sh", "./scripts/build.sh", True, "script name matches path"),
("./scripts/build.sh", "./scripts/build.sh", True, "exact script path"),
("scripts/build.sh", "./scripts/build.sh", True, "relative script path"),
("/abs/path/scripts/build.sh", "./scripts/build.sh", True, "absolute path matches"),
("test.sh", "./scripts/build.sh", False, "different script name"),
# Path patterns (without ./ prefix - new behavior)
("test.sh", "scripts/test.sh", True, "script name matches path pattern"),
("scripts/test.sh", "scripts/test.sh", True, "exact path pattern match"),
("/abs/path/scripts/test.sh", "scripts/test.sh", True, "absolute path matches pattern"),
("build.sh", "scripts/test.sh", False, "different script name in pattern"),
("integration.test.js", "tests/integration.test.js", True, "script with dots matches"),
# Non-matches
("go", "swift*", False, "go doesn't match swift*"),
("rustc", "swift*", False, "rustc doesn't match swift*"),
]
for command, pattern, should_match, description in test_cases:
result = matches_pattern(command, pattern)
if result == should_match:
print(f" PASS: {command!r} vs {pattern!r} ({description})")
passed += 1
else:
expected = "match" if should_match else "no match"
actual = "match" if result else "no match"
print(f" FAIL: {command!r} vs {pattern!r} ({description})")
print(f" Expected: {expected}, Got: {actual}")
failed += 1
return passed, failed
def test_yaml_loading():
"""Test YAML config loading and validation."""
print("\nTesting YAML loading:\n")
passed = 0
failed = 0
with tempfile.TemporaryDirectory() as tmpdir:
project_dir = Path(tmpdir)
autocoder_dir = project_dir / ".autocoder"
autocoder_dir.mkdir()
# Test 1: Valid YAML
config_path = autocoder_dir / "allowed_commands.yaml"
config_path.write_text("""version: 1
commands:
- name: swift
description: Swift compiler
- name: xcodebuild
description: Xcode build
- name: swift*
description: All Swift tools
""")
config = load_project_commands(project_dir)
if config and config["version"] == 1 and len(config["commands"]) == 3:
print(" PASS: Load valid YAML")
passed += 1
else:
print(" FAIL: Load valid YAML")
print(f" Got: {config}")
failed += 1
# Test 2: Missing file returns None
(project_dir / ".autocoder" / "allowed_commands.yaml").unlink()
config = load_project_commands(project_dir)
if config is None:
print(" PASS: Missing file returns None")
passed += 1
else:
print(" FAIL: Missing file returns None")
print(f" Got: {config}")
failed += 1
# Test 3: Invalid YAML returns None
config_path.write_text("invalid: yaml: content:")
config = load_project_commands(project_dir)
if config is None:
print(" PASS: Invalid YAML returns None")
passed += 1
else:
print(" FAIL: Invalid YAML returns None")
print(f" Got: {config}")
failed += 1
# Test 4: Over limit (100 commands)
commands = [f" - name: cmd{i}\n description: Command {i}" for i in range(101)]
config_path.write_text("version: 1\ncommands:\n" + "\n".join(commands))
config = load_project_commands(project_dir)
if config is None:
print(" PASS: Over limit rejected")
passed += 1
else:
print(" FAIL: Over limit rejected")
print(f" Got: {config}")
failed += 1
return passed, failed
def test_command_validation():
"""Test project command validation."""
print("\nTesting command validation:\n")
passed = 0
failed = 0
# Test cases: (cmd_config, should_be_valid, description)
test_cases = [
# Valid commands
({"name": "swift", "description": "Swift compiler"}, True, "valid command"),
({"name": "swift"}, True, "command without description"),
({"name": "swift*", "description": "All Swift tools"}, True, "pattern command"),
({"name": "./scripts/build.sh", "description": "Build script"}, True, "local script"),
# Invalid commands
({}, False, "missing name"),
({"description": "No name"}, False, "missing name field"),
({"name": ""}, False, "empty name"),
({"name": 123}, False, "non-string name"),
# Security: Bare wildcard not allowed
({"name": "*"}, False, "bare wildcard rejected"),
# Blocklisted commands
({"name": "sudo"}, False, "blocklisted sudo"),
({"name": "shutdown"}, False, "blocklisted shutdown"),
({"name": "dd"}, False, "blocklisted dd"),
]
for cmd_config, should_be_valid, description in test_cases:
valid, error = validate_project_command(cmd_config)
if valid == should_be_valid:
print(f" PASS: {description}")
passed += 1
else:
expected = "valid" if should_be_valid else "invalid"
actual = "valid" if valid else "invalid"
print(f" FAIL: {description}")
print(f" Expected: {expected}, Got: {actual}")
if error:
print(f" Error: {error}")
failed += 1
return passed, failed
def test_blocklist_enforcement():
"""Test blocklist enforcement in security hook."""
print("\nTesting blocklist enforcement:\n")
passed = 0
failed = 0
# All blocklisted commands should be rejected
for cmd in ["sudo apt install", "shutdown now", "dd if=/dev/zero", "aws s3 ls"]:
input_data = {"tool_name": "Bash", "tool_input": {"command": cmd}}
result = asyncio.run(bash_security_hook(input_data))
if result.get("decision") == "block":
print(f" PASS: Blocked {cmd.split()[0]}")
passed += 1
else:
print(f" FAIL: Should block {cmd.split()[0]}")
failed += 1
return passed, failed
def test_project_commands():
"""Test project-specific commands in security hook."""
print("\nTesting project-specific commands:\n")
passed = 0
failed = 0
with tempfile.TemporaryDirectory() as tmpdir:
project_dir = Path(tmpdir)
autocoder_dir = project_dir / ".autocoder"
autocoder_dir.mkdir()
# Create a config with Swift commands
config_path = autocoder_dir / "allowed_commands.yaml"
config_path.write_text("""version: 1
commands:
- name: swift
description: Swift compiler
- name: xcodebuild
description: Xcode build
- name: swift*
description: All Swift tools
""")
# Test 1: Project command should be allowed
input_data = {"tool_name": "Bash", "tool_input": {"command": "swift --version"}}
context = {"project_dir": str(project_dir)}
result = asyncio.run(bash_security_hook(input_data, context=context))
if result.get("decision") != "block":
print(" PASS: Project command 'swift' allowed")
passed += 1
else:
print(" FAIL: Project command 'swift' should be allowed")
print(f" Reason: {result.get('reason')}")
failed += 1
# Test 2: Pattern match should work
input_data = {"tool_name": "Bash", "tool_input": {"command": "swiftlint"}}
result = asyncio.run(bash_security_hook(input_data, context=context))
if result.get("decision") != "block":
print(" PASS: Pattern 'swift*' matches 'swiftlint'")
passed += 1
else:
print(" FAIL: Pattern 'swift*' should match 'swiftlint'")
print(f" Reason: {result.get('reason')}")
failed += 1
# Test 3: Non-allowed command should be blocked
input_data = {"tool_name": "Bash", "tool_input": {"command": "rustc"}}
result = asyncio.run(bash_security_hook(input_data, context=context))
if result.get("decision") == "block":
print(" PASS: Non-allowed command 'rustc' blocked")
passed += 1
else:
print(" FAIL: Non-allowed command 'rustc' should be blocked")
failed += 1
return passed, failed
def test_org_config_loading():
"""Test organization-level config loading."""
print("\nTesting org config loading:\n")
passed = 0
failed = 0
# Save original org config path
original_home = Path.home()
with tempfile.TemporaryDirectory() as tmpdir:
# Temporarily override home directory for testing
import os
os.environ["HOME"] = tmpdir
org_dir = Path(tmpdir) / ".autocoder"
org_dir.mkdir()
org_config_path = org_dir / "config.yaml"
# Test 1: Valid org config
org_config_path.write_text("""version: 1
allowed_commands:
- name: jq
description: JSON processor
blocked_commands:
- aws
- kubectl
""")
config = load_org_config()
if config and config["version"] == 1:
if len(config["allowed_commands"]) == 1 and len(config["blocked_commands"]) == 2:
print(" PASS: Load valid org config")
passed += 1
else:
print(" FAIL: Load valid org config (wrong counts)")
failed += 1
else:
print(" FAIL: Load valid org config")
print(f" Got: {config}")
failed += 1
# Test 2: Missing file returns None
org_config_path.unlink()
config = load_org_config()
if config is None:
print(" PASS: Missing org config returns None")
passed += 1
else:
print(" FAIL: Missing org config returns None")
failed += 1
# Test 3: Non-string command name is rejected
org_config_path.write_text("""version: 1
allowed_commands:
- name: 123
description: Invalid numeric name
""")
config = load_org_config()
if config is None:
print(" PASS: Non-string command name rejected")
passed += 1
else:
print(" FAIL: Non-string command name rejected")
print(f" Got: {config}")
failed += 1
# Test 4: Empty command name is rejected
org_config_path.write_text("""version: 1
allowed_commands:
- name: ""
description: Empty name
""")
config = load_org_config()
if config is None:
print(" PASS: Empty command name rejected")
passed += 1
else:
print(" FAIL: Empty command name rejected")
print(f" Got: {config}")
failed += 1
# Test 5: Whitespace-only command name is rejected
org_config_path.write_text("""version: 1
allowed_commands:
- name: " "
description: Whitespace name
""")
config = load_org_config()
if config is None:
print(" PASS: Whitespace-only command name rejected")
passed += 1
else:
print(" FAIL: Whitespace-only command name rejected")
print(f" Got: {config}")
failed += 1
# Restore HOME
os.environ["HOME"] = str(original_home)
return passed, failed
def test_hierarchy_resolution():
"""Test command hierarchy resolution."""
print("\nTesting hierarchy resolution:\n")
passed = 0
failed = 0
with tempfile.TemporaryDirectory() as tmphome:
with tempfile.TemporaryDirectory() as tmpproject:
# Setup fake home directory
import os
original_home = os.environ.get("HOME")
os.environ["HOME"] = tmphome
org_dir = Path(tmphome) / ".autocoder"
org_dir.mkdir()
org_config_path = org_dir / "config.yaml"
# Create org config with allowed and blocked commands
org_config_path.write_text("""version: 1
allowed_commands:
- name: jq
description: JSON processor
- name: python3
description: Python interpreter
blocked_commands:
- terraform
- kubectl
""")
project_dir = Path(tmpproject)
project_autocoder = project_dir / ".autocoder"
project_autocoder.mkdir()
project_config = project_autocoder / "allowed_commands.yaml"
# Create project config
project_config.write_text("""version: 1
commands:
- name: swift
description: Swift compiler
""")
# Test 1: Org allowed commands are included
allowed, blocked = get_effective_commands(project_dir)
if "jq" in allowed and "python3" in allowed:
print(" PASS: Org allowed commands included")
passed += 1
else:
print(" FAIL: Org allowed commands included")
print(f" jq in allowed: {'jq' in allowed}")
print(f" python3 in allowed: {'python3' in allowed}")
failed += 1
# Test 2: Org blocked commands are in blocklist
if "terraform" in blocked and "kubectl" in blocked:
print(" PASS: Org blocked commands in blocklist")
passed += 1
else:
print(" FAIL: Org blocked commands in blocklist")
failed += 1
# Test 3: Project commands are included
if "swift" in allowed:
print(" PASS: Project commands included")
passed += 1
else:
print(" FAIL: Project commands included")
failed += 1
# Test 4: Global commands are included
if "npm" in allowed and "git" in allowed:
print(" PASS: Global commands included")
passed += 1
else:
print(" FAIL: Global commands included")
failed += 1
# Test 5: Hardcoded blocklist cannot be overridden
if "sudo" in blocked and "shutdown" in blocked:
print(" PASS: Hardcoded blocklist enforced")
passed += 1
else:
print(" FAIL: Hardcoded blocklist enforced")
failed += 1
# Restore HOME
if original_home:
os.environ["HOME"] = original_home
else:
del os.environ["HOME"]
return passed, failed
def test_org_blocklist_enforcement():
"""Test that org-level blocked commands cannot be used."""
print("\nTesting org blocklist enforcement:\n")
passed = 0
failed = 0
with tempfile.TemporaryDirectory() as tmphome:
with tempfile.TemporaryDirectory() as tmpproject:
# Setup fake home directory
import os
original_home = os.environ.get("HOME")
os.environ["HOME"] = tmphome
org_dir = Path(tmphome) / ".autocoder"
org_dir.mkdir()
org_config_path = org_dir / "config.yaml"
# Create org config that blocks terraform
org_config_path.write_text("""version: 1
blocked_commands:
- terraform
""")
project_dir = Path(tmpproject)
project_autocoder = project_dir / ".autocoder"
project_autocoder.mkdir()
# Try to use terraform (should be blocked)
input_data = {"tool_name": "Bash", "tool_input": {"command": "terraform apply"}}
context = {"project_dir": str(project_dir)}
result = asyncio.run(bash_security_hook(input_data, context=context))
if result.get("decision") == "block":
print(" PASS: Org blocked command 'terraform' rejected")
passed += 1
else:
print(" FAIL: Org blocked command 'terraform' should be rejected")
failed += 1
# Restore HOME
if original_home:
os.environ["HOME"] = original_home
else:
del os.environ["HOME"]
return passed, failed
def main():
print("=" * 70)
print(" SECURITY HOOK TESTS")
print("=" * 70)
passed = 0
failed = 0
# Test command extraction
ext_passed, ext_failed = test_extract_commands()
passed += ext_passed
failed += ext_failed
# Test chmod validation
chmod_passed, chmod_failed = test_validate_chmod()
passed += chmod_passed
failed += chmod_failed
# Test init.sh validation
init_passed, init_failed = test_validate_init_script()
passed += init_passed
failed += init_failed
# Test pattern matching (Phase 1)
pattern_passed, pattern_failed = test_pattern_matching()
passed += pattern_passed
failed += pattern_failed
# Test YAML loading (Phase 1)
yaml_passed, yaml_failed = test_yaml_loading()
passed += yaml_passed
failed += yaml_failed
# Test command validation (Phase 1)
validation_passed, validation_failed = test_command_validation()
passed += validation_passed
failed += validation_failed
# Test blocklist enforcement (Phase 1)
blocklist_passed, blocklist_failed = test_blocklist_enforcement()
passed += blocklist_passed
failed += blocklist_failed
# Test project commands (Phase 1)
project_passed, project_failed = test_project_commands()
passed += project_passed
failed += project_failed
# Test org config loading (Phase 2)
org_loading_passed, org_loading_failed = test_org_config_loading()
passed += org_loading_passed
failed += org_loading_failed
# Test hierarchy resolution (Phase 2)
hierarchy_passed, hierarchy_failed = test_hierarchy_resolution()
passed += hierarchy_passed
failed += hierarchy_failed
# Test org blocklist enforcement (Phase 2)
org_block_passed, org_block_failed = test_org_blocklist_enforcement()
passed += org_block_passed
failed += org_block_failed
# Commands that SHOULD be blocked
print("\nCommands that should be BLOCKED:\n")
dangerous = [
# Not in allowlist - dangerous system commands
"shutdown now",
"reboot",
"dd if=/dev/zero of=/dev/sda",
# Not in allowlist - common commands excluded from minimal set
"wget https://example.com",
"python app.py",
"killall node",
# pkill with non-dev processes
"pkill bash",
"pkill chrome",
"pkill python",
# Shell injection attempts
"$(echo pkill) node",
'eval "pkill node"',
# chmod with disallowed modes
"chmod 777 file.sh",
"chmod 755 file.sh",
"chmod +w file.sh",
"chmod -R +x dir/",
# Non-init.sh scripts
"./setup.sh",
"./malicious.sh",
]
for cmd in dangerous:
if check_hook(cmd, should_block=True):
passed += 1
else:
failed += 1
# Commands that SHOULD be allowed
print("\nCommands that should be ALLOWED:\n")
safe = [
# File inspection
"ls -la",
"cat README.md",
"head -100 file.txt",
"tail -20 log.txt",
"wc -l file.txt",
"grep -r pattern src/",
# File operations
"cp file1.txt file2.txt",
"mkdir newdir",
"mkdir -p path/to/dir",
"touch file.txt",
"rm -rf temp/",
"mv old.txt new.txt",
# Directory
"pwd",
# Output
"echo hello",
# Node.js development
"npm install",
"npm run build",
"node server.js",
# Version control
"git status",
"git commit -m 'test'",
"git add . && git commit -m 'msg'",
# Process management
"ps aux",
"lsof -i :3000",
"sleep 2",
"kill 12345",
# Allowed pkill patterns for dev servers
"pkill node",
"pkill npm",
"pkill -f node",
"pkill -f 'node server.js'",
"pkill vite",
# Network/API testing
"curl https://example.com",
# Shell scripts (bash/sh in allowlist)
"bash script.sh",
"sh script.sh",
'bash -c "echo hello"',
# Chained commands
"npm install && npm run build",
"ls | grep test",
# Full paths
"/usr/local/bin/node app.js",
# chmod +x (allowed)
"chmod +x init.sh",
"chmod +x script.sh",
"chmod u+x init.sh",
"chmod a+x init.sh",
# init.sh execution (allowed)
"./init.sh",
"./init.sh --production",
"/path/to/init.sh",
# Combined chmod and init.sh
"chmod +x init.sh && ./init.sh",
]
for cmd in safe:
if check_hook(cmd, should_block=False):
passed += 1
else:
failed += 1
# Summary
print("\n" + "-" * 70)
print(f" Results: {passed} passed, {failed} failed")
print("-" * 70)
if failed == 0:
print("\n ALL TESTS PASSED")
return 0
else:
print(f"\n {failed} TEST(S) FAILED")
return 1
if __name__ == "__main__":
sys.exit(main())