mirror of
https://github.com/juanfont/headscale.git
synced 2026-01-22 18:18:00 +00:00
cli,hscontrol: use ID-based preauthkey operations
This commit is contained in:
parent
8631581852
commit
1325fd8b27
6 changed files with 43 additions and 112 deletions
|
|
@ -20,17 +20,6 @@ const (
|
|||
|
||||
func init() {
|
||||
rootCmd.AddCommand(preauthkeysCmd)
|
||||
preauthkeysCmd.PersistentFlags().Uint64P("user", "u", 0, "User identifier (ID)")
|
||||
|
||||
preauthkeysCmd.PersistentFlags().StringP("namespace", "n", "", "User")
|
||||
pakNamespaceFlag := preauthkeysCmd.PersistentFlags().Lookup("namespace")
|
||||
pakNamespaceFlag.Deprecated = deprecateNamespaceMessage
|
||||
pakNamespaceFlag.Hidden = true
|
||||
|
||||
err := preauthkeysCmd.MarkPersistentFlagRequired("user")
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("")
|
||||
}
|
||||
preauthkeysCmd.AddCommand(listPreAuthKeys)
|
||||
preauthkeysCmd.AddCommand(createPreAuthKeyCmd)
|
||||
preauthkeysCmd.AddCommand(expirePreAuthKeyCmd)
|
||||
|
|
@ -43,6 +32,9 @@ func init() {
|
|||
StringP("expiration", "e", DefaultPreAuthKeyExpiry, "Human-readable expiration of the key (e.g. 30m, 24h)")
|
||||
createPreAuthKeyCmd.Flags().
|
||||
StringSlice("tags", []string{}, "Tags to automatically assign to node")
|
||||
createPreAuthKeyCmd.PersistentFlags().Uint64P("user", "u", 0, "User identifier (ID)")
|
||||
expirePreAuthKeyCmd.PersistentFlags().Uint64P("id", "i", 0, "Authkey ID")
|
||||
deletePreAuthKeyCmd.PersistentFlags().Uint64P("id", "i", 0, "Authkey ID")
|
||||
}
|
||||
|
||||
var preauthkeysCmd = &cobra.Command{
|
||||
|
|
@ -53,25 +45,16 @@ var preauthkeysCmd = &cobra.Command{
|
|||
|
||||
var listPreAuthKeys = &cobra.Command{
|
||||
Use: "list",
|
||||
Short: "List the preauthkeys for this user",
|
||||
Short: "List all preauthkeys",
|
||||
Aliases: []string{"ls", "show"},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
|
||||
user, err := cmd.Flags().GetUint64("user")
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
|
||||
}
|
||||
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
request := &v1.ListPreAuthKeysRequest{
|
||||
User: user,
|
||||
}
|
||||
|
||||
response, err := client.ListPreAuthKeys(ctx, request)
|
||||
response, err := client.ListPreAuthKeys(ctx, &v1.ListPreAuthKeysRequest{})
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
|
|
@ -137,16 +120,12 @@ var listPreAuthKeys = &cobra.Command{
|
|||
|
||||
var createPreAuthKeyCmd = &cobra.Command{
|
||||
Use: "create",
|
||||
Short: "Creates a new preauthkey in the specified user",
|
||||
Short: "Creates a new preauthkey",
|
||||
Aliases: []string{"c", "new"},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
|
||||
user, err := cmd.Flags().GetUint64("user")
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
|
||||
}
|
||||
|
||||
user, _ := cmd.Flags().GetUint64("user")
|
||||
reusable, _ := cmd.Flags().GetBool("reusable")
|
||||
ephemeral, _ := cmd.Flags().GetBool("ephemeral")
|
||||
tags, _ := cmd.Flags().GetStringSlice("tags")
|
||||
|
|
@ -195,7 +174,7 @@ var createPreAuthKeyCmd = &cobra.Command{
|
|||
}
|
||||
|
||||
var expirePreAuthKeyCmd = &cobra.Command{
|
||||
Use: "expire KEY",
|
||||
Use: "expire",
|
||||
Short: "Expire a preauthkey",
|
||||
Aliases: []string{"revoke", "exp", "e"},
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
|
|
@ -207,18 +186,14 @@ var expirePreAuthKeyCmd = &cobra.Command{
|
|||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
user, err := cmd.Flags().GetUint64("user")
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
|
||||
}
|
||||
id, _ := cmd.Flags().GetUint64("id")
|
||||
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
request := &v1.ExpirePreAuthKeyRequest{
|
||||
User: user,
|
||||
Key: args[0],
|
||||
Id: id,
|
||||
}
|
||||
|
||||
response, err := client.ExpirePreAuthKey(ctx, request)
|
||||
|
|
@ -235,7 +210,7 @@ var expirePreAuthKeyCmd = &cobra.Command{
|
|||
}
|
||||
|
||||
var deletePreAuthKeyCmd = &cobra.Command{
|
||||
Use: "delete KEY",
|
||||
Use: "delete",
|
||||
Short: "Delete a preauthkey",
|
||||
Aliases: []string{"del", "rm", "d"},
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
|
|
@ -247,18 +222,14 @@ var deletePreAuthKeyCmd = &cobra.Command{
|
|||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
user, err := cmd.Flags().GetUint64("user")
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
|
||||
}
|
||||
id, _ := cmd.Flags().GetUint64("id")
|
||||
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
request := &v1.DeletePreAuthKeyRequest{
|
||||
User: user,
|
||||
Key: args[0],
|
||||
Id: id,
|
||||
}
|
||||
|
||||
response, err := client.DeletePreAuthKey(ctx, request)
|
||||
|
|
|
|||
|
|
@ -57,10 +57,6 @@ func CreatePreAuthKey(
|
|||
return nil, ErrPreAuthKeyNotTaggedOrOwned
|
||||
}
|
||||
|
||||
// If uid != nil && len(aclTags) > 0:
|
||||
// Both are allowed: UserID tracks "created by", tags define node ownership
|
||||
// This is valid per the new model
|
||||
|
||||
var (
|
||||
user *types.User
|
||||
userID *uint
|
||||
|
|
@ -158,22 +154,17 @@ func CreatePreAuthKey(
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) ListPreAuthKeys(uid types.UserID) ([]types.PreAuthKey, error) {
|
||||
func (hsdb *HSDatabase) ListPreAuthKeys() ([]types.PreAuthKey, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) {
|
||||
return ListPreAuthKeysByUser(rx, uid)
|
||||
return ListPreAuthKeys(rx)
|
||||
})
|
||||
}
|
||||
|
||||
// ListPreAuthKeysByUser returns the list of PreAuthKeys for a user.
|
||||
func ListPreAuthKeysByUser(tx *gorm.DB, uid types.UserID) ([]types.PreAuthKey, error) {
|
||||
user, err := GetUserByID(tx, uid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// ListPreAuthKeys returns the list of PreAuthKeys for a user.
|
||||
func ListPreAuthKeys(tx *gorm.DB) ([]types.PreAuthKey, error) {
|
||||
var keys []types.PreAuthKey
|
||||
|
||||
keys := []types.PreAuthKey{}
|
||||
|
||||
err = tx.Preload("User").Where(&types.PreAuthKey{UserID: &user.ID}).Find(&keys).Error
|
||||
err := tx.Preload("User").Find(&keys).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -298,34 +289,35 @@ func GetPreAuthKey(tx *gorm.DB, key string) (*types.PreAuthKey, error) {
|
|||
// DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey
|
||||
// does not exist. This also clears the auth_key_id on any nodes that reference
|
||||
// this key.
|
||||
func DestroyPreAuthKey(tx *gorm.DB, pak types.PreAuthKey) error {
|
||||
func DestroyPreAuthKey(tx *gorm.DB, id uint64) error {
|
||||
return tx.Transaction(func(db *gorm.DB) error {
|
||||
// First, clear the foreign key reference on any nodes using this key
|
||||
err := db.Model(&types.Node{}).
|
||||
Where("auth_key_id = ?", pak.ID).
|
||||
Where("auth_key_id = ?", id).
|
||||
Update("auth_key_id", nil).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to clear auth_key_id on nodes: %w", err)
|
||||
}
|
||||
|
||||
// Then delete the pre-auth key
|
||||
if result := db.Unscoped().Delete(pak); result.Error != nil {
|
||||
return result.Error
|
||||
err = tx.Unscoped().Delete(&types.PreAuthKey{}, id).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error {
|
||||
func (hsdb *HSDatabase) ExpirePreAuthKey(id uint64) error {
|
||||
return hsdb.Write(func(tx *gorm.DB) error {
|
||||
return ExpirePreAuthKey(tx, k)
|
||||
return ExpirePreAuthKey(tx, id)
|
||||
})
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) DeletePreAuthKey(k *types.PreAuthKey) error {
|
||||
func (hsdb *HSDatabase) DeletePreAuthKey(id uint64) error {
|
||||
return hsdb.Write(func(tx *gorm.DB) error {
|
||||
return DestroyPreAuthKey(tx, *k)
|
||||
return DestroyPreAuthKey(tx, id)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -341,7 +333,7 @@ func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
|
|||
}
|
||||
|
||||
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
|
||||
func ExpirePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
|
||||
func ExpirePreAuthKey(tx *gorm.DB, id uint64) error {
|
||||
now := time.Now()
|
||||
return tx.Model(&types.PreAuthKey{}).Where("id = ?", k.ID).Update("expiration", now).Error
|
||||
return tx.Model(&types.PreAuthKey{}).Where("id = ?", id).Update("expiration", now).Error
|
||||
}
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ func TestCreatePreAuthKey(t *testing.T) {
|
|||
assert.NotEmpty(t, key.Key)
|
||||
|
||||
// List keys for the user
|
||||
keys, err := db.ListPreAuthKeys(types.UserID(user.ID))
|
||||
keys, err := db.ListPreAuthKeys()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, keys, 1)
|
||||
|
||||
|
|
@ -49,15 +49,6 @@ func TestCreatePreAuthKey(t *testing.T) {
|
|||
assert.Equal(t, user.ID, keys[0].User.ID)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "error_list_invalid_user_id",
|
||||
test: func(t *testing.T, db *HSDatabase) {
|
||||
t.Helper()
|
||||
|
||||
_, err := db.ListPreAuthKeys(1000000)
|
||||
assert.Error(t, err)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
|
@ -101,7 +92,7 @@ func TestPreAuthKeyACLTags(t *testing.T) {
|
|||
_, err = db.CreatePreAuthKey(user.TypedID(), false, false, nil, tagsWithDuplicate)
|
||||
require.NoError(t, err)
|
||||
|
||||
listedPaks, err := db.ListPreAuthKeys(types.UserID(user.ID))
|
||||
listedPaks, err := db.ListPreAuthKeys()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, listedPaks, 1)
|
||||
|
||||
|
|
|
|||
|
|
@ -58,12 +58,12 @@ func DestroyUser(tx *gorm.DB, uid types.UserID) error {
|
|||
return ErrUserStillHasNodes
|
||||
}
|
||||
|
||||
keys, err := ListPreAuthKeysByUser(tx, uid)
|
||||
keys, err := ListPreAuthKeys(tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, key := range keys {
|
||||
err = DestroyPreAuthKey(tx, key)
|
||||
err = DestroyPreAuthKey(tx, key.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -184,16 +184,7 @@ func (api headscaleV1APIServer) ExpirePreAuthKey(
|
|||
ctx context.Context,
|
||||
request *v1.ExpirePreAuthKeyRequest,
|
||||
) (*v1.ExpirePreAuthKeyResponse, error) {
|
||||
preAuthKey, err := api.h.state.GetPreAuthKey(request.Key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if uint64(preAuthKey.User.ID) != request.GetUser() {
|
||||
return nil, fmt.Errorf("preauth key does not belong to user")
|
||||
}
|
||||
|
||||
err = api.h.state.ExpirePreAuthKey(preAuthKey)
|
||||
err := api.h.state.ExpirePreAuthKey(request.GetId())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -205,16 +196,7 @@ func (api headscaleV1APIServer) DeletePreAuthKey(
|
|||
ctx context.Context,
|
||||
request *v1.DeletePreAuthKeyRequest,
|
||||
) (*v1.DeletePreAuthKeyResponse, error) {
|
||||
preAuthKey, err := api.h.state.GetPreAuthKey(request.Key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if uint64(preAuthKey.User.ID) != request.GetUser() {
|
||||
return nil, fmt.Errorf("preauth key does not belong to user")
|
||||
}
|
||||
|
||||
err = api.h.state.DeletePreAuthKey(preAuthKey)
|
||||
err := api.h.state.DeletePreAuthKey(request.GetId())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -226,12 +208,7 @@ func (api headscaleV1APIServer) ListPreAuthKeys(
|
|||
ctx context.Context,
|
||||
request *v1.ListPreAuthKeysRequest,
|
||||
) (*v1.ListPreAuthKeysResponse, error) {
|
||||
user, err := api.h.state.GetUserByID(types.UserID(request.GetUser()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
preAuthKeys, err := api.h.state.ListPreAuthKeys(types.UserID(user.ID))
|
||||
preAuthKeys, err := api.h.state.ListPreAuthKeys()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1036,18 +1036,18 @@ func (s *State) GetPreAuthKey(id string) (*types.PreAuthKey, error) {
|
|||
}
|
||||
|
||||
// ListPreAuthKeys returns all pre-authentication keys for a user.
|
||||
func (s *State) ListPreAuthKeys(userID types.UserID) ([]types.PreAuthKey, error) {
|
||||
return s.db.ListPreAuthKeys(userID)
|
||||
func (s *State) ListPreAuthKeys() ([]types.PreAuthKey, error) {
|
||||
return s.db.ListPreAuthKeys()
|
||||
}
|
||||
|
||||
// ExpirePreAuthKey marks a pre-authentication key as expired.
|
||||
func (s *State) ExpirePreAuthKey(preAuthKey *types.PreAuthKey) error {
|
||||
return s.db.ExpirePreAuthKey(preAuthKey)
|
||||
func (s *State) ExpirePreAuthKey(id uint64) error {
|
||||
return s.db.ExpirePreAuthKey(id)
|
||||
}
|
||||
|
||||
// DeletePreAuthKey permanently deletes a pre-authentication key.
|
||||
func (s *State) DeletePreAuthKey(preAuthKey *types.PreAuthKey) error {
|
||||
return s.db.DeletePreAuthKey(preAuthKey)
|
||||
func (s *State) DeletePreAuthKey(id uint64) error {
|
||||
return s.db.DeletePreAuthKey(id)
|
||||
}
|
||||
|
||||
// GetRegistrationCacheEntry retrieves a node registration from cache.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue