Commit b55cee4d authored by AndreasHuber-CH's avatar AndreasHuber-CH Committed by GitHub

plugin/rewrite: Allow configuring min and max TTL values when rewriting TTL (#5508)

parent 1a31b35b
...@@ -311,7 +311,27 @@ The syntax for the TTL rewrite rule is as follows. The meaning of ...@@ -311,7 +311,27 @@ The syntax for the TTL rewrite rule is as follows. The meaning of
An omitted type is defaulted to `exact`. An omitted type is defaulted to `exact`.
``` ```
rewrite [continue|stop] ttl [exact|prefix|suffix|substring|regex] STRING SECONDS rewrite [continue|stop] ttl [exact|prefix|suffix|substring|regex] STRING [SECONDS|MIN-MAX]
```
It is possible to supply a range of TTL values in the `SECONDS` parameters instead of a single value.
If a range is supplied, the TTL value is set to `MIN` if it is below, or set to `MAX` if it is above.
The TTL value is left unchanged if it is already inside the provided range.
The ranges can be unbounded on either side.
TTL examples with ranges:
```
# rewrite TTL to be between 30s and 300s
rewrite ttl example.com. 30-300
# cap TTL at 30s
rewrite ttl example.com. -30 # equivalent to rewrite ttl example.com. 0-30
# increase TTL to a minimum of 30s
rewrite ttl example.com. 30-
# set TTL to 30s
rewrite ttl example.com. 30 # equivalent to rewrite ttl example.com. 30-30
``` ```
## EDNS0 Options ## EDNS0 Options
......
...@@ -14,11 +14,16 @@ import ( ...@@ -14,11 +14,16 @@ import (
) )
type ttlResponseRule struct { type ttlResponseRule struct {
TTL uint32 minTTL uint32
maxTTL uint32
} }
func (r *ttlResponseRule) RewriteResponse(rr dns.RR) { func (r *ttlResponseRule) RewriteResponse(rr dns.RR) {
rr.Header().Ttl = r.TTL if rr.Header().Ttl < r.minTTL {
rr.Header().Ttl = r.minTTL
} else if rr.Header().Ttl > r.maxTTL {
rr.Header().Ttl = r.maxTTL
}
} }
type ttlRuleBase struct { type ttlRuleBase struct {
...@@ -26,10 +31,10 @@ type ttlRuleBase struct { ...@@ -26,10 +31,10 @@ type ttlRuleBase struct {
response ttlResponseRule response ttlResponseRule
} }
func newTTLRuleBase(nextAction string, ttl uint32) ttlRuleBase { func newTTLRuleBase(nextAction string, minTtl, maxTtl uint32) ttlRuleBase {
return ttlRuleBase{ return ttlRuleBase{
nextAction: nextAction, nextAction: nextAction,
response: ttlResponseRule{TTL: ttl}, response: ttlResponseRule{minTTL: minTtl, maxTTL: maxTtl},
} }
} }
...@@ -108,7 +113,7 @@ func newTTLRule(nextAction string, args ...string) (Rule, error) { ...@@ -108,7 +113,7 @@ func newTTLRule(nextAction string, args ...string) (Rule, error) {
if len(args) == 3 { if len(args) == 3 {
s = args[2] s = args[2]
} }
ttl, valid := isValidTTL(s) minTtl, maxTtl, valid := isValidTTL(s)
if !valid { if !valid {
return nil, fmt.Errorf("invalid TTL '%s' for a ttl rule", s) return nil, fmt.Errorf("invalid TTL '%s' for a ttl rule", s)
} }
...@@ -116,22 +121,22 @@ func newTTLRule(nextAction string, args ...string) (Rule, error) { ...@@ -116,22 +121,22 @@ func newTTLRule(nextAction string, args ...string) (Rule, error) {
switch strings.ToLower(args[0]) { switch strings.ToLower(args[0]) {
case ExactMatch: case ExactMatch:
return &exactTTLRule{ return &exactTTLRule{
newTTLRuleBase(nextAction, ttl), newTTLRuleBase(nextAction, minTtl, maxTtl),
plugin.Name(args[1]).Normalize(), plugin.Name(args[1]).Normalize(),
}, nil }, nil
case PrefixMatch: case PrefixMatch:
return &prefixTTLRule{ return &prefixTTLRule{
newTTLRuleBase(nextAction, ttl), newTTLRuleBase(nextAction, minTtl, maxTtl),
plugin.Name(args[1]).Normalize(), plugin.Name(args[1]).Normalize(),
}, nil }, nil
case SuffixMatch: case SuffixMatch:
return &suffixTTLRule{ return &suffixTTLRule{
newTTLRuleBase(nextAction, ttl), newTTLRuleBase(nextAction, minTtl, maxTtl),
plugin.Name(args[1]).Normalize(), plugin.Name(args[1]).Normalize(),
}, nil }, nil
case SubstringMatch: case SubstringMatch:
return &substringTTLRule{ return &substringTTLRule{
newTTLRuleBase(nextAction, ttl), newTTLRuleBase(nextAction, minTtl, maxTtl),
plugin.Name(args[1]).Normalize(), plugin.Name(args[1]).Normalize(),
}, nil }, nil
case RegexMatch: case RegexMatch:
...@@ -140,7 +145,7 @@ func newTTLRule(nextAction string, args ...string) (Rule, error) { ...@@ -140,7 +145,7 @@ func newTTLRule(nextAction string, args ...string) (Rule, error) {
return nil, fmt.Errorf("invalid regex pattern in a ttl rule: %s", args[1]) return nil, fmt.Errorf("invalid regex pattern in a ttl rule: %s", args[1])
} }
return &regexTTLRule{ return &regexTTLRule{
newTTLRuleBase(nextAction, ttl), newTTLRuleBase(nextAction, minTtl, maxTtl),
regexPattern, regexPattern,
}, nil }, nil
default: default:
...@@ -151,22 +156,50 @@ func newTTLRule(nextAction string, args ...string) (Rule, error) { ...@@ -151,22 +156,50 @@ func newTTLRule(nextAction string, args ...string) (Rule, error) {
return nil, fmt.Errorf("many few arguments for a ttl rule") return nil, fmt.Errorf("many few arguments for a ttl rule")
} }
return &exactTTLRule{ return &exactTTLRule{
newTTLRuleBase(nextAction, ttl), newTTLRuleBase(nextAction, minTtl, maxTtl),
plugin.Name(args[0]).Normalize(), plugin.Name(args[0]).Normalize(),
}, nil }, nil
} }
// validTTL returns true if v is valid TTL value. // validTTL returns true if v is valid TTL value.
func isValidTTL(v string) (uint32, bool) { func isValidTTL(v string) (uint32, uint32, bool) {
i, err := strconv.Atoi(v) s := strings.Split(v, "-")
if err != nil { if len(s) == 1 {
return uint32(0), false i, err := strconv.ParseUint(s[0], 10, 32)
} if err != nil {
if i > 2147483647 { return 0, 0, false
return uint32(0), false }
return uint32(i), uint32(i), true
} }
if i < 0 { if len(s) == 2 {
return uint32(0), false var min, max uint64
var err error
if s[0] == "" {
min = 0
} else {
min, err = strconv.ParseUint(s[0], 10, 32)
if err != nil {
return 0, 0, false
}
}
if s[1] == "" {
if s[0] == "" {
// explicitly reject ttl directive "-" that would otherwise be interpreted
// as 0-2147483647 which is pretty useless
return 0, 0, false
}
max = 2147483647
} else {
max, err = strconv.ParseUint(s[1], 10, 32)
if err != nil {
return 0, 0, false
}
}
if min > max {
// reject invalid range
return 0, 0, false
}
return uint32(min), uint32(max), true
} }
return uint32(i), true return 0, 0, false
} }
...@@ -32,7 +32,14 @@ func TestNewTTLRule(t *testing.T) { ...@@ -32,7 +32,14 @@ func TestNewTTLRule(t *testing.T) {
{"continue", []string{"regex", `(srv1)\.(coredns)\.(rocks)`, "35"}, false}, {"continue", []string{"regex", `(srv1)\.(coredns)\.(rocks)`, "35"}, false},
{"stop", []string{"srv1.coredns.rocks", "12345678901234567890"}, true}, {"stop", []string{"srv1.coredns.rocks", "12345678901234567890"}, true},
{"stop", []string{"srv1.coredns.rocks", "coredns.rocks"}, true}, {"stop", []string{"srv1.coredns.rocks", "coredns.rocks"}, true},
{"stop", []string{"srv1.coredns.rocks", "-1"}, true}, {"stop", []string{"srv1.coredns.rocks", "#1"}, true},
{"stop", []string{"range.coredns.rocks", "1-2"}, false},
{"stop", []string{"ceil.coredns.rocks", "-2"}, false},
{"stop", []string{"floor.coredns.rocks", "1-"}, false},
{"stop", []string{"range.coredns.rocks", "2-2"}, false},
{"stop", []string{"invalid.coredns.rocks", "-"}, true},
{"stop", []string{"invalid.coredns.rocks", "2-1"}, true},
{"stop", []string{"invalid.coredns.rocks", "5-10-20"}, true},
} }
for i, tc := range tests { for i, tc := range tests {
failed := false failed := false
...@@ -78,6 +85,9 @@ func TestTtlRewrite(t *testing.T) { ...@@ -78,6 +85,9 @@ func TestTtlRewrite(t *testing.T) {
{[]string{"stop", "ttl", "substring", "rv50", "50"}, reflect.TypeOf(&substringTTLRule{})}, {[]string{"stop", "ttl", "substring", "rv50", "50"}, reflect.TypeOf(&substringTTLRule{})},
{[]string{"stop", "ttl", "regex", `(srv10)\.(coredns)\.(rocks)`, "10"}, reflect.TypeOf(&regexTTLRule{})}, {[]string{"stop", "ttl", "regex", `(srv10)\.(coredns)\.(rocks)`, "10"}, reflect.TypeOf(&regexTTLRule{})},
{[]string{"stop", "ttl", "regex", `(srv20)\.(coredns)\.(rocks)`, "20"}, reflect.TypeOf(&regexTTLRule{})}, {[]string{"stop", "ttl", "regex", `(srv20)\.(coredns)\.(rocks)`, "20"}, reflect.TypeOf(&regexTTLRule{})},
{[]string{"stop", "ttl", "range.example.com.", "30-300"}, reflect.TypeOf(&exactTTLRule{})},
{[]string{"stop", "ttl", "ceil.example.com.", "-11"}, reflect.TypeOf(&exactTTLRule{})},
{[]string{"stop", "ttl", "floor.example.com.", "5-"}, reflect.TypeOf(&exactTTLRule{})},
} }
for i, r := range ruleset { for i, r := range ruleset {
rule, err := newRule(r.args...) rule, err := newRule(r.args...)
...@@ -112,6 +122,13 @@ func doTTLTests(rules []Rule, t *testing.T) { ...@@ -112,6 +122,13 @@ func doTTLTests(rules []Rule, t *testing.T) {
test.A("srv20.coredns.rocks. 5 IN A 10.0.0.22"), test.A("srv20.coredns.rocks. 5 IN A 10.0.0.22"),
test.A("srv20.coredns.rocks. 5 IN A 10.0.0.23"), test.A("srv20.coredns.rocks. 5 IN A 10.0.0.23"),
}, 20, false}, }, 20, false},
{"range.example.com.", dns.TypeA, []dns.RR{test.A("range.example.com. 5 IN A 10.0.0.1")}, 30, false},
{"range.example.com.", dns.TypeA, []dns.RR{test.A("range.example.com. 55 IN A 10.0.0.1")}, 55, false},
{"range.example.com.", dns.TypeA, []dns.RR{test.A("range.example.com. 500 IN A 10.0.0.1")}, 300, false},
{"ceil.example.com.", dns.TypeA, []dns.RR{test.A("ceil.example.com. 5 IN A 10.0.0.1")}, 5, false},
{"ceil.example.com.", dns.TypeA, []dns.RR{test.A("ceil.example.com. 15 IN A 10.0.0.1")}, 11, false},
{"floor.example.com.", dns.TypeA, []dns.RR{test.A("floor.example.com. 0 IN A 10.0.0.1")}, 5, false},
{"floor.example.com.", dns.TypeA, []dns.RR{test.A("floor.example.com. 30 IN A 10.0.0.1")}, 30, false},
} }
ctx := context.TODO() ctx := context.TODO()
for i, tc := range tests { for i, tc := range tests {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment