cli,hscontrol: use ID-based preauthkey operations

This commit is contained in:
Kristoffer Dalby 2026-01-07 13:36:51 +01:00 committed by Kristoffer Dalby
parent 8631581852
commit 1325fd8b27
6 changed files with 43 additions and 112 deletions

View file

@ -20,17 +20,6 @@ const (
func init() {
rootCmd.AddCommand(preauthkeysCmd)
preauthkeysCmd.PersistentFlags().Uint64P("user", "u", 0, "User identifier (ID)")
preauthkeysCmd.PersistentFlags().StringP("namespace", "n", "", "User")
pakNamespaceFlag := preauthkeysCmd.PersistentFlags().Lookup("namespace")
pakNamespaceFlag.Deprecated = deprecateNamespaceMessage
pakNamespaceFlag.Hidden = true
err := preauthkeysCmd.MarkPersistentFlagRequired("user")
if err != nil {
log.Fatal().Err(err).Msg("")
}
preauthkeysCmd.AddCommand(listPreAuthKeys)
preauthkeysCmd.AddCommand(createPreAuthKeyCmd)
preauthkeysCmd.AddCommand(expirePreAuthKeyCmd)
@ -43,6 +32,9 @@ func init() {
StringP("expiration", "e", DefaultPreAuthKeyExpiry, "Human-readable expiration of the key (e.g. 30m, 24h)")
createPreAuthKeyCmd.Flags().
StringSlice("tags", []string{}, "Tags to automatically assign to node")
createPreAuthKeyCmd.PersistentFlags().Uint64P("user", "u", 0, "User identifier (ID)")
expirePreAuthKeyCmd.PersistentFlags().Uint64P("id", "i", 0, "Authkey ID")
deletePreAuthKeyCmd.PersistentFlags().Uint64P("id", "i", 0, "Authkey ID")
}
var preauthkeysCmd = &cobra.Command{
@ -53,25 +45,16 @@ var preauthkeysCmd = &cobra.Command{
var listPreAuthKeys = &cobra.Command{
Use: "list",
Short: "List the preauthkeys for this user",
Short: "List all preauthkeys",
Aliases: []string{"ls", "show"},
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
user, err := cmd.Flags().GetUint64("user")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
}
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
request := &v1.ListPreAuthKeysRequest{
User: user,
}
response, err := client.ListPreAuthKeys(ctx, request)
response, err := client.ListPreAuthKeys(ctx, &v1.ListPreAuthKeysRequest{})
if err != nil {
ErrorOutput(
err,
@ -137,16 +120,12 @@ var listPreAuthKeys = &cobra.Command{
var createPreAuthKeyCmd = &cobra.Command{
Use: "create",
Short: "Creates a new preauthkey in the specified user",
Short: "Creates a new preauthkey",
Aliases: []string{"c", "new"},
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
user, err := cmd.Flags().GetUint64("user")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
}
user, _ := cmd.Flags().GetUint64("user")
reusable, _ := cmd.Flags().GetBool("reusable")
ephemeral, _ := cmd.Flags().GetBool("ephemeral")
tags, _ := cmd.Flags().GetStringSlice("tags")
@ -195,7 +174,7 @@ var createPreAuthKeyCmd = &cobra.Command{
}
var expirePreAuthKeyCmd = &cobra.Command{
Use: "expire KEY",
Use: "expire",
Short: "Expire a preauthkey",
Aliases: []string{"revoke", "exp", "e"},
Args: func(cmd *cobra.Command, args []string) error {
@ -207,18 +186,14 @@ var expirePreAuthKeyCmd = &cobra.Command{
},
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
user, err := cmd.Flags().GetUint64("user")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
}
id, _ := cmd.Flags().GetUint64("id")
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
request := &v1.ExpirePreAuthKeyRequest{
User: user,
Key: args[0],
Id: id,
}
response, err := client.ExpirePreAuthKey(ctx, request)
@ -235,7 +210,7 @@ var expirePreAuthKeyCmd = &cobra.Command{
}
var deletePreAuthKeyCmd = &cobra.Command{
Use: "delete KEY",
Use: "delete",
Short: "Delete a preauthkey",
Aliases: []string{"del", "rm", "d"},
Args: func(cmd *cobra.Command, args []string) error {
@ -247,18 +222,14 @@ var deletePreAuthKeyCmd = &cobra.Command{
},
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
user, err := cmd.Flags().GetUint64("user")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
}
id, _ := cmd.Flags().GetUint64("id")
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
request := &v1.DeletePreAuthKeyRequest{
User: user,
Key: args[0],
Id: id,
}
response, err := client.DeletePreAuthKey(ctx, request)

View file

@ -57,10 +57,6 @@ func CreatePreAuthKey(
return nil, ErrPreAuthKeyNotTaggedOrOwned
}
// If uid != nil && len(aclTags) > 0:
// Both are allowed: UserID tracks "created by", tags define node ownership
// This is valid per the new model
var (
user *types.User
userID *uint
@ -158,22 +154,17 @@ func CreatePreAuthKey(
}, nil
}
func (hsdb *HSDatabase) ListPreAuthKeys(uid types.UserID) ([]types.PreAuthKey, error) {
func (hsdb *HSDatabase) ListPreAuthKeys() ([]types.PreAuthKey, error) {
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) {
return ListPreAuthKeysByUser(rx, uid)
return ListPreAuthKeys(rx)
})
}
// ListPreAuthKeysByUser returns the list of PreAuthKeys for a user.
func ListPreAuthKeysByUser(tx *gorm.DB, uid types.UserID) ([]types.PreAuthKey, error) {
user, err := GetUserByID(tx, uid)
if err != nil {
return nil, err
}
// ListPreAuthKeys returns the list of PreAuthKeys for a user.
func ListPreAuthKeys(tx *gorm.DB) ([]types.PreAuthKey, error) {
var keys []types.PreAuthKey
keys := []types.PreAuthKey{}
err = tx.Preload("User").Where(&types.PreAuthKey{UserID: &user.ID}).Find(&keys).Error
err := tx.Preload("User").Find(&keys).Error
if err != nil {
return nil, err
}
@ -298,34 +289,35 @@ func GetPreAuthKey(tx *gorm.DB, key string) (*types.PreAuthKey, error) {
// DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey
// does not exist. This also clears the auth_key_id on any nodes that reference
// this key.
func DestroyPreAuthKey(tx *gorm.DB, pak types.PreAuthKey) error {
func DestroyPreAuthKey(tx *gorm.DB, id uint64) error {
return tx.Transaction(func(db *gorm.DB) error {
// First, clear the foreign key reference on any nodes using this key
err := db.Model(&types.Node{}).
Where("auth_key_id = ?", pak.ID).
Where("auth_key_id = ?", id).
Update("auth_key_id", nil).Error
if err != nil {
return fmt.Errorf("failed to clear auth_key_id on nodes: %w", err)
}
// Then delete the pre-auth key
if result := db.Unscoped().Delete(pak); result.Error != nil {
return result.Error
err = tx.Unscoped().Delete(&types.PreAuthKey{}, id).Error
if err != nil {
return err
}
return nil
})
}
func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error {
func (hsdb *HSDatabase) ExpirePreAuthKey(id uint64) error {
return hsdb.Write(func(tx *gorm.DB) error {
return ExpirePreAuthKey(tx, k)
return ExpirePreAuthKey(tx, id)
})
}
func (hsdb *HSDatabase) DeletePreAuthKey(k *types.PreAuthKey) error {
func (hsdb *HSDatabase) DeletePreAuthKey(id uint64) error {
return hsdb.Write(func(tx *gorm.DB) error {
return DestroyPreAuthKey(tx, *k)
return DestroyPreAuthKey(tx, id)
})
}
@ -341,7 +333,7 @@ func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
}
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
func ExpirePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
func ExpirePreAuthKey(tx *gorm.DB, id uint64) error {
now := time.Now()
return tx.Model(&types.PreAuthKey{}).Where("id = ?", k.ID).Update("expiration", now).Error
return tx.Model(&types.PreAuthKey{}).Where("id = ?", id).Update("expiration", now).Error
}

View file

@ -41,7 +41,7 @@ func TestCreatePreAuthKey(t *testing.T) {
assert.NotEmpty(t, key.Key)
// List keys for the user
keys, err := db.ListPreAuthKeys(types.UserID(user.ID))
keys, err := db.ListPreAuthKeys()
require.NoError(t, err)
assert.Len(t, keys, 1)
@ -49,15 +49,6 @@ func TestCreatePreAuthKey(t *testing.T) {
assert.Equal(t, user.ID, keys[0].User.ID)
},
},
{
name: "error_list_invalid_user_id",
test: func(t *testing.T, db *HSDatabase) {
t.Helper()
_, err := db.ListPreAuthKeys(1000000)
assert.Error(t, err)
},
},
}
for _, tt := range tests {
@ -101,7 +92,7 @@ func TestPreAuthKeyACLTags(t *testing.T) {
_, err = db.CreatePreAuthKey(user.TypedID(), false, false, nil, tagsWithDuplicate)
require.NoError(t, err)
listedPaks, err := db.ListPreAuthKeys(types.UserID(user.ID))
listedPaks, err := db.ListPreAuthKeys()
require.NoError(t, err)
require.Len(t, listedPaks, 1)

View file

@ -58,12 +58,12 @@ func DestroyUser(tx *gorm.DB, uid types.UserID) error {
return ErrUserStillHasNodes
}
keys, err := ListPreAuthKeysByUser(tx, uid)
keys, err := ListPreAuthKeys(tx)
if err != nil {
return err
}
for _, key := range keys {
err = DestroyPreAuthKey(tx, key)
err = DestroyPreAuthKey(tx, key.ID)
if err != nil {
return err
}

View file

@ -184,16 +184,7 @@ func (api headscaleV1APIServer) ExpirePreAuthKey(
ctx context.Context,
request *v1.ExpirePreAuthKeyRequest,
) (*v1.ExpirePreAuthKeyResponse, error) {
preAuthKey, err := api.h.state.GetPreAuthKey(request.Key)
if err != nil {
return nil, err
}
if uint64(preAuthKey.User.ID) != request.GetUser() {
return nil, fmt.Errorf("preauth key does not belong to user")
}
err = api.h.state.ExpirePreAuthKey(preAuthKey)
err := api.h.state.ExpirePreAuthKey(request.GetId())
if err != nil {
return nil, err
}
@ -205,16 +196,7 @@ func (api headscaleV1APIServer) DeletePreAuthKey(
ctx context.Context,
request *v1.DeletePreAuthKeyRequest,
) (*v1.DeletePreAuthKeyResponse, error) {
preAuthKey, err := api.h.state.GetPreAuthKey(request.Key)
if err != nil {
return nil, err
}
if uint64(preAuthKey.User.ID) != request.GetUser() {
return nil, fmt.Errorf("preauth key does not belong to user")
}
err = api.h.state.DeletePreAuthKey(preAuthKey)
err := api.h.state.DeletePreAuthKey(request.GetId())
if err != nil {
return nil, err
}
@ -226,12 +208,7 @@ func (api headscaleV1APIServer) ListPreAuthKeys(
ctx context.Context,
request *v1.ListPreAuthKeysRequest,
) (*v1.ListPreAuthKeysResponse, error) {
user, err := api.h.state.GetUserByID(types.UserID(request.GetUser()))
if err != nil {
return nil, err
}
preAuthKeys, err := api.h.state.ListPreAuthKeys(types.UserID(user.ID))
preAuthKeys, err := api.h.state.ListPreAuthKeys()
if err != nil {
return nil, err
}

View file

@ -1036,18 +1036,18 @@ func (s *State) GetPreAuthKey(id string) (*types.PreAuthKey, error) {
}
// ListPreAuthKeys returns all pre-authentication keys for a user.
func (s *State) ListPreAuthKeys(userID types.UserID) ([]types.PreAuthKey, error) {
return s.db.ListPreAuthKeys(userID)
func (s *State) ListPreAuthKeys() ([]types.PreAuthKey, error) {
return s.db.ListPreAuthKeys()
}
// ExpirePreAuthKey marks a pre-authentication key as expired.
func (s *State) ExpirePreAuthKey(preAuthKey *types.PreAuthKey) error {
return s.db.ExpirePreAuthKey(preAuthKey)
func (s *State) ExpirePreAuthKey(id uint64) error {
return s.db.ExpirePreAuthKey(id)
}
// DeletePreAuthKey permanently deletes a pre-authentication key.
func (s *State) DeletePreAuthKey(preAuthKey *types.PreAuthKey) error {
return s.db.DeletePreAuthKey(preAuthKey)
func (s *State) DeletePreAuthKey(id uint64) error {
return s.db.DeletePreAuthKey(id)
}
// GetRegistrationCacheEntry retrieves a node registration from cache.