2021-02-24 15:24:45 -05:00
|
|
|
package server
|
|
|
|
|
|
|
|
import (
|
|
|
|
"sort"
|
|
|
|
|
|
|
|
"github.com/gin-gonic/gin"
|
2021-02-26 16:17:04 -05:00
|
|
|
"github.com/h44z/wg-portal/internal/authentication"
|
2021-02-24 15:24:45 -05:00
|
|
|
"github.com/h44z/wg-portal/internal/users"
|
|
|
|
"github.com/sirupsen/logrus"
|
|
|
|
)
|
|
|
|
|
2021-02-26 16:17:04 -05:00
|
|
|
// AuthManager keeps track of available authentication providers.
|
2021-02-24 15:24:45 -05:00
|
|
|
type AuthManager struct {
|
|
|
|
Server *Server
|
|
|
|
Group *gin.RouterGroup // basic group for all providers (/auth)
|
|
|
|
providers []authentication.AuthProvider
|
|
|
|
UserManager *users.Manager
|
|
|
|
}
|
|
|
|
|
|
|
|
// RegisterProvider register auth provider
|
|
|
|
func (auth *AuthManager) RegisterProvider(provider authentication.AuthProvider) {
|
|
|
|
name := provider.GetName()
|
|
|
|
if auth.GetProvider(name) != nil {
|
|
|
|
logrus.Warnf("auth provider %v already registered", name)
|
|
|
|
}
|
|
|
|
|
|
|
|
provider.SetupRoutes(auth.Group)
|
|
|
|
auth.providers = append(auth.providers, provider)
|
|
|
|
}
|
|
|
|
|
|
|
|
// RegisterProviderWithoutError register auth provider if err is nil
|
|
|
|
func (auth *AuthManager) RegisterProviderWithoutError(provider authentication.AuthProvider, err error) {
|
|
|
|
if err != nil {
|
|
|
|
logrus.Errorf("skipping provider registration: %v", err)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
auth.RegisterProvider(provider)
|
|
|
|
}
|
|
|
|
|
2021-02-26 16:17:04 -05:00
|
|
|
// GetProvider get provider by name
|
2021-02-24 15:24:45 -05:00
|
|
|
func (auth *AuthManager) GetProvider(name string) authentication.AuthProvider {
|
|
|
|
for _, provider := range auth.providers {
|
|
|
|
if provider.GetName() == name {
|
|
|
|
return provider
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2021-02-26 16:17:04 -05:00
|
|
|
// GetProviders return registered providers.
|
|
|
|
// Returned providers are ordered by provider priority.
|
2021-02-24 15:24:45 -05:00
|
|
|
func (auth *AuthManager) GetProviders() (providers []authentication.AuthProvider) {
|
|
|
|
for _, provider := range auth.providers {
|
|
|
|
providers = append(providers, provider)
|
|
|
|
}
|
2021-02-26 16:17:04 -05:00
|
|
|
|
|
|
|
// order by priority
|
|
|
|
sort.SliceStable(providers, func(i, j int) bool {
|
|
|
|
return providers[i].GetPriority() < providers[j].GetPriority()
|
|
|
|
})
|
|
|
|
|
2021-02-24 15:24:45 -05:00
|
|
|
return
|
|
|
|
}
|
|
|
|
|
2021-02-26 16:17:04 -05:00
|
|
|
// GetProvidersForType return registered providers for the given type.
|
|
|
|
// Returned providers are ordered by provider priority.
|
2021-02-24 15:24:45 -05:00
|
|
|
func (auth *AuthManager) GetProvidersForType(typ authentication.AuthProviderType) (providers []authentication.AuthProvider) {
|
|
|
|
for _, provider := range auth.providers {
|
|
|
|
if provider.GetType() == typ {
|
|
|
|
providers = append(providers, provider)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// order by priority
|
|
|
|
sort.SliceStable(providers, func(i, j int) bool {
|
|
|
|
return providers[i].GetPriority() < providers[j].GetPriority()
|
|
|
|
})
|
|
|
|
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
func NewAuthManager(server *Server) *AuthManager {
|
|
|
|
m := &AuthManager{
|
|
|
|
Server: server,
|
|
|
|
}
|
|
|
|
|
|
|
|
m.Group = m.Server.server.Group("/auth")
|
|
|
|
|
|
|
|
return m
|
|
|
|
}
|