Compare commits

25 Commits
api-dev ... dev

Author SHA1 Message Date
70d0e808f8 Updated agent to include ping in heartbeat. 2026-02-03 15:24:11 +00:00
bebaaf1367 Refactor to Flowbite for UI 2026-02-03 09:54:49 +00:00
962ba27679 Commented terminal files 2026-02-03 09:33:49 +00:00
f54cc3f09b Changed ephemeral key to 30m lifespan; keys stored in /dev/shm; explicit 0600 perms; delete keys when session opens. 2026-02-03 09:17:15 +00:00
667b02f0c3 Ephemeral keys for xterm.js. Initial rework of audit logging. All endpoints now return a 401 regardless of presence if not logged in. 2026-02-03 08:26:37 +00:00
3e17d6412c ASGI via Daphne for websockets, WSGI via Gunicorn. Implemented xterm.js for shell proxy to target servers. 2026-01-27 00:33:53 +00:00
56caa194ec Cleaned up object perms 2026-01-26 23:55:58 +00:00
9cf782ffd6 Added certificate regeneration. Refactored server dashboard. 2026-01-26 23:36:12 +00:00
664e7be9f0 Certificate generation and sync, implemented proper grant and revocation flows. Pubkey uploading. Added openssh-client to Dockerfile 2026-01-26 23:27:18 +00:00
cdaceb1cf7 Display username on client panel 2026-01-26 17:18:31 +00:00
43bff4513a object‑permission–driven server access; agent‑managed account provisioning with presence reporting 2026-01-26 17:03:44 +00:00
ed2f921b0f Fixed index name too long. Added icon and fixed unfold branding 2026-01-26 16:41:16 +00:00
e693a7616c GDPR Compliant erasure requests 2026-01-26 13:50:21 +00:00
548681face Improved API docs, removed DELETE endpoint from user 2026-01-26 13:42:08 +00:00
c115f41dac Switched to Redoc 2026-01-26 13:31:08 +00:00
69802f3ece Agent retries on connection loss, sends connection info (v4 v6) Uses system CA for mTLS. Removed server endpoints. 2026-01-26 01:13:51 +00:00
e7d20360a2 Django->6.0.1 Django Ninja->1.5.2 mozilla-django-oidc->5.0.2 django-guardian->3.2 gunicorn->24.1 django-unfold->0.76 2026-01-26 00:43:49 +00:00
1d0c075d68 Attempting to resolve unfold form inconsistencies. 2026-01-26 00:21:49 +00:00
b95084ddc3 Linux agent functional. Added new client-facing server panel. Removed deferred pydantic annotations. 2026-01-25 23:08:40 +00:00
4885622d6a Initial linux agent and api functionality for enrolling servers 2026-01-25 22:24:20 +00:00
66ffa3d3fb Initial django guardian integrations 2026-01-25 17:48:14 +00:00
6901f6fcc4 RBAC + Per-Route Audit Events 2026-01-20 10:08:32 +00:00
47b90fee87 Added logging, self-signed certs and KEYWARDEN_DOMAIN env variable 2026-01-19 19:47:31 +00:00
43fe875cde Created example env, updated docker-compose, added valkey to supervisord 2026-01-19 19:05:45 +00:00
35252fa1e8 Merge pull request 'Merge new endpoints with development prototype' (#6) from api-dev into dev
Reviewed-on: #6
2026-01-19 18:40:21 +00:00
166 changed files with 9264 additions and 1338 deletions

31
.env.example Normal file
View File

@@ -0,0 +1,31 @@
# Django settings
KEYWARDEN_SECRET_KEY=supersecret
KEYWARDEN_DEBUG=True
KEYWARDEN_ALLOWED_HOSTS=*
KEYWARDEN_TRUSTED_ORIGINS=https://reverse.proxy.domain.xyz,https://127.0.0.1
KEYWARDEN_DOMAIN=https://example.domain.xyz
# Database
KEYWARDEN_POSTGRES_DB=keywarden
KEYWARDEN_POSTGRES_USER=keywarden
KEYWARDEN_POSTGRES_PASSWORD=postgres
KEYWARDEN_POSTGRES_HOST=keywarden-db
KEYWARDEN_POSTGRES_PORT=5432
# Admin
KEYWARDEN_ADMIN_USERNAME=admin
KEYWARDEN_ADMIN_EMAIL=admin@example.com
KEYWARDEN_ADMIN_PASSWORD=password
# Auth mode: native | oidc | hybrid
KEYWARDEN_AUTH_MODE=native
# OIDC (optional)
# KEYWARDEN_OIDC_CLIENT_ID=
# KEYWARDEN_OIDC_CLIENT_SECRET=
# KEYWARDEN_OIDC_AUTHORIZATION_ENDPOINT=
# KEYWARDEN_OIDC_TOKEN_ENDPOINT=
# KEYWARDEN_OIDC_USER_ENDPOINT=
# KEYWARDEN_OIDC_JWKS_ENDPOINT=

3
.gitignore vendored
View File

@@ -218,9 +218,6 @@ __marimo__/
# Certificates # Certificates
*.pem *.pem
# Docker
*compose.yml
nginx/logs/* nginx/logs/*
nginx/certs/*.pem nginx/certs/*.pem

View File

@@ -13,3 +13,35 @@ Authentication:
Notes: Notes:
- Base URL for v1 endpoints is `/api/v1`. - Base URL for v1 endpoints is `/api/v1`.
- Admin-only routes return `403 Forbidden` when the token user is not staff/superuser. - Admin-only routes return `403 Forbidden` when the token user is not staff/superuser.
Example: update server display name (admin-only)
PATCH `/api/v1/servers/{server_id}`
```json
{
"display_name": "Keywarden Prod"
}
```
## SSH user certificates (OpenSSH CA)
Keywarden signs user SSH keys with an OpenSSH certificate authority. The flow is:
- User uploads a public key (`POST /api/v1/keys`).
- Server signs the key using the active user CA.
- Certificate is stored server-side and can be downloaded by the user.
Endpoints:
- `POST /api/v1/keys/{key_id}/certificate` issues (or re-issues) a certificate.
- `GET /api/v1/keys/{key_id}/certificate` downloads the certificate.
- `GET /api/v1/keys/{key_id}/certificate.sha256` downloads a sha256 hash file.
Agent endpoints (mTLS):
- `GET /api/v1/agent/servers/{server_id}/ssh-ca` returns the CA public key for agent install.
- `GET /api/v1/agent/servers/{server_id}/accounts` returns account + system username (no raw keys).
Configuration:
- `KEYWARDEN_USER_CERT_VALIDITY_DAYS` controls certificate lifetime (default: 30 days).
- `KEYWARDEN_ACCOUNT_USERNAME_TEMPLATE` controls account name derivation.
Note: `ssh-keygen` must be available on the Keywarden server to sign certificates.

View File

@@ -13,12 +13,18 @@ WORKDIR /app
# System deps for psycopg2, node (for Tailwind), etc. # System deps for psycopg2, node (for Tailwind), etc.
RUN apt-get update && apt-get install -y --no-install-recommends \ RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \ build-essential \
ca-certificates \
libpq-dev \ libpq-dev \
curl \ curl \
openssl \
openssh-client \
nginx \ nginx \
nodejs \ nodejs \
npm \ npm \
supervisor \ supervisor \
mkcert \
libnss3-tools \
valkey-server \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# ============================================= # =============================================
@@ -44,13 +50,13 @@ RUN pip install --upgrade pip \
WORKDIR /app WORKDIR /app
COPY ./app . COPY ./app .
COPY nginx/configs/nginx.conf /etc/nginx/nginx.conf COPY nginx/configs/nginx.conf.template /etc/nginx/nginx.conf.template
COPY nginx/configs/options-* /etc/nginx/ COPY nginx/configs/options-* /etc/nginx/
#COPY nginx/configs/sites/ /etc/nginx/conf.d/ #COPY nginx/configs/sites/ /etc/nginx/conf.d/
COPY supervisor/supervisord.conf /etc/supervisor/supervisord.conf COPY supervisor/supervisord.conf /etc/supervisor/supervisord.conf
RUN python manage.py collectstatic --noinput RUN python manage.py collectstatic --noinput
RUN chmod +x /app/entrypoint.sh /app/scripts/gunicorn.sh RUN chmod +x /app/entrypoint.sh /app/scripts/gunicorn.sh /app/scripts/daphne.sh
# ============================================= # =============================================
# 5. Create users for services # 5. Create users for services

View File

@@ -0,0 +1,29 @@
BSD 3-Clause License
Copyright (c) 2024, Valkey contributors
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

22
THIRD_PARTY_NOTICES.md Normal file
View File

@@ -0,0 +1,22 @@
# Third-party notices
This project is licensed under the GNU AGPL v3. It includes third-party components that
are distributed under their own licenses. When redistributing Keywarden (source or binary),
ensure you comply with each component's license terms and include required notices.
## Valkey
Valkey is included in the container image and used as the cache backend.
License: BSD 3-Clause. See `LICENSES/valkey.BSD-3-Clause.txt`.
## Other third-party components
This repository and container image include additional dependencies (Python packages and
system packages). Their licenses typically require you to retain copyright notices and
license texts when redistributing binaries. Review the following sources to determine
exact obligations:
- `requirements.txt` for Python dependencies.
- `Dockerfile` for system packages installed into the image.
- `app/static/` and `app/theme/` for bundled frontend assets.
If you need a full license inventory, generate it from your build environment and add
corresponding license texts under `LICENSES/`.

37
TODO.md Normal file
View File

@@ -0,0 +1,37 @@
Next steps:
Certificate Generation:
- User account is created
- User can input SSH pubkey into profile page
- Keywarden creates signed SSH Certificate from User's pubkey and Keywarden CA
Grant:
- User requests access to target server
- Access request approved
- User has linux account created and has key / cert trusted by target server
- User can log into account
Revocation:
- User has access expire or revoked
- Keywarden removes key / cert from target server, or invalidates on Keywarden's side
- Keywarden removes object permissions
- User cannot access server anymore
Permissions:
Administrator:
- Everything
Auditor:
- Can exclusively view audit logs of servers they have access to via request.
User:
Access Requests:
- Can use Shell?
- Can view logs?
- Can have user account?

1
agent/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
keywarden-agent

27
agent/README.md Normal file
View File

@@ -0,0 +1,27 @@
TODO: Move to boris/keywarden-agent. In main repo for now for development.
# keywarden-agent
Minimal Go agent for Keywarden.
## Build
```
go build -o keywarden-agent ./cmd/keywarden-agent
```
## Run
```
./keywarden-agent -config /etc/keywarden/agent.json -server-url https://keywarden.example.com -enroll-token <token>
```
You can also pass `KEYWARDEN_SERVER_URL` and `KEYWARDEN_ENROLL_TOKEN` as environment variables.
## Config
On first boot, the agent will create a config file if it does not exist. Only `server_url` is required for bootstrapping.
If the Keywarden server uses a private TLS CA, set `server_ca_path` (or `KEYWARDEN_SERVER_CA_PATH`) to the CA PEM file so the agent can verify the server certificate.
See `config.example.json`.

View File

@@ -0,0 +1,278 @@
package main
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/json"
"encoding/pem"
"flag"
"fmt"
"log"
"os"
"os/signal"
"syscall"
"time"
"keywarden/agent/internal/client"
"keywarden/agent/internal/config"
"keywarden/agent/internal/host"
"keywarden/agent/internal/logs"
"keywarden/agent/internal/version"
)
func main() {
configPath := flag.String("config", config.DefaultConfigPath, "Path to agent config JSON")
serverURL := flag.String("server-url", "", "Keywarden server URL (first boot)")
enrollToken := flag.String("enroll-token", "", "Enrollment token (first boot)")
showVersion := flag.Bool("version", false, "Print version and exit")
flag.Parse()
if *showVersion {
fmt.Printf("keywarden-agent %s (commit %s, built %s)\n", version.Version, version.Commit, version.BuildDate)
return
}
cfg, err := config.LoadOrInit(*configPath, pickServerURL(*serverURL))
if err != nil {
log.Fatalf("config error: %v", err)
}
if err := ensureDirs(cfg); err != nil {
log.Fatalf("state dir error: %v", err)
}
if err := bootstrapIfNeeded(cfg, *configPath, pickEnrollToken(*enrollToken)); err != nil {
log.Fatalf("bootstrap error: %v", err)
}
apiClient, err := client.New(cfg)
if err != nil {
log.Fatalf("client error: %v", err)
}
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()
interval := time.Duration(cfg.SyncIntervalSeconds) * time.Second
log.Printf("keywarden-agent started: server_id=%s interval=%s", cfg.ServerID, interval)
ticker := time.NewTicker(interval)
defer ticker.Stop()
runOnce(ctx, apiClient, cfg)
for {
select {
case <-ctx.Done():
log.Printf("shutdown requested")
return
case <-ticker.C:
runOnce(ctx, apiClient, cfg)
}
}
}
func runOnce(ctx context.Context, apiClient *client.Client, cfg *config.Config) {
if err := reportHost(ctx, apiClient, cfg); err != nil {
if client.IsRetriable(err) {
log.Printf("host update deferred; will retry: %v", err)
} else {
log.Printf("host update error: %v", err)
}
}
if err := apiClient.SyncAccounts(ctx, cfg); err != nil {
log.Printf("sync accounts error: %v", err)
}
if err := shipLogs(ctx, apiClient, cfg); err != nil {
if client.IsRetriable(err) {
log.Printf("log shipping deferred; will retry: %v", err)
} else {
log.Printf("log shipping error: %v", err)
}
}
}
func ensureDirs(cfg *config.Config) error {
if err := os.MkdirAll(cfg.StateDir, 0o700); err != nil {
return err
}
if err := os.MkdirAll(cfg.LogSpoolDir(), 0o700); err != nil {
return err
}
return nil
}
func shipLogs(ctx context.Context, apiClient *client.Client, cfg *config.Config) error {
send := func(payload []byte) error {
return retry(ctx, []time.Duration{250 * time.Millisecond, time.Second, 2 * time.Second}, func() error {
return apiClient.SendLogBatch(ctx, cfg.ServerID, payload)
})
}
if err := logs.DrainSpool(cfg.LogSpoolDir(), send); err != nil {
return err
}
cursor, err := logs.ReadCursor(cfg.LogCursorPath())
if err != nil {
return err
}
collector := logs.NewCollector()
events, nextCursor, err := collector.Collect(ctx, cursor, cfg.LogBatchSize)
if err != nil {
return err
}
if len(events) == 0 {
return nil
}
payload, err := json.Marshal(events)
if err != nil {
return err
}
if err := send(payload); err != nil {
if spoolErr := logs.SaveSpool(cfg.LogSpoolDir(), payload); spoolErr != nil {
return spoolErr
}
return err
}
if err := logs.WriteCursor(cfg.LogCursorPath(), nextCursor); err != nil {
return err
}
return nil
}
func reportHost(ctx context.Context, apiClient *client.Client, cfg *config.Config) error {
info := host.Detect()
var pingPtr *int
if pingMs, err := apiClient.Ping(ctx); err == nil {
pingPtr = &pingMs
}
return retry(ctx, []time.Duration{250 * time.Millisecond, time.Second, 2 * time.Second}, func() error {
return apiClient.UpdateHost(ctx, cfg.ServerID, client.HeartbeatRequest{
Host: info.Hostname,
IPv4: info.IPv4,
IPv6: info.IPv6,
PingMs: pingPtr,
})
})
}
func pickServerURL(flagValue string) string {
if flagValue != "" {
return flagValue
}
return os.Getenv("KEYWARDEN_SERVER_URL")
}
func pickEnrollToken(flagValue string) string {
if flagValue != "" {
return flagValue
}
return os.Getenv("KEYWARDEN_ENROLL_TOKEN")
}
func bootstrapIfNeeded(cfg *config.Config, configPath string, enrollToken string) error {
if cfg.ServerID != "" && fileExists(cfg.ClientCertPath()) && fileExists(cfg.CACertPath()) {
return nil
}
if enrollToken == "" {
return fmt.Errorf("missing enrollment token; set KEYWARDEN_ENROLL_TOKEN or -enroll-token")
}
keyPath := cfg.ClientKeyPath()
if !fileExists(keyPath) {
if err := generateKey(keyPath); err != nil {
return err
}
}
csrPEM, err := buildCSR(keyPath)
if err != nil {
return err
}
info := host.Detect()
hostname := info.Hostname
resp, err := client.Enroll(context.Background(), cfg.ServerURL, client.EnrollRequest{
Token: enrollToken,
CSRPEM: csrPEM,
Host: hostname,
IPv4: info.IPv4,
IPv6: info.IPv6,
})
if err != nil {
return err
}
if err := os.WriteFile(cfg.ClientCertPath(), []byte(resp.ClientCert), 0o600); err != nil {
return err
}
if err := os.WriteFile(cfg.CACertPath(), []byte(resp.CACert), 0o600); err != nil {
return err
}
cfg.ServerID = resp.ServerID
if err := config.Save(configPath, cfg); err != nil {
return err
}
return nil
}
func retry(ctx context.Context, delays []time.Duration, fn func() error) error {
var lastErr error
for attempt := 0; attempt <= len(delays); attempt++ {
if attempt > 0 {
if !client.IsRetriable(lastErr) {
return lastErr
}
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(delays[attempt-1]):
}
}
if err := fn(); err != nil {
lastErr = err
continue
}
return nil
}
return lastErr
}
func generateKey(path string) error {
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return err
}
keyDER := x509.MarshalPKCS1PrivateKey(key)
block := &pem.Block{Type: "RSA PRIVATE KEY", Bytes: keyDER}
data := pem.EncodeToMemory(block)
return os.WriteFile(path, data, 0o600)
}
func buildCSR(keyPath string) (string, error) {
keyData, err := os.ReadFile(keyPath)
if err != nil {
return "", err
}
block, _ := pem.Decode(keyData)
if block == nil || block.Type != "RSA PRIVATE KEY" {
return "", fmt.Errorf("invalid private key")
}
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return "", err
}
csrTemplate := &x509.CertificateRequest{Subject: pkix.Name{CommonName: "keywarden-agent"}}
csrDER, err := x509.CreateCertificateRequest(rand.Reader, csrTemplate, key)
if err != nil {
return "", err
}
csrBlock := &pem.Block{Type: "CERTIFICATE REQUEST", Bytes: csrDER}
return string(pem.EncodeToMemory(csrBlock)), nil
}
func fileExists(path string) bool {
info, err := os.Stat(path)
if err != nil {
return false
}
return !info.IsDir()
}

15
agent/config.example.json Normal file
View File

@@ -0,0 +1,15 @@
{
"server_url": "https://keywarden.dev.ntbx.io/api/v1",
"server_id": "4",
"server_ca_path": "",
"sync_interval_seconds": 5,
"log_batch_size": 500,
"state_dir": "/var/lib/keywarden-agent",
"account_policy": {
"username_template": "{{username}}_{{user_id}}",
"default_shell": "/bin/bash",
"admin_group": "sudo",
"create_home": true,
"lock_on_revoke": true
}
}

7
agent/go.mod Normal file
View File

@@ -0,0 +1,7 @@
module keywarden/agent
go 1.22
require (
github.com/coreos/go-systemd/v22 v22.5.0
)

3
agent/go.sum Normal file
View File

@@ -0,0 +1,3 @@
github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=

View File

@@ -0,0 +1,496 @@
package accounts
import (
"bufio"
"encoding/json"
"errors"
"fmt"
"os"
"os/exec"
"path/filepath"
"strconv"
"strings"
"keywarden/agent/internal/config"
)
const (
stateFileName = "accounts.json"
maxUsernameLen = 32
passwdFilePath = "/etc/passwd"
groupFilePath = "/etc/group"
sshDirName = ".ssh"
authKeysName = "authorized_keys"
keywardenGroup = "keywarden"
userCAPath = "/etc/ssh/keywarden_user_ca.pub"
sshdConfigDropDir = "/etc/ssh/sshd_config.d"
sshdConfigDropIn = "/etc/ssh/sshd_config.d/keywarden.conf"
sshdConfigPath = "/etc/ssh/sshd_config"
)
type AccessUser struct {
UserID int
Username string
Email string
SystemUsername string
}
type ReportAccount struct {
UserID int `json:"user_id"`
SystemUser string `json:"system_username"`
Present bool `json:"present"`
}
type Result struct {
Applied int
Revoked int
Accounts []ReportAccount
}
type managedAccount struct {
UserID int `json:"user_id"`
SystemUser string `json:"system_username"`
}
type state struct {
Users map[string]managedAccount `json:"users"`
}
type passwdEntry struct {
UID int
GID int
Home string
}
func Sync(policy config.AccountPolicy, stateDir string, users []AccessUser) (Result, error) {
result := Result{}
statePath := filepath.Join(stateDir, stateFileName)
current, err := loadState(statePath)
if err != nil {
return result, err
}
desired := make(map[int]managedAccount, len(users))
for _, user := range users {
systemUser := user.SystemUsername
if strings.TrimSpace(systemUser) == "" {
systemUser = renderUsername(policy.UsernameTemplate, user.Username, user.UserID)
}
desired[user.UserID] = managedAccount{UserID: user.UserID, SystemUser: systemUser}
}
var syncErr error
if err := ensureGroup(keywardenGroup); err != nil && syncErr == nil {
syncErr = err
}
for _, account := range current.Users {
if _, ok := desired[account.UserID]; ok {
continue
}
if err := revokeUser(account.SystemUser, policy); err != nil && syncErr == nil {
syncErr = err
}
result.Revoked++
}
for userID, account := range desired {
present, err := ensureAccount(account.SystemUser, policy)
if err != nil && syncErr == nil {
syncErr = err
}
if present {
result.Applied++
}
result.Accounts = append(result.Accounts, ReportAccount{
UserID: userID,
SystemUser: account.SystemUser,
Present: present,
})
}
if err := saveState(statePath, desired); err != nil && syncErr == nil {
syncErr = err
}
return result, syncErr
}
func loadState(path string) (state, error) {
st := state{Users: map[string]managedAccount{}}
data, err := os.ReadFile(path)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return st, nil
}
return st, fmt.Errorf("read state: %w", err)
}
if err := json.Unmarshal(data, &st); err != nil {
return st, fmt.Errorf("parse state: %w", err)
}
if st.Users == nil {
st.Users = map[string]managedAccount{}
}
return st, nil
}
func saveState(path string, desired map[int]managedAccount) error {
st := state{Users: map[string]managedAccount{}}
for id, account := range desired {
st.Users[strconv.Itoa(id)] = account
}
data, err := json.MarshalIndent(st, "", " ")
if err != nil {
return fmt.Errorf("encode state: %w", err)
}
if err := os.MkdirAll(filepath.Dir(path), 0o700); err != nil {
return fmt.Errorf("create state dir: %w", err)
}
if err := os.WriteFile(path, data, 0o600); err != nil {
return fmt.Errorf("write state: %w", err)
}
return nil
}
func renderUsername(template string, username string, userID int) string {
raw := strings.ReplaceAll(template, "{{username}}", username)
raw = strings.ReplaceAll(raw, "{{user_id}}", strconv.Itoa(userID))
clean := sanitizeUsername(raw)
if len(clean) > maxUsernameLen {
clean = clean[:maxUsernameLen]
}
if clean == "" {
clean = fmt.Sprintf("kw_%d", userID)
}
return clean
}
func sanitizeUsername(raw string) string {
raw = strings.ToLower(raw)
var b strings.Builder
b.Grow(len(raw))
for _, r := range raw {
if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '_' || r == '-' {
b.WriteRune(r)
continue
}
b.WriteByte('_')
}
out := strings.Trim(b.String(), "-_")
if out == "" {
return ""
}
if strings.HasPrefix(out, "-") {
return "kw" + out
}
return out
}
func userExists(username string) (bool, error) {
cmd := exec.Command("id", "-u", username)
if err := cmd.Run(); err != nil {
if _, ok := err.(*exec.ExitError); ok {
return false, nil
}
return false, err
}
return true, nil
}
func ensureAccount(username string, policy config.AccountPolicy) (bool, error) {
exists, err := userExists(username)
if err != nil {
return false, err
}
if !exists {
if err := createUser(username, policy); err != nil {
return false, err
}
}
if err := ensureGroupMembership(username, keywardenGroup); err != nil {
return true, err
}
if err := enforceCertificateOnly(username, policy); err != nil {
return true, err
}
if err := writeAuthorizedKeys(username, nil); err != nil {
return true, err
}
return true, nil
}
func createUser(username string, policy config.AccountPolicy) error {
args := []string{"-U", "-G", keywardenGroup}
if policy.CreateHome {
args = append(args, "-m")
} else {
args = append(args, "-M")
}
if policy.DefaultShell != "" {
args = append(args, "-s", policy.DefaultShell)
}
args = append(args, username)
cmd := exec.Command("useradd", args...)
if err := cmd.Run(); err != nil {
return fmt.Errorf("useradd %s: %w", username, err)
}
return nil
}
func enforceCertificateOnly(username string, policy config.AccountPolicy) error {
cmd := exec.Command("usermod", "-L", username)
if err := cmd.Run(); err != nil {
return fmt.Errorf("lock account %s: %w", username, err)
}
if policy.DefaultShell != "" {
shellCmd := exec.Command("usermod", "-s", policy.DefaultShell, username)
if err := shellCmd.Run(); err != nil {
return fmt.Errorf("set shell %s: %w", username, err)
}
}
expiryCmd := exec.Command("chage", "-E", "-1", username)
if err := expiryCmd.Run(); err != nil {
return fmt.Errorf("clear expiry %s: %w", username, err)
}
return nil
}
func revokeUser(username string, policy config.AccountPolicy) error {
exists, err := userExists(username)
if err != nil {
return err
}
if !exists {
return nil
}
var revokeErr error
if policy.LockOnRevoke {
if err := disableAccount(username); err != nil {
revokeErr = err
}
}
if err := writeAuthorizedKeys(username, nil); err != nil && revokeErr == nil {
revokeErr = err
}
return revokeErr
}
func disableAccount(username string) error {
cmd := exec.Command("usermod", "-L", username)
if err := cmd.Run(); err != nil {
return fmt.Errorf("lock account %s: %w", username, err)
}
expiryCmd := exec.Command("chage", "-E", "0", username)
if err := expiryCmd.Run(); err != nil {
return fmt.Errorf("expire account %s: %w", username, err)
}
return nil
}
func ensureGroup(name string) error {
exists, err := groupExists(name)
if err != nil {
return err
}
if exists {
return nil
}
cmd := exec.Command("groupadd", name)
if err := cmd.Run(); err != nil {
return fmt.Errorf("groupadd %s: %w", name, err)
}
return nil
}
func groupExists(name string) (bool, error) {
file, err := os.Open(groupFilePath)
if err != nil {
return false, fmt.Errorf("open group file: %w", err)
}
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
if line == "" || strings.HasPrefix(line, "#") {
continue
}
fields := strings.SplitN(line, ":", 4)
if len(fields) < 1 {
continue
}
if fields[0] == name {
return true, nil
}
}
if err := scanner.Err(); err != nil {
return false, fmt.Errorf("scan group file: %w", err)
}
return false, nil
}
func ensureGroupMembership(username string, group string) error {
cmd := exec.Command("usermod", "-a", "-G", group, username)
if err := cmd.Run(); err != nil {
return fmt.Errorf("usermod add %s to %s: %w", username, group, err)
}
return nil
}
func EnsureCA(publicKey string) error {
key := strings.TrimSpace(publicKey)
if key == "" {
return errors.New("user CA public key required")
}
changed, err := writeCAKeyIfChanged(key)
if err != nil {
return err
}
configChanged, err := ensureSSHDConfig()
if err != nil {
return err
}
if changed || configChanged {
if err := reloadSSHD(); err != nil {
return err
}
}
return nil
}
func writeCAKeyIfChanged(key string) (bool, error) {
if data, err := os.ReadFile(userCAPath); err == nil {
if strings.TrimSpace(string(data)) == key {
return false, nil
}
} else if !errors.Is(err, os.ErrNotExist) {
return false, fmt.Errorf("read user CA key: %w", err)
}
if err := os.WriteFile(userCAPath, []byte(key+"\n"), 0o644); err != nil {
return false, fmt.Errorf("write user CA key: %w", err)
}
return true, nil
}
func ensureSSHDConfig() (bool, error) {
content := fmt.Sprintf(
"TrustedUserCAKeys %s\nMatch Group %s\n AuthorizedKeysFile none\n PasswordAuthentication no\n ChallengeResponseAuthentication no\n",
userCAPath,
keywardenGroup,
)
if info, err := os.Stat(sshdConfigDropDir); err == nil && info.IsDir() {
if existing, err := os.ReadFile(sshdConfigDropIn); err == nil {
if string(existing) == content {
return false, nil
}
}
if err := os.WriteFile(sshdConfigDropIn, []byte(content), 0o644); err != nil {
return false, fmt.Errorf("write sshd drop-in: %w", err)
}
return true, nil
}
data, err := os.ReadFile(sshdConfigPath)
if err != nil {
return false, fmt.Errorf("read sshd config: %w", err)
}
if strings.Contains(string(data), "TrustedUserCAKeys "+userCAPath) {
return false, nil
}
updated := string(data)
if !strings.HasSuffix(updated, "\n") {
updated += "\n"
}
updated += "\n# Keywarden managed users\n" + content
if err := os.WriteFile(sshdConfigPath, []byte(updated), 0o644); err != nil {
return false, fmt.Errorf("write sshd config: %w", err)
}
return true, nil
}
func reloadSSHD() error {
if path, _ := exec.LookPath("systemctl"); path != "" {
if err := exec.Command("systemctl", "reload", "sshd").Run(); err == nil {
return nil
}
if err := exec.Command("systemctl", "reload", "ssh").Run(); err == nil {
return nil
}
}
if path, _ := exec.LookPath("service"); path != "" {
if err := exec.Command("service", "sshd", "reload").Run(); err == nil {
return nil
}
if err := exec.Command("service", "ssh", "reload").Run(); err == nil {
return nil
}
}
return errors.New("unable to reload sshd")
}
func writeAuthorizedKeys(username string, keys []string) error {
entry, err := lookupUser(username)
if err != nil {
return err
}
if entry.Home == "" {
return fmt.Errorf("missing home dir for %s", username)
}
sshDir := filepath.Join(entry.Home, sshDirName)
if err := os.MkdirAll(sshDir, 0o700); err != nil {
return fmt.Errorf("mkdir %s: %w", sshDir, err)
}
if err := os.Chmod(sshDir, 0o700); err != nil {
return fmt.Errorf("chmod %s: %w", sshDir, err)
}
if err := os.Chown(sshDir, entry.UID, entry.GID); err != nil {
return fmt.Errorf("chown %s: %w", sshDir, err)
}
authKeysPath := filepath.Join(sshDir, authKeysName)
payload := strings.TrimSpace(strings.Join(keys, "\n"))
if payload != "" {
payload += "\n"
}
if err := os.WriteFile(authKeysPath, []byte(payload), 0o600); err != nil {
return fmt.Errorf("write %s: %w", authKeysPath, err)
}
if err := os.Chown(authKeysPath, entry.UID, entry.GID); err != nil {
return fmt.Errorf("chown %s: %w", authKeysPath, err)
}
return nil
}
func lookupUser(username string) (passwdEntry, error) {
file, err := os.Open(passwdFilePath)
if err != nil {
return passwdEntry{}, fmt.Errorf("open passwd: %w", err)
}
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
if line == "" || strings.HasPrefix(line, "#") {
continue
}
fields := strings.SplitN(line, ":", 7)
if len(fields) < 7 {
continue
}
if fields[0] != username {
continue
}
uid, err := strconv.Atoi(fields[2])
if err != nil {
return passwdEntry{}, fmt.Errorf("parse uid for %s: %w", username, err)
}
gid, err := strconv.Atoi(fields[3])
if err != nil {
return passwdEntry{}, fmt.Errorf("parse gid for %s: %w", username, err)
}
return passwdEntry{
UID: uid,
GID: gid,
Home: fields[5],
}, nil
}
if err := scanner.Err(); err != nil {
return passwdEntry{}, fmt.Errorf("scan passwd: %w", err)
}
return passwdEntry{}, fmt.Errorf("user %s not found", username)
}

View File

@@ -0,0 +1,382 @@
package client
import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"os"
"strings"
"time"
"keywarden/agent/internal/accounts"
"keywarden/agent/internal/config"
)
const defaultTimeout = 15 * time.Second
type Client struct {
baseURL string
http *http.Client
tlsCfg *tls.Config
scheme string
host string
addr string
}
func New(cfg *config.Config) (*Client, error) {
baseURL := strings.TrimRight(cfg.ServerURL, "/")
if baseURL == "" {
return nil, errors.New("server url is required")
}
cert, err := tls.LoadX509KeyPair(cfg.ClientCertPath(), cfg.ClientKeyPath())
if err != nil {
return nil, fmt.Errorf("load client cert: %w", err)
}
caPool, err := x509.SystemCertPool()
if err != nil || caPool == nil {
caPool = x509.NewCertPool()
}
if cfg.ServerCAPath != "" {
caData, err := os.ReadFile(cfg.ServerCAPath)
if err != nil {
return nil, fmt.Errorf("read server ca cert: %w", err)
}
if !caPool.AppendCertsFromPEM(caData) {
return nil, errors.New("parse server ca cert")
}
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
RootCAs: caPool,
MinVersion: tls.VersionTLS12,
}
transport := &http.Transport{
TLSClientConfig: tlsConfig,
}
httpClient := &http.Client{
Timeout: defaultTimeout,
Transport: transport,
}
parsed, err := url.Parse(baseURL)
if err != nil {
return nil, fmt.Errorf("parse server url: %w", err)
}
if parsed.Host == "" {
return nil, errors.New("server url missing host")
}
scheme := parsed.Scheme
if scheme == "" {
scheme = "https"
}
host := parsed.Hostname()
port := parsed.Port()
if port == "" {
if scheme == "http" {
port = "80"
} else {
port = "443"
}
}
addr := net.JoinHostPort(host, port)
return &Client{
baseURL: baseURL,
http: httpClient,
tlsCfg: tlsConfig,
scheme: scheme,
host: host,
addr: addr,
}, nil
}
type EnrollRequest struct {
Token string `json:"token"`
CSRPEM string `json:"csr_pem"`
Host string `json:"host"`
IPv4 string `json:"ipv4,omitempty"`
IPv6 string `json:"ipv6,omitempty"`
AgentID string `json:"agent_id,omitempty"`
}
type EnrollResponse struct {
ServerID string `json:"server_id"`
ClientCert string `json:"client_cert_pem"`
CACert string `json:"ca_cert_pem"`
SyncProfile string `json:"sync_profile,omitempty"`
DisplayName string `json:"display_name,omitempty"`
}
type AccountKey struct {
PublicKey string `json:"public_key"`
Fingerprint string `json:"fingerprint"`
}
type AccountAccess struct {
UserID int `json:"user_id"`
Username string `json:"username"`
Email string `json:"email"`
SystemUsername string `json:"system_username"`
Keys []AccountKey `json:"keys"`
}
type UserCAResponse struct {
PublicKey string `json:"public_key"`
Fingerprint string `json:"fingerprint"`
}
type AccountSyncEntry struct {
UserID int `json:"user_id"`
SystemUsername string `json:"system_username"`
Present bool `json:"present"`
}
type SyncReportRequest struct {
AppliedCount int `json:"applied_count"`
RevokedCount int `json:"revoked_count"`
Message string `json:"message,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
Accounts []AccountSyncEntry `json:"accounts,omitempty"`
}
func Enroll(ctx context.Context, serverURL string, req EnrollRequest) (*EnrollResponse, error) {
baseURL := strings.TrimRight(serverURL, "/")
if baseURL == "" {
return nil, errors.New("server url is required")
}
body, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("encode enroll request: %w", err)
}
httpClient := &http.Client{Timeout: defaultTimeout}
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+"/agent/enroll", bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("build enroll request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
resp, err := httpClient.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("enroll request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("enroll failed: status %s", resp.Status)
}
var out EnrollResponse
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
return nil, fmt.Errorf("decode enroll response: %w", err)
}
if out.ServerID == "" || out.ClientCert == "" || out.CACert == "" {
return nil, errors.New("enroll response missing required fields")
}
return &out, nil
}
func (c *Client) SyncAccounts(ctx context.Context, cfg *config.Config) error {
if cfg == nil {
return errors.New("config required for account sync")
}
ca, err := c.FetchUserCA(ctx, cfg.ServerID)
if err != nil {
return err
}
if err := accounts.EnsureCA(ca.PublicKey); err != nil {
return err
}
users, err := c.FetchAccountAccess(ctx, cfg.ServerID)
if err != nil {
return err
}
accessUsers := make([]accounts.AccessUser, 0, len(users))
for _, user := range users {
accessUsers = append(accessUsers, accounts.AccessUser{
UserID: user.UserID,
Username: user.Username,
Email: user.Email,
SystemUsername: user.SystemUsername,
})
}
result, syncErr := accounts.Sync(cfg.AccountPolicy, cfg.StateDir, accessUsers)
report := SyncReportRequest{
AppliedCount: result.Applied,
RevokedCount: result.Revoked,
Accounts: make([]AccountSyncEntry, 0, len(result.Accounts)),
}
for _, account := range result.Accounts {
report.Accounts = append(report.Accounts, AccountSyncEntry{
UserID: account.UserID,
SystemUsername: account.SystemUser,
Present: account.Present,
})
}
if syncErr != nil {
report.Message = syncErr.Error()
}
if err := c.SendSyncReport(ctx, cfg.ServerID, report); err != nil {
if syncErr != nil {
return fmt.Errorf("sync report failed: %w (sync error: %v)", err, syncErr)
}
return err
}
return syncErr
}
func (c *Client) FetchAccountAccess(ctx context.Context, serverID string) ([]AccountAccess, error) {
req, err := http.NewRequestWithContext(
ctx,
http.MethodGet,
c.baseURL+"/agent/servers/"+serverID+"/accounts",
nil,
)
if err != nil {
return nil, fmt.Errorf("build account access request: %w", err)
}
resp, err := c.http.Do(req)
if err != nil {
return nil, fmt.Errorf("fetch account access: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
return nil, &HTTPStatusError{StatusCode: resp.StatusCode, Status: resp.Status}
}
var out []AccountAccess
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
return nil, fmt.Errorf("decode account access: %w", err)
}
return out, nil
}
func (c *Client) FetchUserCA(ctx context.Context, serverID string) (*UserCAResponse, error) {
req, err := http.NewRequestWithContext(
ctx,
http.MethodGet,
c.baseURL+"/agent/servers/"+serverID+"/ssh-ca",
nil,
)
if err != nil {
return nil, fmt.Errorf("build user ca request: %w", err)
}
resp, err := c.http.Do(req)
if err != nil {
return nil, fmt.Errorf("fetch user ca: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
return nil, &HTTPStatusError{StatusCode: resp.StatusCode, Status: resp.Status}
}
var out UserCAResponse
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
return nil, fmt.Errorf("decode user ca: %w", err)
}
if strings.TrimSpace(out.PublicKey) == "" {
return nil, errors.New("user ca missing public key")
}
return &out, nil
}
func (c *Client) SendSyncReport(ctx context.Context, serverID string, report SyncReportRequest) error {
body, err := json.Marshal(report)
if err != nil {
return fmt.Errorf("encode sync report: %w", err)
}
req, err := http.NewRequestWithContext(
ctx,
http.MethodPost,
c.baseURL+"/agent/servers/"+serverID+"/sync-report",
bytes.NewReader(body),
)
if err != nil {
return fmt.Errorf("build sync report: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.http.Do(req)
if err != nil {
return fmt.Errorf("send sync report: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
return &HTTPStatusError{StatusCode: resp.StatusCode, Status: resp.Status}
}
return nil
}
func (c *Client) SendLogBatch(ctx context.Context, serverID string, payload []byte) error {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/agent/servers/"+serverID+"/logs", bytes.NewReader(payload))
if err != nil {
return fmt.Errorf("build log request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.http.Do(req)
if err != nil {
return fmt.Errorf("send log batch: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
return &HTTPStatusError{StatusCode: resp.StatusCode, Status: resp.Status}
}
return nil
}
type HeartbeatRequest struct {
Host string `json:"host,omitempty"`
IPv4 string `json:"ipv4,omitempty"`
IPv6 string `json:"ipv6,omitempty"`
PingMs *int `json:"ping_ms,omitempty"`
}
func (c *Client) UpdateHost(ctx context.Context, serverID string, reqBody HeartbeatRequest) error {
body, err := json.Marshal(reqBody)
if err != nil {
return fmt.Errorf("encode host update: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/agent/servers/"+serverID+"/heartbeat", bytes.NewReader(body))
if err != nil {
return fmt.Errorf("build host update: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.http.Do(req)
if err != nil {
return fmt.Errorf("send host update: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
return &HTTPStatusError{StatusCode: resp.StatusCode, Status: resp.Status}
}
return nil
}
func (c *Client) Ping(ctx context.Context) (int, error) {
if c.addr == "" {
return 0, errors.New("server address not configured")
}
start := time.Now()
dialer := &net.Dialer{Timeout: defaultTimeout}
if c.scheme == "http" {
conn, err := dialer.DialContext(ctx, "tcp", c.addr)
if err != nil {
return 0, err
}
_ = conn.Close()
return int(time.Since(start).Milliseconds()), nil
}
cfg := c.tlsCfg.Clone()
if cfg.ServerName == "" && c.host != "" {
cfg.ServerName = c.host
}
conn, err := tls.DialWithDialer(dialer, "tcp", c.addr, cfg)
if err != nil {
return 0, err
}
_ = conn.Close()
return int(time.Since(start).Milliseconds()), nil
}

View File

@@ -0,0 +1,36 @@
package client
import (
"context"
"errors"
"net"
)
type HTTPStatusError struct {
StatusCode int
Status string
}
func (e *HTTPStatusError) Error() string {
return "remote status " + e.Status
}
func IsRetriable(err error) bool {
if err == nil {
return false
}
var statusErr *HTTPStatusError
if errors.As(err, &statusErr) {
switch statusErr.StatusCode {
case 404, 408, 429, 500, 502, 503, 504:
return true
default:
return false
}
}
if errors.Is(err, context.DeadlineExceeded) {
return true
}
var netErr net.Error
return errors.As(err, &netErr)
}

View File

@@ -0,0 +1,152 @@
package config
import (
"encoding/json"
"errors"
"fmt"
"os"
"strings"
)
const (
DefaultConfigPath = "/etc/keywarden/agent.json"
DefaultStateDir = "/var/lib/keywarden-agent"
DefaultSyncIntervalSeconds = 30
DefaultLogBatchSize = 500
DefaultUsernameTemplate = "{{username}}_{{user_id}}"
DefaultShell = "/bin/bash"
DefaultAdminGroup = "sudo"
)
type AccountPolicy struct {
UsernameTemplate string `json:"username_template"`
DefaultShell string `json:"default_shell"`
AdminGroup string `json:"admin_group"`
CreateHome bool `json:"create_home"`
LockOnRevoke bool `json:"lock_on_revoke"`
}
type Config struct {
ServerURL string `json:"server_url"`
ServerID string `json:"server_id,omitempty"`
ServerCAPath string `json:"server_ca_path,omitempty"`
SyncIntervalSeconds int `json:"sync_interval_seconds,omitempty"`
LogBatchSize int `json:"log_batch_size,omitempty"`
StateDir string `json:"state_dir,omitempty"`
AccountPolicy AccountPolicy `json:"account_policy,omitempty"`
}
func LoadOrInit(path string, serverURL string) (*Config, error) {
if path == "" {
path = DefaultConfigPath
}
data, err := os.ReadFile(path)
if err != nil {
if !errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("read config: %w", err)
}
if serverURL == "" {
return nil, errors.New("server url required for first boot")
}
cfg := &Config{ServerURL: serverURL, ServerCAPath: os.Getenv("KEYWARDEN_SERVER_CA_PATH")}
applyDefaults(cfg)
if err := validate(cfg, false); err != nil {
return nil, err
}
if err := Save(path, cfg); err != nil {
return nil, err
}
return cfg, nil
}
cfg := &Config{}
if err := json.Unmarshal(data, cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
if cfg.ServerCAPath == "" {
cfg.ServerCAPath = os.Getenv("KEYWARDEN_SERVER_CA_PATH")
}
applyDefaults(cfg)
if err := validate(cfg, false); err != nil {
return nil, err
}
return cfg, nil
}
func Save(path string, cfg *Config) error {
data, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
return fmt.Errorf("encode config: %w", err)
}
if err := os.MkdirAll(dir(path), 0o755); err != nil {
return fmt.Errorf("create config dir: %w", err)
}
if err := os.WriteFile(path, data, 0o600); err != nil {
return fmt.Errorf("write config: %w", err)
}
return nil
}
func applyDefaults(cfg *Config) {
if cfg.SyncIntervalSeconds <= 0 {
cfg.SyncIntervalSeconds = DefaultSyncIntervalSeconds
}
if cfg.LogBatchSize <= 0 {
cfg.LogBatchSize = DefaultLogBatchSize
}
if cfg.StateDir == "" {
cfg.StateDir = DefaultStateDir
}
if cfg.AccountPolicy.UsernameTemplate == "" {
cfg.AccountPolicy.UsernameTemplate = DefaultUsernameTemplate
}
if cfg.AccountPolicy.DefaultShell == "" {
cfg.AccountPolicy.DefaultShell = DefaultShell
}
if cfg.AccountPolicy.AdminGroup == "" {
cfg.AccountPolicy.AdminGroup = DefaultAdminGroup
}
}
func validate(cfg *Config, requireServerID bool) error {
var missing []string
if cfg.ServerURL == "" {
missing = append(missing, "server_url")
}
if requireServerID && cfg.ServerID == "" {
missing = append(missing, "server_id")
}
if len(missing) > 0 {
return fmt.Errorf("missing required config fields: %v", missing)
}
if cfg.SyncIntervalSeconds < 5 {
return errors.New("sync_interval_seconds must be >= 5")
}
return nil
}
func (c *Config) ClientCertPath() string {
return c.StateDir + "/agent.crt"
}
func (c *Config) ClientKeyPath() string {
return c.StateDir + "/agent.key"
}
func (c *Config) CACertPath() string {
return c.StateDir + "/ca.crt"
}
func (c *Config) LogCursorPath() string {
return c.StateDir + "/journal.cursor"
}
func (c *Config) LogSpoolDir() string {
return c.StateDir + "/spool"
}
func dir(path string) string {
if idx := strings.LastIndex(path, string(os.PathSeparator)); idx != -1 {
return path[:idx]
}
return "."
}

View File

@@ -0,0 +1,57 @@
package host
import (
"net"
"os"
)
type Info struct {
Hostname string
IPv4 string
IPv6 string
}
func Detect() Info {
hostname, _ := os.Hostname()
info := Info{Hostname: hostname}
ifaces, err := net.Interfaces()
if err != nil {
return info
}
for _, iface := range ifaces {
if iface.Flags&net.FlagUp == 0 || iface.Flags&net.FlagLoopback != 0 {
continue
}
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
default:
continue
}
if ip == nil || ip.IsLoopback() || ip.IsUnspecified() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
continue
}
if ip4 := ip.To4(); ip4 != nil {
if info.IPv4 == "" {
info.IPv4 = ip4.String()
}
continue
}
if ip.To16() != nil && info.IPv6 == "" {
info.IPv6 = ip.String()
}
}
if info.IPv4 != "" && info.IPv6 != "" {
break
}
}
return info
}

View File

@@ -0,0 +1,177 @@
package logs
import (
"context"
"strings"
"time"
"github.com/coreos/go-systemd/v22/sdjournal"
)
const defaultLimit = 500
type Collector struct {
matches []string
}
func NewCollector() *Collector {
return &Collector{matches: defaultMatches()}
}
func (c *Collector) Collect(ctx context.Context, cursor string, limit int) ([]Event, string, error) {
if limit <= 0 {
limit = defaultLimit
}
j, err := sdjournal.NewJournal()
if err != nil {
return nil, "", err
}
defer j.Close()
for i, match := range c.matches {
if i > 0 {
if err := j.AddDisjunction(); err != nil {
return nil, "", err
}
}
if err := j.AddMatch(match); err != nil {
return nil, "", err
}
}
if cursor != "" {
if err := j.SeekCursor(cursor); err == nil {
_, _ = j.Next()
}
} else {
_ = j.SeekTail()
_, _ = j.Next()
}
var events []Event
var nextCursor string
for len(events) < limit {
select {
case <-ctx.Done():
return events, nextCursor, ctx.Err()
default:
}
n, err := j.Next()
if err != nil {
return events, nextCursor, err
}
if n == 0 {
break
}
entry, err := j.GetEntry()
if err != nil {
return events, nextCursor, err
}
event := fromEntry(entry)
events = append(events, event)
nextCursor = entry.Cursor
}
return events, nextCursor, nil
}
func defaultMatches() []string {
return []string{
"_SYSTEMD_UNIT=sshd.service",
"_SYSTEMD_UNIT=sudo.service",
"_SYSTEMD_UNIT=systemd-networkd.service",
"_SYSTEMD_UNIT=NetworkManager.service",
"_SYSTEMD_UNIT=systemd-logind.service",
"_TRANSPORT=kernel",
}
}
func fromEntry(entry *sdjournal.JournalEntry) Event {
ts := time.Unix(0, int64(entry.RealtimeTimestamp)*int64(time.Microsecond))
event := NewEvent(ts)
fields := entry.Fields
unit := fields["_SYSTEMD_UNIT"]
message := fields["MESSAGE"]
identifier := fields["SYSLOG_IDENTIFIER"]
event.Unit = unit
event.Message = message
event.Priority = fields["PRIORITY"]
event.Hostname = fields["_HOSTNAME"]
event.Fields = fields
event.Category = categorize(unit, identifier, fields)
event.EventType, event.Username, event.SourceIP, event.SessionID = parseMessage(event.Category, message)
if event.EventType == "" {
event.EventType = defaultEventType(event.Category)
}
return event
}
func categorize(unit string, identifier string, fields map[string]string) string {
switch {
case unit == "sshd.service" || identifier == "sshd":
return "access"
case unit == "sudo.service" || identifier == "sudo":
return "auth"
case unit == "systemd-networkd.service" || identifier == "NetworkManager":
return "network"
case fields["_TRANSPORT"] == "kernel":
return "system"
default:
return "system"
}
}
func defaultEventType(category string) string {
switch category {
case "access":
return "ssh"
case "auth":
return "auth"
case "network":
return "network"
default:
return "system"
}
}
func parseMessage(category string, msg string) (eventType string, username string, sourceIP string, sessionID string) {
if msg == "" {
return "", "", "", ""
}
lower := strings.ToLower(msg)
if category == "access" {
switch {
case strings.Contains(lower, "accepted"):
eventType = "ssh.login.success"
username = extractBetween(msg, "for ", " from")
sourceIP = extractBetween(msg, "from ", " port")
case strings.Contains(lower, "failed password"):
eventType = "ssh.login.fail"
username = extractBetween(msg, "for ", " from")
sourceIP = extractBetween(msg, "from ", " port")
case strings.Contains(lower, "session opened"):
eventType = "ssh.session.open"
username = extractBetween(msg, "for user ", " by")
case strings.Contains(lower, "session closed"):
eventType = "ssh.session.close"
username = extractBetween(msg, "for user ", " by")
}
}
return eventType, strings.TrimSpace(username), strings.TrimSpace(sourceIP), strings.TrimSpace(sessionID)
}
func extractBetween(msg string, start string, end string) string {
startIdx := strings.Index(msg, start)
if startIdx == -1 {
return ""
}
startIdx += len(start)
rest := msg[startIdx:]
endIdx := strings.Index(rest, end)
if endIdx == -1 {
return strings.TrimSpace(rest)
}
return strings.TrimSpace(rest[:endIdx])
}

View File

@@ -0,0 +1,24 @@
package logs
import (
"os"
"strings"
)
func ReadCursor(path string) (string, error) {
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return "", nil
}
return "", err
}
return strings.TrimSpace(string(data)), nil
}
func WriteCursor(path string, cursor string) error {
if cursor == "" {
return nil
}
return os.WriteFile(path, []byte(cursor+"\n"), 0o600)
}

View File

@@ -0,0 +1,53 @@
package logs
import (
"fmt"
"os"
"path/filepath"
"sort"
"time"
)
func SaveSpool(dir string, payload []byte) error {
if err := os.MkdirAll(dir, 0o700); err != nil {
return err
}
name := fmt.Sprintf("%d.json", time.Now().UnixNano())
tmp := filepath.Join(dir, name+".tmp")
final := filepath.Join(dir, name)
if err := os.WriteFile(tmp, payload, 0o600); err != nil {
return err
}
return os.Rename(tmp, final)
}
func DrainSpool(dir string, send func([]byte) error) error {
entries, err := os.ReadDir(dir)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return err
}
var files []string
for _, entry := range entries {
if entry.IsDir() {
continue
}
files = append(files, filepath.Join(dir, entry.Name()))
}
sort.Strings(files)
for _, path := range files {
data, err := os.ReadFile(path)
if err != nil {
return err
}
if err := send(data); err != nil {
return err
}
if err := os.Remove(path); err != nil {
return err
}
}
return nil
}

View File

@@ -0,0 +1,23 @@
package logs
import "time"
type Event struct {
Timestamp string `json:"timestamp"`
Category string `json:"category"`
EventType string `json:"event_type"`
Unit string `json:"unit,omitempty"`
Priority string `json:"priority,omitempty"`
Hostname string `json:"hostname,omitempty"`
Username string `json:"username,omitempty"`
Principal string `json:"principal,omitempty"`
SourceIP string `json:"source_ip,omitempty"`
SessionID string `json:"session_id,omitempty"`
Message string `json:"message,omitempty"`
Raw string `json:"raw,omitempty"`
Fields map[string]string `json:"fields,omitempty"`
}
func NewEvent(ts time.Time) Event {
return Event{Timestamp: ts.UTC().Format(time.RFC3339Nano)}
}

View File

@@ -0,0 +1,7 @@
package version
var (
Version = "0.0.1-dev"
Commit = ""
BuildDate = ""
)

BIN
agent/keywarden-agent Executable file

Binary file not shown.

View File

@@ -1,19 +1,108 @@
from django.contrib import admin from django.contrib import admin
from django.urls import reverse
from django.utils import timezone
from django.utils.html import format_html
try:
from unfold.contrib.guardian.admin import GuardedModelAdmin
except ImportError: # Fallback for older Unfold builds without guardian admin shim.
from guardian.admin import GuardedModelAdmin as GuardianGuardedModelAdmin
from unfold.admin import ModelAdmin as UnfoldModelAdmin
class GuardedModelAdmin(GuardianGuardedModelAdmin, UnfoldModelAdmin):
pass
from .models import AccessRequest from .models import AccessRequest
@admin.register(AccessRequest) @admin.register(AccessRequest)
class AccessRequestAdmin(admin.ModelAdmin): class AccessRequestAdmin(GuardedModelAdmin):
autocomplete_fields = ("requester", "server", "decided_by")
list_display = ( list_display = (
"id", "id",
"requester", "requester",
"server", "server",
"status", "status",
"request_shell",
"request_logs",
"request_users",
"requested_at", "requested_at",
"expires_at", "expires_at",
"decided_by", "decided_by",
"delete_link",
) )
list_filter = ("status", "server") list_filter = ("status", "server")
search_fields = ("requester__username", "requester__email", "server__display_name") search_fields = ("requester__username", "requester__email", "server__display_name")
ordering = ("-requested_at",) ordering = ("-requested_at",)
compressed_fields = True
actions_on_top = True
actions_on_bottom = True
def get_readonly_fields(self, request, obj=None):
readonly = ["requested_at"]
if obj:
readonly.extend(["decided_at", "decided_by"])
return readonly
def get_fieldsets(self, request, obj=None):
if obj is None:
return (
(
"Request",
{
"fields": (
"requester",
"server",
"status",
"reason",
"request_shell",
"request_logs",
"request_users",
"expires_at",
)
},
),
)
return (
(
"Request",
{
"fields": (
"requester",
"server",
"status",
"reason",
"request_shell",
"request_logs",
"request_users",
"expires_at",
)
},
),
(
"Decision",
{
"fields": (
"decided_at",
"decided_by",
)
},
),
)
def save_model(self, request, obj, form, change) -> None:
if obj.status in {
AccessRequest.Status.APPROVED,
AccessRequest.Status.DENIED,
AccessRequest.Status.REVOKED,
AccessRequest.Status.CANCELLED,
}:
if not obj.decided_at:
obj.decided_at = timezone.now()
if not obj.decided_by_id and request.user and request.user.is_authenticated:
obj.decided_by = request.user
super().save_model(request, obj, form, change)
def delete_link(self, obj: AccessRequest):
url = reverse("admin:access_accessrequest_delete", args=[obj.pk])
return format_html('<a class="text-red-600" href="{}">Delete</a>', url)
delete_link.short_description = "Delete"

View File

@@ -5,3 +5,7 @@ class AccessConfig(AppConfig):
default_auto_field = "django.db.models.BigAutoField" default_auto_field = "django.db.models.BigAutoField"
name = "apps.access" name = "apps.access"
verbose_name = "Access Requests" verbose_name = "Access Requests"
def ready(self) -> None:
from . import signals # noqa: F401
return super().ready()

View File

@@ -0,0 +1,37 @@
from django.db import migrations, models
def remove_delete_accessrequest_perm(apps, schema_editor):
Permission = apps.get_model("auth", "Permission")
ContentType = apps.get_model("contenttypes", "ContentType")
try:
content_type = ContentType.objects.get(app_label="access", model="accessrequest")
except ContentType.DoesNotExist:
return
Permission.objects.filter(content_type=content_type, codename="delete_accessrequest").delete()
class Migration(migrations.Migration):
dependencies = [
("access", "0001_initial"),
("auth", "__latest__"),
("contenttypes", "__latest__"),
]
operations = [
migrations.RunPython(remove_delete_accessrequest_perm, migrations.RunPython.noop),
migrations.AlterModelOptions(
name="accessrequest",
options={
"verbose_name": "Access request",
"verbose_name_plural": "Access requests",
"default_permissions": ("add", "view", "change"),
"indexes": [
models.Index(fields=["status", "requested_at"], name="acc_req_status_req_idx"),
models.Index(fields=["server", "status"], name="acc_req_server_status_idx"),
],
"ordering": ["-requested_at"],
},
),
]

View File

@@ -0,0 +1,26 @@
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("access", "0002_remove_delete_permission"),
]
operations = [
migrations.AddField(
model_name="accessrequest",
name="request_shell",
field=models.BooleanField(default=False),
),
migrations.AddField(
model_name="accessrequest",
name="request_logs",
field=models.BooleanField(default=False),
),
migrations.AddField(
model_name="accessrequest",
name="request_users",
field=models.BooleanField(default=False),
),
]

View File

@@ -28,6 +28,9 @@ class AccessRequest(models.Model):
max_length=16, choices=Status.choices, default=Status.PENDING, db_index=True max_length=16, choices=Status.choices, default=Status.PENDING, db_index=True
) )
reason = models.TextField(blank=True) reason = models.TextField(blank=True)
request_shell = models.BooleanField(default=False)
request_logs = models.BooleanField(default=False)
request_users = models.BooleanField(default=False)
requested_at = models.DateTimeField(default=timezone.now, editable=False) requested_at = models.DateTimeField(default=timezone.now, editable=False)
decided_at = models.DateTimeField(null=True, blank=True) decided_at = models.DateTimeField(null=True, blank=True)
expires_at = models.DateTimeField(null=True, blank=True) expires_at = models.DateTimeField(null=True, blank=True)
@@ -42,6 +45,7 @@ class AccessRequest(models.Model):
class Meta: class Meta:
verbose_name = "Access request" verbose_name = "Access request"
verbose_name_plural = "Access requests" verbose_name_plural = "Access requests"
default_permissions = ("add", "view", "change")
indexes = [ indexes = [
models.Index(fields=["status", "requested_at"], name="acc_req_status_req_idx"), models.Index(fields=["status", "requested_at"], name="acc_req_status_req_idx"),
models.Index(fields=["server", "status"], name="acc_req_server_status_idx"), models.Index(fields=["server", "status"], name="acc_req_server_status_idx"),

View File

@@ -0,0 +1,26 @@
from __future__ import annotations
from django.db.models import Q
from django.utils import timezone
from guardian.shortcuts import assign_perm, remove_perm
from .models import AccessRequest
def sync_server_view_perm(access_request: AccessRequest) -> None:
if not access_request or not access_request.requester_id or not access_request.server_id:
return
now = timezone.now()
has_valid_access = (
AccessRequest.objects.filter(
requester_id=access_request.requester_id,
server_id=access_request.server_id,
status=AccessRequest.Status.APPROVED,
)
.filter(Q(expires_at__isnull=True) | Q(expires_at__gt=now))
.exists()
)
if has_valid_access:
assign_perm("servers.view_server", access_request.requester, access_request.server)
return
remove_perm("servers.view_server", access_request.requester, access_request.server)

View File

@@ -0,0 +1,22 @@
from __future__ import annotations
from django.db.models.signals import post_save
from django.dispatch import receiver
from guardian.shortcuts import assign_perm
from apps.core.rbac import assign_default_object_permissions
from .models import AccessRequest
from .permissions import sync_server_view_perm
@receiver(post_save, sender=AccessRequest)
def assign_access_request_perms(sender, instance: AccessRequest, created: bool, **kwargs) -> None:
if not created:
sync_server_view_perm(instance)
return
if instance.requester_id:
user = instance.requester
for perm in ("access.view_accessrequest", "access.change_accessrequest"):
assign_perm(perm, user, instance)
assign_default_object_permissions(instance)
sync_server_view_perm(instance)

27
app/apps/access/tasks.py Normal file
View File

@@ -0,0 +1,27 @@
from __future__ import annotations
from celery import shared_task
from django.db import transaction
from django.utils import timezone
from .models import AccessRequest
from .permissions import sync_server_view_perm
@shared_task
def expire_access_requests() -> int:
now = timezone.now()
expired_qs = AccessRequest.objects.select_related("server", "requester").filter(
status=AccessRequest.Status.APPROVED,
expires_at__isnull=False,
expires_at__lte=now,
)
count = 0
for access_request in expired_qs:
with transaction.atomic():
access_request.status = AccessRequest.Status.EXPIRED
access_request.decided_at = now
access_request.decided_by = None
access_request.save(update_fields=["status", "decided_at", "decided_by"])
sync_server_view_perm(access_request)
count += 1
return count

View File

@@ -1,3 +1,58 @@
from django import forms
from django.contrib import admin from django.contrib import admin
# from django.utils import timezone
# No custom models registered in accounts app. The legacy Account model has been removed. from unfold.admin import ModelAdmin
from .models import ErasureRequest
class ErasureRequestAdminForm(forms.ModelForm):
class Meta:
model = ErasureRequest
fields = "__all__"
def clean(self):
cleaned = super().clean()
status = cleaned.get("status")
decision_reason = (cleaned.get("decision_reason") or "").strip()
if status in {ErasureRequest.Status.DENIED, ErasureRequest.Status.PROCESSED} and not decision_reason:
raise forms.ValidationError("Decision reason is required for denied or processed requests.")
return cleaned
@admin.register(ErasureRequest)
class ErasureRequestAdmin(ModelAdmin):
form = ErasureRequestAdminForm
list_display = ("id", "user", "status", "requested_at", "decided_at", "processed_at")
list_filter = ("status", "requested_at", "processed_at")
search_fields = ("user__username", "user__email")
readonly_fields = ("requested_at", "decided_at", "processed_at", "decided_by", "processed_by")
fieldsets = (
(
"Request",
{
"fields": ("user", "reason", "status", "requested_at"),
},
),
(
"Decision",
{
"fields": ("decision_reason", "decided_by", "decided_at"),
},
),
(
"Processing",
{
"fields": ("processed_by", "processed_at"),
},
),
)
def save_model(self, request, obj, form, change) -> None:
if obj.status == ErasureRequest.Status.PROCESSED:
obj.process(request.user, decision_reason=obj.decision_reason)
return
if obj.status == ErasureRequest.Status.DENIED and not obj.decided_at:
obj.decided_at = timezone.now()
obj.decided_by = request.user
super().save_model(request, obj, form, change)

View File

@@ -0,0 +1,39 @@
from django import forms
class ErasureRequestForm(forms.Form):
reason = forms.CharField(
label="Reason for erasure request",
widget=forms.Textarea(
attrs={
"rows": 4,
"placeholder": "Explain why you are requesting data erasure.",
"class": "block w-full resize-y rounded-lg border border-gray-300 bg-gray-50 p-2.5 text-sm text-gray-900 focus:border-blue-500 focus:ring-blue-500",
}
),
min_length=10,
max_length=2000,
)
class SSHKeyForm(forms.Form):
name = forms.CharField(
label="Key Name",
max_length=128,
widget=forms.TextInput(
attrs={
"placeholder": "Device Name",
"class": "block w-full rounded-lg border border-gray-300 bg-gray-50 p-2.5 text-sm text-gray-900 focus:border-blue-500 focus:ring-blue-500",
}
),
)
public_key = forms.CharField(
label="SSH Public Key",
widget=forms.Textarea(
attrs={
"rows": 4,
"placeholder": "ssh-ed25519 AAAaBBbBcCcc111122223333... user@host",
"class": "block w-full resize-y rounded-lg border border-gray-300 bg-gray-50 p-2.5 text-sm text-gray-900 focus:border-blue-500 focus:ring-blue-500",
}
),
)

View File

@@ -0,0 +1,75 @@
from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
import django.utils.timezone
class Migration(migrations.Migration):
dependencies = [
("accounts", "0005_unique_user_email_index"),
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]
operations = [
migrations.CreateModel(
name="ErasureRequest",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("reason", models.TextField()),
(
"status",
models.CharField(
choices=[("pending", "Pending"), ("denied", "Denied"), ("processed", "Processed")],
db_index=True,
default="pending",
max_length=16,
),
),
("requested_at", models.DateTimeField(default=django.utils.timezone.now, editable=False)),
("decided_at", models.DateTimeField(blank=True, null=True)),
("decision_reason", models.TextField(blank=True)),
("processed_at", models.DateTimeField(blank=True, null=True)),
(
"decided_by",
models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="erasure_decisions",
to=settings.AUTH_USER_MODEL,
),
),
(
"processed_by",
models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="erasure_processes",
to=settings.AUTH_USER_MODEL,
),
),
(
"user",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name="erasure_requests",
to=settings.AUTH_USER_MODEL,
),
),
],
options={
"verbose_name": "Erasure request",
"verbose_name_plural": "Erasure requests",
"ordering": ["-requested_at"],
},
),
migrations.AddIndex(
model_name="erasurerequest",
index=models.Index(fields=["status", "requested_at"], name="accounts_erasure_status_idx"),
),
migrations.AddIndex(
model_name="erasurerequest",
index=models.Index(fields=["user", "status"], name="accounts_er_user_status_idx"),
),
]

View File

@@ -1,3 +1,126 @@
from django.db import models from __future__ import annotations
#
# Legacy Account model has been removed. This app now contains URLs/views only. import uuid
from django.conf import settings
from django.db import models, transaction
from django.utils import timezone
class ErasureRequest(models.Model):
class Status(models.TextChoices):
PENDING = "pending", "Pending"
DENIED = "denied", "Denied"
PROCESSED = "processed", "Processed"
user = models.ForeignKey(
settings.AUTH_USER_MODEL,
on_delete=models.CASCADE,
related_name="erasure_requests",
)
reason = models.TextField()
status = models.CharField(max_length=16, choices=Status.choices, default=Status.PENDING, db_index=True)
requested_at = models.DateTimeField(default=timezone.now, editable=False)
decided_at = models.DateTimeField(null=True, blank=True)
decided_by = models.ForeignKey(
settings.AUTH_USER_MODEL,
null=True,
blank=True,
on_delete=models.SET_NULL,
related_name="erasure_decisions",
)
decision_reason = models.TextField(blank=True)
processed_at = models.DateTimeField(null=True, blank=True)
processed_by = models.ForeignKey(
settings.AUTH_USER_MODEL,
null=True,
blank=True,
on_delete=models.SET_NULL,
related_name="erasure_processes",
)
class Meta:
verbose_name = "Erasure request"
verbose_name_plural = "Erasure requests"
ordering = ["-requested_at"]
indexes = [
models.Index(fields=["status", "requested_at"], name="accounts_erasure_status_idx"),
models.Index(fields=["user", "status"], name="accounts_er_user_status_idx"),
]
def __str__(self) -> str:
return f"Erasure request #{self.id} ({self.user_id})"
def process(self, admin_user, decision_reason: str = "") -> None:
if self.status == self.Status.PROCESSED:
return
now = timezone.now()
with transaction.atomic():
self._anonymize_user(admin_user, now)
self.status = self.Status.PROCESSED
self.decided_at = now
self.decided_by = admin_user
self.decision_reason = (decision_reason or "").strip()
self.processed_at = now
self.processed_by = admin_user
self.save(
update_fields=[
"status",
"decided_at",
"decided_by",
"decision_reason",
"processed_at",
"processed_by",
]
)
def _anonymize_user(self, admin_user, now) -> None:
from guardian.models import UserObjectPermission
from apps.access.models import AccessRequest
from apps.keys.models import SSHCertificate, SSHKey
user = self.user
token = uuid.uuid4().hex
anonymous_username = f"erased-{token}"
anonymous_email = f"{anonymous_username}@erased.local"
user.username = anonymous_username
user.email = anonymous_email
user.first_name = ""
user.last_name = ""
user.is_active = False
user.is_staff = False
user.is_superuser = False
user.last_login = None
user.set_unusable_password()
user.save(
update_fields=[
"username",
"email",
"first_name",
"last_name",
"is_active",
"is_staff",
"is_superuser",
"last_login",
"password",
]
)
user.groups.clear()
user.user_permissions.clear()
UserObjectPermission.objects.filter(user=user).delete()
SSHKey.objects.filter(user=user, is_active=True).update(is_active=False, revoked_at=now)
SSHCertificate.objects.filter(user=user, is_active=True).update(is_active=False, revoked_at=now)
AccessRequest.objects.filter(requester=user).update(reason="[redacted]")
AccessRequest.objects.filter(
requester=user,
status__in=[AccessRequest.Status.PENDING, AccessRequest.Status.APPROVED],
).update(
status=AccessRequest.Status.REVOKED,
decided_at=now,
decided_by=admin_user,
expires_at=now,
)

View File

@@ -4,35 +4,55 @@
{% block content %} {% block content %}
<div class="mx-auto max-w-md"> <div class="mx-auto max-w-md">
<div class="rounded-xl border border-gray-200 bg-white p-6 shadow-sm sm:p-8"> <div class="rounded-2xl border border-gray-200 bg-white p-6 shadow-sm sm:p-8">
<h1 class="mb-6 text-xl font-semibold tracking-tight text-gray-900">Sign in</h1> <div class="space-y-2">
<form method="post" class="space-y-4"> <h1 class="text-2xl font-semibold tracking-tight text-gray-900">Welcome back</h1>
<p class="text-sm text-gray-500">Sign in to manage server access and certificates.</p>
</div>
<form method="post" class="mt-6 space-y-5">
{% csrf_token %} {% csrf_token %}
<input type="hidden" name="next" value="{% url 'accounts:profile' %}"> <input type="hidden" name="next" value="{% url 'servers:dashboard' %}">
<div class="space-y-1.5"> <div>
<label class="block text-sm font-medium text-gray-700">Username</label> <label class="mb-2 block text-sm font-medium text-gray-900">Username</label>
<input type="text" name="username" autocomplete="username" required class="block w-full rounded-md border-gray-300 shadow-sm focus:border-purple-600 focus:ring-purple-600"> <input
type="text"
name="username"
autocomplete="username"
required
class="block w-full rounded-lg border border-gray-300 bg-gray-50 p-2.5 text-sm text-gray-900 focus:border-blue-500 focus:ring-blue-500"
>
</div> </div>
<div class="space-y-1.5"> <div>
<label class="block text-sm font-medium text-gray-700">Password</label> <label class="mb-2 block text-sm font-medium text-gray-900">Password</label>
<input type="password" name="password" autocomplete="current-password" required class="block w-full rounded-md border-gray-300 shadow-sm focus:border-purple-600 focus:ring-purple-600"> <input
type="password"
name="password"
autocomplete="current-password"
required
class="block w-full rounded-lg border border-gray-300 bg-gray-50 p-2.5 text-sm text-gray-900 focus:border-blue-500 focus:ring-blue-500"
>
</div> </div>
{% if form.errors %} {% if form.errors %}
<p class="text-sm text-red-600">Please check your username and password.</p> <div class="flex items-center gap-2 rounded-lg bg-red-50 p-3 text-sm text-red-800" role="alert">
<span class="font-medium">Sign-in failed.</span>
<span>Please check your username and password.</span>
</div>
{% endif %} {% endif %}
<div class="pt-2"> <button
<button type="submit" class="inline-flex w-full items-center justify-center rounded-md bg-purple-600 px-4 py-2.5 text-sm font-semibold text-white shadow hover:bg-purple-700 focus:outline-none focus-visible:ring-2 focus-visible:ring-purple-600"> type="submit"
Sign in class="inline-flex w-full items-center justify-center rounded-lg bg-blue-700 px-5 py-2.5 text-sm font-semibold text-white shadow-sm hover:bg-blue-800 focus:outline-none focus:ring-4 focus:ring-blue-300"
</button> >
</div> Sign in
</button>
</form> </form>
<div class="mt-6 border-t border-gray-200 pt-6"> <div class="mt-6 border-t border-gray-200 pt-6">
<p class="text-sm text-gray-600"> <p class="text-sm text-gray-600">
Or, if configured, use Or, if configured, use
<a href="/oidc/authenticate/" class="font-medium text-purple-700 hover:text-purple-800">OIDC login</a>. <a href="/oidc/authenticate/" class="font-medium text-blue-700 hover:underline">OIDC login</a>.
</p> </p>
</div> </div>
</div> </div>
</div> </div>
{% endblock %} {% endblock %}

View File

@@ -3,47 +3,281 @@
{% block title %}Profile • Keywarden{% endblock %} {% block title %}Profile • Keywarden{% endblock %}
{% block content %} {% block content %}
<div class="grid grid-cols-1 gap-6 lg:grid-cols-2"> <div class="space-y-6">
<div> <div class="grid gap-6 lg:grid-cols-3">
<div class="rounded-xl border border-gray-200 bg-white p-6 shadow-sm sm:p-8"> <section class="rounded-2xl border border-gray-200 bg-white p-6 shadow-sm lg:col-span-2">
<h1 class="mb-6 text-xl font-semibold tracking-tight text-gray-900">Your Profile</h1> <div class="space-y-2">
<dl class="grid grid-cols-1 gap-x-6 gap-y-4 sm:grid-cols-2"> <h1 class="text-xl font-semibold tracking-tight text-gray-900">Your Profile</h1>
<div> <p class="text-sm text-gray-500">Account details and contact information.</p>
<dt class="text-sm font-medium text-gray-500">Username</dt> </div>
<dd class="mt-1 text-sm text-gray-900">{{ user.username }}</dd> <dl class="mt-6 grid grid-cols-1 gap-4 text-sm text-gray-600 sm:grid-cols-2">
<div class="rounded-lg border border-gray-100 bg-gray-50 p-4">
<dt class="text-xs font-semibold uppercase tracking-wide text-gray-500">Username</dt>
<dd class="mt-2 text-sm font-medium text-gray-900">{{ user.username }}</dd>
</div> </div>
<div> <div class="rounded-lg border border-gray-100 bg-gray-50 p-4">
<dt class="text-sm font-medium text-gray-500">Email</dt> <dt class="text-xs font-semibold uppercase tracking-wide text-gray-500">Email</dt>
<dd class="mt-1 text-sm text-gray-900">{{ user.email }}</dd> <dd class="mt-2 text-sm font-medium text-gray-900">{{ user.email }}</dd>
</div> </div>
<div> <div class="rounded-lg border border-gray-100 bg-gray-50 p-4">
<dt class="text-sm font-medium text-gray-500">First name</dt> <dt class="text-xs font-semibold uppercase tracking-wide text-gray-500">First name</dt>
<dd class="mt-1 text-sm text-gray-900">{{ user.first_name|default:"—" }}</dd> <dd class="mt-2 text-sm font-medium text-gray-900">{{ user.first_name|default:"—" }}</dd>
</div> </div>
<div> <div class="rounded-lg border border-gray-100 bg-gray-50 p-4">
<dt class="text-sm font-medium text-gray-500">Last name</dt> <dt class="text-xs font-semibold uppercase tracking-wide text-gray-500">Last name</dt>
<dd class="mt-1 text-sm text-gray-900">{{ user.last_name|default:"—" }}</dd> <dd class="mt-2 text-sm font-medium text-gray-900">{{ user.last_name|default:"—" }}</dd>
</div> </div>
</dl> </dl>
</div> </section>
</div>
<div>
<div class="rounded-xl border border-gray-200 bg-white p-6 shadow-sm sm:p-8">
<h2 class="mb-4 text-base font-semibold tracking-tight text-gray-900">Single Sign-On</h2>
{% if auth_mode == "hybrid" %}
<div class="mt-6 border-t border-gray-200 pt-6">
<p class="text-sm text-gray-600">
Optional: Link your account with your identity provider for single sign-on.
<a href="/oidc/authenticate/" class="font-medium text-purple-700 hover:text-purple-800">Link with SSO</a>
</p>
</div>
{% elif auth_mode == "oidc" %}
<p class="text-sm text-gray-600">OIDC is required. Sign-in is managed by your identity provider.</p>
{% else %}
<p class="text-sm text-gray-600">OIDC is disabled. You are using native authentication.</p>
{% endif %}
</div>
</div>
</div>
{% endblock %}
<section class="rounded-2xl border border-gray-200 bg-white p-6 shadow-sm">
<div class="space-y-2">
<h2 class="text-base font-semibold text-gray-900">Single Sign-On</h2>
<p class="text-sm text-gray-500">Manage how you authenticate with external providers.</p>
</div>
<div class="mt-4 rounded-xl border border-dashed border-gray-200 bg-gray-50 p-4 text-sm text-gray-600">
{% if auth_mode == "hybrid" %}
Optional: Link your account with your identity provider for single sign-on.
<a href="/oidc/authenticate/" class="font-semibold text-blue-700 hover:underline">Link with SSO</a>
{% elif auth_mode == "oidc" %}
OIDC is required. Sign-in is managed by your identity provider.
{% else %}
OIDC is disabled. You are using native authentication.
{% endif %}
</div>
</section>
</div>
<section class="rounded-2xl border border-gray-200 bg-white p-6 shadow-sm">
<div class="flex flex-wrap items-start justify-between gap-4">
<div>
<h2 class="text-base font-semibold text-gray-900">SSH certificates</h2>
<p class="mt-1 text-sm text-gray-500">
Upload your SSH public key to receive a signed certificate for server access.
</p>
</div>
<span class="inline-flex items-center rounded-full bg-blue-50 px-2.5 py-1 text-xs font-semibold text-blue-700">Certificates</span>
</div>
{% if can_add_key %}
<form method="post" class="mt-6 grid gap-4 lg:grid-cols-2">
{% csrf_token %}
<input type="hidden" name="form_type" value="ssh_key">
<div>
<label for="{{ key_form.name.id_for_label }}" class="mb-2 block text-sm font-medium text-gray-900">
Key name
</label>
{{ key_form.name }}
{% if key_form.name.errors %}
<p class="mt-1 text-sm text-red-600">{{ key_form.name.errors|striptags }}</p>
{% endif %}
</div>
<div class="lg:col-span-2">
<label for="{{ key_form.public_key.id_for_label }}" class="mb-2 block text-sm font-medium text-gray-900">
SSH public key
</label>
{{ key_form.public_key }}
{% if key_form.public_key.errors %}
<p class="mt-1 text-sm text-red-600">{{ key_form.public_key.errors|striptags }}</p>
{% endif %}
</div>
{% if key_form.non_field_errors %}
<p class="text-sm text-red-600">{{ key_form.non_field_errors|striptags }}</p>
{% endif %}
<div>
<button
type="submit"
class="inline-flex items-center rounded-lg bg-blue-700 px-4 py-2 text-sm font-semibold text-white shadow-sm hover:bg-blue-800 focus:outline-none focus:ring-4 focus:ring-blue-300"
>
Upload key
</button>
</div>
</form>
{% else %}
<p class="mt-4 text-sm text-gray-600">You do not have permission to add SSH keys.</p>
{% endif %}
{% if ssh_keys %}
<div class="mt-6 overflow-hidden rounded-xl border border-gray-200">
<table class="w-full text-left text-sm text-gray-500">
<thead class="bg-gray-50 text-xs uppercase text-gray-500">
<tr>
<th scope="col" class="px-6 py-3">Key</th>
<th scope="col" class="px-6 py-3">Fingerprint</th>
<th scope="col" class="px-6 py-3 text-right">Actions</th>
</tr>
</thead>
<tbody>
{% for key in ssh_keys %}
<tr class="border-t bg-white">
<th scope="row" class="px-6 py-4 font-medium text-gray-900">
{{ key.name }}
</th>
<td class="px-6 py-4 text-xs text-gray-500">{{ key.fingerprint }}</td>
<td class="px-6 py-4">
<div class="flex flex-wrap items-center justify-end gap-2">
{% if key.is_active %}
<span class="inline-flex items-center rounded-full bg-emerald-100 px-2.5 py-0.5 text-xs font-medium text-emerald-800">Active</span>
<div class="inline-flex rounded-lg shadow-sm" role="group">
<button
type="button"
class="inline-flex items-center rounded-l-lg bg-blue-700 px-3 py-1.5 text-xs font-semibold text-white hover:bg-blue-800 focus:outline-none focus:ring-2 focus:ring-blue-300"
data-download-url="/api/v1/keys/{{ key.id }}/certificate"
>
Download
</button>
<button
type="button"
class="inline-flex items-center rounded-r-lg border border-gray-200 bg-white px-3 py-1.5 text-xs font-semibold text-gray-700 hover:bg-gray-50 focus:outline-none focus:ring-2 focus:ring-blue-300"
data-download-url="/api/v1/keys/{{ key.id }}/certificate.sha256"
>
Hash
</button>
</div>
<button
type="button"
class="inline-flex items-center rounded-lg bg-rose-600 px-3 py-1.5 text-xs font-semibold text-white hover:bg-rose-700 focus:outline-none focus:ring-2 focus:ring-rose-300 js-regenerate-cert"
data-key-id="{{ key.id }}"
>
Regenerate
</button>
{% else %}
<span class="inline-flex items-center rounded-full bg-gray-200 px-2.5 py-0.5 text-xs font-medium text-gray-700">Revoked</span>
{% endif %}
</div>
</td>
</tr>
{% endfor %}
</tbody>
</table>
</div>
{% else %}
<p class="mt-4 text-sm text-gray-600">No SSH keys uploaded yet.</p>
{% endif %}
</section>
<section class="rounded-2xl border border-gray-200 bg-white p-6 shadow-sm">
<div class="flex flex-wrap items-start justify-between gap-4">
<div>
<h2 class="text-base font-semibold text-gray-900">Data erasure request</h2>
<p class="mt-1 text-sm text-gray-500">
Submit a GDPR erasure request to anonymize your account data. An administrator
must review and approve the request before processing.
</p>
</div>
<span class="inline-flex items-center rounded-full bg-gray-100 px-2.5 py-1 text-xs font-semibold text-gray-700">GDPR</span>
</div>
{% if erasure_request %}
<div class="mt-4 rounded-lg border border-gray-200 bg-gray-50 p-4 text-sm text-gray-700">
<div class="flex flex-wrap items-center gap-2">
<span class="text-xs font-semibold uppercase tracking-wide text-gray-500">Status</span>
<span class="inline-flex items-center rounded-full bg-gray-200 px-2.5 py-1 text-xs font-semibold text-gray-700">
{{ erasure_request.status|capfirst }}
</span>
<span class="text-gray-500">Requested {{ erasure_request.requested_at|date:"M j, Y H:i" }}</span>
</div>
{% if erasure_request.decided_at %}
<p class="mt-2 text-gray-600">
Decision {{ erasure_request.decided_at|date:"M j, Y H:i" }}.
{% if erasure_request.decision_reason %}
Reason: {{ erasure_request.decision_reason }}
{% endif %}
</p>
{% endif %}
{% if erasure_request.status == "processed" %}
<p class="mt-2 text-gray-600">
Your account has been anonymized. Access has been revoked and SSH keys disabled.
</p>
{% endif %}
</div>
{% endif %}
{% if not erasure_request or erasure_request.status != "pending" %}
<form method="post" class="mt-6 grid gap-4">
{% csrf_token %}
<input type="hidden" name="form_type" value="erasure">
<div>
<label for="{{ erasure_form.reason.id_for_label }}" class="mb-2 block text-sm font-medium text-gray-900">
Reason for request
</label>
{{ erasure_form.reason }}
{% if erasure_form.reason.errors %}
<p class="mt-1 text-sm text-red-600">{{ erasure_form.reason.errors|striptags }}</p>
{% endif %}
</div>
{% if erasure_form.non_field_errors %}
<p class="text-sm text-red-600">{{ erasure_form.non_field_errors|striptags }}</p>
{% endif %}
<div>
<button
type="submit"
class="inline-flex items-center rounded-lg bg-blue-700 px-4 py-2 text-sm font-semibold text-white shadow-sm hover:bg-blue-800 focus:outline-none focus:ring-4 focus:ring-blue-300"
>
Submit erasure request
</button>
</div>
</form>
{% endif %}
</section>
</div>
<script>
(function () {
function getCookie(name) {
var value = "; " + document.cookie;
var parts = value.split("; " + name + "=");
if (parts.length === 2) {
return parts.pop().split(";").shift();
}
return "";
}
function handleDownload(event) {
var button = event.currentTarget;
var url = button.getAttribute("data-download-url");
if (!url) {
return;
}
window.location.href = url;
}
function handleRegenerate(event) {
var button = event.currentTarget;
var keyId = button.getAttribute("data-key-id");
if (!keyId) {
return;
}
if (!window.confirm("Regenerate the certificate for this key?")) {
return;
}
var csrf = getCookie("csrftoken");
fetch("/api/v1/keys/" + keyId + "/certificate", {
method: "POST",
credentials: "same-origin",
headers: {
"X-CSRFToken": csrf,
},
})
.then(function (response) {
if (!response.ok) {
throw new Error("Certificate regeneration failed.");
}
window.alert("Certificate regenerated.");
})
.catch(function (err) {
window.alert(err.message);
});
}
var downloadButtons = document.querySelectorAll("[data-download-url]");
for (var i = 0; i < downloadButtons.length; i += 1) {
downloadButtons[i].addEventListener("click", handleDownload);
}
var buttons = document.querySelectorAll(".js-regenerate-cert");
for (var j = 0; j < buttons.length; j += 1) {
buttons[j].addEventListener("click", handleRegenerate);
}
})();
</script>
{% endblock %}

View File

@@ -1,16 +1,72 @@
from django.contrib.auth.decorators import login_required
from django.shortcuts import render
from django.conf import settings from django.conf import settings
from django.shortcuts import redirect
from django.contrib.auth import views as auth_views
from django.contrib.auth import logout from django.contrib.auth import logout
from django.contrib.auth import views as auth_views
from django.contrib.auth.decorators import login_required
from django.core.exceptions import ValidationError
from django.db import IntegrityError
from django.shortcuts import redirect, render
from apps.keys.certificates import issue_certificate_for_key
from apps.keys.models import SSHKey
from .forms import ErasureRequestForm, SSHKeyForm
from .models import ErasureRequest
@login_required(login_url="/accounts/login/") @login_required(login_url="/accounts/login/")
def profile(request): def profile(request):
erasure_request = (
ErasureRequest.objects.filter(user=request.user).order_by("-requested_at").first()
)
can_add_key = request.user.has_perm("keys.add_sshkey")
if request.method == "POST":
form_type = request.POST.get("form_type")
if form_type == "ssh_key":
erasure_form = ErasureRequestForm()
key_form = SSHKeyForm(request.POST)
if key_form.is_valid():
if not can_add_key:
key_form.add_error(None, "You do not have permission to add SSH keys.")
else:
name = key_form.cleaned_data["name"].strip()
public_key = key_form.cleaned_data["public_key"].strip()
key = SSHKey(user=request.user, name=name)
try:
key.set_public_key(public_key)
key.save()
issue_certificate_for_key(key, created_by=request.user)
return redirect("accounts:profile")
except ValidationError as exc:
key_form.add_error("public_key", str(exc))
except IntegrityError:
key_form.add_error("public_key", "Key already exists.")
except Exception:
key_form.add_error(None, "Certificate issuance failed.")
else:
key_form = SSHKeyForm()
erasure_form = ErasureRequestForm(request.POST)
if erasure_form.is_valid():
if erasure_request and erasure_request.status == ErasureRequest.Status.PENDING:
erasure_form.add_error(None, "You already have a pending erasure request.")
else:
ErasureRequest.objects.create(
user=request.user,
reason=erasure_form.cleaned_data["reason"].strip(),
)
return redirect("accounts:profile")
else:
erasure_form = ErasureRequestForm()
key_form = SSHKeyForm()
ssh_keys = SSHKey.objects.filter(user=request.user).order_by("-created_at")
context = { context = {
"user": request.user, "user": request.user,
"auth_mode": getattr(settings, "KEYWARDEN_AUTH_MODE", "hybrid"), "auth_mode": getattr(settings, "KEYWARDEN_AUTH_MODE", "hybrid"),
"erasure_request": erasure_request,
"erasure_form": erasure_form,
"key_form": key_form,
"ssh_keys": ssh_keys,
"can_add_key": can_add_key,
} }
return render(request, "accounts/profile.html", context) return render(request, "accounts/profile.html", context)
@@ -26,4 +82,3 @@ def login_view(request):
def logout_view(request): def logout_view(request):
logout(request) logout(request)
return redirect(getattr(settings, "LOGOUT_REDIRECT_URL", "/")) return redirect(getattr(settings, "LOGOUT_REDIRECT_URL", "/"))

View File

@@ -1,17 +1,140 @@
import json
from django import forms
from django.contrib import admin from django.contrib import admin
from unfold.admin import ModelAdmin from unfold.admin import ModelAdmin
from unfold.decorators import action # type: ignore
from .matching import list_api_endpoint_suggestions, list_websocket_endpoint_suggestions
from .models import AuditEventType, AuditLog from .models import AuditEventType, AuditLog
class AuditEventTypeAdminForm(forms.ModelForm):
endpoints_text = forms.CharField(
required=False,
widget=forms.Textarea(
attrs={
"rows": 8,
"placeholder": "/api/v1/servers/\nGET /api/v1/servers/<int:server_id>/\n/ws/servers/*/shell/",
}
),
help_text=(
"One endpoint pattern per line. Supports '*' wildcards and optional METHOD prefixes "
"like 'GET /api/v1/servers/*'."
),
label="Endpoint patterns",
)
ip_whitelist_text = forms.CharField(
required=False,
widget=forms.Textarea(attrs={"rows": 4, "placeholder": "10.0.0.1\n192.168.1.0/24"}),
help_text="One IP address or CIDR range per line.",
label="IP whitelist entries",
)
ip_blacklist_text = forms.CharField(
required=False,
widget=forms.Textarea(attrs={"rows": 4, "placeholder": "203.0.113.10\n198.51.100.0/24"}),
help_text="One IP address or CIDR range per line.",
label="IP blacklist entries",
)
class Meta:
model = AuditEventType
fields = (
"key",
"title",
"description",
"kind",
"default_severity",
"endpoints_text",
"ip_whitelist_enabled",
"ip_whitelist_text",
"ip_blacklist_enabled",
"ip_blacklist_text",
)
class Media:
js = ("audit/eventtype_form.js",)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
instance = kwargs.get("instance") or getattr(self, "instance", None)
if instance and instance.pk:
self.fields["endpoints_text"].initial = "\n".join(instance.endpoints or [])
self.fields["ip_whitelist_text"].initial = "\n".join(instance.ip_whitelist or [])
self.fields["ip_blacklist_text"].initial = "\n".join(instance.ip_blacklist or [])
self.fields["endpoints_text"].widget.attrs["data-api-suggestions"] = json.dumps(
list_api_endpoint_suggestions()
)
self.fields["endpoints_text"].widget.attrs["data-ws-suggestions"] = json.dumps(
list_websocket_endpoint_suggestions()
)
def _lines_to_list(self, value: str) -> list[str]:
results: list[str] = []
for line in (value or "").splitlines():
candidate = line.strip()
if candidate:
results.append(candidate)
return results
def clean_endpoints_text(self) -> str:
value = self.cleaned_data.get("endpoints_text", "")
# Normalize whitespace but keep the raw text for display.
lines = self._lines_to_list(value)
return "\n".join(lines)
def save(self, commit: bool = True):
instance: AuditEventType = super().save(commit=False)
endpoints_text = self.cleaned_data.get("endpoints_text", "")
whitelist_text = self.cleaned_data.get("ip_whitelist_text", "")
blacklist_text = self.cleaned_data.get("ip_blacklist_text", "")
instance.endpoints = self._lines_to_list(endpoints_text)
instance.ip_whitelist = self._lines_to_list(whitelist_text)
instance.ip_blacklist = self._lines_to_list(blacklist_text)
if commit:
instance.save()
return instance
@admin.register(AuditEventType) @admin.register(AuditEventType)
class AuditEventTypeAdmin(ModelAdmin): class AuditEventTypeAdmin(ModelAdmin):
list_display = ("key", "title", "default_severity", "created_at") form = AuditEventTypeAdminForm
search_fields = ("key", "title", "description") list_display = ("key", "title", "kind", "default_severity", "created_at")
list_filter = ("default_severity",) search_fields = ("key", "title", "description", "endpoints")
list_filter = ("kind", "default_severity", "ip_whitelist_enabled", "ip_blacklist_enabled")
ordering = ("key",) ordering = ("key",)
compressed_fields = True compressed_fields = True
fieldsets = (
(
"Event Type",
{
"fields": (
"key",
"title",
"description",
"kind",
"default_severity",
)
},
),
(
"Endpoints",
{
"fields": ("endpoints_text",),
"description": "Only matching endpoints will create audit events.",
},
),
(
"IP Controls",
{
"fields": (
"ip_whitelist_enabled",
"ip_whitelist_text",
"ip_blacklist_enabled",
"ip_blacklist_text",
),
},
),
)
@admin.register(AuditLog) @admin.register(AuditLog)
@@ -87,5 +210,3 @@ class AuditLogAdmin(ModelAdmin):
{"fields": ("metadata",)}, {"fields": ("metadata",)},
), ),
) )

View File

@@ -1,4 +1,5 @@
from django.apps import AppConfig from django.apps import AppConfig
from django.db.models.signals import post_delete, post_save
class AuditConfig(AppConfig): class AuditConfig(AppConfig):
@@ -10,6 +11,10 @@ class AuditConfig(AppConfig):
def ready(self) -> None: def ready(self) -> None:
# Import signal handlers # Import signal handlers
from . import signals # noqa: F401 from . import signals # noqa: F401
from .matching import clear_event_type_cache
from .models import AuditEventType
post_save.connect(clear_event_type_cache, sender=AuditEventType)
post_delete.connect(clear_event_type_cache, sender=AuditEventType)
return super().ready() return super().ready()

231
app/apps/audit/matching.py Normal file
View File

@@ -0,0 +1,231 @@
from __future__ import annotations
import fnmatch
import ipaddress
import re
import time
from dataclasses import dataclass
from typing import Iterable
from django.urls import URLPattern, URLResolver, get_resolver
from .models import AuditEventType
_CACHE_TTL_SECONDS = 15.0
_METHOD_RE = re.compile(r"^(GET|POST|PUT|PATCH|DELETE|OPTIONS|HEAD)\s+(.+)$", re.IGNORECASE)
_REGEX_GROUP_RE = re.compile(r"\(\?P<(?P<name>\w+)>[^)]+\)")
_CONVERTER_RE = re.compile(r"<(?:(?P<converter>[^:>]+):)?(?P<name>[^>]+)>")
@dataclass(frozen=True)
class ParsedEndpointPattern:
method: str | None
pattern: str
def _normalize_path(value: str) -> str:
candidate = (value or "").strip()
if not candidate:
return ""
if "?" in candidate:
candidate = candidate.split("?", 1)[0]
if not candidate.startswith("/"):
candidate = f"/{candidate}"
# Collapse duplicate slashes without being clever.
while "//" in candidate:
candidate = candidate.replace("//", "/")
return candidate
def _strip_regex_anchors(value: str) -> str:
candidate = value.strip()
if candidate.startswith("^"):
candidate = candidate[1:]
if candidate.endswith("$"):
candidate = candidate[:-1]
return candidate
def _placeholder_to_wildcard(value: str) -> str:
candidate = _strip_regex_anchors(value)
candidate = _REGEX_GROUP_RE.sub("*", candidate)
candidate = _CONVERTER_RE.sub("*", candidate)
return candidate
def parse_endpoint_pattern(raw_pattern: str) -> ParsedEndpointPattern | None:
# Parse admin-provided patterns like:
# - "/api/v1/servers/*"
# - "GET /api/v1/servers/<int:server_id>/"
# We normalize both Django-style placeholders and regex routes into
# fnmatch-friendly wildcard patterns.
if not raw_pattern:
return None
raw = raw_pattern.strip()
if not raw:
return None
method: str | None = None
endpoint = raw
match = _METHOD_RE.match(raw)
if match:
method = match.group(1).upper()
endpoint = match.group(2)
endpoint = _normalize_path(_placeholder_to_wildcard(endpoint))
if not endpoint:
return None
return ParsedEndpointPattern(method=method, pattern=endpoint)
def _endpoint_matches_pattern(pattern: ParsedEndpointPattern, method: str, route: str, path: str) -> bool:
if pattern.method and pattern.method != method.upper():
return False
route_norm = _normalize_path(route)
path_norm = _normalize_path(path)
return fnmatch.fnmatch(route_norm, pattern.pattern) or fnmatch.fnmatch(path_norm, pattern.pattern)
def _parse_ip_entry(
entry: str,
) -> ipaddress.IPv4Address | ipaddress.IPv6Address | ipaddress.IPv4Network | ipaddress.IPv6Network | None:
raw = (entry or "").strip()
if not raw:
return None
try:
if "/" in raw:
return ipaddress.ip_network(raw, strict=False)
return ipaddress.ip_address(raw)
except ValueError:
return None
def _ip_in_entries(ip: str, entries: Iterable[str]) -> bool:
try:
candidate_ip = ipaddress.ip_address(ip)
except ValueError:
return False
for entry in entries:
parsed = _parse_ip_entry(entry)
if parsed is None:
continue
if isinstance(parsed, (ipaddress.IPv4Network, ipaddress.IPv6Network)):
if candidate_ip in parsed:
return True
elif candidate_ip == parsed:
return True
return False
def ip_allowed_for_event(event_type: AuditEventType, ip: str | None) -> bool:
# Apply whitelist first (default deny when enabled), then blacklist
# (explicit deny). If the IP cannot be determined, we only allow it
# when no whitelist is enforced.
if not ip:
# If we cannot determine the IP, allow by default unless a whitelist is enforced.
return not event_type.ip_whitelist_enabled
if event_type.ip_whitelist_enabled and not _ip_in_entries(ip, event_type.ip_whitelist or []):
return False
if event_type.ip_blacklist_enabled and _ip_in_entries(ip, event_type.ip_blacklist or []):
return False
return True
def endpoint_matches_event(event_type: AuditEventType, method: str, route: str, path: str) -> bool:
# Event types are opt-in: an empty endpoint list never matches.
# We allow either the resolved Django route or the raw path to match
# so patterns can be authored using whichever is more stable.
patterns = event_type.endpoints or []
if not patterns:
return False
for raw_pattern in patterns:
parsed = parse_endpoint_pattern(str(raw_pattern))
if parsed and _endpoint_matches_pattern(parsed, method, route, path):
return True
return False
_EVENT_TYPE_CACHE: dict[str, tuple[float, list[AuditEventType]]] = {}
def clear_event_type_cache(*_args, **_kwargs) -> None:
_EVENT_TYPE_CACHE.clear()
def get_event_types_for_kind(kind: str) -> list[AuditEventType]:
# Cache event-type catalogs briefly to avoid repeated DB hits on
# high-volume request paths. The cache is cleared on save/delete.
now = time.monotonic()
cached = _EVENT_TYPE_CACHE.get(kind)
if cached and (now - cached[0]) < _CACHE_TTL_SECONDS:
return cached[1]
event_types = list(AuditEventType.objects.filter(kind=kind).order_by("key"))
_EVENT_TYPE_CACHE[kind] = (now, event_types)
return event_types
def find_matching_event_type(kind: str, method: str, route: str, path: str, ip: str | None) -> AuditEventType | None:
# Deterministic first-match semantics: the ordered catalog defines
# precedence when multiple event types could match.
for event_type in get_event_types_for_kind(kind):
if not endpoint_matches_event(event_type, method=method, route=route, path=path):
continue
if not ip_allowed_for_event(event_type, ip):
continue
return event_type
return None
def _join_paths(prefix: str, segment: str) -> str:
if not prefix:
return segment
if not segment:
return prefix
return f"{prefix.rstrip('/')}/{segment.lstrip('/')}"
def _walk_urlpatterns(patterns: Iterable[URLPattern | URLResolver], prefix: str = "") -> list[str]:
# Flatten the resolver tree into full route strings so the admin
# UI can offer endpoint suggestions without hardcoding routes.
results: list[str] = []
for pattern in patterns:
segment = str(pattern.pattern)
combined = _join_paths(prefix, segment)
if isinstance(pattern, URLResolver):
results.extend(_walk_urlpatterns(pattern.url_patterns, combined))
else:
results.append(combined)
return results
def _normalize_suggestion(value: str) -> str:
candidate = _strip_regex_anchors(value)
candidate = candidate.replace("\\", "")
candidate = _REGEX_GROUP_RE.sub(lambda m: f"<{m.group('name')}>", candidate)
candidate = _normalize_path(candidate)
return candidate
def list_api_endpoint_suggestions() -> list[str]:
# Introspect the URL resolver and keep only API routes. Suggestions
# are normalized to human-editable patterns (e.g., "<server_id>").
resolver = get_resolver()
raw_patterns = _walk_urlpatterns(resolver.url_patterns)
suggestions: set[str] = set()
for pattern in raw_patterns:
if not pattern:
continue
normalized = _normalize_suggestion(pattern)
if normalized.startswith("/api"):
suggestions.add(normalized)
return sorted(s for s in suggestions if s)
def list_websocket_endpoint_suggestions() -> list[str]:
# WebSocket routes are maintained separately by Channels, so we
# import them directly from the ASGI routing module.
try:
from keywarden.routing import websocket_urlpatterns
except Exception:
return []
raw_patterns = [str(p.pattern) for p in websocket_urlpatterns]
suggestions = {_normalize_suggestion(p) for p in raw_patterns}
return sorted(s for s in suggestions if s)

View File

@@ -0,0 +1,103 @@
from __future__ import annotations
import time
from django.utils import timezone
from .matching import find_matching_event_type
from .models import AuditEventType, AuditLog
from .utils import get_client_ip, get_request_id
_SKIP_PREFIXES = ("/api/v1/audit", "/api/v1/user")
_SKIP_SUFFIXES = ("/health", "/health/")
def _is_api_request(path: str) -> bool:
return path == "/api" or path.startswith("/api/")
def _should_log_request(path: str) -> bool:
# Only audit API traffic and skip endpoints that would recursively
# generate noisy audit events (audit endpoints, health checks, etc.).
if not _is_api_request(path):
return False
if path in _SKIP_PREFIXES:
return False
if any(path.startswith(prefix + "/") for prefix in _SKIP_PREFIXES):
return False
if any(path.endswith(suffix) for suffix in _SKIP_SUFFIXES):
return False
return True
def _resolve_route(request, fallback: str) -> str:
match = getattr(request, "resolver_match", None)
route = getattr(match, "route", None) if match else None
if route:
return route if route.startswith("/") else f"/{route}"
return fallback
class ApiAuditLogMiddleware:
def __init__(self, get_response):
self.get_response = get_response
def __call__(self, request):
# Fast-exit for non-audited paths before taking timing measurements.
path = request.path_info or request.path
if not _should_log_request(path):
return self.get_response(request)
start = time.monotonic()
try:
response = self.get_response(request)
except Exception as exc:
duration_ms = int((time.monotonic() - start) * 1000)
self._write_log(request, path, 500, duration_ms, error=type(exc).__name__)
raise
duration_ms = int((time.monotonic() - start) * 1000)
self._write_log(request, path, response.status_code, duration_ms)
return response
def _write_log(self, request, path: str, status_code: int, duration_ms: int, error: str | None = None) -> None:
try:
route = _resolve_route(request, path)
client_ip = get_client_ip(request)
# Audit events are explicit: if no configured event type matches,
# we do not create either an event type or a log entry.
event_type = find_matching_event_type(
kind=AuditEventType.Kind.API,
method=request.method,
route=route,
path=path,
ip=client_ip,
)
if event_type is None:
return
user = getattr(request, "user", None)
actor = user if getattr(user, "is_authenticated", False) else None
# Store normalized request context for filtering and forensics.
metadata = {
"method": request.method,
"path": path,
"route": route,
"status_code": status_code,
"duration_ms": duration_ms,
"query_string": request.META.get("QUERY_STRING", ""),
}
if error:
metadata["error"] = error
AuditLog.objects.create(
created_at=timezone.now(),
actor=actor,
event_type=event_type,
message=f"API request {request.method} {route} -> {status_code}",
severity=event_type.default_severity,
source=AuditLog.Source.API,
ip_address=client_ip,
user_agent=request.META.get("HTTP_USER_AGENT", ""),
request_id=get_request_id(request),
metadata=metadata,
)
except Exception:
return

View File

@@ -0,0 +1,69 @@
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("audit", "0002_alter_auditlog_event_type"),
]
operations = [
migrations.AddField(
model_name="auditeventtype",
name="kind",
field=models.CharField(
choices=[("api", "API"), ("websocket", "WebSocket")],
db_index=True,
default="api",
help_text="Whether this event type applies to API or WebSocket traffic.",
max_length=16,
),
),
migrations.AddField(
model_name="auditeventtype",
name="endpoints",
field=models.JSONField(
blank=True,
default=list,
help_text=(
"List of endpoint patterns that should generate this event type. "
"Use one pattern per line in the admin form. Supports '*' wildcards "
"and optional METHOD prefixes like 'GET /api/v1/servers/*'."
),
),
),
migrations.AddField(
model_name="auditeventtype",
name="ip_whitelist_enabled",
field=models.BooleanField(
default=False,
help_text="If enabled, only IPs in the whitelist will generate this event type.",
),
),
migrations.AddField(
model_name="auditeventtype",
name="ip_whitelist",
field=models.JSONField(
blank=True,
default=list,
help_text="List of allowed IP addresses or CIDR ranges. One per line in the admin form.",
),
),
migrations.AddField(
model_name="auditeventtype",
name="ip_blacklist_enabled",
field=models.BooleanField(
default=False,
help_text="If enabled, IPs in the blacklist will be blocked for this event type.",
),
),
migrations.AddField(
model_name="auditeventtype",
name="ip_blacklist",
field=models.JSONField(
blank=True,
default=list,
help_text="List of denied IP addresses or CIDR ranges. One per line in the admin form.",
),
),
]

View File

@@ -13,6 +13,10 @@ class AuditEventType(models.Model):
Useful for consistent naming, severity, and descriptions. Useful for consistent naming, severity, and descriptions.
""" """
class Kind(models.TextChoices):
API = "api", "API"
WEBSOCKET = "websocket", "WebSocket"
class Severity(models.TextChoices): class Severity(models.TextChoices):
INFO = "info", "Info" INFO = "info", "Info"
WARNING = "warning", "Warning" WARNING = "warning", "Warning"
@@ -22,9 +26,43 @@ class AuditEventType(models.Model):
key = models.SlugField(max_length=64, unique=True, help_text="Stable machine key, e.g., user_login") key = models.SlugField(max_length=64, unique=True, help_text="Stable machine key, e.g., user_login")
title = models.CharField(max_length=128, help_text="Human-readable title") title = models.CharField(max_length=128, help_text="Human-readable title")
description = models.TextField(blank=True) description = models.TextField(blank=True)
kind = models.CharField(
max_length=16,
choices=Kind.choices,
default=Kind.API,
db_index=True,
help_text="Whether this event type applies to API or WebSocket traffic.",
)
default_severity = models.CharField( default_severity = models.CharField(
max_length=16, choices=Severity.choices, default=Severity.INFO, db_index=True max_length=16, choices=Severity.choices, default=Severity.INFO, db_index=True
) )
endpoints = models.JSONField(
default=list,
blank=True,
help_text=(
"List of endpoint patterns that should generate this event type. "
"Use one pattern per line in the admin form. Supports '*' wildcards "
"and optional METHOD prefixes like 'GET /api/v1/servers/*'."
),
)
ip_whitelist_enabled = models.BooleanField(
default=False,
help_text="If enabled, only IPs in the whitelist will generate this event type.",
)
ip_whitelist = models.JSONField(
default=list,
blank=True,
help_text="List of allowed IP addresses or CIDR ranges. One per line in the admin form.",
)
ip_blacklist_enabled = models.BooleanField(
default=False,
help_text="If enabled, IPs in the blacklist will be blocked for this event type.",
)
ip_blacklist = models.JSONField(
default=list,
blank=True,
help_text="List of denied IP addresses or CIDR ranges. One per line in the admin form.",
)
created_at = models.DateTimeField(default=timezone.now, editable=False) created_at = models.DateTimeField(default=timezone.now, editable=False)
class Meta: class Meta:
@@ -33,7 +71,7 @@ class AuditEventType(models.Model):
ordering = ["key"] ordering = ["key"]
def __str__(self) -> str: def __str__(self) -> str:
return f"{self.key} ({self.default_severity})" return f"{self.key} [{self.kind}] ({self.default_severity})"
class AuditLog(models.Model): class AuditLog(models.Model):

View File

@@ -6,21 +6,23 @@ from django.dispatch import receiver
from django.utils import timezone from django.utils import timezone
from .models import AuditEventType, AuditLog from .models import AuditEventType, AuditLog
from .utils import get_client_ip
User = get_user_model() User = get_user_model()
def _get_or_create_event(key: str, title: str, severity: str = AuditEventType.Severity.INFO) -> AuditEventType: def _get_event(key: str) -> AuditEventType | None:
event, _ = AuditEventType.objects.get_or_create( try:
key=key, return AuditEventType.objects.get(key=key)
defaults={"title": title, "default_severity": severity}, except AuditEventType.DoesNotExist:
) return None
return event
@receiver(user_logged_in) @receiver(user_logged_in)
def on_user_logged_in(sender, request, user: User, **kwargs): def on_user_logged_in(sender, request, user: User, **kwargs):
event = _get_or_create_event("user_login", "User logged in", AuditEventType.Severity.INFO) event = _get_event("user_login")
if event is None:
return
AuditLog.objects.create( AuditLog.objects.create(
created_at=timezone.now(), created_at=timezone.now(),
actor=user, actor=user,
@@ -28,7 +30,7 @@ def on_user_logged_in(sender, request, user: User, **kwargs):
message=f"User {user} logged in", message=f"User {user} logged in",
severity=event.default_severity, severity=event.default_severity,
source=AuditLog.Source.UI, source=AuditLog.Source.UI,
ip_address=(request.META.get("REMOTE_ADDR") if request else None), ip_address=get_client_ip(request),
user_agent=(request.META.get("HTTP_USER_AGENT") if request else ""), user_agent=(request.META.get("HTTP_USER_AGENT") if request else ""),
metadata={"path": request.path} if request else {}, metadata={"path": request.path} if request else {},
) )
@@ -36,7 +38,9 @@ def on_user_logged_in(sender, request, user: User, **kwargs):
@receiver(user_logged_out) @receiver(user_logged_out)
def on_user_logged_out(sender, request, user: User, **kwargs): def on_user_logged_out(sender, request, user: User, **kwargs):
event = _get_or_create_event("user_logout", "User logged out", AuditEventType.Severity.INFO) event = _get_event("user_logout")
if event is None:
return
AuditLog.objects.create( AuditLog.objects.create(
created_at=timezone.now(), created_at=timezone.now(),
actor=user, actor=user,
@@ -44,9 +48,7 @@ def on_user_logged_out(sender, request, user: User, **kwargs):
message=f"User {user} logged out", message=f"User {user} logged out",
severity=event.default_severity, severity=event.default_severity,
source=AuditLog.Source.UI, source=AuditLog.Source.UI,
ip_address=(request.META.get("REMOTE_ADDR") if request else None), ip_address=get_client_ip(request),
user_agent=(request.META.get("HTTP_USER_AGENT") if request else ""), user_agent=(request.META.get("HTTP_USER_AGENT") if request else ""),
metadata={"path": request.path} if request else {}, metadata={"path": request.path} if request else {},
) )

View File

@@ -0,0 +1,93 @@
(function () {
function parseSuggestions(textarea, key) {
try {
var raw = textarea.dataset[key];
return raw ? JSON.parse(raw) : [];
} catch (err) {
return [];
}
}
function splitLines(value) {
return (value || "")
.split(/\r?\n/)
.map(function (line) {
return line.trim();
})
.filter(function (line) {
return line.length > 0;
});
}
function appendLine(textarea, value) {
var lines = splitLines(textarea.value);
if (lines.indexOf(value) !== -1) {
return;
}
lines.push(value);
textarea.value = lines.join("\n");
textarea.dispatchEvent(new Event("change", { bubbles: true }));
}
document.addEventListener("DOMContentLoaded", function () {
var textarea = document.getElementById("id_endpoints_text");
var kindSelect = document.getElementById("id_kind");
if (!textarea || !kindSelect) {
return;
}
var apiSuggestions = parseSuggestions(textarea, "apiSuggestions");
var wsSuggestions = parseSuggestions(textarea, "wsSuggestions");
var container = document.createElement("div");
container.className = "audit-endpoint-suggestions";
container.style.marginTop = "0.5rem";
var title = document.createElement("div");
title.style.fontWeight = "600";
title.style.marginBottom = "0.25rem";
title.textContent = "Suggested endpoints";
container.appendChild(title);
var list = document.createElement("div");
list.style.display = "flex";
list.style.flexWrap = "wrap";
list.style.gap = "0.25rem";
container.appendChild(list);
textarea.parentNode.insertBefore(container, textarea.nextSibling);
function currentSuggestions() {
return kindSelect.value === "websocket" ? wsSuggestions : apiSuggestions;
}
function renderSuggestions() {
var suggestions = currentSuggestions();
list.innerHTML = "";
if (!suggestions || suggestions.length === 0) {
var empty = document.createElement("span");
empty.textContent = "No endpoint suggestions were found.";
empty.style.opacity = "0.7";
list.appendChild(empty);
return;
}
suggestions.slice(0, 40).forEach(function (suggestion) {
var button = document.createElement("button");
button.type = "button";
button.textContent = suggestion;
button.style.padding = "0.2rem 0.45rem";
button.style.borderRadius = "999px";
button.style.border = "1px solid #d1d5db";
button.style.background = "#f9fafb";
button.style.cursor = "pointer";
button.addEventListener("click", function () {
appendLine(textarea, suggestion);
});
list.appendChild(button);
});
}
kindSelect.addEventListener("change", renderSuggestions);
renderSuggestions();
});
})();

View File

View File

@@ -0,0 +1,86 @@
from __future__ import annotations
from django.http import HttpResponse
from django.test import RequestFactory, TestCase
from apps.audit.matching import find_matching_event_type
from apps.audit.middleware import ApiAuditLogMiddleware
from apps.audit.models import AuditEventType, AuditLog
class ApiAuditMiddlewareTests(TestCase):
def setUp(self) -> None:
super().setUp()
self.factory = RequestFactory()
self.middleware = ApiAuditLogMiddleware(lambda request: HttpResponse("ok"))
def _call(self, method: str, path: str, ip: str = "203.0.113.5") -> None:
request = self.factory.generic(method, path)
request.META["REMOTE_ADDR"] = ip
self.middleware(request)
def test_no_matching_event_type_creates_no_logs_or_event_types(self) -> None:
self._call("GET", "/api/auto/")
self.assertEqual(AuditEventType.objects.count(), 0)
self.assertEqual(AuditLog.objects.count(), 0)
def test_matching_event_type_creates_log(self) -> None:
event_type = AuditEventType.objects.create(
key="api_test",
title="API test",
kind=AuditEventType.Kind.API,
endpoints=["/api/test/"],
)
self._call("GET", "/api/test/")
log = AuditLog.objects.get()
self.assertEqual(log.event_type_id, event_type.id)
self.assertEqual(log.source, AuditLog.Source.API)
self.assertEqual(log.severity, event_type.default_severity)
def test_ip_whitelist_blocks_and_allows(self) -> None:
AuditEventType.objects.create(
key="api_whitelist",
title="API whitelist",
kind=AuditEventType.Kind.API,
endpoints=["/api/whitelist/"],
ip_whitelist_enabled=True,
ip_whitelist=["203.0.113.10"],
)
self._call("GET", "/api/whitelist/", ip="203.0.113.5")
self.assertEqual(AuditLog.objects.count(), 0)
self._call("GET", "/api/whitelist/", ip="203.0.113.10")
self.assertEqual(AuditLog.objects.count(), 1)
def test_ip_blacklist_blocks(self) -> None:
AuditEventType.objects.create(
key="api_blacklist",
title="API blacklist",
kind=AuditEventType.Kind.API,
endpoints=["/api/blacklist/"],
ip_blacklist_enabled=True,
ip_blacklist=["203.0.113.5"],
)
self._call("GET", "/api/blacklist/", ip="203.0.113.5")
self.assertEqual(AuditLog.objects.count(), 0)
class AuditEventMatchingTests(TestCase):
def test_websocket_event_type_can_match(self) -> None:
event_type = AuditEventType.objects.create(
key="ws_shell",
title="WebSocket shell",
kind=AuditEventType.Kind.WEBSOCKET,
endpoints=["/ws/servers/*/shell/"],
)
matched = find_matching_event_type(
kind=AuditEventType.Kind.WEBSOCKET,
method="GET",
route="/ws/servers/123/shell/",
path="/ws/servers/123/shell/",
ip="203.0.113.10",
)
self.assertIsNotNone(matched)
self.assertEqual(matched.id, event_type.id)

92
app/apps/audit/utils.py Normal file
View File

@@ -0,0 +1,92 @@
from __future__ import annotations
import ipaddress
def _normalize_ip(value: str | None) -> str | None:
if not value:
return None
candidate = value.strip()
if not candidate:
return None
if candidate.startswith("[") and "]" in candidate:
candidate = candidate[1 : candidate.index("]")]
elif candidate.count(":") == 1 and candidate.rsplit(":", 1)[1].isdigit():
candidate = candidate.rsplit(":", 1)[0]
try:
return str(ipaddress.ip_address(candidate))
except ValueError:
return None
def get_client_ip(request) -> str | None:
if not request:
return None
x_real_ip = _normalize_ip(request.META.get("HTTP_X_REAL_IP"))
if x_real_ip:
return x_real_ip
forwarded_for = request.META.get("HTTP_X_FORWARDED_FOR", "")
if forwarded_for:
for part in forwarded_for.split(","):
ip = _normalize_ip(part)
if ip:
return ip
return _normalize_ip(request.META.get("REMOTE_ADDR"))
def get_request_id(request) -> str:
if not request:
return ""
return (
request.META.get("HTTP_X_REQUEST_ID")
or request.META.get("HTTP_X_CORRELATION_ID")
or ""
)
def _get_scope_header(scope, header_name: str) -> str | None:
headers = scope.get("headers") if scope else None
if not headers:
return None
target = header_name.lower().encode("latin-1")
for key, value in headers:
if key.lower() == target:
try:
return value.decode("latin-1")
except Exception:
return None
return None
def get_client_ip_from_scope(scope) -> str | None:
if not scope:
return None
x_real_ip = _normalize_ip(_get_scope_header(scope, "x-real-ip"))
if x_real_ip:
return x_real_ip
forwarded_for = _get_scope_header(scope, "x-forwarded-for") or ""
if forwarded_for:
for part in forwarded_for.split(","):
ip = _normalize_ip(part)
if ip:
return ip
client = scope.get("client")
if isinstance(client, (list, tuple)) and client:
return _normalize_ip(str(client[0]))
return None
def get_request_id_from_scope(scope) -> str:
if not scope:
return ""
return (
_get_scope_header(scope, "x-request-id")
or _get_scope_header(scope, "x-correlation-id")
or ""
)
def get_user_agent_from_scope(scope) -> str:
if not scope:
return ""
return _get_scope_header(scope, "user-agent") or ""

21
app/apps/core/apps.py Normal file
View File

@@ -0,0 +1,21 @@
from __future__ import annotations
from django.apps import AppConfig
from django.db.models.signals import post_migrate
class CoreConfig(AppConfig):
default_auto_field = "django.db.models.BigAutoField"
name = "apps.core"
label = "core"
verbose_name = "Core"
def ready(self) -> None:
from .rbac import assign_role_permissions, ensure_role_groups
def _ensure_roles(**_kwargs) -> None:
ensure_role_groups()
assign_role_permissions()
post_migrate.connect(_ensure_roles, dispatch_uid="core_rbac")
return super().ready()

View File

@@ -3,6 +3,8 @@ import os
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
from apps.core.rbac import ROLE_ADMIN, set_user_role
class Command(BaseCommand): class Command(BaseCommand):
help = "Ensure a Django superuser exists using environment variables" help = "Ensure a Django superuser exists using environment variables"
@@ -41,6 +43,7 @@ class Command(BaseCommand):
if created: if created:
user.set_password(password) user.set_password(password)
set_user_role(user, ROLE_ADMIN)
user.save() user.save()
self.stdout.write(self.style.SUCCESS(f"Superuser '{username}' created.")) self.stdout.write(self.style.SUCCESS(f"Superuser '{username}' created."))
return return
@@ -59,10 +62,11 @@ class Command(BaseCommand):
user.is_superuser = True user.is_superuser = True
changed = True changed = True
set_user_role(user, ROLE_ADMIN)
if changed: if changed:
user.save() user.save()
self.stdout.write(self.style.SUCCESS(f"Superuser '{username}' updated.")) self.stdout.write(self.style.SUCCESS(f"Superuser '{username}' updated."))
else: else:
self.stdout.write(self.style.SUCCESS(f"Superuser '{username}' already present.")) self.stdout.write(self.style.SUCCESS(f"Superuser '{username}' already present."))

View File

@@ -0,0 +1,50 @@
from django.core.management.base import BaseCommand
from guardian.shortcuts import assign_perm
from apps.access.models import AccessRequest
from apps.core.rbac import assign_default_object_permissions
from apps.keys.models import SSHKey
from apps.servers.models import Server
class Command(BaseCommand):
help = "Backfill guardian object permissions for access requests and SSH keys."
def handle(self, *args, **options):
access_count = 0
for access_request in AccessRequest.objects.select_related("requester"):
if not access_request.requester_id:
assign_default_object_permissions(access_request)
else:
for perm in (
"access.view_accessrequest",
"access.change_accessrequest",
):
assign_perm(perm, access_request.requester, access_request)
assign_default_object_permissions(access_request)
access_count += 1
key_count = 0
for key in SSHKey.objects.select_related("user"):
if not key.user_id:
assign_default_object_permissions(key)
else:
for perm in ("keys.view_sshkey", "keys.change_sshkey", "keys.delete_sshkey"):
assign_perm(perm, key.user, key)
assign_default_object_permissions(key)
key_count += 1
server_count = 0
for server in Server.objects.all():
assign_default_object_permissions(server)
server_count += 1
self.stdout.write(
self.style.SUCCESS(
"Synced object permissions for "
f"{access_count} access requests, "
f"{key_count} SSH keys, "
f"and {server_count} servers."
)
)

View File

@@ -0,0 +1,20 @@
from __future__ import annotations
from django.http import HttpRequest, HttpResponse
from .views import disguised_not_found
class DisguiseNotFoundMiddleware:
"""Mask 404 responses with a less-informative alternative."""
def __init__(self, get_response):
self.get_response = get_response
def __call__(self, request: HttpRequest) -> HttpResponse:
response = self.get_response(request)
if getattr(response, "status_code", None) != 404:
return response
# Replace all 404 responses, even when DEBUG=True, because Django's
# handler404 is bypassed in debug mode.
return disguised_not_found(request)

136
app/apps/core/rbac.py Normal file
View File

@@ -0,0 +1,136 @@
from __future__ import annotations
from django.contrib.auth.models import Group, Permission
from guardian.shortcuts import assign_perm
from ninja.errors import HttpError
ROLE_ADMIN = "administrator"
ROLE_USER = "user"
ROLE_ORDER = (ROLE_ADMIN, ROLE_USER)
ROLE_ALL = ROLE_ORDER
ROLE_ALIASES = {"admin": ROLE_ADMIN}
ROLE_INPUTS = tuple(sorted(set(ROLE_ORDER) | set(ROLE_ALIASES.keys())))
def _model_perms(app_label: str, model: str, actions: list[str]) -> list[str]:
return [f"{app_label}.{action}_{model}" for action in actions]
ROLE_PERMISSIONS = {
ROLE_ADMIN: [],
ROLE_USER: [
*_model_perms("access", "accessrequest", ["add"]),
*_model_perms("keys", "sshkey", ["add"]),
],
}
OBJECT_PERMISSION_MODELS = {
("servers", "server"),
("access", "accessrequest"),
("keys", "sshkey"),
}
def normalize_role(role: str) -> str:
normalized = (role or "").strip().lower()
return ROLE_ALIASES.get(normalized, normalized)
def ensure_role_groups() -> None:
for role in ROLE_ORDER:
Group.objects.get_or_create(name=role)
def assign_role_permissions() -> None:
ensure_role_groups()
for role, perm_codes in ROLE_PERMISSIONS.items():
group = Group.objects.get(name=role)
if role == ROLE_ADMIN:
group.permissions.set(Permission.objects.all())
continue
perms = []
for code in perm_codes:
if "." not in code:
continue
app_label, codename = code.split(".", 1)
try:
perms.append(
Permission.objects.get(
content_type__app_label=app_label,
codename=codename,
)
)
except Permission.DoesNotExist:
continue
group.permissions.set(perms)
def assign_default_object_permissions(instance) -> None:
app_label = instance._meta.app_label
model_name = instance._meta.model_name
if (app_label, model_name) not in OBJECT_PERMISSION_MODELS:
return
ensure_role_groups()
groups = {group.name: group for group in Group.objects.filter(name__in=ROLE_ORDER)}
for role, perm_codes in ROLE_PERMISSIONS.items():
if role == ROLE_ADMIN:
continue
group = groups.get(role)
if not group:
continue
for code in perm_codes:
if "." not in code:
continue
perm_app, codename = code.split(".", 1)
if perm_app != app_label:
continue
if not codename.endswith(f"_{model_name}"):
continue
if codename.startswith("add_"):
continue
assign_perm(code, group, instance)
def get_user_role(user, default: str = ROLE_USER) -> str | None:
if not user or not getattr(user, "is_authenticated", False):
return None
if getattr(user, "is_superuser", False):
return ROLE_ADMIN
group_names = set(user.groups.values_list("name", flat=True))
for role in ROLE_ORDER:
if role in group_names:
return role
return default
def set_user_role(user, role: str) -> str:
canonical = normalize_role(role)
if canonical not in ROLE_ORDER:
raise ValueError(f"Invalid role: {role}")
ensure_role_groups()
role_groups = list(Group.objects.filter(name__in=ROLE_ORDER))
if role_groups:
user.groups.remove(*role_groups)
target_group = Group.objects.get(name=canonical)
user.groups.add(target_group)
if canonical == ROLE_ADMIN:
user.is_staff = True
user.is_superuser = True
else:
user.is_staff = False
user.is_superuser = False
return canonical
def require_authenticated(request) -> None:
user = getattr(request, "user", None)
if not user or not getattr(user, "is_authenticated", False):
raise HttpError(403, "Forbidden")
def require_perms(request, *perms: str) -> None:
user = getattr(request, "user", None)
if not user or not getattr(user, "is_authenticated", False):
raise HttpError(403, "Forbidden")
if not user.has_perms(perms):
raise HttpError(403, "Forbidden")

27
app/apps/core/views.py Normal file
View File

@@ -0,0 +1,27 @@
from __future__ import annotations
from django.http import HttpRequest, HttpResponse, HttpResponseRedirect, JsonResponse
from django.urls import reverse
from django.views.decorators.cache import never_cache
@never_cache
def disguised_not_found(request: HttpRequest, exception=None) -> HttpResponse:
"""Return a less-informative response for unknown endpoints."""
path = request.path or ""
accepts = (request.META.get("HTTP_ACCEPT") or "").lower()
# Treat anything that looks API-like as a probe and return a generic
# auth-style response rather than a 404 page.
is_api_like = path.startswith("/api/") or "application/json" in accepts
if is_api_like:
# Avoid a 404 response for unknown API paths.
return JsonResponse({"detail": "Unauthorized."}, status=401)
try:
# For browser traffic, redirect to a known entry point so the
# response shape is predictable and uninformative.
target = reverse("servers:dashboard")
except Exception:
target = "/"
return HttpResponseRedirect(target)

View File

@@ -1,11 +1,37 @@
from django.contrib import admin from django.contrib import admin
try:
from unfold.contrib.guardian.admin import GuardedModelAdmin
except ImportError: # Fallback for older Unfold builds without guardian admin shim.
from guardian.admin import GuardedModelAdmin as GuardianGuardedModelAdmin
from unfold.admin import ModelAdmin as UnfoldModelAdmin
from .models import SSHKey class GuardedModelAdmin(GuardianGuardedModelAdmin, UnfoldModelAdmin):
pass
from .models import SSHCertificate, SSHCertificateAuthority, SSHKey
@admin.register(SSHKey) @admin.register(SSHKey)
class SSHKeyAdmin(admin.ModelAdmin): class SSHKeyAdmin(GuardedModelAdmin):
list_display = ("id", "user", "name", "key_type", "fingerprint", "is_active", "created_at") list_display = ("id", "user", "name", "key_type", "fingerprint", "is_active", "created_at")
list_filter = ("is_active", "key_type") list_filter = ("is_active", "key_type")
search_fields = ("name", "user__username", "user__email", "fingerprint") search_fields = ("name", "user__username", "user__email", "fingerprint")
ordering = ("-created_at",) ordering = ("-created_at",)
@admin.register(SSHCertificateAuthority)
class SSHCertificateAuthorityAdmin(admin.ModelAdmin):
list_display = ("name", "fingerprint", "is_active", "created_at", "revoked_at")
list_filter = ("is_active",)
search_fields = ("name", "fingerprint")
readonly_fields = ("created_at", "revoked_at", "fingerprint", "public_key", "private_key")
ordering = ("-created_at",)
@admin.register(SSHCertificate)
class SSHCertificateAdmin(admin.ModelAdmin):
list_display = ("id", "user", "key", "serial", "is_active", "valid_before", "created_at")
list_filter = ("is_active",)
search_fields = ("user__username", "user__email", "serial")
readonly_fields = ("created_at", "revoked_at", "certificate")
ordering = ("-created_at",)

View File

@@ -5,3 +5,7 @@ class KeysConfig(AppConfig):
default_auto_field = "django.db.models.BigAutoField" default_auto_field = "django.db.models.BigAutoField"
name = "apps.keys" name = "apps.keys"
verbose_name = "SSH Keys" verbose_name = "SSH Keys"
def ready(self) -> None:
from . import signals # noqa: F401
return super().ready()

View File

@@ -0,0 +1,159 @@
from __future__ import annotations
import os
import re
import secrets
import subprocess
import tempfile
from datetime import timedelta
from django.conf import settings
from django.utils import timezone
from .models import SSHCertificate, SSHCertificateAuthority, SSHKey
from .utils import render_system_username
def get_active_ca(created_by=None) -> SSHCertificateAuthority:
# Reuse the most recent active CA, or lazily create one if missing.
ca = (
SSHCertificateAuthority.objects.filter(is_active=True, revoked_at__isnull=True)
.order_by("-created_at")
.first()
)
if not ca:
ca = SSHCertificateAuthority(created_by=created_by)
ca.ensure_material()
ca.save()
return ca
def issue_certificate_for_key(key: SSHKey, created_by=None) -> SSHCertificate:
if not key or not key.user_id:
raise ValueError("key must have a user")
ca = get_active_ca(created_by=created_by)
# Principal must match the system account used for SSH logins.
principal = render_system_username(key.user.username, key.user_id)
now = timezone.now()
valid_before = now + timedelta(days=settings.KEYWARDEN_USER_CERT_VALIDITY_DAYS)
# Serial should be unique and non-guessable for audit purposes.
serial = secrets.randbits(63)
safe_name = _sanitize_label(key.name or "key")
identity = f"keywarden-cert-{key.user_id}-{safe_name}-{key.id}"
cert_text = _sign_public_key(
ca_private_key=ca.private_key,
ca_public_key=ca.public_key,
public_key=key.public_key,
identity=identity,
principal=principal,
serial=serial,
validity_days=settings.KEYWARDEN_USER_CERT_VALIDITY_DAYS,
comment=identity,
)
cert, _ = SSHCertificate.objects.update_or_create(
key=key,
defaults={
"user": key.user,
"certificate": cert_text,
"serial": serial,
"principals": [principal],
"valid_after": now,
"valid_before": valid_before,
"revoked_at": None,
"is_active": True,
},
)
return cert
def revoke_certificate_for_key(key: SSHKey) -> None:
if not key:
return
try:
cert = key.certificate
except SSHCertificate.DoesNotExist:
return
# Mark the cert as revoked but keep the record for audit/history.
cert.revoke()
cert.save(update_fields=["is_active", "revoked_at"])
def _sign_public_key(
ca_private_key: str,
ca_public_key: str,
public_key: str,
identity: str,
principal: str,
serial: int,
validity_days: int,
comment: str,
validity_override: str | None = None,
) -> str:
if not ca_private_key or not ca_public_key:
raise RuntimeError("CA material missing")
# Write key material into a temp dir to avoid persisting secrets.
with tempfile.TemporaryDirectory() as tmpdir:
ca_path = os.path.join(tmpdir, "user_ca")
pubkey_path = os.path.join(tmpdir, "user.pub")
_write_file(ca_path, ca_private_key, 0o600)
_write_file(ca_path + ".pub", ca_public_key.strip() + "\n", 0o644)
pubkey_with_comment = _ensure_comment(public_key, comment)
_write_file(pubkey_path, pubkey_with_comment + "\n", 0o644)
# Use ssh-keygen to sign the public key with the CA.
cmd = [
"ssh-keygen",
"-s",
ca_path,
"-I",
identity,
"-n",
principal,
"-V",
validity_override or f"+{validity_days}d",
"-z",
str(serial),
pubkey_path,
]
try:
result = subprocess.run(cmd, check=True, capture_output=True)
except FileNotFoundError as exc:
raise RuntimeError("ssh-keygen not available") from exc
except subprocess.CalledProcessError as exc:
raise RuntimeError(f"ssh-keygen failed: {exc.stderr.decode('utf-8', 'ignore')}") from exc
# ssh-keygen writes the cert alongside the input pubkey.
cert_path = pubkey_path
if cert_path.endswith(".pub"):
cert_path = cert_path[: -len(".pub")]
cert_path += "-cert.pub"
if not os.path.exists(cert_path):
stderr = result.stderr.decode("utf-8", "ignore")
raise RuntimeError(f"ssh-keygen output missing: {cert_path} {stderr}")
with open(cert_path, "r", encoding="utf-8") as handle:
return handle.read().strip()
def _ensure_comment(public_key: str, comment: str) -> str:
# Preserve the key type and base64 payload; replace/append only the comment.
parts = (public_key or "").strip().split()
if len(parts) < 2:
return public_key.strip()
key_type, key_b64 = parts[0], parts[1]
if not comment:
return f"{key_type} {key_b64}"
return f"{key_type} {key_b64} {comment}"
def _sanitize_label(value: str) -> str:
# Reduce label to a safe, lowercase token for certificate identity.
cleaned = re.sub(r"[^a-zA-Z0-9_-]+", "-", (value or "").strip())
cleaned = cleaned.strip("-_")
if cleaned:
return cleaned.lower()
return "key"
def _write_file(path: str, data: str, mode: int) -> None:
with open(path, "w", encoding="utf-8") as handle:
handle.write(data)
# Apply explicit permissions for key material.
os.chmod(path, mode)

View File

@@ -0,0 +1,86 @@
from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
import django.utils.timezone
class Migration(migrations.Migration):
dependencies = [
("keys", "0001_initial"),
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]
operations = [
migrations.CreateModel(
name="SSHCertificateAuthority",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("name", models.CharField(default="Keywarden User SSH CA", max_length=128)),
("public_key", models.TextField(blank=True)),
("private_key", models.TextField(blank=True)),
("fingerprint", models.CharField(blank=True, max_length=128)),
("created_at", models.DateTimeField(default=django.utils.timezone.now, editable=False)),
("revoked_at", models.DateTimeField(blank=True, null=True)),
("is_active", models.BooleanField(db_index=True, default=True)),
(
"created_by",
models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="ssh_certificate_authorities",
to=settings.AUTH_USER_MODEL,
),
),
],
options={
"verbose_name": "SSH certificate authority",
"verbose_name_plural": "SSH certificate authorities",
"ordering": ["-created_at"],
},
),
migrations.CreateModel(
name="SSHCertificate",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("certificate", models.TextField()),
("serial", models.BigIntegerField()),
("principals", models.JSONField(blank=True, default=list)),
("valid_after", models.DateTimeField()),
("valid_before", models.DateTimeField()),
("created_at", models.DateTimeField(default=django.utils.timezone.now, editable=False)),
("revoked_at", models.DateTimeField(blank=True, null=True)),
("is_active", models.BooleanField(db_index=True, default=True)),
(
"key",
models.OneToOneField(
on_delete=django.db.models.deletion.CASCADE,
related_name="certificate",
to="keys.sshkey",
),
),
(
"user",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name="ssh_certificates",
to=settings.AUTH_USER_MODEL,
),
),
],
options={
"verbose_name": "SSH certificate",
"verbose_name_plural": "SSH certificates",
"ordering": ["-created_at"],
},
),
migrations.AddIndex(
model_name="sshcertificate",
index=models.Index(fields=["user", "is_active"], name="keys_cert_user_active_idx"),
),
migrations.AddIndex(
model_name="sshcertificate",
index=models.Index(fields=["valid_before"], name="keys_cert_valid_before_idx"),
),
]

View File

@@ -3,6 +3,9 @@ from __future__ import annotations
import base64 import base64
import binascii import binascii
import hashlib import hashlib
import os
import subprocess
import tempfile
from django.conf import settings from django.conf import settings
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
@@ -61,6 +64,107 @@ class SSHKey(models.Model):
def revoke(self) -> None: def revoke(self) -> None:
self.is_active = False self.is_active = False
self.revoked_at = timezone.now() self.revoked_at = timezone.now()
try:
cert = self.certificate
except SSHCertificate.DoesNotExist:
return
cert.revoke()
cert.save(update_fields=["is_active", "revoked_at"])
def __str__(self) -> str: def __str__(self) -> str:
return f"{self.name} ({self.user_id})" return f"{self.name} ({self.user_id})"
class SSHCertificateAuthority(models.Model):
name = models.CharField(max_length=128, default="Keywarden User SSH CA")
public_key = models.TextField(blank=True)
private_key = models.TextField(blank=True)
fingerprint = models.CharField(max_length=128, blank=True)
created_at = models.DateTimeField(default=timezone.now, editable=False)
revoked_at = models.DateTimeField(null=True, blank=True)
is_active = models.BooleanField(default=True, db_index=True)
created_by = models.ForeignKey(
settings.AUTH_USER_MODEL,
null=True,
blank=True,
on_delete=models.SET_NULL,
related_name="ssh_certificate_authorities",
)
class Meta:
verbose_name = "SSH certificate authority"
verbose_name_plural = "SSH certificate authorities"
ordering = ["-created_at"]
def __str__(self) -> str:
status = "active" if self.is_active and not self.revoked_at else "revoked"
return f"{self.name} ({status})"
def revoke(self) -> None:
self.is_active = False
self.revoked_at = timezone.now()
def ensure_material(self) -> None:
if self.public_key and self.private_key:
if not self.fingerprint:
_, _, fingerprint = parse_public_key(self.public_key)
self.fingerprint = fingerprint
return
with tempfile.TemporaryDirectory() as tmpdir:
key_path = os.path.join(tmpdir, "keywarden_user_ca")
cmd = [
"ssh-keygen",
"-t",
"ed25519",
"-f",
key_path,
"-C",
self.name,
"-N",
"",
]
try:
subprocess.run(cmd, check=True, capture_output=True)
except FileNotFoundError as exc:
raise RuntimeError("ssh-keygen not available") from exc
except subprocess.CalledProcessError as exc:
raise RuntimeError(f"ssh-keygen failed: {exc.stderr.decode('utf-8', 'ignore')}") from exc
with open(key_path, "r", encoding="utf-8") as handle:
self.private_key = handle.read()
with open(key_path + ".pub", "r", encoding="utf-8") as handle:
self.public_key = handle.read().strip()
_, _, fingerprint = parse_public_key(self.public_key)
self.fingerprint = fingerprint
class SSHCertificate(models.Model):
key = models.OneToOneField(
SSHKey, on_delete=models.CASCADE, related_name="certificate"
)
user = models.ForeignKey(
settings.AUTH_USER_MODEL, on_delete=models.CASCADE, related_name="ssh_certificates"
)
certificate = models.TextField()
serial = models.BigIntegerField()
principals = models.JSONField(default=list, blank=True)
valid_after = models.DateTimeField()
valid_before = models.DateTimeField()
created_at = models.DateTimeField(default=timezone.now, editable=False)
revoked_at = models.DateTimeField(null=True, blank=True)
is_active = models.BooleanField(default=True, db_index=True)
class Meta:
verbose_name = "SSH certificate"
verbose_name_plural = "SSH certificates"
indexes = [
models.Index(fields=["user", "is_active"], name="keys_cert_user_active_idx"),
models.Index(fields=["valid_before"], name="keys_cert_valid_before_idx"),
]
ordering = ["-created_at"]
def revoke(self) -> None:
self.is_active = False
self.revoked_at = timezone.now()
def __str__(self) -> str:
return f"{self.user_id} ({self.serial})"

19
app/apps/keys/signals.py Normal file
View File

@@ -0,0 +1,19 @@
from __future__ import annotations
from django.db.models.signals import post_save
from django.dispatch import receiver
from guardian.shortcuts import assign_perm
from apps.core.rbac import assign_default_object_permissions
from .models import SSHKey
@receiver(post_save, sender=SSHKey)
def assign_ssh_key_perms(sender, instance: SSHKey, created: bool, **kwargs) -> None:
if not created:
return
if instance.user_id:
user = instance.user
for perm in ("keys.view_sshkey", "keys.change_sshkey", "keys.delete_sshkey"):
assign_perm(perm, user, instance)
assign_default_object_permissions(instance)

33
app/apps/keys/utils.py Normal file
View File

@@ -0,0 +1,33 @@
from __future__ import annotations
import re
from django.conf import settings
MAX_USERNAME_LEN = 32
_SANITIZE_RE = re.compile(r"[^a-z0-9_-]")
def render_system_username(username: str, user_id: int) -> str:
# Render from template and then sanitize to an OS-safe username.
template = settings.KEYWARDEN_ACCOUNT_USERNAME_TEMPLATE
raw = template.replace("{{username}}", username or "")
raw = raw.replace("{{user_id}}", str(user_id))
cleaned = sanitize_username(raw)
if len(cleaned) > MAX_USERNAME_LEN:
cleaned = cleaned[:MAX_USERNAME_LEN]
if cleaned:
return cleaned
# Fall back to a deterministic, non-empty username.
return f"kw_{user_id}"
def sanitize_username(raw: str) -> str:
# Normalize to lowercase and replace disallowed characters.
raw = (raw or "").lower()
raw = _SANITIZE_RE.sub("_", raw)
raw = raw.strip("-_")
if raw.startswith("-"):
# Avoid leading dash, which can be interpreted as a CLI flag.
return "kw" + raw
return raw

View File

@@ -1,16 +1,34 @@
from django.contrib import admin from django.contrib import admin
from django.utils.html import format_html from django.utils.html import format_html
from .models import Server try:
from unfold.contrib.guardian.admin import GuardedModelAdmin
except ImportError: # Fallback for older Unfold builds without guardian admin shim.
from guardian.admin import GuardedModelAdmin as GuardianGuardedModelAdmin
from unfold.admin import ModelAdmin as UnfoldModelAdmin
class GuardedModelAdmin(GuardianGuardedModelAdmin, UnfoldModelAdmin):
pass
from .models import AgentCertificateAuthority, EnrollmentToken, Server
@admin.register(Server) @admin.register(Server)
class ServerAdmin(admin.ModelAdmin): class ServerAdmin(GuardedModelAdmin):
list_display = ("avatar", "display_name", "hostname", "ipv4", "ipv6", "created_at") list_display = ("avatar", "display_name", "hostname", "ipv4", "ipv6", "agent_enrolled_at", "created_at")
list_display_links = ("display_name",) list_display_links = ("display_name",)
search_fields = ("display_name", "hostname", "ipv4", "ipv6") search_fields = ("display_name", "hostname", "ipv4", "ipv6")
list_filter = ("created_at",) list_filter = ("created_at",)
readonly_fields = ("created_at", "updated_at") readonly_fields = ("created_at", "updated_at", "agent_enrolled_at")
fields = ("display_name", "hostname", "ipv4", "ipv6", "image", "created_at", "updated_at") fields = (
"display_name",
"hostname",
"ipv4",
"ipv6",
"image",
"agent_enrolled_at",
"created_at",
"updated_at",
)
def avatar(self, obj: Server): def avatar(self, obj: Server):
if obj.image_url: if obj.image_url:
@@ -27,3 +45,50 @@ class ServerAdmin(admin.ModelAdmin):
avatar.short_description = "" avatar.short_description = ""
@admin.register(EnrollmentToken)
class EnrollmentTokenAdmin(admin.ModelAdmin):
list_display = ("token", "created_at", "expires_at", "used_at", "server")
list_filter = ("created_at", "used_at")
search_fields = ("token", "server__display_name", "server__hostname")
readonly_fields = ("token", "created_at", "used_at", "server", "created_by")
fields = ("token", "expires_at", "created_by", "created_at", "used_at", "server")
def save_model(self, request, obj, form, change) -> None:
if not obj.pk:
obj.ensure_token()
if request.user and request.user.is_authenticated and not obj.created_by_id:
obj.created_by = request.user
super().save_model(request, obj, form, change)
@admin.register(AgentCertificateAuthority)
class AgentCertificateAuthorityAdmin(admin.ModelAdmin):
list_display = ("name", "is_active", "created_at", "revoked_at")
list_filter = ("is_active", "created_at", "revoked_at")
search_fields = ("name", "fingerprint")
readonly_fields = ("cert_pem", "fingerprint", "serial", "created_at", "revoked_at", "created_by")
fields = (
"name",
"is_active",
"cert_pem",
"fingerprint",
"serial",
"created_by",
"created_at",
"revoked_at",
)
actions = ["revoke_selected"]
def save_model(self, request, obj, form, change) -> None:
if request.user and request.user.is_authenticated and not obj.created_by_id:
obj.created_by = request.user
obj.ensure_material()
if obj.is_active:
AgentCertificateAuthority.objects.exclude(pk=obj.pk).update(is_active=False)
super().save_model(request, obj, form, change)
@admin.action(description="Revoke selected CAs")
def revoke_selected(self, request, queryset):
for ca in queryset:
ca.revoke()
ca.save(update_fields=["is_active", "revoked_at"])

View File

@@ -6,4 +6,7 @@ class ServersConfig(AppConfig):
name = "apps.servers" name = "apps.servers"
verbose_name = "Servers" verbose_name = "Servers"
def ready(self) -> None:
from . import signals # noqa: F401
return super().ready()

View File

@@ -0,0 +1,295 @@
from __future__ import annotations
import asyncio
import os
import secrets
import subprocess
import tempfile
from channels.db import database_sync_to_async
from channels.generic.websocket import AsyncWebsocketConsumer
from django.conf import settings
from django.utils import timezone
from apps.audit.matching import find_matching_event_type
from apps.audit.models import AuditEventType, AuditLog
from apps.audit.utils import (
get_client_ip_from_scope,
get_request_id_from_scope,
get_user_agent_from_scope,
)
from apps.keys.certificates import get_active_ca, _sign_public_key
from apps.keys.utils import render_system_username
from apps.servers.models import Server, ServerAccount
from apps.servers.permissions import user_can_shell
class ShellConsumer(AsyncWebsocketConsumer):
async def connect(self):
# Initialize per-connection state; this consumer is stateful
# across the WebSocket lifecycle.
self.proc = None
self.reader_task = None
self.tempdir = None
self.system_username = ""
self.shell_target = ""
self.server_id: int | None = None
# Reject unauthenticated connections before any side effects.
user = self.scope.get("user")
if not user or not getattr(user, "is_authenticated", False):
await self.close(code=4401)
return
server_id = self.scope.get("url_route", {}).get("kwargs", {}).get("server_id")
if not server_id:
await self.close(code=4400)
return
# Resolve the server and enforce object-level permissions before
# accepting the socket.
server = await self._get_server(user, int(server_id))
if not server:
await self.close(code=4404)
return
self.server_id = server.id
can_shell = await self._can_shell(user, server)
if not can_shell:
await self.close(code=4403)
return
# Resolve the per-user system account name and the best reachable host.
system_username = await self._get_system_username(user, server)
shell_target = server.hostname or server.ipv4 or server.ipv6
if not system_username or not shell_target:
await self.close(code=4400)
return
self.system_username = system_username
self.shell_target = shell_target
# Only accept the socket after all authn/authz checks have passed.
await self.accept()
# Audit the WebSocket connection as an explicit, opt-in event.
await self._audit_websocket_event(user=user, action="connect", metadata={"server_id": server.id})
await self.send(text_data="Connecting...\r\n")
try:
await self._start_ssh(user)
except Exception:
await self.send(text_data="Connection failed.\r\n")
await self.close()
async def disconnect(self, code):
user = self.scope.get("user")
if user and getattr(user, "is_authenticated", False):
await self._audit_websocket_event(
user=user,
action="disconnect",
metadata={"code": code, "server_id": self.server_id},
)
if self.reader_task:
self.reader_task.cancel()
self.reader_task = None
if self.proc and self.proc.returncode is None:
self.proc.terminate()
try:
await asyncio.wait_for(self.proc.wait(), timeout=2.0)
except asyncio.TimeoutError:
self.proc.kill()
if self.tempdir:
self.tempdir.cleanup()
self.tempdir = None
async def receive(self, text_data=None, bytes_data=None):
if not self.proc or not self.proc.stdin:
return
# Forward WebSocket payloads directly to the SSH subprocess stdin.
if bytes_data is not None:
data = bytes_data
elif text_data is not None:
data = text_data.encode("utf-8")
else:
return
if data:
self.proc.stdin.write(data)
await self.proc.stdin.drain()
async def _start_ssh(self, user):
# Generate a short-lived keypair + SSH certificate and then
# bridge the WebSocket to an SSH subprocess.
# Prefer tmpfs when available so the private key never hits disk.
temp_base = "/dev/shm" if os.path.isdir("/dev/shm") and os.access("/dev/shm", os.W_OK) else None
self.tempdir = tempfile.TemporaryDirectory(prefix="keywarden-shell-", dir=temp_base)
key_path, cert_path = await asyncio.to_thread(
_generate_session_keypair,
self.tempdir.name,
user,
self.system_username,
)
ssh_host = _format_ssh_host(self.shell_target)
# Use a locked-down, non-interactive SSH invocation suitable for websockets.
command = [
"ssh",
"-tt",
"-i",
key_path,
"-o",
f"CertificateFile={cert_path}",
"-o",
"BatchMode=yes",
"-o",
"PasswordAuthentication=no",
"-o",
"KbdInteractiveAuthentication=no",
"-o",
"ChallengeResponseAuthentication=no",
"-o",
"PreferredAuthentications=publickey",
"-o",
"UserKnownHostsFile=/dev/null",
"-o",
"GlobalKnownHostsFile=/dev/null",
"-o",
"StrictHostKeyChecking=no",
"-o",
"VerifyHostKeyDNS=no",
"-o",
"LogLevel=ERROR",
f"{self.system_username}@{ssh_host}",
"/bin/bash",
]
self.proc = await asyncio.create_subprocess_exec(
*command,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
# Delete key material immediately after the SSH process has it open.
for path in (key_path, cert_path, f"{key_path}.pub"):
try:
os.remove(path)
except FileNotFoundError:
continue
except Exception:
pass
self.reader_task = asyncio.create_task(self._stream_output())
async def _stream_output(self):
if not self.proc or not self.proc.stdout:
return
# Pump subprocess output until EOF, then close the socket.
while True:
chunk = await self.proc.stdout.read(4096)
if not chunk:
break
await self.send(bytes_data=chunk)
await self.close()
@database_sync_to_async
def _get_server(self, user, server_id: int):
try:
server = Server.objects.get(id=server_id)
except Server.DoesNotExist:
return None
if not user.has_perm("servers.view_server", server):
return None
return server
@database_sync_to_async
def _can_shell(self, user, server) -> bool:
return user_can_shell(user, server, timezone.now())
@database_sync_to_async
def _get_system_username(self, user, server) -> str:
account = ServerAccount.objects.filter(server=server, user=user).first()
if account:
return account.system_username
return render_system_username(user.username, user.id)
@database_sync_to_async
def _audit_websocket_event(self, user, action: str, metadata: dict | None = None) -> None:
try:
path = str(self.scope.get("path") or "")
client_ip = get_client_ip_from_scope(self.scope)
# Match only against explicitly configured WebSocket event types.
event_type = find_matching_event_type(
kind=AuditEventType.Kind.WEBSOCKET,
method="GET",
route=path,
path=path,
ip=client_ip,
)
if event_type is None:
return
combined_metadata = {
"action": action,
"path": path,
}
if metadata:
combined_metadata.update(metadata)
AuditLog.objects.create(
created_at=timezone.now(),
actor=user,
event_type=event_type,
message=f"WebSocket {action} {path}",
severity=event_type.default_severity,
source=AuditLog.Source.API,
ip_address=client_ip,
user_agent=get_user_agent_from_scope(self.scope),
request_id=get_request_id_from_scope(self.scope),
metadata=combined_metadata,
)
except Exception:
# Auditing is best-effort; never fail the shell session.
return
def _generate_session_keypair(tempdir: str, user, principal: str) -> tuple[str, str]:
# Create an ephemeral SSH keypair and sign it with the active CA so
# the user gets time-scoped shell access without long-lived keys.
ca = get_active_ca(created_by=user)
serial = secrets.randbits(63)
identity = f"keywarden-shell-{user.id}-{serial}"
key_path = os.path.join(tempdir, "session_key")
cmd = [
"ssh-keygen",
"-t",
"ed25519",
"-f",
key_path,
"-C",
identity,
"-N",
"",
]
try:
subprocess.run(cmd, check=True, capture_output=True)
except FileNotFoundError as exc:
raise RuntimeError("ssh-keygen not available") from exc
except subprocess.CalledProcessError as exc:
raise RuntimeError(f"ssh-keygen failed: {exc.stderr.decode('utf-8', 'ignore')}") from exc
# Restrict filesystem access to the private key.
os.chmod(key_path, 0o600)
pubkey_path = key_path + ".pub"
with open(pubkey_path, "r", encoding="utf-8") as handle:
public_key = handle.read().strip()
cert_text = _sign_public_key(
ca_private_key=ca.private_key,
ca_public_key=ca.public_key,
public_key=public_key,
identity=identity,
principal=principal,
serial=serial,
validity_days=1,
validity_override=f"+{settings.KEYWARDEN_SHELL_CERT_VALIDITY_MINUTES}m",
comment=identity,
)
cert_path = key_path + "-cert.pub"
with open(cert_path, "w", encoding="utf-8") as handle:
handle.write(cert_text + "\n")
# Public cert is safe to be world-readable.
os.chmod(cert_path, 0o644)
return key_path, cert_path
def _format_ssh_host(host: str) -> str:
# IPv6 hosts must be wrapped in brackets for the SSH CLI.
if ":" in host and not (host.startswith("[") and host.endswith("]")):
return f"[{host}]"
return host

View File

@@ -0,0 +1,73 @@
from django.conf import settings
from django.db import migrations, models
import django.utils.timezone
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
("servers", "0001_initial"),
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]
operations = [
migrations.AddField(
model_name="server",
name="agent_cert_fingerprint",
field=models.CharField(blank=True, max_length=128, null=True),
),
migrations.AddField(
model_name="server",
name="agent_cert_serial",
field=models.CharField(blank=True, max_length=64, null=True),
),
migrations.AddField(
model_name="server",
name="agent_enrolled_at",
field=models.DateTimeField(blank=True, null=True),
),
migrations.CreateModel(
name="EnrollmentToken",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("token", models.CharField(max_length=128, unique=True)),
("created_at", models.DateTimeField(default=django.utils.timezone.now, editable=False)),
("expires_at", models.DateTimeField(blank=True, null=True)),
("used_at", models.DateTimeField(blank=True, null=True)),
(
"created_by",
models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="server_enrollment_tokens",
to=settings.AUTH_USER_MODEL,
),
),
(
"server",
models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="enrollment_tokens",
to="servers.server",
),
),
],
options={
"verbose_name": "Enrollment token",
"verbose_name_plural": "Enrollment tokens",
"ordering": ["-created_at"],
},
),
migrations.AddIndex(
model_name="enrollmenttoken",
index=models.Index(fields=["created_at"], name="servers_enroll_created_idx"),
),
migrations.AddIndex(
model_name="enrollmenttoken",
index=models.Index(fields=["used_at"], name="servers_enroll_used_idx"),
),
]

View File

@@ -0,0 +1,44 @@
from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
import django.utils.timezone
class Migration(migrations.Migration):
dependencies = [
("servers", "0002_agent_enrollment"),
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]
operations = [
migrations.CreateModel(
name="AgentCertificateAuthority",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("name", models.CharField(default="Keywarden Agent CA", max_length=128)),
("cert_pem", models.TextField()),
("key_pem", models.TextField()),
("fingerprint", models.CharField(blank=True, max_length=128)),
("serial", models.CharField(blank=True, max_length=64)),
("created_at", models.DateTimeField(default=django.utils.timezone.now, editable=False)),
("revoked_at", models.DateTimeField(blank=True, null=True)),
("is_active", models.BooleanField(db_index=True, default=True)),
(
"created_by",
models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="agent_certificate_authorities",
to=settings.AUTH_USER_MODEL,
),
),
],
options={
"verbose_name": "Agent certificate authority",
"verbose_name_plural": "Agent certificate authorities",
"ordering": ["-created_at"],
},
),
]

View File

@@ -0,0 +1,59 @@
from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
import django.utils.timezone
class Migration(migrations.Migration):
dependencies = [
("servers", "0003_agent_ca"),
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]
operations = [
migrations.CreateModel(
name="ServerAccount",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("system_username", models.CharField(max_length=128)),
("is_present", models.BooleanField(db_index=True, default=False)),
("last_synced_at", models.DateTimeField(default=django.utils.timezone.now, editable=False)),
("created_at", models.DateTimeField(default=django.utils.timezone.now, editable=False)),
("updated_at", models.DateTimeField(auto_now=True)),
(
"server",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name="accounts",
to="servers.server",
),
),
(
"user",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name="server_accounts",
to=settings.AUTH_USER_MODEL,
),
),
],
options={
"verbose_name": "Server account",
"verbose_name_plural": "Server accounts",
"ordering": ["server_id", "user_id"],
},
),
migrations.AddConstraint(
model_name="serveraccount",
constraint=models.UniqueConstraint(fields=("server", "user"), name="unique_server_account"),
),
migrations.AddIndex(
model_name="serveraccount",
index=models.Index(fields=["server", "user"], name="servers_account_user_idx"),
),
migrations.AddIndex(
model_name="serveraccount",
index=models.Index(fields=["server", "is_present"], name="servers_account_present_idx"),
),
]

View File

@@ -0,0 +1,19 @@
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("servers", "0004_server_account"),
]
operations = [
migrations.AlterModelOptions(
name="server",
options={
"ordering": ["display_name", "hostname", "ipv4", "ipv6"],
"permissions": [("shell_server", "Can access server shell")],
"verbose_name": "Server",
"verbose_name_plural": "Servers",
},
),
]

View File

@@ -0,0 +1,35 @@
from django.db import migrations
def remove_user_group_server_perms(apps, schema_editor):
Group = apps.get_model("auth", "Group")
Permission = apps.get_model("auth", "Permission")
ContentType = apps.get_model("contenttypes", "ContentType")
GroupObjectPermission = apps.get_model("guardian", "GroupObjectPermission")
try:
group = Group.objects.get(name="user")
except Group.DoesNotExist:
return
try:
content_type = ContentType.objects.get(app_label="servers", model="server")
except ContentType.DoesNotExist:
return
perm_ids = Permission.objects.filter(content_type=content_type).values_list("id", flat=True)
GroupObjectPermission.objects.filter(
group_id=group.id,
permission_id__in=list(perm_ids),
).delete()
class Migration(migrations.Migration):
dependencies = [
("servers", "0005_server_shell_permission"),
("guardian", "0001_initial"),
]
operations = [
migrations.RunPython(remove_user_group_server_perms, migrations.RunPython.noop),
]

View File

@@ -0,0 +1,20 @@
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("servers", "0006_remove_user_group_server_perms"),
]
operations = [
migrations.AddField(
model_name="server",
name="ssh_host_public_key",
field=models.TextField(blank=True),
),
migrations.AddField(
model_name="server",
name="ssh_host_fingerprint",
field=models.CharField(blank=True, max_length=128),
),
]

View File

@@ -0,0 +1,18 @@
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("servers", "0007_server_host_key"),
]
operations = [
migrations.RemoveField(
model_name="server",
name="ssh_host_fingerprint",
),
migrations.RemoveField(
model_name="server",
name="ssh_host_public_key",
),
]

View File

@@ -0,0 +1,21 @@
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("servers", "0008_remove_server_host_key"),
]
operations = [
migrations.AddField(
model_name="server",
name="last_heartbeat_at",
field=models.DateTimeField(blank=True, db_index=True, null=True),
),
migrations.AddField(
model_name="server",
name="last_ping_ms",
field=models.PositiveIntegerField(blank=True, null=True),
),
]

View File

@@ -1,8 +1,16 @@
from __future__ import annotations from __future__ import annotations
import secrets
from datetime import datetime, timedelta
from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509.oid import NameOID
from django.conf import settings
from django.core.validators import RegexValidator from django.core.validators import RegexValidator
from django.db import models from django.db import models
from django.utils.text import slugify from django.utils import timezone
hostname_validator = RegexValidator( hostname_validator = RegexValidator(
@@ -17,6 +25,11 @@ class Server(models.Model):
ipv4 = models.GenericIPAddressField(null=True, blank=True, protocol="IPv4", unique=True) ipv4 = models.GenericIPAddressField(null=True, blank=True, protocol="IPv4", unique=True)
ipv6 = models.GenericIPAddressField(null=True, blank=True, protocol="IPv6", unique=True) ipv6 = models.GenericIPAddressField(null=True, blank=True, protocol="IPv6", unique=True)
image = models.ImageField(upload_to="servers/", null=True, blank=True) image = models.ImageField(upload_to="servers/", null=True, blank=True)
agent_enrolled_at = models.DateTimeField(null=True, blank=True)
agent_cert_fingerprint = models.CharField(max_length=128, null=True, blank=True)
agent_cert_serial = models.CharField(max_length=64, null=True, blank=True)
last_heartbeat_at = models.DateTimeField(null=True, blank=True, db_index=True)
last_ping_ms = models.PositiveIntegerField(null=True, blank=True)
created_at = models.DateTimeField(auto_now_add=True) created_at = models.DateTimeField(auto_now_add=True)
updated_at = models.DateTimeField(auto_now=True) updated_at = models.DateTimeField(auto_now=True)
@@ -24,6 +37,9 @@ class Server(models.Model):
ordering = ["display_name", "hostname", "ipv4", "ipv6"] ordering = ["display_name", "hostname", "ipv4", "ipv6"]
verbose_name = "Server" verbose_name = "Server"
verbose_name_plural = "Servers" verbose_name_plural = "Servers"
permissions = [
("shell_server", "Can access server shell"),
]
def __str__(self) -> str: def __str__(self) -> str:
primary = self.hostname or self.ipv4 or self.ipv6 or "unassigned" primary = self.hostname or self.ipv4 or self.ipv6 or "unassigned"
@@ -41,3 +57,135 @@ class Server(models.Model):
return (self.display_name or "?").strip()[:1].upper() or "?" return (self.display_name or "?").strip()[:1].upper() or "?"
class EnrollmentToken(models.Model):
token = models.CharField(max_length=128, unique=True)
created_at = models.DateTimeField(default=timezone.now, editable=False)
expires_at = models.DateTimeField(null=True, blank=True)
created_by = models.ForeignKey(
settings.AUTH_USER_MODEL,
null=True,
blank=True,
on_delete=models.SET_NULL,
related_name="server_enrollment_tokens",
)
used_at = models.DateTimeField(null=True, blank=True)
server = models.ForeignKey(
Server, null=True, blank=True, on_delete=models.SET_NULL, related_name="enrollment_tokens"
)
class Meta:
verbose_name = "Enrollment token"
verbose_name_plural = "Enrollment tokens"
indexes = [
models.Index(fields=["created_at"], name="servers_enroll_created_idx"),
models.Index(fields=["used_at"], name="servers_enroll_used_idx"),
]
ordering = ["-created_at"]
def __str__(self) -> str:
return f"{self.token[:8]}... ({'used' if self.used_at else 'unused'})"
def ensure_token(self) -> None:
if not self.token:
self.token = secrets.token_urlsafe(32)
def is_valid(self) -> bool:
if self.used_at:
return False
if self.expires_at and self.expires_at <= timezone.now():
return False
return True
def mark_used(self, server: Server) -> None:
self.used_at = timezone.now()
self.server = server
def save(self, *args, **kwargs):
self.ensure_token()
super().save(*args, **kwargs)
class AgentCertificateAuthority(models.Model):
name = models.CharField(max_length=128, default="Keywarden Agent CA")
cert_pem = models.TextField()
key_pem = models.TextField()
fingerprint = models.CharField(max_length=128, blank=True)
serial = models.CharField(max_length=64, blank=True)
created_at = models.DateTimeField(default=timezone.now, editable=False)
revoked_at = models.DateTimeField(null=True, blank=True)
is_active = models.BooleanField(default=True, db_index=True)
created_by = models.ForeignKey(
settings.AUTH_USER_MODEL,
null=True,
blank=True,
on_delete=models.SET_NULL,
related_name="agent_certificate_authorities",
)
class Meta:
verbose_name = "Agent certificate authority"
verbose_name_plural = "Agent certificate authorities"
ordering = ["-created_at"]
def __str__(self) -> str:
status = "active" if self.is_active and not self.revoked_at else "revoked"
return f"{self.name} ({status})"
def revoke(self) -> None:
self.is_active = False
self.revoked_at = timezone.now()
def ensure_material(self) -> None:
if self.cert_pem and self.key_pem:
return
key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
subject = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, self.name)])
now = datetime.utcnow()
cert = (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(subject)
.public_key(key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(now - timedelta(minutes=5))
.not_valid_after(now + timedelta(days=3650))
.add_extension(x509.BasicConstraints(ca=True, path_length=None), critical=True)
.sign(key, hashes.SHA256())
)
cert_pem = cert.public_bytes(serialization.Encoding.PEM).decode("utf-8")
key_pem = key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
).decode("utf-8")
self.cert_pem = cert_pem
self.key_pem = key_pem
self.fingerprint = cert.fingerprint(hashes.SHA256()).hex()
self.serial = format(cert.serial_number, "x")
class ServerAccount(models.Model):
server = models.ForeignKey(Server, on_delete=models.CASCADE, related_name="accounts")
user = models.ForeignKey(
settings.AUTH_USER_MODEL, on_delete=models.CASCADE, related_name="server_accounts"
)
system_username = models.CharField(max_length=128)
is_present = models.BooleanField(default=False, db_index=True)
last_synced_at = models.DateTimeField(default=timezone.now, editable=False)
created_at = models.DateTimeField(default=timezone.now, editable=False)
updated_at = models.DateTimeField(auto_now=True)
class Meta:
verbose_name = "Server account"
verbose_name_plural = "Server accounts"
constraints = [
models.UniqueConstraint(fields=["server", "user"], name="unique_server_account")
]
indexes = [
models.Index(fields=["server", "user"], name="servers_account_user_idx"),
models.Index(fields=["server", "is_present"], name="servers_account_present_idx"),
]
ordering = ["server_id", "user_id"]
def __str__(self) -> str:
return f"{self.system_username} ({self.server_id})"

View File

@@ -0,0 +1,23 @@
from __future__ import annotations
from django.db.models import Q
from django.utils import timezone
from apps.access.models import AccessRequest
def user_can_shell(user, server, now=None) -> bool:
if user.has_perm("servers.shell_server", server):
return True
if now is None:
now = timezone.now()
return (
AccessRequest.objects.filter(
requester=user,
server=server,
status=AccessRequest.Status.APPROVED,
request_shell=True,
)
.filter(Q(expires_at__isnull=True) | Q(expires_at__gt=now))
.exists()
)

View File

@@ -0,0 +1,14 @@
from __future__ import annotations
from django.db.models.signals import post_save
from django.dispatch import receiver
from apps.core.rbac import assign_default_object_permissions
from .models import Server
@receiver(post_save, sender=Server)
def assign_server_perms(sender, instance: Server, created: bool, **kwargs) -> None:
if not created:
return
assign_default_object_permissions(instance)

View File

@@ -0,0 +1,114 @@
<div class="space-y-4">
<nav class="flex" aria-label="Breadcrumb">
<ol class="inline-flex items-center space-x-1 text-sm text-gray-500">
<li class="inline-flex items-center">
<a href="{% url 'servers:dashboard' %}" class="inline-flex items-center gap-1 font-medium text-gray-600 hover:text-blue-700">
<svg class="h-4 w-4" aria-hidden="true" fill="currentColor" viewBox="0 0 20 20">
<path d="M10 3.172 2 10v7a1 1 0 0 0 1 1h5v-5h4v5h5a1 1 0 0 0 1-1v-7l-8-6.828Z"></path>
</svg>
Servers
</a>
</li>
<li class="inline-flex items-center">
<svg class="h-4 w-4 text-gray-400" aria-hidden="true" fill="currentColor" viewBox="0 0 20 20">
<path d="M7.05 4.55a1 1 0 0 1 1.4-1.42l6 5.9a1 1 0 0 1 0 1.42l-6 5.9a1 1 0 1 1-1.4-1.42L12.5 10 7.05 4.55Z"></path>
</svg>
<span class="ml-1 font-medium text-gray-700">{{ server.display_name }}</span>
</li>
</ol>
</nav>
<div class="flex flex-col gap-4 rounded-2xl border border-gray-200 bg-white p-5 shadow-sm sm:flex-row sm:items-center sm:justify-between">
<div class="flex items-center gap-4">
<div class="flex h-12 w-12 items-center justify-center rounded-2xl bg-blue-700 text-lg font-semibold text-white shadow-sm">
{{ server.initial }}
</div>
<div>
<h1 class="text-2xl font-semibold tracking-tight text-gray-900">{{ server.display_name }}</h1>
<p class="text-sm text-gray-500">
{{ server.hostname|default:server.ipv4|default:server.ipv6|default:"Unassigned" }}
</p>
</div>
</div>
<div class="flex items-center gap-2">
<div class="relative">
<button
type="button"
data-tooltip-target="server-header-status-{{ server.id }}"
class="{% if server_status.is_active %}inline-flex items-center rounded-full bg-emerald-50 px-2.5 py-1 text-xs font-semibold text-emerald-700{% else %}inline-flex items-center rounded-full bg-rose-50 px-2.5 py-1 text-xs font-semibold text-rose-700{% endif %}"
>
{{ server_status.label }}: {{ server_status.detail }}
</button>
<div
id="server-header-status-{{ server.id }}"
role="tooltip"
class="invisible absolute z-10 inline-block w-64 rounded-lg border border-gray-200 bg-white p-3 text-xs text-gray-700 shadow-sm opacity-0 transition-opacity"
>
<div class="space-y-1">
<div class="flex items-center justify-between">
<span class="font-semibold text-gray-500">Status</span>
<span class="font-medium text-gray-900">{{ server_status.label }}: {{ server_status.detail }}</span>
</div>
<div class="flex items-center justify-between">
<span class="font-semibold text-gray-500">Ping</span>
<span class="font-medium text-gray-900">
{% if server_status.ping_ms is not None %}{{ server_status.ping_ms }}ms{% else %}—{% endif %}
</span>
</div>
<div class="flex items-center justify-between">
<span class="font-semibold text-gray-500">Hostname</span>
<span class="font-medium text-gray-900">{{ server.hostname|default:"—" }}</span>
</div>
<div class="flex items-center justify-between">
<span class="font-semibold text-gray-500">IPv4</span>
<span class="font-medium text-gray-900">{{ server.ipv4|default:"—" }}</span>
</div>
<div class="flex items-center justify-between">
<span class="font-semibold text-gray-500">IPv6</span>
<span class="font-medium text-gray-900">{{ server.ipv6|default:"—" }}</span>
</div>
<div class="flex items-center justify-between">
<span class="font-semibold text-gray-500">Last heartbeat</span>
<span class="font-medium text-gray-900">
{% if server_status.heartbeat_at %}{{ server_status.heartbeat_at|date:"M j, Y H:i:s" }}{% else %}—{% endif %}
</span>
</div>
</div>
<div class="tooltip-arrow" data-popper-arrow></div>
</div>
</div>
<a href="{% url 'servers:dashboard' %}" class="inline-flex items-center rounded-lg border border-gray-200 bg-white px-3 py-2 text-xs font-semibold text-gray-700 hover:bg-gray-50">
Back to servers
</a>
</div>
</div>
</div>
<nav class="mt-4 flex flex-wrap gap-2 border-b border-gray-200 pb-3 text-sm font-medium text-gray-500">
<a
href="{% url 'servers:detail' server.id %}"
class="{% if active_tab == 'details' %}rounded-full bg-blue-50 px-4 py-1.5 text-blue-700 ring-1 ring-blue-100{% else %}rounded-full bg-gray-100 px-4 py-1.5 text-gray-600 hover:bg-white hover:text-gray-900{% endif %}"
>
Details
</a>
<a
href="{% url 'servers:audit' server.id %}"
class="{% if active_tab == 'audit' %}rounded-full bg-blue-50 px-4 py-1.5 text-blue-700 ring-1 ring-blue-100{% else %}rounded-full bg-gray-100 px-4 py-1.5 text-gray-600 hover:bg-white hover:text-gray-900{% endif %}"
>
Audit
</a>
{% if can_shell %}
<a
href="{% url 'servers:shell' server.id %}"
class="{% if active_tab == 'shell' %}rounded-full bg-blue-50 px-4 py-1.5 text-blue-700 ring-1 ring-blue-100{% else %}rounded-full bg-gray-100 px-4 py-1.5 text-gray-600 hover:bg-white hover:text-gray-900{% endif %}"
>
Shell
</a>
{% endif %}
<a
href="{% url 'servers:settings' server.id %}"
class="{% if active_tab == 'settings' %}rounded-full bg-blue-50 px-4 py-1.5 text-blue-700 ring-1 ring-blue-100{% else %}rounded-full bg-gray-100 px-4 py-1.5 text-gray-600 hover:bg-white hover:text-gray-900{% endif %}"
>
Settings
</a>
</nav>

View File

@@ -0,0 +1,46 @@
{% extends "base.html" %}
{% block title %}Audit • {{ server.display_name }} • Keywarden{% endblock %}
{% block content %}
<div class="space-y-6">
{% include "servers/_header.html" %}
<section class="rounded-2xl border border-gray-200 bg-white p-6 shadow-sm">
<div class="flex flex-wrap items-start justify-between gap-4">
<div>
<h2 class="text-lg font-semibold text-gray-900">Audit logs</h2>
<p class="mt-1 text-sm text-gray-500">Track certificate issuance and access events.</p>
</div>
<span class="inline-flex items-center rounded-full bg-gray-100 px-2.5 py-1 text-xs font-semibold text-gray-700">Placeholder</span>
</div>
<div class="mt-5 rounded-xl border border-dashed border-gray-200 bg-gray-50 p-6 text-center">
<div class="mx-auto flex h-12 w-12 items-center justify-center rounded-full bg-blue-50 text-blue-700">
<svg class="h-6 w-6" aria-hidden="true" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="M12 6v6l4 2" />
<circle cx="12" cy="12" r="9" stroke-width="1.5"></circle>
</svg>
</div>
<p class="mt-3 text-sm text-gray-600">Logs will appear here once collection is enabled for this server.</p>
</div>
</section>
<section class="rounded-2xl border border-gray-200 bg-white p-6 shadow-sm">
<div class="flex flex-wrap items-start justify-between gap-4">
<div>
<h2 class="text-lg font-semibold text-gray-900">Metrics</h2>
<p class="mt-1 text-sm text-gray-500">Monitor CPU, memory, and session activity.</p>
</div>
<span class="inline-flex items-center rounded-full bg-gray-100 px-2.5 py-1 text-xs font-semibold text-gray-700">Placeholder</span>
</div>
<div class="mt-5 rounded-xl border border-dashed border-gray-200 bg-gray-50 p-6 text-center">
<div class="mx-auto flex h-12 w-12 items-center justify-center rounded-full bg-blue-50 text-blue-700">
<svg class="h-6 w-6" aria-hidden="true" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="M3 12h18M7 8v8M17 8v8" />
</svg>
</div>
<p class="mt-3 text-sm text-gray-600">Metrics will appear here once collection is enabled for this server.</p>
</div>
</section>
</div>
{% endblock %}

View File

@@ -0,0 +1,124 @@
{% extends "base.html" %}
{% block title %}Servers • Keywarden{% endblock %}
{% block content %}
<div class="space-y-6">
<div class="flex flex-wrap items-center justify-between gap-4">
<div>
<h1 class="text-2xl font-semibold tracking-tight text-gray-900">Servers</h1>
<p class="mt-1 text-sm text-gray-500">Review the servers you can access and their certificate status.</p>
</div>
<span class="inline-flex items-center rounded-full bg-blue-50 px-3 py-1 text-xs font-semibold text-blue-700">
{{ servers|length }} total
</span>
</div>
{% if servers %}
<div class="grid gap-5 sm:grid-cols-2 lg:grid-cols-3">
{% for item in servers %}
<article class="flex h-full flex-col rounded-2xl border border-gray-200 bg-white p-5 shadow-sm transition hover:-translate-y-0.5 hover:shadow-md">
<div class="flex items-start justify-between">
<div class="flex items-center gap-3">
<div class="flex h-10 w-10 items-center justify-center rounded-xl bg-blue-700 text-sm font-semibold text-white">
{{ item.server.initial }}
</div>
<div>
<h2 class="text-lg font-semibold text-gray-900">{{ item.server.display_name }}</h2>
<p class="text-xs text-gray-500">
{{ item.server.hostname|default:item.server.ipv4|default:item.server.ipv6|default:"Unassigned" }}
</p>
</div>
</div>
<div class="relative">
<button
type="button"
data-tooltip-target="server-status-{{ item.server.id }}"
class="{% if item.status.is_active %}inline-flex items-center rounded-full bg-emerald-50 px-2.5 py-1 text-xs font-semibold text-emerald-700{% else %}inline-flex items-center rounded-full bg-rose-50 px-2.5 py-1 text-xs font-semibold text-rose-700{% endif %}"
>
{{ item.status.label }}: {{ item.status.detail }}
</button>
<div
id="server-status-{{ item.server.id }}"
role="tooltip"
class="invisible absolute z-10 inline-block w-64 rounded-lg border border-gray-200 bg-white p-3 text-xs text-gray-700 shadow-sm opacity-0 transition-opacity"
>
<div class="space-y-1">
<div class="flex items-center justify-between">
<span class="font-semibold text-gray-500">Status</span>
<span class="font-medium text-gray-900">{{ item.status.label }}: {{ item.status.detail }}</span>
</div>
<div class="flex items-center justify-between">
<span class="font-semibold text-gray-500">Ping</span>
<span class="font-medium text-gray-900">
{% if item.status.ping_ms is not None %}{{ item.status.ping_ms }}ms{% else %}—{% endif %}
</span>
</div>
<div class="flex items-center justify-between">
<span class="font-semibold text-gray-500">Hostname</span>
<span class="font-medium text-gray-900">{{ item.server.hostname|default:"—" }}</span>
</div>
<div class="flex items-center justify-between">
<span class="font-semibold text-gray-500">IPv4</span>
<span class="font-medium text-gray-900">{{ item.server.ipv4|default:"—" }}</span>
</div>
<div class="flex items-center justify-between">
<span class="font-semibold text-gray-500">IPv6</span>
<span class="font-medium text-gray-900">{{ item.server.ipv6|default:"—" }}</span>
</div>
<div class="flex items-center justify-between">
<span class="font-semibold text-gray-500">Last heartbeat</span>
<span class="font-medium text-gray-900">
{% if item.status.heartbeat_at %}{{ item.status.heartbeat_at|date:"M j, Y H:i:s" }}{% else %}—{% endif %}
</span>
</div>
</div>
<div class="tooltip-arrow" data-popper-arrow></div>
</div>
</div>
</div>
<dl class="mt-5 divide-y divide-gray-100 text-sm text-gray-600">
<div class="flex items-center justify-between py-2">
<dt>Access until</dt>
<dd class="font-medium text-gray-900">
{% if item.expires_at %}
{{ item.expires_at|date:"M j, Y H:i" }}
{% else %}
No expiry
{% endif %}
</dd>
</div>
<div class="flex items-center justify-between py-2">
<dt>Last accessed</dt>
<dd class="font-medium text-gray-900">
{% if item.last_accessed %}
{{ item.last_accessed|date:"M j, Y H:i" }}
{% else %}
{% endif %}
</dd>
</div>
</dl>
<div class="mt-5 flex items-center justify-between border-t border-gray-100 pt-4 text-xs text-gray-500">
<span>Certificates and access</span>
<a href="{% url 'servers:detail' item.server.id %}" class="font-semibold text-blue-700 hover:underline">View details</a>
</div>
</article>
{% endfor %}
</div>
{% else %}
<div class="rounded-2xl border border-dashed border-gray-200 bg-white p-10 text-center">
<div class="mx-auto flex h-12 w-12 items-center justify-center rounded-full bg-blue-50 text-blue-700">
<svg class="h-6 w-6" aria-hidden="true" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="M12 6v6l4 2" />
<circle cx="12" cy="12" r="9" stroke-width="1.5"></circle>
</svg>
</div>
<h2 class="mt-4 text-lg font-semibold text-gray-900">No server access yet</h2>
<p class="mt-2 text-sm text-gray-600">Request access to a server to see it listed here.</p>
</div>
{% endif %}
</div>
{% endblock %}

View File

@@ -0,0 +1,188 @@
{% extends "base.html" %}
{% block title %}{{ server.display_name }} • Keywarden{% endblock %}
{% block content %}
<div class="space-y-6">
{% include "servers/_header.html" %}
<section class="grid gap-4 lg:grid-cols-3">
<div class="rounded-2xl border border-gray-200 bg-white p-6 shadow-sm">
<h2 class="text-lg font-semibold text-gray-900">Server details</h2>
<dl class="mt-4 divide-y divide-gray-100 text-sm text-gray-600">
<div class="flex items-center justify-between py-2">
<dt>Hostname</dt>
<dd class="font-medium text-gray-900">{{ server.hostname|default:"—" }}</dd>
</div>
<div class="flex items-center justify-between py-2">
<dt>IPv4</dt>
<dd class="font-medium text-gray-900">{{ server.ipv4|default:"—" }}</dd>
</div>
<div class="flex items-center justify-between py-2">
<dt>IPv6</dt>
<dd class="font-medium text-gray-900">{{ server.ipv6|default:"—" }}</dd>
</div>
</dl>
</div>
<div class="rounded-2xl border border-gray-200 bg-white p-6 shadow-sm lg:col-span-2">
<div class="flex flex-wrap items-start justify-between gap-4">
<div>
<h2 class="text-lg font-semibold text-gray-900">Account & certificate</h2>
<p class="mt-1 text-sm text-gray-500">Credentials and certificate download options.</p>
</div>
<span class="inline-flex items-center rounded-full bg-blue-50 px-2.5 py-1 text-xs font-semibold text-blue-700">Access</span>
</div>
<dl class="mt-4 divide-y divide-gray-100 text-sm text-gray-600">
<div class="flex items-center justify-between py-2">
<dt>Account name</dt>
<dd class="font-medium text-gray-900">
{% if system_username %}
{{ system_username }}
{% else %}
Unknown
{% endif %}
</dd>
</div>
<div class="flex items-center justify-between py-2">
<dt>Account status</dt>
<dd class="font-medium text-gray-900">
{% if account_present is None %}
Unknown
{% elif account_present %}
Present
{% else %}
Not on server
{% endif %}
</dd>
</div>
<div class="flex flex-col gap-3 py-2 sm:flex-row sm:items-center sm:justify-between">
<dt>Certificate</dt>
<dd class="font-medium text-gray-900">
{% if certificate_key_id %}
<div class="flex flex-wrap items-center gap-2">
<div class="inline-flex rounded-lg shadow-sm" role="group">
<button
type="button"
class="inline-flex items-center rounded-l-lg bg-blue-700 px-3 py-1.5 text-xs font-semibold text-white hover:bg-blue-800 focus:outline-none focus:ring-2 focus:ring-blue-300"
data-download-url="/api/v1/keys/{{ certificate_key_id }}/certificate"
>
Download
</button>
<button
type="button"
class="inline-flex items-center rounded-r-lg border border-gray-200 bg-white px-3 py-1.5 text-xs font-semibold text-gray-700 hover:bg-gray-50 focus:outline-none focus:ring-2 focus:ring-blue-300"
data-download-url="/api/v1/keys/{{ certificate_key_id }}/certificate.sha256"
>
Hash
</button>
</div>
<button
type="button"
class="inline-flex items-center rounded-lg bg-rose-600 px-3 py-1.5 text-xs font-semibold text-white hover:bg-rose-700 focus:outline-none focus:ring-2 focus:ring-rose-300 js-regenerate-cert"
data-key-id="{{ certificate_key_id }}"
>
Regenerate
</button>
</div>
{% else %}
<span class="text-xs font-semibold text-gray-500">Upload a key to download</span>
{% endif %}
</dd>
</div>
</dl>
</div>
<div class="rounded-2xl border border-gray-200 bg-white p-6 shadow-sm lg:col-span-3">
<div class="flex flex-wrap items-start justify-between gap-4">
<div>
<h2 class="text-lg font-semibold text-gray-900">Access</h2>
<p class="mt-1 text-sm text-gray-500">Review access windows and last usage.</p>
</div>
<span class="inline-flex items-center rounded-full bg-gray-100 px-2.5 py-1 text-xs font-semibold text-gray-700">Usage</span>
</div>
<dl class="mt-4 divide-y divide-gray-100 text-sm text-gray-600">
<div class="flex items-center justify-between py-2">
<dt>Access until</dt>
<dd class="font-medium text-gray-900">
{% if expires_at %}
{{ expires_at|date:"M j, Y H:i" }}
{% else %}
No expiry
{% endif %}
</dd>
</div>
<div class="flex items-center justify-between py-2">
<dt>Last accessed</dt>
<dd class="font-medium text-gray-900">
{% if last_accessed %}
{{ last_accessed|date:"M j, Y H:i" }}
{% else %}
{% endif %}
</dd>
</div>
</dl>
</div>
</section>
</div>
<script>
(function () {
function getCookie(name) {
var value = "; " + document.cookie;
var parts = value.split("; " + name + "=");
if (parts.length === 2) {
return parts.pop().split(";").shift();
}
return "";
}
function handleDownload(event) {
var button = event.currentTarget;
var url = button.getAttribute("data-download-url");
if (!url) {
return;
}
window.location.href = url;
}
function handleRegenerate(event) {
var button = event.currentTarget;
var keyId = button.getAttribute("data-key-id");
if (!keyId) {
return;
}
if (!window.confirm("Regenerate the certificate for this key?")) {
return;
}
var csrf = getCookie("csrftoken");
fetch("/api/v1/keys/" + keyId + "/certificate", {
method: "POST",
credentials: "same-origin",
headers: {
"X-CSRFToken": csrf,
},
})
.then(function (response) {
if (!response.ok) {
throw new Error("Certificate regeneration failed.");
}
window.alert("Certificate regenerated.");
})
.catch(function (err) {
window.alert(err.message);
});
}
var downloadButtons = document.querySelectorAll("[data-download-url]");
for (var i = 0; i < downloadButtons.length; i += 1) {
downloadButtons[i].addEventListener("click", handleDownload);
}
var buttons = document.querySelectorAll(".js-regenerate-cert");
for (var j = 0; j < buttons.length; j += 1) {
buttons[j].addEventListener("click", handleRegenerate);
}
})();
</script>
{% endblock %}

View File

@@ -0,0 +1,28 @@
{% extends "base.html" %}
{% block title %}Settings • {{ server.display_name }} • Keywarden{% endblock %}
{% block content %}
<div class="space-y-6">
{% include "servers/_header.html" %}
<section class="rounded-2xl border border-gray-200 bg-white p-6 shadow-sm">
<div class="flex flex-wrap items-start justify-between gap-4">
<div>
<h2 class="text-lg font-semibold text-gray-900">Settings</h2>
<p class="mt-1 text-sm text-gray-500">Manage server-level access policies and metadata.</p>
</div>
<span class="inline-flex items-center rounded-full bg-gray-100 px-2.5 py-1 text-xs font-semibold text-gray-700">Placeholder</span>
</div>
<div class="mt-5 rounded-xl border border-dashed border-gray-200 bg-gray-50 p-6 text-center">
<div class="mx-auto flex h-12 w-12 items-center justify-center rounded-full bg-blue-50 text-blue-700">
<svg class="h-6 w-6" aria-hidden="true" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="M12 6v6l4 2" />
<circle cx="12" cy="12" r="9" stroke-width="1.5"></circle>
</svg>
</div>
<p class="mt-3 text-sm text-gray-600">Settings will appear here as server options are added.</p>
</div>
</section>
</div>
{% endblock %}

View File

@@ -0,0 +1,389 @@
{% extends "base.html" %}
{% load static %}
{% block title %}Shell • {{ server.display_name }} • Keywarden{% endblock %}
{% block extra_head %}
<link rel="stylesheet" href="{% static 'vendor/xterm/xterm.css' %}">
{% if is_popout %}
<style>
body.popout-shell main {
max-width: none !important;
padding: 0 !important;
}
</style>
{% endif %}
{% endblock %}
{% block content %}
{% if is_popout %}
<div class="w-screen">
<div id="shell-popout-shell" class="w-full rounded-2xl border border-gray-200 bg-slate-950 shadow-sm">
<div id="shell-terminal" class="h-full w-full p-3"></div>
</div>
</div>
{% else %}
<div class="space-y-6">
{% include "servers/_header.html" %}
<section class="rounded-2xl border border-gray-200 bg-white p-6 shadow-sm">
<div class="flex flex-wrap items-center justify-between gap-3">
<div>
<h2 class="text-lg font-semibold text-gray-900">Shell access</h2>
<p class="mt-1 text-sm text-gray-500">
Connect with your private key and the signed certificate for this server.
</p>
</div>
<button
type="button"
class="inline-flex items-center rounded-lg border border-gray-200 bg-white px-3 py-2 text-xs font-semibold text-gray-700 hover:bg-gray-50 focus:outline-none focus:ring-2 focus:ring-blue-300"
id="shell-popout"
data-popout-url="{% url 'servers:shell' server.id %}?popout=1"
>
Pop out terminal
</button>
</div>
<div class="mt-5 grid gap-4 text-sm text-gray-600 lg:grid-cols-2">
<dl class="space-y-4">
<div class="flex items-center justify-between">
<dt>Account name</dt>
<dd class="font-medium text-gray-900">
{% if system_username %}
{{ system_username }}
{% else %}
Unknown
{% endif %}
</dd>
</div>
<div class="flex items-center justify-between">
<dt>Host</dt>
<dd class="font-medium text-gray-900">
{% if shell_target %}
{{ shell_target }}
{% else %}
Unknown
{% endif %}
</dd>
</div>
</dl>
{% if shell_command %}
<div class="rounded-xl border border-gray-200 bg-gray-50 p-4">
<div class="flex flex-wrap items-center justify-between gap-2">
<span class="text-xs font-semibold uppercase tracking-wide text-gray-500">SSH command</span>
<button
type="button"
class="text-xs font-semibold text-blue-700 hover:underline"
data-copy-target="shell-command"
>
Copy command
</button>
</div>
<code class="mt-3 block break-all rounded-lg bg-white p-3 text-xs text-gray-800" id="shell-command">{{ shell_command }}</code>
<div class="mt-4 flex flex-wrap items-center gap-2">
{% if certificate_key_id %}
<div class="inline-flex rounded-lg shadow-sm" role="group">
<button
type="button"
class="inline-flex items-center rounded-l-lg bg-blue-700 px-3 py-1.5 text-xs font-semibold text-white hover:bg-blue-800 focus:outline-none focus:ring-2 focus:ring-blue-300"
data-download-url="/api/v1/keys/{{ certificate_key_id }}/certificate"
>
Download
</button>
<button
type="button"
class="inline-flex items-center rounded-r-lg border border-gray-200 bg-white px-3 py-1.5 text-xs font-semibold text-gray-700 hover:bg-gray-50 focus:outline-none focus:ring-2 focus:ring-blue-300"
data-download-url="/api/v1/keys/{{ certificate_key_id }}/certificate.sha256"
>
Hash
</button>
</div>
<button
type="button"
class="inline-flex items-center rounded-lg bg-rose-600 px-3 py-1.5 text-xs font-semibold text-white hover:bg-rose-700 focus:outline-none focus:ring-2 focus:ring-rose-300 js-regenerate-cert"
data-key-id="{{ certificate_key_id }}"
>
Regenerate
</button>
{% endif %}
<span class="text-xs text-gray-500">Use the command above for local SSH.</span>
</div>
</div>
{% else %}
<p class="text-sm text-gray-600">Upload a key to enable downloads and a local SSH command.</p>
{% endif %}
</div>
</section>
<section class="rounded-2xl border border-gray-200 bg-white p-6 shadow-sm">
<div class="flex flex-wrap items-center justify-between gap-3">
<div>
<h2 class="text-lg font-semibold text-gray-900">Browser terminal</h2>
<p class="mt-1 text-sm text-gray-500">
Launch a proxied terminal session to the target host in your browser.
</p>
</div>
<div class="flex items-center gap-2">
<span class="inline-flex items-center rounded-full bg-amber-100 px-2.5 py-1 text-xs font-semibold text-amber-800">Beta</span>
<button
type="button"
class="inline-flex items-center rounded-lg bg-blue-700 px-3 py-2 text-xs font-semibold text-white hover:bg-blue-800 focus:outline-none focus:ring-4 focus:ring-blue-300"
id="shell-start"
>
Start terminal
</button>
</div>
</div>
<div class="mt-4 rounded-xl border border-gray-200 bg-slate-950 p-2">
<div id="shell-terminal" class="h-96"></div>
</div>
<p class="mt-3 text-xs text-gray-500">
Sessions are proxied through Keywarden and end when this page closes.
</p>
</section>
</div>
{% endif %}
<script src="{% static 'vendor/xterm/xterm.js' %}"></script>
<script>
(function () {
function getCookie(name) {
var value = "; " + document.cookie;
var parts = value.split("; " + name + "=");
if (parts.length === 2) {
return parts.pop().split(";").shift();
}
return "";
}
function handleDownload(event) {
var button = event.currentTarget;
var url = button.getAttribute("data-download-url");
if (!url) {
return;
}
window.location.href = url;
}
function handleRegenerate(event) {
var button = event.currentTarget;
var keyId = button.getAttribute("data-key-id");
if (!keyId) {
return;
}
if (!window.confirm("Regenerate the certificate for this key?")) {
return;
}
var csrf = getCookie("csrftoken");
fetch("/api/v1/keys/" + keyId + "/certificate", {
method: "POST",
credentials: "same-origin",
headers: {
"X-CSRFToken": csrf,
},
})
.then(function (response) {
if (!response.ok) {
throw new Error("Certificate regeneration failed.");
}
window.alert("Certificate regenerated.");
})
.catch(function (err) {
window.alert(err.message);
});
}
function handleCopy(event) {
var targetId = event.currentTarget.getAttribute("data-copy-target");
if (!targetId) {
return;
}
var node = document.getElementById(targetId);
if (!node) {
return;
}
var text = node.textContent || "";
if (!navigator.clipboard || !text) {
return;
}
navigator.clipboard.writeText(text).then(function () {
window.alert("Command copied.");
});
}
var popout = document.getElementById("shell-popout");
if (popout) {
popout.addEventListener("click", function () {
var url = popout.getAttribute("data-popout-url");
if (!url) {
return;
}
window.open(url, "_blank", "width=900,height=700");
});
}
var downloadButtons = document.querySelectorAll("[data-download-url]");
for (var i = 0; i < downloadButtons.length; i += 1) {
downloadButtons[i].addEventListener("click", handleDownload);
}
var buttons = document.querySelectorAll(".js-regenerate-cert");
for (var j = 0; j < buttons.length; j += 1) {
buttons[j].addEventListener("click", handleRegenerate);
}
var copyButtons = document.querySelectorAll("[data-copy-target]");
for (var k = 0; k < copyButtons.length; k += 1) {
copyButtons[k].addEventListener("click", handleCopy);
}
var termContainer = document.getElementById("shell-terminal");
var startButton = document.getElementById("shell-start");
var activeSocket = null;
var activeTerm = null;
var popoutShell = document.getElementById("shell-popout-shell");
var isPopout = {{ is_popout|yesno:"true,false" }};
function sizePopoutTerminal() {
if (!isPopout || !popoutShell || !termContainer) {
return;
}
var padding = 24;
var height = Math.max(320, window.innerHeight - padding);
popoutShell.style.height = height + "px";
termContainer.style.height = (height - 8) + "px";
}
function fitTerminal(term) {
if (!termContainer || !term || !term._core || !term._core._renderService) {
return;
}
var dims = term._core._renderService.dimensions;
if (!dims || !dims.css || !dims.css.cell) {
return;
}
var cellWidth = dims.css.cell.width || 9;
var cellHeight = dims.css.cell.height || 18;
if (!cellWidth || !cellHeight) {
return;
}
var cols = Math.max(20, Math.floor(termContainer.clientWidth / cellWidth));
var rows = Math.max(10, Math.floor(termContainer.clientHeight / cellHeight));
term.resize(cols, rows);
}
function setButtonState(isRunning) {
if (!startButton) {
return;
}
startButton.disabled = false;
startButton.textContent = isRunning ? "Stop terminal" : "Start terminal";
startButton.classList.toggle("bg-red-600", isRunning);
startButton.classList.toggle("hover:bg-red-700", isRunning);
startButton.classList.toggle("bg-blue-700", !isRunning);
startButton.classList.toggle("hover:bg-blue-800", !isRunning);
}
function stopTerminal() {
if (activeSocket) {
try {
activeSocket.close();
} catch (err) {
// noop
}
}
if (termContainer) {
termContainer.dataset.started = "0";
}
activeSocket = null;
activeTerm = null;
setButtonState(false);
}
function startTerminal() {
if (!termContainer || !window.Terminal || termContainer.dataset.started === "1") {
return;
}
termContainer.dataset.started = "1";
if (startButton) {
startButton.disabled = true;
startButton.textContent = "Starting...";
}
var term = new window.Terminal({
cursorBlink: true,
fontFamily: "ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, 'Liberation Mono', 'Courier New', monospace",
fontSize: 13,
theme: {
background: "#0b1120",
foreground: "#e2e8f0",
cursor: "#38bdf8"
}
});
term.open(termContainer);
setTimeout(function () {
fitTerminal(term);
}, 0);
var protocol = window.location.protocol === "https:" ? "wss" : "ws";
var socketUrl = protocol + "://" + window.location.host + "/ws/servers/{{ server.id }}/shell/";
var socket = new WebSocket(socketUrl);
socket.binaryType = "arraybuffer";
activeSocket = socket;
activeTerm = term;
socket.onmessage = function (event) {
if (typeof event.data === "string") {
term.write(event.data);
return;
}
var data = new Uint8Array(event.data);
var text = new TextDecoder("utf-8").decode(data);
term.write(text);
};
socket.onclose = function () {
term.write("\r\nSession closed.\r\n");
if (activeSocket === socket) {
stopTerminal();
}
};
term.onData(function (data) {
if (socket.readyState === WebSocket.OPEN) {
socket.send(data);
}
});
setButtonState(true);
if (isPopout) {
var onResize = function () {
sizePopoutTerminal();
fitTerminal(term);
};
window.addEventListener("resize", onResize);
}
}
if (termContainer && window.Terminal) {
if (isPopout) {
document.body.classList.add("popout-shell");
sizePopoutTerminal();
window.addEventListener("resize", sizePopoutTerminal);
}
if (startButton) {
startButton.addEventListener("click", function () {
if (termContainer.dataset.started === "1") {
stopTerminal();
return;
}
startTerminal();
});
} else {
startTerminal();
}
} else if (termContainer) {
termContainer.textContent = "Terminal assets failed to load.";
}
})();
</script>
{% endblock %}

13
app/apps/servers/urls.py Normal file
View File

@@ -0,0 +1,13 @@
from django.urls import path
from . import views
app_name = "servers"
urlpatterns = [
path("", views.dashboard, name="dashboard"),
path("<int:server_id>/", views.detail, name="detail"),
path("<int:server_id>/audit/", views.audit, name="audit"),
path("<int:server_id>/shell/", views.shell, name="shell"),
path("<int:server_id>/settings/", views.settings, name="settings"),
]

230
app/apps/servers/views.py Normal file
View File

@@ -0,0 +1,230 @@
from __future__ import annotations
from datetime import timedelta
from django.conf import settings
from django.contrib.auth.decorators import login_required
from django.db.models import Q
from django.http import Http404
from django.shortcuts import render
from django.utils import timezone
from guardian.shortcuts import get_objects_for_user, get_perms
from apps.access.models import AccessRequest
from apps.keys.utils import render_system_username
from apps.keys.models import SSHKey
from apps.servers.models import Server, ServerAccount
from apps.servers.permissions import user_can_shell
@login_required(login_url="/accounts/login/")
def dashboard(request):
now = timezone.now()
server_qs = get_objects_for_user(
request.user,
"servers.view_server",
klass=Server,
accept_global_perms=False,
)
access_qs = (
AccessRequest.objects.select_related("server")
.filter(
requester=request.user,
status=AccessRequest.Status.APPROVED,
)
.filter(Q(expires_at__isnull=True) | Q(expires_at__gt=now))
)
expires_map = {}
for access in access_qs:
expires_at = access.expires_at
if access.server_id not in expires_map:
expires_map[access.server_id] = expires_at
continue
current = expires_map[access.server_id]
if current is None:
continue
if expires_at is None or expires_at > current:
expires_map[access.server_id] = expires_at
servers = []
for server in server_qs:
servers.append(
{
"server": server,
"expires_at": expires_map.get(server.id),
"last_accessed": None,
"status": _build_server_status(server, now),
}
)
context = {
"servers": servers,
}
return render(request, "servers/dashboard.html", context)
@login_required(login_url="/accounts/login/")
def detail(request, server_id: int):
now = timezone.now()
# Authorization is enforced via object-level permissions before we do
# any other server-specific work.
server = _get_server_or_404(request, server_id)
can_shell = user_can_shell(request.user, server, now)
access = (
AccessRequest.objects.filter(
requester=request.user,
server_id=server_id,
status=AccessRequest.Status.APPROVED,
)
.filter(Q(expires_at__isnull=True) | Q(expires_at__gt=now))
.order_by("-requested_at")
.first()
)
account, system_username, certificate_key_id = _load_account_context(request, server)
context = {
"server": server,
"expires_at": access.expires_at if access else None,
"last_accessed": None,
"account_present": account.is_present if account else None,
"account_synced_at": account.last_synced_at if account else None,
"system_username": system_username,
"certificate_key_id": certificate_key_id,
"active_tab": "details",
"can_shell": can_shell,
"server_status": _build_server_status(server, now),
}
return render(request, "servers/detail.html", context)
@login_required(login_url="/accounts/login/")
def shell(request, server_id: int):
now = timezone.now()
server = _get_server_or_404(request, server_id)
# We intentionally return a 404 on denied shell access to avoid
# disclosing that the server exists but is restricted.
if not user_can_shell(request.user, server):
raise Http404("Shell access not available")
_, system_username, certificate_key_id = _load_account_context(request, server)
shell_target = server.hostname or server.ipv4 or server.ipv6 or ""
cert_filename = ""
if certificate_key_id:
cert_filename = f"keywarden-{request.user.id}-{certificate_key_id}-cert.pub"
command = ""
if shell_target and system_username and certificate_key_id:
command = (
"ssh -i /path/to/private_key "
f"-o CertificateFile=~/Downloads/{cert_filename} "
f"{system_username}@{shell_target} -t /bin/bash"
)
context = {
"server": server,
"system_username": system_username,
"certificate_key_id": certificate_key_id,
"shell_target": shell_target,
"shell_command": command,
"cert_filename": cert_filename,
"active_tab": "shell",
"is_popout": request.GET.get("popout") == "1",
"can_shell": True,
"server_status": _build_server_status(server, now),
}
return render(request, "servers/shell.html", context)
@login_required(login_url="/accounts/login/")
def audit(request, server_id: int):
now = timezone.now()
server = _get_server_or_404(request, server_id)
context = {
"server": server,
"active_tab": "audit",
"can_shell": user_can_shell(request.user, server),
"server_status": _build_server_status(server, now),
}
return render(request, "servers/audit.html", context)
@login_required(login_url="/accounts/login/")
def settings(request, server_id: int):
now = timezone.now()
server = _get_server_or_404(request, server_id)
context = {
"server": server,
"active_tab": "settings",
"can_shell": user_can_shell(request.user, server),
"server_status": _build_server_status(server, now),
}
return render(request, "servers/settings.html", context)
def _get_server_or_404(request, server_id: int) -> Server:
# Centralized object lookup + permission gate. We raise 404 for both
# missing objects and permission denials to reduce enumeration signals.
try:
server = Server.objects.get(id=server_id)
except Server.DoesNotExist:
raise Http404("Server not found")
if "view_server" not in get_perms(request.user, server):
raise Http404("Server not found")
return server
def _load_account_context(request, server: Server):
# Resolve the effective system username and the currently active SSH
# key/certificate context used by the shell UI.
account = ServerAccount.objects.filter(server=server, user=request.user).first()
system_username = account.system_username if account else render_system_username(
request.user.username, request.user.id
)
active_key = SSHKey.objects.filter(user=request.user, is_active=True).order_by("-created_at").first()
certificate_key_id = active_key.id if active_key else None
return account, system_username, certificate_key_id
def _format_age_short(delta: timedelta) -> str:
seconds = max(0, int(delta.total_seconds()))
if seconds < 60:
return f"{seconds}s"
minutes = seconds // 60
rem_seconds = seconds % 60
if minutes < 60:
return f"{minutes}m {rem_seconds}s"
hours = minutes // 60
rem_minutes = minutes % 60
if hours < 48:
return f"{hours}h {rem_minutes}m {rem_seconds}s"
days = hours // 24
if days < 14:
return f"{days}d {hours % 24}h"
weeks = days // 7
return f"{weeks}w {days % 7}d"
def _build_server_status(server: Server, now):
stale_seconds = int(getattr(settings, "KEYWARDEN_HEARTBEAT_STALE_SECONDS", 120))
heartbeat_at = getattr(server, "last_heartbeat_at", None)
ping_ms = getattr(server, "last_ping_ms", None)
if heartbeat_at:
age = now - heartbeat_at
age_seconds = max(0, int(age.total_seconds()))
is_active = age_seconds <= stale_seconds
age_short = _format_age_short(age)
else:
is_active = False
age_short = "never"
label = "Active" if is_active else "Inactive"
if is_active:
detail = f"{ping_ms}ms" if ping_ms is not None else ""
else:
detail = age_short
return {
"is_active": is_active,
"label": label,
"detail": detail,
"ping_ms": ping_ms,
"age_short": age_short,
"heartbeat_at": heartbeat_at,
}

View File

@@ -1,6 +1,31 @@
#!/bin/sh #!/bin/sh
set -eu set -eu
DOMAIN="${KEYWARDEN_DOMAIN:-localhost}"
CERT_DIR="/etc/nginx/certs"
NGINX_TEMPLATE="/etc/nginx/nginx.conf.template"
NGINX_CONF="/etc/nginx/nginx.conf"
# Replaces server_name in nginx.conf with $KEYWARDEN_DOMAIN
if [ -f "$NGINX_TEMPLATE" ]; then
ESCAPED_DOMAIN=$(printf '%s' "$DOMAIN" | sed 's/[&/]/\\&/g')
sed "s/__SERVER_NAME__/${ESCAPED_DOMAIN}/g" "$NGINX_TEMPLATE" > "$NGINX_CONF"
fi
# Creates self-signed certs using mkcert $KEYWARDEN_DOMAIN, and renaming them.
if [ ! -f "$CERT_DIR/certificate.pem" ] || [ ! -f "$CERT_DIR/key.pem" ]; then
mkdir -p "$CERT_DIR"
if command -v mkcert >/dev/null 2>&1; then
mkcert -install >/dev/null 2>&1 || true
mkcert -cert-file "$CERT_DIR/certificate.pem" -key-file "$CERT_DIR/key.pem" "$DOMAIN"
else
openssl req -x509 -nodes -newkey rsa:2048 -days 365 \
-subj "/CN=$DOMAIN" \
-keyout "$CERT_DIR/key.pem" \
-out "$CERT_DIR/certificate.pem"
fi
fi
# Build Tailwind CSS (best-effort; skip if not configured) # Build Tailwind CSS (best-effort; skip if not configured)
python manage.py tailwind install || true python manage.py tailwind install || true
python manage.py tailwind build || true python manage.py tailwind build || true
@@ -12,4 +37,3 @@ python manage.py migrate --noinput
python manage.py ensure_admin python manage.py ensure_admin
exec /usr/bin/supervisord -c /etc/supervisor/supervisord.conf exec /usr/bin/supervisord -c /etc/supervisor/supervisord.conf

View File

@@ -0,0 +1,3 @@
from .celery import app as celery_app
__all__ = ("celery_app",)

View File

@@ -1,6 +1,7 @@
import inspect
from typing import List, Optional from typing import List, Optional
from ninja import NinjaAPI, Router, Schema from ninja import NinjaAPI, Router, Schema, Redoc
from ninja.security import django_auth from ninja.security import django_auth
from .security import JWTAuth from .security import JWTAuth
@@ -14,34 +15,43 @@ from .routers.access import build_router as build_access_router
from .routers.telemetry import build_router as build_telemetry_router from .routers.telemetry import build_router as build_telemetry_router
from .routers.agent import build_router as build_agent_router from .routers.agent import build_router as build_agent_router
from django.contrib.admin.views.decorators import staff_member_required
def register_routers(target_api: NinjaAPI) -> None: def register_routers(target_api: NinjaAPI) -> None:
target_api.add_router("/system", build_system_router(), tags=["system"]) target_api.add_router("/system", build_system_router(), tags=["System"])
target_api.add_router("/user", build_accounts_router(), tags=["user"]) target_api.add_router("/user", build_accounts_router(), tags=["Account Context"])
target_api.add_router("/audit", build_audit_router(), tags=["audit"]) target_api.add_router("/audit", build_audit_router(), tags=["Audit Logging"])
target_api.add_router("/servers", build_servers_router(), tags=["servers"]) target_api.add_router("/servers", build_servers_router(), tags=["Servers"])
target_api.add_router("/users", build_users_router(), tags=["users"]) target_api.add_router("/users", build_users_router(), tags=["User Directory"])
target_api.add_router("/keys", build_keys_router(), tags=["keys"]) target_api.add_router("/keys", build_keys_router(), tags=["SSH Keys"])
target_api.add_router("/access-requests", build_access_router(), tags=["access"]) target_api.add_router("/access-requests", build_access_router(), tags=["Access Requests"])
target_api.add_router("/telemetry", build_telemetry_router(), tags=["telemetry"]) target_api.add_router("/telemetry", build_telemetry_router(), tags=["Telemetry"])
target_api.add_router("/agent", build_agent_router(), tags=["agent"]) target_api.add_router("/agent", build_agent_router(), tags=["Agent"])
api = NinjaAPI( def build_api(**kwargs) -> NinjaAPI:
if "csrf" in inspect.signature(NinjaAPI).parameters:
return NinjaAPI(csrf=True, **kwargs)
return NinjaAPI(**kwargs)
api = build_api(
title="Keywarden API", title="Keywarden API",
version="1.0.0", version="1.0.0",
description="Authenticated API for internal app use and external clients.", description="Authenticated API for internal app use and external clients.",
auth=[django_auth, JWTAuth()], auth=[django_auth, JWTAuth()],
csrf=True, # enforce CSRF for session-authenticated unsafe requests docs=Redoc(),
docs_decorator=staff_member_required,
) )
register_routers(api) register_routers(api)
api_v1 = NinjaAPI( api_v1 = build_api(
title="Keywarden API", title="Keywarden API",
version="1.0.0", version="1.0.0",
description="Authenticated API for internal app use and external clients.", description="Authenticated API for internal app use and external clients.",
auth=[django_auth, JWTAuth()], auth=[django_auth, JWTAuth()],
csrf=True,
urls_namespace="api-v1", urls_namespace="api-v1",
docs=Redoc(),
docs_decorator=staff_member_required,
) )
register_routers(api_v1) register_routers(api_v1)

View File

@@ -5,17 +5,23 @@ from typing import List, Optional
from django.http import HttpRequest from django.http import HttpRequest
from django.utils import timezone from django.utils import timezone
from guardian.shortcuts import get_objects_for_user
from ninja import Query, Router, Schema from ninja import Query, Router, Schema
from ninja.errors import HttpError from ninja.errors import HttpError
from pydantic import Field from pydantic import Field
from apps.access.models import AccessRequest from apps.access.models import AccessRequest
from apps.core.rbac import require_authenticated
from apps.servers.models import Server from apps.servers.models import Server
from apps.access.permissions import sync_server_view_perm
class AccessRequestCreateIn(Schema): class AccessRequestCreateIn(Schema):
server_id: int server_id: int
reason: Optional[str] = None reason: Optional[str] = None
request_shell: bool = False
request_logs: bool = False
request_users: bool = False
expires_at: Optional[datetime] = None expires_at: Optional[datetime] = None
@@ -30,6 +36,9 @@ class AccessRequestOut(Schema):
server_id: int server_id: int
status: str status: str
reason: str reason: str
request_shell: bool
request_logs: bool
request_users: bool
requested_at: str requested_at: str
decided_at: Optional[str] = None decided_at: Optional[str] = None
expires_at: Optional[str] = None expires_at: Optional[str] = None
@@ -44,16 +53,6 @@ class AccessQuery(Schema):
requester_id: Optional[int] = None requester_id: Optional[int] = None
def _require_authenticated(request: HttpRequest) -> None:
if not getattr(request.user, "is_authenticated", False):
raise HttpError(403, "Forbidden")
def _is_admin(request: HttpRequest) -> bool:
user = request.user
return bool(getattr(user, "is_staff", False) or getattr(user, "is_superuser", False))
def _request_to_out(access_request: AccessRequest) -> AccessRequestOut: def _request_to_out(access_request: AccessRequest) -> AccessRequestOut:
return AccessRequestOut( return AccessRequestOut(
id=access_request.id, id=access_request.id,
@@ -61,6 +60,9 @@ def _request_to_out(access_request: AccessRequest) -> AccessRequestOut:
server_id=access_request.server_id, server_id=access_request.server_id,
status=access_request.status, status=access_request.status,
reason=access_request.reason or "", reason=access_request.reason or "",
request_shell=access_request.request_shell,
request_logs=access_request.request_logs,
request_users=access_request.request_users,
requested_at=access_request.requested_at.isoformat(), requested_at=access_request.requested_at.isoformat(),
decided_at=access_request.decided_at.isoformat() if access_request.decided_at else None, decided_at=access_request.decided_at.isoformat() if access_request.decided_at else None,
expires_at=access_request.expires_at.isoformat() if access_request.expires_at else None, expires_at=access_request.expires_at.isoformat() if access_request.expires_at else None,
@@ -68,19 +70,39 @@ def _request_to_out(access_request: AccessRequest) -> AccessRequestOut:
) )
def _has_global_perm(request: HttpRequest, perm: str) -> bool:
user = request.user
return bool(user and user.has_perm(perm))
def build_router() -> Router: def build_router() -> Router:
router = Router() router = Router()
@router.get("/", response=List[AccessRequestOut]) @router.get("/", response=List[AccessRequestOut])
def list_requests(request: HttpRequest, filters: AccessQuery = Query(...)): def list_requests(request: HttpRequest, filters: AccessQuery = Query(...)):
"""List access requests for the user, or all if admin.""" """List access requests with pagination and filters.
_require_authenticated(request)
qs = AccessRequest.objects.order_by("-requested_at") Auth: required.
if _is_admin(request): Permissions:
if filters.requester_id: - If user has global `access.view_accessrequest`, returns all requests.
qs = qs.filter(requester_id=filters.requester_id) - Otherwise, returns only objects with `access.view_accessrequest` object permission.
Filters: status, server_id, requester_id (requester_id is honored only with global view).
Rationale: powers the access request queue and auditing views.
"""
require_authenticated(request)
user = request.user
if _has_global_perm(request, "access.view_accessrequest"):
qs = AccessRequest.objects.all()
else: else:
qs = qs.filter(requester=request.user) qs = get_objects_for_user(
user,
"access.view_accessrequest",
klass=AccessRequest,
accept_global_perms=False,
)
qs = qs.order_by("-requested_at")
if filters.requester_id and _has_global_perm(request, "access.view_accessrequest"):
qs = qs.filter(requester_id=filters.requester_id)
if filters.status: if filters.status:
qs = qs.filter(status=filters.status) qs = qs.filter(status=filters.status)
if filters.server_id: if filters.server_id:
@@ -90,8 +112,18 @@ def build_router() -> Router:
@router.post("/", response=AccessRequestOut) @router.post("/", response=AccessRequestOut)
def create_request(request: HttpRequest, payload: AccessRequestCreateIn): def create_request(request: HttpRequest, payload: AccessRequestCreateIn):
"""Create a new access request for a server.""" """Create a new access request for the current user.
_require_authenticated(request)
Auth: required.
Permissions: requires global `access.add_accessrequest`.
Side effects: grants owner object perms on the new request.
Behavior: creates a pending access request; it does not grant access
until approved. Optional expires_at defines the requested access window.
Rationale: this is the entry point for delegating server access.
"""
require_authenticated(request)
if not request.user.has_perm("access.add_accessrequest"):
raise HttpError(403, "Forbidden")
try: try:
server = Server.objects.get(id=payload.server_id) server = Server.objects.get(id=payload.server_id)
except Server.DoesNotExist: except Server.DoesNotExist:
@@ -100,6 +132,9 @@ def build_router() -> Router:
requester=request.user, requester=request.user,
server=server, server=server,
reason=(payload.reason or "").strip(), reason=(payload.reason or "").strip(),
request_shell=payload.request_shell,
request_logs=payload.request_logs,
request_users=payload.request_users,
) )
if payload.expires_at: if payload.expires_at:
access_request.expires_at = payload.expires_at access_request.expires_at = payload.expires_at
@@ -110,28 +145,43 @@ def build_router() -> Router:
@router.get("/{request_id}", response=AccessRequestOut) @router.get("/{request_id}", response=AccessRequestOut)
def get_request(request: HttpRequest, request_id: int): def get_request(request: HttpRequest, request_id: int):
"""Get an access request if permitted.""" """Get a single access request by id.
_require_authenticated(request)
Auth: required.
Permissions: requires `access.view_accessrequest` on the object.
Rationale: used for request detail views and approval workflows.
"""
require_authenticated(request)
try: try:
access_request = AccessRequest.objects.get(id=request_id) access_request = AccessRequest.objects.get(id=request_id)
except AccessRequest.DoesNotExist: except AccessRequest.DoesNotExist:
raise HttpError(404, "Not Found") raise HttpError(404, "Not Found")
if not _is_admin(request) and access_request.requester_id != request.user.id: if not request.user.has_perm("access.view_accessrequest", access_request):
raise HttpError(403, "Forbidden") raise HttpError(403, "Forbidden")
return _request_to_out(access_request) return _request_to_out(access_request)
@router.patch("/{request_id}", response=AccessRequestOut) @router.patch("/{request_id}", response=AccessRequestOut)
def update_request(request: HttpRequest, request_id: int, payload: AccessRequestUpdateIn): def update_request(request: HttpRequest, request_id: int, payload: AccessRequestUpdateIn):
"""Update request status or expiry (admin or owner with restrictions).""" """Update request status or expiry.
_require_authenticated(request)
Auth: required.
Permissions: requires `access.change_accessrequest` on the object.
Rules:
- Admin/operator (global change) can set status to approved/denied/revoked/cancelled and
update expires_at.
- Non-admin can only set status to cancelled, and only while pending.
Side effects: updates object permissions for server visibility when
approvals or revocations occur.
Rationale: this is the core approval/denial path for access control.
"""
require_authenticated(request)
try: try:
access_request = AccessRequest.objects.get(id=request_id) access_request = AccessRequest.objects.get(id=request_id)
except AccessRequest.DoesNotExist: except AccessRequest.DoesNotExist:
raise HttpError(404, "Not Found") raise HttpError(404, "Not Found")
is_admin = _is_admin(request) if not request.user.has_perm("access.change_accessrequest", access_request):
is_owner = access_request.requester_id == request.user.id
if not is_admin and not is_owner:
raise HttpError(403, "Forbidden") raise HttpError(403, "Forbidden")
is_admin = _has_global_perm(request, "access.change_accessrequest")
if payload.status is None and payload.expires_at is None: if payload.status is None and payload.expires_at is None:
raise HttpError(422, {"detail": "No fields provided."}) raise HttpError(422, {"detail": "No fields provided."})
if payload.expires_at is not None: if payload.expires_at is not None:
@@ -162,21 +212,9 @@ def build_router() -> Router:
else: else:
access_request.decided_by = None access_request.decided_by = None
access_request.save() access_request.save()
sync_server_view_perm(access_request)
return _request_to_out(access_request) return _request_to_out(access_request)
@router.delete("/{request_id}", response={204: None})
def delete_request(request: HttpRequest, request_id: int):
"""Delete an access request if permitted."""
_require_authenticated(request)
try:
access_request = AccessRequest.objects.get(id=request_id)
except AccessRequest.DoesNotExist:
raise HttpError(404, "Not Found")
if not _is_admin(request) and access_request.requester_id != request.user.id:
raise HttpError(403, "Forbidden")
access_request.delete()
return 204, None
return router return router

View File

@@ -3,6 +3,7 @@ from typing import Optional
from django.http import HttpRequest from django.http import HttpRequest
from ninja import Router, Schema from ninja import Router, Schema
from apps.core.rbac import require_authenticated
class UserSchema(Schema): class UserSchema(Schema):
id: int id: int
@@ -19,7 +20,15 @@ def build_router() -> Router:
@router.get("/me", response=UserSchema) @router.get("/me", response=UserSchema)
def me(request: HttpRequest): def me(request: HttpRequest):
"""Return the current authenticated user's profile.""" """Return the authenticated user's profile and role context.
Auth: required (session or JWT). Used by the UI to build navigation,
display the user identity, and decide which actions are enabled.
Fields: returns only the minimal identity and privilege flags needed
by the client; no secrets or permissions lists are exposed here.
Rationale: keeps the client-side state aligned with the session user.
"""
require_authenticated(request)
user = request.user user = request.user
return { return {
"id": user.id, "id": user.id,

View File

@@ -1,17 +1,33 @@
from __future__ import annotations from datetime import datetime, timedelta
from typing import List, Optional from typing import List, Optional
from django.db import models from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.x509.oid import ExtendedKeyUsageOID, NameOID
from django.conf import settings
from django.contrib.auth import get_user_model
from django.core.exceptions import ValidationError
from django.core.validators import validate_ipv4_address, validate_ipv6_address
from django.db import IntegrityError, transaction
from django.http import HttpRequest from django.http import HttpRequest
from django.utils import timezone from django.utils import timezone
from ninja import Router, Schema from django.views.decorators.csrf import csrf_exempt
from ninja import Body, Router, Schema
from ninja.errors import HttpError from ninja.errors import HttpError
from pydantic import Field from pydantic import Field
from guardian.shortcuts import get_users_with_perms
from apps.access.models import AccessRequest from apps.core.rbac import require_perms
from apps.keys.certificates import get_active_ca
from apps.keys.models import SSHKey from apps.keys.models import SSHKey
from apps.servers.models import Server from apps.keys.utils import render_system_username
from apps.servers.models import (
AgentCertificateAuthority,
EnrollmentToken,
Server,
ServerAccount,
hostname_validator,
)
from apps.telemetry.models import TelemetryEvent from apps.telemetry.models import TelemetryEvent
@@ -23,63 +39,220 @@ class AuthorizedKeyOut(Schema):
fingerprint: str fingerprint: str
class AccountKeyOut(Schema):
public_key: str
fingerprint: str
class AccountAccessOut(Schema):
user_id: int
username: str
email: str
system_username: str
keys: List[AccountKeyOut] = Field(default_factory=list)
class AccountSyncIn(Schema):
user_id: int
system_username: str
present: bool
class SyncReportIn(Schema): class SyncReportIn(Schema):
applied_count: int = Field(default=0, ge=0) applied_count: int = Field(default=0, ge=0)
revoked_count: int = Field(default=0, ge=0) revoked_count: int = Field(default=0, ge=0)
message: Optional[str] = None message: Optional[str] = None
metadata: dict = Field(default_factory=dict) metadata: dict = Field(default_factory=dict)
accounts: List[AccountSyncIn] = Field(default_factory=list)
class SyncReportOut(Schema): class SyncReportOut(Schema):
status: str status: str
def _require_admin(request: HttpRequest) -> None: class AgentEnrollIn(Schema):
user = request.user token: str
if not getattr(user, "is_authenticated", False): csr_pem: str
raise HttpError(403, "Forbidden") host: Optional[str] = None
if not (user.is_staff or user.is_superuser): ipv4: Optional[str] = None
raise HttpError(403, "Forbidden") ipv6: Optional[str] = None
class AgentEnrollOut(Schema):
server_id: str
client_cert_pem: str
ca_cert_pem: str
class LogEventIn(Schema):
timestamp: str
category: str
event_type: str
unit: Optional[str] = None
priority: Optional[str] = None
hostname: Optional[str] = None
username: Optional[str] = None
principal: Optional[str] = None
source_ip: Optional[str] = None
session_id: Optional[str] = None
message: Optional[str] = None
raw: Optional[str] = None
fields: Optional[dict] = None
class LogIngestOut(Schema):
status: str
accepted: int
class AgentHeartbeatIn(Schema):
host: Optional[str] = None
ipv4: Optional[str] = None
ipv6: Optional[str] = None
ping_ms: Optional[int] = None
def build_router() -> Router: def build_router() -> Router:
router = Router() router = Router()
@router.post("/enroll", response=AgentEnrollOut, auth=None)
@csrf_exempt
def enroll_agent(request: HttpRequest, payload: AgentEnrollIn = Body(...)):
"""Enroll a server agent using a one-time enrollment token.
Auth: token only (no session/JWT); mTLS is not yet available until
enrollment completes.
Inputs: enrollment token + CSR from the agent, optional host/IP hints.
Behavior:
- Creates a Server record (agent is the source of truth for host/IP).
- Marks the token as used (single-use).
- Signs the CSR with the active Agent CA and returns client cert + CA.
Rationale: this is the only supported server onboarding flow. If this
endpoint is removed, agents cannot bootstrap mTLS credentials.
"""
token_value = (payload.token or "").strip()
if not token_value:
raise HttpError(422, "Token required")
try:
token = EnrollmentToken.objects.get(token=token_value)
except EnrollmentToken.DoesNotExist:
raise HttpError(403, "Invalid token")
if not token.is_valid():
raise HttpError(403, "Token expired or already used")
host = (payload.host or "").strip()[:253]
display_name = host or "server"
hostname = None
if host:
try:
hostname_validator(host)
hostname = host
except ValidationError:
hostname = None
ipv4 = _normalize_ip(payload.ipv4, 4)
ipv6 = _normalize_ip(payload.ipv6, 6)
csr = _load_csr((payload.csr_pem or "").strip())
try:
with transaction.atomic():
server = Server.objects.create(
display_name=display_name,
hostname=hostname,
ipv4=ipv4,
ipv6=ipv6,
)
token.mark_used(server)
token.save(update_fields=["used_at", "server"])
cert_pem, ca_pem, fingerprint, serial = _issue_client_cert(csr, host, server.id)
server.agent_enrolled_at = timezone.now()
server.agent_cert_fingerprint = fingerprint
server.agent_cert_serial = serial
server.save(update_fields=["agent_enrolled_at", "agent_cert_fingerprint", "agent_cert_serial"])
except IntegrityError:
raise HttpError(409, "Server already enrolled")
return AgentEnrollOut(
server_id=str(server.id),
client_cert_pem=cert_pem,
ca_cert_pem=ca_pem,
)
@router.get("/servers/{server_id}/authorized-keys", response=List[AuthorizedKeyOut]) @router.get("/servers/{server_id}/authorized-keys", response=List[AuthorizedKeyOut])
def authorized_keys(request: HttpRequest, server_id: int): def authorized_keys(request: HttpRequest, server_id: int):
"""Return authorized public keys for a server (admin only).""" """Resolve the effective authorized_keys list for a server.
_require_admin(request)
try: Auth: required (admin/operator via API).
server = Server.objects.get(id=server_id) Permissions: requires view access to servers and keys.
except Server.DoesNotExist: Behavior: uses server object permissions + active SSH keys to produce
raise HttpError(404, "Server not found") the exact key list the agent should deploy to the server.
now = timezone.now() Rationale: this is the policy enforcement point for per-user access.
access_qs = AccessRequest.objects.select_related("requester").filter( """
server=server, require_perms(
status=AccessRequest.Status.APPROVED, request,
) "servers.view_server",
access_qs = access_qs.filter(models.Q(expires_at__isnull=True) | models.Q(expires_at__gt=now)) "keys.view_sshkey",
users = [req.requester for req in access_qs if req.requester and req.requester.is_active]
keys = SSHKey.objects.select_related("user").filter(
user__in=users,
is_active=True,
revoked_at__isnull=True,
) )
server = _get_server_or_404(server_id)
users = _resolve_access_users(server)
key_map = _key_map_for_users(users)
output: list[AuthorizedKeyOut] = []
for user in users:
for key in key_map.get(user.id, []):
output.append(
AuthorizedKeyOut(
user_id=user.id,
username=user.username,
email=user.email or "",
public_key=key.public_key,
fingerprint=key.fingerprint,
)
)
return output
@router.get("/servers/{server_id}/accounts", response=List[AccountAccessOut], auth=None)
def account_access(request: HttpRequest, server_id: int):
"""List accounts that should exist on a server.
Auth: mTLS expected at the edge (no session/JWT).
Behavior: resolves active users with server object perms and their keys.
Rationale: drives agent-side account provisioning.
"""
server = _get_server_or_404(server_id)
users = _resolve_access_users(server)
return [ return [
AuthorizedKeyOut( AccountAccessOut(
user_id=key.user_id, user_id=user.id,
username=key.user.username, username=user.username,
email=key.user.email or "", email=user.email or "",
public_key=key.public_key, system_username=render_system_username(user.username, user.id),
fingerprint=key.fingerprint, keys=[],
) )
for key in keys for user in users
] ]
@router.post("/servers/{server_id}/sync-report", response=SyncReportOut) @router.get("/servers/{server_id}/ssh-ca", auth=None)
def sync_report(request: HttpRequest, server_id: int, payload: SyncReportIn): @csrf_exempt
"""Record an agent sync report for a server (admin only).""" def ssh_ca(request: HttpRequest, server_id: int):
_require_admin(request) """Return the active SSH user CA public key for agents.
Auth: mTLS expected at the edge (no session/JWT).
"""
_ = _get_server_or_404(server_id)
ca = get_active_ca()
if not ca.public_key:
raise HttpError(404, "SSH CA not configured")
return {"public_key": ca.public_key, "fingerprint": ca.fingerprint}
@router.post("/servers/{server_id}/sync-report", response=SyncReportOut, auth=None)
@csrf_exempt
def sync_report(request: HttpRequest, server_id: int, payload: SyncReportIn = Body(...)):
"""Record an agent sync report for a server.
Auth: mTLS expected at the edge (no session/JWT).
Behavior: stores a telemetry event with counts of applied/revoked keys.
Rationale: provides an audit trail of enforcement actions without
requiring full log ingestion for every sync cycle.
"""
try: try:
server = Server.objects.get(id=server_id) server = Server.objects.get(id=server_id)
except Server.DoesNotExist: except Server.DoesNotExist:
@@ -96,9 +269,198 @@ def build_router() -> Router:
**(payload.metadata or {}), **(payload.metadata or {}),
}, },
) )
if payload.accounts:
_update_server_accounts(server, payload.accounts)
return SyncReportOut(status="ok")
@router.post("/servers/{server_id}/logs", response=LogIngestOut, auth=None)
@csrf_exempt
def ingest_logs(request: HttpRequest, server_id: int, payload: List[LogEventIn] = Body(...)):
"""Accept log batches from agents for audit collection.
Auth: mTLS expected at the edge (no session/JWT).
Behavior: accepts structured log events for later storage and indexing.
Storage: raw logs are persisted separately per-server (SQLite shards),
not in the primary Postgres database.
Rationale: this is the ingestion pipe for security audit logging.
"""
try:
Server.objects.get(id=server_id)
except Server.DoesNotExist:
raise HttpError(404, "Server not found")
# TODO: enqueue to Valkey and persist to SQLite slices.
return LogIngestOut(status="accepted", accepted=len(payload))
@router.post("/servers/{server_id}/heartbeat", response=SyncReportOut, auth=None)
@csrf_exempt
def heartbeat(request: HttpRequest, server_id: int, payload: AgentHeartbeatIn = Body(...)):
"""Update server host metadata (hostname/IPs) reported by the agent.
Auth: mTLS expected at the edge (no session/JWT).
Behavior: updates hostname/IPv4/IPv6 when they change (e.g., DHCP).
Conflict: unique constraints are enforced; conflicts return 409.
Rationale: keeps the server inventory accurate without manual edits.
"""
try:
server = Server.objects.get(id=server_id)
except Server.DoesNotExist:
raise HttpError(404, "Server not found")
updates: dict[str, str | int | datetime] = {}
host = (payload.host or "").strip()[:253]
if host:
try:
hostname_validator(host)
if server.hostname != host:
updates["hostname"] = host
except ValidationError:
pass
ipv4 = _normalize_ip(payload.ipv4, 4)
if ipv4 and server.ipv4 != ipv4:
updates["ipv4"] = ipv4
ipv6 = _normalize_ip(payload.ipv6, 6)
if ipv6 and server.ipv6 != ipv6:
updates["ipv6"] = ipv6
now = timezone.now()
updates["last_heartbeat_at"] = now
if payload.ping_ms is not None:
updates["last_ping_ms"] = max(0, int(payload.ping_ms))
if updates:
for field, value in updates.items():
setattr(server, field, value)
try:
server.save(update_fields=list(updates.keys()))
except IntegrityError:
raise HttpError(409, "Server address already in use")
return SyncReportOut(status="ok") return SyncReportOut(status="ok")
return router return router
def _get_server_or_404(server_id: int) -> Server:
try:
return Server.objects.get(id=server_id)
except Server.DoesNotExist:
raise HttpError(404, "Server not found")
def _resolve_access_users(server: Server) -> list:
users = list(
get_users_with_perms(
server,
only_with_perms_in=["view_server"],
with_group_users=True,
with_superusers=False,
)
)
active = [user for user in users if getattr(user, "is_active", False)]
return sorted(active, key=lambda user: (user.username or "", user.id))
def _key_map_for_users(users: list) -> dict[int, list[SSHKey]]:
if not users:
return {}
keys = SSHKey.objects.select_related("user").filter(
user__in=users,
is_active=True,
revoked_at__isnull=True,
)
key_map: dict[int, list[SSHKey]] = {}
for key in keys:
key_map.setdefault(key.user_id, []).append(key)
return key_map
def _update_server_accounts(server: Server, accounts: list[AccountSyncIn]) -> None:
user_ids = {account.user_id for account in accounts}
if not user_ids:
return
User = get_user_model()
users = {user.id: user for user in User.objects.filter(id__in=user_ids)}
now = timezone.now()
for account in accounts:
user = users.get(account.user_id)
if not user:
continue
ServerAccount.objects.update_or_create(
server=server,
user=user,
defaults={
"system_username": account.system_username,
"is_present": account.present,
"last_synced_at": now,
},
)
def _load_agent_ca() -> tuple[x509.Certificate, object, str]:
ca = (
AgentCertificateAuthority.objects.filter(is_active=True, revoked_at__isnull=True)
.order_by("-created_at")
.first()
)
if not ca:
raise HttpError(500, "Agent CA not configured")
try:
ca_cert = x509.load_pem_x509_certificate(ca.cert_pem.encode("utf-8"))
ca_key = serialization.load_pem_private_key(ca.key_pem.encode("utf-8"), password=None)
except (ValueError, TypeError):
raise HttpError(500, "Invalid agent CA material")
return ca_cert, ca_key, ca.cert_pem
def _load_csr(csr_pem: str) -> x509.CertificateSigningRequest:
try:
csr = x509.load_pem_x509_csr(csr_pem.encode("utf-8"))
except ValueError:
raise HttpError(422, "Invalid CSR")
if not csr.is_signature_valid:
raise HttpError(422, "Invalid CSR signature")
return csr
def _issue_client_cert(
csr: x509.CertificateSigningRequest, host: str | None, server_id: int
) -> tuple[str, str, str, str]:
ca_cert, ca_key, ca_pem = _load_agent_ca()
now = datetime.utcnow()
subject = csr.subject
if len(subject) == 0:
subject = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, f"keywarden-agent-{server_id}")])
builder = (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(ca_cert.subject)
.public_key(csr.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(now - timedelta(minutes=5))
.not_valid_after(now + timedelta(days=settings.KEYWARDEN_AGENT_CERT_VALIDITY_DAYS))
.add_extension(x509.BasicConstraints(ca=False, path_length=None), critical=True)
.add_extension(x509.ExtendedKeyUsage([ExtendedKeyUsageOID.CLIENT_AUTH]), critical=False)
)
if host:
try:
hostname_validator(host)
builder = builder.add_extension(x509.SubjectAlternativeName([x509.DNSName(host)]), critical=False)
except ValidationError:
pass
cert = builder.sign(private_key=ca_key, algorithm=hashes.SHA256())
cert_pem = cert.public_bytes(serialization.Encoding.PEM).decode("utf-8")
fingerprint = cert.fingerprint(hashes.SHA256()).hex()
serial = format(cert.serial_number, "x")
return cert_pem, ca_pem, fingerprint, serial
def _normalize_ip(value: Optional[str], version: int) -> Optional[str]:
if not value:
return None
try:
if version == 4:
validate_ipv4_address(value)
else:
validate_ipv6_address(value)
except ValidationError:
return None
return value
router = build_router() router = build_router()

View File

@@ -8,13 +8,20 @@ from django.http import HttpRequest
from ninja import Query, Router, Schema from ninja import Query, Router, Schema
from apps.audit.models import AuditEventType, AuditLog from apps.audit.models import AuditEventType, AuditLog
from apps.core.rbac import require_perms
class AuditEventTypeSchema(Schema): class AuditEventTypeSchema(Schema):
id: int id: int
key: str key: str
title: str title: str
description: str | None = None description: str | None = None
kind: str
default_severity: str default_severity: str
endpoints: list[str]
ip_whitelist_enabled: bool
ip_whitelist: list[str]
ip_blacklist_enabled: bool
ip_blacklist: list[str]
class AuditLogSchema(Schema): class AuditLogSchema(Schema):
@@ -46,7 +53,15 @@ def build_router() -> Router:
@router.get("/event-types", response=List[AuditEventTypeSchema]) @router.get("/event-types", response=List[AuditEventTypeSchema])
def list_event_types(request: HttpRequest): def list_event_types(request: HttpRequest):
"""List audit event types and their default severity.""" """List audit event types used by the platform audit log.
Auth: required.
Permissions: requires global `audit.view_auditeventtype`.
Behavior: returns the canonical event taxonomy (key, title, severity).
Rationale: the admin UI and audit filters use this to map log entries
to human-readable categories and severity defaults.
"""
require_perms(request, "audit.view_auditeventtype")
qs: QuerySet[AuditEventType] = AuditEventType.objects.all() qs: QuerySet[AuditEventType] = AuditEventType.objects.all()
return [ return [
{ {
@@ -54,14 +69,30 @@ def build_router() -> Router:
"key": et.key, "key": et.key,
"title": et.title, "title": et.title,
"description": et.description or "", "description": et.description or "",
"kind": et.kind,
"default_severity": et.default_severity, "default_severity": et.default_severity,
"endpoints": list(et.endpoints or []),
"ip_whitelist_enabled": bool(et.ip_whitelist_enabled),
"ip_whitelist": list(et.ip_whitelist or []),
"ip_blacklist_enabled": bool(et.ip_blacklist_enabled),
"ip_blacklist": list(et.ip_blacklist or []),
} }
for et in qs for et in qs
] ]
@router.get("/logs", response=List[AuditLogSchema]) @router.get("/logs", response=List[AuditLogSchema])
def list_logs(request: HttpRequest, filters: LogsQuery = Query(...)): def list_logs(request: HttpRequest, filters: LogsQuery = Query(...)):
"""List audit logs with optional filters and pagination.""" """List application audit log entries with filters and pagination.
Auth: required.
Permissions: requires global `audit.view_auditlog`.
Filters: severity, actor_id, event_type_key, source.
Pagination: limit + offset.
Scope: this is the Keywarden app audit trail (who changed what), not
the server OS log ingestion stream stored by the agent.
Rationale: used by the audit UI and for administrative forensics.
"""
require_perms(request, "audit.view_auditlog")
qs: QuerySet[AuditLog] = AuditLog.objects.select_related("event_type", "actor").all() qs: QuerySet[AuditLog] = AuditLog.objects.select_related("event_type", "actor").all()
if filters.severity: if filters.severity:
qs = qs.filter(severity=filters.severity) qs = qs.filter(severity=filters.severity)

View File

@@ -3,15 +3,20 @@ from __future__ import annotations
from typing import List, Optional from typing import List, Optional
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
import hashlib
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.db import IntegrityError from django.db import IntegrityError, transaction
from django.http import HttpRequest from django.http import HttpRequest, HttpResponse
from django.utils import timezone from django.utils import timezone
from guardian.shortcuts import get_objects_for_user
from ninja import Query, Router, Schema from ninja import Query, Router, Schema
from ninja.errors import HttpError from ninja.errors import HttpError
from pydantic import Field from pydantic import Field
from apps.keys.models import SSHKey from apps.core.rbac import require_authenticated
from apps.keys.certificates import issue_certificate_for_key, revoke_certificate_for_key
from apps.keys.models import SSHCertificate, SSHKey
class KeyCreateIn(Schema): class KeyCreateIn(Schema):
@@ -37,22 +42,20 @@ class KeyOut(Schema):
revoked_at: Optional[str] = None revoked_at: Optional[str] = None
class CertificateOut(Schema):
key_id: int
serial: int
valid_after: str
valid_before: str
principals: List[str]
class KeysQuery(Schema): class KeysQuery(Schema):
limit: int = Field(default=50, ge=1, le=200) limit: int = Field(default=50, ge=1, le=200)
offset: int = Field(default=0, ge=0) offset: int = Field(default=0, ge=0)
user_id: Optional[int] = None user_id: Optional[int] = None
def _require_authenticated(request: HttpRequest) -> None:
if not getattr(request.user, "is_authenticated", False):
raise HttpError(403, "Forbidden")
def _is_admin(request: HttpRequest) -> bool:
user = request.user
return bool(getattr(user, "is_staff", False) or getattr(user, "is_superuser", False))
def _key_to_out(key: SSHKey) -> KeyOut: def _key_to_out(key: SSHKey) -> KeyOut:
return KeyOut( return KeyOut(
id=key.id, id=key.id,
@@ -67,33 +70,82 @@ def _key_to_out(key: SSHKey) -> KeyOut:
) )
def _ensure_certificate(key: SSHKey, request_user) -> SSHCertificate:
if not key.is_active:
raise HttpError(409, "Key is revoked")
now = timezone.now()
try:
cert = key.certificate
except SSHCertificate.DoesNotExist:
return issue_certificate_for_key(key, created_by=request_user)
if not cert.is_active or cert.valid_before <= now:
return issue_certificate_for_key(key, created_by=request_user)
return cert
def _has_global_perm(request: HttpRequest, perm: str) -> bool:
user = request.user
return bool(user and user.has_perm(perm))
def build_router() -> Router: def build_router() -> Router:
router = Router() router = Router()
@router.get("/", response=List[KeyOut]) @router.get("/", response=List[KeyOut])
def list_keys(request: HttpRequest, filters: KeysQuery = Query(...)): def list_keys(request: HttpRequest, filters: KeysQuery = Query(...)):
"""List SSH keys for the current user, or any user if admin.""" """List SSH keys with pagination and filters.
_require_authenticated(request)
qs = SSHKey.objects.order_by("-created_at") Auth: required.
if _is_admin(request): Permissions:
if filters.user_id: - If user has global `keys.view_sshkey`, returns all keys.
qs = qs.filter(user_id=filters.user_id) - Otherwise, returns only objects with `keys.view_sshkey` object permission.
Filter: user_id (honored only with global view).
Rationale: powers the key inventory UI and lets admins audit key usage.
"""
require_authenticated(request)
user = request.user
if _has_global_perm(request, "keys.view_sshkey"):
qs = SSHKey.objects.all()
else: else:
qs = qs.filter(user=request.user) qs = get_objects_for_user(
user,
"keys.view_sshkey",
klass=SSHKey,
accept_global_perms=False,
)
qs = qs.order_by("-created_at")
if filters.user_id and _has_global_perm(request, "keys.view_sshkey"):
qs = qs.filter(user_id=filters.user_id)
qs = qs[filters.offset : filters.offset + filters.limit] qs = qs[filters.offset : filters.offset + filters.limit]
return [_key_to_out(key) for key in qs] return [_key_to_out(key) for key in qs]
@router.post("/", response=KeyOut) @router.post("/", response=KeyOut)
def create_key(request: HttpRequest, payload: KeyCreateIn): def create_key(request: HttpRequest, payload: KeyCreateIn):
"""Create an SSH public key for the current user (admin can specify user_id).""" """Create an SSH public key.
_require_authenticated(request)
Auth: required.
Permissions: requires global `keys.add_sshkey`.
Rules:
- Default owner is the current user.
- If caller has global `keys.add_sshkey` and `keys.view_sshkey`, they may specify user_id.
Side effects: grants owner object perms on the new key.
Rationale: keys are the core authorization material synced to servers.
"""
require_authenticated(request)
if not request.user.has_perm("keys.add_sshkey"):
raise HttpError(403, "Forbidden")
is_admin = _has_global_perm(request, "keys.add_sshkey") and _has_global_perm(
request, "keys.view_sshkey"
)
owner = request.user owner = request.user
if _is_admin(request) and payload.user_id: if is_admin and payload.user_id:
User = get_user_model() User = get_user_model()
try: try:
owner = User.objects.get(id=payload.user_id) owner = User.objects.get(id=payload.user_id)
except User.DoesNotExist: except User.DoesNotExist:
raise HttpError(404, "User not found") raise HttpError(404, "User not found")
elif payload.user_id and payload.user_id != request.user.id:
raise HttpError(403, "Forbidden")
name = (payload.name or "").strip() name = (payload.name or "").strip()
if not name: if not name:
raise HttpError(422, {"name": ["Name cannot be empty."]}) raise HttpError(422, {"name": ["Name cannot be empty."]})
@@ -103,32 +155,104 @@ def build_router() -> Router:
except ValidationError as exc: except ValidationError as exc:
raise HttpError(422, {"public_key": [str(exc)]}) raise HttpError(422, {"public_key": [str(exc)]})
try: try:
key.save() with transaction.atomic():
key.save()
issue_certificate_for_key(key, created_by=request.user)
except IntegrityError: except IntegrityError:
raise HttpError(422, {"public_key": ["Key already exists."]}) raise HttpError(422, {"public_key": ["Key already exists."]})
except Exception as exc:
raise HttpError(500, {"detail": f"Certificate issuance failed: {exc}"})
return _key_to_out(key) return _key_to_out(key)
@router.get("/{key_id}", response=KeyOut) @router.get("/{key_id}", response=KeyOut)
def get_key(request: HttpRequest, key_id: int): def get_key(request: HttpRequest, key_id: int):
"""Get a specific SSH key if permitted.""" """Get a specific SSH key by id.
_require_authenticated(request)
Auth: required.
Permissions: requires `keys.view_sshkey` on the object.
Rationale: used by key detail views and server access debugging.
"""
require_authenticated(request)
try: try:
key = SSHKey.objects.get(id=key_id) key = SSHKey.objects.get(id=key_id)
except SSHKey.DoesNotExist: except SSHKey.DoesNotExist:
raise HttpError(404, "Not Found") raise HttpError(404, "Not Found")
if not _is_admin(request) and key.user_id != request.user.id: if not request.user.has_perm("keys.view_sshkey", key):
raise HttpError(403, "Forbidden") raise HttpError(403, "Forbidden")
return _key_to_out(key) return _key_to_out(key)
@router.patch("/{key_id}", response=KeyOut) @router.post("/{key_id}/certificate", response=CertificateOut)
def update_key(request: HttpRequest, key_id: int, payload: KeyUpdateIn): def issue_certificate(request: HttpRequest, key_id: int):
"""Update key name or active state if permitted.""" """Issue or re-issue an SSH certificate for a key.
_require_authenticated(request)
Auth: required.
Permissions: requires `keys.view_sshkey` on the object.
Rationale: allows users to download a fresh certificate as needed.
"""
require_authenticated(request)
try: try:
key = SSHKey.objects.get(id=key_id) key = SSHKey.objects.get(id=key_id)
except SSHKey.DoesNotExist: except SSHKey.DoesNotExist:
raise HttpError(404, "Not Found") raise HttpError(404, "Not Found")
if not _is_admin(request) and key.user_id != request.user.id: if not request.user.has_perm("keys.view_sshkey", key):
raise HttpError(403, "Forbidden")
cert = issue_certificate_for_key(key, created_by=request.user)
return CertificateOut(
key_id=key.id,
serial=cert.serial,
valid_after=cert.valid_after.isoformat(),
valid_before=cert.valid_before.isoformat(),
principals=list(cert.principals or []),
)
@router.get("/{key_id}/certificate")
def download_certificate(request: HttpRequest, key_id: int):
"""Download the SSH certificate for a key."""
require_authenticated(request)
try:
key = SSHKey.objects.get(id=key_id)
except SSHKey.DoesNotExist:
raise HttpError(404, "Not Found")
if not request.user.has_perm("keys.view_sshkey", key):
raise HttpError(403, "Forbidden")
cert = _ensure_certificate(key, request.user)
filename = f"keywarden-{key.user_id}-{key.id}-cert.pub"
response = HttpResponse(cert.certificate, content_type="text/plain")
response["Content-Disposition"] = f'attachment; filename="{filename}"'
return response
@router.get("/{key_id}/certificate.sha256")
def download_certificate_hash(request: HttpRequest, key_id: int):
"""Download the SSH certificate hash for a key."""
require_authenticated(request)
try:
key = SSHKey.objects.get(id=key_id)
except SSHKey.DoesNotExist:
raise HttpError(404, "Not Found")
if not request.user.has_perm("keys.view_sshkey", key):
raise HttpError(403, "Forbidden")
cert = _ensure_certificate(key, request.user)
filename = f"keywarden-{key.user_id}-{key.id}-cert.pub"
digest = hashlib.sha256(cert.certificate.encode("utf-8")).hexdigest()
payload = f"{digest} {filename}\n"
response = HttpResponse(payload, content_type="text/plain")
response["Content-Disposition"] = f'attachment; filename="{filename}.sha256"'
return response
@router.patch("/{key_id}", response=KeyOut)
def update_key(request: HttpRequest, key_id: int, payload: KeyUpdateIn):
"""Update key name or active state.
Auth: required.
Permissions: requires `keys.change_sshkey` on the object.
Rationale: allows key rotation and revocation without deletion.
"""
require_authenticated(request)
try:
key = SSHKey.objects.get(id=key_id)
except SSHKey.DoesNotExist:
raise HttpError(404, "Not Found")
if not request.user.has_perm("keys.change_sshkey", key):
raise HttpError(403, "Forbidden") raise HttpError(403, "Forbidden")
if payload.name is None and payload.is_active is None: if payload.name is None and payload.is_active is None:
raise HttpError(422, {"detail": "No fields provided."}) raise HttpError(422, {"detail": "No fields provided."})
@@ -141,25 +265,37 @@ def build_router() -> Router:
key.is_active = payload.is_active key.is_active = payload.is_active
if payload.is_active: if payload.is_active:
key.revoked_at = None key.revoked_at = None
try:
issue_certificate_for_key(key, created_by=request.user)
except Exception as exc:
raise HttpError(500, {"detail": f"Certificate issuance failed: {exc}"})
else: else:
key.revoked_at = timezone.now() key.revoked_at = timezone.now()
revoke_certificate_for_key(key)
key.save() key.save()
return _key_to_out(key) return _key_to_out(key)
@router.delete("/{key_id}", response={204: None}) @router.delete("/{key_id}", response={204: None})
def delete_key(request: HttpRequest, key_id: int): def delete_key(request: HttpRequest, key_id: int):
"""Revoke an SSH key if permitted (soft delete).""" """Revoke (soft delete) an SSH key.
_require_authenticated(request)
Auth: required.
Permissions: requires `keys.delete_sshkey` on the object.
Behavior: sets is_active false and revoked_at if key is active.
Rationale: removes key access while preserving auditability.
"""
require_authenticated(request)
try: try:
key = SSHKey.objects.get(id=key_id) key = SSHKey.objects.get(id=key_id)
except SSHKey.DoesNotExist: except SSHKey.DoesNotExist:
raise HttpError(404, "Not Found") raise HttpError(404, "Not Found")
if not _is_admin(request) and key.user_id != request.user.id: if not request.user.has_perm("keys.delete_sshkey", key):
raise HttpError(403, "Forbidden") raise HttpError(403, "Forbidden")
if key.is_active: if key.is_active:
key.is_active = False key.is_active = False
key.revoked_at = timezone.now() key.revoked_at = timezone.now()
key.save(update_fields=["is_active", "revoked_at"]) key.save(update_fields=["is_active", "revoked_at"])
revoke_certificate_for_key(key)
return 204, None return 204, None
return router return router

View File

@@ -2,11 +2,11 @@ from __future__ import annotations
from typing import List, Optional from typing import List, Optional
from django.db import IntegrityError
from django.http import HttpRequest from django.http import HttpRequest
from ninja import File, Form, Router, Schema from ninja import Router, Schema
from ninja.files import UploadedFile
from ninja.errors import HttpError from ninja.errors import HttpError
from guardian.shortcuts import get_objects_for_user, get_perms
from apps.core.rbac import require_authenticated, require_perms
from apps.servers.models import Server from apps.servers.models import Server
@@ -20,26 +20,8 @@ class ServerOut(Schema):
initial: str initial: str
class ServerCreate(Schema):
display_name: str
hostname: Optional[str] = None
ipv4: Optional[str] = None
ipv6: Optional[str] = None
class ServerUpdate(Schema): class ServerUpdate(Schema):
display_name: Optional[str] = None display_name: Optional[str] = None
hostname: Optional[str] = None
ipv4: Optional[str] = None
ipv6: Optional[str] = None
def _require_admin(request: HttpRequest) -> None:
user = request.user
if not getattr(user, "is_authenticated", False):
raise HttpError(403, "Forbidden")
if not (user.is_staff or user.is_superuser):
raise HttpError(403, "Forbidden")
def build_router() -> Router: def build_router() -> Router:
@@ -47,8 +29,20 @@ def build_router() -> Router:
@router.get("/", response=List[ServerOut]) @router.get("/", response=List[ServerOut])
def list_servers(request: HttpRequest): def list_servers(request: HttpRequest):
"""List servers visible to authenticated users.""" """List servers the caller can view.
servers = Server.objects.all()
Auth: required.
Permissions: requires `servers.view_server` via object permissions.
Behavior: returns only servers the user can see via object perms.
Rationale: drives the server dashboard and access-aware navigation.
"""
require_authenticated(request)
servers = get_objects_for_user(
request.user,
"servers.view_server",
klass=Server,
accept_global_perms=False,
)
return [ return [
{ {
"id": s.id, "id": s.id,
@@ -64,61 +58,20 @@ def build_router() -> Router:
@router.get("/{server_id}", response=ServerOut) @router.get("/{server_id}", response=ServerOut)
def get_server(request: HttpRequest, server_id: int): def get_server(request: HttpRequest, server_id: int):
"""Get server details by id.""" """Get a server record by id.
Auth: required.
Permissions: requires `servers.view_server` via object permissions.
Rationale: used by server detail views and API clients inspecting
server metadata (hostname/IPs populated by the agent).
"""
require_authenticated(request)
try: try:
server = Server.objects.get(id=server_id) server = Server.objects.get(id=server_id)
except Server.DoesNotExist: except Server.DoesNotExist:
raise HttpError(404, "Not Found") raise HttpError(404, "Not Found")
return { if "view_server" not in get_perms(request.user, server):
"id": server.id, raise HttpError(403, "Forbidden")
"display_name": server.display_name,
"hostname": server.hostname,
"ipv4": server.ipv4,
"ipv6": server.ipv6,
"image_url": server.image_url,
"initial": server.initial,
}
@router.post("/", response=ServerOut)
def create_server_json(request: HttpRequest, payload: ServerCreate):
"""Create a server using JSON payload (admin only)."""
_require_admin(request)
server = Server.objects.create(
display_name=payload.display_name.strip(),
hostname=(payload.hostname or "").strip() or None,
ipv4=(payload.ipv4 or "").strip() or None,
ipv6=(payload.ipv6 or "").strip() or None,
)
return {
"id": server.id,
"display_name": server.display_name,
"hostname": server.hostname,
"ipv4": server.ipv4,
"ipv6": server.ipv6,
"image_url": server.image_url,
"initial": server.initial,
}
@router.post("/upload", response=ServerOut)
def create_server_multipart(
request: HttpRequest,
display_name: str = Form(...),
hostname: Optional[str] = Form(None),
ipv4: Optional[str] = Form(None),
ipv6: Optional[str] = Form(None),
image: Optional[UploadedFile] = File(None),
):
"""Create a server with optional image upload (admin only)."""
_require_admin(request)
server = Server(
display_name=display_name.strip(),
hostname=(hostname or "").strip() or None,
ipv4=(ipv4 or "").strip() or None,
ipv6=(ipv6 or "").strip() or None,
)
if image:
server.image.save(image.name, image) # type: ignore[arg-type]
server.save()
return { return {
"id": server.id, "id": server.id,
"display_name": server.display_name, "display_name": server.display_name,
@@ -131,34 +84,26 @@ def build_router() -> Router:
@router.patch("/{server_id}", response=ServerOut) @router.patch("/{server_id}", response=ServerOut)
def update_server(request: HttpRequest, server_id: int, payload: ServerUpdate): def update_server(request: HttpRequest, server_id: int, payload: ServerUpdate):
"""Update server fields (admin only).""" """Update the server display name (admin only).
_require_admin(request)
if ( Auth: required.
payload.display_name is None Permissions: requires `servers.change_server`.
and payload.hostname is None Behavior: only display_name is editable via API; host/IP data is owned
and payload.ipv4 is None by the agent heartbeat to avoid conflicting sources of truth.
and payload.ipv6 is None Rationale: allows human-friendly naming without bypassing enrollment.
): """
require_perms(request, "servers.change_server")
if payload.display_name is None:
raise HttpError(422, {"detail": "No fields provided."}) raise HttpError(422, {"detail": "No fields provided."})
try: try:
server = Server.objects.get(id=server_id) server = Server.objects.get(id=server_id)
except Server.DoesNotExist: except Server.DoesNotExist:
raise HttpError(404, "Not Found") raise HttpError(404, "Not Found")
if payload.display_name is not None: display_name = payload.display_name.strip()
display_name = payload.display_name.strip() if not display_name:
if not display_name: raise HttpError(422, {"display_name": ["Display name cannot be empty."]})
raise HttpError(422, {"display_name": ["Display name cannot be empty."]}) server.display_name = display_name
server.display_name = display_name server.save(update_fields=["display_name"])
if payload.hostname is not None:
server.hostname = (payload.hostname or "").strip() or None
if payload.ipv4 is not None:
server.ipv4 = (payload.ipv4 or "").strip() or None
if payload.ipv6 is not None:
server.ipv6 = (payload.ipv6 or "").strip() or None
try:
server.save()
except IntegrityError:
raise HttpError(422, {"detail": "Unique constraint violated."})
return { return {
"id": server.id, "id": server.id,
"display_name": server.display_name, "display_name": server.display_name,
@@ -169,17 +114,6 @@ def build_router() -> Router:
"initial": server.initial, "initial": server.initial,
} }
@router.delete("/{server_id}", response={204: None})
def delete_server(request: HttpRequest, server_id: int):
"""Delete a server by id (admin only)."""
_require_admin(request)
try:
server = Server.objects.get(id=server_id)
except Server.DoesNotExist:
raise HttpError(404, "Not Found")
server.delete()
return 204, None
return router return router

View File

@@ -2,6 +2,8 @@ from typing import Literal, TypedDict
from ninja import Router from ninja import Router
from apps.core.rbac import require_authenticated
class HealthResponse(TypedDict): class HealthResponse(TypedDict):
status: Literal["ok"] status: Literal["ok"]
@@ -11,8 +13,16 @@ def build_router() -> Router:
router = Router() router = Router()
@router.get("/health", response=HealthResponse) @router.get("/health", response=HealthResponse)
def health() -> HealthResponse: def health(request) -> HealthResponse:
"""Health check endpoint for service monitoring.""" """Return application liveness for internal monitoring.
Auth: required (session or JWT). This is intentionally protected to avoid
exposing internal status to unauthenticated callers.
Behavior: returns a static {"status": "ok"} if the app stack is reachable.
Rationale: used by uptime checks and deployments to confirm the API
process is running and can authenticate requests.
"""
require_authenticated(request)
return {"status": "ok"} return {"status": "ok"}
return router return router

View File

@@ -10,6 +10,7 @@ from ninja import Query, Router, Schema
from ninja.errors import HttpError from ninja.errors import HttpError
from pydantic import Field from pydantic import Field
from apps.core.rbac import require_perms
from apps.servers.models import Server from apps.servers.models import Server
from apps.telemetry.models import TelemetryEvent from apps.telemetry.models import TelemetryEvent
@@ -51,14 +52,6 @@ class TelemetryQuery(Schema):
success: Optional[bool] = None success: Optional[bool] = None
def _require_admin(request: HttpRequest) -> None:
user = request.user
if not getattr(user, "is_authenticated", False):
raise HttpError(403, "Forbidden")
if not (user.is_staff or user.is_superuser):
raise HttpError(403, "Forbidden")
def _event_to_out(event: TelemetryEvent) -> TelemetryOut: def _event_to_out(event: TelemetryEvent) -> TelemetryOut:
return TelemetryOut( return TelemetryOut(
id=event.id, id=event.id,
@@ -78,8 +71,14 @@ def build_router() -> Router:
@router.get("/", response=List[TelemetryOut]) @router.get("/", response=List[TelemetryOut])
def list_events(request: HttpRequest, filters: TelemetryQuery = Query(...)): def list_events(request: HttpRequest, filters: TelemetryQuery = Query(...)):
"""List telemetry events with filters (admin only).""" """List telemetry events emitted by the platform and agents.
_require_admin(request)
Auth: required.
Permissions: requires `telemetry.view_telemetryevent`.
Filters: event_type, server_id, user_id, success.
Rationale: supports operational dashboards and audit-style timelines.
"""
require_perms(request, "telemetry.view_telemetryevent")
qs = TelemetryEvent.objects.order_by("-created_at") qs = TelemetryEvent.objects.order_by("-created_at")
if filters.event_type: if filters.event_type:
qs = qs.filter(event_type=filters.event_type) qs = qs.filter(event_type=filters.event_type)
@@ -94,8 +93,15 @@ def build_router() -> Router:
@router.post("/", response=TelemetryOut) @router.post("/", response=TelemetryOut)
def create_event(request: HttpRequest, payload: TelemetryCreateIn): def create_event(request: HttpRequest, payload: TelemetryCreateIn):
"""Create a telemetry event entry (admin only).""" """Create a telemetry event entry.
_require_admin(request)
Auth: required.
Permissions: requires `telemetry.add_telemetryevent`.
Behavior: validates server/user references and normalizes source.
Rationale: used by internal automation; if external clients are not
expected to emit telemetry, this endpoint can be restricted further.
"""
require_perms(request, "telemetry.add_telemetryevent")
server = None server = None
if payload.server_id: if payload.server_id:
try: try:
@@ -122,8 +128,13 @@ def build_router() -> Router:
@router.get("/summary", response=TelemetrySummaryOut) @router.get("/summary", response=TelemetrySummaryOut)
def summary(request: HttpRequest): def summary(request: HttpRequest):
"""Return a high-level telemetry summary (admin only).""" """Return a high-level success/failure summary.
_require_admin(request)
Auth: required.
Permissions: requires `telemetry.view_telemetryevent`.
Rationale: feeds dashboard widgets without pulling full event lists.
"""
require_perms(request, "telemetry.view_telemetryevent")
totals = TelemetryEvent.objects.aggregate( totals = TelemetryEvent.objects.aggregate(
total=Count("id"), total=Count("id"),
success=Count("id", filter=models.Q(success=True)), success=Count("id", filter=models.Q(success=True)),

View File

@@ -9,11 +9,13 @@ from ninja import Query, Router, Schema
from ninja.errors import HttpError from ninja.errors import HttpError
from pydantic import EmailStr, Field from pydantic import EmailStr, Field
from apps.core.rbac import ROLE_USER, get_user_role, require_perms, set_user_role
class UserCreateIn(Schema): class UserCreateIn(Schema):
email: EmailStr email: EmailStr
password: str = Field(min_length=8) password: str = Field(min_length=8)
role: Literal["admin", "user"] role: Literal["administrator", "operator", "auditor", "user", "admin"]
class UserListOut(Schema): class UserListOut(Schema):
@@ -33,7 +35,7 @@ class UserDetailOut(Schema):
class UserUpdateIn(Schema): class UserUpdateIn(Schema):
email: EmailStr | None = None email: EmailStr | None = None
password: str | None = Field(default=None, min_length=8) password: str | None = Field(default=None, min_length=8)
role: Literal["admin", "user"] | None = None role: Literal["administrator", "operator", "auditor", "user", "admin"] | None = None
is_active: bool | None = None is_active: bool | None = None
@@ -42,25 +44,8 @@ class UsersQuery(Schema):
offset: int = Field(default=0, ge=0) offset: int = Field(default=0, ge=0)
def _require_admin(request: HttpRequest) -> None:
user = request.user
if not getattr(user, "is_authenticated", False):
raise HttpError(403, "Forbidden")
if not (user.is_staff or user.is_superuser):
raise HttpError(403, "Forbidden")
def _role_from_user(user) -> str: def _role_from_user(user) -> str:
return "admin" if (user.is_staff or user.is_superuser) else "user" return get_user_role(user) or ROLE_USER
def _apply_role(user, role: str) -> None:
if role == "admin":
user.is_staff = True
user.is_superuser = True
else:
user.is_staff = False
user.is_superuser = False
def build_router() -> Router: def build_router() -> Router:
@@ -68,19 +53,31 @@ def build_router() -> Router:
@router.post("/", response=UserDetailOut) @router.post("/", response=UserDetailOut)
def create_user(request: HttpRequest, payload: UserCreateIn): def create_user(request: HttpRequest, payload: UserCreateIn):
"""Create a user with role and password (admin only).""" """Create a platform user and assign a Keywarden role.
_require_admin(request)
Auth: required.
Permissions: requires `auth.add_user` (admin/operator).
Behavior: uses email as username, hashes the password, and assigns a
role which maps to Keywarden group permissions.
Rationale: enables automation and external admin workflows; mirrors
the admin UI user creation flow.
"""
require_perms(request, "auth.add_user")
User = get_user_model() User = get_user_model()
email = payload.email.strip().lower() email = payload.email.strip().lower()
if User.objects.filter(email__iexact=email).exists(): if User.objects.filter(email__iexact=email).exists():
raise HttpError(422, {"email": ["Email already exists."]}) raise HttpError(422, {"email": ["Email already exists."]})
user = User(username=email, email=email, is_active=True) user = User(username=email, email=email, is_active=True)
_apply_role(user, payload.role)
user.set_password(payload.password) user.set_password(payload.password)
try: try:
user.save() user.save()
except IntegrityError: except IntegrityError:
raise HttpError(422, {"email": ["Email already exists."]}) raise HttpError(422, {"email": ["Email already exists."]})
try:
set_user_role(user, payload.role)
except ValueError:
raise HttpError(422, {"role": ["Invalid role."]})
user.save()
return { return {
"id": user.id, "id": user.id,
"email": user.email, "email": user.email,
@@ -90,8 +87,14 @@ def build_router() -> Router:
@router.get("/", response=List[UserListOut]) @router.get("/", response=List[UserListOut])
def list_users(request: HttpRequest, pagination: UsersQuery = Query(...)): def list_users(request: HttpRequest, pagination: UsersQuery = Query(...)):
"""List users with pagination (admin only).""" """List users for administrative visibility and access management.
_require_admin(request)
Auth: required.
Permissions: requires `auth.view_user`.
Pagination: limit + offset.
Rationale: used by admin UI and automation to audit user access.
"""
require_perms(request, "auth.view_user")
User = get_user_model() User = get_user_model()
qs = User.objects.order_by("id")[pagination.offset : pagination.offset + pagination.limit] qs = User.objects.order_by("id")[pagination.offset : pagination.offset + pagination.limit]
return [ return [
@@ -106,8 +109,13 @@ def build_router() -> Router:
@router.get("/{user_id}", response=UserDetailOut) @router.get("/{user_id}", response=UserDetailOut)
def get_user(request: HttpRequest, user_id: int): def get_user(request: HttpRequest, user_id: int):
"""Get user details by id (admin only).""" """Fetch a single user record for inspection.
_require_admin(request)
Auth: required.
Permissions: requires `auth.view_user`.
Rationale: used by admin detail views and automation scripts.
"""
require_perms(request, "auth.view_user")
User = get_user_model() User = get_user_model()
try: try:
user = User.objects.get(id=user_id) user = User.objects.get(id=user_id)
@@ -122,8 +130,14 @@ def build_router() -> Router:
@router.patch("/{user_id}", response=UserDetailOut) @router.patch("/{user_id}", response=UserDetailOut)
def update_user(request: HttpRequest, user_id: int, payload: UserUpdateIn): def update_user(request: HttpRequest, user_id: int, payload: UserUpdateIn):
"""Update user fields such as role, email, or status (admin only).""" """Update user identity, role, password, or activation state.
_require_admin(request)
Auth: required.
Permissions: requires `auth.change_user` (admin).
Side effects: role changes update Keywarden role/group mappings.
Rationale: required for role delegation and account lifecycle control.
"""
require_perms(request, "auth.change_user")
if payload.email is None and payload.password is None and payload.role is None and payload.is_active is None: if payload.email is None and payload.password is None and payload.role is None and payload.is_active is None:
raise HttpError(422, {"detail": "No fields provided."}) raise HttpError(422, {"detail": "No fields provided."})
User = get_user_model() User = get_user_model()
@@ -140,7 +154,10 @@ def build_router() -> Router:
if payload.password is not None: if payload.password is not None:
user.set_password(payload.password) user.set_password(payload.password)
if payload.role is not None: if payload.role is not None:
_apply_role(user, payload.role) try:
set_user_role(user, payload.role)
except ValueError:
raise HttpError(422, {"role": ["Invalid role."]})
if payload.is_active is not None: if payload.is_active is not None:
user.is_active = payload.is_active user.is_active = payload.is_active
user.save() user.save()
@@ -151,18 +168,6 @@ def build_router() -> Router:
"is_active": user.is_active, "is_active": user.is_active,
} }
@router.delete("/{user_id}", response={204: None})
def delete_user(request: HttpRequest, user_id: int):
"""Delete a user by id (admin only)."""
_require_admin(request)
User = get_user_model()
try:
user = User.objects.get(id=user_id)
except User.DoesNotExist:
raise HttpError(404, "Not Found")
user.delete()
return 204, None
return router return router

View File

@@ -0,0 +1,19 @@
import os
from channels.auth import AuthMiddlewareStack
from channels.routing import ProtocolTypeRouter, URLRouter
from django.core.asgi import get_asgi_application
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "keywarden.settings.dev")
django_app = get_asgi_application()
from .routing import websocket_urlpatterns # noqa: E402
application = ProtocolTypeRouter(
{
"http": django_app,
"websocket": AuthMiddlewareStack(URLRouter(websocket_urlpatterns)),
}
)

9
app/keywarden/celery.py Normal file
View File

@@ -0,0 +1,9 @@
import os
from celery import Celery
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "keywarden.settings.dev")
app = Celery("keywarden")
app.config_from_object("django.conf:settings", namespace="CELERY")
app.autodiscover_tasks()

7
app/keywarden/routing.py Normal file
View File

@@ -0,0 +1,7 @@
from django.urls import re_path
from apps.servers.consumers import ShellConsumer
websocket_urlpatterns = [
re_path(r"^ws/servers/(?P<server_id>\d+)/shell/$", ShellConsumer.as_asgi()),
]

Some files were not shown because too many files have changed in this diff Show More