mirror of
https://github.com/juanfont/headscale.git
synced 2026-01-23 02:24:10 +00:00
all: use context-aware methods for exec, database, and HTTP
Replace direct calls with context-aware versions: - exec.Command → exec.CommandContext - db.Exec → db.ExecContext - db.Ping → db.PingContext - db.QueryRow → db.QueryRowContext - http.NewRequest → http.NewRequestWithContext - net.LookupIP → net.DefaultResolver.LookupIPAddr
This commit is contained in:
parent
676273ee9d
commit
3843036d13
7 changed files with 25 additions and 22 deletions
|
|
@ -475,7 +475,7 @@ func createDockerClient() (*client.Client, error) {
|
|||
|
||||
// getCurrentDockerContext retrieves the current Docker context information.
|
||||
func getCurrentDockerContext() (*DockerContext, error) {
|
||||
cmd := exec.Command("docker", "context", "inspect")
|
||||
cmd := exec.CommandContext(context.Background(), "docker", "context", "inspect")
|
||||
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -265,7 +265,7 @@ func checkGoInstallation() DoctorResult {
|
|||
}
|
||||
}
|
||||
|
||||
cmd := exec.Command("go", "version")
|
||||
cmd := exec.CommandContext(context.Background(), "go", "version")
|
||||
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
|
|
@ -287,7 +287,7 @@ func checkGoInstallation() DoctorResult {
|
|||
|
||||
// checkGitRepository verifies we're in a git repository.
|
||||
func checkGitRepository() DoctorResult {
|
||||
cmd := exec.Command("git", "rev-parse", "--git-dir")
|
||||
cmd := exec.CommandContext(context.Background(), "git", "rev-parse", "--git-dir")
|
||||
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
|
|
@ -320,7 +320,7 @@ func checkRequiredFiles() DoctorResult {
|
|||
var missingFiles []string
|
||||
|
||||
for _, file := range requiredFiles {
|
||||
cmd := exec.Command("test", "-e", file)
|
||||
cmd := exec.CommandContext(context.Background(), "test", "-e", file)
|
||||
if err := cmd.Run(); err != nil {
|
||||
missingFiles = append(missingFiles, file)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1035,7 +1035,7 @@ func (hsdb *HSDatabase) Close() error {
|
|||
}
|
||||
|
||||
if hsdb.cfg.Database.Type == types.DatabaseSqlite && hsdb.cfg.Database.Sqlite.WriteAheadLog {
|
||||
_, _ = db.Exec("VACUUM")
|
||||
_, _ = db.ExecContext(context.Background(), "VACUUM")
|
||||
}
|
||||
|
||||
return db.Close()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"os"
|
||||
"os/exec"
|
||||
|
|
@ -177,7 +178,7 @@ func createSQLiteFromSQLFile(sqlFilePath, dbPath string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
_, err = db.Exec(string(schemaContent))
|
||||
_, err = db.ExecContext(context.Background(), string(schemaContent))
|
||||
|
||||
return err
|
||||
}
|
||||
|
|
@ -322,7 +323,7 @@ func TestPostgresMigrationAndDataValidation(t *testing.T) {
|
|||
}
|
||||
|
||||
// Construct the pg_restore command
|
||||
cmd := exec.Command(pgRestorePath, "--verbose", "--if-exists", "--clean", "--no-owner", "--dbname", u.String(), tt.dbPath)
|
||||
cmd := exec.CommandContext(context.Background(), pgRestorePath, "--verbose", "--if-exists", "--clean", "--no-owner", "--dbname", u.String(), tt.dbPath)
|
||||
|
||||
// Set the output streams
|
||||
cmd.Stdout = os.Stdout
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package sqliteconfig
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
|
@ -101,7 +102,7 @@ func TestSQLiteDriverPragmaIntegration(t *testing.T) {
|
|||
defer db.Close()
|
||||
|
||||
// Test connection
|
||||
if err := db.Ping(); err != nil {
|
||||
if err := db.PingContext(context.Background()); err != nil {
|
||||
t.Fatalf("Failed to ping database: %v", err)
|
||||
}
|
||||
|
||||
|
|
@ -112,7 +113,7 @@ func TestSQLiteDriverPragmaIntegration(t *testing.T) {
|
|||
|
||||
query := "PRAGMA " + pragma
|
||||
|
||||
err := db.QueryRow(query).Scan(&actualValue)
|
||||
err := db.QueryRowContext(context.Background(), query).Scan(&actualValue)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to query %s: %v", query, err)
|
||||
}
|
||||
|
|
@ -180,23 +181,23 @@ func TestForeignKeyConstraintEnforcement(t *testing.T) {
|
|||
);
|
||||
`
|
||||
|
||||
if _, err := db.Exec(schema); err != nil {
|
||||
if _, err := db.ExecContext(context.Background(), schema); err != nil {
|
||||
t.Fatalf("Failed to create schema: %v", err)
|
||||
}
|
||||
|
||||
// Insert parent record
|
||||
if _, err := db.Exec("INSERT INTO parent (id, name) VALUES (1, 'Parent 1')"); err != nil {
|
||||
if _, err := db.ExecContext(context.Background(), "INSERT INTO parent (id, name) VALUES (1, 'Parent 1')"); err != nil {
|
||||
t.Fatalf("Failed to insert parent: %v", err)
|
||||
}
|
||||
|
||||
// Test 1: Valid foreign key should work
|
||||
_, err = db.Exec("INSERT INTO child (id, parent_id, name) VALUES (1, 1, 'Child 1')")
|
||||
_, err = db.ExecContext(context.Background(), "INSERT INTO child (id, parent_id, name) VALUES (1, 1, 'Child 1')")
|
||||
if err != nil {
|
||||
t.Fatalf("Valid foreign key insert failed: %v", err)
|
||||
}
|
||||
|
||||
// Test 2: Invalid foreign key should fail
|
||||
_, err = db.Exec("INSERT INTO child (id, parent_id, name) VALUES (2, 999, 'Child 2')")
|
||||
_, err = db.ExecContext(context.Background(), "INSERT INTO child (id, parent_id, name) VALUES (2, 999, 'Child 2')")
|
||||
if err == nil {
|
||||
t.Error("Expected foreign key constraint violation, but insert succeeded")
|
||||
} else if !contains(err.Error(), "FOREIGN KEY constraint failed") {
|
||||
|
|
@ -206,7 +207,7 @@ func TestForeignKeyConstraintEnforcement(t *testing.T) {
|
|||
}
|
||||
|
||||
// Test 3: Deleting referenced parent should fail
|
||||
_, err = db.Exec("DELETE FROM parent WHERE id = 1")
|
||||
_, err = db.ExecContext(context.Background(), "DELETE FROM parent WHERE id = 1")
|
||||
if err == nil {
|
||||
t.Error("Expected foreign key constraint violation when deleting referenced parent")
|
||||
} else if !contains(err.Error(), "FOREIGN KEY constraint failed") {
|
||||
|
|
@ -252,7 +253,7 @@ func TestJournalModeValidation(t *testing.T) {
|
|||
|
||||
var actualMode string
|
||||
|
||||
err = db.QueryRow("PRAGMA journal_mode").Scan(&actualMode)
|
||||
err = db.QueryRowContext(context.Background(), "PRAGMA journal_mode").Scan(&actualMode)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to query journal_mode: %v", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -99,12 +99,12 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) {
|
|||
|
||||
// If debug flag is set, resolve hostname to IP address
|
||||
if debugUseDERPIP {
|
||||
ips, err := net.LookupIP(host)
|
||||
addrs, err := net.DefaultResolver.LookupIPAddr(context.Background(), host)
|
||||
if err != nil {
|
||||
log.Error().Caller().Err(err).Msgf("Failed to resolve DERP hostname %s to IP, using hostname", host)
|
||||
} else if len(ips) > 0 {
|
||||
} else if len(addrs) > 0 {
|
||||
// Use the first IP address
|
||||
ipStr := ips[0].String()
|
||||
ipStr := addrs[0].IP.String()
|
||||
log.Info().Caller().Msgf("HEADSCALE_DEBUG_DERP_USE_IP: Resolved %s to %s", host, ipStr)
|
||||
host = ipStr
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
|
@ -78,7 +79,7 @@ func TestAPIAuthenticationBypass(t *testing.T) {
|
|||
t.Run("HTTP_NoAuthHeader", func(t *testing.T) {
|
||||
// Test 1: Request without any Authorization header
|
||||
// Expected: Should return 401 with ONLY "Unauthorized" text, no user data
|
||||
req, err := http.NewRequest("GET", apiURL, nil)
|
||||
req, err := http.NewRequestWithContext(context.Background(), "GET", apiURL, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
|
|
@ -130,7 +131,7 @@ func TestAPIAuthenticationBypass(t *testing.T) {
|
|||
t.Run("HTTP_InvalidAuthHeader", func(t *testing.T) {
|
||||
// Test 2: Request with invalid Authorization header (missing "Bearer " prefix)
|
||||
// Expected: Should return 401 with ONLY "Unauthorized" text, no user data
|
||||
req, err := http.NewRequest("GET", apiURL, nil)
|
||||
req, err := http.NewRequestWithContext(context.Background(), "GET", apiURL, nil)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set("Authorization", "InvalidToken")
|
||||
|
||||
|
|
@ -164,7 +165,7 @@ func TestAPIAuthenticationBypass(t *testing.T) {
|
|||
// Test 3: Request with Bearer prefix but invalid token
|
||||
// Expected: Should return 401 with ONLY "Unauthorized" text, no user data
|
||||
// Note: Both malformed and properly formatted invalid tokens should return 401
|
||||
req, err := http.NewRequest("GET", apiURL, nil)
|
||||
req, err := http.NewRequestWithContext(context.Background(), "GET", apiURL, nil)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set("Authorization", "Bearer invalid-token-12345")
|
||||
|
||||
|
|
@ -197,7 +198,7 @@ func TestAPIAuthenticationBypass(t *testing.T) {
|
|||
t.Run("HTTP_ValidAPIKey", func(t *testing.T) {
|
||||
// Test 4: Request with valid API key
|
||||
// Expected: Should return 200 with user data (this is the authorized case)
|
||||
req, err := http.NewRequest("GET", apiURL, nil)
|
||||
req, err := http.NewRequestWithContext(context.Background(), "GET", apiURL, nil)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set("Authorization", "Bearer "+validAPIKey)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue