diff --git a/assets/tpl/admin_create_clients.html b/assets/tpl/admin_create_clients.html index 9d11549..12dd8a9 100644 --- a/assets/tpl/admin_create_clients.html +++ b/assets/tpl/admin_create_clients.html @@ -41,7 +41,7 @@
- +
diff --git a/assets/tpl/admin_index.html b/assets/tpl/admin_index.html index ec944c8..b76694d 100644 --- a/assets/tpl/admin_index.html +++ b/assets/tpl/admin_index.html @@ -111,8 +111,13 @@ {{$p.PublicKey}} {{$p.Email}} {{$p.IPsStr}} + {{if not $p.Peer}} + ? / ? + ? + {{else}} {{if $p.DeactivatedAt}}-{{else}}{{$p.Peer.ReceiveBytes}} / {{$p.Peer.TransmitBytes}}{{end}} {{if $p.DeactivatedAt}}-{{else}}{{$p.Peer.LastHandshakeTime}}{{end}} + {{end}} {{if eq $.Session.IsAdmin true}} diff --git a/internal/common/iputil.go b/internal/common/iputil.go index 0241b02..240c620 100644 --- a/internal/common/iputil.go +++ b/internal/common/iputil.go @@ -39,19 +39,19 @@ func IsIPv6(address string) bool { return ip.To4() == nil } -func ParseIPList(lst string) []string { - ips := strings.Split(lst, ",") - validatedIPs := make([]string, 0, len(ips)) - for i := range ips { - ips[i] = strings.TrimSpace(ips[i]) - if ips[i] != "" { - validatedIPs = append(validatedIPs, ips[i]) +func ParseStringList(lst string) []string { + tokens := strings.Split(lst, ",") + validatedTokens := make([]string, 0, len(tokens)) + for i := range tokens { + tokens[i] = strings.TrimSpace(tokens[i]) + if tokens[i] != "" { + validatedTokens = append(validatedTokens, tokens[i]) } } - return validatedIPs + return validatedTokens } -func IPListToString(lst []string) string { +func ListToString(lst []string) string { return strings.Join(lst, ", ") } diff --git a/internal/server/handlers.go b/internal/server/handlers.go index 8def6a9..1a9ad14 100644 --- a/internal/server/handlers.go +++ b/internal/server/handlers.go @@ -6,8 +6,11 @@ import ( "net/http" "net/url" "strconv" + "strings" "time" + log "github.com/sirupsen/logrus" + "github.com/h44z/wg-portal/internal/ldap" "github.com/h44z/wg-portal/internal/common" @@ -89,12 +92,12 @@ func (s *Server) PostAdminEditInterface(c *gin.Context) { return } // Clean list input - formDevice.IPs = common.ParseIPList(formDevice.IPsStr) - formDevice.AllowedIPs = common.ParseIPList(formDevice.AllowedIPsStr) - formDevice.DNS = common.ParseIPList(formDevice.DNSStr) - formDevice.IPsStr = common.IPListToString(formDevice.IPs) - formDevice.AllowedIPsStr = common.IPListToString(formDevice.AllowedIPs) - formDevice.DNSStr = common.IPListToString(formDevice.DNS) + formDevice.IPs = common.ParseStringList(formDevice.IPsStr) + formDevice.AllowedIPs = common.ParseStringList(formDevice.AllowedIPsStr) + formDevice.DNS = common.ParseStringList(formDevice.DNSStr) + formDevice.IPsStr = common.ListToString(formDevice.IPs) + formDevice.AllowedIPsStr = common.ListToString(formDevice.AllowedIPs) + formDevice.DNSStr = common.ListToString(formDevice.DNS) // Update WireGuard device err := s.wg.UpdateDevice(formDevice.DeviceName, formDevice.GetDeviceConfig()) @@ -149,10 +152,10 @@ func (s *Server) PostAdminEditPeer(c *gin.Context) { } // Clean list input - formUser.IPs = common.ParseIPList(formUser.IPsStr) - formUser.AllowedIPs = common.ParseIPList(formUser.AllowedIPsStr) - formUser.IPsStr = common.IPListToString(formUser.IPs) - formUser.AllowedIPsStr = common.IPListToString(formUser.AllowedIPs) + formUser.IPs = common.ParseStringList(formUser.IPsStr) + formUser.AllowedIPs = common.ParseStringList(formUser.AllowedIPsStr) + formUser.IPsStr = common.ListToString(formUser.IPs) + formUser.AllowedIPsStr = common.ListToString(formUser.AllowedIPs) disabled := c.PostForm("isdisabled") != "" now := time.Now() @@ -244,10 +247,10 @@ func (s *Server) PostAdminCreatePeer(c *gin.Context) { } // Clean list input - formUser.IPs = common.ParseIPList(formUser.IPsStr) - formUser.AllowedIPs = common.ParseIPList(formUser.AllowedIPsStr) - formUser.IPsStr = common.IPListToString(formUser.IPs) - formUser.AllowedIPsStr = common.IPListToString(formUser.AllowedIPs) + formUser.IPs = common.ParseStringList(formUser.IPsStr) + formUser.AllowedIPs = common.ParseStringList(formUser.AllowedIPsStr) + formUser.IPsStr = common.ListToString(formUser.IPs) + formUser.AllowedIPsStr = common.ListToString(formUser.AllowedIPs) disabled := c.PostForm("isdisabled") != "" now := time.Now() @@ -265,7 +268,7 @@ func (s *Server) PostAdminCreatePeer(c *gin.Context) { } } - // Update in database + // Create in database err := s.users.CreateUser(formUser) if err != nil { s.setAlert(c, "failed to add user in database: "+err.Error(), "danger") @@ -297,6 +300,73 @@ func (s *Server) GetAdminCreateLdapPeers(c *gin.Context) { }) } +func (s *Server) PostAdminCreateLdapPeers(c *gin.Context) { + email := c.PostForm("email") + identifier := c.PostForm("identifier") + if identifier == "" { + identifier = "Default" + } + if email == "" { + s.setAlert(c, "missing email address", "danger") + c.Redirect(http.StatusSeeOther, "/admin/peer/createldap") + return + } + emails := common.ParseStringList(email) + for i := range emails { + // TODO: also check email addr for validity? + if !strings.ContainsRune(emails[i], '@') || s.ldapUsers.GetUserDNByMail(emails[i]) == "" { + s.setAlert(c, "invalid email address: "+emails[i], "danger") + c.Redirect(http.StatusSeeOther, "/admin/peer/createldap") + return + } + } + + log.Infof("creating %d ldap peers", len(emails)) + device := s.users.GetDevice() + + for i := range emails { + ldapUser := s.ldapUsers.GetUserData(s.ldapUsers.GetUserDNByMail(emails[i])) + user := User{} + user.AllowedIPsStr = device.AllowedIPsStr + user.IPsStr = "" // TODO: add a valid ip here + psk, err := wgtypes.GenerateKey() + if err != nil { + s.HandleError(c, http.StatusInternalServerError, "Preshared key generation error", err.Error()) + return + } + key, err := wgtypes.GeneratePrivateKey() + if err != nil { + s.HandleError(c, http.StatusInternalServerError, "Private key generation error", err.Error()) + return + } + user.PresharedKey = psk.String() + user.PrivateKey = key.String() + user.PublicKey = key.PublicKey().String() + user.UID = fmt.Sprintf("u%x", md5.Sum([]byte(user.PublicKey))) + user.Email = emails[i] + user.Identifier = fmt.Sprintf("%s %s (%s)", ldapUser.Firstname, ldapUser.Lastname, identifier) + + // Create wireguard interface + err = s.wg.AddPeer(user.GetPeerConfig()) + if err != nil { + s.setAlert(c, "failed to add peer in WireGuard: "+err.Error(), "danger") + c.Redirect(http.StatusSeeOther, "/admin/peer/createldap") + return + } + + // Create in database + err = s.users.CreateUser(user) + if err != nil { + s.setAlert(c, "failed to add user in database: "+err.Error(), "danger") + c.Redirect(http.StatusSeeOther, "/admin/peer/createldap") + return + } + } + + s.setAlert(c, "client(s) created successfully", "success") + c.Redirect(http.StatusSeeOther, "/admin/peer/createldap") +} + func (s *Server) GetUserQRCode(c *gin.Context) { user := s.users.GetUserByKey(c.Query("pkey")) png, err := user.GetQRCode() diff --git a/internal/server/routes.go b/internal/server/routes.go index 9d2d99e..8c66fdf 100644 --- a/internal/server/routes.go +++ b/internal/server/routes.go @@ -27,6 +27,7 @@ func SetupRoutes(s *Server) { admin.GET("/peer/create", s.GetAdminCreatePeer) admin.POST("/peer/create", s.PostAdminCreatePeer) admin.GET("/peer/createldap", s.GetAdminCreateLdapPeers) + admin.POST("/peer/createldap", s.PostAdminCreateLdapPeers) // User routes user := s.server.Group("/user") diff --git a/internal/server/usermanager.go b/internal/server/usermanager.go index cef6878..a9316e9 100644 --- a/internal/server/usermanager.go +++ b/internal/server/usermanager.go @@ -32,7 +32,7 @@ import ( var cidrList validator.Func = func(fl validator.FieldLevel) bool { cidrListStr := fl.Field().String() - cidrList := common.ParseIPList(cidrListStr) + cidrList := common.ParseStringList(cidrListStr) for i := range cidrList { _, _, err := net.ParseCIDR(cidrList[i]) if err != nil { @@ -45,7 +45,7 @@ var cidrList validator.Func = func(fl validator.FieldLevel) bool { var ipList validator.Func = func(fl validator.FieldLevel) bool { ipListStr := fl.Field().String() - ipList := common.ParseIPList(ipListStr) + ipList := common.ParseStringList(ipListStr) for i := range ipList { ip := net.ParseIP(ipList[i]) if ip == nil {