287 lines
8.0 KiB
Go
287 lines
8.0 KiB
Go
package client
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
|
|
"keywarden/agent/internal/accounts"
|
|
"keywarden/agent/internal/config"
|
|
)
|
|
|
|
const defaultTimeout = 15 * time.Second
|
|
|
|
type Client struct {
|
|
baseURL string
|
|
http *http.Client
|
|
}
|
|
|
|
func New(cfg *config.Config) (*Client, error) {
|
|
baseURL := strings.TrimRight(cfg.ServerURL, "/")
|
|
if baseURL == "" {
|
|
return nil, errors.New("server url is required")
|
|
}
|
|
cert, err := tls.LoadX509KeyPair(cfg.ClientCertPath(), cfg.ClientKeyPath())
|
|
if err != nil {
|
|
return nil, fmt.Errorf("load client cert: %w", err)
|
|
}
|
|
caPool, err := x509.SystemCertPool()
|
|
if err != nil || caPool == nil {
|
|
caPool = x509.NewCertPool()
|
|
}
|
|
if cfg.ServerCAPath != "" {
|
|
caData, err := os.ReadFile(cfg.ServerCAPath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read server ca cert: %w", err)
|
|
}
|
|
if !caPool.AppendCertsFromPEM(caData) {
|
|
return nil, errors.New("parse server ca cert")
|
|
}
|
|
}
|
|
|
|
tlsConfig := &tls.Config{
|
|
Certificates: []tls.Certificate{cert},
|
|
RootCAs: caPool,
|
|
MinVersion: tls.VersionTLS12,
|
|
}
|
|
|
|
transport := &http.Transport{
|
|
TLSClientConfig: tlsConfig,
|
|
}
|
|
|
|
httpClient := &http.Client{
|
|
Timeout: defaultTimeout,
|
|
Transport: transport,
|
|
}
|
|
|
|
return &Client{baseURL: baseURL, http: httpClient}, nil
|
|
}
|
|
|
|
type EnrollRequest struct {
|
|
Token string `json:"token"`
|
|
CSRPEM string `json:"csr_pem"`
|
|
Host string `json:"host"`
|
|
IPv4 string `json:"ipv4,omitempty"`
|
|
IPv6 string `json:"ipv6,omitempty"`
|
|
AgentID string `json:"agent_id,omitempty"`
|
|
}
|
|
|
|
type EnrollResponse struct {
|
|
ServerID string `json:"server_id"`
|
|
ClientCert string `json:"client_cert_pem"`
|
|
CACert string `json:"ca_cert_pem"`
|
|
SyncProfile string `json:"sync_profile,omitempty"`
|
|
DisplayName string `json:"display_name,omitempty"`
|
|
}
|
|
|
|
type AccountKey struct {
|
|
PublicKey string `json:"public_key"`
|
|
Fingerprint string `json:"fingerprint"`
|
|
}
|
|
|
|
type AccountAccess struct {
|
|
UserID int `json:"user_id"`
|
|
Username string `json:"username"`
|
|
Email string `json:"email"`
|
|
Keys []AccountKey `json:"keys"`
|
|
}
|
|
|
|
type AccountSyncEntry struct {
|
|
UserID int `json:"user_id"`
|
|
SystemUsername string `json:"system_username"`
|
|
Present bool `json:"present"`
|
|
}
|
|
|
|
type SyncReportRequest struct {
|
|
AppliedCount int `json:"applied_count"`
|
|
RevokedCount int `json:"revoked_count"`
|
|
Message string `json:"message,omitempty"`
|
|
Metadata map[string]any `json:"metadata,omitempty"`
|
|
Accounts []AccountSyncEntry `json:"accounts,omitempty"`
|
|
}
|
|
|
|
func Enroll(ctx context.Context, serverURL string, req EnrollRequest) (*EnrollResponse, error) {
|
|
baseURL := strings.TrimRight(serverURL, "/")
|
|
if baseURL == "" {
|
|
return nil, errors.New("server url is required")
|
|
}
|
|
body, err := json.Marshal(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("encode enroll request: %w", err)
|
|
}
|
|
httpClient := &http.Client{Timeout: defaultTimeout}
|
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+"/agent/enroll", bytes.NewReader(body))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("build enroll request: %w", err)
|
|
}
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|
resp, err := httpClient.Do(httpReq)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("enroll request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("enroll failed: status %s", resp.Status)
|
|
}
|
|
var out EnrollResponse
|
|
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
|
|
return nil, fmt.Errorf("decode enroll response: %w", err)
|
|
}
|
|
if out.ServerID == "" || out.ClientCert == "" || out.CACert == "" {
|
|
return nil, errors.New("enroll response missing required fields")
|
|
}
|
|
return &out, nil
|
|
}
|
|
|
|
func (c *Client) SyncAccounts(ctx context.Context, cfg *config.Config) error {
|
|
if cfg == nil {
|
|
return errors.New("config required for account sync")
|
|
}
|
|
users, err := c.FetchAccountAccess(ctx, cfg.ServerID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
accessUsers := make([]accounts.AccessUser, 0, len(users))
|
|
for _, user := range users {
|
|
keys := make([]string, 0, len(user.Keys))
|
|
for _, key := range user.Keys {
|
|
if strings.TrimSpace(key.PublicKey) == "" {
|
|
continue
|
|
}
|
|
keys = append(keys, strings.TrimSpace(key.PublicKey))
|
|
}
|
|
accessUsers = append(accessUsers, accounts.AccessUser{
|
|
UserID: user.UserID,
|
|
Username: user.Username,
|
|
Email: user.Email,
|
|
Keys: keys,
|
|
})
|
|
}
|
|
result, syncErr := accounts.Sync(cfg.AccountPolicy, cfg.StateDir, accessUsers)
|
|
report := SyncReportRequest{
|
|
AppliedCount: result.Applied,
|
|
RevokedCount: result.Revoked,
|
|
Accounts: make([]AccountSyncEntry, 0, len(result.Accounts)),
|
|
}
|
|
for _, account := range result.Accounts {
|
|
report.Accounts = append(report.Accounts, AccountSyncEntry{
|
|
UserID: account.UserID,
|
|
SystemUsername: account.SystemUser,
|
|
Present: account.Present,
|
|
})
|
|
}
|
|
if syncErr != nil {
|
|
report.Message = syncErr.Error()
|
|
}
|
|
if err := c.SendSyncReport(ctx, cfg.ServerID, report); err != nil {
|
|
if syncErr != nil {
|
|
return fmt.Errorf("sync report failed: %w (sync error: %v)", err, syncErr)
|
|
}
|
|
return err
|
|
}
|
|
return syncErr
|
|
}
|
|
|
|
func (c *Client) FetchAccountAccess(ctx context.Context, serverID string) ([]AccountAccess, error) {
|
|
req, err := http.NewRequestWithContext(
|
|
ctx,
|
|
http.MethodGet,
|
|
c.baseURL+"/agent/servers/"+serverID+"/accounts",
|
|
nil,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("build account access request: %w", err)
|
|
}
|
|
resp, err := c.http.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("fetch account access: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode >= 300 {
|
|
return nil, &HTTPStatusError{StatusCode: resp.StatusCode, Status: resp.Status}
|
|
}
|
|
var out []AccountAccess
|
|
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
|
|
return nil, fmt.Errorf("decode account access: %w", err)
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func (c *Client) SendSyncReport(ctx context.Context, serverID string, report SyncReportRequest) error {
|
|
body, err := json.Marshal(report)
|
|
if err != nil {
|
|
return fmt.Errorf("encode sync report: %w", err)
|
|
}
|
|
req, err := http.NewRequestWithContext(
|
|
ctx,
|
|
http.MethodPost,
|
|
c.baseURL+"/agent/servers/"+serverID+"/sync-report",
|
|
bytes.NewReader(body),
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("build sync report: %w", err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
resp, err := c.http.Do(req)
|
|
if err != nil {
|
|
return fmt.Errorf("send sync report: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode >= 300 {
|
|
return &HTTPStatusError{StatusCode: resp.StatusCode, Status: resp.Status}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *Client) SendLogBatch(ctx context.Context, serverID string, payload []byte) error {
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/agent/servers/"+serverID+"/logs", bytes.NewReader(payload))
|
|
if err != nil {
|
|
return fmt.Errorf("build log request: %w", err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
resp, err := c.http.Do(req)
|
|
if err != nil {
|
|
return fmt.Errorf("send log batch: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode >= 300 {
|
|
return &HTTPStatusError{StatusCode: resp.StatusCode, Status: resp.Status}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type HeartbeatRequest struct {
|
|
Host string `json:"host,omitempty"`
|
|
IPv4 string `json:"ipv4,omitempty"`
|
|
IPv6 string `json:"ipv6,omitempty"`
|
|
}
|
|
|
|
func (c *Client) UpdateHost(ctx context.Context, serverID string, reqBody HeartbeatRequest) error {
|
|
body, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
return fmt.Errorf("encode host update: %w", err)
|
|
}
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/agent/servers/"+serverID+"/heartbeat", bytes.NewReader(body))
|
|
if err != nil {
|
|
return fmt.Errorf("build host update: %w", err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
resp, err := c.http.Do(req)
|
|
if err != nil {
|
|
return fmt.Errorf("send host update: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode >= 300 {
|
|
return &HTTPStatusError{StatusCode: resp.StatusCode, Status: resp.Status}
|
|
}
|
|
return nil
|
|
}
|