From 3843036d13947db7f8b91a03620304d780b2f4a3 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 20 Jan 2026 15:13:52 +0000 Subject: [PATCH] all: use context-aware methods for exec, database, and HTTP MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace direct calls with context-aware versions: - exec.Command → exec.CommandContext - db.Exec → db.ExecContext - db.Ping → db.PingContext - db.QueryRow → db.QueryRowContext - http.NewRequest → http.NewRequestWithContext - net.LookupIP → net.DefaultResolver.LookupIPAddr --- cmd/hi/docker.go | 2 +- cmd/hi/doctor.go | 6 +++--- hscontrol/db/db.go | 2 +- hscontrol/db/db_test.go | 5 +++-- hscontrol/db/sqliteconfig/integration_test.go | 17 +++++++++-------- hscontrol/derp/server/derp_server.go | 6 +++--- integration/api_auth_test.go | 9 +++++---- 7 files changed, 25 insertions(+), 22 deletions(-) diff --git a/cmd/hi/docker.go b/cmd/hi/docker.go index fbc2dba6..81f1d729 100644 --- a/cmd/hi/docker.go +++ b/cmd/hi/docker.go @@ -475,7 +475,7 @@ func createDockerClient() (*client.Client, error) { // getCurrentDockerContext retrieves the current Docker context information. func getCurrentDockerContext() (*DockerContext, error) { - cmd := exec.Command("docker", "context", "inspect") + cmd := exec.CommandContext(context.Background(), "docker", "context", "inspect") output, err := cmd.Output() if err != nil { diff --git a/cmd/hi/doctor.go b/cmd/hi/doctor.go index 8ebda159..2bfc41fd 100644 --- a/cmd/hi/doctor.go +++ b/cmd/hi/doctor.go @@ -265,7 +265,7 @@ func checkGoInstallation() DoctorResult { } } - cmd := exec.Command("go", "version") + cmd := exec.CommandContext(context.Background(), "go", "version") output, err := cmd.Output() if err != nil { @@ -287,7 +287,7 @@ func checkGoInstallation() DoctorResult { // checkGitRepository verifies we're in a git repository. func checkGitRepository() DoctorResult { - cmd := exec.Command("git", "rev-parse", "--git-dir") + cmd := exec.CommandContext(context.Background(), "git", "rev-parse", "--git-dir") err := cmd.Run() if err != nil { @@ -320,7 +320,7 @@ func checkRequiredFiles() DoctorResult { var missingFiles []string for _, file := range requiredFiles { - cmd := exec.Command("test", "-e", file) + cmd := exec.CommandContext(context.Background(), "test", "-e", file) if err := cmd.Run(); err != nil { missingFiles = append(missingFiles, file) } diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index 1ef767ce..ff9379c1 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -1035,7 +1035,7 @@ func (hsdb *HSDatabase) Close() error { } if hsdb.cfg.Database.Type == types.DatabaseSqlite && hsdb.cfg.Database.Sqlite.WriteAheadLog { - _, _ = db.Exec("VACUUM") + _, _ = db.ExecContext(context.Background(), "VACUUM") } return db.Close() diff --git a/hscontrol/db/db_test.go b/hscontrol/db/db_test.go index 47a527b9..f93b9ef8 100644 --- a/hscontrol/db/db_test.go +++ b/hscontrol/db/db_test.go @@ -1,6 +1,7 @@ package db import ( + "context" "database/sql" "os" "os/exec" @@ -177,7 +178,7 @@ func createSQLiteFromSQLFile(sqlFilePath, dbPath string) error { return err } - _, err = db.Exec(string(schemaContent)) + _, err = db.ExecContext(context.Background(), string(schemaContent)) return err } @@ -322,7 +323,7 @@ func TestPostgresMigrationAndDataValidation(t *testing.T) { } // Construct the pg_restore command - cmd := exec.Command(pgRestorePath, "--verbose", "--if-exists", "--clean", "--no-owner", "--dbname", u.String(), tt.dbPath) + cmd := exec.CommandContext(context.Background(), pgRestorePath, "--verbose", "--if-exists", "--clean", "--no-owner", "--dbname", u.String(), tt.dbPath) // Set the output streams cmd.Stdout = os.Stdout diff --git a/hscontrol/db/sqliteconfig/integration_test.go b/hscontrol/db/sqliteconfig/integration_test.go index b411daeb..fa39f958 100644 --- a/hscontrol/db/sqliteconfig/integration_test.go +++ b/hscontrol/db/sqliteconfig/integration_test.go @@ -1,6 +1,7 @@ package sqliteconfig import ( + "context" "database/sql" "path/filepath" "strings" @@ -101,7 +102,7 @@ func TestSQLiteDriverPragmaIntegration(t *testing.T) { defer db.Close() // Test connection - if err := db.Ping(); err != nil { + if err := db.PingContext(context.Background()); err != nil { t.Fatalf("Failed to ping database: %v", err) } @@ -112,7 +113,7 @@ func TestSQLiteDriverPragmaIntegration(t *testing.T) { query := "PRAGMA " + pragma - err := db.QueryRow(query).Scan(&actualValue) + err := db.QueryRowContext(context.Background(), query).Scan(&actualValue) if err != nil { t.Fatalf("Failed to query %s: %v", query, err) } @@ -180,23 +181,23 @@ func TestForeignKeyConstraintEnforcement(t *testing.T) { ); ` - if _, err := db.Exec(schema); err != nil { + if _, err := db.ExecContext(context.Background(), schema); err != nil { t.Fatalf("Failed to create schema: %v", err) } // Insert parent record - if _, err := db.Exec("INSERT INTO parent (id, name) VALUES (1, 'Parent 1')"); err != nil { + if _, err := db.ExecContext(context.Background(), "INSERT INTO parent (id, name) VALUES (1, 'Parent 1')"); err != nil { t.Fatalf("Failed to insert parent: %v", err) } // Test 1: Valid foreign key should work - _, err = db.Exec("INSERT INTO child (id, parent_id, name) VALUES (1, 1, 'Child 1')") + _, err = db.ExecContext(context.Background(), "INSERT INTO child (id, parent_id, name) VALUES (1, 1, 'Child 1')") if err != nil { t.Fatalf("Valid foreign key insert failed: %v", err) } // Test 2: Invalid foreign key should fail - _, err = db.Exec("INSERT INTO child (id, parent_id, name) VALUES (2, 999, 'Child 2')") + _, err = db.ExecContext(context.Background(), "INSERT INTO child (id, parent_id, name) VALUES (2, 999, 'Child 2')") if err == nil { t.Error("Expected foreign key constraint violation, but insert succeeded") } else if !contains(err.Error(), "FOREIGN KEY constraint failed") { @@ -206,7 +207,7 @@ func TestForeignKeyConstraintEnforcement(t *testing.T) { } // Test 3: Deleting referenced parent should fail - _, err = db.Exec("DELETE FROM parent WHERE id = 1") + _, err = db.ExecContext(context.Background(), "DELETE FROM parent WHERE id = 1") if err == nil { t.Error("Expected foreign key constraint violation when deleting referenced parent") } else if !contains(err.Error(), "FOREIGN KEY constraint failed") { @@ -252,7 +253,7 @@ func TestJournalModeValidation(t *testing.T) { var actualMode string - err = db.QueryRow("PRAGMA journal_mode").Scan(&actualMode) + err = db.QueryRowContext(context.Background(), "PRAGMA journal_mode").Scan(&actualMode) if err != nil { t.Fatalf("Failed to query journal_mode: %v", err) } diff --git a/hscontrol/derp/server/derp_server.go b/hscontrol/derp/server/derp_server.go index bf292d03..562061e2 100644 --- a/hscontrol/derp/server/derp_server.go +++ b/hscontrol/derp/server/derp_server.go @@ -99,12 +99,12 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) { // If debug flag is set, resolve hostname to IP address if debugUseDERPIP { - ips, err := net.LookupIP(host) + addrs, err := net.DefaultResolver.LookupIPAddr(context.Background(), host) if err != nil { log.Error().Caller().Err(err).Msgf("Failed to resolve DERP hostname %s to IP, using hostname", host) - } else if len(ips) > 0 { + } else if len(addrs) > 0 { // Use the first IP address - ipStr := ips[0].String() + ipStr := addrs[0].IP.String() log.Info().Caller().Msgf("HEADSCALE_DEBUG_DERP_USE_IP: Resolved %s to %s", host, ipStr) host = ipStr } diff --git a/integration/api_auth_test.go b/integration/api_auth_test.go index 825f3d17..ed4a1f4d 100644 --- a/integration/api_auth_test.go +++ b/integration/api_auth_test.go @@ -1,6 +1,7 @@ package integration import ( + "context" "crypto/tls" "encoding/json" "fmt" @@ -78,7 +79,7 @@ func TestAPIAuthenticationBypass(t *testing.T) { t.Run("HTTP_NoAuthHeader", func(t *testing.T) { // Test 1: Request without any Authorization header // Expected: Should return 401 with ONLY "Unauthorized" text, no user data - req, err := http.NewRequest("GET", apiURL, nil) + req, err := http.NewRequestWithContext(context.Background(), "GET", apiURL, nil) require.NoError(t, err) resp, err := client.Do(req) @@ -130,7 +131,7 @@ func TestAPIAuthenticationBypass(t *testing.T) { t.Run("HTTP_InvalidAuthHeader", func(t *testing.T) { // Test 2: Request with invalid Authorization header (missing "Bearer " prefix) // Expected: Should return 401 with ONLY "Unauthorized" text, no user data - req, err := http.NewRequest("GET", apiURL, nil) + req, err := http.NewRequestWithContext(context.Background(), "GET", apiURL, nil) require.NoError(t, err) req.Header.Set("Authorization", "InvalidToken") @@ -164,7 +165,7 @@ func TestAPIAuthenticationBypass(t *testing.T) { // Test 3: Request with Bearer prefix but invalid token // Expected: Should return 401 with ONLY "Unauthorized" text, no user data // Note: Both malformed and properly formatted invalid tokens should return 401 - req, err := http.NewRequest("GET", apiURL, nil) + req, err := http.NewRequestWithContext(context.Background(), "GET", apiURL, nil) require.NoError(t, err) req.Header.Set("Authorization", "Bearer invalid-token-12345") @@ -197,7 +198,7 @@ func TestAPIAuthenticationBypass(t *testing.T) { t.Run("HTTP_ValidAPIKey", func(t *testing.T) { // Test 4: Request with valid API key // Expected: Should return 200 with user data (this is the authorized case) - req, err := http.NewRequest("GET", apiURL, nil) + req, err := http.NewRequestWithContext(context.Background(), "GET", apiURL, nil) require.NoError(t, err) req.Header.Set("Authorization", "Bearer "+validAPIKey)