Agent retries on connection loss, sends connection info (v4 v6) Uses system CA for mTLS. Removed server endpoints.

This commit is contained in:
2026-01-26 01:13:51 +00:00
parent e7d20360a2
commit 69802f3ece
11 changed files with 278 additions and 80 deletions

View File

@@ -13,3 +13,13 @@ Authentication:
Notes:
- Base URL for v1 endpoints is `/api/v1`.
- Admin-only routes return `403 Forbidden` when the token user is not staff/superuser.
Example: update server display name (admin-only)
PATCH `/api/v1/servers/{server_id}`
```json
{
"display_name": "Keywarden Prod"
}
```

View File

@@ -20,4 +20,6 @@ You can also pass `KEYWARDEN_SERVER_URL` and `KEYWARDEN_ENROLL_TOKEN` as environ
On first boot, the agent will create a config file if it does not exist. Only `server_url` is required for bootstrapping.
If the Keywarden server uses a private TLS CA, set `server_ca_path` (or `KEYWARDEN_SERVER_CA_PATH`) to the CA PEM file so the agent can verify the server certificate.
See `config.example.json`.

View File

@@ -18,6 +18,7 @@ import (
"keywarden/agent/internal/client"
"keywarden/agent/internal/config"
"keywarden/agent/internal/host"
"keywarden/agent/internal/logs"
"keywarden/agent/internal/version"
)
@@ -74,12 +75,23 @@ func main() {
}
func runOnce(ctx context.Context, apiClient *client.Client, cfg *config.Config) {
if err := reportHost(ctx, apiClient, cfg); err != nil {
if client.IsRetriable(err) {
log.Printf("host update deferred; will retry: %v", err)
} else {
log.Printf("host update error: %v", err)
}
}
if err := apiClient.SyncAccounts(ctx, cfg.ServerID); err != nil {
log.Printf("sync accounts error: %v", err)
}
if err := shipLogs(ctx, apiClient, cfg); err != nil {
if client.IsRetriable(err) {
log.Printf("log shipping deferred; will retry: %v", err)
} else {
log.Printf("log shipping error: %v", err)
}
}
}
func ensureDirs(cfg *config.Config) error {
@@ -94,7 +106,9 @@ func ensureDirs(cfg *config.Config) error {
func shipLogs(ctx context.Context, apiClient *client.Client, cfg *config.Config) error {
send := func(payload []byte) error {
return retry(ctx, []time.Duration{250 * time.Millisecond, time.Second, 2 * time.Second}, func() error {
return apiClient.SendLogBatch(ctx, cfg.ServerID, payload)
})
}
if err := logs.DrainSpool(cfg.LogSpoolDir(), send); err != nil {
return err
@@ -128,6 +142,17 @@ func shipLogs(ctx context.Context, apiClient *client.Client, cfg *config.Config)
return nil
}
func reportHost(ctx context.Context, apiClient *client.Client, cfg *config.Config) error {
info := host.Detect()
return retry(ctx, []time.Duration{250 * time.Millisecond, time.Second, 2 * time.Second}, func() error {
return apiClient.UpdateHost(ctx, cfg.ServerID, client.HeartbeatRequest{
Host: info.Hostname,
IPv4: info.IPv4,
IPv6: info.IPv6,
})
})
}
func pickServerURL(flagValue string) string {
if flagValue != "" {
return flagValue
@@ -159,11 +184,14 @@ func bootstrapIfNeeded(cfg *config.Config, configPath string, enrollToken string
if err != nil {
return err
}
hostname, _ := os.Hostname()
info := host.Detect()
hostname := info.Hostname
resp, err := client.Enroll(context.Background(), cfg.ServerURL, client.EnrollRequest{
Token: enrollToken,
CSRPEM: csrPEM,
Host: hostname,
IPv4: info.IPv4,
IPv6: info.IPv6,
})
if err != nil {
return err
@@ -181,6 +209,28 @@ func bootstrapIfNeeded(cfg *config.Config, configPath string, enrollToken string
return nil
}
func retry(ctx context.Context, delays []time.Duration, fn func() error) error {
var lastErr error
for attempt := 0; attempt <= len(delays); attempt++ {
if attempt > 0 {
if !client.IsRetriable(lastErr) {
return lastErr
}
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(delays[attempt-1]):
}
}
if err := fn(); err != nil {
lastErr = err
continue
}
return nil
}
return lastErr
}
func generateKey(path string) error {
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {

View File

@@ -1,6 +1,7 @@
{
"server_url": "https://keywarden.dev.ntbx.io/api/v1",
"server_id": "4",
"server_ca_path": "",
"sync_interval_seconds": 30,
"log_batch_size": 500,
"state_dir": "/var/lib/keywarden-agent",

View File

@@ -32,13 +32,18 @@ func New(cfg *config.Config) (*Client, error) {
if err != nil {
return nil, fmt.Errorf("load client cert: %w", err)
}
caData, err := os.ReadFile(cfg.CACertPath())
if err != nil {
return nil, fmt.Errorf("read ca cert: %w", err)
caPool, err := x509.SystemCertPool()
if err != nil || caPool == nil {
caPool = x509.NewCertPool()
}
if cfg.ServerCAPath != "" {
caData, err := os.ReadFile(cfg.ServerCAPath)
if err != nil {
return nil, fmt.Errorf("read server ca cert: %w", err)
}
caPool := x509.NewCertPool()
if !caPool.AppendCertsFromPEM(caData) {
return nil, errors.New("parse ca cert")
return nil, errors.New("parse server ca cert")
}
}
tlsConfig := &tls.Config{
@@ -63,6 +68,8 @@ type EnrollRequest struct {
Token string `json:"token"`
CSRPEM string `json:"csr_pem"`
Host string `json:"host"`
IPv4 string `json:"ipv4,omitempty"`
IPv6 string `json:"ipv6,omitempty"`
AgentID string `json:"agent_id,omitempty"`
}
@@ -126,7 +133,34 @@ func (c *Client) SendLogBatch(ctx context.Context, serverID string, payload []by
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
return fmt.Errorf("log batch failed: status %s", resp.Status)
return &HTTPStatusError{StatusCode: resp.StatusCode, Status: resp.Status}
}
return nil
}
type HeartbeatRequest struct {
Host string `json:"host,omitempty"`
IPv4 string `json:"ipv4,omitempty"`
IPv6 string `json:"ipv6,omitempty"`
}
func (c *Client) UpdateHost(ctx context.Context, serverID string, reqBody HeartbeatRequest) error {
body, err := json.Marshal(reqBody)
if err != nil {
return fmt.Errorf("encode host update: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/agent/servers/"+serverID+"/heartbeat", bytes.NewReader(body))
if err != nil {
return fmt.Errorf("build host update: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.http.Do(req)
if err != nil {
return fmt.Errorf("send host update: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
return &HTTPStatusError{StatusCode: resp.StatusCode, Status: resp.Status}
}
return nil
}

View File

@@ -0,0 +1,36 @@
package client
import (
"context"
"errors"
"net"
)
type HTTPStatusError struct {
StatusCode int
Status string
}
func (e *HTTPStatusError) Error() string {
return "remote status " + e.Status
}
func IsRetriable(err error) bool {
if err == nil {
return false
}
var statusErr *HTTPStatusError
if errors.As(err, &statusErr) {
switch statusErr.StatusCode {
case 404, 408, 429, 500, 502, 503, 504:
return true
default:
return false
}
}
if errors.Is(err, context.DeadlineExceeded) {
return true
}
var netErr net.Error
return errors.As(err, &netErr)
}

View File

@@ -29,6 +29,7 @@ type AccountPolicy struct {
type Config struct {
ServerURL string `json:"server_url"`
ServerID string `json:"server_id,omitempty"`
ServerCAPath string `json:"server_ca_path,omitempty"`
SyncIntervalSeconds int `json:"sync_interval_seconds,omitempty"`
LogBatchSize int `json:"log_batch_size,omitempty"`
StateDir string `json:"state_dir,omitempty"`
@@ -47,7 +48,7 @@ func LoadOrInit(path string, serverURL string) (*Config, error) {
if serverURL == "" {
return nil, errors.New("server url required for first boot")
}
cfg := &Config{ServerURL: serverURL}
cfg := &Config{ServerURL: serverURL, ServerCAPath: os.Getenv("KEYWARDEN_SERVER_CA_PATH")}
applyDefaults(cfg)
if err := validate(cfg, false); err != nil {
return nil, err
@@ -61,6 +62,9 @@ func LoadOrInit(path string, serverURL string) (*Config, error) {
if err := json.Unmarshal(data, cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
if cfg.ServerCAPath == "" {
cfg.ServerCAPath = os.Getenv("KEYWARDEN_SERVER_CA_PATH")
}
applyDefaults(cfg)
if err := validate(cfg, false); err != nil {
return nil, err

View File

@@ -0,0 +1,57 @@
package host
import (
"net"
"os"
)
type Info struct {
Hostname string
IPv4 string
IPv6 string
}
func Detect() Info {
hostname, _ := os.Hostname()
info := Info{Hostname: hostname}
ifaces, err := net.Interfaces()
if err != nil {
return info
}
for _, iface := range ifaces {
if iface.Flags&net.FlagUp == 0 || iface.Flags&net.FlagLoopback != 0 {
continue
}
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
default:
continue
}
if ip == nil || ip.IsLoopback() || ip.IsUnspecified() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
continue
}
if ip4 := ip.To4(); ip4 != nil {
if info.IPv4 == "" {
info.IPv4 = ip4.String()
}
continue
}
if ip.To16() != nil && info.IPv6 == "" {
info.IPv6 = ip.String()
}
}
if info.IPv4 != "" && info.IPv6 != "" {
break
}
}
return info
}

Binary file not shown.

View File

@@ -6,6 +6,7 @@ from cryptography.hazmat.primitives import hashes, serialization
from cryptography.x509.oid import ExtendedKeyUsageOID, NameOID
from django.conf import settings
from django.core.exceptions import ValidationError
from django.core.validators import validate_ipv4_address, validate_ipv6_address
from django.db import IntegrityError, models, transaction
from django.http import HttpRequest
from django.utils import timezone
@@ -44,6 +45,8 @@ class AgentEnrollIn(Schema):
token: str
csr_pem: str
host: Optional[str] = None
ipv4: Optional[str] = None
ipv6: Optional[str] = None
class AgentEnrollOut(Schema):
@@ -73,6 +76,12 @@ class LogIngestOut(Schema):
accepted: int
class AgentHeartbeatIn(Schema):
host: Optional[str] = None
ipv4: Optional[str] = None
ipv6: Optional[str] = None
def build_router() -> Router:
router = Router()
@@ -99,11 +108,18 @@ def build_router() -> Router:
hostname = host
except ValidationError:
hostname = None
ipv4 = _normalize_ip(payload.ipv4, 4)
ipv6 = _normalize_ip(payload.ipv6, 6)
csr = _load_csr((payload.csr_pem or "").strip())
try:
with transaction.atomic():
server = Server.objects.create(display_name=display_name, hostname=hostname)
server = Server.objects.create(
display_name=display_name,
hostname=hostname,
ipv4=ipv4,
ipv6=ipv6,
)
token.mark_used(server)
token.save(update_fields=["used_at", "server"])
cert_pem, ca_pem, fingerprint, serial = _issue_client_cert(csr, host, server.id)
@@ -189,6 +205,38 @@ def build_router() -> Router:
# TODO: enqueue to Valkey and persist to SQLite slices.
return LogIngestOut(status="accepted", accepted=len(payload))
@router.post("/servers/{server_id}/heartbeat", response=SyncReportOut, auth=None)
@csrf_exempt
def heartbeat(request: HttpRequest, server_id: int, payload: AgentHeartbeatIn = Body(...)):
"""Update server host metadata (mTLS required at the edge)."""
try:
server = Server.objects.get(id=server_id)
except Server.DoesNotExist:
raise HttpError(404, "Server not found")
updates: dict[str, str] = {}
host = (payload.host or "").strip()[:253]
if host:
try:
hostname_validator(host)
if server.hostname != host:
updates["hostname"] = host
except ValidationError:
pass
ipv4 = _normalize_ip(payload.ipv4, 4)
if ipv4 and server.ipv4 != ipv4:
updates["ipv4"] = ipv4
ipv6 = _normalize_ip(payload.ipv6, 6)
if ipv6 and server.ipv6 != ipv6:
updates["ipv6"] = ipv6
if updates:
for field, value in updates.items():
setattr(server, field, value)
try:
server.save(update_fields=list(updates.keys()))
except IntegrityError:
raise HttpError(409, "Server address already in use")
return SyncReportOut(status="ok")
return router
@@ -250,4 +298,17 @@ def _issue_client_cert(
return cert_pem, ca_pem, fingerprint, serial
def _normalize_ip(value: Optional[str], version: int) -> Optional[str]:
if not value:
return None
try:
if version == 4:
validate_ipv4_address(value)
else:
validate_ipv6_address(value)
except ValidationError:
return None
return value
router = build_router()

View File

@@ -2,10 +2,8 @@ from __future__ import annotations
from typing import List, Optional
from django.db import IntegrityError
from django.http import HttpRequest
from ninja import File, Form, Router, Schema
from ninja.files import UploadedFile
from ninja import Router, Schema
from ninja.errors import HttpError
from guardian.shortcuts import get_objects_for_user
from apps.core.rbac import require_perms
@@ -22,18 +20,8 @@ class ServerOut(Schema):
initial: str
class ServerCreate(Schema):
display_name: str
hostname: Optional[str] = None
ipv4: Optional[str] = None
ipv6: Optional[str] = None
class ServerUpdate(Schema):
display_name: Optional[str] = None
hostname: Optional[str] = None
ipv4: Optional[str] = None
ipv6: Optional[str] = None
def build_router() -> Router:
@@ -87,55 +75,21 @@ def build_router() -> Router:
"initial": server.initial,
}
@router.post("/", response=ServerOut)
def create_server_json(request: HttpRequest, payload: ServerCreate):
"""Create a server using JSON payload (admin only)."""
require_perms(request, "servers.add_server")
raise HttpError(403, "Servers are created via agent enrollment tokens.")
@router.post("/upload", response=ServerOut)
def create_server_multipart(
request: HttpRequest,
display_name: str = Form(...),
hostname: Optional[str] = Form(None),
ipv4: Optional[str] = Form(None),
ipv6: Optional[str] = Form(None),
image: Optional[UploadedFile] = File(None),
):
"""Create a server with optional image upload (admin only)."""
require_perms(request, "servers.add_server")
raise HttpError(403, "Servers are created via agent enrollment tokens.")
@router.patch("/{server_id}", response=ServerOut)
def update_server(request: HttpRequest, server_id: int, payload: ServerUpdate):
"""Update server fields (admin only)."""
"""Update server display name (admin only)."""
require_perms(request, "servers.change_server")
if (
payload.display_name is None
and payload.hostname is None
and payload.ipv4 is None
and payload.ipv6 is None
):
if payload.display_name is None:
raise HttpError(422, {"detail": "No fields provided."})
try:
server = Server.objects.get(id=server_id)
except Server.DoesNotExist:
raise HttpError(404, "Not Found")
if payload.display_name is not None:
display_name = payload.display_name.strip()
if not display_name:
raise HttpError(422, {"display_name": ["Display name cannot be empty."]})
server.display_name = display_name
if payload.hostname is not None:
server.hostname = (payload.hostname or "").strip() or None
if payload.ipv4 is not None:
server.ipv4 = (payload.ipv4 or "").strip() or None
if payload.ipv6 is not None:
server.ipv6 = (payload.ipv6 or "").strip() or None
try:
server.save()
except IntegrityError:
raise HttpError(422, {"detail": "Unique constraint violated."})
server.save(update_fields=["display_name"])
return {
"id": server.id,
"display_name": server.display_name,
@@ -146,17 +100,6 @@ def build_router() -> Router:
"initial": server.initial,
}
@router.delete("/{server_id}", response={204: None})
def delete_server(request: HttpRequest, server_id: int):
"""Delete a server by id (admin only)."""
require_perms(request, "servers.delete_server")
try:
server = Server.objects.get(id=server_id)
except Server.DoesNotExist:
raise HttpError(404, "Not Found")
server.delete()
return 204, None
return router