Agent retries on connection loss, sends connection info (v4 v6) Uses system CA for mTLS. Removed server endpoints.

This commit is contained in:
2026-01-26 01:13:51 +00:00
parent e7d20360a2
commit 69802f3ece
11 changed files with 278 additions and 80 deletions

View File

@@ -13,3 +13,13 @@ Authentication:
Notes: Notes:
- Base URL for v1 endpoints is `/api/v1`. - Base URL for v1 endpoints is `/api/v1`.
- Admin-only routes return `403 Forbidden` when the token user is not staff/superuser. - Admin-only routes return `403 Forbidden` when the token user is not staff/superuser.
Example: update server display name (admin-only)
PATCH `/api/v1/servers/{server_id}`
```json
{
"display_name": "Keywarden Prod"
}
```

View File

@@ -20,4 +20,6 @@ You can also pass `KEYWARDEN_SERVER_URL` and `KEYWARDEN_ENROLL_TOKEN` as environ
On first boot, the agent will create a config file if it does not exist. Only `server_url` is required for bootstrapping. On first boot, the agent will create a config file if it does not exist. Only `server_url` is required for bootstrapping.
If the Keywarden server uses a private TLS CA, set `server_ca_path` (or `KEYWARDEN_SERVER_CA_PATH`) to the CA PEM file so the agent can verify the server certificate.
See `config.example.json`. See `config.example.json`.

View File

@@ -18,6 +18,7 @@ import (
"keywarden/agent/internal/client" "keywarden/agent/internal/client"
"keywarden/agent/internal/config" "keywarden/agent/internal/config"
"keywarden/agent/internal/host"
"keywarden/agent/internal/logs" "keywarden/agent/internal/logs"
"keywarden/agent/internal/version" "keywarden/agent/internal/version"
) )
@@ -74,13 +75,24 @@ func main() {
} }
func runOnce(ctx context.Context, apiClient *client.Client, cfg *config.Config) { func runOnce(ctx context.Context, apiClient *client.Client, cfg *config.Config) {
if err := reportHost(ctx, apiClient, cfg); err != nil {
if client.IsRetriable(err) {
log.Printf("host update deferred; will retry: %v", err)
} else {
log.Printf("host update error: %v", err)
}
}
if err := apiClient.SyncAccounts(ctx, cfg.ServerID); err != nil { if err := apiClient.SyncAccounts(ctx, cfg.ServerID); err != nil {
log.Printf("sync accounts error: %v", err) log.Printf("sync accounts error: %v", err)
} }
if err := shipLogs(ctx, apiClient, cfg); err != nil { if err := shipLogs(ctx, apiClient, cfg); err != nil {
if client.IsRetriable(err) {
log.Printf("log shipping deferred; will retry: %v", err)
} else {
log.Printf("log shipping error: %v", err) log.Printf("log shipping error: %v", err)
} }
} }
}
func ensureDirs(cfg *config.Config) error { func ensureDirs(cfg *config.Config) error {
if err := os.MkdirAll(cfg.StateDir, 0o700); err != nil { if err := os.MkdirAll(cfg.StateDir, 0o700); err != nil {
@@ -94,7 +106,9 @@ func ensureDirs(cfg *config.Config) error {
func shipLogs(ctx context.Context, apiClient *client.Client, cfg *config.Config) error { func shipLogs(ctx context.Context, apiClient *client.Client, cfg *config.Config) error {
send := func(payload []byte) error { send := func(payload []byte) error {
return retry(ctx, []time.Duration{250 * time.Millisecond, time.Second, 2 * time.Second}, func() error {
return apiClient.SendLogBatch(ctx, cfg.ServerID, payload) return apiClient.SendLogBatch(ctx, cfg.ServerID, payload)
})
} }
if err := logs.DrainSpool(cfg.LogSpoolDir(), send); err != nil { if err := logs.DrainSpool(cfg.LogSpoolDir(), send); err != nil {
return err return err
@@ -128,6 +142,17 @@ func shipLogs(ctx context.Context, apiClient *client.Client, cfg *config.Config)
return nil return nil
} }
func reportHost(ctx context.Context, apiClient *client.Client, cfg *config.Config) error {
info := host.Detect()
return retry(ctx, []time.Duration{250 * time.Millisecond, time.Second, 2 * time.Second}, func() error {
return apiClient.UpdateHost(ctx, cfg.ServerID, client.HeartbeatRequest{
Host: info.Hostname,
IPv4: info.IPv4,
IPv6: info.IPv6,
})
})
}
func pickServerURL(flagValue string) string { func pickServerURL(flagValue string) string {
if flagValue != "" { if flagValue != "" {
return flagValue return flagValue
@@ -159,11 +184,14 @@ func bootstrapIfNeeded(cfg *config.Config, configPath string, enrollToken string
if err != nil { if err != nil {
return err return err
} }
hostname, _ := os.Hostname() info := host.Detect()
hostname := info.Hostname
resp, err := client.Enroll(context.Background(), cfg.ServerURL, client.EnrollRequest{ resp, err := client.Enroll(context.Background(), cfg.ServerURL, client.EnrollRequest{
Token: enrollToken, Token: enrollToken,
CSRPEM: csrPEM, CSRPEM: csrPEM,
Host: hostname, Host: hostname,
IPv4: info.IPv4,
IPv6: info.IPv6,
}) })
if err != nil { if err != nil {
return err return err
@@ -181,6 +209,28 @@ func bootstrapIfNeeded(cfg *config.Config, configPath string, enrollToken string
return nil return nil
} }
func retry(ctx context.Context, delays []time.Duration, fn func() error) error {
var lastErr error
for attempt := 0; attempt <= len(delays); attempt++ {
if attempt > 0 {
if !client.IsRetriable(lastErr) {
return lastErr
}
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(delays[attempt-1]):
}
}
if err := fn(); err != nil {
lastErr = err
continue
}
return nil
}
return lastErr
}
func generateKey(path string) error { func generateKey(path string) error {
key, err := rsa.GenerateKey(rand.Reader, 2048) key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil { if err != nil {

View File

@@ -1,6 +1,7 @@
{ {
"server_url": "https://keywarden.dev.ntbx.io/api/v1", "server_url": "https://keywarden.dev.ntbx.io/api/v1",
"server_id": "4", "server_id": "4",
"server_ca_path": "",
"sync_interval_seconds": 30, "sync_interval_seconds": 30,
"log_batch_size": 500, "log_batch_size": 500,
"state_dir": "/var/lib/keywarden-agent", "state_dir": "/var/lib/keywarden-agent",

View File

@@ -32,13 +32,18 @@ func New(cfg *config.Config) (*Client, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("load client cert: %w", err) return nil, fmt.Errorf("load client cert: %w", err)
} }
caData, err := os.ReadFile(cfg.CACertPath()) caPool, err := x509.SystemCertPool()
if err != nil { if err != nil || caPool == nil {
return nil, fmt.Errorf("read ca cert: %w", err) caPool = x509.NewCertPool()
}
if cfg.ServerCAPath != "" {
caData, err := os.ReadFile(cfg.ServerCAPath)
if err != nil {
return nil, fmt.Errorf("read server ca cert: %w", err)
} }
caPool := x509.NewCertPool()
if !caPool.AppendCertsFromPEM(caData) { if !caPool.AppendCertsFromPEM(caData) {
return nil, errors.New("parse ca cert") return nil, errors.New("parse server ca cert")
}
} }
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
@@ -63,6 +68,8 @@ type EnrollRequest struct {
Token string `json:"token"` Token string `json:"token"`
CSRPEM string `json:"csr_pem"` CSRPEM string `json:"csr_pem"`
Host string `json:"host"` Host string `json:"host"`
IPv4 string `json:"ipv4,omitempty"`
IPv6 string `json:"ipv6,omitempty"`
AgentID string `json:"agent_id,omitempty"` AgentID string `json:"agent_id,omitempty"`
} }
@@ -126,7 +133,34 @@ func (c *Client) SendLogBatch(ctx context.Context, serverID string, payload []by
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode >= 300 { if resp.StatusCode >= 300 {
return fmt.Errorf("log batch failed: status %s", resp.Status) return &HTTPStatusError{StatusCode: resp.StatusCode, Status: resp.Status}
}
return nil
}
type HeartbeatRequest struct {
Host string `json:"host,omitempty"`
IPv4 string `json:"ipv4,omitempty"`
IPv6 string `json:"ipv6,omitempty"`
}
func (c *Client) UpdateHost(ctx context.Context, serverID string, reqBody HeartbeatRequest) error {
body, err := json.Marshal(reqBody)
if err != nil {
return fmt.Errorf("encode host update: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/agent/servers/"+serverID+"/heartbeat", bytes.NewReader(body))
if err != nil {
return fmt.Errorf("build host update: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.http.Do(req)
if err != nil {
return fmt.Errorf("send host update: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
return &HTTPStatusError{StatusCode: resp.StatusCode, Status: resp.Status}
} }
return nil return nil
} }

View File

@@ -0,0 +1,36 @@
package client
import (
"context"
"errors"
"net"
)
type HTTPStatusError struct {
StatusCode int
Status string
}
func (e *HTTPStatusError) Error() string {
return "remote status " + e.Status
}
func IsRetriable(err error) bool {
if err == nil {
return false
}
var statusErr *HTTPStatusError
if errors.As(err, &statusErr) {
switch statusErr.StatusCode {
case 404, 408, 429, 500, 502, 503, 504:
return true
default:
return false
}
}
if errors.Is(err, context.DeadlineExceeded) {
return true
}
var netErr net.Error
return errors.As(err, &netErr)
}

View File

@@ -29,6 +29,7 @@ type AccountPolicy struct {
type Config struct { type Config struct {
ServerURL string `json:"server_url"` ServerURL string `json:"server_url"`
ServerID string `json:"server_id,omitempty"` ServerID string `json:"server_id,omitempty"`
ServerCAPath string `json:"server_ca_path,omitempty"`
SyncIntervalSeconds int `json:"sync_interval_seconds,omitempty"` SyncIntervalSeconds int `json:"sync_interval_seconds,omitempty"`
LogBatchSize int `json:"log_batch_size,omitempty"` LogBatchSize int `json:"log_batch_size,omitempty"`
StateDir string `json:"state_dir,omitempty"` StateDir string `json:"state_dir,omitempty"`
@@ -47,7 +48,7 @@ func LoadOrInit(path string, serverURL string) (*Config, error) {
if serverURL == "" { if serverURL == "" {
return nil, errors.New("server url required for first boot") return nil, errors.New("server url required for first boot")
} }
cfg := &Config{ServerURL: serverURL} cfg := &Config{ServerURL: serverURL, ServerCAPath: os.Getenv("KEYWARDEN_SERVER_CA_PATH")}
applyDefaults(cfg) applyDefaults(cfg)
if err := validate(cfg, false); err != nil { if err := validate(cfg, false); err != nil {
return nil, err return nil, err
@@ -61,6 +62,9 @@ func LoadOrInit(path string, serverURL string) (*Config, error) {
if err := json.Unmarshal(data, cfg); err != nil { if err := json.Unmarshal(data, cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err) return nil, fmt.Errorf("parse config: %w", err)
} }
if cfg.ServerCAPath == "" {
cfg.ServerCAPath = os.Getenv("KEYWARDEN_SERVER_CA_PATH")
}
applyDefaults(cfg) applyDefaults(cfg)
if err := validate(cfg, false); err != nil { if err := validate(cfg, false); err != nil {
return nil, err return nil, err

View File

@@ -0,0 +1,57 @@
package host
import (
"net"
"os"
)
type Info struct {
Hostname string
IPv4 string
IPv6 string
}
func Detect() Info {
hostname, _ := os.Hostname()
info := Info{Hostname: hostname}
ifaces, err := net.Interfaces()
if err != nil {
return info
}
for _, iface := range ifaces {
if iface.Flags&net.FlagUp == 0 || iface.Flags&net.FlagLoopback != 0 {
continue
}
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
default:
continue
}
if ip == nil || ip.IsLoopback() || ip.IsUnspecified() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
continue
}
if ip4 := ip.To4(); ip4 != nil {
if info.IPv4 == "" {
info.IPv4 = ip4.String()
}
continue
}
if ip.To16() != nil && info.IPv6 == "" {
info.IPv6 = ip.String()
}
}
if info.IPv4 != "" && info.IPv6 != "" {
break
}
}
return info
}

Binary file not shown.

View File

@@ -6,6 +6,7 @@ from cryptography.hazmat.primitives import hashes, serialization
from cryptography.x509.oid import ExtendedKeyUsageOID, NameOID from cryptography.x509.oid import ExtendedKeyUsageOID, NameOID
from django.conf import settings from django.conf import settings
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.core.validators import validate_ipv4_address, validate_ipv6_address
from django.db import IntegrityError, models, transaction from django.db import IntegrityError, models, transaction
from django.http import HttpRequest from django.http import HttpRequest
from django.utils import timezone from django.utils import timezone
@@ -44,6 +45,8 @@ class AgentEnrollIn(Schema):
token: str token: str
csr_pem: str csr_pem: str
host: Optional[str] = None host: Optional[str] = None
ipv4: Optional[str] = None
ipv6: Optional[str] = None
class AgentEnrollOut(Schema): class AgentEnrollOut(Schema):
@@ -73,6 +76,12 @@ class LogIngestOut(Schema):
accepted: int accepted: int
class AgentHeartbeatIn(Schema):
host: Optional[str] = None
ipv4: Optional[str] = None
ipv6: Optional[str] = None
def build_router() -> Router: def build_router() -> Router:
router = Router() router = Router()
@@ -99,11 +108,18 @@ def build_router() -> Router:
hostname = host hostname = host
except ValidationError: except ValidationError:
hostname = None hostname = None
ipv4 = _normalize_ip(payload.ipv4, 4)
ipv6 = _normalize_ip(payload.ipv6, 6)
csr = _load_csr((payload.csr_pem or "").strip()) csr = _load_csr((payload.csr_pem or "").strip())
try: try:
with transaction.atomic(): with transaction.atomic():
server = Server.objects.create(display_name=display_name, hostname=hostname) server = Server.objects.create(
display_name=display_name,
hostname=hostname,
ipv4=ipv4,
ipv6=ipv6,
)
token.mark_used(server) token.mark_used(server)
token.save(update_fields=["used_at", "server"]) token.save(update_fields=["used_at", "server"])
cert_pem, ca_pem, fingerprint, serial = _issue_client_cert(csr, host, server.id) cert_pem, ca_pem, fingerprint, serial = _issue_client_cert(csr, host, server.id)
@@ -189,6 +205,38 @@ def build_router() -> Router:
# TODO: enqueue to Valkey and persist to SQLite slices. # TODO: enqueue to Valkey and persist to SQLite slices.
return LogIngestOut(status="accepted", accepted=len(payload)) return LogIngestOut(status="accepted", accepted=len(payload))
@router.post("/servers/{server_id}/heartbeat", response=SyncReportOut, auth=None)
@csrf_exempt
def heartbeat(request: HttpRequest, server_id: int, payload: AgentHeartbeatIn = Body(...)):
"""Update server host metadata (mTLS required at the edge)."""
try:
server = Server.objects.get(id=server_id)
except Server.DoesNotExist:
raise HttpError(404, "Server not found")
updates: dict[str, str] = {}
host = (payload.host or "").strip()[:253]
if host:
try:
hostname_validator(host)
if server.hostname != host:
updates["hostname"] = host
except ValidationError:
pass
ipv4 = _normalize_ip(payload.ipv4, 4)
if ipv4 and server.ipv4 != ipv4:
updates["ipv4"] = ipv4
ipv6 = _normalize_ip(payload.ipv6, 6)
if ipv6 and server.ipv6 != ipv6:
updates["ipv6"] = ipv6
if updates:
for field, value in updates.items():
setattr(server, field, value)
try:
server.save(update_fields=list(updates.keys()))
except IntegrityError:
raise HttpError(409, "Server address already in use")
return SyncReportOut(status="ok")
return router return router
@@ -250,4 +298,17 @@ def _issue_client_cert(
return cert_pem, ca_pem, fingerprint, serial return cert_pem, ca_pem, fingerprint, serial
def _normalize_ip(value: Optional[str], version: int) -> Optional[str]:
if not value:
return None
try:
if version == 4:
validate_ipv4_address(value)
else:
validate_ipv6_address(value)
except ValidationError:
return None
return value
router = build_router() router = build_router()

View File

@@ -2,10 +2,8 @@ from __future__ import annotations
from typing import List, Optional from typing import List, Optional
from django.db import IntegrityError
from django.http import HttpRequest from django.http import HttpRequest
from ninja import File, Form, Router, Schema from ninja import Router, Schema
from ninja.files import UploadedFile
from ninja.errors import HttpError from ninja.errors import HttpError
from guardian.shortcuts import get_objects_for_user from guardian.shortcuts import get_objects_for_user
from apps.core.rbac import require_perms from apps.core.rbac import require_perms
@@ -22,18 +20,8 @@ class ServerOut(Schema):
initial: str initial: str
class ServerCreate(Schema):
display_name: str
hostname: Optional[str] = None
ipv4: Optional[str] = None
ipv6: Optional[str] = None
class ServerUpdate(Schema): class ServerUpdate(Schema):
display_name: Optional[str] = None display_name: Optional[str] = None
hostname: Optional[str] = None
ipv4: Optional[str] = None
ipv6: Optional[str] = None
def build_router() -> Router: def build_router() -> Router:
@@ -87,55 +75,21 @@ def build_router() -> Router:
"initial": server.initial, "initial": server.initial,
} }
@router.post("/", response=ServerOut)
def create_server_json(request: HttpRequest, payload: ServerCreate):
"""Create a server using JSON payload (admin only)."""
require_perms(request, "servers.add_server")
raise HttpError(403, "Servers are created via agent enrollment tokens.")
@router.post("/upload", response=ServerOut)
def create_server_multipart(
request: HttpRequest,
display_name: str = Form(...),
hostname: Optional[str] = Form(None),
ipv4: Optional[str] = Form(None),
ipv6: Optional[str] = Form(None),
image: Optional[UploadedFile] = File(None),
):
"""Create a server with optional image upload (admin only)."""
require_perms(request, "servers.add_server")
raise HttpError(403, "Servers are created via agent enrollment tokens.")
@router.patch("/{server_id}", response=ServerOut) @router.patch("/{server_id}", response=ServerOut)
def update_server(request: HttpRequest, server_id: int, payload: ServerUpdate): def update_server(request: HttpRequest, server_id: int, payload: ServerUpdate):
"""Update server fields (admin only).""" """Update server display name (admin only)."""
require_perms(request, "servers.change_server") require_perms(request, "servers.change_server")
if ( if payload.display_name is None:
payload.display_name is None
and payload.hostname is None
and payload.ipv4 is None
and payload.ipv6 is None
):
raise HttpError(422, {"detail": "No fields provided."}) raise HttpError(422, {"detail": "No fields provided."})
try: try:
server = Server.objects.get(id=server_id) server = Server.objects.get(id=server_id)
except Server.DoesNotExist: except Server.DoesNotExist:
raise HttpError(404, "Not Found") raise HttpError(404, "Not Found")
if payload.display_name is not None:
display_name = payload.display_name.strip() display_name = payload.display_name.strip()
if not display_name: if not display_name:
raise HttpError(422, {"display_name": ["Display name cannot be empty."]}) raise HttpError(422, {"display_name": ["Display name cannot be empty."]})
server.display_name = display_name server.display_name = display_name
if payload.hostname is not None: server.save(update_fields=["display_name"])
server.hostname = (payload.hostname or "").strip() or None
if payload.ipv4 is not None:
server.ipv4 = (payload.ipv4 or "").strip() or None
if payload.ipv6 is not None:
server.ipv6 = (payload.ipv6 or "").strip() or None
try:
server.save()
except IntegrityError:
raise HttpError(422, {"detail": "Unique constraint violated."})
return { return {
"id": server.id, "id": server.id,
"display_name": server.display_name, "display_name": server.display_name,
@@ -146,17 +100,6 @@ def build_router() -> Router:
"initial": server.initial, "initial": server.initial,
} }
@router.delete("/{server_id}", response={204: None})
def delete_server(request: HttpRequest, server_id: int):
"""Delete a server by id (admin only)."""
require_perms(request, "servers.delete_server")
try:
server = Server.objects.get(id=server_id)
except Server.DoesNotExist:
raise HttpError(404, "Not Found")
server.delete()
return 204, None
return router return router