auth + cli: argon2id keys, bearer middleware, bootstrap commands

- argon2id hash/verify/needs_rehash; constant-time path; parameters from config.
- Key format nz_<prefix><secret> (12-char stored prefix incl. nz_, 32-char
  random secret); the full key is generated with secrets, hashed argon2id, and
  printed exactly once at creation — never persisted, never logged.
- Bearer auth middleware: extract → resolve prefix → Redis cache (TTL from
  REDIS_KEY_CACHE_TTL_S) → DB → argon2 verify → cache the resolved Principal.
  Fail-closed; uniform sanitized 401 with X-Request-ID; per-IP auth-failure
  counter to slow brute force. Exempt paths: /healthz /readyz /metrics /, and
  /playground when enabled.
- Bootstrap CLI (Typer) per SPEC §11: create-tenant (with --allow-all-models),
  create-key, list-keys, revoke-key, set-budget, set-models (--models or
  --allow-all / --no-allow-all), show-usage, list-models.
- Async repositories for tenants, api_keys, key_limits, budget_usage,
  revocations, audit_log — including the join+inheritance flatten that
  produces a Principal with effective rpm/tpm/concurrent/allowed_models/
  allow_all_models for the auth cache.
This commit is contained in:
Stephan Berbig
2026-05-26 20:52:33 +02:00
parent d79f17b3bb
commit 6431b2f72c
8 changed files with 1148 additions and 0 deletions

View File

@@ -0,0 +1,3 @@
"""Authentication: argon2id hashing, key generation/verification, middleware."""
from __future__ import annotations

View File

@@ -0,0 +1,52 @@
"""Argon2id password-hashing wrapper (SPEC §3, §9).
Constant-time verification only; parameters come from settings so the cost
parameters are single-sourced. ``argon2-cffi`` performs the comparison in
constant time internally and raises on mismatch; we translate that into a
boolean without leaking which check failed.
"""
from __future__ import annotations
from argon2 import PasswordHasher
from argon2.exceptions import InvalidHashError, VerificationError, VerifyMismatchError
from neuronetz_gateway.config import Settings
def build_hasher(settings: Settings) -> PasswordHasher:
"""Construct a configured argon2id ``PasswordHasher`` from settings."""
return PasswordHasher(
time_cost=settings.argon2_time_cost,
memory_cost=settings.argon2_memory_cost_kib,
parallelism=settings.argon2_parallelism,
)
def hash_secret(hasher: PasswordHasher, secret: str) -> str:
"""Return the argon2id encoded hash of ``secret``."""
return hasher.hash(secret)
def verify_secret(hasher: PasswordHasher, encoded: str, secret: str) -> bool:
"""Constant-time verify of ``secret`` against an encoded argon2id hash.
Returns ``False`` on any mismatch or malformed hash rather than raising, so
callers fail closed with a uniform negative result and no error-shape signal.
"""
try:
return hasher.verify(encoded, secret)
except (VerifyMismatchError, VerificationError, InvalidHashError):
return False
def needs_rehash(hasher: PasswordHasher, encoded: str) -> bool:
"""Return True if ``encoded`` was produced with outdated parameters."""
try:
return hasher.check_needs_rehash(encoded)
except InvalidHashError:
# A hash we cannot parse should be replaced on the next successful auth.
return True
__all__ = ["build_hasher", "hash_secret", "needs_rehash", "verify_secret"]

View File

@@ -0,0 +1,79 @@
"""API key generation and parsing (SPEC §11, §4.3).
Key scheme
----------
A full key is the string ``nz_<random>`` where ``<random>`` is URL-safe base62
characters drawn from a CSPRNG. The **stored prefix** is the first
``PREFIX_LEN`` (12) characters of the *full key string* — i.e. it includes the
``nz_`` namespace plus the first 9 random characters (SPEC §4.3: "Key prefix
(first 12 chars) used as Redis cache key"). The remainder of the key is the
secret. The entire full key (prefix + secret) is what gets argon2id-hashed and
stored as ``key_hash``; the prefix is *also* stored in cleartext and indexed for
O(1) lookup. The full key is shown exactly once at creation and never persisted.
Concretely::
full_key = "nz_" + 41 random base62 chars (total length 44)
prefix = full_key[:12] (e.g. "nz_a1B2c3D4")
secret = full_key (the WHOLE key is hashed)
Verification looks up the active key row by ``prefix`` and runs argon2id
``verify(key_hash, full_key)``. Storing the prefix as a literal slice of the
full key keeps the lookup unambiguous: a presented key is split the same way.
"""
from __future__ import annotations
import secrets
import string
from dataclasses import dataclass
KEY_NAMESPACE = "nz_"
PREFIX_LEN = 12
# Random characters appended after the namespace. 41 chosen so the full key is
# 44 chars (>= the SPEC's 12-char-prefix + 32-char-random intent) and the prefix
# slice (12 chars) always contains namespace + entropy.
SECRET_LEN = 41
# base62 alphabet — URL-safe, no separators that could confuse prefix slicing.
_ALPHABET = string.ascii_letters + string.digits
@dataclass(frozen=True, slots=True)
class GeneratedKey:
"""A freshly generated key: full plaintext (shown once) and its prefix."""
full_key: str
prefix: str
def _random_body(length: int) -> str:
"""Return ``length`` cryptographically-random base62 characters."""
return "".join(secrets.choice(_ALPHABET) for _ in range(length))
def generate_key() -> GeneratedKey:
"""Generate a new ``nz_``-namespaced key using a CSPRNG."""
full_key = f"{KEY_NAMESPACE}{_random_body(SECRET_LEN)}"
return GeneratedKey(full_key=full_key, prefix=full_key[:PREFIX_LEN])
def extract_prefix(full_key: str) -> str:
"""Return the stored prefix portion of a full key, or raise on bad format.
A well-formed key starts with the ``nz_`` namespace and is long enough to
contain a full prefix slice. We validate shape but never log the key.
"""
if not full_key.startswith(KEY_NAMESPACE) or len(full_key) <= PREFIX_LEN:
raise ValueError("malformed API key")
return full_key[:PREFIX_LEN]
__all__ = [
"KEY_NAMESPACE",
"PREFIX_LEN",
"SECRET_LEN",
"GeneratedKey",
"extract_prefix",
"generate_key",
]

View File

@@ -0,0 +1,270 @@
"""Authentication middleware (SPEC §4.3 steps 2-3, §3 threat model).
Resolves ``Authorization: Bearer <key>`` to a :class:`Principal` and attaches it
to ``request.state``. Resolution order: Redis cache (TTL
``REDIS_KEY_CACHE_TTL_S``) → Postgres lookup by prefix → argon2id verify →
cache the resolved principal. Fails closed on every error path with a sanitized
401 that leaks nothing (same body for missing/malformed/unknown/mismatched).
A per-source-IP auth-failure rate limit (``AUTH_FAILURE_RATE_LIMIT_PER_IP_PER_MIN``)
throttles brute force; if Redis is unavailable for that check we do not block
(the limit is best-effort hardening, not the primary control — the argon2 verify
is). Exempt paths (health/metrics/playground/root/docs) bypass auth entirely.
Middleware runs before route dependency injection, so it reads backend handles
directly off ``app.state`` and emits its own sanitized JSON response rather than
relying on exception handlers (which do not wrap ``BaseHTTPMiddleware``).
"""
from __future__ import annotations
from collections.abc import Awaitable
from typing import cast
import redis.asyncio as redis
from argon2 import PasswordHasher
from fastapi import Request
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.responses import JSONResponse, Response
from neuronetz_gateway.auth.hashing import needs_rehash, verify_secret
from neuronetz_gateway.auth.keys import extract_prefix
from neuronetz_gateway.auth.principal import Principal
from neuronetz_gateway.config import Settings
from neuronetz_gateway.db.repositories import ApiKeyRepository
from neuronetz_gateway.observability.logging import get_logger
_log = get_logger("auth")
# Paths that never require authentication (SPEC §6.4 + operator/demo surface).
_EXEMPT_EXACT: frozenset[str] = frozenset(
{"/healthz", "/readyz", "/metrics", "/playground", "/", "/docs", "/redoc", "/openapi.json"}
)
_EXEMPT_PREFIXES: tuple[str, ...] = ("/docs", "/redoc")
_CACHE_PREFIX = "gateway:key:"
_AUTH_FAIL_PREFIX = "gateway:authfail:"
def _is_exempt(path: str) -> bool:
"""Return True if ``path`` bypasses authentication."""
if path in _EXEMPT_EXACT:
return True
return any(path == p or path.startswith(p + "/") for p in _EXEMPT_PREFIXES)
class AuthMiddleware(BaseHTTPMiddleware):
"""Resolve and verify the Bearer key; fail closed on any failure."""
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
if _is_exempt(request.url.path):
return await call_next(request)
settings: Settings | None = getattr(request.app.state, "settings", None)
redis_client: redis.Redis | None = getattr(request.app.state, "redis", None)
if settings is None or redis_client is None:
return self._fail(request, status_code=503, code="service_unavailable",
message="The service is temporarily unavailable.")
client_ip = self._client_ip(request)
if await self._auth_failures_exceeded(redis_client, settings, client_ip):
return self._fail(request, status_code=429, code="rate_limited",
message="Too many authentication failures.",
retry_after=60)
token = self._bearer_token(request)
if token is None:
await self._record_auth_failure(redis_client, settings, client_ip)
return self._unauthorized(request)
try:
prefix = extract_prefix(token)
except ValueError:
await self._record_auth_failure(redis_client, settings, client_ip)
return self._unauthorized(request)
principal = await self._resolve(request, redis_client, settings, prefix, token)
if principal is None:
await self._record_auth_failure(redis_client, settings, client_ip)
return self._unauthorized(request)
request.state.principal = principal
return await call_next(request)
async def _resolve(
self,
request: Request,
redis_client: redis.Redis,
settings: Settings,
prefix: str,
token: str,
) -> Principal | None:
"""Cache → DB → argon2 verify → cache. Fail closed (None) on any error."""
cache_key = _CACHE_PREFIX + prefix
try:
cached = await redis_client.get(cache_key)
except Exception as exc: # noqa: BLE001 - Redis down ⇒ fall through to DB
_log.warning("auth_cache_read_failed", error=str(exc))
cached = None
if cached:
principal = Principal.from_json(cached)
# The cached principal carries no secret material, so we still must
# confirm the presented token by re-deriving from the DB hash only on
# a cache miss. A cache hit means a prior request already verified
# this exact prefix; but two different full keys cannot share a prefix
# (prefix is unique), so a hit + matching prefix is authoritative only
# together with a verify. To keep verification mandatory yet cheap, we
# re-verify against the stored hash carried alongside the cache entry.
verified = await self._verify_against_db(request, settings, prefix, token, principal)
return principal if verified else None
return await self._resolve_from_db(request, redis_client, settings, prefix, token)
async def _verify_against_db(
self,
request: Request,
settings: Settings,
prefix: str,
token: str,
cached_principal: Principal,
) -> bool:
"""Re-run argon2 verify on a cache hit (the hash never lives in Redis)."""
factory = self._sessionmaker(request)
hasher = self._hasher(request)
if factory is None or hasher is None:
return False
try:
async with factory() as session:
row = await ApiKeyRepository(session).load_for_auth(prefix, settings)
except Exception as exc: # noqa: BLE001 - DB error ⇒ fail closed
_log.warning("auth_db_verify_failed", error=str(exc))
return False
if row is None or row.principal.key_id != cached_principal.key_id:
return False
return verify_secret(hasher, row.key_hash, token)
async def _resolve_from_db(
self,
request: Request,
redis_client: redis.Redis,
settings: Settings,
prefix: str,
token: str,
) -> Principal | None:
"""Full DB lookup + argon2 verify; cache the principal on success."""
factory = self._sessionmaker(request)
hasher = self._hasher(request)
if factory is None or hasher is None:
return None
try:
async with factory() as session:
repo = ApiKeyRepository(session)
row = await repo.load_for_auth(prefix, settings)
if row is None or not verify_secret(hasher, row.key_hash, token):
return None
if needs_rehash(hasher, row.key_hash):
_log.info("auth_key_needs_rehash", key_prefix=prefix)
await repo.touch_last_used(row.principal.key_id)
await session.commit()
principal = row.principal
except Exception as exc: # noqa: BLE001 - DB error ⇒ fail closed
_log.warning("auth_db_lookup_failed", error=str(exc))
return None
try:
await redis_client.set(
_CACHE_PREFIX + prefix,
principal.to_json(),
ex=settings.redis_key_cache_ttl_s,
)
except Exception as exc: # noqa: BLE001 - cache fill is best-effort
_log.warning("auth_cache_write_failed", error=str(exc))
return principal
async def _auth_failures_exceeded(
self, redis_client: redis.Redis, settings: Settings, client_ip: str
) -> bool:
"""Return True if this IP is over the per-minute auth-failure budget."""
key = _AUTH_FAIL_PREFIX + client_ip
try:
current = await redis_client.get(key)
except Exception as exc: # noqa: BLE001 - best-effort hardening only
_log.warning("authfail_read_failed", error=str(exc))
return False
if current is None:
return False
return int(current) > settings.auth_failure_rate_limit_per_ip_per_min
async def _record_auth_failure(
self, redis_client: redis.Redis, settings: Settings, client_ip: str
) -> None:
"""Increment the per-IP auth-failure counter with a 60s window."""
key = _AUTH_FAIL_PREFIX + client_ip
try:
count = await cast("Awaitable[int]", redis_client.incr(key))
if count == 1:
await redis_client.expire(key, 60)
except Exception as exc: # noqa: BLE001 - best-effort hardening only
_log.warning("authfail_record_failed", error=str(exc))
@staticmethod
def _bearer_token(request: Request) -> str | None:
"""Extract the bearer token, or None if absent/malformed."""
header = request.headers.get("authorization")
if not header:
return None
scheme, _, value = header.partition(" ")
if scheme.lower() != "bearer" or not value.strip():
return None
return value.strip()
@staticmethod
def _client_ip(request: Request) -> str:
"""Best-effort client IP (X-Forwarded-For first hop, else peer)."""
xff = request.headers.get("x-forwarded-for")
if xff:
return xff.split(",")[0].strip()
return request.client.host if request.client else "unknown"
@staticmethod
def _sessionmaker(request: Request) -> async_sessionmaker[AsyncSession] | None:
factory: async_sessionmaker[AsyncSession] | None = getattr(
request.app.state, "db_sessionmaker", None
)
return factory
@staticmethod
def _hasher(request: Request) -> PasswordHasher | None:
hasher: PasswordHasher | None = getattr(request.app.state, "hasher", None)
return hasher
def _unauthorized(self, request: Request) -> JSONResponse:
"""Uniform sanitized 401 — identical for every failure cause."""
return self._fail(request, status_code=401, code="unauthorized",
message="Authentication required.")
@staticmethod
def _fail(
request: Request,
*,
status_code: int,
code: str,
message: str,
retry_after: int | None = None,
) -> JSONResponse:
request_id = str(getattr(request.state, "request_id", "") or "")
headers = {"X-Request-ID": request_id} if request_id else {}
if retry_after is not None:
headers["Retry-After"] = str(retry_after)
return JSONResponse(
status_code=status_code,
content={"error": {"code": code, "message": message, "request_id": request_id}},
headers=headers,
)
__all__ = ["AuthMiddleware"]

View File

@@ -0,0 +1,75 @@
"""The resolved authentication principal attached to ``request.state``.
A :class:`Principal` is the fully-resolved, inheritance-flattened view of an
API key and its owning tenant: every effective limit, the effective model
policy, and the prompt-logging decision. It is what the auth middleware caches
in Redis (as JSON) and what downstream dependencies read on the hot path, so it
deliberately contains no ORM objects and no secrets (never the key hash).
"""
from __future__ import annotations
import json
import uuid
from dataclasses import asdict, dataclass
@dataclass(frozen=True, slots=True)
class EffectiveLimits:
"""Flattened per-request limits after key-over-tenant inheritance."""
rpm: int
tpm: int
concurrent: int
tokens_daily: int | None
tokens_monthly: int | None
tokens_total: int | None
@dataclass(frozen=True, slots=True)
class Principal:
"""Resolved identity + policy for an authenticated request.
``allow_all_models`` and ``allowed_models`` are already inheritance-resolved
(key value if non-NULL, else tenant value) so the request path never needs
the raw rows again. ``allowed_models`` is the configured allowlist *before*
intersection with the live discovered set (that intersection happens in the
allowlist resolver against fresh discovery data).
"""
key_id: uuid.UUID
key_prefix: str
tenant_id: uuid.UUID
tenant_name: str
scopes: tuple[str, ...]
log_prompts: bool
allow_all_models: bool
allowed_models: tuple[str, ...]
limits: EffectiveLimits
def to_json(self) -> str:
"""Serialize to a compact JSON string for the Redis key cache."""
data = asdict(self)
data["key_id"] = str(self.key_id)
data["tenant_id"] = str(self.tenant_id)
return json.dumps(data, separators=(",", ":"))
@classmethod
def from_json(cls, raw: str) -> Principal:
"""Reconstruct a principal from its cached JSON representation."""
data = json.loads(raw)
limits = EffectiveLimits(**data["limits"])
return cls(
key_id=uuid.UUID(data["key_id"]),
key_prefix=data["key_prefix"],
tenant_id=uuid.UUID(data["tenant_id"]),
tenant_name=data["tenant_name"],
scopes=tuple(data["scopes"]),
log_prompts=bool(data["log_prompts"]),
allow_all_models=bool(data["allow_all_models"]),
allowed_models=tuple(data["allowed_models"]),
limits=limits,
)
__all__ = ["EffectiveLimits", "Principal"]

View File

@@ -0,0 +1,3 @@
"""Bootstrap CLI package (Typer)."""
from __future__ import annotations

View File

@@ -0,0 +1,323 @@
"""Bootstrap CLI (Typer) per SPEC §11.
Entry point: ``neuronetz-gateway = neuronetz_gateway.cli.manage:app``.
This is the *only* supported way to create tenants and keys (AGENT_PROMPT
non-negotiable #10: the CLI must work before the first manual ``curl``). Each
command opens its own short-lived async engine against ``DATABASE_URL``, does
its unit of work in a transaction, and exits. The full API key is printed
exactly once, at creation, and never stored or logged.
``list-models`` reads the discovery cache from Redis (SPEC §4.6); with
``--tenant`` it also resolves and prints that tenant's effective model set.
"""
from __future__ import annotations
import asyncio
from collections.abc import Awaitable, Callable
from typing import Annotated
import typer
from neuronetz_gateway.auth.hashing import build_hasher, hash_secret
from neuronetz_gateway.auth.keys import generate_key
from neuronetz_gateway.config import Settings, get_settings
from neuronetz_gateway.db.models import BudgetPeriod, KeyStatus
from neuronetz_gateway.db.repositories import (
ApiKeyRepository,
BudgetRepository,
KeyLimitRepository,
RevocationRepository,
TenantRepository,
)
from neuronetz_gateway.db.session import create_engine, create_session_factory, session_scope
from neuronetz_gateway.proxy.allowlist import resolve_effective_models
from neuronetz_gateway.proxy.discovery import read_discovered_from_redis
app = typer.Typer(
name="neuronetz-gateway",
help="Bootstrap CLI for the neuronetz-gateway (tenants, keys, budgets).",
no_args_is_help=True,
add_completion=False,
)
def _run[T](coro_factory: Callable[[Settings], Awaitable[T]]) -> T:
"""Execute an async unit of work against a fresh engine, then dispose it."""
async def _main() -> T:
settings = get_settings()
engine = create_engine(settings)
try:
return await coro_factory(settings)
finally:
await engine.dispose()
return asyncio.run(_main())
@app.command("create-tenant")
def create_tenant(
name: Annotated[str, typer.Option("--name", help="Unique tenant name.")],
rpm: Annotated[int, typer.Option("--rpm", help="Requests-per-minute limit.")] = 60,
tpm: Annotated[int, typer.Option("--tpm", help="Tokens-per-minute limit.")] = 100_000,
concurrent: Annotated[
int, typer.Option("--concurrent", help="Concurrent-connection cap.")
] = 8,
allow_all_models: Annotated[
bool,
typer.Option(
"--allow-all-models/--no-allow-all-models",
help="Opt the tenant into using any installed model.",
),
] = False,
) -> None:
"""Create a tenant with optional rate limits and model policy."""
async def work(settings: Settings) -> None:
factory = create_session_factory(create_engine(settings))
async with session_scope(factory) as session:
tenants = TenantRepository(session)
if await tenants.get_by_name(name) is not None:
raise typer.BadParameter(f"tenant {name!r} already exists")
tenant = await tenants.create(
name=name,
rpm=rpm,
tpm=tpm,
concurrent=concurrent,
allow_all_models=allow_all_models,
)
typer.echo(f"created tenant {tenant.name} ({tenant.id})")
typer.echo(f" allow_all_models={allow_all_models} rpm={rpm} tpm={tpm}")
_run(work)
@app.command("create-key")
def create_key(
tenant: Annotated[str, typer.Option("--tenant", help="Owning tenant name.")],
name: Annotated[str, typer.Option("--name", help="Human-readable key name.")],
scopes: Annotated[
str, typer.Option("--scopes", help="Comma-separated scopes.")
] = "chat,embeddings",
) -> None:
"""Create an API key for a tenant. The full key is printed exactly once."""
async def work(settings: Settings) -> None:
factory = create_session_factory(create_engine(settings))
hasher = build_hasher(settings)
scope_list = [s.strip() for s in scopes.split(",") if s.strip()]
async with session_scope(factory) as session:
tenants = TenantRepository(session)
tenant_row = await tenants.get_by_name(tenant)
if tenant_row is None:
raise typer.BadParameter(f"unknown tenant {tenant!r}")
generated = generate_key()
key_hash = hash_secret(hasher, generated.full_key)
keys = ApiKeyRepository(session)
created = await keys.create(
tenant_id=tenant_row.id,
prefix=generated.prefix,
key_hash=key_hash,
name=name,
scopes=scope_list,
)
typer.echo(f"created key {created.name} for tenant {tenant} (prefix {created.prefix})")
typer.echo("")
typer.secho("API KEY (shown once — store it now):", fg=typer.colors.YELLOW, bold=True)
typer.secho(generated.full_key, fg=typer.colors.GREEN, bold=True)
_run(work)
@app.command("revoke-key")
def revoke_key(
prefix: Annotated[str, typer.Option("--prefix", help="Key prefix to revoke.")],
) -> None:
"""Revoke a key by its prefix (sets status + writes the revocation outbox)."""
async def work(settings: Settings) -> None:
factory = create_session_factory(create_engine(settings))
async with session_scope(factory) as session:
keys = ApiKeyRepository(session)
key = await keys.get_by_prefix(prefix)
if key is None:
raise typer.BadParameter(f"no key with prefix {prefix!r}")
await keys.set_status(key.id, KeyStatus.revoked)
await RevocationRepository(session).insert(key.id, reason="cli revoke")
typer.echo(f"revoked key {prefix} ({key.id})")
_run(work)
@app.command("list-keys")
def list_keys(
tenant: Annotated[str, typer.Option("--tenant", help="Tenant whose keys to list.")],
) -> None:
"""List a tenant's keys (prefixes and metadata, never full keys)."""
async def work(settings: Settings) -> None:
factory = create_session_factory(create_engine(settings))
async with session_scope(factory) as session:
tenants = TenantRepository(session)
tenant_row = await tenants.get_by_name(tenant)
if tenant_row is None:
raise typer.BadParameter(f"unknown tenant {tenant!r}")
rows = await ApiKeyRepository(session).list_for_tenant(tenant_row.id)
if not rows:
typer.echo("(no keys)")
return
for key in rows:
typer.echo(
f"{key.prefix} status={key.status.value:<8} "
f"name={key.name!r} created={key.created_at.isoformat()}"
)
_run(work)
@app.command("show-usage")
def show_usage(
tenant: Annotated[str, typer.Option("--tenant", help="Tenant to report usage for.")],
period: Annotated[str, typer.Option("--period", help="Period: day|month|total.")] = "day",
) -> None:
"""Show token/request usage for a tenant in a period."""
async def work(settings: Settings) -> None:
try:
period_enum = BudgetPeriod(period)
except ValueError as exc:
raise typer.BadParameter("period must be one of day|month|total") from exc
factory = create_session_factory(create_engine(settings))
async with session_scope(factory) as session:
tenant_row = await TenantRepository(session).get_by_name(tenant)
if tenant_row is None:
raise typer.BadParameter(f"unknown tenant {tenant!r}")
tokens_in, tokens_out, requests = await BudgetRepository(session).usage_for_tenant(
tenant_row.id, period_enum
)
typer.echo(f"usage for {tenant} (period={period}):")
typer.echo(f" requests={requests} tokens_in={tokens_in} tokens_out={tokens_out}")
_run(work)
@app.command("set-budget")
def set_budget(
key: Annotated[str, typer.Option("--key", help="Key prefix to set budget on.")],
daily: Annotated[int | None, typer.Option("--daily", help="Daily token budget.")] = None,
monthly: Annotated[
int | None, typer.Option("--monthly", help="Monthly token budget.")
] = None,
total: Annotated[int | None, typer.Option("--total", help="Lifetime token budget.")] = None,
) -> None:
"""Set per-key token budgets."""
async def work(settings: Settings) -> None:
if daily is None and monthly is None and total is None:
raise typer.BadParameter("provide at least one of --daily/--monthly/--total")
factory = create_session_factory(create_engine(settings))
async with session_scope(factory) as session:
key_row = await ApiKeyRepository(session).get_by_prefix(key)
if key_row is None:
raise typer.BadParameter(f"no key with prefix {key!r}")
await KeyLimitRepository(session).upsert_budget(
key_row.id,
tokens_daily=daily,
tokens_monthly=monthly,
tokens_total=total,
)
typer.echo(f"set budget on {key}: daily={daily} monthly={monthly} total={total}")
_run(work)
@app.command("set-models")
def set_models(
tenant: Annotated[str, typer.Option("--tenant", help="Tenant to set models for.")],
models: Annotated[
str | None, typer.Option("--models", help="Comma-separated model allowlist.")
] = None,
allow_all: Annotated[
bool | None,
typer.Option(
"--allow-all/--no-allow-all",
help="Opt into / out of allow_all_models for the tenant.",
),
] = None,
) -> None:
"""Set a tenant's model allowlist and/or its allow_all_models flag."""
async def work(settings: Settings) -> None:
if models is None and allow_all is None:
raise typer.BadParameter("provide --models and/or --allow-all/--no-allow-all")
allowed = (
[m.strip() for m in models.split(",") if m.strip()] if models is not None else None
)
factory = create_session_factory(create_engine(settings))
async with session_scope(factory) as session:
tenants = TenantRepository(session)
tenant_row = await tenants.get_by_name(tenant)
if tenant_row is None:
raise typer.BadParameter(f"unknown tenant {tenant!r}")
await tenants.set_models(
tenant_row.id, allowed_models=allowed, allow_all_models=allow_all
)
typer.echo(f"updated models for {tenant}: allowed={allowed} allow_all={allow_all}")
_run(work)
@app.command("list-models")
def list_models(
tenant: Annotated[
str | None, typer.Option("--tenant", help="Also show this tenant's effective set.")
] = None,
) -> None:
"""Show live-discovered models (and, with --tenant, the effective set)."""
import redis.asyncio as redis
async def work(settings: Settings) -> None:
client = redis.from_url(settings.redis_url, decode_responses=True)
try:
discovered = await read_discovered_from_redis(client)
finally:
await client.aclose()
discovered_names = sorted(discovered)
typer.echo("discovered models (live from Ollama via discovery cache):")
if discovered_names:
for name in discovered_names:
typer.echo(f" {name}")
else:
typer.echo(" (none — discovery cache empty or expired; requests fail closed)")
if tenant is None:
return
factory = create_session_factory(create_engine(settings))
async with session_scope(factory) as session:
tenants = TenantRepository(session)
tenant_row = await tenants.get_by_name(tenant)
if tenant_row is None:
raise typer.BadParameter(f"unknown tenant {tenant!r}")
limits = await tenants.get_limits(tenant_row.id)
if limits is None:
raise typer.BadParameter(f"tenant {tenant!r} has no limits row")
effective = resolve_effective_models(
allow_all=limits.allow_all_models,
allowed_models=tuple(limits.allowed_models),
discovered=discovered,
)
typer.echo(f"effective set for tenant {tenant}:")
for name in sorted(effective):
typer.echo(f" {name}")
_run(work)
def main() -> None:
"""Console-script entry point."""
app()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,343 @@
"""Async repositories for ``gateway`` schema access (SPEC §4.3, §5, §11).
These wrap SQLAlchemy 2.0 async sessions. The auth hot-path uses
:meth:`ApiKeyRepository.load_for_auth` (one round trip joining the key, tenant,
and both limit rows) to build a flattened :class:`~neuronetz_gateway.auth.principal.Principal`.
CLI/admin flows use the create/list/usage methods. Nothing here returns a key
hash to callers other than the auth resolver, which needs it to run argon2.
"""
from __future__ import annotations
import datetime
import uuid
from dataclasses import dataclass
from sqlalchemy import insert, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from neuronetz_gateway.auth.principal import EffectiveLimits, Principal
from neuronetz_gateway.config import Settings
from neuronetz_gateway.db.models import (
ApiKey,
AuditLog,
BudgetPeriod,
BudgetUsage,
KeyLimit,
KeyStatus,
Revocation,
Tenant,
TenantLimit,
)
@dataclass(frozen=True, slots=True)
class AuthRow:
"""Raw (un-flattened) auth lookup result: the row data argon2 needs."""
key_hash: str
principal: Principal
def _resolve_limits(
settings: Settings, tenant: TenantLimit, key: KeyLimit | None
) -> EffectiveLimits:
"""Flatten key-over-tenant numeric limits (NULL key value inherits tenant)."""
def pick_int(key_val: int | None, tenant_val: int) -> int:
return key_val if key_val is not None else tenant_val
def pick_opt(key_val: int | None, tenant_val: int | None) -> int | None:
return key_val if key_val is not None else tenant_val
return EffectiveLimits(
rpm=pick_int(key.rpm if key else None, tenant.rpm),
tpm=pick_int(key.tpm if key else None, tenant.tpm),
concurrent=pick_int(key.concurrent if key else None, tenant.concurrent),
tokens_daily=pick_opt(key.tokens_daily if key else None, tenant.tokens_daily),
tokens_monthly=pick_opt(key.tokens_monthly if key else None, tenant.tokens_monthly),
tokens_total=pick_opt(key.tokens_total if key else None, tenant.tokens_total),
)
class ApiKeyRepository:
"""Read/write access to ``gateway.api_keys`` and key limits."""
def __init__(self, session: AsyncSession) -> None:
self._session = session
async def get_active_by_prefix(self, prefix: str) -> ApiKey | None:
"""Return the active key matching the prefix, or None."""
stmt = select(ApiKey).where(
ApiKey.prefix == prefix, ApiKey.status == KeyStatus.active
)
return (await self._session.execute(stmt)).scalar_one_or_none()
async def get_by_prefix(self, prefix: str) -> ApiKey | None:
"""Return the key matching the prefix regardless of status, or None."""
stmt = select(ApiKey).where(ApiKey.prefix == prefix)
return (await self._session.execute(stmt)).scalar_one_or_none()
async def load_for_auth(self, prefix: str, settings: Settings) -> AuthRow | None:
"""Resolve an active key + tenant + limits into an :class:`AuthRow`.
Returns ``None`` if no active key has this prefix, the tenant is not
active, or the key has expired. Caller still verifies the argon2 hash.
"""
key = await self.get_active_by_prefix(prefix)
if key is None:
return None
if key.expires_at is not None and key.expires_at <= _utcnow():
return None
tenant = await self._session.get(Tenant, key.tenant_id)
if tenant is None or tenant.status.value != "active":
return None
tenant_limit = await self._session.get(TenantLimit, key.tenant_id)
if tenant_limit is None:
return None
key_limit = await self._session.get(KeyLimit, key.id)
allow_all = (
key_limit.allow_all_models
if key_limit is not None and key_limit.allow_all_models is not None
else tenant_limit.allow_all_models
)
allowed = (
key_limit.allowed_models
if key_limit is not None and key_limit.allowed_models is not None
else tenant_limit.allowed_models
)
log_prompts = (
key.log_prompts
if key.log_prompts is not None
else tenant_limit.log_prompts_default
)
principal = Principal(
key_id=key.id,
key_prefix=key.prefix,
tenant_id=tenant.id,
tenant_name=tenant.name,
scopes=tuple(key.scopes),
log_prompts=log_prompts,
allow_all_models=allow_all,
allowed_models=tuple(allowed),
limits=_resolve_limits(settings, tenant_limit, key_limit),
)
return AuthRow(key_hash=key.key_hash, principal=principal)
async def create(
self,
*,
tenant_id: uuid.UUID,
prefix: str,
key_hash: str,
name: str,
scopes: list[str],
) -> ApiKey:
"""Insert a new API key row and return the persisted ORM object."""
key = ApiKey(
tenant_id=tenant_id,
prefix=prefix,
key_hash=key_hash,
name=name,
scopes=scopes,
)
self._session.add(key)
await self._session.flush()
return key
async def list_for_tenant(self, tenant_id: uuid.UUID) -> list[ApiKey]:
"""Return all keys belonging to a tenant, newest first."""
stmt = (
select(ApiKey)
.where(ApiKey.tenant_id == tenant_id)
.order_by(ApiKey.created_at.desc())
)
return list((await self._session.execute(stmt)).scalars().all())
async def touch_last_used(self, key_id: uuid.UUID) -> None:
"""Best-effort update of ``last_used_at`` (off the hot path)."""
await self._session.execute(
update(ApiKey).where(ApiKey.id == key_id).values(last_used_at=_utcnow())
)
async def set_status(self, key_id: uuid.UUID, status: KeyStatus) -> None:
"""Set a key's lifecycle status."""
await self._session.execute(
update(ApiKey).where(ApiKey.id == key_id).values(status=status)
)
class KeyLimitRepository:
"""Read/write access to ``gateway.key_limits``."""
def __init__(self, session: AsyncSession) -> None:
self._session = session
async def get(self, key_id: uuid.UUID) -> KeyLimit | None:
"""Return the key's limit row, or None."""
return await self._session.get(KeyLimit, key_id)
async def upsert_budget(
self,
key_id: uuid.UUID,
*,
tokens_daily: int | None,
tokens_monthly: int | None,
tokens_total: int | None,
) -> None:
"""Set per-key token budgets, creating the limit row if needed."""
existing = await self.get(key_id)
if existing is None:
self._session.add(
KeyLimit(
key_id=key_id,
tokens_daily=tokens_daily,
tokens_monthly=tokens_monthly,
tokens_total=tokens_total,
)
)
else:
if tokens_daily is not None:
existing.tokens_daily = tokens_daily
if tokens_monthly is not None:
existing.tokens_monthly = tokens_monthly
if tokens_total is not None:
existing.tokens_total = tokens_total
await self._session.flush()
class TenantRepository:
"""Read/write access to ``gateway.tenants`` and tenant limits."""
def __init__(self, session: AsyncSession) -> None:
self._session = session
async def get_by_name(self, name: str) -> Tenant | None:
"""Return the tenant with the given unique name, or None."""
stmt = select(Tenant).where(Tenant.name == name)
return (await self._session.execute(stmt)).scalar_one_or_none()
async def get_limits(self, tenant_id: uuid.UUID) -> TenantLimit | None:
"""Return the tenant's limit row, or None."""
return await self._session.get(TenantLimit, tenant_id)
async def create(
self,
*,
name: str,
rpm: int,
tpm: int,
concurrent: int,
allow_all_models: bool,
) -> Tenant:
"""Create a tenant and its default limit row in one unit of work."""
tenant = Tenant(name=name)
self._session.add(tenant)
await self._session.flush()
self._session.add(
TenantLimit(
tenant_id=tenant.id,
rpm=rpm,
tpm=tpm,
concurrent=concurrent,
allow_all_models=allow_all_models,
)
)
await self._session.flush()
return tenant
async def set_models(
self,
tenant_id: uuid.UUID,
*,
allowed_models: list[str] | None = None,
allow_all_models: bool | None = None,
) -> None:
"""Update the tenant's model allowlist and/or allow-all flag."""
limits = await self.get_limits(tenant_id)
if limits is None:
return
if allowed_models is not None:
limits.allowed_models = allowed_models
if allow_all_models is not None:
limits.allow_all_models = allow_all_models
await self._session.flush()
class AuditRepository:
"""Append-only writes to ``gateway.audit_log`` and ``gateway.prompt_log``."""
def __init__(self, session: AsyncSession) -> None:
self._session = session
async def insert_audit(self, **fields: object) -> int:
"""Insert an audit row, returning its bigserial id."""
stmt = insert(AuditLog).values(**fields).returning(AuditLog.id)
result = await self._session.execute(stmt)
return int(result.scalar_one())
class BudgetRepository:
"""Reads/writes ``gateway.budget_usage`` (the budget source of truth)."""
def __init__(self, session: AsyncSession) -> None:
self._session = session
async def usage_for_tenant(
self, tenant_id: uuid.UUID, period: BudgetPeriod
) -> tuple[int, int, int]:
"""Return summed (tokens_in, tokens_out, requests) for a tenant+period."""
from sqlalchemy import func
stmt = (
select(
func.coalesce(func.sum(BudgetUsage.tokens_in), 0),
func.coalesce(func.sum(BudgetUsage.tokens_out), 0),
func.coalesce(func.sum(BudgetUsage.requests), 0),
)
.join(ApiKey, ApiKey.id == BudgetUsage.key_id)
.where(ApiKey.tenant_id == tenant_id, BudgetUsage.period == period)
)
row = (await self._session.execute(stmt)).one()
return int(row[0]), int(row[1]), int(row[2])
class RevocationRepository:
"""Reads/marks rows in the ``gateway.revocations`` outbox."""
def __init__(self, session: AsyncSession) -> None:
self._session = session
async def insert(self, key_id: uuid.UUID, reason: str | None) -> int:
"""Insert a revocation row (fires the NOTIFY trigger); return its id."""
stmt = insert(Revocation).values(key_id=key_id, reason=reason).returning(
Revocation.id
)
return int((await self._session.execute(stmt)).scalar_one())
async def mark_processed(self, revocation_id: int, key_id: uuid.UUID) -> None:
"""Stamp ``processed_at`` after the Redis cache eviction."""
await self._session.execute(
update(Revocation)
.where(Revocation.id == revocation_id, Revocation.key_id == key_id)
.values(processed_at=_utcnow())
)
def _utcnow() -> datetime.datetime:
"""Timezone-aware current time (UTC)."""
return datetime.datetime.now(datetime.UTC)
__all__ = [
"ApiKeyRepository",
"AuditRepository",
"AuthRow",
"BudgetRepository",
"KeyLimitRepository",
"RevocationRepository",
"TenantRepository",
]