diff --git a/assets/css/custom.css b/assets/css/custom.css index 8e0baa9..0782b69 100644 --- a/assets/css/custom.css +++ b/assets/css/custom.css @@ -47,4 +47,8 @@ .navbar { padding: 0.5rem 1rem; +} + +.disabled-peer { + color: #d03131; } \ No newline at end of file diff --git a/assets/tpl/admin_edit_client.html b/assets/tpl/admin_edit_client.html new file mode 100644 index 0000000..9dc09d5 --- /dev/null +++ b/assets/tpl/admin_edit_client.html @@ -0,0 +1,97 @@ + + + + + + {{ .Static.WebsiteTitle }} - Admin + + + + + + + + + {{template "prt_nav.html" .}} +
+ {{if eq .Peer.UID ""}} +

Create a new client

+ {{else}} +

Edit client {{.Peer.Identifier}}

+ {{end}} + + {{if $.Alerts.HasAlert}} +
+
+
+ +
+
+
+ {{end}} + +
+ +
+
+ + +
+
+
+
+ + +
+
+
+
+ + +
+
+
+
+ + +
+
+
+
+ + +
+
+ +
+
+
+ + +
+
+ + +
+
+
+ + + + Cancel +
+
+ {{template "prt_footer.html"}} + + + + + + + \ No newline at end of file diff --git a/assets/tpl/admin_edit_interface.html b/assets/tpl/admin_edit_interface.html new file mode 100644 index 0000000..55328f8 --- /dev/null +++ b/assets/tpl/admin_edit_interface.html @@ -0,0 +1,114 @@ + + + + + + {{ .Static.WebsiteTitle }} - Admin + + + + + + + + + {{template "prt_nav.html" .}} +
+

Edit interface {{.Device.DeviceName}}

+ + {{if $.Alerts.HasAlert}} +
+
+
+ +
+
+
+ {{end}} + +
+ +

Server's interface configuration

+
+
+ + +
+
+
+
+ + +
+
+ + +
+
+

Client's global configuration

+
+
+ + +
+
+
+
+ + +
+
+ + +
+
+
+
+ + +
+
+ + +
+
+

Interface configuration hooks

+
+
+ + +
+
+
+
+ + +
+
+
+
+ + +
+
+
+
+ + +
+
+ + + Cancel +
+
+ {{template "prt_footer.html"}} + + + + + + + \ No newline at end of file diff --git a/assets/tpl/admin_index.html b/assets/tpl/admin_index.html index fdd84bd..97b14ce 100644 --- a/assets/tpl/admin_index.html +++ b/assets/tpl/admin_index.html @@ -95,7 +95,7 @@ {{range $i, $p :=.Peers}} - + @@ -104,11 +104,11 @@ {{$p.PublicKey}} {{$p.Email}} {{$p.IPsStr}} - {{$p.Peer.ReceiveBytes}} / {{$p.Peer.TransmitBytes}} - {{$p.Peer.LastHandshakeTime}} + {{if $p.DeactivatedAt}}-{{else}}{{$p.Peer.ReceiveBytes}} / {{$p.Peer.TransmitBytes}}{{end}} + {{if $p.DeactivatedAt}}-{{else}}{{$p.Peer.LastHandshakeTime}}{{end}} {{if eq $.Session.IsAdmin true}} - + {{end}} @@ -119,7 +119,7 @@
-
    -
  • 0
  • -
+ {{if not $p.LdapUser}} +

No LDAP user-information available...

+ {{else}} +
    +
  • Firstname: {{$p.LdapUser.Firstname}}
  • +
  • Lastname: {{$p.LdapUser.Lastname}}
  • +
  • Phone: {{$p.UID}}
  • +
  • Mail: {{$p.LdapUser.Mail}}
  • +
  • Department: {{$p.UID}}
  • +
+ {{end}}
{{$p.Config}}
diff --git a/internal/server/core.go b/internal/server/core.go index 5417e48..9b7063a 100644 --- a/internal/server/core.go +++ b/internal/server/core.go @@ -90,12 +90,12 @@ func (s *Server) Setup() error { } // Setup user manager - s.users = NewUserManager() - if s.users == nil { + if s.users = NewUserManager(s.wg, s.ldapUsers); s.users == nil { return errors.New("unable to setup user manager") } - s.users.InitWithDevice(s.wg.GetDeviceInfo()) - s.users.InitWithPeers(s.wg.GetPeerList()) + if err := s.users.InitFromCurrentInterface(); err != nil { + return errors.New("unable to initialize user manager") + } dir := s.getExecutableDirectory() rDir, _ := filepath.Abs(filepath.Dir(os.Args[0])) diff --git a/internal/server/handlers.go b/internal/server/handlers.go index 36efa7e..7ae946c 100644 --- a/internal/server/handlers.go +++ b/internal/server/handlers.go @@ -2,7 +2,10 @@ package server import ( "net/http" + "net/url" "strconv" + "strings" + "time" "github.com/gin-gonic/gin" ) @@ -32,24 +35,9 @@ func (s *Server) HandleError(c *gin.Context, code int, message, details string) } func (s *Server) GetAdminIndex(c *gin.Context) { - dev, err := s.wg.GetDeviceInfo() - if err != nil { - s.HandleError(c, http.StatusInternalServerError, "WireGuard error", err.Error()) - return - } - peers, err := s.wg.GetPeerList() - if err != nil { - s.HandleError(c, http.StatusInternalServerError, "WireGuard error", err.Error()) - return - } - device := s.users.GetDevice() - device.Interface = dev + users := s.users.GetAllUsers() - users := make([]User, len(peers)) - for i, peer := range peers { - users[i] = s.users.GetOrCreateUserForPeer(peer) - } c.HTML(http.StatusOK, "admin_index.html", struct { Route string Session SessionData @@ -65,8 +53,332 @@ func (s *Server) GetAdminIndex(c *gin.Context) { }) } +func (s *Server) GetAdminEditInterface(c *gin.Context) { + device := s.users.GetDevice() + users := s.users.GetAllUsers() + + c.HTML(http.StatusOK, "admin_edit_interface.html", struct { + Route string + Alerts AlertData + Session SessionData + Static StaticData + Peers []User + Device Device + }{ + Route: c.Request.URL.Path, + Alerts: s.getAlertData(c), + Session: s.getSessionData(c), + Static: s.getStaticData(), + Peers: users, + Device: device, + }) +} + +func (s *Server) PostAdminEditInterface(c *gin.Context) { + device := s.users.GetDevice() + var err error + + device.ListenPort, err = strconv.Atoi(c.PostForm("port")) + if err != nil { + s.setAlert(c, "invalid port: "+err.Error(), "danger") + c.Redirect(http.StatusSeeOther, "/admin/device/edit") + return + } + + ipField := c.PostForm("ip") + ips := strings.Split(ipField, ",") + validatedIPs := make([]string, 0, len(ips)) + for i := range ips { + ips[i] = strings.TrimSpace(ips[i]) + if ips[i] != "" { + validatedIPs = append(validatedIPs, ips[i]) + } + } + if len(validatedIPs) == 0 { + s.setAlert(c, "invalid ip address", "danger") + c.Redirect(http.StatusSeeOther, "/admin/device/edit") + return + } + device.IPs = validatedIPs + + device.Endpoint = c.PostForm("endpoint") + + dnsField := c.PostForm("dns") + dns := strings.Split(dnsField, ",") + validatedDNS := make([]string, 0, len(dns)) + for i := range dns { + dns[i] = strings.TrimSpace(dns[i]) + if dns[i] != "" { + validatedDNS = append(validatedDNS, dns[i]) + } + } + device.DNS = validatedDNS + + allowedIPField := c.PostForm("allowedip") + allowedIP := strings.Split(allowedIPField, ",") + validatedAllowedIP := make([]string, 0, len(allowedIP)) + for i := range allowedIP { + allowedIP[i] = strings.TrimSpace(allowedIP[i]) + if allowedIP[i] != "" { + validatedAllowedIP = append(validatedAllowedIP, allowedIP[i]) + } + } + device.AllowedIPs = validatedAllowedIP + + device.Mtu, err = strconv.Atoi(c.PostForm("mtu")) + if err != nil { + s.setAlert(c, "invalid MTU: "+err.Error(), "danger") + c.Redirect(http.StatusSeeOther, "/admin/device/edit") + return + } + + device.PersistentKeepalive, err = strconv.Atoi(c.PostForm("keepalive")) + if err != nil { + s.setAlert(c, "invalid PersistentKeepalive: "+err.Error(), "danger") + c.Redirect(http.StatusSeeOther, "/admin/device/edit") + return + } + + // Update WireGuard device + err = s.wg.UpdateDevice(device.DeviceName, device.GetDeviceConfig()) + if err != nil { + s.setAlert(c, "failed to update device in WireGuard: "+err.Error(), "danger") + c.Redirect(http.StatusSeeOther, "/admin/device/edit") + return + } + + // Update in database + err = s.users.UpdateDevice(device) + if err != nil { + s.setAlert(c, "failed to update device in database: "+err.Error(), "danger") + c.Redirect(http.StatusSeeOther, "/admin/device/edit") + return + } + + s.setAlert(c, "changes applied successfully", "success") + c.Redirect(http.StatusSeeOther, "/admin/device/edit") +} + +func (s *Server) GetAdminEditPeer(c *gin.Context) { + device := s.users.GetDevice() + user := s.users.GetUserByKey(c.Query("pkey")) + + c.HTML(http.StatusOK, "admin_edit_client.html", struct { + Route string + Alerts AlertData + Session SessionData + Static StaticData + Peer User + Device Device + }{ + Route: c.Request.URL.Path, + Alerts: s.getAlertData(c), + Session: s.getSessionData(c), + Static: s.getStaticData(), + Peer: user, + Device: device, + }) +} + +func (s *Server) PostAdminEditPeer(c *gin.Context) { + user := s.users.GetUserByKey(c.Query("pkey")) + urlEncodedKey := url.QueryEscape(c.Query("pkey")) + var err error + + user.Identifier = c.PostForm("identifier") + if user.Identifier == "" { + s.setAlert(c, "invalid identifier, must not be empty", "danger") + c.Redirect(http.StatusSeeOther, "/admin/peer/edit?pkey="+urlEncodedKey) + return + } + + user.Email = c.PostForm("mail") + if user.Email == "" { + s.setAlert(c, "invalid email, must not be empty", "danger") + c.Redirect(http.StatusSeeOther, "/admin/peer/edit?pkey="+urlEncodedKey) + return + } + + ipField := c.PostForm("ip") + ips := strings.Split(ipField, ",") + validatedIPs := make([]string, 0, len(ips)) + for i := range ips { + ips[i] = strings.TrimSpace(ips[i]) + if ips[i] != "" { + validatedIPs = append(validatedIPs, ips[i]) + } + } + if len(validatedIPs) == 0 { + s.setAlert(c, "invalid ip address", "danger") + c.Redirect(http.StatusSeeOther, "/admin/peer/edit?pkey="+urlEncodedKey) + return + } + user.IPs = validatedIPs + + allowedIPField := c.PostForm("allowedip") + allowedIP := strings.Split(allowedIPField, ",") + validatedAllowedIP := make([]string, 0, len(allowedIP)) + for i := range allowedIP { + allowedIP[i] = strings.TrimSpace(allowedIP[i]) + if allowedIP[i] != "" { + validatedAllowedIP = append(validatedAllowedIP, allowedIP[i]) + } + } + user.AllowedIPs = validatedAllowedIP + + user.IgnorePersistentKeepalive = c.PostForm("ignorekeepalive") != "" + disabled := c.PostForm("isdisabled") != "" + now := time.Now() + if disabled && user.DeactivatedAt == nil { + user.DeactivatedAt = &now + } else if !disabled { + user.DeactivatedAt = nil + } + + // Update WireGuard device + if user.DeactivatedAt == &now { + err = s.wg.RemovePeer(user.PublicKey) + if err != nil { + s.setAlert(c, "failed to remove peer in WireGuard: "+err.Error(), "danger") + c.Redirect(http.StatusSeeOther, "/admin/peer/edit?pkey="+urlEncodedKey) + return + } + } else if user.DeactivatedAt == nil && user.Peer != nil { + err = s.wg.UpdatePeer(user.GetPeerConfig()) + if err != nil { + s.setAlert(c, "failed to update peer in WireGuard: "+err.Error(), "danger") + c.Redirect(http.StatusSeeOther, "/admin/peer/edit?pkey="+urlEncodedKey) + return + } + } else if user.DeactivatedAt == nil && user.Peer == nil { + err = s.wg.AddPeer(user.GetPeerConfig()) + if err != nil { + s.setAlert(c, "failed to add peer in WireGuard: "+err.Error(), "danger") + c.Redirect(http.StatusSeeOther, "/admin/peer/edit?pkey="+urlEncodedKey) + return + } + } + + // Update in database + err = s.users.UpdateUser(user) + if err != nil { + s.setAlert(c, "failed to update user in database: "+err.Error(), "danger") + c.Redirect(http.StatusSeeOther, "/admin/peer/edit?pkey="+urlEncodedKey) + return + } + + s.setAlert(c, "changes applied successfully", "success") + c.Redirect(http.StatusSeeOther, "/admin/peer/edit?pkey="+urlEncodedKey) +} + +func (s *Server) GetAdminCreatePeer(c *gin.Context) { + device := s.users.GetDevice() + user := s.users.GetUserByKey(c.Query("pkey")) + + c.HTML(http.StatusOK, "admin_edit_client.html", struct { + Route string + Alerts AlertData + Session SessionData + Static StaticData + Peer User + Device Device + }{ + Route: c.Request.URL.Path, + Alerts: s.getAlertData(c), + Session: s.getSessionData(c), + Static: s.getStaticData(), + Peer: user, + Device: device, + }) +} + +func (s *Server) PostAdminCreatePeer(c *gin.Context) { + device := s.users.GetDevice() + var err error + + device.ListenPort, err = strconv.Atoi(c.PostForm("port")) + if err != nil { + s.setAlert(c, "invalid port: "+err.Error(), "danger") + c.Redirect(http.StatusSeeOther, "/admin/device/edit") + return + } + + ipField := c.PostForm("ip") + ips := strings.Split(ipField, ",") + validatedIPs := make([]string, 0, len(ips)) + for i := range ips { + ips[i] = strings.TrimSpace(ips[i]) + if ips[i] != "" { + validatedIPs = append(validatedIPs, ips[i]) + } + } + if len(validatedIPs) == 0 { + s.setAlert(c, "invalid ip address", "danger") + c.Redirect(http.StatusSeeOther, "/admin/device/edit") + return + } + device.IPs = validatedIPs + + device.Endpoint = c.PostForm("endpoint") + + dnsField := c.PostForm("dns") + dns := strings.Split(dnsField, ",") + validatedDNS := make([]string, 0, len(dns)) + for i := range dns { + dns[i] = strings.TrimSpace(dns[i]) + if dns[i] != "" { + validatedDNS = append(validatedDNS, dns[i]) + } + } + device.DNS = validatedDNS + + allowedIPField := c.PostForm("allowedip") + allowedIP := strings.Split(allowedIPField, ",") + validatedAllowedIP := make([]string, 0, len(allowedIP)) + for i := range allowedIP { + allowedIP[i] = strings.TrimSpace(allowedIP[i]) + if allowedIP[i] != "" { + validatedAllowedIP = append(validatedAllowedIP, allowedIP[i]) + } + } + device.AllowedIPs = validatedAllowedIP + + device.Mtu, err = strconv.Atoi(c.PostForm("mtu")) + if err != nil { + s.setAlert(c, "invalid MTU: "+err.Error(), "danger") + c.Redirect(http.StatusSeeOther, "/admin/device/edit") + return + } + + device.PersistentKeepalive, err = strconv.Atoi(c.PostForm("keepalive")) + if err != nil { + s.setAlert(c, "invalid PersistentKeepalive: "+err.Error(), "danger") + c.Redirect(http.StatusSeeOther, "/admin/device/edit") + return + } + + // Update WireGuard device + err = s.wg.UpdateDevice(device.DeviceName, device.GetDeviceConfig()) + if err != nil { + s.setAlert(c, "failed to update device in WireGuard: "+err.Error(), "danger") + c.Redirect(http.StatusSeeOther, "/admin/device/edit") + return + } + + // Update in database + err = s.users.UpdateDevice(device) + if err != nil { + s.setAlert(c, "failed to update device in database: "+err.Error(), "danger") + c.Redirect(http.StatusSeeOther, "/admin/device/edit") + return + } + + s.setAlert(c, "changes applied successfully", "success") + c.Redirect(http.StatusSeeOther, "/admin/device/edit") +} + func (s *Server) GetUserQRCode(c *gin.Context) { - user := s.users.GetUser(c.Param("pkey")) + user := s.users.GetUserByKey(c.Query("pkey")) png, err := user.GetQRCode() if err != nil { s.HandleError(c, http.StatusInternalServerError, "QRCode error", err.Error()) diff --git a/internal/server/routes.go b/internal/server/routes.go index 1fe28f3..9dc523e 100644 --- a/internal/server/routes.go +++ b/internal/server/routes.go @@ -20,6 +20,12 @@ func SetupRoutes(s *Server) { admin := s.server.Group("/admin") admin.Use(s.RequireAuthentication(s.config.AdminLdapGroup)) admin.GET("/", s.GetAdminIndex) + admin.GET("/device/edit", s.GetAdminEditInterface) + admin.POST("/device/edit", s.PostAdminEditInterface) + admin.GET("/peer/edit", s.GetAdminEditPeer) + admin.POST("/peer/edit", s.PostAdminEditPeer) + admin.GET("/peer/create", s.GetAdminCreatePeer) + admin.POST("/peer/create", s.PostAdminCreatePeer) // User routes user := s.server.Group("/user") diff --git a/internal/server/usermanager.go b/internal/server/usermanager.go index 7112c85..5308e6f 100644 --- a/internal/server/usermanager.go +++ b/internal/server/usermanager.go @@ -22,10 +22,14 @@ import ( "gorm.io/gorm" ) +// +// USER ---------------------------------------------------------------------------------------- +// + type User struct { - Peer wgtypes.Peer `gorm:"-"` - User *ldap.UserCacheHolderEntry `gorm:"-"` // optional, it is still possible to have users without ldap - Config string `gorm:"-"` + Peer *wgtypes.Peer `gorm:"-"` + LdapUser *ldap.UserCacheHolderEntry `gorm:"-"` // optional, it is still possible to have users without ldap + Config string `gorm:"-"` UID string // uid for html identification IsOnline bool `gorm:"-"` @@ -48,6 +52,28 @@ type User struct { UpdatedAt time.Time } +func (u User) GetClientConfigFile(device Device) ([]byte, error) { + tpl, err := template.New("client").Funcs(template.FuncMap{"StringsJoin": strings.Join}).Parse(wireguard.ClientCfgTpl) + if err != nil { + return nil, err + } + + var tplBuff bytes.Buffer + + err = tpl.Execute(&tplBuff, struct { + Client User + Server Device + }{ + Client: u, + Server: device, + }) + if err != nil { + return nil, err + } + + return tplBuff.Bytes(), nil +} + func (u User) GetPeerConfig() wgtypes.PeerConfig { publicKey, _ := wgtypes.ParseKey(u.PublicKey) var presharedKey *wgtypes.Key @@ -87,6 +113,18 @@ func (u User) GetQRCode() ([]byte, error) { return png, nil } +func (u User) IsValid() bool { + if u.PublicKey == "" { + return false + } + + return true +} + +// +// DEVICE -------------------------------------------------------------------------------------- +// + type Device struct { Interface *wgtypes.Device `gorm:"-"` @@ -112,6 +150,9 @@ type Device struct { } func (d Device) IsValid() bool { + if d.PublicKey == "" { + return false + } if len(d.IPs) == 0 { return false } @@ -122,12 +163,33 @@ func (d Device) IsValid() bool { return true } -type UserManager struct { - db *gorm.DB +func (d Device) GetDeviceConfig() wgtypes.Config { + var privateKey *wgtypes.Key + if d.PrivateKey != "" { + pKey, _ := wgtypes.ParseKey(d.PrivateKey) + privateKey = &pKey + } + + cfg := wgtypes.Config{ + PrivateKey: privateKey, + ListenPort: &d.ListenPort, + } + + return cfg } -func NewUserManager() *UserManager { - um := &UserManager{} +// +// USER-MANAGER -------------------------------------------------------------------------------- +// + +type UserManager struct { + db *gorm.DB + wg *wireguard.Manager + ldapUsers *ldap.SynchronizedUserCacheHolder +} + +func NewUserManager(wg *wireguard.Manager, ldapUsers *ldap.SynchronizedUserCacheHolder) *UserManager { + um := &UserManager{wg: wg, ldapUsers: ldapUsers} var err error um.db, err = gorm.Open(sqlite.Open("wg_portal.db"), &gorm.Config{}) if err != nil { @@ -144,52 +206,32 @@ func NewUserManager() *UserManager { return um } -func (u *UserManager) InitWithPeers(peers []wgtypes.Peer, err error) { +func (u *UserManager) InitFromCurrentInterface() error { + peers, err := u.wg.GetPeerList() if err != nil { log.Errorf("failed to init user-manager from peers: %v", err) - return + return err } - for _, peer := range peers { - u.GetOrCreateUserForPeer(peer) - } -} - -func (u *UserManager) InitWithDevice(dev *wgtypes.Device, err error) { + device, err := u.wg.GetDeviceInfo() if err != nil { log.Errorf("failed to init user-manager from device: %v", err) - return - } - u.GetOrCreateDevice(*dev) -} - -func (u *UserManager) GetAllUsers() []User { - users := make([]User, 0) - u.db.Find(&users) - - for i := range users { - users[i].AllowedIPs = strings.Split(users[i].AllowedIPsStr, ", ") - users[i].IPs = strings.Split(users[i].IPsStr, ", ") - tmpCfg, _ := u.GetPeerConfigFile(users[i]) - users[i].Config = string(tmpCfg) + return err } - return users -} - -func (u *UserManager) GetDevice() Device { - devices := make([]Device, 0, 1) - u.db.Find(&devices) - - for i := range devices { - devices[i].AllowedIPs = strings.Split(devices[i].AllowedIPsStr, ", ") - devices[i].IPs = strings.Split(devices[i].IPsStr, ", ") - devices[i].DNS = strings.Split(devices[i].DNSStr, ", ") + // Check if entries already exist in database, if not create them + for _, peer := range peers { + if err := u.validateOrCreateUserForPeer(peer); err != nil { + return err + } + } + if err := u.validateOrCreateDevice(*device); err != nil { + return err } - return devices[0] + return nil } -func (u *UserManager) GetOrCreateUserForPeer(peer wgtypes.Peer) User { +func (u *UserManager) validateOrCreateUserForPeer(peer wgtypes.Peer) error { user := User{} u.db.Where("public_key = ?", peer.PublicKey.String()).FirstOrInit(&user) @@ -215,25 +257,92 @@ func (u *UserManager) GetOrCreateUserForPeer(peer wgtypes.Peer) User { res := u.db.Create(&user) if res.Error != nil { log.Errorf("failed to create autodetected peer: %v", res.Error) + return res.Error } } - user.IPs = strings.Split(user.IPsStr, ", ") + return nil +} + +func (u *UserManager) validateOrCreateDevice(dev wgtypes.Device) error { + device := Device{} + u.db.Where("device_name = ?", dev.Name).FirstOrInit(&device) + + if device.PublicKey == "" { // device not found, create + device.PublicKey = dev.PublicKey.String() + device.PrivateKey = dev.PrivateKey.String() + device.DeviceName = dev.Name + device.ListenPort = dev.ListenPort + device.Mtu = 0 + device.PersistentKeepalive = 16 // Default + + res := u.db.Create(&device) + if res.Error != nil { + log.Errorf("failed to create autodetected device: %v", res.Error) + return res.Error + } + } + + return nil +} + +func (u *UserManager) populateUserData(user *User) { user.AllowedIPs = strings.Split(user.AllowedIPsStr, ", ") - tmpCfg, _ := u.GetPeerConfigFile(user) + user.IPs = strings.Split(user.IPsStr, ", ") + // Set config file + tmpCfg, _ := user.GetClientConfigFile(u.GetDevice()) user.Config = string(tmpCfg) + // set data from WireGuard interface + user.Peer, _ = u.wg.GetPeer(user.PublicKey) + user.IsOnline = false // todo: calculate online status + + // set ldap data + user.LdapUser = u.ldapUsers.GetUserData(u.ldapUsers.GetUserDNByMail(user.Email)) +} + +func (u *UserManager) populateDeviceData(device *Device) { + device.AllowedIPs = strings.Split(device.AllowedIPsStr, ", ") + device.IPs = strings.Split(device.IPsStr, ", ") + device.DNS = strings.Split(device.DNSStr, ", ") + + // set data from WireGuard interface + device.Interface, _ = u.wg.GetDeviceInfo() +} + +func (u *UserManager) GetAllUsers() []User { + users := make([]User, 0) + u.db.Find(&users) + + for i := range users { + u.populateUserData(&users[i]) + } + + return users +} + +func (u *UserManager) GetDevice() Device { + devices := make([]Device, 0, 1) + u.db.Find(&devices) + + for i := range devices { + u.populateDeviceData(&devices[i]) + } + + return devices[0] // use first device for now... more to come? +} + +func (u *UserManager) GetUserByKey(publicKey string) User { + user := User{} + u.db.Where("public_key = ?", publicKey).FirstOrInit(&user) + u.populateUserData(&user) return user } -func (u *UserManager) GetUser(publicKey string) User { +func (u *UserManager) GetUserByMail(mail string) User { user := User{} - u.db.Where("public_key = ?", publicKey).FirstOrInit(&user) - - user.IPs = strings.Split(user.IPsStr, ", ") - user.AllowedIPs = strings.Split(user.AllowedIPsStr, ", ") - tmpCfg, _ := u.GetPeerConfigFile(user) - user.Config = string(tmpCfg) + u.db.Where("email = ?", mail).FirstOrInit(&user) + u.populateUserData(&user) return user } @@ -268,6 +377,21 @@ func (u *UserManager) UpdateUser(user User) error { return nil } +func (u *UserManager) UpdateDevice(device Device) error { + device.UpdatedAt = time.Now() + device.AllowedIPsStr = strings.Join(device.AllowedIPs, ", ") + device.IPsStr = strings.Join(device.IPs, ", ") + device.DNSStr = strings.Join(device.DNS, ", ") + + res := u.db.Save(&device) + if res.Error != nil { + log.Errorf("failed to update device: %v", res.Error) + return res.Error + } + + return nil +} + func (u *UserManager) GetAllReservedIps() ([]string, error) { reservedIps := make([]string, 0) users := u.GetAllUsers() @@ -328,50 +452,3 @@ func (u *UserManager) GetAvailableIp(cidr string, reserved []string) (string, er return "", errors.New("no more available address from cidr") } - -func (u *UserManager) GetOrCreateDevice(dev wgtypes.Device) Device { - device := Device{} - u.db.Where("device_name = ?", dev.Name).FirstOrInit(&device) - - if device.PublicKey == "" { // device not found, create - device.PublicKey = dev.PublicKey.String() - device.PrivateKey = dev.PrivateKey.String() - device.DeviceName = dev.Name - device.ListenPort = dev.ListenPort - device.Mtu = 0 - device.PersistentKeepalive = 16 // Default - - res := u.db.Create(&device) - if res.Error != nil { - log.Errorf("failed to create autodetected device: %v", res.Error) - } - } - - device.IPs = strings.Split(device.IPsStr, ", ") - device.AllowedIPs = strings.Split(device.AllowedIPsStr, ", ") - device.DNS = strings.Split(device.DNSStr, ", ") - - return device -} - -func (u *UserManager) GetPeerConfigFile(user User) ([]byte, error) { - tpl, err := template.New("client").Funcs(template.FuncMap{"StringsJoin": strings.Join}).Parse(wireguard.ClientCfgTpl) - if err != nil { - return nil, err - } - - var tplBuff bytes.Buffer - - err = tpl.Execute(&tplBuff, struct { - Client User - Server Device - }{ - Client: user, - Server: u.GetDevice(), - }) - if err != nil { - return nil, err - } - - return tplBuff.Bytes(), nil -} diff --git a/internal/wireguard/manager.go b/internal/wireguard/manager.go index 0f22b3c..976b841 100644 --- a/internal/wireguard/manager.go +++ b/internal/wireguard/manager.go @@ -80,6 +80,19 @@ func (m *Manager) AddPeer(cfg wgtypes.PeerConfig) error { return nil } +func (m *Manager) UpdatePeer(cfg wgtypes.PeerConfig) error { + m.mux.Lock() + defer m.mux.Unlock() + + cfg.UpdateOnly = true + err := m.wg.ConfigureDevice(m.Cfg.DeviceName, wgtypes.Config{Peers: []wgtypes.PeerConfig{cfg}}) + if err != nil { + return fmt.Errorf("could not configure WireGuard device: %w", err) + } + + return nil +} + func (m *Manager) RemovePeer(pubKey string) error { m.mux.Lock() defer m.mux.Unlock() @@ -101,3 +114,7 @@ func (m *Manager) RemovePeer(pubKey string) error { return nil } + +func (m *Manager) UpdateDevice(name string, cfg wgtypes.Config) error { + return m.wg.ConfigureDevice(name, cfg) +}