Allows passing UAPI fd to service

This commit is contained in:
Mathias Hall-Andersen
2017-11-17 14:36:08 +01:00
parent 88801529fd
commit e1227d3af4
3 changed files with 111 additions and 67 deletions
+69 -48
View File
@@ -10,12 +10,12 @@ import (
)
const (
ipcErrorIO = -int64(unix.EIO)
ipcErrorProtocol = -int64(unix.EPROTO)
ipcErrorInvalid = -int64(unix.EINVAL)
ipcErrorPortInUse = -int64(unix.EADDRINUSE)
socketDirectory = "/var/run/wireguard"
socketName = "%s.sock"
ipcErrorIO = -int64(unix.EIO)
ipcErrorProtocol = -int64(unix.EPROTO)
ipcErrorInvalid = -int64(unix.EINVAL)
ipcErrorPortInUse = -int64(unix.EADDRINUSE)
socketDirectory = "/var/run/wireguard"
socketName = "%s.sock"
)
type UAPIListener struct {
@@ -50,49 +50,11 @@ func (l *UAPIListener) Addr() net.Addr {
return nil
}
func connectUnixSocket(path string) (net.Listener, error) {
func UAPIListen(name string, file *os.File) (net.Listener, error) {
// attempt inital connection
// wrap file in listener
listener, err := net.Listen("unix", path)
if err == nil {
return listener, nil
}
// check if active
_, err = net.Dial("unix", path)
if err == nil {
return nil, errors.New("Unix socket in use")
}
// attempt cleanup
err = os.Remove(path)
if err != nil {
return nil, err
}
return net.Listen("unix", path)
}
func NewUAPIListener(name string) (net.Listener, error) {
// check if path exist
err := os.MkdirAll(socketDirectory, 077)
if err != nil && !os.IsExist(err) {
return nil, err
}
// open UNIX socket
socketPath := path.Join(
socketDirectory,
fmt.Sprintf(socketName, name),
)
listener, err := connectUnixSocket(socketPath)
listener, err := net.FileListener(file)
if err != nil {
return nil, err
}
@@ -105,6 +67,11 @@ func NewUAPIListener(name string) (net.Listener, error) {
// watch for deletion of socket
socketPath := path.Join(
socketDirectory,
fmt.Sprintf(socketName, name),
)
uapi.inotifyFd, err = unix.InotifyInit()
if err != nil {
return nil, err
@@ -125,11 +92,12 @@ func NewUAPIListener(name string) (net.Listener, error) {
go func(l *UAPIListener) {
var buff [4096]byte
for {
unix.Read(uapi.inotifyFd, buff[:])
// start with lstat to avoid race condition
if _, err := os.Lstat(socketPath); os.IsNotExist(err) {
l.connErr <- err
return
}
unix.Read(uapi.inotifyFd, buff[:])
}
}(uapi)
@@ -148,3 +116,56 @@ func NewUAPIListener(name string) (net.Listener, error) {
return uapi, nil
}
func UAPIOpen(name string) (*os.File, error) {
// check if path exist
err := os.MkdirAll(socketDirectory, 0600)
if err != nil && !os.IsExist(err) {
return nil, err
}
// open UNIX socket
socketPath := path.Join(
socketDirectory,
fmt.Sprintf(socketName, name),
)
addr, err := net.ResolveUnixAddr("unix", socketPath)
if err != nil {
return nil, err
}
listener, err := func() (*net.UnixListener, error) {
// initial connection attempt
listener, err := net.ListenUnix("unix", addr)
if err == nil {
return listener, nil
}
// check if socket already active
_, err = net.Dial("unix", socketPath)
if err == nil {
return nil, errors.New("unix socket in use")
}
// cleanup & attempt again
err = os.Remove(socketPath)
if err != nil {
return nil, err
}
return net.ListenUnix("unix", addr)
}()
if err != nil {
return nil, err
}
return listener.File()
}