Files
elmprodvpn/selective-vpn-api/app/transport_policy_idempotency_state.go

210 lines
5.9 KiB
Go

package app
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"os"
"path/filepath"
"sort"
"strings"
"time"
)
const (
transportPolicyIdempotencyTTL = 24 * time.Hour
transportPolicyIdempotencyMaxItems = 256
transportPolicyIdempotencyApplyScope = "transport_policy_apply"
transportPolicyIdempotencyRollbackScope = "transport_policy_rollback"
)
type transportPolicyIdempotencyRecord struct {
Key string `json:"key"`
Scope string `json:"scope"`
RequestHash string `json:"request_hash,omitempty"`
Response TransportPolicyResponse `json:"response"`
CreatedAt string `json:"created_at,omitempty"`
}
type transportPolicyIdempotencyState struct {
Version int `json:"version"`
UpdatedAt string `json:"updated_at,omitempty"`
Items []transportPolicyIdempotencyRecord `json:"items,omitempty"`
}
type transportPolicyIdempotencyLookup struct {
Replay bool
Conflict bool
Response TransportPolicyResponse
}
func normalizeTransportIdempotencyKey(raw string) string {
return strings.TrimSpace(raw)
}
func hashTransportPolicyMutationRequest(v any) string {
data, _ := json.Marshal(v)
sum := sha256.Sum256(data)
return hex.EncodeToString(sum[:])
}
func loadTransportPolicyIdempotencyState() transportPolicyIdempotencyState {
st := transportPolicyIdempotencyState{Version: transportStateVersion}
data, err := os.ReadFile(transportPolicyIdempotencyStatePath)
if err != nil {
return st
}
if err := json.Unmarshal(data, &st); err != nil {
return transportPolicyIdempotencyState{Version: transportStateVersion}
}
if st.Version == 0 {
st.Version = transportStateVersion
}
if st.Items == nil {
st.Items = nil
}
norm, changed := normalizeTransportPolicyIdempotencyState(st, time.Now().UTC())
if changed {
_ = saveTransportPolicyIdempotencyState(norm)
return norm
}
return norm
}
func saveTransportPolicyIdempotencyState(st transportPolicyIdempotencyState) error {
st.Version = transportStateVersion
st.UpdatedAt = time.Now().UTC().Format(time.RFC3339)
data, err := json.MarshalIndent(st, "", " ")
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(transportPolicyIdempotencyStatePath), 0o755); err != nil {
return err
}
tmp := transportPolicyIdempotencyStatePath + ".tmp"
if err := os.WriteFile(tmp, data, 0o644); err != nil {
return err
}
return os.Rename(tmp, transportPolicyIdempotencyStatePath)
}
func normalizeTransportPolicyIdempotencyState(st transportPolicyIdempotencyState, now time.Time) (transportPolicyIdempotencyState, bool) {
changed := false
st.Version = transportStateVersion
if st.Items == nil {
st.Items = nil
}
out := make([]transportPolicyIdempotencyRecord, 0, len(st.Items))
for _, raw := range st.Items {
rec := raw
rec.Key = normalizeTransportIdempotencyKey(rec.Key)
rec.Scope = strings.TrimSpace(rec.Scope)
rec.RequestHash = strings.TrimSpace(rec.RequestHash)
if rec.Key == "" || rec.Scope == "" || rec.RequestHash == "" {
changed = true
continue
}
if transportPolicyIdempotencyExpired(rec, now) {
changed = true
continue
}
out = append(out, rec)
}
sort.Slice(out, func(i, j int) bool {
return strings.TrimSpace(out[i].CreatedAt) > strings.TrimSpace(out[j].CreatedAt)
})
if len(out) > transportPolicyIdempotencyMaxItems {
out = out[:transportPolicyIdempotencyMaxItems]
changed = true
}
st.Items = out
return st, changed
}
func transportPolicyIdempotencyExpired(rec transportPolicyIdempotencyRecord, now time.Time) bool {
if transportPolicyIdempotencyTTL <= 0 {
return false
}
ts := strings.TrimSpace(rec.CreatedAt)
if ts == "" {
return false
}
parsed, err := time.Parse(time.RFC3339, ts)
if err != nil {
return false
}
return now.Sub(parsed) > transportPolicyIdempotencyTTL
}
func lookupTransportPolicyIdempotencyLocked(scope, key, requestHash string) transportPolicyIdempotencyLookup {
key = normalizeTransportIdempotencyKey(key)
scope = strings.TrimSpace(scope)
requestHash = strings.TrimSpace(requestHash)
if key == "" || scope == "" || requestHash == "" {
return transportPolicyIdempotencyLookup{}
}
st := loadTransportPolicyIdempotencyState()
for _, rec := range st.Items {
if rec.Scope != scope || rec.Key != key {
continue
}
if rec.RequestHash == requestHash {
return transportPolicyIdempotencyLookup{
Replay: true,
Response: rec.Response,
}
}
return transportPolicyIdempotencyLookup{
Conflict: true,
Response: TransportPolicyResponse{
OK: false,
Message: "idempotency key already used for different request payload",
Code: "IDEMPOTENCY_KEY_REUSED",
},
}
}
return transportPolicyIdempotencyLookup{}
}
func saveTransportPolicyIdempotencyLocked(scope, key, requestHash string, resp TransportPolicyResponse) error {
key = normalizeTransportIdempotencyKey(key)
scope = strings.TrimSpace(scope)
requestHash = strings.TrimSpace(requestHash)
if key == "" || scope == "" || requestHash == "" {
return nil
}
now := time.Now().UTC()
st := loadTransportPolicyIdempotencyState()
st, _ = normalizeTransportPolicyIdempotencyState(st, now)
record := transportPolicyIdempotencyRecord{
Key: key,
Scope: scope,
RequestHash: requestHash,
Response: resp,
CreatedAt: now.Format(time.RFC3339),
}
replaced := false
for i := range st.Items {
if st.Items[i].Scope == scope && st.Items[i].Key == key {
st.Items[i] = record
replaced = true
break
}
}
if !replaced {
st.Items = append(st.Items, record)
}
st, _ = normalizeTransportPolicyIdempotencyState(st, now)
return saveTransportPolicyIdempotencyState(st)
}
func persistTransportPolicyIdempotencyLocked(scope, key, requestHash string, resp TransportPolicyResponse) {
if err := saveTransportPolicyIdempotencyLocked(scope, key, requestHash, resp); err != nil {
appendTraceLineRateLimited(
"transport",
"policy idempotency save warning: "+err.Error(),
5*time.Second,
)
}
}