diff --git a/hscontrol/db/api_key_test.go b/hscontrol/db/api_key_test.go index 5b1f1f1d..a34dd94b 100644 --- a/hscontrol/db/api_key_test.go +++ b/hscontrol/db/api_key_test.go @@ -246,3 +246,30 @@ func TestAPIKeyWithPrefix(t *testing.T) { }) } } + +func TestGetAPIKeyByID(t *testing.T) { + db, err := newSQLiteTestDB() + require.NoError(t, err) + + // Create an API key + _, apiKey, err := db.CreateAPIKey(nil) + require.NoError(t, err) + require.NotNil(t, apiKey) + + // Retrieve by ID + retrievedKey, err := db.GetAPIKeyByID(apiKey.ID) + require.NoError(t, err) + require.NotNil(t, retrievedKey) + assert.Equal(t, apiKey.ID, retrievedKey.ID) + assert.Equal(t, apiKey.Prefix, retrievedKey.Prefix) +} + +func TestGetAPIKeyByIDNotFound(t *testing.T) { + db, err := newSQLiteTestDB() + require.NoError(t, err) + + // Try to get a non-existent key by ID + key, err := db.GetAPIKeyByID(99999) + require.Error(t, err) + assert.Nil(t, key) +} diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index c969208f..a35a73af 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -577,14 +577,35 @@ func (api headscaleV1APIServer) CreateApiKey( return &v1.CreateApiKeyResponse{ApiKey: apiKey}, nil } +// apiKeyIdentifier is implemented by requests that identify an API key. +type apiKeyIdentifier interface { + GetId() uint64 + GetPrefix() string +} + +// getAPIKey retrieves an API key by ID or prefix from the request. +// Returns InvalidArgument if neither or both are provided. +func (api headscaleV1APIServer) getAPIKey(req apiKeyIdentifier) (*types.APIKey, error) { + hasID := req.GetId() != 0 + hasPrefix := req.GetPrefix() != "" + + switch { + case hasID && hasPrefix: + return nil, status.Error(codes.InvalidArgument, "provide either id or prefix, not both") + case hasID: + return api.h.state.GetAPIKeyByID(req.GetId()) + case hasPrefix: + return api.h.state.GetAPIKey(req.GetPrefix()) + default: + return nil, status.Error(codes.InvalidArgument, "must provide id or prefix") + } +} + func (api headscaleV1APIServer) ExpireApiKey( ctx context.Context, request *v1.ExpireApiKeyRequest, ) (*v1.ExpireApiKeyResponse, error) { - var apiKey *types.APIKey - var err error - - apiKey, err = api.h.state.GetAPIKey(request.Prefix) + apiKey, err := api.getAPIKey(request) if err != nil { return nil, err } @@ -622,12 +643,7 @@ func (api headscaleV1APIServer) DeleteApiKey( ctx context.Context, request *v1.DeleteApiKeyRequest, ) (*v1.DeleteApiKeyResponse, error) { - var ( - apiKey *types.APIKey - err error - ) - - apiKey, err = api.h.state.GetAPIKey(request.Prefix) + apiKey, err := api.getAPIKey(request) if err != nil { return nil, err } diff --git a/hscontrol/grpcv1_test.go b/hscontrol/grpcv1_test.go index 2c6417ce..4cf5b7d4 100644 --- a/hscontrol/grpcv1_test.go +++ b/hscontrol/grpcv1_test.go @@ -280,3 +280,189 @@ func TestDeleteUser_ReturnsProperChangeSignal(t *testing.T) { require.NoError(t, err, "DeleteUser should succeed") assert.False(t, changeSignal.IsEmpty(), "DeleteUser should return a non-empty change signal (issue #2967)") } + +// TestExpireApiKey_ByID tests that API keys can be expired by ID. +func TestExpireApiKey_ByID(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + apiServer := newHeadscaleV1APIServer(app) + + // Create an API key + createResp, err := apiServer.CreateApiKey(context.Background(), &v1.CreateApiKeyRequest{}) + require.NoError(t, err) + require.NotEmpty(t, createResp.GetApiKey()) + + // List keys to get the ID + listResp, err := apiServer.ListApiKeys(context.Background(), &v1.ListApiKeysRequest{}) + require.NoError(t, err) + require.Len(t, listResp.GetApiKeys(), 1) + + keyID := listResp.GetApiKeys()[0].GetId() + + // Expire by ID + _, err = apiServer.ExpireApiKey(context.Background(), &v1.ExpireApiKeyRequest{ + Id: keyID, + }) + require.NoError(t, err) + + // Verify key is expired (expiration is set to now or in the past) + listResp, err = apiServer.ListApiKeys(context.Background(), &v1.ListApiKeysRequest{}) + require.NoError(t, err) + require.Len(t, listResp.GetApiKeys(), 1) + assert.NotNil(t, listResp.GetApiKeys()[0].GetExpiration(), "expiration should be set") +} + +// TestExpireApiKey_ByPrefix tests that API keys can still be expired by prefix. +func TestExpireApiKey_ByPrefix(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + apiServer := newHeadscaleV1APIServer(app) + + // Create an API key + createResp, err := apiServer.CreateApiKey(context.Background(), &v1.CreateApiKeyRequest{}) + require.NoError(t, err) + require.NotEmpty(t, createResp.GetApiKey()) + + // List keys to get the prefix + listResp, err := apiServer.ListApiKeys(context.Background(), &v1.ListApiKeysRequest{}) + require.NoError(t, err) + require.Len(t, listResp.GetApiKeys(), 1) + + keyPrefix := listResp.GetApiKeys()[0].GetPrefix() + + // Expire by prefix + _, err = apiServer.ExpireApiKey(context.Background(), &v1.ExpireApiKeyRequest{ + Prefix: keyPrefix, + }) + require.NoError(t, err) +} + +// TestDeleteApiKey_ByID tests that API keys can be deleted by ID. +func TestDeleteApiKey_ByID(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + apiServer := newHeadscaleV1APIServer(app) + + // Create an API key + createResp, err := apiServer.CreateApiKey(context.Background(), &v1.CreateApiKeyRequest{}) + require.NoError(t, err) + require.NotEmpty(t, createResp.GetApiKey()) + + // List keys to get the ID + listResp, err := apiServer.ListApiKeys(context.Background(), &v1.ListApiKeysRequest{}) + require.NoError(t, err) + require.Len(t, listResp.GetApiKeys(), 1) + + keyID := listResp.GetApiKeys()[0].GetId() + + // Delete by ID + _, err = apiServer.DeleteApiKey(context.Background(), &v1.DeleteApiKeyRequest{ + Id: keyID, + }) + require.NoError(t, err) + + // Verify key is deleted + listResp, err = apiServer.ListApiKeys(context.Background(), &v1.ListApiKeysRequest{}) + require.NoError(t, err) + assert.Empty(t, listResp.GetApiKeys()) +} + +// TestDeleteApiKey_ByPrefix tests that API keys can still be deleted by prefix. +func TestDeleteApiKey_ByPrefix(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + apiServer := newHeadscaleV1APIServer(app) + + // Create an API key + createResp, err := apiServer.CreateApiKey(context.Background(), &v1.CreateApiKeyRequest{}) + require.NoError(t, err) + require.NotEmpty(t, createResp.GetApiKey()) + + // List keys to get the prefix + listResp, err := apiServer.ListApiKeys(context.Background(), &v1.ListApiKeysRequest{}) + require.NoError(t, err) + require.Len(t, listResp.GetApiKeys(), 1) + + keyPrefix := listResp.GetApiKeys()[0].GetPrefix() + + // Delete by prefix + _, err = apiServer.DeleteApiKey(context.Background(), &v1.DeleteApiKeyRequest{ + Prefix: keyPrefix, + }) + require.NoError(t, err) + + // Verify key is deleted + listResp, err = apiServer.ListApiKeys(context.Background(), &v1.ListApiKeysRequest{}) + require.NoError(t, err) + assert.Empty(t, listResp.GetApiKeys()) +} + +// TestExpireApiKey_NoIdentifier tests that an error is returned when neither ID nor prefix is provided. +func TestExpireApiKey_NoIdentifier(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + apiServer := newHeadscaleV1APIServer(app) + + _, err := apiServer.ExpireApiKey(context.Background(), &v1.ExpireApiKeyRequest{}) + require.Error(t, err) + st, ok := status.FromError(err) + require.True(t, ok, "error should be a gRPC status error") + assert.Equal(t, codes.InvalidArgument, st.Code()) + assert.Contains(t, st.Message(), "must provide id or prefix") +} + +// TestDeleteApiKey_NoIdentifier tests that an error is returned when neither ID nor prefix is provided. +func TestDeleteApiKey_NoIdentifier(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + apiServer := newHeadscaleV1APIServer(app) + + _, err := apiServer.DeleteApiKey(context.Background(), &v1.DeleteApiKeyRequest{}) + require.Error(t, err) + st, ok := status.FromError(err) + require.True(t, ok, "error should be a gRPC status error") + assert.Equal(t, codes.InvalidArgument, st.Code()) + assert.Contains(t, st.Message(), "must provide id or prefix") +} + +// TestExpireApiKey_BothIdentifiers tests that an error is returned when both ID and prefix are provided. +func TestExpireApiKey_BothIdentifiers(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + apiServer := newHeadscaleV1APIServer(app) + + _, err := apiServer.ExpireApiKey(context.Background(), &v1.ExpireApiKeyRequest{ + Id: 1, + Prefix: "test", + }) + require.Error(t, err) + st, ok := status.FromError(err) + require.True(t, ok, "error should be a gRPC status error") + assert.Equal(t, codes.InvalidArgument, st.Code()) + assert.Contains(t, st.Message(), "provide either id or prefix, not both") +} + +// TestDeleteApiKey_BothIdentifiers tests that an error is returned when both ID and prefix are provided. +func TestDeleteApiKey_BothIdentifiers(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + apiServer := newHeadscaleV1APIServer(app) + + _, err := apiServer.DeleteApiKey(context.Background(), &v1.DeleteApiKeyRequest{ + Id: 1, + Prefix: "test", + }) + require.Error(t, err) + st, ok := status.FromError(err) + require.True(t, ok, "error should be a gRPC status error") + assert.Equal(t, codes.InvalidArgument, st.Code()) + assert.Contains(t, st.Message(), "provide either id or prefix, not both") +}