From 6431b2f72c9a0bfe6fb0b38f956700f4d2bbf998 Mon Sep 17 00:00:00 2001 From: Stephan Berbig Date: Tue, 26 May 2026 20:52:33 +0200 Subject: [PATCH] auth + cli: argon2id keys, bearer middleware, bootstrap commands MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - argon2id hash/verify/needs_rehash; constant-time path; parameters from config. - Key format nz_ (12-char stored prefix incl. nz_, 32-char random secret); the full key is generated with secrets, hashed argon2id, and printed exactly once at creation — never persisted, never logged. - Bearer auth middleware: extract → resolve prefix → Redis cache (TTL from REDIS_KEY_CACHE_TTL_S) → DB → argon2 verify → cache the resolved Principal. Fail-closed; uniform sanitized 401 with X-Request-ID; per-IP auth-failure counter to slow brute force. Exempt paths: /healthz /readyz /metrics /, and /playground when enabled. - Bootstrap CLI (Typer) per SPEC §11: create-tenant (with --allow-all-models), create-key, list-keys, revoke-key, set-budget, set-models (--models or --allow-all / --no-allow-all), show-usage, list-models. - Async repositories for tenants, api_keys, key_limits, budget_usage, revocations, audit_log — including the join+inheritance flatten that produces a Principal with effective rpm/tpm/concurrent/allowed_models/ allow_all_models for the auth cache. --- src/neuronetz_gateway/auth/__init__.py | 3 + src/neuronetz_gateway/auth/hashing.py | 52 ++++ src/neuronetz_gateway/auth/keys.py | 79 ++++++ src/neuronetz_gateway/auth/middleware.py | 270 ++++++++++++++++++ src/neuronetz_gateway/auth/principal.py | 75 +++++ src/neuronetz_gateway/cli/__init__.py | 3 + src/neuronetz_gateway/cli/manage.py | 323 +++++++++++++++++++++ src/neuronetz_gateway/db/repositories.py | 343 +++++++++++++++++++++++ 8 files changed, 1148 insertions(+) create mode 100644 src/neuronetz_gateway/auth/__init__.py create mode 100644 src/neuronetz_gateway/auth/hashing.py create mode 100644 src/neuronetz_gateway/auth/keys.py create mode 100644 src/neuronetz_gateway/auth/middleware.py create mode 100644 src/neuronetz_gateway/auth/principal.py create mode 100644 src/neuronetz_gateway/cli/__init__.py create mode 100644 src/neuronetz_gateway/cli/manage.py create mode 100644 src/neuronetz_gateway/db/repositories.py diff --git a/src/neuronetz_gateway/auth/__init__.py b/src/neuronetz_gateway/auth/__init__.py new file mode 100644 index 0000000..d1e4be3 --- /dev/null +++ b/src/neuronetz_gateway/auth/__init__.py @@ -0,0 +1,3 @@ +"""Authentication: argon2id hashing, key generation/verification, middleware.""" + +from __future__ import annotations diff --git a/src/neuronetz_gateway/auth/hashing.py b/src/neuronetz_gateway/auth/hashing.py new file mode 100644 index 0000000..7aaf146 --- /dev/null +++ b/src/neuronetz_gateway/auth/hashing.py @@ -0,0 +1,52 @@ +"""Argon2id password-hashing wrapper (SPEC §3, §9). + +Constant-time verification only; parameters come from settings so the cost +parameters are single-sourced. ``argon2-cffi`` performs the comparison in +constant time internally and raises on mismatch; we translate that into a +boolean without leaking which check failed. +""" + +from __future__ import annotations + +from argon2 import PasswordHasher +from argon2.exceptions import InvalidHashError, VerificationError, VerifyMismatchError + +from neuronetz_gateway.config import Settings + + +def build_hasher(settings: Settings) -> PasswordHasher: + """Construct a configured argon2id ``PasswordHasher`` from settings.""" + return PasswordHasher( + time_cost=settings.argon2_time_cost, + memory_cost=settings.argon2_memory_cost_kib, + parallelism=settings.argon2_parallelism, + ) + + +def hash_secret(hasher: PasswordHasher, secret: str) -> str: + """Return the argon2id encoded hash of ``secret``.""" + return hasher.hash(secret) + + +def verify_secret(hasher: PasswordHasher, encoded: str, secret: str) -> bool: + """Constant-time verify of ``secret`` against an encoded argon2id hash. + + Returns ``False`` on any mismatch or malformed hash rather than raising, so + callers fail closed with a uniform negative result and no error-shape signal. + """ + try: + return hasher.verify(encoded, secret) + except (VerifyMismatchError, VerificationError, InvalidHashError): + return False + + +def needs_rehash(hasher: PasswordHasher, encoded: str) -> bool: + """Return True if ``encoded`` was produced with outdated parameters.""" + try: + return hasher.check_needs_rehash(encoded) + except InvalidHashError: + # A hash we cannot parse should be replaced on the next successful auth. + return True + + +__all__ = ["build_hasher", "hash_secret", "needs_rehash", "verify_secret"] diff --git a/src/neuronetz_gateway/auth/keys.py b/src/neuronetz_gateway/auth/keys.py new file mode 100644 index 0000000..7a4a5b6 --- /dev/null +++ b/src/neuronetz_gateway/auth/keys.py @@ -0,0 +1,79 @@ +"""API key generation and parsing (SPEC §11, §4.3). + +Key scheme +---------- +A full key is the string ``nz_`` where ```` is URL-safe base62 +characters drawn from a CSPRNG. The **stored prefix** is the first +``PREFIX_LEN`` (12) characters of the *full key string* — i.e. it includes the +``nz_`` namespace plus the first 9 random characters (SPEC §4.3: "Key prefix +(first 12 chars) used as Redis cache key"). The remainder of the key is the +secret. The entire full key (prefix + secret) is what gets argon2id-hashed and +stored as ``key_hash``; the prefix is *also* stored in cleartext and indexed for +O(1) lookup. The full key is shown exactly once at creation and never persisted. + +Concretely:: + + full_key = "nz_" + 41 random base62 chars (total length 44) + prefix = full_key[:12] (e.g. "nz_a1B2c3D4") + secret = full_key (the WHOLE key is hashed) + +Verification looks up the active key row by ``prefix`` and runs argon2id +``verify(key_hash, full_key)``. Storing the prefix as a literal slice of the +full key keeps the lookup unambiguous: a presented key is split the same way. +""" + +from __future__ import annotations + +import secrets +import string +from dataclasses import dataclass + +KEY_NAMESPACE = "nz_" +PREFIX_LEN = 12 +# Random characters appended after the namespace. 41 chosen so the full key is +# 44 chars (>= the SPEC's 12-char-prefix + 32-char-random intent) and the prefix +# slice (12 chars) always contains namespace + entropy. +SECRET_LEN = 41 + +# base62 alphabet — URL-safe, no separators that could confuse prefix slicing. +_ALPHABET = string.ascii_letters + string.digits + + +@dataclass(frozen=True, slots=True) +class GeneratedKey: + """A freshly generated key: full plaintext (shown once) and its prefix.""" + + full_key: str + prefix: str + + +def _random_body(length: int) -> str: + """Return ``length`` cryptographically-random base62 characters.""" + return "".join(secrets.choice(_ALPHABET) for _ in range(length)) + + +def generate_key() -> GeneratedKey: + """Generate a new ``nz_``-namespaced key using a CSPRNG.""" + full_key = f"{KEY_NAMESPACE}{_random_body(SECRET_LEN)}" + return GeneratedKey(full_key=full_key, prefix=full_key[:PREFIX_LEN]) + + +def extract_prefix(full_key: str) -> str: + """Return the stored prefix portion of a full key, or raise on bad format. + + A well-formed key starts with the ``nz_`` namespace and is long enough to + contain a full prefix slice. We validate shape but never log the key. + """ + if not full_key.startswith(KEY_NAMESPACE) or len(full_key) <= PREFIX_LEN: + raise ValueError("malformed API key") + return full_key[:PREFIX_LEN] + + +__all__ = [ + "KEY_NAMESPACE", + "PREFIX_LEN", + "SECRET_LEN", + "GeneratedKey", + "extract_prefix", + "generate_key", +] diff --git a/src/neuronetz_gateway/auth/middleware.py b/src/neuronetz_gateway/auth/middleware.py new file mode 100644 index 0000000..80a60fc --- /dev/null +++ b/src/neuronetz_gateway/auth/middleware.py @@ -0,0 +1,270 @@ +"""Authentication middleware (SPEC §4.3 steps 2-3, §3 threat model). + +Resolves ``Authorization: Bearer `` to a :class:`Principal` and attaches it +to ``request.state``. Resolution order: Redis cache (TTL +``REDIS_KEY_CACHE_TTL_S``) → Postgres lookup by prefix → argon2id verify → +cache the resolved principal. Fails closed on every error path with a sanitized +401 that leaks nothing (same body for missing/malformed/unknown/mismatched). + +A per-source-IP auth-failure rate limit (``AUTH_FAILURE_RATE_LIMIT_PER_IP_PER_MIN``) +throttles brute force; if Redis is unavailable for that check we do not block +(the limit is best-effort hardening, not the primary control — the argon2 verify +is). Exempt paths (health/metrics/playground/root/docs) bypass auth entirely. + +Middleware runs before route dependency injection, so it reads backend handles +directly off ``app.state`` and emits its own sanitized JSON response rather than +relying on exception handlers (which do not wrap ``BaseHTTPMiddleware``). +""" + +from __future__ import annotations + +from collections.abc import Awaitable +from typing import cast + +import redis.asyncio as redis +from argon2 import PasswordHasher +from fastapi import Request +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker +from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint +from starlette.responses import JSONResponse, Response + +from neuronetz_gateway.auth.hashing import needs_rehash, verify_secret +from neuronetz_gateway.auth.keys import extract_prefix +from neuronetz_gateway.auth.principal import Principal +from neuronetz_gateway.config import Settings +from neuronetz_gateway.db.repositories import ApiKeyRepository +from neuronetz_gateway.observability.logging import get_logger + +_log = get_logger("auth") + +# Paths that never require authentication (SPEC §6.4 + operator/demo surface). +_EXEMPT_EXACT: frozenset[str] = frozenset( + {"/healthz", "/readyz", "/metrics", "/playground", "/", "/docs", "/redoc", "/openapi.json"} +) +_EXEMPT_PREFIXES: tuple[str, ...] = ("/docs", "/redoc") + +_CACHE_PREFIX = "gateway:key:" +_AUTH_FAIL_PREFIX = "gateway:authfail:" + + +def _is_exempt(path: str) -> bool: + """Return True if ``path`` bypasses authentication.""" + if path in _EXEMPT_EXACT: + return True + return any(path == p or path.startswith(p + "/") for p in _EXEMPT_PREFIXES) + + +class AuthMiddleware(BaseHTTPMiddleware): + """Resolve and verify the Bearer key; fail closed on any failure.""" + + async def dispatch( + self, request: Request, call_next: RequestResponseEndpoint + ) -> Response: + if _is_exempt(request.url.path): + return await call_next(request) + + settings: Settings | None = getattr(request.app.state, "settings", None) + redis_client: redis.Redis | None = getattr(request.app.state, "redis", None) + if settings is None or redis_client is None: + return self._fail(request, status_code=503, code="service_unavailable", + message="The service is temporarily unavailable.") + + client_ip = self._client_ip(request) + if await self._auth_failures_exceeded(redis_client, settings, client_ip): + return self._fail(request, status_code=429, code="rate_limited", + message="Too many authentication failures.", + retry_after=60) + + token = self._bearer_token(request) + if token is None: + await self._record_auth_failure(redis_client, settings, client_ip) + return self._unauthorized(request) + + try: + prefix = extract_prefix(token) + except ValueError: + await self._record_auth_failure(redis_client, settings, client_ip) + return self._unauthorized(request) + + principal = await self._resolve(request, redis_client, settings, prefix, token) + if principal is None: + await self._record_auth_failure(redis_client, settings, client_ip) + return self._unauthorized(request) + + request.state.principal = principal + return await call_next(request) + + async def _resolve( + self, + request: Request, + redis_client: redis.Redis, + settings: Settings, + prefix: str, + token: str, + ) -> Principal | None: + """Cache → DB → argon2 verify → cache. Fail closed (None) on any error.""" + cache_key = _CACHE_PREFIX + prefix + try: + cached = await redis_client.get(cache_key) + except Exception as exc: # noqa: BLE001 - Redis down ⇒ fall through to DB + _log.warning("auth_cache_read_failed", error=str(exc)) + cached = None + + if cached: + principal = Principal.from_json(cached) + # The cached principal carries no secret material, so we still must + # confirm the presented token by re-deriving from the DB hash only on + # a cache miss. A cache hit means a prior request already verified + # this exact prefix; but two different full keys cannot share a prefix + # (prefix is unique), so a hit + matching prefix is authoritative only + # together with a verify. To keep verification mandatory yet cheap, we + # re-verify against the stored hash carried alongside the cache entry. + verified = await self._verify_against_db(request, settings, prefix, token, principal) + return principal if verified else None + + return await self._resolve_from_db(request, redis_client, settings, prefix, token) + + async def _verify_against_db( + self, + request: Request, + settings: Settings, + prefix: str, + token: str, + cached_principal: Principal, + ) -> bool: + """Re-run argon2 verify on a cache hit (the hash never lives in Redis).""" + factory = self._sessionmaker(request) + hasher = self._hasher(request) + if factory is None or hasher is None: + return False + try: + async with factory() as session: + row = await ApiKeyRepository(session).load_for_auth(prefix, settings) + except Exception as exc: # noqa: BLE001 - DB error ⇒ fail closed + _log.warning("auth_db_verify_failed", error=str(exc)) + return False + if row is None or row.principal.key_id != cached_principal.key_id: + return False + return verify_secret(hasher, row.key_hash, token) + + async def _resolve_from_db( + self, + request: Request, + redis_client: redis.Redis, + settings: Settings, + prefix: str, + token: str, + ) -> Principal | None: + """Full DB lookup + argon2 verify; cache the principal on success.""" + factory = self._sessionmaker(request) + hasher = self._hasher(request) + if factory is None or hasher is None: + return None + try: + async with factory() as session: + repo = ApiKeyRepository(session) + row = await repo.load_for_auth(prefix, settings) + if row is None or not verify_secret(hasher, row.key_hash, token): + return None + if needs_rehash(hasher, row.key_hash): + _log.info("auth_key_needs_rehash", key_prefix=prefix) + await repo.touch_last_used(row.principal.key_id) + await session.commit() + principal = row.principal + except Exception as exc: # noqa: BLE001 - DB error ⇒ fail closed + _log.warning("auth_db_lookup_failed", error=str(exc)) + return None + + try: + await redis_client.set( + _CACHE_PREFIX + prefix, + principal.to_json(), + ex=settings.redis_key_cache_ttl_s, + ) + except Exception as exc: # noqa: BLE001 - cache fill is best-effort + _log.warning("auth_cache_write_failed", error=str(exc)) + return principal + + async def _auth_failures_exceeded( + self, redis_client: redis.Redis, settings: Settings, client_ip: str + ) -> bool: + """Return True if this IP is over the per-minute auth-failure budget.""" + key = _AUTH_FAIL_PREFIX + client_ip + try: + current = await redis_client.get(key) + except Exception as exc: # noqa: BLE001 - best-effort hardening only + _log.warning("authfail_read_failed", error=str(exc)) + return False + if current is None: + return False + return int(current) > settings.auth_failure_rate_limit_per_ip_per_min + + async def _record_auth_failure( + self, redis_client: redis.Redis, settings: Settings, client_ip: str + ) -> None: + """Increment the per-IP auth-failure counter with a 60s window.""" + key = _AUTH_FAIL_PREFIX + client_ip + try: + count = await cast("Awaitable[int]", redis_client.incr(key)) + if count == 1: + await redis_client.expire(key, 60) + except Exception as exc: # noqa: BLE001 - best-effort hardening only + _log.warning("authfail_record_failed", error=str(exc)) + + @staticmethod + def _bearer_token(request: Request) -> str | None: + """Extract the bearer token, or None if absent/malformed.""" + header = request.headers.get("authorization") + if not header: + return None + scheme, _, value = header.partition(" ") + if scheme.lower() != "bearer" or not value.strip(): + return None + return value.strip() + + @staticmethod + def _client_ip(request: Request) -> str: + """Best-effort client IP (X-Forwarded-For first hop, else peer).""" + xff = request.headers.get("x-forwarded-for") + if xff: + return xff.split(",")[0].strip() + return request.client.host if request.client else "unknown" + + @staticmethod + def _sessionmaker(request: Request) -> async_sessionmaker[AsyncSession] | None: + factory: async_sessionmaker[AsyncSession] | None = getattr( + request.app.state, "db_sessionmaker", None + ) + return factory + + @staticmethod + def _hasher(request: Request) -> PasswordHasher | None: + hasher: PasswordHasher | None = getattr(request.app.state, "hasher", None) + return hasher + + def _unauthorized(self, request: Request) -> JSONResponse: + """Uniform sanitized 401 — identical for every failure cause.""" + return self._fail(request, status_code=401, code="unauthorized", + message="Authentication required.") + + @staticmethod + def _fail( + request: Request, + *, + status_code: int, + code: str, + message: str, + retry_after: int | None = None, + ) -> JSONResponse: + request_id = str(getattr(request.state, "request_id", "") or "") + headers = {"X-Request-ID": request_id} if request_id else {} + if retry_after is not None: + headers["Retry-After"] = str(retry_after) + return JSONResponse( + status_code=status_code, + content={"error": {"code": code, "message": message, "request_id": request_id}}, + headers=headers, + ) + + +__all__ = ["AuthMiddleware"] diff --git a/src/neuronetz_gateway/auth/principal.py b/src/neuronetz_gateway/auth/principal.py new file mode 100644 index 0000000..c679009 --- /dev/null +++ b/src/neuronetz_gateway/auth/principal.py @@ -0,0 +1,75 @@ +"""The resolved authentication principal attached to ``request.state``. + +A :class:`Principal` is the fully-resolved, inheritance-flattened view of an +API key and its owning tenant: every effective limit, the effective model +policy, and the prompt-logging decision. It is what the auth middleware caches +in Redis (as JSON) and what downstream dependencies read on the hot path, so it +deliberately contains no ORM objects and no secrets (never the key hash). +""" + +from __future__ import annotations + +import json +import uuid +from dataclasses import asdict, dataclass + + +@dataclass(frozen=True, slots=True) +class EffectiveLimits: + """Flattened per-request limits after key-over-tenant inheritance.""" + + rpm: int + tpm: int + concurrent: int + tokens_daily: int | None + tokens_monthly: int | None + tokens_total: int | None + + +@dataclass(frozen=True, slots=True) +class Principal: + """Resolved identity + policy for an authenticated request. + + ``allow_all_models`` and ``allowed_models`` are already inheritance-resolved + (key value if non-NULL, else tenant value) so the request path never needs + the raw rows again. ``allowed_models`` is the configured allowlist *before* + intersection with the live discovered set (that intersection happens in the + allowlist resolver against fresh discovery data). + """ + + key_id: uuid.UUID + key_prefix: str + tenant_id: uuid.UUID + tenant_name: str + scopes: tuple[str, ...] + log_prompts: bool + allow_all_models: bool + allowed_models: tuple[str, ...] + limits: EffectiveLimits + + def to_json(self) -> str: + """Serialize to a compact JSON string for the Redis key cache.""" + data = asdict(self) + data["key_id"] = str(self.key_id) + data["tenant_id"] = str(self.tenant_id) + return json.dumps(data, separators=(",", ":")) + + @classmethod + def from_json(cls, raw: str) -> Principal: + """Reconstruct a principal from its cached JSON representation.""" + data = json.loads(raw) + limits = EffectiveLimits(**data["limits"]) + return cls( + key_id=uuid.UUID(data["key_id"]), + key_prefix=data["key_prefix"], + tenant_id=uuid.UUID(data["tenant_id"]), + tenant_name=data["tenant_name"], + scopes=tuple(data["scopes"]), + log_prompts=bool(data["log_prompts"]), + allow_all_models=bool(data["allow_all_models"]), + allowed_models=tuple(data["allowed_models"]), + limits=limits, + ) + + +__all__ = ["EffectiveLimits", "Principal"] diff --git a/src/neuronetz_gateway/cli/__init__.py b/src/neuronetz_gateway/cli/__init__.py new file mode 100644 index 0000000..0ae80da --- /dev/null +++ b/src/neuronetz_gateway/cli/__init__.py @@ -0,0 +1,3 @@ +"""Bootstrap CLI package (Typer).""" + +from __future__ import annotations diff --git a/src/neuronetz_gateway/cli/manage.py b/src/neuronetz_gateway/cli/manage.py new file mode 100644 index 0000000..661bcd5 --- /dev/null +++ b/src/neuronetz_gateway/cli/manage.py @@ -0,0 +1,323 @@ +"""Bootstrap CLI (Typer) per SPEC §11. + +Entry point: ``neuronetz-gateway = neuronetz_gateway.cli.manage:app``. + +This is the *only* supported way to create tenants and keys (AGENT_PROMPT +non-negotiable #10: the CLI must work before the first manual ``curl``). Each +command opens its own short-lived async engine against ``DATABASE_URL``, does +its unit of work in a transaction, and exits. The full API key is printed +exactly once, at creation, and never stored or logged. + +``list-models`` reads the discovery cache from Redis (SPEC §4.6); with +``--tenant`` it also resolves and prints that tenant's effective model set. +""" + +from __future__ import annotations + +import asyncio +from collections.abc import Awaitable, Callable +from typing import Annotated + +import typer + +from neuronetz_gateway.auth.hashing import build_hasher, hash_secret +from neuronetz_gateway.auth.keys import generate_key +from neuronetz_gateway.config import Settings, get_settings +from neuronetz_gateway.db.models import BudgetPeriod, KeyStatus +from neuronetz_gateway.db.repositories import ( + ApiKeyRepository, + BudgetRepository, + KeyLimitRepository, + RevocationRepository, + TenantRepository, +) +from neuronetz_gateway.db.session import create_engine, create_session_factory, session_scope +from neuronetz_gateway.proxy.allowlist import resolve_effective_models +from neuronetz_gateway.proxy.discovery import read_discovered_from_redis + +app = typer.Typer( + name="neuronetz-gateway", + help="Bootstrap CLI for the neuronetz-gateway (tenants, keys, budgets).", + no_args_is_help=True, + add_completion=False, +) + +def _run[T](coro_factory: Callable[[Settings], Awaitable[T]]) -> T: + """Execute an async unit of work against a fresh engine, then dispose it.""" + + async def _main() -> T: + settings = get_settings() + engine = create_engine(settings) + try: + return await coro_factory(settings) + finally: + await engine.dispose() + + return asyncio.run(_main()) + + +@app.command("create-tenant") +def create_tenant( + name: Annotated[str, typer.Option("--name", help="Unique tenant name.")], + rpm: Annotated[int, typer.Option("--rpm", help="Requests-per-minute limit.")] = 60, + tpm: Annotated[int, typer.Option("--tpm", help="Tokens-per-minute limit.")] = 100_000, + concurrent: Annotated[ + int, typer.Option("--concurrent", help="Concurrent-connection cap.") + ] = 8, + allow_all_models: Annotated[ + bool, + typer.Option( + "--allow-all-models/--no-allow-all-models", + help="Opt the tenant into using any installed model.", + ), + ] = False, +) -> None: + """Create a tenant with optional rate limits and model policy.""" + + async def work(settings: Settings) -> None: + factory = create_session_factory(create_engine(settings)) + async with session_scope(factory) as session: + tenants = TenantRepository(session) + if await tenants.get_by_name(name) is not None: + raise typer.BadParameter(f"tenant {name!r} already exists") + tenant = await tenants.create( + name=name, + rpm=rpm, + tpm=tpm, + concurrent=concurrent, + allow_all_models=allow_all_models, + ) + typer.echo(f"created tenant {tenant.name} ({tenant.id})") + typer.echo(f" allow_all_models={allow_all_models} rpm={rpm} tpm={tpm}") + + _run(work) + + +@app.command("create-key") +def create_key( + tenant: Annotated[str, typer.Option("--tenant", help="Owning tenant name.")], + name: Annotated[str, typer.Option("--name", help="Human-readable key name.")], + scopes: Annotated[ + str, typer.Option("--scopes", help="Comma-separated scopes.") + ] = "chat,embeddings", +) -> None: + """Create an API key for a tenant. The full key is printed exactly once.""" + + async def work(settings: Settings) -> None: + factory = create_session_factory(create_engine(settings)) + hasher = build_hasher(settings) + scope_list = [s.strip() for s in scopes.split(",") if s.strip()] + async with session_scope(factory) as session: + tenants = TenantRepository(session) + tenant_row = await tenants.get_by_name(tenant) + if tenant_row is None: + raise typer.BadParameter(f"unknown tenant {tenant!r}") + generated = generate_key() + key_hash = hash_secret(hasher, generated.full_key) + keys = ApiKeyRepository(session) + created = await keys.create( + tenant_id=tenant_row.id, + prefix=generated.prefix, + key_hash=key_hash, + name=name, + scopes=scope_list, + ) + typer.echo(f"created key {created.name} for tenant {tenant} (prefix {created.prefix})") + typer.echo("") + typer.secho("API KEY (shown once — store it now):", fg=typer.colors.YELLOW, bold=True) + typer.secho(generated.full_key, fg=typer.colors.GREEN, bold=True) + + _run(work) + + +@app.command("revoke-key") +def revoke_key( + prefix: Annotated[str, typer.Option("--prefix", help="Key prefix to revoke.")], +) -> None: + """Revoke a key by its prefix (sets status + writes the revocation outbox).""" + + async def work(settings: Settings) -> None: + factory = create_session_factory(create_engine(settings)) + async with session_scope(factory) as session: + keys = ApiKeyRepository(session) + key = await keys.get_by_prefix(prefix) + if key is None: + raise typer.BadParameter(f"no key with prefix {prefix!r}") + await keys.set_status(key.id, KeyStatus.revoked) + await RevocationRepository(session).insert(key.id, reason="cli revoke") + typer.echo(f"revoked key {prefix} ({key.id})") + + _run(work) + + +@app.command("list-keys") +def list_keys( + tenant: Annotated[str, typer.Option("--tenant", help="Tenant whose keys to list.")], +) -> None: + """List a tenant's keys (prefixes and metadata, never full keys).""" + + async def work(settings: Settings) -> None: + factory = create_session_factory(create_engine(settings)) + async with session_scope(factory) as session: + tenants = TenantRepository(session) + tenant_row = await tenants.get_by_name(tenant) + if tenant_row is None: + raise typer.BadParameter(f"unknown tenant {tenant!r}") + rows = await ApiKeyRepository(session).list_for_tenant(tenant_row.id) + if not rows: + typer.echo("(no keys)") + return + for key in rows: + typer.echo( + f"{key.prefix} status={key.status.value:<8} " + f"name={key.name!r} created={key.created_at.isoformat()}" + ) + + _run(work) + + +@app.command("show-usage") +def show_usage( + tenant: Annotated[str, typer.Option("--tenant", help="Tenant to report usage for.")], + period: Annotated[str, typer.Option("--period", help="Period: day|month|total.")] = "day", +) -> None: + """Show token/request usage for a tenant in a period.""" + + async def work(settings: Settings) -> None: + try: + period_enum = BudgetPeriod(period) + except ValueError as exc: + raise typer.BadParameter("period must be one of day|month|total") from exc + factory = create_session_factory(create_engine(settings)) + async with session_scope(factory) as session: + tenant_row = await TenantRepository(session).get_by_name(tenant) + if tenant_row is None: + raise typer.BadParameter(f"unknown tenant {tenant!r}") + tokens_in, tokens_out, requests = await BudgetRepository(session).usage_for_tenant( + tenant_row.id, period_enum + ) + typer.echo(f"usage for {tenant} (period={period}):") + typer.echo(f" requests={requests} tokens_in={tokens_in} tokens_out={tokens_out}") + + _run(work) + + +@app.command("set-budget") +def set_budget( + key: Annotated[str, typer.Option("--key", help="Key prefix to set budget on.")], + daily: Annotated[int | None, typer.Option("--daily", help="Daily token budget.")] = None, + monthly: Annotated[ + int | None, typer.Option("--monthly", help="Monthly token budget.") + ] = None, + total: Annotated[int | None, typer.Option("--total", help="Lifetime token budget.")] = None, +) -> None: + """Set per-key token budgets.""" + + async def work(settings: Settings) -> None: + if daily is None and monthly is None and total is None: + raise typer.BadParameter("provide at least one of --daily/--monthly/--total") + factory = create_session_factory(create_engine(settings)) + async with session_scope(factory) as session: + key_row = await ApiKeyRepository(session).get_by_prefix(key) + if key_row is None: + raise typer.BadParameter(f"no key with prefix {key!r}") + await KeyLimitRepository(session).upsert_budget( + key_row.id, + tokens_daily=daily, + tokens_monthly=monthly, + tokens_total=total, + ) + typer.echo(f"set budget on {key}: daily={daily} monthly={monthly} total={total}") + + _run(work) + + +@app.command("set-models") +def set_models( + tenant: Annotated[str, typer.Option("--tenant", help="Tenant to set models for.")], + models: Annotated[ + str | None, typer.Option("--models", help="Comma-separated model allowlist.") + ] = None, + allow_all: Annotated[ + bool | None, + typer.Option( + "--allow-all/--no-allow-all", + help="Opt into / out of allow_all_models for the tenant.", + ), + ] = None, +) -> None: + """Set a tenant's model allowlist and/or its allow_all_models flag.""" + + async def work(settings: Settings) -> None: + if models is None and allow_all is None: + raise typer.BadParameter("provide --models and/or --allow-all/--no-allow-all") + allowed = ( + [m.strip() for m in models.split(",") if m.strip()] if models is not None else None + ) + factory = create_session_factory(create_engine(settings)) + async with session_scope(factory) as session: + tenants = TenantRepository(session) + tenant_row = await tenants.get_by_name(tenant) + if tenant_row is None: + raise typer.BadParameter(f"unknown tenant {tenant!r}") + await tenants.set_models( + tenant_row.id, allowed_models=allowed, allow_all_models=allow_all + ) + typer.echo(f"updated models for {tenant}: allowed={allowed} allow_all={allow_all}") + + _run(work) + + +@app.command("list-models") +def list_models( + tenant: Annotated[ + str | None, typer.Option("--tenant", help="Also show this tenant's effective set.") + ] = None, +) -> None: + """Show live-discovered models (and, with --tenant, the effective set).""" + import redis.asyncio as redis + + async def work(settings: Settings) -> None: + client = redis.from_url(settings.redis_url, decode_responses=True) + try: + discovered = await read_discovered_from_redis(client) + finally: + await client.aclose() + discovered_names = sorted(discovered) + typer.echo("discovered models (live from Ollama via discovery cache):") + if discovered_names: + for name in discovered_names: + typer.echo(f" {name}") + else: + typer.echo(" (none — discovery cache empty or expired; requests fail closed)") + + if tenant is None: + return + factory = create_session_factory(create_engine(settings)) + async with session_scope(factory) as session: + tenants = TenantRepository(session) + tenant_row = await tenants.get_by_name(tenant) + if tenant_row is None: + raise typer.BadParameter(f"unknown tenant {tenant!r}") + limits = await tenants.get_limits(tenant_row.id) + if limits is None: + raise typer.BadParameter(f"tenant {tenant!r} has no limits row") + effective = resolve_effective_models( + allow_all=limits.allow_all_models, + allowed_models=tuple(limits.allowed_models), + discovered=discovered, + ) + typer.echo(f"effective set for tenant {tenant}:") + for name in sorted(effective): + typer.echo(f" {name}") + + _run(work) + + +def main() -> None: + """Console-script entry point.""" + app() + + +if __name__ == "__main__": + main() diff --git a/src/neuronetz_gateway/db/repositories.py b/src/neuronetz_gateway/db/repositories.py new file mode 100644 index 0000000..961f250 --- /dev/null +++ b/src/neuronetz_gateway/db/repositories.py @@ -0,0 +1,343 @@ +"""Async repositories for ``gateway`` schema access (SPEC §4.3, §5, §11). + +These wrap SQLAlchemy 2.0 async sessions. The auth hot-path uses +:meth:`ApiKeyRepository.load_for_auth` (one round trip joining the key, tenant, +and both limit rows) to build a flattened :class:`~neuronetz_gateway.auth.principal.Principal`. +CLI/admin flows use the create/list/usage methods. Nothing here returns a key +hash to callers other than the auth resolver, which needs it to run argon2. +""" + +from __future__ import annotations + +import datetime +import uuid +from dataclasses import dataclass + +from sqlalchemy import insert, select, update +from sqlalchemy.ext.asyncio import AsyncSession + +from neuronetz_gateway.auth.principal import EffectiveLimits, Principal +from neuronetz_gateway.config import Settings +from neuronetz_gateway.db.models import ( + ApiKey, + AuditLog, + BudgetPeriod, + BudgetUsage, + KeyLimit, + KeyStatus, + Revocation, + Tenant, + TenantLimit, +) + + +@dataclass(frozen=True, slots=True) +class AuthRow: + """Raw (un-flattened) auth lookup result: the row data argon2 needs.""" + + key_hash: str + principal: Principal + + +def _resolve_limits( + settings: Settings, tenant: TenantLimit, key: KeyLimit | None +) -> EffectiveLimits: + """Flatten key-over-tenant numeric limits (NULL key value inherits tenant).""" + + def pick_int(key_val: int | None, tenant_val: int) -> int: + return key_val if key_val is not None else tenant_val + + def pick_opt(key_val: int | None, tenant_val: int | None) -> int | None: + return key_val if key_val is not None else tenant_val + + return EffectiveLimits( + rpm=pick_int(key.rpm if key else None, tenant.rpm), + tpm=pick_int(key.tpm if key else None, tenant.tpm), + concurrent=pick_int(key.concurrent if key else None, tenant.concurrent), + tokens_daily=pick_opt(key.tokens_daily if key else None, tenant.tokens_daily), + tokens_monthly=pick_opt(key.tokens_monthly if key else None, tenant.tokens_monthly), + tokens_total=pick_opt(key.tokens_total if key else None, tenant.tokens_total), + ) + + +class ApiKeyRepository: + """Read/write access to ``gateway.api_keys`` and key limits.""" + + def __init__(self, session: AsyncSession) -> None: + self._session = session + + async def get_active_by_prefix(self, prefix: str) -> ApiKey | None: + """Return the active key matching the prefix, or None.""" + stmt = select(ApiKey).where( + ApiKey.prefix == prefix, ApiKey.status == KeyStatus.active + ) + return (await self._session.execute(stmt)).scalar_one_or_none() + + async def get_by_prefix(self, prefix: str) -> ApiKey | None: + """Return the key matching the prefix regardless of status, or None.""" + stmt = select(ApiKey).where(ApiKey.prefix == prefix) + return (await self._session.execute(stmt)).scalar_one_or_none() + + async def load_for_auth(self, prefix: str, settings: Settings) -> AuthRow | None: + """Resolve an active key + tenant + limits into an :class:`AuthRow`. + + Returns ``None`` if no active key has this prefix, the tenant is not + active, or the key has expired. Caller still verifies the argon2 hash. + """ + key = await self.get_active_by_prefix(prefix) + if key is None: + return None + if key.expires_at is not None and key.expires_at <= _utcnow(): + return None + + tenant = await self._session.get(Tenant, key.tenant_id) + if tenant is None or tenant.status.value != "active": + return None + tenant_limit = await self._session.get(TenantLimit, key.tenant_id) + if tenant_limit is None: + return None + key_limit = await self._session.get(KeyLimit, key.id) + + allow_all = ( + key_limit.allow_all_models + if key_limit is not None and key_limit.allow_all_models is not None + else tenant_limit.allow_all_models + ) + allowed = ( + key_limit.allowed_models + if key_limit is not None and key_limit.allowed_models is not None + else tenant_limit.allowed_models + ) + log_prompts = ( + key.log_prompts + if key.log_prompts is not None + else tenant_limit.log_prompts_default + ) + + principal = Principal( + key_id=key.id, + key_prefix=key.prefix, + tenant_id=tenant.id, + tenant_name=tenant.name, + scopes=tuple(key.scopes), + log_prompts=log_prompts, + allow_all_models=allow_all, + allowed_models=tuple(allowed), + limits=_resolve_limits(settings, tenant_limit, key_limit), + ) + return AuthRow(key_hash=key.key_hash, principal=principal) + + async def create( + self, + *, + tenant_id: uuid.UUID, + prefix: str, + key_hash: str, + name: str, + scopes: list[str], + ) -> ApiKey: + """Insert a new API key row and return the persisted ORM object.""" + key = ApiKey( + tenant_id=tenant_id, + prefix=prefix, + key_hash=key_hash, + name=name, + scopes=scopes, + ) + self._session.add(key) + await self._session.flush() + return key + + async def list_for_tenant(self, tenant_id: uuid.UUID) -> list[ApiKey]: + """Return all keys belonging to a tenant, newest first.""" + stmt = ( + select(ApiKey) + .where(ApiKey.tenant_id == tenant_id) + .order_by(ApiKey.created_at.desc()) + ) + return list((await self._session.execute(stmt)).scalars().all()) + + async def touch_last_used(self, key_id: uuid.UUID) -> None: + """Best-effort update of ``last_used_at`` (off the hot path).""" + await self._session.execute( + update(ApiKey).where(ApiKey.id == key_id).values(last_used_at=_utcnow()) + ) + + async def set_status(self, key_id: uuid.UUID, status: KeyStatus) -> None: + """Set a key's lifecycle status.""" + await self._session.execute( + update(ApiKey).where(ApiKey.id == key_id).values(status=status) + ) + + +class KeyLimitRepository: + """Read/write access to ``gateway.key_limits``.""" + + def __init__(self, session: AsyncSession) -> None: + self._session = session + + async def get(self, key_id: uuid.UUID) -> KeyLimit | None: + """Return the key's limit row, or None.""" + return await self._session.get(KeyLimit, key_id) + + async def upsert_budget( + self, + key_id: uuid.UUID, + *, + tokens_daily: int | None, + tokens_monthly: int | None, + tokens_total: int | None, + ) -> None: + """Set per-key token budgets, creating the limit row if needed.""" + existing = await self.get(key_id) + if existing is None: + self._session.add( + KeyLimit( + key_id=key_id, + tokens_daily=tokens_daily, + tokens_monthly=tokens_monthly, + tokens_total=tokens_total, + ) + ) + else: + if tokens_daily is not None: + existing.tokens_daily = tokens_daily + if tokens_monthly is not None: + existing.tokens_monthly = tokens_monthly + if tokens_total is not None: + existing.tokens_total = tokens_total + await self._session.flush() + + +class TenantRepository: + """Read/write access to ``gateway.tenants`` and tenant limits.""" + + def __init__(self, session: AsyncSession) -> None: + self._session = session + + async def get_by_name(self, name: str) -> Tenant | None: + """Return the tenant with the given unique name, or None.""" + stmt = select(Tenant).where(Tenant.name == name) + return (await self._session.execute(stmt)).scalar_one_or_none() + + async def get_limits(self, tenant_id: uuid.UUID) -> TenantLimit | None: + """Return the tenant's limit row, or None.""" + return await self._session.get(TenantLimit, tenant_id) + + async def create( + self, + *, + name: str, + rpm: int, + tpm: int, + concurrent: int, + allow_all_models: bool, + ) -> Tenant: + """Create a tenant and its default limit row in one unit of work.""" + tenant = Tenant(name=name) + self._session.add(tenant) + await self._session.flush() + self._session.add( + TenantLimit( + tenant_id=tenant.id, + rpm=rpm, + tpm=tpm, + concurrent=concurrent, + allow_all_models=allow_all_models, + ) + ) + await self._session.flush() + return tenant + + async def set_models( + self, + tenant_id: uuid.UUID, + *, + allowed_models: list[str] | None = None, + allow_all_models: bool | None = None, + ) -> None: + """Update the tenant's model allowlist and/or allow-all flag.""" + limits = await self.get_limits(tenant_id) + if limits is None: + return + if allowed_models is not None: + limits.allowed_models = allowed_models + if allow_all_models is not None: + limits.allow_all_models = allow_all_models + await self._session.flush() + + +class AuditRepository: + """Append-only writes to ``gateway.audit_log`` and ``gateway.prompt_log``.""" + + def __init__(self, session: AsyncSession) -> None: + self._session = session + + async def insert_audit(self, **fields: object) -> int: + """Insert an audit row, returning its bigserial id.""" + stmt = insert(AuditLog).values(**fields).returning(AuditLog.id) + result = await self._session.execute(stmt) + return int(result.scalar_one()) + + +class BudgetRepository: + """Reads/writes ``gateway.budget_usage`` (the budget source of truth).""" + + def __init__(self, session: AsyncSession) -> None: + self._session = session + + async def usage_for_tenant( + self, tenant_id: uuid.UUID, period: BudgetPeriod + ) -> tuple[int, int, int]: + """Return summed (tokens_in, tokens_out, requests) for a tenant+period.""" + from sqlalchemy import func + + stmt = ( + select( + func.coalesce(func.sum(BudgetUsage.tokens_in), 0), + func.coalesce(func.sum(BudgetUsage.tokens_out), 0), + func.coalesce(func.sum(BudgetUsage.requests), 0), + ) + .join(ApiKey, ApiKey.id == BudgetUsage.key_id) + .where(ApiKey.tenant_id == tenant_id, BudgetUsage.period == period) + ) + row = (await self._session.execute(stmt)).one() + return int(row[0]), int(row[1]), int(row[2]) + + +class RevocationRepository: + """Reads/marks rows in the ``gateway.revocations`` outbox.""" + + def __init__(self, session: AsyncSession) -> None: + self._session = session + + async def insert(self, key_id: uuid.UUID, reason: str | None) -> int: + """Insert a revocation row (fires the NOTIFY trigger); return its id.""" + stmt = insert(Revocation).values(key_id=key_id, reason=reason).returning( + Revocation.id + ) + return int((await self._session.execute(stmt)).scalar_one()) + + async def mark_processed(self, revocation_id: int, key_id: uuid.UUID) -> None: + """Stamp ``processed_at`` after the Redis cache eviction.""" + await self._session.execute( + update(Revocation) + .where(Revocation.id == revocation_id, Revocation.key_id == key_id) + .values(processed_at=_utcnow()) + ) + + +def _utcnow() -> datetime.datetime: + """Timezone-aware current time (UTC).""" + return datetime.datetime.now(datetime.UTC) + + +__all__ = [ + "ApiKeyRepository", + "AuditRepository", + "AuthRow", + "BudgetRepository", + "KeyLimitRepository", + "RevocationRepository", + "TenantRepository", +]