object‑permission–driven server access; agent‑managed account provisioning with presence reporting
This commit is contained in:
@@ -82,7 +82,7 @@ func runOnce(ctx context.Context, apiClient *client.Client, cfg *config.Config)
|
||||
log.Printf("host update error: %v", err)
|
||||
}
|
||||
}
|
||||
if err := apiClient.SyncAccounts(ctx, cfg.ServerID); err != nil {
|
||||
if err := apiClient.SyncAccounts(ctx, cfg); err != nil {
|
||||
log.Printf("sync accounts error: %v", err)
|
||||
}
|
||||
if err := shipLogs(ctx, apiClient, cfg); err != nil {
|
||||
|
||||
323
agent/internal/accounts/sync.go
Normal file
323
agent/internal/accounts/sync.go
Normal file
@@ -0,0 +1,323 @@
|
||||
package accounts
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"keywarden/agent/internal/config"
|
||||
)
|
||||
|
||||
const (
|
||||
stateFileName = "accounts.json"
|
||||
maxUsernameLen = 32
|
||||
passwdFilePath = "/etc/passwd"
|
||||
sshDirName = ".ssh"
|
||||
authKeysName = "authorized_keys"
|
||||
)
|
||||
|
||||
type AccessUser struct {
|
||||
UserID int
|
||||
Username string
|
||||
Email string
|
||||
Keys []string
|
||||
}
|
||||
|
||||
type ReportAccount struct {
|
||||
UserID int `json:"user_id"`
|
||||
SystemUser string `json:"system_username"`
|
||||
Present bool `json:"present"`
|
||||
}
|
||||
|
||||
type Result struct {
|
||||
Applied int
|
||||
Revoked int
|
||||
Accounts []ReportAccount
|
||||
}
|
||||
|
||||
type managedAccount struct {
|
||||
UserID int `json:"user_id"`
|
||||
SystemUser string `json:"system_username"`
|
||||
}
|
||||
|
||||
type state struct {
|
||||
Users map[string]managedAccount `json:"users"`
|
||||
}
|
||||
|
||||
type passwdEntry struct {
|
||||
UID int
|
||||
GID int
|
||||
Home string
|
||||
}
|
||||
|
||||
func Sync(policy config.AccountPolicy, stateDir string, users []AccessUser) (Result, error) {
|
||||
result := Result{}
|
||||
statePath := filepath.Join(stateDir, stateFileName)
|
||||
current, err := loadState(statePath)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
desired := make(map[int]managedAccount, len(users))
|
||||
userIndex := make(map[int]AccessUser, len(users))
|
||||
for _, user := range users {
|
||||
systemUser := renderUsername(policy.UsernameTemplate, user.Username, user.UserID)
|
||||
desired[user.UserID] = managedAccount{UserID: user.UserID, SystemUser: systemUser}
|
||||
userIndex[user.UserID] = user
|
||||
}
|
||||
|
||||
var syncErr error
|
||||
for _, account := range current.Users {
|
||||
if _, ok := desired[account.UserID]; ok {
|
||||
continue
|
||||
}
|
||||
if err := revokeUser(account.SystemUser, policy); err != nil && syncErr == nil {
|
||||
syncErr = err
|
||||
}
|
||||
result.Revoked++
|
||||
}
|
||||
|
||||
for userID, account := range desired {
|
||||
accessUser := userIndex[userID]
|
||||
present, err := ensureAccount(account.SystemUser, policy, accessUser.Keys)
|
||||
if err != nil && syncErr == nil {
|
||||
syncErr = err
|
||||
}
|
||||
if present {
|
||||
result.Applied++
|
||||
}
|
||||
result.Accounts = append(result.Accounts, ReportAccount{
|
||||
UserID: userID,
|
||||
SystemUser: account.SystemUser,
|
||||
Present: present,
|
||||
})
|
||||
}
|
||||
|
||||
if err := saveState(statePath, desired); err != nil && syncErr == nil {
|
||||
syncErr = err
|
||||
}
|
||||
return result, syncErr
|
||||
}
|
||||
|
||||
func loadState(path string) (state, error) {
|
||||
st := state{Users: map[string]managedAccount{}}
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return st, nil
|
||||
}
|
||||
return st, fmt.Errorf("read state: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(data, &st); err != nil {
|
||||
return st, fmt.Errorf("parse state: %w", err)
|
||||
}
|
||||
if st.Users == nil {
|
||||
st.Users = map[string]managedAccount{}
|
||||
}
|
||||
return st, nil
|
||||
}
|
||||
|
||||
func saveState(path string, desired map[int]managedAccount) error {
|
||||
st := state{Users: map[string]managedAccount{}}
|
||||
for id, account := range desired {
|
||||
st.Users[strconv.Itoa(id)] = account
|
||||
}
|
||||
data, err := json.MarshalIndent(st, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("encode state: %w", err)
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o700); err != nil {
|
||||
return fmt.Errorf("create state dir: %w", err)
|
||||
}
|
||||
if err := os.WriteFile(path, data, 0o600); err != nil {
|
||||
return fmt.Errorf("write state: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func renderUsername(template string, username string, userID int) string {
|
||||
raw := strings.ReplaceAll(template, "{{username}}", username)
|
||||
raw = strings.ReplaceAll(raw, "{{user_id}}", strconv.Itoa(userID))
|
||||
clean := sanitizeUsername(raw)
|
||||
if len(clean) > maxUsernameLen {
|
||||
clean = clean[:maxUsernameLen]
|
||||
}
|
||||
if clean == "" {
|
||||
clean = fmt.Sprintf("kw_%d", userID)
|
||||
}
|
||||
return clean
|
||||
}
|
||||
|
||||
func sanitizeUsername(raw string) string {
|
||||
raw = strings.ToLower(raw)
|
||||
var b strings.Builder
|
||||
b.Grow(len(raw))
|
||||
for _, r := range raw {
|
||||
if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '_' || r == '-' {
|
||||
b.WriteRune(r)
|
||||
continue
|
||||
}
|
||||
b.WriteByte('_')
|
||||
}
|
||||
out := strings.Trim(b.String(), "-_")
|
||||
if out == "" {
|
||||
return ""
|
||||
}
|
||||
if strings.HasPrefix(out, "-") {
|
||||
return "kw" + out
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func userExists(username string) (bool, error) {
|
||||
cmd := exec.Command("id", "-u", username)
|
||||
if err := cmd.Run(); err != nil {
|
||||
if _, ok := err.(*exec.ExitError); ok {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func ensureAccount(username string, policy config.AccountPolicy, keys []string) (bool, error) {
|
||||
exists, err := userExists(username)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if !exists {
|
||||
if err := createUser(username, policy); err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
if err := lockPassword(username); err != nil {
|
||||
return true, err
|
||||
}
|
||||
if err := writeAuthorizedKeys(username, keys); err != nil {
|
||||
return true, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func createUser(username string, policy config.AccountPolicy) error {
|
||||
args := []string{"-U"}
|
||||
if policy.CreateHome {
|
||||
args = append(args, "-m")
|
||||
} else {
|
||||
args = append(args, "-M")
|
||||
}
|
||||
if policy.DefaultShell != "" {
|
||||
args = append(args, "-s", policy.DefaultShell)
|
||||
}
|
||||
args = append(args, username)
|
||||
cmd := exec.Command("useradd", args...)
|
||||
if err := cmd.Run(); err != nil {
|
||||
return fmt.Errorf("useradd %s: %w", username, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func lockPassword(username string) error {
|
||||
cmd := exec.Command("usermod", "-L", username)
|
||||
if err := cmd.Run(); err != nil {
|
||||
return fmt.Errorf("lock password %s: %w", username, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func revokeUser(username string, policy config.AccountPolicy) error {
|
||||
exists, err := userExists(username)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
var revokeErr error
|
||||
if policy.LockOnRevoke {
|
||||
if err := lockPassword(username); err != nil {
|
||||
revokeErr = err
|
||||
}
|
||||
}
|
||||
if err := writeAuthorizedKeys(username, nil); err != nil && revokeErr == nil {
|
||||
revokeErr = err
|
||||
}
|
||||
return revokeErr
|
||||
}
|
||||
|
||||
func writeAuthorizedKeys(username string, keys []string) error {
|
||||
entry, err := lookupUser(username)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if entry.Home == "" {
|
||||
return fmt.Errorf("missing home dir for %s", username)
|
||||
}
|
||||
sshDir := filepath.Join(entry.Home, sshDirName)
|
||||
if err := os.MkdirAll(sshDir, 0o700); err != nil {
|
||||
return fmt.Errorf("mkdir %s: %w", sshDir, err)
|
||||
}
|
||||
if err := os.Chmod(sshDir, 0o700); err != nil {
|
||||
return fmt.Errorf("chmod %s: %w", sshDir, err)
|
||||
}
|
||||
if err := os.Chown(sshDir, entry.UID, entry.GID); err != nil {
|
||||
return fmt.Errorf("chown %s: %w", sshDir, err)
|
||||
}
|
||||
authKeysPath := filepath.Join(sshDir, authKeysName)
|
||||
payload := strings.TrimSpace(strings.Join(keys, "\n"))
|
||||
if payload != "" {
|
||||
payload += "\n"
|
||||
}
|
||||
if err := os.WriteFile(authKeysPath, []byte(payload), 0o600); err != nil {
|
||||
return fmt.Errorf("write %s: %w", authKeysPath, err)
|
||||
}
|
||||
if err := os.Chown(authKeysPath, entry.UID, entry.GID); err != nil {
|
||||
return fmt.Errorf("chown %s: %w", authKeysPath, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func lookupUser(username string) (passwdEntry, error) {
|
||||
file, err := os.Open(passwdFilePath)
|
||||
if err != nil {
|
||||
return passwdEntry{}, fmt.Errorf("open passwd: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
fields := strings.SplitN(line, ":", 7)
|
||||
if len(fields) < 7 {
|
||||
continue
|
||||
}
|
||||
if fields[0] != username {
|
||||
continue
|
||||
}
|
||||
uid, err := strconv.Atoi(fields[2])
|
||||
if err != nil {
|
||||
return passwdEntry{}, fmt.Errorf("parse uid for %s: %w", username, err)
|
||||
}
|
||||
gid, err := strconv.Atoi(fields[3])
|
||||
if err != nil {
|
||||
return passwdEntry{}, fmt.Errorf("parse gid for %s: %w", username, err)
|
||||
}
|
||||
return passwdEntry{
|
||||
UID: uid,
|
||||
GID: gid,
|
||||
Home: fields[5],
|
||||
}, nil
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return passwdEntry{}, fmt.Errorf("scan passwd: %w", err)
|
||||
}
|
||||
return passwdEntry{}, fmt.Errorf("user %s not found", username)
|
||||
}
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"keywarden/agent/internal/accounts"
|
||||
"keywarden/agent/internal/config"
|
||||
)
|
||||
|
||||
@@ -81,6 +82,32 @@ type EnrollResponse struct {
|
||||
DisplayName string `json:"display_name,omitempty"`
|
||||
}
|
||||
|
||||
type AccountKey struct {
|
||||
PublicKey string `json:"public_key"`
|
||||
Fingerprint string `json:"fingerprint"`
|
||||
}
|
||||
|
||||
type AccountAccess struct {
|
||||
UserID int `json:"user_id"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
Keys []AccountKey `json:"keys"`
|
||||
}
|
||||
|
||||
type AccountSyncEntry struct {
|
||||
UserID int `json:"user_id"`
|
||||
SystemUsername string `json:"system_username"`
|
||||
Present bool `json:"present"`
|
||||
}
|
||||
|
||||
type SyncReportRequest struct {
|
||||
AppliedCount int `json:"applied_count"`
|
||||
RevokedCount int `json:"revoked_count"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
Accounts []AccountSyncEntry `json:"accounts,omitempty"`
|
||||
}
|
||||
|
||||
func Enroll(ctx context.Context, serverURL string, req EnrollRequest) (*EnrollResponse, error) {
|
||||
baseURL := strings.TrimRight(serverURL, "/")
|
||||
if baseURL == "" {
|
||||
@@ -114,10 +141,103 @@ func Enroll(ctx context.Context, serverURL string, req EnrollRequest) (*EnrollRe
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func (c *Client) SyncAccounts(ctx context.Context, serverID string) error {
|
||||
_ = ctx
|
||||
_ = serverID
|
||||
// TODO: call API to fetch account policy + approved access list.
|
||||
func (c *Client) SyncAccounts(ctx context.Context, cfg *config.Config) error {
|
||||
if cfg == nil {
|
||||
return errors.New("config required for account sync")
|
||||
}
|
||||
users, err := c.FetchAccountAccess(ctx, cfg.ServerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
accessUsers := make([]accounts.AccessUser, 0, len(users))
|
||||
for _, user := range users {
|
||||
keys := make([]string, 0, len(user.Keys))
|
||||
for _, key := range user.Keys {
|
||||
if strings.TrimSpace(key.PublicKey) == "" {
|
||||
continue
|
||||
}
|
||||
keys = append(keys, strings.TrimSpace(key.PublicKey))
|
||||
}
|
||||
accessUsers = append(accessUsers, accounts.AccessUser{
|
||||
UserID: user.UserID,
|
||||
Username: user.Username,
|
||||
Email: user.Email,
|
||||
Keys: keys,
|
||||
})
|
||||
}
|
||||
result, syncErr := accounts.Sync(cfg.AccountPolicy, cfg.StateDir, accessUsers)
|
||||
report := SyncReportRequest{
|
||||
AppliedCount: result.Applied,
|
||||
RevokedCount: result.Revoked,
|
||||
Accounts: make([]AccountSyncEntry, 0, len(result.Accounts)),
|
||||
}
|
||||
for _, account := range result.Accounts {
|
||||
report.Accounts = append(report.Accounts, AccountSyncEntry{
|
||||
UserID: account.UserID,
|
||||
SystemUsername: account.SystemUser,
|
||||
Present: account.Present,
|
||||
})
|
||||
}
|
||||
if syncErr != nil {
|
||||
report.Message = syncErr.Error()
|
||||
}
|
||||
if err := c.SendSyncReport(ctx, cfg.ServerID, report); err != nil {
|
||||
if syncErr != nil {
|
||||
return fmt.Errorf("sync report failed: %w (sync error: %v)", err, syncErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
return syncErr
|
||||
}
|
||||
|
||||
func (c *Client) FetchAccountAccess(ctx context.Context, serverID string) ([]AccountAccess, error) {
|
||||
req, err := http.NewRequestWithContext(
|
||||
ctx,
|
||||
http.MethodGet,
|
||||
c.baseURL+"/agent/servers/"+serverID+"/accounts",
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build account access request: %w", err)
|
||||
}
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetch account access: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode >= 300 {
|
||||
return nil, &HTTPStatusError{StatusCode: resp.StatusCode, Status: resp.Status}
|
||||
}
|
||||
var out []AccountAccess
|
||||
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
|
||||
return nil, fmt.Errorf("decode account access: %w", err)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *Client) SendSyncReport(ctx context.Context, serverID string, report SyncReportRequest) error {
|
||||
body, err := json.Marshal(report)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encode sync report: %w", err)
|
||||
}
|
||||
req, err := http.NewRequestWithContext(
|
||||
ctx,
|
||||
http.MethodPost,
|
||||
c.baseURL+"/agent/servers/"+serverID+"/sync-report",
|
||||
bytes.NewReader(body),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("build sync report: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("send sync report: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode >= 300 {
|
||||
return &HTTPStatusError{StatusCode: resp.StatusCode, Status: resp.Status}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
Binary file not shown.
Reference in New Issue
Block a user