Agent retries on connection loss, sends connection info (v4 v6) Uses system CA for mTLS. Removed server endpoints.
This commit is contained in:
@@ -20,4 +20,6 @@ You can also pass `KEYWARDEN_SERVER_URL` and `KEYWARDEN_ENROLL_TOKEN` as environ
|
||||
|
||||
On first boot, the agent will create a config file if it does not exist. Only `server_url` is required for bootstrapping.
|
||||
|
||||
If the Keywarden server uses a private TLS CA, set `server_ca_path` (or `KEYWARDEN_SERVER_CA_PATH`) to the CA PEM file so the agent can verify the server certificate.
|
||||
|
||||
See `config.example.json`.
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
|
||||
"keywarden/agent/internal/client"
|
||||
"keywarden/agent/internal/config"
|
||||
"keywarden/agent/internal/host"
|
||||
"keywarden/agent/internal/logs"
|
||||
"keywarden/agent/internal/version"
|
||||
)
|
||||
@@ -74,11 +75,22 @@ func main() {
|
||||
}
|
||||
|
||||
func runOnce(ctx context.Context, apiClient *client.Client, cfg *config.Config) {
|
||||
if err := reportHost(ctx, apiClient, cfg); err != nil {
|
||||
if client.IsRetriable(err) {
|
||||
log.Printf("host update deferred; will retry: %v", err)
|
||||
} else {
|
||||
log.Printf("host update error: %v", err)
|
||||
}
|
||||
}
|
||||
if err := apiClient.SyncAccounts(ctx, cfg.ServerID); err != nil {
|
||||
log.Printf("sync accounts error: %v", err)
|
||||
}
|
||||
if err := shipLogs(ctx, apiClient, cfg); err != nil {
|
||||
log.Printf("log shipping error: %v", err)
|
||||
if client.IsRetriable(err) {
|
||||
log.Printf("log shipping deferred; will retry: %v", err)
|
||||
} else {
|
||||
log.Printf("log shipping error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -94,7 +106,9 @@ func ensureDirs(cfg *config.Config) error {
|
||||
|
||||
func shipLogs(ctx context.Context, apiClient *client.Client, cfg *config.Config) error {
|
||||
send := func(payload []byte) error {
|
||||
return apiClient.SendLogBatch(ctx, cfg.ServerID, payload)
|
||||
return retry(ctx, []time.Duration{250 * time.Millisecond, time.Second, 2 * time.Second}, func() error {
|
||||
return apiClient.SendLogBatch(ctx, cfg.ServerID, payload)
|
||||
})
|
||||
}
|
||||
if err := logs.DrainSpool(cfg.LogSpoolDir(), send); err != nil {
|
||||
return err
|
||||
@@ -128,6 +142,17 @@ func shipLogs(ctx context.Context, apiClient *client.Client, cfg *config.Config)
|
||||
return nil
|
||||
}
|
||||
|
||||
func reportHost(ctx context.Context, apiClient *client.Client, cfg *config.Config) error {
|
||||
info := host.Detect()
|
||||
return retry(ctx, []time.Duration{250 * time.Millisecond, time.Second, 2 * time.Second}, func() error {
|
||||
return apiClient.UpdateHost(ctx, cfg.ServerID, client.HeartbeatRequest{
|
||||
Host: info.Hostname,
|
||||
IPv4: info.IPv4,
|
||||
IPv6: info.IPv6,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func pickServerURL(flagValue string) string {
|
||||
if flagValue != "" {
|
||||
return flagValue
|
||||
@@ -159,11 +184,14 @@ func bootstrapIfNeeded(cfg *config.Config, configPath string, enrollToken string
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hostname, _ := os.Hostname()
|
||||
info := host.Detect()
|
||||
hostname := info.Hostname
|
||||
resp, err := client.Enroll(context.Background(), cfg.ServerURL, client.EnrollRequest{
|
||||
Token: enrollToken,
|
||||
CSRPEM: csrPEM,
|
||||
Host: hostname,
|
||||
IPv4: info.IPv4,
|
||||
IPv6: info.IPv6,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -181,6 +209,28 @@ func bootstrapIfNeeded(cfg *config.Config, configPath string, enrollToken string
|
||||
return nil
|
||||
}
|
||||
|
||||
func retry(ctx context.Context, delays []time.Duration, fn func() error) error {
|
||||
var lastErr error
|
||||
for attempt := 0; attempt <= len(delays); attempt++ {
|
||||
if attempt > 0 {
|
||||
if !client.IsRetriable(lastErr) {
|
||||
return lastErr
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(delays[attempt-1]):
|
||||
}
|
||||
}
|
||||
if err := fn(); err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return lastErr
|
||||
}
|
||||
|
||||
func generateKey(path string) error {
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
{
|
||||
"server_url": "https://keywarden.dev.ntbx.io/api/v1",
|
||||
"server_id": "4",
|
||||
"server_ca_path": "",
|
||||
"sync_interval_seconds": 30,
|
||||
"log_batch_size": 500,
|
||||
"state_dir": "/var/lib/keywarden-agent",
|
||||
@@ -11,4 +12,4 @@
|
||||
"create_home": true,
|
||||
"lock_on_revoke": true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,13 +32,18 @@ func New(cfg *config.Config) (*Client, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load client cert: %w", err)
|
||||
}
|
||||
caData, err := os.ReadFile(cfg.CACertPath())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read ca cert: %w", err)
|
||||
caPool, err := x509.SystemCertPool()
|
||||
if err != nil || caPool == nil {
|
||||
caPool = x509.NewCertPool()
|
||||
}
|
||||
caPool := x509.NewCertPool()
|
||||
if !caPool.AppendCertsFromPEM(caData) {
|
||||
return nil, errors.New("parse ca cert")
|
||||
if cfg.ServerCAPath != "" {
|
||||
caData, err := os.ReadFile(cfg.ServerCAPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read server ca cert: %w", err)
|
||||
}
|
||||
if !caPool.AppendCertsFromPEM(caData) {
|
||||
return nil, errors.New("parse server ca cert")
|
||||
}
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
@@ -63,14 +68,16 @@ type EnrollRequest struct {
|
||||
Token string `json:"token"`
|
||||
CSRPEM string `json:"csr_pem"`
|
||||
Host string `json:"host"`
|
||||
IPv4 string `json:"ipv4,omitempty"`
|
||||
IPv6 string `json:"ipv6,omitempty"`
|
||||
AgentID string `json:"agent_id,omitempty"`
|
||||
}
|
||||
|
||||
type EnrollResponse struct {
|
||||
ServerID string `json:"server_id"`
|
||||
ServerID string `json:"server_id"`
|
||||
ClientCert string `json:"client_cert_pem"`
|
||||
CACert string `json:"ca_cert_pem"`
|
||||
SyncProfile string `json:"sync_profile,omitempty"`
|
||||
SyncProfile string `json:"sync_profile,omitempty"`
|
||||
DisplayName string `json:"display_name,omitempty"`
|
||||
}
|
||||
|
||||
@@ -126,7 +133,34 @@ func (c *Client) SendLogBatch(ctx context.Context, serverID string, payload []by
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode >= 300 {
|
||||
return fmt.Errorf("log batch failed: status %s", resp.Status)
|
||||
return &HTTPStatusError{StatusCode: resp.StatusCode, Status: resp.Status}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type HeartbeatRequest struct {
|
||||
Host string `json:"host,omitempty"`
|
||||
IPv4 string `json:"ipv4,omitempty"`
|
||||
IPv6 string `json:"ipv6,omitempty"`
|
||||
}
|
||||
|
||||
func (c *Client) UpdateHost(ctx context.Context, serverID string, reqBody HeartbeatRequest) error {
|
||||
body, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encode host update: %w", err)
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/agent/servers/"+serverID+"/heartbeat", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return fmt.Errorf("build host update: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("send host update: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode >= 300 {
|
||||
return &HTTPStatusError{StatusCode: resp.StatusCode, Status: resp.Status}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
36
agent/internal/client/errors.go
Normal file
36
agent/internal/client/errors.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
)
|
||||
|
||||
type HTTPStatusError struct {
|
||||
StatusCode int
|
||||
Status string
|
||||
}
|
||||
|
||||
func (e *HTTPStatusError) Error() string {
|
||||
return "remote status " + e.Status
|
||||
}
|
||||
|
||||
func IsRetriable(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var statusErr *HTTPStatusError
|
||||
if errors.As(err, &statusErr) {
|
||||
switch statusErr.StatusCode {
|
||||
case 404, 408, 429, 500, 502, 503, 504:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return true
|
||||
}
|
||||
var netErr net.Error
|
||||
return errors.As(err, &netErr)
|
||||
}
|
||||
@@ -29,6 +29,7 @@ type AccountPolicy struct {
|
||||
type Config struct {
|
||||
ServerURL string `json:"server_url"`
|
||||
ServerID string `json:"server_id,omitempty"`
|
||||
ServerCAPath string `json:"server_ca_path,omitempty"`
|
||||
SyncIntervalSeconds int `json:"sync_interval_seconds,omitempty"`
|
||||
LogBatchSize int `json:"log_batch_size,omitempty"`
|
||||
StateDir string `json:"state_dir,omitempty"`
|
||||
@@ -47,7 +48,7 @@ func LoadOrInit(path string, serverURL string) (*Config, error) {
|
||||
if serverURL == "" {
|
||||
return nil, errors.New("server url required for first boot")
|
||||
}
|
||||
cfg := &Config{ServerURL: serverURL}
|
||||
cfg := &Config{ServerURL: serverURL, ServerCAPath: os.Getenv("KEYWARDEN_SERVER_CA_PATH")}
|
||||
applyDefaults(cfg)
|
||||
if err := validate(cfg, false); err != nil {
|
||||
return nil, err
|
||||
@@ -61,6 +62,9 @@ func LoadOrInit(path string, serverURL string) (*Config, error) {
|
||||
if err := json.Unmarshal(data, cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
if cfg.ServerCAPath == "" {
|
||||
cfg.ServerCAPath = os.Getenv("KEYWARDEN_SERVER_CA_PATH")
|
||||
}
|
||||
applyDefaults(cfg)
|
||||
if err := validate(cfg, false); err != nil {
|
||||
return nil, err
|
||||
|
||||
57
agent/internal/host/host.go
Normal file
57
agent/internal/host/host.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package host
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
)
|
||||
|
||||
type Info struct {
|
||||
Hostname string
|
||||
IPv4 string
|
||||
IPv6 string
|
||||
}
|
||||
|
||||
func Detect() Info {
|
||||
hostname, _ := os.Hostname()
|
||||
info := Info{Hostname: hostname}
|
||||
ifaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return info
|
||||
}
|
||||
for _, iface := range ifaces {
|
||||
if iface.Flags&net.FlagUp == 0 || iface.Flags&net.FlagLoopback != 0 {
|
||||
continue
|
||||
}
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for _, addr := range addrs {
|
||||
var ip net.IP
|
||||
switch v := addr.(type) {
|
||||
case *net.IPNet:
|
||||
ip = v.IP
|
||||
case *net.IPAddr:
|
||||
ip = v.IP
|
||||
default:
|
||||
continue
|
||||
}
|
||||
if ip == nil || ip.IsLoopback() || ip.IsUnspecified() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
|
||||
continue
|
||||
}
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
if info.IPv4 == "" {
|
||||
info.IPv4 = ip4.String()
|
||||
}
|
||||
continue
|
||||
}
|
||||
if ip.To16() != nil && info.IPv6 == "" {
|
||||
info.IPv6 = ip.String()
|
||||
}
|
||||
}
|
||||
if info.IPv4 != "" && info.IPv6 != "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
return info
|
||||
}
|
||||
Binary file not shown.
Reference in New Issue
Block a user