Source code for sqlspec.extensions.fastapi.extension

from typing import TYPE_CHECKING, Any, overload

from fastapi import FastAPI, Request

from sqlspec.base import SQLSpec
from sqlspec.extensions.fastapi.providers import DEPENDENCY_DEFAULTS
from sqlspec.extensions.fastapi.providers import provide_filters as _provide_filters
from sqlspec.extensions.starlette.extension import SQLSpecPlugin as _StarlettePlugin

if TYPE_CHECKING:
    from collections.abc import Callable

    from sqlspec.config import AsyncDatabaseConfig, SyncDatabaseConfig
    from sqlspec.core import FilterTypes
    from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase
    from sqlspec.extensions.fastapi.providers import DependencyDefaults, FilterConfig

    # Type aliases for static analysis - IDEs see the real types
    _AsyncSession = AsyncDriverAdapterBase
    _SyncSession = SyncDriverAdapterBase
    _Session = AsyncDriverAdapterBase | SyncDriverAdapterBase
else:
    # Runtime fallback - FastAPI sees Any (avoids NameError)
    _AsyncSession = Any
    _SyncSession = Any
    _Session = Any

__all__ = ("SQLSpecPlugin",)


[docs] class SQLSpecPlugin(_StarlettePlugin): """SQLSpec integration for FastAPI applications. Extends Starlette integration with dependency injection helpers for FastAPI's Depends() system. """
[docs] def __init__(self, sqlspec: SQLSpec, app: "FastAPI | None" = None) -> None: """Initialize SQLSpec FastAPI extension. Args: sqlspec: Pre-configured SQLSpec instance with registered configs. app: Optional FastAPI application to initialize immediately. """ super().__init__(sqlspec, app)
def _extract_extension_settings(self, config: Any) -> "dict[str, Any]": """Extract FastAPI settings from config.extension_config. Args: config: Database configuration instance. Returns: Dictionary of FastAPI-specific settings. """ fastapi_config = config.extension_config.get("fastapi", {}) connection_key = fastapi_config.get("connection_key", "db_connection") pool_key = fastapi_config.get("pool_key", "db_pool") session_key = fastapi_config.get("session_key", "db_session") commit_mode = fastapi_config.get("commit_mode", "manual") if not config.supports_connection_pooling and pool_key == "db_pool": pool_key = f"_db_pool_{id(config)}" correlation_headers = fastapi_config.get("correlation_headers") if correlation_headers is not None: correlation_headers = tuple(correlation_headers) return { "connection_key": connection_key, "pool_key": pool_key, "session_key": session_key, "commit_mode": commit_mode, "extra_commit_statuses": fastapi_config.get("extra_commit_statuses"), "extra_rollback_statuses": fastapi_config.get("extra_rollback_statuses"), "disable_di": fastapi_config.get("disable_di", False), "enable_correlation_middleware": fastapi_config.get("enable_correlation_middleware", False), "correlation_header": fastapi_config.get("correlation_header", "x-request-id"), "correlation_headers": correlation_headers, "auto_trace_headers": fastapi_config.get("auto_trace_headers", True), "enable_sqlcommenter_middleware": fastapi_config.get("enable_sqlcommenter_middleware", True), "sqlcommenter_framework": fastapi_config.get("sqlcommenter_framework", "fastapi"), } @overload def provide_session( self, key: None = None ) -> "Callable[[Request], AsyncDriverAdapterBase | SyncDriverAdapterBase]": ... @overload def provide_session(self, key: str) -> "Callable[[Request], AsyncDriverAdapterBase | SyncDriverAdapterBase]": ... @overload def provide_session(self, key: "type[AsyncDatabaseConfig]") -> "Callable[[Request], AsyncDriverAdapterBase]": ... @overload def provide_session(self, key: "type[SyncDatabaseConfig]") -> "Callable[[Request], SyncDriverAdapterBase]": ... @overload def provide_session(self, key: "AsyncDatabaseConfig") -> "Callable[[Request], AsyncDriverAdapterBase]": ... @overload def provide_session(self, key: "SyncDatabaseConfig") -> "Callable[[Request], SyncDriverAdapterBase]": ...
[docs] def provide_session( self, key: "str | type[AsyncDatabaseConfig | SyncDatabaseConfig] | AsyncDatabaseConfig | SyncDatabaseConfig | None" = None, ) -> "Callable[[Request], AsyncDriverAdapterBase | SyncDriverAdapterBase]": """Create dependency factory for session injection. Returns a callable that can be used with FastAPI's Depends() to inject a database session into route handlers. Args: key: Optional session key (str), config type for type narrowing, or None. Returns: Dependency callable for FastAPI Depends(). """ # Extract string key if provided, ignore config types/instances (used only for type narrowing) session_key = key if isinstance(key, str) or key is None else None def dependency(request: Request) -> _Session: return self.get_session(request, session_key) # type: ignore[no-any-return] return dependency
[docs] def provide_async_session(self, key: "str | None" = None) -> "Callable[[Request], AsyncDriverAdapterBase]": """Create dependency factory for async session injection. Type-narrowed version of provide_session() that returns AsyncDriverAdapterBase. Useful when using string keys and you know the config is async. Args: key: Optional session key for multi-database configurations. Returns: Dependency callable that returns AsyncDriverAdapterBase. """ def dependency(request: Request) -> _AsyncSession: return self.get_session(request, key) # type: ignore[no-any-return] return dependency
[docs] def provide_sync_session(self, key: "str | None" = None) -> "Callable[[Request], SyncDriverAdapterBase]": """Create dependency factory for sync session injection. Type-narrowed version of provide_session() that returns SyncDriverAdapterBase. Useful when using string keys and you know the config is sync. Args: key: Optional session key for multi-database configurations. Returns: Dependency callable that returns SyncDriverAdapterBase. """ def dependency(request: Request) -> _SyncSession: return self.get_session(request, key) # type: ignore[no-any-return] return dependency
@overload def provide_connection(self, key: None = None) -> "Callable[[Request], Any]": ... @overload def provide_connection(self, key: str) -> "Callable[[Request], Any]": ... @overload def provide_connection(self, key: "type[AsyncDatabaseConfig]") -> "Callable[[Request], Any]": ... @overload def provide_connection(self, key: "type[SyncDatabaseConfig]") -> "Callable[[Request], Any]": ... @overload def provide_connection(self, key: "AsyncDatabaseConfig") -> "Callable[[Request], Any]": ... @overload def provide_connection(self, key: "SyncDatabaseConfig") -> "Callable[[Request], Any]": ...
[docs] def provide_connection( self, key: "str | type[AsyncDatabaseConfig | SyncDatabaseConfig] | AsyncDatabaseConfig | SyncDatabaseConfig | None" = None, ) -> "Callable[[Request], Any]": """Create dependency factory for connection injection. Returns a callable that can be used with FastAPI's Depends() to inject a database connection into route handlers. Args: key: Optional session key (str), config type for type narrowing, or None. Returns: Dependency callable for FastAPI Depends(). """ # Extract string key if provided, ignore config types/instances (used only for type narrowing) connection_key = key if isinstance(key, str) or key is None else None def dependency(request: Request) -> Any: return self.get_connection(request, connection_key) return dependency
[docs] def provide_async_connection(self, key: "str | None" = None) -> "Callable[[Request], Any]": """Create dependency factory for async connection injection. Type-narrowed version of provide_connection() for async connections. Useful when using string keys and you know the config is async. Args: key: Optional session key for multi-database configurations. Returns: Dependency callable for async connection. """ def dependency(request: Request) -> Any: return self.get_connection(request, key) return dependency
[docs] def provide_sync_connection(self, key: "str | None" = None) -> "Callable[[Request], Any]": """Create dependency factory for sync connection injection. Type-narrowed version of provide_connection() for sync connections. Useful when using string keys and you know the config is sync. Args: key: Optional session key for multi-database configurations. Returns: Dependency callable for sync connection. """ def dependency(request: Request) -> Any: return self.get_connection(request, key) return dependency
[docs] @staticmethod def provide_filters( config: "FilterConfig", dep_defaults: "DependencyDefaults | None" = None ) -> "Callable[..., list[FilterTypes]]": """Create filter dependency for FastAPI routes. Dynamically generates a FastAPI dependency function that parses query parameters into SQLSpec filter objects. The returned callable can be used with FastAPI's Depends() for automatic filter injection. Args: config: Filter configuration specifying which filters to enable. dep_defaults: Optional dependency defaults for customization. Returns: Callable for use with Depends() that returns list of filters. """ if dep_defaults is None: dep_defaults = DEPENDENCY_DEFAULTS return _provide_filters(config, dep_defaults=dep_defaults)