diff --git a/agent/cmd/keywarden-agent/main.go b/agent/cmd/keywarden-agent/main.go index 5eb6621..447e714 100644 --- a/agent/cmd/keywarden-agent/main.go +++ b/agent/cmd/keywarden-agent/main.go @@ -82,7 +82,7 @@ func runOnce(ctx context.Context, apiClient *client.Client, cfg *config.Config) log.Printf("host update error: %v", err) } } - if err := apiClient.SyncAccounts(ctx, cfg.ServerID); err != nil { + if err := apiClient.SyncAccounts(ctx, cfg); err != nil { log.Printf("sync accounts error: %v", err) } if err := shipLogs(ctx, apiClient, cfg); err != nil { diff --git a/agent/internal/accounts/sync.go b/agent/internal/accounts/sync.go new file mode 100644 index 0000000..bb3c0f3 --- /dev/null +++ b/agent/internal/accounts/sync.go @@ -0,0 +1,323 @@ +package accounts + +import ( + "bufio" + "encoding/json" + "errors" + "fmt" + "os" + "os/exec" + "path/filepath" + "strconv" + "strings" + + "keywarden/agent/internal/config" +) + +const ( + stateFileName = "accounts.json" + maxUsernameLen = 32 + passwdFilePath = "/etc/passwd" + sshDirName = ".ssh" + authKeysName = "authorized_keys" +) + +type AccessUser struct { + UserID int + Username string + Email string + Keys []string +} + +type ReportAccount struct { + UserID int `json:"user_id"` + SystemUser string `json:"system_username"` + Present bool `json:"present"` +} + +type Result struct { + Applied int + Revoked int + Accounts []ReportAccount +} + +type managedAccount struct { + UserID int `json:"user_id"` + SystemUser string `json:"system_username"` +} + +type state struct { + Users map[string]managedAccount `json:"users"` +} + +type passwdEntry struct { + UID int + GID int + Home string +} + +func Sync(policy config.AccountPolicy, stateDir string, users []AccessUser) (Result, error) { + result := Result{} + statePath := filepath.Join(stateDir, stateFileName) + current, err := loadState(statePath) + if err != nil { + return result, err + } + + desired := make(map[int]managedAccount, len(users)) + userIndex := make(map[int]AccessUser, len(users)) + for _, user := range users { + systemUser := renderUsername(policy.UsernameTemplate, user.Username, user.UserID) + desired[user.UserID] = managedAccount{UserID: user.UserID, SystemUser: systemUser} + userIndex[user.UserID] = user + } + + var syncErr error + for _, account := range current.Users { + if _, ok := desired[account.UserID]; ok { + continue + } + if err := revokeUser(account.SystemUser, policy); err != nil && syncErr == nil { + syncErr = err + } + result.Revoked++ + } + + for userID, account := range desired { + accessUser := userIndex[userID] + present, err := ensureAccount(account.SystemUser, policy, accessUser.Keys) + if err != nil && syncErr == nil { + syncErr = err + } + if present { + result.Applied++ + } + result.Accounts = append(result.Accounts, ReportAccount{ + UserID: userID, + SystemUser: account.SystemUser, + Present: present, + }) + } + + if err := saveState(statePath, desired); err != nil && syncErr == nil { + syncErr = err + } + return result, syncErr +} + +func loadState(path string) (state, error) { + st := state{Users: map[string]managedAccount{}} + data, err := os.ReadFile(path) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return st, nil + } + return st, fmt.Errorf("read state: %w", err) + } + if err := json.Unmarshal(data, &st); err != nil { + return st, fmt.Errorf("parse state: %w", err) + } + if st.Users == nil { + st.Users = map[string]managedAccount{} + } + return st, nil +} + +func saveState(path string, desired map[int]managedAccount) error { + st := state{Users: map[string]managedAccount{}} + for id, account := range desired { + st.Users[strconv.Itoa(id)] = account + } + data, err := json.MarshalIndent(st, "", " ") + if err != nil { + return fmt.Errorf("encode state: %w", err) + } + if err := os.MkdirAll(filepath.Dir(path), 0o700); err != nil { + return fmt.Errorf("create state dir: %w", err) + } + if err := os.WriteFile(path, data, 0o600); err != nil { + return fmt.Errorf("write state: %w", err) + } + return nil +} + +func renderUsername(template string, username string, userID int) string { + raw := strings.ReplaceAll(template, "{{username}}", username) + raw = strings.ReplaceAll(raw, "{{user_id}}", strconv.Itoa(userID)) + clean := sanitizeUsername(raw) + if len(clean) > maxUsernameLen { + clean = clean[:maxUsernameLen] + } + if clean == "" { + clean = fmt.Sprintf("kw_%d", userID) + } + return clean +} + +func sanitizeUsername(raw string) string { + raw = strings.ToLower(raw) + var b strings.Builder + b.Grow(len(raw)) + for _, r := range raw { + if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '_' || r == '-' { + b.WriteRune(r) + continue + } + b.WriteByte('_') + } + out := strings.Trim(b.String(), "-_") + if out == "" { + return "" + } + if strings.HasPrefix(out, "-") { + return "kw" + out + } + return out +} + +func userExists(username string) (bool, error) { + cmd := exec.Command("id", "-u", username) + if err := cmd.Run(); err != nil { + if _, ok := err.(*exec.ExitError); ok { + return false, nil + } + return false, err + } + return true, nil +} + +func ensureAccount(username string, policy config.AccountPolicy, keys []string) (bool, error) { + exists, err := userExists(username) + if err != nil { + return false, err + } + if !exists { + if err := createUser(username, policy); err != nil { + return false, err + } + } + if err := lockPassword(username); err != nil { + return true, err + } + if err := writeAuthorizedKeys(username, keys); err != nil { + return true, err + } + return true, nil +} + +func createUser(username string, policy config.AccountPolicy) error { + args := []string{"-U"} + if policy.CreateHome { + args = append(args, "-m") + } else { + args = append(args, "-M") + } + if policy.DefaultShell != "" { + args = append(args, "-s", policy.DefaultShell) + } + args = append(args, username) + cmd := exec.Command("useradd", args...) + if err := cmd.Run(); err != nil { + return fmt.Errorf("useradd %s: %w", username, err) + } + return nil +} + +func lockPassword(username string) error { + cmd := exec.Command("usermod", "-L", username) + if err := cmd.Run(); err != nil { + return fmt.Errorf("lock password %s: %w", username, err) + } + return nil +} + +func revokeUser(username string, policy config.AccountPolicy) error { + exists, err := userExists(username) + if err != nil { + return err + } + if !exists { + return nil + } + var revokeErr error + if policy.LockOnRevoke { + if err := lockPassword(username); err != nil { + revokeErr = err + } + } + if err := writeAuthorizedKeys(username, nil); err != nil && revokeErr == nil { + revokeErr = err + } + return revokeErr +} + +func writeAuthorizedKeys(username string, keys []string) error { + entry, err := lookupUser(username) + if err != nil { + return err + } + if entry.Home == "" { + return fmt.Errorf("missing home dir for %s", username) + } + sshDir := filepath.Join(entry.Home, sshDirName) + if err := os.MkdirAll(sshDir, 0o700); err != nil { + return fmt.Errorf("mkdir %s: %w", sshDir, err) + } + if err := os.Chmod(sshDir, 0o700); err != nil { + return fmt.Errorf("chmod %s: %w", sshDir, err) + } + if err := os.Chown(sshDir, entry.UID, entry.GID); err != nil { + return fmt.Errorf("chown %s: %w", sshDir, err) + } + authKeysPath := filepath.Join(sshDir, authKeysName) + payload := strings.TrimSpace(strings.Join(keys, "\n")) + if payload != "" { + payload += "\n" + } + if err := os.WriteFile(authKeysPath, []byte(payload), 0o600); err != nil { + return fmt.Errorf("write %s: %w", authKeysPath, err) + } + if err := os.Chown(authKeysPath, entry.UID, entry.GID); err != nil { + return fmt.Errorf("chown %s: %w", authKeysPath, err) + } + return nil +} + +func lookupUser(username string) (passwdEntry, error) { + file, err := os.Open(passwdFilePath) + if err != nil { + return passwdEntry{}, fmt.Errorf("open passwd: %w", err) + } + defer file.Close() + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := scanner.Text() + if line == "" || strings.HasPrefix(line, "#") { + continue + } + fields := strings.SplitN(line, ":", 7) + if len(fields) < 7 { + continue + } + if fields[0] != username { + continue + } + uid, err := strconv.Atoi(fields[2]) + if err != nil { + return passwdEntry{}, fmt.Errorf("parse uid for %s: %w", username, err) + } + gid, err := strconv.Atoi(fields[3]) + if err != nil { + return passwdEntry{}, fmt.Errorf("parse gid for %s: %w", username, err) + } + return passwdEntry{ + UID: uid, + GID: gid, + Home: fields[5], + }, nil + } + if err := scanner.Err(); err != nil { + return passwdEntry{}, fmt.Errorf("scan passwd: %w", err) + } + return passwdEntry{}, fmt.Errorf("user %s not found", username) +} diff --git a/agent/internal/client/client.go b/agent/internal/client/client.go index 69b5950..f71c34b 100644 --- a/agent/internal/client/client.go +++ b/agent/internal/client/client.go @@ -13,6 +13,7 @@ import ( "strings" "time" + "keywarden/agent/internal/accounts" "keywarden/agent/internal/config" ) @@ -81,6 +82,32 @@ type EnrollResponse struct { DisplayName string `json:"display_name,omitempty"` } +type AccountKey struct { + PublicKey string `json:"public_key"` + Fingerprint string `json:"fingerprint"` +} + +type AccountAccess struct { + UserID int `json:"user_id"` + Username string `json:"username"` + Email string `json:"email"` + Keys []AccountKey `json:"keys"` +} + +type AccountSyncEntry struct { + UserID int `json:"user_id"` + SystemUsername string `json:"system_username"` + Present bool `json:"present"` +} + +type SyncReportRequest struct { + AppliedCount int `json:"applied_count"` + RevokedCount int `json:"revoked_count"` + Message string `json:"message,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + Accounts []AccountSyncEntry `json:"accounts,omitempty"` +} + func Enroll(ctx context.Context, serverURL string, req EnrollRequest) (*EnrollResponse, error) { baseURL := strings.TrimRight(serverURL, "/") if baseURL == "" { @@ -114,10 +141,103 @@ func Enroll(ctx context.Context, serverURL string, req EnrollRequest) (*EnrollRe return &out, nil } -func (c *Client) SyncAccounts(ctx context.Context, serverID string) error { - _ = ctx - _ = serverID - // TODO: call API to fetch account policy + approved access list. +func (c *Client) SyncAccounts(ctx context.Context, cfg *config.Config) error { + if cfg == nil { + return errors.New("config required for account sync") + } + users, err := c.FetchAccountAccess(ctx, cfg.ServerID) + if err != nil { + return err + } + accessUsers := make([]accounts.AccessUser, 0, len(users)) + for _, user := range users { + keys := make([]string, 0, len(user.Keys)) + for _, key := range user.Keys { + if strings.TrimSpace(key.PublicKey) == "" { + continue + } + keys = append(keys, strings.TrimSpace(key.PublicKey)) + } + accessUsers = append(accessUsers, accounts.AccessUser{ + UserID: user.UserID, + Username: user.Username, + Email: user.Email, + Keys: keys, + }) + } + result, syncErr := accounts.Sync(cfg.AccountPolicy, cfg.StateDir, accessUsers) + report := SyncReportRequest{ + AppliedCount: result.Applied, + RevokedCount: result.Revoked, + Accounts: make([]AccountSyncEntry, 0, len(result.Accounts)), + } + for _, account := range result.Accounts { + report.Accounts = append(report.Accounts, AccountSyncEntry{ + UserID: account.UserID, + SystemUsername: account.SystemUser, + Present: account.Present, + }) + } + if syncErr != nil { + report.Message = syncErr.Error() + } + if err := c.SendSyncReport(ctx, cfg.ServerID, report); err != nil { + if syncErr != nil { + return fmt.Errorf("sync report failed: %w (sync error: %v)", err, syncErr) + } + return err + } + return syncErr +} + +func (c *Client) FetchAccountAccess(ctx context.Context, serverID string) ([]AccountAccess, error) { + req, err := http.NewRequestWithContext( + ctx, + http.MethodGet, + c.baseURL+"/agent/servers/"+serverID+"/accounts", + nil, + ) + if err != nil { + return nil, fmt.Errorf("build account access request: %w", err) + } + resp, err := c.http.Do(req) + if err != nil { + return nil, fmt.Errorf("fetch account access: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode >= 300 { + return nil, &HTTPStatusError{StatusCode: resp.StatusCode, Status: resp.Status} + } + var out []AccountAccess + if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { + return nil, fmt.Errorf("decode account access: %w", err) + } + return out, nil +} + +func (c *Client) SendSyncReport(ctx context.Context, serverID string, report SyncReportRequest) error { + body, err := json.Marshal(report) + if err != nil { + return fmt.Errorf("encode sync report: %w", err) + } + req, err := http.NewRequestWithContext( + ctx, + http.MethodPost, + c.baseURL+"/agent/servers/"+serverID+"/sync-report", + bytes.NewReader(body), + ) + if err != nil { + return fmt.Errorf("build sync report: %w", err) + } + req.Header.Set("Content-Type", "application/json") + resp, err := c.http.Do(req) + if err != nil { + return fmt.Errorf("send sync report: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode >= 300 { + return &HTTPStatusError{StatusCode: resp.StatusCode, Status: resp.Status} + } return nil } diff --git a/agent/keywarden-agent b/agent/keywarden-agent index c724980..3072405 100755 Binary files a/agent/keywarden-agent and b/agent/keywarden-agent differ diff --git a/app/apps/servers/migrations/0004_server_account.py b/app/apps/servers/migrations/0004_server_account.py new file mode 100644 index 0000000..447f786 --- /dev/null +++ b/app/apps/servers/migrations/0004_server_account.py @@ -0,0 +1,59 @@ +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion +import django.utils.timezone + + +class Migration(migrations.Migration): + + dependencies = [ + ("servers", "0003_agent_ca"), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.CreateModel( + name="ServerAccount", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("system_username", models.CharField(max_length=128)), + ("is_present", models.BooleanField(db_index=True, default=False)), + ("last_synced_at", models.DateTimeField(default=django.utils.timezone.now, editable=False)), + ("created_at", models.DateTimeField(default=django.utils.timezone.now, editable=False)), + ("updated_at", models.DateTimeField(auto_now=True)), + ( + "server", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="accounts", + to="servers.server", + ), + ), + ( + "user", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="server_accounts", + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "verbose_name": "Server account", + "verbose_name_plural": "Server accounts", + "ordering": ["server_id", "user_id"], + }, + ), + migrations.AddConstraint( + model_name="serveraccount", + constraint=models.UniqueConstraint(fields=("server", "user"), name="unique_server_account"), + ), + migrations.AddIndex( + model_name="serveraccount", + index=models.Index(fields=["server", "user"], name="servers_account_user_idx"), + ), + migrations.AddIndex( + model_name="serveraccount", + index=models.Index(fields=["server", "is_present"], name="servers_account_present_idx"), + ), + ] diff --git a/app/apps/servers/models.py b/app/apps/servers/models.py index 1efe9ed..37215aa 100644 --- a/app/apps/servers/models.py +++ b/app/apps/servers/models.py @@ -157,3 +157,30 @@ class AgentCertificateAuthority(models.Model): self.key_pem = key_pem self.fingerprint = cert.fingerprint(hashes.SHA256()).hex() self.serial = format(cert.serial_number, "x") + + +class ServerAccount(models.Model): + server = models.ForeignKey(Server, on_delete=models.CASCADE, related_name="accounts") + user = models.ForeignKey( + settings.AUTH_USER_MODEL, on_delete=models.CASCADE, related_name="server_accounts" + ) + system_username = models.CharField(max_length=128) + is_present = models.BooleanField(default=False, db_index=True) + last_synced_at = models.DateTimeField(default=timezone.now, editable=False) + created_at = models.DateTimeField(default=timezone.now, editable=False) + updated_at = models.DateTimeField(auto_now=True) + + class Meta: + verbose_name = "Server account" + verbose_name_plural = "Server accounts" + constraints = [ + models.UniqueConstraint(fields=["server", "user"], name="unique_server_account") + ] + indexes = [ + models.Index(fields=["server", "user"], name="servers_account_user_idx"), + models.Index(fields=["server", "is_present"], name="servers_account_present_idx"), + ] + ordering = ["server_id", "user_id"] + + def __str__(self) -> str: + return f"{self.system_username} ({self.server_id})" diff --git a/app/apps/servers/templates/servers/detail.html b/app/apps/servers/templates/servers/detail.html index efa2184..29023da 100644 --- a/app/apps/servers/templates/servers/detail.html +++ b/app/apps/servers/templates/servers/detail.html @@ -33,6 +33,18 @@ {% endif %} +
+
Account on server
+
+ {% if account_present is None %} + Unknown + {% elif account_present %} + Present + {% else %} + Missing + {% endif %} +
+
Last accessed
diff --git a/app/apps/servers/views.py b/app/apps/servers/views.py index cc1eed4..7d9291a 100644 --- a/app/apps/servers/views.py +++ b/app/apps/servers/views.py @@ -5,24 +5,21 @@ from django.db.models import Q from django.http import Http404 from django.shortcuts import render from django.utils import timezone -from guardian.shortcuts import get_objects_for_user +from guardian.shortcuts import get_objects_for_user, get_perms from apps.access.models import AccessRequest -from apps.servers.models import Server +from apps.servers.models import Server, ServerAccount @login_required(login_url="/accounts/login/") def dashboard(request): now = timezone.now() - if request.user.has_perm("servers.view_server"): - server_qs = Server.objects.all() - else: - server_qs = get_objects_for_user( - request.user, - "servers.view_server", - klass=Server, - accept_global_perms=False, - ) + server_qs = get_objects_for_user( + request.user, + "servers.view_server", + klass=Server, + accept_global_perms=False, + ) access_qs = ( AccessRequest.objects.select_related("server") @@ -66,9 +63,7 @@ def detail(request, server_id: int): server = Server.objects.get(id=server_id) except Server.DoesNotExist: raise Http404("Server not found") - if not request.user.has_perm("servers.view_server", server) and not request.user.has_perm( - "servers.view_server" - ): + if "view_server" not in get_perms(request.user, server): raise Http404("Server not found") access = ( @@ -82,9 +77,13 @@ def detail(request, server_id: int): .first() ) + account = ServerAccount.objects.filter(server=server, user=request.user).first() context = { "server": server, "expires_at": access.expires_at if access else None, "last_accessed": None, + "account_present": account.is_present if account else None, + "account_synced_at": account.last_synced_at if account else None, + "system_username": account.system_username if account else None, } return render(request, "servers/detail.html", context) diff --git a/app/keywarden/api/routers/agent.py b/app/keywarden/api/routers/agent.py index 3215b01..b6f38ee 100644 --- a/app/keywarden/api/routers/agent.py +++ b/app/keywarden/api/routers/agent.py @@ -5,20 +5,27 @@ from cryptography import x509 from cryptography.hazmat.primitives import hashes, serialization from cryptography.x509.oid import ExtendedKeyUsageOID, NameOID from django.conf import settings +from django.contrib.auth import get_user_model from django.core.exceptions import ValidationError from django.core.validators import validate_ipv4_address, validate_ipv6_address -from django.db import IntegrityError, models, transaction +from django.db import IntegrityError, transaction from django.http import HttpRequest from django.utils import timezone from django.views.decorators.csrf import csrf_exempt from ninja import Body, Router, Schema from ninja.errors import HttpError from pydantic import Field +from guardian.shortcuts import get_users_with_perms from apps.core.rbac import require_perms -from apps.access.models import AccessRequest from apps.keys.models import SSHKey -from apps.servers.models import AgentCertificateAuthority, EnrollmentToken, Server, hostname_validator +from apps.servers.models import ( + AgentCertificateAuthority, + EnrollmentToken, + Server, + ServerAccount, + hostname_validator, +) from apps.telemetry.models import TelemetryEvent @@ -30,11 +37,30 @@ class AuthorizedKeyOut(Schema): fingerprint: str +class AccountKeyOut(Schema): + public_key: str + fingerprint: str + + +class AccountAccessOut(Schema): + user_id: int + username: str + email: str + keys: List[AccountKeyOut] + + +class AccountSyncIn(Schema): + user_id: int + system_username: str + present: bool + + class SyncReportIn(Schema): applied_count: int = Field(default=0, ge=0) revoked_count: int = Field(default=0, ge=0) message: Optional[str] = None metadata: dict = Field(default_factory=dict) + accounts: List[AccountSyncIn] = Field(default_factory=list) class SyncReportOut(Schema): @@ -152,42 +178,55 @@ def build_router() -> Router: """Resolve the effective authorized_keys list for a server. Auth: required (admin/operator via API). - Permissions: requires view access to servers, keys, and access requests. - Behavior: combines approved access requests with active SSH keys to - produce the exact key list the agent should deploy to the server. + Permissions: requires view access to servers and keys. + Behavior: uses server object permissions + active SSH keys to produce + the exact key list the agent should deploy to the server. Rationale: this is the policy enforcement point for per-user access. """ require_perms( request, "servers.view_server", "keys.view_sshkey", - "access.view_accessrequest", - ) - try: - server = Server.objects.get(id=server_id) - except Server.DoesNotExist: - raise HttpError(404, "Server not found") - now = timezone.now() - access_qs = AccessRequest.objects.select_related("requester").filter( - server=server, - status=AccessRequest.Status.APPROVED, - ) - access_qs = access_qs.filter(models.Q(expires_at__isnull=True) | models.Q(expires_at__gt=now)) - users = [req.requester for req in access_qs if req.requester and req.requester.is_active] - keys = SSHKey.objects.select_related("user").filter( - user__in=users, - is_active=True, - revoked_at__isnull=True, ) + server = _get_server_or_404(server_id) + users = _resolve_access_users(server) + key_map = _key_map_for_users(users) + output: list[AuthorizedKeyOut] = [] + for user in users: + for key in key_map.get(user.id, []): + output.append( + AuthorizedKeyOut( + user_id=user.id, + username=user.username, + email=user.email or "", + public_key=key.public_key, + fingerprint=key.fingerprint, + ) + ) + return output + + @router.get("/servers/{server_id}/accounts", response=List[AccountAccessOut], auth=None) + def account_access(request: HttpRequest, server_id: int): + """List accounts that should exist on a server. + + Auth: mTLS expected at the edge (no session/JWT). + Behavior: resolves active users with server object perms and their keys. + Rationale: drives agent-side account provisioning. + """ + server = _get_server_or_404(server_id) + users = _resolve_access_users(server) + key_map = _key_map_for_users(users) return [ - AuthorizedKeyOut( - user_id=key.user_id, - username=key.user.username, - email=key.user.email or "", - public_key=key.public_key, - fingerprint=key.fingerprint, + AccountAccessOut( + user_id=user.id, + username=user.username, + email=user.email or "", + keys=[ + AccountKeyOut(public_key=key.public_key, fingerprint=key.fingerprint) + for key in key_map.get(user.id, []) + ], ) - for key in keys + for user in users ] @router.post("/servers/{server_id}/sync-report", response=SyncReportOut, auth=None) @@ -216,6 +255,8 @@ def build_router() -> Router: **(payload.metadata or {}), }, ) + if payload.accounts: + _update_server_accounts(server, payload.accounts) return SyncReportOut(status="ok") @router.post("/servers/{server_id}/logs", response=LogIngestOut, auth=None) @@ -277,6 +318,62 @@ def build_router() -> Router: return router +def _get_server_or_404(server_id: int) -> Server: + try: + return Server.objects.get(id=server_id) + except Server.DoesNotExist: + raise HttpError(404, "Server not found") + + +def _resolve_access_users(server: Server) -> list: + users = list( + get_users_with_perms( + server, + only_with_perms_in=["view_server"], + with_group_users=True, + with_superusers=False, + ) + ) + active = [user for user in users if getattr(user, "is_active", False)] + return sorted(active, key=lambda user: (user.username or "", user.id)) + + +def _key_map_for_users(users: list) -> dict[int, list[SSHKey]]: + if not users: + return {} + keys = SSHKey.objects.select_related("user").filter( + user__in=users, + is_active=True, + revoked_at__isnull=True, + ) + key_map: dict[int, list[SSHKey]] = {} + for key in keys: + key_map.setdefault(key.user_id, []).append(key) + return key_map + + +def _update_server_accounts(server: Server, accounts: list[AccountSyncIn]) -> None: + user_ids = {account.user_id for account in accounts} + if not user_ids: + return + User = get_user_model() + users = {user.id: user for user in User.objects.filter(id__in=user_ids)} + now = timezone.now() + for account in accounts: + user = users.get(account.user_id) + if not user: + continue + ServerAccount.objects.update_or_create( + server=server, + user=user, + defaults={ + "system_username": account.system_username, + "is_present": account.present, + "last_synced_at": now, + }, + ) + + def _load_agent_ca() -> tuple[x509.Certificate, object, str]: ca = ( AgentCertificateAuthority.objects.filter(is_active=True, revoked_at__isnull=True) diff --git a/app/keywarden/api/routers/servers.py b/app/keywarden/api/routers/servers.py index 81f1845..9463aff 100644 --- a/app/keywarden/api/routers/servers.py +++ b/app/keywarden/api/routers/servers.py @@ -5,8 +5,8 @@ from typing import List, Optional from django.http import HttpRequest from ninja import Router, Schema from ninja.errors import HttpError -from guardian.shortcuts import get_objects_for_user -from apps.core.rbac import require_perms +from guardian.shortcuts import get_objects_for_user, get_perms +from apps.core.rbac import require_authenticated, require_perms from apps.servers.models import Server @@ -32,20 +32,17 @@ def build_router() -> Router: """List servers the caller can view. Auth: required. - Permissions: requires `servers.view_server` globally or per-object. + Permissions: requires `servers.view_server` via object permissions. Behavior: returns only servers the user can see via object perms. Rationale: drives the server dashboard and access-aware navigation. """ - require_perms(request, "servers.view_server") - if request.user.has_perm("servers.view_server"): - servers = Server.objects.all() - else: - servers = get_objects_for_user( - request.user, - "servers.view_server", - klass=Server, - accept_global_perms=False, - ) + require_authenticated(request) + servers = get_objects_for_user( + request.user, + "servers.view_server", + klass=Server, + accept_global_perms=False, + ) return [ { "id": s.id, @@ -64,18 +61,16 @@ def build_router() -> Router: """Get a server record by id. Auth: required. - Permissions: requires `servers.view_server` globally or per-object. + Permissions: requires `servers.view_server` via object permissions. Rationale: used by server detail views and API clients inspecting server metadata (hostname/IPs populated by the agent). """ - require_perms(request, "servers.view_server") + require_authenticated(request) try: server = Server.objects.get(id=server_id) except Server.DoesNotExist: raise HttpError(404, "Not Found") - if not request.user.has_perm("servers.view_server", server) and not request.user.has_perm( - "servers.view_server" - ): + if "view_server" not in get_perms(request.user, server): raise HttpError(403, "Forbidden") return { "id": server.id,