diff --git a/src/api/middleware/logging_middleware.py b/src/api/middleware/logging_middleware.py index 305ecee..fde2eb2 100644 --- a/src/api/middleware/logging_middleware.py +++ b/src/api/middleware/logging_middleware.py @@ -4,141 +4,218 @@ Logging middleware for request/response debugging. import json import logging import time -from typing import Callable -from fastapi import Request, Response -from starlette.middleware.base import BaseHTTPMiddleware -from starlette.responses import StreamingResponse +from typing import Callable, Awaitable + +from starlette.types import ASGIApp, Scope, Receive, Send, Message +from starlette.requests import Request +from starlette.datastructures import Headers from core.config import get_settings logger = logging.getLogger(__name__) -class LoggingMiddleware(BaseHTTPMiddleware): +class LoggingMiddleware: """ - Middleware to log detailed request and response information. + Pure ASGI middleware to log detailed request and response information. Activated when LOG_REQUESTS or LOG_RESPONSES is enabled. + + Uses pure ASGI interface to avoid compatibility issues with streaming responses. """ - async def dispatch(self, request: Request, call_next: Callable) -> Response: - """Process request and log details.""" - settings = get_settings() + def __init__(self, app: ASGIApp): + self.app = app - # Generate request ID for tracking + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """Process ASGI request.""" + if scope["type"] != "http": + # Only handle HTTP requests + await self.app(scope, receive, send) + return + + settings = get_settings() + if not (settings.log_requests or settings.log_responses or settings.debug): + # Logging disabled, pass through + await self.app(scope, receive, send) + return + + # Generate request ID request_id = f"req_{int(time.time() * 1000)}" - # Log request details - if settings.log_requests or settings.debug: - await self._log_request(request, request_id) + # Cache request body for logging + body_cache = None + body_received = False - # Record start time + async def receive_wrapper() -> Message: + """Wrap receive to cache request body.""" + nonlocal body_cache, body_received + message = await receive() + + if message["type"] == "http.request" and not body_received: + body_received = True + body_cache = message.get("body", b"") + + # Log request after receiving body + if settings.log_requests or settings.debug: + await self._log_request_with_body(scope, body_cache, request_id) + + return message + + # Track response start_time = time.time() + response_started = False + status_code = None + response_headers = None + is_streaming = False + response_body_chunks = [] # Accumulate response body chunks - # Process request - response = await call_next(request) + async def send_wrapper(message: Message) -> None: + """Wrap send to capture response details.""" + nonlocal response_started, status_code, response_headers, is_streaming, response_body_chunks - # Calculate processing time - process_time = time.time() - start_time + if message["type"] == "http.response.start": + response_started = True + status_code = message["status"] + response_headers = Headers(raw=message["headers"]) - # Log response details (only for non-streaming responses) - if (settings.log_responses or settings.debug) and not isinstance(response, StreamingResponse): - await self._log_response(response, request_id, process_time) - elif isinstance(response, StreamingResponse): - logger.debug(f"[{request_id}] Response: Streaming response (not logged)") + # Check if streaming by content-type + content_type = response_headers.get("content-type", "") + is_streaming = any( + st in content_type + for st in ["text/event-stream", "application/x-ndjson", "application/stream+json"] + ) - return response + # Log response start (for streaming, initial log) + if (settings.log_responses or settings.debug) and is_streaming and not settings.log_streaming: + # Only log headers if LOG_STREAMING is disabled + process_time = time.time() - start_time + logger.debug("=" * 80) + logger.debug(f"[{request_id}] 📤 OUTGOING RESPONSE (Streaming)") + logger.debug("=" * 80) + logger.debug(f"Status Code: {status_code}") + logger.debug(f"Processing Time: {process_time:.3f}s") + logger.debug(f"Content-Type: {content_type}") + logger.debug("Response Body: [Streaming - not logged (set LOG_STREAMING=true to enable)]") + logger.debug("=" * 80) - async def _log_request(self, request: Request, request_id: str): - """Log detailed request information.""" + elif message["type"] == "http.response.body": + # Accumulate body chunks if logging streaming content + if is_streaming and settings.log_streaming: + body_chunk = message.get("body", b"") + if body_chunk: + response_body_chunks.append(body_chunk) + + # Check if this is the last chunk + more_body = message.get("more_body", False) + if not more_body and is_streaming and settings.log_streaming and response_body_chunks: + # Log complete streaming response + process_time = time.time() - start_time + full_body = b"".join(response_body_chunks) + await self._log_streaming_response( + request_id, status_code, response_headers, full_body, process_time + ) + + await send(message) + + # Call app with wrapped receive/send + await self.app(scope, receive_wrapper, send_wrapper) + + async def _log_request_with_body(self, scope: Scope, body: bytes, request_id: str): + """Log detailed request information with cached body.""" try: + # Extract request info from scope + method = scope.get("method", "") + path = scope.get("path", "") + query_string = scope.get("query_string", b"").decode() + # Basic request info logger.debug("=" * 80) logger.debug(f"[{request_id}] 📨 INCOMING REQUEST") logger.debug("=" * 80) - logger.debug(f"Method: {request.method}") - logger.debug(f"URL: {request.url.path}") - logger.debug(f"Query Params: {dict(request.query_params)}") + logger.debug(f"Method: {method}") + logger.debug(f"URL: {path}") + if query_string: + logger.debug(f"Query String: {query_string}") # Headers (filter sensitive data) - headers = dict(request.headers) + headers = {} + for name, value in scope.get("headers", []): + name_str = name.decode("latin1") + value_str = value.decode("latin1") + headers[name_str] = value_str + if "authorization" in headers: headers["authorization"] = "Bearer ***" logger.debug(f"Headers: {json.dumps(headers, indent=2)}") # Request body - if request.method in ["POST", "PUT", "PATCH"]: - body = await self._read_body(request) - if body: + if method in ["POST", "PUT", "PATCH"] and body: + try: + # Try to parse and pretty-print JSON + body_json = json.loads(body.decode()) + logger.debug(f"Request Body:\n{json.dumps(body_json, indent=2, ensure_ascii=False)}") + except (json.JSONDecodeError, UnicodeDecodeError): + # Log raw body if not JSON try: - # Try to parse and pretty-print JSON - body_json = json.loads(body) - logger.debug(f"Request Body:\n{json.dumps(body_json, indent=2, ensure_ascii=False)}") - except json.JSONDecodeError: - # Log raw body if not JSON - logger.debug(f"Request Body (raw): {body[:1000]}...") # Limit to 1000 chars + logger.debug(f"Request Body (raw): {body.decode()[:1000]}...") + except: + logger.debug(f"Request Body (binary): {len(body)} bytes") logger.debug("=" * 80) except Exception as e: logger.error(f"Error logging request: {e}") - async def _log_response(self, response: Response, request_id: str, process_time: float): - """Log detailed response information.""" + async def _log_streaming_response( + self, request_id: str, status_code: int, headers: Headers, body: bytes, process_time: float + ): + """Log streaming response with full content.""" try: logger.debug("=" * 80) - logger.debug(f"[{request_id}] 📤 OUTGOING RESPONSE") + logger.debug(f"[{request_id}] 📤 OUTGOING RESPONSE (Streaming - Complete)") logger.debug("=" * 80) - logger.debug(f"Status Code: {response.status_code}") + logger.debug(f"Status Code: {status_code}") logger.debug(f"Processing Time: {process_time:.3f}s") + logger.debug(f"Content-Type: {headers.get('content-type', 'N/A')}") + logger.debug(f"Total Size: {len(body)} bytes") - # Headers - headers = dict(response.headers) - logger.debug(f"Headers: {json.dumps(headers, indent=2)}") + # Parse and log SSE events + try: + body_str = body.decode("utf-8") + # Split by SSE event boundaries + events = [e.strip() for e in body_str.split("\n\n") if e.strip()] + logger.debug(f"SSE Events Count: {len(events)}") - # Response body - body = b"" - async for chunk in response.body_iterator: - body += chunk + # Log first few events (limit to avoid huge logs) + max_events_to_log = 10 + logger.debug("Response Body (SSE Events):") + for i, event in enumerate(events[:max_events_to_log]): + if event.startswith("data: "): + data_content = event[6:] # Remove "data: " prefix + if data_content == "[DONE]": + logger.debug(f" Event {i+1}: [DONE]") + else: + try: + # Try to parse and pretty-print JSON + event_json = json.loads(data_content) + logger.debug(f" Event {i+1}:") + logger.debug(f" {json.dumps(event_json, ensure_ascii=False)}") + except json.JSONDecodeError: + logger.debug(f" Event {i+1}: {data_content[:200]}...") + else: + logger.debug(f" Event {i+1}: {event[:200]}...") - # Recreate response body iterator - response.body_iterator = self._create_body_iterator(body) + if len(events) > max_events_to_log: + logger.debug(f" ... and {len(events) - max_events_to_log} more events") - if body: - try: - # Try to parse and pretty-print JSON - body_json = json.loads(body.decode()) - logger.debug(f"Response Body:\n{json.dumps(body_json, indent=2, ensure_ascii=False)}") - except (json.JSONDecodeError, UnicodeDecodeError): - # Log raw body if not JSON - logger.debug(f"Response Body (raw): {body[:1000]}...") # Limit to 1000 chars + except (UnicodeDecodeError, Exception) as e: + logger.debug(f"Response Body (raw): {body[:1000]}...") + logger.debug(f"(Could not parse SSE events: {e})") logger.debug("=" * 80) except Exception as e: - logger.error(f"Error logging response: {e}") - - async def _read_body(self, request: Request) -> str: - """Read request body without consuming it.""" - try: - body = await request.body() - - # Create new receive callable to preserve body for downstream handlers - async def receive(): - return {"type": "http.request", "body": body} - - # Replace request's receive - request._receive = receive - - return body.decode() - except Exception as e: - logger.error(f"Error reading request body: {e}") - return "" - - def _create_body_iterator(self, body: bytes): - """Create an async iterator for response body.""" - async def iterator(): - yield body - return iterator() + logger.error(f"Error logging streaming response: {e}") def setup_logging_middleware(app): @@ -148,9 +225,13 @@ def setup_logging_middleware(app): """ settings = get_settings() if settings.log_requests or settings.log_responses or settings.debug: + # Add pure ASGI middleware app.add_middleware(LoggingMiddleware) + logger.info("🔍 Request/Response logging middleware enabled") if settings.log_requests or settings.debug: logger.info(" - Request logging: ON") if settings.log_responses or settings.debug: logger.info(" - Response logging: ON") + if settings.log_streaming: + logger.info(" - Streaming content logging: ON (⚠️ increases memory usage)") diff --git a/src/core/config.py b/src/core/config.py index b5f6ad5..c9f9107 100644 --- a/src/core/config.py +++ b/src/core/config.py @@ -58,6 +58,7 @@ class Settings(BaseSettings): log_level: str = "INFO" log_requests: bool = False log_responses: bool = False + log_streaming: bool = False # Log streaming response content (may increase memory usage) log_file: Optional[str] = None log_file_max_size: int = 10 # MB log_file_backup_count: int = 5