scaffold: project skeleton, schema, healthz/readyz, CI

Initial project structure for neuronetz-gateway per scope-docs/SPEC.md:

- Python 3.12 / FastAPI / SQLAlchemy 2.0 (async) / Redis / Postgres stack
  managed by uv. Multi-stage non-root Dockerfile, prod + dev compose files
  (ollama service is NEVER published in either), Caddyfile + systemd unit,
  justfile, GitHub Actions CI (ruff, mypy --strict, pytest, bandit, pip-audit).
- Pydantic-Settings config covering every env var from SPEC §7, including the
  MODEL_DISCOVERY_* keys for the dynamic-discovery feature (§4.6).
- Alembic 0001_initial creates the full gateway schema (8 tables, 3 enums,
  notify_key_revoked() trigger), incl. allow_all_models on tenant_limits and
  key_limits for the per-tenant auto-grant toggle.
- Working /healthz, /readyz (fail-closed when deps unreachable), and a
  Prometheus /metrics stub. Sanitizing error handlers that attach X-Request-ID
  to every response and never leak upstream internals.
- SPEC + AGENT_PROMPT included under scope-docs/ (source of truth).
This commit is contained in:
Stephan Berbig
2026-05-26 20:50:35 +02:00
commit d79f17b3bb
32 changed files with 3610 additions and 0 deletions

View File

@@ -0,0 +1,7 @@
"""neuronetz-gateway: secure multi-tenant API gateway in front of Ollama."""
from __future__ import annotations
__version__ = "0.1.0"
__all__ = ["__version__"]

View File

@@ -0,0 +1,28 @@
"""Uvicorn entry point: ``python -m neuronetz_gateway``.
Binds the app to ``GATEWAY_BIND_HOST``:``GATEWAY_BIND_PORT`` (default
0.0.0.0:8080). The factory string is passed to uvicorn so the app is built in
the worker process.
"""
from __future__ import annotations
import uvicorn
from neuronetz_gateway.config import get_settings
def main() -> None:
"""Run the gateway under uvicorn using the configured bind address."""
settings = get_settings()
uvicorn.run(
"neuronetz_gateway.app:create_app",
factory=True,
host=settings.gateway_bind_host,
port=settings.gateway_bind_port,
log_level=settings.gateway_log_level.lower(),
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,111 @@
"""FastAPI application factory.
``create_app()`` is the shared contract entry point: other agents (DevOps, QA)
import and serve this. It configures logging, installs the request-id and auth
middleware, registers the sanitizing exception handlers, mounts routers, and
binds the lifespan that manages backend handles + background tasks.
Production safety: FastAPI's ``/docs`` + ``/openapi.json`` are disabled by
default (enabled only via ``DOCS_ENABLED``). The ``/playground`` route is served
only when ``PLAYGROUND_ENABLED`` is true and ``PLAYGROUND_FILE`` exists.
"""
from __future__ import annotations
import uuid
from pathlib import Path
from fastapi import FastAPI, Request
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.responses import HTMLResponse, Response
from starlette.types import ASGIApp
from neuronetz_gateway import __version__
from neuronetz_gateway.auth.middleware import AuthMiddleware
from neuronetz_gateway.config import Settings, get_settings
from neuronetz_gateway.errors import register_exception_handlers
from neuronetz_gateway.lifespan import lifespan
from neuronetz_gateway.observability.logging import configure_logging
from neuronetz_gateway.routes import health, ollama_native, openai_compat
class RequestIDMiddleware(BaseHTTPMiddleware):
"""Assign/propagate a request id and expose it on ``request.state``.
Honours an inbound ``X-Request-ID`` from a trusted proxy; otherwise mints a
fresh UUID. The id is echoed on the response and used by error handlers.
"""
def __init__(self, app: ASGIApp, header_name: str) -> None:
super().__init__(app)
self._header = header_name
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
incoming = request.headers.get(self._header)
request_id = incoming or str(uuid.uuid4())
request.state.request_id = request_id
response = await call_next(request)
response.headers[self._header] = request_id
return response
def _register_playground(app: FastAPI, cfg: Settings) -> None:
"""Add the flag-gated ``/playground`` route (HTML asset, owned by docs agent).
The file is read off the event loop via ``asyncio.to_thread`` so a slow disk
cannot stall request handling. Missing-file is a simple 404, never an error.
"""
import asyncio as _asyncio
def _load(path_str: str) -> str | None:
p = Path(path_str)
if not p.is_file():
return None
return p.read_text(encoding="utf-8")
@app.get("/playground", include_in_schema=False)
async def playground() -> Response:
content = await _asyncio.to_thread(_load, cfg.playground_file)
if content is None:
return Response(status_code=404, content="Not found")
return HTMLResponse(content)
def create_app(settings: Settings | None = None) -> FastAPI:
"""Build and return the configured FastAPI application."""
cfg = settings or get_settings()
configure_logging(level=cfg.gateway_log_level, fmt=cfg.gateway_log_format)
app = FastAPI(
title="neuronetz-gateway",
version=__version__,
lifespan=lifespan,
docs_url="/docs" if cfg.docs_enabled else None,
redoc_url="/redoc" if cfg.docs_enabled else None,
openapi_url="/openapi.json" if cfg.docs_enabled else None,
)
# Settings are needed by the auth middleware before lifespan runs in some
# test setups; lifespan also sets this. Setting here is idempotent.
app.state.settings = cfg
# Auth runs inside RequestID so a request id is always available for the
# sanitized 401 the auth middleware emits. add_middleware wraps outermost
# last, so add Auth first then RequestID.
app.add_middleware(AuthMiddleware)
app.add_middleware(RequestIDMiddleware, header_name=cfg.gateway_request_id_header)
register_exception_handlers(app)
app.include_router(health.router)
app.include_router(openai_compat.router)
app.include_router(ollama_native.router)
if cfg.playground_enabled:
_register_playground(app, cfg)
return app
__all__ = ["RequestIDMiddleware", "create_app"]

View File

@@ -0,0 +1,86 @@
"""Application configuration via Pydantic Settings v2.
Reads every environment variable documented in SPEC §7 with the documented
defaults. Boot fails loudly (ValidationError) on invalid config.
"""
from __future__ import annotations
from functools import lru_cache
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
"""Gateway runtime configuration. All fields map to SPEC §7 env vars."""
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
extra="ignore",
case_sensitive=False,
)
# --- Service ---
gateway_bind_host: str = Field(default="0.0.0.0") # noqa: S104 - bind-all is intended in container
gateway_bind_port: int = Field(default=8080)
gateway_log_level: str = Field(default="INFO")
gateway_log_format: str = Field(default="json") # json|console
gateway_request_id_header: str = Field(default="X-Request-ID")
gateway_trusted_proxies: str = Field(default="127.0.0.1,caddy")
# --- Upstream (Ollama) ---
ollama_base_url: str = Field(default="http://ollama:11434")
ollama_connect_timeout_s: int = Field(default=5)
ollama_read_timeout_s: int = Field(default=600)
ollama_max_connections: int = Field(default=64)
# --- Model discovery (SPEC §4.6) ---
model_discovery_refresh_s: int = Field(default=60)
model_discovery_cache_ttl_s: int = Field(default=120)
# --- Database ---
database_url: str = Field(
default="postgresql+asyncpg://gateway:gateway@postgres:5432/neuronetz",
)
database_pool_size: int = Field(default=10)
database_pool_overflow: int = Field(default=20)
# --- Redis ---
redis_url: str = Field(default="redis://redis:6379/0")
redis_key_cache_ttl_s: int = Field(default=60)
# --- Limits ---
default_rpm: int = Field(default=60)
default_tpm: int = Field(default=100_000)
default_concurrent: int = Field(default=8)
max_request_body_bytes: int = Field(default=262_144)
max_num_predict: int = Field(default=4096)
# --- Security ---
argon2_time_cost: int = Field(default=3)
argon2_memory_cost_kib: int = Field(default=65_536)
argon2_parallelism: int = Field(default=4)
auth_failure_rate_limit_per_ip_per_min: int = Field(default=20)
# --- Audit ---
audit_buffer_size: int = Field(default=1000)
prompt_log_default_retention_days: int = Field(default=30)
audit_log_default_retention_days: int = Field(default=365)
# --- Playground / docs (prod-safe defaults: both OFF) ---
playground_enabled: bool = Field(default=False)
playground_file: str = Field(default="/app/playground/index.html")
docs_enabled: bool = Field(default=False)
@property
def trusted_proxies_list(self) -> list[str]:
"""Parse the comma-separated trusted-proxy list into individual hosts."""
return [p.strip() for p in self.gateway_trusted_proxies.split(",") if p.strip()]
@lru_cache(maxsize=1)
def get_settings() -> Settings:
"""Return a cached Settings instance, constructed from the environment."""
return Settings()

View File

@@ -0,0 +1,3 @@
"""Database access layer: SQLAlchemy models, session factory, repositories."""
from __future__ import annotations

View File

@@ -0,0 +1,292 @@
"""SQLAlchemy 2.0 (async) ORM models for schema ``gateway`` per SPEC §5.
These mirror the migration in ``alembic/versions/0001_initial.py`` exactly.
The migration is the authoritative DDL; these models are for application use.
"""
from __future__ import annotations
import datetime
import enum
import uuid
from sqlalchemy import (
BigInteger,
Boolean,
ForeignKey,
Integer,
MetaData,
String,
Text,
text,
)
from sqlalchemy.dialects.postgresql import ARRAY, ENUM, INET, JSONB, TIMESTAMP, UUID
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
GATEWAY_SCHEMA = "gateway"
# Stable naming convention so Alembic autogenerate and ad-hoc DDL agree.
_NAMING_CONVENTION = {
"ix": "ix_%(column_0_label)s",
"uq": "uq_%(table_name)s_%(column_0_name)s",
"ck": "ck_%(table_name)s_%(constraint_name)s",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
"pk": "pk_%(table_name)s",
}
class Base(DeclarativeBase):
"""Declarative base; all tables live in the ``gateway`` schema."""
metadata = MetaData(schema=GATEWAY_SCHEMA, naming_convention=_NAMING_CONVENTION)
class KeyStatus(enum.StrEnum):
"""Lifecycle states for an API key (SPEC §5 ``gateway.key_status``)."""
active = "active"
disabled = "disabled"
revoked = "revoked"
class TenantStatus(enum.StrEnum):
"""Lifecycle states for a tenant (SPEC §5 ``gateway.tenant_status``)."""
active = "active"
suspended = "suspended"
closed = "closed"
class BudgetPeriod(enum.StrEnum):
"""Budget accounting periods (SPEC §5 ``gateway.budget_period``)."""
day = "day"
month = "month"
total = "total"
# Reuse existing Postgres enum types (the migration creates them); do not let
# SQLAlchemy try to CREATE TYPE again at runtime.
_key_status_enum = ENUM(KeyStatus, name="key_status", schema=GATEWAY_SCHEMA, create_type=False)
_tenant_status_enum = ENUM(
TenantStatus, name="tenant_status", schema=GATEWAY_SCHEMA, create_type=False
)
_budget_period_enum = ENUM(
BudgetPeriod, name="budget_period", schema=GATEWAY_SCHEMA, create_type=False
)
class Tenant(Base):
"""A tenant: the top-level isolation and ownership boundary."""
__tablename__ = "tenants"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, server_default=text("gen_random_uuid()")
)
name: Mapped[str] = mapped_column(Text, nullable=False, unique=True)
status: Mapped[TenantStatus] = mapped_column(
_tenant_status_enum, nullable=False, server_default=text("'active'")
)
created_at: Mapped[datetime.datetime] = mapped_column(
TIMESTAMP(timezone=True), nullable=False, server_default=text("now()")
)
tenant_metadata: Mapped[dict[str, object]] = mapped_column(
"metadata", JSONB, nullable=False, server_default=text("'{}'::jsonb")
)
class TenantLimit(Base):
"""Per-tenant default limits and retention policy."""
__tablename__ = "tenant_limits"
tenant_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("tenants.id", ondelete="CASCADE"),
primary_key=True,
)
rpm: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("60"))
tpm: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("100000"))
concurrent: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("8"))
tokens_daily: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
tokens_monthly: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
tokens_total: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
allowed_models: Mapped[list[str]] = mapped_column(
ARRAY(Text), nullable=False, server_default=text("'{}'")
)
# When true, the tenant may use ANY model currently installed on the Ollama
# backend (resolved live via model discovery). When false (default), access is
# default-deny and restricted to ``allowed_models`` intersected with the live set.
allow_all_models: Mapped[bool] = mapped_column(
Boolean, nullable=False, server_default=text("false")
)
log_prompts_default: Mapped[bool] = mapped_column(
Boolean, nullable=False, server_default=text("false")
)
prompt_retention_days: Mapped[int] = mapped_column(
Integer, nullable=False, server_default=text("30")
)
audit_retention_days: Mapped[int] = mapped_column(
Integer, nullable=False, server_default=text("365")
)
class ApiKey(Base):
"""An API key belonging to a tenant. The full key is never stored."""
__tablename__ = "api_keys"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, server_default=text("gen_random_uuid()")
)
tenant_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("tenants.id", ondelete="CASCADE"),
nullable=False,
)
prefix: Mapped[str] = mapped_column(Text, nullable=False, unique=True)
key_hash: Mapped[str] = mapped_column(Text, nullable=False)
name: Mapped[str] = mapped_column(Text, nullable=False)
status: Mapped[KeyStatus] = mapped_column(
_key_status_enum, nullable=False, server_default=text("'active'")
)
scopes: Mapped[list[str]] = mapped_column(
ARRAY(Text), nullable=False, server_default=text("'{chat,embeddings}'")
)
created_at: Mapped[datetime.datetime] = mapped_column(
TIMESTAMP(timezone=True), nullable=False, server_default=text("now()")
)
last_used_at: Mapped[datetime.datetime | None] = mapped_column(
TIMESTAMP(timezone=True), nullable=True
)
expires_at: Mapped[datetime.datetime | None] = mapped_column(
TIMESTAMP(timezone=True), nullable=True
)
log_prompts: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
key_metadata: Mapped[dict[str, object]] = mapped_column(
"metadata", JSONB, nullable=False, server_default=text("'{}'::jsonb")
)
class KeyLimit(Base):
"""Per-key overrides; NULL columns inherit the tenant value."""
__tablename__ = "key_limits"
key_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("api_keys.id", ondelete="CASCADE"),
primary_key=True,
)
rpm: Mapped[int | None] = mapped_column(Integer, nullable=True)
tpm: Mapped[int | None] = mapped_column(Integer, nullable=True)
concurrent: Mapped[int | None] = mapped_column(Integer, nullable=True)
tokens_daily: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
tokens_monthly: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
tokens_total: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
allowed_models: Mapped[list[str] | None] = mapped_column(ARRAY(Text), nullable=True)
# NULL = inherit tenant's allow_all_models; otherwise overrides it for this key.
allow_all_models: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
class BudgetUsage(Base):
"""Token/request accounting per key, period, and period start."""
__tablename__ = "budget_usage"
key_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("api_keys.id", ondelete="CASCADE"),
primary_key=True,
)
period: Mapped[BudgetPeriod] = mapped_column(_budget_period_enum, primary_key=True)
period_start: Mapped[datetime.datetime] = mapped_column(
TIMESTAMP(timezone=True), primary_key=True
)
tokens_in: Mapped[int] = mapped_column(BigInteger, nullable=False, server_default=text("0"))
tokens_out: Mapped[int] = mapped_column(BigInteger, nullable=False, server_default=text("0"))
requests: Mapped[int] = mapped_column(BigInteger, nullable=False, server_default=text("0"))
class AuditLog(Base):
"""Always-on append-only request metadata log."""
__tablename__ = "audit_log"
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
ts: Mapped[datetime.datetime] = mapped_column(
TIMESTAMP(timezone=True), nullable=False, server_default=text("now()")
)
request_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
tenant_id: Mapped[uuid.UUID | None] = mapped_column(UUID(as_uuid=True), nullable=True)
key_id: Mapped[uuid.UUID | None] = mapped_column(UUID(as_uuid=True), nullable=True)
key_prefix: Mapped[str | None] = mapped_column(Text, nullable=True)
method: Mapped[str] = mapped_column(Text, nullable=False)
path: Mapped[str] = mapped_column(Text, nullable=False)
model: Mapped[str | None] = mapped_column(Text, nullable=True)
tokens_in: Mapped[int | None] = mapped_column(Integer, nullable=True)
tokens_out: Mapped[int | None] = mapped_column(Integer, nullable=True)
latency_ms: Mapped[int | None] = mapped_column(Integer, nullable=True)
status: Mapped[int] = mapped_column(Integer, nullable=False)
client_ip: Mapped[str | None] = mapped_column(INET, nullable=True)
user_agent: Mapped[str | None] = mapped_column(Text, nullable=True)
error_code: Mapped[str | None] = mapped_column(Text, nullable=True)
class PromptLog(Base):
"""Opt-in, TTL'd capture of request/response bodies."""
__tablename__ = "prompt_log"
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
audit_id: Mapped[int] = mapped_column(
BigInteger,
ForeignKey("audit_log.id", ondelete="CASCADE"),
nullable=False,
)
ts: Mapped[datetime.datetime] = mapped_column(
TIMESTAMP(timezone=True), nullable=False, server_default=text("now()")
)
key_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
request_body: Mapped[dict[str, object]] = mapped_column(JSONB, nullable=False)
response_text: Mapped[str | None] = mapped_column(Text, nullable=True)
retention_until: Mapped[datetime.datetime] = mapped_column(
TIMESTAMP(timezone=True), nullable=False
)
class Revocation(Base):
"""Outbox table written by console (or gateway) to revoke a key.
An ``AFTER INSERT`` trigger fires ``pg_notify('key_revoked', key_id)``.
"""
__tablename__ = "revocations"
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
key_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
ts: Mapped[datetime.datetime] = mapped_column(
TIMESTAMP(timezone=True), nullable=False, server_default=text("now()")
)
reason: Mapped[str | None] = mapped_column(String, nullable=True)
processed_at: Mapped[datetime.datetime | None] = mapped_column(
TIMESTAMP(timezone=True), nullable=True
)
__all__ = [
"GATEWAY_SCHEMA",
"ApiKey",
"AuditLog",
"Base",
"BudgetPeriod",
"BudgetUsage",
"KeyLimit",
"KeyStatus",
"PromptLog",
"Revocation",
"Tenant",
"TenantLimit",
"TenantStatus",
]

View File

@@ -0,0 +1,53 @@
"""Async SQLAlchemy engine and session factory construction.
Phase 1 provides the wiring only; the lifespan owns the engine instance and
stores it on ``app.state``. Business-logic callers should depend on the
session factory via ``deps.py``.
"""
from __future__ import annotations
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from neuronetz_gateway.config import Settings
def create_engine(settings: Settings) -> AsyncEngine:
"""Build the async engine from settings (asyncpg driver, pooled)."""
return create_async_engine(
settings.database_url,
pool_size=settings.database_pool_size,
max_overflow=settings.database_pool_overflow,
pool_pre_ping=True,
future=True,
)
def create_session_factory(engine: AsyncEngine) -> async_sessionmaker[AsyncSession]:
"""Build a session factory bound to the given engine."""
return async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
@asynccontextmanager
async def session_scope(
factory: async_sessionmaker[AsyncSession],
) -> AsyncIterator[AsyncSession]:
"""Provide a transactional session scope, committing on success."""
async with factory() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
__all__ = ["create_engine", "create_session_factory", "session_scope"]

View File

@@ -0,0 +1,180 @@
"""FastAPI dependency-injection providers.
Exposes typed accessors for the handles placed on ``app.state`` by the lifespan
(Redis, the upstream httpx client, the DB session factory, the discovery cache)
plus the request principal and the proxy client.
QA override contract
--------------------
Routes obtain the upstream proxy via :func:`get_ollama_client`. Tests override
the *Ollama backend* by overriding this provider::
from neuronetz_gateway.deps import get_ollama_client
from neuronetz_gateway.proxy.ollama import OllamaClient
import httpx
from tests.integration.mock_ollama import create_mock_ollama
transport = httpx.ASGITransport(app=create_mock_ollama())
mock_http = httpx.AsyncClient(transport=transport, base_url="http://ollama")
app.dependency_overrides[get_ollama_client] = lambda: OllamaClient(mock_http)
Because ``get_ollama_client`` returns a fully-built :class:`OllamaClient`, an
override needs no access to ``app.state`` and can point at the in-process mock.
"""
from __future__ import annotations
from collections.abc import AsyncIterator
from typing import Annotated
import httpx
import redis.asyncio as redis
from fastapi import Depends, Request
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from neuronetz_gateway.audit.writer import AuditWriter
from neuronetz_gateway.auth.principal import Principal
from neuronetz_gateway.budget.counter import BudgetCounter
from neuronetz_gateway.config import Settings, get_settings
from neuronetz_gateway.errors import AuthenticationError, DependencyUnavailableError
from neuronetz_gateway.proxy.discovery import DiscoveryCache
from neuronetz_gateway.proxy.ollama import OllamaClient
from neuronetz_gateway.proxy.pipeline import Pipeline
from neuronetz_gateway.ratelimit.concurrency import ConcurrencyLimiter
from neuronetz_gateway.ratelimit.sliding_window import SlidingWindowLimiter
def get_config() -> Settings:
"""Provide the cached application settings."""
return get_settings()
def get_redis(request: Request) -> redis.Redis:
"""Provide the shared Redis client, failing closed if unavailable."""
client: redis.Redis | None = getattr(request.app.state, "redis", None)
if client is None:
raise DependencyUnavailableError(internal_detail="redis client not initialised")
return client
def get_http_client(request: Request) -> httpx.AsyncClient:
"""Provide the shared upstream httpx client."""
client: httpx.AsyncClient | None = getattr(request.app.state, "http_client", None)
if client is None:
raise DependencyUnavailableError(internal_detail="http client not initialised")
return client
def get_ollama_client(request: Request) -> OllamaClient:
"""Provide the upstream Ollama proxy client (override target for tests)."""
return OllamaClient(get_http_client(request))
def get_discovery_cache(request: Request) -> DiscoveryCache:
"""Provide the in-process discovery cache; fail closed if absent."""
cache: DiscoveryCache | None = getattr(request.app.state, "discovery_cache", None)
if cache is None:
raise DependencyUnavailableError(internal_detail="discovery cache not initialised")
return cache
def get_principal(request: Request) -> Principal:
"""Return the authenticated principal placed on ``request.state``.
The auth middleware attaches it before routing; its absence on a non-exempt
route is a programming error, so we fail closed with a 401.
"""
principal: Principal | None = getattr(request.state, "principal", None)
if principal is None:
raise AuthenticationError(internal_detail="principal missing on authenticated route")
return principal
def get_audit_writer(request: Request) -> AuditWriter:
"""Provide the shared buffered audit writer; fail closed if absent."""
writer: AuditWriter | None = getattr(request.app.state, "audit_writer", None)
if writer is None:
raise DependencyUnavailableError(internal_detail="audit writer not initialised")
return writer
def get_pipeline(
request: Request,
principal: Annotated[Principal, Depends(get_principal)],
settings: Annotated[Settings, Depends(get_config)],
ollama: Annotated[OllamaClient, Depends(get_ollama_client)],
discovery: Annotated[DiscoveryCache, Depends(get_discovery_cache)],
redis_client: Annotated[redis.Redis, Depends(get_redis)],
audit: Annotated[AuditWriter, Depends(get_audit_writer)],
) -> Pipeline:
"""Assemble a per-request enforcement + proxy pipeline.
The pipeline owns all hot-path checks (rate limit, budget, concurrency,
model/endpoint allowlist) and the streaming-with-bookkeeping contract.
Audit deny-mode flips this to fail closed at the route layer.
"""
sessionmaker: async_sessionmaker[AsyncSession] | None = getattr(
request.app.state, "db_sessionmaker", None
)
return Pipeline(
request=request,
principal=principal,
settings=settings,
ollama=ollama,
discovery=discovery,
rate_limiter=SlidingWindowLimiter(redis_client),
concurrency=ConcurrencyLimiter(redis_client),
budget=BudgetCounter(redis_client),
audit=audit,
sessionmaker=sessionmaker,
)
def _get_sessionmaker(request: Request) -> async_sessionmaker[AsyncSession]:
"""Return the session factory or fail closed if the engine is absent."""
factory: async_sessionmaker[AsyncSession] | None = getattr(
request.app.state, "db_sessionmaker", None
)
if factory is None:
raise DependencyUnavailableError(internal_detail="db session factory not initialised")
return factory
async def get_db_session(request: Request) -> AsyncIterator[AsyncSession]:
"""Provide a request-scoped async DB session."""
factory = _get_sessionmaker(request)
async with factory() as session:
yield session
ConfigDep = Annotated[Settings, Depends(get_config)]
RedisDep = Annotated[redis.Redis, Depends(get_redis)]
HttpClientDep = Annotated[httpx.AsyncClient, Depends(get_http_client)]
OllamaClientDep = Annotated[OllamaClient, Depends(get_ollama_client)]
DiscoveryCacheDep = Annotated[DiscoveryCache, Depends(get_discovery_cache)]
PrincipalDep = Annotated[Principal, Depends(get_principal)]
AuditWriterDep = Annotated[AuditWriter, Depends(get_audit_writer)]
PipelineDep = Annotated[Pipeline, Depends(get_pipeline)]
DbSessionDep = Annotated[AsyncSession, Depends(get_db_session)]
__all__ = [
"AuditWriterDep",
"ConfigDep",
"DbSessionDep",
"DiscoveryCacheDep",
"HttpClientDep",
"OllamaClientDep",
"PipelineDep",
"PrincipalDep",
"RedisDep",
"get_audit_writer",
"get_config",
"get_db_session",
"get_discovery_cache",
"get_http_client",
"get_ollama_client",
"get_pipeline",
"get_principal",
"get_redis",
]

View File

@@ -0,0 +1,179 @@
"""Exception types and FastAPI exception handlers.
Hard rule (SPEC §3, AGENT_PROMPT non-negotiable #4): never leak upstream or
internal error details to the client. Every error response is a generic,
sanitized JSON body carrying only a stable ``error.code``, a safe message, and
the request id. Detailed context is logged server-side, never returned.
"""
from __future__ import annotations
from fastapi import FastAPI, Request, status
from fastapi.responses import JSONResponse
from neuronetz_gateway.observability.logging import get_logger
_log = get_logger("errors")
class GatewayError(Exception):
"""Base class for gateway errors that map to a sanitized HTTP response.
``message`` MUST be safe to return to clients. Anything sensitive belongs
in ``internal_detail`` which is logged but never serialized to the client.
"""
status_code: int = status.HTTP_500_INTERNAL_SERVER_ERROR
code: str = "internal_error"
message: str = "An internal error occurred."
def __init__(self, message: str | None = None, *, internal_detail: str | None = None) -> None:
super().__init__(message or self.message)
if message is not None:
self.message = message
self.internal_detail = internal_detail
class AuthenticationError(GatewayError):
"""Missing/invalid credentials. Fail closed, no detail."""
status_code = status.HTTP_401_UNAUTHORIZED
code = "unauthorized"
message = "Authentication required."
class AuthorizationError(GatewayError):
"""Authenticated but not permitted (scope/model/endpoint denied)."""
status_code = status.HTTP_403_FORBIDDEN
code = "forbidden"
message = "This request is not permitted."
class RateLimitError(GatewayError):
"""Rate limit exceeded. Handler attaches ``Retry-After`` when known."""
status_code = status.HTTP_429_TOO_MANY_REQUESTS
code = "rate_limited"
message = "Rate limit exceeded."
def __init__(
self,
message: str | None = None,
*,
retry_after: int | None = None,
internal_detail: str | None = None,
) -> None:
super().__init__(message, internal_detail=internal_detail)
self.retry_after = retry_after
class BudgetExceededError(GatewayError):
"""Token budget exhausted for the active period."""
status_code = status.HTTP_429_TOO_MANY_REQUESTS
code = "budget_exceeded"
message = "Token budget exhausted for the current period."
class RequestTooLargeError(GatewayError):
"""Request body exceeds the configured limit."""
status_code = status.HTTP_413_REQUEST_ENTITY_TOO_LARGE
code = "request_too_large"
message = "Request body is too large."
class UpstreamUnavailableError(GatewayError):
"""Ollama (or another dependency) is unreachable. Fail closed."""
status_code = status.HTTP_502_BAD_GATEWAY
code = "upstream_unavailable"
message = "The upstream service is temporarily unavailable."
class DependencyUnavailableError(GatewayError):
"""A required backend (DB/Redis) is unavailable; serve 503, fail closed."""
status_code = status.HTTP_503_SERVICE_UNAVAILABLE
code = "service_unavailable"
message = "The service is temporarily unavailable."
def _request_id(request: Request) -> str:
"""Extract the request id placed on ``request.state`` by middleware."""
rid = getattr(request.state, "request_id", None)
return str(rid) if rid else ""
def _error_response(
request: Request,
*,
status_code: int,
code: str,
message: str,
extra_headers: dict[str, str] | None = None,
) -> JSONResponse:
"""Build a sanitized JSON error response with the request id header."""
request_id = _request_id(request)
headers = {"X-Request-ID": request_id} if request_id else {}
if extra_headers:
headers.update(extra_headers)
return JSONResponse(
status_code=status_code,
content={"error": {"code": code, "message": message, "request_id": request_id}},
headers=headers,
)
async def _gateway_error_handler(request: Request, exc: GatewayError) -> JSONResponse:
"""Render a ``GatewayError`` as a sanitized response."""
if exc.internal_detail:
_log.warning(
"gateway_error",
code=exc.code,
status_code=exc.status_code,
internal_detail=exc.internal_detail,
)
extra: dict[str, str] | None = None
if isinstance(exc, RateLimitError) and exc.retry_after is not None:
extra = {"Retry-After": str(exc.retry_after)}
return _error_response(
request,
status_code=exc.status_code,
code=exc.code,
message=exc.message,
extra_headers=extra,
)
async def _unhandled_error_handler(request: Request, exc: Exception) -> JSONResponse:
"""Catch-all: log the real exception, return a generic 500. No leakage."""
_log.error("unhandled_exception", exc_info=exc)
return _error_response(
request,
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
code="internal_error",
message="An internal error occurred.",
)
def register_exception_handlers(app: FastAPI) -> None:
"""Attach the gateway's sanitizing exception handlers to the app."""
# mypy: FastAPI's add_exception_handler accepts these handler signatures;
# the stubs are intentionally broad, so casts are unnecessary here.
app.add_exception_handler(GatewayError, _gateway_error_handler) # type: ignore[arg-type] # handler typed for GatewayError subclass
app.add_exception_handler(Exception, _unhandled_error_handler)
__all__ = [
"AuthenticationError",
"AuthorizationError",
"BudgetExceededError",
"DependencyUnavailableError",
"GatewayError",
"RateLimitError",
"RequestTooLargeError",
"UpstreamUnavailableError",
"register_exception_handlers",
]

View File

@@ -0,0 +1,131 @@
"""Application lifespan: connect/dispose backends and run background tasks.
Startup connects Postgres + Redis + the upstream httpx client, builds the
argon2 hasher and the buffered audit writer, and launches the background tasks:
the model-discovery poller (SPEC §4.6) and the Postgres revocation NOTIFY
listener (SPEC §4.5). Connection failures are tolerated so ``/healthz`` always
serves; ``/readyz`` reports true readiness. All handles live on ``app.state``.
"""
from __future__ import annotations
import asyncio
import contextlib
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING
import httpx
import redis.asyncio as redis
from neuronetz_gateway.audit.writer import AuditWriter
from neuronetz_gateway.auth.hashing import build_hasher
from neuronetz_gateway.config import Settings, get_settings
from neuronetz_gateway.db.session import create_engine, create_session_factory
from neuronetz_gateway.observability.logging import get_logger
from neuronetz_gateway.proxy.discovery import DiscoveryCache, discovery_loop
from neuronetz_gateway.revocation import revocation_listener
if TYPE_CHECKING:
from fastapi import FastAPI
_log = get_logger("lifespan")
def _build_http_client(settings: Settings) -> httpx.AsyncClient:
"""Construct the shared httpx client used to reach Ollama."""
timeout = httpx.Timeout(
connect=settings.ollama_connect_timeout_s,
read=settings.ollama_read_timeout_s,
write=settings.ollama_read_timeout_s,
pool=settings.ollama_connect_timeout_s,
)
limits = httpx.Limits(max_connections=settings.ollama_max_connections)
return httpx.AsyncClient(base_url=settings.ollama_base_url, timeout=timeout, limits=limits)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
"""Manage startup/shutdown of all backends and background tasks."""
settings: Settings = get_settings()
app.state.settings = settings
app.state.hasher = build_hasher(settings)
app.state.discovery_cache = DiscoveryCache()
tasks: list[asyncio.Task[None]] = []
try:
engine = create_engine(settings)
app.state.db_engine = engine
app.state.db_sessionmaker = create_session_factory(engine)
except Exception as exc: # noqa: BLE001 - tolerate so /healthz still serves
_log.error("db_engine_init_failed", error=str(exc))
app.state.db_engine = None
app.state.db_sessionmaker = None
try:
app.state.redis = redis.from_url(settings.redis_url, decode_responses=True)
except Exception as exc: # noqa: BLE001 - tolerate so /healthz still serves
_log.error("redis_init_failed", error=str(exc))
app.state.redis = None
app.state.http_client = _build_http_client(settings)
audit_writer = AuditWriter(settings.audit_buffer_size, app.state.db_sessionmaker)
audit_writer.start()
app.state.audit_writer = audit_writer
# Background tasks (cancelled on shutdown).
tasks.append(
asyncio.create_task(
discovery_loop(
app.state.http_client, app.state.redis, app.state.discovery_cache, settings
)
)
)
if app.state.redis is not None and app.state.db_sessionmaker is not None:
tasks.append(
asyncio.create_task(
revocation_listener(settings, app.state.redis, app.state.db_sessionmaker)
)
)
app.state.background_tasks = tasks
_log.info("gateway_startup_complete")
try:
yield
finally:
await _shutdown(app, tasks, audit_writer)
async def _shutdown(
app: FastAPI, tasks: list[asyncio.Task[None]], audit_writer: AuditWriter
) -> None:
"""Cancel background tasks and dispose of all backend handles."""
for task in tasks:
task.cancel()
for task in tasks:
with contextlib.suppress(asyncio.CancelledError):
await task
with contextlib.suppress(Exception):
await audit_writer.stop()
http_client: httpx.AsyncClient | None = getattr(app.state, "http_client", None)
if http_client is not None:
with contextlib.suppress(Exception):
await http_client.aclose()
redis_client = getattr(app.state, "redis", None)
if redis_client is not None:
with contextlib.suppress(Exception):
await redis_client.aclose()
engine = getattr(app.state, "db_engine", None)
if engine is not None:
with contextlib.suppress(Exception):
await engine.dispose()
_log.info("gateway_shutdown_complete")
__all__ = ["lifespan"]

View File

@@ -0,0 +1,3 @@
"""Observability: structured logging and Prometheus metrics."""
from __future__ import annotations

View File

@@ -0,0 +1,48 @@
"""structlog configuration.
Renders JSON in production (``GATEWAY_LOG_FORMAT=json``) and a human-friendly
console format in development. No secrets are ever logged; processors here
must not introduce any.
"""
from __future__ import annotations
import logging
from typing import Any
import structlog
def configure_logging(level: str = "INFO", fmt: str = "json") -> None:
"""Configure stdlib logging and structlog according to settings."""
log_level = getattr(logging, level.upper(), logging.INFO)
logging.basicConfig(format="%(message)s", level=log_level)
shared_processors: list[structlog.types.Processor] = [
structlog.contextvars.merge_contextvars,
structlog.processors.add_log_level,
structlog.processors.TimeStamper(fmt="iso", utc=True),
structlog.processors.StackInfoRenderer(),
structlog.processors.format_exc_info,
]
renderer: structlog.types.Processor
if fmt == "console":
renderer = structlog.dev.ConsoleRenderer()
else:
renderer = structlog.processors.JSONRenderer()
structlog.configure(
processors=[*shared_processors, renderer],
wrapper_class=structlog.make_filtering_bound_logger(log_level),
logger_factory=structlog.PrintLoggerFactory(),
cache_logger_on_first_use=True,
)
def get_logger(name: str | None = None) -> Any: # noqa: ANN401 - structlog returns a dynamic proxy
"""Return a bound structlog logger."""
return structlog.get_logger(name)
__all__ = ["configure_logging", "get_logger"]

View File

@@ -0,0 +1,3 @@
"""HTTP route modules: health, native Ollama passthrough, OpenAI-compat."""
from __future__ import annotations

View File

@@ -0,0 +1,114 @@
"""Health, readiness, and metrics endpoints (SPEC §6.4).
- ``GET /healthz`` : liveness — always 200 if the process can respond.
- ``GET /readyz`` : readiness — 200 only if Postgres + Redis + Ollama are all
reachable; otherwise 503 with which dependencies are down.
In Phase 1 dev there is no Ollama, so 503 is expected.
- ``GET /metrics`` : Prometheus exposition. (Loopback-only IP check deferred.)
None of these endpoints require auth and none leak secrets or internal detail.
"""
from __future__ import annotations
from collections.abc import Awaitable
from typing import Literal, cast
import httpx
import redis.asyncio as redis
from fastapi import APIRouter, Request, Response, status
from pydantic import BaseModel
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from neuronetz_gateway.observability.logging import get_logger
from neuronetz_gateway.observability.metrics import CONTENT_TYPE_LATEST, render_latest
router = APIRouter(tags=["health"])
_log = get_logger("health")
class HealthResponse(BaseModel):
"""Liveness response body."""
status: Literal["ok"] = "ok"
class ReadyResponse(BaseModel):
"""Readiness response body. ``checks`` maps dependency -> reachable bool."""
status: Literal["ready", "not_ready"]
checks: dict[str, bool]
@router.get("/healthz", response_model=HealthResponse, status_code=status.HTTP_200_OK)
async def healthz() -> HealthResponse:
"""Liveness probe — always returns 200 while the process is responsive."""
return HealthResponse()
async def _check_postgres(app_state: object) -> bool:
"""Return True if a trivial query succeeds against Postgres."""
factory: async_sessionmaker[AsyncSession] | None = getattr(
app_state, "db_sessionmaker", None
)
if factory is None:
return False
try:
async with factory() as session:
await session.execute(text("SELECT 1"))
return True
except Exception as exc: # noqa: BLE001 - any failure means not ready
_log.warning("readyz_postgres_unreachable", error=str(exc))
return False
async def _check_redis(app_state: object) -> bool:
"""Return True if Redis answers PING."""
client: redis.Redis | None = getattr(app_state, "redis", None)
if client is None:
return False
try:
# redis-py types ping() as Awaitable[bool] | bool (sync+async share stubs);
# the asyncio client always returns an awaitable at runtime.
return bool(await cast("Awaitable[bool]", client.ping()))
except Exception as exc: # noqa: BLE001 - any failure means not ready
_log.warning("readyz_redis_unreachable", error=str(exc))
return False
async def _check_ollama(app_state: object) -> bool:
"""Return True if Ollama's root endpoint is reachable."""
client: httpx.AsyncClient | None = getattr(app_state, "http_client", None)
if client is None:
return False
try:
resp = await client.get("/")
return resp.status_code < 500
except Exception as exc: # noqa: BLE001 - any failure means not ready
_log.warning("readyz_ollama_unreachable", error=str(exc))
return False
@router.get("/readyz", response_model=ReadyResponse)
async def readyz(request: Request, response: Response) -> ReadyResponse:
"""Readiness probe — 200 only if every dependency is reachable, else 503."""
app_state = request.app.state
checks = {
"postgres": await _check_postgres(app_state),
"redis": await _check_redis(app_state),
"ollama": await _check_ollama(app_state),
}
all_ready = all(checks.values())
if not all_ready:
response.status_code = status.HTTP_503_SERVICE_UNAVAILABLE
return ReadyResponse(status="ready" if all_ready else "not_ready", checks=checks)
@router.get("/metrics")
async def metrics() -> Response:
"""Prometheus exposition. Loopback-only enforcement is deferred to Phase 4."""
return Response(content=render_latest(), media_type=CONTENT_TYPE_LATEST)
__all__ = ["router"]