wg-portal/internal/server/helper.go

202 lines
4.7 KiB
Go

package server
import (
"crypto/md5"
"fmt"
"io/ioutil"
"syscall"
"time"
"github.com/h44z/wg-portal/internal/common"
"github.com/pkg/errors"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
func (s *Server) PrepareNewPeer() (Peer, error) {
device := s.peers.GetDevice()
peer := Peer{}
peer.IsNew = true
peer.AllowedIPsStr = device.AllowedIPsStr
peer.IPs = make([]string, len(device.IPs))
for i := range device.IPs {
freeIP, err := s.peers.GetAvailableIp(device.IPs[i])
if err != nil {
return Peer{}, err
}
peer.IPs[i] = freeIP
}
peer.IPsStr = common.ListToString(peer.IPs)
psk, err := wgtypes.GenerateKey()
if err != nil {
return Peer{}, err
}
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
return Peer{}, err
}
peer.PresharedKey = psk.String()
peer.PrivateKey = key.String()
peer.PublicKey = key.PublicKey().String()
peer.UID = fmt.Sprintf("u%x", md5.Sum([]byte(peer.PublicKey)))
return peer, nil
}
func (s *Server) CreatePeerByEmail(email, identifierSuffix string, disabled bool) error {
user, err := s.users.GetOrCreateUser(email)
if err != nil {
return errors.WithMessagef(err, "failed to load/create related user %s", email)
}
device := s.peers.GetDevice()
peer := Peer{}
peer.User = user
peer.AllowedIPsStr = device.AllowedIPsStr
peer.IPs = make([]string, len(device.IPs))
for i := range device.IPs {
freeIP, err := s.peers.GetAvailableIp(device.IPs[i])
if err != nil {
return err
}
peer.IPs[i] = freeIP
}
peer.IPsStr = common.ListToString(peer.IPs)
psk, err := wgtypes.GenerateKey()
if err != nil {
return err
}
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
return err
}
peer.PresharedKey = psk.String()
peer.PrivateKey = key.String()
peer.PublicKey = key.PublicKey().String()
peer.UID = fmt.Sprintf("u%x", md5.Sum([]byte(peer.PublicKey)))
peer.Email = email
peer.Identifier = fmt.Sprintf("%s %s (%s)", user.Firstname, user.Lastname, identifierSuffix)
now := time.Now()
if disabled {
peer.DeactivatedAt = &now
}
return s.CreatePeer(peer)
}
func (s *Server) CreatePeer(peer Peer) error {
device := s.peers.GetDevice()
peer.AllowedIPsStr = device.AllowedIPsStr
if peer.IPs == nil || len(peer.IPs) == 0 {
peer.IPs = make([]string, len(device.IPs))
for i := range device.IPs {
freeIP, err := s.peers.GetAvailableIp(device.IPs[i])
if err != nil {
return err
}
peer.IPs[i] = freeIP
}
peer.IPsStr = common.ListToString(peer.IPs)
}
if peer.PrivateKey == "" { // if private key is empty create a new one
psk, err := wgtypes.GenerateKey()
if err != nil {
return err
}
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
return err
}
peer.PresharedKey = psk.String()
peer.PrivateKey = key.String()
peer.PublicKey = key.PublicKey().String()
}
peer.UID = fmt.Sprintf("u%x", md5.Sum([]byte(peer.PublicKey)))
// Create WireGuard interface
if peer.DeactivatedAt == nil {
if err := s.wg.AddPeer(peer.GetConfig()); err != nil {
return err
}
}
// Create in database
if err := s.peers.CreatePeer(peer); err != nil {
return err
}
return s.WriteWireGuardConfigFile()
}
func (s *Server) UpdatePeer(peer Peer, updateTime time.Time) error {
currentPeer := s.peers.GetPeerByKey(peer.PublicKey)
// Update WireGuard device
var err error
switch {
case peer.DeactivatedAt == &updateTime:
err = s.wg.RemovePeer(peer.PublicKey)
case peer.DeactivatedAt == nil && currentPeer.Peer != nil:
err = s.wg.UpdatePeer(peer.GetConfig())
case peer.DeactivatedAt == nil && currentPeer.Peer == nil:
err = s.wg.AddPeer(peer.GetConfig())
}
if err != nil {
return err
}
// Update in database
if err := s.peers.UpdatePeer(peer); err != nil {
return err
}
return s.WriteWireGuardConfigFile()
}
func (s *Server) DeletePeer(peer Peer) error {
// Delete WireGuard peer
if err := s.wg.RemovePeer(peer.PublicKey); err != nil {
return err
}
// Delete in database
if err := s.peers.DeletePeer(peer); err != nil {
return err
}
return s.WriteWireGuardConfigFile()
}
func (s *Server) RestoreWireGuardInterface() error {
activePeers := s.peers.GetActivePeers()
for i := range activePeers {
if activePeers[i].Peer == nil {
if err := s.wg.AddPeer(activePeers[i].GetConfig()); err != nil {
return err
}
}
}
return nil
}
func (s *Server) WriteWireGuardConfigFile() error {
if s.config.WG.WireGuardConfig == "" {
return nil // writing disabled
}
if err := syscall.Access(s.config.WG.WireGuardConfig, syscall.O_RDWR); err != nil {
return err
}
device := s.peers.GetDevice()
cfg, err := device.GetConfigFile(s.peers.GetActivePeers())
if err != nil {
return err
}
if err := ioutil.WriteFile(s.config.WG.WireGuardConfig, cfg, 0644); err != nil {
return err
}
return nil
}