mirror of
https://github.com/DJSundog/wg-portal.git
synced 2024-11-23 07:03:50 -05:00
cleanup
This commit is contained in:
parent
8ea82c1916
commit
9faa459c44
@ -12,7 +12,7 @@ import (
|
|||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
var Version string = "unknown (local build)"
|
var Version = "unknown (local build)"
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
_ = setupLogger(logrus.StandardLogger())
|
_ = setupLogger(logrus.StandardLogger())
|
||||||
@ -20,7 +20,7 @@ func main() {
|
|||||||
c := make(chan os.Signal, 1)
|
c := make(chan os.Signal, 1)
|
||||||
signal.Notify(c, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
|
signal.Notify(c, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
|
||||||
|
|
||||||
logrus.Infof("Starting WireGuard Portal Server [%s]...", Version)
|
logrus.Infof("starting WireGuard Portal Server [%s]...", Version)
|
||||||
|
|
||||||
// Context for clean shutdown
|
// Context for clean shutdown
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
@ -28,7 +28,7 @@ func main() {
|
|||||||
|
|
||||||
service := server.Server{}
|
service := server.Server{}
|
||||||
if err := service.Setup(ctx); err != nil {
|
if err := service.Setup(ctx); err != nil {
|
||||||
logrus.Fatalf("Setup failed: %v", err)
|
logrus.Fatalf("setup failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Attach signal handlers to context
|
// Attach signal handlers to context
|
||||||
@ -44,10 +44,10 @@ func main() {
|
|||||||
<-ctx.Done() // Wait until the context gets canceled
|
<-ctx.Done() // Wait until the context gets canceled
|
||||||
|
|
||||||
// Give goroutines some time to stop gracefully
|
// Give goroutines some time to stop gracefully
|
||||||
logrus.Info("Stopping WireGuard Portal Server...")
|
logrus.Info("stopping WireGuard Portal Server...")
|
||||||
time.Sleep(2 * time.Second)
|
time.Sleep(2 * time.Second)
|
||||||
|
|
||||||
logrus.Infof("Stopped WireGuard Portal Server...")
|
logrus.Infof("stopped WireGuard Portal Server...")
|
||||||
logrus.Exit(0)
|
logrus.Exit(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4,10 +4,10 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// AuthContext contains all information that the AuthProvider needs to perform the authentication.
|
||||||
type AuthContext struct {
|
type AuthContext struct {
|
||||||
Provider AuthProvider
|
|
||||||
Username string // email or username
|
Username string // email or username
|
||||||
Password string // optional for OIDC
|
Password string
|
||||||
Callback string // callback for OIDC
|
Callback string // callback for OIDC
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -18,6 +18,7 @@ const (
|
|||||||
AuthProviderTypeOauth AuthProviderType = "oauth"
|
AuthProviderTypeOauth AuthProviderType = "oauth"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// AuthProvider is a interface that can be implemented by different authentication providers like LDAP, OAUTH, ...
|
||||||
type AuthProvider interface {
|
type AuthProvider interface {
|
||||||
GetName() string
|
GetName() string
|
||||||
GetType() AuthProviderType
|
GetType() AuthProviderType
|
||||||
|
@ -13,7 +13,7 @@ import (
|
|||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Provider provide login with password method
|
// Provider implements a password login method for an LDAP backend.
|
||||||
type Provider struct {
|
type Provider struct {
|
||||||
config *ldapconfig.Config
|
config *ldapconfig.Config
|
||||||
}
|
}
|
||||||
|
@ -14,7 +14,7 @@ import (
|
|||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Provider provide login with password method
|
// Provider implements a password login method for a database backend.
|
||||||
type Provider struct {
|
type Provider struct {
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
package authentication
|
package authentication
|
||||||
|
|
||||||
|
// User represents the data that can be retrieved from authentication backends.
|
||||||
type User struct {
|
type User struct {
|
||||||
Email string
|
Email string
|
||||||
IsAdmin bool
|
IsAdmin bool
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package common
|
package common
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"runtime"
|
"runtime"
|
||||||
@ -10,13 +9,14 @@ import (
|
|||||||
"github.com/h44z/wg-portal/internal/users"
|
"github.com/h44z/wg-portal/internal/users"
|
||||||
"github.com/h44z/wg-portal/internal/wireguard"
|
"github.com/h44z/wg-portal/internal/wireguard"
|
||||||
"github.com/kelseyhightower/envconfig"
|
"github.com/kelseyhightower/envconfig"
|
||||||
|
"github.com/pkg/errors"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrInvalidSpecification = errors.New("specification must be a struct pointer")
|
var ErrInvalidSpecification = errors.New("specification must be a struct pointer")
|
||||||
|
|
||||||
// LoadConfigFile parses yaml files. It uses to yaml annotation to store the data in a struct.
|
// loadConfigFile parses yaml files. It uses yaml annotation to store the data in a struct.
|
||||||
func loadConfigFile(cfg interface{}, filename string) error {
|
func loadConfigFile(cfg interface{}, filename string) error {
|
||||||
s := reflect.ValueOf(cfg)
|
s := reflect.ValueOf(cfg)
|
||||||
|
|
||||||
@ -30,24 +30,24 @@ func loadConfigFile(cfg interface{}, filename string) error {
|
|||||||
|
|
||||||
f, err := os.Open(filename)
|
f, err := os.Open(filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return errors.Wrapf(err, "failed to open config file %s", filename)
|
||||||
}
|
}
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
decoder := yaml.NewDecoder(f)
|
decoder := yaml.NewDecoder(f)
|
||||||
err = decoder.Decode(cfg)
|
err = decoder.Decode(cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return errors.Wrapf(err, "failed to decode config file %s", filename)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadConfigEnv processes envconfig annotations and loads environment variables to the given configuration struct.
|
// loadConfigEnv processes envconfig annotations and loads environment variables to the given configuration struct.
|
||||||
func loadConfigEnv(cfg interface{}) error {
|
func loadConfigEnv(cfg interface{}) error {
|
||||||
err := envconfig.Process("", cfg)
|
err := envconfig.Process("", cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return errors.Wrap(err, "failed to process environment config")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -124,7 +124,7 @@ func NewConfig() *Config {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if cfg.WG.ManageIPAddresses && runtime.GOOS != "linux" {
|
if cfg.WG.ManageIPAddresses && runtime.GOOS != "linux" {
|
||||||
logrus.Warnf("Managing IP addresses only works on linux! Feature disabled.")
|
logrus.Warnf("managing IP addresses only works on linux, feature disabled...")
|
||||||
cfg.WG.ManageIPAddresses = false
|
cfg.WG.ManageIPAddresses = false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -26,7 +26,7 @@ type MailAttachment struct {
|
|||||||
Embedded bool
|
Embedded bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendEmailWithAttachments sends a mail with attachments.
|
// SendEmailWithAttachments sends a mail with optional attachments.
|
||||||
func SendEmailWithAttachments(cfg MailConfig, sender, replyTo, subject, body string, htmlBody string, receivers []string, attachments []MailAttachment) error {
|
func SendEmailWithAttachments(cfg MailConfig, sender, replyTo, subject, body string, htmlBody string, receivers []string, attachments []MailAttachment) error {
|
||||||
e := email.NewEmail()
|
e := email.NewEmail()
|
||||||
|
|
||||||
|
@ -40,6 +40,8 @@ func IsIPv6(address string) bool {
|
|||||||
return ip.To4() == nil
|
return ip.To4() == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ParseStringList converts a comma separated string into a list of strings.
|
||||||
|
// It also trims spaces from each element of the list.
|
||||||
func ParseStringList(lst string) []string {
|
func ParseStringList(lst string) []string {
|
||||||
tokens := strings.Split(lst, ",")
|
tokens := strings.Split(lst, ",")
|
||||||
validatedTokens := make([]string, 0, len(tokens))
|
validatedTokens := make([]string, 0, len(tokens))
|
||||||
@ -53,6 +55,7 @@ func ParseStringList(lst string) []string {
|
|||||||
return validatedTokens
|
return validatedTokens
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ListToString converts a list of strings into a comma separated string.
|
||||||
func ListToString(lst []string) string {
|
func ListToString(lst []string) string {
|
||||||
return strings.Join(lst, ", ")
|
return strings.Join(lst, ", ")
|
||||||
}
|
}
|
||||||
|
@ -23,5 +23,5 @@ type Config struct {
|
|||||||
GroupMemberAttribute string `yaml:"attrGroups" envconfig:"LDAP_ATTR_GROUPS"`
|
GroupMemberAttribute string `yaml:"attrGroups" envconfig:"LDAP_ATTR_GROUPS"`
|
||||||
DisabledAttribute string `yaml:"attrDisabled" envconfig:"LDAP_ATTR_DISABLED"`
|
DisabledAttribute string `yaml:"attrDisabled" envconfig:"LDAP_ATTR_DISABLED"`
|
||||||
|
|
||||||
AdminLdapGroup string `yaml:"adminGroup" envconfig:"LDAP_ADMIN_GROUP"`
|
AdminLdapGroup string `yaml:"adminGroup" envconfig:"LDAP_ADMIN_GROUP"` // Members of this group receive admin rights in WG-Portal
|
||||||
}
|
}
|
||||||
|
@ -18,20 +18,20 @@ type RawLdapData struct {
|
|||||||
func Open(cfg *Config) (*ldap.Conn, error) {
|
func Open(cfg *Config) (*ldap.Conn, error) {
|
||||||
conn, err := ldap.DialURL(cfg.URL)
|
conn, err := ldap.DialURL(cfg.URL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, errors.Wrap(err, "failed to connect to LDAP")
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.StartTLS {
|
if cfg.StartTLS {
|
||||||
// Reconnect with TLS
|
// Reconnect with TLS
|
||||||
err = conn.StartTLS(&tls.Config{InsecureSkipVerify: true})
|
err = conn.StartTLS(&tls.Config{InsecureSkipVerify: true})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, errors.Wrap(err, "failed to star TLS on connection")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = conn.Bind(cfg.BindUser, cfg.BindPass)
|
err = conn.Bind(cfg.BindUser, cfg.BindPass)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, errors.Wrap(err, "failed to bind to LDAP")
|
||||||
}
|
}
|
||||||
|
|
||||||
return conn, nil
|
return conn, nil
|
||||||
|
@ -3,14 +3,13 @@ package server
|
|||||||
import (
|
import (
|
||||||
"sort"
|
"sort"
|
||||||
|
|
||||||
"github.com/h44z/wg-portal/internal/authentication"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/h44z/wg-portal/internal/authentication"
|
||||||
"github.com/h44z/wg-portal/internal/users"
|
"github.com/h44z/wg-portal/internal/users"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Auth auth struct
|
// AuthManager keeps track of available authentication providers.
|
||||||
type AuthManager struct {
|
type AuthManager struct {
|
||||||
Server *Server
|
Server *Server
|
||||||
Group *gin.RouterGroup // basic group for all providers (/auth)
|
Group *gin.RouterGroup // basic group for all providers (/auth)
|
||||||
@ -38,7 +37,7 @@ func (auth *AuthManager) RegisterProviderWithoutError(provider authentication.Au
|
|||||||
auth.RegisterProvider(provider)
|
auth.RegisterProvider(provider)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetProvider get provider with name
|
// GetProvider get provider by name
|
||||||
func (auth *AuthManager) GetProvider(name string) authentication.AuthProvider {
|
func (auth *AuthManager) GetProvider(name string) authentication.AuthProvider {
|
||||||
for _, provider := range auth.providers {
|
for _, provider := range auth.providers {
|
||||||
if provider.GetName() == name {
|
if provider.GetName() == name {
|
||||||
@ -48,15 +47,23 @@ func (auth *AuthManager) GetProvider(name string) authentication.AuthProvider {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetProviders return registered providers
|
// GetProviders return registered providers.
|
||||||
|
// Returned providers are ordered by provider priority.
|
||||||
func (auth *AuthManager) GetProviders() (providers []authentication.AuthProvider) {
|
func (auth *AuthManager) GetProviders() (providers []authentication.AuthProvider) {
|
||||||
for _, provider := range auth.providers {
|
for _, provider := range auth.providers {
|
||||||
providers = append(providers, provider)
|
providers = append(providers, provider)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// order by priority
|
||||||
|
sort.SliceStable(providers, func(i, j int) bool {
|
||||||
|
return providers[i].GetPriority() < providers[j].GetPriority()
|
||||||
|
})
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetProviders return registered providers
|
// GetProvidersForType return registered providers for the given type.
|
||||||
|
// Returned providers are ordered by provider priority.
|
||||||
func (auth *AuthManager) GetProvidersForType(typ authentication.AuthProviderType) (providers []authentication.AuthProvider) {
|
func (auth *AuthManager) GetProvidersForType(typ authentication.AuthProviderType) (providers []authentication.AuthProvider) {
|
||||||
for _, provider := range auth.providers {
|
for _, provider := range auth.providers {
|
||||||
if provider.GetType() == typ {
|
if provider.GetType() == typ {
|
||||||
|
@ -8,7 +8,6 @@ import (
|
|||||||
"github.com/h44z/wg-portal/internal/authentication"
|
"github.com/h44z/wg-portal/internal/authentication"
|
||||||
"github.com/h44z/wg-portal/internal/users"
|
"github.com/h44z/wg-portal/internal/users"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"gorm.io/gorm"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Server) GetLogin(c *gin.Context) {
|
func (s *Server) GetLogin(c *gin.Context) {
|
||||||
@ -85,10 +84,6 @@ func (s *Server) PostLogin(c *gin.Context) {
|
|||||||
loginProvider = provider
|
loginProvider = provider
|
||||||
|
|
||||||
// create new user in the database (or reactivate him)
|
// create new user in the database (or reactivate him)
|
||||||
if user, err = s.users.GetOrCreateUserUnscoped(email); err != nil {
|
|
||||||
s.GetHandleError(c, http.StatusInternalServerError, "login error", "failed to create new user")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
userData, err := loginProvider.GetUserModel(&authentication.AuthContext{
|
userData, err := loginProvider.GetUserModel(&authentication.AuthContext{
|
||||||
Username: email,
|
Username: email,
|
||||||
})
|
})
|
||||||
@ -96,23 +91,25 @@ func (s *Server) PostLogin(c *gin.Context) {
|
|||||||
s.GetHandleError(c, http.StatusInternalServerError, "login error", err.Error())
|
s.GetHandleError(c, http.StatusInternalServerError, "login error", err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
user.Firstname = userData.Firstname
|
if err := s.CreateUser(users.User{
|
||||||
user.Lastname = userData.Lastname
|
Email: userData.Email,
|
||||||
user.Email = userData.Email
|
Source: users.UserSource(loginProvider.GetName()),
|
||||||
user.Phone = userData.Phone
|
IsAdmin: userData.IsAdmin,
|
||||||
user.IsAdmin = userData.IsAdmin
|
Firstname: userData.Firstname,
|
||||||
user.Source = users.UserSource(loginProvider.GetName())
|
Lastname: userData.Lastname,
|
||||||
user.DeletedAt = gorm.DeletedAt{} // reset deleted flag
|
Phone: userData.Phone,
|
||||||
if err = s.users.UpdateUser(user); err != nil {
|
}); err != nil {
|
||||||
s.GetHandleError(c, http.StatusInternalServerError, "login error", "failed to update user data")
|
s.GetHandleError(c, http.StatusInternalServerError, "login error", "failed to update user data")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
user = s.users.GetUser(username)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if user is authenticated
|
// Check if user is authenticated
|
||||||
if email == "" || loginProvider == nil {
|
if email == "" || loginProvider == nil || user == nil {
|
||||||
c.Redirect(http.StatusSeeOther, "/auth/login?err=authfail")
|
c.Redirect(http.StatusSeeOther, "/auth/login?err=authfail")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -126,17 +123,9 @@ func (s *Server) PostLogin(c *gin.Context) {
|
|||||||
sessionData.Lastname = user.Lastname
|
sessionData.Lastname = user.Lastname
|
||||||
|
|
||||||
// Check if user already has a peer setup, if not create one
|
// Check if user already has a peer setup, if not create one
|
||||||
if s.config.Core.CreateDefaultPeer {
|
if err := s.CreateUserDefaultPeer(user.Email); err != nil {
|
||||||
peers := s.peers.GetPeersByMail(sessionData.Email)
|
// Not a fatal error, just log it...
|
||||||
if len(peers) == 0 { // Create vpn peer
|
logrus.Errorf("failed to automatically create vpn peer for %s: %v", sessionData.Email, err)
|
||||||
err := s.CreatePeer(Peer{
|
|
||||||
Identifier: sessionData.Firstname + " " + sessionData.Lastname + " (Default)",
|
|
||||||
Email: sessionData.Email,
|
|
||||||
CreatedBy: sessionData.Email,
|
|
||||||
UpdatedBy: sessionData.Email,
|
|
||||||
})
|
|
||||||
logrus.Errorf("Failed to automatically create vpn peer for %s: %v", sessionData.Email, err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := UpdateSessionData(c, sessionData); err != nil {
|
if err := UpdateSessionData(c, sessionData); err != nil {
|
||||||
|
@ -4,6 +4,8 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -145,7 +147,7 @@ func (s *Server) updateFormInSession(c *gin.Context, formData interface{}) error
|
|||||||
currentSession.FormData = formData
|
currentSession.FormData = formData
|
||||||
|
|
||||||
if err := UpdateSessionData(c, currentSession); err != nil {
|
if err := UpdateSessionData(c, currentSession); err != nil {
|
||||||
return err
|
return errors.WithMessage(err, "failed to update form in session")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -158,13 +160,13 @@ func (s *Server) setNewPeerFormInSession(c *gin.Context) (SessionData, error) {
|
|||||||
if currentSession.FormData == nil || c.Query("formerr") == "" {
|
if currentSession.FormData == nil || c.Query("formerr") == "" {
|
||||||
user, err := s.PrepareNewPeer()
|
user, err := s.PrepareNewPeer()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return currentSession, err
|
return currentSession, errors.WithMessage(err, "failed to prepare new peer")
|
||||||
}
|
}
|
||||||
currentSession.FormData = user
|
currentSession.FormData = user
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := UpdateSessionData(c, currentSession); err != nil {
|
if err := UpdateSessionData(c, currentSession); err != nil {
|
||||||
return currentSession, err
|
return currentSession, errors.WithMessage(err, "failed to update peer form in session")
|
||||||
}
|
}
|
||||||
|
|
||||||
return currentSession, nil
|
return currentSession, nil
|
||||||
@ -179,7 +181,7 @@ func (s *Server) setFormInSession(c *gin.Context, formData interface{}) (Session
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := UpdateSessionData(c, currentSession); err != nil {
|
if err := UpdateSessionData(c, currentSession); err != nil {
|
||||||
return currentSession, err
|
return currentSession, errors.WithMessage(err, "failed to set form in session")
|
||||||
}
|
}
|
||||||
|
|
||||||
return currentSession, nil
|
return currentSession, nil
|
||||||
|
@ -7,7 +7,6 @@ import (
|
|||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/h44z/wg-portal/internal/users"
|
"github.com/h44z/wg-portal/internal/users"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
@ -141,31 +140,7 @@ func (s *Server) PostAdminUsersEdit(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
formUser.IsAdmin = c.PostForm("isadmin") == "true"
|
formUser.IsAdmin = c.PostForm("isadmin") == "true"
|
||||||
|
|
||||||
// Update peers
|
if err := s.UpdateUser(formUser); err != nil {
|
||||||
if disabled != currentUser.DeletedAt.Valid {
|
|
||||||
if disabled {
|
|
||||||
// disable all peers for the given user
|
|
||||||
for _, peer := range s.peers.GetPeersByMail(currentUser.Email) {
|
|
||||||
now := time.Now()
|
|
||||||
peer.DeactivatedAt = &now
|
|
||||||
if err := s.UpdatePeer(peer, now); err != nil {
|
|
||||||
logrus.Errorf("failed to update deactivated peer %s: %v", peer.PublicKey, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// enable all peers for the given user
|
|
||||||
for _, peer := range s.peers.GetPeersByMail(currentUser.Email) {
|
|
||||||
now := time.Now()
|
|
||||||
peer.DeactivatedAt = nil
|
|
||||||
if err := s.UpdatePeer(peer, now); err != nil {
|
|
||||||
logrus.Errorf("failed to update activated peer %s: %v", peer.PublicKey, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update in database
|
|
||||||
if err := s.users.UpdateUser(&formUser); err != nil {
|
|
||||||
_ = s.updateFormInSession(c, formUser)
|
_ = s.updateFormInSession(c, formUser)
|
||||||
SetFlashMessage(c, "failed to update user: "+err.Error(), "danger")
|
SetFlashMessage(c, "failed to update user: "+err.Error(), "danger")
|
||||||
c.Redirect(http.StatusSeeOther, "/admin/users/edit?pkey="+urlEncodedKey+"&formerr=update")
|
c.Redirect(http.StatusSeeOther, "/admin/users/edit?pkey="+urlEncodedKey+"&formerr=update")
|
||||||
@ -242,28 +217,14 @@ func (s *Server) PostAdminUsersCreate(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
formUser.IsAdmin = c.PostForm("isadmin") == "true"
|
formUser.IsAdmin = c.PostForm("isadmin") == "true"
|
||||||
formUser.Source = users.UserSourceDatabase
|
formUser.Source = users.UserSourceDatabase
|
||||||
if err := s.users.CreateUser(&formUser); err != nil {
|
|
||||||
formUser.CreatedAt = time.Time{} // reset created time
|
if err := s.CreateUser(formUser); err != nil {
|
||||||
_ = s.updateFormInSession(c, formUser)
|
_ = s.updateFormInSession(c, formUser)
|
||||||
SetFlashMessage(c, "failed to add user: "+err.Error(), "danger")
|
SetFlashMessage(c, "failed to add user: "+err.Error(), "danger")
|
||||||
c.Redirect(http.StatusSeeOther, "/admin/users/create?formerr=create")
|
c.Redirect(http.StatusSeeOther, "/admin/users/create?formerr=create")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if user already has a peer setup, if not create one
|
|
||||||
if s.config.Core.CreateDefaultPeer {
|
|
||||||
peers := s.peers.GetPeersByMail(formUser.Email)
|
|
||||||
if len(peers) == 0 { // Create vpn peer
|
|
||||||
err := s.CreatePeer(Peer{
|
|
||||||
Identifier: formUser.Firstname + " " + formUser.Lastname + " (Default)",
|
|
||||||
Email: formUser.Email,
|
|
||||||
CreatedBy: formUser.Email,
|
|
||||||
UpdatedBy: formUser.Email,
|
|
||||||
})
|
|
||||||
logrus.Errorf("Failed to automatically create vpn peer for %s: %v", formUser.Email, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
SetFlashMessage(c, "user created successfully", "success")
|
SetFlashMessage(c, "user created successfully", "success")
|
||||||
c.Redirect(http.StatusSeeOther, "/admin/users/")
|
c.Redirect(http.StatusSeeOther, "/admin/users/")
|
||||||
}
|
}
|
||||||
|
@ -124,7 +124,7 @@ func (p Peer) GetConfig() wgtypes.PeerConfig {
|
|||||||
func (p Peer) GetConfigFile(device Device) ([]byte, error) {
|
func (p Peer) GetConfigFile(device Device) ([]byte, error) {
|
||||||
tpl, err := template.New("client").Funcs(template.FuncMap{"StringsJoin": strings.Join}).Parse(wireguard.ClientCfgTpl)
|
tpl, err := template.New("client").Funcs(template.FuncMap{"StringsJoin": strings.Join}).Parse(wireguard.ClientCfgTpl)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, errors.Wrap(err, "failed to parse client template")
|
||||||
}
|
}
|
||||||
|
|
||||||
var tplBuff bytes.Buffer
|
var tplBuff bytes.Buffer
|
||||||
@ -137,7 +137,7 @@ func (p Peer) GetConfigFile(device Device) ([]byte, error) {
|
|||||||
Server: device,
|
Server: device,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, errors.Wrap(err, "failed to execute client template")
|
||||||
}
|
}
|
||||||
|
|
||||||
return tplBuff.Bytes(), nil
|
return tplBuff.Bytes(), nil
|
||||||
@ -149,7 +149,7 @@ func (p Peer) GetQRCode() ([]byte, error) {
|
|||||||
logrus.WithFields(logrus.Fields{
|
logrus.WithFields(logrus.Fields{
|
||||||
"err": err,
|
"err": err,
|
||||||
}).Error("failed to create qrcode")
|
}).Error("failed to create qrcode")
|
||||||
return nil, err
|
return nil, errors.Wrap(err, "failed to encode qrcode")
|
||||||
}
|
}
|
||||||
return png, nil
|
return png, nil
|
||||||
}
|
}
|
||||||
@ -247,7 +247,7 @@ func (d Device) GetConfig() wgtypes.Config {
|
|||||||
func (d Device) GetConfigFile(peers []Peer) ([]byte, error) {
|
func (d Device) GetConfigFile(peers []Peer) ([]byte, error) {
|
||||||
tpl, err := template.New("server").Funcs(template.FuncMap{"StringsJoin": strings.Join}).Parse(wireguard.DeviceCfgTpl)
|
tpl, err := template.New("server").Funcs(template.FuncMap{"StringsJoin": strings.Join}).Parse(wireguard.DeviceCfgTpl)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, errors.Wrap(err, "failed to parse server template")
|
||||||
}
|
}
|
||||||
|
|
||||||
var tplBuff bytes.Buffer
|
var tplBuff bytes.Buffer
|
||||||
@ -260,7 +260,7 @@ func (d Device) GetConfigFile(peers []Peer) ([]byte, error) {
|
|||||||
Server: d,
|
Server: d,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, errors.Wrap(err, "failed to execute server template")
|
||||||
}
|
}
|
||||||
|
|
||||||
return tplBuff.Bytes(), nil
|
return tplBuff.Bytes(), nil
|
||||||
@ -582,7 +582,7 @@ func (u *PeerManager) CreatePeer(peer Peer) error {
|
|||||||
res := u.db.Create(&peer)
|
res := u.db.Create(&peer)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
logrus.Errorf("failed to create peer: %v", res.Error)
|
logrus.Errorf("failed to create peer: %v", res.Error)
|
||||||
return res.Error
|
return errors.Wrap(res.Error, "failed to create peer")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -596,7 +596,7 @@ func (u *PeerManager) UpdatePeer(peer Peer) error {
|
|||||||
res := u.db.Save(&peer)
|
res := u.db.Save(&peer)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
logrus.Errorf("failed to update peer: %v", res.Error)
|
logrus.Errorf("failed to update peer: %v", res.Error)
|
||||||
return res.Error
|
return errors.Wrap(res.Error, "failed to update peer")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -606,7 +606,7 @@ func (u *PeerManager) DeletePeer(peer Peer) error {
|
|||||||
res := u.db.Delete(&peer)
|
res := u.db.Delete(&peer)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
logrus.Errorf("failed to delete peer: %v", res.Error)
|
logrus.Errorf("failed to delete peer: %v", res.Error)
|
||||||
return res.Error
|
return errors.Wrap(res.Error, "failed to delete peer")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -621,7 +621,7 @@ func (u *PeerManager) UpdateDevice(device Device) error {
|
|||||||
res := u.db.Save(&device)
|
res := u.db.Save(&device)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
logrus.Errorf("failed to update device: %v", res.Error)
|
logrus.Errorf("failed to update device: %v", res.Error)
|
||||||
return res.Error
|
return errors.Wrap(res.Error, "failed to update device")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -637,7 +637,7 @@ func (u *PeerManager) GetAllReservedIps() ([]string, error) {
|
|||||||
}
|
}
|
||||||
ip, _, err := net.ParseCIDR(cidr)
|
ip, _, err := net.ParseCIDR(cidr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, errors.Wrap(err, "failed to parse cidr")
|
||||||
}
|
}
|
||||||
reservedIps = append(reservedIps, ip.String())
|
reservedIps = append(reservedIps, ip.String())
|
||||||
}
|
}
|
||||||
@ -650,7 +650,7 @@ func (u *PeerManager) GetAllReservedIps() ([]string, error) {
|
|||||||
}
|
}
|
||||||
ip, _, err := net.ParseCIDR(cidr)
|
ip, _, err := net.ParseCIDR(cidr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, errors.Wrap(err, "failed to parse cidr")
|
||||||
}
|
}
|
||||||
|
|
||||||
reservedIps = append(reservedIps, ip.String())
|
reservedIps = append(reservedIps, ip.String())
|
||||||
@ -691,11 +691,11 @@ func (u *PeerManager) IsIPReserved(cidr string) bool {
|
|||||||
func (u *PeerManager) GetAvailableIp(cidr string) (string, error) {
|
func (u *PeerManager) GetAvailableIp(cidr string) (string, error) {
|
||||||
reserved, err := u.GetAllReservedIps()
|
reserved, err := u.GetAllReservedIps()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", errors.WithMessage(err, "failed to get all reserved IP addresses")
|
||||||
}
|
}
|
||||||
ip, ipnet, err := net.ParseCIDR(cidr)
|
ip, ipnet, err := net.ParseCIDR(cidr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", errors.Wrap(err, "failed to parse cidr")
|
||||||
}
|
}
|
||||||
|
|
||||||
// this two addresses are not usable
|
// this two addresses are not usable
|
||||||
|
@ -84,8 +84,8 @@ func (s *Server) Setup(ctx context.Context) error {
|
|||||||
|
|
||||||
dir := s.getExecutableDirectory()
|
dir := s.getExecutableDirectory()
|
||||||
rDir, _ := filepath.Abs(filepath.Dir(os.Args[0]))
|
rDir, _ := filepath.Abs(filepath.Dir(os.Args[0]))
|
||||||
logrus.Infof("Real working directory: %s", rDir)
|
logrus.Infof("real working directory: %s", rDir)
|
||||||
logrus.Infof("Current working directory: %s", dir)
|
logrus.Infof("current working directory: %s", dir)
|
||||||
|
|
||||||
// Init rand
|
// Init rand
|
||||||
rand.Seed(time.Now().UnixNano())
|
rand.Seed(time.Now().UnixNano())
|
||||||
@ -166,7 +166,7 @@ func (s *Server) Setup(ctx context.Context) error {
|
|||||||
return errors.Wrap(err, "unable to pare mail template")
|
return errors.Wrap(err, "unable to pare mail template")
|
||||||
}
|
}
|
||||||
|
|
||||||
logrus.Infof("Setup of service completed!")
|
logrus.Infof("setup of service completed!")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -201,7 +201,7 @@ func (s *Server) Run() {
|
|||||||
func (s *Server) getExecutableDirectory() string {
|
func (s *Server) getExecutableDirectory() string {
|
||||||
dir, err := filepath.Abs(filepath.Dir(os.Args[0]))
|
dir, err := filepath.Abs(filepath.Dir(os.Args[0]))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.Errorf("Failed to get executable directory: %v", err)
|
logrus.Errorf("failed to get executable directory: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := os.Stat(filepath.Join(dir, "assets")); os.IsNotExist(err) {
|
if _, err := os.Stat(filepath.Join(dir, "assets")); os.IsNotExist(err) {
|
||||||
@ -240,7 +240,7 @@ func GetSessionData(c *gin.Context) SessionData {
|
|||||||
}
|
}
|
||||||
session.Set(SessionIdentifier, sessionData)
|
session.Set(SessionIdentifier, sessionData)
|
||||||
if err := session.Save(); err != nil {
|
if err := session.Save(); err != nil {
|
||||||
logrus.Errorf("Failed to store session: %v", err)
|
logrus.Errorf("failed to store session: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -251,7 +251,7 @@ func GetFlashes(c *gin.Context) []FlashData {
|
|||||||
session := sessions.Default(c)
|
session := sessions.Default(c)
|
||||||
flashes := session.Flashes()
|
flashes := session.Flashes()
|
||||||
if err := session.Save(); err != nil {
|
if err := session.Save(); err != nil {
|
||||||
logrus.Errorf("Failed to store session after setting flash: %v", err)
|
logrus.Errorf("failed to store session after setting flash: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
flashData := make([]FlashData, len(flashes))
|
flashData := make([]FlashData, len(flashes))
|
||||||
@ -266,8 +266,8 @@ func UpdateSessionData(c *gin.Context, data SessionData) error {
|
|||||||
session := sessions.Default(c)
|
session := sessions.Default(c)
|
||||||
session.Set(SessionIdentifier, data)
|
session.Set(SessionIdentifier, data)
|
||||||
if err := session.Save(); err != nil {
|
if err := session.Save(); err != nil {
|
||||||
logrus.Errorf("Failed to store session: %v", err)
|
logrus.Errorf("failed to store session: %v", err)
|
||||||
return err
|
return errors.Wrap(err, "failed to store session")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -276,8 +276,8 @@ func DestroySessionData(c *gin.Context) error {
|
|||||||
session := sessions.Default(c)
|
session := sessions.Default(c)
|
||||||
session.Delete(SessionIdentifier)
|
session.Delete(SessionIdentifier)
|
||||||
if err := session.Save(); err != nil {
|
if err := session.Save(); err != nil {
|
||||||
logrus.Errorf("Failed to destroy session: %v", err)
|
logrus.Errorf("failed to destroy session: %v", err)
|
||||||
return err
|
return errors.Wrap(err, "failed to destroy session")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -289,7 +289,7 @@ func SetFlashMessage(c *gin.Context, message, typ string) {
|
|||||||
Type: typ,
|
Type: typ,
|
||||||
})
|
})
|
||||||
if err := session.Save(); err != nil {
|
if err := session.Save(); err != nil {
|
||||||
logrus.Errorf("Failed to store session after setting flash: %v", err)
|
logrus.Errorf("failed to store session after setting flash: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -8,8 +8,11 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/h44z/wg-portal/internal/common"
|
"github.com/h44z/wg-portal/internal/common"
|
||||||
|
"github.com/h44z/wg-portal/internal/users"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Server) PrepareNewPeer() (Peer, error) {
|
func (s *Server) PrepareNewPeer() (Peer, error) {
|
||||||
@ -22,18 +25,18 @@ func (s *Server) PrepareNewPeer() (Peer, error) {
|
|||||||
for i := range device.IPs {
|
for i := range device.IPs {
|
||||||
freeIP, err := s.peers.GetAvailableIp(device.IPs[i])
|
freeIP, err := s.peers.GetAvailableIp(device.IPs[i])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Peer{}, err
|
return Peer{}, errors.WithMessage(err, "failed to get available IP addresses")
|
||||||
}
|
}
|
||||||
peer.IPs[i] = freeIP
|
peer.IPs[i] = freeIP
|
||||||
}
|
}
|
||||||
peer.IPsStr = common.ListToString(peer.IPs)
|
peer.IPsStr = common.ListToString(peer.IPs)
|
||||||
psk, err := wgtypes.GenerateKey()
|
psk, err := wgtypes.GenerateKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Peer{}, err
|
return Peer{}, errors.Wrap(err, "failed to generate key")
|
||||||
}
|
}
|
||||||
key, err := wgtypes.GeneratePrivateKey()
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Peer{}, err
|
return Peer{}, errors.Wrap(err, "failed to generate private key")
|
||||||
}
|
}
|
||||||
peer.PresharedKey = psk.String()
|
peer.PresharedKey = psk.String()
|
||||||
peer.PrivateKey = key.String()
|
peer.PrivateKey = key.String()
|
||||||
@ -57,18 +60,18 @@ func (s *Server) CreatePeerByEmail(email, identifierSuffix string, disabled bool
|
|||||||
for i := range device.IPs {
|
for i := range device.IPs {
|
||||||
freeIP, err := s.peers.GetAvailableIp(device.IPs[i])
|
freeIP, err := s.peers.GetAvailableIp(device.IPs[i])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return errors.WithMessage(err, "failed to get available IP addresses")
|
||||||
}
|
}
|
||||||
peer.IPs[i] = freeIP
|
peer.IPs[i] = freeIP
|
||||||
}
|
}
|
||||||
peer.IPsStr = common.ListToString(peer.IPs)
|
peer.IPsStr = common.ListToString(peer.IPs)
|
||||||
psk, err := wgtypes.GenerateKey()
|
psk, err := wgtypes.GenerateKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return errors.Wrap(err, "failed to generate key")
|
||||||
}
|
}
|
||||||
key, err := wgtypes.GeneratePrivateKey()
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return errors.Wrap(err, "failed to generate private key")
|
||||||
}
|
}
|
||||||
peer.PresharedKey = psk.String()
|
peer.PresharedKey = psk.String()
|
||||||
peer.PrivateKey = key.String()
|
peer.PrivateKey = key.String()
|
||||||
@ -92,7 +95,7 @@ func (s *Server) CreatePeer(peer Peer) error {
|
|||||||
for i := range device.IPs {
|
for i := range device.IPs {
|
||||||
freeIP, err := s.peers.GetAvailableIp(device.IPs[i])
|
freeIP, err := s.peers.GetAvailableIp(device.IPs[i])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return errors.WithMessage(err, "failed to get available IP addresses")
|
||||||
}
|
}
|
||||||
peer.IPs[i] = freeIP
|
peer.IPs[i] = freeIP
|
||||||
}
|
}
|
||||||
@ -101,11 +104,11 @@ func (s *Server) CreatePeer(peer Peer) error {
|
|||||||
if peer.PrivateKey == "" { // if private key is empty create a new one
|
if peer.PrivateKey == "" { // if private key is empty create a new one
|
||||||
psk, err := wgtypes.GenerateKey()
|
psk, err := wgtypes.GenerateKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return errors.Wrap(err, "failed to generate key")
|
||||||
}
|
}
|
||||||
key, err := wgtypes.GeneratePrivateKey()
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return errors.Wrap(err, "failed to generate private key")
|
||||||
}
|
}
|
||||||
peer.PresharedKey = psk.String()
|
peer.PresharedKey = psk.String()
|
||||||
peer.PrivateKey = key.String()
|
peer.PrivateKey = key.String()
|
||||||
@ -116,13 +119,13 @@ func (s *Server) CreatePeer(peer Peer) error {
|
|||||||
// Create WireGuard interface
|
// Create WireGuard interface
|
||||||
if peer.DeactivatedAt == nil {
|
if peer.DeactivatedAt == nil {
|
||||||
if err := s.wg.AddPeer(peer.GetConfig()); err != nil {
|
if err := s.wg.AddPeer(peer.GetConfig()); err != nil {
|
||||||
return err
|
return errors.WithMessage(err, "failed to add WireGuard peer")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create in database
|
// Create in database
|
||||||
if err := s.peers.CreatePeer(peer); err != nil {
|
if err := s.peers.CreatePeer(peer); err != nil {
|
||||||
return err
|
return errors.WithMessage(err, "failed to create peer")
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.WriteWireGuardConfigFile()
|
return s.WriteWireGuardConfigFile()
|
||||||
@ -142,12 +145,12 @@ func (s *Server) UpdatePeer(peer Peer, updateTime time.Time) error {
|
|||||||
err = s.wg.AddPeer(peer.GetConfig())
|
err = s.wg.AddPeer(peer.GetConfig())
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return errors.WithMessage(err, "failed to update WireGuard peer")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update in database
|
// Update in database
|
||||||
if err := s.peers.UpdatePeer(peer); err != nil {
|
if err := s.peers.UpdatePeer(peer); err != nil {
|
||||||
return err
|
return errors.WithMessage(err, "failed to update peer")
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.WriteWireGuardConfigFile()
|
return s.WriteWireGuardConfigFile()
|
||||||
@ -156,12 +159,12 @@ func (s *Server) UpdatePeer(peer Peer, updateTime time.Time) error {
|
|||||||
func (s *Server) DeletePeer(peer Peer) error {
|
func (s *Server) DeletePeer(peer Peer) error {
|
||||||
// Delete WireGuard peer
|
// Delete WireGuard peer
|
||||||
if err := s.wg.RemovePeer(peer.PublicKey); err != nil {
|
if err := s.wg.RemovePeer(peer.PublicKey); err != nil {
|
||||||
return err
|
return errors.WithMessage(err, "failed to remove WireGuard peer")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete in database
|
// Delete in database
|
||||||
if err := s.peers.DeletePeer(peer); err != nil {
|
if err := s.peers.DeletePeer(peer); err != nil {
|
||||||
return err
|
return errors.WithMessage(err, "failed to remove peer")
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.WriteWireGuardConfigFile()
|
return s.WriteWireGuardConfigFile()
|
||||||
@ -173,7 +176,7 @@ func (s *Server) RestoreWireGuardInterface() error {
|
|||||||
for i := range activePeers {
|
for i := range activePeers {
|
||||||
if activePeers[i].Peer == nil {
|
if activePeers[i].Peer == nil {
|
||||||
if err := s.wg.AddPeer(activePeers[i].GetConfig()); err != nil {
|
if err := s.wg.AddPeer(activePeers[i].GetConfig()); err != nil {
|
||||||
return err
|
return errors.WithMessage(err, "failed to add WireGuard peer")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -186,16 +189,109 @@ func (s *Server) WriteWireGuardConfigFile() error {
|
|||||||
return nil // writing disabled
|
return nil // writing disabled
|
||||||
}
|
}
|
||||||
if err := syscall.Access(s.config.WG.WireGuardConfig, syscall.O_RDWR); err != nil {
|
if err := syscall.Access(s.config.WG.WireGuardConfig, syscall.O_RDWR); err != nil {
|
||||||
return err
|
return errors.Wrap(err, "failed to check WireGuard config access rights")
|
||||||
}
|
}
|
||||||
|
|
||||||
device := s.peers.GetDevice()
|
device := s.peers.GetDevice()
|
||||||
cfg, err := device.GetConfigFile(s.peers.GetActivePeers())
|
cfg, err := device.GetConfigFile(s.peers.GetActivePeers())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return errors.WithMessage(err, "failed to get config file")
|
||||||
}
|
}
|
||||||
if err := ioutil.WriteFile(s.config.WG.WireGuardConfig, cfg, 0644); err != nil {
|
if err := ioutil.WriteFile(s.config.WG.WireGuardConfig, cfg, 0644); err != nil {
|
||||||
return err
|
return errors.Wrap(err, "failed to write WireGuard config file")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) CreateUser(user users.User) error {
|
||||||
|
if user.Email == "" {
|
||||||
|
return errors.New("cannot create user with empty email address")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if user already exists, if so re-enable
|
||||||
|
if existingUser := s.users.GetUserUnscoped(user.Email); existingUser != nil {
|
||||||
|
user.DeletedAt = gorm.DeletedAt{} // reset deleted flag to enable that user again
|
||||||
|
return s.UpdateUser(user)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create user in database
|
||||||
|
if err := s.users.CreateUser(&user); err != nil {
|
||||||
|
return errors.WithMessage(err, "failed to create user in manager")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if user already has a peer setup, if not, create one
|
||||||
|
return s.CreateUserDefaultPeer(user.Email)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) UpdateUser(user users.User) error {
|
||||||
|
if user.DeletedAt.Valid {
|
||||||
|
return s.DeleteUser(user)
|
||||||
|
}
|
||||||
|
|
||||||
|
currentUser := s.users.GetUserUnscoped(user.Email)
|
||||||
|
|
||||||
|
// Update in database
|
||||||
|
if err := s.users.UpdateUser(&user); err != nil {
|
||||||
|
return errors.WithMessage(err, "failed to update user in manager")
|
||||||
|
}
|
||||||
|
|
||||||
|
// If user was deleted (disabled), reactivate it's peers
|
||||||
|
if currentUser.DeletedAt.Valid {
|
||||||
|
for _, peer := range s.peers.GetPeersByMail(user.Email) {
|
||||||
|
now := time.Now()
|
||||||
|
peer.DeactivatedAt = nil
|
||||||
|
if err := s.UpdatePeer(peer, now); err != nil {
|
||||||
|
logrus.Errorf("failed to update (re)activated peer %s for %s: %v", peer.PublicKey, user.Email, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) DeleteUser(user users.User) error {
|
||||||
|
currentUser := s.users.GetUserUnscoped(user.Email)
|
||||||
|
|
||||||
|
// Update in database
|
||||||
|
if err := s.users.DeleteUser(&user); err != nil {
|
||||||
|
return errors.WithMessage(err, "failed to delete user in manager")
|
||||||
|
}
|
||||||
|
|
||||||
|
// If user was active, disable it's peers
|
||||||
|
if !currentUser.DeletedAt.Valid {
|
||||||
|
for _, peer := range s.peers.GetPeersByMail(user.Email) {
|
||||||
|
now := time.Now()
|
||||||
|
peer.DeactivatedAt = &now
|
||||||
|
if err := s.UpdatePeer(peer, now); err != nil {
|
||||||
|
logrus.Errorf("failed to update deactivated peer %s for %s: %v", peer.PublicKey, user.Email, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) CreateUserDefaultPeer(email string) error {
|
||||||
|
// Check if user is active, if not, quit
|
||||||
|
var existingUser *users.User
|
||||||
|
if existingUser = s.users.GetUser(email); existingUser == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if user already has a peer setup, if not, create one
|
||||||
|
if s.config.Core.CreateDefaultPeer {
|
||||||
|
peers := s.peers.GetPeersByMail(email)
|
||||||
|
if len(peers) == 0 { // Create default vpn peer
|
||||||
|
if err := s.CreatePeer(Peer{
|
||||||
|
Identifier: existingUser.Firstname + " " + existingUser.Lastname + " (Default)",
|
||||||
|
Email: existingUser.Email,
|
||||||
|
CreatedBy: existingUser.Email,
|
||||||
|
UpdatedBy: existingUser.Email,
|
||||||
|
}); err != nil {
|
||||||
|
return errors.WithMessagef(err, "failed to automatically create vpn peer for %s", email)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
@ -1,9 +1,10 @@
|
|||||||
package wireguard
|
package wireguard
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl"
|
"golang.zx2c4.com/wireguard/wgctrl"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
)
|
)
|
||||||
@ -18,7 +19,7 @@ func (m *Manager) Init() error {
|
|||||||
var err error
|
var err error
|
||||||
m.wg, err = wgctrl.New()
|
m.wg, err = wgctrl.New()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("could not create WireGuard client: %w", err)
|
return errors.Wrap(err, "could not create WireGuard client")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -27,7 +28,7 @@ func (m *Manager) Init() error {
|
|||||||
func (m *Manager) GetDeviceInfo() (*wgtypes.Device, error) {
|
func (m *Manager) GetDeviceInfo() (*wgtypes.Device, error) {
|
||||||
dev, err := m.wg.Device(m.Cfg.DeviceName)
|
dev, err := m.wg.Device(m.Cfg.DeviceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("could not get WireGuard device: %w", err)
|
return nil, errors.Wrap(err, "could not get WireGuard device")
|
||||||
}
|
}
|
||||||
|
|
||||||
return dev, nil
|
return dev, nil
|
||||||
@ -39,7 +40,7 @@ func (m *Manager) GetPeerList() ([]wgtypes.Peer, error) {
|
|||||||
|
|
||||||
dev, err := m.wg.Device(m.Cfg.DeviceName)
|
dev, err := m.wg.Device(m.Cfg.DeviceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("could not get WireGuard device: %w", err)
|
return nil, errors.Wrap(err, "could not get WireGuard device")
|
||||||
}
|
}
|
||||||
|
|
||||||
return dev.Peers, nil
|
return dev.Peers, nil
|
||||||
@ -51,12 +52,12 @@ func (m *Manager) GetPeer(pubKey string) (*wgtypes.Peer, error) {
|
|||||||
|
|
||||||
publicKey, err := wgtypes.ParseKey(pubKey)
|
publicKey, err := wgtypes.ParseKey(pubKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid public key: %w", err)
|
return nil, errors.Wrap(err, "invalid public key")
|
||||||
}
|
}
|
||||||
|
|
||||||
peers, err := m.GetPeerList()
|
peers, err := m.GetPeerList()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("could not get WireGuard peers: %w", err)
|
return nil, errors.Wrap(err, "could not get WireGuard peers")
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, peer := range peers {
|
for _, peer := range peers {
|
||||||
@ -65,7 +66,7 @@ func (m *Manager) GetPeer(pubKey string) (*wgtypes.Peer, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, fmt.Errorf("could not find WireGuard peer: %s", pubKey)
|
return nil, errors.Errorf("could not find WireGuard peer: %s", pubKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) AddPeer(cfg wgtypes.PeerConfig) error {
|
func (m *Manager) AddPeer(cfg wgtypes.PeerConfig) error {
|
||||||
@ -74,7 +75,7 @@ func (m *Manager) AddPeer(cfg wgtypes.PeerConfig) error {
|
|||||||
|
|
||||||
err := m.wg.ConfigureDevice(m.Cfg.DeviceName, wgtypes.Config{Peers: []wgtypes.PeerConfig{cfg}})
|
err := m.wg.ConfigureDevice(m.Cfg.DeviceName, wgtypes.Config{Peers: []wgtypes.PeerConfig{cfg}})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("could not configure WireGuard device: %w", err)
|
return errors.Wrap(err, "could not configure WireGuard device")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -87,7 +88,7 @@ func (m *Manager) UpdatePeer(cfg wgtypes.PeerConfig) error {
|
|||||||
cfg.UpdateOnly = true
|
cfg.UpdateOnly = true
|
||||||
err := m.wg.ConfigureDevice(m.Cfg.DeviceName, wgtypes.Config{Peers: []wgtypes.PeerConfig{cfg}})
|
err := m.wg.ConfigureDevice(m.Cfg.DeviceName, wgtypes.Config{Peers: []wgtypes.PeerConfig{cfg}})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("could not configure WireGuard device: %w", err)
|
return errors.Wrap(err, "could not configure WireGuard device")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -99,7 +100,7 @@ func (m *Manager) RemovePeer(pubKey string) error {
|
|||||||
|
|
||||||
publicKey, err := wgtypes.ParseKey(pubKey)
|
publicKey, err := wgtypes.ParseKey(pubKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("invalid public key: %w", err)
|
return errors.Wrap(err, "invalid public key")
|
||||||
}
|
}
|
||||||
|
|
||||||
peer := wgtypes.PeerConfig{
|
peer := wgtypes.PeerConfig{
|
||||||
@ -109,7 +110,7 @@ func (m *Manager) RemovePeer(pubKey string) error {
|
|||||||
|
|
||||||
err = m.wg.ConfigureDevice(m.Cfg.DeviceName, wgtypes.Config{Peers: []wgtypes.PeerConfig{peer}})
|
err = m.wg.ConfigureDevice(m.Cfg.DeviceName, wgtypes.Config{Peers: []wgtypes.PeerConfig{peer}})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("could not configure WireGuard device: %w", err)
|
return errors.Wrap(err, "could not configure WireGuard device")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -4,6 +4,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
"github.com/milosgajdos/tenus"
|
"github.com/milosgajdos/tenus"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -12,18 +14,18 @@ const DefaultMTU = 1420
|
|||||||
func (m *Manager) GetIPAddress() ([]string, error) {
|
func (m *Manager) GetIPAddress() ([]string, error) {
|
||||||
wgInterface, err := tenus.NewLinkFrom(m.Cfg.DeviceName)
|
wgInterface, err := tenus.NewLinkFrom(m.Cfg.DeviceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("could not retrieve WireGuard interface %s: %w", m.Cfg.DeviceName, err)
|
return nil, errors.Wrapf(err, "could not retrieve WireGuard interface %s", m.Cfg.DeviceName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get golang net.interface
|
// Get golang net.interface
|
||||||
iface := wgInterface.NetInterface()
|
iface := wgInterface.NetInterface()
|
||||||
if iface == nil { // Not sure if this check is really necessary
|
if iface == nil { // Not sure if this check is really necessary
|
||||||
return nil, fmt.Errorf("could not retrieve WireGuard net.interface: %w", err)
|
return nil, errors.Wrap(err, "could not retrieve WireGuard net.interface")
|
||||||
}
|
}
|
||||||
|
|
||||||
addrs, err := iface.Addrs()
|
addrs, err := iface.Addrs()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("could not retrieve WireGuard ip addresses: %w", err)
|
return nil, errors.Wrap(err, "could not retrieve WireGuard ip addresses")
|
||||||
}
|
}
|
||||||
|
|
||||||
ipAddresses := make([]string, 0, len(addrs))
|
ipAddresses := make([]string, 0, len(addrs))
|
||||||
@ -53,22 +55,22 @@ func (m *Manager) GetIPAddress() ([]string, error) {
|
|||||||
func (m *Manager) SetIPAddress(cidrs []string) error {
|
func (m *Manager) SetIPAddress(cidrs []string) error {
|
||||||
wgInterface, err := tenus.NewLinkFrom(m.Cfg.DeviceName)
|
wgInterface, err := tenus.NewLinkFrom(m.Cfg.DeviceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("could not retrieve WireGuard interface %s: %w", m.Cfg.DeviceName, err)
|
return errors.Wrapf(err, "could not retrieve WireGuard interface %s", m.Cfg.DeviceName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// First remove existing IP addresses
|
// First remove existing IP addresses
|
||||||
existingIPs, err := m.GetIPAddress()
|
existingIPs, err := m.GetIPAddress()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return errors.Wrap(err, "could not retrieve IP addresses")
|
||||||
}
|
}
|
||||||
for _, cidr := range existingIPs {
|
for _, cidr := range existingIPs {
|
||||||
wgIp, wgIpNet, err := net.ParseCIDR(cidr)
|
wgIp, wgIpNet, err := net.ParseCIDR(cidr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to parse cidr %s: %w", cidr, err)
|
return errors.Wrapf(err, "unable to parse cidr %s", cidr)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := wgInterface.UnsetLinkIp(wgIp, wgIpNet); err != nil {
|
if err := wgInterface.UnsetLinkIp(wgIp, wgIpNet); err != nil {
|
||||||
return fmt.Errorf("failed to unset ip %s: %w", cidr, err)
|
return errors.Wrapf(err, "failed to unset ip %s", cidr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -76,11 +78,11 @@ func (m *Manager) SetIPAddress(cidrs []string) error {
|
|||||||
for _, cidr := range cidrs {
|
for _, cidr := range cidrs {
|
||||||
wgIp, wgIpNet, err := net.ParseCIDR(cidr)
|
wgIp, wgIpNet, err := net.ParseCIDR(cidr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to parse cidr %s: %w", cidr, err)
|
return errors.Wrapf(err, "unable to parse cidr %s", cidr)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := wgInterface.SetLinkIp(wgIp, wgIpNet); err != nil {
|
if err := wgInterface.SetLinkIp(wgIp, wgIpNet); err != nil {
|
||||||
return fmt.Errorf("failed to set ip %s: %w", cidr, err)
|
return errors.Wrapf(err, "failed to set ip %s", cidr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -90,13 +92,13 @@ func (m *Manager) SetIPAddress(cidrs []string) error {
|
|||||||
func (m *Manager) GetMTU() (int, error) {
|
func (m *Manager) GetMTU() (int, error) {
|
||||||
wgInterface, err := tenus.NewLinkFrom(m.Cfg.DeviceName)
|
wgInterface, err := tenus.NewLinkFrom(m.Cfg.DeviceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("could not retrieve WireGuard interface %s: %w", m.Cfg.DeviceName, err)
|
return 0, errors.Wrapf(err, "could not retrieve WireGuard interface %s", m.Cfg.DeviceName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get golang net.interface
|
// Get golang net.interface
|
||||||
iface := wgInterface.NetInterface()
|
iface := wgInterface.NetInterface()
|
||||||
if iface == nil { // Not sure if this check is really necessary
|
if iface == nil { // Not sure if this check is really necessary
|
||||||
return 0, fmt.Errorf("could not retrieve WireGuard net.interface: %w", err)
|
return 0, errors.Wrap(err, "could not retrieve WireGuard net.interface")
|
||||||
}
|
}
|
||||||
|
|
||||||
return iface.MTU, nil
|
return iface.MTU, nil
|
||||||
@ -105,7 +107,7 @@ func (m *Manager) GetMTU() (int, error) {
|
|||||||
func (m *Manager) SetMTU(mtu int) error {
|
func (m *Manager) SetMTU(mtu int) error {
|
||||||
wgInterface, err := tenus.NewLinkFrom(m.Cfg.DeviceName)
|
wgInterface, err := tenus.NewLinkFrom(m.Cfg.DeviceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("could not retrieve WireGuard interface %s: %w", m.Cfg.DeviceName, err)
|
return errors.Wrapf(err, "could not retrieve WireGuard interface %s", m.Cfg.DeviceName)
|
||||||
}
|
}
|
||||||
|
|
||||||
if mtu == 0 {
|
if mtu == 0 {
|
||||||
@ -113,7 +115,7 @@ func (m *Manager) SetMTU(mtu int) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := wgInterface.SetLinkMTU(mtu); err != nil {
|
if err := wgInterface.SetLinkMTU(mtu); err != nil {
|
||||||
return fmt.Errorf("could not set MTU on interface %s: %w", m.Cfg.DeviceName, err)
|
return errors.Wrapf(err, "could not set MTU on interface %s", m.Cfg.DeviceName)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
Loading…
Reference in New Issue
Block a user