OIDC: Refactor /internal/auth/oidc package #782

Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
Michael Mayer 2024-07-02 14:38:29 +02:00
parent 11b04bcbe7
commit cc920698a2
5 changed files with 52 additions and 46 deletions

View file

@ -17,7 +17,6 @@ import (
"github.com/zitadel/oidc/pkg/oidc"
"github.com/photoprism/photoprism/internal/config"
"github.com/photoprism/photoprism/internal/event"
"github.com/photoprism/photoprism/pkg/rnd"
)
@ -26,35 +25,30 @@ const (
AdminRole = "photoprism_admin"
)
var log = event.Log
type Client struct {
rp.RelyingParty
debug bool
}
func NewClient(iss *url.URL, clientId, clientSecret, customScopes, siteUrl string, debug bool) (result *Client, err error) {
log.Debugf("oidc: Provider Params: %s %s %s %s", iss.String(), clientId, clientSecret, siteUrl)
u, err := url.Parse(siteUrl)
if err != nil {
log.Error(err)
log.Debug(err)
return nil, err
}
u.Path = path.Join(u.Path, config.OidcRedirectUri)
log.Debugf("oidc: redirect uri %s", u.String())
var hashKey, encryptKey []byte
if hashKey, err = rnd.RandomBytes(16); err != nil {
log.Errorf("oidc: %q (create hash key)", err)
log.Debugf("oidc: %q (create hash key)", err)
return nil, err
}
if encryptKey, err = rnd.RandomBytes(16); err != nil {
log.Errorf("oidc: %q (create encrypt key)", err)
log.Debugf("oidc: %q (create encrypt key)", err)
return nil, err
}
@ -68,7 +62,7 @@ func NewClient(iss *url.URL, clientId, clientSecret, customScopes, siteUrl strin
rp.WithIssuedAtOffset(5 * time.Second),
),
rp.WithErrorHandler(func(w http.ResponseWriter, r *http.Request, errorType string, errorDesc string, state string) {
log.Errorf("oidc: %s: %s (state: %s)", errorType, errorDesc, state)
log.Debugf("oidc: %s: %s (state: %s)", errorType, errorDesc, state)
w.WriteHeader(http.StatusInternalServerError)
w.Header().Add("oidc_error", fmt.Sprintf("oidc: %s", errorDesc))
}),
@ -77,7 +71,7 @@ func NewClient(iss *url.URL, clientId, clientSecret, customScopes, siteUrl strin
discover, err := client.Discover(iss.String(), httpClient)
if err != nil {
log.Errorf("oidc: %q (discover)", err)
log.Debugf("oidc: %q (discover)", err)
return nil, err
}
@ -92,11 +86,11 @@ func NewClient(iss *url.URL, clientId, clientSecret, customScopes, siteUrl strin
provider, err := rp.NewRelyingPartyOIDC(iss.String(), clientId, clientSecret, u.String(), scopes, clientOpt...)
if err != nil {
log.Errorf("oidc: %s (issuer)", err)
log.Debugf("oidc: %s (issuer)", err)
return nil, err
}
log.Debugf("oidc: PKCE enabled %v", provider.IsPKCE())
log.Tracef("oidc: pkce enabled %v", provider.IsPKCE())
return &Client{
provider,
@ -119,19 +113,19 @@ func (c *Client) CodeExchangeUserInfo(ctx *gin.Context) (userInfo oidc.UserInfo,
tokens = t
}
//you could also just take the access_token and id_token without calling the userinfo endpoint:
//
//tokeninfoClosure := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp rp.RelyingParty) {
// log.Infof("IDTOKEN: %q\n\n" , tokens.IDToken)
// log.Infof("ACCESSTOKEN: %q\n\n" , tokens.AccessToken)
// log.Infof("REFRESHTOKEN: %q\n\n" , tokens.RefreshToken)
//}
/*
You could also just take the access_token and id_token without calling the userinfo endpoint, e.g.:
tokeninfoClosure := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp rp.RelyingParty) {
log.Infof("IDTOKEN: %q\n\n" , tokens.IDToken)
log.Infof("ACCESSTOKEN: %q\n\n" , tokens.AccessToken)
log.Infof("REFRESHTOKEN: %q\n\n" , tokens.RefreshToken)
*/
handle := rp.CodeExchangeHandler(rp.UserinfoCallback(userinfoClosure), c)
//handle := rp.CodeExchangeHandler(tokeninfoClosure, c)
handle(ctx.Writer, ctx.Request)
// log.Debugf("oidc: current request state: %v", ctx.Writer.Status())
if sc := ctx.Writer.Status(); sc != 0 && sc != http.StatusOK {
if oidcErr := ctx.Writer.Header().Get("oidc_error"); oidcErr == "" {
return userInfo, tokens, errors.New("tailed to exchange the authentication code and retrieve the user information")

View file

@ -23,7 +23,7 @@ func UsernameFromUserInfo(userinfo UserInfo) (username string) {
} else if len(userinfo.GetEmail()) >= 4 {
username = userinfo.GetEmail()
} else {
log.Error("oidc: no username found")
log.Debug("oidc: no username found")
}
return username
}

View file

@ -1,8 +1,6 @@
package oidc
import (
"bytes"
"io"
"net/http"
"time"
)
@ -30,28 +28,16 @@ type LoggingRoundTripper struct {
proxy http.RoundTripper
}
func (lrt LoggingRoundTripper) RoundTrip(req *http.Request) (res *http.Response, e error) {
// Do "before sending requests" actions here.
log.Debugf("sending request to %s", req.URL.String())
func (lrt LoggingRoundTripper) RoundTrip(req *http.Request) (res *http.Response, err error) {
log.Tracef("oidc: %s %s", req.Method, req.URL.String())
// Send the request, get the response (or the error)
res, e = lrt.proxy.RoundTrip(req)
// Send request.
res, err = lrt.proxy.RoundTrip(req)
// Handle the result.
if e != nil {
log.Errorf("http error: %s", e)
} else {
log.Debugf("http response: %s", res.Status)
// Copy body into buffer for logging
buf := new(bytes.Buffer)
_, err := io.Copy(buf, res.Body)
if err != nil {
log.Errorf("http buffer error: %s", err)
}
// log.Debugf("Header: %s\n", res.Header)
// log.Debugf("Reponse Body: %s\n", buf.String())
res.Body = io.NopCloser(buf)
// Log error, if any.
if err != nil {
log.Debugf("oidc: request to %s has failed (%s)", req.URL.String(), err)
}
return
return res, err
}

View file

@ -23,3 +23,7 @@ Additional information can be found in our Developer Guide:
<https://docs.photoprism.app/developer-guide/>
*/
package oidc
import "github.com/photoprism/photoprism/internal/event"
var log = event.Log

View file

@ -0,0 +1,22 @@
package oidc
import (
"os"
"testing"
"github.com/sirupsen/logrus"
"github.com/photoprism/photoprism/internal/event"
)
func TestMain(m *testing.M) {
// Init test logger.
log = logrus.StandardLogger()
log.SetLevel(logrus.TraceLevel)
event.AuditLog = log
// Run unit tests.
code := m.Run()
os.Exit(code)
}