diff --git a/requirements-prod.txt b/requirements-prod.txt index 1b2f0b2..12d5d32 100644 --- a/requirements-prod.txt +++ b/requirements-prod.txt @@ -12,3 +12,7 @@ aiofiles>=24.0.0 apscheduler>=3.10.0,<4.0.0 pywinpty>=2.0.0; sys_platform == "win32" pyyaml>=6.0.0 +python-docx>=1.1.0 +openpyxl>=3.1.0 +PyPDF2>=3.0.0 +python-pptx>=1.0.0 diff --git a/requirements.txt b/requirements.txt index f042b4d..23c9afd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,6 +10,10 @@ aiofiles>=24.0.0 apscheduler>=3.10.0,<4.0.0 pywinpty>=2.0.0; sys_platform == "win32" pyyaml>=6.0.0 +python-docx>=1.1.0 +openpyxl>=3.1.0 +PyPDF2>=3.0.0 +python-pptx>=1.0.0 # Dev dependencies ruff>=0.8.0 diff --git a/server/routers/expand_project.py b/server/routers/expand_project.py index d680b95..af96161 100644 --- a/server/routers/expand_project.py +++ b/server/routers/expand_project.py @@ -13,7 +13,7 @@ from typing import Optional from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect from pydantic import BaseModel, ValidationError -from ..schemas import ImageAttachment +from ..schemas import FileAttachment from ..services.expand_chat_session import ( ExpandChatSession, create_expand_session, @@ -181,12 +181,12 @@ async def expand_project_websocket(websocket: WebSocket, project_name: str): user_content = message.get("content", "").strip() # Parse attachments if present - attachments: list[ImageAttachment] = [] + attachments: list[FileAttachment] = [] raw_attachments = message.get("attachments", []) if raw_attachments: try: for raw_att in raw_attachments: - attachments.append(ImageAttachment(**raw_att)) + attachments.append(FileAttachment(**raw_att)) except (ValidationError, Exception) as e: logger.warning(f"Invalid attachment data: {e}") await websocket.send_json({ diff --git a/server/routers/spec_creation.py b/server/routers/spec_creation.py index 44b8d04..a333f35 100644 --- a/server/routers/spec_creation.py +++ b/server/routers/spec_creation.py @@ -12,7 +12,7 @@ from typing import Optional from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect from pydantic import BaseModel, ValidationError -from ..schemas import ImageAttachment +from ..schemas import FileAttachment from ..services.spec_chat_session import ( SpecChatSession, create_session, @@ -242,12 +242,12 @@ async def spec_chat_websocket(websocket: WebSocket, project_name: str): user_content = message.get("content", "").strip() # Parse attachments if present - attachments: list[ImageAttachment] = [] + attachments: list[FileAttachment] = [] raw_attachments = message.get("attachments", []) if raw_attachments: try: for raw_att in raw_attachments: - attachments.append(ImageAttachment(**raw_att)) + attachments.append(FileAttachment(**raw_att)) except (ValidationError, Exception) as e: logger.warning(f"Invalid attachment data: {e}") await websocket.send_json({ diff --git a/server/schemas.py b/server/schemas.py index 72124a5..abe5bbc 100644 --- a/server/schemas.py +++ b/server/schemas.py @@ -11,7 +11,7 @@ from datetime import datetime from pathlib import Path from typing import Literal -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field, field_validator, model_validator # Import model constants from registry (single source of truth) _root = Path(__file__).parent.parent @@ -331,36 +331,61 @@ class WSAgentUpdateMessage(BaseModel): # ============================================================================ -# Spec Chat Schemas +# Chat Attachment Schemas # ============================================================================ -# Maximum image file size: 5 MB -MAX_IMAGE_SIZE = 5 * 1024 * 1024 +# Size limits +MAX_IMAGE_SIZE = 5 * 1024 * 1024 # 5 MB for images +MAX_DOCUMENT_SIZE = 20 * 1024 * 1024 # 20 MB for documents + +_IMAGE_MIME_TYPES = {'image/jpeg', 'image/png'} -class ImageAttachment(BaseModel): - """Image attachment from client for spec creation chat.""" +class FileAttachment(BaseModel): + """File attachment from client for spec creation / expand project chat.""" filename: str = Field(..., min_length=1, max_length=255) - mimeType: Literal['image/jpeg', 'image/png'] + mimeType: Literal[ + 'image/jpeg', 'image/png', + 'text/plain', 'text/markdown', 'text/csv', + 'application/vnd.openxmlformats-officedocument.wordprocessingml.document', + 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', + 'application/pdf', + 'application/vnd.openxmlformats-officedocument.presentationml.presentation', + ] base64Data: str @field_validator('base64Data') @classmethod - def validate_base64_and_size(cls, v: str) -> str: - """Validate that base64 data is valid and within size limit.""" + def validate_base64(cls, v: str) -> str: + """Validate that base64 data is decodable.""" try: - decoded = base64.b64decode(v) - if len(decoded) > MAX_IMAGE_SIZE: - raise ValueError( - f'Image size ({len(decoded) / (1024 * 1024):.1f} MB) exceeds ' - f'maximum of {MAX_IMAGE_SIZE // (1024 * 1024)} MB' - ) + base64.b64decode(v) return v except Exception as e: - if 'Image size' in str(e): - raise raise ValueError(f'Invalid base64 data: {e}') + @model_validator(mode='after') + def validate_size(self) -> 'FileAttachment': + """Validate file size based on MIME type.""" + try: + decoded = base64.b64decode(self.base64Data) + except Exception: + return self # Already caught by field validator + + if self.mimeType in _IMAGE_MIME_TYPES: + max_size = MAX_IMAGE_SIZE + label = "Image" + else: + max_size = MAX_DOCUMENT_SIZE + label = "Document" + + if len(decoded) > max_size: + raise ValueError( + f'{label} size ({len(decoded) / (1024 * 1024):.1f} MB) exceeds ' + f'maximum of {max_size // (1024 * 1024)} MB' + ) + return self + # ============================================================================ # Filesystem Schemas diff --git a/server/services/chat_constants.py b/server/services/chat_constants.py index 16a41fd..2e832e2 100644 --- a/server/services/chat_constants.py +++ b/server/services/chat_constants.py @@ -35,6 +35,13 @@ if _root_str not in sys.path: from env_constants import API_ENV_VARS # noqa: E402, F401 from rate_limit_utils import is_rate_limit_error, parse_retry_after # noqa: E402, F401 +from ..schemas import FileAttachment +from ..utils.document_extraction import ( + extract_text_from_document, + is_document, + is_image, +) + logger = logging.getLogger(__name__) @@ -88,6 +95,35 @@ async def safe_receive_response(client: Any, log: logging.Logger) -> AsyncGenera raise +def build_attachment_content_blocks(attachments: list[FileAttachment]) -> list[dict]: + """Convert FileAttachment objects to Claude API content blocks. + + Images become image content blocks (passed directly to Claude's vision). + Documents are extracted to text and become text content blocks. + + Raises: + DocumentExtractionError: If a document cannot be read. + """ + blocks: list[dict] = [] + for att in attachments: + if is_image(att.mimeType): + blocks.append({ + "type": "image", + "source": { + "type": "base64", + "media_type": att.mimeType, + "data": att.base64Data, + } + }) + elif is_document(att.mimeType): + text = extract_text_from_document(att.base64Data, att.mimeType, att.filename) + blocks.append({ + "type": "text", + "text": f"[Content of uploaded file: {att.filename}]\n\n{text}", + }) + return blocks + + async def make_multimodal_message(content_blocks: list[dict]) -> AsyncGenerator[dict, None]: """Yield a single multimodal user message in Claude Agent SDK format. diff --git a/server/services/expand_chat_session.py b/server/services/expand_chat_session.py index 35a2f5c..00a0926 100644 --- a/server/services/expand_chat_session.py +++ b/server/services/expand_chat_session.py @@ -21,9 +21,11 @@ from typing import Any, AsyncGenerator, Optional from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient from dotenv import load_dotenv -from ..schemas import ImageAttachment +from ..schemas import FileAttachment +from ..utils.document_extraction import DocumentExtractionError from .chat_constants import ( ROOT_DIR, + build_attachment_content_blocks, check_rate_limit_error, make_multimodal_message, safe_receive_response, @@ -226,7 +228,7 @@ class ExpandChatSession: async def send_message( self, user_message: str, - attachments: list[ImageAttachment] | None = None + attachments: list[FileAttachment] | None = None ) -> AsyncGenerator[dict, None]: """ Send user message and stream Claude's response. @@ -273,7 +275,7 @@ class ExpandChatSession: async def _query_claude( self, message: str, - attachments: list[ImageAttachment] | None = None + attachments: list[FileAttachment] | None = None ) -> AsyncGenerator[dict, None]: """ Internal method to query Claude and stream responses. @@ -289,17 +291,16 @@ class ExpandChatSession: content_blocks: list[dict[str, Any]] = [] if message: content_blocks.append({"type": "text", "text": message}) - for att in attachments: - content_blocks.append({ - "type": "image", - "source": { - "type": "base64", - "media_type": att.mimeType, - "data": att.base64Data, - } - }) + + # Add attachment blocks (images as image blocks, documents as extracted text) + try: + content_blocks.extend(build_attachment_content_blocks(attachments)) + except DocumentExtractionError as e: + yield {"type": "error", "content": str(e)} + return + await self.client.query(make_multimodal_message(content_blocks)) - logger.info(f"Sent multimodal message with {len(attachments)} image(s)") + logger.info(f"Sent multimodal message with {len(attachments)} attachment(s)") else: await self.client.query(message) diff --git a/server/services/spec_chat_session.py b/server/services/spec_chat_session.py index a6b5598..9e853c3 100644 --- a/server/services/spec_chat_session.py +++ b/server/services/spec_chat_session.py @@ -18,9 +18,11 @@ from typing import Any, AsyncGenerator, Optional from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient from dotenv import load_dotenv -from ..schemas import ImageAttachment +from ..schemas import FileAttachment +from ..utils.document_extraction import DocumentExtractionError from .chat_constants import ( ROOT_DIR, + build_attachment_content_blocks, check_rate_limit_error, make_multimodal_message, safe_receive_response, @@ -201,7 +203,7 @@ class SpecChatSession: async def send_message( self, user_message: str, - attachments: list[ImageAttachment] | None = None + attachments: list[FileAttachment] | None = None ) -> AsyncGenerator[dict, None]: """ Send user message and stream Claude's response. @@ -247,7 +249,7 @@ class SpecChatSession: async def _query_claude( self, message: str, - attachments: list[ImageAttachment] | None = None + attachments: list[FileAttachment] | None = None ) -> AsyncGenerator[dict, None]: """ Internal method to query Claude and stream responses. @@ -273,21 +275,17 @@ class SpecChatSession: if message: content_blocks.append({"type": "text", "text": message}) - # Add image blocks - for att in attachments: - content_blocks.append({ - "type": "image", - "source": { - "type": "base64", - "media_type": att.mimeType, - "data": att.base64Data, - } - }) + # Add attachment blocks (images as image blocks, documents as extracted text) + try: + content_blocks.extend(build_attachment_content_blocks(attachments)) + except DocumentExtractionError as e: + yield {"type": "error", "content": str(e)} + return # Send multimodal content to Claude using async generator format # The SDK's query() accepts AsyncIterable[dict] for custom message formats await self.client.query(make_multimodal_message(content_blocks)) - logger.info(f"Sent multimodal message with {len(attachments)} image(s)") + logger.info(f"Sent multimodal message with {len(attachments)} attachment(s)") else: # Text-only message: use string format await self.client.query(message) diff --git a/server/utils/document_extraction.py b/server/utils/document_extraction.py new file mode 100644 index 0000000..b0c13d8 --- /dev/null +++ b/server/utils/document_extraction.py @@ -0,0 +1,221 @@ +""" +Document Extraction Utility +============================ + +Extracts text content from various document formats in memory (no disk I/O). +Supports: TXT, MD, CSV, DOCX, XLSX, PDF, PPTX. +""" + +import base64 +import csv +import io +import logging + +logger = logging.getLogger(__name__) + +# Maximum characters of extracted text to send to Claude +MAX_EXTRACTED_CHARS = 200_000 + +# Maximum rows per sheet for Excel files +MAX_EXCEL_ROWS_PER_SHEET = 10_000 +MAX_EXCEL_SHEETS = 50 + +# MIME type classification +DOCUMENT_MIME_TYPES: dict[str, str] = { + "text/plain": ".txt", + "text/markdown": ".md", + "text/csv": ".csv", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx", + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx", + "application/pdf": ".pdf", + "application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx", +} + +IMAGE_MIME_TYPES = {"image/jpeg", "image/png"} + +ALL_ALLOWED_MIME_TYPES = IMAGE_MIME_TYPES | set(DOCUMENT_MIME_TYPES.keys()) + + +def is_image(mime_type: str) -> bool: + """Check if the MIME type is a supported image format.""" + return mime_type in IMAGE_MIME_TYPES + + +def is_document(mime_type: str) -> bool: + """Check if the MIME type is a supported document format.""" + return mime_type in DOCUMENT_MIME_TYPES + + +class DocumentExtractionError(Exception): + """Raised when text extraction from a document fails.""" + + def __init__(self, filename: str, reason: str): + self.filename = filename + self.reason = reason + super().__init__(f"Failed to read {filename}: {reason}") + + +def _truncate(text: str) -> str: + """Truncate text if it exceeds the maximum character limit.""" + if len(text) > MAX_EXTRACTED_CHARS: + omitted = len(text) - MAX_EXTRACTED_CHARS + return text[:MAX_EXTRACTED_CHARS] + f"\n\n[... truncated, {omitted:,} characters omitted]" + return text + + +def _extract_plain_text(data: bytes) -> str: + """Extract text from plain text or markdown files.""" + try: + return data.decode("utf-8") + except UnicodeDecodeError: + return data.decode("latin-1") + + +def _extract_csv(data: bytes) -> str: + """Extract text from CSV files, formatted as a readable table.""" + try: + text = data.decode("utf-8") + except UnicodeDecodeError: + text = data.decode("latin-1") + + reader = csv.reader(io.StringIO(text)) + lines = [] + for i, row in enumerate(reader): + lines.append(f"Row {i + 1}: {', '.join(row)}") + return "\n".join(lines) + + +def _extract_docx(data: bytes) -> str: + """Extract text from Word documents.""" + from docx import Document + + doc = Document(io.BytesIO(data)) + paragraphs = [p.text for p in doc.paragraphs if p.text.strip()] + return "\n\n".join(paragraphs) + + +def _extract_xlsx(data: bytes) -> str: + """Extract text from Excel spreadsheets.""" + from openpyxl import load_workbook + + wb = load_workbook(io.BytesIO(data), read_only=True, data_only=True) + sections = [] + + for sheet_idx, sheet_name in enumerate(wb.sheetnames): + if sheet_idx >= MAX_EXCEL_SHEETS: + sections.append(f"\n[... {len(wb.sheetnames) - MAX_EXCEL_SHEETS} more sheets omitted]") + break + + ws = wb[sheet_name] + rows_text = [f"=== Sheet: {sheet_name} ==="] + row_count = 0 + + for row in ws.iter_rows(values_only=True): + if row_count >= MAX_EXCEL_ROWS_PER_SHEET: + rows_text.append(f"[... more rows omitted, limit {MAX_EXCEL_ROWS_PER_SHEET:,} rows/sheet]") + break + cells = [str(cell) if cell is not None else "" for cell in row] + rows_text.append("\t".join(cells)) + row_count += 1 + + sections.append("\n".join(rows_text)) + + wb.close() + return "\n\n".join(sections) + + +def _extract_pdf(data: bytes, filename: str) -> str: + """Extract text from PDF files.""" + from PyPDF2 import PdfReader + from PyPDF2.errors import PdfReadError + + try: + reader = PdfReader(io.BytesIO(data)) + except PdfReadError as e: + if "encrypt" in str(e).lower() or "password" in str(e).lower(): + raise DocumentExtractionError(filename, "PDF is password-protected") + raise + + if reader.is_encrypted: + raise DocumentExtractionError(filename, "PDF is password-protected") + + pages = [] + for i, page in enumerate(reader.pages): + text = page.extract_text() + if text and text.strip(): + pages.append(f"--- Page {i + 1} ---\n{text}") + + return "\n\n".join(pages) + + +def _extract_pptx(data: bytes) -> str: + """Extract text from PowerPoint presentations.""" + from pptx import Presentation + + prs = Presentation(io.BytesIO(data)) + slides_text = [] + + for i, slide in enumerate(prs.slides): + texts = [] + for shape in slide.shapes: + if shape.has_text_frame: + for paragraph in shape.text_frame.paragraphs: + text = paragraph.text.strip() + if text: + texts.append(text) + if texts: + slides_text.append(f"--- Slide {i + 1} ---\n" + "\n".join(texts)) + + return "\n\n".join(slides_text) + + +def extract_text_from_document(base64_data: str, mime_type: str, filename: str) -> str: + """ + Extract text content from a document file. + + Args: + base64_data: Base64-encoded file content + mime_type: MIME type of the document + filename: Original filename (for error messages) + + Returns: + Extracted text content, truncated if necessary + + Raises: + DocumentExtractionError: If extraction fails + """ + if mime_type not in DOCUMENT_MIME_TYPES: + raise DocumentExtractionError(filename, f"unsupported document type: {mime_type}") + + try: + data = base64.b64decode(base64_data) + except Exception as e: + raise DocumentExtractionError(filename, f"invalid base64 data: {e}") + + try: + if mime_type in ("text/plain", "text/markdown"): + text = _extract_plain_text(data) + elif mime_type == "text/csv": + text = _extract_csv(data) + elif mime_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document": + text = _extract_docx(data) + elif mime_type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": + text = _extract_xlsx(data) + elif mime_type == "application/pdf": + text = _extract_pdf(data, filename) + elif mime_type == "application/vnd.openxmlformats-officedocument.presentationml.presentation": + text = _extract_pptx(data) + else: + raise DocumentExtractionError(filename, f"unsupported document type: {mime_type}") + except DocumentExtractionError: + raise + except Exception as e: + logger.warning(f"Document extraction failed for {filename}: {e}") + raise DocumentExtractionError( + filename, "file appears to be corrupt or in an unexpected format" + ) + + if not text or not text.strip(): + return f"[File {filename} is empty or contains no extractable text]" + + return _truncate(text) diff --git a/ui/src/components/ChatMessage.tsx b/ui/src/components/ChatMessage.tsx index fe87407..eff6f5a 100644 --- a/ui/src/components/ChatMessage.tsx +++ b/ui/src/components/ChatMessage.tsx @@ -6,10 +6,11 @@ */ import { memo } from 'react' -import { Bot, User, Info } from 'lucide-react' +import { Bot, User, Info, FileText } from 'lucide-react' import ReactMarkdown, { type Components } from 'react-markdown' import remarkGfm from 'remark-gfm' import type { ChatMessage as ChatMessageType } from '../lib/types' +import { isImageAttachment } from '../lib/types' import { Card } from '@/components/ui/card' interface ChatMessageProps { @@ -104,21 +105,35 @@ export const ChatMessage = memo(function ChatMessage({ message }: ChatMessagePro )} - {/* Display image attachments */} + {/* Display file attachments */} {attachments && attachments.length > 0 && (
- Press Enter to send. Drag & drop or click
- Press Enter to send, Shift+Enter for new line. Drag & drop or click