refactor: improve Vertex AI model conversion and add tests

- Rename compute_mode -> convert_model_for_vertex for clarity
- Move `import re` to module top-level (stdlib convention)
- Use greedy regex quantifier for more readable pattern matching
- Restore PEP 8 double blank line between top-level definitions
- Add test_client.py with 10 unit tests covering:
  - Vertex disabled (env unset, "0", empty)
  - Standard conversions (Opus, Sonnet, Haiku)
  - Edge cases (already-converted, non-Claude, no date suffix, empty)

Follow-up improvements from PR #129 review.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Auto
2026-01-30 10:20:03 +02:00
parent 813fcde18b
commit 79d02a1410
2 changed files with 110 additions and 4 deletions

View File

@@ -7,6 +7,7 @@ Functions for creating and configuring the Claude Agent SDK client.
import json
import os
import re
import shutil
import sys
from pathlib import Path
@@ -68,7 +69,7 @@ EXTRA_READ_PATHS_BLOCKLIST = {
".netrc",
}
def compute_mode(model: str) -> str:
def convert_model_for_vertex(model: str) -> str:
"""
Convert model name format for Vertex AI compatibility.
@@ -89,8 +90,7 @@ def compute_mode(model: str) -> str:
# Pattern: claude-{name}-{version}-{date} -> claude-{name}-{version}@{date}
# Example: claude-opus-4-5-20251101 -> claude-opus-4-5@20251101
# The date is always 8 digits at the end
import re
match = re.match(r'^(claude-[a-z0-9-]+?)-(\d{8})$', model)
match = re.match(r'^(claude-.+)-(\d{8})$', model)
if match:
base_name, date = match.groups()
return f"{base_name}@{date}"
@@ -208,6 +208,7 @@ def get_extra_read_paths() -> list[Path]:
return validated_paths
# Feature MCP tools for feature/test management
FEATURE_MCP_TOOLS = [
# Core feature operations
@@ -438,7 +439,7 @@ def create_client(
is_vertex = sdk_env.get("CLAUDE_CODE_USE_VERTEX") == "1"
is_alternative_api = bool(base_url) or is_vertex
is_ollama = "localhost:11434" in base_url or "127.0.0.1:11434" in base_url
model = compute_mode(model)
model = convert_model_for_vertex(model)
if sdk_env:
print(f" - API overrides: {', '.join(sdk_env.keys())}")
if is_vertex: