From 1325fd8b271c4cf0fa274e10541775e4519cb236 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Wed, 7 Jan 2026 13:36:51 +0100 Subject: [PATCH] cli,hscontrol: use ID-based preauthkey operations --- cmd/headscale/cli/preauthkeys.go | 55 ++++++++----------------------- hscontrol/db/preauth_keys.go | 42 ++++++++++------------- hscontrol/db/preauth_keys_test.go | 13 ++------ hscontrol/db/users.go | 4 +-- hscontrol/grpcv1.go | 29 ++-------------- hscontrol/state/state.go | 12 +++---- 6 files changed, 43 insertions(+), 112 deletions(-) diff --git a/cmd/headscale/cli/preauthkeys.go b/cmd/headscale/cli/preauthkeys.go index 4762c10e..6acb346d 100644 --- a/cmd/headscale/cli/preauthkeys.go +++ b/cmd/headscale/cli/preauthkeys.go @@ -20,17 +20,6 @@ const ( func init() { rootCmd.AddCommand(preauthkeysCmd) - preauthkeysCmd.PersistentFlags().Uint64P("user", "u", 0, "User identifier (ID)") - - preauthkeysCmd.PersistentFlags().StringP("namespace", "n", "", "User") - pakNamespaceFlag := preauthkeysCmd.PersistentFlags().Lookup("namespace") - pakNamespaceFlag.Deprecated = deprecateNamespaceMessage - pakNamespaceFlag.Hidden = true - - err := preauthkeysCmd.MarkPersistentFlagRequired("user") - if err != nil { - log.Fatal().Err(err).Msg("") - } preauthkeysCmd.AddCommand(listPreAuthKeys) preauthkeysCmd.AddCommand(createPreAuthKeyCmd) preauthkeysCmd.AddCommand(expirePreAuthKeyCmd) @@ -43,6 +32,9 @@ func init() { StringP("expiration", "e", DefaultPreAuthKeyExpiry, "Human-readable expiration of the key (e.g. 30m, 24h)") createPreAuthKeyCmd.Flags(). StringSlice("tags", []string{}, "Tags to automatically assign to node") + createPreAuthKeyCmd.PersistentFlags().Uint64P("user", "u", 0, "User identifier (ID)") + expirePreAuthKeyCmd.PersistentFlags().Uint64P("id", "i", 0, "Authkey ID") + deletePreAuthKeyCmd.PersistentFlags().Uint64P("id", "i", 0, "Authkey ID") } var preauthkeysCmd = &cobra.Command{ @@ -53,25 +45,16 @@ var preauthkeysCmd = &cobra.Command{ var listPreAuthKeys = &cobra.Command{ Use: "list", - Short: "List the preauthkeys for this user", + Short: "List all preauthkeys", Aliases: []string{"ls", "show"}, Run: func(cmd *cobra.Command, args []string) { output, _ := cmd.Flags().GetString("output") - user, err := cmd.Flags().GetUint64("user") - if err != nil { - ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output) - } - ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() - request := &v1.ListPreAuthKeysRequest{ - User: user, - } - - response, err := client.ListPreAuthKeys(ctx, request) + response, err := client.ListPreAuthKeys(ctx, &v1.ListPreAuthKeysRequest{}) if err != nil { ErrorOutput( err, @@ -137,16 +120,12 @@ var listPreAuthKeys = &cobra.Command{ var createPreAuthKeyCmd = &cobra.Command{ Use: "create", - Short: "Creates a new preauthkey in the specified user", + Short: "Creates a new preauthkey", Aliases: []string{"c", "new"}, Run: func(cmd *cobra.Command, args []string) { output, _ := cmd.Flags().GetString("output") - user, err := cmd.Flags().GetUint64("user") - if err != nil { - ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output) - } - + user, _ := cmd.Flags().GetUint64("user") reusable, _ := cmd.Flags().GetBool("reusable") ephemeral, _ := cmd.Flags().GetBool("ephemeral") tags, _ := cmd.Flags().GetStringSlice("tags") @@ -195,7 +174,7 @@ var createPreAuthKeyCmd = &cobra.Command{ } var expirePreAuthKeyCmd = &cobra.Command{ - Use: "expire KEY", + Use: "expire", Short: "Expire a preauthkey", Aliases: []string{"revoke", "exp", "e"}, Args: func(cmd *cobra.Command, args []string) error { @@ -207,18 +186,14 @@ var expirePreAuthKeyCmd = &cobra.Command{ }, Run: func(cmd *cobra.Command, args []string) { output, _ := cmd.Flags().GetString("output") - user, err := cmd.Flags().GetUint64("user") - if err != nil { - ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output) - } + id, _ := cmd.Flags().GetUint64("id") ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() request := &v1.ExpirePreAuthKeyRequest{ - User: user, - Key: args[0], + Id: id, } response, err := client.ExpirePreAuthKey(ctx, request) @@ -235,7 +210,7 @@ var expirePreAuthKeyCmd = &cobra.Command{ } var deletePreAuthKeyCmd = &cobra.Command{ - Use: "delete KEY", + Use: "delete", Short: "Delete a preauthkey", Aliases: []string{"del", "rm", "d"}, Args: func(cmd *cobra.Command, args []string) error { @@ -247,18 +222,14 @@ var deletePreAuthKeyCmd = &cobra.Command{ }, Run: func(cmd *cobra.Command, args []string) { output, _ := cmd.Flags().GetString("output") - user, err := cmd.Flags().GetUint64("user") - if err != nil { - ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output) - } + id, _ := cmd.Flags().GetUint64("id") ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() request := &v1.DeletePreAuthKeyRequest{ - User: user, - Key: args[0], + Id: id, } response, err := client.DeletePreAuthKey(ctx, request) diff --git a/hscontrol/db/preauth_keys.go b/hscontrol/db/preauth_keys.go index 105495f1..a7b848c7 100644 --- a/hscontrol/db/preauth_keys.go +++ b/hscontrol/db/preauth_keys.go @@ -57,10 +57,6 @@ func CreatePreAuthKey( return nil, ErrPreAuthKeyNotTaggedOrOwned } - // If uid != nil && len(aclTags) > 0: - // Both are allowed: UserID tracks "created by", tags define node ownership - // This is valid per the new model - var ( user *types.User userID *uint @@ -158,22 +154,17 @@ func CreatePreAuthKey( }, nil } -func (hsdb *HSDatabase) ListPreAuthKeys(uid types.UserID) ([]types.PreAuthKey, error) { +func (hsdb *HSDatabase) ListPreAuthKeys() ([]types.PreAuthKey, error) { return Read(hsdb.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) { - return ListPreAuthKeysByUser(rx, uid) + return ListPreAuthKeys(rx) }) } -// ListPreAuthKeysByUser returns the list of PreAuthKeys for a user. -func ListPreAuthKeysByUser(tx *gorm.DB, uid types.UserID) ([]types.PreAuthKey, error) { - user, err := GetUserByID(tx, uid) - if err != nil { - return nil, err - } +// ListPreAuthKeys returns the list of PreAuthKeys for a user. +func ListPreAuthKeys(tx *gorm.DB) ([]types.PreAuthKey, error) { + var keys []types.PreAuthKey - keys := []types.PreAuthKey{} - - err = tx.Preload("User").Where(&types.PreAuthKey{UserID: &user.ID}).Find(&keys).Error + err := tx.Preload("User").Find(&keys).Error if err != nil { return nil, err } @@ -298,34 +289,35 @@ func GetPreAuthKey(tx *gorm.DB, key string) (*types.PreAuthKey, error) { // DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey // does not exist. This also clears the auth_key_id on any nodes that reference // this key. -func DestroyPreAuthKey(tx *gorm.DB, pak types.PreAuthKey) error { +func DestroyPreAuthKey(tx *gorm.DB, id uint64) error { return tx.Transaction(func(db *gorm.DB) error { // First, clear the foreign key reference on any nodes using this key err := db.Model(&types.Node{}). - Where("auth_key_id = ?", pak.ID). + Where("auth_key_id = ?", id). Update("auth_key_id", nil).Error if err != nil { return fmt.Errorf("failed to clear auth_key_id on nodes: %w", err) } // Then delete the pre-auth key - if result := db.Unscoped().Delete(pak); result.Error != nil { - return result.Error + err = tx.Unscoped().Delete(&types.PreAuthKey{}, id).Error + if err != nil { + return err } return nil }) } -func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error { +func (hsdb *HSDatabase) ExpirePreAuthKey(id uint64) error { return hsdb.Write(func(tx *gorm.DB) error { - return ExpirePreAuthKey(tx, k) + return ExpirePreAuthKey(tx, id) }) } -func (hsdb *HSDatabase) DeletePreAuthKey(k *types.PreAuthKey) error { +func (hsdb *HSDatabase) DeletePreAuthKey(id uint64) error { return hsdb.Write(func(tx *gorm.DB) error { - return DestroyPreAuthKey(tx, *k) + return DestroyPreAuthKey(tx, id) }) } @@ -341,7 +333,7 @@ func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error { } // MarkExpirePreAuthKey marks a PreAuthKey as expired. -func ExpirePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error { +func ExpirePreAuthKey(tx *gorm.DB, id uint64) error { now := time.Now() - return tx.Model(&types.PreAuthKey{}).Where("id = ?", k.ID).Update("expiration", now).Error + return tx.Model(&types.PreAuthKey{}).Where("id = ?", id).Update("expiration", now).Error } diff --git a/hscontrol/db/preauth_keys_test.go b/hscontrol/db/preauth_keys_test.go index 643b579c..7c5dcbd7 100644 --- a/hscontrol/db/preauth_keys_test.go +++ b/hscontrol/db/preauth_keys_test.go @@ -41,7 +41,7 @@ func TestCreatePreAuthKey(t *testing.T) { assert.NotEmpty(t, key.Key) // List keys for the user - keys, err := db.ListPreAuthKeys(types.UserID(user.ID)) + keys, err := db.ListPreAuthKeys() require.NoError(t, err) assert.Len(t, keys, 1) @@ -49,15 +49,6 @@ func TestCreatePreAuthKey(t *testing.T) { assert.Equal(t, user.ID, keys[0].User.ID) }, }, - { - name: "error_list_invalid_user_id", - test: func(t *testing.T, db *HSDatabase) { - t.Helper() - - _, err := db.ListPreAuthKeys(1000000) - assert.Error(t, err) - }, - }, } for _, tt := range tests { @@ -101,7 +92,7 @@ func TestPreAuthKeyACLTags(t *testing.T) { _, err = db.CreatePreAuthKey(user.TypedID(), false, false, nil, tagsWithDuplicate) require.NoError(t, err) - listedPaks, err := db.ListPreAuthKeys(types.UserID(user.ID)) + listedPaks, err := db.ListPreAuthKeys() require.NoError(t, err) require.Len(t, listedPaks, 1) diff --git a/hscontrol/db/users.go b/hscontrol/db/users.go index 92c3292d..6aff9ed1 100644 --- a/hscontrol/db/users.go +++ b/hscontrol/db/users.go @@ -58,12 +58,12 @@ func DestroyUser(tx *gorm.DB, uid types.UserID) error { return ErrUserStillHasNodes } - keys, err := ListPreAuthKeysByUser(tx, uid) + keys, err := ListPreAuthKeys(tx) if err != nil { return err } for _, key := range keys { - err = DestroyPreAuthKey(tx, key) + err = DestroyPreAuthKey(tx, key.ID) if err != nil { return err } diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 6c384201..f928fde2 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -184,16 +184,7 @@ func (api headscaleV1APIServer) ExpirePreAuthKey( ctx context.Context, request *v1.ExpirePreAuthKeyRequest, ) (*v1.ExpirePreAuthKeyResponse, error) { - preAuthKey, err := api.h.state.GetPreAuthKey(request.Key) - if err != nil { - return nil, err - } - - if uint64(preAuthKey.User.ID) != request.GetUser() { - return nil, fmt.Errorf("preauth key does not belong to user") - } - - err = api.h.state.ExpirePreAuthKey(preAuthKey) + err := api.h.state.ExpirePreAuthKey(request.GetId()) if err != nil { return nil, err } @@ -205,16 +196,7 @@ func (api headscaleV1APIServer) DeletePreAuthKey( ctx context.Context, request *v1.DeletePreAuthKeyRequest, ) (*v1.DeletePreAuthKeyResponse, error) { - preAuthKey, err := api.h.state.GetPreAuthKey(request.Key) - if err != nil { - return nil, err - } - - if uint64(preAuthKey.User.ID) != request.GetUser() { - return nil, fmt.Errorf("preauth key does not belong to user") - } - - err = api.h.state.DeletePreAuthKey(preAuthKey) + err := api.h.state.DeletePreAuthKey(request.GetId()) if err != nil { return nil, err } @@ -226,12 +208,7 @@ func (api headscaleV1APIServer) ListPreAuthKeys( ctx context.Context, request *v1.ListPreAuthKeysRequest, ) (*v1.ListPreAuthKeysResponse, error) { - user, err := api.h.state.GetUserByID(types.UserID(request.GetUser())) - if err != nil { - return nil, err - } - - preAuthKeys, err := api.h.state.ListPreAuthKeys(types.UserID(user.ID)) + preAuthKeys, err := api.h.state.ListPreAuthKeys() if err != nil { return nil, err } diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index 0a5fabdb..b8181079 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -1036,18 +1036,18 @@ func (s *State) GetPreAuthKey(id string) (*types.PreAuthKey, error) { } // ListPreAuthKeys returns all pre-authentication keys for a user. -func (s *State) ListPreAuthKeys(userID types.UserID) ([]types.PreAuthKey, error) { - return s.db.ListPreAuthKeys(userID) +func (s *State) ListPreAuthKeys() ([]types.PreAuthKey, error) { + return s.db.ListPreAuthKeys() } // ExpirePreAuthKey marks a pre-authentication key as expired. -func (s *State) ExpirePreAuthKey(preAuthKey *types.PreAuthKey) error { - return s.db.ExpirePreAuthKey(preAuthKey) +func (s *State) ExpirePreAuthKey(id uint64) error { + return s.db.ExpirePreAuthKey(id) } // DeletePreAuthKey permanently deletes a pre-authentication key. -func (s *State) DeletePreAuthKey(preAuthKey *types.PreAuthKey) error { - return s.db.DeletePreAuthKey(preAuthKey) +func (s *State) DeletePreAuthKey(id uint64) error { + return s.db.DeletePreAuthKey(id) } // GetRegistrationCacheEntry retrieves a node registration from cache.