diff --git a/agent/cmd/keywarden-agent/main.go b/agent/cmd/keywarden-agent/main.go index 5eb6621..447e714 100644 --- a/agent/cmd/keywarden-agent/main.go +++ b/agent/cmd/keywarden-agent/main.go @@ -82,7 +82,7 @@ func runOnce(ctx context.Context, apiClient *client.Client, cfg *config.Config) log.Printf("host update error: %v", err) } } - if err := apiClient.SyncAccounts(ctx, cfg.ServerID); err != nil { + if err := apiClient.SyncAccounts(ctx, cfg); err != nil { log.Printf("sync accounts error: %v", err) } if err := shipLogs(ctx, apiClient, cfg); err != nil { diff --git a/agent/internal/accounts/sync.go b/agent/internal/accounts/sync.go new file mode 100644 index 0000000..bb3c0f3 --- /dev/null +++ b/agent/internal/accounts/sync.go @@ -0,0 +1,323 @@ +package accounts + +import ( + "bufio" + "encoding/json" + "errors" + "fmt" + "os" + "os/exec" + "path/filepath" + "strconv" + "strings" + + "keywarden/agent/internal/config" +) + +const ( + stateFileName = "accounts.json" + maxUsernameLen = 32 + passwdFilePath = "/etc/passwd" + sshDirName = ".ssh" + authKeysName = "authorized_keys" +) + +type AccessUser struct { + UserID int + Username string + Email string + Keys []string +} + +type ReportAccount struct { + UserID int `json:"user_id"` + SystemUser string `json:"system_username"` + Present bool `json:"present"` +} + +type Result struct { + Applied int + Revoked int + Accounts []ReportAccount +} + +type managedAccount struct { + UserID int `json:"user_id"` + SystemUser string `json:"system_username"` +} + +type state struct { + Users map[string]managedAccount `json:"users"` +} + +type passwdEntry struct { + UID int + GID int + Home string +} + +func Sync(policy config.AccountPolicy, stateDir string, users []AccessUser) (Result, error) { + result := Result{} + statePath := filepath.Join(stateDir, stateFileName) + current, err := loadState(statePath) + if err != nil { + return result, err + } + + desired := make(map[int]managedAccount, len(users)) + userIndex := make(map[int]AccessUser, len(users)) + for _, user := range users { + systemUser := renderUsername(policy.UsernameTemplate, user.Username, user.UserID) + desired[user.UserID] = managedAccount{UserID: user.UserID, SystemUser: systemUser} + userIndex[user.UserID] = user + } + + var syncErr error + for _, account := range current.Users { + if _, ok := desired[account.UserID]; ok { + continue + } + if err := revokeUser(account.SystemUser, policy); err != nil && syncErr == nil { + syncErr = err + } + result.Revoked++ + } + + for userID, account := range desired { + accessUser := userIndex[userID] + present, err := ensureAccount(account.SystemUser, policy, accessUser.Keys) + if err != nil && syncErr == nil { + syncErr = err + } + if present { + result.Applied++ + } + result.Accounts = append(result.Accounts, ReportAccount{ + UserID: userID, + SystemUser: account.SystemUser, + Present: present, + }) + } + + if err := saveState(statePath, desired); err != nil && syncErr == nil { + syncErr = err + } + return result, syncErr +} + +func loadState(path string) (state, error) { + st := state{Users: map[string]managedAccount{}} + data, err := os.ReadFile(path) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return st, nil + } + return st, fmt.Errorf("read state: %w", err) + } + if err := json.Unmarshal(data, &st); err != nil { + return st, fmt.Errorf("parse state: %w", err) + } + if st.Users == nil { + st.Users = map[string]managedAccount{} + } + return st, nil +} + +func saveState(path string, desired map[int]managedAccount) error { + st := state{Users: map[string]managedAccount{}} + for id, account := range desired { + st.Users[strconv.Itoa(id)] = account + } + data, err := json.MarshalIndent(st, "", " ") + if err != nil { + return fmt.Errorf("encode state: %w", err) + } + if err := os.MkdirAll(filepath.Dir(path), 0o700); err != nil { + return fmt.Errorf("create state dir: %w", err) + } + if err := os.WriteFile(path, data, 0o600); err != nil { + return fmt.Errorf("write state: %w", err) + } + return nil +} + +func renderUsername(template string, username string, userID int) string { + raw := strings.ReplaceAll(template, "{{username}}", username) + raw = strings.ReplaceAll(raw, "{{user_id}}", strconv.Itoa(userID)) + clean := sanitizeUsername(raw) + if len(clean) > maxUsernameLen { + clean = clean[:maxUsernameLen] + } + if clean == "" { + clean = fmt.Sprintf("kw_%d", userID) + } + return clean +} + +func sanitizeUsername(raw string) string { + raw = strings.ToLower(raw) + var b strings.Builder + b.Grow(len(raw)) + for _, r := range raw { + if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '_' || r == '-' { + b.WriteRune(r) + continue + } + b.WriteByte('_') + } + out := strings.Trim(b.String(), "-_") + if out == "" { + return "" + } + if strings.HasPrefix(out, "-") { + return "kw" + out + } + return out +} + +func userExists(username string) (bool, error) { + cmd := exec.Command("id", "-u", username) + if err := cmd.Run(); err != nil { + if _, ok := err.(*exec.ExitError); ok { + return false, nil + } + return false, err + } + return true, nil +} + +func ensureAccount(username string, policy config.AccountPolicy, keys []string) (bool, error) { + exists, err := userExists(username) + if err != nil { + return false, err + } + if !exists { + if err := createUser(username, policy); err != nil { + return false, err + } + } + if err := lockPassword(username); err != nil { + return true, err + } + if err := writeAuthorizedKeys(username, keys); err != nil { + return true, err + } + return true, nil +} + +func createUser(username string, policy config.AccountPolicy) error { + args := []string{"-U"} + if policy.CreateHome { + args = append(args, "-m") + } else { + args = append(args, "-M") + } + if policy.DefaultShell != "" { + args = append(args, "-s", policy.DefaultShell) + } + args = append(args, username) + cmd := exec.Command("useradd", args...) + if err := cmd.Run(); err != nil { + return fmt.Errorf("useradd %s: %w", username, err) + } + return nil +} + +func lockPassword(username string) error { + cmd := exec.Command("usermod", "-L", username) + if err := cmd.Run(); err != nil { + return fmt.Errorf("lock password %s: %w", username, err) + } + return nil +} + +func revokeUser(username string, policy config.AccountPolicy) error { + exists, err := userExists(username) + if err != nil { + return err + } + if !exists { + return nil + } + var revokeErr error + if policy.LockOnRevoke { + if err := lockPassword(username); err != nil { + revokeErr = err + } + } + if err := writeAuthorizedKeys(username, nil); err != nil && revokeErr == nil { + revokeErr = err + } + return revokeErr +} + +func writeAuthorizedKeys(username string, keys []string) error { + entry, err := lookupUser(username) + if err != nil { + return err + } + if entry.Home == "" { + return fmt.Errorf("missing home dir for %s", username) + } + sshDir := filepath.Join(entry.Home, sshDirName) + if err := os.MkdirAll(sshDir, 0o700); err != nil { + return fmt.Errorf("mkdir %s: %w", sshDir, err) + } + if err := os.Chmod(sshDir, 0o700); err != nil { + return fmt.Errorf("chmod %s: %w", sshDir, err) + } + if err := os.Chown(sshDir, entry.UID, entry.GID); err != nil { + return fmt.Errorf("chown %s: %w", sshDir, err) + } + authKeysPath := filepath.Join(sshDir, authKeysName) + payload := strings.TrimSpace(strings.Join(keys, "\n")) + if payload != "" { + payload += "\n" + } + if err := os.WriteFile(authKeysPath, []byte(payload), 0o600); err != nil { + return fmt.Errorf("write %s: %w", authKeysPath, err) + } + if err := os.Chown(authKeysPath, entry.UID, entry.GID); err != nil { + return fmt.Errorf("chown %s: %w", authKeysPath, err) + } + return nil +} + +func lookupUser(username string) (passwdEntry, error) { + file, err := os.Open(passwdFilePath) + if err != nil { + return passwdEntry{}, fmt.Errorf("open passwd: %w", err) + } + defer file.Close() + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := scanner.Text() + if line == "" || strings.HasPrefix(line, "#") { + continue + } + fields := strings.SplitN(line, ":", 7) + if len(fields) < 7 { + continue + } + if fields[0] != username { + continue + } + uid, err := strconv.Atoi(fields[2]) + if err != nil { + return passwdEntry{}, fmt.Errorf("parse uid for %s: %w", username, err) + } + gid, err := strconv.Atoi(fields[3]) + if err != nil { + return passwdEntry{}, fmt.Errorf("parse gid for %s: %w", username, err) + } + return passwdEntry{ + UID: uid, + GID: gid, + Home: fields[5], + }, nil + } + if err := scanner.Err(); err != nil { + return passwdEntry{}, fmt.Errorf("scan passwd: %w", err) + } + return passwdEntry{}, fmt.Errorf("user %s not found", username) +} diff --git a/agent/internal/client/client.go b/agent/internal/client/client.go index 69b5950..f71c34b 100644 --- a/agent/internal/client/client.go +++ b/agent/internal/client/client.go @@ -13,6 +13,7 @@ import ( "strings" "time" + "keywarden/agent/internal/accounts" "keywarden/agent/internal/config" ) @@ -81,6 +82,32 @@ type EnrollResponse struct { DisplayName string `json:"display_name,omitempty"` } +type AccountKey struct { + PublicKey string `json:"public_key"` + Fingerprint string `json:"fingerprint"` +} + +type AccountAccess struct { + UserID int `json:"user_id"` + Username string `json:"username"` + Email string `json:"email"` + Keys []AccountKey `json:"keys"` +} + +type AccountSyncEntry struct { + UserID int `json:"user_id"` + SystemUsername string `json:"system_username"` + Present bool `json:"present"` +} + +type SyncReportRequest struct { + AppliedCount int `json:"applied_count"` + RevokedCount int `json:"revoked_count"` + Message string `json:"message,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + Accounts []AccountSyncEntry `json:"accounts,omitempty"` +} + func Enroll(ctx context.Context, serverURL string, req EnrollRequest) (*EnrollResponse, error) { baseURL := strings.TrimRight(serverURL, "/") if baseURL == "" { @@ -114,10 +141,103 @@ func Enroll(ctx context.Context, serverURL string, req EnrollRequest) (*EnrollRe return &out, nil } -func (c *Client) SyncAccounts(ctx context.Context, serverID string) error { - _ = ctx - _ = serverID - // TODO: call API to fetch account policy + approved access list. +func (c *Client) SyncAccounts(ctx context.Context, cfg *config.Config) error { + if cfg == nil { + return errors.New("config required for account sync") + } + users, err := c.FetchAccountAccess(ctx, cfg.ServerID) + if err != nil { + return err + } + accessUsers := make([]accounts.AccessUser, 0, len(users)) + for _, user := range users { + keys := make([]string, 0, len(user.Keys)) + for _, key := range user.Keys { + if strings.TrimSpace(key.PublicKey) == "" { + continue + } + keys = append(keys, strings.TrimSpace(key.PublicKey)) + } + accessUsers = append(accessUsers, accounts.AccessUser{ + UserID: user.UserID, + Username: user.Username, + Email: user.Email, + Keys: keys, + }) + } + result, syncErr := accounts.Sync(cfg.AccountPolicy, cfg.StateDir, accessUsers) + report := SyncReportRequest{ + AppliedCount: result.Applied, + RevokedCount: result.Revoked, + Accounts: make([]AccountSyncEntry, 0, len(result.Accounts)), + } + for _, account := range result.Accounts { + report.Accounts = append(report.Accounts, AccountSyncEntry{ + UserID: account.UserID, + SystemUsername: account.SystemUser, + Present: account.Present, + }) + } + if syncErr != nil { + report.Message = syncErr.Error() + } + if err := c.SendSyncReport(ctx, cfg.ServerID, report); err != nil { + if syncErr != nil { + return fmt.Errorf("sync report failed: %w (sync error: %v)", err, syncErr) + } + return err + } + return syncErr +} + +func (c *Client) FetchAccountAccess(ctx context.Context, serverID string) ([]AccountAccess, error) { + req, err := http.NewRequestWithContext( + ctx, + http.MethodGet, + c.baseURL+"/agent/servers/"+serverID+"/accounts", + nil, + ) + if err != nil { + return nil, fmt.Errorf("build account access request: %w", err) + } + resp, err := c.http.Do(req) + if err != nil { + return nil, fmt.Errorf("fetch account access: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode >= 300 { + return nil, &HTTPStatusError{StatusCode: resp.StatusCode, Status: resp.Status} + } + var out []AccountAccess + if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { + return nil, fmt.Errorf("decode account access: %w", err) + } + return out, nil +} + +func (c *Client) SendSyncReport(ctx context.Context, serverID string, report SyncReportRequest) error { + body, err := json.Marshal(report) + if err != nil { + return fmt.Errorf("encode sync report: %w", err) + } + req, err := http.NewRequestWithContext( + ctx, + http.MethodPost, + c.baseURL+"/agent/servers/"+serverID+"/sync-report", + bytes.NewReader(body), + ) + if err != nil { + return fmt.Errorf("build sync report: %w", err) + } + req.Header.Set("Content-Type", "application/json") + resp, err := c.http.Do(req) + if err != nil { + return fmt.Errorf("send sync report: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode >= 300 { + return &HTTPStatusError{StatusCode: resp.StatusCode, Status: resp.Status} + } return nil } diff --git a/agent/keywarden-agent b/agent/keywarden-agent index c724980..3072405 100755 Binary files a/agent/keywarden-agent and b/agent/keywarden-agent differ diff --git a/app/apps/servers/migrations/0004_server_account.py b/app/apps/servers/migrations/0004_server_account.py new file mode 100644 index 0000000..447f786 --- /dev/null +++ b/app/apps/servers/migrations/0004_server_account.py @@ -0,0 +1,59 @@ +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion +import django.utils.timezone + + +class Migration(migrations.Migration): + + dependencies = [ + ("servers", "0003_agent_ca"), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.CreateModel( + name="ServerAccount", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("system_username", models.CharField(max_length=128)), + ("is_present", models.BooleanField(db_index=True, default=False)), + ("last_synced_at", models.DateTimeField(default=django.utils.timezone.now, editable=False)), + ("created_at", models.DateTimeField(default=django.utils.timezone.now, editable=False)), + ("updated_at", models.DateTimeField(auto_now=True)), + ( + "server", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="accounts", + to="servers.server", + ), + ), + ( + "user", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="server_accounts", + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "verbose_name": "Server account", + "verbose_name_plural": "Server accounts", + "ordering": ["server_id", "user_id"], + }, + ), + migrations.AddConstraint( + model_name="serveraccount", + constraint=models.UniqueConstraint(fields=("server", "user"), name="unique_server_account"), + ), + migrations.AddIndex( + model_name="serveraccount", + index=models.Index(fields=["server", "user"], name="servers_account_user_idx"), + ), + migrations.AddIndex( + model_name="serveraccount", + index=models.Index(fields=["server", "is_present"], name="servers_account_present_idx"), + ), + ] diff --git a/app/apps/servers/models.py b/app/apps/servers/models.py index 1efe9ed..37215aa 100644 --- a/app/apps/servers/models.py +++ b/app/apps/servers/models.py @@ -157,3 +157,30 @@ class AgentCertificateAuthority(models.Model): self.key_pem = key_pem self.fingerprint = cert.fingerprint(hashes.SHA256()).hex() self.serial = format(cert.serial_number, "x") + + +class ServerAccount(models.Model): + server = models.ForeignKey(Server, on_delete=models.CASCADE, related_name="accounts") + user = models.ForeignKey( + settings.AUTH_USER_MODEL, on_delete=models.CASCADE, related_name="server_accounts" + ) + system_username = models.CharField(max_length=128) + is_present = models.BooleanField(default=False, db_index=True) + last_synced_at = models.DateTimeField(default=timezone.now, editable=False) + created_at = models.DateTimeField(default=timezone.now, editable=False) + updated_at = models.DateTimeField(auto_now=True) + + class Meta: + verbose_name = "Server account" + verbose_name_plural = "Server accounts" + constraints = [ + models.UniqueConstraint(fields=["server", "user"], name="unique_server_account") + ] + indexes = [ + models.Index(fields=["server", "user"], name="servers_account_user_idx"), + models.Index(fields=["server", "is_present"], name="servers_account_present_idx"), + ] + ordering = ["server_id", "user_id"] + + def __str__(self) -> str: + return f"{self.system_username} ({self.server_id})" diff --git a/app/apps/servers/templates/servers/detail.html b/app/apps/servers/templates/servers/detail.html index efa2184..29023da 100644 --- a/app/apps/servers/templates/servers/detail.html +++ b/app/apps/servers/templates/servers/detail.html @@ -33,6 +33,18 @@ {% endif %} +