diff --git a/hscontrol/policy/v2/filter.go b/hscontrol/policy/v2/filter.go index 958902a2..ced8531c 100644 --- a/hscontrol/policy/v2/filter.go +++ b/hscontrol/policy/v2/filter.go @@ -1,7 +1,6 @@ package v2 import ( - "errors" "fmt" "slices" "time" @@ -14,8 +13,6 @@ import ( "tailscale.com/types/views" ) -var ErrInvalidAction = errors.New("invalid action") - // compileFilterRules takes a set of nodes and an ACLPolicy and generates a // set of Tailscale compatible FilterRules used to allow traffic on clients. func (pol *Policy) compileFilterRules( @@ -149,7 +146,7 @@ func (pol *Policy) compileACLWithAutogroupSelf( for _, src := range acl.Sources { if ag, ok := src.(*AutoGroup); ok && ag.Is(AutoGroupSelf) { - return nil, errors.New("autogroup:self cannot be used in sources") + return nil, ErrAutogroupSelfInSource } ips, err := src.Resolve(pol, users, nodes) diff --git a/hscontrol/policy/v2/types.go b/hscontrol/policy/v2/types.go index 41e7e0d9..4ff5dd1a 100644 --- a/hscontrol/policy/v2/types.go +++ b/hscontrol/policy/v2/types.go @@ -36,6 +36,73 @@ var ErrCircularReference = errors.New("circular reference detected") var ErrUndefinedTagReference = errors.New("references undefined tag") +// Sentinel errors for type/alias validation. +var ( + ErrUnknownAliasType = errors.New("unknown alias type") + ErrUnknownOwnerType = errors.New("unknown owner type") + ErrUnknownAutoApproverType = errors.New("unknown auto approver type") + ErrInvalidAlias = errors.New("invalid alias") + ErrInvalidAutoApprover = errors.New("invalid auto approver") + ErrInvalidOwner = errors.New("invalid owner") +) + +// Sentinel errors for format validation. +var ( + ErrUsernameMissingAt = errors.New("username must contain @") + ErrGroupMissingPrefix = errors.New("group must start with 'group:'") + ErrTagMissingPrefix = errors.New("tag must start with 'tag:'") + ErrInvalidHostname = errors.New("invalid hostname") + ErrInvalidPrefix = errors.New("invalid prefix") + ErrInvalidAutoGroup = errors.New("invalid autogroup") + ErrInvalidAction = errors.New("invalid action") + ErrInvalidSSHAction = errors.New("invalid SSH action") + ErrInvalidProtocol = errors.New("invalid protocol") + ErrProtocolOutOfRange = errors.New("protocol number out of range") + ErrLeadingZeroProtocol = errors.New("leading zero not permitted in protocol number") + ErrHostportMissingColon = errors.New("hostport must contain a colon") + ErrUnsupportedType = errors.New("unsupported type") +) + +// Sentinel errors for resolution/lookup failures. +var ( + ErrUserNotFound = errors.New("user not found") + ErrMultipleUsersFound = errors.New("multiple users found") + ErrHostNotResolved = errors.New("unable to resolve host") + ErrGroupNotDefined = errors.New("group not defined in policy") + ErrTagNotDefined = errors.New("tag not defined in policy") + ErrHostNotDefined = errors.New("host not defined in policy") + ErrInvalidIPAddress = errors.New("invalid IP address") + ErrNestedGroups = errors.New("nested groups not allowed") + ErrInvalidGroupMember = errors.New("invalid group member type") + ErrGroupValueNotArray = errors.New("group value must be an array") + ErrAutoApproverNotAlias = errors.New("auto approver is not an alias") +) + +// Sentinel errors for autogroup context validation. +var ( + ErrAutogroupInternetInSource = errors.New("autogroup:internet can only be used in ACL destinations") + ErrAutogroupSelfInSource = errors.New("autogroup:self can only be used in ACL destinations") + ErrAutogroupNotSupportedSource = errors.New("autogroup not supported for source") + ErrAutogroupNotSupportedDest = errors.New("autogroup not supported for destination") + ErrAutogroupNotSupportedSSH = errors.New("autogroup not supported for SSH") + ErrAutogroupNotSupported = errors.New("autogroup not supported in headscale") + ErrAliasNotSupportedSSH = errors.New("alias type not supported for SSH") +) + +// Sentinel errors for SSH aliases. +var ( + ErrAliasNotSupportedSSHSrc = errors.New("alias type not supported for SSH source") + ErrAliasNotSupportedSSHDst = errors.New("alias type not supported for SSH destination") + ErrUnknownSSHSrcAliasType = errors.New("unknown SSH source alias type") + ErrUnknownSSHDstAliasType = errors.New("unknown SSH destination alias type") +) + +// Sentinel errors for policy parsing. +var ( + ErrUnknownField = errors.New("unknown field in policy") + ErrProtocolNoSpecificPorts = errors.New("protocol does not support specific ports") +) + type Asterix int func (a Asterix) Validate() error { @@ -75,7 +142,7 @@ func (a AliasWithPorts) MarshalJSON() ([]byte, error) { case Asterix: alias = "*" default: - return nil, fmt.Errorf("unknown alias type: %T", v) + return nil, fmt.Errorf("%w: %T", ErrUnknownAliasType, v) } // If no ports are specified @@ -126,7 +193,7 @@ func (u Username) Validate() error { return nil } - return fmt.Errorf("Username has to contain @, got: %q", u) + return fmt.Errorf("%w: got %q", ErrUsernameMissingAt, u) } func (u *Username) String() string { @@ -186,11 +253,11 @@ func (u Username) resolveUser(users types.Users) (types.User, error) { } if len(potentialUsers) == 0 { - return types.User{}, fmt.Errorf("user with token %q not found", u.String()) + return types.User{}, fmt.Errorf("%w: token %q", ErrUserNotFound, u.String()) } if len(potentialUsers) > 1 { - return types.User{}, fmt.Errorf("multiple users with token %q found: %s", u.String(), potentialUsers.String()) + return types.User{}, fmt.Errorf("%w: token %q found %s", ErrMultipleUsersFound, u.String(), potentialUsers.String()) } return potentialUsers[0], nil @@ -234,7 +301,7 @@ func (g Group) Validate() error { return nil } - return fmt.Errorf(`Group has to start with "group:", got: %q`, g) + return fmt.Errorf("%w: got %q", ErrGroupMissingPrefix, g) } func (g *Group) UnmarshalJSON(b []byte) error { @@ -299,7 +366,7 @@ func (t Tag) Validate() error { return nil } - return fmt.Errorf(`tag has to start with "tag:", got: %q`, t) + return fmt.Errorf("%w: got %q", ErrTagMissingPrefix, t) } func (t *Tag) UnmarshalJSON(b []byte) error { @@ -349,7 +416,7 @@ func (h Host) Validate() error { return nil } - return fmt.Errorf("Hostname %q is invalid", h) + return fmt.Errorf("%w: %q", ErrInvalidHostname, h) } func (h *Host) UnmarshalJSON(b []byte) error { @@ -369,7 +436,7 @@ func (h Host) Resolve(p *Policy, _ types.Users, nodes views.Slice[types.NodeView pref, ok := p.Hosts[h] if !ok { - return nil, fmt.Errorf("unable to resolve host: %q", h) + return nil, fmt.Errorf("%w: %q", ErrHostNotResolved, h) } err := pref.Validate() @@ -406,7 +473,7 @@ func (p Prefix) Validate() error { return nil } - return fmt.Errorf("Prefix %q is invalid", p) + return fmt.Errorf("%w: %q", ErrInvalidPrefix, p) } func (p Prefix) String() string { @@ -510,7 +577,7 @@ func (ag AutoGroup) Validate() error { return nil } - return fmt.Errorf("AutoGroup is invalid, got: %q, must be one of %v", ag, autogroups) + return fmt.Errorf("%w: got %q, must be one of %v", ErrInvalidAutoGroup, ag, autogroups) } func (ag *AutoGroup) UnmarshalJSON(b []byte) error { @@ -570,7 +637,7 @@ func (ag AutoGroup) Resolve(p *Policy, users types.Users, nodes views.Slice[type return nil, ErrAutogroupSelfRequiresPerNodeResolution default: - return nil, fmt.Errorf("unknown autogroup %q", ag) + return nil, fmt.Errorf("%w: %q", ErrInvalidAutoGroup, ag) } } @@ -626,7 +693,7 @@ func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error { ve.Ports = ports } else { - return errors.New(`hostport must contain a colon (":")`) + return ErrHostportMissingColon } ve.Alias, err = parseAlias(vs) @@ -639,7 +706,7 @@ func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error { } default: - return fmt.Errorf("type %T not supported", vs) + return fmt.Errorf("%w: %T", ErrUnsupportedType, vs) } return nil @@ -694,15 +761,7 @@ func parseAlias(vs string) (Alias, error) { return new(Host(vs)), nil } - return nil, fmt.Errorf(`Invalid alias %q. An alias must be one of the following types: -- wildcard (*) -- user (containing an "@") -- group (starting with "group:") -- tag (starting with "tag:") -- autogroup (starting with "autogroup:") -- host - -Please check the format and try again.`, vs) + return nil, fmt.Errorf("%w: %q", ErrInvalidAlias, vs) } // AliasEnc is used to deserialize a Alias. @@ -764,7 +823,7 @@ func (a Aliases) MarshalJSON() ([]byte, error) { case Asterix: aliases[i] = "*" default: - return nil, fmt.Errorf("unknown alias type: %T", v) + return nil, fmt.Errorf("%w: %T", ErrUnknownAliasType, v) } } @@ -850,7 +909,7 @@ func (aa AutoApprovers) MarshalJSON() ([]byte, error) { case *Group: approvers[i] = string(*v) default: - return nil, fmt.Errorf("unknown auto approver type: %T", v) + return nil, fmt.Errorf("%w: %T", ErrUnknownAutoApproverType, v) } } @@ -867,12 +926,7 @@ func parseAutoApprover(s string) (AutoApprover, error) { return new(Tag(s)), nil } - return nil, fmt.Errorf(`Invalid AutoApprover %q. An alias must be one of the following types: -- user (containing an "@") -- group (starting with "group:") -- tag (starting with "tag:") - -Please check the format and try again.`, s) + return nil, fmt.Errorf("%w: %q", ErrInvalidAutoApprover, s) } // AutoApproverEnc is used to deserialize a AutoApprover. @@ -949,7 +1003,7 @@ func (o Owners) MarshalJSON() ([]byte, error) { case *Tag: owners[i] = string(*v) default: - return nil, fmt.Errorf("unknown owner type: %T", v) + return nil, fmt.Errorf("%w: %T", ErrUnknownOwnerType, v) } } @@ -966,12 +1020,7 @@ func parseOwner(s string) (Owner, error) { return new(Tag(s)), nil } - return nil, fmt.Errorf(`Invalid Owner %q. An alias must be one of the following types: -- user (containing an "@") -- group (starting with "group:") -- tag (starting with "tag:") - -Please check the format and try again.`, s) + return nil, fmt.Errorf("%w: %q", ErrInvalidOwner, s) } type Usernames []Username @@ -990,7 +1039,7 @@ func (g Groups) Contains(group *Group) error { } } - return fmt.Errorf(`Group %q is not defined in the Policy, please define or remove the reference to it`, group) + return fmt.Errorf("%w: %q", ErrGroupNotDefined, group) } // UnmarshalJSON overrides the default JSON unmarshalling for Groups to ensure @@ -1025,15 +1074,15 @@ func (g *Groups) UnmarshalJSON(b []byte) error { if str, ok := item.(string); ok { stringSlice = append(stringSlice, str) } else { - return fmt.Errorf(`Group "%s" contains invalid member type, expected string but got %T`, key, item) + return fmt.Errorf("%w: group %q got %T", ErrInvalidGroupMember, key, item) } } rawGroups[key] = stringSlice case string: - return fmt.Errorf(`Group "%s" value must be an array of users, got string: "%s"`, key, v) + return fmt.Errorf("%w: group %q got string %q", ErrGroupValueNotArray, key, v) default: - return fmt.Errorf(`Group "%s" value must be an array of users, got %T`, key, v) + return fmt.Errorf("%w: group %q got %T", ErrGroupValueNotArray, key, v) } } @@ -1048,7 +1097,7 @@ func (g *Groups) UnmarshalJSON(b []byte) error { username := Username(u) if err := username.Validate(); err != nil { if isGroup(u) { - return fmt.Errorf("Nested groups are not allowed, found %q inside %q", u, group) + return fmt.Errorf("%w: found %q inside %q", ErrNestedGroups, u, group) } return err @@ -1082,7 +1131,7 @@ func (h *Hosts) UnmarshalJSON(b []byte) error { var prefix Prefix if err := prefix.parseString(value); err != nil { - return fmt.Errorf(`Hostname "%s" contains an invalid IP address: "%s"`, key, value) + return fmt.Errorf("%w: hostname %q value %q", ErrInvalidIPAddress, key, value) } (*h)[host] = prefix @@ -1131,7 +1180,7 @@ func (to TagOwners) MarshalJSON() ([]byte, error) { case *Tag: ownerStrs[i] = string(*v) default: - return nil, fmt.Errorf("unknown owner type: %T", v) + return nil, fmt.Errorf("%w: %T", ErrUnknownOwnerType, v) } } @@ -1155,7 +1204,7 @@ func (to TagOwners) Contains(tagOwner *Tag) error { } } - return fmt.Errorf(`Tag %q is not defined in the Policy, please define or remove the reference to it`, tagOwner) + return fmt.Errorf("%w: %q", ErrTagNotDefined, tagOwner) } type AutoApproverPolicy struct { @@ -1208,7 +1257,7 @@ func resolveAutoApprovers(p *Policy, users types.Users, nodes views.Slice[types. aa, ok := autoApprover.(Alias) if !ok { // Should never happen - return nil, nil, fmt.Errorf("autoApprover %v is not an Alias", autoApprover) + return nil, nil, fmt.Errorf("%w: %v", ErrAutoApproverNotAlias, autoApprover) } // If it does not resolve, that means the autoApprover is not associated with any IP addresses. ips, _ := aa.Resolve(p, users, nodes) @@ -1223,7 +1272,7 @@ func resolveAutoApprovers(p *Policy, users types.Users, nodes views.Slice[types. aa, ok := autoApprover.(Alias) if !ok { // Should never happen - return nil, nil, fmt.Errorf("autoApprover %v is not an Alias", autoApprover) + return nil, nil, fmt.Errorf("%w: %v", ErrAutoApproverNotAlias, autoApprover) } // If it does not resolve, that means the autoApprover is not associated with any IP addresses. ips, _ := aa.Resolve(p, users, nodes) @@ -1280,7 +1329,7 @@ func (a *Action) UnmarshalJSON(b []byte) error { case "accept": *a = ActionAccept default: - return fmt.Errorf("invalid action %q, must be %q", str, ActionAccept) + return fmt.Errorf("%w: %q, must be %q", ErrInvalidAction, str, ActionAccept) } return nil @@ -1305,7 +1354,7 @@ func (a *SSHAction) UnmarshalJSON(b []byte) error { case "check": *a = SSHActionCheck default: - return fmt.Errorf("invalid SSH action %q, must be one of: accept, check", str) + return fmt.Errorf("%w: %q, must be one of: accept, check", ErrInvalidSSHAction, str) } return nil @@ -1443,23 +1492,23 @@ func (p Protocol) validate() error { return nil case ProtocolWildcard: // Wildcard "*" is not allowed - Tailscale rejects it - return errors.New("proto name \"*\" not known; use protocol number 0-255 or protocol name (icmp, tcp, udp, etc.)") + return fmt.Errorf("%w: use protocol number 0-255 or protocol name", ErrInvalidProtocol) default: // Try to parse as a numeric protocol number str := string(p) // Check for leading zeros (not allowed by Tailscale) if str == "0" || (len(str) > 1 && str[0] == '0') { - return fmt.Errorf("leading 0 not permitted in protocol number \"%s\"", str) + return fmt.Errorf("%w: %q", ErrLeadingZeroProtocol, str) } protocolNumber, err := strconv.Atoi(str) if err != nil { - return fmt.Errorf("invalid protocol %q: must be a known protocol name or valid protocol number 0-255", p) + return fmt.Errorf("%w: %q must be a known protocol name or valid protocol number 0-255", ErrInvalidProtocol, p) } if protocolNumber < 0 || protocolNumber > 255 { - return fmt.Errorf("protocol number %d out of range (0-255)", protocolNumber) + return fmt.Errorf("%w: %d", ErrProtocolOutOfRange, protocolNumber) } return nil @@ -1577,7 +1626,7 @@ func validateAutogroupSupported(ag *AutoGroup) error { } if slices.Contains(autogroupNotSupported, *ag) { - return fmt.Errorf("autogroup %q is not supported in headscale", *ag) + return fmt.Errorf("%w: %q", ErrAutogroupNotSupported, *ag) } return nil @@ -1589,15 +1638,15 @@ func validateAutogroupForSrc(src *AutoGroup) error { } if src.Is(AutoGroupInternet) { - return errors.New(`"autogroup:internet" used in source, it can only be used in ACL destinations`) + return ErrAutogroupInternetInSource } if src.Is(AutoGroupSelf) { - return errors.New(`"autogroup:self" used in source, it can only be used in ACL destinations`) + return ErrAutogroupSelfInSource } if !slices.Contains(autogroupForSrc, *src) { - return fmt.Errorf("autogroup %q is not supported for ACL sources, can be %v", *src, autogroupForSrc) + return fmt.Errorf("%w: %q, can be %v", ErrAutogroupNotSupportedSource, *src, autogroupForSrc) } return nil @@ -1609,7 +1658,7 @@ func validateAutogroupForDst(dst *AutoGroup) error { } if !slices.Contains(autogroupForDst, *dst) { - return fmt.Errorf("autogroup %q is not supported for ACL destinations, can be %v", *dst, autogroupForDst) + return fmt.Errorf("%w: %q, can be %v", ErrAutogroupNotSupportedDest, *dst, autogroupForDst) } return nil @@ -1621,11 +1670,11 @@ func validateAutogroupForSSHSrc(src *AutoGroup) error { } if src.Is(AutoGroupInternet) { - return errors.New(`"autogroup:internet" used in SSH source, it can only be used in ACL destinations`) + return fmt.Errorf("%w: autogroup:internet in SSH source", ErrAutogroupNotSupportedSSH) } if !slices.Contains(autogroupForSSHSrc, *src) { - return fmt.Errorf("autogroup %q is not supported for SSH sources, can be %v", *src, autogroupForSSHSrc) + return fmt.Errorf("%w: %q for SSH sources, can be %v", ErrAutogroupNotSupportedSSH, *src, autogroupForSSHSrc) } return nil @@ -1637,11 +1686,11 @@ func validateAutogroupForSSHDst(dst *AutoGroup) error { } if dst.Is(AutoGroupInternet) { - return errors.New(`"autogroup:internet" used in SSH destination, it can only be used in ACL destinations`) + return fmt.Errorf("%w: autogroup:internet in SSH destination", ErrAutogroupNotSupportedSSH) } if !slices.Contains(autogroupForSSHDst, *dst) { - return fmt.Errorf("autogroup %q is not supported for SSH sources, can be %v", *dst, autogroupForSSHDst) + return fmt.Errorf("%w: %q for SSH destinations, can be %v", ErrAutogroupNotSupportedSSH, *dst, autogroupForSSHDst) } return nil @@ -1653,7 +1702,7 @@ func validateAutogroupForSSHUser(user *AutoGroup) error { } if !slices.Contains(autogroupForSSHUser, *user) { - return fmt.Errorf("autogroup %q is not supported for SSH user, can be %v", *user, autogroupForSSHUser) + return fmt.Errorf("%w: %q for SSH user, can be %v", ErrAutogroupNotSupportedSSH, *user, autogroupForSSHUser) } return nil @@ -1678,7 +1727,7 @@ func (p *Policy) validate() error { case *Host: h := src if !p.Hosts.exist(*h) { - errs = append(errs, fmt.Errorf(`Host %q is not defined in the Policy, please define or remove the reference to it`, *h)) + errs = append(errs, fmt.Errorf("%w: %q - please define or remove the reference", ErrHostNotDefined, *h)) } case *AutoGroup: ag := src @@ -1710,7 +1759,7 @@ func (p *Policy) validate() error { case *Host: h := dst.Alias.(*Host) if !p.Hosts.exist(*h) { - errs = append(errs, fmt.Errorf(`Host %q is not defined in the Policy, please define or remove the reference to it`, *h)) + errs = append(errs, fmt.Errorf("%w: %q - please define or remove the reference", ErrHostNotDefined, *h)) } case *AutoGroup: ag := dst.Alias.(*AutoGroup) @@ -1915,10 +1964,7 @@ func (a *SSHSrcAliases) UnmarshalJSON(b []byte) error { case *Username, *Group, *Tag, *AutoGroup: (*a)[i] = alias.Alias default: - return fmt.Errorf( - "alias %T is not supported for SSH source", - alias.Alias, - ) + return fmt.Errorf("%w: %T", ErrAliasNotSupportedSSHSrc, alias.Alias) } } @@ -1946,10 +1992,7 @@ func (a *SSHDstAliases) UnmarshalJSON(b []byte) error { Asterix: (*a)[i] = alias.Alias default: - return fmt.Errorf( - "alias %T is not supported for SSH destination", - alias.Alias, - ) + return fmt.Errorf("%w: %T", ErrAliasNotSupportedSSHDst, alias.Alias) } } @@ -1976,7 +2019,7 @@ func (a SSHDstAliases) MarshalJSON() ([]byte, error) { case Asterix: aliases[i] = "*" default: - return nil, fmt.Errorf("unknown SSH destination alias type: %T", v) + return nil, fmt.Errorf("%w: %T", ErrUnknownSSHDstAliasType, v) } } @@ -2003,7 +2046,7 @@ func (a SSHSrcAliases) MarshalJSON() ([]byte, error) { case Asterix: aliases[i] = "*" default: - return nil, fmt.Errorf("unknown SSH source alias type: %T", v) + return nil, fmt.Errorf("%w: %T", ErrUnknownSSHSrcAliasType, v) } } @@ -2077,11 +2120,11 @@ func unmarshalPolicy(b []byte) (*Policy, error) { ast.Standardize() if err = json.Unmarshal(ast.Pack(), &policy, policyJSONOpts...); err != nil { - if serr, ok := errors.AsType[*json.SemanticError](err); ok && serr.Err == json.ErrUnknownName { + if serr, ok := errors.AsType[*json.SemanticError](err); ok && errors.Is(serr.Err, json.ErrUnknownName) { ptr := serr.JSONPointer name := ptr.LastToken() - return nil, fmt.Errorf("unknown field %q", name) + return nil, fmt.Errorf("%w: %q", ErrUnknownField, name) } return nil, fmt.Errorf("parsing policy from bytes: %w", err) @@ -2109,7 +2152,7 @@ func validateProtocolPortCompatibility(protocol Protocol, destinations []AliasWi for _, portRange := range dst.Ports { // Check if it's not a wildcard port (0-65535) if portRange.First != 0 || portRange.Last != 65535 { - return fmt.Errorf("protocol %q does not support specific ports; only \"*\" is allowed", protocol) + return fmt.Errorf("%w: %q only allows \"*\"", ErrProtocolNoSpecificPorts, protocol) } } } diff --git a/hscontrol/policy/v2/types_test.go b/hscontrol/policy/v2/types_test.go index 79d005a3..b5e5a210 100644 --- a/hscontrol/policy/v2/types_test.go +++ b/hscontrol/policy/v2/types_test.go @@ -366,7 +366,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: "alias v2.Asterix is not supported for SSH source", + wantErr: "alias type not supported for SSH source: v2.Asterix", }, { name: "invalid-username", @@ -380,7 +380,7 @@ func TestUnmarshalPolicy(t *testing.T) { }, } `, - wantErr: `Username has to contain @, got: "invalid"`, + wantErr: `username must contain @: got "invalid"`, }, { name: "invalid-group", @@ -393,7 +393,7 @@ func TestUnmarshalPolicy(t *testing.T) { }, } `, - wantErr: `Group has to start with "group:", got: "grou:example"`, + wantErr: `group must start with 'group:': got "grou:example"`, }, { name: "group-in-group", @@ -408,7 +408,7 @@ func TestUnmarshalPolicy(t *testing.T) { } `, // wantErr: `Username has to contain @, got: "group:inner"`, - wantErr: `Nested groups are not allowed, found "group:inner" inside "group:example"`, + wantErr: `nested groups not allowed: found "group:inner" inside "group:example"`, }, { name: "invalid-addr", @@ -419,7 +419,7 @@ func TestUnmarshalPolicy(t *testing.T) { }, } `, - wantErr: `Hostname "derp" contains an invalid IP address: "10.0"`, + wantErr: `invalid IP address: hostname "derp" value "10.0"`, }, { name: "invalid-prefix", @@ -430,7 +430,7 @@ func TestUnmarshalPolicy(t *testing.T) { }, } `, - wantErr: `Hostname "derp" contains an invalid IP address: "10.0/42"`, + wantErr: `invalid IP address: hostname "derp" value "10.0/42"`, }, // TODO(kradalby): Figure out why this doesn't work. // { @@ -459,7 +459,7 @@ func TestUnmarshalPolicy(t *testing.T) { ], } `, - wantErr: `AutoGroup is invalid, got: "autogroup:invalid", must be one of [autogroup:internet autogroup:member autogroup:nonroot autogroup:tagged autogroup:self]`, + wantErr: `invalid autogroup: got "autogroup:invalid", must be one of [autogroup:internet autogroup:member autogroup:nonroot autogroup:tagged autogroup:self]`, }, { name: "undefined-hostname-errors-2490", @@ -478,7 +478,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `Host "user1" is not defined in the Policy, please define or remove the reference to it`, + wantErr: `host not defined in policy: "user1" - please define or remove the reference`, }, { name: "defined-hostname-does-not-err-2490", @@ -571,7 +571,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `"autogroup:internet" used in source, it can only be used in ACL destinations`, + wantErr: `autogroup:internet can only be used in ACL destinations`, }, { name: "autogroup:internet-in-ssh-src-not-allowed", @@ -590,7 +590,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `"autogroup:internet" used in SSH source, it can only be used in ACL destinations`, + wantErr: `autogroup not supported for SSH: autogroup:internet in SSH source`, }, { name: "autogroup:internet-in-ssh-dst-not-allowed", @@ -609,7 +609,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `"autogroup:internet" used in SSH destination, it can only be used in ACL destinations`, + wantErr: `autogroup not supported for SSH: autogroup:internet in SSH destination`, }, { name: "ssh-basic", @@ -760,7 +760,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `Group "group:notdefined" is not defined in the Policy, please define or remove the reference to it`, + wantErr: `group not defined in policy: "group:notdefined"`, }, { name: "group-must-be-defined-acl-dst", @@ -779,7 +779,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `Group "group:notdefined" is not defined in the Policy, please define or remove the reference to it`, + wantErr: `group not defined in policy: "group:notdefined"`, }, { name: "group-must-be-defined-acl-ssh-src", @@ -798,7 +798,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `Group "group:notdefined" is not defined in the Policy, please define or remove the reference to it`, + wantErr: `group not defined in policy: "group:notdefined"`, }, { name: "group-must-be-defined-acl-tagOwner", @@ -809,7 +809,7 @@ func TestUnmarshalPolicy(t *testing.T) { }, } `, - wantErr: `Group "group:notdefined" is not defined in the Policy, please define or remove the reference to it`, + wantErr: `group not defined in policy: "group:notdefined"`, }, { name: "group-must-be-defined-acl-autoapprover-route", @@ -822,7 +822,7 @@ func TestUnmarshalPolicy(t *testing.T) { }, } `, - wantErr: `Group "group:notdefined" is not defined in the Policy, please define or remove the reference to it`, + wantErr: `group not defined in policy: "group:notdefined"`, }, { name: "group-must-be-defined-acl-autoapprover-exitnode", @@ -833,7 +833,7 @@ func TestUnmarshalPolicy(t *testing.T) { }, } `, - wantErr: `Group "group:notdefined" is not defined in the Policy, please define or remove the reference to it`, + wantErr: `group not defined in policy: "group:notdefined"`, }, { name: "tag-must-be-defined-acl-src", @@ -852,7 +852,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `Tag "tag:notdefined" is not defined in the Policy, please define or remove the reference to it`, + wantErr: `tag not defined in policy: "tag:notdefined"`, }, { name: "tag-must-be-defined-acl-dst", @@ -871,7 +871,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `Tag "tag:notdefined" is not defined in the Policy, please define or remove the reference to it`, + wantErr: `tag not defined in policy: "tag:notdefined"`, }, { name: "tag-must-be-defined-acl-ssh-src", @@ -890,7 +890,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `Tag "tag:notdefined" is not defined in the Policy, please define or remove the reference to it`, + wantErr: `tag not defined in policy: "tag:notdefined"`, }, { name: "tag-must-be-defined-acl-ssh-dst", @@ -912,7 +912,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `Tag "tag:notdefined" is not defined in the Policy, please define or remove the reference to it`, + wantErr: `tag not defined in policy: "tag:notdefined"`, }, { name: "tag-must-be-defined-acl-autoapprover-route", @@ -925,7 +925,7 @@ func TestUnmarshalPolicy(t *testing.T) { }, } `, - wantErr: `Tag "tag:notdefined" is not defined in the Policy, please define or remove the reference to it`, + wantErr: `tag not defined in policy: "tag:notdefined"`, }, { name: "tag-must-be-defined-acl-autoapprover-exitnode", @@ -936,7 +936,7 @@ func TestUnmarshalPolicy(t *testing.T) { }, } `, - wantErr: `Tag "tag:notdefined" is not defined in the Policy, please define or remove the reference to it`, + wantErr: `tag not defined in policy: "tag:notdefined"`, }, { name: "missing-dst-port-is-err", @@ -955,7 +955,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `hostport must contain a colon (":")`, + wantErr: `hostport must contain a colon`, }, { name: "dst-port-zero-is-err", @@ -985,7 +985,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `unknown field "rules"`, + wantErr: `unknown field in policy: "rules"`, }, { name: "disallow-unsupported-fields-nested", @@ -1008,7 +1008,7 @@ func TestUnmarshalPolicy(t *testing.T) { } } `, - wantErr: `Group has to start with "group:", got: "INVALID_GROUP_FIELD"`, + wantErr: `group must start with 'group:': got "INVALID_GROUP_FIELD"`, }, { name: "invalid-group-datatype", @@ -1020,7 +1020,7 @@ func TestUnmarshalPolicy(t *testing.T) { } } `, - wantErr: `Group "group:invalid" value must be an array of users, got string: "should fail"`, + wantErr: `group value must be an array: group "group:invalid" got string "should fail"`, }, { name: "invalid-group-name-and-datatype-fails-on-name-first", @@ -1032,7 +1032,7 @@ func TestUnmarshalPolicy(t *testing.T) { } } `, - wantErr: `Group has to start with "group:", got: "INVALID_GROUP_FIELD"`, + wantErr: `group must start with 'group:': got "INVALID_GROUP_FIELD"`, }, { name: "disallow-unsupported-fields-hosts-level", @@ -1044,7 +1044,7 @@ func TestUnmarshalPolicy(t *testing.T) { } } `, - wantErr: `Hostname "INVALID_HOST_FIELD" contains an invalid IP address: "should fail"`, + wantErr: `invalid IP address: hostname "INVALID_HOST_FIELD" value "should fail"`, }, { name: "disallow-unsupported-fields-tagowners-level", @@ -1056,7 +1056,7 @@ func TestUnmarshalPolicy(t *testing.T) { } } `, - wantErr: `tag has to start with "tag:", got: "INVALID_TAG_FIELD"`, + wantErr: `tag must start with 'tag:': got "INVALID_TAG_FIELD"`, }, { name: "disallow-unsupported-fields-acls-level", @@ -1073,7 +1073,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `unknown field "INVALID_ACL_FIELD"`, + wantErr: `unknown field in policy: "INVALID_ACL_FIELD"`, }, { name: "disallow-unsupported-fields-ssh-level", @@ -1090,7 +1090,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `unknown field "INVALID_SSH_FIELD"`, + wantErr: `unknown field in policy: "INVALID_SSH_FIELD"`, }, { name: "disallow-unsupported-fields-policy-level", @@ -1107,7 +1107,7 @@ func TestUnmarshalPolicy(t *testing.T) { "INVALID_POLICY_FIELD": "should fail at policy level" } `, - wantErr: `unknown field "INVALID_POLICY_FIELD"`, + wantErr: `unknown field in policy: "INVALID_POLICY_FIELD"`, }, { name: "disallow-unsupported-fields-autoapprovers-level", @@ -1122,7 +1122,7 @@ func TestUnmarshalPolicy(t *testing.T) { } } `, - wantErr: `unknown field "INVALID_AUTO_APPROVER_FIELD"`, + wantErr: `unknown field in policy: "INVALID_AUTO_APPROVER_FIELD"`, }, // headscale-admin uses # in some field names to add metadata, so we will ignore // those to ensure it doesnt break. @@ -1181,7 +1181,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `unknown field "proto"`, + wantErr: `unknown field in policy: "proto"`, }, { name: "protocol-wildcard-not-allowed", @@ -1197,7 +1197,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `proto name "*" not known; use protocol number 0-255 or protocol name (icmp, tcp, udp, etc.)`, + wantErr: `invalid protocol: use protocol number 0-255 or protocol name`, }, { name: "protocol-case-insensitive-uppercase", @@ -1277,7 +1277,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `leading 0 not permitted in protocol number "0"`, + wantErr: `leading zero not permitted in protocol number: "0"`, }, { name: "protocol-empty-applies-to-tcp-udp-only", @@ -1324,7 +1324,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `protocol "icmp" does not support specific ports; only "*" is allowed`, + wantErr: `protocol does not support specific ports: "icmp" only allows "*"`, }, { name: "protocol-icmp-with-wildcard-port-allowed", @@ -1372,7 +1372,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `protocol "gre" does not support specific ports; only "*" is allowed`, + wantErr: `protocol does not support specific ports: "gre" only allows "*"`, }, { name: "protocol-tcp-with-specific-port-allowed", @@ -1836,7 +1836,7 @@ func TestResolvePolicy(t *testing.T) { IPv4: ap("100.100.101.103"), }, }, - wantErr: `user with token "invaliduser@" not found`, + wantErr: `user not found: token "invaliduser@"`, }, { name: "invalid-tag", @@ -1999,7 +1999,7 @@ func TestResolvePolicy(t *testing.T) { { name: "autogroup-invalid", toResolve: new(AutoGroup("autogroup:invalid")), - wantErr: "unknown autogroup", + wantErr: "invalid autogroup", }, } @@ -2670,7 +2670,7 @@ func TestNodeCanHaveTag(t *testing.T) { node: nodes[0], tag: "tag:test", want: false, - wantErr: "Username has to contain @", + wantErr: "username must contain @", }, { name: "node-cannot-have-tag", @@ -3248,7 +3248,8 @@ func TestACL_UnmarshalJSON_InvalidAction(t *testing.T) { _, err := unmarshalPolicy([]byte(policyJSON)) require.Error(t, err) - assert.Contains(t, err.Error(), `invalid action "deny"`) + assert.Contains(t, err.Error(), `invalid action`) + assert.Contains(t, err.Error(), `deny`) } // Helper function to parse aliases for testing. diff --git a/hscontrol/policy/v2/utils.go b/hscontrol/policy/v2/utils.go index 80de52bc..3fb0d38b 100644 --- a/hscontrol/policy/v2/utils.go +++ b/hscontrol/policy/v2/utils.go @@ -9,6 +9,18 @@ import ( "tailscale.com/tailcfg" ) +// Sentinel errors for port and destination parsing. +var ( + ErrInputMissingColon = errors.New("input must contain a colon character separating destination and port") + ErrInputStartsWithColon = errors.New("input cannot start with a colon character") + ErrInputEndsWithColon = errors.New("input cannot end with a colon character") + ErrInvalidPortRange = errors.New("invalid port range format") + ErrPortRangeInverted = errors.New("invalid port range: first port is greater than last port") + ErrPortMustBePositive = errors.New("first port must be >0, or use '*' for wildcard") + ErrInvalidPortNumber = errors.New("invalid port number") + ErrPortOutOfRange = errors.New("port number out of range") +) + // splitDestinationAndPort takes an input string and returns the destination and port as a tuple, or an error if the input is invalid. func splitDestinationAndPort(input string) (string, string, error) { // Find the last occurrence of the colon character @@ -16,15 +28,15 @@ func splitDestinationAndPort(input string) (string, string, error) { // Check if the colon character is present and not at the beginning or end of the string if lastColonIndex == -1 { - return "", "", errors.New("input must contain a colon character separating destination and port") + return "", "", ErrInputMissingColon } if lastColonIndex == 0 { - return "", "", errors.New("input cannot start with a colon character") + return "", "", ErrInputStartsWithColon } if lastColonIndex == len(input)-1 { - return "", "", errors.New("input cannot end with a colon character") + return "", "", ErrInputEndsWithColon } // Split the string into destination and port based on the last colon @@ -52,7 +64,7 @@ func parsePortRange(portDef string) ([]tailcfg.PortRange, error) { return e == "" }) if len(rangeParts) != 2 { - return nil, errors.New("invalid port range format") + return nil, ErrInvalidPortRange } first, err := parsePort(rangeParts[0]) @@ -66,7 +78,7 @@ func parsePortRange(portDef string) ([]tailcfg.PortRange, error) { } if first > last { - return nil, errors.New("invalid port range: first port is greater than last port") + return nil, ErrPortRangeInverted } portRanges = append(portRanges, tailcfg.PortRange{First: first, Last: last}) @@ -77,7 +89,7 @@ func parsePortRange(portDef string) ([]tailcfg.PortRange, error) { } if port < 1 { - return nil, errors.New("first port must be >0, or use '*' for wildcard") + return nil, ErrPortMustBePositive } portRanges = append(portRanges, tailcfg.PortRange{First: port, Last: port}) @@ -91,11 +103,11 @@ func parsePortRange(portDef string) ([]tailcfg.PortRange, error) { func parsePort(portStr string) (uint16, error) { port, err := strconv.Atoi(portStr) if err != nil { - return 0, errors.New("invalid port number") + return 0, ErrInvalidPortNumber } if port < 0 || port > 65535 { - return 0, errors.New("port number out of range") + return 0, ErrPortOutOfRange } return uint16(port), nil diff --git a/hscontrol/policy/v2/utils_test.go b/hscontrol/policy/v2/utils_test.go index a845e7a9..2ce95537 100644 --- a/hscontrol/policy/v2/utils_test.go +++ b/hscontrol/policy/v2/utils_test.go @@ -24,14 +24,14 @@ func TestParseDestinationAndPort(t *testing.T) { {"tag:api-server:443", "tag:api-server", "443", nil}, {"example-host-1:*", "example-host-1", "*", nil}, {"hostname:80-90", "hostname", "80-90", nil}, - {"invalidinput", "", "", errors.New("input must contain a colon character separating destination and port")}, - {":invalid", "", "", errors.New("input cannot start with a colon character")}, - {"invalid:", "", "", errors.New("input cannot end with a colon character")}, + {"invalidinput", "", "", ErrInputMissingColon}, + {":invalid", "", "", ErrInputStartsWithColon}, + {"invalid:", "", "", ErrInputEndsWithColon}, } for _, testCase := range testCases { dst, port, err := splitDestinationAndPort(testCase.input) - if dst != testCase.expectedDst || port != testCase.expectedPort || (err != nil && err.Error() != testCase.expectedErr.Error()) { + if dst != testCase.expectedDst || port != testCase.expectedPort || !errors.Is(err, testCase.expectedErr) { t.Errorf("parseDestinationAndPort(%q) = (%q, %q, %v), want (%q, %q, %v)", testCase.input, dst, port, err, testCase.expectedDst, testCase.expectedPort, testCase.expectedErr) } @@ -42,27 +42,23 @@ func TestParsePort(t *testing.T) { tests := []struct { input string expected uint16 - err string + err error }{ - {"80", 80, ""}, - {"0", 0, ""}, - {"65535", 65535, ""}, - {"-1", 0, "port number out of range"}, - {"65536", 0, "port number out of range"}, - {"abc", 0, "invalid port number"}, - {"", 0, "invalid port number"}, + {"80", 80, nil}, + {"0", 0, nil}, + {"65535", 65535, nil}, + {"-1", 0, ErrPortOutOfRange}, + {"65536", 0, ErrPortOutOfRange}, + {"abc", 0, ErrInvalidPortNumber}, + {"", 0, ErrInvalidPortNumber}, } for _, test := range tests { result, err := parsePort(test.input) - if err != nil && err.Error() != test.err { + if !errors.Is(err, test.err) { t.Errorf("parsePort(%q) error = %v, expected error = %v", test.input, err, test.err) } - if err == nil && test.err != "" { - t.Errorf("parsePort(%q) expected error = %v, got nil", test.input, test.err) - } - if result != test.expected { t.Errorf("parsePort(%q) = %v, expected %v", test.input, result, test.expected) } @@ -73,32 +69,28 @@ func TestParsePortRange(t *testing.T) { tests := []struct { input string expected []tailcfg.PortRange - err string + err error }{ - {"80", []tailcfg.PortRange{{First: 80, Last: 80}}, ""}, - {"80-90", []tailcfg.PortRange{{First: 80, Last: 90}}, ""}, - {"80,90", []tailcfg.PortRange{{First: 80, Last: 80}, {First: 90, Last: 90}}, ""}, - {"80-91,92,93-95", []tailcfg.PortRange{{First: 80, Last: 91}, {First: 92, Last: 92}, {First: 93, Last: 95}}, ""}, - {"*", []tailcfg.PortRange{tailcfg.PortRangeAny}, ""}, - {"80-", nil, "invalid port range format"}, - {"-90", nil, "invalid port range format"}, - {"80-90,", nil, "invalid port number"}, - {"80,90-", nil, "invalid port range format"}, - {"80-90,abc", nil, "invalid port number"}, - {"80-90,65536", nil, "port number out of range"}, - {"80-90,90-80", nil, "invalid port range: first port is greater than last port"}, + {"80", []tailcfg.PortRange{{First: 80, Last: 80}}, nil}, + {"80-90", []tailcfg.PortRange{{First: 80, Last: 90}}, nil}, + {"80,90", []tailcfg.PortRange{{First: 80, Last: 80}, {First: 90, Last: 90}}, nil}, + {"80-91,92,93-95", []tailcfg.PortRange{{First: 80, Last: 91}, {First: 92, Last: 92}, {First: 93, Last: 95}}, nil}, + {"*", []tailcfg.PortRange{tailcfg.PortRangeAny}, nil}, + {"80-", nil, ErrInvalidPortRange}, + {"-90", nil, ErrInvalidPortRange}, + {"80-90,", nil, ErrInvalidPortNumber}, + {"80,90-", nil, ErrInvalidPortRange}, + {"80-90,abc", nil, ErrInvalidPortNumber}, + {"80-90,65536", nil, ErrPortOutOfRange}, + {"80-90,90-80", nil, ErrPortRangeInverted}, } for _, test := range tests { result, err := parsePortRange(test.input) - if err != nil && err.Error() != test.err { + if !errors.Is(err, test.err) { t.Errorf("parsePortRange(%q) error = %v, expected error = %v", test.input, err, test.err) } - if err == nil && test.err != "" { - t.Errorf("parsePortRange(%q) expected error = %v, got nil", test.input, test.err) - } - if diff := cmp.Diff(result, test.expected); diff != "" { t.Errorf("parsePortRange(%q) mismatch (-want +got):\n%s", test.input, diff) }