Agent retries on connection loss, sends connection info (v4 v6) Uses system CA for mTLS. Removed server endpoints.

This commit is contained in:
2026-01-26 01:13:51 +00:00
parent e7d20360a2
commit 69802f3ece
11 changed files with 278 additions and 80 deletions

View File

@@ -32,13 +32,18 @@ func New(cfg *config.Config) (*Client, error) {
if err != nil {
return nil, fmt.Errorf("load client cert: %w", err)
}
caData, err := os.ReadFile(cfg.CACertPath())
if err != nil {
return nil, fmt.Errorf("read ca cert: %w", err)
caPool, err := x509.SystemCertPool()
if err != nil || caPool == nil {
caPool = x509.NewCertPool()
}
caPool := x509.NewCertPool()
if !caPool.AppendCertsFromPEM(caData) {
return nil, errors.New("parse ca cert")
if cfg.ServerCAPath != "" {
caData, err := os.ReadFile(cfg.ServerCAPath)
if err != nil {
return nil, fmt.Errorf("read server ca cert: %w", err)
}
if !caPool.AppendCertsFromPEM(caData) {
return nil, errors.New("parse server ca cert")
}
}
tlsConfig := &tls.Config{
@@ -63,14 +68,16 @@ type EnrollRequest struct {
Token string `json:"token"`
CSRPEM string `json:"csr_pem"`
Host string `json:"host"`
IPv4 string `json:"ipv4,omitempty"`
IPv6 string `json:"ipv6,omitempty"`
AgentID string `json:"agent_id,omitempty"`
}
type EnrollResponse struct {
ServerID string `json:"server_id"`
ServerID string `json:"server_id"`
ClientCert string `json:"client_cert_pem"`
CACert string `json:"ca_cert_pem"`
SyncProfile string `json:"sync_profile,omitempty"`
SyncProfile string `json:"sync_profile,omitempty"`
DisplayName string `json:"display_name,omitempty"`
}
@@ -126,7 +133,34 @@ func (c *Client) SendLogBatch(ctx context.Context, serverID string, payload []by
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
return fmt.Errorf("log batch failed: status %s", resp.Status)
return &HTTPStatusError{StatusCode: resp.StatusCode, Status: resp.Status}
}
return nil
}
type HeartbeatRequest struct {
Host string `json:"host,omitempty"`
IPv4 string `json:"ipv4,omitempty"`
IPv6 string `json:"ipv6,omitempty"`
}
func (c *Client) UpdateHost(ctx context.Context, serverID string, reqBody HeartbeatRequest) error {
body, err := json.Marshal(reqBody)
if err != nil {
return fmt.Errorf("encode host update: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/agent/servers/"+serverID+"/heartbeat", bytes.NewReader(body))
if err != nil {
return fmt.Errorf("build host update: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.http.Do(req)
if err != nil {
return fmt.Errorf("send host update: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
return &HTTPStatusError{StatusCode: resp.StatusCode, Status: resp.Status}
}
return nil
}

View File

@@ -0,0 +1,36 @@
package client
import (
"context"
"errors"
"net"
)
type HTTPStatusError struct {
StatusCode int
Status string
}
func (e *HTTPStatusError) Error() string {
return "remote status " + e.Status
}
func IsRetriable(err error) bool {
if err == nil {
return false
}
var statusErr *HTTPStatusError
if errors.As(err, &statusErr) {
switch statusErr.StatusCode {
case 404, 408, 429, 500, 502, 503, 504:
return true
default:
return false
}
}
if errors.Is(err, context.DeadlineExceeded) {
return true
}
var netErr net.Error
return errors.As(err, &netErr)
}

View File

@@ -29,6 +29,7 @@ type AccountPolicy struct {
type Config struct {
ServerURL string `json:"server_url"`
ServerID string `json:"server_id,omitempty"`
ServerCAPath string `json:"server_ca_path,omitempty"`
SyncIntervalSeconds int `json:"sync_interval_seconds,omitempty"`
LogBatchSize int `json:"log_batch_size,omitempty"`
StateDir string `json:"state_dir,omitempty"`
@@ -47,7 +48,7 @@ func LoadOrInit(path string, serverURL string) (*Config, error) {
if serverURL == "" {
return nil, errors.New("server url required for first boot")
}
cfg := &Config{ServerURL: serverURL}
cfg := &Config{ServerURL: serverURL, ServerCAPath: os.Getenv("KEYWARDEN_SERVER_CA_PATH")}
applyDefaults(cfg)
if err := validate(cfg, false); err != nil {
return nil, err
@@ -61,6 +62,9 @@ func LoadOrInit(path string, serverURL string) (*Config, error) {
if err := json.Unmarshal(data, cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
if cfg.ServerCAPath == "" {
cfg.ServerCAPath = os.Getenv("KEYWARDEN_SERVER_CA_PATH")
}
applyDefaults(cfg)
if err := validate(cfg, false); err != nil {
return nil, err

View File

@@ -0,0 +1,57 @@
package host
import (
"net"
"os"
)
type Info struct {
Hostname string
IPv4 string
IPv6 string
}
func Detect() Info {
hostname, _ := os.Hostname()
info := Info{Hostname: hostname}
ifaces, err := net.Interfaces()
if err != nil {
return info
}
for _, iface := range ifaces {
if iface.Flags&net.FlagUp == 0 || iface.Flags&net.FlagLoopback != 0 {
continue
}
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
default:
continue
}
if ip == nil || ip.IsLoopback() || ip.IsUnspecified() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
continue
}
if ip4 := ip.To4(); ip4 != nil {
if info.IPv4 == "" {
info.IPv4 = ip4.String()
}
continue
}
if ip.To16() != nil && info.IPv6 == "" {
info.IPv6 = ip.String()
}
}
if info.IPv4 != "" && info.IPv6 != "" {
break
}
}
return info
}