wg-portal/internal/wireguard/net.go

123 lines
3.0 KiB
Go
Raw Normal View History

2020-12-18 15:54:57 -05:00
package wireguard
import (
"fmt"
"net"
2021-02-26 16:17:04 -05:00
"github.com/pkg/errors"
2020-12-18 15:54:57 -05:00
"github.com/milosgajdos/tenus"
)
2021-02-08 16:56:02 -05:00
const DefaultMTU = 1420
2020-12-18 16:07:55 -05:00
2020-12-18 15:54:57 -05:00
func (m *Manager) GetIPAddress() ([]string, error) {
wgInterface, err := tenus.NewLinkFrom(m.Cfg.DeviceName)
if err != nil {
2021-02-26 16:17:04 -05:00
return nil, errors.Wrapf(err, "could not retrieve WireGuard interface %s", m.Cfg.DeviceName)
2020-12-18 15:54:57 -05:00
}
// Get golang net.interface
iface := wgInterface.NetInterface()
if iface == nil { // Not sure if this check is really necessary
2021-02-26 16:17:04 -05:00
return nil, errors.Wrap(err, "could not retrieve WireGuard net.interface")
2020-12-18 15:54:57 -05:00
}
addrs, err := iface.Addrs()
if err != nil {
2021-02-26 16:17:04 -05:00
return nil, errors.Wrap(err, "could not retrieve WireGuard ip addresses")
2020-12-18 15:54:57 -05:00
}
ipAddresses := make([]string, 0, len(addrs))
for _, addr := range addrs {
var ip net.IP
var mask net.IPMask
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
mask = v.Mask
case *net.IPAddr:
ip = v.IP
mask = ip.DefaultMask()
}
2021-02-08 16:56:02 -05:00
if ip == nil || mask == nil {
2020-12-18 15:54:57 -05:00
continue // something is wrong?
}
maskSize, _ := mask.Size()
cidr := fmt.Sprintf("%s/%d", ip.String(), maskSize)
ipAddresses = append(ipAddresses, cidr)
}
return ipAddresses, nil
}
func (m *Manager) SetIPAddress(cidrs []string) error {
wgInterface, err := tenus.NewLinkFrom(m.Cfg.DeviceName)
if err != nil {
2021-02-26 16:17:04 -05:00
return errors.Wrapf(err, "could not retrieve WireGuard interface %s", m.Cfg.DeviceName)
2020-12-18 15:54:57 -05:00
}
// First remove existing IP addresses
existingIPs, err := m.GetIPAddress()
if err != nil {
2021-02-26 16:17:04 -05:00
return errors.Wrap(err, "could not retrieve IP addresses")
2020-12-18 15:54:57 -05:00
}
for _, cidr := range existingIPs {
wgIp, wgIpNet, err := net.ParseCIDR(cidr)
if err != nil {
2021-02-26 16:17:04 -05:00
return errors.Wrapf(err, "unable to parse cidr %s", cidr)
2020-12-18 15:54:57 -05:00
}
if err := wgInterface.UnsetLinkIp(wgIp, wgIpNet); err != nil {
2021-02-26 16:17:04 -05:00
return errors.Wrapf(err, "failed to unset ip %s", cidr)
2020-12-18 15:54:57 -05:00
}
}
2021-02-08 16:56:02 -05:00
// Next set new IP addresses
2020-12-18 15:54:57 -05:00
for _, cidr := range cidrs {
wgIp, wgIpNet, err := net.ParseCIDR(cidr)
if err != nil {
2021-02-26 16:17:04 -05:00
return errors.Wrapf(err, "unable to parse cidr %s", cidr)
2020-12-18 15:54:57 -05:00
}
if err := wgInterface.SetLinkIp(wgIp, wgIpNet); err != nil {
2021-02-26 16:17:04 -05:00
return errors.Wrapf(err, "failed to set ip %s", cidr)
2020-12-18 15:54:57 -05:00
}
}
return nil
}
func (m *Manager) GetMTU() (int, error) {
wgInterface, err := tenus.NewLinkFrom(m.Cfg.DeviceName)
if err != nil {
2021-02-26 16:17:04 -05:00
return 0, errors.Wrapf(err, "could not retrieve WireGuard interface %s", m.Cfg.DeviceName)
2020-12-18 15:54:57 -05:00
}
// Get golang net.interface
iface := wgInterface.NetInterface()
if iface == nil { // Not sure if this check is really necessary
2021-02-26 16:17:04 -05:00
return 0, errors.Wrap(err, "could not retrieve WireGuard net.interface")
2020-12-18 15:54:57 -05:00
}
return iface.MTU, nil
}
func (m *Manager) SetMTU(mtu int) error {
wgInterface, err := tenus.NewLinkFrom(m.Cfg.DeviceName)
if err != nil {
2021-02-26 16:17:04 -05:00
return errors.Wrapf(err, "could not retrieve WireGuard interface %s", m.Cfg.DeviceName)
2020-12-18 15:54:57 -05:00
}
if mtu == 0 {
2021-02-08 16:56:02 -05:00
mtu = DefaultMTU
2020-12-18 15:54:57 -05:00
}
if err := wgInterface.SetLinkMTU(mtu); err != nil {
2021-02-26 16:17:04 -05:00
return errors.Wrapf(err, "could not set MTU on interface %s", m.Cfg.DeviceName)
2020-12-18 15:54:57 -05:00
}
return nil
}