Initial linux agent and api functionality for enrolling servers
This commit is contained in:
@@ -21,6 +21,7 @@ KEYWARDEN_ADMIN_PASSWORD=password
|
|||||||
# Auth mode: native | oidc | hybrid
|
# Auth mode: native | oidc | hybrid
|
||||||
KEYWARDEN_AUTH_MODE=native
|
KEYWARDEN_AUTH_MODE=native
|
||||||
|
|
||||||
|
|
||||||
# OIDC (optional)
|
# OIDC (optional)
|
||||||
# KEYWARDEN_OIDC_CLIENT_ID=
|
# KEYWARDEN_OIDC_CLIENT_ID=
|
||||||
# KEYWARDEN_OIDC_CLIENT_SECRET=
|
# KEYWARDEN_OIDC_CLIENT_SECRET=
|
||||||
|
|||||||
23
agent/README.md
Normal file
23
agent/README.md
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
# keywarden-agent
|
||||||
|
|
||||||
|
Minimal Go agent scaffold for Keywarden.
|
||||||
|
|
||||||
|
## Build
|
||||||
|
|
||||||
|
```
|
||||||
|
go build -o keywarden-agent ./cmd/keywarden-agent
|
||||||
|
```
|
||||||
|
|
||||||
|
## Run
|
||||||
|
|
||||||
|
```
|
||||||
|
./keywarden-agent -config /etc/keywarden/agent.json -server-url https://keywarden.example.com -enroll-token <token>
|
||||||
|
```
|
||||||
|
|
||||||
|
You can also pass `KEYWARDEN_SERVER_URL` and `KEYWARDEN_ENROLL_TOKEN` as environment variables.
|
||||||
|
|
||||||
|
## Config
|
||||||
|
|
||||||
|
On first boot, the agent will create a config file if it does not exist. Only `server_url` is required for bootstrapping.
|
||||||
|
|
||||||
|
See `config.example.json`.
|
||||||
223
agent/cmd/keywarden-agent/main.go
Normal file
223
agent/cmd/keywarden-agent/main.go
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"encoding/json"
|
||||||
|
"encoding/pem"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"keywarden/agent/internal/client"
|
||||||
|
"keywarden/agent/internal/config"
|
||||||
|
"keywarden/agent/internal/logs"
|
||||||
|
"keywarden/agent/internal/version"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
configPath := flag.String("config", config.DefaultConfigPath, "Path to agent config JSON")
|
||||||
|
serverURL := flag.String("server-url", "", "Keywarden server URL (first boot)")
|
||||||
|
enrollToken := flag.String("enroll-token", "", "Enrollment token (first boot)")
|
||||||
|
showVersion := flag.Bool("version", false, "Print version and exit")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
if *showVersion {
|
||||||
|
fmt.Printf("keywarden-agent %s (commit %s, built %s)\n", version.Version, version.Commit, version.BuildDate)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := config.LoadOrInit(*configPath, pickServerURL(*serverURL))
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("config error: %v", err)
|
||||||
|
}
|
||||||
|
if err := ensureDirs(cfg); err != nil {
|
||||||
|
log.Fatalf("state dir error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := bootstrapIfNeeded(cfg, *configPath, pickEnrollToken(*enrollToken)); err != nil {
|
||||||
|
log.Fatalf("bootstrap error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
apiClient, err := client.New(cfg)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("client error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
interval := time.Duration(cfg.SyncIntervalSeconds) * time.Second
|
||||||
|
log.Printf("keywarden-agent started: server_id=%s interval=%s", cfg.ServerID, interval)
|
||||||
|
|
||||||
|
ticker := time.NewTicker(interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
runOnce(ctx, apiClient, cfg)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
log.Printf("shutdown requested")
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
runOnce(ctx, apiClient, cfg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func runOnce(ctx context.Context, apiClient *client.Client, cfg *config.Config) {
|
||||||
|
if err := apiClient.SyncAccounts(ctx, cfg.ServerID); err != nil {
|
||||||
|
log.Printf("sync accounts error: %v", err)
|
||||||
|
}
|
||||||
|
if err := shipLogs(ctx, apiClient, cfg); err != nil {
|
||||||
|
log.Printf("log shipping error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ensureDirs(cfg *config.Config) error {
|
||||||
|
if err := os.MkdirAll(cfg.StateDir, 0o700); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := os.MkdirAll(cfg.LogSpoolDir(), 0o700); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func shipLogs(ctx context.Context, apiClient *client.Client, cfg *config.Config) error {
|
||||||
|
send := func(payload []byte) error {
|
||||||
|
return apiClient.SendLogBatch(ctx, cfg.ServerID, payload)
|
||||||
|
}
|
||||||
|
if err := logs.DrainSpool(cfg.LogSpoolDir(), send); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cursor, err := logs.ReadCursor(cfg.LogCursorPath())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
collector := logs.NewCollector()
|
||||||
|
events, nextCursor, err := collector.Collect(ctx, cursor, cfg.LogBatchSize)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if len(events) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
payload, err := json.Marshal(events)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := send(payload); err != nil {
|
||||||
|
if spoolErr := logs.SaveSpool(cfg.LogSpoolDir(), payload); spoolErr != nil {
|
||||||
|
return spoolErr
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := logs.WriteCursor(cfg.LogCursorPath(), nextCursor); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func pickServerURL(flagValue string) string {
|
||||||
|
if flagValue != "" {
|
||||||
|
return flagValue
|
||||||
|
}
|
||||||
|
return os.Getenv("KEYWARDEN_SERVER_URL")
|
||||||
|
}
|
||||||
|
|
||||||
|
func pickEnrollToken(flagValue string) string {
|
||||||
|
if flagValue != "" {
|
||||||
|
return flagValue
|
||||||
|
}
|
||||||
|
return os.Getenv("KEYWARDEN_ENROLL_TOKEN")
|
||||||
|
}
|
||||||
|
|
||||||
|
func bootstrapIfNeeded(cfg *config.Config, configPath string, enrollToken string) error {
|
||||||
|
if cfg.ServerID != "" && fileExists(cfg.ClientCertPath()) && fileExists(cfg.CACertPath()) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if enrollToken == "" {
|
||||||
|
return fmt.Errorf("missing enrollment token; set KEYWARDEN_ENROLL_TOKEN or -enroll-token")
|
||||||
|
}
|
||||||
|
keyPath := cfg.ClientKeyPath()
|
||||||
|
if !fileExists(keyPath) {
|
||||||
|
if err := generateKey(keyPath); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
csrPEM, err := buildCSR(keyPath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
hostname, _ := os.Hostname()
|
||||||
|
resp, err := client.Enroll(context.Background(), cfg.ServerURL, client.EnrollRequest{
|
||||||
|
Token: enrollToken,
|
||||||
|
CSRPEM: csrPEM,
|
||||||
|
Host: hostname,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(cfg.ClientCertPath(), []byte(resp.ClientCert), 0o600); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(cfg.CACertPath(), []byte(resp.CACert), 0o600); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
cfg.ServerID = resp.ServerID
|
||||||
|
if err := config.Save(configPath, cfg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateKey(path string) error {
|
||||||
|
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
keyDER := x509.MarshalPKCS1PrivateKey(key)
|
||||||
|
block := &pem.Block{Type: "RSA PRIVATE KEY", Bytes: keyDER}
|
||||||
|
data := pem.EncodeToMemory(block)
|
||||||
|
return os.WriteFile(path, data, 0o600)
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildCSR(keyPath string) (string, error) {
|
||||||
|
keyData, err := os.ReadFile(keyPath)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
block, _ := pem.Decode(keyData)
|
||||||
|
if block == nil || block.Type != "RSA PRIVATE KEY" {
|
||||||
|
return "", fmt.Errorf("invalid private key")
|
||||||
|
}
|
||||||
|
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
csrTemplate := &x509.CertificateRequest{Subject: pkix.Name{CommonName: "keywarden-agent"}}
|
||||||
|
csrDER, err := x509.CreateCertificateRequest(rand.Reader, csrTemplate, key)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
csrBlock := &pem.Block{Type: "CERTIFICATE REQUEST", Bytes: csrDER}
|
||||||
|
return string(pem.EncodeToMemory(csrBlock)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func fileExists(path string) bool {
|
||||||
|
info, err := os.Stat(path)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return !info.IsDir()
|
||||||
|
}
|
||||||
14
agent/config.example.json
Normal file
14
agent/config.example.json
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
{
|
||||||
|
"server_url": "https://keywarden.example.com",
|
||||||
|
"server_id": "",
|
||||||
|
"sync_interval_seconds": 30,
|
||||||
|
"log_batch_size": 500,
|
||||||
|
"state_dir": "/var/lib/keywarden-agent",
|
||||||
|
"account_policy": {
|
||||||
|
"username_template": "{{username}}_{{user_id}}",
|
||||||
|
"default_shell": "/bin/bash",
|
||||||
|
"admin_group": "sudo",
|
||||||
|
"create_home": true,
|
||||||
|
"lock_on_revoke": true
|
||||||
|
}
|
||||||
|
}
|
||||||
7
agent/go.mod
Normal file
7
agent/go.mod
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
module keywarden/agent
|
||||||
|
|
||||||
|
go 1.22
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/coreos/go-systemd/v22 v22.5.0
|
||||||
|
)
|
||||||
3
agent/go.sum
Normal file
3
agent/go.sum
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs=
|
||||||
|
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||||
|
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||||
132
agent/internal/client/client.go
Normal file
132
agent/internal/client/client.go
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"keywarden/agent/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
const defaultTimeout = 15 * time.Second
|
||||||
|
|
||||||
|
type Client struct {
|
||||||
|
baseURL string
|
||||||
|
http *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(cfg *config.Config) (*Client, error) {
|
||||||
|
baseURL := strings.TrimRight(cfg.ServerURL, "/")
|
||||||
|
if baseURL == "" {
|
||||||
|
return nil, errors.New("server url is required")
|
||||||
|
}
|
||||||
|
cert, err := tls.LoadX509KeyPair(cfg.ClientCertPath(), cfg.ClientKeyPath())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("load client cert: %w", err)
|
||||||
|
}
|
||||||
|
caData, err := os.ReadFile(cfg.CACertPath())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read ca cert: %w", err)
|
||||||
|
}
|
||||||
|
caPool := x509.NewCertPool()
|
||||||
|
if !caPool.AppendCertsFromPEM(caData) {
|
||||||
|
return nil, errors.New("parse ca cert")
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{cert},
|
||||||
|
RootCAs: caPool,
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
}
|
||||||
|
|
||||||
|
transport := &http.Transport{
|
||||||
|
TLSClientConfig: tlsConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
httpClient := &http.Client{
|
||||||
|
Timeout: defaultTimeout,
|
||||||
|
Transport: transport,
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Client{baseURL: baseURL, http: httpClient}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type EnrollRequest struct {
|
||||||
|
Token string `json:"token"`
|
||||||
|
CSRPEM string `json:"csr_pem"`
|
||||||
|
Host string `json:"host"`
|
||||||
|
AgentID string `json:"agent_id,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EnrollResponse struct {
|
||||||
|
ServerID string `json:"server_id"`
|
||||||
|
ClientCert string `json:"client_cert_pem"`
|
||||||
|
CACert string `json:"ca_cert_pem"`
|
||||||
|
SyncProfile string `json:"sync_profile,omitempty"`
|
||||||
|
DisplayName string `json:"display_name,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func Enroll(ctx context.Context, serverURL string, req EnrollRequest) (*EnrollResponse, error) {
|
||||||
|
baseURL := strings.TrimRight(serverURL, "/")
|
||||||
|
if baseURL == "" {
|
||||||
|
return nil, errors.New("server url is required")
|
||||||
|
}
|
||||||
|
body, err := json.Marshal(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("encode enroll request: %w", err)
|
||||||
|
}
|
||||||
|
httpClient := &http.Client{Timeout: defaultTimeout}
|
||||||
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+"/agent/enroll", bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("build enroll request: %w", err)
|
||||||
|
}
|
||||||
|
httpReq.Header.Set("Content-Type", "application/json")
|
||||||
|
resp, err := httpClient.Do(httpReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("enroll request: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("enroll failed: status %s", resp.Status)
|
||||||
|
}
|
||||||
|
var out EnrollResponse
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
|
||||||
|
return nil, fmt.Errorf("decode enroll response: %w", err)
|
||||||
|
}
|
||||||
|
if out.ServerID == "" || out.ClientCert == "" || out.CACert == "" {
|
||||||
|
return nil, errors.New("enroll response missing required fields")
|
||||||
|
}
|
||||||
|
return &out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) SyncAccounts(ctx context.Context, serverID string) error {
|
||||||
|
_ = ctx
|
||||||
|
_ = serverID
|
||||||
|
// TODO: call API to fetch account policy + approved access list.
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) SendLogBatch(ctx context.Context, serverID string, payload []byte) error {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/agent/servers/"+serverID+"/logs", bytes.NewReader(payload))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("build log request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
resp, err := c.http.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("send log batch: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode >= 300 {
|
||||||
|
return fmt.Errorf("log batch failed: status %s", resp.Status)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
148
agent/internal/config/config.go
Normal file
148
agent/internal/config/config.go
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
DefaultConfigPath = "/etc/keywarden/agent.json"
|
||||||
|
DefaultStateDir = "/var/lib/keywarden-agent"
|
||||||
|
DefaultSyncIntervalSeconds = 30
|
||||||
|
DefaultLogBatchSize = 500
|
||||||
|
DefaultUsernameTemplate = "{{username}}_{{user_id}}"
|
||||||
|
DefaultShell = "/bin/bash"
|
||||||
|
DefaultAdminGroup = "sudo"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AccountPolicy struct {
|
||||||
|
UsernameTemplate string `json:"username_template"`
|
||||||
|
DefaultShell string `json:"default_shell"`
|
||||||
|
AdminGroup string `json:"admin_group"`
|
||||||
|
CreateHome bool `json:"create_home"`
|
||||||
|
LockOnRevoke bool `json:"lock_on_revoke"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
ServerURL string `json:"server_url"`
|
||||||
|
ServerID string `json:"server_id,omitempty"`
|
||||||
|
SyncIntervalSeconds int `json:"sync_interval_seconds,omitempty"`
|
||||||
|
LogBatchSize int `json:"log_batch_size,omitempty"`
|
||||||
|
StateDir string `json:"state_dir,omitempty"`
|
||||||
|
AccountPolicy AccountPolicy `json:"account_policy,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadOrInit(path string, serverURL string) (*Config, error) {
|
||||||
|
if path == "" {
|
||||||
|
path = DefaultConfigPath
|
||||||
|
}
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
if !errors.Is(err, os.ErrNotExist) {
|
||||||
|
return nil, fmt.Errorf("read config: %w", err)
|
||||||
|
}
|
||||||
|
if serverURL == "" {
|
||||||
|
return nil, errors.New("server url required for first boot")
|
||||||
|
}
|
||||||
|
cfg := &Config{ServerURL: serverURL}
|
||||||
|
applyDefaults(cfg)
|
||||||
|
if err := validate(cfg, false); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := Save(path, cfg); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return cfg, nil
|
||||||
|
}
|
||||||
|
cfg := &Config{}
|
||||||
|
if err := json.Unmarshal(data, cfg); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse config: %w", err)
|
||||||
|
}
|
||||||
|
applyDefaults(cfg)
|
||||||
|
if err := validate(cfg, false); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return cfg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func Save(path string, cfg *Config) error {
|
||||||
|
data, err := json.MarshalIndent(cfg, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("encode config: %w", err)
|
||||||
|
}
|
||||||
|
if err := os.MkdirAll(dir(path), 0o755); err != nil {
|
||||||
|
return fmt.Errorf("create config dir: %w", err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(path, data, 0o600); err != nil {
|
||||||
|
return fmt.Errorf("write config: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyDefaults(cfg *Config) {
|
||||||
|
if cfg.SyncIntervalSeconds <= 0 {
|
||||||
|
cfg.SyncIntervalSeconds = DefaultSyncIntervalSeconds
|
||||||
|
}
|
||||||
|
if cfg.LogBatchSize <= 0 {
|
||||||
|
cfg.LogBatchSize = DefaultLogBatchSize
|
||||||
|
}
|
||||||
|
if cfg.StateDir == "" {
|
||||||
|
cfg.StateDir = DefaultStateDir
|
||||||
|
}
|
||||||
|
if cfg.AccountPolicy.UsernameTemplate == "" {
|
||||||
|
cfg.AccountPolicy.UsernameTemplate = DefaultUsernameTemplate
|
||||||
|
}
|
||||||
|
if cfg.AccountPolicy.DefaultShell == "" {
|
||||||
|
cfg.AccountPolicy.DefaultShell = DefaultShell
|
||||||
|
}
|
||||||
|
if cfg.AccountPolicy.AdminGroup == "" {
|
||||||
|
cfg.AccountPolicy.AdminGroup = DefaultAdminGroup
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func validate(cfg *Config, requireServerID bool) error {
|
||||||
|
var missing []string
|
||||||
|
if cfg.ServerURL == "" {
|
||||||
|
missing = append(missing, "server_url")
|
||||||
|
}
|
||||||
|
if requireServerID && cfg.ServerID == "" {
|
||||||
|
missing = append(missing, "server_id")
|
||||||
|
}
|
||||||
|
if len(missing) > 0 {
|
||||||
|
return fmt.Errorf("missing required config fields: %v", missing)
|
||||||
|
}
|
||||||
|
if cfg.SyncIntervalSeconds < 5 {
|
||||||
|
return errors.New("sync_interval_seconds must be >= 5")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) ClientCertPath() string {
|
||||||
|
return c.StateDir + "/agent.crt"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) ClientKeyPath() string {
|
||||||
|
return c.StateDir + "/agent.key"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) CACertPath() string {
|
||||||
|
return c.StateDir + "/ca.crt"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) LogCursorPath() string {
|
||||||
|
return c.StateDir + "/journal.cursor"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) LogSpoolDir() string {
|
||||||
|
return c.StateDir + "/spool"
|
||||||
|
}
|
||||||
|
|
||||||
|
func dir(path string) string {
|
||||||
|
if idx := strings.LastIndex(path, string(os.PathSeparator)); idx != -1 {
|
||||||
|
return path[:idx]
|
||||||
|
}
|
||||||
|
return "."
|
||||||
|
}
|
||||||
177
agent/internal/logs/collector.go
Normal file
177
agent/internal/logs/collector.go
Normal file
@@ -0,0 +1,177 @@
|
|||||||
|
package logs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/coreos/go-systemd/v22/sdjournal"
|
||||||
|
)
|
||||||
|
|
||||||
|
const defaultLimit = 500
|
||||||
|
|
||||||
|
type Collector struct {
|
||||||
|
matches []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCollector() *Collector {
|
||||||
|
return &Collector{matches: defaultMatches()}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Collector) Collect(ctx context.Context, cursor string, limit int) ([]Event, string, error) {
|
||||||
|
if limit <= 0 {
|
||||||
|
limit = defaultLimit
|
||||||
|
}
|
||||||
|
j, err := sdjournal.NewJournal()
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
defer j.Close()
|
||||||
|
|
||||||
|
for i, match := range c.matches {
|
||||||
|
if i > 0 {
|
||||||
|
if err := j.AddDisjunction(); err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := j.AddMatch(match); err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if cursor != "" {
|
||||||
|
if err := j.SeekCursor(cursor); err == nil {
|
||||||
|
_, _ = j.Next()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
_ = j.SeekTail()
|
||||||
|
_, _ = j.Next()
|
||||||
|
}
|
||||||
|
|
||||||
|
var events []Event
|
||||||
|
var nextCursor string
|
||||||
|
for len(events) < limit {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return events, nextCursor, ctx.Err()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
n, err := j.Next()
|
||||||
|
if err != nil {
|
||||||
|
return events, nextCursor, err
|
||||||
|
}
|
||||||
|
if n == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
entry, err := j.GetEntry()
|
||||||
|
if err != nil {
|
||||||
|
return events, nextCursor, err
|
||||||
|
}
|
||||||
|
event := fromEntry(entry)
|
||||||
|
events = append(events, event)
|
||||||
|
nextCursor = entry.Cursor
|
||||||
|
}
|
||||||
|
|
||||||
|
return events, nextCursor, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultMatches() []string {
|
||||||
|
return []string{
|
||||||
|
"_SYSTEMD_UNIT=sshd.service",
|
||||||
|
"_SYSTEMD_UNIT=sudo.service",
|
||||||
|
"_SYSTEMD_UNIT=systemd-networkd.service",
|
||||||
|
"_SYSTEMD_UNIT=NetworkManager.service",
|
||||||
|
"_SYSTEMD_UNIT=systemd-logind.service",
|
||||||
|
"_TRANSPORT=kernel",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func fromEntry(entry *sdjournal.JournalEntry) Event {
|
||||||
|
ts := time.Unix(0, int64(entry.RealtimeTimestamp)*int64(time.Microsecond))
|
||||||
|
event := NewEvent(ts)
|
||||||
|
fields := entry.Fields
|
||||||
|
unit := fields["_SYSTEMD_UNIT"]
|
||||||
|
message := fields["MESSAGE"]
|
||||||
|
identifier := fields["SYSLOG_IDENTIFIER"]
|
||||||
|
|
||||||
|
event.Unit = unit
|
||||||
|
event.Message = message
|
||||||
|
event.Priority = fields["PRIORITY"]
|
||||||
|
event.Hostname = fields["_HOSTNAME"]
|
||||||
|
event.Fields = fields
|
||||||
|
|
||||||
|
event.Category = categorize(unit, identifier, fields)
|
||||||
|
event.EventType, event.Username, event.SourceIP, event.SessionID = parseMessage(event.Category, message)
|
||||||
|
if event.EventType == "" {
|
||||||
|
event.EventType = defaultEventType(event.Category)
|
||||||
|
}
|
||||||
|
return event
|
||||||
|
}
|
||||||
|
|
||||||
|
func categorize(unit string, identifier string, fields map[string]string) string {
|
||||||
|
switch {
|
||||||
|
case unit == "sshd.service" || identifier == "sshd":
|
||||||
|
return "access"
|
||||||
|
case unit == "sudo.service" || identifier == "sudo":
|
||||||
|
return "auth"
|
||||||
|
case unit == "systemd-networkd.service" || identifier == "NetworkManager":
|
||||||
|
return "network"
|
||||||
|
case fields["_TRANSPORT"] == "kernel":
|
||||||
|
return "system"
|
||||||
|
default:
|
||||||
|
return "system"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultEventType(category string) string {
|
||||||
|
switch category {
|
||||||
|
case "access":
|
||||||
|
return "ssh"
|
||||||
|
case "auth":
|
||||||
|
return "auth"
|
||||||
|
case "network":
|
||||||
|
return "network"
|
||||||
|
default:
|
||||||
|
return "system"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseMessage(category string, msg string) (eventType string, username string, sourceIP string, sessionID string) {
|
||||||
|
if msg == "" {
|
||||||
|
return "", "", "", ""
|
||||||
|
}
|
||||||
|
lower := strings.ToLower(msg)
|
||||||
|
if category == "access" {
|
||||||
|
switch {
|
||||||
|
case strings.Contains(lower, "accepted"):
|
||||||
|
eventType = "ssh.login.success"
|
||||||
|
username = extractBetween(msg, "for ", " from")
|
||||||
|
sourceIP = extractBetween(msg, "from ", " port")
|
||||||
|
case strings.Contains(lower, "failed password"):
|
||||||
|
eventType = "ssh.login.fail"
|
||||||
|
username = extractBetween(msg, "for ", " from")
|
||||||
|
sourceIP = extractBetween(msg, "from ", " port")
|
||||||
|
case strings.Contains(lower, "session opened"):
|
||||||
|
eventType = "ssh.session.open"
|
||||||
|
username = extractBetween(msg, "for user ", " by")
|
||||||
|
case strings.Contains(lower, "session closed"):
|
||||||
|
eventType = "ssh.session.close"
|
||||||
|
username = extractBetween(msg, "for user ", " by")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return eventType, strings.TrimSpace(username), strings.TrimSpace(sourceIP), strings.TrimSpace(sessionID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractBetween(msg string, start string, end string) string {
|
||||||
|
startIdx := strings.Index(msg, start)
|
||||||
|
if startIdx == -1 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
startIdx += len(start)
|
||||||
|
rest := msg[startIdx:]
|
||||||
|
endIdx := strings.Index(rest, end)
|
||||||
|
if endIdx == -1 {
|
||||||
|
return strings.TrimSpace(rest)
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(rest[:endIdx])
|
||||||
|
}
|
||||||
24
agent/internal/logs/cursor.go
Normal file
24
agent/internal/logs/cursor.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package logs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ReadCursor(path string) (string, error) {
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(string(data)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func WriteCursor(path string, cursor string) error {
|
||||||
|
if cursor == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return os.WriteFile(path, []byte(cursor+"\n"), 0o600)
|
||||||
|
}
|
||||||
53
agent/internal/logs/spool.go
Normal file
53
agent/internal/logs/spool.go
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
package logs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sort"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func SaveSpool(dir string, payload []byte) error {
|
||||||
|
if err := os.MkdirAll(dir, 0o700); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
name := fmt.Sprintf("%d.json", time.Now().UnixNano())
|
||||||
|
tmp := filepath.Join(dir, name+".tmp")
|
||||||
|
final := filepath.Join(dir, name)
|
||||||
|
if err := os.WriteFile(tmp, payload, 0o600); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return os.Rename(tmp, final)
|
||||||
|
}
|
||||||
|
|
||||||
|
func DrainSpool(dir string, send func([]byte) error) error {
|
||||||
|
entries, err := os.ReadDir(dir)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
var files []string
|
||||||
|
for _, entry := range entries {
|
||||||
|
if entry.IsDir() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
files = append(files, filepath.Join(dir, entry.Name()))
|
||||||
|
}
|
||||||
|
sort.Strings(files)
|
||||||
|
for _, path := range files {
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := send(data); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := os.Remove(path); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
23
agent/internal/logs/types.go
Normal file
23
agent/internal/logs/types.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package logs
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
type Event struct {
|
||||||
|
Timestamp string `json:"timestamp"`
|
||||||
|
Category string `json:"category"`
|
||||||
|
EventType string `json:"event_type"`
|
||||||
|
Unit string `json:"unit,omitempty"`
|
||||||
|
Priority string `json:"priority,omitempty"`
|
||||||
|
Hostname string `json:"hostname,omitempty"`
|
||||||
|
Username string `json:"username,omitempty"`
|
||||||
|
Principal string `json:"principal,omitempty"`
|
||||||
|
SourceIP string `json:"source_ip,omitempty"`
|
||||||
|
SessionID string `json:"session_id,omitempty"`
|
||||||
|
Message string `json:"message,omitempty"`
|
||||||
|
Raw string `json:"raw,omitempty"`
|
||||||
|
Fields map[string]string `json:"fields,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewEvent(ts time.Time) Event {
|
||||||
|
return Event{Timestamp: ts.UTC().Format(time.RFC3339Nano)}
|
||||||
|
}
|
||||||
7
agent/internal/version/version.go
Normal file
7
agent/internal/version/version.go
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
package version
|
||||||
|
|
||||||
|
var (
|
||||||
|
Version = "0.0.1-dev"
|
||||||
|
Commit = ""
|
||||||
|
BuildDate = ""
|
||||||
|
)
|
||||||
BIN
agent/keywarden-agent
Executable file
BIN
agent/keywarden-agent
Executable file
Binary file not shown.
@@ -1,17 +1,27 @@
|
|||||||
from django.contrib import admin
|
from django.contrib import admin
|
||||||
from guardian.admin import GuardedModelAdmin
|
|
||||||
from django.utils.html import format_html
|
from django.utils.html import format_html
|
||||||
from .models import Server
|
from guardian.admin import GuardedModelAdmin
|
||||||
|
|
||||||
|
from .models import AgentCertificateAuthority, EnrollmentToken, Server
|
||||||
|
|
||||||
|
|
||||||
@admin.register(Server)
|
@admin.register(Server)
|
||||||
class ServerAdmin(GuardedModelAdmin):
|
class ServerAdmin(GuardedModelAdmin):
|
||||||
list_display = ("avatar", "display_name", "hostname", "ipv4", "ipv6", "created_at")
|
list_display = ("avatar", "display_name", "hostname", "ipv4", "ipv6", "agent_enrolled_at", "created_at")
|
||||||
list_display_links = ("display_name",)
|
list_display_links = ("display_name",)
|
||||||
search_fields = ("display_name", "hostname", "ipv4", "ipv6")
|
search_fields = ("display_name", "hostname", "ipv4", "ipv6")
|
||||||
list_filter = ("created_at",)
|
list_filter = ("created_at",)
|
||||||
readonly_fields = ("created_at", "updated_at")
|
readonly_fields = ("created_at", "updated_at", "agent_enrolled_at")
|
||||||
fields = ("display_name", "hostname", "ipv4", "ipv6", "image", "created_at", "updated_at")
|
fields = (
|
||||||
|
"display_name",
|
||||||
|
"hostname",
|
||||||
|
"ipv4",
|
||||||
|
"ipv6",
|
||||||
|
"image",
|
||||||
|
"agent_enrolled_at",
|
||||||
|
"created_at",
|
||||||
|
"updated_at",
|
||||||
|
)
|
||||||
|
|
||||||
def avatar(self, obj: Server):
|
def avatar(self, obj: Server):
|
||||||
if obj.image_url:
|
if obj.image_url:
|
||||||
@@ -27,3 +37,52 @@ class ServerAdmin(GuardedModelAdmin):
|
|||||||
)
|
)
|
||||||
avatar.short_description = ""
|
avatar.short_description = ""
|
||||||
|
|
||||||
|
|
||||||
|
@admin.register(EnrollmentToken)
|
||||||
|
class EnrollmentTokenAdmin(admin.ModelAdmin):
|
||||||
|
list_display = ("token", "created_at", "expires_at", "used_at", "server")
|
||||||
|
list_filter = ("created_at", "used_at")
|
||||||
|
search_fields = ("token", "server__display_name", "server__hostname")
|
||||||
|
readonly_fields = ("token", "created_at", "used_at", "server", "created_by")
|
||||||
|
fields = ("token", "expires_at", "created_by", "created_at", "used_at", "server")
|
||||||
|
|
||||||
|
def save_model(self, request, obj, form, change) -> None:
|
||||||
|
if not obj.pk:
|
||||||
|
obj.ensure_token()
|
||||||
|
if request.user and request.user.is_authenticated and not obj.created_by_id:
|
||||||
|
obj.created_by = request.user
|
||||||
|
super().save_model(request, obj, form, change)
|
||||||
|
|
||||||
|
|
||||||
|
@admin.register(AgentCertificateAuthority)
|
||||||
|
class AgentCertificateAuthorityAdmin(admin.ModelAdmin):
|
||||||
|
list_display = ("name", "is_active", "created_at", "revoked_at")
|
||||||
|
list_filter = ("is_active", "created_at", "revoked_at")
|
||||||
|
search_fields = ("name", "fingerprint")
|
||||||
|
readonly_fields = ("fingerprint", "serial", "created_at", "revoked_at", "created_by")
|
||||||
|
fields = (
|
||||||
|
"name",
|
||||||
|
"is_active",
|
||||||
|
"cert_pem",
|
||||||
|
"key_pem",
|
||||||
|
"fingerprint",
|
||||||
|
"serial",
|
||||||
|
"created_by",
|
||||||
|
"created_at",
|
||||||
|
"revoked_at",
|
||||||
|
)
|
||||||
|
actions = ["revoke_selected"]
|
||||||
|
|
||||||
|
def save_model(self, request, obj, form, change) -> None:
|
||||||
|
if request.user and request.user.is_authenticated and not obj.created_by_id:
|
||||||
|
obj.created_by = request.user
|
||||||
|
obj.ensure_material()
|
||||||
|
if obj.is_active:
|
||||||
|
AgentCertificateAuthority.objects.exclude(pk=obj.pk).update(is_active=False)
|
||||||
|
super().save_model(request, obj, form, change)
|
||||||
|
|
||||||
|
@admin.action(description="Revoke selected CAs")
|
||||||
|
def revoke_selected(self, request, queryset):
|
||||||
|
for ca in queryset:
|
||||||
|
ca.revoke()
|
||||||
|
ca.save(update_fields=["is_active", "revoked_at"])
|
||||||
|
|||||||
73
app/apps/servers/migrations/0002_agent_enrollment.py
Normal file
73
app/apps/servers/migrations/0002_agent_enrollment.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
from django.conf import settings
|
||||||
|
from django.db import migrations, models
|
||||||
|
import django.utils.timezone
|
||||||
|
import django.db.models.deletion
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
("servers", "0001_initial"),
|
||||||
|
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="server",
|
||||||
|
name="agent_cert_fingerprint",
|
||||||
|
field=models.CharField(blank=True, max_length=128, null=True),
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="server",
|
||||||
|
name="agent_cert_serial",
|
||||||
|
field=models.CharField(blank=True, max_length=64, null=True),
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="server",
|
||||||
|
name="agent_enrolled_at",
|
||||||
|
field=models.DateTimeField(blank=True, null=True),
|
||||||
|
),
|
||||||
|
migrations.CreateModel(
|
||||||
|
name="EnrollmentToken",
|
||||||
|
fields=[
|
||||||
|
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||||
|
("token", models.CharField(max_length=128, unique=True)),
|
||||||
|
("created_at", models.DateTimeField(default=django.utils.timezone.now, editable=False)),
|
||||||
|
("expires_at", models.DateTimeField(blank=True, null=True)),
|
||||||
|
("used_at", models.DateTimeField(blank=True, null=True)),
|
||||||
|
(
|
||||||
|
"created_by",
|
||||||
|
models.ForeignKey(
|
||||||
|
blank=True,
|
||||||
|
null=True,
|
||||||
|
on_delete=django.db.models.deletion.SET_NULL,
|
||||||
|
related_name="server_enrollment_tokens",
|
||||||
|
to=settings.AUTH_USER_MODEL,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"server",
|
||||||
|
models.ForeignKey(
|
||||||
|
blank=True,
|
||||||
|
null=True,
|
||||||
|
on_delete=django.db.models.deletion.SET_NULL,
|
||||||
|
related_name="enrollment_tokens",
|
||||||
|
to="servers.server",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
options={
|
||||||
|
"verbose_name": "Enrollment token",
|
||||||
|
"verbose_name_plural": "Enrollment tokens",
|
||||||
|
"ordering": ["-created_at"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
migrations.AddIndex(
|
||||||
|
model_name="enrollmenttoken",
|
||||||
|
index=models.Index(fields=["created_at"], name="servers_enroll_created_idx"),
|
||||||
|
),
|
||||||
|
migrations.AddIndex(
|
||||||
|
model_name="enrollmenttoken",
|
||||||
|
index=models.Index(fields=["used_at"], name="servers_enroll_used_idx"),
|
||||||
|
),
|
||||||
|
]
|
||||||
44
app/apps/servers/migrations/0003_agent_ca.py
Normal file
44
app/apps/servers/migrations/0003_agent_ca.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
from django.conf import settings
|
||||||
|
from django.db import migrations, models
|
||||||
|
import django.db.models.deletion
|
||||||
|
import django.utils.timezone
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
("servers", "0002_agent_enrollment"),
|
||||||
|
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.CreateModel(
|
||||||
|
name="AgentCertificateAuthority",
|
||||||
|
fields=[
|
||||||
|
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||||
|
("name", models.CharField(default="Keywarden Agent CA", max_length=128)),
|
||||||
|
("cert_pem", models.TextField()),
|
||||||
|
("key_pem", models.TextField()),
|
||||||
|
("fingerprint", models.CharField(blank=True, max_length=128)),
|
||||||
|
("serial", models.CharField(blank=True, max_length=64)),
|
||||||
|
("created_at", models.DateTimeField(default=django.utils.timezone.now, editable=False)),
|
||||||
|
("revoked_at", models.DateTimeField(blank=True, null=True)),
|
||||||
|
("is_active", models.BooleanField(db_index=True, default=True)),
|
||||||
|
(
|
||||||
|
"created_by",
|
||||||
|
models.ForeignKey(
|
||||||
|
blank=True,
|
||||||
|
null=True,
|
||||||
|
on_delete=django.db.models.deletion.SET_NULL,
|
||||||
|
related_name="agent_certificate_authorities",
|
||||||
|
to=settings.AUTH_USER_MODEL,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
options={
|
||||||
|
"verbose_name": "Agent certificate authority",
|
||||||
|
"verbose_name_plural": "Agent certificate authorities",
|
||||||
|
"ordering": ["-created_at"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
@@ -1,8 +1,16 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import secrets
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
from cryptography import x509
|
||||||
|
from cryptography.hazmat.primitives import hashes, serialization
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||||
|
from cryptography.x509.oid import NameOID
|
||||||
|
from django.conf import settings
|
||||||
from django.core.validators import RegexValidator
|
from django.core.validators import RegexValidator
|
||||||
from django.db import models
|
from django.db import models
|
||||||
from django.utils.text import slugify
|
from django.utils import timezone
|
||||||
|
|
||||||
|
|
||||||
hostname_validator = RegexValidator(
|
hostname_validator = RegexValidator(
|
||||||
@@ -17,6 +25,9 @@ class Server(models.Model):
|
|||||||
ipv4 = models.GenericIPAddressField(null=True, blank=True, protocol="IPv4", unique=True)
|
ipv4 = models.GenericIPAddressField(null=True, blank=True, protocol="IPv4", unique=True)
|
||||||
ipv6 = models.GenericIPAddressField(null=True, blank=True, protocol="IPv6", unique=True)
|
ipv6 = models.GenericIPAddressField(null=True, blank=True, protocol="IPv6", unique=True)
|
||||||
image = models.ImageField(upload_to="servers/", null=True, blank=True)
|
image = models.ImageField(upload_to="servers/", null=True, blank=True)
|
||||||
|
agent_enrolled_at = models.DateTimeField(null=True, blank=True)
|
||||||
|
agent_cert_fingerprint = models.CharField(max_length=128, null=True, blank=True)
|
||||||
|
agent_cert_serial = models.CharField(max_length=64, null=True, blank=True)
|
||||||
created_at = models.DateTimeField(auto_now_add=True)
|
created_at = models.DateTimeField(auto_now_add=True)
|
||||||
updated_at = models.DateTimeField(auto_now=True)
|
updated_at = models.DateTimeField(auto_now=True)
|
||||||
|
|
||||||
@@ -41,3 +52,108 @@ class Server(models.Model):
|
|||||||
return (self.display_name or "?").strip()[:1].upper() or "?"
|
return (self.display_name or "?").strip()[:1].upper() or "?"
|
||||||
|
|
||||||
|
|
||||||
|
class EnrollmentToken(models.Model):
|
||||||
|
token = models.CharField(max_length=128, unique=True)
|
||||||
|
created_at = models.DateTimeField(default=timezone.now, editable=False)
|
||||||
|
expires_at = models.DateTimeField(null=True, blank=True)
|
||||||
|
created_by = models.ForeignKey(
|
||||||
|
settings.AUTH_USER_MODEL,
|
||||||
|
null=True,
|
||||||
|
blank=True,
|
||||||
|
on_delete=models.SET_NULL,
|
||||||
|
related_name="server_enrollment_tokens",
|
||||||
|
)
|
||||||
|
used_at = models.DateTimeField(null=True, blank=True)
|
||||||
|
server = models.ForeignKey(
|
||||||
|
Server, null=True, blank=True, on_delete=models.SET_NULL, related_name="enrollment_tokens"
|
||||||
|
)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
verbose_name = "Enrollment token"
|
||||||
|
verbose_name_plural = "Enrollment tokens"
|
||||||
|
indexes = [
|
||||||
|
models.Index(fields=["created_at"], name="servers_enroll_created_idx"),
|
||||||
|
models.Index(fields=["used_at"], name="servers_enroll_used_idx"),
|
||||||
|
]
|
||||||
|
ordering = ["-created_at"]
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"{self.token[:8]}... ({'used' if self.used_at else 'unused'})"
|
||||||
|
|
||||||
|
def ensure_token(self) -> None:
|
||||||
|
if not self.token:
|
||||||
|
self.token = secrets.token_urlsafe(32)
|
||||||
|
|
||||||
|
def is_valid(self) -> bool:
|
||||||
|
if self.used_at:
|
||||||
|
return False
|
||||||
|
if self.expires_at and self.expires_at <= timezone.now():
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def mark_used(self, server: Server) -> None:
|
||||||
|
self.used_at = timezone.now()
|
||||||
|
self.server = server
|
||||||
|
|
||||||
|
def save(self, *args, **kwargs):
|
||||||
|
self.ensure_token()
|
||||||
|
super().save(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentCertificateAuthority(models.Model):
|
||||||
|
name = models.CharField(max_length=128, default="Keywarden Agent CA")
|
||||||
|
cert_pem = models.TextField()
|
||||||
|
key_pem = models.TextField()
|
||||||
|
fingerprint = models.CharField(max_length=128, blank=True)
|
||||||
|
serial = models.CharField(max_length=64, blank=True)
|
||||||
|
created_at = models.DateTimeField(default=timezone.now, editable=False)
|
||||||
|
revoked_at = models.DateTimeField(null=True, blank=True)
|
||||||
|
is_active = models.BooleanField(default=True, db_index=True)
|
||||||
|
created_by = models.ForeignKey(
|
||||||
|
settings.AUTH_USER_MODEL,
|
||||||
|
null=True,
|
||||||
|
blank=True,
|
||||||
|
on_delete=models.SET_NULL,
|
||||||
|
related_name="agent_certificate_authorities",
|
||||||
|
)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
verbose_name = "Agent certificate authority"
|
||||||
|
verbose_name_plural = "Agent certificate authorities"
|
||||||
|
ordering = ["-created_at"]
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
status = "active" if self.is_active and not self.revoked_at else "revoked"
|
||||||
|
return f"{self.name} ({status})"
|
||||||
|
|
||||||
|
def revoke(self) -> None:
|
||||||
|
self.is_active = False
|
||||||
|
self.revoked_at = timezone.now()
|
||||||
|
|
||||||
|
def ensure_material(self) -> None:
|
||||||
|
if self.cert_pem and self.key_pem:
|
||||||
|
return
|
||||||
|
key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
|
||||||
|
subject = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, self.name)])
|
||||||
|
now = datetime.utcnow()
|
||||||
|
cert = (
|
||||||
|
x509.CertificateBuilder()
|
||||||
|
.subject_name(subject)
|
||||||
|
.issuer_name(subject)
|
||||||
|
.public_key(key.public_key())
|
||||||
|
.serial_number(x509.random_serial_number())
|
||||||
|
.not_valid_before(now - timedelta(minutes=5))
|
||||||
|
.not_valid_after(now + timedelta(days=3650))
|
||||||
|
.add_extension(x509.BasicConstraints(ca=True, path_length=None), critical=True)
|
||||||
|
.sign(key, hashes.SHA256())
|
||||||
|
)
|
||||||
|
cert_pem = cert.public_bytes(serialization.Encoding.PEM).decode("utf-8")
|
||||||
|
key_pem = key.private_bytes(
|
||||||
|
encoding=serialization.Encoding.PEM,
|
||||||
|
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
||||||
|
encryption_algorithm=serialization.NoEncryption(),
|
||||||
|
).decode("utf-8")
|
||||||
|
self.cert_pem = cert_pem
|
||||||
|
self.key_pem = key_pem
|
||||||
|
self.fingerprint = cert.fingerprint(hashes.SHA256()).hex()
|
||||||
|
self.serial = format(cert.serial_number, "x")
|
||||||
|
|||||||
@@ -70,7 +70,14 @@ def build_router() -> Router:
|
|||||||
|
|
||||||
@router.get("/", response=List[AccessRequestOut])
|
@router.get("/", response=List[AccessRequestOut])
|
||||||
def list_requests(request: HttpRequest, filters: AccessQuery = Query(...)):
|
def list_requests(request: HttpRequest, filters: AccessQuery = Query(...)):
|
||||||
"""List access requests for the user, or all if admin/operator."""
|
"""List access requests with pagination and filters.
|
||||||
|
|
||||||
|
Auth: required.
|
||||||
|
Permissions:
|
||||||
|
- If user has global `access.view_accessrequest`, returns all requests.
|
||||||
|
- Otherwise, returns only objects with `access.view_accessrequest` object permission.
|
||||||
|
Filters: status, server_id, requester_id (requester_id is honored only with global view).
|
||||||
|
"""
|
||||||
require_authenticated(request)
|
require_authenticated(request)
|
||||||
user = request.user
|
user = request.user
|
||||||
if _has_global_perm(request, "access.view_accessrequest"):
|
if _has_global_perm(request, "access.view_accessrequest"):
|
||||||
@@ -94,7 +101,12 @@ def build_router() -> Router:
|
|||||||
|
|
||||||
@router.post("/", response=AccessRequestOut)
|
@router.post("/", response=AccessRequestOut)
|
||||||
def create_request(request: HttpRequest, payload: AccessRequestCreateIn):
|
def create_request(request: HttpRequest, payload: AccessRequestCreateIn):
|
||||||
"""Create a new access request for a server."""
|
"""Create a new access request for the current user.
|
||||||
|
|
||||||
|
Auth: required.
|
||||||
|
Permissions: requires global `access.add_accessrequest`.
|
||||||
|
Side effects: grants owner object perms on the new request.
|
||||||
|
"""
|
||||||
require_authenticated(request)
|
require_authenticated(request)
|
||||||
if not request.user.has_perm("access.add_accessrequest"):
|
if not request.user.has_perm("access.add_accessrequest"):
|
||||||
raise HttpError(403, "Forbidden")
|
raise HttpError(403, "Forbidden")
|
||||||
@@ -116,7 +128,11 @@ def build_router() -> Router:
|
|||||||
|
|
||||||
@router.get("/{request_id}", response=AccessRequestOut)
|
@router.get("/{request_id}", response=AccessRequestOut)
|
||||||
def get_request(request: HttpRequest, request_id: int):
|
def get_request(request: HttpRequest, request_id: int):
|
||||||
"""Get an access request if permitted."""
|
"""Get a single access request by id.
|
||||||
|
|
||||||
|
Auth: required.
|
||||||
|
Permissions: requires `access.view_accessrequest` on the object.
|
||||||
|
"""
|
||||||
require_authenticated(request)
|
require_authenticated(request)
|
||||||
try:
|
try:
|
||||||
access_request = AccessRequest.objects.get(id=request_id)
|
access_request = AccessRequest.objects.get(id=request_id)
|
||||||
@@ -128,7 +144,15 @@ def build_router() -> Router:
|
|||||||
|
|
||||||
@router.patch("/{request_id}", response=AccessRequestOut)
|
@router.patch("/{request_id}", response=AccessRequestOut)
|
||||||
def update_request(request: HttpRequest, request_id: int, payload: AccessRequestUpdateIn):
|
def update_request(request: HttpRequest, request_id: int, payload: AccessRequestUpdateIn):
|
||||||
"""Update request status or expiry (admin/operator or owner with restrictions)."""
|
"""Update request status or expiry.
|
||||||
|
|
||||||
|
Auth: required.
|
||||||
|
Permissions: requires `access.change_accessrequest` on the object.
|
||||||
|
Rules:
|
||||||
|
- Admin/operator (global change) can set status to approved/denied/revoked/cancelled and
|
||||||
|
update expires_at.
|
||||||
|
- Non-admin can only set status to cancelled, and only while pending.
|
||||||
|
"""
|
||||||
require_authenticated(request)
|
require_authenticated(request)
|
||||||
try:
|
try:
|
||||||
access_request = AccessRequest.objects.get(id=request_id)
|
access_request = AccessRequest.objects.get(id=request_id)
|
||||||
@@ -171,7 +195,11 @@ def build_router() -> Router:
|
|||||||
|
|
||||||
@router.delete("/{request_id}", response={204: None})
|
@router.delete("/{request_id}", response={204: None})
|
||||||
def delete_request(request: HttpRequest, request_id: int):
|
def delete_request(request: HttpRequest, request_id: int):
|
||||||
"""Delete an access request if permitted."""
|
"""Delete an access request.
|
||||||
|
|
||||||
|
Auth: required.
|
||||||
|
Permissions: requires `access.delete_accessrequest` on the object.
|
||||||
|
"""
|
||||||
require_authenticated(request)
|
require_authenticated(request)
|
||||||
try:
|
try:
|
||||||
access_request = AccessRequest.objects.get(id=request_id)
|
access_request = AccessRequest.objects.get(id=request_id)
|
||||||
|
|||||||
@@ -1,7 +1,13 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timedelta
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from cryptography import x509
|
||||||
|
from cryptography.hazmat.primitives import hashes, serialization
|
||||||
|
from cryptography.x509.oid import ExtendedKeyUsageOID, NameOID
|
||||||
|
from django.conf import settings
|
||||||
|
from django.core.exceptions import ValidationError
|
||||||
from django.db import models
|
from django.db import models
|
||||||
from django.http import HttpRequest
|
from django.http import HttpRequest
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
@@ -12,7 +18,7 @@ from pydantic import Field
|
|||||||
from apps.core.rbac import require_perms
|
from apps.core.rbac import require_perms
|
||||||
from apps.access.models import AccessRequest
|
from apps.access.models import AccessRequest
|
||||||
from apps.keys.models import SSHKey
|
from apps.keys.models import SSHKey
|
||||||
from apps.servers.models import Server
|
from apps.servers.models import AgentCertificateAuthority, EnrollmentToken, Server, hostname_validator
|
||||||
from apps.telemetry.models import TelemetryEvent
|
from apps.telemetry.models import TelemetryEvent
|
||||||
|
|
||||||
|
|
||||||
@@ -35,9 +41,82 @@ class SyncReportOut(Schema):
|
|||||||
status: str
|
status: str
|
||||||
|
|
||||||
|
|
||||||
|
class AgentEnrollIn(Schema):
|
||||||
|
token: str
|
||||||
|
csr_pem: str
|
||||||
|
host: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class AgentEnrollOut(Schema):
|
||||||
|
server_id: str
|
||||||
|
client_cert_pem: str
|
||||||
|
ca_cert_pem: str
|
||||||
|
|
||||||
|
|
||||||
|
class LogEventIn(Schema):
|
||||||
|
timestamp: str
|
||||||
|
category: str
|
||||||
|
event_type: str
|
||||||
|
unit: Optional[str] = None
|
||||||
|
priority: Optional[str] = None
|
||||||
|
hostname: Optional[str] = None
|
||||||
|
username: Optional[str] = None
|
||||||
|
principal: Optional[str] = None
|
||||||
|
source_ip: Optional[str] = None
|
||||||
|
session_id: Optional[str] = None
|
||||||
|
message: Optional[str] = None
|
||||||
|
raw: Optional[str] = None
|
||||||
|
fields: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
|
class LogIngestOut(Schema):
|
||||||
|
status: str
|
||||||
|
accepted: int
|
||||||
|
|
||||||
|
|
||||||
def build_router() -> Router:
|
def build_router() -> Router:
|
||||||
router = Router()
|
router = Router()
|
||||||
|
|
||||||
|
@router.post("/enroll", response=AgentEnrollOut, auth=None)
|
||||||
|
def enroll_agent(request: HttpRequest, payload: AgentEnrollIn):
|
||||||
|
"""Enroll a server agent using a one-time token."""
|
||||||
|
token_value = (payload.token or "").strip()
|
||||||
|
if not token_value:
|
||||||
|
raise HttpError(422, "Token required")
|
||||||
|
try:
|
||||||
|
token = EnrollmentToken.objects.get(token=token_value)
|
||||||
|
except EnrollmentToken.DoesNotExist:
|
||||||
|
raise HttpError(403, "Invalid token")
|
||||||
|
if not token.is_valid():
|
||||||
|
raise HttpError(403, "Token expired or already used")
|
||||||
|
|
||||||
|
host = (payload.host or "").strip()[:253]
|
||||||
|
display_name = host or "server"
|
||||||
|
hostname = None
|
||||||
|
if host:
|
||||||
|
try:
|
||||||
|
hostname_validator(host)
|
||||||
|
hostname = host
|
||||||
|
except ValidationError:
|
||||||
|
hostname = None
|
||||||
|
|
||||||
|
server = Server.objects.create(display_name=display_name, hostname=hostname)
|
||||||
|
token.mark_used(server)
|
||||||
|
token.save(update_fields=["used_at", "server"])
|
||||||
|
|
||||||
|
csr = _load_csr((payload.csr_pem or "").strip())
|
||||||
|
cert_pem, ca_pem, fingerprint, serial = _issue_client_cert(csr, host, server.id)
|
||||||
|
server.agent_enrolled_at = timezone.now()
|
||||||
|
server.agent_cert_fingerprint = fingerprint
|
||||||
|
server.agent_cert_serial = serial
|
||||||
|
server.save(update_fields=["agent_enrolled_at", "agent_cert_fingerprint", "agent_cert_serial"])
|
||||||
|
|
||||||
|
return AgentEnrollOut(
|
||||||
|
server_id=str(server.id),
|
||||||
|
client_cert_pem=cert_pem,
|
||||||
|
ca_cert_pem=ca_pem,
|
||||||
|
)
|
||||||
|
|
||||||
@router.get("/servers/{server_id}/authorized-keys", response=List[AuthorizedKeyOut])
|
@router.get("/servers/{server_id}/authorized-keys", response=List[AuthorizedKeyOut])
|
||||||
def authorized_keys(request: HttpRequest, server_id: int):
|
def authorized_keys(request: HttpRequest, server_id: int):
|
||||||
"""Return authorized public keys for a server (admin or operator)."""
|
"""Return authorized public keys for a server (admin or operator)."""
|
||||||
@@ -96,7 +175,75 @@ def build_router() -> Router:
|
|||||||
)
|
)
|
||||||
return SyncReportOut(status="ok")
|
return SyncReportOut(status="ok")
|
||||||
|
|
||||||
|
@router.post("/servers/{server_id}/logs", response=LogIngestOut, auth=None)
|
||||||
|
def ingest_logs(request: HttpRequest, server_id: int, payload: List[LogEventIn]):
|
||||||
|
"""Accept log batches from agents (mTLS required at the edge)."""
|
||||||
|
try:
|
||||||
|
Server.objects.get(id=server_id)
|
||||||
|
except Server.DoesNotExist:
|
||||||
|
raise HttpError(404, "Server not found")
|
||||||
|
# TODO: enqueue to Valkey and persist to SQLite slices.
|
||||||
|
return LogIngestOut(status="accepted", accepted=len(payload))
|
||||||
|
|
||||||
return router
|
return router
|
||||||
|
|
||||||
|
|
||||||
|
def _load_agent_ca() -> tuple[x509.Certificate, object, str]:
|
||||||
|
ca = (
|
||||||
|
AgentCertificateAuthority.objects.filter(is_active=True, revoked_at__isnull=True)
|
||||||
|
.order_by("-created_at")
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if not ca:
|
||||||
|
raise HttpError(500, "Agent CA not configured")
|
||||||
|
try:
|
||||||
|
ca_cert = x509.load_pem_x509_certificate(ca.cert_pem.encode("utf-8"))
|
||||||
|
ca_key = serialization.load_pem_private_key(ca.key_pem.encode("utf-8"), password=None)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
raise HttpError(500, "Invalid agent CA material")
|
||||||
|
return ca_cert, ca_key, ca.cert_pem
|
||||||
|
|
||||||
|
|
||||||
|
def _load_csr(csr_pem: str) -> x509.CertificateSigningRequest:
|
||||||
|
try:
|
||||||
|
csr = x509.load_pem_x509_csr(csr_pem.encode("utf-8"))
|
||||||
|
except ValueError:
|
||||||
|
raise HttpError(422, "Invalid CSR")
|
||||||
|
if not csr.is_signature_valid:
|
||||||
|
raise HttpError(422, "Invalid CSR signature")
|
||||||
|
return csr
|
||||||
|
|
||||||
|
|
||||||
|
def _issue_client_cert(
|
||||||
|
csr: x509.CertificateSigningRequest, host: str | None, server_id: int
|
||||||
|
) -> tuple[str, str, str, str]:
|
||||||
|
ca_cert, ca_key, ca_pem = _load_agent_ca()
|
||||||
|
now = datetime.utcnow()
|
||||||
|
subject = csr.subject
|
||||||
|
if len(subject) == 0:
|
||||||
|
subject = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, f"keywarden-agent-{server_id}")])
|
||||||
|
builder = (
|
||||||
|
x509.CertificateBuilder()
|
||||||
|
.subject_name(subject)
|
||||||
|
.issuer_name(ca_cert.subject)
|
||||||
|
.public_key(csr.public_key())
|
||||||
|
.serial_number(x509.random_serial_number())
|
||||||
|
.not_valid_before(now - timedelta(minutes=5))
|
||||||
|
.not_valid_after(now + timedelta(days=settings.KEYWARDEN_AGENT_CERT_VALIDITY_DAYS))
|
||||||
|
.add_extension(x509.BasicConstraints(ca=False, path_length=None), critical=True)
|
||||||
|
.add_extension(x509.ExtendedKeyUsage([ExtendedKeyUsageOID.CLIENT_AUTH]), critical=False)
|
||||||
|
)
|
||||||
|
if host:
|
||||||
|
try:
|
||||||
|
hostname_validator(host)
|
||||||
|
builder = builder.add_extension(x509.SubjectAlternativeName([x509.DNSName(host)]), critical=False)
|
||||||
|
except ValidationError:
|
||||||
|
pass
|
||||||
|
cert = builder.sign(private_key=ca_key, algorithm=hashes.SHA256())
|
||||||
|
cert_pem = cert.public_bytes(serialization.Encoding.PEM).decode("utf-8")
|
||||||
|
fingerprint = cert.fingerprint(hashes.SHA256()).hex()
|
||||||
|
serial = format(cert.serial_number, "x")
|
||||||
|
return cert_pem, ca_pem, fingerprint, serial
|
||||||
|
|
||||||
|
|
||||||
router = build_router()
|
router = build_router()
|
||||||
|
|||||||
@@ -69,7 +69,14 @@ def build_router() -> Router:
|
|||||||
|
|
||||||
@router.get("/", response=List[KeyOut])
|
@router.get("/", response=List[KeyOut])
|
||||||
def list_keys(request: HttpRequest, filters: KeysQuery = Query(...)):
|
def list_keys(request: HttpRequest, filters: KeysQuery = Query(...)):
|
||||||
"""List SSH keys for the current user, or any user if admin/operator."""
|
"""List SSH keys with pagination and filters.
|
||||||
|
|
||||||
|
Auth: required.
|
||||||
|
Permissions:
|
||||||
|
- If user has global `keys.view_sshkey`, returns all keys.
|
||||||
|
- Otherwise, returns only objects with `keys.view_sshkey` object permission.
|
||||||
|
Filter: user_id (honored only with global view).
|
||||||
|
"""
|
||||||
require_authenticated(request)
|
require_authenticated(request)
|
||||||
user = request.user
|
user = request.user
|
||||||
if _has_global_perm(request, "keys.view_sshkey"):
|
if _has_global_perm(request, "keys.view_sshkey"):
|
||||||
@@ -89,7 +96,15 @@ def build_router() -> Router:
|
|||||||
|
|
||||||
@router.post("/", response=KeyOut)
|
@router.post("/", response=KeyOut)
|
||||||
def create_key(request: HttpRequest, payload: KeyCreateIn):
|
def create_key(request: HttpRequest, payload: KeyCreateIn):
|
||||||
"""Create an SSH public key for the current user (admin/operator can specify user_id)."""
|
"""Create an SSH public key.
|
||||||
|
|
||||||
|
Auth: required.
|
||||||
|
Permissions: requires global `keys.add_sshkey`.
|
||||||
|
Rules:
|
||||||
|
- Default owner is the current user.
|
||||||
|
- If caller has global `keys.add_sshkey` and `keys.view_sshkey`, they may specify user_id.
|
||||||
|
Side effects: grants owner object perms on the new key.
|
||||||
|
"""
|
||||||
require_authenticated(request)
|
require_authenticated(request)
|
||||||
if not request.user.has_perm("keys.add_sshkey"):
|
if not request.user.has_perm("keys.add_sshkey"):
|
||||||
raise HttpError(403, "Forbidden")
|
raise HttpError(403, "Forbidden")
|
||||||
@@ -121,7 +136,11 @@ def build_router() -> Router:
|
|||||||
|
|
||||||
@router.get("/{key_id}", response=KeyOut)
|
@router.get("/{key_id}", response=KeyOut)
|
||||||
def get_key(request: HttpRequest, key_id: int):
|
def get_key(request: HttpRequest, key_id: int):
|
||||||
"""Get a specific SSH key if permitted."""
|
"""Get a specific SSH key by id.
|
||||||
|
|
||||||
|
Auth: required.
|
||||||
|
Permissions: requires `keys.view_sshkey` on the object.
|
||||||
|
"""
|
||||||
require_authenticated(request)
|
require_authenticated(request)
|
||||||
try:
|
try:
|
||||||
key = SSHKey.objects.get(id=key_id)
|
key = SSHKey.objects.get(id=key_id)
|
||||||
@@ -133,7 +152,11 @@ def build_router() -> Router:
|
|||||||
|
|
||||||
@router.patch("/{key_id}", response=KeyOut)
|
@router.patch("/{key_id}", response=KeyOut)
|
||||||
def update_key(request: HttpRequest, key_id: int, payload: KeyUpdateIn):
|
def update_key(request: HttpRequest, key_id: int, payload: KeyUpdateIn):
|
||||||
"""Update key name or active state if permitted."""
|
"""Update key name or active state.
|
||||||
|
|
||||||
|
Auth: required.
|
||||||
|
Permissions: requires `keys.change_sshkey` on the object.
|
||||||
|
"""
|
||||||
require_authenticated(request)
|
require_authenticated(request)
|
||||||
try:
|
try:
|
||||||
key = SSHKey.objects.get(id=key_id)
|
key = SSHKey.objects.get(id=key_id)
|
||||||
@@ -159,7 +182,12 @@ def build_router() -> Router:
|
|||||||
|
|
||||||
@router.delete("/{key_id}", response={204: None})
|
@router.delete("/{key_id}", response={204: None})
|
||||||
def delete_key(request: HttpRequest, key_id: int):
|
def delete_key(request: HttpRequest, key_id: int):
|
||||||
"""Revoke an SSH key if permitted (soft delete)."""
|
"""Revoke (soft delete) an SSH key.
|
||||||
|
|
||||||
|
Auth: required.
|
||||||
|
Permissions: requires `keys.delete_sshkey` on the object.
|
||||||
|
Behavior: sets is_active false and revoked_at if key is active.
|
||||||
|
"""
|
||||||
require_authenticated(request)
|
require_authenticated(request)
|
||||||
try:
|
try:
|
||||||
key = SSHKey.objects.get(id=key_id)
|
key = SSHKey.objects.get(id=key_id)
|
||||||
|
|||||||
@@ -78,21 +78,7 @@ def build_router() -> Router:
|
|||||||
def create_server_json(request: HttpRequest, payload: ServerCreate):
|
def create_server_json(request: HttpRequest, payload: ServerCreate):
|
||||||
"""Create a server using JSON payload (admin only)."""
|
"""Create a server using JSON payload (admin only)."""
|
||||||
require_perms(request, "servers.add_server")
|
require_perms(request, "servers.add_server")
|
||||||
server = Server.objects.create(
|
raise HttpError(403, "Servers are created via agent enrollment tokens.")
|
||||||
display_name=payload.display_name.strip(),
|
|
||||||
hostname=(payload.hostname or "").strip() or None,
|
|
||||||
ipv4=(payload.ipv4 or "").strip() or None,
|
|
||||||
ipv6=(payload.ipv6 or "").strip() or None,
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"id": server.id,
|
|
||||||
"display_name": server.display_name,
|
|
||||||
"hostname": server.hostname,
|
|
||||||
"ipv4": server.ipv4,
|
|
||||||
"ipv6": server.ipv6,
|
|
||||||
"image_url": server.image_url,
|
|
||||||
"initial": server.initial,
|
|
||||||
}
|
|
||||||
|
|
||||||
@router.post("/upload", response=ServerOut)
|
@router.post("/upload", response=ServerOut)
|
||||||
def create_server_multipart(
|
def create_server_multipart(
|
||||||
@@ -105,24 +91,7 @@ def build_router() -> Router:
|
|||||||
):
|
):
|
||||||
"""Create a server with optional image upload (admin only)."""
|
"""Create a server with optional image upload (admin only)."""
|
||||||
require_perms(request, "servers.add_server")
|
require_perms(request, "servers.add_server")
|
||||||
server = Server(
|
raise HttpError(403, "Servers are created via agent enrollment tokens.")
|
||||||
display_name=display_name.strip(),
|
|
||||||
hostname=(hostname or "").strip() or None,
|
|
||||||
ipv4=(ipv4 or "").strip() or None,
|
|
||||||
ipv6=(ipv6 or "").strip() or None,
|
|
||||||
)
|
|
||||||
if image:
|
|
||||||
server.image.save(image.name, image) # type: ignore[arg-type]
|
|
||||||
server.save()
|
|
||||||
return {
|
|
||||||
"id": server.id,
|
|
||||||
"display_name": server.display_name,
|
|
||||||
"hostname": server.hostname,
|
|
||||||
"ipv4": server.ipv4,
|
|
||||||
"ipv6": server.ipv6,
|
|
||||||
"image_url": server.image_url,
|
|
||||||
"initial": server.initial,
|
|
||||||
}
|
|
||||||
|
|
||||||
@router.patch("/{server_id}", response=ServerOut)
|
@router.patch("/{server_id}", response=ServerOut)
|
||||||
def update_server(request: HttpRequest, server_id: int, payload: ServerUpdate):
|
def update_server(request: HttpRequest, server_id: int, payload: ServerUpdate):
|
||||||
|
|||||||
@@ -94,6 +94,8 @@ CACHES = {
|
|||||||
SESSION_ENGINE = "django.contrib.sessions.backends.cache"
|
SESSION_ENGINE = "django.contrib.sessions.backends.cache"
|
||||||
SESSION_CACHE_ALIAS = "default"
|
SESSION_CACHE_ALIAS = "default"
|
||||||
|
|
||||||
|
KEYWARDEN_AGENT_CERT_VALIDITY_DAYS = int(os.getenv("KEYWARDEN_AGENT_CERT_VALIDITY_DAYS", "90"))
|
||||||
|
|
||||||
PASSWORD_HASHERS = [
|
PASSWORD_HASHERS = [
|
||||||
"django.contrib.auth.hashers.Argon2PasswordHasher",
|
"django.contrib.auth.hashers.Argon2PasswordHasher",
|
||||||
"django.contrib.auth.hashers.PBKDF2PasswordHasher",
|
"django.contrib.auth.hashers.PBKDF2PasswordHasher",
|
||||||
|
|||||||
Reference in New Issue
Block a user