Initial linux agent and api functionality for enrolling servers
This commit is contained in:
223
agent/cmd/keywarden-agent/main.go
Normal file
223
agent/cmd/keywarden-agent/main.go
Normal file
@@ -0,0 +1,223 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"keywarden/agent/internal/client"
|
||||
"keywarden/agent/internal/config"
|
||||
"keywarden/agent/internal/logs"
|
||||
"keywarden/agent/internal/version"
|
||||
)
|
||||
|
||||
func main() {
|
||||
configPath := flag.String("config", config.DefaultConfigPath, "Path to agent config JSON")
|
||||
serverURL := flag.String("server-url", "", "Keywarden server URL (first boot)")
|
||||
enrollToken := flag.String("enroll-token", "", "Enrollment token (first boot)")
|
||||
showVersion := flag.Bool("version", false, "Print version and exit")
|
||||
flag.Parse()
|
||||
|
||||
if *showVersion {
|
||||
fmt.Printf("keywarden-agent %s (commit %s, built %s)\n", version.Version, version.Commit, version.BuildDate)
|
||||
return
|
||||
}
|
||||
|
||||
cfg, err := config.LoadOrInit(*configPath, pickServerURL(*serverURL))
|
||||
if err != nil {
|
||||
log.Fatalf("config error: %v", err)
|
||||
}
|
||||
if err := ensureDirs(cfg); err != nil {
|
||||
log.Fatalf("state dir error: %v", err)
|
||||
}
|
||||
|
||||
if err := bootstrapIfNeeded(cfg, *configPath, pickEnrollToken(*enrollToken)); err != nil {
|
||||
log.Fatalf("bootstrap error: %v", err)
|
||||
}
|
||||
|
||||
apiClient, err := client.New(cfg)
|
||||
if err != nil {
|
||||
log.Fatalf("client error: %v", err)
|
||||
}
|
||||
|
||||
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||
defer stop()
|
||||
|
||||
interval := time.Duration(cfg.SyncIntervalSeconds) * time.Second
|
||||
log.Printf("keywarden-agent started: server_id=%s interval=%s", cfg.ServerID, interval)
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
runOnce(ctx, apiClient, cfg)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Printf("shutdown requested")
|
||||
return
|
||||
case <-ticker.C:
|
||||
runOnce(ctx, apiClient, cfg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func runOnce(ctx context.Context, apiClient *client.Client, cfg *config.Config) {
|
||||
if err := apiClient.SyncAccounts(ctx, cfg.ServerID); err != nil {
|
||||
log.Printf("sync accounts error: %v", err)
|
||||
}
|
||||
if err := shipLogs(ctx, apiClient, cfg); err != nil {
|
||||
log.Printf("log shipping error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func ensureDirs(cfg *config.Config) error {
|
||||
if err := os.MkdirAll(cfg.StateDir, 0o700); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.MkdirAll(cfg.LogSpoolDir(), 0o700); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func shipLogs(ctx context.Context, apiClient *client.Client, cfg *config.Config) error {
|
||||
send := func(payload []byte) error {
|
||||
return apiClient.SendLogBatch(ctx, cfg.ServerID, payload)
|
||||
}
|
||||
if err := logs.DrainSpool(cfg.LogSpoolDir(), send); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cursor, err := logs.ReadCursor(cfg.LogCursorPath())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
collector := logs.NewCollector()
|
||||
events, nextCursor, err := collector.Collect(ctx, cursor, cfg.LogBatchSize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(events) == 0 {
|
||||
return nil
|
||||
}
|
||||
payload, err := json.Marshal(events)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := send(payload); err != nil {
|
||||
if spoolErr := logs.SaveSpool(cfg.LogSpoolDir(), payload); spoolErr != nil {
|
||||
return spoolErr
|
||||
}
|
||||
return err
|
||||
}
|
||||
if err := logs.WriteCursor(cfg.LogCursorPath(), nextCursor); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func pickServerURL(flagValue string) string {
|
||||
if flagValue != "" {
|
||||
return flagValue
|
||||
}
|
||||
return os.Getenv("KEYWARDEN_SERVER_URL")
|
||||
}
|
||||
|
||||
func pickEnrollToken(flagValue string) string {
|
||||
if flagValue != "" {
|
||||
return flagValue
|
||||
}
|
||||
return os.Getenv("KEYWARDEN_ENROLL_TOKEN")
|
||||
}
|
||||
|
||||
func bootstrapIfNeeded(cfg *config.Config, configPath string, enrollToken string) error {
|
||||
if cfg.ServerID != "" && fileExists(cfg.ClientCertPath()) && fileExists(cfg.CACertPath()) {
|
||||
return nil
|
||||
}
|
||||
if enrollToken == "" {
|
||||
return fmt.Errorf("missing enrollment token; set KEYWARDEN_ENROLL_TOKEN or -enroll-token")
|
||||
}
|
||||
keyPath := cfg.ClientKeyPath()
|
||||
if !fileExists(keyPath) {
|
||||
if err := generateKey(keyPath); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
csrPEM, err := buildCSR(keyPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hostname, _ := os.Hostname()
|
||||
resp, err := client.Enroll(context.Background(), cfg.ServerURL, client.EnrollRequest{
|
||||
Token: enrollToken,
|
||||
CSRPEM: csrPEM,
|
||||
Host: hostname,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.WriteFile(cfg.ClientCertPath(), []byte(resp.ClientCert), 0o600); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.WriteFile(cfg.CACertPath(), []byte(resp.CACert), 0o600); err != nil {
|
||||
return err
|
||||
}
|
||||
cfg.ServerID = resp.ServerID
|
||||
if err := config.Save(configPath, cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func generateKey(path string) error {
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
keyDER := x509.MarshalPKCS1PrivateKey(key)
|
||||
block := &pem.Block{Type: "RSA PRIVATE KEY", Bytes: keyDER}
|
||||
data := pem.EncodeToMemory(block)
|
||||
return os.WriteFile(path, data, 0o600)
|
||||
}
|
||||
|
||||
func buildCSR(keyPath string) (string, error) {
|
||||
keyData, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
block, _ := pem.Decode(keyData)
|
||||
if block == nil || block.Type != "RSA PRIVATE KEY" {
|
||||
return "", fmt.Errorf("invalid private key")
|
||||
}
|
||||
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
csrTemplate := &x509.CertificateRequest{Subject: pkix.Name{CommonName: "keywarden-agent"}}
|
||||
csrDER, err := x509.CreateCertificateRequest(rand.Reader, csrTemplate, key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
csrBlock := &pem.Block{Type: "CERTIFICATE REQUEST", Bytes: csrDER}
|
||||
return string(pem.EncodeToMemory(csrBlock)), nil
|
||||
}
|
||||
|
||||
func fileExists(path string) bool {
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return !info.IsDir()
|
||||
}
|
||||
Reference in New Issue
Block a user