mirror of
https://github.com/juanfont/headscale.git
synced 2026-01-23 02:24:10 +00:00
hscontrol/policy/v2: define sentinel errors
Add comprehensive sentinel errors for all error conditions in the policy
engine and use consistent error wrapping patterns with fmt.Errorf("%w: ...).
Update test expectations to match the new error message formats.
This commit is contained in:
parent
144c79aedf
commit
25fdad3949
5 changed files with 213 additions and 168 deletions
|
|
@ -1,7 +1,6 @@
|
|||
package v2
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"time"
|
||||
|
|
@ -14,8 +13,6 @@ import (
|
|||
"tailscale.com/types/views"
|
||||
)
|
||||
|
||||
var ErrInvalidAction = errors.New("invalid action")
|
||||
|
||||
// compileFilterRules takes a set of nodes and an ACLPolicy and generates a
|
||||
// set of Tailscale compatible FilterRules used to allow traffic on clients.
|
||||
func (pol *Policy) compileFilterRules(
|
||||
|
|
@ -149,7 +146,7 @@ func (pol *Policy) compileACLWithAutogroupSelf(
|
|||
|
||||
for _, src := range acl.Sources {
|
||||
if ag, ok := src.(*AutoGroup); ok && ag.Is(AutoGroupSelf) {
|
||||
return nil, errors.New("autogroup:self cannot be used in sources")
|
||||
return nil, ErrAutogroupSelfInSource
|
||||
}
|
||||
|
||||
ips, err := src.Resolve(pol, users, nodes)
|
||||
|
|
|
|||
|
|
@ -36,6 +36,73 @@ var ErrCircularReference = errors.New("circular reference detected")
|
|||
|
||||
var ErrUndefinedTagReference = errors.New("references undefined tag")
|
||||
|
||||
// Sentinel errors for type/alias validation.
|
||||
var (
|
||||
ErrUnknownAliasType = errors.New("unknown alias type")
|
||||
ErrUnknownOwnerType = errors.New("unknown owner type")
|
||||
ErrUnknownAutoApproverType = errors.New("unknown auto approver type")
|
||||
ErrInvalidAlias = errors.New("invalid alias")
|
||||
ErrInvalidAutoApprover = errors.New("invalid auto approver")
|
||||
ErrInvalidOwner = errors.New("invalid owner")
|
||||
)
|
||||
|
||||
// Sentinel errors for format validation.
|
||||
var (
|
||||
ErrUsernameMissingAt = errors.New("username must contain @")
|
||||
ErrGroupMissingPrefix = errors.New("group must start with 'group:'")
|
||||
ErrTagMissingPrefix = errors.New("tag must start with 'tag:'")
|
||||
ErrInvalidHostname = errors.New("invalid hostname")
|
||||
ErrInvalidPrefix = errors.New("invalid prefix")
|
||||
ErrInvalidAutoGroup = errors.New("invalid autogroup")
|
||||
ErrInvalidAction = errors.New("invalid action")
|
||||
ErrInvalidSSHAction = errors.New("invalid SSH action")
|
||||
ErrInvalidProtocol = errors.New("invalid protocol")
|
||||
ErrProtocolOutOfRange = errors.New("protocol number out of range")
|
||||
ErrLeadingZeroProtocol = errors.New("leading zero not permitted in protocol number")
|
||||
ErrHostportMissingColon = errors.New("hostport must contain a colon")
|
||||
ErrUnsupportedType = errors.New("unsupported type")
|
||||
)
|
||||
|
||||
// Sentinel errors for resolution/lookup failures.
|
||||
var (
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
ErrMultipleUsersFound = errors.New("multiple users found")
|
||||
ErrHostNotResolved = errors.New("unable to resolve host")
|
||||
ErrGroupNotDefined = errors.New("group not defined in policy")
|
||||
ErrTagNotDefined = errors.New("tag not defined in policy")
|
||||
ErrHostNotDefined = errors.New("host not defined in policy")
|
||||
ErrInvalidIPAddress = errors.New("invalid IP address")
|
||||
ErrNestedGroups = errors.New("nested groups not allowed")
|
||||
ErrInvalidGroupMember = errors.New("invalid group member type")
|
||||
ErrGroupValueNotArray = errors.New("group value must be an array")
|
||||
ErrAutoApproverNotAlias = errors.New("auto approver is not an alias")
|
||||
)
|
||||
|
||||
// Sentinel errors for autogroup context validation.
|
||||
var (
|
||||
ErrAutogroupInternetInSource = errors.New("autogroup:internet can only be used in ACL destinations")
|
||||
ErrAutogroupSelfInSource = errors.New("autogroup:self can only be used in ACL destinations")
|
||||
ErrAutogroupNotSupportedSource = errors.New("autogroup not supported for source")
|
||||
ErrAutogroupNotSupportedDest = errors.New("autogroup not supported for destination")
|
||||
ErrAutogroupNotSupportedSSH = errors.New("autogroup not supported for SSH")
|
||||
ErrAutogroupNotSupported = errors.New("autogroup not supported in headscale")
|
||||
ErrAliasNotSupportedSSH = errors.New("alias type not supported for SSH")
|
||||
)
|
||||
|
||||
// Sentinel errors for SSH aliases.
|
||||
var (
|
||||
ErrAliasNotSupportedSSHSrc = errors.New("alias type not supported for SSH source")
|
||||
ErrAliasNotSupportedSSHDst = errors.New("alias type not supported for SSH destination")
|
||||
ErrUnknownSSHSrcAliasType = errors.New("unknown SSH source alias type")
|
||||
ErrUnknownSSHDstAliasType = errors.New("unknown SSH destination alias type")
|
||||
)
|
||||
|
||||
// Sentinel errors for policy parsing.
|
||||
var (
|
||||
ErrUnknownField = errors.New("unknown field in policy")
|
||||
ErrProtocolNoSpecificPorts = errors.New("protocol does not support specific ports")
|
||||
)
|
||||
|
||||
type Asterix int
|
||||
|
||||
func (a Asterix) Validate() error {
|
||||
|
|
@ -75,7 +142,7 @@ func (a AliasWithPorts) MarshalJSON() ([]byte, error) {
|
|||
case Asterix:
|
||||
alias = "*"
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown alias type: %T", v)
|
||||
return nil, fmt.Errorf("%w: %T", ErrUnknownAliasType, v)
|
||||
}
|
||||
|
||||
// If no ports are specified
|
||||
|
|
@ -126,7 +193,7 @@ func (u Username) Validate() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("Username has to contain @, got: %q", u)
|
||||
return fmt.Errorf("%w: got %q", ErrUsernameMissingAt, u)
|
||||
}
|
||||
|
||||
func (u *Username) String() string {
|
||||
|
|
@ -186,11 +253,11 @@ func (u Username) resolveUser(users types.Users) (types.User, error) {
|
|||
}
|
||||
|
||||
if len(potentialUsers) == 0 {
|
||||
return types.User{}, fmt.Errorf("user with token %q not found", u.String())
|
||||
return types.User{}, fmt.Errorf("%w: token %q", ErrUserNotFound, u.String())
|
||||
}
|
||||
|
||||
if len(potentialUsers) > 1 {
|
||||
return types.User{}, fmt.Errorf("multiple users with token %q found: %s", u.String(), potentialUsers.String())
|
||||
return types.User{}, fmt.Errorf("%w: token %q found %s", ErrMultipleUsersFound, u.String(), potentialUsers.String())
|
||||
}
|
||||
|
||||
return potentialUsers[0], nil
|
||||
|
|
@ -234,7 +301,7 @@ func (g Group) Validate() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf(`Group has to start with "group:", got: %q`, g)
|
||||
return fmt.Errorf("%w: got %q", ErrGroupMissingPrefix, g)
|
||||
}
|
||||
|
||||
func (g *Group) UnmarshalJSON(b []byte) error {
|
||||
|
|
@ -299,7 +366,7 @@ func (t Tag) Validate() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf(`tag has to start with "tag:", got: %q`, t)
|
||||
return fmt.Errorf("%w: got %q", ErrTagMissingPrefix, t)
|
||||
}
|
||||
|
||||
func (t *Tag) UnmarshalJSON(b []byte) error {
|
||||
|
|
@ -349,7 +416,7 @@ func (h Host) Validate() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("Hostname %q is invalid", h)
|
||||
return fmt.Errorf("%w: %q", ErrInvalidHostname, h)
|
||||
}
|
||||
|
||||
func (h *Host) UnmarshalJSON(b []byte) error {
|
||||
|
|
@ -369,7 +436,7 @@ func (h Host) Resolve(p *Policy, _ types.Users, nodes views.Slice[types.NodeView
|
|||
|
||||
pref, ok := p.Hosts[h]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unable to resolve host: %q", h)
|
||||
return nil, fmt.Errorf("%w: %q", ErrHostNotResolved, h)
|
||||
}
|
||||
|
||||
err := pref.Validate()
|
||||
|
|
@ -406,7 +473,7 @@ func (p Prefix) Validate() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("Prefix %q is invalid", p)
|
||||
return fmt.Errorf("%w: %q", ErrInvalidPrefix, p)
|
||||
}
|
||||
|
||||
func (p Prefix) String() string {
|
||||
|
|
@ -510,7 +577,7 @@ func (ag AutoGroup) Validate() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("AutoGroup is invalid, got: %q, must be one of %v", ag, autogroups)
|
||||
return fmt.Errorf("%w: got %q, must be one of %v", ErrInvalidAutoGroup, ag, autogroups)
|
||||
}
|
||||
|
||||
func (ag *AutoGroup) UnmarshalJSON(b []byte) error {
|
||||
|
|
@ -570,7 +637,7 @@ func (ag AutoGroup) Resolve(p *Policy, users types.Users, nodes views.Slice[type
|
|||
return nil, ErrAutogroupSelfRequiresPerNodeResolution
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown autogroup %q", ag)
|
||||
return nil, fmt.Errorf("%w: %q", ErrInvalidAutoGroup, ag)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -626,7 +693,7 @@ func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error {
|
|||
|
||||
ve.Ports = ports
|
||||
} else {
|
||||
return errors.New(`hostport must contain a colon (":")`)
|
||||
return ErrHostportMissingColon
|
||||
}
|
||||
|
||||
ve.Alias, err = parseAlias(vs)
|
||||
|
|
@ -639,7 +706,7 @@ func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error {
|
|||
}
|
||||
|
||||
default:
|
||||
return fmt.Errorf("type %T not supported", vs)
|
||||
return fmt.Errorf("%w: %T", ErrUnsupportedType, vs)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
@ -694,15 +761,7 @@ func parseAlias(vs string) (Alias, error) {
|
|||
return new(Host(vs)), nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf(`Invalid alias %q. An alias must be one of the following types:
|
||||
- wildcard (*)
|
||||
- user (containing an "@")
|
||||
- group (starting with "group:")
|
||||
- tag (starting with "tag:")
|
||||
- autogroup (starting with "autogroup:")
|
||||
- host
|
||||
|
||||
Please check the format and try again.`, vs)
|
||||
return nil, fmt.Errorf("%w: %q", ErrInvalidAlias, vs)
|
||||
}
|
||||
|
||||
// AliasEnc is used to deserialize a Alias.
|
||||
|
|
@ -764,7 +823,7 @@ func (a Aliases) MarshalJSON() ([]byte, error) {
|
|||
case Asterix:
|
||||
aliases[i] = "*"
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown alias type: %T", v)
|
||||
return nil, fmt.Errorf("%w: %T", ErrUnknownAliasType, v)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -850,7 +909,7 @@ func (aa AutoApprovers) MarshalJSON() ([]byte, error) {
|
|||
case *Group:
|
||||
approvers[i] = string(*v)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown auto approver type: %T", v)
|
||||
return nil, fmt.Errorf("%w: %T", ErrUnknownAutoApproverType, v)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -867,12 +926,7 @@ func parseAutoApprover(s string) (AutoApprover, error) {
|
|||
return new(Tag(s)), nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf(`Invalid AutoApprover %q. An alias must be one of the following types:
|
||||
- user (containing an "@")
|
||||
- group (starting with "group:")
|
||||
- tag (starting with "tag:")
|
||||
|
||||
Please check the format and try again.`, s)
|
||||
return nil, fmt.Errorf("%w: %q", ErrInvalidAutoApprover, s)
|
||||
}
|
||||
|
||||
// AutoApproverEnc is used to deserialize a AutoApprover.
|
||||
|
|
@ -949,7 +1003,7 @@ func (o Owners) MarshalJSON() ([]byte, error) {
|
|||
case *Tag:
|
||||
owners[i] = string(*v)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown owner type: %T", v)
|
||||
return nil, fmt.Errorf("%w: %T", ErrUnknownOwnerType, v)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -966,12 +1020,7 @@ func parseOwner(s string) (Owner, error) {
|
|||
return new(Tag(s)), nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf(`Invalid Owner %q. An alias must be one of the following types:
|
||||
- user (containing an "@")
|
||||
- group (starting with "group:")
|
||||
- tag (starting with "tag:")
|
||||
|
||||
Please check the format and try again.`, s)
|
||||
return nil, fmt.Errorf("%w: %q", ErrInvalidOwner, s)
|
||||
}
|
||||
|
||||
type Usernames []Username
|
||||
|
|
@ -990,7 +1039,7 @@ func (g Groups) Contains(group *Group) error {
|
|||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf(`Group %q is not defined in the Policy, please define or remove the reference to it`, group)
|
||||
return fmt.Errorf("%w: %q", ErrGroupNotDefined, group)
|
||||
}
|
||||
|
||||
// UnmarshalJSON overrides the default JSON unmarshalling for Groups to ensure
|
||||
|
|
@ -1025,15 +1074,15 @@ func (g *Groups) UnmarshalJSON(b []byte) error {
|
|||
if str, ok := item.(string); ok {
|
||||
stringSlice = append(stringSlice, str)
|
||||
} else {
|
||||
return fmt.Errorf(`Group "%s" contains invalid member type, expected string but got %T`, key, item)
|
||||
return fmt.Errorf("%w: group %q got %T", ErrInvalidGroupMember, key, item)
|
||||
}
|
||||
}
|
||||
|
||||
rawGroups[key] = stringSlice
|
||||
case string:
|
||||
return fmt.Errorf(`Group "%s" value must be an array of users, got string: "%s"`, key, v)
|
||||
return fmt.Errorf("%w: group %q got string %q", ErrGroupValueNotArray, key, v)
|
||||
default:
|
||||
return fmt.Errorf(`Group "%s" value must be an array of users, got %T`, key, v)
|
||||
return fmt.Errorf("%w: group %q got %T", ErrGroupValueNotArray, key, v)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1048,7 +1097,7 @@ func (g *Groups) UnmarshalJSON(b []byte) error {
|
|||
username := Username(u)
|
||||
if err := username.Validate(); err != nil {
|
||||
if isGroup(u) {
|
||||
return fmt.Errorf("Nested groups are not allowed, found %q inside %q", u, group)
|
||||
return fmt.Errorf("%w: found %q inside %q", ErrNestedGroups, u, group)
|
||||
}
|
||||
|
||||
return err
|
||||
|
|
@ -1082,7 +1131,7 @@ func (h *Hosts) UnmarshalJSON(b []byte) error {
|
|||
|
||||
var prefix Prefix
|
||||
if err := prefix.parseString(value); err != nil {
|
||||
return fmt.Errorf(`Hostname "%s" contains an invalid IP address: "%s"`, key, value)
|
||||
return fmt.Errorf("%w: hostname %q value %q", ErrInvalidIPAddress, key, value)
|
||||
}
|
||||
|
||||
(*h)[host] = prefix
|
||||
|
|
@ -1131,7 +1180,7 @@ func (to TagOwners) MarshalJSON() ([]byte, error) {
|
|||
case *Tag:
|
||||
ownerStrs[i] = string(*v)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown owner type: %T", v)
|
||||
return nil, fmt.Errorf("%w: %T", ErrUnknownOwnerType, v)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1155,7 +1204,7 @@ func (to TagOwners) Contains(tagOwner *Tag) error {
|
|||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf(`Tag %q is not defined in the Policy, please define or remove the reference to it`, tagOwner)
|
||||
return fmt.Errorf("%w: %q", ErrTagNotDefined, tagOwner)
|
||||
}
|
||||
|
||||
type AutoApproverPolicy struct {
|
||||
|
|
@ -1208,7 +1257,7 @@ func resolveAutoApprovers(p *Policy, users types.Users, nodes views.Slice[types.
|
|||
aa, ok := autoApprover.(Alias)
|
||||
if !ok {
|
||||
// Should never happen
|
||||
return nil, nil, fmt.Errorf("autoApprover %v is not an Alias", autoApprover)
|
||||
return nil, nil, fmt.Errorf("%w: %v", ErrAutoApproverNotAlias, autoApprover)
|
||||
}
|
||||
// If it does not resolve, that means the autoApprover is not associated with any IP addresses.
|
||||
ips, _ := aa.Resolve(p, users, nodes)
|
||||
|
|
@ -1223,7 +1272,7 @@ func resolveAutoApprovers(p *Policy, users types.Users, nodes views.Slice[types.
|
|||
aa, ok := autoApprover.(Alias)
|
||||
if !ok {
|
||||
// Should never happen
|
||||
return nil, nil, fmt.Errorf("autoApprover %v is not an Alias", autoApprover)
|
||||
return nil, nil, fmt.Errorf("%w: %v", ErrAutoApproverNotAlias, autoApprover)
|
||||
}
|
||||
// If it does not resolve, that means the autoApprover is not associated with any IP addresses.
|
||||
ips, _ := aa.Resolve(p, users, nodes)
|
||||
|
|
@ -1280,7 +1329,7 @@ func (a *Action) UnmarshalJSON(b []byte) error {
|
|||
case "accept":
|
||||
*a = ActionAccept
|
||||
default:
|
||||
return fmt.Errorf("invalid action %q, must be %q", str, ActionAccept)
|
||||
return fmt.Errorf("%w: %q, must be %q", ErrInvalidAction, str, ActionAccept)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
@ -1305,7 +1354,7 @@ func (a *SSHAction) UnmarshalJSON(b []byte) error {
|
|||
case "check":
|
||||
*a = SSHActionCheck
|
||||
default:
|
||||
return fmt.Errorf("invalid SSH action %q, must be one of: accept, check", str)
|
||||
return fmt.Errorf("%w: %q, must be one of: accept, check", ErrInvalidSSHAction, str)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
@ -1443,23 +1492,23 @@ func (p Protocol) validate() error {
|
|||
return nil
|
||||
case ProtocolWildcard:
|
||||
// Wildcard "*" is not allowed - Tailscale rejects it
|
||||
return errors.New("proto name \"*\" not known; use protocol number 0-255 or protocol name (icmp, tcp, udp, etc.)")
|
||||
return fmt.Errorf("%w: use protocol number 0-255 or protocol name", ErrInvalidProtocol)
|
||||
default:
|
||||
// Try to parse as a numeric protocol number
|
||||
str := string(p)
|
||||
|
||||
// Check for leading zeros (not allowed by Tailscale)
|
||||
if str == "0" || (len(str) > 1 && str[0] == '0') {
|
||||
return fmt.Errorf("leading 0 not permitted in protocol number \"%s\"", str)
|
||||
return fmt.Errorf("%w: %q", ErrLeadingZeroProtocol, str)
|
||||
}
|
||||
|
||||
protocolNumber, err := strconv.Atoi(str)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid protocol %q: must be a known protocol name or valid protocol number 0-255", p)
|
||||
return fmt.Errorf("%w: %q must be a known protocol name or valid protocol number 0-255", ErrInvalidProtocol, p)
|
||||
}
|
||||
|
||||
if protocolNumber < 0 || protocolNumber > 255 {
|
||||
return fmt.Errorf("protocol number %d out of range (0-255)", protocolNumber)
|
||||
return fmt.Errorf("%w: %d", ErrProtocolOutOfRange, protocolNumber)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
@ -1577,7 +1626,7 @@ func validateAutogroupSupported(ag *AutoGroup) error {
|
|||
}
|
||||
|
||||
if slices.Contains(autogroupNotSupported, *ag) {
|
||||
return fmt.Errorf("autogroup %q is not supported in headscale", *ag)
|
||||
return fmt.Errorf("%w: %q", ErrAutogroupNotSupported, *ag)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
@ -1589,15 +1638,15 @@ func validateAutogroupForSrc(src *AutoGroup) error {
|
|||
}
|
||||
|
||||
if src.Is(AutoGroupInternet) {
|
||||
return errors.New(`"autogroup:internet" used in source, it can only be used in ACL destinations`)
|
||||
return ErrAutogroupInternetInSource
|
||||
}
|
||||
|
||||
if src.Is(AutoGroupSelf) {
|
||||
return errors.New(`"autogroup:self" used in source, it can only be used in ACL destinations`)
|
||||
return ErrAutogroupSelfInSource
|
||||
}
|
||||
|
||||
if !slices.Contains(autogroupForSrc, *src) {
|
||||
return fmt.Errorf("autogroup %q is not supported for ACL sources, can be %v", *src, autogroupForSrc)
|
||||
return fmt.Errorf("%w: %q, can be %v", ErrAutogroupNotSupportedSource, *src, autogroupForSrc)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
@ -1609,7 +1658,7 @@ func validateAutogroupForDst(dst *AutoGroup) error {
|
|||
}
|
||||
|
||||
if !slices.Contains(autogroupForDst, *dst) {
|
||||
return fmt.Errorf("autogroup %q is not supported for ACL destinations, can be %v", *dst, autogroupForDst)
|
||||
return fmt.Errorf("%w: %q, can be %v", ErrAutogroupNotSupportedDest, *dst, autogroupForDst)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
@ -1621,11 +1670,11 @@ func validateAutogroupForSSHSrc(src *AutoGroup) error {
|
|||
}
|
||||
|
||||
if src.Is(AutoGroupInternet) {
|
||||
return errors.New(`"autogroup:internet" used in SSH source, it can only be used in ACL destinations`)
|
||||
return fmt.Errorf("%w: autogroup:internet in SSH source", ErrAutogroupNotSupportedSSH)
|
||||
}
|
||||
|
||||
if !slices.Contains(autogroupForSSHSrc, *src) {
|
||||
return fmt.Errorf("autogroup %q is not supported for SSH sources, can be %v", *src, autogroupForSSHSrc)
|
||||
return fmt.Errorf("%w: %q for SSH sources, can be %v", ErrAutogroupNotSupportedSSH, *src, autogroupForSSHSrc)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
@ -1637,11 +1686,11 @@ func validateAutogroupForSSHDst(dst *AutoGroup) error {
|
|||
}
|
||||
|
||||
if dst.Is(AutoGroupInternet) {
|
||||
return errors.New(`"autogroup:internet" used in SSH destination, it can only be used in ACL destinations`)
|
||||
return fmt.Errorf("%w: autogroup:internet in SSH destination", ErrAutogroupNotSupportedSSH)
|
||||
}
|
||||
|
||||
if !slices.Contains(autogroupForSSHDst, *dst) {
|
||||
return fmt.Errorf("autogroup %q is not supported for SSH sources, can be %v", *dst, autogroupForSSHDst)
|
||||
return fmt.Errorf("%w: %q for SSH destinations, can be %v", ErrAutogroupNotSupportedSSH, *dst, autogroupForSSHDst)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
@ -1653,7 +1702,7 @@ func validateAutogroupForSSHUser(user *AutoGroup) error {
|
|||
}
|
||||
|
||||
if !slices.Contains(autogroupForSSHUser, *user) {
|
||||
return fmt.Errorf("autogroup %q is not supported for SSH user, can be %v", *user, autogroupForSSHUser)
|
||||
return fmt.Errorf("%w: %q for SSH user, can be %v", ErrAutogroupNotSupportedSSH, *user, autogroupForSSHUser)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
@ -1678,7 +1727,7 @@ func (p *Policy) validate() error {
|
|||
case *Host:
|
||||
h := src
|
||||
if !p.Hosts.exist(*h) {
|
||||
errs = append(errs, fmt.Errorf(`Host %q is not defined in the Policy, please define or remove the reference to it`, *h))
|
||||
errs = append(errs, fmt.Errorf("%w: %q - please define or remove the reference", ErrHostNotDefined, *h))
|
||||
}
|
||||
case *AutoGroup:
|
||||
ag := src
|
||||
|
|
@ -1710,7 +1759,7 @@ func (p *Policy) validate() error {
|
|||
case *Host:
|
||||
h := dst.Alias.(*Host)
|
||||
if !p.Hosts.exist(*h) {
|
||||
errs = append(errs, fmt.Errorf(`Host %q is not defined in the Policy, please define or remove the reference to it`, *h))
|
||||
errs = append(errs, fmt.Errorf("%w: %q - please define or remove the reference", ErrHostNotDefined, *h))
|
||||
}
|
||||
case *AutoGroup:
|
||||
ag := dst.Alias.(*AutoGroup)
|
||||
|
|
@ -1915,10 +1964,7 @@ func (a *SSHSrcAliases) UnmarshalJSON(b []byte) error {
|
|||
case *Username, *Group, *Tag, *AutoGroup:
|
||||
(*a)[i] = alias.Alias
|
||||
default:
|
||||
return fmt.Errorf(
|
||||
"alias %T is not supported for SSH source",
|
||||
alias.Alias,
|
||||
)
|
||||
return fmt.Errorf("%w: %T", ErrAliasNotSupportedSSHSrc, alias.Alias)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1946,10 +1992,7 @@ func (a *SSHDstAliases) UnmarshalJSON(b []byte) error {
|
|||
Asterix:
|
||||
(*a)[i] = alias.Alias
|
||||
default:
|
||||
return fmt.Errorf(
|
||||
"alias %T is not supported for SSH destination",
|
||||
alias.Alias,
|
||||
)
|
||||
return fmt.Errorf("%w: %T", ErrAliasNotSupportedSSHDst, alias.Alias)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1976,7 +2019,7 @@ func (a SSHDstAliases) MarshalJSON() ([]byte, error) {
|
|||
case Asterix:
|
||||
aliases[i] = "*"
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown SSH destination alias type: %T", v)
|
||||
return nil, fmt.Errorf("%w: %T", ErrUnknownSSHDstAliasType, v)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2003,7 +2046,7 @@ func (a SSHSrcAliases) MarshalJSON() ([]byte, error) {
|
|||
case Asterix:
|
||||
aliases[i] = "*"
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown SSH source alias type: %T", v)
|
||||
return nil, fmt.Errorf("%w: %T", ErrUnknownSSHSrcAliasType, v)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2077,11 +2120,11 @@ func unmarshalPolicy(b []byte) (*Policy, error) {
|
|||
ast.Standardize()
|
||||
|
||||
if err = json.Unmarshal(ast.Pack(), &policy, policyJSONOpts...); err != nil {
|
||||
if serr, ok := errors.AsType[*json.SemanticError](err); ok && serr.Err == json.ErrUnknownName {
|
||||
if serr, ok := errors.AsType[*json.SemanticError](err); ok && errors.Is(serr.Err, json.ErrUnknownName) {
|
||||
ptr := serr.JSONPointer
|
||||
name := ptr.LastToken()
|
||||
|
||||
return nil, fmt.Errorf("unknown field %q", name)
|
||||
return nil, fmt.Errorf("%w: %q", ErrUnknownField, name)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("parsing policy from bytes: %w", err)
|
||||
|
|
@ -2109,7 +2152,7 @@ func validateProtocolPortCompatibility(protocol Protocol, destinations []AliasWi
|
|||
for _, portRange := range dst.Ports {
|
||||
// Check if it's not a wildcard port (0-65535)
|
||||
if portRange.First != 0 || portRange.Last != 65535 {
|
||||
return fmt.Errorf("protocol %q does not support specific ports; only \"*\" is allowed", protocol)
|
||||
return fmt.Errorf("%w: %q only allows \"*\"", ErrProtocolNoSpecificPorts, protocol)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -366,7 +366,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
]
|
||||
}
|
||||
`,
|
||||
wantErr: "alias v2.Asterix is not supported for SSH source",
|
||||
wantErr: "alias type not supported for SSH source: v2.Asterix",
|
||||
},
|
||||
{
|
||||
name: "invalid-username",
|
||||
|
|
@ -380,7 +380,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
},
|
||||
}
|
||||
`,
|
||||
wantErr: `Username has to contain @, got: "invalid"`,
|
||||
wantErr: `username must contain @: got "invalid"`,
|
||||
},
|
||||
{
|
||||
name: "invalid-group",
|
||||
|
|
@ -393,7 +393,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
},
|
||||
}
|
||||
`,
|
||||
wantErr: `Group has to start with "group:", got: "grou:example"`,
|
||||
wantErr: `group must start with 'group:': got "grou:example"`,
|
||||
},
|
||||
{
|
||||
name: "group-in-group",
|
||||
|
|
@ -408,7 +408,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
}
|
||||
`,
|
||||
// wantErr: `Username has to contain @, got: "group:inner"`,
|
||||
wantErr: `Nested groups are not allowed, found "group:inner" inside "group:example"`,
|
||||
wantErr: `nested groups not allowed: found "group:inner" inside "group:example"`,
|
||||
},
|
||||
{
|
||||
name: "invalid-addr",
|
||||
|
|
@ -419,7 +419,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
},
|
||||
}
|
||||
`,
|
||||
wantErr: `Hostname "derp" contains an invalid IP address: "10.0"`,
|
||||
wantErr: `invalid IP address: hostname "derp" value "10.0"`,
|
||||
},
|
||||
{
|
||||
name: "invalid-prefix",
|
||||
|
|
@ -430,7 +430,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
},
|
||||
}
|
||||
`,
|
||||
wantErr: `Hostname "derp" contains an invalid IP address: "10.0/42"`,
|
||||
wantErr: `invalid IP address: hostname "derp" value "10.0/42"`,
|
||||
},
|
||||
// TODO(kradalby): Figure out why this doesn't work.
|
||||
// {
|
||||
|
|
@ -459,7 +459,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
],
|
||||
}
|
||||
`,
|
||||
wantErr: `AutoGroup is invalid, got: "autogroup:invalid", must be one of [autogroup:internet autogroup:member autogroup:nonroot autogroup:tagged autogroup:self]`,
|
||||
wantErr: `invalid autogroup: got "autogroup:invalid", must be one of [autogroup:internet autogroup:member autogroup:nonroot autogroup:tagged autogroup:self]`,
|
||||
},
|
||||
{
|
||||
name: "undefined-hostname-errors-2490",
|
||||
|
|
@ -478,7 +478,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
]
|
||||
}
|
||||
`,
|
||||
wantErr: `Host "user1" is not defined in the Policy, please define or remove the reference to it`,
|
||||
wantErr: `host not defined in policy: "user1" - please define or remove the reference`,
|
||||
},
|
||||
{
|
||||
name: "defined-hostname-does-not-err-2490",
|
||||
|
|
@ -571,7 +571,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
]
|
||||
}
|
||||
`,
|
||||
wantErr: `"autogroup:internet" used in source, it can only be used in ACL destinations`,
|
||||
wantErr: `autogroup:internet can only be used in ACL destinations`,
|
||||
},
|
||||
{
|
||||
name: "autogroup:internet-in-ssh-src-not-allowed",
|
||||
|
|
@ -590,7 +590,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
]
|
||||
}
|
||||
`,
|
||||
wantErr: `"autogroup:internet" used in SSH source, it can only be used in ACL destinations`,
|
||||
wantErr: `autogroup not supported for SSH: autogroup:internet in SSH source`,
|
||||
},
|
||||
{
|
||||
name: "autogroup:internet-in-ssh-dst-not-allowed",
|
||||
|
|
@ -609,7 +609,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
]
|
||||
}
|
||||
`,
|
||||
wantErr: `"autogroup:internet" used in SSH destination, it can only be used in ACL destinations`,
|
||||
wantErr: `autogroup not supported for SSH: autogroup:internet in SSH destination`,
|
||||
},
|
||||
{
|
||||
name: "ssh-basic",
|
||||
|
|
@ -760,7 +760,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
]
|
||||
}
|
||||
`,
|
||||
wantErr: `Group "group:notdefined" is not defined in the Policy, please define or remove the reference to it`,
|
||||
wantErr: `group not defined in policy: "group:notdefined"`,
|
||||
},
|
||||
{
|
||||
name: "group-must-be-defined-acl-dst",
|
||||
|
|
@ -779,7 +779,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
]
|
||||
}
|
||||
`,
|
||||
wantErr: `Group "group:notdefined" is not defined in the Policy, please define or remove the reference to it`,
|
||||
wantErr: `group not defined in policy: "group:notdefined"`,
|
||||
},
|
||||
{
|
||||
name: "group-must-be-defined-acl-ssh-src",
|
||||
|
|
@ -798,7 +798,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
]
|
||||
}
|
||||
`,
|
||||
wantErr: `Group "group:notdefined" is not defined in the Policy, please define or remove the reference to it`,
|
||||
wantErr: `group not defined in policy: "group:notdefined"`,
|
||||
},
|
||||
{
|
||||
name: "group-must-be-defined-acl-tagOwner",
|
||||
|
|
@ -809,7 +809,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
},
|
||||
}
|
||||
`,
|
||||
wantErr: `Group "group:notdefined" is not defined in the Policy, please define or remove the reference to it`,
|
||||
wantErr: `group not defined in policy: "group:notdefined"`,
|
||||
},
|
||||
{
|
||||
name: "group-must-be-defined-acl-autoapprover-route",
|
||||
|
|
@ -822,7 +822,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
},
|
||||
}
|
||||
`,
|
||||
wantErr: `Group "group:notdefined" is not defined in the Policy, please define or remove the reference to it`,
|
||||
wantErr: `group not defined in policy: "group:notdefined"`,
|
||||
},
|
||||
{
|
||||
name: "group-must-be-defined-acl-autoapprover-exitnode",
|
||||
|
|
@ -833,7 +833,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
},
|
||||
}
|
||||
`,
|
||||
wantErr: `Group "group:notdefined" is not defined in the Policy, please define or remove the reference to it`,
|
||||
wantErr: `group not defined in policy: "group:notdefined"`,
|
||||
},
|
||||
{
|
||||
name: "tag-must-be-defined-acl-src",
|
||||
|
|
@ -852,7 +852,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
]
|
||||
}
|
||||
`,
|
||||
wantErr: `Tag "tag:notdefined" is not defined in the Policy, please define or remove the reference to it`,
|
||||
wantErr: `tag not defined in policy: "tag:notdefined"`,
|
||||
},
|
||||
{
|
||||
name: "tag-must-be-defined-acl-dst",
|
||||
|
|
@ -871,7 +871,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
]
|
||||
}
|
||||
`,
|
||||
wantErr: `Tag "tag:notdefined" is not defined in the Policy, please define or remove the reference to it`,
|
||||
wantErr: `tag not defined in policy: "tag:notdefined"`,
|
||||
},
|
||||
{
|
||||
name: "tag-must-be-defined-acl-ssh-src",
|
||||
|
|
@ -890,7 +890,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
]
|
||||
}
|
||||
`,
|
||||
wantErr: `Tag "tag:notdefined" is not defined in the Policy, please define or remove the reference to it`,
|
||||
wantErr: `tag not defined in policy: "tag:notdefined"`,
|
||||
},
|
||||
{
|
||||
name: "tag-must-be-defined-acl-ssh-dst",
|
||||
|
|
@ -912,7 +912,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
]
|
||||
}
|
||||
`,
|
||||
wantErr: `Tag "tag:notdefined" is not defined in the Policy, please define or remove the reference to it`,
|
||||
wantErr: `tag not defined in policy: "tag:notdefined"`,
|
||||
},
|
||||
{
|
||||
name: "tag-must-be-defined-acl-autoapprover-route",
|
||||
|
|
@ -925,7 +925,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
},
|
||||
}
|
||||
`,
|
||||
wantErr: `Tag "tag:notdefined" is not defined in the Policy, please define or remove the reference to it`,
|
||||
wantErr: `tag not defined in policy: "tag:notdefined"`,
|
||||
},
|
||||
{
|
||||
name: "tag-must-be-defined-acl-autoapprover-exitnode",
|
||||
|
|
@ -936,7 +936,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
},
|
||||
}
|
||||
`,
|
||||
wantErr: `Tag "tag:notdefined" is not defined in the Policy, please define or remove the reference to it`,
|
||||
wantErr: `tag not defined in policy: "tag:notdefined"`,
|
||||
},
|
||||
{
|
||||
name: "missing-dst-port-is-err",
|
||||
|
|
@ -955,7 +955,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
]
|
||||
}
|
||||
`,
|
||||
wantErr: `hostport must contain a colon (":")`,
|
||||
wantErr: `hostport must contain a colon`,
|
||||
},
|
||||
{
|
||||
name: "dst-port-zero-is-err",
|
||||
|
|
@ -985,7 +985,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
]
|
||||
}
|
||||
`,
|
||||
wantErr: `unknown field "rules"`,
|
||||
wantErr: `unknown field in policy: "rules"`,
|
||||
},
|
||||
{
|
||||
name: "disallow-unsupported-fields-nested",
|
||||
|
|
@ -1008,7 +1008,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
}
|
||||
}
|
||||
`,
|
||||
wantErr: `Group has to start with "group:", got: "INVALID_GROUP_FIELD"`,
|
||||
wantErr: `group must start with 'group:': got "INVALID_GROUP_FIELD"`,
|
||||
},
|
||||
{
|
||||
name: "invalid-group-datatype",
|
||||
|
|
@ -1020,7 +1020,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
}
|
||||
}
|
||||
`,
|
||||
wantErr: `Group "group:invalid" value must be an array of users, got string: "should fail"`,
|
||||
wantErr: `group value must be an array: group "group:invalid" got string "should fail"`,
|
||||
},
|
||||
{
|
||||
name: "invalid-group-name-and-datatype-fails-on-name-first",
|
||||
|
|
@ -1032,7 +1032,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
}
|
||||
}
|
||||
`,
|
||||
wantErr: `Group has to start with "group:", got: "INVALID_GROUP_FIELD"`,
|
||||
wantErr: `group must start with 'group:': got "INVALID_GROUP_FIELD"`,
|
||||
},
|
||||
{
|
||||
name: "disallow-unsupported-fields-hosts-level",
|
||||
|
|
@ -1044,7 +1044,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
}
|
||||
}
|
||||
`,
|
||||
wantErr: `Hostname "INVALID_HOST_FIELD" contains an invalid IP address: "should fail"`,
|
||||
wantErr: `invalid IP address: hostname "INVALID_HOST_FIELD" value "should fail"`,
|
||||
},
|
||||
{
|
||||
name: "disallow-unsupported-fields-tagowners-level",
|
||||
|
|
@ -1056,7 +1056,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
}
|
||||
}
|
||||
`,
|
||||
wantErr: `tag has to start with "tag:", got: "INVALID_TAG_FIELD"`,
|
||||
wantErr: `tag must start with 'tag:': got "INVALID_TAG_FIELD"`,
|
||||
},
|
||||
{
|
||||
name: "disallow-unsupported-fields-acls-level",
|
||||
|
|
@ -1073,7 +1073,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
]
|
||||
}
|
||||
`,
|
||||
wantErr: `unknown field "INVALID_ACL_FIELD"`,
|
||||
wantErr: `unknown field in policy: "INVALID_ACL_FIELD"`,
|
||||
},
|
||||
{
|
||||
name: "disallow-unsupported-fields-ssh-level",
|
||||
|
|
@ -1090,7 +1090,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
]
|
||||
}
|
||||
`,
|
||||
wantErr: `unknown field "INVALID_SSH_FIELD"`,
|
||||
wantErr: `unknown field in policy: "INVALID_SSH_FIELD"`,
|
||||
},
|
||||
{
|
||||
name: "disallow-unsupported-fields-policy-level",
|
||||
|
|
@ -1107,7 +1107,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
"INVALID_POLICY_FIELD": "should fail at policy level"
|
||||
}
|
||||
`,
|
||||
wantErr: `unknown field "INVALID_POLICY_FIELD"`,
|
||||
wantErr: `unknown field in policy: "INVALID_POLICY_FIELD"`,
|
||||
},
|
||||
{
|
||||
name: "disallow-unsupported-fields-autoapprovers-level",
|
||||
|
|
@ -1122,7 +1122,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
}
|
||||
}
|
||||
`,
|
||||
wantErr: `unknown field "INVALID_AUTO_APPROVER_FIELD"`,
|
||||
wantErr: `unknown field in policy: "INVALID_AUTO_APPROVER_FIELD"`,
|
||||
},
|
||||
// headscale-admin uses # in some field names to add metadata, so we will ignore
|
||||
// those to ensure it doesnt break.
|
||||
|
|
@ -1181,7 +1181,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
]
|
||||
}
|
||||
`,
|
||||
wantErr: `unknown field "proto"`,
|
||||
wantErr: `unknown field in policy: "proto"`,
|
||||
},
|
||||
{
|
||||
name: "protocol-wildcard-not-allowed",
|
||||
|
|
@ -1197,7 +1197,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
]
|
||||
}
|
||||
`,
|
||||
wantErr: `proto name "*" not known; use protocol number 0-255 or protocol name (icmp, tcp, udp, etc.)`,
|
||||
wantErr: `invalid protocol: use protocol number 0-255 or protocol name`,
|
||||
},
|
||||
{
|
||||
name: "protocol-case-insensitive-uppercase",
|
||||
|
|
@ -1277,7 +1277,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
]
|
||||
}
|
||||
`,
|
||||
wantErr: `leading 0 not permitted in protocol number "0"`,
|
||||
wantErr: `leading zero not permitted in protocol number: "0"`,
|
||||
},
|
||||
{
|
||||
name: "protocol-empty-applies-to-tcp-udp-only",
|
||||
|
|
@ -1324,7 +1324,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
]
|
||||
}
|
||||
`,
|
||||
wantErr: `protocol "icmp" does not support specific ports; only "*" is allowed`,
|
||||
wantErr: `protocol does not support specific ports: "icmp" only allows "*"`,
|
||||
},
|
||||
{
|
||||
name: "protocol-icmp-with-wildcard-port-allowed",
|
||||
|
|
@ -1372,7 +1372,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
]
|
||||
}
|
||||
`,
|
||||
wantErr: `protocol "gre" does not support specific ports; only "*" is allowed`,
|
||||
wantErr: `protocol does not support specific ports: "gre" only allows "*"`,
|
||||
},
|
||||
{
|
||||
name: "protocol-tcp-with-specific-port-allowed",
|
||||
|
|
@ -1836,7 +1836,7 @@ func TestResolvePolicy(t *testing.T) {
|
|||
IPv4: ap("100.100.101.103"),
|
||||
},
|
||||
},
|
||||
wantErr: `user with token "invaliduser@" not found`,
|
||||
wantErr: `user not found: token "invaliduser@"`,
|
||||
},
|
||||
{
|
||||
name: "invalid-tag",
|
||||
|
|
@ -1999,7 +1999,7 @@ func TestResolvePolicy(t *testing.T) {
|
|||
{
|
||||
name: "autogroup-invalid",
|
||||
toResolve: new(AutoGroup("autogroup:invalid")),
|
||||
wantErr: "unknown autogroup",
|
||||
wantErr: "invalid autogroup",
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -2670,7 +2670,7 @@ func TestNodeCanHaveTag(t *testing.T) {
|
|||
node: nodes[0],
|
||||
tag: "tag:test",
|
||||
want: false,
|
||||
wantErr: "Username has to contain @",
|
||||
wantErr: "username must contain @",
|
||||
},
|
||||
{
|
||||
name: "node-cannot-have-tag",
|
||||
|
|
@ -3248,7 +3248,8 @@ func TestACL_UnmarshalJSON_InvalidAction(t *testing.T) {
|
|||
|
||||
_, err := unmarshalPolicy([]byte(policyJSON))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), `invalid action "deny"`)
|
||||
assert.Contains(t, err.Error(), `invalid action`)
|
||||
assert.Contains(t, err.Error(), `deny`)
|
||||
}
|
||||
|
||||
// Helper function to parse aliases for testing.
|
||||
|
|
|
|||
|
|
@ -9,6 +9,18 @@ import (
|
|||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
// Sentinel errors for port and destination parsing.
|
||||
var (
|
||||
ErrInputMissingColon = errors.New("input must contain a colon character separating destination and port")
|
||||
ErrInputStartsWithColon = errors.New("input cannot start with a colon character")
|
||||
ErrInputEndsWithColon = errors.New("input cannot end with a colon character")
|
||||
ErrInvalidPortRange = errors.New("invalid port range format")
|
||||
ErrPortRangeInverted = errors.New("invalid port range: first port is greater than last port")
|
||||
ErrPortMustBePositive = errors.New("first port must be >0, or use '*' for wildcard")
|
||||
ErrInvalidPortNumber = errors.New("invalid port number")
|
||||
ErrPortOutOfRange = errors.New("port number out of range")
|
||||
)
|
||||
|
||||
// splitDestinationAndPort takes an input string and returns the destination and port as a tuple, or an error if the input is invalid.
|
||||
func splitDestinationAndPort(input string) (string, string, error) {
|
||||
// Find the last occurrence of the colon character
|
||||
|
|
@ -16,15 +28,15 @@ func splitDestinationAndPort(input string) (string, string, error) {
|
|||
|
||||
// Check if the colon character is present and not at the beginning or end of the string
|
||||
if lastColonIndex == -1 {
|
||||
return "", "", errors.New("input must contain a colon character separating destination and port")
|
||||
return "", "", ErrInputMissingColon
|
||||
}
|
||||
|
||||
if lastColonIndex == 0 {
|
||||
return "", "", errors.New("input cannot start with a colon character")
|
||||
return "", "", ErrInputStartsWithColon
|
||||
}
|
||||
|
||||
if lastColonIndex == len(input)-1 {
|
||||
return "", "", errors.New("input cannot end with a colon character")
|
||||
return "", "", ErrInputEndsWithColon
|
||||
}
|
||||
|
||||
// Split the string into destination and port based on the last colon
|
||||
|
|
@ -52,7 +64,7 @@ func parsePortRange(portDef string) ([]tailcfg.PortRange, error) {
|
|||
return e == ""
|
||||
})
|
||||
if len(rangeParts) != 2 {
|
||||
return nil, errors.New("invalid port range format")
|
||||
return nil, ErrInvalidPortRange
|
||||
}
|
||||
|
||||
first, err := parsePort(rangeParts[0])
|
||||
|
|
@ -66,7 +78,7 @@ func parsePortRange(portDef string) ([]tailcfg.PortRange, error) {
|
|||
}
|
||||
|
||||
if first > last {
|
||||
return nil, errors.New("invalid port range: first port is greater than last port")
|
||||
return nil, ErrPortRangeInverted
|
||||
}
|
||||
|
||||
portRanges = append(portRanges, tailcfg.PortRange{First: first, Last: last})
|
||||
|
|
@ -77,7 +89,7 @@ func parsePortRange(portDef string) ([]tailcfg.PortRange, error) {
|
|||
}
|
||||
|
||||
if port < 1 {
|
||||
return nil, errors.New("first port must be >0, or use '*' for wildcard")
|
||||
return nil, ErrPortMustBePositive
|
||||
}
|
||||
|
||||
portRanges = append(portRanges, tailcfg.PortRange{First: port, Last: port})
|
||||
|
|
@ -91,11 +103,11 @@ func parsePortRange(portDef string) ([]tailcfg.PortRange, error) {
|
|||
func parsePort(portStr string) (uint16, error) {
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return 0, errors.New("invalid port number")
|
||||
return 0, ErrInvalidPortNumber
|
||||
}
|
||||
|
||||
if port < 0 || port > 65535 {
|
||||
return 0, errors.New("port number out of range")
|
||||
return 0, ErrPortOutOfRange
|
||||
}
|
||||
|
||||
return uint16(port), nil
|
||||
|
|
|
|||
|
|
@ -24,14 +24,14 @@ func TestParseDestinationAndPort(t *testing.T) {
|
|||
{"tag:api-server:443", "tag:api-server", "443", nil},
|
||||
{"example-host-1:*", "example-host-1", "*", nil},
|
||||
{"hostname:80-90", "hostname", "80-90", nil},
|
||||
{"invalidinput", "", "", errors.New("input must contain a colon character separating destination and port")},
|
||||
{":invalid", "", "", errors.New("input cannot start with a colon character")},
|
||||
{"invalid:", "", "", errors.New("input cannot end with a colon character")},
|
||||
{"invalidinput", "", "", ErrInputMissingColon},
|
||||
{":invalid", "", "", ErrInputStartsWithColon},
|
||||
{"invalid:", "", "", ErrInputEndsWithColon},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
dst, port, err := splitDestinationAndPort(testCase.input)
|
||||
if dst != testCase.expectedDst || port != testCase.expectedPort || (err != nil && err.Error() != testCase.expectedErr.Error()) {
|
||||
if dst != testCase.expectedDst || port != testCase.expectedPort || !errors.Is(err, testCase.expectedErr) {
|
||||
t.Errorf("parseDestinationAndPort(%q) = (%q, %q, %v), want (%q, %q, %v)",
|
||||
testCase.input, dst, port, err, testCase.expectedDst, testCase.expectedPort, testCase.expectedErr)
|
||||
}
|
||||
|
|
@ -42,27 +42,23 @@ func TestParsePort(t *testing.T) {
|
|||
tests := []struct {
|
||||
input string
|
||||
expected uint16
|
||||
err string
|
||||
err error
|
||||
}{
|
||||
{"80", 80, ""},
|
||||
{"0", 0, ""},
|
||||
{"65535", 65535, ""},
|
||||
{"-1", 0, "port number out of range"},
|
||||
{"65536", 0, "port number out of range"},
|
||||
{"abc", 0, "invalid port number"},
|
||||
{"", 0, "invalid port number"},
|
||||
{"80", 80, nil},
|
||||
{"0", 0, nil},
|
||||
{"65535", 65535, nil},
|
||||
{"-1", 0, ErrPortOutOfRange},
|
||||
{"65536", 0, ErrPortOutOfRange},
|
||||
{"abc", 0, ErrInvalidPortNumber},
|
||||
{"", 0, ErrInvalidPortNumber},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
result, err := parsePort(test.input)
|
||||
if err != nil && err.Error() != test.err {
|
||||
if !errors.Is(err, test.err) {
|
||||
t.Errorf("parsePort(%q) error = %v, expected error = %v", test.input, err, test.err)
|
||||
}
|
||||
|
||||
if err == nil && test.err != "" {
|
||||
t.Errorf("parsePort(%q) expected error = %v, got nil", test.input, test.err)
|
||||
}
|
||||
|
||||
if result != test.expected {
|
||||
t.Errorf("parsePort(%q) = %v, expected %v", test.input, result, test.expected)
|
||||
}
|
||||
|
|
@ -73,32 +69,28 @@ func TestParsePortRange(t *testing.T) {
|
|||
tests := []struct {
|
||||
input string
|
||||
expected []tailcfg.PortRange
|
||||
err string
|
||||
err error
|
||||
}{
|
||||
{"80", []tailcfg.PortRange{{First: 80, Last: 80}}, ""},
|
||||
{"80-90", []tailcfg.PortRange{{First: 80, Last: 90}}, ""},
|
||||
{"80,90", []tailcfg.PortRange{{First: 80, Last: 80}, {First: 90, Last: 90}}, ""},
|
||||
{"80-91,92,93-95", []tailcfg.PortRange{{First: 80, Last: 91}, {First: 92, Last: 92}, {First: 93, Last: 95}}, ""},
|
||||
{"*", []tailcfg.PortRange{tailcfg.PortRangeAny}, ""},
|
||||
{"80-", nil, "invalid port range format"},
|
||||
{"-90", nil, "invalid port range format"},
|
||||
{"80-90,", nil, "invalid port number"},
|
||||
{"80,90-", nil, "invalid port range format"},
|
||||
{"80-90,abc", nil, "invalid port number"},
|
||||
{"80-90,65536", nil, "port number out of range"},
|
||||
{"80-90,90-80", nil, "invalid port range: first port is greater than last port"},
|
||||
{"80", []tailcfg.PortRange{{First: 80, Last: 80}}, nil},
|
||||
{"80-90", []tailcfg.PortRange{{First: 80, Last: 90}}, nil},
|
||||
{"80,90", []tailcfg.PortRange{{First: 80, Last: 80}, {First: 90, Last: 90}}, nil},
|
||||
{"80-91,92,93-95", []tailcfg.PortRange{{First: 80, Last: 91}, {First: 92, Last: 92}, {First: 93, Last: 95}}, nil},
|
||||
{"*", []tailcfg.PortRange{tailcfg.PortRangeAny}, nil},
|
||||
{"80-", nil, ErrInvalidPortRange},
|
||||
{"-90", nil, ErrInvalidPortRange},
|
||||
{"80-90,", nil, ErrInvalidPortNumber},
|
||||
{"80,90-", nil, ErrInvalidPortRange},
|
||||
{"80-90,abc", nil, ErrInvalidPortNumber},
|
||||
{"80-90,65536", nil, ErrPortOutOfRange},
|
||||
{"80-90,90-80", nil, ErrPortRangeInverted},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
result, err := parsePortRange(test.input)
|
||||
if err != nil && err.Error() != test.err {
|
||||
if !errors.Is(err, test.err) {
|
||||
t.Errorf("parsePortRange(%q) error = %v, expected error = %v", test.input, err, test.err)
|
||||
}
|
||||
|
||||
if err == nil && test.err != "" {
|
||||
t.Errorf("parsePortRange(%q) expected error = %v, got nil", test.input, test.err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(result, test.expected); diff != "" {
|
||||
t.Errorf("parsePortRange(%q) mismatch (-want +got):\n%s", test.input, diff)
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue