diff --git a/API_DOCS.md b/API_DOCS.md index 7f161ba..ed689be 100644 --- a/API_DOCS.md +++ b/API_DOCS.md @@ -13,3 +13,13 @@ Authentication: Notes: - Base URL for v1 endpoints is `/api/v1`. - Admin-only routes return `403 Forbidden` when the token user is not staff/superuser. + +Example: update server display name (admin-only) + +PATCH `/api/v1/servers/{server_id}` + +```json +{ + "display_name": "Keywarden Prod" +} +``` diff --git a/agent/README.md b/agent/README.md index 6659403..b88b694 100644 --- a/agent/README.md +++ b/agent/README.md @@ -20,4 +20,6 @@ You can also pass `KEYWARDEN_SERVER_URL` and `KEYWARDEN_ENROLL_TOKEN` as environ On first boot, the agent will create a config file if it does not exist. Only `server_url` is required for bootstrapping. +If the Keywarden server uses a private TLS CA, set `server_ca_path` (or `KEYWARDEN_SERVER_CA_PATH`) to the CA PEM file so the agent can verify the server certificate. + See `config.example.json`. diff --git a/agent/cmd/keywarden-agent/main.go b/agent/cmd/keywarden-agent/main.go index d331ab4..5eb6621 100644 --- a/agent/cmd/keywarden-agent/main.go +++ b/agent/cmd/keywarden-agent/main.go @@ -18,6 +18,7 @@ import ( "keywarden/agent/internal/client" "keywarden/agent/internal/config" + "keywarden/agent/internal/host" "keywarden/agent/internal/logs" "keywarden/agent/internal/version" ) @@ -74,11 +75,22 @@ func main() { } func runOnce(ctx context.Context, apiClient *client.Client, cfg *config.Config) { + if err := reportHost(ctx, apiClient, cfg); err != nil { + if client.IsRetriable(err) { + log.Printf("host update deferred; will retry: %v", err) + } else { + log.Printf("host update error: %v", err) + } + } if err := apiClient.SyncAccounts(ctx, cfg.ServerID); err != nil { log.Printf("sync accounts error: %v", err) } if err := shipLogs(ctx, apiClient, cfg); err != nil { - log.Printf("log shipping error: %v", err) + if client.IsRetriable(err) { + log.Printf("log shipping deferred; will retry: %v", err) + } else { + log.Printf("log shipping error: %v", err) + } } } @@ -94,7 +106,9 @@ func ensureDirs(cfg *config.Config) error { func shipLogs(ctx context.Context, apiClient *client.Client, cfg *config.Config) error { send := func(payload []byte) error { - return apiClient.SendLogBatch(ctx, cfg.ServerID, payload) + return retry(ctx, []time.Duration{250 * time.Millisecond, time.Second, 2 * time.Second}, func() error { + return apiClient.SendLogBatch(ctx, cfg.ServerID, payload) + }) } if err := logs.DrainSpool(cfg.LogSpoolDir(), send); err != nil { return err @@ -128,6 +142,17 @@ func shipLogs(ctx context.Context, apiClient *client.Client, cfg *config.Config) return nil } +func reportHost(ctx context.Context, apiClient *client.Client, cfg *config.Config) error { + info := host.Detect() + return retry(ctx, []time.Duration{250 * time.Millisecond, time.Second, 2 * time.Second}, func() error { + return apiClient.UpdateHost(ctx, cfg.ServerID, client.HeartbeatRequest{ + Host: info.Hostname, + IPv4: info.IPv4, + IPv6: info.IPv6, + }) + }) +} + func pickServerURL(flagValue string) string { if flagValue != "" { return flagValue @@ -159,11 +184,14 @@ func bootstrapIfNeeded(cfg *config.Config, configPath string, enrollToken string if err != nil { return err } - hostname, _ := os.Hostname() + info := host.Detect() + hostname := info.Hostname resp, err := client.Enroll(context.Background(), cfg.ServerURL, client.EnrollRequest{ Token: enrollToken, CSRPEM: csrPEM, Host: hostname, + IPv4: info.IPv4, + IPv6: info.IPv6, }) if err != nil { return err @@ -181,6 +209,28 @@ func bootstrapIfNeeded(cfg *config.Config, configPath string, enrollToken string return nil } +func retry(ctx context.Context, delays []time.Duration, fn func() error) error { + var lastErr error + for attempt := 0; attempt <= len(delays); attempt++ { + if attempt > 0 { + if !client.IsRetriable(lastErr) { + return lastErr + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(delays[attempt-1]): + } + } + if err := fn(); err != nil { + lastErr = err + continue + } + return nil + } + return lastErr +} + func generateKey(path string) error { key, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { diff --git a/agent/config.example.json b/agent/config.example.json index 3b25335..1c3f8b7 100644 --- a/agent/config.example.json +++ b/agent/config.example.json @@ -1,6 +1,7 @@ { "server_url": "https://keywarden.dev.ntbx.io/api/v1", "server_id": "4", + "server_ca_path": "", "sync_interval_seconds": 30, "log_batch_size": 500, "state_dir": "/var/lib/keywarden-agent", @@ -11,4 +12,4 @@ "create_home": true, "lock_on_revoke": true } -} \ No newline at end of file +} diff --git a/agent/internal/client/client.go b/agent/internal/client/client.go index 4e4b09d..69b5950 100644 --- a/agent/internal/client/client.go +++ b/agent/internal/client/client.go @@ -32,13 +32,18 @@ func New(cfg *config.Config) (*Client, error) { if err != nil { return nil, fmt.Errorf("load client cert: %w", err) } - caData, err := os.ReadFile(cfg.CACertPath()) - if err != nil { - return nil, fmt.Errorf("read ca cert: %w", err) + caPool, err := x509.SystemCertPool() + if err != nil || caPool == nil { + caPool = x509.NewCertPool() } - caPool := x509.NewCertPool() - if !caPool.AppendCertsFromPEM(caData) { - return nil, errors.New("parse ca cert") + if cfg.ServerCAPath != "" { + caData, err := os.ReadFile(cfg.ServerCAPath) + if err != nil { + return nil, fmt.Errorf("read server ca cert: %w", err) + } + if !caPool.AppendCertsFromPEM(caData) { + return nil, errors.New("parse server ca cert") + } } tlsConfig := &tls.Config{ @@ -63,14 +68,16 @@ type EnrollRequest struct { Token string `json:"token"` CSRPEM string `json:"csr_pem"` Host string `json:"host"` + IPv4 string `json:"ipv4,omitempty"` + IPv6 string `json:"ipv6,omitempty"` AgentID string `json:"agent_id,omitempty"` } type EnrollResponse struct { - ServerID string `json:"server_id"` + ServerID string `json:"server_id"` ClientCert string `json:"client_cert_pem"` CACert string `json:"ca_cert_pem"` - SyncProfile string `json:"sync_profile,omitempty"` + SyncProfile string `json:"sync_profile,omitempty"` DisplayName string `json:"display_name,omitempty"` } @@ -126,7 +133,34 @@ func (c *Client) SendLogBatch(ctx context.Context, serverID string, payload []by } defer resp.Body.Close() if resp.StatusCode >= 300 { - return fmt.Errorf("log batch failed: status %s", resp.Status) + return &HTTPStatusError{StatusCode: resp.StatusCode, Status: resp.Status} + } + return nil +} + +type HeartbeatRequest struct { + Host string `json:"host,omitempty"` + IPv4 string `json:"ipv4,omitempty"` + IPv6 string `json:"ipv6,omitempty"` +} + +func (c *Client) UpdateHost(ctx context.Context, serverID string, reqBody HeartbeatRequest) error { + body, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("encode host update: %w", err) + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/agent/servers/"+serverID+"/heartbeat", bytes.NewReader(body)) + if err != nil { + return fmt.Errorf("build host update: %w", err) + } + req.Header.Set("Content-Type", "application/json") + resp, err := c.http.Do(req) + if err != nil { + return fmt.Errorf("send host update: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode >= 300 { + return &HTTPStatusError{StatusCode: resp.StatusCode, Status: resp.Status} } return nil } diff --git a/agent/internal/client/errors.go b/agent/internal/client/errors.go new file mode 100644 index 0000000..d82ff14 --- /dev/null +++ b/agent/internal/client/errors.go @@ -0,0 +1,36 @@ +package client + +import ( + "context" + "errors" + "net" +) + +type HTTPStatusError struct { + StatusCode int + Status string +} + +func (e *HTTPStatusError) Error() string { + return "remote status " + e.Status +} + +func IsRetriable(err error) bool { + if err == nil { + return false + } + var statusErr *HTTPStatusError + if errors.As(err, &statusErr) { + switch statusErr.StatusCode { + case 404, 408, 429, 500, 502, 503, 504: + return true + default: + return false + } + } + if errors.Is(err, context.DeadlineExceeded) { + return true + } + var netErr net.Error + return errors.As(err, &netErr) +} diff --git a/agent/internal/config/config.go b/agent/internal/config/config.go index c09dc03..1701ab1 100644 --- a/agent/internal/config/config.go +++ b/agent/internal/config/config.go @@ -29,6 +29,7 @@ type AccountPolicy struct { type Config struct { ServerURL string `json:"server_url"` ServerID string `json:"server_id,omitempty"` + ServerCAPath string `json:"server_ca_path,omitempty"` SyncIntervalSeconds int `json:"sync_interval_seconds,omitempty"` LogBatchSize int `json:"log_batch_size,omitempty"` StateDir string `json:"state_dir,omitempty"` @@ -47,7 +48,7 @@ func LoadOrInit(path string, serverURL string) (*Config, error) { if serverURL == "" { return nil, errors.New("server url required for first boot") } - cfg := &Config{ServerURL: serverURL} + cfg := &Config{ServerURL: serverURL, ServerCAPath: os.Getenv("KEYWARDEN_SERVER_CA_PATH")} applyDefaults(cfg) if err := validate(cfg, false); err != nil { return nil, err @@ -61,6 +62,9 @@ func LoadOrInit(path string, serverURL string) (*Config, error) { if err := json.Unmarshal(data, cfg); err != nil { return nil, fmt.Errorf("parse config: %w", err) } + if cfg.ServerCAPath == "" { + cfg.ServerCAPath = os.Getenv("KEYWARDEN_SERVER_CA_PATH") + } applyDefaults(cfg) if err := validate(cfg, false); err != nil { return nil, err diff --git a/agent/internal/host/host.go b/agent/internal/host/host.go new file mode 100644 index 0000000..13c083c --- /dev/null +++ b/agent/internal/host/host.go @@ -0,0 +1,57 @@ +package host + +import ( + "net" + "os" +) + +type Info struct { + Hostname string + IPv4 string + IPv6 string +} + +func Detect() Info { + hostname, _ := os.Hostname() + info := Info{Hostname: hostname} + ifaces, err := net.Interfaces() + if err != nil { + return info + } + for _, iface := range ifaces { + if iface.Flags&net.FlagUp == 0 || iface.Flags&net.FlagLoopback != 0 { + continue + } + addrs, err := iface.Addrs() + if err != nil { + continue + } + for _, addr := range addrs { + var ip net.IP + switch v := addr.(type) { + case *net.IPNet: + ip = v.IP + case *net.IPAddr: + ip = v.IP + default: + continue + } + if ip == nil || ip.IsLoopback() || ip.IsUnspecified() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { + continue + } + if ip4 := ip.To4(); ip4 != nil { + if info.IPv4 == "" { + info.IPv4 = ip4.String() + } + continue + } + if ip.To16() != nil && info.IPv6 == "" { + info.IPv6 = ip.String() + } + } + if info.IPv4 != "" && info.IPv6 != "" { + break + } + } + return info +} diff --git a/agent/keywarden-agent b/agent/keywarden-agent index 65d6dc5..c724980 100755 Binary files a/agent/keywarden-agent and b/agent/keywarden-agent differ diff --git a/app/keywarden/api/routers/agent.py b/app/keywarden/api/routers/agent.py index 74dc6bd..6a9870b 100644 --- a/app/keywarden/api/routers/agent.py +++ b/app/keywarden/api/routers/agent.py @@ -6,6 +6,7 @@ from cryptography.hazmat.primitives import hashes, serialization from cryptography.x509.oid import ExtendedKeyUsageOID, NameOID from django.conf import settings from django.core.exceptions import ValidationError +from django.core.validators import validate_ipv4_address, validate_ipv6_address from django.db import IntegrityError, models, transaction from django.http import HttpRequest from django.utils import timezone @@ -44,6 +45,8 @@ class AgentEnrollIn(Schema): token: str csr_pem: str host: Optional[str] = None + ipv4: Optional[str] = None + ipv6: Optional[str] = None class AgentEnrollOut(Schema): @@ -73,6 +76,12 @@ class LogIngestOut(Schema): accepted: int +class AgentHeartbeatIn(Schema): + host: Optional[str] = None + ipv4: Optional[str] = None + ipv6: Optional[str] = None + + def build_router() -> Router: router = Router() @@ -99,11 +108,18 @@ def build_router() -> Router: hostname = host except ValidationError: hostname = None + ipv4 = _normalize_ip(payload.ipv4, 4) + ipv6 = _normalize_ip(payload.ipv6, 6) csr = _load_csr((payload.csr_pem or "").strip()) try: with transaction.atomic(): - server = Server.objects.create(display_name=display_name, hostname=hostname) + server = Server.objects.create( + display_name=display_name, + hostname=hostname, + ipv4=ipv4, + ipv6=ipv6, + ) token.mark_used(server) token.save(update_fields=["used_at", "server"]) cert_pem, ca_pem, fingerprint, serial = _issue_client_cert(csr, host, server.id) @@ -189,6 +205,38 @@ def build_router() -> Router: # TODO: enqueue to Valkey and persist to SQLite slices. return LogIngestOut(status="accepted", accepted=len(payload)) + @router.post("/servers/{server_id}/heartbeat", response=SyncReportOut, auth=None) + @csrf_exempt + def heartbeat(request: HttpRequest, server_id: int, payload: AgentHeartbeatIn = Body(...)): + """Update server host metadata (mTLS required at the edge).""" + try: + server = Server.objects.get(id=server_id) + except Server.DoesNotExist: + raise HttpError(404, "Server not found") + updates: dict[str, str] = {} + host = (payload.host or "").strip()[:253] + if host: + try: + hostname_validator(host) + if server.hostname != host: + updates["hostname"] = host + except ValidationError: + pass + ipv4 = _normalize_ip(payload.ipv4, 4) + if ipv4 and server.ipv4 != ipv4: + updates["ipv4"] = ipv4 + ipv6 = _normalize_ip(payload.ipv6, 6) + if ipv6 and server.ipv6 != ipv6: + updates["ipv6"] = ipv6 + if updates: + for field, value in updates.items(): + setattr(server, field, value) + try: + server.save(update_fields=list(updates.keys())) + except IntegrityError: + raise HttpError(409, "Server address already in use") + return SyncReportOut(status="ok") + return router @@ -250,4 +298,17 @@ def _issue_client_cert( return cert_pem, ca_pem, fingerprint, serial +def _normalize_ip(value: Optional[str], version: int) -> Optional[str]: + if not value: + return None + try: + if version == 4: + validate_ipv4_address(value) + else: + validate_ipv6_address(value) + except ValidationError: + return None + return value + + router = build_router() diff --git a/app/keywarden/api/routers/servers.py b/app/keywarden/api/routers/servers.py index 8d8aa25..581f495 100644 --- a/app/keywarden/api/routers/servers.py +++ b/app/keywarden/api/routers/servers.py @@ -2,10 +2,8 @@ from __future__ import annotations from typing import List, Optional -from django.db import IntegrityError from django.http import HttpRequest -from ninja import File, Form, Router, Schema -from ninja.files import UploadedFile +from ninja import Router, Schema from ninja.errors import HttpError from guardian.shortcuts import get_objects_for_user from apps.core.rbac import require_perms @@ -22,18 +20,8 @@ class ServerOut(Schema): initial: str -class ServerCreate(Schema): - display_name: str - hostname: Optional[str] = None - ipv4: Optional[str] = None - ipv6: Optional[str] = None - - class ServerUpdate(Schema): display_name: Optional[str] = None - hostname: Optional[str] = None - ipv4: Optional[str] = None - ipv6: Optional[str] = None def build_router() -> Router: @@ -87,55 +75,21 @@ def build_router() -> Router: "initial": server.initial, } - @router.post("/", response=ServerOut) - def create_server_json(request: HttpRequest, payload: ServerCreate): - """Create a server using JSON payload (admin only).""" - require_perms(request, "servers.add_server") - raise HttpError(403, "Servers are created via agent enrollment tokens.") - - @router.post("/upload", response=ServerOut) - def create_server_multipart( - request: HttpRequest, - display_name: str = Form(...), - hostname: Optional[str] = Form(None), - ipv4: Optional[str] = Form(None), - ipv6: Optional[str] = Form(None), - image: Optional[UploadedFile] = File(None), - ): - """Create a server with optional image upload (admin only).""" - require_perms(request, "servers.add_server") - raise HttpError(403, "Servers are created via agent enrollment tokens.") - @router.patch("/{server_id}", response=ServerOut) def update_server(request: HttpRequest, server_id: int, payload: ServerUpdate): - """Update server fields (admin only).""" + """Update server display name (admin only).""" require_perms(request, "servers.change_server") - if ( - payload.display_name is None - and payload.hostname is None - and payload.ipv4 is None - and payload.ipv6 is None - ): + if payload.display_name is None: raise HttpError(422, {"detail": "No fields provided."}) try: server = Server.objects.get(id=server_id) except Server.DoesNotExist: raise HttpError(404, "Not Found") - if payload.display_name is not None: - display_name = payload.display_name.strip() - if not display_name: - raise HttpError(422, {"display_name": ["Display name cannot be empty."]}) - server.display_name = display_name - if payload.hostname is not None: - server.hostname = (payload.hostname or "").strip() or None - if payload.ipv4 is not None: - server.ipv4 = (payload.ipv4 or "").strip() or None - if payload.ipv6 is not None: - server.ipv6 = (payload.ipv6 or "").strip() or None - try: - server.save() - except IntegrityError: - raise HttpError(422, {"detail": "Unique constraint violated."}) + display_name = payload.display_name.strip() + if not display_name: + raise HttpError(422, {"display_name": ["Display name cannot be empty."]}) + server.display_name = display_name + server.save(update_fields=["display_name"]) return { "id": server.id, "display_name": server.display_name, @@ -146,17 +100,6 @@ def build_router() -> Router: "initial": server.initial, } - @router.delete("/{server_id}", response={204: None}) - def delete_server(request: HttpRequest, server_id: int): - """Delete a server by id (admin only).""" - require_perms(request, "servers.delete_server") - try: - server = Server.objects.get(id=server_id) - except Server.DoesNotExist: - raise HttpError(404, "Not Found") - server.delete() - return 204, None - return router