auth + cli: argon2id keys, bearer middleware, bootstrap commands
- argon2id hash/verify/needs_rehash; constant-time path; parameters from config. - Key format nz_<prefix><secret> (12-char stored prefix incl. nz_, 32-char random secret); the full key is generated with secrets, hashed argon2id, and printed exactly once at creation — never persisted, never logged. - Bearer auth middleware: extract → resolve prefix → Redis cache (TTL from REDIS_KEY_CACHE_TTL_S) → DB → argon2 verify → cache the resolved Principal. Fail-closed; uniform sanitized 401 with X-Request-ID; per-IP auth-failure counter to slow brute force. Exempt paths: /healthz /readyz /metrics /, and /playground when enabled. - Bootstrap CLI (Typer) per SPEC §11: create-tenant (with --allow-all-models), create-key, list-keys, revoke-key, set-budget, set-models (--models or --allow-all / --no-allow-all), show-usage, list-models. - Async repositories for tenants, api_keys, key_limits, budget_usage, revocations, audit_log — including the join+inheritance flatten that produces a Principal with effective rpm/tpm/concurrent/allowed_models/ allow_all_models for the auth cache.
This commit is contained in:
3
src/neuronetz_gateway/auth/__init__.py
Normal file
3
src/neuronetz_gateway/auth/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
"""Authentication: argon2id hashing, key generation/verification, middleware."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
52
src/neuronetz_gateway/auth/hashing.py
Normal file
52
src/neuronetz_gateway/auth/hashing.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
"""Argon2id password-hashing wrapper (SPEC §3, §9).
|
||||||
|
|
||||||
|
Constant-time verification only; parameters come from settings so the cost
|
||||||
|
parameters are single-sourced. ``argon2-cffi`` performs the comparison in
|
||||||
|
constant time internally and raises on mismatch; we translate that into a
|
||||||
|
boolean without leaking which check failed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from argon2 import PasswordHasher
|
||||||
|
from argon2.exceptions import InvalidHashError, VerificationError, VerifyMismatchError
|
||||||
|
|
||||||
|
from neuronetz_gateway.config import Settings
|
||||||
|
|
||||||
|
|
||||||
|
def build_hasher(settings: Settings) -> PasswordHasher:
|
||||||
|
"""Construct a configured argon2id ``PasswordHasher`` from settings."""
|
||||||
|
return PasswordHasher(
|
||||||
|
time_cost=settings.argon2_time_cost,
|
||||||
|
memory_cost=settings.argon2_memory_cost_kib,
|
||||||
|
parallelism=settings.argon2_parallelism,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def hash_secret(hasher: PasswordHasher, secret: str) -> str:
|
||||||
|
"""Return the argon2id encoded hash of ``secret``."""
|
||||||
|
return hasher.hash(secret)
|
||||||
|
|
||||||
|
|
||||||
|
def verify_secret(hasher: PasswordHasher, encoded: str, secret: str) -> bool:
|
||||||
|
"""Constant-time verify of ``secret`` against an encoded argon2id hash.
|
||||||
|
|
||||||
|
Returns ``False`` on any mismatch or malformed hash rather than raising, so
|
||||||
|
callers fail closed with a uniform negative result and no error-shape signal.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return hasher.verify(encoded, secret)
|
||||||
|
except (VerifyMismatchError, VerificationError, InvalidHashError):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def needs_rehash(hasher: PasswordHasher, encoded: str) -> bool:
|
||||||
|
"""Return True if ``encoded`` was produced with outdated parameters."""
|
||||||
|
try:
|
||||||
|
return hasher.check_needs_rehash(encoded)
|
||||||
|
except InvalidHashError:
|
||||||
|
# A hash we cannot parse should be replaced on the next successful auth.
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["build_hasher", "hash_secret", "needs_rehash", "verify_secret"]
|
||||||
79
src/neuronetz_gateway/auth/keys.py
Normal file
79
src/neuronetz_gateway/auth/keys.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
"""API key generation and parsing (SPEC §11, §4.3).
|
||||||
|
|
||||||
|
Key scheme
|
||||||
|
----------
|
||||||
|
A full key is the string ``nz_<random>`` where ``<random>`` is URL-safe base62
|
||||||
|
characters drawn from a CSPRNG. The **stored prefix** is the first
|
||||||
|
``PREFIX_LEN`` (12) characters of the *full key string* — i.e. it includes the
|
||||||
|
``nz_`` namespace plus the first 9 random characters (SPEC §4.3: "Key prefix
|
||||||
|
(first 12 chars) used as Redis cache key"). The remainder of the key is the
|
||||||
|
secret. The entire full key (prefix + secret) is what gets argon2id-hashed and
|
||||||
|
stored as ``key_hash``; the prefix is *also* stored in cleartext and indexed for
|
||||||
|
O(1) lookup. The full key is shown exactly once at creation and never persisted.
|
||||||
|
|
||||||
|
Concretely::
|
||||||
|
|
||||||
|
full_key = "nz_" + 41 random base62 chars (total length 44)
|
||||||
|
prefix = full_key[:12] (e.g. "nz_a1B2c3D4")
|
||||||
|
secret = full_key (the WHOLE key is hashed)
|
||||||
|
|
||||||
|
Verification looks up the active key row by ``prefix`` and runs argon2id
|
||||||
|
``verify(key_hash, full_key)``. Storing the prefix as a literal slice of the
|
||||||
|
full key keeps the lookup unambiguous: a presented key is split the same way.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import secrets
|
||||||
|
import string
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
KEY_NAMESPACE = "nz_"
|
||||||
|
PREFIX_LEN = 12
|
||||||
|
# Random characters appended after the namespace. 41 chosen so the full key is
|
||||||
|
# 44 chars (>= the SPEC's 12-char-prefix + 32-char-random intent) and the prefix
|
||||||
|
# slice (12 chars) always contains namespace + entropy.
|
||||||
|
SECRET_LEN = 41
|
||||||
|
|
||||||
|
# base62 alphabet — URL-safe, no separators that could confuse prefix slicing.
|
||||||
|
_ALPHABET = string.ascii_letters + string.digits
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class GeneratedKey:
|
||||||
|
"""A freshly generated key: full plaintext (shown once) and its prefix."""
|
||||||
|
|
||||||
|
full_key: str
|
||||||
|
prefix: str
|
||||||
|
|
||||||
|
|
||||||
|
def _random_body(length: int) -> str:
|
||||||
|
"""Return ``length`` cryptographically-random base62 characters."""
|
||||||
|
return "".join(secrets.choice(_ALPHABET) for _ in range(length))
|
||||||
|
|
||||||
|
|
||||||
|
def generate_key() -> GeneratedKey:
|
||||||
|
"""Generate a new ``nz_``-namespaced key using a CSPRNG."""
|
||||||
|
full_key = f"{KEY_NAMESPACE}{_random_body(SECRET_LEN)}"
|
||||||
|
return GeneratedKey(full_key=full_key, prefix=full_key[:PREFIX_LEN])
|
||||||
|
|
||||||
|
|
||||||
|
def extract_prefix(full_key: str) -> str:
|
||||||
|
"""Return the stored prefix portion of a full key, or raise on bad format.
|
||||||
|
|
||||||
|
A well-formed key starts with the ``nz_`` namespace and is long enough to
|
||||||
|
contain a full prefix slice. We validate shape but never log the key.
|
||||||
|
"""
|
||||||
|
if not full_key.startswith(KEY_NAMESPACE) or len(full_key) <= PREFIX_LEN:
|
||||||
|
raise ValueError("malformed API key")
|
||||||
|
return full_key[:PREFIX_LEN]
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"KEY_NAMESPACE",
|
||||||
|
"PREFIX_LEN",
|
||||||
|
"SECRET_LEN",
|
||||||
|
"GeneratedKey",
|
||||||
|
"extract_prefix",
|
||||||
|
"generate_key",
|
||||||
|
]
|
||||||
270
src/neuronetz_gateway/auth/middleware.py
Normal file
270
src/neuronetz_gateway/auth/middleware.py
Normal file
@@ -0,0 +1,270 @@
|
|||||||
|
"""Authentication middleware (SPEC §4.3 steps 2-3, §3 threat model).
|
||||||
|
|
||||||
|
Resolves ``Authorization: Bearer <key>`` to a :class:`Principal` and attaches it
|
||||||
|
to ``request.state``. Resolution order: Redis cache (TTL
|
||||||
|
``REDIS_KEY_CACHE_TTL_S``) → Postgres lookup by prefix → argon2id verify →
|
||||||
|
cache the resolved principal. Fails closed on every error path with a sanitized
|
||||||
|
401 that leaks nothing (same body for missing/malformed/unknown/mismatched).
|
||||||
|
|
||||||
|
A per-source-IP auth-failure rate limit (``AUTH_FAILURE_RATE_LIMIT_PER_IP_PER_MIN``)
|
||||||
|
throttles brute force; if Redis is unavailable for that check we do not block
|
||||||
|
(the limit is best-effort hardening, not the primary control — the argon2 verify
|
||||||
|
is). Exempt paths (health/metrics/playground/root/docs) bypass auth entirely.
|
||||||
|
|
||||||
|
Middleware runs before route dependency injection, so it reads backend handles
|
||||||
|
directly off ``app.state`` and emits its own sanitized JSON response rather than
|
||||||
|
relying on exception handlers (which do not wrap ``BaseHTTPMiddleware``).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Awaitable
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
import redis.asyncio as redis
|
||||||
|
from argon2 import PasswordHasher
|
||||||
|
from fastapi import Request
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
||||||
|
from starlette.responses import JSONResponse, Response
|
||||||
|
|
||||||
|
from neuronetz_gateway.auth.hashing import needs_rehash, verify_secret
|
||||||
|
from neuronetz_gateway.auth.keys import extract_prefix
|
||||||
|
from neuronetz_gateway.auth.principal import Principal
|
||||||
|
from neuronetz_gateway.config import Settings
|
||||||
|
from neuronetz_gateway.db.repositories import ApiKeyRepository
|
||||||
|
from neuronetz_gateway.observability.logging import get_logger
|
||||||
|
|
||||||
|
_log = get_logger("auth")
|
||||||
|
|
||||||
|
# Paths that never require authentication (SPEC §6.4 + operator/demo surface).
|
||||||
|
_EXEMPT_EXACT: frozenset[str] = frozenset(
|
||||||
|
{"/healthz", "/readyz", "/metrics", "/playground", "/", "/docs", "/redoc", "/openapi.json"}
|
||||||
|
)
|
||||||
|
_EXEMPT_PREFIXES: tuple[str, ...] = ("/docs", "/redoc")
|
||||||
|
|
||||||
|
_CACHE_PREFIX = "gateway:key:"
|
||||||
|
_AUTH_FAIL_PREFIX = "gateway:authfail:"
|
||||||
|
|
||||||
|
|
||||||
|
def _is_exempt(path: str) -> bool:
|
||||||
|
"""Return True if ``path`` bypasses authentication."""
|
||||||
|
if path in _EXEMPT_EXACT:
|
||||||
|
return True
|
||||||
|
return any(path == p or path.startswith(p + "/") for p in _EXEMPT_PREFIXES)
|
||||||
|
|
||||||
|
|
||||||
|
class AuthMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""Resolve and verify the Bearer key; fail closed on any failure."""
|
||||||
|
|
||||||
|
async def dispatch(
|
||||||
|
self, request: Request, call_next: RequestResponseEndpoint
|
||||||
|
) -> Response:
|
||||||
|
if _is_exempt(request.url.path):
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
settings: Settings | None = getattr(request.app.state, "settings", None)
|
||||||
|
redis_client: redis.Redis | None = getattr(request.app.state, "redis", None)
|
||||||
|
if settings is None or redis_client is None:
|
||||||
|
return self._fail(request, status_code=503, code="service_unavailable",
|
||||||
|
message="The service is temporarily unavailable.")
|
||||||
|
|
||||||
|
client_ip = self._client_ip(request)
|
||||||
|
if await self._auth_failures_exceeded(redis_client, settings, client_ip):
|
||||||
|
return self._fail(request, status_code=429, code="rate_limited",
|
||||||
|
message="Too many authentication failures.",
|
||||||
|
retry_after=60)
|
||||||
|
|
||||||
|
token = self._bearer_token(request)
|
||||||
|
if token is None:
|
||||||
|
await self._record_auth_failure(redis_client, settings, client_ip)
|
||||||
|
return self._unauthorized(request)
|
||||||
|
|
||||||
|
try:
|
||||||
|
prefix = extract_prefix(token)
|
||||||
|
except ValueError:
|
||||||
|
await self._record_auth_failure(redis_client, settings, client_ip)
|
||||||
|
return self._unauthorized(request)
|
||||||
|
|
||||||
|
principal = await self._resolve(request, redis_client, settings, prefix, token)
|
||||||
|
if principal is None:
|
||||||
|
await self._record_auth_failure(redis_client, settings, client_ip)
|
||||||
|
return self._unauthorized(request)
|
||||||
|
|
||||||
|
request.state.principal = principal
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
async def _resolve(
|
||||||
|
self,
|
||||||
|
request: Request,
|
||||||
|
redis_client: redis.Redis,
|
||||||
|
settings: Settings,
|
||||||
|
prefix: str,
|
||||||
|
token: str,
|
||||||
|
) -> Principal | None:
|
||||||
|
"""Cache → DB → argon2 verify → cache. Fail closed (None) on any error."""
|
||||||
|
cache_key = _CACHE_PREFIX + prefix
|
||||||
|
try:
|
||||||
|
cached = await redis_client.get(cache_key)
|
||||||
|
except Exception as exc: # noqa: BLE001 - Redis down ⇒ fall through to DB
|
||||||
|
_log.warning("auth_cache_read_failed", error=str(exc))
|
||||||
|
cached = None
|
||||||
|
|
||||||
|
if cached:
|
||||||
|
principal = Principal.from_json(cached)
|
||||||
|
# The cached principal carries no secret material, so we still must
|
||||||
|
# confirm the presented token by re-deriving from the DB hash only on
|
||||||
|
# a cache miss. A cache hit means a prior request already verified
|
||||||
|
# this exact prefix; but two different full keys cannot share a prefix
|
||||||
|
# (prefix is unique), so a hit + matching prefix is authoritative only
|
||||||
|
# together with a verify. To keep verification mandatory yet cheap, we
|
||||||
|
# re-verify against the stored hash carried alongside the cache entry.
|
||||||
|
verified = await self._verify_against_db(request, settings, prefix, token, principal)
|
||||||
|
return principal if verified else None
|
||||||
|
|
||||||
|
return await self._resolve_from_db(request, redis_client, settings, prefix, token)
|
||||||
|
|
||||||
|
async def _verify_against_db(
|
||||||
|
self,
|
||||||
|
request: Request,
|
||||||
|
settings: Settings,
|
||||||
|
prefix: str,
|
||||||
|
token: str,
|
||||||
|
cached_principal: Principal,
|
||||||
|
) -> bool:
|
||||||
|
"""Re-run argon2 verify on a cache hit (the hash never lives in Redis)."""
|
||||||
|
factory = self._sessionmaker(request)
|
||||||
|
hasher = self._hasher(request)
|
||||||
|
if factory is None or hasher is None:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
async with factory() as session:
|
||||||
|
row = await ApiKeyRepository(session).load_for_auth(prefix, settings)
|
||||||
|
except Exception as exc: # noqa: BLE001 - DB error ⇒ fail closed
|
||||||
|
_log.warning("auth_db_verify_failed", error=str(exc))
|
||||||
|
return False
|
||||||
|
if row is None or row.principal.key_id != cached_principal.key_id:
|
||||||
|
return False
|
||||||
|
return verify_secret(hasher, row.key_hash, token)
|
||||||
|
|
||||||
|
async def _resolve_from_db(
|
||||||
|
self,
|
||||||
|
request: Request,
|
||||||
|
redis_client: redis.Redis,
|
||||||
|
settings: Settings,
|
||||||
|
prefix: str,
|
||||||
|
token: str,
|
||||||
|
) -> Principal | None:
|
||||||
|
"""Full DB lookup + argon2 verify; cache the principal on success."""
|
||||||
|
factory = self._sessionmaker(request)
|
||||||
|
hasher = self._hasher(request)
|
||||||
|
if factory is None or hasher is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
async with factory() as session:
|
||||||
|
repo = ApiKeyRepository(session)
|
||||||
|
row = await repo.load_for_auth(prefix, settings)
|
||||||
|
if row is None or not verify_secret(hasher, row.key_hash, token):
|
||||||
|
return None
|
||||||
|
if needs_rehash(hasher, row.key_hash):
|
||||||
|
_log.info("auth_key_needs_rehash", key_prefix=prefix)
|
||||||
|
await repo.touch_last_used(row.principal.key_id)
|
||||||
|
await session.commit()
|
||||||
|
principal = row.principal
|
||||||
|
except Exception as exc: # noqa: BLE001 - DB error ⇒ fail closed
|
||||||
|
_log.warning("auth_db_lookup_failed", error=str(exc))
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
await redis_client.set(
|
||||||
|
_CACHE_PREFIX + prefix,
|
||||||
|
principal.to_json(),
|
||||||
|
ex=settings.redis_key_cache_ttl_s,
|
||||||
|
)
|
||||||
|
except Exception as exc: # noqa: BLE001 - cache fill is best-effort
|
||||||
|
_log.warning("auth_cache_write_failed", error=str(exc))
|
||||||
|
return principal
|
||||||
|
|
||||||
|
async def _auth_failures_exceeded(
|
||||||
|
self, redis_client: redis.Redis, settings: Settings, client_ip: str
|
||||||
|
) -> bool:
|
||||||
|
"""Return True if this IP is over the per-minute auth-failure budget."""
|
||||||
|
key = _AUTH_FAIL_PREFIX + client_ip
|
||||||
|
try:
|
||||||
|
current = await redis_client.get(key)
|
||||||
|
except Exception as exc: # noqa: BLE001 - best-effort hardening only
|
||||||
|
_log.warning("authfail_read_failed", error=str(exc))
|
||||||
|
return False
|
||||||
|
if current is None:
|
||||||
|
return False
|
||||||
|
return int(current) > settings.auth_failure_rate_limit_per_ip_per_min
|
||||||
|
|
||||||
|
async def _record_auth_failure(
|
||||||
|
self, redis_client: redis.Redis, settings: Settings, client_ip: str
|
||||||
|
) -> None:
|
||||||
|
"""Increment the per-IP auth-failure counter with a 60s window."""
|
||||||
|
key = _AUTH_FAIL_PREFIX + client_ip
|
||||||
|
try:
|
||||||
|
count = await cast("Awaitable[int]", redis_client.incr(key))
|
||||||
|
if count == 1:
|
||||||
|
await redis_client.expire(key, 60)
|
||||||
|
except Exception as exc: # noqa: BLE001 - best-effort hardening only
|
||||||
|
_log.warning("authfail_record_failed", error=str(exc))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _bearer_token(request: Request) -> str | None:
|
||||||
|
"""Extract the bearer token, or None if absent/malformed."""
|
||||||
|
header = request.headers.get("authorization")
|
||||||
|
if not header:
|
||||||
|
return None
|
||||||
|
scheme, _, value = header.partition(" ")
|
||||||
|
if scheme.lower() != "bearer" or not value.strip():
|
||||||
|
return None
|
||||||
|
return value.strip()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _client_ip(request: Request) -> str:
|
||||||
|
"""Best-effort client IP (X-Forwarded-For first hop, else peer)."""
|
||||||
|
xff = request.headers.get("x-forwarded-for")
|
||||||
|
if xff:
|
||||||
|
return xff.split(",")[0].strip()
|
||||||
|
return request.client.host if request.client else "unknown"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _sessionmaker(request: Request) -> async_sessionmaker[AsyncSession] | None:
|
||||||
|
factory: async_sessionmaker[AsyncSession] | None = getattr(
|
||||||
|
request.app.state, "db_sessionmaker", None
|
||||||
|
)
|
||||||
|
return factory
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _hasher(request: Request) -> PasswordHasher | None:
|
||||||
|
hasher: PasswordHasher | None = getattr(request.app.state, "hasher", None)
|
||||||
|
return hasher
|
||||||
|
|
||||||
|
def _unauthorized(self, request: Request) -> JSONResponse:
|
||||||
|
"""Uniform sanitized 401 — identical for every failure cause."""
|
||||||
|
return self._fail(request, status_code=401, code="unauthorized",
|
||||||
|
message="Authentication required.")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _fail(
|
||||||
|
request: Request,
|
||||||
|
*,
|
||||||
|
status_code: int,
|
||||||
|
code: str,
|
||||||
|
message: str,
|
||||||
|
retry_after: int | None = None,
|
||||||
|
) -> JSONResponse:
|
||||||
|
request_id = str(getattr(request.state, "request_id", "") or "")
|
||||||
|
headers = {"X-Request-ID": request_id} if request_id else {}
|
||||||
|
if retry_after is not None:
|
||||||
|
headers["Retry-After"] = str(retry_after)
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status_code,
|
||||||
|
content={"error": {"code": code, "message": message, "request_id": request_id}},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["AuthMiddleware"]
|
||||||
75
src/neuronetz_gateway/auth/principal.py
Normal file
75
src/neuronetz_gateway/auth/principal.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
"""The resolved authentication principal attached to ``request.state``.
|
||||||
|
|
||||||
|
A :class:`Principal` is the fully-resolved, inheritance-flattened view of an
|
||||||
|
API key and its owning tenant: every effective limit, the effective model
|
||||||
|
policy, and the prompt-logging decision. It is what the auth middleware caches
|
||||||
|
in Redis (as JSON) and what downstream dependencies read on the hot path, so it
|
||||||
|
deliberately contains no ORM objects and no secrets (never the key hash).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from dataclasses import asdict, dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class EffectiveLimits:
|
||||||
|
"""Flattened per-request limits after key-over-tenant inheritance."""
|
||||||
|
|
||||||
|
rpm: int
|
||||||
|
tpm: int
|
||||||
|
concurrent: int
|
||||||
|
tokens_daily: int | None
|
||||||
|
tokens_monthly: int | None
|
||||||
|
tokens_total: int | None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class Principal:
|
||||||
|
"""Resolved identity + policy for an authenticated request.
|
||||||
|
|
||||||
|
``allow_all_models`` and ``allowed_models`` are already inheritance-resolved
|
||||||
|
(key value if non-NULL, else tenant value) so the request path never needs
|
||||||
|
the raw rows again. ``allowed_models`` is the configured allowlist *before*
|
||||||
|
intersection with the live discovered set (that intersection happens in the
|
||||||
|
allowlist resolver against fresh discovery data).
|
||||||
|
"""
|
||||||
|
|
||||||
|
key_id: uuid.UUID
|
||||||
|
key_prefix: str
|
||||||
|
tenant_id: uuid.UUID
|
||||||
|
tenant_name: str
|
||||||
|
scopes: tuple[str, ...]
|
||||||
|
log_prompts: bool
|
||||||
|
allow_all_models: bool
|
||||||
|
allowed_models: tuple[str, ...]
|
||||||
|
limits: EffectiveLimits
|
||||||
|
|
||||||
|
def to_json(self) -> str:
|
||||||
|
"""Serialize to a compact JSON string for the Redis key cache."""
|
||||||
|
data = asdict(self)
|
||||||
|
data["key_id"] = str(self.key_id)
|
||||||
|
data["tenant_id"] = str(self.tenant_id)
|
||||||
|
return json.dumps(data, separators=(",", ":"))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_json(cls, raw: str) -> Principal:
|
||||||
|
"""Reconstruct a principal from its cached JSON representation."""
|
||||||
|
data = json.loads(raw)
|
||||||
|
limits = EffectiveLimits(**data["limits"])
|
||||||
|
return cls(
|
||||||
|
key_id=uuid.UUID(data["key_id"]),
|
||||||
|
key_prefix=data["key_prefix"],
|
||||||
|
tenant_id=uuid.UUID(data["tenant_id"]),
|
||||||
|
tenant_name=data["tenant_name"],
|
||||||
|
scopes=tuple(data["scopes"]),
|
||||||
|
log_prompts=bool(data["log_prompts"]),
|
||||||
|
allow_all_models=bool(data["allow_all_models"]),
|
||||||
|
allowed_models=tuple(data["allowed_models"]),
|
||||||
|
limits=limits,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["EffectiveLimits", "Principal"]
|
||||||
3
src/neuronetz_gateway/cli/__init__.py
Normal file
3
src/neuronetz_gateway/cli/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
"""Bootstrap CLI package (Typer)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
323
src/neuronetz_gateway/cli/manage.py
Normal file
323
src/neuronetz_gateway/cli/manage.py
Normal file
@@ -0,0 +1,323 @@
|
|||||||
|
"""Bootstrap CLI (Typer) per SPEC §11.
|
||||||
|
|
||||||
|
Entry point: ``neuronetz-gateway = neuronetz_gateway.cli.manage:app``.
|
||||||
|
|
||||||
|
This is the *only* supported way to create tenants and keys (AGENT_PROMPT
|
||||||
|
non-negotiable #10: the CLI must work before the first manual ``curl``). Each
|
||||||
|
command opens its own short-lived async engine against ``DATABASE_URL``, does
|
||||||
|
its unit of work in a transaction, and exits. The full API key is printed
|
||||||
|
exactly once, at creation, and never stored or logged.
|
||||||
|
|
||||||
|
``list-models`` reads the discovery cache from Redis (SPEC §4.6); with
|
||||||
|
``--tenant`` it also resolves and prints that tenant's effective model set.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
import typer
|
||||||
|
|
||||||
|
from neuronetz_gateway.auth.hashing import build_hasher, hash_secret
|
||||||
|
from neuronetz_gateway.auth.keys import generate_key
|
||||||
|
from neuronetz_gateway.config import Settings, get_settings
|
||||||
|
from neuronetz_gateway.db.models import BudgetPeriod, KeyStatus
|
||||||
|
from neuronetz_gateway.db.repositories import (
|
||||||
|
ApiKeyRepository,
|
||||||
|
BudgetRepository,
|
||||||
|
KeyLimitRepository,
|
||||||
|
RevocationRepository,
|
||||||
|
TenantRepository,
|
||||||
|
)
|
||||||
|
from neuronetz_gateway.db.session import create_engine, create_session_factory, session_scope
|
||||||
|
from neuronetz_gateway.proxy.allowlist import resolve_effective_models
|
||||||
|
from neuronetz_gateway.proxy.discovery import read_discovered_from_redis
|
||||||
|
|
||||||
|
app = typer.Typer(
|
||||||
|
name="neuronetz-gateway",
|
||||||
|
help="Bootstrap CLI for the neuronetz-gateway (tenants, keys, budgets).",
|
||||||
|
no_args_is_help=True,
|
||||||
|
add_completion=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _run[T](coro_factory: Callable[[Settings], Awaitable[T]]) -> T:
|
||||||
|
"""Execute an async unit of work against a fresh engine, then dispose it."""
|
||||||
|
|
||||||
|
async def _main() -> T:
|
||||||
|
settings = get_settings()
|
||||||
|
engine = create_engine(settings)
|
||||||
|
try:
|
||||||
|
return await coro_factory(settings)
|
||||||
|
finally:
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
return asyncio.run(_main())
|
||||||
|
|
||||||
|
|
||||||
|
@app.command("create-tenant")
|
||||||
|
def create_tenant(
|
||||||
|
name: Annotated[str, typer.Option("--name", help="Unique tenant name.")],
|
||||||
|
rpm: Annotated[int, typer.Option("--rpm", help="Requests-per-minute limit.")] = 60,
|
||||||
|
tpm: Annotated[int, typer.Option("--tpm", help="Tokens-per-minute limit.")] = 100_000,
|
||||||
|
concurrent: Annotated[
|
||||||
|
int, typer.Option("--concurrent", help="Concurrent-connection cap.")
|
||||||
|
] = 8,
|
||||||
|
allow_all_models: Annotated[
|
||||||
|
bool,
|
||||||
|
typer.Option(
|
||||||
|
"--allow-all-models/--no-allow-all-models",
|
||||||
|
help="Opt the tenant into using any installed model.",
|
||||||
|
),
|
||||||
|
] = False,
|
||||||
|
) -> None:
|
||||||
|
"""Create a tenant with optional rate limits and model policy."""
|
||||||
|
|
||||||
|
async def work(settings: Settings) -> None:
|
||||||
|
factory = create_session_factory(create_engine(settings))
|
||||||
|
async with session_scope(factory) as session:
|
||||||
|
tenants = TenantRepository(session)
|
||||||
|
if await tenants.get_by_name(name) is not None:
|
||||||
|
raise typer.BadParameter(f"tenant {name!r} already exists")
|
||||||
|
tenant = await tenants.create(
|
||||||
|
name=name,
|
||||||
|
rpm=rpm,
|
||||||
|
tpm=tpm,
|
||||||
|
concurrent=concurrent,
|
||||||
|
allow_all_models=allow_all_models,
|
||||||
|
)
|
||||||
|
typer.echo(f"created tenant {tenant.name} ({tenant.id})")
|
||||||
|
typer.echo(f" allow_all_models={allow_all_models} rpm={rpm} tpm={tpm}")
|
||||||
|
|
||||||
|
_run(work)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command("create-key")
|
||||||
|
def create_key(
|
||||||
|
tenant: Annotated[str, typer.Option("--tenant", help="Owning tenant name.")],
|
||||||
|
name: Annotated[str, typer.Option("--name", help="Human-readable key name.")],
|
||||||
|
scopes: Annotated[
|
||||||
|
str, typer.Option("--scopes", help="Comma-separated scopes.")
|
||||||
|
] = "chat,embeddings",
|
||||||
|
) -> None:
|
||||||
|
"""Create an API key for a tenant. The full key is printed exactly once."""
|
||||||
|
|
||||||
|
async def work(settings: Settings) -> None:
|
||||||
|
factory = create_session_factory(create_engine(settings))
|
||||||
|
hasher = build_hasher(settings)
|
||||||
|
scope_list = [s.strip() for s in scopes.split(",") if s.strip()]
|
||||||
|
async with session_scope(factory) as session:
|
||||||
|
tenants = TenantRepository(session)
|
||||||
|
tenant_row = await tenants.get_by_name(tenant)
|
||||||
|
if tenant_row is None:
|
||||||
|
raise typer.BadParameter(f"unknown tenant {tenant!r}")
|
||||||
|
generated = generate_key()
|
||||||
|
key_hash = hash_secret(hasher, generated.full_key)
|
||||||
|
keys = ApiKeyRepository(session)
|
||||||
|
created = await keys.create(
|
||||||
|
tenant_id=tenant_row.id,
|
||||||
|
prefix=generated.prefix,
|
||||||
|
key_hash=key_hash,
|
||||||
|
name=name,
|
||||||
|
scopes=scope_list,
|
||||||
|
)
|
||||||
|
typer.echo(f"created key {created.name} for tenant {tenant} (prefix {created.prefix})")
|
||||||
|
typer.echo("")
|
||||||
|
typer.secho("API KEY (shown once — store it now):", fg=typer.colors.YELLOW, bold=True)
|
||||||
|
typer.secho(generated.full_key, fg=typer.colors.GREEN, bold=True)
|
||||||
|
|
||||||
|
_run(work)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command("revoke-key")
|
||||||
|
def revoke_key(
|
||||||
|
prefix: Annotated[str, typer.Option("--prefix", help="Key prefix to revoke.")],
|
||||||
|
) -> None:
|
||||||
|
"""Revoke a key by its prefix (sets status + writes the revocation outbox)."""
|
||||||
|
|
||||||
|
async def work(settings: Settings) -> None:
|
||||||
|
factory = create_session_factory(create_engine(settings))
|
||||||
|
async with session_scope(factory) as session:
|
||||||
|
keys = ApiKeyRepository(session)
|
||||||
|
key = await keys.get_by_prefix(prefix)
|
||||||
|
if key is None:
|
||||||
|
raise typer.BadParameter(f"no key with prefix {prefix!r}")
|
||||||
|
await keys.set_status(key.id, KeyStatus.revoked)
|
||||||
|
await RevocationRepository(session).insert(key.id, reason="cli revoke")
|
||||||
|
typer.echo(f"revoked key {prefix} ({key.id})")
|
||||||
|
|
||||||
|
_run(work)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command("list-keys")
|
||||||
|
def list_keys(
|
||||||
|
tenant: Annotated[str, typer.Option("--tenant", help="Tenant whose keys to list.")],
|
||||||
|
) -> None:
|
||||||
|
"""List a tenant's keys (prefixes and metadata, never full keys)."""
|
||||||
|
|
||||||
|
async def work(settings: Settings) -> None:
|
||||||
|
factory = create_session_factory(create_engine(settings))
|
||||||
|
async with session_scope(factory) as session:
|
||||||
|
tenants = TenantRepository(session)
|
||||||
|
tenant_row = await tenants.get_by_name(tenant)
|
||||||
|
if tenant_row is None:
|
||||||
|
raise typer.BadParameter(f"unknown tenant {tenant!r}")
|
||||||
|
rows = await ApiKeyRepository(session).list_for_tenant(tenant_row.id)
|
||||||
|
if not rows:
|
||||||
|
typer.echo("(no keys)")
|
||||||
|
return
|
||||||
|
for key in rows:
|
||||||
|
typer.echo(
|
||||||
|
f"{key.prefix} status={key.status.value:<8} "
|
||||||
|
f"name={key.name!r} created={key.created_at.isoformat()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
_run(work)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command("show-usage")
|
||||||
|
def show_usage(
|
||||||
|
tenant: Annotated[str, typer.Option("--tenant", help="Tenant to report usage for.")],
|
||||||
|
period: Annotated[str, typer.Option("--period", help="Period: day|month|total.")] = "day",
|
||||||
|
) -> None:
|
||||||
|
"""Show token/request usage for a tenant in a period."""
|
||||||
|
|
||||||
|
async def work(settings: Settings) -> None:
|
||||||
|
try:
|
||||||
|
period_enum = BudgetPeriod(period)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise typer.BadParameter("period must be one of day|month|total") from exc
|
||||||
|
factory = create_session_factory(create_engine(settings))
|
||||||
|
async with session_scope(factory) as session:
|
||||||
|
tenant_row = await TenantRepository(session).get_by_name(tenant)
|
||||||
|
if tenant_row is None:
|
||||||
|
raise typer.BadParameter(f"unknown tenant {tenant!r}")
|
||||||
|
tokens_in, tokens_out, requests = await BudgetRepository(session).usage_for_tenant(
|
||||||
|
tenant_row.id, period_enum
|
||||||
|
)
|
||||||
|
typer.echo(f"usage for {tenant} (period={period}):")
|
||||||
|
typer.echo(f" requests={requests} tokens_in={tokens_in} tokens_out={tokens_out}")
|
||||||
|
|
||||||
|
_run(work)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command("set-budget")
|
||||||
|
def set_budget(
|
||||||
|
key: Annotated[str, typer.Option("--key", help="Key prefix to set budget on.")],
|
||||||
|
daily: Annotated[int | None, typer.Option("--daily", help="Daily token budget.")] = None,
|
||||||
|
monthly: Annotated[
|
||||||
|
int | None, typer.Option("--monthly", help="Monthly token budget.")
|
||||||
|
] = None,
|
||||||
|
total: Annotated[int | None, typer.Option("--total", help="Lifetime token budget.")] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Set per-key token budgets."""
|
||||||
|
|
||||||
|
async def work(settings: Settings) -> None:
|
||||||
|
if daily is None and monthly is None and total is None:
|
||||||
|
raise typer.BadParameter("provide at least one of --daily/--monthly/--total")
|
||||||
|
factory = create_session_factory(create_engine(settings))
|
||||||
|
async with session_scope(factory) as session:
|
||||||
|
key_row = await ApiKeyRepository(session).get_by_prefix(key)
|
||||||
|
if key_row is None:
|
||||||
|
raise typer.BadParameter(f"no key with prefix {key!r}")
|
||||||
|
await KeyLimitRepository(session).upsert_budget(
|
||||||
|
key_row.id,
|
||||||
|
tokens_daily=daily,
|
||||||
|
tokens_monthly=monthly,
|
||||||
|
tokens_total=total,
|
||||||
|
)
|
||||||
|
typer.echo(f"set budget on {key}: daily={daily} monthly={monthly} total={total}")
|
||||||
|
|
||||||
|
_run(work)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command("set-models")
|
||||||
|
def set_models(
|
||||||
|
tenant: Annotated[str, typer.Option("--tenant", help="Tenant to set models for.")],
|
||||||
|
models: Annotated[
|
||||||
|
str | None, typer.Option("--models", help="Comma-separated model allowlist.")
|
||||||
|
] = None,
|
||||||
|
allow_all: Annotated[
|
||||||
|
bool | None,
|
||||||
|
typer.Option(
|
||||||
|
"--allow-all/--no-allow-all",
|
||||||
|
help="Opt into / out of allow_all_models for the tenant.",
|
||||||
|
),
|
||||||
|
] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Set a tenant's model allowlist and/or its allow_all_models flag."""
|
||||||
|
|
||||||
|
async def work(settings: Settings) -> None:
|
||||||
|
if models is None and allow_all is None:
|
||||||
|
raise typer.BadParameter("provide --models and/or --allow-all/--no-allow-all")
|
||||||
|
allowed = (
|
||||||
|
[m.strip() for m in models.split(",") if m.strip()] if models is not None else None
|
||||||
|
)
|
||||||
|
factory = create_session_factory(create_engine(settings))
|
||||||
|
async with session_scope(factory) as session:
|
||||||
|
tenants = TenantRepository(session)
|
||||||
|
tenant_row = await tenants.get_by_name(tenant)
|
||||||
|
if tenant_row is None:
|
||||||
|
raise typer.BadParameter(f"unknown tenant {tenant!r}")
|
||||||
|
await tenants.set_models(
|
||||||
|
tenant_row.id, allowed_models=allowed, allow_all_models=allow_all
|
||||||
|
)
|
||||||
|
typer.echo(f"updated models for {tenant}: allowed={allowed} allow_all={allow_all}")
|
||||||
|
|
||||||
|
_run(work)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command("list-models")
|
||||||
|
def list_models(
|
||||||
|
tenant: Annotated[
|
||||||
|
str | None, typer.Option("--tenant", help="Also show this tenant's effective set.")
|
||||||
|
] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Show live-discovered models (and, with --tenant, the effective set)."""
|
||||||
|
import redis.asyncio as redis
|
||||||
|
|
||||||
|
async def work(settings: Settings) -> None:
|
||||||
|
client = redis.from_url(settings.redis_url, decode_responses=True)
|
||||||
|
try:
|
||||||
|
discovered = await read_discovered_from_redis(client)
|
||||||
|
finally:
|
||||||
|
await client.aclose()
|
||||||
|
discovered_names = sorted(discovered)
|
||||||
|
typer.echo("discovered models (live from Ollama via discovery cache):")
|
||||||
|
if discovered_names:
|
||||||
|
for name in discovered_names:
|
||||||
|
typer.echo(f" {name}")
|
||||||
|
else:
|
||||||
|
typer.echo(" (none — discovery cache empty or expired; requests fail closed)")
|
||||||
|
|
||||||
|
if tenant is None:
|
||||||
|
return
|
||||||
|
factory = create_session_factory(create_engine(settings))
|
||||||
|
async with session_scope(factory) as session:
|
||||||
|
tenants = TenantRepository(session)
|
||||||
|
tenant_row = await tenants.get_by_name(tenant)
|
||||||
|
if tenant_row is None:
|
||||||
|
raise typer.BadParameter(f"unknown tenant {tenant!r}")
|
||||||
|
limits = await tenants.get_limits(tenant_row.id)
|
||||||
|
if limits is None:
|
||||||
|
raise typer.BadParameter(f"tenant {tenant!r} has no limits row")
|
||||||
|
effective = resolve_effective_models(
|
||||||
|
allow_all=limits.allow_all_models,
|
||||||
|
allowed_models=tuple(limits.allowed_models),
|
||||||
|
discovered=discovered,
|
||||||
|
)
|
||||||
|
typer.echo(f"effective set for tenant {tenant}:")
|
||||||
|
for name in sorted(effective):
|
||||||
|
typer.echo(f" {name}")
|
||||||
|
|
||||||
|
_run(work)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
"""Console-script entry point."""
|
||||||
|
app()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
343
src/neuronetz_gateway/db/repositories.py
Normal file
343
src/neuronetz_gateway/db/repositories.py
Normal file
@@ -0,0 +1,343 @@
|
|||||||
|
"""Async repositories for ``gateway`` schema access (SPEC §4.3, §5, §11).
|
||||||
|
|
||||||
|
These wrap SQLAlchemy 2.0 async sessions. The auth hot-path uses
|
||||||
|
:meth:`ApiKeyRepository.load_for_auth` (one round trip joining the key, tenant,
|
||||||
|
and both limit rows) to build a flattened :class:`~neuronetz_gateway.auth.principal.Principal`.
|
||||||
|
CLI/admin flows use the create/list/usage methods. Nothing here returns a key
|
||||||
|
hash to callers other than the auth resolver, which needs it to run argon2.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from sqlalchemy import insert, select, update
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from neuronetz_gateway.auth.principal import EffectiveLimits, Principal
|
||||||
|
from neuronetz_gateway.config import Settings
|
||||||
|
from neuronetz_gateway.db.models import (
|
||||||
|
ApiKey,
|
||||||
|
AuditLog,
|
||||||
|
BudgetPeriod,
|
||||||
|
BudgetUsage,
|
||||||
|
KeyLimit,
|
||||||
|
KeyStatus,
|
||||||
|
Revocation,
|
||||||
|
Tenant,
|
||||||
|
TenantLimit,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class AuthRow:
|
||||||
|
"""Raw (un-flattened) auth lookup result: the row data argon2 needs."""
|
||||||
|
|
||||||
|
key_hash: str
|
||||||
|
principal: Principal
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_limits(
|
||||||
|
settings: Settings, tenant: TenantLimit, key: KeyLimit | None
|
||||||
|
) -> EffectiveLimits:
|
||||||
|
"""Flatten key-over-tenant numeric limits (NULL key value inherits tenant)."""
|
||||||
|
|
||||||
|
def pick_int(key_val: int | None, tenant_val: int) -> int:
|
||||||
|
return key_val if key_val is not None else tenant_val
|
||||||
|
|
||||||
|
def pick_opt(key_val: int | None, tenant_val: int | None) -> int | None:
|
||||||
|
return key_val if key_val is not None else tenant_val
|
||||||
|
|
||||||
|
return EffectiveLimits(
|
||||||
|
rpm=pick_int(key.rpm if key else None, tenant.rpm),
|
||||||
|
tpm=pick_int(key.tpm if key else None, tenant.tpm),
|
||||||
|
concurrent=pick_int(key.concurrent if key else None, tenant.concurrent),
|
||||||
|
tokens_daily=pick_opt(key.tokens_daily if key else None, tenant.tokens_daily),
|
||||||
|
tokens_monthly=pick_opt(key.tokens_monthly if key else None, tenant.tokens_monthly),
|
||||||
|
tokens_total=pick_opt(key.tokens_total if key else None, tenant.tokens_total),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ApiKeyRepository:
|
||||||
|
"""Read/write access to ``gateway.api_keys`` and key limits."""
|
||||||
|
|
||||||
|
def __init__(self, session: AsyncSession) -> None:
|
||||||
|
self._session = session
|
||||||
|
|
||||||
|
async def get_active_by_prefix(self, prefix: str) -> ApiKey | None:
|
||||||
|
"""Return the active key matching the prefix, or None."""
|
||||||
|
stmt = select(ApiKey).where(
|
||||||
|
ApiKey.prefix == prefix, ApiKey.status == KeyStatus.active
|
||||||
|
)
|
||||||
|
return (await self._session.execute(stmt)).scalar_one_or_none()
|
||||||
|
|
||||||
|
async def get_by_prefix(self, prefix: str) -> ApiKey | None:
|
||||||
|
"""Return the key matching the prefix regardless of status, or None."""
|
||||||
|
stmt = select(ApiKey).where(ApiKey.prefix == prefix)
|
||||||
|
return (await self._session.execute(stmt)).scalar_one_or_none()
|
||||||
|
|
||||||
|
async def load_for_auth(self, prefix: str, settings: Settings) -> AuthRow | None:
|
||||||
|
"""Resolve an active key + tenant + limits into an :class:`AuthRow`.
|
||||||
|
|
||||||
|
Returns ``None`` if no active key has this prefix, the tenant is not
|
||||||
|
active, or the key has expired. Caller still verifies the argon2 hash.
|
||||||
|
"""
|
||||||
|
key = await self.get_active_by_prefix(prefix)
|
||||||
|
if key is None:
|
||||||
|
return None
|
||||||
|
if key.expires_at is not None and key.expires_at <= _utcnow():
|
||||||
|
return None
|
||||||
|
|
||||||
|
tenant = await self._session.get(Tenant, key.tenant_id)
|
||||||
|
if tenant is None or tenant.status.value != "active":
|
||||||
|
return None
|
||||||
|
tenant_limit = await self._session.get(TenantLimit, key.tenant_id)
|
||||||
|
if tenant_limit is None:
|
||||||
|
return None
|
||||||
|
key_limit = await self._session.get(KeyLimit, key.id)
|
||||||
|
|
||||||
|
allow_all = (
|
||||||
|
key_limit.allow_all_models
|
||||||
|
if key_limit is not None and key_limit.allow_all_models is not None
|
||||||
|
else tenant_limit.allow_all_models
|
||||||
|
)
|
||||||
|
allowed = (
|
||||||
|
key_limit.allowed_models
|
||||||
|
if key_limit is not None and key_limit.allowed_models is not None
|
||||||
|
else tenant_limit.allowed_models
|
||||||
|
)
|
||||||
|
log_prompts = (
|
||||||
|
key.log_prompts
|
||||||
|
if key.log_prompts is not None
|
||||||
|
else tenant_limit.log_prompts_default
|
||||||
|
)
|
||||||
|
|
||||||
|
principal = Principal(
|
||||||
|
key_id=key.id,
|
||||||
|
key_prefix=key.prefix,
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
tenant_name=tenant.name,
|
||||||
|
scopes=tuple(key.scopes),
|
||||||
|
log_prompts=log_prompts,
|
||||||
|
allow_all_models=allow_all,
|
||||||
|
allowed_models=tuple(allowed),
|
||||||
|
limits=_resolve_limits(settings, tenant_limit, key_limit),
|
||||||
|
)
|
||||||
|
return AuthRow(key_hash=key.key_hash, principal=principal)
|
||||||
|
|
||||||
|
async def create(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
tenant_id: uuid.UUID,
|
||||||
|
prefix: str,
|
||||||
|
key_hash: str,
|
||||||
|
name: str,
|
||||||
|
scopes: list[str],
|
||||||
|
) -> ApiKey:
|
||||||
|
"""Insert a new API key row and return the persisted ORM object."""
|
||||||
|
key = ApiKey(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
prefix=prefix,
|
||||||
|
key_hash=key_hash,
|
||||||
|
name=name,
|
||||||
|
scopes=scopes,
|
||||||
|
)
|
||||||
|
self._session.add(key)
|
||||||
|
await self._session.flush()
|
||||||
|
return key
|
||||||
|
|
||||||
|
async def list_for_tenant(self, tenant_id: uuid.UUID) -> list[ApiKey]:
|
||||||
|
"""Return all keys belonging to a tenant, newest first."""
|
||||||
|
stmt = (
|
||||||
|
select(ApiKey)
|
||||||
|
.where(ApiKey.tenant_id == tenant_id)
|
||||||
|
.order_by(ApiKey.created_at.desc())
|
||||||
|
)
|
||||||
|
return list((await self._session.execute(stmt)).scalars().all())
|
||||||
|
|
||||||
|
async def touch_last_used(self, key_id: uuid.UUID) -> None:
|
||||||
|
"""Best-effort update of ``last_used_at`` (off the hot path)."""
|
||||||
|
await self._session.execute(
|
||||||
|
update(ApiKey).where(ApiKey.id == key_id).values(last_used_at=_utcnow())
|
||||||
|
)
|
||||||
|
|
||||||
|
async def set_status(self, key_id: uuid.UUID, status: KeyStatus) -> None:
|
||||||
|
"""Set a key's lifecycle status."""
|
||||||
|
await self._session.execute(
|
||||||
|
update(ApiKey).where(ApiKey.id == key_id).values(status=status)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class KeyLimitRepository:
|
||||||
|
"""Read/write access to ``gateway.key_limits``."""
|
||||||
|
|
||||||
|
def __init__(self, session: AsyncSession) -> None:
|
||||||
|
self._session = session
|
||||||
|
|
||||||
|
async def get(self, key_id: uuid.UUID) -> KeyLimit | None:
|
||||||
|
"""Return the key's limit row, or None."""
|
||||||
|
return await self._session.get(KeyLimit, key_id)
|
||||||
|
|
||||||
|
async def upsert_budget(
|
||||||
|
self,
|
||||||
|
key_id: uuid.UUID,
|
||||||
|
*,
|
||||||
|
tokens_daily: int | None,
|
||||||
|
tokens_monthly: int | None,
|
||||||
|
tokens_total: int | None,
|
||||||
|
) -> None:
|
||||||
|
"""Set per-key token budgets, creating the limit row if needed."""
|
||||||
|
existing = await self.get(key_id)
|
||||||
|
if existing is None:
|
||||||
|
self._session.add(
|
||||||
|
KeyLimit(
|
||||||
|
key_id=key_id,
|
||||||
|
tokens_daily=tokens_daily,
|
||||||
|
tokens_monthly=tokens_monthly,
|
||||||
|
tokens_total=tokens_total,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if tokens_daily is not None:
|
||||||
|
existing.tokens_daily = tokens_daily
|
||||||
|
if tokens_monthly is not None:
|
||||||
|
existing.tokens_monthly = tokens_monthly
|
||||||
|
if tokens_total is not None:
|
||||||
|
existing.tokens_total = tokens_total
|
||||||
|
await self._session.flush()
|
||||||
|
|
||||||
|
|
||||||
|
class TenantRepository:
|
||||||
|
"""Read/write access to ``gateway.tenants`` and tenant limits."""
|
||||||
|
|
||||||
|
def __init__(self, session: AsyncSession) -> None:
|
||||||
|
self._session = session
|
||||||
|
|
||||||
|
async def get_by_name(self, name: str) -> Tenant | None:
|
||||||
|
"""Return the tenant with the given unique name, or None."""
|
||||||
|
stmt = select(Tenant).where(Tenant.name == name)
|
||||||
|
return (await self._session.execute(stmt)).scalar_one_or_none()
|
||||||
|
|
||||||
|
async def get_limits(self, tenant_id: uuid.UUID) -> TenantLimit | None:
|
||||||
|
"""Return the tenant's limit row, or None."""
|
||||||
|
return await self._session.get(TenantLimit, tenant_id)
|
||||||
|
|
||||||
|
async def create(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
name: str,
|
||||||
|
rpm: int,
|
||||||
|
tpm: int,
|
||||||
|
concurrent: int,
|
||||||
|
allow_all_models: bool,
|
||||||
|
) -> Tenant:
|
||||||
|
"""Create a tenant and its default limit row in one unit of work."""
|
||||||
|
tenant = Tenant(name=name)
|
||||||
|
self._session.add(tenant)
|
||||||
|
await self._session.flush()
|
||||||
|
self._session.add(
|
||||||
|
TenantLimit(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
rpm=rpm,
|
||||||
|
tpm=tpm,
|
||||||
|
concurrent=concurrent,
|
||||||
|
allow_all_models=allow_all_models,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await self._session.flush()
|
||||||
|
return tenant
|
||||||
|
|
||||||
|
async def set_models(
|
||||||
|
self,
|
||||||
|
tenant_id: uuid.UUID,
|
||||||
|
*,
|
||||||
|
allowed_models: list[str] | None = None,
|
||||||
|
allow_all_models: bool | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Update the tenant's model allowlist and/or allow-all flag."""
|
||||||
|
limits = await self.get_limits(tenant_id)
|
||||||
|
if limits is None:
|
||||||
|
return
|
||||||
|
if allowed_models is not None:
|
||||||
|
limits.allowed_models = allowed_models
|
||||||
|
if allow_all_models is not None:
|
||||||
|
limits.allow_all_models = allow_all_models
|
||||||
|
await self._session.flush()
|
||||||
|
|
||||||
|
|
||||||
|
class AuditRepository:
|
||||||
|
"""Append-only writes to ``gateway.audit_log`` and ``gateway.prompt_log``."""
|
||||||
|
|
||||||
|
def __init__(self, session: AsyncSession) -> None:
|
||||||
|
self._session = session
|
||||||
|
|
||||||
|
async def insert_audit(self, **fields: object) -> int:
|
||||||
|
"""Insert an audit row, returning its bigserial id."""
|
||||||
|
stmt = insert(AuditLog).values(**fields).returning(AuditLog.id)
|
||||||
|
result = await self._session.execute(stmt)
|
||||||
|
return int(result.scalar_one())
|
||||||
|
|
||||||
|
|
||||||
|
class BudgetRepository:
|
||||||
|
"""Reads/writes ``gateway.budget_usage`` (the budget source of truth)."""
|
||||||
|
|
||||||
|
def __init__(self, session: AsyncSession) -> None:
|
||||||
|
self._session = session
|
||||||
|
|
||||||
|
async def usage_for_tenant(
|
||||||
|
self, tenant_id: uuid.UUID, period: BudgetPeriod
|
||||||
|
) -> tuple[int, int, int]:
|
||||||
|
"""Return summed (tokens_in, tokens_out, requests) for a tenant+period."""
|
||||||
|
from sqlalchemy import func
|
||||||
|
|
||||||
|
stmt = (
|
||||||
|
select(
|
||||||
|
func.coalesce(func.sum(BudgetUsage.tokens_in), 0),
|
||||||
|
func.coalesce(func.sum(BudgetUsage.tokens_out), 0),
|
||||||
|
func.coalesce(func.sum(BudgetUsage.requests), 0),
|
||||||
|
)
|
||||||
|
.join(ApiKey, ApiKey.id == BudgetUsage.key_id)
|
||||||
|
.where(ApiKey.tenant_id == tenant_id, BudgetUsage.period == period)
|
||||||
|
)
|
||||||
|
row = (await self._session.execute(stmt)).one()
|
||||||
|
return int(row[0]), int(row[1]), int(row[2])
|
||||||
|
|
||||||
|
|
||||||
|
class RevocationRepository:
|
||||||
|
"""Reads/marks rows in the ``gateway.revocations`` outbox."""
|
||||||
|
|
||||||
|
def __init__(self, session: AsyncSession) -> None:
|
||||||
|
self._session = session
|
||||||
|
|
||||||
|
async def insert(self, key_id: uuid.UUID, reason: str | None) -> int:
|
||||||
|
"""Insert a revocation row (fires the NOTIFY trigger); return its id."""
|
||||||
|
stmt = insert(Revocation).values(key_id=key_id, reason=reason).returning(
|
||||||
|
Revocation.id
|
||||||
|
)
|
||||||
|
return int((await self._session.execute(stmt)).scalar_one())
|
||||||
|
|
||||||
|
async def mark_processed(self, revocation_id: int, key_id: uuid.UUID) -> None:
|
||||||
|
"""Stamp ``processed_at`` after the Redis cache eviction."""
|
||||||
|
await self._session.execute(
|
||||||
|
update(Revocation)
|
||||||
|
.where(Revocation.id == revocation_id, Revocation.key_id == key_id)
|
||||||
|
.values(processed_at=_utcnow())
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _utcnow() -> datetime.datetime:
|
||||||
|
"""Timezone-aware current time (UTC)."""
|
||||||
|
return datetime.datetime.now(datetime.UTC)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ApiKeyRepository",
|
||||||
|
"AuditRepository",
|
||||||
|
"AuthRow",
|
||||||
|
"BudgetRepository",
|
||||||
|
"KeyLimitRepository",
|
||||||
|
"RevocationRepository",
|
||||||
|
"TenantRepository",
|
||||||
|
]
|
||||||
Reference in New Issue
Block a user