Initial linux agent and api functionality for enrolling servers

This commit is contained in:
2026-01-25 22:24:20 +00:00
parent 66ffa3d3fb
commit 4885622d6a
23 changed files with 1351 additions and 50 deletions

View File

@@ -21,6 +21,7 @@ KEYWARDEN_ADMIN_PASSWORD=password
# Auth mode: native | oidc | hybrid # Auth mode: native | oidc | hybrid
KEYWARDEN_AUTH_MODE=native KEYWARDEN_AUTH_MODE=native
# OIDC (optional) # OIDC (optional)
# KEYWARDEN_OIDC_CLIENT_ID= # KEYWARDEN_OIDC_CLIENT_ID=
# KEYWARDEN_OIDC_CLIENT_SECRET= # KEYWARDEN_OIDC_CLIENT_SECRET=

23
agent/README.md Normal file
View File

@@ -0,0 +1,23 @@
# keywarden-agent
Minimal Go agent scaffold for Keywarden.
## Build
```
go build -o keywarden-agent ./cmd/keywarden-agent
```
## Run
```
./keywarden-agent -config /etc/keywarden/agent.json -server-url https://keywarden.example.com -enroll-token <token>
```
You can also pass `KEYWARDEN_SERVER_URL` and `KEYWARDEN_ENROLL_TOKEN` as environment variables.
## Config
On first boot, the agent will create a config file if it does not exist. Only `server_url` is required for bootstrapping.
See `config.example.json`.

View File

@@ -0,0 +1,223 @@
package main
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/json"
"encoding/pem"
"flag"
"fmt"
"log"
"os"
"os/signal"
"syscall"
"time"
"keywarden/agent/internal/client"
"keywarden/agent/internal/config"
"keywarden/agent/internal/logs"
"keywarden/agent/internal/version"
)
func main() {
configPath := flag.String("config", config.DefaultConfigPath, "Path to agent config JSON")
serverURL := flag.String("server-url", "", "Keywarden server URL (first boot)")
enrollToken := flag.String("enroll-token", "", "Enrollment token (first boot)")
showVersion := flag.Bool("version", false, "Print version and exit")
flag.Parse()
if *showVersion {
fmt.Printf("keywarden-agent %s (commit %s, built %s)\n", version.Version, version.Commit, version.BuildDate)
return
}
cfg, err := config.LoadOrInit(*configPath, pickServerURL(*serverURL))
if err != nil {
log.Fatalf("config error: %v", err)
}
if err := ensureDirs(cfg); err != nil {
log.Fatalf("state dir error: %v", err)
}
if err := bootstrapIfNeeded(cfg, *configPath, pickEnrollToken(*enrollToken)); err != nil {
log.Fatalf("bootstrap error: %v", err)
}
apiClient, err := client.New(cfg)
if err != nil {
log.Fatalf("client error: %v", err)
}
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()
interval := time.Duration(cfg.SyncIntervalSeconds) * time.Second
log.Printf("keywarden-agent started: server_id=%s interval=%s", cfg.ServerID, interval)
ticker := time.NewTicker(interval)
defer ticker.Stop()
runOnce(ctx, apiClient, cfg)
for {
select {
case <-ctx.Done():
log.Printf("shutdown requested")
return
case <-ticker.C:
runOnce(ctx, apiClient, cfg)
}
}
}
func runOnce(ctx context.Context, apiClient *client.Client, cfg *config.Config) {
if err := apiClient.SyncAccounts(ctx, cfg.ServerID); err != nil {
log.Printf("sync accounts error: %v", err)
}
if err := shipLogs(ctx, apiClient, cfg); err != nil {
log.Printf("log shipping error: %v", err)
}
}
func ensureDirs(cfg *config.Config) error {
if err := os.MkdirAll(cfg.StateDir, 0o700); err != nil {
return err
}
if err := os.MkdirAll(cfg.LogSpoolDir(), 0o700); err != nil {
return err
}
return nil
}
func shipLogs(ctx context.Context, apiClient *client.Client, cfg *config.Config) error {
send := func(payload []byte) error {
return apiClient.SendLogBatch(ctx, cfg.ServerID, payload)
}
if err := logs.DrainSpool(cfg.LogSpoolDir(), send); err != nil {
return err
}
cursor, err := logs.ReadCursor(cfg.LogCursorPath())
if err != nil {
return err
}
collector := logs.NewCollector()
events, nextCursor, err := collector.Collect(ctx, cursor, cfg.LogBatchSize)
if err != nil {
return err
}
if len(events) == 0 {
return nil
}
payload, err := json.Marshal(events)
if err != nil {
return err
}
if err := send(payload); err != nil {
if spoolErr := logs.SaveSpool(cfg.LogSpoolDir(), payload); spoolErr != nil {
return spoolErr
}
return err
}
if err := logs.WriteCursor(cfg.LogCursorPath(), nextCursor); err != nil {
return err
}
return nil
}
func pickServerURL(flagValue string) string {
if flagValue != "" {
return flagValue
}
return os.Getenv("KEYWARDEN_SERVER_URL")
}
func pickEnrollToken(flagValue string) string {
if flagValue != "" {
return flagValue
}
return os.Getenv("KEYWARDEN_ENROLL_TOKEN")
}
func bootstrapIfNeeded(cfg *config.Config, configPath string, enrollToken string) error {
if cfg.ServerID != "" && fileExists(cfg.ClientCertPath()) && fileExists(cfg.CACertPath()) {
return nil
}
if enrollToken == "" {
return fmt.Errorf("missing enrollment token; set KEYWARDEN_ENROLL_TOKEN or -enroll-token")
}
keyPath := cfg.ClientKeyPath()
if !fileExists(keyPath) {
if err := generateKey(keyPath); err != nil {
return err
}
}
csrPEM, err := buildCSR(keyPath)
if err != nil {
return err
}
hostname, _ := os.Hostname()
resp, err := client.Enroll(context.Background(), cfg.ServerURL, client.EnrollRequest{
Token: enrollToken,
CSRPEM: csrPEM,
Host: hostname,
})
if err != nil {
return err
}
if err := os.WriteFile(cfg.ClientCertPath(), []byte(resp.ClientCert), 0o600); err != nil {
return err
}
if err := os.WriteFile(cfg.CACertPath(), []byte(resp.CACert), 0o600); err != nil {
return err
}
cfg.ServerID = resp.ServerID
if err := config.Save(configPath, cfg); err != nil {
return err
}
return nil
}
func generateKey(path string) error {
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return err
}
keyDER := x509.MarshalPKCS1PrivateKey(key)
block := &pem.Block{Type: "RSA PRIVATE KEY", Bytes: keyDER}
data := pem.EncodeToMemory(block)
return os.WriteFile(path, data, 0o600)
}
func buildCSR(keyPath string) (string, error) {
keyData, err := os.ReadFile(keyPath)
if err != nil {
return "", err
}
block, _ := pem.Decode(keyData)
if block == nil || block.Type != "RSA PRIVATE KEY" {
return "", fmt.Errorf("invalid private key")
}
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return "", err
}
csrTemplate := &x509.CertificateRequest{Subject: pkix.Name{CommonName: "keywarden-agent"}}
csrDER, err := x509.CreateCertificateRequest(rand.Reader, csrTemplate, key)
if err != nil {
return "", err
}
csrBlock := &pem.Block{Type: "CERTIFICATE REQUEST", Bytes: csrDER}
return string(pem.EncodeToMemory(csrBlock)), nil
}
func fileExists(path string) bool {
info, err := os.Stat(path)
if err != nil {
return false
}
return !info.IsDir()
}

14
agent/config.example.json Normal file
View File

@@ -0,0 +1,14 @@
{
"server_url": "https://keywarden.example.com",
"server_id": "",
"sync_interval_seconds": 30,
"log_batch_size": 500,
"state_dir": "/var/lib/keywarden-agent",
"account_policy": {
"username_template": "{{username}}_{{user_id}}",
"default_shell": "/bin/bash",
"admin_group": "sudo",
"create_home": true,
"lock_on_revoke": true
}
}

7
agent/go.mod Normal file
View File

@@ -0,0 +1,7 @@
module keywarden/agent
go 1.22
require (
github.com/coreos/go-systemd/v22 v22.5.0
)

3
agent/go.sum Normal file
View File

@@ -0,0 +1,3 @@
github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=

View File

@@ -0,0 +1,132 @@
package client
import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"net/http"
"os"
"strings"
"time"
"keywarden/agent/internal/config"
)
const defaultTimeout = 15 * time.Second
type Client struct {
baseURL string
http *http.Client
}
func New(cfg *config.Config) (*Client, error) {
baseURL := strings.TrimRight(cfg.ServerURL, "/")
if baseURL == "" {
return nil, errors.New("server url is required")
}
cert, err := tls.LoadX509KeyPair(cfg.ClientCertPath(), cfg.ClientKeyPath())
if err != nil {
return nil, fmt.Errorf("load client cert: %w", err)
}
caData, err := os.ReadFile(cfg.CACertPath())
if err != nil {
return nil, fmt.Errorf("read ca cert: %w", err)
}
caPool := x509.NewCertPool()
if !caPool.AppendCertsFromPEM(caData) {
return nil, errors.New("parse ca cert")
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
RootCAs: caPool,
MinVersion: tls.VersionTLS12,
}
transport := &http.Transport{
TLSClientConfig: tlsConfig,
}
httpClient := &http.Client{
Timeout: defaultTimeout,
Transport: transport,
}
return &Client{baseURL: baseURL, http: httpClient}, nil
}
type EnrollRequest struct {
Token string `json:"token"`
CSRPEM string `json:"csr_pem"`
Host string `json:"host"`
AgentID string `json:"agent_id,omitempty"`
}
type EnrollResponse struct {
ServerID string `json:"server_id"`
ClientCert string `json:"client_cert_pem"`
CACert string `json:"ca_cert_pem"`
SyncProfile string `json:"sync_profile,omitempty"`
DisplayName string `json:"display_name,omitempty"`
}
func Enroll(ctx context.Context, serverURL string, req EnrollRequest) (*EnrollResponse, error) {
baseURL := strings.TrimRight(serverURL, "/")
if baseURL == "" {
return nil, errors.New("server url is required")
}
body, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("encode enroll request: %w", err)
}
httpClient := &http.Client{Timeout: defaultTimeout}
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+"/agent/enroll", bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("build enroll request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
resp, err := httpClient.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("enroll request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("enroll failed: status %s", resp.Status)
}
var out EnrollResponse
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
return nil, fmt.Errorf("decode enroll response: %w", err)
}
if out.ServerID == "" || out.ClientCert == "" || out.CACert == "" {
return nil, errors.New("enroll response missing required fields")
}
return &out, nil
}
func (c *Client) SyncAccounts(ctx context.Context, serverID string) error {
_ = ctx
_ = serverID
// TODO: call API to fetch account policy + approved access list.
return nil
}
func (c *Client) SendLogBatch(ctx context.Context, serverID string, payload []byte) error {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/agent/servers/"+serverID+"/logs", bytes.NewReader(payload))
if err != nil {
return fmt.Errorf("build log request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.http.Do(req)
if err != nil {
return fmt.Errorf("send log batch: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
return fmt.Errorf("log batch failed: status %s", resp.Status)
}
return nil
}

View File

@@ -0,0 +1,148 @@
package config
import (
"encoding/json"
"errors"
"fmt"
"os"
"strings"
)
const (
DefaultConfigPath = "/etc/keywarden/agent.json"
DefaultStateDir = "/var/lib/keywarden-agent"
DefaultSyncIntervalSeconds = 30
DefaultLogBatchSize = 500
DefaultUsernameTemplate = "{{username}}_{{user_id}}"
DefaultShell = "/bin/bash"
DefaultAdminGroup = "sudo"
)
type AccountPolicy struct {
UsernameTemplate string `json:"username_template"`
DefaultShell string `json:"default_shell"`
AdminGroup string `json:"admin_group"`
CreateHome bool `json:"create_home"`
LockOnRevoke bool `json:"lock_on_revoke"`
}
type Config struct {
ServerURL string `json:"server_url"`
ServerID string `json:"server_id,omitempty"`
SyncIntervalSeconds int `json:"sync_interval_seconds,omitempty"`
LogBatchSize int `json:"log_batch_size,omitempty"`
StateDir string `json:"state_dir,omitempty"`
AccountPolicy AccountPolicy `json:"account_policy,omitempty"`
}
func LoadOrInit(path string, serverURL string) (*Config, error) {
if path == "" {
path = DefaultConfigPath
}
data, err := os.ReadFile(path)
if err != nil {
if !errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("read config: %w", err)
}
if serverURL == "" {
return nil, errors.New("server url required for first boot")
}
cfg := &Config{ServerURL: serverURL}
applyDefaults(cfg)
if err := validate(cfg, false); err != nil {
return nil, err
}
if err := Save(path, cfg); err != nil {
return nil, err
}
return cfg, nil
}
cfg := &Config{}
if err := json.Unmarshal(data, cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
applyDefaults(cfg)
if err := validate(cfg, false); err != nil {
return nil, err
}
return cfg, nil
}
func Save(path string, cfg *Config) error {
data, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
return fmt.Errorf("encode config: %w", err)
}
if err := os.MkdirAll(dir(path), 0o755); err != nil {
return fmt.Errorf("create config dir: %w", err)
}
if err := os.WriteFile(path, data, 0o600); err != nil {
return fmt.Errorf("write config: %w", err)
}
return nil
}
func applyDefaults(cfg *Config) {
if cfg.SyncIntervalSeconds <= 0 {
cfg.SyncIntervalSeconds = DefaultSyncIntervalSeconds
}
if cfg.LogBatchSize <= 0 {
cfg.LogBatchSize = DefaultLogBatchSize
}
if cfg.StateDir == "" {
cfg.StateDir = DefaultStateDir
}
if cfg.AccountPolicy.UsernameTemplate == "" {
cfg.AccountPolicy.UsernameTemplate = DefaultUsernameTemplate
}
if cfg.AccountPolicy.DefaultShell == "" {
cfg.AccountPolicy.DefaultShell = DefaultShell
}
if cfg.AccountPolicy.AdminGroup == "" {
cfg.AccountPolicy.AdminGroup = DefaultAdminGroup
}
}
func validate(cfg *Config, requireServerID bool) error {
var missing []string
if cfg.ServerURL == "" {
missing = append(missing, "server_url")
}
if requireServerID && cfg.ServerID == "" {
missing = append(missing, "server_id")
}
if len(missing) > 0 {
return fmt.Errorf("missing required config fields: %v", missing)
}
if cfg.SyncIntervalSeconds < 5 {
return errors.New("sync_interval_seconds must be >= 5")
}
return nil
}
func (c *Config) ClientCertPath() string {
return c.StateDir + "/agent.crt"
}
func (c *Config) ClientKeyPath() string {
return c.StateDir + "/agent.key"
}
func (c *Config) CACertPath() string {
return c.StateDir + "/ca.crt"
}
func (c *Config) LogCursorPath() string {
return c.StateDir + "/journal.cursor"
}
func (c *Config) LogSpoolDir() string {
return c.StateDir + "/spool"
}
func dir(path string) string {
if idx := strings.LastIndex(path, string(os.PathSeparator)); idx != -1 {
return path[:idx]
}
return "."
}

View File

@@ -0,0 +1,177 @@
package logs
import (
"context"
"strings"
"time"
"github.com/coreos/go-systemd/v22/sdjournal"
)
const defaultLimit = 500
type Collector struct {
matches []string
}
func NewCollector() *Collector {
return &Collector{matches: defaultMatches()}
}
func (c *Collector) Collect(ctx context.Context, cursor string, limit int) ([]Event, string, error) {
if limit <= 0 {
limit = defaultLimit
}
j, err := sdjournal.NewJournal()
if err != nil {
return nil, "", err
}
defer j.Close()
for i, match := range c.matches {
if i > 0 {
if err := j.AddDisjunction(); err != nil {
return nil, "", err
}
}
if err := j.AddMatch(match); err != nil {
return nil, "", err
}
}
if cursor != "" {
if err := j.SeekCursor(cursor); err == nil {
_, _ = j.Next()
}
} else {
_ = j.SeekTail()
_, _ = j.Next()
}
var events []Event
var nextCursor string
for len(events) < limit {
select {
case <-ctx.Done():
return events, nextCursor, ctx.Err()
default:
}
n, err := j.Next()
if err != nil {
return events, nextCursor, err
}
if n == 0 {
break
}
entry, err := j.GetEntry()
if err != nil {
return events, nextCursor, err
}
event := fromEntry(entry)
events = append(events, event)
nextCursor = entry.Cursor
}
return events, nextCursor, nil
}
func defaultMatches() []string {
return []string{
"_SYSTEMD_UNIT=sshd.service",
"_SYSTEMD_UNIT=sudo.service",
"_SYSTEMD_UNIT=systemd-networkd.service",
"_SYSTEMD_UNIT=NetworkManager.service",
"_SYSTEMD_UNIT=systemd-logind.service",
"_TRANSPORT=kernel",
}
}
func fromEntry(entry *sdjournal.JournalEntry) Event {
ts := time.Unix(0, int64(entry.RealtimeTimestamp)*int64(time.Microsecond))
event := NewEvent(ts)
fields := entry.Fields
unit := fields["_SYSTEMD_UNIT"]
message := fields["MESSAGE"]
identifier := fields["SYSLOG_IDENTIFIER"]
event.Unit = unit
event.Message = message
event.Priority = fields["PRIORITY"]
event.Hostname = fields["_HOSTNAME"]
event.Fields = fields
event.Category = categorize(unit, identifier, fields)
event.EventType, event.Username, event.SourceIP, event.SessionID = parseMessage(event.Category, message)
if event.EventType == "" {
event.EventType = defaultEventType(event.Category)
}
return event
}
func categorize(unit string, identifier string, fields map[string]string) string {
switch {
case unit == "sshd.service" || identifier == "sshd":
return "access"
case unit == "sudo.service" || identifier == "sudo":
return "auth"
case unit == "systemd-networkd.service" || identifier == "NetworkManager":
return "network"
case fields["_TRANSPORT"] == "kernel":
return "system"
default:
return "system"
}
}
func defaultEventType(category string) string {
switch category {
case "access":
return "ssh"
case "auth":
return "auth"
case "network":
return "network"
default:
return "system"
}
}
func parseMessage(category string, msg string) (eventType string, username string, sourceIP string, sessionID string) {
if msg == "" {
return "", "", "", ""
}
lower := strings.ToLower(msg)
if category == "access" {
switch {
case strings.Contains(lower, "accepted"):
eventType = "ssh.login.success"
username = extractBetween(msg, "for ", " from")
sourceIP = extractBetween(msg, "from ", " port")
case strings.Contains(lower, "failed password"):
eventType = "ssh.login.fail"
username = extractBetween(msg, "for ", " from")
sourceIP = extractBetween(msg, "from ", " port")
case strings.Contains(lower, "session opened"):
eventType = "ssh.session.open"
username = extractBetween(msg, "for user ", " by")
case strings.Contains(lower, "session closed"):
eventType = "ssh.session.close"
username = extractBetween(msg, "for user ", " by")
}
}
return eventType, strings.TrimSpace(username), strings.TrimSpace(sourceIP), strings.TrimSpace(sessionID)
}
func extractBetween(msg string, start string, end string) string {
startIdx := strings.Index(msg, start)
if startIdx == -1 {
return ""
}
startIdx += len(start)
rest := msg[startIdx:]
endIdx := strings.Index(rest, end)
if endIdx == -1 {
return strings.TrimSpace(rest)
}
return strings.TrimSpace(rest[:endIdx])
}

View File

@@ -0,0 +1,24 @@
package logs
import (
"os"
"strings"
)
func ReadCursor(path string) (string, error) {
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return "", nil
}
return "", err
}
return strings.TrimSpace(string(data)), nil
}
func WriteCursor(path string, cursor string) error {
if cursor == "" {
return nil
}
return os.WriteFile(path, []byte(cursor+"\n"), 0o600)
}

View File

@@ -0,0 +1,53 @@
package logs
import (
"fmt"
"os"
"path/filepath"
"sort"
"time"
)
func SaveSpool(dir string, payload []byte) error {
if err := os.MkdirAll(dir, 0o700); err != nil {
return err
}
name := fmt.Sprintf("%d.json", time.Now().UnixNano())
tmp := filepath.Join(dir, name+".tmp")
final := filepath.Join(dir, name)
if err := os.WriteFile(tmp, payload, 0o600); err != nil {
return err
}
return os.Rename(tmp, final)
}
func DrainSpool(dir string, send func([]byte) error) error {
entries, err := os.ReadDir(dir)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return err
}
var files []string
for _, entry := range entries {
if entry.IsDir() {
continue
}
files = append(files, filepath.Join(dir, entry.Name()))
}
sort.Strings(files)
for _, path := range files {
data, err := os.ReadFile(path)
if err != nil {
return err
}
if err := send(data); err != nil {
return err
}
if err := os.Remove(path); err != nil {
return err
}
}
return nil
}

View File

@@ -0,0 +1,23 @@
package logs
import "time"
type Event struct {
Timestamp string `json:"timestamp"`
Category string `json:"category"`
EventType string `json:"event_type"`
Unit string `json:"unit,omitempty"`
Priority string `json:"priority,omitempty"`
Hostname string `json:"hostname,omitempty"`
Username string `json:"username,omitempty"`
Principal string `json:"principal,omitempty"`
SourceIP string `json:"source_ip,omitempty"`
SessionID string `json:"session_id,omitempty"`
Message string `json:"message,omitempty"`
Raw string `json:"raw,omitempty"`
Fields map[string]string `json:"fields,omitempty"`
}
func NewEvent(ts time.Time) Event {
return Event{Timestamp: ts.UTC().Format(time.RFC3339Nano)}
}

View File

@@ -0,0 +1,7 @@
package version
var (
Version = "0.0.1-dev"
Commit = ""
BuildDate = ""
)

BIN
agent/keywarden-agent Executable file

Binary file not shown.

View File

@@ -1,17 +1,27 @@
from django.contrib import admin from django.contrib import admin
from guardian.admin import GuardedModelAdmin
from django.utils.html import format_html from django.utils.html import format_html
from .models import Server from guardian.admin import GuardedModelAdmin
from .models import AgentCertificateAuthority, EnrollmentToken, Server
@admin.register(Server) @admin.register(Server)
class ServerAdmin(GuardedModelAdmin): class ServerAdmin(GuardedModelAdmin):
list_display = ("avatar", "display_name", "hostname", "ipv4", "ipv6", "created_at") list_display = ("avatar", "display_name", "hostname", "ipv4", "ipv6", "agent_enrolled_at", "created_at")
list_display_links = ("display_name",) list_display_links = ("display_name",)
search_fields = ("display_name", "hostname", "ipv4", "ipv6") search_fields = ("display_name", "hostname", "ipv4", "ipv6")
list_filter = ("created_at",) list_filter = ("created_at",)
readonly_fields = ("created_at", "updated_at") readonly_fields = ("created_at", "updated_at", "agent_enrolled_at")
fields = ("display_name", "hostname", "ipv4", "ipv6", "image", "created_at", "updated_at") fields = (
"display_name",
"hostname",
"ipv4",
"ipv6",
"image",
"agent_enrolled_at",
"created_at",
"updated_at",
)
def avatar(self, obj: Server): def avatar(self, obj: Server):
if obj.image_url: if obj.image_url:
@@ -27,3 +37,52 @@ class ServerAdmin(GuardedModelAdmin):
) )
avatar.short_description = "" avatar.short_description = ""
@admin.register(EnrollmentToken)
class EnrollmentTokenAdmin(admin.ModelAdmin):
list_display = ("token", "created_at", "expires_at", "used_at", "server")
list_filter = ("created_at", "used_at")
search_fields = ("token", "server__display_name", "server__hostname")
readonly_fields = ("token", "created_at", "used_at", "server", "created_by")
fields = ("token", "expires_at", "created_by", "created_at", "used_at", "server")
def save_model(self, request, obj, form, change) -> None:
if not obj.pk:
obj.ensure_token()
if request.user and request.user.is_authenticated and not obj.created_by_id:
obj.created_by = request.user
super().save_model(request, obj, form, change)
@admin.register(AgentCertificateAuthority)
class AgentCertificateAuthorityAdmin(admin.ModelAdmin):
list_display = ("name", "is_active", "created_at", "revoked_at")
list_filter = ("is_active", "created_at", "revoked_at")
search_fields = ("name", "fingerprint")
readonly_fields = ("fingerprint", "serial", "created_at", "revoked_at", "created_by")
fields = (
"name",
"is_active",
"cert_pem",
"key_pem",
"fingerprint",
"serial",
"created_by",
"created_at",
"revoked_at",
)
actions = ["revoke_selected"]
def save_model(self, request, obj, form, change) -> None:
if request.user and request.user.is_authenticated and not obj.created_by_id:
obj.created_by = request.user
obj.ensure_material()
if obj.is_active:
AgentCertificateAuthority.objects.exclude(pk=obj.pk).update(is_active=False)
super().save_model(request, obj, form, change)
@admin.action(description="Revoke selected CAs")
def revoke_selected(self, request, queryset):
for ca in queryset:
ca.revoke()
ca.save(update_fields=["is_active", "revoked_at"])

View File

@@ -0,0 +1,73 @@
from django.conf import settings
from django.db import migrations, models
import django.utils.timezone
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
("servers", "0001_initial"),
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]
operations = [
migrations.AddField(
model_name="server",
name="agent_cert_fingerprint",
field=models.CharField(blank=True, max_length=128, null=True),
),
migrations.AddField(
model_name="server",
name="agent_cert_serial",
field=models.CharField(blank=True, max_length=64, null=True),
),
migrations.AddField(
model_name="server",
name="agent_enrolled_at",
field=models.DateTimeField(blank=True, null=True),
),
migrations.CreateModel(
name="EnrollmentToken",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("token", models.CharField(max_length=128, unique=True)),
("created_at", models.DateTimeField(default=django.utils.timezone.now, editable=False)),
("expires_at", models.DateTimeField(blank=True, null=True)),
("used_at", models.DateTimeField(blank=True, null=True)),
(
"created_by",
models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="server_enrollment_tokens",
to=settings.AUTH_USER_MODEL,
),
),
(
"server",
models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="enrollment_tokens",
to="servers.server",
),
),
],
options={
"verbose_name": "Enrollment token",
"verbose_name_plural": "Enrollment tokens",
"ordering": ["-created_at"],
},
),
migrations.AddIndex(
model_name="enrollmenttoken",
index=models.Index(fields=["created_at"], name="servers_enroll_created_idx"),
),
migrations.AddIndex(
model_name="enrollmenttoken",
index=models.Index(fields=["used_at"], name="servers_enroll_used_idx"),
),
]

View File

@@ -0,0 +1,44 @@
from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
import django.utils.timezone
class Migration(migrations.Migration):
dependencies = [
("servers", "0002_agent_enrollment"),
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]
operations = [
migrations.CreateModel(
name="AgentCertificateAuthority",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("name", models.CharField(default="Keywarden Agent CA", max_length=128)),
("cert_pem", models.TextField()),
("key_pem", models.TextField()),
("fingerprint", models.CharField(blank=True, max_length=128)),
("serial", models.CharField(blank=True, max_length=64)),
("created_at", models.DateTimeField(default=django.utils.timezone.now, editable=False)),
("revoked_at", models.DateTimeField(blank=True, null=True)),
("is_active", models.BooleanField(db_index=True, default=True)),
(
"created_by",
models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="agent_certificate_authorities",
to=settings.AUTH_USER_MODEL,
),
),
],
options={
"verbose_name": "Agent certificate authority",
"verbose_name_plural": "Agent certificate authorities",
"ordering": ["-created_at"],
},
),
]

View File

@@ -1,8 +1,16 @@
from __future__ import annotations from __future__ import annotations
import secrets
from datetime import datetime, timedelta
from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509.oid import NameOID
from django.conf import settings
from django.core.validators import RegexValidator from django.core.validators import RegexValidator
from django.db import models from django.db import models
from django.utils.text import slugify from django.utils import timezone
hostname_validator = RegexValidator( hostname_validator = RegexValidator(
@@ -17,6 +25,9 @@ class Server(models.Model):
ipv4 = models.GenericIPAddressField(null=True, blank=True, protocol="IPv4", unique=True) ipv4 = models.GenericIPAddressField(null=True, blank=True, protocol="IPv4", unique=True)
ipv6 = models.GenericIPAddressField(null=True, blank=True, protocol="IPv6", unique=True) ipv6 = models.GenericIPAddressField(null=True, blank=True, protocol="IPv6", unique=True)
image = models.ImageField(upload_to="servers/", null=True, blank=True) image = models.ImageField(upload_to="servers/", null=True, blank=True)
agent_enrolled_at = models.DateTimeField(null=True, blank=True)
agent_cert_fingerprint = models.CharField(max_length=128, null=True, blank=True)
agent_cert_serial = models.CharField(max_length=64, null=True, blank=True)
created_at = models.DateTimeField(auto_now_add=True) created_at = models.DateTimeField(auto_now_add=True)
updated_at = models.DateTimeField(auto_now=True) updated_at = models.DateTimeField(auto_now=True)
@@ -41,3 +52,108 @@ class Server(models.Model):
return (self.display_name or "?").strip()[:1].upper() or "?" return (self.display_name or "?").strip()[:1].upper() or "?"
class EnrollmentToken(models.Model):
token = models.CharField(max_length=128, unique=True)
created_at = models.DateTimeField(default=timezone.now, editable=False)
expires_at = models.DateTimeField(null=True, blank=True)
created_by = models.ForeignKey(
settings.AUTH_USER_MODEL,
null=True,
blank=True,
on_delete=models.SET_NULL,
related_name="server_enrollment_tokens",
)
used_at = models.DateTimeField(null=True, blank=True)
server = models.ForeignKey(
Server, null=True, blank=True, on_delete=models.SET_NULL, related_name="enrollment_tokens"
)
class Meta:
verbose_name = "Enrollment token"
verbose_name_plural = "Enrollment tokens"
indexes = [
models.Index(fields=["created_at"], name="servers_enroll_created_idx"),
models.Index(fields=["used_at"], name="servers_enroll_used_idx"),
]
ordering = ["-created_at"]
def __str__(self) -> str:
return f"{self.token[:8]}... ({'used' if self.used_at else 'unused'})"
def ensure_token(self) -> None:
if not self.token:
self.token = secrets.token_urlsafe(32)
def is_valid(self) -> bool:
if self.used_at:
return False
if self.expires_at and self.expires_at <= timezone.now():
return False
return True
def mark_used(self, server: Server) -> None:
self.used_at = timezone.now()
self.server = server
def save(self, *args, **kwargs):
self.ensure_token()
super().save(*args, **kwargs)
class AgentCertificateAuthority(models.Model):
name = models.CharField(max_length=128, default="Keywarden Agent CA")
cert_pem = models.TextField()
key_pem = models.TextField()
fingerprint = models.CharField(max_length=128, blank=True)
serial = models.CharField(max_length=64, blank=True)
created_at = models.DateTimeField(default=timezone.now, editable=False)
revoked_at = models.DateTimeField(null=True, blank=True)
is_active = models.BooleanField(default=True, db_index=True)
created_by = models.ForeignKey(
settings.AUTH_USER_MODEL,
null=True,
blank=True,
on_delete=models.SET_NULL,
related_name="agent_certificate_authorities",
)
class Meta:
verbose_name = "Agent certificate authority"
verbose_name_plural = "Agent certificate authorities"
ordering = ["-created_at"]
def __str__(self) -> str:
status = "active" if self.is_active and not self.revoked_at else "revoked"
return f"{self.name} ({status})"
def revoke(self) -> None:
self.is_active = False
self.revoked_at = timezone.now()
def ensure_material(self) -> None:
if self.cert_pem and self.key_pem:
return
key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
subject = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, self.name)])
now = datetime.utcnow()
cert = (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(subject)
.public_key(key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(now - timedelta(minutes=5))
.not_valid_after(now + timedelta(days=3650))
.add_extension(x509.BasicConstraints(ca=True, path_length=None), critical=True)
.sign(key, hashes.SHA256())
)
cert_pem = cert.public_bytes(serialization.Encoding.PEM).decode("utf-8")
key_pem = key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
).decode("utf-8")
self.cert_pem = cert_pem
self.key_pem = key_pem
self.fingerprint = cert.fingerprint(hashes.SHA256()).hex()
self.serial = format(cert.serial_number, "x")

View File

@@ -70,7 +70,14 @@ def build_router() -> Router:
@router.get("/", response=List[AccessRequestOut]) @router.get("/", response=List[AccessRequestOut])
def list_requests(request: HttpRequest, filters: AccessQuery = Query(...)): def list_requests(request: HttpRequest, filters: AccessQuery = Query(...)):
"""List access requests for the user, or all if admin/operator.""" """List access requests with pagination and filters.
Auth: required.
Permissions:
- If user has global `access.view_accessrequest`, returns all requests.
- Otherwise, returns only objects with `access.view_accessrequest` object permission.
Filters: status, server_id, requester_id (requester_id is honored only with global view).
"""
require_authenticated(request) require_authenticated(request)
user = request.user user = request.user
if _has_global_perm(request, "access.view_accessrequest"): if _has_global_perm(request, "access.view_accessrequest"):
@@ -94,7 +101,12 @@ def build_router() -> Router:
@router.post("/", response=AccessRequestOut) @router.post("/", response=AccessRequestOut)
def create_request(request: HttpRequest, payload: AccessRequestCreateIn): def create_request(request: HttpRequest, payload: AccessRequestCreateIn):
"""Create a new access request for a server.""" """Create a new access request for the current user.
Auth: required.
Permissions: requires global `access.add_accessrequest`.
Side effects: grants owner object perms on the new request.
"""
require_authenticated(request) require_authenticated(request)
if not request.user.has_perm("access.add_accessrequest"): if not request.user.has_perm("access.add_accessrequest"):
raise HttpError(403, "Forbidden") raise HttpError(403, "Forbidden")
@@ -116,7 +128,11 @@ def build_router() -> Router:
@router.get("/{request_id}", response=AccessRequestOut) @router.get("/{request_id}", response=AccessRequestOut)
def get_request(request: HttpRequest, request_id: int): def get_request(request: HttpRequest, request_id: int):
"""Get an access request if permitted.""" """Get a single access request by id.
Auth: required.
Permissions: requires `access.view_accessrequest` on the object.
"""
require_authenticated(request) require_authenticated(request)
try: try:
access_request = AccessRequest.objects.get(id=request_id) access_request = AccessRequest.objects.get(id=request_id)
@@ -128,7 +144,15 @@ def build_router() -> Router:
@router.patch("/{request_id}", response=AccessRequestOut) @router.patch("/{request_id}", response=AccessRequestOut)
def update_request(request: HttpRequest, request_id: int, payload: AccessRequestUpdateIn): def update_request(request: HttpRequest, request_id: int, payload: AccessRequestUpdateIn):
"""Update request status or expiry (admin/operator or owner with restrictions).""" """Update request status or expiry.
Auth: required.
Permissions: requires `access.change_accessrequest` on the object.
Rules:
- Admin/operator (global change) can set status to approved/denied/revoked/cancelled and
update expires_at.
- Non-admin can only set status to cancelled, and only while pending.
"""
require_authenticated(request) require_authenticated(request)
try: try:
access_request = AccessRequest.objects.get(id=request_id) access_request = AccessRequest.objects.get(id=request_id)
@@ -171,7 +195,11 @@ def build_router() -> Router:
@router.delete("/{request_id}", response={204: None}) @router.delete("/{request_id}", response={204: None})
def delete_request(request: HttpRequest, request_id: int): def delete_request(request: HttpRequest, request_id: int):
"""Delete an access request if permitted.""" """Delete an access request.
Auth: required.
Permissions: requires `access.delete_accessrequest` on the object.
"""
require_authenticated(request) require_authenticated(request)
try: try:
access_request = AccessRequest.objects.get(id=request_id) access_request = AccessRequest.objects.get(id=request_id)

View File

@@ -1,7 +1,13 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime, timedelta
from typing import List, Optional from typing import List, Optional
from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.x509.oid import ExtendedKeyUsageOID, NameOID
from django.conf import settings
from django.core.exceptions import ValidationError
from django.db import models from django.db import models
from django.http import HttpRequest from django.http import HttpRequest
from django.utils import timezone from django.utils import timezone
@@ -12,7 +18,7 @@ from pydantic import Field
from apps.core.rbac import require_perms from apps.core.rbac import require_perms
from apps.access.models import AccessRequest from apps.access.models import AccessRequest
from apps.keys.models import SSHKey from apps.keys.models import SSHKey
from apps.servers.models import Server from apps.servers.models import AgentCertificateAuthority, EnrollmentToken, Server, hostname_validator
from apps.telemetry.models import TelemetryEvent from apps.telemetry.models import TelemetryEvent
@@ -35,9 +41,82 @@ class SyncReportOut(Schema):
status: str status: str
class AgentEnrollIn(Schema):
token: str
csr_pem: str
host: Optional[str] = None
class AgentEnrollOut(Schema):
server_id: str
client_cert_pem: str
ca_cert_pem: str
class LogEventIn(Schema):
timestamp: str
category: str
event_type: str
unit: Optional[str] = None
priority: Optional[str] = None
hostname: Optional[str] = None
username: Optional[str] = None
principal: Optional[str] = None
source_ip: Optional[str] = None
session_id: Optional[str] = None
message: Optional[str] = None
raw: Optional[str] = None
fields: Optional[dict] = None
class LogIngestOut(Schema):
status: str
accepted: int
def build_router() -> Router: def build_router() -> Router:
router = Router() router = Router()
@router.post("/enroll", response=AgentEnrollOut, auth=None)
def enroll_agent(request: HttpRequest, payload: AgentEnrollIn):
"""Enroll a server agent using a one-time token."""
token_value = (payload.token or "").strip()
if not token_value:
raise HttpError(422, "Token required")
try:
token = EnrollmentToken.objects.get(token=token_value)
except EnrollmentToken.DoesNotExist:
raise HttpError(403, "Invalid token")
if not token.is_valid():
raise HttpError(403, "Token expired or already used")
host = (payload.host or "").strip()[:253]
display_name = host or "server"
hostname = None
if host:
try:
hostname_validator(host)
hostname = host
except ValidationError:
hostname = None
server = Server.objects.create(display_name=display_name, hostname=hostname)
token.mark_used(server)
token.save(update_fields=["used_at", "server"])
csr = _load_csr((payload.csr_pem or "").strip())
cert_pem, ca_pem, fingerprint, serial = _issue_client_cert(csr, host, server.id)
server.agent_enrolled_at = timezone.now()
server.agent_cert_fingerprint = fingerprint
server.agent_cert_serial = serial
server.save(update_fields=["agent_enrolled_at", "agent_cert_fingerprint", "agent_cert_serial"])
return AgentEnrollOut(
server_id=str(server.id),
client_cert_pem=cert_pem,
ca_cert_pem=ca_pem,
)
@router.get("/servers/{server_id}/authorized-keys", response=List[AuthorizedKeyOut]) @router.get("/servers/{server_id}/authorized-keys", response=List[AuthorizedKeyOut])
def authorized_keys(request: HttpRequest, server_id: int): def authorized_keys(request: HttpRequest, server_id: int):
"""Return authorized public keys for a server (admin or operator).""" """Return authorized public keys for a server (admin or operator)."""
@@ -96,7 +175,75 @@ def build_router() -> Router:
) )
return SyncReportOut(status="ok") return SyncReportOut(status="ok")
@router.post("/servers/{server_id}/logs", response=LogIngestOut, auth=None)
def ingest_logs(request: HttpRequest, server_id: int, payload: List[LogEventIn]):
"""Accept log batches from agents (mTLS required at the edge)."""
try:
Server.objects.get(id=server_id)
except Server.DoesNotExist:
raise HttpError(404, "Server not found")
# TODO: enqueue to Valkey and persist to SQLite slices.
return LogIngestOut(status="accepted", accepted=len(payload))
return router return router
def _load_agent_ca() -> tuple[x509.Certificate, object, str]:
ca = (
AgentCertificateAuthority.objects.filter(is_active=True, revoked_at__isnull=True)
.order_by("-created_at")
.first()
)
if not ca:
raise HttpError(500, "Agent CA not configured")
try:
ca_cert = x509.load_pem_x509_certificate(ca.cert_pem.encode("utf-8"))
ca_key = serialization.load_pem_private_key(ca.key_pem.encode("utf-8"), password=None)
except (ValueError, TypeError):
raise HttpError(500, "Invalid agent CA material")
return ca_cert, ca_key, ca.cert_pem
def _load_csr(csr_pem: str) -> x509.CertificateSigningRequest:
try:
csr = x509.load_pem_x509_csr(csr_pem.encode("utf-8"))
except ValueError:
raise HttpError(422, "Invalid CSR")
if not csr.is_signature_valid:
raise HttpError(422, "Invalid CSR signature")
return csr
def _issue_client_cert(
csr: x509.CertificateSigningRequest, host: str | None, server_id: int
) -> tuple[str, str, str, str]:
ca_cert, ca_key, ca_pem = _load_agent_ca()
now = datetime.utcnow()
subject = csr.subject
if len(subject) == 0:
subject = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, f"keywarden-agent-{server_id}")])
builder = (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(ca_cert.subject)
.public_key(csr.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(now - timedelta(minutes=5))
.not_valid_after(now + timedelta(days=settings.KEYWARDEN_AGENT_CERT_VALIDITY_DAYS))
.add_extension(x509.BasicConstraints(ca=False, path_length=None), critical=True)
.add_extension(x509.ExtendedKeyUsage([ExtendedKeyUsageOID.CLIENT_AUTH]), critical=False)
)
if host:
try:
hostname_validator(host)
builder = builder.add_extension(x509.SubjectAlternativeName([x509.DNSName(host)]), critical=False)
except ValidationError:
pass
cert = builder.sign(private_key=ca_key, algorithm=hashes.SHA256())
cert_pem = cert.public_bytes(serialization.Encoding.PEM).decode("utf-8")
fingerprint = cert.fingerprint(hashes.SHA256()).hex()
serial = format(cert.serial_number, "x")
return cert_pem, ca_pem, fingerprint, serial
router = build_router() router = build_router()

View File

@@ -69,7 +69,14 @@ def build_router() -> Router:
@router.get("/", response=List[KeyOut]) @router.get("/", response=List[KeyOut])
def list_keys(request: HttpRequest, filters: KeysQuery = Query(...)): def list_keys(request: HttpRequest, filters: KeysQuery = Query(...)):
"""List SSH keys for the current user, or any user if admin/operator.""" """List SSH keys with pagination and filters.
Auth: required.
Permissions:
- If user has global `keys.view_sshkey`, returns all keys.
- Otherwise, returns only objects with `keys.view_sshkey` object permission.
Filter: user_id (honored only with global view).
"""
require_authenticated(request) require_authenticated(request)
user = request.user user = request.user
if _has_global_perm(request, "keys.view_sshkey"): if _has_global_perm(request, "keys.view_sshkey"):
@@ -89,7 +96,15 @@ def build_router() -> Router:
@router.post("/", response=KeyOut) @router.post("/", response=KeyOut)
def create_key(request: HttpRequest, payload: KeyCreateIn): def create_key(request: HttpRequest, payload: KeyCreateIn):
"""Create an SSH public key for the current user (admin/operator can specify user_id).""" """Create an SSH public key.
Auth: required.
Permissions: requires global `keys.add_sshkey`.
Rules:
- Default owner is the current user.
- If caller has global `keys.add_sshkey` and `keys.view_sshkey`, they may specify user_id.
Side effects: grants owner object perms on the new key.
"""
require_authenticated(request) require_authenticated(request)
if not request.user.has_perm("keys.add_sshkey"): if not request.user.has_perm("keys.add_sshkey"):
raise HttpError(403, "Forbidden") raise HttpError(403, "Forbidden")
@@ -121,7 +136,11 @@ def build_router() -> Router:
@router.get("/{key_id}", response=KeyOut) @router.get("/{key_id}", response=KeyOut)
def get_key(request: HttpRequest, key_id: int): def get_key(request: HttpRequest, key_id: int):
"""Get a specific SSH key if permitted.""" """Get a specific SSH key by id.
Auth: required.
Permissions: requires `keys.view_sshkey` on the object.
"""
require_authenticated(request) require_authenticated(request)
try: try:
key = SSHKey.objects.get(id=key_id) key = SSHKey.objects.get(id=key_id)
@@ -133,7 +152,11 @@ def build_router() -> Router:
@router.patch("/{key_id}", response=KeyOut) @router.patch("/{key_id}", response=KeyOut)
def update_key(request: HttpRequest, key_id: int, payload: KeyUpdateIn): def update_key(request: HttpRequest, key_id: int, payload: KeyUpdateIn):
"""Update key name or active state if permitted.""" """Update key name or active state.
Auth: required.
Permissions: requires `keys.change_sshkey` on the object.
"""
require_authenticated(request) require_authenticated(request)
try: try:
key = SSHKey.objects.get(id=key_id) key = SSHKey.objects.get(id=key_id)
@@ -159,7 +182,12 @@ def build_router() -> Router:
@router.delete("/{key_id}", response={204: None}) @router.delete("/{key_id}", response={204: None})
def delete_key(request: HttpRequest, key_id: int): def delete_key(request: HttpRequest, key_id: int):
"""Revoke an SSH key if permitted (soft delete).""" """Revoke (soft delete) an SSH key.
Auth: required.
Permissions: requires `keys.delete_sshkey` on the object.
Behavior: sets is_active false and revoked_at if key is active.
"""
require_authenticated(request) require_authenticated(request)
try: try:
key = SSHKey.objects.get(id=key_id) key = SSHKey.objects.get(id=key_id)

View File

@@ -78,21 +78,7 @@ def build_router() -> Router:
def create_server_json(request: HttpRequest, payload: ServerCreate): def create_server_json(request: HttpRequest, payload: ServerCreate):
"""Create a server using JSON payload (admin only).""" """Create a server using JSON payload (admin only)."""
require_perms(request, "servers.add_server") require_perms(request, "servers.add_server")
server = Server.objects.create( raise HttpError(403, "Servers are created via agent enrollment tokens.")
display_name=payload.display_name.strip(),
hostname=(payload.hostname or "").strip() or None,
ipv4=(payload.ipv4 or "").strip() or None,
ipv6=(payload.ipv6 or "").strip() or None,
)
return {
"id": server.id,
"display_name": server.display_name,
"hostname": server.hostname,
"ipv4": server.ipv4,
"ipv6": server.ipv6,
"image_url": server.image_url,
"initial": server.initial,
}
@router.post("/upload", response=ServerOut) @router.post("/upload", response=ServerOut)
def create_server_multipart( def create_server_multipart(
@@ -105,24 +91,7 @@ def build_router() -> Router:
): ):
"""Create a server with optional image upload (admin only).""" """Create a server with optional image upload (admin only)."""
require_perms(request, "servers.add_server") require_perms(request, "servers.add_server")
server = Server( raise HttpError(403, "Servers are created via agent enrollment tokens.")
display_name=display_name.strip(),
hostname=(hostname or "").strip() or None,
ipv4=(ipv4 or "").strip() or None,
ipv6=(ipv6 or "").strip() or None,
)
if image:
server.image.save(image.name, image) # type: ignore[arg-type]
server.save()
return {
"id": server.id,
"display_name": server.display_name,
"hostname": server.hostname,
"ipv4": server.ipv4,
"ipv6": server.ipv6,
"image_url": server.image_url,
"initial": server.initial,
}
@router.patch("/{server_id}", response=ServerOut) @router.patch("/{server_id}", response=ServerOut)
def update_server(request: HttpRequest, server_id: int, payload: ServerUpdate): def update_server(request: HttpRequest, server_id: int, payload: ServerUpdate):

View File

@@ -94,6 +94,8 @@ CACHES = {
SESSION_ENGINE = "django.contrib.sessions.backends.cache" SESSION_ENGINE = "django.contrib.sessions.backends.cache"
SESSION_CACHE_ALIAS = "default" SESSION_CACHE_ALIAS = "default"
KEYWARDEN_AGENT_CERT_VALIDITY_DAYS = int(os.getenv("KEYWARDEN_AGENT_CERT_VALIDITY_DAYS", "90"))
PASSWORD_HASHERS = [ PASSWORD_HASHERS = [
"django.contrib.auth.hashers.Argon2PasswordHasher", "django.contrib.auth.hashers.Argon2PasswordHasher",
"django.contrib.auth.hashers.PBKDF2PasswordHasher", "django.contrib.auth.hashers.PBKDF2PasswordHasher",