From 5b84ea3be8a4bf5304b8df44ca5e567362941620 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Fri, 16 Jan 2026 13:57:49 +0000 Subject: [PATCH] grpc: support expire/delete API keys by ID Update ExpireApiKey and DeleteApiKey handlers to accept either ID or prefix for identifying the API key. Returns InvalidArgument error if neither or both are provided. Add tests for: - Expire by ID - Expire by prefix (backwards compatibility) - Delete by ID - Delete by prefix (backwards compatibility) - Error when neither ID nor prefix provided - Error when both ID and prefix provided Updates #2986 --- hscontrol/db/api_key_test.go | 27 +++++ hscontrol/grpcv1.go | 36 +++++-- hscontrol/grpcv1_test.go | 186 +++++++++++++++++++++++++++++++++++ 3 files changed, 239 insertions(+), 10 deletions(-) diff --git a/hscontrol/db/api_key_test.go b/hscontrol/db/api_key_test.go index 5b1f1f1d..a34dd94b 100644 --- a/hscontrol/db/api_key_test.go +++ b/hscontrol/db/api_key_test.go @@ -246,3 +246,30 @@ func TestAPIKeyWithPrefix(t *testing.T) { }) } } + +func TestGetAPIKeyByID(t *testing.T) { + db, err := newSQLiteTestDB() + require.NoError(t, err) + + // Create an API key + _, apiKey, err := db.CreateAPIKey(nil) + require.NoError(t, err) + require.NotNil(t, apiKey) + + // Retrieve by ID + retrievedKey, err := db.GetAPIKeyByID(apiKey.ID) + require.NoError(t, err) + require.NotNil(t, retrievedKey) + assert.Equal(t, apiKey.ID, retrievedKey.ID) + assert.Equal(t, apiKey.Prefix, retrievedKey.Prefix) +} + +func TestGetAPIKeyByIDNotFound(t *testing.T) { + db, err := newSQLiteTestDB() + require.NoError(t, err) + + // Try to get a non-existent key by ID + key, err := db.GetAPIKeyByID(99999) + require.Error(t, err) + assert.Nil(t, key) +} diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index c969208f..a35a73af 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -577,14 +577,35 @@ func (api headscaleV1APIServer) CreateApiKey( return &v1.CreateApiKeyResponse{ApiKey: apiKey}, nil } +// apiKeyIdentifier is implemented by requests that identify an API key. +type apiKeyIdentifier interface { + GetId() uint64 + GetPrefix() string +} + +// getAPIKey retrieves an API key by ID or prefix from the request. +// Returns InvalidArgument if neither or both are provided. +func (api headscaleV1APIServer) getAPIKey(req apiKeyIdentifier) (*types.APIKey, error) { + hasID := req.GetId() != 0 + hasPrefix := req.GetPrefix() != "" + + switch { + case hasID && hasPrefix: + return nil, status.Error(codes.InvalidArgument, "provide either id or prefix, not both") + case hasID: + return api.h.state.GetAPIKeyByID(req.GetId()) + case hasPrefix: + return api.h.state.GetAPIKey(req.GetPrefix()) + default: + return nil, status.Error(codes.InvalidArgument, "must provide id or prefix") + } +} + func (api headscaleV1APIServer) ExpireApiKey( ctx context.Context, request *v1.ExpireApiKeyRequest, ) (*v1.ExpireApiKeyResponse, error) { - var apiKey *types.APIKey - var err error - - apiKey, err = api.h.state.GetAPIKey(request.Prefix) + apiKey, err := api.getAPIKey(request) if err != nil { return nil, err } @@ -622,12 +643,7 @@ func (api headscaleV1APIServer) DeleteApiKey( ctx context.Context, request *v1.DeleteApiKeyRequest, ) (*v1.DeleteApiKeyResponse, error) { - var ( - apiKey *types.APIKey - err error - ) - - apiKey, err = api.h.state.GetAPIKey(request.Prefix) + apiKey, err := api.getAPIKey(request) if err != nil { return nil, err } diff --git a/hscontrol/grpcv1_test.go b/hscontrol/grpcv1_test.go index 2c6417ce..4cf5b7d4 100644 --- a/hscontrol/grpcv1_test.go +++ b/hscontrol/grpcv1_test.go @@ -280,3 +280,189 @@ func TestDeleteUser_ReturnsProperChangeSignal(t *testing.T) { require.NoError(t, err, "DeleteUser should succeed") assert.False(t, changeSignal.IsEmpty(), "DeleteUser should return a non-empty change signal (issue #2967)") } + +// TestExpireApiKey_ByID tests that API keys can be expired by ID. +func TestExpireApiKey_ByID(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + apiServer := newHeadscaleV1APIServer(app) + + // Create an API key + createResp, err := apiServer.CreateApiKey(context.Background(), &v1.CreateApiKeyRequest{}) + require.NoError(t, err) + require.NotEmpty(t, createResp.GetApiKey()) + + // List keys to get the ID + listResp, err := apiServer.ListApiKeys(context.Background(), &v1.ListApiKeysRequest{}) + require.NoError(t, err) + require.Len(t, listResp.GetApiKeys(), 1) + + keyID := listResp.GetApiKeys()[0].GetId() + + // Expire by ID + _, err = apiServer.ExpireApiKey(context.Background(), &v1.ExpireApiKeyRequest{ + Id: keyID, + }) + require.NoError(t, err) + + // Verify key is expired (expiration is set to now or in the past) + listResp, err = apiServer.ListApiKeys(context.Background(), &v1.ListApiKeysRequest{}) + require.NoError(t, err) + require.Len(t, listResp.GetApiKeys(), 1) + assert.NotNil(t, listResp.GetApiKeys()[0].GetExpiration(), "expiration should be set") +} + +// TestExpireApiKey_ByPrefix tests that API keys can still be expired by prefix. +func TestExpireApiKey_ByPrefix(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + apiServer := newHeadscaleV1APIServer(app) + + // Create an API key + createResp, err := apiServer.CreateApiKey(context.Background(), &v1.CreateApiKeyRequest{}) + require.NoError(t, err) + require.NotEmpty(t, createResp.GetApiKey()) + + // List keys to get the prefix + listResp, err := apiServer.ListApiKeys(context.Background(), &v1.ListApiKeysRequest{}) + require.NoError(t, err) + require.Len(t, listResp.GetApiKeys(), 1) + + keyPrefix := listResp.GetApiKeys()[0].GetPrefix() + + // Expire by prefix + _, err = apiServer.ExpireApiKey(context.Background(), &v1.ExpireApiKeyRequest{ + Prefix: keyPrefix, + }) + require.NoError(t, err) +} + +// TestDeleteApiKey_ByID tests that API keys can be deleted by ID. +func TestDeleteApiKey_ByID(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + apiServer := newHeadscaleV1APIServer(app) + + // Create an API key + createResp, err := apiServer.CreateApiKey(context.Background(), &v1.CreateApiKeyRequest{}) + require.NoError(t, err) + require.NotEmpty(t, createResp.GetApiKey()) + + // List keys to get the ID + listResp, err := apiServer.ListApiKeys(context.Background(), &v1.ListApiKeysRequest{}) + require.NoError(t, err) + require.Len(t, listResp.GetApiKeys(), 1) + + keyID := listResp.GetApiKeys()[0].GetId() + + // Delete by ID + _, err = apiServer.DeleteApiKey(context.Background(), &v1.DeleteApiKeyRequest{ + Id: keyID, + }) + require.NoError(t, err) + + // Verify key is deleted + listResp, err = apiServer.ListApiKeys(context.Background(), &v1.ListApiKeysRequest{}) + require.NoError(t, err) + assert.Empty(t, listResp.GetApiKeys()) +} + +// TestDeleteApiKey_ByPrefix tests that API keys can still be deleted by prefix. +func TestDeleteApiKey_ByPrefix(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + apiServer := newHeadscaleV1APIServer(app) + + // Create an API key + createResp, err := apiServer.CreateApiKey(context.Background(), &v1.CreateApiKeyRequest{}) + require.NoError(t, err) + require.NotEmpty(t, createResp.GetApiKey()) + + // List keys to get the prefix + listResp, err := apiServer.ListApiKeys(context.Background(), &v1.ListApiKeysRequest{}) + require.NoError(t, err) + require.Len(t, listResp.GetApiKeys(), 1) + + keyPrefix := listResp.GetApiKeys()[0].GetPrefix() + + // Delete by prefix + _, err = apiServer.DeleteApiKey(context.Background(), &v1.DeleteApiKeyRequest{ + Prefix: keyPrefix, + }) + require.NoError(t, err) + + // Verify key is deleted + listResp, err = apiServer.ListApiKeys(context.Background(), &v1.ListApiKeysRequest{}) + require.NoError(t, err) + assert.Empty(t, listResp.GetApiKeys()) +} + +// TestExpireApiKey_NoIdentifier tests that an error is returned when neither ID nor prefix is provided. +func TestExpireApiKey_NoIdentifier(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + apiServer := newHeadscaleV1APIServer(app) + + _, err := apiServer.ExpireApiKey(context.Background(), &v1.ExpireApiKeyRequest{}) + require.Error(t, err) + st, ok := status.FromError(err) + require.True(t, ok, "error should be a gRPC status error") + assert.Equal(t, codes.InvalidArgument, st.Code()) + assert.Contains(t, st.Message(), "must provide id or prefix") +} + +// TestDeleteApiKey_NoIdentifier tests that an error is returned when neither ID nor prefix is provided. +func TestDeleteApiKey_NoIdentifier(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + apiServer := newHeadscaleV1APIServer(app) + + _, err := apiServer.DeleteApiKey(context.Background(), &v1.DeleteApiKeyRequest{}) + require.Error(t, err) + st, ok := status.FromError(err) + require.True(t, ok, "error should be a gRPC status error") + assert.Equal(t, codes.InvalidArgument, st.Code()) + assert.Contains(t, st.Message(), "must provide id or prefix") +} + +// TestExpireApiKey_BothIdentifiers tests that an error is returned when both ID and prefix are provided. +func TestExpireApiKey_BothIdentifiers(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + apiServer := newHeadscaleV1APIServer(app) + + _, err := apiServer.ExpireApiKey(context.Background(), &v1.ExpireApiKeyRequest{ + Id: 1, + Prefix: "test", + }) + require.Error(t, err) + st, ok := status.FromError(err) + require.True(t, ok, "error should be a gRPC status error") + assert.Equal(t, codes.InvalidArgument, st.Code()) + assert.Contains(t, st.Message(), "provide either id or prefix, not both") +} + +// TestDeleteApiKey_BothIdentifiers tests that an error is returned when both ID and prefix are provided. +func TestDeleteApiKey_BothIdentifiers(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + apiServer := newHeadscaleV1APIServer(app) + + _, err := apiServer.DeleteApiKey(context.Background(), &v1.DeleteApiKeyRequest{ + Id: 1, + Prefix: "test", + }) + require.Error(t, err) + st, ok := status.FromError(err) + require.True(t, ok, "error should be a gRPC status error") + assert.Equal(t, codes.InvalidArgument, st.Code()) + assert.Contains(t, st.Message(), "provide either id or prefix, not both") +}