diff --git a/cmd/headscale/cli/mockoidc.go b/cmd/headscale/cli/mockoidc.go index c80c2a28..d1374ec5 100644 --- a/cmd/headscale/cli/mockoidc.go +++ b/cmd/headscale/cli/mockoidc.go @@ -2,7 +2,6 @@ package cli import ( "encoding/json" - "errors" "fmt" "net" "net/http" @@ -19,6 +18,7 @@ const ( errMockOidcClientIDNotDefined = Error("MOCKOIDC_CLIENT_ID not defined") errMockOidcClientSecretNotDefined = Error("MOCKOIDC_CLIENT_SECRET not defined") errMockOidcPortNotDefined = Error("MOCKOIDC_PORT not defined") + errMockOidcUsersNotDefined = Error("MOCKOIDC_USERS not defined") refreshTTL = 60 * time.Minute ) @@ -69,7 +69,7 @@ func mockOIDC() error { userStr := os.Getenv("MOCKOIDC_USERS") if userStr == "" { - return errors.New("MOCKOIDC_USERS not defined") + return errMockOidcUsersNotDefined } var users []mockoidc.MockUser diff --git a/cmd/headscale/cli/users.go b/cmd/headscale/cli/users.go index 9f0954c6..084548a9 100644 --- a/cmd/headscale/cli/users.go +++ b/cmd/headscale/cli/users.go @@ -14,6 +14,12 @@ import ( "google.golang.org/grpc/status" ) +// Sentinel errors for CLI commands. +var ( + ErrNameOrIDRequired = errors.New("--name or --identifier flag is required") + ErrMultipleUsersFoundUseID = errors.New("unable to determine user, query returned multiple users, use ID") +) + func usernameAndIDFlag(cmd *cobra.Command) { cmd.Flags().Int64P("identifier", "i", -1, "User identifier (ID)") cmd.Flags().StringP("name", "n", "", "Username") @@ -26,10 +32,9 @@ func usernameAndIDFromFlag(cmd *cobra.Command) (uint64, string) { identifier, _ := cmd.Flags().GetInt64("identifier") if username == "" && identifier < 0 { - err := errors.New("--name or --identifier flag is required") ErrorOutput( - err, - "Cannot rename user: "+status.Convert(err).Message(), + ErrNameOrIDRequired, + "Cannot rename user: "+status.Convert(ErrNameOrIDRequired).Message(), "", ) } @@ -149,7 +154,7 @@ var destroyUserCmd = &cobra.Command{ } if len(users.GetUsers()) != 1 { - err := errors.New("Unable to determine user to delete, query returned multiple users, use ID") + err := ErrMultipleUsersFoundUseID ErrorOutput( err, "Error: "+status.Convert(err).Message(), @@ -277,7 +282,7 @@ var renameUserCmd = &cobra.Command{ } if len(users.GetUsers()) != 1 { - err := errors.New("Unable to determine user to delete, query returned multiple users, use ID") + err := ErrMultipleUsersFoundUseID ErrorOutput( err, "Error: "+status.Convert(err).Message(), diff --git a/cmd/hi/docker.go b/cmd/hi/docker.go index 81f1d729..698e9d54 100644 --- a/cmd/hi/docker.go +++ b/cmd/hi/docker.go @@ -27,6 +27,7 @@ var ( ErrTestFailed = errors.New("test failed") ErrUnexpectedContainerWait = errors.New("unexpected end of container wait") ErrNoDockerContext = errors.New("no docker context found") + ErrMemoryLimitExceeded = errors.New("container exceeded memory limits") ) // runTestContainer executes integration tests in a Docker container. @@ -151,7 +152,7 @@ func runTestContainer(ctx context.Context, config *RunConfig) error { violation.ContainerName, violation.MaxMemoryMB, violation.LimitMB) } - return fmt.Errorf("test failed: %d container(s) exceeded memory limits", len(violations)) + return fmt.Errorf("%w: %d container(s)", ErrMemoryLimitExceeded, len(violations)) } } diff --git a/cmd/hi/stats.go b/cmd/hi/stats.go index e80ee8d1..00a6cc4f 100644 --- a/cmd/hi/stats.go +++ b/cmd/hi/stats.go @@ -18,6 +18,9 @@ import ( "github.com/docker/docker/client" ) +// Sentinel errors for stats collection. +var ErrStatsCollectionAlreadyStarted = errors.New("stats collection already started") + // ContainerStats represents statistics for a single container. type ContainerStats struct { ContainerID string @@ -63,7 +66,7 @@ func (sc *StatsCollector) StartCollection(ctx context.Context, runID string, ver defer sc.mutex.Unlock() if sc.collectionStarted { - return errors.New("stats collection already started") + return ErrStatsCollectionAlreadyStarted } sc.collectionStarted = true diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index 7c818a75..04f9a621 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -36,6 +36,7 @@ var ( "node not found in registration cache", ) ErrCouldNotConvertNodeInterface = errors.New("failed to convert node interface") + ErrNameNotUnique = errors.New("name is not unique") ) // ListPeers returns peers of node, regardless of any Policy or if the node is expired. @@ -288,7 +289,7 @@ func RenameNode(tx *gorm.DB, } if count > 0 { - return errors.New("name is not unique") + return ErrNameNotUnique } if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil { @@ -670,7 +671,7 @@ func (hsdb *HSDatabase) CreateNodeForTest(user *types.User, hostname ...string) Hostname: nodeName, UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: new(pak.ID), + AuthKeyID: &pak.ID, } err = hsdb.DB.Save(node).Error diff --git a/hscontrol/db/text_serialiser.go b/hscontrol/db/text_serialiser.go index 102c0e9c..b1d294ea 100644 --- a/hscontrol/db/text_serialiser.go +++ b/hscontrol/db/text_serialiser.go @@ -3,12 +3,20 @@ package db import ( "context" "encoding" + "errors" "fmt" "reflect" "gorm.io/gorm/schema" ) +// Sentinel errors for text serialisation. +var ( + ErrTextUnmarshalFailed = errors.New("failed to unmarshal text value") + ErrUnsupportedType = errors.New("unsupported type") + ErrTextMarshalerOnly = errors.New("only encoding.TextMarshaler is supported") +) + // Got from https://github.com/xdg-go/strum/blob/main/types.go var textUnmarshalerType = reflect.TypeFor[encoding.TextUnmarshaler]() @@ -49,7 +57,7 @@ func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect case string: bytes = []byte(v) default: - return fmt.Errorf("failed to unmarshal text value: %#v", dbValue) + return fmt.Errorf("%w: %#v", ErrTextUnmarshalFailed, dbValue) } if isTextUnmarshaler(fieldValue) { @@ -75,7 +83,7 @@ func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect return nil } else { - return fmt.Errorf("unsupported type: %T", fieldValue.Interface()) + return fmt.Errorf("%w: %T", ErrUnsupportedType, fieldValue.Interface()) } } @@ -99,6 +107,6 @@ func (TextSerialiser) Value(ctx context.Context, field *schema.Field, dst reflec return string(b), nil default: - return nil, fmt.Errorf("only encoding.TextMarshaler is supported, got %t", v) + return nil, fmt.Errorf("%w, got %T", ErrTextMarshalerOnly, v) } } diff --git a/hscontrol/db/users.go b/hscontrol/db/users.go index 650dbd49..be073999 100644 --- a/hscontrol/db/users.go +++ b/hscontrol/db/users.go @@ -15,6 +15,8 @@ var ( ErrUserExists = errors.New("user already exists") ErrUserNotFound = errors.New("user not found") ErrUserStillHasNodes = errors.New("user not empty: node(s) found") + ErrTooManyWhereArgs = errors.New("expect 0 or 1 where User structs") + ErrMultipleUsers = errors.New("expected exactly one user") ) func (hsdb *HSDatabase) CreateUser(user types.User) (*types.User, error) { @@ -153,7 +155,7 @@ func (hsdb *HSDatabase) ListUsers(where ...*types.User) ([]types.User, error) { // ListUsers gets all the existing users. func ListUsers(tx *gorm.DB, where ...*types.User) ([]types.User, error) { if len(where) > 1 { - return nil, fmt.Errorf("expect 0 or 1 where User structs, got %d", len(where)) + return nil, fmt.Errorf("%w, got %d", ErrTooManyWhereArgs, len(where)) } var user *types.User @@ -182,7 +184,7 @@ func (hsdb *HSDatabase) GetUserByName(name string) (*types.User, error) { } if len(users) != 1 { - return nil, fmt.Errorf("expected exactly one user, found %d", len(users)) + return nil, fmt.Errorf("%w, found %d", ErrMultipleUsers, len(users)) } return &users[0], nil diff --git a/hscontrol/dns/extrarecords.go b/hscontrol/dns/extrarecords.go index 5d16c675..7cd88abe 100644 --- a/hscontrol/dns/extrarecords.go +++ b/hscontrol/dns/extrarecords.go @@ -4,6 +4,7 @@ import ( "context" "crypto/sha256" "encoding/json" + "errors" "fmt" "os" "sync" @@ -15,6 +16,9 @@ import ( "tailscale.com/util/set" ) +// Sentinel errors for extra records. +var ErrPathIsDirectory = errors.New("path is a directory, only file is supported") + type ExtraRecordsMan struct { mu sync.RWMutex records set.Set[tailcfg.DNSRecord] @@ -39,7 +43,7 @@ func NewExtraRecordsManager(path string) (*ExtraRecordsMan, error) { } if fi.IsDir() { - return nil, fmt.Errorf("path is a directory, only file is supported: %s", path) + return nil, fmt.Errorf("%w: %s", ErrPathIsDirectory, path) } records, hash, err := readExtraRecordsFromPath(path) diff --git a/hscontrol/noise.go b/hscontrol/noise.go index 869fe3f3..d8e83154 100644 --- a/hscontrol/noise.go +++ b/hscontrol/noise.go @@ -31,6 +31,9 @@ const ( earlyPayloadMagic = "\xff\xff\xffTS" ) +// Sentinel errors for noise server. +var ErrUnsupportedClientVersion = errors.New("unsupported client version") + type noiseServer struct { headscale *Headscale @@ -117,7 +120,7 @@ func (h *Headscale) NoiseUpgradeHandler( } func unsupportedClientError(version tailcfg.CapabilityVersion) error { - return fmt.Errorf("unsupported client version: %s (%d)", capver.TailscaleVersion(version), version) + return fmt.Errorf("%w: %s (%d)", ErrUnsupportedClientVersion, capver.TailscaleVersion(version), version) } func (ns *noiseServer) earlyNoise(protocolVersion int, writer io.Writer) error { diff --git a/hscontrol/tailsql.go b/hscontrol/tailsql.go index efce647d..82cf9d58 100644 --- a/hscontrol/tailsql.go +++ b/hscontrol/tailsql.go @@ -13,6 +13,9 @@ import ( "tailscale.com/types/logger" ) +// Sentinel errors for tailsql service. +var ErrNoCertDomains = errors.New("no cert domains available for HTTPS") + func runTailSQLService(ctx context.Context, logf logger.Logf, stateDir, dbPath string) error { opts := tailsql.Options{ Hostname: "tailsql-headscale", @@ -71,7 +74,7 @@ func runTailSQLService(ctx context.Context, logf logger.Logf, stateDir, dbPath s // When serving TLS, add a redirect from HTTP on port 80 to HTTPS on 443. certDomains := tsNode.CertDomains() if len(certDomains) == 0 { - return errors.New("no cert domains available for HTTPS") + return ErrNoCertDomains } base := "https://" + certDomains[0] go http.Serve(lst, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/hscontrol/types/common.go b/hscontrol/types/common.go index be3756a0..e0f4fcdd 100644 --- a/hscontrol/types/common.go +++ b/hscontrol/types/common.go @@ -20,7 +20,10 @@ const ( DatabaseSqlite = "sqlite3" ) -var ErrCannotParsePrefix = errors.New("cannot parse prefix") +var ( + ErrCannotParsePrefix = errors.New("cannot parse prefix") + ErrInvalidRegIDLength = errors.New("registration ID has invalid length") +) type StateUpdateType int @@ -175,7 +178,7 @@ func MustRegistrationID() RegistrationID { func RegistrationIDFromString(str string) (RegistrationID, error) { if len(str) != RegistrationIDLength { - return "", fmt.Errorf("registration ID must be %d characters long", RegistrationIDLength) + return "", fmt.Errorf("%w: expected %d characters", ErrInvalidRegIDLength, RegistrationIDLength) } return RegistrationID(str), nil diff --git a/hscontrol/types/config.go b/hscontrol/types/config.go index fffe166d..e947e104 100644 --- a/hscontrol/types/config.go +++ b/hscontrol/types/config.go @@ -33,10 +33,12 @@ const ( ) var ( - errOidcMutuallyExclusive = errors.New("oidc_client_secret and oidc_client_secret_path are mutually exclusive") - errServerURLSuffix = errors.New("server_url cannot be part of base_domain in a way that could make the DERP and headscale server unreachable") - errServerURLSame = errors.New("server_url cannot use the same domain as base_domain in a way that could make the DERP and headscale server unreachable") - errInvalidPKCEMethod = errors.New("pkce.method must be either 'plain' or 'S256'") + errOidcMutuallyExclusive = errors.New("oidc_client_secret and oidc_client_secret_path are mutually exclusive") + errServerURLSuffix = errors.New("server_url cannot be part of base_domain in a way that could make the DERP and headscale server unreachable") + errServerURLSame = errors.New("server_url cannot use the same domain as base_domain in a way that could make the DERP and headscale server unreachable") + errInvalidPKCEMethod = errors.New("pkce.method must be either 'plain' or 'S256'") + errNoPrefixConfigured = errors.New("no IPv4 or IPv6 prefix configured, minimum one prefix is required") + errInvalidAllocationStrategy = errors.New("invalid prefixes.allocation strategy") ) type IPAllocationStrategy string @@ -929,7 +931,7 @@ func LoadServerConfig() (*Config, error) { } if prefix4 == nil && prefix6 == nil { - return nil, errors.New("no IPv4 or IPv6 prefix configured, minimum one prefix is required") + return nil, errNoPrefixConfigured } allocStr := viper.GetString("prefixes.allocation") @@ -941,7 +943,8 @@ func LoadServerConfig() (*Config, error) { alloc = IPAllocationStrategyRandom default: return nil, fmt.Errorf( - "config error, prefixes.allocation is set to %s, which is not a valid strategy, allowed options: %s, %s", + "%w: %q, allowed options: %s, %s", + errInvalidAllocationStrategy, allocStr, IPAllocationStrategySequential, IPAllocationStrategyRandom, diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index 5140bc44..ea96284c 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -844,7 +844,7 @@ func (nv NodeView) PeerChangeFromMapRequest(req tailcfg.MapRequest) tailcfg.Peer // GetFQDN returns the fully qualified domain name for the node. func (nv NodeView) GetFQDN(baseDomain string) (string, error) { if !nv.Valid() { - return "", errors.New("failed to create valid FQDN: node view is invalid") + return "", fmt.Errorf("failed to create valid FQDN: %w", ErrInvalidNodeView) } return nv.ж.GetFQDN(baseDomain) diff --git a/hscontrol/types/users.go b/hscontrol/types/users.go index dbcf4f44..c724c909 100644 --- a/hscontrol/types/users.go +++ b/hscontrol/types/users.go @@ -4,6 +4,7 @@ import ( "cmp" "database/sql" "encoding/json" + "errors" "fmt" "net/mail" "net/url" @@ -18,6 +19,9 @@ import ( "tailscale.com/tailcfg" ) +// Sentinel errors for user types. +var ErrCannotParseBool = errors.New("could not parse value as boolean") + type UserID uint64 type Users []User @@ -224,7 +228,7 @@ func (bit *FlexibleBoolean) UnmarshalJSON(data []byte) error { *bit = FlexibleBoolean(pv) default: - return fmt.Errorf("could not parse %v as boolean", v) + return fmt.Errorf("%w: %v", ErrCannotParseBool, v) } return nil diff --git a/hscontrol/util/dns.go b/hscontrol/util/dns.go index dcd58528..bc48f592 100644 --- a/hscontrol/util/dns.go +++ b/hscontrol/util/dns.go @@ -26,6 +26,21 @@ var invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+") var ErrInvalidHostName = errors.New("invalid hostname") +// Sentinel errors for username validation. +var ( + ErrUsernameTooShort = errors.New("username must be at least 2 characters long") + ErrUsernameMustStartLetter = errors.New("username must start with a letter") + ErrUsernameTooManyAt = errors.New("username cannot contain more than one '@'") + ErrUsernameInvalidChar = errors.New("username contains invalid character") +) + +// Sentinel errors for hostname validation. +var ( + ErrHostnameTooShort = errors.New("hostname too short, must be at least 2 characters") + ErrHostnameHyphenEnds = errors.New("hostname cannot start or end with a hyphen") + ErrHostnameDotEnds = errors.New("hostname cannot start or end with a dot") +) + // ValidateUsername checks if a username is valid. // It must be at least 2 characters long, start with a letter, and contain // only letters, numbers, hyphens, dots, and underscores. @@ -34,12 +49,12 @@ var ErrInvalidHostName = errors.New("invalid hostname") func ValidateUsername(username string) error { // Ensure the username meets the minimum length requirement if len(username) < 2 { - return errors.New("username must be at least 2 characters long") + return ErrUsernameTooShort } // Ensure the username starts with a letter if !unicode.IsLetter(rune(username[0])) { - return errors.New("username must start with a letter") + return ErrUsernameMustStartLetter } atCount := 0 @@ -55,10 +70,10 @@ func ValidateUsername(username string) error { case char == '@': atCount++ if atCount > 1 { - return errors.New("username cannot contain more than one '@'") + return ErrUsernameTooManyAt } default: - return fmt.Errorf("username contains invalid character: '%c'", char) + return fmt.Errorf("%w: '%c'", ErrUsernameInvalidChar, char) } } @@ -70,10 +85,7 @@ func ValidateUsername(username string) error { // The hostname must already be lowercase and contain only valid characters. func ValidateHostname(name string) error { if len(name) < 2 { - return fmt.Errorf( - "hostname %q is too short, must be at least 2 characters", - name, - ) + return fmt.Errorf("%w: %q", ErrHostnameTooShort, name) } if len(name) > LabelHostnameLength { return fmt.Errorf( @@ -90,17 +102,11 @@ func ValidateHostname(name string) error { } if strings.HasPrefix(name, "-") || strings.HasSuffix(name, "-") { - return fmt.Errorf( - "hostname %q cannot start or end with a hyphen", - name, - ) + return fmt.Errorf("%w: %q", ErrHostnameHyphenEnds, name) } if strings.HasPrefix(name, ".") || strings.HasSuffix(name, ".") { - return fmt.Errorf( - "hostname %q cannot start or end with a dot", - name, - ) + return fmt.Errorf("%w: %q", ErrHostnameDotEnds, name) } if invalidDNSRegex.MatchString(name) { diff --git a/hscontrol/util/util.go b/hscontrol/util/util.go index 53189656..b4ca0c51 100644 --- a/hscontrol/util/util.go +++ b/hscontrol/util/util.go @@ -16,6 +16,19 @@ import ( "tailscale.com/util/cmpver" ) +// Sentinel errors for URL parsing. +var ( + ErrMultipleURLsFound = errors.New("multiple URLs found") + ErrNoURLFound = errors.New("no URL found") +) + +// Sentinel errors for traceroute parsing. +var ( + ErrTracerouteEmpty = errors.New("empty traceroute output") + ErrTracerouteHeader = errors.New("parsing traceroute header") + ErrTracerouteNotReached = errors.New("traceroute did not reach target") +) + func TailscaleVersionNewerOrEqual(minimum, toCheck string) bool { if cmpver.Compare(minimum, toCheck) <= 0 || toCheck == "unstable" || @@ -37,7 +50,7 @@ func ParseLoginURLFromCLILogin(output string) (*url.URL, error) { line = strings.TrimSpace(line) if strings.HasPrefix(line, "http://") || strings.HasPrefix(line, "https://") { if urlStr != "" { - return nil, fmt.Errorf("multiple URLs found: %s and %s", urlStr, line) + return nil, fmt.Errorf("%w: %s and %s", ErrMultipleURLsFound, urlStr, line) } urlStr = line @@ -45,7 +58,7 @@ func ParseLoginURLFromCLILogin(output string) (*url.URL, error) { } if urlStr == "" { - return nil, errors.New("no URL found") + return nil, ErrNoURLFound } loginURL, err := url.Parse(urlStr) @@ -91,7 +104,7 @@ type Traceroute struct { func ParseTraceroute(output string) (Traceroute, error) { lines := strings.Split(strings.TrimSpace(output), "\n") if len(lines) < 1 { - return Traceroute{}, errors.New("empty traceroute output") + return Traceroute{}, ErrTracerouteEmpty } // Parse the header line - handle both 'traceroute' and 'tracert' (Windows) @@ -99,7 +112,7 @@ func ParseTraceroute(output string) (Traceroute, error) { headerMatches := headerRegex.FindStringSubmatch(lines[0]) if len(headerMatches) < 2 { - return Traceroute{}, fmt.Errorf("parsing traceroute header: %s", lines[0]) + return Traceroute{}, fmt.Errorf("%w: %s", ErrTracerouteHeader, lines[0]) } hostname := headerMatches[1] @@ -255,7 +268,7 @@ func ParseTraceroute(output string) (Traceroute, error) { // If we didn't reach the target, it's unsuccessful if !result.Success { - result.Err = errors.New("traceroute did not reach target") + result.Err = ErrTracerouteNotReached } return result, nil diff --git a/hscontrol/util/util_test.go b/hscontrol/util/util_test.go index ec72250e..22788bff 100644 --- a/hscontrol/util/util_test.go +++ b/hscontrol/util/util_test.go @@ -1,7 +1,6 @@ package util import ( - "errors" "net/netip" "strings" "testing" @@ -322,7 +321,7 @@ func TestParseTraceroute(t *testing.T) { }, }, Success: false, - Err: errors.New("traceroute did not reach target"), + Err: ErrTracerouteNotReached, }, wantErr: false, }, @@ -490,7 +489,7 @@ over a maximum of 30 hops: }, }, Success: false, - Err: errors.New("traceroute did not reach target"), + Err: ErrTracerouteNotReached, }, wantErr: false, },