Initial linux agent and api functionality for enrolling servers

This commit is contained in:
2026-01-25 22:24:20 +00:00
parent 66ffa3d3fb
commit 4885622d6a
23 changed files with 1351 additions and 50 deletions

View File

@@ -0,0 +1,132 @@
package client
import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"net/http"
"os"
"strings"
"time"
"keywarden/agent/internal/config"
)
const defaultTimeout = 15 * time.Second
type Client struct {
baseURL string
http *http.Client
}
func New(cfg *config.Config) (*Client, error) {
baseURL := strings.TrimRight(cfg.ServerURL, "/")
if baseURL == "" {
return nil, errors.New("server url is required")
}
cert, err := tls.LoadX509KeyPair(cfg.ClientCertPath(), cfg.ClientKeyPath())
if err != nil {
return nil, fmt.Errorf("load client cert: %w", err)
}
caData, err := os.ReadFile(cfg.CACertPath())
if err != nil {
return nil, fmt.Errorf("read ca cert: %w", err)
}
caPool := x509.NewCertPool()
if !caPool.AppendCertsFromPEM(caData) {
return nil, errors.New("parse ca cert")
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
RootCAs: caPool,
MinVersion: tls.VersionTLS12,
}
transport := &http.Transport{
TLSClientConfig: tlsConfig,
}
httpClient := &http.Client{
Timeout: defaultTimeout,
Transport: transport,
}
return &Client{baseURL: baseURL, http: httpClient}, nil
}
type EnrollRequest struct {
Token string `json:"token"`
CSRPEM string `json:"csr_pem"`
Host string `json:"host"`
AgentID string `json:"agent_id,omitempty"`
}
type EnrollResponse struct {
ServerID string `json:"server_id"`
ClientCert string `json:"client_cert_pem"`
CACert string `json:"ca_cert_pem"`
SyncProfile string `json:"sync_profile,omitempty"`
DisplayName string `json:"display_name,omitempty"`
}
func Enroll(ctx context.Context, serverURL string, req EnrollRequest) (*EnrollResponse, error) {
baseURL := strings.TrimRight(serverURL, "/")
if baseURL == "" {
return nil, errors.New("server url is required")
}
body, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("encode enroll request: %w", err)
}
httpClient := &http.Client{Timeout: defaultTimeout}
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+"/agent/enroll", bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("build enroll request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
resp, err := httpClient.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("enroll request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("enroll failed: status %s", resp.Status)
}
var out EnrollResponse
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
return nil, fmt.Errorf("decode enroll response: %w", err)
}
if out.ServerID == "" || out.ClientCert == "" || out.CACert == "" {
return nil, errors.New("enroll response missing required fields")
}
return &out, nil
}
func (c *Client) SyncAccounts(ctx context.Context, serverID string) error {
_ = ctx
_ = serverID
// TODO: call API to fetch account policy + approved access list.
return nil
}
func (c *Client) SendLogBatch(ctx context.Context, serverID string, payload []byte) error {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/agent/servers/"+serverID+"/logs", bytes.NewReader(payload))
if err != nil {
return fmt.Errorf("build log request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.http.Do(req)
if err != nil {
return fmt.Errorf("send log batch: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
return fmt.Errorf("log batch failed: status %s", resp.Status)
}
return nil
}

View File

@@ -0,0 +1,148 @@
package config
import (
"encoding/json"
"errors"
"fmt"
"os"
"strings"
)
const (
DefaultConfigPath = "/etc/keywarden/agent.json"
DefaultStateDir = "/var/lib/keywarden-agent"
DefaultSyncIntervalSeconds = 30
DefaultLogBatchSize = 500
DefaultUsernameTemplate = "{{username}}_{{user_id}}"
DefaultShell = "/bin/bash"
DefaultAdminGroup = "sudo"
)
type AccountPolicy struct {
UsernameTemplate string `json:"username_template"`
DefaultShell string `json:"default_shell"`
AdminGroup string `json:"admin_group"`
CreateHome bool `json:"create_home"`
LockOnRevoke bool `json:"lock_on_revoke"`
}
type Config struct {
ServerURL string `json:"server_url"`
ServerID string `json:"server_id,omitempty"`
SyncIntervalSeconds int `json:"sync_interval_seconds,omitempty"`
LogBatchSize int `json:"log_batch_size,omitempty"`
StateDir string `json:"state_dir,omitempty"`
AccountPolicy AccountPolicy `json:"account_policy,omitempty"`
}
func LoadOrInit(path string, serverURL string) (*Config, error) {
if path == "" {
path = DefaultConfigPath
}
data, err := os.ReadFile(path)
if err != nil {
if !errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("read config: %w", err)
}
if serverURL == "" {
return nil, errors.New("server url required for first boot")
}
cfg := &Config{ServerURL: serverURL}
applyDefaults(cfg)
if err := validate(cfg, false); err != nil {
return nil, err
}
if err := Save(path, cfg); err != nil {
return nil, err
}
return cfg, nil
}
cfg := &Config{}
if err := json.Unmarshal(data, cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
applyDefaults(cfg)
if err := validate(cfg, false); err != nil {
return nil, err
}
return cfg, nil
}
func Save(path string, cfg *Config) error {
data, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
return fmt.Errorf("encode config: %w", err)
}
if err := os.MkdirAll(dir(path), 0o755); err != nil {
return fmt.Errorf("create config dir: %w", err)
}
if err := os.WriteFile(path, data, 0o600); err != nil {
return fmt.Errorf("write config: %w", err)
}
return nil
}
func applyDefaults(cfg *Config) {
if cfg.SyncIntervalSeconds <= 0 {
cfg.SyncIntervalSeconds = DefaultSyncIntervalSeconds
}
if cfg.LogBatchSize <= 0 {
cfg.LogBatchSize = DefaultLogBatchSize
}
if cfg.StateDir == "" {
cfg.StateDir = DefaultStateDir
}
if cfg.AccountPolicy.UsernameTemplate == "" {
cfg.AccountPolicy.UsernameTemplate = DefaultUsernameTemplate
}
if cfg.AccountPolicy.DefaultShell == "" {
cfg.AccountPolicy.DefaultShell = DefaultShell
}
if cfg.AccountPolicy.AdminGroup == "" {
cfg.AccountPolicy.AdminGroup = DefaultAdminGroup
}
}
func validate(cfg *Config, requireServerID bool) error {
var missing []string
if cfg.ServerURL == "" {
missing = append(missing, "server_url")
}
if requireServerID && cfg.ServerID == "" {
missing = append(missing, "server_id")
}
if len(missing) > 0 {
return fmt.Errorf("missing required config fields: %v", missing)
}
if cfg.SyncIntervalSeconds < 5 {
return errors.New("sync_interval_seconds must be >= 5")
}
return nil
}
func (c *Config) ClientCertPath() string {
return c.StateDir + "/agent.crt"
}
func (c *Config) ClientKeyPath() string {
return c.StateDir + "/agent.key"
}
func (c *Config) CACertPath() string {
return c.StateDir + "/ca.crt"
}
func (c *Config) LogCursorPath() string {
return c.StateDir + "/journal.cursor"
}
func (c *Config) LogSpoolDir() string {
return c.StateDir + "/spool"
}
func dir(path string) string {
if idx := strings.LastIndex(path, string(os.PathSeparator)); idx != -1 {
return path[:idx]
}
return "."
}

View File

@@ -0,0 +1,177 @@
package logs
import (
"context"
"strings"
"time"
"github.com/coreos/go-systemd/v22/sdjournal"
)
const defaultLimit = 500
type Collector struct {
matches []string
}
func NewCollector() *Collector {
return &Collector{matches: defaultMatches()}
}
func (c *Collector) Collect(ctx context.Context, cursor string, limit int) ([]Event, string, error) {
if limit <= 0 {
limit = defaultLimit
}
j, err := sdjournal.NewJournal()
if err != nil {
return nil, "", err
}
defer j.Close()
for i, match := range c.matches {
if i > 0 {
if err := j.AddDisjunction(); err != nil {
return nil, "", err
}
}
if err := j.AddMatch(match); err != nil {
return nil, "", err
}
}
if cursor != "" {
if err := j.SeekCursor(cursor); err == nil {
_, _ = j.Next()
}
} else {
_ = j.SeekTail()
_, _ = j.Next()
}
var events []Event
var nextCursor string
for len(events) < limit {
select {
case <-ctx.Done():
return events, nextCursor, ctx.Err()
default:
}
n, err := j.Next()
if err != nil {
return events, nextCursor, err
}
if n == 0 {
break
}
entry, err := j.GetEntry()
if err != nil {
return events, nextCursor, err
}
event := fromEntry(entry)
events = append(events, event)
nextCursor = entry.Cursor
}
return events, nextCursor, nil
}
func defaultMatches() []string {
return []string{
"_SYSTEMD_UNIT=sshd.service",
"_SYSTEMD_UNIT=sudo.service",
"_SYSTEMD_UNIT=systemd-networkd.service",
"_SYSTEMD_UNIT=NetworkManager.service",
"_SYSTEMD_UNIT=systemd-logind.service",
"_TRANSPORT=kernel",
}
}
func fromEntry(entry *sdjournal.JournalEntry) Event {
ts := time.Unix(0, int64(entry.RealtimeTimestamp)*int64(time.Microsecond))
event := NewEvent(ts)
fields := entry.Fields
unit := fields["_SYSTEMD_UNIT"]
message := fields["MESSAGE"]
identifier := fields["SYSLOG_IDENTIFIER"]
event.Unit = unit
event.Message = message
event.Priority = fields["PRIORITY"]
event.Hostname = fields["_HOSTNAME"]
event.Fields = fields
event.Category = categorize(unit, identifier, fields)
event.EventType, event.Username, event.SourceIP, event.SessionID = parseMessage(event.Category, message)
if event.EventType == "" {
event.EventType = defaultEventType(event.Category)
}
return event
}
func categorize(unit string, identifier string, fields map[string]string) string {
switch {
case unit == "sshd.service" || identifier == "sshd":
return "access"
case unit == "sudo.service" || identifier == "sudo":
return "auth"
case unit == "systemd-networkd.service" || identifier == "NetworkManager":
return "network"
case fields["_TRANSPORT"] == "kernel":
return "system"
default:
return "system"
}
}
func defaultEventType(category string) string {
switch category {
case "access":
return "ssh"
case "auth":
return "auth"
case "network":
return "network"
default:
return "system"
}
}
func parseMessage(category string, msg string) (eventType string, username string, sourceIP string, sessionID string) {
if msg == "" {
return "", "", "", ""
}
lower := strings.ToLower(msg)
if category == "access" {
switch {
case strings.Contains(lower, "accepted"):
eventType = "ssh.login.success"
username = extractBetween(msg, "for ", " from")
sourceIP = extractBetween(msg, "from ", " port")
case strings.Contains(lower, "failed password"):
eventType = "ssh.login.fail"
username = extractBetween(msg, "for ", " from")
sourceIP = extractBetween(msg, "from ", " port")
case strings.Contains(lower, "session opened"):
eventType = "ssh.session.open"
username = extractBetween(msg, "for user ", " by")
case strings.Contains(lower, "session closed"):
eventType = "ssh.session.close"
username = extractBetween(msg, "for user ", " by")
}
}
return eventType, strings.TrimSpace(username), strings.TrimSpace(sourceIP), strings.TrimSpace(sessionID)
}
func extractBetween(msg string, start string, end string) string {
startIdx := strings.Index(msg, start)
if startIdx == -1 {
return ""
}
startIdx += len(start)
rest := msg[startIdx:]
endIdx := strings.Index(rest, end)
if endIdx == -1 {
return strings.TrimSpace(rest)
}
return strings.TrimSpace(rest[:endIdx])
}

View File

@@ -0,0 +1,24 @@
package logs
import (
"os"
"strings"
)
func ReadCursor(path string) (string, error) {
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return "", nil
}
return "", err
}
return strings.TrimSpace(string(data)), nil
}
func WriteCursor(path string, cursor string) error {
if cursor == "" {
return nil
}
return os.WriteFile(path, []byte(cursor+"\n"), 0o600)
}

View File

@@ -0,0 +1,53 @@
package logs
import (
"fmt"
"os"
"path/filepath"
"sort"
"time"
)
func SaveSpool(dir string, payload []byte) error {
if err := os.MkdirAll(dir, 0o700); err != nil {
return err
}
name := fmt.Sprintf("%d.json", time.Now().UnixNano())
tmp := filepath.Join(dir, name+".tmp")
final := filepath.Join(dir, name)
if err := os.WriteFile(tmp, payload, 0o600); err != nil {
return err
}
return os.Rename(tmp, final)
}
func DrainSpool(dir string, send func([]byte) error) error {
entries, err := os.ReadDir(dir)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return err
}
var files []string
for _, entry := range entries {
if entry.IsDir() {
continue
}
files = append(files, filepath.Join(dir, entry.Name()))
}
sort.Strings(files)
for _, path := range files {
data, err := os.ReadFile(path)
if err != nil {
return err
}
if err := send(data); err != nil {
return err
}
if err := os.Remove(path); err != nil {
return err
}
}
return nil
}

View File

@@ -0,0 +1,23 @@
package logs
import "time"
type Event struct {
Timestamp string `json:"timestamp"`
Category string `json:"category"`
EventType string `json:"event_type"`
Unit string `json:"unit,omitempty"`
Priority string `json:"priority,omitempty"`
Hostname string `json:"hostname,omitempty"`
Username string `json:"username,omitempty"`
Principal string `json:"principal,omitempty"`
SourceIP string `json:"source_ip,omitempty"`
SessionID string `json:"session_id,omitempty"`
Message string `json:"message,omitempty"`
Raw string `json:"raw,omitempty"`
Fields map[string]string `json:"fields,omitempty"`
}
func NewEvent(ts time.Time) Event {
return Event{Timestamp: ts.UTC().Format(time.RFC3339Nano)}
}

View File

@@ -0,0 +1,7 @@
package version
var (
Version = "0.0.1-dev"
Commit = ""
BuildDate = ""
)