Compare commits

...

2 Commits

10 changed files with 705 additions and 67 deletions

View File

@@ -82,7 +82,7 @@ func runOnce(ctx context.Context, apiClient *client.Client, cfg *config.Config)
log.Printf("host update error: %v", err) log.Printf("host update error: %v", err)
} }
} }
if err := apiClient.SyncAccounts(ctx, cfg.ServerID); err != nil { if err := apiClient.SyncAccounts(ctx, cfg); err != nil {
log.Printf("sync accounts error: %v", err) log.Printf("sync accounts error: %v", err)
} }
if err := shipLogs(ctx, apiClient, cfg); err != nil { if err := shipLogs(ctx, apiClient, cfg); err != nil {

View File

@@ -0,0 +1,323 @@
package accounts
import (
"bufio"
"encoding/json"
"errors"
"fmt"
"os"
"os/exec"
"path/filepath"
"strconv"
"strings"
"keywarden/agent/internal/config"
)
const (
stateFileName = "accounts.json"
maxUsernameLen = 32
passwdFilePath = "/etc/passwd"
sshDirName = ".ssh"
authKeysName = "authorized_keys"
)
type AccessUser struct {
UserID int
Username string
Email string
Keys []string
}
type ReportAccount struct {
UserID int `json:"user_id"`
SystemUser string `json:"system_username"`
Present bool `json:"present"`
}
type Result struct {
Applied int
Revoked int
Accounts []ReportAccount
}
type managedAccount struct {
UserID int `json:"user_id"`
SystemUser string `json:"system_username"`
}
type state struct {
Users map[string]managedAccount `json:"users"`
}
type passwdEntry struct {
UID int
GID int
Home string
}
func Sync(policy config.AccountPolicy, stateDir string, users []AccessUser) (Result, error) {
result := Result{}
statePath := filepath.Join(stateDir, stateFileName)
current, err := loadState(statePath)
if err != nil {
return result, err
}
desired := make(map[int]managedAccount, len(users))
userIndex := make(map[int]AccessUser, len(users))
for _, user := range users {
systemUser := renderUsername(policy.UsernameTemplate, user.Username, user.UserID)
desired[user.UserID] = managedAccount{UserID: user.UserID, SystemUser: systemUser}
userIndex[user.UserID] = user
}
var syncErr error
for _, account := range current.Users {
if _, ok := desired[account.UserID]; ok {
continue
}
if err := revokeUser(account.SystemUser, policy); err != nil && syncErr == nil {
syncErr = err
}
result.Revoked++
}
for userID, account := range desired {
accessUser := userIndex[userID]
present, err := ensureAccount(account.SystemUser, policy, accessUser.Keys)
if err != nil && syncErr == nil {
syncErr = err
}
if present {
result.Applied++
}
result.Accounts = append(result.Accounts, ReportAccount{
UserID: userID,
SystemUser: account.SystemUser,
Present: present,
})
}
if err := saveState(statePath, desired); err != nil && syncErr == nil {
syncErr = err
}
return result, syncErr
}
func loadState(path string) (state, error) {
st := state{Users: map[string]managedAccount{}}
data, err := os.ReadFile(path)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return st, nil
}
return st, fmt.Errorf("read state: %w", err)
}
if err := json.Unmarshal(data, &st); err != nil {
return st, fmt.Errorf("parse state: %w", err)
}
if st.Users == nil {
st.Users = map[string]managedAccount{}
}
return st, nil
}
func saveState(path string, desired map[int]managedAccount) error {
st := state{Users: map[string]managedAccount{}}
for id, account := range desired {
st.Users[strconv.Itoa(id)] = account
}
data, err := json.MarshalIndent(st, "", " ")
if err != nil {
return fmt.Errorf("encode state: %w", err)
}
if err := os.MkdirAll(filepath.Dir(path), 0o700); err != nil {
return fmt.Errorf("create state dir: %w", err)
}
if err := os.WriteFile(path, data, 0o600); err != nil {
return fmt.Errorf("write state: %w", err)
}
return nil
}
func renderUsername(template string, username string, userID int) string {
raw := strings.ReplaceAll(template, "{{username}}", username)
raw = strings.ReplaceAll(raw, "{{user_id}}", strconv.Itoa(userID))
clean := sanitizeUsername(raw)
if len(clean) > maxUsernameLen {
clean = clean[:maxUsernameLen]
}
if clean == "" {
clean = fmt.Sprintf("kw_%d", userID)
}
return clean
}
func sanitizeUsername(raw string) string {
raw = strings.ToLower(raw)
var b strings.Builder
b.Grow(len(raw))
for _, r := range raw {
if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '_' || r == '-' {
b.WriteRune(r)
continue
}
b.WriteByte('_')
}
out := strings.Trim(b.String(), "-_")
if out == "" {
return ""
}
if strings.HasPrefix(out, "-") {
return "kw" + out
}
return out
}
func userExists(username string) (bool, error) {
cmd := exec.Command("id", "-u", username)
if err := cmd.Run(); err != nil {
if _, ok := err.(*exec.ExitError); ok {
return false, nil
}
return false, err
}
return true, nil
}
func ensureAccount(username string, policy config.AccountPolicy, keys []string) (bool, error) {
exists, err := userExists(username)
if err != nil {
return false, err
}
if !exists {
if err := createUser(username, policy); err != nil {
return false, err
}
}
if err := lockPassword(username); err != nil {
return true, err
}
if err := writeAuthorizedKeys(username, keys); err != nil {
return true, err
}
return true, nil
}
func createUser(username string, policy config.AccountPolicy) error {
args := []string{"-U"}
if policy.CreateHome {
args = append(args, "-m")
} else {
args = append(args, "-M")
}
if policy.DefaultShell != "" {
args = append(args, "-s", policy.DefaultShell)
}
args = append(args, username)
cmd := exec.Command("useradd", args...)
if err := cmd.Run(); err != nil {
return fmt.Errorf("useradd %s: %w", username, err)
}
return nil
}
func lockPassword(username string) error {
cmd := exec.Command("usermod", "-L", username)
if err := cmd.Run(); err != nil {
return fmt.Errorf("lock password %s: %w", username, err)
}
return nil
}
func revokeUser(username string, policy config.AccountPolicy) error {
exists, err := userExists(username)
if err != nil {
return err
}
if !exists {
return nil
}
var revokeErr error
if policy.LockOnRevoke {
if err := lockPassword(username); err != nil {
revokeErr = err
}
}
if err := writeAuthorizedKeys(username, nil); err != nil && revokeErr == nil {
revokeErr = err
}
return revokeErr
}
func writeAuthorizedKeys(username string, keys []string) error {
entry, err := lookupUser(username)
if err != nil {
return err
}
if entry.Home == "" {
return fmt.Errorf("missing home dir for %s", username)
}
sshDir := filepath.Join(entry.Home, sshDirName)
if err := os.MkdirAll(sshDir, 0o700); err != nil {
return fmt.Errorf("mkdir %s: %w", sshDir, err)
}
if err := os.Chmod(sshDir, 0o700); err != nil {
return fmt.Errorf("chmod %s: %w", sshDir, err)
}
if err := os.Chown(sshDir, entry.UID, entry.GID); err != nil {
return fmt.Errorf("chown %s: %w", sshDir, err)
}
authKeysPath := filepath.Join(sshDir, authKeysName)
payload := strings.TrimSpace(strings.Join(keys, "\n"))
if payload != "" {
payload += "\n"
}
if err := os.WriteFile(authKeysPath, []byte(payload), 0o600); err != nil {
return fmt.Errorf("write %s: %w", authKeysPath, err)
}
if err := os.Chown(authKeysPath, entry.UID, entry.GID); err != nil {
return fmt.Errorf("chown %s: %w", authKeysPath, err)
}
return nil
}
func lookupUser(username string) (passwdEntry, error) {
file, err := os.Open(passwdFilePath)
if err != nil {
return passwdEntry{}, fmt.Errorf("open passwd: %w", err)
}
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
if line == "" || strings.HasPrefix(line, "#") {
continue
}
fields := strings.SplitN(line, ":", 7)
if len(fields) < 7 {
continue
}
if fields[0] != username {
continue
}
uid, err := strconv.Atoi(fields[2])
if err != nil {
return passwdEntry{}, fmt.Errorf("parse uid for %s: %w", username, err)
}
gid, err := strconv.Atoi(fields[3])
if err != nil {
return passwdEntry{}, fmt.Errorf("parse gid for %s: %w", username, err)
}
return passwdEntry{
UID: uid,
GID: gid,
Home: fields[5],
}, nil
}
if err := scanner.Err(); err != nil {
return passwdEntry{}, fmt.Errorf("scan passwd: %w", err)
}
return passwdEntry{}, fmt.Errorf("user %s not found", username)
}

View File

@@ -13,6 +13,7 @@ import (
"strings" "strings"
"time" "time"
"keywarden/agent/internal/accounts"
"keywarden/agent/internal/config" "keywarden/agent/internal/config"
) )
@@ -81,6 +82,32 @@ type EnrollResponse struct {
DisplayName string `json:"display_name,omitempty"` DisplayName string `json:"display_name,omitempty"`
} }
type AccountKey struct {
PublicKey string `json:"public_key"`
Fingerprint string `json:"fingerprint"`
}
type AccountAccess struct {
UserID int `json:"user_id"`
Username string `json:"username"`
Email string `json:"email"`
Keys []AccountKey `json:"keys"`
}
type AccountSyncEntry struct {
UserID int `json:"user_id"`
SystemUsername string `json:"system_username"`
Present bool `json:"present"`
}
type SyncReportRequest struct {
AppliedCount int `json:"applied_count"`
RevokedCount int `json:"revoked_count"`
Message string `json:"message,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
Accounts []AccountSyncEntry `json:"accounts,omitempty"`
}
func Enroll(ctx context.Context, serverURL string, req EnrollRequest) (*EnrollResponse, error) { func Enroll(ctx context.Context, serverURL string, req EnrollRequest) (*EnrollResponse, error) {
baseURL := strings.TrimRight(serverURL, "/") baseURL := strings.TrimRight(serverURL, "/")
if baseURL == "" { if baseURL == "" {
@@ -114,10 +141,103 @@ func Enroll(ctx context.Context, serverURL string, req EnrollRequest) (*EnrollRe
return &out, nil return &out, nil
} }
func (c *Client) SyncAccounts(ctx context.Context, serverID string) error { func (c *Client) SyncAccounts(ctx context.Context, cfg *config.Config) error {
_ = ctx if cfg == nil {
_ = serverID return errors.New("config required for account sync")
// TODO: call API to fetch account policy + approved access list. }
users, err := c.FetchAccountAccess(ctx, cfg.ServerID)
if err != nil {
return err
}
accessUsers := make([]accounts.AccessUser, 0, len(users))
for _, user := range users {
keys := make([]string, 0, len(user.Keys))
for _, key := range user.Keys {
if strings.TrimSpace(key.PublicKey) == "" {
continue
}
keys = append(keys, strings.TrimSpace(key.PublicKey))
}
accessUsers = append(accessUsers, accounts.AccessUser{
UserID: user.UserID,
Username: user.Username,
Email: user.Email,
Keys: keys,
})
}
result, syncErr := accounts.Sync(cfg.AccountPolicy, cfg.StateDir, accessUsers)
report := SyncReportRequest{
AppliedCount: result.Applied,
RevokedCount: result.Revoked,
Accounts: make([]AccountSyncEntry, 0, len(result.Accounts)),
}
for _, account := range result.Accounts {
report.Accounts = append(report.Accounts, AccountSyncEntry{
UserID: account.UserID,
SystemUsername: account.SystemUser,
Present: account.Present,
})
}
if syncErr != nil {
report.Message = syncErr.Error()
}
if err := c.SendSyncReport(ctx, cfg.ServerID, report); err != nil {
if syncErr != nil {
return fmt.Errorf("sync report failed: %w (sync error: %v)", err, syncErr)
}
return err
}
return syncErr
}
func (c *Client) FetchAccountAccess(ctx context.Context, serverID string) ([]AccountAccess, error) {
req, err := http.NewRequestWithContext(
ctx,
http.MethodGet,
c.baseURL+"/agent/servers/"+serverID+"/accounts",
nil,
)
if err != nil {
return nil, fmt.Errorf("build account access request: %w", err)
}
resp, err := c.http.Do(req)
if err != nil {
return nil, fmt.Errorf("fetch account access: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
return nil, &HTTPStatusError{StatusCode: resp.StatusCode, Status: resp.Status}
}
var out []AccountAccess
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
return nil, fmt.Errorf("decode account access: %w", err)
}
return out, nil
}
func (c *Client) SendSyncReport(ctx context.Context, serverID string, report SyncReportRequest) error {
body, err := json.Marshal(report)
if err != nil {
return fmt.Errorf("encode sync report: %w", err)
}
req, err := http.NewRequestWithContext(
ctx,
http.MethodPost,
c.baseURL+"/agent/servers/"+serverID+"/sync-report",
bytes.NewReader(body),
)
if err != nil {
return fmt.Errorf("build sync report: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.http.Do(req)
if err != nil {
return fmt.Errorf("send sync report: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
return &HTTPStatusError{StatusCode: resp.StatusCode, Status: resp.Status}
}
return nil return nil
} }

Binary file not shown.

View File

@@ -0,0 +1,59 @@
from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
import django.utils.timezone
class Migration(migrations.Migration):
dependencies = [
("servers", "0003_agent_ca"),
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]
operations = [
migrations.CreateModel(
name="ServerAccount",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("system_username", models.CharField(max_length=128)),
("is_present", models.BooleanField(db_index=True, default=False)),
("last_synced_at", models.DateTimeField(default=django.utils.timezone.now, editable=False)),
("created_at", models.DateTimeField(default=django.utils.timezone.now, editable=False)),
("updated_at", models.DateTimeField(auto_now=True)),
(
"server",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name="accounts",
to="servers.server",
),
),
(
"user",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name="server_accounts",
to=settings.AUTH_USER_MODEL,
),
),
],
options={
"verbose_name": "Server account",
"verbose_name_plural": "Server accounts",
"ordering": ["server_id", "user_id"],
},
),
migrations.AddConstraint(
model_name="serveraccount",
constraint=models.UniqueConstraint(fields=("server", "user"), name="unique_server_account"),
),
migrations.AddIndex(
model_name="serveraccount",
index=models.Index(fields=["server", "user"], name="servers_account_user_idx"),
),
migrations.AddIndex(
model_name="serveraccount",
index=models.Index(fields=["server", "is_present"], name="servers_account_present_idx"),
),
]

View File

@@ -157,3 +157,30 @@ class AgentCertificateAuthority(models.Model):
self.key_pem = key_pem self.key_pem = key_pem
self.fingerprint = cert.fingerprint(hashes.SHA256()).hex() self.fingerprint = cert.fingerprint(hashes.SHA256()).hex()
self.serial = format(cert.serial_number, "x") self.serial = format(cert.serial_number, "x")
class ServerAccount(models.Model):
server = models.ForeignKey(Server, on_delete=models.CASCADE, related_name="accounts")
user = models.ForeignKey(
settings.AUTH_USER_MODEL, on_delete=models.CASCADE, related_name="server_accounts"
)
system_username = models.CharField(max_length=128)
is_present = models.BooleanField(default=False, db_index=True)
last_synced_at = models.DateTimeField(default=timezone.now, editable=False)
created_at = models.DateTimeField(default=timezone.now, editable=False)
updated_at = models.DateTimeField(auto_now=True)
class Meta:
verbose_name = "Server account"
verbose_name_plural = "Server accounts"
constraints = [
models.UniqueConstraint(fields=["server", "user"], name="unique_server_account")
]
indexes = [
models.Index(fields=["server", "user"], name="servers_account_user_idx"),
models.Index(fields=["server", "is_present"], name="servers_account_present_idx"),
]
ordering = ["server_id", "user_id"]
def __str__(self) -> str:
return f"{self.system_username} ({self.server_id})"

View File

@@ -33,6 +33,24 @@
{% endif %} {% endif %}
</dd> </dd>
</div> </div>
<div class="flex items-center justify-between">
<dt>Account name</dt>
<dd class="font-medium text-gray-900">
{% if system_username %}
{{ system_username }}
{% else %}
Unknown
{% endif %}
</dd>
</div>
<div class="flex items-center justify-between">
<dt>Certificate</dt>
<dd class="font-medium text-gray-900">
<span class="inline-flex items-center rounded-full border border-gray-200 bg-gray-50 px-2 py-1 text-xs font-semibold text-gray-500">
Download coming soon
</span>
</dd>
</div>
<div class="flex items-center justify-between"> <div class="flex items-center justify-between">
<dt>Last accessed</dt> <dt>Last accessed</dt>
<dd class="font-medium text-gray-900"> <dd class="font-medium text-gray-900">

View File

@@ -5,24 +5,21 @@ from django.db.models import Q
from django.http import Http404 from django.http import Http404
from django.shortcuts import render from django.shortcuts import render
from django.utils import timezone from django.utils import timezone
from guardian.shortcuts import get_objects_for_user from guardian.shortcuts import get_objects_for_user, get_perms
from apps.access.models import AccessRequest from apps.access.models import AccessRequest
from apps.servers.models import Server from apps.servers.models import Server, ServerAccount
@login_required(login_url="/accounts/login/") @login_required(login_url="/accounts/login/")
def dashboard(request): def dashboard(request):
now = timezone.now() now = timezone.now()
if request.user.has_perm("servers.view_server"): server_qs = get_objects_for_user(
server_qs = Server.objects.all() request.user,
else: "servers.view_server",
server_qs = get_objects_for_user( klass=Server,
request.user, accept_global_perms=False,
"servers.view_server", )
klass=Server,
accept_global_perms=False,
)
access_qs = ( access_qs = (
AccessRequest.objects.select_related("server") AccessRequest.objects.select_related("server")
@@ -66,9 +63,7 @@ def detail(request, server_id: int):
server = Server.objects.get(id=server_id) server = Server.objects.get(id=server_id)
except Server.DoesNotExist: except Server.DoesNotExist:
raise Http404("Server not found") raise Http404("Server not found")
if not request.user.has_perm("servers.view_server", server) and not request.user.has_perm( if "view_server" not in get_perms(request.user, server):
"servers.view_server"
):
raise Http404("Server not found") raise Http404("Server not found")
access = ( access = (
@@ -82,9 +77,13 @@ def detail(request, server_id: int):
.first() .first()
) )
account = ServerAccount.objects.filter(server=server, user=request.user).first()
context = { context = {
"server": server, "server": server,
"expires_at": access.expires_at if access else None, "expires_at": access.expires_at if access else None,
"last_accessed": None, "last_accessed": None,
"account_present": account.is_present if account else None,
"account_synced_at": account.last_synced_at if account else None,
"system_username": account.system_username if account else None,
} }
return render(request, "servers/detail.html", context) return render(request, "servers/detail.html", context)

View File

@@ -5,20 +5,27 @@ from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives import hashes, serialization
from cryptography.x509.oid import ExtendedKeyUsageOID, NameOID from cryptography.x509.oid import ExtendedKeyUsageOID, NameOID
from django.conf import settings from django.conf import settings
from django.contrib.auth import get_user_model
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.core.validators import validate_ipv4_address, validate_ipv6_address from django.core.validators import validate_ipv4_address, validate_ipv6_address
from django.db import IntegrityError, models, transaction from django.db import IntegrityError, transaction
from django.http import HttpRequest from django.http import HttpRequest
from django.utils import timezone from django.utils import timezone
from django.views.decorators.csrf import csrf_exempt from django.views.decorators.csrf import csrf_exempt
from ninja import Body, Router, Schema from ninja import Body, Router, Schema
from ninja.errors import HttpError from ninja.errors import HttpError
from pydantic import Field from pydantic import Field
from guardian.shortcuts import get_users_with_perms
from apps.core.rbac import require_perms from apps.core.rbac import require_perms
from apps.access.models import AccessRequest
from apps.keys.models import SSHKey from apps.keys.models import SSHKey
from apps.servers.models import AgentCertificateAuthority, EnrollmentToken, Server, hostname_validator from apps.servers.models import (
AgentCertificateAuthority,
EnrollmentToken,
Server,
ServerAccount,
hostname_validator,
)
from apps.telemetry.models import TelemetryEvent from apps.telemetry.models import TelemetryEvent
@@ -30,11 +37,30 @@ class AuthorizedKeyOut(Schema):
fingerprint: str fingerprint: str
class AccountKeyOut(Schema):
public_key: str
fingerprint: str
class AccountAccessOut(Schema):
user_id: int
username: str
email: str
keys: List[AccountKeyOut]
class AccountSyncIn(Schema):
user_id: int
system_username: str
present: bool
class SyncReportIn(Schema): class SyncReportIn(Schema):
applied_count: int = Field(default=0, ge=0) applied_count: int = Field(default=0, ge=0)
revoked_count: int = Field(default=0, ge=0) revoked_count: int = Field(default=0, ge=0)
message: Optional[str] = None message: Optional[str] = None
metadata: dict = Field(default_factory=dict) metadata: dict = Field(default_factory=dict)
accounts: List[AccountSyncIn] = Field(default_factory=list)
class SyncReportOut(Schema): class SyncReportOut(Schema):
@@ -152,42 +178,55 @@ def build_router() -> Router:
"""Resolve the effective authorized_keys list for a server. """Resolve the effective authorized_keys list for a server.
Auth: required (admin/operator via API). Auth: required (admin/operator via API).
Permissions: requires view access to servers, keys, and access requests. Permissions: requires view access to servers and keys.
Behavior: combines approved access requests with active SSH keys to Behavior: uses server object permissions + active SSH keys to produce
produce the exact key list the agent should deploy to the server. the exact key list the agent should deploy to the server.
Rationale: this is the policy enforcement point for per-user access. Rationale: this is the policy enforcement point for per-user access.
""" """
require_perms( require_perms(
request, request,
"servers.view_server", "servers.view_server",
"keys.view_sshkey", "keys.view_sshkey",
"access.view_accessrequest",
)
try:
server = Server.objects.get(id=server_id)
except Server.DoesNotExist:
raise HttpError(404, "Server not found")
now = timezone.now()
access_qs = AccessRequest.objects.select_related("requester").filter(
server=server,
status=AccessRequest.Status.APPROVED,
)
access_qs = access_qs.filter(models.Q(expires_at__isnull=True) | models.Q(expires_at__gt=now))
users = [req.requester for req in access_qs if req.requester and req.requester.is_active]
keys = SSHKey.objects.select_related("user").filter(
user__in=users,
is_active=True,
revoked_at__isnull=True,
) )
server = _get_server_or_404(server_id)
users = _resolve_access_users(server)
key_map = _key_map_for_users(users)
output: list[AuthorizedKeyOut] = []
for user in users:
for key in key_map.get(user.id, []):
output.append(
AuthorizedKeyOut(
user_id=user.id,
username=user.username,
email=user.email or "",
public_key=key.public_key,
fingerprint=key.fingerprint,
)
)
return output
@router.get("/servers/{server_id}/accounts", response=List[AccountAccessOut], auth=None)
def account_access(request: HttpRequest, server_id: int):
"""List accounts that should exist on a server.
Auth: mTLS expected at the edge (no session/JWT).
Behavior: resolves active users with server object perms and their keys.
Rationale: drives agent-side account provisioning.
"""
server = _get_server_or_404(server_id)
users = _resolve_access_users(server)
key_map = _key_map_for_users(users)
return [ return [
AuthorizedKeyOut( AccountAccessOut(
user_id=key.user_id, user_id=user.id,
username=key.user.username, username=user.username,
email=key.user.email or "", email=user.email or "",
public_key=key.public_key, keys=[
fingerprint=key.fingerprint, AccountKeyOut(public_key=key.public_key, fingerprint=key.fingerprint)
for key in key_map.get(user.id, [])
],
) )
for key in keys for user in users
] ]
@router.post("/servers/{server_id}/sync-report", response=SyncReportOut, auth=None) @router.post("/servers/{server_id}/sync-report", response=SyncReportOut, auth=None)
@@ -216,6 +255,8 @@ def build_router() -> Router:
**(payload.metadata or {}), **(payload.metadata or {}),
}, },
) )
if payload.accounts:
_update_server_accounts(server, payload.accounts)
return SyncReportOut(status="ok") return SyncReportOut(status="ok")
@router.post("/servers/{server_id}/logs", response=LogIngestOut, auth=None) @router.post("/servers/{server_id}/logs", response=LogIngestOut, auth=None)
@@ -277,6 +318,62 @@ def build_router() -> Router:
return router return router
def _get_server_or_404(server_id: int) -> Server:
try:
return Server.objects.get(id=server_id)
except Server.DoesNotExist:
raise HttpError(404, "Server not found")
def _resolve_access_users(server: Server) -> list:
users = list(
get_users_with_perms(
server,
only_with_perms_in=["view_server"],
with_group_users=True,
with_superusers=False,
)
)
active = [user for user in users if getattr(user, "is_active", False)]
return sorted(active, key=lambda user: (user.username or "", user.id))
def _key_map_for_users(users: list) -> dict[int, list[SSHKey]]:
if not users:
return {}
keys = SSHKey.objects.select_related("user").filter(
user__in=users,
is_active=True,
revoked_at__isnull=True,
)
key_map: dict[int, list[SSHKey]] = {}
for key in keys:
key_map.setdefault(key.user_id, []).append(key)
return key_map
def _update_server_accounts(server: Server, accounts: list[AccountSyncIn]) -> None:
user_ids = {account.user_id for account in accounts}
if not user_ids:
return
User = get_user_model()
users = {user.id: user for user in User.objects.filter(id__in=user_ids)}
now = timezone.now()
for account in accounts:
user = users.get(account.user_id)
if not user:
continue
ServerAccount.objects.update_or_create(
server=server,
user=user,
defaults={
"system_username": account.system_username,
"is_present": account.present,
"last_synced_at": now,
},
)
def _load_agent_ca() -> tuple[x509.Certificate, object, str]: def _load_agent_ca() -> tuple[x509.Certificate, object, str]:
ca = ( ca = (
AgentCertificateAuthority.objects.filter(is_active=True, revoked_at__isnull=True) AgentCertificateAuthority.objects.filter(is_active=True, revoked_at__isnull=True)

View File

@@ -5,8 +5,8 @@ from typing import List, Optional
from django.http import HttpRequest from django.http import HttpRequest
from ninja import Router, Schema from ninja import Router, Schema
from ninja.errors import HttpError from ninja.errors import HttpError
from guardian.shortcuts import get_objects_for_user from guardian.shortcuts import get_objects_for_user, get_perms
from apps.core.rbac import require_perms from apps.core.rbac import require_authenticated, require_perms
from apps.servers.models import Server from apps.servers.models import Server
@@ -32,20 +32,17 @@ def build_router() -> Router:
"""List servers the caller can view. """List servers the caller can view.
Auth: required. Auth: required.
Permissions: requires `servers.view_server` globally or per-object. Permissions: requires `servers.view_server` via object permissions.
Behavior: returns only servers the user can see via object perms. Behavior: returns only servers the user can see via object perms.
Rationale: drives the server dashboard and access-aware navigation. Rationale: drives the server dashboard and access-aware navigation.
""" """
require_perms(request, "servers.view_server") require_authenticated(request)
if request.user.has_perm("servers.view_server"): servers = get_objects_for_user(
servers = Server.objects.all() request.user,
else: "servers.view_server",
servers = get_objects_for_user( klass=Server,
request.user, accept_global_perms=False,
"servers.view_server", )
klass=Server,
accept_global_perms=False,
)
return [ return [
{ {
"id": s.id, "id": s.id,
@@ -64,18 +61,16 @@ def build_router() -> Router:
"""Get a server record by id. """Get a server record by id.
Auth: required. Auth: required.
Permissions: requires `servers.view_server` globally or per-object. Permissions: requires `servers.view_server` via object permissions.
Rationale: used by server detail views and API clients inspecting Rationale: used by server detail views and API clients inspecting
server metadata (hostname/IPs populated by the agent). server metadata (hostname/IPs populated by the agent).
""" """
require_perms(request, "servers.view_server") require_authenticated(request)
try: try:
server = Server.objects.get(id=server_id) server = Server.objects.get(id=server_id)
except Server.DoesNotExist: except Server.DoesNotExist:
raise HttpError(404, "Not Found") raise HttpError(404, "Not Found")
if not request.user.has_perm("servers.view_server", server) and not request.user.has_perm( if "view_server" not in get_perms(request.user, server):
"servers.view_server"
):
raise HttpError(403, "Forbidden") raise HttpError(403, "Forbidden")
return { return {
"id": server.id, "id": server.id,