diff --git a/.env.example b/.env.example index cf08c6e..c4df085 100644 --- a/.env.example +++ b/.env.example @@ -21,6 +21,7 @@ KEYWARDEN_ADMIN_PASSWORD=password # Auth mode: native | oidc | hybrid KEYWARDEN_AUTH_MODE=native + # OIDC (optional) # KEYWARDEN_OIDC_CLIENT_ID= # KEYWARDEN_OIDC_CLIENT_SECRET= diff --git a/agent/README.md b/agent/README.md new file mode 100644 index 0000000..6659403 --- /dev/null +++ b/agent/README.md @@ -0,0 +1,23 @@ +# keywarden-agent + +Minimal Go agent scaffold for Keywarden. + +## Build + +``` +go build -o keywarden-agent ./cmd/keywarden-agent +``` + +## Run + +``` +./keywarden-agent -config /etc/keywarden/agent.json -server-url https://keywarden.example.com -enroll-token +``` + +You can also pass `KEYWARDEN_SERVER_URL` and `KEYWARDEN_ENROLL_TOKEN` as environment variables. + +## Config + +On first boot, the agent will create a config file if it does not exist. Only `server_url` is required for bootstrapping. + +See `config.example.json`. diff --git a/agent/cmd/keywarden-agent/main.go b/agent/cmd/keywarden-agent/main.go new file mode 100644 index 0000000..d331ab4 --- /dev/null +++ b/agent/cmd/keywarden-agent/main.go @@ -0,0 +1,223 @@ +package main + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/json" + "encoding/pem" + "flag" + "fmt" + "log" + "os" + "os/signal" + "syscall" + "time" + + "keywarden/agent/internal/client" + "keywarden/agent/internal/config" + "keywarden/agent/internal/logs" + "keywarden/agent/internal/version" +) + +func main() { + configPath := flag.String("config", config.DefaultConfigPath, "Path to agent config JSON") + serverURL := flag.String("server-url", "", "Keywarden server URL (first boot)") + enrollToken := flag.String("enroll-token", "", "Enrollment token (first boot)") + showVersion := flag.Bool("version", false, "Print version and exit") + flag.Parse() + + if *showVersion { + fmt.Printf("keywarden-agent %s (commit %s, built %s)\n", version.Version, version.Commit, version.BuildDate) + return + } + + cfg, err := config.LoadOrInit(*configPath, pickServerURL(*serverURL)) + if err != nil { + log.Fatalf("config error: %v", err) + } + if err := ensureDirs(cfg); err != nil { + log.Fatalf("state dir error: %v", err) + } + + if err := bootstrapIfNeeded(cfg, *configPath, pickEnrollToken(*enrollToken)); err != nil { + log.Fatalf("bootstrap error: %v", err) + } + + apiClient, err := client.New(cfg) + if err != nil { + log.Fatalf("client error: %v", err) + } + + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + + interval := time.Duration(cfg.SyncIntervalSeconds) * time.Second + log.Printf("keywarden-agent started: server_id=%s interval=%s", cfg.ServerID, interval) + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + runOnce(ctx, apiClient, cfg) + + for { + select { + case <-ctx.Done(): + log.Printf("shutdown requested") + return + case <-ticker.C: + runOnce(ctx, apiClient, cfg) + } + } +} + +func runOnce(ctx context.Context, apiClient *client.Client, cfg *config.Config) { + if err := apiClient.SyncAccounts(ctx, cfg.ServerID); err != nil { + log.Printf("sync accounts error: %v", err) + } + if err := shipLogs(ctx, apiClient, cfg); err != nil { + log.Printf("log shipping error: %v", err) + } +} + +func ensureDirs(cfg *config.Config) error { + if err := os.MkdirAll(cfg.StateDir, 0o700); err != nil { + return err + } + if err := os.MkdirAll(cfg.LogSpoolDir(), 0o700); err != nil { + return err + } + return nil +} + +func shipLogs(ctx context.Context, apiClient *client.Client, cfg *config.Config) error { + send := func(payload []byte) error { + return apiClient.SendLogBatch(ctx, cfg.ServerID, payload) + } + if err := logs.DrainSpool(cfg.LogSpoolDir(), send); err != nil { + return err + } + + cursor, err := logs.ReadCursor(cfg.LogCursorPath()) + if err != nil { + return err + } + collector := logs.NewCollector() + events, nextCursor, err := collector.Collect(ctx, cursor, cfg.LogBatchSize) + if err != nil { + return err + } + if len(events) == 0 { + return nil + } + payload, err := json.Marshal(events) + if err != nil { + return err + } + if err := send(payload); err != nil { + if spoolErr := logs.SaveSpool(cfg.LogSpoolDir(), payload); spoolErr != nil { + return spoolErr + } + return err + } + if err := logs.WriteCursor(cfg.LogCursorPath(), nextCursor); err != nil { + return err + } + return nil +} + +func pickServerURL(flagValue string) string { + if flagValue != "" { + return flagValue + } + return os.Getenv("KEYWARDEN_SERVER_URL") +} + +func pickEnrollToken(flagValue string) string { + if flagValue != "" { + return flagValue + } + return os.Getenv("KEYWARDEN_ENROLL_TOKEN") +} + +func bootstrapIfNeeded(cfg *config.Config, configPath string, enrollToken string) error { + if cfg.ServerID != "" && fileExists(cfg.ClientCertPath()) && fileExists(cfg.CACertPath()) { + return nil + } + if enrollToken == "" { + return fmt.Errorf("missing enrollment token; set KEYWARDEN_ENROLL_TOKEN or -enroll-token") + } + keyPath := cfg.ClientKeyPath() + if !fileExists(keyPath) { + if err := generateKey(keyPath); err != nil { + return err + } + } + csrPEM, err := buildCSR(keyPath) + if err != nil { + return err + } + hostname, _ := os.Hostname() + resp, err := client.Enroll(context.Background(), cfg.ServerURL, client.EnrollRequest{ + Token: enrollToken, + CSRPEM: csrPEM, + Host: hostname, + }) + if err != nil { + return err + } + if err := os.WriteFile(cfg.ClientCertPath(), []byte(resp.ClientCert), 0o600); err != nil { + return err + } + if err := os.WriteFile(cfg.CACertPath(), []byte(resp.CACert), 0o600); err != nil { + return err + } + cfg.ServerID = resp.ServerID + if err := config.Save(configPath, cfg); err != nil { + return err + } + return nil +} + +func generateKey(path string) error { + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return err + } + keyDER := x509.MarshalPKCS1PrivateKey(key) + block := &pem.Block{Type: "RSA PRIVATE KEY", Bytes: keyDER} + data := pem.EncodeToMemory(block) + return os.WriteFile(path, data, 0o600) +} + +func buildCSR(keyPath string) (string, error) { + keyData, err := os.ReadFile(keyPath) + if err != nil { + return "", err + } + block, _ := pem.Decode(keyData) + if block == nil || block.Type != "RSA PRIVATE KEY" { + return "", fmt.Errorf("invalid private key") + } + key, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return "", err + } + csrTemplate := &x509.CertificateRequest{Subject: pkix.Name{CommonName: "keywarden-agent"}} + csrDER, err := x509.CreateCertificateRequest(rand.Reader, csrTemplate, key) + if err != nil { + return "", err + } + csrBlock := &pem.Block{Type: "CERTIFICATE REQUEST", Bytes: csrDER} + return string(pem.EncodeToMemory(csrBlock)), nil +} + +func fileExists(path string) bool { + info, err := os.Stat(path) + if err != nil { + return false + } + return !info.IsDir() +} diff --git a/agent/config.example.json b/agent/config.example.json new file mode 100644 index 0000000..f0ebbc8 --- /dev/null +++ b/agent/config.example.json @@ -0,0 +1,14 @@ +{ + "server_url": "https://keywarden.example.com", + "server_id": "", + "sync_interval_seconds": 30, + "log_batch_size": 500, + "state_dir": "/var/lib/keywarden-agent", + "account_policy": { + "username_template": "{{username}}_{{user_id}}", + "default_shell": "/bin/bash", + "admin_group": "sudo", + "create_home": true, + "lock_on_revoke": true + } +} diff --git a/agent/go.mod b/agent/go.mod new file mode 100644 index 0000000..ee3e676 --- /dev/null +++ b/agent/go.mod @@ -0,0 +1,7 @@ +module keywarden/agent + +go 1.22 + +require ( + github.com/coreos/go-systemd/v22 v22.5.0 +) diff --git a/agent/go.sum b/agent/go.sum new file mode 100644 index 0000000..716073a --- /dev/null +++ b/agent/go.sum @@ -0,0 +1,3 @@ +github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs= +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= diff --git a/agent/internal/client/client.go b/agent/internal/client/client.go new file mode 100644 index 0000000..4e4b09d --- /dev/null +++ b/agent/internal/client/client.go @@ -0,0 +1,132 @@ +package client + +import ( + "bytes" + "context" + "crypto/tls" + "crypto/x509" + "encoding/json" + "errors" + "fmt" + "net/http" + "os" + "strings" + "time" + + "keywarden/agent/internal/config" +) + +const defaultTimeout = 15 * time.Second + +type Client struct { + baseURL string + http *http.Client +} + +func New(cfg *config.Config) (*Client, error) { + baseURL := strings.TrimRight(cfg.ServerURL, "/") + if baseURL == "" { + return nil, errors.New("server url is required") + } + cert, err := tls.LoadX509KeyPair(cfg.ClientCertPath(), cfg.ClientKeyPath()) + if err != nil { + return nil, fmt.Errorf("load client cert: %w", err) + } + caData, err := os.ReadFile(cfg.CACertPath()) + if err != nil { + return nil, fmt.Errorf("read ca cert: %w", err) + } + caPool := x509.NewCertPool() + if !caPool.AppendCertsFromPEM(caData) { + return nil, errors.New("parse ca cert") + } + + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{cert}, + RootCAs: caPool, + MinVersion: tls.VersionTLS12, + } + + transport := &http.Transport{ + TLSClientConfig: tlsConfig, + } + + httpClient := &http.Client{ + Timeout: defaultTimeout, + Transport: transport, + } + + return &Client{baseURL: baseURL, http: httpClient}, nil +} + +type EnrollRequest struct { + Token string `json:"token"` + CSRPEM string `json:"csr_pem"` + Host string `json:"host"` + AgentID string `json:"agent_id,omitempty"` +} + +type EnrollResponse struct { + ServerID string `json:"server_id"` + ClientCert string `json:"client_cert_pem"` + CACert string `json:"ca_cert_pem"` + SyncProfile string `json:"sync_profile,omitempty"` + DisplayName string `json:"display_name,omitempty"` +} + +func Enroll(ctx context.Context, serverURL string, req EnrollRequest) (*EnrollResponse, error) { + baseURL := strings.TrimRight(serverURL, "/") + if baseURL == "" { + return nil, errors.New("server url is required") + } + body, err := json.Marshal(req) + if err != nil { + return nil, fmt.Errorf("encode enroll request: %w", err) + } + httpClient := &http.Client{Timeout: defaultTimeout} + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+"/agent/enroll", bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("build enroll request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/json") + resp, err := httpClient.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("enroll request: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("enroll failed: status %s", resp.Status) + } + var out EnrollResponse + if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { + return nil, fmt.Errorf("decode enroll response: %w", err) + } + if out.ServerID == "" || out.ClientCert == "" || out.CACert == "" { + return nil, errors.New("enroll response missing required fields") + } + return &out, nil +} + +func (c *Client) SyncAccounts(ctx context.Context, serverID string) error { + _ = ctx + _ = serverID + // TODO: call API to fetch account policy + approved access list. + return nil +} + +func (c *Client) SendLogBatch(ctx context.Context, serverID string, payload []byte) error { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/agent/servers/"+serverID+"/logs", bytes.NewReader(payload)) + if err != nil { + return fmt.Errorf("build log request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + resp, err := c.http.Do(req) + if err != nil { + return fmt.Errorf("send log batch: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode >= 300 { + return fmt.Errorf("log batch failed: status %s", resp.Status) + } + return nil +} diff --git a/agent/internal/config/config.go b/agent/internal/config/config.go new file mode 100644 index 0000000..c09dc03 --- /dev/null +++ b/agent/internal/config/config.go @@ -0,0 +1,148 @@ +package config + +import ( + "encoding/json" + "errors" + "fmt" + "os" + "strings" +) + +const ( + DefaultConfigPath = "/etc/keywarden/agent.json" + DefaultStateDir = "/var/lib/keywarden-agent" + DefaultSyncIntervalSeconds = 30 + DefaultLogBatchSize = 500 + DefaultUsernameTemplate = "{{username}}_{{user_id}}" + DefaultShell = "/bin/bash" + DefaultAdminGroup = "sudo" +) + +type AccountPolicy struct { + UsernameTemplate string `json:"username_template"` + DefaultShell string `json:"default_shell"` + AdminGroup string `json:"admin_group"` + CreateHome bool `json:"create_home"` + LockOnRevoke bool `json:"lock_on_revoke"` +} + +type Config struct { + ServerURL string `json:"server_url"` + ServerID string `json:"server_id,omitempty"` + SyncIntervalSeconds int `json:"sync_interval_seconds,omitempty"` + LogBatchSize int `json:"log_batch_size,omitempty"` + StateDir string `json:"state_dir,omitempty"` + AccountPolicy AccountPolicy `json:"account_policy,omitempty"` +} + +func LoadOrInit(path string, serverURL string) (*Config, error) { + if path == "" { + path = DefaultConfigPath + } + data, err := os.ReadFile(path) + if err != nil { + if !errors.Is(err, os.ErrNotExist) { + return nil, fmt.Errorf("read config: %w", err) + } + if serverURL == "" { + return nil, errors.New("server url required for first boot") + } + cfg := &Config{ServerURL: serverURL} + applyDefaults(cfg) + if err := validate(cfg, false); err != nil { + return nil, err + } + if err := Save(path, cfg); err != nil { + return nil, err + } + return cfg, nil + } + cfg := &Config{} + if err := json.Unmarshal(data, cfg); err != nil { + return nil, fmt.Errorf("parse config: %w", err) + } + applyDefaults(cfg) + if err := validate(cfg, false); err != nil { + return nil, err + } + return cfg, nil +} + +func Save(path string, cfg *Config) error { + data, err := json.MarshalIndent(cfg, "", " ") + if err != nil { + return fmt.Errorf("encode config: %w", err) + } + if err := os.MkdirAll(dir(path), 0o755); err != nil { + return fmt.Errorf("create config dir: %w", err) + } + if err := os.WriteFile(path, data, 0o600); err != nil { + return fmt.Errorf("write config: %w", err) + } + return nil +} + +func applyDefaults(cfg *Config) { + if cfg.SyncIntervalSeconds <= 0 { + cfg.SyncIntervalSeconds = DefaultSyncIntervalSeconds + } + if cfg.LogBatchSize <= 0 { + cfg.LogBatchSize = DefaultLogBatchSize + } + if cfg.StateDir == "" { + cfg.StateDir = DefaultStateDir + } + if cfg.AccountPolicy.UsernameTemplate == "" { + cfg.AccountPolicy.UsernameTemplate = DefaultUsernameTemplate + } + if cfg.AccountPolicy.DefaultShell == "" { + cfg.AccountPolicy.DefaultShell = DefaultShell + } + if cfg.AccountPolicy.AdminGroup == "" { + cfg.AccountPolicy.AdminGroup = DefaultAdminGroup + } +} + +func validate(cfg *Config, requireServerID bool) error { + var missing []string + if cfg.ServerURL == "" { + missing = append(missing, "server_url") + } + if requireServerID && cfg.ServerID == "" { + missing = append(missing, "server_id") + } + if len(missing) > 0 { + return fmt.Errorf("missing required config fields: %v", missing) + } + if cfg.SyncIntervalSeconds < 5 { + return errors.New("sync_interval_seconds must be >= 5") + } + return nil +} + +func (c *Config) ClientCertPath() string { + return c.StateDir + "/agent.crt" +} + +func (c *Config) ClientKeyPath() string { + return c.StateDir + "/agent.key" +} + +func (c *Config) CACertPath() string { + return c.StateDir + "/ca.crt" +} + +func (c *Config) LogCursorPath() string { + return c.StateDir + "/journal.cursor" +} + +func (c *Config) LogSpoolDir() string { + return c.StateDir + "/spool" +} + +func dir(path string) string { + if idx := strings.LastIndex(path, string(os.PathSeparator)); idx != -1 { + return path[:idx] + } + return "." +} diff --git a/agent/internal/logs/collector.go b/agent/internal/logs/collector.go new file mode 100644 index 0000000..1d4e534 --- /dev/null +++ b/agent/internal/logs/collector.go @@ -0,0 +1,177 @@ +package logs + +import ( + "context" + "strings" + "time" + + "github.com/coreos/go-systemd/v22/sdjournal" +) + +const defaultLimit = 500 + +type Collector struct { + matches []string +} + +func NewCollector() *Collector { + return &Collector{matches: defaultMatches()} +} + +func (c *Collector) Collect(ctx context.Context, cursor string, limit int) ([]Event, string, error) { + if limit <= 0 { + limit = defaultLimit + } + j, err := sdjournal.NewJournal() + if err != nil { + return nil, "", err + } + defer j.Close() + + for i, match := range c.matches { + if i > 0 { + if err := j.AddDisjunction(); err != nil { + return nil, "", err + } + } + if err := j.AddMatch(match); err != nil { + return nil, "", err + } + } + + if cursor != "" { + if err := j.SeekCursor(cursor); err == nil { + _, _ = j.Next() + } + } else { + _ = j.SeekTail() + _, _ = j.Next() + } + + var events []Event + var nextCursor string + for len(events) < limit { + select { + case <-ctx.Done(): + return events, nextCursor, ctx.Err() + default: + } + n, err := j.Next() + if err != nil { + return events, nextCursor, err + } + if n == 0 { + break + } + entry, err := j.GetEntry() + if err != nil { + return events, nextCursor, err + } + event := fromEntry(entry) + events = append(events, event) + nextCursor = entry.Cursor + } + + return events, nextCursor, nil +} + +func defaultMatches() []string { + return []string{ + "_SYSTEMD_UNIT=sshd.service", + "_SYSTEMD_UNIT=sudo.service", + "_SYSTEMD_UNIT=systemd-networkd.service", + "_SYSTEMD_UNIT=NetworkManager.service", + "_SYSTEMD_UNIT=systemd-logind.service", + "_TRANSPORT=kernel", + } +} + +func fromEntry(entry *sdjournal.JournalEntry) Event { + ts := time.Unix(0, int64(entry.RealtimeTimestamp)*int64(time.Microsecond)) + event := NewEvent(ts) + fields := entry.Fields + unit := fields["_SYSTEMD_UNIT"] + message := fields["MESSAGE"] + identifier := fields["SYSLOG_IDENTIFIER"] + + event.Unit = unit + event.Message = message + event.Priority = fields["PRIORITY"] + event.Hostname = fields["_HOSTNAME"] + event.Fields = fields + + event.Category = categorize(unit, identifier, fields) + event.EventType, event.Username, event.SourceIP, event.SessionID = parseMessage(event.Category, message) + if event.EventType == "" { + event.EventType = defaultEventType(event.Category) + } + return event +} + +func categorize(unit string, identifier string, fields map[string]string) string { + switch { + case unit == "sshd.service" || identifier == "sshd": + return "access" + case unit == "sudo.service" || identifier == "sudo": + return "auth" + case unit == "systemd-networkd.service" || identifier == "NetworkManager": + return "network" + case fields["_TRANSPORT"] == "kernel": + return "system" + default: + return "system" + } +} + +func defaultEventType(category string) string { + switch category { + case "access": + return "ssh" + case "auth": + return "auth" + case "network": + return "network" + default: + return "system" + } +} + +func parseMessage(category string, msg string) (eventType string, username string, sourceIP string, sessionID string) { + if msg == "" { + return "", "", "", "" + } + lower := strings.ToLower(msg) + if category == "access" { + switch { + case strings.Contains(lower, "accepted"): + eventType = "ssh.login.success" + username = extractBetween(msg, "for ", " from") + sourceIP = extractBetween(msg, "from ", " port") + case strings.Contains(lower, "failed password"): + eventType = "ssh.login.fail" + username = extractBetween(msg, "for ", " from") + sourceIP = extractBetween(msg, "from ", " port") + case strings.Contains(lower, "session opened"): + eventType = "ssh.session.open" + username = extractBetween(msg, "for user ", " by") + case strings.Contains(lower, "session closed"): + eventType = "ssh.session.close" + username = extractBetween(msg, "for user ", " by") + } + } + return eventType, strings.TrimSpace(username), strings.TrimSpace(sourceIP), strings.TrimSpace(sessionID) +} + +func extractBetween(msg string, start string, end string) string { + startIdx := strings.Index(msg, start) + if startIdx == -1 { + return "" + } + startIdx += len(start) + rest := msg[startIdx:] + endIdx := strings.Index(rest, end) + if endIdx == -1 { + return strings.TrimSpace(rest) + } + return strings.TrimSpace(rest[:endIdx]) +} diff --git a/agent/internal/logs/cursor.go b/agent/internal/logs/cursor.go new file mode 100644 index 0000000..acba1f0 --- /dev/null +++ b/agent/internal/logs/cursor.go @@ -0,0 +1,24 @@ +package logs + +import ( + "os" + "strings" +) + +func ReadCursor(path string) (string, error) { + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return "", nil + } + return "", err + } + return strings.TrimSpace(string(data)), nil +} + +func WriteCursor(path string, cursor string) error { + if cursor == "" { + return nil + } + return os.WriteFile(path, []byte(cursor+"\n"), 0o600) +} diff --git a/agent/internal/logs/spool.go b/agent/internal/logs/spool.go new file mode 100644 index 0000000..82260b7 --- /dev/null +++ b/agent/internal/logs/spool.go @@ -0,0 +1,53 @@ +package logs + +import ( + "fmt" + "os" + "path/filepath" + "sort" + "time" +) + +func SaveSpool(dir string, payload []byte) error { + if err := os.MkdirAll(dir, 0o700); err != nil { + return err + } + name := fmt.Sprintf("%d.json", time.Now().UnixNano()) + tmp := filepath.Join(dir, name+".tmp") + final := filepath.Join(dir, name) + if err := os.WriteFile(tmp, payload, 0o600); err != nil { + return err + } + return os.Rename(tmp, final) +} + +func DrainSpool(dir string, send func([]byte) error) error { + entries, err := os.ReadDir(dir) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return err + } + var files []string + for _, entry := range entries { + if entry.IsDir() { + continue + } + files = append(files, filepath.Join(dir, entry.Name())) + } + sort.Strings(files) + for _, path := range files { + data, err := os.ReadFile(path) + if err != nil { + return err + } + if err := send(data); err != nil { + return err + } + if err := os.Remove(path); err != nil { + return err + } + } + return nil +} diff --git a/agent/internal/logs/types.go b/agent/internal/logs/types.go new file mode 100644 index 0000000..db6e961 --- /dev/null +++ b/agent/internal/logs/types.go @@ -0,0 +1,23 @@ +package logs + +import "time" + +type Event struct { + Timestamp string `json:"timestamp"` + Category string `json:"category"` + EventType string `json:"event_type"` + Unit string `json:"unit,omitempty"` + Priority string `json:"priority,omitempty"` + Hostname string `json:"hostname,omitempty"` + Username string `json:"username,omitempty"` + Principal string `json:"principal,omitempty"` + SourceIP string `json:"source_ip,omitempty"` + SessionID string `json:"session_id,omitempty"` + Message string `json:"message,omitempty"` + Raw string `json:"raw,omitempty"` + Fields map[string]string `json:"fields,omitempty"` +} + +func NewEvent(ts time.Time) Event { + return Event{Timestamp: ts.UTC().Format(time.RFC3339Nano)} +} diff --git a/agent/internal/version/version.go b/agent/internal/version/version.go new file mode 100644 index 0000000..c126426 --- /dev/null +++ b/agent/internal/version/version.go @@ -0,0 +1,7 @@ +package version + +var ( + Version = "0.0.1-dev" + Commit = "" + BuildDate = "" +) diff --git a/agent/keywarden-agent b/agent/keywarden-agent new file mode 100755 index 0000000..db4bcb7 Binary files /dev/null and b/agent/keywarden-agent differ diff --git a/app/apps/servers/admin.py b/app/apps/servers/admin.py index c461aab..839c387 100644 --- a/app/apps/servers/admin.py +++ b/app/apps/servers/admin.py @@ -1,17 +1,27 @@ from django.contrib import admin -from guardian.admin import GuardedModelAdmin from django.utils.html import format_html -from .models import Server +from guardian.admin import GuardedModelAdmin + +from .models import AgentCertificateAuthority, EnrollmentToken, Server @admin.register(Server) class ServerAdmin(GuardedModelAdmin): - list_display = ("avatar", "display_name", "hostname", "ipv4", "ipv6", "created_at") + list_display = ("avatar", "display_name", "hostname", "ipv4", "ipv6", "agent_enrolled_at", "created_at") list_display_links = ("display_name",) search_fields = ("display_name", "hostname", "ipv4", "ipv6") list_filter = ("created_at",) - readonly_fields = ("created_at", "updated_at") - fields = ("display_name", "hostname", "ipv4", "ipv6", "image", "created_at", "updated_at") + readonly_fields = ("created_at", "updated_at", "agent_enrolled_at") + fields = ( + "display_name", + "hostname", + "ipv4", + "ipv6", + "image", + "agent_enrolled_at", + "created_at", + "updated_at", + ) def avatar(self, obj: Server): if obj.image_url: @@ -27,3 +37,52 @@ class ServerAdmin(GuardedModelAdmin): ) avatar.short_description = "" + +@admin.register(EnrollmentToken) +class EnrollmentTokenAdmin(admin.ModelAdmin): + list_display = ("token", "created_at", "expires_at", "used_at", "server") + list_filter = ("created_at", "used_at") + search_fields = ("token", "server__display_name", "server__hostname") + readonly_fields = ("token", "created_at", "used_at", "server", "created_by") + fields = ("token", "expires_at", "created_by", "created_at", "used_at", "server") + + def save_model(self, request, obj, form, change) -> None: + if not obj.pk: + obj.ensure_token() + if request.user and request.user.is_authenticated and not obj.created_by_id: + obj.created_by = request.user + super().save_model(request, obj, form, change) + + +@admin.register(AgentCertificateAuthority) +class AgentCertificateAuthorityAdmin(admin.ModelAdmin): + list_display = ("name", "is_active", "created_at", "revoked_at") + list_filter = ("is_active", "created_at", "revoked_at") + search_fields = ("name", "fingerprint") + readonly_fields = ("fingerprint", "serial", "created_at", "revoked_at", "created_by") + fields = ( + "name", + "is_active", + "cert_pem", + "key_pem", + "fingerprint", + "serial", + "created_by", + "created_at", + "revoked_at", + ) + actions = ["revoke_selected"] + + def save_model(self, request, obj, form, change) -> None: + if request.user and request.user.is_authenticated and not obj.created_by_id: + obj.created_by = request.user + obj.ensure_material() + if obj.is_active: + AgentCertificateAuthority.objects.exclude(pk=obj.pk).update(is_active=False) + super().save_model(request, obj, form, change) + + @admin.action(description="Revoke selected CAs") + def revoke_selected(self, request, queryset): + for ca in queryset: + ca.revoke() + ca.save(update_fields=["is_active", "revoked_at"]) diff --git a/app/apps/servers/migrations/0002_agent_enrollment.py b/app/apps/servers/migrations/0002_agent_enrollment.py new file mode 100644 index 0000000..9979d77 --- /dev/null +++ b/app/apps/servers/migrations/0002_agent_enrollment.py @@ -0,0 +1,73 @@ +from django.conf import settings +from django.db import migrations, models +import django.utils.timezone +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ("servers", "0001_initial"), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.AddField( + model_name="server", + name="agent_cert_fingerprint", + field=models.CharField(blank=True, max_length=128, null=True), + ), + migrations.AddField( + model_name="server", + name="agent_cert_serial", + field=models.CharField(blank=True, max_length=64, null=True), + ), + migrations.AddField( + model_name="server", + name="agent_enrolled_at", + field=models.DateTimeField(blank=True, null=True), + ), + migrations.CreateModel( + name="EnrollmentToken", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("token", models.CharField(max_length=128, unique=True)), + ("created_at", models.DateTimeField(default=django.utils.timezone.now, editable=False)), + ("expires_at", models.DateTimeField(blank=True, null=True)), + ("used_at", models.DateTimeField(blank=True, null=True)), + ( + "created_by", + models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="server_enrollment_tokens", + to=settings.AUTH_USER_MODEL, + ), + ), + ( + "server", + models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="enrollment_tokens", + to="servers.server", + ), + ), + ], + options={ + "verbose_name": "Enrollment token", + "verbose_name_plural": "Enrollment tokens", + "ordering": ["-created_at"], + }, + ), + migrations.AddIndex( + model_name="enrollmenttoken", + index=models.Index(fields=["created_at"], name="servers_enroll_created_idx"), + ), + migrations.AddIndex( + model_name="enrollmenttoken", + index=models.Index(fields=["used_at"], name="servers_enroll_used_idx"), + ), + ] diff --git a/app/apps/servers/migrations/0003_agent_ca.py b/app/apps/servers/migrations/0003_agent_ca.py new file mode 100644 index 0000000..5f4ecc9 --- /dev/null +++ b/app/apps/servers/migrations/0003_agent_ca.py @@ -0,0 +1,44 @@ +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion +import django.utils.timezone + + +class Migration(migrations.Migration): + + dependencies = [ + ("servers", "0002_agent_enrollment"), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.CreateModel( + name="AgentCertificateAuthority", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("name", models.CharField(default="Keywarden Agent CA", max_length=128)), + ("cert_pem", models.TextField()), + ("key_pem", models.TextField()), + ("fingerprint", models.CharField(blank=True, max_length=128)), + ("serial", models.CharField(blank=True, max_length=64)), + ("created_at", models.DateTimeField(default=django.utils.timezone.now, editable=False)), + ("revoked_at", models.DateTimeField(blank=True, null=True)), + ("is_active", models.BooleanField(db_index=True, default=True)), + ( + "created_by", + models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="agent_certificate_authorities", + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "verbose_name": "Agent certificate authority", + "verbose_name_plural": "Agent certificate authorities", + "ordering": ["-created_at"], + }, + ), + ] diff --git a/app/apps/servers/models.py b/app/apps/servers/models.py index 4ce6361..1efe9ed 100644 --- a/app/apps/servers/models.py +++ b/app/apps/servers/models.py @@ -1,8 +1,16 @@ from __future__ import annotations +import secrets +from datetime import datetime, timedelta + +from cryptography import x509 +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.x509.oid import NameOID +from django.conf import settings from django.core.validators import RegexValidator from django.db import models -from django.utils.text import slugify +from django.utils import timezone hostname_validator = RegexValidator( @@ -17,6 +25,9 @@ class Server(models.Model): ipv4 = models.GenericIPAddressField(null=True, blank=True, protocol="IPv4", unique=True) ipv6 = models.GenericIPAddressField(null=True, blank=True, protocol="IPv6", unique=True) image = models.ImageField(upload_to="servers/", null=True, blank=True) + agent_enrolled_at = models.DateTimeField(null=True, blank=True) + agent_cert_fingerprint = models.CharField(max_length=128, null=True, blank=True) + agent_cert_serial = models.CharField(max_length=64, null=True, blank=True) created_at = models.DateTimeField(auto_now_add=True) updated_at = models.DateTimeField(auto_now=True) @@ -41,3 +52,108 @@ class Server(models.Model): return (self.display_name or "?").strip()[:1].upper() or "?" +class EnrollmentToken(models.Model): + token = models.CharField(max_length=128, unique=True) + created_at = models.DateTimeField(default=timezone.now, editable=False) + expires_at = models.DateTimeField(null=True, blank=True) + created_by = models.ForeignKey( + settings.AUTH_USER_MODEL, + null=True, + blank=True, + on_delete=models.SET_NULL, + related_name="server_enrollment_tokens", + ) + used_at = models.DateTimeField(null=True, blank=True) + server = models.ForeignKey( + Server, null=True, blank=True, on_delete=models.SET_NULL, related_name="enrollment_tokens" + ) + + class Meta: + verbose_name = "Enrollment token" + verbose_name_plural = "Enrollment tokens" + indexes = [ + models.Index(fields=["created_at"], name="servers_enroll_created_idx"), + models.Index(fields=["used_at"], name="servers_enroll_used_idx"), + ] + ordering = ["-created_at"] + + def __str__(self) -> str: + return f"{self.token[:8]}... ({'used' if self.used_at else 'unused'})" + + def ensure_token(self) -> None: + if not self.token: + self.token = secrets.token_urlsafe(32) + + def is_valid(self) -> bool: + if self.used_at: + return False + if self.expires_at and self.expires_at <= timezone.now(): + return False + return True + + def mark_used(self, server: Server) -> None: + self.used_at = timezone.now() + self.server = server + + def save(self, *args, **kwargs): + self.ensure_token() + super().save(*args, **kwargs) + + +class AgentCertificateAuthority(models.Model): + name = models.CharField(max_length=128, default="Keywarden Agent CA") + cert_pem = models.TextField() + key_pem = models.TextField() + fingerprint = models.CharField(max_length=128, blank=True) + serial = models.CharField(max_length=64, blank=True) + created_at = models.DateTimeField(default=timezone.now, editable=False) + revoked_at = models.DateTimeField(null=True, blank=True) + is_active = models.BooleanField(default=True, db_index=True) + created_by = models.ForeignKey( + settings.AUTH_USER_MODEL, + null=True, + blank=True, + on_delete=models.SET_NULL, + related_name="agent_certificate_authorities", + ) + + class Meta: + verbose_name = "Agent certificate authority" + verbose_name_plural = "Agent certificate authorities" + ordering = ["-created_at"] + + def __str__(self) -> str: + status = "active" if self.is_active and not self.revoked_at else "revoked" + return f"{self.name} ({status})" + + def revoke(self) -> None: + self.is_active = False + self.revoked_at = timezone.now() + + def ensure_material(self) -> None: + if self.cert_pem and self.key_pem: + return + key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + subject = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, self.name)]) + now = datetime.utcnow() + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(subject) + .public_key(key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(now - timedelta(minutes=5)) + .not_valid_after(now + timedelta(days=3650)) + .add_extension(x509.BasicConstraints(ca=True, path_length=None), critical=True) + .sign(key, hashes.SHA256()) + ) + cert_pem = cert.public_bytes(serialization.Encoding.PEM).decode("utf-8") + key_pem = key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ).decode("utf-8") + self.cert_pem = cert_pem + self.key_pem = key_pem + self.fingerprint = cert.fingerprint(hashes.SHA256()).hex() + self.serial = format(cert.serial_number, "x") diff --git a/app/keywarden/api/routers/access.py b/app/keywarden/api/routers/access.py index 82788f0..b06cdfe 100644 --- a/app/keywarden/api/routers/access.py +++ b/app/keywarden/api/routers/access.py @@ -70,7 +70,14 @@ def build_router() -> Router: @router.get("/", response=List[AccessRequestOut]) def list_requests(request: HttpRequest, filters: AccessQuery = Query(...)): - """List access requests for the user, or all if admin/operator.""" + """List access requests with pagination and filters. + + Auth: required. + Permissions: + - If user has global `access.view_accessrequest`, returns all requests. + - Otherwise, returns only objects with `access.view_accessrequest` object permission. + Filters: status, server_id, requester_id (requester_id is honored only with global view). + """ require_authenticated(request) user = request.user if _has_global_perm(request, "access.view_accessrequest"): @@ -94,7 +101,12 @@ def build_router() -> Router: @router.post("/", response=AccessRequestOut) def create_request(request: HttpRequest, payload: AccessRequestCreateIn): - """Create a new access request for a server.""" + """Create a new access request for the current user. + + Auth: required. + Permissions: requires global `access.add_accessrequest`. + Side effects: grants owner object perms on the new request. + """ require_authenticated(request) if not request.user.has_perm("access.add_accessrequest"): raise HttpError(403, "Forbidden") @@ -116,7 +128,11 @@ def build_router() -> Router: @router.get("/{request_id}", response=AccessRequestOut) def get_request(request: HttpRequest, request_id: int): - """Get an access request if permitted.""" + """Get a single access request by id. + + Auth: required. + Permissions: requires `access.view_accessrequest` on the object. + """ require_authenticated(request) try: access_request = AccessRequest.objects.get(id=request_id) @@ -128,7 +144,15 @@ def build_router() -> Router: @router.patch("/{request_id}", response=AccessRequestOut) def update_request(request: HttpRequest, request_id: int, payload: AccessRequestUpdateIn): - """Update request status or expiry (admin/operator or owner with restrictions).""" + """Update request status or expiry. + + Auth: required. + Permissions: requires `access.change_accessrequest` on the object. + Rules: + - Admin/operator (global change) can set status to approved/denied/revoked/cancelled and + update expires_at. + - Non-admin can only set status to cancelled, and only while pending. + """ require_authenticated(request) try: access_request = AccessRequest.objects.get(id=request_id) @@ -171,7 +195,11 @@ def build_router() -> Router: @router.delete("/{request_id}", response={204: None}) def delete_request(request: HttpRequest, request_id: int): - """Delete an access request if permitted.""" + """Delete an access request. + + Auth: required. + Permissions: requires `access.delete_accessrequest` on the object. + """ require_authenticated(request) try: access_request = AccessRequest.objects.get(id=request_id) diff --git a/app/keywarden/api/routers/agent.py b/app/keywarden/api/routers/agent.py index 39c24db..7ba0115 100644 --- a/app/keywarden/api/routers/agent.py +++ b/app/keywarden/api/routers/agent.py @@ -1,7 +1,13 @@ from __future__ import annotations +from datetime import datetime, timedelta from typing import List, Optional +from cryptography import x509 +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.x509.oid import ExtendedKeyUsageOID, NameOID +from django.conf import settings +from django.core.exceptions import ValidationError from django.db import models from django.http import HttpRequest from django.utils import timezone @@ -12,7 +18,7 @@ from pydantic import Field from apps.core.rbac import require_perms from apps.access.models import AccessRequest from apps.keys.models import SSHKey -from apps.servers.models import Server +from apps.servers.models import AgentCertificateAuthority, EnrollmentToken, Server, hostname_validator from apps.telemetry.models import TelemetryEvent @@ -35,9 +41,82 @@ class SyncReportOut(Schema): status: str +class AgentEnrollIn(Schema): + token: str + csr_pem: str + host: Optional[str] = None + + +class AgentEnrollOut(Schema): + server_id: str + client_cert_pem: str + ca_cert_pem: str + + +class LogEventIn(Schema): + timestamp: str + category: str + event_type: str + unit: Optional[str] = None + priority: Optional[str] = None + hostname: Optional[str] = None + username: Optional[str] = None + principal: Optional[str] = None + source_ip: Optional[str] = None + session_id: Optional[str] = None + message: Optional[str] = None + raw: Optional[str] = None + fields: Optional[dict] = None + + +class LogIngestOut(Schema): + status: str + accepted: int + + def build_router() -> Router: router = Router() + @router.post("/enroll", response=AgentEnrollOut, auth=None) + def enroll_agent(request: HttpRequest, payload: AgentEnrollIn): + """Enroll a server agent using a one-time token.""" + token_value = (payload.token or "").strip() + if not token_value: + raise HttpError(422, "Token required") + try: + token = EnrollmentToken.objects.get(token=token_value) + except EnrollmentToken.DoesNotExist: + raise HttpError(403, "Invalid token") + if not token.is_valid(): + raise HttpError(403, "Token expired or already used") + + host = (payload.host or "").strip()[:253] + display_name = host or "server" + hostname = None + if host: + try: + hostname_validator(host) + hostname = host + except ValidationError: + hostname = None + + server = Server.objects.create(display_name=display_name, hostname=hostname) + token.mark_used(server) + token.save(update_fields=["used_at", "server"]) + + csr = _load_csr((payload.csr_pem or "").strip()) + cert_pem, ca_pem, fingerprint, serial = _issue_client_cert(csr, host, server.id) + server.agent_enrolled_at = timezone.now() + server.agent_cert_fingerprint = fingerprint + server.agent_cert_serial = serial + server.save(update_fields=["agent_enrolled_at", "agent_cert_fingerprint", "agent_cert_serial"]) + + return AgentEnrollOut( + server_id=str(server.id), + client_cert_pem=cert_pem, + ca_cert_pem=ca_pem, + ) + @router.get("/servers/{server_id}/authorized-keys", response=List[AuthorizedKeyOut]) def authorized_keys(request: HttpRequest, server_id: int): """Return authorized public keys for a server (admin or operator).""" @@ -96,7 +175,75 @@ def build_router() -> Router: ) return SyncReportOut(status="ok") + @router.post("/servers/{server_id}/logs", response=LogIngestOut, auth=None) + def ingest_logs(request: HttpRequest, server_id: int, payload: List[LogEventIn]): + """Accept log batches from agents (mTLS required at the edge).""" + try: + Server.objects.get(id=server_id) + except Server.DoesNotExist: + raise HttpError(404, "Server not found") + # TODO: enqueue to Valkey and persist to SQLite slices. + return LogIngestOut(status="accepted", accepted=len(payload)) + return router +def _load_agent_ca() -> tuple[x509.Certificate, object, str]: + ca = ( + AgentCertificateAuthority.objects.filter(is_active=True, revoked_at__isnull=True) + .order_by("-created_at") + .first() + ) + if not ca: + raise HttpError(500, "Agent CA not configured") + try: + ca_cert = x509.load_pem_x509_certificate(ca.cert_pem.encode("utf-8")) + ca_key = serialization.load_pem_private_key(ca.key_pem.encode("utf-8"), password=None) + except (ValueError, TypeError): + raise HttpError(500, "Invalid agent CA material") + return ca_cert, ca_key, ca.cert_pem + + +def _load_csr(csr_pem: str) -> x509.CertificateSigningRequest: + try: + csr = x509.load_pem_x509_csr(csr_pem.encode("utf-8")) + except ValueError: + raise HttpError(422, "Invalid CSR") + if not csr.is_signature_valid: + raise HttpError(422, "Invalid CSR signature") + return csr + + +def _issue_client_cert( + csr: x509.CertificateSigningRequest, host: str | None, server_id: int +) -> tuple[str, str, str, str]: + ca_cert, ca_key, ca_pem = _load_agent_ca() + now = datetime.utcnow() + subject = csr.subject + if len(subject) == 0: + subject = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, f"keywarden-agent-{server_id}")]) + builder = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(ca_cert.subject) + .public_key(csr.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(now - timedelta(minutes=5)) + .not_valid_after(now + timedelta(days=settings.KEYWARDEN_AGENT_CERT_VALIDITY_DAYS)) + .add_extension(x509.BasicConstraints(ca=False, path_length=None), critical=True) + .add_extension(x509.ExtendedKeyUsage([ExtendedKeyUsageOID.CLIENT_AUTH]), critical=False) + ) + if host: + try: + hostname_validator(host) + builder = builder.add_extension(x509.SubjectAlternativeName([x509.DNSName(host)]), critical=False) + except ValidationError: + pass + cert = builder.sign(private_key=ca_key, algorithm=hashes.SHA256()) + cert_pem = cert.public_bytes(serialization.Encoding.PEM).decode("utf-8") + fingerprint = cert.fingerprint(hashes.SHA256()).hex() + serial = format(cert.serial_number, "x") + return cert_pem, ca_pem, fingerprint, serial + + router = build_router() diff --git a/app/keywarden/api/routers/keys.py b/app/keywarden/api/routers/keys.py index b9011e1..bba28a8 100644 --- a/app/keywarden/api/routers/keys.py +++ b/app/keywarden/api/routers/keys.py @@ -69,7 +69,14 @@ def build_router() -> Router: @router.get("/", response=List[KeyOut]) def list_keys(request: HttpRequest, filters: KeysQuery = Query(...)): - """List SSH keys for the current user, or any user if admin/operator.""" + """List SSH keys with pagination and filters. + + Auth: required. + Permissions: + - If user has global `keys.view_sshkey`, returns all keys. + - Otherwise, returns only objects with `keys.view_sshkey` object permission. + Filter: user_id (honored only with global view). + """ require_authenticated(request) user = request.user if _has_global_perm(request, "keys.view_sshkey"): @@ -89,7 +96,15 @@ def build_router() -> Router: @router.post("/", response=KeyOut) def create_key(request: HttpRequest, payload: KeyCreateIn): - """Create an SSH public key for the current user (admin/operator can specify user_id).""" + """Create an SSH public key. + + Auth: required. + Permissions: requires global `keys.add_sshkey`. + Rules: + - Default owner is the current user. + - If caller has global `keys.add_sshkey` and `keys.view_sshkey`, they may specify user_id. + Side effects: grants owner object perms on the new key. + """ require_authenticated(request) if not request.user.has_perm("keys.add_sshkey"): raise HttpError(403, "Forbidden") @@ -121,7 +136,11 @@ def build_router() -> Router: @router.get("/{key_id}", response=KeyOut) def get_key(request: HttpRequest, key_id: int): - """Get a specific SSH key if permitted.""" + """Get a specific SSH key by id. + + Auth: required. + Permissions: requires `keys.view_sshkey` on the object. + """ require_authenticated(request) try: key = SSHKey.objects.get(id=key_id) @@ -133,7 +152,11 @@ def build_router() -> Router: @router.patch("/{key_id}", response=KeyOut) def update_key(request: HttpRequest, key_id: int, payload: KeyUpdateIn): - """Update key name or active state if permitted.""" + """Update key name or active state. + + Auth: required. + Permissions: requires `keys.change_sshkey` on the object. + """ require_authenticated(request) try: key = SSHKey.objects.get(id=key_id) @@ -159,7 +182,12 @@ def build_router() -> Router: @router.delete("/{key_id}", response={204: None}) def delete_key(request: HttpRequest, key_id: int): - """Revoke an SSH key if permitted (soft delete).""" + """Revoke (soft delete) an SSH key. + + Auth: required. + Permissions: requires `keys.delete_sshkey` on the object. + Behavior: sets is_active false and revoked_at if key is active. + """ require_authenticated(request) try: key = SSHKey.objects.get(id=key_id) diff --git a/app/keywarden/api/routers/servers.py b/app/keywarden/api/routers/servers.py index ddb5374..47348d7 100644 --- a/app/keywarden/api/routers/servers.py +++ b/app/keywarden/api/routers/servers.py @@ -78,21 +78,7 @@ def build_router() -> Router: def create_server_json(request: HttpRequest, payload: ServerCreate): """Create a server using JSON payload (admin only).""" require_perms(request, "servers.add_server") - server = Server.objects.create( - display_name=payload.display_name.strip(), - hostname=(payload.hostname or "").strip() or None, - ipv4=(payload.ipv4 or "").strip() or None, - ipv6=(payload.ipv6 or "").strip() or None, - ) - return { - "id": server.id, - "display_name": server.display_name, - "hostname": server.hostname, - "ipv4": server.ipv4, - "ipv6": server.ipv6, - "image_url": server.image_url, - "initial": server.initial, - } + raise HttpError(403, "Servers are created via agent enrollment tokens.") @router.post("/upload", response=ServerOut) def create_server_multipart( @@ -105,24 +91,7 @@ def build_router() -> Router: ): """Create a server with optional image upload (admin only).""" require_perms(request, "servers.add_server") - server = Server( - display_name=display_name.strip(), - hostname=(hostname or "").strip() or None, - ipv4=(ipv4 or "").strip() or None, - ipv6=(ipv6 or "").strip() or None, - ) - if image: - server.image.save(image.name, image) # type: ignore[arg-type] - server.save() - return { - "id": server.id, - "display_name": server.display_name, - "hostname": server.hostname, - "ipv4": server.ipv4, - "ipv6": server.ipv6, - "image_url": server.image_url, - "initial": server.initial, - } + raise HttpError(403, "Servers are created via agent enrollment tokens.") @router.patch("/{server_id}", response=ServerOut) def update_server(request: HttpRequest, server_id: int, payload: ServerUpdate): diff --git a/app/keywarden/settings/base.py b/app/keywarden/settings/base.py index e08337a..c0524c2 100644 --- a/app/keywarden/settings/base.py +++ b/app/keywarden/settings/base.py @@ -94,6 +94,8 @@ CACHES = { SESSION_ENGINE = "django.contrib.sessions.backends.cache" SESSION_CACHE_ALIAS = "default" +KEYWARDEN_AGENT_CERT_VALIDITY_DAYS = int(os.getenv("KEYWARDEN_AGENT_CERT_VALIDITY_DAYS", "90")) + PASSWORD_HASHERS = [ "django.contrib.auth.hashers.Argon2PasswordHasher", "django.contrib.auth.hashers.PBKDF2PasswordHasher",