diff --git a/cmd/wg-portal/main.go b/cmd/wg-portal/main.go index d814b2f..057d27a 100644 --- a/cmd/wg-portal/main.go +++ b/cmd/wg-portal/main.go @@ -12,7 +12,7 @@ import ( "github.com/sirupsen/logrus" ) -var Version string = "unknown (local build)" +var Version = "unknown (local build)" func main() { _ = setupLogger(logrus.StandardLogger()) @@ -20,7 +20,7 @@ func main() { c := make(chan os.Signal, 1) signal.Notify(c, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP) - logrus.Infof("Starting WireGuard Portal Server [%s]...", Version) + logrus.Infof("starting WireGuard Portal Server [%s]...", Version) // Context for clean shutdown ctx, cancel := context.WithCancel(context.Background()) @@ -28,7 +28,7 @@ func main() { service := server.Server{} if err := service.Setup(ctx); err != nil { - logrus.Fatalf("Setup failed: %v", err) + logrus.Fatalf("setup failed: %v", err) } // Attach signal handlers to context @@ -44,10 +44,10 @@ func main() { <-ctx.Done() // Wait until the context gets canceled // Give goroutines some time to stop gracefully - logrus.Info("Stopping WireGuard Portal Server...") + logrus.Info("stopping WireGuard Portal Server...") time.Sleep(2 * time.Second) - logrus.Infof("Stopped WireGuard Portal Server...") + logrus.Infof("stopped WireGuard Portal Server...") logrus.Exit(0) } diff --git a/internal/authentication/provider.go b/internal/authentication/provider.go index b8420c4..a1e57d1 100644 --- a/internal/authentication/provider.go +++ b/internal/authentication/provider.go @@ -4,10 +4,10 @@ import ( "github.com/gin-gonic/gin" ) +// AuthContext contains all information that the AuthProvider needs to perform the authentication. type AuthContext struct { - Provider AuthProvider Username string // email or username - Password string // optional for OIDC + Password string Callback string // callback for OIDC } @@ -18,6 +18,7 @@ const ( AuthProviderTypeOauth AuthProviderType = "oauth" ) +// AuthProvider is a interface that can be implemented by different authentication providers like LDAP, OAUTH, ... type AuthProvider interface { GetName() string GetType() AuthProviderType diff --git a/internal/authentication/providers/ldap/provider.go b/internal/authentication/providers/ldap/provider.go index 9fa6e83..7e6dc55 100644 --- a/internal/authentication/providers/ldap/provider.go +++ b/internal/authentication/providers/ldap/provider.go @@ -13,7 +13,7 @@ import ( "github.com/pkg/errors" ) -// Provider provide login with password method +// Provider implements a password login method for an LDAP backend. type Provider struct { config *ldapconfig.Config } diff --git a/internal/authentication/providers/password/provider.go b/internal/authentication/providers/password/provider.go index 4a6d95d..ca63ee5 100644 --- a/internal/authentication/providers/password/provider.go +++ b/internal/authentication/providers/password/provider.go @@ -14,7 +14,7 @@ import ( "gorm.io/gorm" ) -// Provider provide login with password method +// Provider implements a password login method for a database backend. type Provider struct { db *gorm.DB } diff --git a/internal/authentication/user.go b/internal/authentication/user.go index e4b0af4..a5afcfc 100644 --- a/internal/authentication/user.go +++ b/internal/authentication/user.go @@ -1,5 +1,6 @@ package authentication +// User represents the data that can be retrieved from authentication backends. type User struct { Email string IsAdmin bool diff --git a/internal/common/configuration.go b/internal/common/configuration.go index 8e8a1cf..2c9f373 100644 --- a/internal/common/configuration.go +++ b/internal/common/configuration.go @@ -1,7 +1,6 @@ package common import ( - "errors" "os" "reflect" "runtime" @@ -10,13 +9,14 @@ import ( "github.com/h44z/wg-portal/internal/users" "github.com/h44z/wg-portal/internal/wireguard" "github.com/kelseyhightower/envconfig" + "github.com/pkg/errors" "github.com/sirupsen/logrus" "gopkg.in/yaml.v3" ) var ErrInvalidSpecification = errors.New("specification must be a struct pointer") -// LoadConfigFile parses yaml files. It uses to yaml annotation to store the data in a struct. +// loadConfigFile parses yaml files. It uses yaml annotation to store the data in a struct. func loadConfigFile(cfg interface{}, filename string) error { s := reflect.ValueOf(cfg) @@ -30,24 +30,24 @@ func loadConfigFile(cfg interface{}, filename string) error { f, err := os.Open(filename) if err != nil { - return err + return errors.Wrapf(err, "failed to open config file %s", filename) } defer f.Close() decoder := yaml.NewDecoder(f) err = decoder.Decode(cfg) if err != nil { - return err + return errors.Wrapf(err, "failed to decode config file %s", filename) } return nil } -// LoadConfigEnv processes envconfig annotations and loads environment variables to the given configuration struct. +// loadConfigEnv processes envconfig annotations and loads environment variables to the given configuration struct. func loadConfigEnv(cfg interface{}) error { err := envconfig.Process("", cfg) if err != nil { - return err + return errors.Wrap(err, "failed to process environment config") } return nil @@ -124,7 +124,7 @@ func NewConfig() *Config { } if cfg.WG.ManageIPAddresses && runtime.GOOS != "linux" { - logrus.Warnf("Managing IP addresses only works on linux! Feature disabled.") + logrus.Warnf("managing IP addresses only works on linux, feature disabled...") cfg.WG.ManageIPAddresses = false } diff --git a/internal/common/email.go b/internal/common/email.go index a70e6d2..e25da1d 100644 --- a/internal/common/email.go +++ b/internal/common/email.go @@ -26,7 +26,7 @@ type MailAttachment struct { Embedded bool } -// SendEmailWithAttachments sends a mail with attachments. +// SendEmailWithAttachments sends a mail with optional attachments. func SendEmailWithAttachments(cfg MailConfig, sender, replyTo, subject, body string, htmlBody string, receivers []string, attachments []MailAttachment) error { e := email.NewEmail() diff --git a/internal/common/util.go b/internal/common/util.go index dbdf60d..fc974b1 100644 --- a/internal/common/util.go +++ b/internal/common/util.go @@ -40,6 +40,8 @@ func IsIPv6(address string) bool { return ip.To4() == nil } +// ParseStringList converts a comma separated string into a list of strings. +// It also trims spaces from each element of the list. func ParseStringList(lst string) []string { tokens := strings.Split(lst, ",") validatedTokens := make([]string, 0, len(tokens)) @@ -53,6 +55,7 @@ func ParseStringList(lst string) []string { return validatedTokens } +// ListToString converts a list of strings into a comma separated string. func ListToString(lst []string) string { return strings.Join(lst, ", ") } diff --git a/internal/ldap/config.go b/internal/ldap/config.go index bb31b20..22caa9f 100644 --- a/internal/ldap/config.go +++ b/internal/ldap/config.go @@ -23,5 +23,5 @@ type Config struct { GroupMemberAttribute string `yaml:"attrGroups" envconfig:"LDAP_ATTR_GROUPS"` DisabledAttribute string `yaml:"attrDisabled" envconfig:"LDAP_ATTR_DISABLED"` - AdminLdapGroup string `yaml:"adminGroup" envconfig:"LDAP_ADMIN_GROUP"` + AdminLdapGroup string `yaml:"adminGroup" envconfig:"LDAP_ADMIN_GROUP"` // Members of this group receive admin rights in WG-Portal } diff --git a/internal/ldap/ldap.go b/internal/ldap/ldap.go index f6f038c..04a8d06 100644 --- a/internal/ldap/ldap.go +++ b/internal/ldap/ldap.go @@ -18,20 +18,20 @@ type RawLdapData struct { func Open(cfg *Config) (*ldap.Conn, error) { conn, err := ldap.DialURL(cfg.URL) if err != nil { - return nil, err + return nil, errors.Wrap(err, "failed to connect to LDAP") } if cfg.StartTLS { // Reconnect with TLS err = conn.StartTLS(&tls.Config{InsecureSkipVerify: true}) if err != nil { - return nil, err + return nil, errors.Wrap(err, "failed to star TLS on connection") } } err = conn.Bind(cfg.BindUser, cfg.BindPass) if err != nil { - return nil, err + return nil, errors.Wrap(err, "failed to bind to LDAP") } return conn, nil diff --git a/internal/server/auth.go b/internal/server/auth.go index d5e3ed8..0e1cc72 100644 --- a/internal/server/auth.go +++ b/internal/server/auth.go @@ -3,14 +3,13 @@ package server import ( "sort" - "github.com/h44z/wg-portal/internal/authentication" - "github.com/gin-gonic/gin" + "github.com/h44z/wg-portal/internal/authentication" "github.com/h44z/wg-portal/internal/users" "github.com/sirupsen/logrus" ) -// Auth auth struct +// AuthManager keeps track of available authentication providers. type AuthManager struct { Server *Server Group *gin.RouterGroup // basic group for all providers (/auth) @@ -38,7 +37,7 @@ func (auth *AuthManager) RegisterProviderWithoutError(provider authentication.Au auth.RegisterProvider(provider) } -// GetProvider get provider with name +// GetProvider get provider by name func (auth *AuthManager) GetProvider(name string) authentication.AuthProvider { for _, provider := range auth.providers { if provider.GetName() == name { @@ -48,15 +47,23 @@ func (auth *AuthManager) GetProvider(name string) authentication.AuthProvider { return nil } -// GetProviders return registered providers +// GetProviders return registered providers. +// Returned providers are ordered by provider priority. func (auth *AuthManager) GetProviders() (providers []authentication.AuthProvider) { for _, provider := range auth.providers { providers = append(providers, provider) } + + // order by priority + sort.SliceStable(providers, func(i, j int) bool { + return providers[i].GetPriority() < providers[j].GetPriority() + }) + return } -// GetProviders return registered providers +// GetProvidersForType return registered providers for the given type. +// Returned providers are ordered by provider priority. func (auth *AuthManager) GetProvidersForType(typ authentication.AuthProviderType) (providers []authentication.AuthProvider) { for _, provider := range auth.providers { if provider.GetType() == typ { diff --git a/internal/server/handlers_auth.go b/internal/server/handlers_auth.go index e3c7f72..bb70b95 100644 --- a/internal/server/handlers_auth.go +++ b/internal/server/handlers_auth.go @@ -8,7 +8,6 @@ import ( "github.com/h44z/wg-portal/internal/authentication" "github.com/h44z/wg-portal/internal/users" "github.com/sirupsen/logrus" - "gorm.io/gorm" ) func (s *Server) GetLogin(c *gin.Context) { @@ -85,10 +84,6 @@ func (s *Server) PostLogin(c *gin.Context) { loginProvider = provider // create new user in the database (or reactivate him) - if user, err = s.users.GetOrCreateUserUnscoped(email); err != nil { - s.GetHandleError(c, http.StatusInternalServerError, "login error", "failed to create new user") - return - } userData, err := loginProvider.GetUserModel(&authentication.AuthContext{ Username: email, }) @@ -96,23 +91,25 @@ func (s *Server) PostLogin(c *gin.Context) { s.GetHandleError(c, http.StatusInternalServerError, "login error", err.Error()) return } - user.Firstname = userData.Firstname - user.Lastname = userData.Lastname - user.Email = userData.Email - user.Phone = userData.Phone - user.IsAdmin = userData.IsAdmin - user.Source = users.UserSource(loginProvider.GetName()) - user.DeletedAt = gorm.DeletedAt{} // reset deleted flag - if err = s.users.UpdateUser(user); err != nil { + if err := s.CreateUser(users.User{ + Email: userData.Email, + Source: users.UserSource(loginProvider.GetName()), + IsAdmin: userData.IsAdmin, + Firstname: userData.Firstname, + Lastname: userData.Lastname, + Phone: userData.Phone, + }); err != nil { s.GetHandleError(c, http.StatusInternalServerError, "login error", "failed to update user data") return } + + user = s.users.GetUser(username) break } } // Check if user is authenticated - if email == "" || loginProvider == nil { + if email == "" || loginProvider == nil || user == nil { c.Redirect(http.StatusSeeOther, "/auth/login?err=authfail") return } @@ -126,17 +123,9 @@ func (s *Server) PostLogin(c *gin.Context) { sessionData.Lastname = user.Lastname // Check if user already has a peer setup, if not create one - if s.config.Core.CreateDefaultPeer { - peers := s.peers.GetPeersByMail(sessionData.Email) - if len(peers) == 0 { // Create vpn peer - err := s.CreatePeer(Peer{ - Identifier: sessionData.Firstname + " " + sessionData.Lastname + " (Default)", - Email: sessionData.Email, - CreatedBy: sessionData.Email, - UpdatedBy: sessionData.Email, - }) - logrus.Errorf("Failed to automatically create vpn peer for %s: %v", sessionData.Email, err) - } + if err := s.CreateUserDefaultPeer(user.Email); err != nil { + // Not a fatal error, just log it... + logrus.Errorf("failed to automatically create vpn peer for %s: %v", sessionData.Email, err) } if err := UpdateSessionData(c, sessionData); err != nil { diff --git a/internal/server/handlers_common.go b/internal/server/handlers_common.go index ba1d0ec..0ff0885 100644 --- a/internal/server/handlers_common.go +++ b/internal/server/handlers_common.go @@ -4,6 +4,8 @@ import ( "net/http" "strconv" + "github.com/pkg/errors" + "github.com/gin-gonic/gin" ) @@ -145,7 +147,7 @@ func (s *Server) updateFormInSession(c *gin.Context, formData interface{}) error currentSession.FormData = formData if err := UpdateSessionData(c, currentSession); err != nil { - return err + return errors.WithMessage(err, "failed to update form in session") } return nil @@ -158,13 +160,13 @@ func (s *Server) setNewPeerFormInSession(c *gin.Context) (SessionData, error) { if currentSession.FormData == nil || c.Query("formerr") == "" { user, err := s.PrepareNewPeer() if err != nil { - return currentSession, err + return currentSession, errors.WithMessage(err, "failed to prepare new peer") } currentSession.FormData = user } if err := UpdateSessionData(c, currentSession); err != nil { - return currentSession, err + return currentSession, errors.WithMessage(err, "failed to update peer form in session") } return currentSession, nil @@ -179,7 +181,7 @@ func (s *Server) setFormInSession(c *gin.Context, formData interface{}) (Session } if err := UpdateSessionData(c, currentSession); err != nil { - return currentSession, err + return currentSession, errors.WithMessage(err, "failed to set form in session") } return currentSession, nil diff --git a/internal/server/handlers_user.go b/internal/server/handlers_user.go index c0a4400..c96a37c 100644 --- a/internal/server/handlers_user.go +++ b/internal/server/handlers_user.go @@ -7,7 +7,6 @@ import ( "github.com/gin-gonic/gin" "github.com/h44z/wg-portal/internal/users" - "github.com/sirupsen/logrus" "golang.org/x/crypto/bcrypt" "gorm.io/gorm" ) @@ -141,31 +140,7 @@ func (s *Server) PostAdminUsersEdit(c *gin.Context) { } formUser.IsAdmin = c.PostForm("isadmin") == "true" - // Update peers - if disabled != currentUser.DeletedAt.Valid { - if disabled { - // disable all peers for the given user - for _, peer := range s.peers.GetPeersByMail(currentUser.Email) { - now := time.Now() - peer.DeactivatedAt = &now - if err := s.UpdatePeer(peer, now); err != nil { - logrus.Errorf("failed to update deactivated peer %s: %v", peer.PublicKey, err) - } - } - } else { - // enable all peers for the given user - for _, peer := range s.peers.GetPeersByMail(currentUser.Email) { - now := time.Now() - peer.DeactivatedAt = nil - if err := s.UpdatePeer(peer, now); err != nil { - logrus.Errorf("failed to update activated peer %s: %v", peer.PublicKey, err) - } - } - } - } - - // Update in database - if err := s.users.UpdateUser(&formUser); err != nil { + if err := s.UpdateUser(formUser); err != nil { _ = s.updateFormInSession(c, formUser) SetFlashMessage(c, "failed to update user: "+err.Error(), "danger") c.Redirect(http.StatusSeeOther, "/admin/users/edit?pkey="+urlEncodedKey+"&formerr=update") @@ -242,28 +217,14 @@ func (s *Server) PostAdminUsersCreate(c *gin.Context) { } formUser.IsAdmin = c.PostForm("isadmin") == "true" formUser.Source = users.UserSourceDatabase - if err := s.users.CreateUser(&formUser); err != nil { - formUser.CreatedAt = time.Time{} // reset created time + + if err := s.CreateUser(formUser); err != nil { _ = s.updateFormInSession(c, formUser) SetFlashMessage(c, "failed to add user: "+err.Error(), "danger") c.Redirect(http.StatusSeeOther, "/admin/users/create?formerr=create") return } - // Check if user already has a peer setup, if not create one - if s.config.Core.CreateDefaultPeer { - peers := s.peers.GetPeersByMail(formUser.Email) - if len(peers) == 0 { // Create vpn peer - err := s.CreatePeer(Peer{ - Identifier: formUser.Firstname + " " + formUser.Lastname + " (Default)", - Email: formUser.Email, - CreatedBy: formUser.Email, - UpdatedBy: formUser.Email, - }) - logrus.Errorf("Failed to automatically create vpn peer for %s: %v", formUser.Email, err) - } - } - SetFlashMessage(c, "user created successfully", "success") c.Redirect(http.StatusSeeOther, "/admin/users/") } diff --git a/internal/server/peermanager.go b/internal/server/peermanager.go index e26785f..ad95fd1 100644 --- a/internal/server/peermanager.go +++ b/internal/server/peermanager.go @@ -124,7 +124,7 @@ func (p Peer) GetConfig() wgtypes.PeerConfig { func (p Peer) GetConfigFile(device Device) ([]byte, error) { tpl, err := template.New("client").Funcs(template.FuncMap{"StringsJoin": strings.Join}).Parse(wireguard.ClientCfgTpl) if err != nil { - return nil, err + return nil, errors.Wrap(err, "failed to parse client template") } var tplBuff bytes.Buffer @@ -137,7 +137,7 @@ func (p Peer) GetConfigFile(device Device) ([]byte, error) { Server: device, }) if err != nil { - return nil, err + return nil, errors.Wrap(err, "failed to execute client template") } return tplBuff.Bytes(), nil @@ -149,7 +149,7 @@ func (p Peer) GetQRCode() ([]byte, error) { logrus.WithFields(logrus.Fields{ "err": err, }).Error("failed to create qrcode") - return nil, err + return nil, errors.Wrap(err, "failed to encode qrcode") } return png, nil } @@ -247,7 +247,7 @@ func (d Device) GetConfig() wgtypes.Config { func (d Device) GetConfigFile(peers []Peer) ([]byte, error) { tpl, err := template.New("server").Funcs(template.FuncMap{"StringsJoin": strings.Join}).Parse(wireguard.DeviceCfgTpl) if err != nil { - return nil, err + return nil, errors.Wrap(err, "failed to parse server template") } var tplBuff bytes.Buffer @@ -260,7 +260,7 @@ func (d Device) GetConfigFile(peers []Peer) ([]byte, error) { Server: d, }) if err != nil { - return nil, err + return nil, errors.Wrap(err, "failed to execute server template") } return tplBuff.Bytes(), nil @@ -582,7 +582,7 @@ func (u *PeerManager) CreatePeer(peer Peer) error { res := u.db.Create(&peer) if res.Error != nil { logrus.Errorf("failed to create peer: %v", res.Error) - return res.Error + return errors.Wrap(res.Error, "failed to create peer") } return nil @@ -596,7 +596,7 @@ func (u *PeerManager) UpdatePeer(peer Peer) error { res := u.db.Save(&peer) if res.Error != nil { logrus.Errorf("failed to update peer: %v", res.Error) - return res.Error + return errors.Wrap(res.Error, "failed to update peer") } return nil @@ -606,7 +606,7 @@ func (u *PeerManager) DeletePeer(peer Peer) error { res := u.db.Delete(&peer) if res.Error != nil { logrus.Errorf("failed to delete peer: %v", res.Error) - return res.Error + return errors.Wrap(res.Error, "failed to delete peer") } return nil @@ -621,7 +621,7 @@ func (u *PeerManager) UpdateDevice(device Device) error { res := u.db.Save(&device) if res.Error != nil { logrus.Errorf("failed to update device: %v", res.Error) - return res.Error + return errors.Wrap(res.Error, "failed to update device") } return nil @@ -637,7 +637,7 @@ func (u *PeerManager) GetAllReservedIps() ([]string, error) { } ip, _, err := net.ParseCIDR(cidr) if err != nil { - return nil, err + return nil, errors.Wrap(err, "failed to parse cidr") } reservedIps = append(reservedIps, ip.String()) } @@ -650,7 +650,7 @@ func (u *PeerManager) GetAllReservedIps() ([]string, error) { } ip, _, err := net.ParseCIDR(cidr) if err != nil { - return nil, err + return nil, errors.Wrap(err, "failed to parse cidr") } reservedIps = append(reservedIps, ip.String()) @@ -691,11 +691,11 @@ func (u *PeerManager) IsIPReserved(cidr string) bool { func (u *PeerManager) GetAvailableIp(cidr string) (string, error) { reserved, err := u.GetAllReservedIps() if err != nil { - return "", err + return "", errors.WithMessage(err, "failed to get all reserved IP addresses") } ip, ipnet, err := net.ParseCIDR(cidr) if err != nil { - return "", err + return "", errors.Wrap(err, "failed to parse cidr") } // this two addresses are not usable diff --git a/internal/server/core.go b/internal/server/server.go similarity index 92% rename from internal/server/core.go rename to internal/server/server.go index c095efc..096c7e0 100644 --- a/internal/server/core.go +++ b/internal/server/server.go @@ -84,8 +84,8 @@ func (s *Server) Setup(ctx context.Context) error { dir := s.getExecutableDirectory() rDir, _ := filepath.Abs(filepath.Dir(os.Args[0])) - logrus.Infof("Real working directory: %s", rDir) - logrus.Infof("Current working directory: %s", dir) + logrus.Infof("real working directory: %s", rDir) + logrus.Infof("current working directory: %s", dir) // Init rand rand.Seed(time.Now().UnixNano()) @@ -166,7 +166,7 @@ func (s *Server) Setup(ctx context.Context) error { return errors.Wrap(err, "unable to pare mail template") } - logrus.Infof("Setup of service completed!") + logrus.Infof("setup of service completed!") return nil } @@ -201,7 +201,7 @@ func (s *Server) Run() { func (s *Server) getExecutableDirectory() string { dir, err := filepath.Abs(filepath.Dir(os.Args[0])) if err != nil { - logrus.Errorf("Failed to get executable directory: %v", err) + logrus.Errorf("failed to get executable directory: %v", err) } if _, err := os.Stat(filepath.Join(dir, "assets")); os.IsNotExist(err) { @@ -240,7 +240,7 @@ func GetSessionData(c *gin.Context) SessionData { } session.Set(SessionIdentifier, sessionData) if err := session.Save(); err != nil { - logrus.Errorf("Failed to store session: %v", err) + logrus.Errorf("failed to store session: %v", err) } } @@ -251,7 +251,7 @@ func GetFlashes(c *gin.Context) []FlashData { session := sessions.Default(c) flashes := session.Flashes() if err := session.Save(); err != nil { - logrus.Errorf("Failed to store session after setting flash: %v", err) + logrus.Errorf("failed to store session after setting flash: %v", err) } flashData := make([]FlashData, len(flashes)) @@ -266,8 +266,8 @@ func UpdateSessionData(c *gin.Context, data SessionData) error { session := sessions.Default(c) session.Set(SessionIdentifier, data) if err := session.Save(); err != nil { - logrus.Errorf("Failed to store session: %v", err) - return err + logrus.Errorf("failed to store session: %v", err) + return errors.Wrap(err, "failed to store session") } return nil } @@ -276,8 +276,8 @@ func DestroySessionData(c *gin.Context) error { session := sessions.Default(c) session.Delete(SessionIdentifier) if err := session.Save(); err != nil { - logrus.Errorf("Failed to destroy session: %v", err) - return err + logrus.Errorf("failed to destroy session: %v", err) + return errors.Wrap(err, "failed to destroy session") } return nil } @@ -289,7 +289,7 @@ func SetFlashMessage(c *gin.Context, message, typ string) { Type: typ, }) if err := session.Save(); err != nil { - logrus.Errorf("Failed to store session after setting flash: %v", err) + logrus.Errorf("failed to store session after setting flash: %v", err) } } diff --git a/internal/server/helper.go b/internal/server/server_helper.go similarity index 52% rename from internal/server/helper.go rename to internal/server/server_helper.go index 0b59aa4..41fd7dc 100644 --- a/internal/server/helper.go +++ b/internal/server/server_helper.go @@ -8,8 +8,11 @@ import ( "time" "github.com/h44z/wg-portal/internal/common" + "github.com/h44z/wg-portal/internal/users" "github.com/pkg/errors" + "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "gorm.io/gorm" ) func (s *Server) PrepareNewPeer() (Peer, error) { @@ -22,18 +25,18 @@ func (s *Server) PrepareNewPeer() (Peer, error) { for i := range device.IPs { freeIP, err := s.peers.GetAvailableIp(device.IPs[i]) if err != nil { - return Peer{}, err + return Peer{}, errors.WithMessage(err, "failed to get available IP addresses") } peer.IPs[i] = freeIP } peer.IPsStr = common.ListToString(peer.IPs) psk, err := wgtypes.GenerateKey() if err != nil { - return Peer{}, err + return Peer{}, errors.Wrap(err, "failed to generate key") } key, err := wgtypes.GeneratePrivateKey() if err != nil { - return Peer{}, err + return Peer{}, errors.Wrap(err, "failed to generate private key") } peer.PresharedKey = psk.String() peer.PrivateKey = key.String() @@ -57,18 +60,18 @@ func (s *Server) CreatePeerByEmail(email, identifierSuffix string, disabled bool for i := range device.IPs { freeIP, err := s.peers.GetAvailableIp(device.IPs[i]) if err != nil { - return err + return errors.WithMessage(err, "failed to get available IP addresses") } peer.IPs[i] = freeIP } peer.IPsStr = common.ListToString(peer.IPs) psk, err := wgtypes.GenerateKey() if err != nil { - return err + return errors.Wrap(err, "failed to generate key") } key, err := wgtypes.GeneratePrivateKey() if err != nil { - return err + return errors.Wrap(err, "failed to generate private key") } peer.PresharedKey = psk.String() peer.PrivateKey = key.String() @@ -92,7 +95,7 @@ func (s *Server) CreatePeer(peer Peer) error { for i := range device.IPs { freeIP, err := s.peers.GetAvailableIp(device.IPs[i]) if err != nil { - return err + return errors.WithMessage(err, "failed to get available IP addresses") } peer.IPs[i] = freeIP } @@ -101,11 +104,11 @@ func (s *Server) CreatePeer(peer Peer) error { if peer.PrivateKey == "" { // if private key is empty create a new one psk, err := wgtypes.GenerateKey() if err != nil { - return err + return errors.Wrap(err, "failed to generate key") } key, err := wgtypes.GeneratePrivateKey() if err != nil { - return err + return errors.Wrap(err, "failed to generate private key") } peer.PresharedKey = psk.String() peer.PrivateKey = key.String() @@ -116,13 +119,13 @@ func (s *Server) CreatePeer(peer Peer) error { // Create WireGuard interface if peer.DeactivatedAt == nil { if err := s.wg.AddPeer(peer.GetConfig()); err != nil { - return err + return errors.WithMessage(err, "failed to add WireGuard peer") } } // Create in database if err := s.peers.CreatePeer(peer); err != nil { - return err + return errors.WithMessage(err, "failed to create peer") } return s.WriteWireGuardConfigFile() @@ -142,12 +145,12 @@ func (s *Server) UpdatePeer(peer Peer, updateTime time.Time) error { err = s.wg.AddPeer(peer.GetConfig()) } if err != nil { - return err + return errors.WithMessage(err, "failed to update WireGuard peer") } // Update in database if err := s.peers.UpdatePeer(peer); err != nil { - return err + return errors.WithMessage(err, "failed to update peer") } return s.WriteWireGuardConfigFile() @@ -156,12 +159,12 @@ func (s *Server) UpdatePeer(peer Peer, updateTime time.Time) error { func (s *Server) DeletePeer(peer Peer) error { // Delete WireGuard peer if err := s.wg.RemovePeer(peer.PublicKey); err != nil { - return err + return errors.WithMessage(err, "failed to remove WireGuard peer") } // Delete in database if err := s.peers.DeletePeer(peer); err != nil { - return err + return errors.WithMessage(err, "failed to remove peer") } return s.WriteWireGuardConfigFile() @@ -173,7 +176,7 @@ func (s *Server) RestoreWireGuardInterface() error { for i := range activePeers { if activePeers[i].Peer == nil { if err := s.wg.AddPeer(activePeers[i].GetConfig()); err != nil { - return err + return errors.WithMessage(err, "failed to add WireGuard peer") } } } @@ -186,16 +189,109 @@ func (s *Server) WriteWireGuardConfigFile() error { return nil // writing disabled } if err := syscall.Access(s.config.WG.WireGuardConfig, syscall.O_RDWR); err != nil { - return err + return errors.Wrap(err, "failed to check WireGuard config access rights") } device := s.peers.GetDevice() cfg, err := device.GetConfigFile(s.peers.GetActivePeers()) if err != nil { - return err + return errors.WithMessage(err, "failed to get config file") } if err := ioutil.WriteFile(s.config.WG.WireGuardConfig, cfg, 0644); err != nil { - return err + return errors.Wrap(err, "failed to write WireGuard config file") } return nil } + +func (s *Server) CreateUser(user users.User) error { + if user.Email == "" { + return errors.New("cannot create user with empty email address") + } + + // Check if user already exists, if so re-enable + if existingUser := s.users.GetUserUnscoped(user.Email); existingUser != nil { + user.DeletedAt = gorm.DeletedAt{} // reset deleted flag to enable that user again + return s.UpdateUser(user) + } + + // Create user in database + if err := s.users.CreateUser(&user); err != nil { + return errors.WithMessage(err, "failed to create user in manager") + } + + // Check if user already has a peer setup, if not, create one + return s.CreateUserDefaultPeer(user.Email) +} + +func (s *Server) UpdateUser(user users.User) error { + if user.DeletedAt.Valid { + return s.DeleteUser(user) + } + + currentUser := s.users.GetUserUnscoped(user.Email) + + // Update in database + if err := s.users.UpdateUser(&user); err != nil { + return errors.WithMessage(err, "failed to update user in manager") + } + + // If user was deleted (disabled), reactivate it's peers + if currentUser.DeletedAt.Valid { + for _, peer := range s.peers.GetPeersByMail(user.Email) { + now := time.Now() + peer.DeactivatedAt = nil + if err := s.UpdatePeer(peer, now); err != nil { + logrus.Errorf("failed to update (re)activated peer %s for %s: %v", peer.PublicKey, user.Email, err) + } + } + } + + return nil +} + +func (s *Server) DeleteUser(user users.User) error { + currentUser := s.users.GetUserUnscoped(user.Email) + + // Update in database + if err := s.users.DeleteUser(&user); err != nil { + return errors.WithMessage(err, "failed to delete user in manager") + } + + // If user was active, disable it's peers + if !currentUser.DeletedAt.Valid { + for _, peer := range s.peers.GetPeersByMail(user.Email) { + now := time.Now() + peer.DeactivatedAt = &now + if err := s.UpdatePeer(peer, now); err != nil { + logrus.Errorf("failed to update deactivated peer %s for %s: %v", peer.PublicKey, user.Email, err) + } + } + } + + return nil +} + +func (s *Server) CreateUserDefaultPeer(email string) error { + // Check if user is active, if not, quit + var existingUser *users.User + if existingUser = s.users.GetUser(email); existingUser == nil { + return nil + } + + // Check if user already has a peer setup, if not, create one + if s.config.Core.CreateDefaultPeer { + peers := s.peers.GetPeersByMail(email) + if len(peers) == 0 { // Create default vpn peer + if err := s.CreatePeer(Peer{ + Identifier: existingUser.Firstname + " " + existingUser.Lastname + " (Default)", + Email: existingUser.Email, + CreatedBy: existingUser.Email, + UpdatedBy: existingUser.Email, + }); err != nil { + return errors.WithMessagef(err, "failed to automatically create vpn peer for %s", email) + } + } + } + + return nil +} diff --git a/internal/wireguard/manager.go b/internal/wireguard/manager.go index 976b841..f7f2a78 100644 --- a/internal/wireguard/manager.go +++ b/internal/wireguard/manager.go @@ -1,9 +1,10 @@ package wireguard import ( - "fmt" "sync" + "github.com/pkg/errors" + "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) @@ -18,7 +19,7 @@ func (m *Manager) Init() error { var err error m.wg, err = wgctrl.New() if err != nil { - return fmt.Errorf("could not create WireGuard client: %w", err) + return errors.Wrap(err, "could not create WireGuard client") } return nil @@ -27,7 +28,7 @@ func (m *Manager) Init() error { func (m *Manager) GetDeviceInfo() (*wgtypes.Device, error) { dev, err := m.wg.Device(m.Cfg.DeviceName) if err != nil { - return nil, fmt.Errorf("could not get WireGuard device: %w", err) + return nil, errors.Wrap(err, "could not get WireGuard device") } return dev, nil @@ -39,7 +40,7 @@ func (m *Manager) GetPeerList() ([]wgtypes.Peer, error) { dev, err := m.wg.Device(m.Cfg.DeviceName) if err != nil { - return nil, fmt.Errorf("could not get WireGuard device: %w", err) + return nil, errors.Wrap(err, "could not get WireGuard device") } return dev.Peers, nil @@ -51,12 +52,12 @@ func (m *Manager) GetPeer(pubKey string) (*wgtypes.Peer, error) { publicKey, err := wgtypes.ParseKey(pubKey) if err != nil { - return nil, fmt.Errorf("invalid public key: %w", err) + return nil, errors.Wrap(err, "invalid public key") } peers, err := m.GetPeerList() if err != nil { - return nil, fmt.Errorf("could not get WireGuard peers: %w", err) + return nil, errors.Wrap(err, "could not get WireGuard peers") } for _, peer := range peers { @@ -65,7 +66,7 @@ func (m *Manager) GetPeer(pubKey string) (*wgtypes.Peer, error) { } } - return nil, fmt.Errorf("could not find WireGuard peer: %s", pubKey) + return nil, errors.Errorf("could not find WireGuard peer: %s", pubKey) } func (m *Manager) AddPeer(cfg wgtypes.PeerConfig) error { @@ -74,7 +75,7 @@ func (m *Manager) AddPeer(cfg wgtypes.PeerConfig) error { err := m.wg.ConfigureDevice(m.Cfg.DeviceName, wgtypes.Config{Peers: []wgtypes.PeerConfig{cfg}}) if err != nil { - return fmt.Errorf("could not configure WireGuard device: %w", err) + return errors.Wrap(err, "could not configure WireGuard device") } return nil @@ -87,7 +88,7 @@ func (m *Manager) UpdatePeer(cfg wgtypes.PeerConfig) error { cfg.UpdateOnly = true err := m.wg.ConfigureDevice(m.Cfg.DeviceName, wgtypes.Config{Peers: []wgtypes.PeerConfig{cfg}}) if err != nil { - return fmt.Errorf("could not configure WireGuard device: %w", err) + return errors.Wrap(err, "could not configure WireGuard device") } return nil @@ -99,7 +100,7 @@ func (m *Manager) RemovePeer(pubKey string) error { publicKey, err := wgtypes.ParseKey(pubKey) if err != nil { - return fmt.Errorf("invalid public key: %w", err) + return errors.Wrap(err, "invalid public key") } peer := wgtypes.PeerConfig{ @@ -109,7 +110,7 @@ func (m *Manager) RemovePeer(pubKey string) error { err = m.wg.ConfigureDevice(m.Cfg.DeviceName, wgtypes.Config{Peers: []wgtypes.PeerConfig{peer}}) if err != nil { - return fmt.Errorf("could not configure WireGuard device: %w", err) + return errors.Wrap(err, "could not configure WireGuard device") } return nil diff --git a/internal/wireguard/net.go b/internal/wireguard/net.go index 92526f8..0b9e68b 100644 --- a/internal/wireguard/net.go +++ b/internal/wireguard/net.go @@ -4,6 +4,8 @@ import ( "fmt" "net" + "github.com/pkg/errors" + "github.com/milosgajdos/tenus" ) @@ -12,18 +14,18 @@ const DefaultMTU = 1420 func (m *Manager) GetIPAddress() ([]string, error) { wgInterface, err := tenus.NewLinkFrom(m.Cfg.DeviceName) if err != nil { - return nil, fmt.Errorf("could not retrieve WireGuard interface %s: %w", m.Cfg.DeviceName, err) + return nil, errors.Wrapf(err, "could not retrieve WireGuard interface %s", m.Cfg.DeviceName) } // Get golang net.interface iface := wgInterface.NetInterface() if iface == nil { // Not sure if this check is really necessary - return nil, fmt.Errorf("could not retrieve WireGuard net.interface: %w", err) + return nil, errors.Wrap(err, "could not retrieve WireGuard net.interface") } addrs, err := iface.Addrs() if err != nil { - return nil, fmt.Errorf("could not retrieve WireGuard ip addresses: %w", err) + return nil, errors.Wrap(err, "could not retrieve WireGuard ip addresses") } ipAddresses := make([]string, 0, len(addrs)) @@ -53,22 +55,22 @@ func (m *Manager) GetIPAddress() ([]string, error) { func (m *Manager) SetIPAddress(cidrs []string) error { wgInterface, err := tenus.NewLinkFrom(m.Cfg.DeviceName) if err != nil { - return fmt.Errorf("could not retrieve WireGuard interface %s: %w", m.Cfg.DeviceName, err) + return errors.Wrapf(err, "could not retrieve WireGuard interface %s", m.Cfg.DeviceName) } // First remove existing IP addresses existingIPs, err := m.GetIPAddress() if err != nil { - return err + return errors.Wrap(err, "could not retrieve IP addresses") } for _, cidr := range existingIPs { wgIp, wgIpNet, err := net.ParseCIDR(cidr) if err != nil { - return fmt.Errorf("unable to parse cidr %s: %w", cidr, err) + return errors.Wrapf(err, "unable to parse cidr %s", cidr) } if err := wgInterface.UnsetLinkIp(wgIp, wgIpNet); err != nil { - return fmt.Errorf("failed to unset ip %s: %w", cidr, err) + return errors.Wrapf(err, "failed to unset ip %s", cidr) } } @@ -76,11 +78,11 @@ func (m *Manager) SetIPAddress(cidrs []string) error { for _, cidr := range cidrs { wgIp, wgIpNet, err := net.ParseCIDR(cidr) if err != nil { - return fmt.Errorf("unable to parse cidr %s: %w", cidr, err) + return errors.Wrapf(err, "unable to parse cidr %s", cidr) } if err := wgInterface.SetLinkIp(wgIp, wgIpNet); err != nil { - return fmt.Errorf("failed to set ip %s: %w", cidr, err) + return errors.Wrapf(err, "failed to set ip %s", cidr) } } @@ -90,13 +92,13 @@ func (m *Manager) SetIPAddress(cidrs []string) error { func (m *Manager) GetMTU() (int, error) { wgInterface, err := tenus.NewLinkFrom(m.Cfg.DeviceName) if err != nil { - return 0, fmt.Errorf("could not retrieve WireGuard interface %s: %w", m.Cfg.DeviceName, err) + return 0, errors.Wrapf(err, "could not retrieve WireGuard interface %s", m.Cfg.DeviceName) } // Get golang net.interface iface := wgInterface.NetInterface() if iface == nil { // Not sure if this check is really necessary - return 0, fmt.Errorf("could not retrieve WireGuard net.interface: %w", err) + return 0, errors.Wrap(err, "could not retrieve WireGuard net.interface") } return iface.MTU, nil @@ -105,7 +107,7 @@ func (m *Manager) GetMTU() (int, error) { func (m *Manager) SetMTU(mtu int) error { wgInterface, err := tenus.NewLinkFrom(m.Cfg.DeviceName) if err != nil { - return fmt.Errorf("could not retrieve WireGuard interface %s: %w", m.Cfg.DeviceName, err) + return errors.Wrapf(err, "could not retrieve WireGuard interface %s", m.Cfg.DeviceName) } if mtu == 0 { @@ -113,7 +115,7 @@ func (m *Manager) SetMTU(mtu int) error { } if err := wgInterface.SetLinkMTU(mtu); err != nil { - return fmt.Errorf("could not set MTU on interface %s: %w", m.Cfg.DeviceName, err) + return errors.Wrapf(err, "could not set MTU on interface %s", m.Cfg.DeviceName) } return nil