Agent retries on connection loss, sends connection info (v4 v6) Uses system CA for mTLS. Removed server endpoints.
This commit is contained in:
10
API_DOCS.md
10
API_DOCS.md
@@ -13,3 +13,13 @@ Authentication:
|
|||||||
Notes:
|
Notes:
|
||||||
- Base URL for v1 endpoints is `/api/v1`.
|
- Base URL for v1 endpoints is `/api/v1`.
|
||||||
- Admin-only routes return `403 Forbidden` when the token user is not staff/superuser.
|
- Admin-only routes return `403 Forbidden` when the token user is not staff/superuser.
|
||||||
|
|
||||||
|
Example: update server display name (admin-only)
|
||||||
|
|
||||||
|
PATCH `/api/v1/servers/{server_id}`
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"display_name": "Keywarden Prod"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|||||||
@@ -20,4 +20,6 @@ You can also pass `KEYWARDEN_SERVER_URL` and `KEYWARDEN_ENROLL_TOKEN` as environ
|
|||||||
|
|
||||||
On first boot, the agent will create a config file if it does not exist. Only `server_url` is required for bootstrapping.
|
On first boot, the agent will create a config file if it does not exist. Only `server_url` is required for bootstrapping.
|
||||||
|
|
||||||
|
If the Keywarden server uses a private TLS CA, set `server_ca_path` (or `KEYWARDEN_SERVER_CA_PATH`) to the CA PEM file so the agent can verify the server certificate.
|
||||||
|
|
||||||
See `config.example.json`.
|
See `config.example.json`.
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
|
|
||||||
"keywarden/agent/internal/client"
|
"keywarden/agent/internal/client"
|
||||||
"keywarden/agent/internal/config"
|
"keywarden/agent/internal/config"
|
||||||
|
"keywarden/agent/internal/host"
|
||||||
"keywarden/agent/internal/logs"
|
"keywarden/agent/internal/logs"
|
||||||
"keywarden/agent/internal/version"
|
"keywarden/agent/internal/version"
|
||||||
)
|
)
|
||||||
@@ -74,11 +75,22 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func runOnce(ctx context.Context, apiClient *client.Client, cfg *config.Config) {
|
func runOnce(ctx context.Context, apiClient *client.Client, cfg *config.Config) {
|
||||||
|
if err := reportHost(ctx, apiClient, cfg); err != nil {
|
||||||
|
if client.IsRetriable(err) {
|
||||||
|
log.Printf("host update deferred; will retry: %v", err)
|
||||||
|
} else {
|
||||||
|
log.Printf("host update error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
if err := apiClient.SyncAccounts(ctx, cfg.ServerID); err != nil {
|
if err := apiClient.SyncAccounts(ctx, cfg.ServerID); err != nil {
|
||||||
log.Printf("sync accounts error: %v", err)
|
log.Printf("sync accounts error: %v", err)
|
||||||
}
|
}
|
||||||
if err := shipLogs(ctx, apiClient, cfg); err != nil {
|
if err := shipLogs(ctx, apiClient, cfg); err != nil {
|
||||||
log.Printf("log shipping error: %v", err)
|
if client.IsRetriable(err) {
|
||||||
|
log.Printf("log shipping deferred; will retry: %v", err)
|
||||||
|
} else {
|
||||||
|
log.Printf("log shipping error: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -94,7 +106,9 @@ func ensureDirs(cfg *config.Config) error {
|
|||||||
|
|
||||||
func shipLogs(ctx context.Context, apiClient *client.Client, cfg *config.Config) error {
|
func shipLogs(ctx context.Context, apiClient *client.Client, cfg *config.Config) error {
|
||||||
send := func(payload []byte) error {
|
send := func(payload []byte) error {
|
||||||
return apiClient.SendLogBatch(ctx, cfg.ServerID, payload)
|
return retry(ctx, []time.Duration{250 * time.Millisecond, time.Second, 2 * time.Second}, func() error {
|
||||||
|
return apiClient.SendLogBatch(ctx, cfg.ServerID, payload)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
if err := logs.DrainSpool(cfg.LogSpoolDir(), send); err != nil {
|
if err := logs.DrainSpool(cfg.LogSpoolDir(), send); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -128,6 +142,17 @@ func shipLogs(ctx context.Context, apiClient *client.Client, cfg *config.Config)
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func reportHost(ctx context.Context, apiClient *client.Client, cfg *config.Config) error {
|
||||||
|
info := host.Detect()
|
||||||
|
return retry(ctx, []time.Duration{250 * time.Millisecond, time.Second, 2 * time.Second}, func() error {
|
||||||
|
return apiClient.UpdateHost(ctx, cfg.ServerID, client.HeartbeatRequest{
|
||||||
|
Host: info.Hostname,
|
||||||
|
IPv4: info.IPv4,
|
||||||
|
IPv6: info.IPv6,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func pickServerURL(flagValue string) string {
|
func pickServerURL(flagValue string) string {
|
||||||
if flagValue != "" {
|
if flagValue != "" {
|
||||||
return flagValue
|
return flagValue
|
||||||
@@ -159,11 +184,14 @@ func bootstrapIfNeeded(cfg *config.Config, configPath string, enrollToken string
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
hostname, _ := os.Hostname()
|
info := host.Detect()
|
||||||
|
hostname := info.Hostname
|
||||||
resp, err := client.Enroll(context.Background(), cfg.ServerURL, client.EnrollRequest{
|
resp, err := client.Enroll(context.Background(), cfg.ServerURL, client.EnrollRequest{
|
||||||
Token: enrollToken,
|
Token: enrollToken,
|
||||||
CSRPEM: csrPEM,
|
CSRPEM: csrPEM,
|
||||||
Host: hostname,
|
Host: hostname,
|
||||||
|
IPv4: info.IPv4,
|
||||||
|
IPv6: info.IPv6,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -181,6 +209,28 @@ func bootstrapIfNeeded(cfg *config.Config, configPath string, enrollToken string
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func retry(ctx context.Context, delays []time.Duration, fn func() error) error {
|
||||||
|
var lastErr error
|
||||||
|
for attempt := 0; attempt <= len(delays); attempt++ {
|
||||||
|
if attempt > 0 {
|
||||||
|
if !client.IsRetriable(lastErr) {
|
||||||
|
return lastErr
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-time.After(delays[attempt-1]):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := fn(); err != nil {
|
||||||
|
lastErr = err
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return lastErr
|
||||||
|
}
|
||||||
|
|
||||||
func generateKey(path string) error {
|
func generateKey(path string) error {
|
||||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
{
|
{
|
||||||
"server_url": "https://keywarden.dev.ntbx.io/api/v1",
|
"server_url": "https://keywarden.dev.ntbx.io/api/v1",
|
||||||
"server_id": "4",
|
"server_id": "4",
|
||||||
|
"server_ca_path": "",
|
||||||
"sync_interval_seconds": 30,
|
"sync_interval_seconds": 30,
|
||||||
"log_batch_size": 500,
|
"log_batch_size": 500,
|
||||||
"state_dir": "/var/lib/keywarden-agent",
|
"state_dir": "/var/lib/keywarden-agent",
|
||||||
@@ -11,4 +12,4 @@
|
|||||||
"create_home": true,
|
"create_home": true,
|
||||||
"lock_on_revoke": true
|
"lock_on_revoke": true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -32,13 +32,18 @@ func New(cfg *config.Config) (*Client, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("load client cert: %w", err)
|
return nil, fmt.Errorf("load client cert: %w", err)
|
||||||
}
|
}
|
||||||
caData, err := os.ReadFile(cfg.CACertPath())
|
caPool, err := x509.SystemCertPool()
|
||||||
if err != nil {
|
if err != nil || caPool == nil {
|
||||||
return nil, fmt.Errorf("read ca cert: %w", err)
|
caPool = x509.NewCertPool()
|
||||||
}
|
}
|
||||||
caPool := x509.NewCertPool()
|
if cfg.ServerCAPath != "" {
|
||||||
if !caPool.AppendCertsFromPEM(caData) {
|
caData, err := os.ReadFile(cfg.ServerCAPath)
|
||||||
return nil, errors.New("parse ca cert")
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read server ca cert: %w", err)
|
||||||
|
}
|
||||||
|
if !caPool.AppendCertsFromPEM(caData) {
|
||||||
|
return nil, errors.New("parse server ca cert")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tlsConfig := &tls.Config{
|
tlsConfig := &tls.Config{
|
||||||
@@ -63,14 +68,16 @@ type EnrollRequest struct {
|
|||||||
Token string `json:"token"`
|
Token string `json:"token"`
|
||||||
CSRPEM string `json:"csr_pem"`
|
CSRPEM string `json:"csr_pem"`
|
||||||
Host string `json:"host"`
|
Host string `json:"host"`
|
||||||
|
IPv4 string `json:"ipv4,omitempty"`
|
||||||
|
IPv6 string `json:"ipv6,omitempty"`
|
||||||
AgentID string `json:"agent_id,omitempty"`
|
AgentID string `json:"agent_id,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type EnrollResponse struct {
|
type EnrollResponse struct {
|
||||||
ServerID string `json:"server_id"`
|
ServerID string `json:"server_id"`
|
||||||
ClientCert string `json:"client_cert_pem"`
|
ClientCert string `json:"client_cert_pem"`
|
||||||
CACert string `json:"ca_cert_pem"`
|
CACert string `json:"ca_cert_pem"`
|
||||||
SyncProfile string `json:"sync_profile,omitempty"`
|
SyncProfile string `json:"sync_profile,omitempty"`
|
||||||
DisplayName string `json:"display_name,omitempty"`
|
DisplayName string `json:"display_name,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -126,7 +133,34 @@ func (c *Client) SendLogBatch(ctx context.Context, serverID string, payload []by
|
|||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
if resp.StatusCode >= 300 {
|
if resp.StatusCode >= 300 {
|
||||||
return fmt.Errorf("log batch failed: status %s", resp.Status)
|
return &HTTPStatusError{StatusCode: resp.StatusCode, Status: resp.Status}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type HeartbeatRequest struct {
|
||||||
|
Host string `json:"host,omitempty"`
|
||||||
|
IPv4 string `json:"ipv4,omitempty"`
|
||||||
|
IPv6 string `json:"ipv6,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) UpdateHost(ctx context.Context, serverID string, reqBody HeartbeatRequest) error {
|
||||||
|
body, err := json.Marshal(reqBody)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("encode host update: %w", err)
|
||||||
|
}
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/agent/servers/"+serverID+"/heartbeat", bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("build host update: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
resp, err := c.http.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("send host update: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode >= 300 {
|
||||||
|
return &HTTPStatusError{StatusCode: resp.StatusCode, Status: resp.Status}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
36
agent/internal/client/errors.go
Normal file
36
agent/internal/client/errors.go
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
type HTTPStatusError struct {
|
||||||
|
StatusCode int
|
||||||
|
Status string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *HTTPStatusError) Error() string {
|
||||||
|
return "remote status " + e.Status
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsRetriable(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
var statusErr *HTTPStatusError
|
||||||
|
if errors.As(err, &statusErr) {
|
||||||
|
switch statusErr.StatusCode {
|
||||||
|
case 404, 408, 429, 500, 502, 503, 504:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
var netErr net.Error
|
||||||
|
return errors.As(err, &netErr)
|
||||||
|
}
|
||||||
@@ -29,6 +29,7 @@ type AccountPolicy struct {
|
|||||||
type Config struct {
|
type Config struct {
|
||||||
ServerURL string `json:"server_url"`
|
ServerURL string `json:"server_url"`
|
||||||
ServerID string `json:"server_id,omitempty"`
|
ServerID string `json:"server_id,omitempty"`
|
||||||
|
ServerCAPath string `json:"server_ca_path,omitempty"`
|
||||||
SyncIntervalSeconds int `json:"sync_interval_seconds,omitempty"`
|
SyncIntervalSeconds int `json:"sync_interval_seconds,omitempty"`
|
||||||
LogBatchSize int `json:"log_batch_size,omitempty"`
|
LogBatchSize int `json:"log_batch_size,omitempty"`
|
||||||
StateDir string `json:"state_dir,omitempty"`
|
StateDir string `json:"state_dir,omitempty"`
|
||||||
@@ -47,7 +48,7 @@ func LoadOrInit(path string, serverURL string) (*Config, error) {
|
|||||||
if serverURL == "" {
|
if serverURL == "" {
|
||||||
return nil, errors.New("server url required for first boot")
|
return nil, errors.New("server url required for first boot")
|
||||||
}
|
}
|
||||||
cfg := &Config{ServerURL: serverURL}
|
cfg := &Config{ServerURL: serverURL, ServerCAPath: os.Getenv("KEYWARDEN_SERVER_CA_PATH")}
|
||||||
applyDefaults(cfg)
|
applyDefaults(cfg)
|
||||||
if err := validate(cfg, false); err != nil {
|
if err := validate(cfg, false); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -61,6 +62,9 @@ func LoadOrInit(path string, serverURL string) (*Config, error) {
|
|||||||
if err := json.Unmarshal(data, cfg); err != nil {
|
if err := json.Unmarshal(data, cfg); err != nil {
|
||||||
return nil, fmt.Errorf("parse config: %w", err)
|
return nil, fmt.Errorf("parse config: %w", err)
|
||||||
}
|
}
|
||||||
|
if cfg.ServerCAPath == "" {
|
||||||
|
cfg.ServerCAPath = os.Getenv("KEYWARDEN_SERVER_CA_PATH")
|
||||||
|
}
|
||||||
applyDefaults(cfg)
|
applyDefaults(cfg)
|
||||||
if err := validate(cfg, false); err != nil {
|
if err := validate(cfg, false); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
57
agent/internal/host/host.go
Normal file
57
agent/internal/host/host.go
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
package host
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Info struct {
|
||||||
|
Hostname string
|
||||||
|
IPv4 string
|
||||||
|
IPv6 string
|
||||||
|
}
|
||||||
|
|
||||||
|
func Detect() Info {
|
||||||
|
hostname, _ := os.Hostname()
|
||||||
|
info := Info{Hostname: hostname}
|
||||||
|
ifaces, err := net.Interfaces()
|
||||||
|
if err != nil {
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
for _, iface := range ifaces {
|
||||||
|
if iface.Flags&net.FlagUp == 0 || iface.Flags&net.FlagLoopback != 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
addrs, err := iface.Addrs()
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, addr := range addrs {
|
||||||
|
var ip net.IP
|
||||||
|
switch v := addr.(type) {
|
||||||
|
case *net.IPNet:
|
||||||
|
ip = v.IP
|
||||||
|
case *net.IPAddr:
|
||||||
|
ip = v.IP
|
||||||
|
default:
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ip == nil || ip.IsLoopback() || ip.IsUnspecified() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ip4 := ip.To4(); ip4 != nil {
|
||||||
|
if info.IPv4 == "" {
|
||||||
|
info.IPv4 = ip4.String()
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ip.To16() != nil && info.IPv6 == "" {
|
||||||
|
info.IPv6 = ip.String()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if info.IPv4 != "" && info.IPv6 != "" {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return info
|
||||||
|
}
|
||||||
Binary file not shown.
@@ -6,6 +6,7 @@ from cryptography.hazmat.primitives import hashes, serialization
|
|||||||
from cryptography.x509.oid import ExtendedKeyUsageOID, NameOID
|
from cryptography.x509.oid import ExtendedKeyUsageOID, NameOID
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.core.exceptions import ValidationError
|
from django.core.exceptions import ValidationError
|
||||||
|
from django.core.validators import validate_ipv4_address, validate_ipv6_address
|
||||||
from django.db import IntegrityError, models, transaction
|
from django.db import IntegrityError, models, transaction
|
||||||
from django.http import HttpRequest
|
from django.http import HttpRequest
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
@@ -44,6 +45,8 @@ class AgentEnrollIn(Schema):
|
|||||||
token: str
|
token: str
|
||||||
csr_pem: str
|
csr_pem: str
|
||||||
host: Optional[str] = None
|
host: Optional[str] = None
|
||||||
|
ipv4: Optional[str] = None
|
||||||
|
ipv6: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class AgentEnrollOut(Schema):
|
class AgentEnrollOut(Schema):
|
||||||
@@ -73,6 +76,12 @@ class LogIngestOut(Schema):
|
|||||||
accepted: int
|
accepted: int
|
||||||
|
|
||||||
|
|
||||||
|
class AgentHeartbeatIn(Schema):
|
||||||
|
host: Optional[str] = None
|
||||||
|
ipv4: Optional[str] = None
|
||||||
|
ipv6: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
def build_router() -> Router:
|
def build_router() -> Router:
|
||||||
router = Router()
|
router = Router()
|
||||||
|
|
||||||
@@ -99,11 +108,18 @@ def build_router() -> Router:
|
|||||||
hostname = host
|
hostname = host
|
||||||
except ValidationError:
|
except ValidationError:
|
||||||
hostname = None
|
hostname = None
|
||||||
|
ipv4 = _normalize_ip(payload.ipv4, 4)
|
||||||
|
ipv6 = _normalize_ip(payload.ipv6, 6)
|
||||||
|
|
||||||
csr = _load_csr((payload.csr_pem or "").strip())
|
csr = _load_csr((payload.csr_pem or "").strip())
|
||||||
try:
|
try:
|
||||||
with transaction.atomic():
|
with transaction.atomic():
|
||||||
server = Server.objects.create(display_name=display_name, hostname=hostname)
|
server = Server.objects.create(
|
||||||
|
display_name=display_name,
|
||||||
|
hostname=hostname,
|
||||||
|
ipv4=ipv4,
|
||||||
|
ipv6=ipv6,
|
||||||
|
)
|
||||||
token.mark_used(server)
|
token.mark_used(server)
|
||||||
token.save(update_fields=["used_at", "server"])
|
token.save(update_fields=["used_at", "server"])
|
||||||
cert_pem, ca_pem, fingerprint, serial = _issue_client_cert(csr, host, server.id)
|
cert_pem, ca_pem, fingerprint, serial = _issue_client_cert(csr, host, server.id)
|
||||||
@@ -189,6 +205,38 @@ def build_router() -> Router:
|
|||||||
# TODO: enqueue to Valkey and persist to SQLite slices.
|
# TODO: enqueue to Valkey and persist to SQLite slices.
|
||||||
return LogIngestOut(status="accepted", accepted=len(payload))
|
return LogIngestOut(status="accepted", accepted=len(payload))
|
||||||
|
|
||||||
|
@router.post("/servers/{server_id}/heartbeat", response=SyncReportOut, auth=None)
|
||||||
|
@csrf_exempt
|
||||||
|
def heartbeat(request: HttpRequest, server_id: int, payload: AgentHeartbeatIn = Body(...)):
|
||||||
|
"""Update server host metadata (mTLS required at the edge)."""
|
||||||
|
try:
|
||||||
|
server = Server.objects.get(id=server_id)
|
||||||
|
except Server.DoesNotExist:
|
||||||
|
raise HttpError(404, "Server not found")
|
||||||
|
updates: dict[str, str] = {}
|
||||||
|
host = (payload.host or "").strip()[:253]
|
||||||
|
if host:
|
||||||
|
try:
|
||||||
|
hostname_validator(host)
|
||||||
|
if server.hostname != host:
|
||||||
|
updates["hostname"] = host
|
||||||
|
except ValidationError:
|
||||||
|
pass
|
||||||
|
ipv4 = _normalize_ip(payload.ipv4, 4)
|
||||||
|
if ipv4 and server.ipv4 != ipv4:
|
||||||
|
updates["ipv4"] = ipv4
|
||||||
|
ipv6 = _normalize_ip(payload.ipv6, 6)
|
||||||
|
if ipv6 and server.ipv6 != ipv6:
|
||||||
|
updates["ipv6"] = ipv6
|
||||||
|
if updates:
|
||||||
|
for field, value in updates.items():
|
||||||
|
setattr(server, field, value)
|
||||||
|
try:
|
||||||
|
server.save(update_fields=list(updates.keys()))
|
||||||
|
except IntegrityError:
|
||||||
|
raise HttpError(409, "Server address already in use")
|
||||||
|
return SyncReportOut(status="ok")
|
||||||
|
|
||||||
return router
|
return router
|
||||||
|
|
||||||
|
|
||||||
@@ -250,4 +298,17 @@ def _issue_client_cert(
|
|||||||
return cert_pem, ca_pem, fingerprint, serial
|
return cert_pem, ca_pem, fingerprint, serial
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_ip(value: Optional[str], version: int) -> Optional[str]:
|
||||||
|
if not value:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
if version == 4:
|
||||||
|
validate_ipv4_address(value)
|
||||||
|
else:
|
||||||
|
validate_ipv6_address(value)
|
||||||
|
except ValidationError:
|
||||||
|
return None
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
router = build_router()
|
router = build_router()
|
||||||
|
|||||||
@@ -2,10 +2,8 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from django.db import IntegrityError
|
|
||||||
from django.http import HttpRequest
|
from django.http import HttpRequest
|
||||||
from ninja import File, Form, Router, Schema
|
from ninja import Router, Schema
|
||||||
from ninja.files import UploadedFile
|
|
||||||
from ninja.errors import HttpError
|
from ninja.errors import HttpError
|
||||||
from guardian.shortcuts import get_objects_for_user
|
from guardian.shortcuts import get_objects_for_user
|
||||||
from apps.core.rbac import require_perms
|
from apps.core.rbac import require_perms
|
||||||
@@ -22,18 +20,8 @@ class ServerOut(Schema):
|
|||||||
initial: str
|
initial: str
|
||||||
|
|
||||||
|
|
||||||
class ServerCreate(Schema):
|
|
||||||
display_name: str
|
|
||||||
hostname: Optional[str] = None
|
|
||||||
ipv4: Optional[str] = None
|
|
||||||
ipv6: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ServerUpdate(Schema):
|
class ServerUpdate(Schema):
|
||||||
display_name: Optional[str] = None
|
display_name: Optional[str] = None
|
||||||
hostname: Optional[str] = None
|
|
||||||
ipv4: Optional[str] = None
|
|
||||||
ipv6: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
def build_router() -> Router:
|
def build_router() -> Router:
|
||||||
@@ -87,55 +75,21 @@ def build_router() -> Router:
|
|||||||
"initial": server.initial,
|
"initial": server.initial,
|
||||||
}
|
}
|
||||||
|
|
||||||
@router.post("/", response=ServerOut)
|
|
||||||
def create_server_json(request: HttpRequest, payload: ServerCreate):
|
|
||||||
"""Create a server using JSON payload (admin only)."""
|
|
||||||
require_perms(request, "servers.add_server")
|
|
||||||
raise HttpError(403, "Servers are created via agent enrollment tokens.")
|
|
||||||
|
|
||||||
@router.post("/upload", response=ServerOut)
|
|
||||||
def create_server_multipart(
|
|
||||||
request: HttpRequest,
|
|
||||||
display_name: str = Form(...),
|
|
||||||
hostname: Optional[str] = Form(None),
|
|
||||||
ipv4: Optional[str] = Form(None),
|
|
||||||
ipv6: Optional[str] = Form(None),
|
|
||||||
image: Optional[UploadedFile] = File(None),
|
|
||||||
):
|
|
||||||
"""Create a server with optional image upload (admin only)."""
|
|
||||||
require_perms(request, "servers.add_server")
|
|
||||||
raise HttpError(403, "Servers are created via agent enrollment tokens.")
|
|
||||||
|
|
||||||
@router.patch("/{server_id}", response=ServerOut)
|
@router.patch("/{server_id}", response=ServerOut)
|
||||||
def update_server(request: HttpRequest, server_id: int, payload: ServerUpdate):
|
def update_server(request: HttpRequest, server_id: int, payload: ServerUpdate):
|
||||||
"""Update server fields (admin only)."""
|
"""Update server display name (admin only)."""
|
||||||
require_perms(request, "servers.change_server")
|
require_perms(request, "servers.change_server")
|
||||||
if (
|
if payload.display_name is None:
|
||||||
payload.display_name is None
|
|
||||||
and payload.hostname is None
|
|
||||||
and payload.ipv4 is None
|
|
||||||
and payload.ipv6 is None
|
|
||||||
):
|
|
||||||
raise HttpError(422, {"detail": "No fields provided."})
|
raise HttpError(422, {"detail": "No fields provided."})
|
||||||
try:
|
try:
|
||||||
server = Server.objects.get(id=server_id)
|
server = Server.objects.get(id=server_id)
|
||||||
except Server.DoesNotExist:
|
except Server.DoesNotExist:
|
||||||
raise HttpError(404, "Not Found")
|
raise HttpError(404, "Not Found")
|
||||||
if payload.display_name is not None:
|
display_name = payload.display_name.strip()
|
||||||
display_name = payload.display_name.strip()
|
if not display_name:
|
||||||
if not display_name:
|
raise HttpError(422, {"display_name": ["Display name cannot be empty."]})
|
||||||
raise HttpError(422, {"display_name": ["Display name cannot be empty."]})
|
server.display_name = display_name
|
||||||
server.display_name = display_name
|
server.save(update_fields=["display_name"])
|
||||||
if payload.hostname is not None:
|
|
||||||
server.hostname = (payload.hostname or "").strip() or None
|
|
||||||
if payload.ipv4 is not None:
|
|
||||||
server.ipv4 = (payload.ipv4 or "").strip() or None
|
|
||||||
if payload.ipv6 is not None:
|
|
||||||
server.ipv6 = (payload.ipv6 or "").strip() or None
|
|
||||||
try:
|
|
||||||
server.save()
|
|
||||||
except IntegrityError:
|
|
||||||
raise HttpError(422, {"detail": "Unique constraint violated."})
|
|
||||||
return {
|
return {
|
||||||
"id": server.id,
|
"id": server.id,
|
||||||
"display_name": server.display_name,
|
"display_name": server.display_name,
|
||||||
@@ -146,17 +100,6 @@ def build_router() -> Router:
|
|||||||
"initial": server.initial,
|
"initial": server.initial,
|
||||||
}
|
}
|
||||||
|
|
||||||
@router.delete("/{server_id}", response={204: None})
|
|
||||||
def delete_server(request: HttpRequest, server_id: int):
|
|
||||||
"""Delete a server by id (admin only)."""
|
|
||||||
require_perms(request, "servers.delete_server")
|
|
||||||
try:
|
|
||||||
server = Server.objects.get(id=server_id)
|
|
||||||
except Server.DoesNotExist:
|
|
||||||
raise HttpError(404, "Not Found")
|
|
||||||
server.delete()
|
|
||||||
return 204, None
|
|
||||||
|
|
||||||
return router
|
return router
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user