WIP: support for multiple WireGuard devices (#2)

This commit is contained in:
Christoph Haas 2021-03-21 12:36:11 +01:00
parent 5f4c041ee7
commit 6ab00ef567
22 changed files with 538 additions and 519 deletions

View File

@ -25,6 +25,11 @@
} }
}); });
}); });
$(function() {
$('select.device-selector').change(function() {
this.form.submit();
});
});
})(jQuery); // End of use strict })(jQuery); // End of use strict

View File

@ -106,6 +106,7 @@
</thead> </thead>
<tbody> <tbody>
{{range $i, $p :=.Peers}} {{range $i, $p :=.Peers}}
{{$peerUser:=(userForEmail $.Users $p.Email)}}
<tr id="user-pos-{{$i}}" {{if $p.DeactivatedAt}}class="disabled-peer"{{end}}> <tr id="user-pos-{{$i}}" {{if $p.DeactivatedAt}}class="disabled-peer"{{end}}>
<th scope="row" class="list-image-cell"> <th scope="row" class="list-image-cell">
<a href="#{{$p.UID}}" data-toggle="collapse" class="collapse-indicator collapsed"></a> <a href="#{{$p.UID}}" data-toggle="collapse" class="collapse-indicator collapsed"></a>
@ -142,14 +143,14 @@
<div class="tab-content" id="tabContent{{$p.UID}}"> <div class="tab-content" id="tabContent{{$p.UID}}">
<div id="t1{{$p.UID}}" class="tab-pane fade active show"> <div id="t1{{$p.UID}}" class="tab-pane fade active show">
<h4>User details</h4> <h4>User details</h4>
{{if not $p.User}} {{if not $peerUser}}
<p>No user information available...</p> <p>No user information available...</p>
{{else}} {{else}}
<ul> <ul>
<li>Firstname: {{$p.User.Firstname}}</li> <li>Firstname: {{$peerUser.Firstname}}</li>
<li>Lastname: {{$p.User.Lastname}}</li> <li>Lastname: {{$peerUser.Lastname}}</li>
<li>Phone: {{$p.User.Phone}}</li> <li>Phone: {{$peerUser.Phone}}</li>
<li>Mail: {{$p.User.Email}}</li> <li>Mail: {{$peerUser.Email}}</li>
</ul> </ul>
{{end}} {{end}}
<h4>Connection / Traffic</h4> <h4>Connection / Traffic</h4>

View File

@ -22,6 +22,19 @@
{{end}} {{end}}
{{end}}{{end}} {{end}}{{end}}
</ul> </ul>
{{with eq $.Session.LoggedIn true}}{{with eq $.Session.IsAdmin true}}
{{with startsWith $.Route "/admin/"}}
<form class="form-inline my-2 my-lg-0" method="get">
<div class="form-group mr-sm-2">
<select name="device" id="inputDevice" class="form-control device-selector">
{{range $i, $d :=$.DeviceNames}}
<option value="{{$d}}" {{if eq $d $.Session.DeviceName}}selected{{end}}>{{$d}}</option>
{{end}}
</select>
</div>
</form>
{{end}}
{{end}}{{end}}
{{if eq $.Session.LoggedIn true}} {{if eq $.Session.LoggedIn true}}
<div class="nav-item dropdown"> <div class="nav-item dropdown">
<a href="#" class="navbar-text dropdown-toggle" data-toggle="dropdown">{{$.Session.Firstname}} {{$.Session.Lastname}} <span class="caret"></span></a> <a href="#" class="navbar-text dropdown-toggle" data-toggle="dropdown">{{$.Session.Firstname}} {{$.Session.Lastname}} <span class="caret"></span></a>
@ -43,6 +56,6 @@
</nav> </nav>
{{if not $.Device.IsValid}} {{if not $.Device.IsValid}}
<div class="container"> <div class="container">
<div class="alert alert-danger">Warning: WireGuard Interface is not fully configured! Configurations may be incomplete and non functional!</div> <div class="alert alert-danger">Warning: WireGuard Interface {{$.Device.DeviceName}} is not fully configured! Configurations may be incomplete and non functional!</div>
</div> </div>
{{end}} {{end}}

View File

@ -30,6 +30,7 @@
</thead> </thead>
<tbody> <tbody>
{{range $i, $p :=.Peers}} {{range $i, $p :=.Peers}}
{{$peerUser:=(userForEmail $.Users $p.Email)}}
<tr id="user-pos-{{$i}}" {{if $p.DeactivatedAt}}class="disabled-peer"{{end}}> <tr id="user-pos-{{$i}}" {{if $p.DeactivatedAt}}class="disabled-peer"{{end}}>
<th scope="row" class="list-image-cell"> <th scope="row" class="list-image-cell">
<a href="#{{$p.UID}}" data-toggle="collapse" class="collapse-indicator collapsed"></a> <a href="#{{$p.UID}}" data-toggle="collapse" class="collapse-indicator collapsed"></a>
@ -58,14 +59,14 @@
<div class="tab-content" id="tabContent{{$p.UID}}"> <div class="tab-content" id="tabContent{{$p.UID}}">
<div id="t1{{$p.UID}}" class="tab-pane fade active show"> <div id="t1{{$p.UID}}" class="tab-pane fade active show">
<h4>User details</h4> <h4>User details</h4>
{{if not $p.User}} {{if not $peerUser}}
<p>No user information available...</p> <p>No user information available...</p>
{{else}} {{else}}
<ul> <ul>
<li>Firstname: {{$p.User.Firstname}}</li> <li>Firstname: {{$peerUser.Firstname}}</li>
<li>Lastname: {{$p.User.Lastname}}</li> <li>Lastname: {{$peerUser.Lastname}}</li>
<li>Phone: {{$p.User.Phone}}</li> <li>Phone: {{$peerUser.Phone}}</li>
<li>Mail: {{$p.User.Email}}</li> <li>Mail: {{$peerUser.Email}}</li>
</ul> </ul>
{{end}} {{end}}
<h4>Traffic</h4> <h4>Traffic</h4>

View File

@ -7,6 +7,8 @@ import (
"strings" "strings"
"time" "time"
"github.com/h44z/wg-portal/internal/common"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/h44z/wg-portal/internal/authentication" "github.com/h44z/wg-portal/internal/authentication"
"github.com/h44z/wg-portal/internal/users" "github.com/h44z/wg-portal/internal/users"
@ -22,11 +24,11 @@ type Provider struct {
db *gorm.DB db *gorm.DB
} }
func New(cfg *users.Config) (*Provider, error) { func New(cfg *common.DatabaseConfig) (*Provider, error) {
p := &Provider{} p := &Provider{}
var err error var err error
p.db, err = users.GetDatabaseForConfig(cfg) p.db, err = common.GetDatabaseForConfig(cfg)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "failed to setup authentication database %s", cfg.Database) return nil, errors.Wrapf(err, "failed to setup authentication database %s", cfg.Database)
} }

76
internal/common/db.go Normal file
View File

@ -0,0 +1,76 @@
package common
import (
"fmt"
"os"
"path/filepath"
"time"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"gorm.io/driver/mysql"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
type SupportedDatabase string
const (
SupportedDatabaseMySQL SupportedDatabase = "mysql"
SupportedDatabaseSQLite SupportedDatabase = "sqlite"
)
type DatabaseConfig struct {
Typ SupportedDatabase `yaml:"typ" envconfig:"DATABASE_TYPE"` //mysql or sqlite
Host string `yaml:"host" envconfig:"DATABASE_HOST"`
Port int `yaml:"port" envconfig:"DATABASE_PORT"`
Database string `yaml:"database" envconfig:"DATABASE_NAME"` // On SQLite: the database file-path, otherwise the database name
User string `yaml:"user" envconfig:"DATABASE_USERNAME"`
Password string `yaml:"password" envconfig:"DATABASE_PASSWORD"`
}
func GetDatabaseForConfig(cfg *DatabaseConfig) (db *gorm.DB, err error) {
switch cfg.Typ {
case SupportedDatabaseSQLite:
if _, err = os.Stat(filepath.Dir(cfg.Database)); os.IsNotExist(err) {
if err = os.MkdirAll(filepath.Dir(cfg.Database), 0700); err != nil {
return
}
}
db, err = gorm.Open(sqlite.Open(cfg.Database), &gorm.Config{})
if err != nil {
return
}
case SupportedDatabaseMySQL:
connectionString := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local", cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.Database)
db, err = gorm.Open(mysql.Open(connectionString), &gorm.Config{})
if err != nil {
return
}
sqlDB, _ := db.DB()
sqlDB.SetConnMaxLifetime(time.Minute * 5)
sqlDB.SetMaxIdleConns(2)
sqlDB.SetMaxOpenConns(10)
err = sqlDB.Ping() // This DOES open a connection if necessary. This makes sure the database is accessible
if err != nil {
return nil, errors.Wrap(err, "failed to ping mysql authentication database")
}
}
// Enable Logger (logrus)
logCfg := logger.Config{
SlowThreshold: time.Second, // all slower than one second
Colorful: false,
LogLevel: logger.Silent, // default: log nothing
}
if logrus.StandardLogger().GetLevel() == logrus.TraceLevel {
logCfg.LogLevel = logger.Info
logCfg.SlowThreshold = 500 * time.Millisecond // all slower than half a second
}
db.Config.Logger = logger.New(logrus.StandardLogger(), logCfg)
return
}

View File

@ -60,6 +60,16 @@ func ListToString(lst []string) string {
return strings.Join(lst, ", ") return strings.Join(lst, ", ")
} }
// ListContains checks if a needle exists in the given list.
func ListContains(lst []string, needle string) bool {
for _, entry := range lst {
if entry == needle {
return true
}
}
return false
}
// https://yourbasic.org/golang/formatting-byte-size-to-human-readable-format/ // https://yourbasic.org/golang/formatting-byte-size-to-human-readable-format/
func ByteCountSI(b int64) string { func ByteCountSI(b int64) string {
const unit = 1000 const unit = 1000

View File

@ -1,12 +1,12 @@
package common package server
import ( import (
"os" "os"
"reflect" "reflect"
"runtime" "runtime"
"github.com/h44z/wg-portal/internal/common"
"github.com/h44z/wg-portal/internal/ldap" "github.com/h44z/wg-portal/internal/ldap"
"github.com/h44z/wg-portal/internal/users"
"github.com/h44z/wg-portal/internal/wireguard" "github.com/h44z/wg-portal/internal/wireguard"
"github.com/kelseyhightower/envconfig" "github.com/kelseyhightower/envconfig"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -66,10 +66,10 @@ type Config struct {
CreateDefaultPeer bool `yaml:"createDefaultPeer" envconfig:"CREATE_DEFAULT_PEER"` CreateDefaultPeer bool `yaml:"createDefaultPeer" envconfig:"CREATE_DEFAULT_PEER"`
LdapEnabled bool `yaml:"ldapEnabled" envconfig:"LDAP_ENABLED"` LdapEnabled bool `yaml:"ldapEnabled" envconfig:"LDAP_ENABLED"`
} `yaml:"core"` } `yaml:"core"`
Database users.Config `yaml:"database"` Database common.DatabaseConfig `yaml:"database"`
Email MailConfig `yaml:"email"` Email common.MailConfig `yaml:"email"`
LDAP ldap.Config `yaml:"ldap"` LDAP ldap.Config `yaml:"ldap"`
WG wireguard.Config `yaml:"wg"` WG wireguard.Config `yaml:"wg"`
} }
func NewConfig() *Config { func NewConfig() *Config {
@ -103,8 +103,9 @@ func NewConfig() *Config {
cfg.LDAP.DisabledAttribute = "userAccountControl" cfg.LDAP.DisabledAttribute = "userAccountControl"
cfg.LDAP.AdminLdapGroup = "CN=WireGuardAdmins,OU=_O_IT,DC=COMPANY,DC=LOCAL" cfg.LDAP.AdminLdapGroup = "CN=WireGuardAdmins,OU=_O_IT,DC=COMPANY,DC=LOCAL"
cfg.WG.DeviceName = "wg0" cfg.WG.DeviceNames = []string{"wg0"}
cfg.WG.WireGuardConfig = "/etc/wireguard/wg0.conf" cfg.WG.DefaultDeviceName = "wg0"
cfg.WG.ConfigDirectoryPath = "/etc/wireguard"
cfg.WG.ManageIPAddresses = true cfg.WG.ManageIPAddresses = true
cfg.Email.Host = "127.0.0.1" cfg.Email.Host = "127.0.0.1"
cfg.Email.Port = 25 cfg.Email.Port = 25

View File

@ -98,7 +98,7 @@ func (s *Server) PostLogin(c *gin.Context) {
Firstname: userData.Firstname, Firstname: userData.Firstname,
Lastname: userData.Lastname, Lastname: userData.Lastname,
Phone: userData.Phone, Phone: userData.Phone,
}); err != nil { }, s.wg.Cfg.DefaultDeviceName); err != nil {
s.GetHandleError(c, http.StatusInternalServerError, "login error", "failed to update user data") s.GetHandleError(c, http.StatusInternalServerError, "login error", "failed to update user data")
return return
} }
@ -121,9 +121,10 @@ func (s *Server) PostLogin(c *gin.Context) {
sessionData.Email = user.Email sessionData.Email = user.Email
sessionData.Firstname = user.Firstname sessionData.Firstname = user.Firstname
sessionData.Lastname = user.Lastname sessionData.Lastname = user.Lastname
sessionData.DeviceName = s.wg.Cfg.DeviceNames[0]
// Check if user already has a peer setup, if not create one // Check if user already has a peer setup, if not create one
if err := s.CreateUserDefaultPeer(user.Email); err != nil { if err := s.CreateUserDefaultPeer(user.Email, s.wg.Cfg.DefaultDeviceName); err != nil {
// Not a fatal error, just log it... // Not a fatal error, just log it...
logrus.Errorf("failed to automatically create vpn peer for %s: %v", sessionData.Email, err) logrus.Errorf("failed to automatically create vpn peer for %s: %v", sessionData.Email, err)
} }

View File

@ -4,37 +4,42 @@ import (
"net/http" "net/http"
"strconv" "strconv"
"github.com/h44z/wg-portal/internal/users"
"github.com/h44z/wg-portal/internal/common"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
func (s *Server) GetHandleError(c *gin.Context, code int, message, details string) { func (s *Server) GetHandleError(c *gin.Context, code int, message, details string) {
currentSession := GetSessionData(c)
c.HTML(code, "error.html", gin.H{ c.HTML(code, "error.html", gin.H{
"Data": gin.H{ "Data": gin.H{
"Code": strconv.Itoa(code), "Code": strconv.Itoa(code),
"Message": message, "Message": message,
"Details": details, "Details": details,
}, },
"Route": c.Request.URL.Path, "Route": c.Request.URL.Path,
"Session": GetSessionData(c), "Session": GetSessionData(c),
"Static": s.getStaticData(), "Static": s.getStaticData(),
"Device": s.peers.GetDevice(currentSession.DeviceName),
"DeviceNames": s.wg.Cfg.DeviceNames,
}) })
} }
func (s *Server) GetIndex(c *gin.Context) { func (s *Server) GetIndex(c *gin.Context) {
c.HTML(http.StatusOK, "index.html", struct { currentSession := GetSessionData(c)
Route string
Alerts []FlashData c.HTML(http.StatusOK, "index.html", gin.H{
Session SessionData "Route": c.Request.URL.Path,
Static StaticData "Alerts": GetFlashes(c),
Device Device "Session": currentSession,
}{ "Static": s.getStaticData(),
Route: c.Request.URL.Path, "Device": s.peers.GetDevice(currentSession.DeviceName),
Alerts: GetFlashes(c), "DeviceNames": s.wg.Cfg.DeviceNames,
Session: GetSessionData(c),
Static: s.getStaticData(),
Device: s.peers.GetDevice(),
}) })
} }
@ -74,25 +79,35 @@ func (s *Server) GetAdminIndex(c *gin.Context) {
return return
} }
device := s.peers.GetDevice() deviceName := c.Query("device")
users := s.peers.GetFilteredAndSortedPeers(currentSession.SortedBy["peers"], currentSession.SortDirection["peers"], currentSession.Search["peers"]) if deviceName != "" {
if !common.ListContains(s.wg.Cfg.DeviceNames, deviceName) {
s.GetHandleError(c, http.StatusInternalServerError, "device selection error", "no such device")
return
}
currentSession.DeviceName = deviceName
c.HTML(http.StatusOK, "admin_index.html", struct { if err := UpdateSessionData(c, currentSession); err != nil {
Route string s.GetHandleError(c, http.StatusInternalServerError, "device selection error", "failed to save session")
Alerts []FlashData return
Session SessionData }
Static StaticData c.Redirect(http.StatusSeeOther, "/admin/")
Peers []Peer return
TotalPeers int }
Device Device
}{ device := s.peers.GetDevice(currentSession.DeviceName)
Route: c.Request.URL.Path, users := s.peers.GetFilteredAndSortedPeers(currentSession.DeviceName, currentSession.SortedBy["peers"], currentSession.SortDirection["peers"], currentSession.Search["peers"])
Alerts: GetFlashes(c),
Session: currentSession, c.HTML(http.StatusOK, "admin_index.html", gin.H{
Static: s.getStaticData(), "Route": c.Request.URL.Path,
Peers: users, "Alerts": GetFlashes(c),
TotalPeers: len(s.peers.GetAllPeers()), "Session": currentSession,
Device: device, "Static": s.getStaticData(),
"Peers": users,
"TotalPeers": len(s.peers.GetAllPeers(currentSession.DeviceName)),
"Users": s.users.GetUsers(),
"Device": device,
"DeviceNames": s.wg.Cfg.DeviceNames,
}) })
} }
@ -120,25 +135,18 @@ func (s *Server) GetUserIndex(c *gin.Context) {
return return
} }
device := s.peers.GetDevice() peers := s.peers.GetSortedPeersForEmail(currentSession.SortedBy["userpeers"], currentSession.SortDirection["userpeers"], currentSession.Email)
users := s.peers.GetSortedPeersForEmail(currentSession.SortedBy["userpeers"], currentSession.SortDirection["userpeers"], currentSession.Email)
c.HTML(http.StatusOK, "user_index.html", struct { c.HTML(http.StatusOK, "user_index.html", gin.H{
Route string "Route": c.Request.URL.Path,
Alerts []FlashData "Alerts": GetFlashes(c),
Session SessionData "Session": currentSession,
Static StaticData "Static": s.getStaticData(),
Peers []Peer "Peers": peers,
TotalPeers int "TotalPeers": len(peers),
Device Device "Users": []users.User{*s.users.GetUser(currentSession.Email)},
}{ "Device": s.peers.GetDevice(currentSession.DeviceName),
Route: c.Request.URL.Path, "DeviceNames": s.wg.Cfg.DeviceNames,
Alerts: GetFlashes(c),
Session: currentSession,
Static: s.getStaticData(),
Peers: users,
TotalPeers: len(users),
Device: device,
}) })
} }
@ -158,7 +166,7 @@ func (s *Server) setNewPeerFormInSession(c *gin.Context) (SessionData, error) {
// If session does not contain a peer form ignore update // If session does not contain a peer form ignore update
// If url contains a formerr parameter reset the form // If url contains a formerr parameter reset the form
if currentSession.FormData == nil || c.Query("formerr") == "" { if currentSession.FormData == nil || c.Query("formerr") == "" {
user, err := s.PrepareNewPeer() user, err := s.PrepareNewPeer(currentSession.DeviceName)
if err != nil { if err != nil {
return currentSession, errors.WithMessage(err, "failed to prepare new peer") return currentSession, errors.WithMessage(err, "failed to prepare new peer")
} }

View File

@ -4,44 +4,37 @@ import (
"net/http" "net/http"
"strings" "strings"
"github.com/h44z/wg-portal/internal/wireguard"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/h44z/wg-portal/internal/common" "github.com/h44z/wg-portal/internal/common"
) )
func (s *Server) GetAdminEditInterface(c *gin.Context) { func (s *Server) GetAdminEditInterface(c *gin.Context) {
device := s.peers.GetDevice() currentSession := GetSessionData(c)
users := s.peers.GetAllPeers() device := s.peers.GetDevice(currentSession.DeviceName)
currentSession, err := s.setFormInSession(c, device) currentSession, err := s.setFormInSession(c, device)
if err != nil { if err != nil {
s.GetHandleError(c, http.StatusInternalServerError, "Session error", err.Error()) s.GetHandleError(c, http.StatusInternalServerError, "Session error", err.Error())
return return
} }
c.HTML(http.StatusOK, "admin_edit_interface.html", struct { c.HTML(http.StatusOK, "admin_edit_interface.html", gin.H{
Route string "Route": c.Request.URL.Path,
Alerts []FlashData "Alerts": GetFlashes(c),
Session SessionData "Session": currentSession,
Static StaticData "Static": s.getStaticData(),
Peers []Peer "Device": currentSession.FormData.(wireguard.Device),
Device Device "EditableKeys": s.config.Core.EditableKeys,
EditableKeys bool "DeviceNames": s.wg.Cfg.DeviceNames,
}{
Route: c.Request.URL.Path,
Alerts: GetFlashes(c),
Session: currentSession,
Static: s.getStaticData(),
Peers: users,
Device: currentSession.FormData.(Device),
EditableKeys: s.config.Core.EditableKeys,
}) })
} }
func (s *Server) PostAdminEditInterface(c *gin.Context) { func (s *Server) PostAdminEditInterface(c *gin.Context) {
currentSession := GetSessionData(c) currentSession := GetSessionData(c)
var formDevice Device var formDevice wireguard.Device
if currentSession.FormData != nil { if currentSession.FormData != nil {
formDevice = currentSession.FormData.(Device) formDevice = currentSession.FormData.(wireguard.Device)
} }
if err := c.ShouldBind(&formDevice); err != nil { if err := c.ShouldBind(&formDevice); err != nil {
_ = s.updateFormInSession(c, formDevice) _ = s.updateFormInSession(c, formDevice)
@ -76,7 +69,7 @@ func (s *Server) PostAdminEditInterface(c *gin.Context) {
} }
// Update WireGuard config file // Update WireGuard config file
err = s.WriteWireGuardConfigFile() err = s.WriteWireGuardConfigFile(currentSession.DeviceName)
if err != nil { if err != nil {
_ = s.updateFormInSession(c, formDevice) _ = s.updateFormInSession(c, formDevice)
SetFlashMessage(c, "Failed to update WireGuard config-file: "+err.Error(), "danger") SetFlashMessage(c, "Failed to update WireGuard config-file: "+err.Error(), "danger")
@ -86,12 +79,12 @@ func (s *Server) PostAdminEditInterface(c *gin.Context) {
// Update interface IP address // Update interface IP address
if s.config.WG.ManageIPAddresses { if s.config.WG.ManageIPAddresses {
if err := s.wg.SetIPAddress(formDevice.IPs); err != nil { if err := s.wg.SetIPAddress(currentSession.DeviceName, formDevice.IPs); err != nil {
_ = s.updateFormInSession(c, formDevice) _ = s.updateFormInSession(c, formDevice)
SetFlashMessage(c, "Failed to update ip address: "+err.Error(), "danger") SetFlashMessage(c, "Failed to update ip address: "+err.Error(), "danger")
c.Redirect(http.StatusSeeOther, "/admin/device/edit?formerr=update") c.Redirect(http.StatusSeeOther, "/admin/device/edit?formerr=update")
} }
if err := s.wg.SetMTU(formDevice.Mtu); err != nil { if err := s.wg.SetMTU(currentSession.DeviceName, formDevice.Mtu); err != nil {
_ = s.updateFormInSession(c, formDevice) _ = s.updateFormInSession(c, formDevice)
SetFlashMessage(c, "Failed to update MTU: "+err.Error(), "danger") SetFlashMessage(c, "Failed to update MTU: "+err.Error(), "danger")
c.Redirect(http.StatusSeeOther, "/admin/device/edit?formerr=update") c.Redirect(http.StatusSeeOther, "/admin/device/edit?formerr=update")
@ -106,9 +99,10 @@ func (s *Server) PostAdminEditInterface(c *gin.Context) {
} }
func (s *Server) GetInterfaceConfig(c *gin.Context) { func (s *Server) GetInterfaceConfig(c *gin.Context) {
device := s.peers.GetDevice() currentSession := GetSessionData(c)
users := s.peers.GetActivePeers() device := s.peers.GetDevice(currentSession.DeviceName)
cfg, err := device.GetConfigFile(users) peers := s.peers.GetActivePeers(device.DeviceName)
cfg, err := device.GetConfigFile(peers)
if err != nil { if err != nil {
s.GetHandleError(c, http.StatusInternalServerError, "ConfigFile error", err.Error()) s.GetHandleError(c, http.StatusInternalServerError, "ConfigFile error", err.Error())
return return
@ -122,13 +116,14 @@ func (s *Server) GetInterfaceConfig(c *gin.Context) {
} }
func (s *Server) GetApplyGlobalConfig(c *gin.Context) { func (s *Server) GetApplyGlobalConfig(c *gin.Context) {
device := s.peers.GetDevice() currentSession := GetSessionData(c)
users := s.peers.GetAllPeers() device := s.peers.GetDevice(currentSession.DeviceName)
peers := s.peers.GetAllPeers(device.DeviceName)
for _, user := range users { for _, peer := range peers {
user.AllowedIPs = device.AllowedIPs peer.AllowedIPs = device.AllowedIPs
user.AllowedIPsStr = device.AllowedIPsStr peer.AllowedIPsStr = device.AllowedIPsStr
if err := s.peers.UpdatePeer(user); err != nil { if err := s.peers.UpdatePeer(peer); err != nil {
SetFlashMessage(c, err.Error(), "danger") SetFlashMessage(c, err.Error(), "danger")
c.Redirect(http.StatusSeeOther, "/admin/device/edit") c.Redirect(http.StatusSeeOther, "/admin/device/edit")
} }

View File

@ -8,9 +8,10 @@ import (
"strings" "strings"
"time" "time"
"github.com/h44z/wg-portal/internal/wireguard"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/h44z/wg-portal/internal/common" "github.com/h44z/wg-portal/internal/common"
"github.com/h44z/wg-portal/internal/users"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/tatsushid/go-fastping" "github.com/tatsushid/go-fastping"
) )
@ -21,7 +22,6 @@ type LdapCreateForm struct {
} }
func (s *Server) GetAdminEditPeer(c *gin.Context) { func (s *Server) GetAdminEditPeer(c *gin.Context) {
device := s.peers.GetDevice()
peer := s.peers.GetPeerByKey(c.Query("pkey")) peer := s.peers.GetPeerByKey(c.Query("pkey"))
currentSession, err := s.setFormInSession(c, peer) currentSession, err := s.setFormInSession(c, peer)
@ -30,22 +30,15 @@ func (s *Server) GetAdminEditPeer(c *gin.Context) {
return return
} }
c.HTML(http.StatusOK, "admin_edit_client.html", struct { c.HTML(http.StatusOK, "admin_edit_client.html", gin.H{
Route string "Route": c.Request.URL.Path,
Alerts []FlashData "Alerts": GetFlashes(c),
Session SessionData "Session": currentSession,
Static StaticData "Static": s.getStaticData(),
Peer Peer "Peer": currentSession.FormData.(wireguard.Peer),
Device Device "EditableKeys": s.config.Core.EditableKeys,
EditableKeys bool "Device": s.peers.GetDevice(currentSession.DeviceName),
}{ "DeviceNames": s.wg.Cfg.DeviceNames,
Route: c.Request.URL.Path,
Alerts: GetFlashes(c),
Session: currentSession,
Static: s.getStaticData(),
Peer: currentSession.FormData.(Peer),
Device: device,
EditableKeys: s.config.Core.EditableKeys,
}) })
} }
@ -54,9 +47,9 @@ func (s *Server) PostAdminEditPeer(c *gin.Context) {
urlEncodedKey := url.QueryEscape(c.Query("pkey")) urlEncodedKey := url.QueryEscape(c.Query("pkey"))
currentSession := GetSessionData(c) currentSession := GetSessionData(c)
var formPeer Peer var formPeer wireguard.Peer
if currentSession.FormData != nil { if currentSession.FormData != nil {
formPeer = currentSession.FormData.(Peer) formPeer = currentSession.FormData.(wireguard.Peer)
} }
if err := c.ShouldBind(&formPeer); err != nil { if err := c.ShouldBind(&formPeer); err != nil {
_ = s.updateFormInSession(c, formPeer) _ = s.updateFormInSession(c, formPeer)
@ -92,37 +85,28 @@ func (s *Server) PostAdminEditPeer(c *gin.Context) {
} }
func (s *Server) GetAdminCreatePeer(c *gin.Context) { func (s *Server) GetAdminCreatePeer(c *gin.Context) {
device := s.peers.GetDevice()
currentSession, err := s.setNewPeerFormInSession(c) currentSession, err := s.setNewPeerFormInSession(c)
if err != nil { if err != nil {
s.GetHandleError(c, http.StatusInternalServerError, "Session error", err.Error()) s.GetHandleError(c, http.StatusInternalServerError, "Session error", err.Error())
return return
} }
c.HTML(http.StatusOK, "admin_edit_client.html", struct { c.HTML(http.StatusOK, "admin_edit_client.html", gin.H{
Route string "Route": c.Request.URL.Path,
Alerts []FlashData "Alerts": GetFlashes(c),
Session SessionData "Session": currentSession,
Static StaticData "Static": s.getStaticData(),
Peer Peer "Peer": currentSession.FormData.(wireguard.Peer),
Device Device "EditableKeys": s.config.Core.EditableKeys,
EditableKeys bool "Device": s.peers.GetDevice(currentSession.DeviceName),
}{ "DeviceNames": s.wg.Cfg.DeviceNames,
Route: c.Request.URL.Path,
Alerts: GetFlashes(c),
Session: currentSession,
Static: s.getStaticData(),
Peer: currentSession.FormData.(Peer),
Device: device,
EditableKeys: s.config.Core.EditableKeys,
}) })
} }
func (s *Server) PostAdminCreatePeer(c *gin.Context) { func (s *Server) PostAdminCreatePeer(c *gin.Context) {
currentSession := GetSessionData(c) currentSession := GetSessionData(c)
var formPeer Peer var formPeer wireguard.Peer
if currentSession.FormData != nil { if currentSession.FormData != nil {
formPeer = currentSession.FormData.(Peer) formPeer = currentSession.FormData.(wireguard.Peer)
} }
if err := c.ShouldBind(&formPeer); err != nil { if err := c.ShouldBind(&formPeer); err != nil {
_ = s.updateFormInSession(c, formPeer) _ = s.updateFormInSession(c, formPeer)
@ -143,7 +127,7 @@ func (s *Server) PostAdminCreatePeer(c *gin.Context) {
formPeer.DeactivatedAt = &now formPeer.DeactivatedAt = &now
} }
if err := s.CreatePeer(formPeer); err != nil { if err := s.CreatePeer(currentSession.DeviceName, formPeer); err != nil {
_ = s.updateFormInSession(c, formPeer) _ = s.updateFormInSession(c, formPeer)
SetFlashMessage(c, "failed to add user: "+err.Error(), "danger") SetFlashMessage(c, "failed to add user: "+err.Error(), "danger")
c.Redirect(http.StatusSeeOther, "/admin/peer/create?formerr=create") c.Redirect(http.StatusSeeOther, "/admin/peer/create?formerr=create")
@ -161,22 +145,15 @@ func (s *Server) GetAdminCreateLdapPeers(c *gin.Context) {
return return
} }
c.HTML(http.StatusOK, "admin_create_clients.html", struct { c.HTML(http.StatusOK, "admin_create_clients.html", gin.H{
Route string "Route": c.Request.URL.Path,
Alerts []FlashData "Alerts": GetFlashes(c),
Session SessionData "Session": currentSession,
Static StaticData "Static": s.getStaticData(),
Users []users.User "Users": s.users.GetFilteredAndSortedUsers("lastname", "asc", ""),
FormData LdapCreateForm "FormData": currentSession.FormData.(LdapCreateForm),
Device Device "Device": s.peers.GetDevice(currentSession.DeviceName),
}{ "DeviceNames": s.wg.Cfg.DeviceNames,
Route: c.Request.URL.Path,
Alerts: GetFlashes(c),
Session: currentSession,
Static: s.getStaticData(),
Users: s.users.GetFilteredAndSortedUsers("lastname", "asc", ""),
FormData: currentSession.FormData.(LdapCreateForm),
Device: s.peers.GetDevice(),
}) })
} }
@ -207,7 +184,7 @@ func (s *Server) PostAdminCreateLdapPeers(c *gin.Context) {
logrus.Infof("creating %d ldap peers", len(emails)) logrus.Infof("creating %d ldap peers", len(emails))
for i := range emails { for i := range emails {
if err := s.CreatePeerByEmail(emails[i], formData.Identifier, false); err != nil { if err := s.CreatePeerByEmail(currentSession.DeviceName, emails[i], formData.Identifier, false); err != nil {
_ = s.updateFormInSession(c, formData) _ = s.updateFormInSession(c, formData)
SetFlashMessage(c, "failed to add user: "+err.Error(), "danger") SetFlashMessage(c, "failed to add user: "+err.Error(), "danger")
c.Redirect(http.StatusSeeOther, "/admin/peer/createldap?formerr=create") c.Redirect(http.StatusSeeOther, "/admin/peer/createldap?formerr=create")
@ -225,7 +202,7 @@ func (s *Server) GetAdminDeletePeer(c *gin.Context) {
s.GetHandleError(c, http.StatusInternalServerError, "Deletion error", err.Error()) s.GetHandleError(c, http.StatusInternalServerError, "Deletion error", err.Error())
return return
} }
SetFlashMessage(c, "user deleted successfully", "success") SetFlashMessage(c, "peer deleted successfully", "success")
c.Redirect(http.StatusSeeOther, "/admin") c.Redirect(http.StatusSeeOther, "/admin")
} }
@ -254,7 +231,7 @@ func (s *Server) GetPeerConfig(c *gin.Context) {
return return
} }
cfg, err := user.GetConfigFile(s.peers.GetDevice()) cfg, err := user.GetConfigFile(s.peers.GetDevice(currentSession.DeviceName))
if err != nil { if err != nil {
s.GetHandleError(c, http.StatusInternalServerError, "ConfigFile error", err.Error()) s.GetHandleError(c, http.StatusInternalServerError, "ConfigFile error", err.Error())
return return
@ -273,7 +250,7 @@ func (s *Server) GetPeerConfigMail(c *gin.Context) {
return return
} }
cfg, err := user.GetConfigFile(s.peers.GetDevice()) cfg, err := user.GetConfigFile(s.peers.GetDevice(currentSession.DeviceName))
if err != nil { if err != nil {
s.GetHandleError(c, http.StatusInternalServerError, "ConfigFile error", err.Error()) s.GetHandleError(c, http.StatusInternalServerError, "ConfigFile error", err.Error())
return return
@ -286,7 +263,7 @@ func (s *Server) GetPeerConfigMail(c *gin.Context) {
// Apply mail template // Apply mail template
var tplBuff bytes.Buffer var tplBuff bytes.Buffer
if err := s.mailTpl.Execute(&tplBuff, struct { if err := s.mailTpl.Execute(&tplBuff, struct {
Client Peer Client wireguard.Peer
QrcodePngName string QrcodePngName string
PortalUrl string PortalUrl string
}{ }{

View File

@ -49,22 +49,15 @@ func (s *Server) GetAdminUsersIndex(c *gin.Context) {
dbUsers := s.users.GetFilteredAndSortedUsersUnscoped(currentSession.SortedBy["users"], currentSession.SortDirection["users"], currentSession.Search["users"]) dbUsers := s.users.GetFilteredAndSortedUsersUnscoped(currentSession.SortedBy["users"], currentSession.SortDirection["users"], currentSession.Search["users"])
c.HTML(http.StatusOK, "admin_user_index.html", struct { c.HTML(http.StatusOK, "admin_user_index.html", gin.H{
Route string "Route": c.Request.URL.Path,
Alerts []FlashData "Alerts": GetFlashes(c),
Session SessionData "Session": currentSession,
Static StaticData "Static": s.getStaticData(),
Users []users.User "Users": dbUsers,
TotalUsers int "TotalUsers": len(s.users.GetUsers()),
Device Device "Device": s.peers.GetDevice(currentSession.DeviceName),
}{ "DeviceNames": s.wg.Cfg.DeviceNames,
Route: c.Request.URL.Path,
Alerts: GetFlashes(c),
Session: currentSession,
Static: s.getStaticData(),
Users: dbUsers,
TotalUsers: len(s.users.GetUsers()),
Device: s.peers.GetDevice(),
}) })
} }
@ -77,21 +70,14 @@ func (s *Server) GetAdminUsersEdit(c *gin.Context) {
return return
} }
c.HTML(http.StatusOK, "admin_edit_user.html", struct { c.HTML(http.StatusOK, "admin_edit_user.html", gin.H{
Route string "Route": c.Request.URL.Path,
Alerts []FlashData "Alerts": GetFlashes(c),
Session SessionData "Session": currentSession,
Static StaticData "Static": s.getStaticData(),
User users.User "User": currentSession.FormData.(users.User),
Device Device "Device": s.peers.GetDevice(currentSession.DeviceName),
Epoch time.Time "DeviceNames": s.wg.Cfg.DeviceNames,
}{
Route: c.Request.URL.Path,
Alerts: GetFlashes(c),
Session: currentSession,
Static: s.getStaticData(),
User: currentSession.FormData.(users.User),
Device: s.peers.GetDevice(),
}) })
} }
@ -160,21 +146,14 @@ func (s *Server) GetAdminUsersCreate(c *gin.Context) {
return return
} }
c.HTML(http.StatusOK, "admin_edit_user.html", struct { c.HTML(http.StatusOK, "admin_edit_user.html", gin.H{
Route string "Route": c.Request.URL.Path,
Alerts []FlashData "Alerts": GetFlashes(c),
Session SessionData "Session": currentSession,
Static StaticData "Static": s.getStaticData(),
User users.User "User": currentSession.FormData.(users.User),
Device Device "Device": s.peers.GetDevice(currentSession.DeviceName),
Epoch time.Time "DeviceNames": s.wg.Cfg.DeviceNames,
}{
Route: c.Request.URL.Path,
Alerts: GetFlashes(c),
Session: currentSession,
Static: s.getStaticData(),
User: currentSession.FormData.(users.User),
Device: s.peers.GetDevice(),
}) })
} }
@ -218,7 +197,7 @@ func (s *Server) PostAdminUsersCreate(c *gin.Context) {
formUser.IsAdmin = c.PostForm("isadmin") == "true" formUser.IsAdmin = c.PostForm("isadmin") == "true"
formUser.Source = users.UserSourceDatabase formUser.Source = users.UserSourceDatabase
if err := s.CreateUser(formUser); err != nil { if err := s.CreateUser(formUser, currentSession.DeviceName); err != nil {
_ = s.updateFormInSession(c, formUser) _ = s.updateFormInSession(c, formUser)
SetFlashMessage(c, "failed to add user: "+err.Error(), "danger") SetFlashMessage(c, "failed to add user: "+err.Error(), "danger")
c.Redirect(http.StatusSeeOther, "/admin/users/create?formerr=create") c.Redirect(http.StatusSeeOther, "/admin/users/create?formerr=create")

View File

@ -4,14 +4,14 @@ import (
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
wg_portal "github.com/h44z/wg-portal" wgportal "github.com/h44z/wg-portal"
) )
func SetupRoutes(s *Server) { func SetupRoutes(s *Server) {
// Startpage // Startpage
s.server.GET("/", s.GetIndex) s.server.GET("/", s.GetIndex)
s.server.GET("/favicon.ico", func(c *gin.Context) { s.server.GET("/favicon.ico", func(c *gin.Context) {
file, _ := wg_portal.Statics.ReadFile("assets/img/favicon.ico") file, _ := wgportal.Statics.ReadFile("assets/img/favicon.ico")
c.Data( c.Data(
http.StatusOK, http.StatusOK,
"image/x-icon", "image/x-icon",

View File

@ -11,12 +11,15 @@ import (
"net/url" "net/url"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"time" "time"
"gorm.io/gorm"
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-contrib/sessions/memstore" "github.com/gin-contrib/sessions/memstore"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
wg_portal "github.com/h44z/wg-portal" wgportal "github.com/h44z/wg-portal"
ldapprovider "github.com/h44z/wg-portal/internal/authentication/providers/ldap" ldapprovider "github.com/h44z/wg-portal/internal/authentication/providers/ldap"
passwordprovider "github.com/h44z/wg-portal/internal/authentication/providers/password" passwordprovider "github.com/h44z/wg-portal/internal/authentication/providers/password"
"github.com/h44z/wg-portal/internal/common" "github.com/h44z/wg-portal/internal/common"
@ -32,18 +35,19 @@ const SessionIdentifier = "wgPortalSession"
func init() { func init() {
gob.Register(SessionData{}) gob.Register(SessionData{})
gob.Register(FlashData{}) gob.Register(FlashData{})
gob.Register(Peer{}) gob.Register(wireguard.Peer{})
gob.Register(Device{}) gob.Register(wireguard.Device{})
gob.Register(LdapCreateForm{}) gob.Register(LdapCreateForm{})
gob.Register(users.User{}) gob.Register(users.User{})
} }
type SessionData struct { type SessionData struct {
LoggedIn bool LoggedIn bool
IsAdmin bool IsAdmin bool
Firstname string Firstname string
Lastname string Lastname string
Email string Email string
DeviceName string
SortedBy map[string]string SortedBy map[string]string
SortDirection map[string]string SortDirection map[string]string
@ -69,14 +73,15 @@ type StaticData struct {
type Server struct { type Server struct {
ctx context.Context ctx context.Context
config *common.Config config *Config
server *gin.Engine server *gin.Engine
mailTpl *template.Template mailTpl *template.Template
auth *AuthManager auth *AuthManager
db *gorm.DB
users *users.Manager users *users.Manager
wg *wireguard.Manager wg *wireguard.Manager
peers *PeerManager peers *wireguard.PeerManager
} }
func (s *Server) Setup(ctx context.Context) error { func (s *Server) Setup(ctx context.Context) error {
@ -90,9 +95,15 @@ func (s *Server) Setup(ctx context.Context) error {
// Init rand // Init rand
rand.Seed(time.Now().UnixNano()) rand.Seed(time.Now().UnixNano())
s.config = common.NewConfig() s.config = NewConfig()
s.ctx = ctx s.ctx = ctx
// Setup database connection
s.db, err = common.GetDatabaseForConfig(&s.config.Database)
if err != nil {
return errors.WithMessage(err, "database setup failed")
}
// Setup http server // Setup http server
gin.SetMode(gin.DebugMode) gin.SetMode(gin.DebugMode)
gin.DefaultWriter = ioutil.Discard gin.DefaultWriter = ioutil.Discard
@ -104,24 +115,33 @@ func (s *Server) Setup(ctx context.Context) error {
s.server.SetFuncMap(template.FuncMap{ s.server.SetFuncMap(template.FuncMap{
"formatBytes": common.ByteCountSI, "formatBytes": common.ByteCountSI,
"urlEncode": url.QueryEscape, "urlEncode": url.QueryEscape,
"startsWith": strings.HasPrefix,
"userForEmail": func(users []users.User, email string) *users.User {
for i := range users {
if users[i].Email == email {
return &users[i]
}
}
return nil
},
}) })
// Setup templates // Setup templates
templates := template.Must(template.New("").Funcs(s.server.FuncMap).ParseFS(wg_portal.Templates, "assets/tpl/*.html")) templates := template.Must(template.New("").Funcs(s.server.FuncMap).ParseFS(wgportal.Templates, "assets/tpl/*.html"))
s.server.SetHTMLTemplate(templates) s.server.SetHTMLTemplate(templates)
s.server.Use(sessions.Sessions("authsession", memstore.NewStore([]byte("secret")))) // TODO: change key? s.server.Use(sessions.Sessions("authsession", memstore.NewStore([]byte("secret")))) // TODO: change key?
// Serve static files // Serve static files
s.server.StaticFS("/css", http.FS(fsMust(fs.Sub(wg_portal.Statics, "assets/css")))) s.server.StaticFS("/css", http.FS(fsMust(fs.Sub(wgportal.Statics, "assets/css"))))
s.server.StaticFS("/js", http.FS(fsMust(fs.Sub(wg_portal.Statics, "assets/js")))) s.server.StaticFS("/js", http.FS(fsMust(fs.Sub(wgportal.Statics, "assets/js"))))
s.server.StaticFS("/img", http.FS(fsMust(fs.Sub(wg_portal.Statics, "assets/img")))) s.server.StaticFS("/img", http.FS(fsMust(fs.Sub(wgportal.Statics, "assets/img"))))
s.server.StaticFS("/fonts", http.FS(fsMust(fs.Sub(wg_portal.Statics, "assets/fonts")))) s.server.StaticFS("/fonts", http.FS(fsMust(fs.Sub(wgportal.Statics, "assets/fonts"))))
// Setup all routes // Setup all routes
SetupRoutes(s) SetupRoutes(s)
// Setup user database (also needed for database authentication) // Setup user database (also needed for database authentication)
s.users, err = users.NewManager(&s.config.Database) s.users, err = users.NewManager(s.db)
if err != nil { if err != nil {
return errors.WithMessage(err, "user-manager initialization failed") return errors.WithMessage(err, "user-manager initialization failed")
} }
@ -153,18 +173,21 @@ func (s *Server) Setup(ctx context.Context) error {
} }
// Setup peer manager // Setup peer manager
if s.peers, err = NewPeerManager(s.config, s.wg, s.users); err != nil { if s.peers, err = wireguard.NewPeerManager(s.db, s.wg); err != nil {
return errors.WithMessage(err, "unable to setup peer manager") return errors.WithMessage(err, "unable to setup peer manager")
} }
if err = s.peers.InitFromCurrentInterface(); err != nil { if err = s.peers.InitFromPhysicalInterface(); err != nil {
return errors.WithMessage(err, "unable to initialize peer manager") return errors.WithMessagef(err, "unable to initialize peer manager")
} }
if err = s.RestoreWireGuardInterface(); err != nil {
return errors.WithMessage(err, "unable to restore WireGuard state") for _, deviceName := range s.wg.Cfg.DeviceNames {
if err = s.RestoreWireGuardInterface(deviceName); err != nil {
return errors.WithMessagef(err, "unable to restore WireGuard state for %s", deviceName)
}
} }
// Setup mail template // Setup mail template
s.mailTpl, err = template.New("email.html").ParseFS(wg_portal.Templates, "assets/tpl/email.html") s.mailTpl, err = template.New("email.html").ParseFS(wgportal.Templates, "assets/tpl/email.html")
if err != nil { if err != nil {
return errors.Wrap(err, "unable to pare mail template") return errors.Wrap(err, "unable to pare mail template")
} }
@ -174,6 +197,8 @@ func (s *Server) Setup(ctx context.Context) error {
} }
func (s *Server) Run() { func (s *Server) Run() {
logrus.Infof("starting web service on %s", s.config.Core.ListeningAddress)
// Start ldap sync // Start ldap sync
if s.config.Core.LdapEnabled { if s.config.Core.LdapEnabled {
go s.SyncLdapWithUserDatabase() go s.SyncLdapWithUserDatabase()
@ -238,6 +263,7 @@ func GetSessionData(c *gin.Context) SessionData {
Email: "", Email: "",
Firstname: "", Firstname: "",
Lastname: "", Lastname: "",
DeviceName: "",
IsAdmin: false, IsAdmin: false,
LoggedIn: false, LoggedIn: false,
} }

View File

@ -4,9 +4,12 @@ import (
"crypto/md5" "crypto/md5"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"path"
"syscall" "syscall"
"time" "time"
"github.com/h44z/wg-portal/internal/wireguard"
"github.com/h44z/wg-portal/internal/common" "github.com/h44z/wg-portal/internal/common"
"github.com/h44z/wg-portal/internal/users" "github.com/h44z/wg-portal/internal/users"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -15,28 +18,29 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
) )
func (s *Server) PrepareNewPeer() (Peer, error) { // PrepareNewPeer initiates a new peer for the given WireGuard device.
device := s.peers.GetDevice() func (s *Server) PrepareNewPeer(device string) (wireguard.Peer, error) {
dev := s.peers.GetDevice(device)
peer := Peer{} peer := wireguard.Peer{}
peer.IsNew = true peer.IsNew = true
peer.AllowedIPsStr = device.AllowedIPsStr peer.AllowedIPsStr = dev.AllowedIPsStr
peer.IPs = make([]string, len(device.IPs)) peer.IPs = make([]string, len(dev.IPs))
for i := range device.IPs { for i := range dev.IPs {
freeIP, err := s.peers.GetAvailableIp(device.IPs[i]) freeIP, err := s.peers.GetAvailableIp(device, dev.IPs[i])
if err != nil { if err != nil {
return Peer{}, errors.WithMessage(err, "failed to get available IP addresses") return wireguard.Peer{}, errors.WithMessage(err, "failed to get available IP addresses")
} }
peer.IPs[i] = freeIP peer.IPs[i] = freeIP
} }
peer.IPsStr = common.ListToString(peer.IPs) peer.IPsStr = common.ListToString(peer.IPs)
psk, err := wgtypes.GenerateKey() psk, err := wgtypes.GenerateKey()
if err != nil { if err != nil {
return Peer{}, errors.Wrap(err, "failed to generate key") return wireguard.Peer{}, errors.Wrap(err, "failed to generate key")
} }
key, err := wgtypes.GeneratePrivateKey() key, err := wgtypes.GeneratePrivateKey()
if err != nil { if err != nil {
return Peer{}, errors.Wrap(err, "failed to generate private key") return wireguard.Peer{}, errors.Wrap(err, "failed to generate private key")
} }
peer.PresharedKey = psk.String() peer.PresharedKey = psk.String()
peer.PrivateKey = key.String() peer.PrivateKey = key.String()
@ -46,54 +50,39 @@ func (s *Server) PrepareNewPeer() (Peer, error) {
return peer, nil return peer, nil
} }
func (s *Server) CreatePeerByEmail(email, identifierSuffix string, disabled bool) error { // CreatePeerByEmail creates a new peer for the given email. If no user with the specified email was found, a new one
// will be created.
func (s *Server) CreatePeerByEmail(device, email, identifierSuffix string, disabled bool) error {
user, err := s.users.GetOrCreateUser(email) user, err := s.users.GetOrCreateUser(email)
if err != nil { if err != nil {
return errors.WithMessagef(err, "failed to load/create related user %s", email) return errors.WithMessagef(err, "failed to load/create related user %s", email)
} }
device := s.peers.GetDevice() peer, err := s.PrepareNewPeer(device)
peer := Peer{}
peer.User = user
peer.AllowedIPsStr = device.AllowedIPsStr
peer.IPs = make([]string, len(device.IPs))
for i := range device.IPs {
freeIP, err := s.peers.GetAvailableIp(device.IPs[i])
if err != nil {
return errors.WithMessage(err, "failed to get available IP addresses")
}
peer.IPs[i] = freeIP
}
peer.IPsStr = common.ListToString(peer.IPs)
psk, err := wgtypes.GenerateKey()
if err != nil { if err != nil {
return errors.Wrap(err, "failed to generate key") return errors.WithMessage(err, "failed to prepare new peer")
} }
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
return errors.Wrap(err, "failed to generate private key")
}
peer.PresharedKey = psk.String()
peer.PrivateKey = key.String()
peer.PublicKey = key.PublicKey().String()
peer.UID = fmt.Sprintf("u%x", md5.Sum([]byte(peer.PublicKey)))
peer.Email = email peer.Email = email
peer.Identifier = fmt.Sprintf("%s %s (%s)", user.Firstname, user.Lastname, identifierSuffix) peer.Identifier = fmt.Sprintf("%s %s (%s)", user.Firstname, user.Lastname, identifierSuffix)
now := time.Now() now := time.Now()
if disabled { if disabled {
peer.DeactivatedAt = &now peer.DeactivatedAt = &now
} }
return s.CreatePeer(peer) return s.CreatePeer(device, peer)
} }
func (s *Server) CreatePeer(peer Peer) error { // CreatePeer creates the new peer in the database. If the peer has no assigned ip addresses, a new one will be assigned
device := s.peers.GetDevice() // automatically. Also, if the private key is empty, a new key-pair will be generated.
peer.AllowedIPsStr = device.AllowedIPsStr // This function also configures the new peer on the physical WireGuard interface if the peer is not deactivated.
func (s *Server) CreatePeer(device string, peer wireguard.Peer) error {
dev := s.peers.GetDevice(device)
peer.AllowedIPsStr = dev.AllowedIPsStr
if peer.IPs == nil || len(peer.IPs) == 0 { if peer.IPs == nil || len(peer.IPs) == 0 {
peer.IPs = make([]string, len(device.IPs)) peer.IPs = make([]string, len(dev.IPs))
for i := range device.IPs { for i := range dev.IPs {
freeIP, err := s.peers.GetAvailableIp(device.IPs[i]) freeIP, err := s.peers.GetAvailableIp(device, dev.IPs[i])
if err != nil { if err != nil {
return errors.WithMessage(err, "failed to get available IP addresses") return errors.WithMessage(err, "failed to get available IP addresses")
} }
@ -114,11 +103,12 @@ func (s *Server) CreatePeer(peer Peer) error {
peer.PrivateKey = key.String() peer.PrivateKey = key.String()
peer.PublicKey = key.PublicKey().String() peer.PublicKey = key.PublicKey().String()
} }
peer.DeviceName = dev.DeviceName
peer.UID = fmt.Sprintf("u%x", md5.Sum([]byte(peer.PublicKey))) peer.UID = fmt.Sprintf("u%x", md5.Sum([]byte(peer.PublicKey)))
// Create WireGuard interface // Create WireGuard interface
if peer.DeactivatedAt == nil { if peer.DeactivatedAt == nil {
if err := s.wg.AddPeer(peer.GetConfig()); err != nil { if err := s.wg.AddPeer(device, peer.GetConfig()); err != nil {
return errors.WithMessage(err, "failed to add WireGuard peer") return errors.WithMessage(err, "failed to add WireGuard peer")
} }
} }
@ -128,21 +118,22 @@ func (s *Server) CreatePeer(peer Peer) error {
return errors.WithMessage(err, "failed to create peer") return errors.WithMessage(err, "failed to create peer")
} }
return s.WriteWireGuardConfigFile() return s.WriteWireGuardConfigFile(device)
} }
func (s *Server) UpdatePeer(peer Peer, updateTime time.Time) error { // UpdatePeer updates the physical WireGuard interface and the database.
func (s *Server) UpdatePeer(peer wireguard.Peer, updateTime time.Time) error {
currentPeer := s.peers.GetPeerByKey(peer.PublicKey) currentPeer := s.peers.GetPeerByKey(peer.PublicKey)
// Update WireGuard device // Update WireGuard device
var err error var err error
switch { switch {
case peer.DeactivatedAt == &updateTime: case peer.DeactivatedAt == &updateTime:
err = s.wg.RemovePeer(peer.PublicKey) err = s.wg.RemovePeer(peer.DeviceName, peer.PublicKey)
case peer.DeactivatedAt == nil && currentPeer.Peer != nil: case peer.DeactivatedAt == nil && currentPeer.Peer != nil:
err = s.wg.UpdatePeer(peer.GetConfig()) err = s.wg.UpdatePeer(peer.DeviceName, peer.GetConfig())
case peer.DeactivatedAt == nil && currentPeer.Peer == nil: case peer.DeactivatedAt == nil && currentPeer.Peer == nil:
err = s.wg.AddPeer(peer.GetConfig()) err = s.wg.AddPeer(peer.DeviceName, peer.GetConfig())
} }
if err != nil { if err != nil {
return errors.WithMessage(err, "failed to update WireGuard peer") return errors.WithMessage(err, "failed to update WireGuard peer")
@ -153,12 +144,13 @@ func (s *Server) UpdatePeer(peer Peer, updateTime time.Time) error {
return errors.WithMessage(err, "failed to update peer") return errors.WithMessage(err, "failed to update peer")
} }
return s.WriteWireGuardConfigFile() return s.WriteWireGuardConfigFile(peer.DeviceName)
} }
func (s *Server) DeletePeer(peer Peer) error { // DeletePeer removes the peer from the physical WireGuard interface and the database.
func (s *Server) DeletePeer(peer wireguard.Peer) error {
// Delete WireGuard peer // Delete WireGuard peer
if err := s.wg.RemovePeer(peer.PublicKey); err != nil { if err := s.wg.RemovePeer(peer.DeviceName, peer.PublicKey); err != nil {
return errors.WithMessage(err, "failed to remove WireGuard peer") return errors.WithMessage(err, "failed to remove WireGuard peer")
} }
@ -167,15 +159,16 @@ func (s *Server) DeletePeer(peer Peer) error {
return errors.WithMessage(err, "failed to remove peer") return errors.WithMessage(err, "failed to remove peer")
} }
return s.WriteWireGuardConfigFile() return s.WriteWireGuardConfigFile(peer.DeviceName)
} }
func (s *Server) RestoreWireGuardInterface() error { // RestoreWireGuardInterface restores the state of the physical WireGuard interface from the database.
activePeers := s.peers.GetActivePeers() func (s *Server) RestoreWireGuardInterface(device string) error {
activePeers := s.peers.GetActivePeers(device)
for i := range activePeers { for i := range activePeers {
if activePeers[i].Peer == nil { if activePeers[i].Peer == nil {
if err := s.wg.AddPeer(activePeers[i].GetConfig()); err != nil { if err := s.wg.AddPeer(device, activePeers[i].GetConfig()); err != nil {
return errors.WithMessage(err, "failed to add WireGuard peer") return errors.WithMessage(err, "failed to add WireGuard peer")
} }
} }
@ -184,26 +177,29 @@ func (s *Server) RestoreWireGuardInterface() error {
return nil return nil
} }
func (s *Server) WriteWireGuardConfigFile() error { // WriteWireGuardConfigFile writes the configuration file for the physical WireGuard interface.
if s.config.WG.WireGuardConfig == "" { func (s *Server) WriteWireGuardConfigFile(device string) error {
if s.config.WG.ConfigDirectoryPath == "" {
return nil // writing disabled return nil // writing disabled
} }
if err := syscall.Access(s.config.WG.WireGuardConfig, syscall.O_RDWR); err != nil { if err := syscall.Access(s.config.WG.ConfigDirectoryPath, syscall.O_RDWR); err != nil {
return errors.Wrap(err, "failed to check WireGuard config access rights") return errors.Wrap(err, "failed to check WireGuard config access rights")
} }
device := s.peers.GetDevice() dev := s.peers.GetDevice(device)
cfg, err := device.GetConfigFile(s.peers.GetActivePeers()) cfg, err := dev.GetConfigFile(s.peers.GetActivePeers(device))
if err != nil { if err != nil {
return errors.WithMessage(err, "failed to get config file") return errors.WithMessage(err, "failed to get config file")
} }
if err := ioutil.WriteFile(s.config.WG.WireGuardConfig, cfg, 0644); err != nil { filePath := path.Join(s.config.WG.ConfigDirectoryPath, dev.DeviceName+".conf")
if err := ioutil.WriteFile(filePath, cfg, 0644); err != nil {
return errors.Wrap(err, "failed to write WireGuard config file") return errors.Wrap(err, "failed to write WireGuard config file")
} }
return nil return nil
} }
func (s *Server) CreateUser(user users.User) error { // CreateUser creates the user in the database and optionally adds a default WireGuard peer for the user.
func (s *Server) CreateUser(user users.User, device string) error {
if user.Email == "" { if user.Email == "" {
return errors.New("cannot create user with empty email address") return errors.New("cannot create user with empty email address")
} }
@ -220,9 +216,11 @@ func (s *Server) CreateUser(user users.User) error {
} }
// Check if user already has a peer setup, if not, create one // Check if user already has a peer setup, if not, create one
return s.CreateUserDefaultPeer(user.Email) return s.CreateUserDefaultPeer(user.Email, device)
} }
// UpdateUser updates the user in the database. If the user is marked as deleted, it will get remove from the database.
// Also, if the user is re-enabled, all it's linked WireGuard peers will be activated again.
func (s *Server) UpdateUser(user users.User) error { func (s *Server) UpdateUser(user users.User) error {
if user.DeletedAt.Valid { if user.DeletedAt.Valid {
return s.DeleteUser(user) return s.DeleteUser(user)
@ -249,6 +247,8 @@ func (s *Server) UpdateUser(user users.User) error {
return nil return nil
} }
// DeleteUser removes the user from the database.
// Also, if the user has linked WireGuard peers, they will be deactivated.
func (s *Server) DeleteUser(user users.User) error { func (s *Server) DeleteUser(user users.User) error {
currentUser := s.users.GetUserUnscoped(user.Email) currentUser := s.users.GetUserUnscoped(user.Email)
@ -271,7 +271,7 @@ func (s *Server) DeleteUser(user users.User) error {
return nil return nil
} }
func (s *Server) CreateUserDefaultPeer(email string) error { func (s *Server) CreateUserDefaultPeer(email, device string) error {
// Check if user is active, if not, quit // Check if user is active, if not, quit
var existingUser *users.User var existingUser *users.User
if existingUser = s.users.GetUser(email); existingUser == nil { if existingUser = s.users.GetUser(email); existingUser == nil {
@ -282,7 +282,7 @@ func (s *Server) CreateUserDefaultPeer(email string) error {
if s.config.Core.CreateDefaultPeer { if s.config.Core.CreateDefaultPeer {
peers := s.peers.GetPeersByMail(email) peers := s.peers.GetPeersByMail(email)
if len(peers) == 0 { // Create default vpn peer if len(peers) == 0 { // Create default vpn peer
if err := s.CreatePeer(Peer{ if err := s.CreatePeer(device, wireguard.Peer{
Identifier: existingUser.Firstname + " " + existingUser.Lastname + " (Default)", Identifier: existingUser.Firstname + " " + existingUser.Lastname + " (Default)",
Email: existingUser.Email, Email: existingUser.Email,
CreatedBy: existingUser.Email, CreatedBy: existingUser.Email,

View File

@ -1,17 +0,0 @@
package users
type SupportedDatabase string
const (
SupportedDatabaseMySQL SupportedDatabase = "mysql"
SupportedDatabaseSQLite SupportedDatabase = "sqlite"
)
type Config struct {
Typ SupportedDatabase `yaml:"typ" envconfig:"DATABASE_TYPE"` //mysql or sqlite
Host string `yaml:"host" envconfig:"DATABASE_HOST"`
Port int `yaml:"port" envconfig:"DATABASE_PORT"`
Database string `yaml:"database" envconfig:"DATABASE_NAME"` // On SQLite: the database file-path, otherwise the database name
User string `yaml:"user" envconfig:"DATABASE_USERNAME"`
Password string `yaml:"password" envconfig:"DATABASE_PASSWORD"`
}

View File

@ -1,9 +1,6 @@
package users package users
import ( import (
"fmt"
"os"
"path/filepath"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
@ -11,69 +8,15 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"gorm.io/driver/mysql"
"gorm.io/driver/sqlite"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/logger"
) )
func GetDatabaseForConfig(cfg *Config) (db *gorm.DB, err error) {
switch cfg.Typ {
case SupportedDatabaseSQLite:
if _, err = os.Stat(filepath.Dir(cfg.Database)); os.IsNotExist(err) {
if err = os.MkdirAll(filepath.Dir(cfg.Database), 0700); err != nil {
return
}
}
db, err = gorm.Open(sqlite.Open(cfg.Database), &gorm.Config{})
if err != nil {
return
}
case SupportedDatabaseMySQL:
connectionString := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local", cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.Database)
db, err = gorm.Open(mysql.Open(connectionString), &gorm.Config{})
if err != nil {
return
}
sqlDB, _ := db.DB()
sqlDB.SetConnMaxLifetime(time.Minute * 5)
sqlDB.SetMaxIdleConns(2)
sqlDB.SetMaxOpenConns(10)
err = sqlDB.Ping() // This DOES open a connection if necessary. This makes sure the database is accessible
if err != nil {
return nil, errors.Wrap(err, "failed to ping mysql authentication database")
}
}
// Enable Logger (logrus)
logCfg := logger.Config{
SlowThreshold: time.Second, // all slower than one second
Colorful: false,
LogLevel: logger.Silent, // default: log nothing
}
if logrus.StandardLogger().GetLevel() == logrus.TraceLevel {
logCfg.LogLevel = logger.Info
logCfg.SlowThreshold = 500 * time.Millisecond // all slower than half a second
}
db.Config.Logger = logger.New(logrus.StandardLogger(), logCfg)
return
}
type Manager struct { type Manager struct {
db *gorm.DB db *gorm.DB
} }
func NewManager(cfg *Config) (*Manager, error) { func NewManager(db *gorm.DB) (*Manager, error) {
m := &Manager{} m := &Manager{db: db}
var err error
m.db, err = GetDatabaseForConfig(cfg)
if err != nil {
return nil, errors.Wrapf(err, "failed to setup user database %s", cfg.Database)
}
// check if old user table exists (from version <= 1.0.2), if so rename it to peers. // check if old user table exists (from version <= 1.0.2), if so rename it to peers.
if m.db.Migrator().HasTable("users") && !m.db.Migrator().HasTable("peers") { if m.db.Migrator().HasTable("users") && !m.db.Migrator().HasTable("peers") {
@ -84,14 +27,11 @@ func NewManager(cfg *Config) (*Manager, error) {
} }
} }
return m, m.MigrateUserDB()
}
func (m Manager) MigrateUserDB() error {
if err := m.db.AutoMigrate(&User{}); err != nil { if err := m.db.AutoMigrate(&User{}); err != nil {
return errors.Wrap(err, "failed to migrate user database") return nil, errors.Wrap(err, "failed to migrate user database")
} }
return nil
return m, nil
} }
func (m Manager) GetUsers() []User { func (m Manager) GetUsers() []User {

View File

@ -1,7 +1,8 @@
package wireguard package wireguard
type Config struct { type Config struct {
DeviceName string `yaml:"device" envconfig:"WG_DEVICE"` DeviceNames []string `yaml:"devices" envconfig:"WG_DEVICES"` // managed devices
WireGuardConfig string `yaml:"configFile" envconfig:"WG_CONFIG_FILE"` // optional, if set, updates will be written to this file DefaultDeviceName string `yaml:"devices" envconfig:"WG_DEFAULT_DEVICE"` // this device is used for auto-created peers
ManageIPAddresses bool `yaml:"manageIPAddresses" envconfig:"MANAGE_IPS"` // handle ip-address setup of interface ConfigDirectoryPath string `yaml:"configDirectory" envconfig:"WG_CONFIG_PATH"` // optional, if set, updates will be written to this path, filename: <devicename>.conf
ManageIPAddresses bool `yaml:"manageIPAddresses" envconfig:"MANAGE_IPS"` // handle ip-address setup of interface
} }

View File

@ -9,6 +9,7 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
// Manager offers a synchronized management interface to the real WireGuard interface.
type Manager struct { type Manager struct {
Cfg *Config Cfg *Config
wg *wgctrl.Client wg *wgctrl.Client
@ -25,8 +26,8 @@ func (m *Manager) Init() error {
return nil return nil
} }
func (m *Manager) GetDeviceInfo() (*wgtypes.Device, error) { func (m *Manager) GetDeviceInfo(device string) (*wgtypes.Device, error) {
dev, err := m.wg.Device(m.Cfg.DeviceName) dev, err := m.wg.Device(device)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "could not get WireGuard device") return nil, errors.Wrap(err, "could not get WireGuard device")
} }
@ -34,11 +35,11 @@ func (m *Manager) GetDeviceInfo() (*wgtypes.Device, error) {
return dev, nil return dev, nil
} }
func (m *Manager) GetPeerList() ([]wgtypes.Peer, error) { func (m *Manager) GetPeerList(device string) ([]wgtypes.Peer, error) {
m.mux.RLock() m.mux.RLock()
defer m.mux.RUnlock() defer m.mux.RUnlock()
dev, err := m.wg.Device(m.Cfg.DeviceName) dev, err := m.wg.Device(device)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "could not get WireGuard device") return nil, errors.Wrap(err, "could not get WireGuard device")
} }
@ -46,7 +47,7 @@ func (m *Manager) GetPeerList() ([]wgtypes.Peer, error) {
return dev.Peers, nil return dev.Peers, nil
} }
func (m *Manager) GetPeer(pubKey string) (*wgtypes.Peer, error) { func (m *Manager) GetPeer(device string, pubKey string) (*wgtypes.Peer, error) {
m.mux.RLock() m.mux.RLock()
defer m.mux.RUnlock() defer m.mux.RUnlock()
@ -55,7 +56,7 @@ func (m *Manager) GetPeer(pubKey string) (*wgtypes.Peer, error) {
return nil, errors.Wrap(err, "invalid public key") return nil, errors.Wrap(err, "invalid public key")
} }
peers, err := m.GetPeerList() peers, err := m.GetPeerList(device)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "could not get WireGuard peers") return nil, errors.Wrap(err, "could not get WireGuard peers")
} }
@ -69,11 +70,11 @@ func (m *Manager) GetPeer(pubKey string) (*wgtypes.Peer, error) {
return nil, errors.Errorf("could not find WireGuard peer: %s", pubKey) return nil, errors.Errorf("could not find WireGuard peer: %s", pubKey)
} }
func (m *Manager) AddPeer(cfg wgtypes.PeerConfig) error { func (m *Manager) AddPeer(device string, cfg wgtypes.PeerConfig) error {
m.mux.Lock() m.mux.Lock()
defer m.mux.Unlock() defer m.mux.Unlock()
err := m.wg.ConfigureDevice(m.Cfg.DeviceName, wgtypes.Config{Peers: []wgtypes.PeerConfig{cfg}}) err := m.wg.ConfigureDevice(device, wgtypes.Config{Peers: []wgtypes.PeerConfig{cfg}})
if err != nil { if err != nil {
return errors.Wrap(err, "could not configure WireGuard device") return errors.Wrap(err, "could not configure WireGuard device")
} }
@ -81,12 +82,12 @@ func (m *Manager) AddPeer(cfg wgtypes.PeerConfig) error {
return nil return nil
} }
func (m *Manager) UpdatePeer(cfg wgtypes.PeerConfig) error { func (m *Manager) UpdatePeer(device string, cfg wgtypes.PeerConfig) error {
m.mux.Lock() m.mux.Lock()
defer m.mux.Unlock() defer m.mux.Unlock()
cfg.UpdateOnly = true cfg.UpdateOnly = true
err := m.wg.ConfigureDevice(m.Cfg.DeviceName, wgtypes.Config{Peers: []wgtypes.PeerConfig{cfg}}) err := m.wg.ConfigureDevice(device, wgtypes.Config{Peers: []wgtypes.PeerConfig{cfg}})
if err != nil { if err != nil {
return errors.Wrap(err, "could not configure WireGuard device") return errors.Wrap(err, "could not configure WireGuard device")
} }
@ -94,7 +95,7 @@ func (m *Manager) UpdatePeer(cfg wgtypes.PeerConfig) error {
return nil return nil
} }
func (m *Manager) RemovePeer(pubKey string) error { func (m *Manager) RemovePeer(device string, pubKey string) error {
m.mux.Lock() m.mux.Lock()
defer m.mux.Unlock() defer m.mux.Unlock()
@ -108,7 +109,7 @@ func (m *Manager) RemovePeer(pubKey string) error {
Remove: true, Remove: true,
} }
err = m.wg.ConfigureDevice(m.Cfg.DeviceName, wgtypes.Config{Peers: []wgtypes.PeerConfig{peer}}) err = m.wg.ConfigureDevice(device, wgtypes.Config{Peers: []wgtypes.PeerConfig{peer}})
if err != nil { if err != nil {
return errors.Wrap(err, "could not configure WireGuard device") return errors.Wrap(err, "could not configure WireGuard device")
} }
@ -116,6 +117,6 @@ func (m *Manager) RemovePeer(pubKey string) error {
return nil return nil
} }
func (m *Manager) UpdateDevice(name string, cfg wgtypes.Config) error { func (m *Manager) UpdateDevice(device string, cfg wgtypes.Config) error {
return m.wg.ConfigureDevice(name, cfg) return m.wg.ConfigureDevice(device, cfg)
} }

View File

@ -11,10 +11,10 @@ import (
const DefaultMTU = 1420 const DefaultMTU = 1420
func (m *Manager) GetIPAddress() ([]string, error) { func (m *Manager) GetIPAddress(device string) ([]string, error) {
wgInterface, err := tenus.NewLinkFrom(m.Cfg.DeviceName) wgInterface, err := tenus.NewLinkFrom(device)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "could not retrieve WireGuard interface %s", m.Cfg.DeviceName) return nil, errors.Wrapf(err, "could not retrieve WireGuard interface %s", device)
} }
// Get golang net.interface // Get golang net.interface
@ -52,14 +52,14 @@ func (m *Manager) GetIPAddress() ([]string, error) {
return ipAddresses, nil return ipAddresses, nil
} }
func (m *Manager) SetIPAddress(cidrs []string) error { func (m *Manager) SetIPAddress(device string, cidrs []string) error {
wgInterface, err := tenus.NewLinkFrom(m.Cfg.DeviceName) wgInterface, err := tenus.NewLinkFrom(device)
if err != nil { if err != nil {
return errors.Wrapf(err, "could not retrieve WireGuard interface %s", m.Cfg.DeviceName) return errors.Wrapf(err, "could not retrieve WireGuard interface %s", device)
} }
// First remove existing IP addresses // First remove existing IP addresses
existingIPs, err := m.GetIPAddress() existingIPs, err := m.GetIPAddress(device)
if err != nil { if err != nil {
return errors.Wrap(err, "could not retrieve IP addresses") return errors.Wrap(err, "could not retrieve IP addresses")
} }
@ -89,10 +89,10 @@ func (m *Manager) SetIPAddress(cidrs []string) error {
return nil return nil
} }
func (m *Manager) GetMTU() (int, error) { func (m *Manager) GetMTU(device string) (int, error) {
wgInterface, err := tenus.NewLinkFrom(m.Cfg.DeviceName) wgInterface, err := tenus.NewLinkFrom(device)
if err != nil { if err != nil {
return 0, errors.Wrapf(err, "could not retrieve WireGuard interface %s", m.Cfg.DeviceName) return 0, errors.Wrapf(err, "could not retrieve WireGuard interface %s", device)
} }
// Get golang net.interface // Get golang net.interface
@ -104,10 +104,10 @@ func (m *Manager) GetMTU() (int, error) {
return iface.MTU, nil return iface.MTU, nil
} }
func (m *Manager) SetMTU(mtu int) error { func (m *Manager) SetMTU(device string, mtu int) error {
wgInterface, err := tenus.NewLinkFrom(m.Cfg.DeviceName) wgInterface, err := tenus.NewLinkFrom(device)
if err != nil { if err != nil {
return errors.Wrapf(err, "could not retrieve WireGuard interface %s", m.Cfg.DeviceName) return errors.Wrapf(err, "could not retrieve WireGuard interface %s", device)
} }
if mtu == 0 { if mtu == 0 {
@ -115,7 +115,7 @@ func (m *Manager) SetMTU(mtu int) error {
} }
if err := wgInterface.SetLinkMTU(mtu); err != nil { if err := wgInterface.SetLinkMTU(mtu); err != nil {
return errors.Wrapf(err, "could not set MTU on interface %s", m.Cfg.DeviceName) return errors.Wrapf(err, "could not set MTU on interface %s", device)
} }
return nil return nil

View File

@ -1,4 +1,4 @@
package server package wireguard
import ( import (
"bytes" "bytes"
@ -15,8 +15,6 @@ import (
"github.com/gin-gonic/gin/binding" "github.com/gin-gonic/gin/binding"
"github.com/go-playground/validator/v10" "github.com/go-playground/validator/v10"
"github.com/h44z/wg-portal/internal/common" "github.com/h44z/wg-portal/internal/common"
"github.com/h44z/wg-portal/internal/users"
"github.com/h44z/wg-portal/internal/wireguard"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/skip2/go-qrcode" "github.com/skip2/go-qrcode"
@ -66,7 +64,6 @@ func init() {
type Peer struct { type Peer struct {
Peer *wgtypes.Peer `gorm:"-"` // WireGuard peer Peer *wgtypes.Peer `gorm:"-"` // WireGuard peer
User *users.User `gorm:"-"` // user reference for the peer
Config string `gorm:"-"` Config string `gorm:"-"`
UID string `form:"uid" binding:"alphanum"` // uid for html identification UID string `form:"uid" binding:"alphanum"` // uid for html identification
@ -85,6 +82,7 @@ type Peer struct {
IPs []string `gorm:"-"` // The IPs of the client IPs []string `gorm:"-"` // The IPs of the client
PrivateKey string `form:"privkey" binding:"omitempty,base64"` PrivateKey string `form:"privkey" binding:"omitempty,base64"`
PublicKey string `gorm:"primaryKey" form:"pubkey" binding:"required,base64"` PublicKey string `gorm:"primaryKey" form:"pubkey" binding:"required,base64"`
DeviceName string `gorm:"index"`
DeactivatedAt *time.Time DeactivatedAt *time.Time
CreatedBy string CreatedBy string
@ -122,7 +120,7 @@ func (p Peer) GetConfig() wgtypes.PeerConfig {
} }
func (p Peer) GetConfigFile(device Device) ([]byte, error) { func (p Peer) GetConfigFile(device Device) ([]byte, error) {
tpl, err := template.New("client").Funcs(template.FuncMap{"StringsJoin": strings.Join}).Parse(wireguard.ClientCfgTpl) tpl, err := template.New("client").Funcs(template.FuncMap{"StringsJoin": strings.Join}).Parse(ClientCfgTpl)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "failed to parse client template") return nil, errors.Wrap(err, "failed to parse client template")
} }
@ -245,7 +243,7 @@ func (d Device) GetConfig() wgtypes.Config {
} }
func (d Device) GetConfigFile(peers []Peer) ([]byte, error) { func (d Device) GetConfigFile(peers []Peer) ([]byte, error) {
tpl, err := template.New("server").Funcs(template.FuncMap{"StringsJoin": strings.Join}).Parse(wireguard.DeviceCfgTpl) tpl, err := template.New("server").Funcs(template.FuncMap{"StringsJoin": strings.Join}).Parse(DeviceCfgTpl)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "failed to parse server template") return nil, errors.Wrap(err, "failed to parse server template")
} }
@ -271,63 +269,61 @@ func (d Device) GetConfigFile(peers []Peer) ([]byte, error) {
// //
type PeerManager struct { type PeerManager struct {
db *gorm.DB db *gorm.DB
wg *wireguard.Manager wg *Manager
users *users.Manager
} }
func NewPeerManager(cfg *common.Config, wg *wireguard.Manager, userDB *users.Manager) (*PeerManager, error) { func NewPeerManager(db *gorm.DB, wg *Manager) (*PeerManager, error) {
um := &PeerManager{wg: wg, users: userDB} um := &PeerManager{db: db, wg: wg}
var err error
um.db, err = users.GetDatabaseForConfig(&cfg.Database)
if err != nil {
return nil, errors.WithMessage(err, "failed to open peer database")
}
err = um.db.AutoMigrate(&Peer{}, &Device{}) if err := um.db.AutoMigrate(&Peer{}, &Device{}); err != nil {
if err != nil {
return nil, errors.WithMessage(err, "failed to migrate peer database") return nil, errors.WithMessage(err, "failed to migrate peer database")
} }
return um, nil return um, nil
} }
func (u *PeerManager) InitFromCurrentInterface() error { // InitFromPhysicalInterface read all WireGuard peers from the WireGuard interface configuration. If a peer does not
peers, err := u.wg.GetPeerList() // exist in the local database, it gets created.
if err != nil { func (m *PeerManager) InitFromPhysicalInterface() error {
return errors.Wrapf(err, "failed to get peer list") for _, deviceName := range m.wg.Cfg.DeviceNames {
} peers, err := m.wg.GetPeerList(deviceName)
device, err := u.wg.GetDeviceInfo() if err != nil {
if err != nil { return errors.Wrapf(err, "failed to get peer list for device %s", deviceName)
return errors.Wrapf(err, "failed to get device info")
}
var ipAddresses []string
var mtu int
if u.wg.Cfg.ManageIPAddresses {
if ipAddresses, err = u.wg.GetIPAddress(); err != nil {
return errors.Wrapf(err, "failed to get ip address")
} }
if mtu, err = u.wg.GetMTU(); err != nil { device, err := m.wg.GetDeviceInfo(deviceName)
return errors.Wrapf(err, "failed to get MTU") if err != nil {
return errors.Wrapf(err, "failed to get device info for device %s", deviceName)
}
var ipAddresses []string
var mtu int
if m.wg.Cfg.ManageIPAddresses {
if ipAddresses, err = m.wg.GetIPAddress(deviceName); err != nil {
return errors.Wrapf(err, "failed to get ip address for device %s", deviceName)
}
if mtu, err = m.wg.GetMTU(deviceName); err != nil {
return errors.Wrapf(err, "failed to get MTU for device %s", deviceName)
}
} }
}
// Check if entries already exist in database, if not create them // Check if entries already exist in database, if not create them
for _, peer := range peers { for _, peer := range peers {
if err := u.validateOrCreatePeer(peer); err != nil { if err := m.validateOrCreatePeer(deviceName, peer); err != nil {
return errors.WithMessagef(err, "failed to validate peer %s", peer.PublicKey) return errors.WithMessagef(err, "failed to validate peer %s for device %s", peer.PublicKey, deviceName)
}
}
if err := m.validateOrCreateDevice(*device, ipAddresses, mtu); err != nil {
return errors.WithMessagef(err, "failed to validate device %s", device.Name)
} }
}
if err := u.validateOrCreateDevice(*device, ipAddresses, mtu); err != nil {
return errors.WithMessagef(err, "failed to validate device %s", device.Name)
} }
return nil return nil
} }
func (u *PeerManager) validateOrCreatePeer(wgPeer wgtypes.Peer) error { // validateOrCreatePeer checks if the given WireGuard peer already exists in the database, if not, the peer entry will be created
func (m *PeerManager) validateOrCreatePeer(device string, wgPeer wgtypes.Peer) error {
peer := Peer{} peer := Peer{}
u.db.Where("public_key = ?", wgPeer.PublicKey.String()).FirstOrInit(&peer) m.db.Where("public_key = ?", wgPeer.PublicKey.String()).FirstOrInit(&peer)
if peer.PublicKey == "" { // peer not found, create if peer.PublicKey == "" { // peer not found, create
peer.UID = fmt.Sprintf("u%x", md5.Sum([]byte(wgPeer.PublicKey.String()))) peer.UID = fmt.Sprintf("u%x", md5.Sum([]byte(wgPeer.PublicKey.String())))
@ -347,8 +343,9 @@ func (u *PeerManager) validateOrCreatePeer(wgPeer wgtypes.Peer) error {
} }
peer.AllowedIPsStr = strings.Join(peer.AllowedIPs, ", ") peer.AllowedIPsStr = strings.Join(peer.AllowedIPs, ", ")
peer.IPsStr = strings.Join(peer.IPs, ", ") peer.IPsStr = strings.Join(peer.IPs, ", ")
peer.DeviceName = device
res := u.db.Create(&peer) res := m.db.Create(&peer)
if res.Error != nil { if res.Error != nil {
return errors.Wrapf(res.Error, "failed to create autodetected peer %s", peer.PublicKey) return errors.Wrapf(res.Error, "failed to create autodetected peer %s", peer.PublicKey)
} }
@ -357,9 +354,10 @@ func (u *PeerManager) validateOrCreatePeer(wgPeer wgtypes.Peer) error {
return nil return nil
} }
func (u *PeerManager) validateOrCreateDevice(dev wgtypes.Device, ipAddresses []string, mtu int) error { // validateOrCreateDevice checks if the given WireGuard device already exists in the database, if not, the peer entry will be created
func (m *PeerManager) validateOrCreateDevice(dev wgtypes.Device, ipAddresses []string, mtu int) error {
device := Device{} device := Device{}
u.db.Where("device_name = ?", dev.Name).FirstOrInit(&device) m.db.Where("device_name = ?", dev.Name).FirstOrInit(&device)
if device.PublicKey == "" { // device not found, create if device.PublicKey == "" { // device not found, create
device.PublicKey = dev.PublicKey.String() device.PublicKey = dev.PublicKey.String()
@ -369,12 +367,12 @@ func (u *PeerManager) validateOrCreateDevice(dev wgtypes.Device, ipAddresses []s
device.Mtu = 0 device.Mtu = 0
device.PersistentKeepalive = 16 // Default device.PersistentKeepalive = 16 // Default
device.IPsStr = strings.Join(ipAddresses, ", ") device.IPsStr = strings.Join(ipAddresses, ", ")
if mtu == wireguard.DefaultMTU { if mtu == DefaultMTU {
mtu = 0 mtu = 0
} }
device.Mtu = mtu device.Mtu = mtu
res := u.db.Create(&device) res := m.db.Create(&device)
if res.Error != nil { if res.Error != nil {
return errors.Wrapf(res.Error, "failed to create autodetected device") return errors.Wrapf(res.Error, "failed to create autodetected device")
} }
@ -383,21 +381,22 @@ func (u *PeerManager) validateOrCreateDevice(dev wgtypes.Device, ipAddresses []s
return nil return nil
} }
func (u *PeerManager) populatePeerData(peer *Peer) { // populatePeerData enriches the peer struct with WireGuard live data like last handshake, ...
func (m *PeerManager) populatePeerData(peer *Peer) {
peer.AllowedIPs = strings.Split(peer.AllowedIPsStr, ", ") peer.AllowedIPs = strings.Split(peer.AllowedIPsStr, ", ")
peer.IPs = strings.Split(peer.IPsStr, ", ") peer.IPs = strings.Split(peer.IPsStr, ", ")
// Set config file // Set config file
tmpCfg, _ := peer.GetConfigFile(u.GetDevice()) tmpCfg, _ := peer.GetConfigFile(m.GetDevice(peer.DeviceName))
peer.Config = string(tmpCfg) peer.Config = string(tmpCfg)
// set data from WireGuard interface // set data from WireGuard interface
peer.Peer, _ = u.wg.GetPeer(peer.PublicKey) peer.Peer, _ = m.wg.GetPeer(peer.DeviceName, peer.PublicKey)
peer.LastHandshake = "never" peer.LastHandshake = "never"
peer.LastHandshakeTime = "Never connected, or user is disabled." peer.LastHandshakeTime = "Never connected, or user is disabled."
if peer.Peer != nil { if peer.Peer != nil {
since := time.Since(peer.Peer.LastHandshakeTime) since := time.Since(peer.Peer.LastHandshakeTime)
sinceSeconds := int(since.Round(time.Second).Seconds()) sinceSeconds := int(since.Round(time.Second).Seconds())
sinceMinutes := int(sinceSeconds / 60) sinceMinutes := sinceSeconds / 60
sinceSeconds -= sinceMinutes * 60 sinceSeconds -= sinceMinutes * 60
if sinceMinutes > 2*10080 { // 2 weeks if sinceMinutes > 2*10080 { // 2 weeks
@ -410,49 +409,47 @@ func (u *PeerManager) populatePeerData(peer *Peer) {
peer.LastHandshakeTime = peer.Peer.LastHandshakeTime.Format(time.UnixDate) peer.LastHandshakeTime = peer.Peer.LastHandshakeTime.Format(time.UnixDate)
} }
peer.IsOnline = false peer.IsOnline = false
// set user data
peer.User = u.users.GetUser(peer.Email)
} }
func (u *PeerManager) populateDeviceData(device *Device) { // populateDeviceData enriches the device struct with WireGuard live data like interface information
func (m *PeerManager) populateDeviceData(device *Device) {
device.AllowedIPs = strings.Split(device.AllowedIPsStr, ", ") device.AllowedIPs = strings.Split(device.AllowedIPsStr, ", ")
device.IPs = strings.Split(device.IPsStr, ", ") device.IPs = strings.Split(device.IPsStr, ", ")
device.DNS = strings.Split(device.DNSStr, ", ") device.DNS = strings.Split(device.DNSStr, ", ")
// set data from WireGuard interface // set data from WireGuard interface
device.Interface, _ = u.wg.GetDeviceInfo() device.Interface, _ = m.wg.GetDeviceInfo(device.DeviceName)
} }
func (u *PeerManager) GetAllPeers() []Peer { func (m *PeerManager) GetAllPeers(device string) []Peer {
peers := make([]Peer, 0) peers := make([]Peer, 0)
u.db.Find(&peers) m.db.Where("device_name = ?", device).Find(&peers)
for i := range peers { for i := range peers {
u.populatePeerData(&peers[i]) m.populatePeerData(&peers[i])
} }
return peers return peers
} }
func (u *PeerManager) GetActivePeers() []Peer { func (m *PeerManager) GetActivePeers(device string) []Peer {
peers := make([]Peer, 0) peers := make([]Peer, 0)
u.db.Where("deactivated_at IS NULL").Find(&peers) m.db.Where("device_name = ? AND deactivated_at IS NULL", device).Find(&peers)
for i := range peers { for i := range peers {
u.populatePeerData(&peers[i]) m.populatePeerData(&peers[i])
} }
return peers return peers
} }
func (u *PeerManager) GetFilteredAndSortedPeers(sortKey, sortDirection, search string) []Peer { func (m *PeerManager) GetFilteredAndSortedPeers(device, sortKey, sortDirection, search string) []Peer {
peers := make([]Peer, 0) peers := make([]Peer, 0)
u.db.Find(&peers) m.db.Where("device_name = ?", device).Find(&peers)
filteredPeers := make([]Peer, 0, len(peers)) filteredPeers := make([]Peer, 0, len(peers))
for i := range peers { for i := range peers {
u.populatePeerData(&peers[i]) m.populatePeerData(&peers[i])
if search == "" || if search == "" ||
strings.Contains(peers[i].Email, search) || strings.Contains(peers[i].Email, search) ||
@ -499,12 +496,12 @@ func (u *PeerManager) GetFilteredAndSortedPeers(sortKey, sortDirection, search s
return filteredPeers return filteredPeers
} }
func (u *PeerManager) GetSortedPeersForEmail(sortKey, sortDirection, email string) []Peer { func (m *PeerManager) GetSortedPeersForEmail(sortKey, sortDirection, email string) []Peer {
peers := make([]Peer, 0) peers := make([]Peer, 0)
u.db.Where("email = ?", email).Find(&peers) m.db.Where("email = ?", email).Find(&peers)
for i := range peers { for i := range peers {
u.populatePeerData(&peers[i]) m.populatePeerData(&peers[i])
} }
sort.Slice(peers, func(i, j int) bool { sort.Slice(peers, func(i, j int) bool {
@ -544,42 +541,42 @@ func (u *PeerManager) GetSortedPeersForEmail(sortKey, sortDirection, email strin
return peers return peers
} }
func (u *PeerManager) GetDevice() Device { func (m *PeerManager) GetDevice(device string) Device {
devices := make([]Device, 0, 1) dev := Device{}
u.db.Find(&devices)
for i := range devices { m.db.Where("device_name = ?", device).First(&dev)
u.populateDeviceData(&devices[i]) m.populateDeviceData(&dev)
}
return devices[0] // use first device for now... more to come? return dev
} }
func (u *PeerManager) GetPeerByKey(publicKey string) Peer { func (m *PeerManager) GetPeerByKey(publicKey string) Peer {
peer := Peer{} peer := Peer{}
u.db.Where("public_key = ?", publicKey).FirstOrInit(&peer) m.db.Where("public_key = ?", publicKey).FirstOrInit(&peer)
u.populatePeerData(&peer) m.populatePeerData(&peer)
return peer return peer
} }
func (u *PeerManager) GetPeersByMail(mail string) []Peer { func (m *PeerManager) GetPeersByMail(mail string) []Peer {
var peers []Peer var peers []Peer
u.db.Where("email = ?", mail).Find(&peers) m.db.Where("email = ?", mail).Find(&peers)
for i := range peers { for i := range peers {
u.populatePeerData(&peers[i]) m.populatePeerData(&peers[i])
} }
return peers return peers
} }
func (u *PeerManager) CreatePeer(peer Peer) error { // ---- Database helpers -----
func (m *PeerManager) CreatePeer(peer Peer) error {
peer.UID = fmt.Sprintf("u%x", md5.Sum([]byte(peer.PublicKey))) peer.UID = fmt.Sprintf("u%x", md5.Sum([]byte(peer.PublicKey)))
peer.UpdatedAt = time.Now() peer.UpdatedAt = time.Now()
peer.CreatedAt = time.Now() peer.CreatedAt = time.Now()
peer.AllowedIPsStr = strings.Join(peer.AllowedIPs, ", ") peer.AllowedIPsStr = strings.Join(peer.AllowedIPs, ", ")
peer.IPsStr = strings.Join(peer.IPs, ", ") peer.IPsStr = strings.Join(peer.IPs, ", ")
res := u.db.Create(&peer) res := m.db.Create(&peer)
if res.Error != nil { if res.Error != nil {
logrus.Errorf("failed to create peer: %v", res.Error) logrus.Errorf("failed to create peer: %v", res.Error)
return errors.Wrap(res.Error, "failed to create peer") return errors.Wrap(res.Error, "failed to create peer")
@ -588,12 +585,12 @@ func (u *PeerManager) CreatePeer(peer Peer) error {
return nil return nil
} }
func (u *PeerManager) UpdatePeer(peer Peer) error { func (m *PeerManager) UpdatePeer(peer Peer) error {
peer.UpdatedAt = time.Now() peer.UpdatedAt = time.Now()
peer.AllowedIPsStr = strings.Join(peer.AllowedIPs, ", ") peer.AllowedIPsStr = strings.Join(peer.AllowedIPs, ", ")
peer.IPsStr = strings.Join(peer.IPs, ", ") peer.IPsStr = strings.Join(peer.IPs, ", ")
res := u.db.Save(&peer) res := m.db.Save(&peer)
if res.Error != nil { if res.Error != nil {
logrus.Errorf("failed to update peer: %v", res.Error) logrus.Errorf("failed to update peer: %v", res.Error)
return errors.Wrap(res.Error, "failed to update peer") return errors.Wrap(res.Error, "failed to update peer")
@ -602,8 +599,8 @@ func (u *PeerManager) UpdatePeer(peer Peer) error {
return nil return nil
} }
func (u *PeerManager) DeletePeer(peer Peer) error { func (m *PeerManager) DeletePeer(peer Peer) error {
res := u.db.Delete(&peer) res := m.db.Delete(&peer)
if res.Error != nil { if res.Error != nil {
logrus.Errorf("failed to delete peer: %v", res.Error) logrus.Errorf("failed to delete peer: %v", res.Error)
return errors.Wrap(res.Error, "failed to delete peer") return errors.Wrap(res.Error, "failed to delete peer")
@ -612,13 +609,13 @@ func (u *PeerManager) DeletePeer(peer Peer) error {
return nil return nil
} }
func (u *PeerManager) UpdateDevice(device Device) error { func (m *PeerManager) UpdateDevice(device Device) error {
device.UpdatedAt = time.Now() device.UpdatedAt = time.Now()
device.AllowedIPsStr = strings.Join(device.AllowedIPs, ", ") device.AllowedIPsStr = strings.Join(device.AllowedIPs, ", ")
device.IPsStr = strings.Join(device.IPs, ", ") device.IPsStr = strings.Join(device.IPs, ", ")
device.DNSStr = strings.Join(device.DNS, ", ") device.DNSStr = strings.Join(device.DNS, ", ")
res := u.db.Save(&device) res := m.db.Save(&device)
if res.Error != nil { if res.Error != nil {
logrus.Errorf("failed to update device: %v", res.Error) logrus.Errorf("failed to update device: %v", res.Error)
return errors.Wrap(res.Error, "failed to update device") return errors.Wrap(res.Error, "failed to update device")
@ -627,9 +624,11 @@ func (u *PeerManager) UpdateDevice(device Device) error {
return nil return nil
} }
func (u *PeerManager) GetAllReservedIps() ([]string, error) { // ---- IP helpers ----
func (m *PeerManager) GetAllReservedIps(device string) ([]string, error) {
reservedIps := make([]string, 0) reservedIps := make([]string, 0)
peers := u.GetAllPeers() peers := m.GetAllPeers(device)
for _, user := range peers { for _, user := range peers {
for _, cidr := range user.IPs { for _, cidr := range user.IPs {
if cidr == "" { if cidr == "" {
@ -643,8 +642,8 @@ func (u *PeerManager) GetAllReservedIps() ([]string, error) {
} }
} }
device := u.GetDevice() dev := m.GetDevice(device)
for _, cidr := range device.IPs { for _, cidr := range dev.IPs {
if cidr == "" { if cidr == "" {
continue continue
} }
@ -659,8 +658,8 @@ func (u *PeerManager) GetAllReservedIps() ([]string, error) {
return reservedIps, nil return reservedIps, nil
} }
func (u *PeerManager) IsIPReserved(cidr string) bool { func (m *PeerManager) IsIPReserved(device string, cidr string) bool {
reserved, err := u.GetAllReservedIps() reserved, err := m.GetAllReservedIps(device)
if err != nil { if err != nil {
return true // in case something failed, assume the ip is reserved return true // in case something failed, assume the ip is reserved
} }
@ -688,10 +687,10 @@ func (u *PeerManager) IsIPReserved(cidr string) bool {
} }
// GetAvailableIp search for an available ip in cidr against a list of reserved ips // GetAvailableIp search for an available ip in cidr against a list of reserved ips
func (u *PeerManager) GetAvailableIp(cidr string) (string, error) { func (m *PeerManager) GetAvailableIp(device string, cidr string) (string, error) {
reserved, err := u.GetAllReservedIps() reserved, err := m.GetAllReservedIps(device)
if err != nil { if err != nil {
return "", errors.WithMessage(err, "failed to get all reserved IP addresses") return "", errors.WithMessagef(err, "failed to get all reserved IP addresses for %s", device)
} }
ip, ipnet, err := net.ParseCIDR(cidr) ip, ipnet, err := net.ParseCIDR(cidr)
if err != nil { if err != nil {