mirror of
https://github.com/photoprism/photoprism.git
synced 2026-01-23 02:24:24 +00:00
OIDC: Refactor /internal/auth/oidc package #782
Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
parent
11b04bcbe7
commit
cc920698a2
5 changed files with 52 additions and 46 deletions
|
|
@ -17,7 +17,6 @@ import (
|
|||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/config"
|
||||
"github.com/photoprism/photoprism/internal/event"
|
||||
"github.com/photoprism/photoprism/pkg/rnd"
|
||||
)
|
||||
|
||||
|
|
@ -26,35 +25,30 @@ const (
|
|||
AdminRole = "photoprism_admin"
|
||||
)
|
||||
|
||||
var log = event.Log
|
||||
|
||||
type Client struct {
|
||||
rp.RelyingParty
|
||||
debug bool
|
||||
}
|
||||
|
||||
func NewClient(iss *url.URL, clientId, clientSecret, customScopes, siteUrl string, debug bool) (result *Client, err error) {
|
||||
log.Debugf("oidc: Provider Params: %s %s %s %s", iss.String(), clientId, clientSecret, siteUrl)
|
||||
|
||||
u, err := url.Parse(siteUrl)
|
||||
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
log.Debug(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
u.Path = path.Join(u.Path, config.OidcRedirectUri)
|
||||
log.Debugf("oidc: redirect uri %s", u.String())
|
||||
|
||||
var hashKey, encryptKey []byte
|
||||
|
||||
if hashKey, err = rnd.RandomBytes(16); err != nil {
|
||||
log.Errorf("oidc: %q (create hash key)", err)
|
||||
log.Debugf("oidc: %q (create hash key)", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if encryptKey, err = rnd.RandomBytes(16); err != nil {
|
||||
log.Errorf("oidc: %q (create encrypt key)", err)
|
||||
log.Debugf("oidc: %q (create encrypt key)", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
@ -68,7 +62,7 @@ func NewClient(iss *url.URL, clientId, clientSecret, customScopes, siteUrl strin
|
|||
rp.WithIssuedAtOffset(5 * time.Second),
|
||||
),
|
||||
rp.WithErrorHandler(func(w http.ResponseWriter, r *http.Request, errorType string, errorDesc string, state string) {
|
||||
log.Errorf("oidc: %s: %s (state: %s)", errorType, errorDesc, state)
|
||||
log.Debugf("oidc: %s: %s (state: %s)", errorType, errorDesc, state)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Header().Add("oidc_error", fmt.Sprintf("oidc: %s", errorDesc))
|
||||
}),
|
||||
|
|
@ -77,7 +71,7 @@ func NewClient(iss *url.URL, clientId, clientSecret, customScopes, siteUrl strin
|
|||
discover, err := client.Discover(iss.String(), httpClient)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("oidc: %q (discover)", err)
|
||||
log.Debugf("oidc: %q (discover)", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
@ -92,11 +86,11 @@ func NewClient(iss *url.URL, clientId, clientSecret, customScopes, siteUrl strin
|
|||
provider, err := rp.NewRelyingPartyOIDC(iss.String(), clientId, clientSecret, u.String(), scopes, clientOpt...)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("oidc: %s (issuer)", err)
|
||||
log.Debugf("oidc: %s (issuer)", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Debugf("oidc: PKCE enabled %v", provider.IsPKCE())
|
||||
log.Tracef("oidc: pkce enabled %v", provider.IsPKCE())
|
||||
|
||||
return &Client{
|
||||
provider,
|
||||
|
|
@ -119,19 +113,19 @@ func (c *Client) CodeExchangeUserInfo(ctx *gin.Context) (userInfo oidc.UserInfo,
|
|||
tokens = t
|
||||
}
|
||||
|
||||
//you could also just take the access_token and id_token without calling the userinfo endpoint:
|
||||
//
|
||||
//tokeninfoClosure := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp rp.RelyingParty) {
|
||||
// log.Infof("IDTOKEN: %q\n\n" , tokens.IDToken)
|
||||
// log.Infof("ACCESSTOKEN: %q\n\n" , tokens.AccessToken)
|
||||
// log.Infof("REFRESHTOKEN: %q\n\n" , tokens.RefreshToken)
|
||||
//}
|
||||
/*
|
||||
You could also just take the access_token and id_token without calling the userinfo endpoint, e.g.:
|
||||
|
||||
tokeninfoClosure := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp rp.RelyingParty) {
|
||||
log.Infof("IDTOKEN: %q\n\n" , tokens.IDToken)
|
||||
log.Infof("ACCESSTOKEN: %q\n\n" , tokens.AccessToken)
|
||||
log.Infof("REFRESHTOKEN: %q\n\n" , tokens.RefreshToken)
|
||||
*/
|
||||
|
||||
handle := rp.CodeExchangeHandler(rp.UserinfoCallback(userinfoClosure), c)
|
||||
//handle := rp.CodeExchangeHandler(tokeninfoClosure, c)
|
||||
|
||||
handle(ctx.Writer, ctx.Request)
|
||||
|
||||
// log.Debugf("oidc: current request state: %v", ctx.Writer.Status())
|
||||
if sc := ctx.Writer.Status(); sc != 0 && sc != http.StatusOK {
|
||||
if oidcErr := ctx.Writer.Header().Get("oidc_error"); oidcErr == "" {
|
||||
return userInfo, tokens, errors.New("tailed to exchange the authentication code and retrieve the user information")
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ func UsernameFromUserInfo(userinfo UserInfo) (username string) {
|
|||
} else if len(userinfo.GetEmail()) >= 4 {
|
||||
username = userinfo.GetEmail()
|
||||
} else {
|
||||
log.Error("oidc: no username found")
|
||||
log.Debug("oidc: no username found")
|
||||
}
|
||||
return username
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
|
@ -30,28 +28,16 @@ type LoggingRoundTripper struct {
|
|||
proxy http.RoundTripper
|
||||
}
|
||||
|
||||
func (lrt LoggingRoundTripper) RoundTrip(req *http.Request) (res *http.Response, e error) {
|
||||
// Do "before sending requests" actions here.
|
||||
log.Debugf("sending request to %s", req.URL.String())
|
||||
func (lrt LoggingRoundTripper) RoundTrip(req *http.Request) (res *http.Response, err error) {
|
||||
log.Tracef("oidc: %s %s", req.Method, req.URL.String())
|
||||
|
||||
// Send the request, get the response (or the error)
|
||||
res, e = lrt.proxy.RoundTrip(req)
|
||||
// Send request.
|
||||
res, err = lrt.proxy.RoundTrip(req)
|
||||
|
||||
// Handle the result.
|
||||
if e != nil {
|
||||
log.Errorf("http error: %s", e)
|
||||
} else {
|
||||
log.Debugf("http response: %s", res.Status)
|
||||
|
||||
// Copy body into buffer for logging
|
||||
buf := new(bytes.Buffer)
|
||||
_, err := io.Copy(buf, res.Body)
|
||||
if err != nil {
|
||||
log.Errorf("http buffer error: %s", err)
|
||||
}
|
||||
// log.Debugf("Header: %s\n", res.Header)
|
||||
// log.Debugf("Reponse Body: %s\n", buf.String())
|
||||
res.Body = io.NopCloser(buf)
|
||||
// Log error, if any.
|
||||
if err != nil {
|
||||
log.Debugf("oidc: request to %s has failed (%s)", req.URL.String(), err)
|
||||
}
|
||||
return
|
||||
|
||||
return res, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,3 +23,7 @@ Additional information can be found in our Developer Guide:
|
|||
<https://docs.photoprism.app/developer-guide/>
|
||||
*/
|
||||
package oidc
|
||||
|
||||
import "github.com/photoprism/photoprism/internal/event"
|
||||
|
||||
var log = event.Log
|
||||
|
|
|
|||
22
internal/auth/oidc/oidc_test.go
Normal file
22
internal/auth/oidc/oidc_test.go
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/event"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
// Init test logger.
|
||||
log = logrus.StandardLogger()
|
||||
log.SetLevel(logrus.TraceLevel)
|
||||
event.AuditLog = log
|
||||
|
||||
// Run unit tests.
|
||||
code := m.Run()
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue