diff --git a/python/sqlcommenter-python/google/cloud/sqlcommenter/fastapi.py b/python/sqlcommenter-python/google/cloud/sqlcommenter/fastapi.py index 3013492a..0fcc5f36 100644 --- a/python/sqlcommenter-python/google/cloud/sqlcommenter/fastapi.py +++ b/python/sqlcommenter-python/google/cloud/sqlcommenter/fastapi.py @@ -95,5 +95,5 @@ def _get_fastapi_route(fastapi_app: FastAPI, scope) -> Optional[Route]: # and return the route name if found. match, child_scope = route.matches(scope) if match == Match.FULL: - return child_scope["route"] + return child_scope.get("route") return None diff --git a/python/sqlcommenter-python/tests/fastapi/app.py b/python/sqlcommenter-python/tests/fastapi/app.py index 24dcc8ce..6bb88625 100644 --- a/python/sqlcommenter-python/tests/fastapi/app.py +++ b/python/sqlcommenter-python/tests/fastapi/app.py @@ -5,7 +5,9 @@ from google.cloud.sqlcommenter.fastapi import ( SQLCommenterMiddleware, get_fastapi_info, ) +from starlette.applications import Starlette from starlette.exceptions import HTTPException as StarletteHTTPException +from starlette.routing import Route app = FastAPI(title="SQLCommenter") @@ -28,3 +30,15 @@ async def custom_http_exception_handler(request, exc): status_code=status.HTTP_404_NOT_FOUND, content=get_fastapi_info(), ) + + +def starlette_endpoint(_): + return JSONResponse({"from": "starlette"}) + + +starlette_subapi = Starlette(routes=[ + Route("/", starlette_endpoint), +]) + + +app.mount("/starlette", starlette_subapi) diff --git a/python/sqlcommenter-python/tests/fastapi/tests.py b/python/sqlcommenter-python/tests/fastapi/tests.py index d2edb9f5..2fb84945 100644 --- a/python/sqlcommenter-python/tests/fastapi/tests.py +++ b/python/sqlcommenter-python/tests/fastapi/tests.py @@ -52,3 +52,13 @@ def test_get_fastapi_info_in_404_error_context(client): def test_get_fastapi_info_outside_request_context(client): assert get_fastapi_info() == {} + + +def test_get_openapi_does_not_throw_an_error(client): + resp = client.get(app.docs_url) + assert resp.status_code == 200 + + +def test_get_starlette_endpoints_does_not_throw_an_error(client): + resp = client.get("/starlette") + assert resp.status_code == 200