65 Commits

Author SHA1 Message Date
Jason A. Donenfeld
f87e87af0d version: bump snapshot
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-10-16 23:27:13 -06:00
Jason A. Donenfeld
ba9e364dab wintun: remove memmod option for dll loading
Only wireguard-windows used this, and it's moving to wgnt exclusively.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-10-16 22:49:38 -06:00
Jason A. Donenfeld
dfd688b6aa global: remove old-style build tags
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-10-12 12:02:10 -06:00
Jason A. Donenfeld
c01d52b66a global: add newer-style build tags
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-10-12 11:46:53 -06:00
Jason A. Donenfeld
82d2aa87aa wintun: use new swdevice-based API for upcoming Wintun 0.14
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-10-12 00:26:46 -06:00
Jason A. Donenfeld
982d5d2e84 conn,wintun: use unsafe.Slice instead of unsafeSlice
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-10-11 14:57:53 -06:00
Jason A. Donenfeld
642a56e165 memmod: import from wireguard-windows
We'll eventually be getting rid of it here, but keep it sync'd up for
now.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-10-11 14:53:36 -06:00
Jason A. Donenfeld
bb745b2ea3 rwcancel: use unix.Poll again but bump x/sys so it uses ppoll under the hood
This reverts commit fcc601dbf0 but then
bumps go.mod.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-09-27 14:19:15 -06:00
Jason A. Donenfeld
fcc601dbf0 rwcancel: use ppoll on Linux for Android
This is a temporary measure while we wait for
https://go-review.googlesource.com/c/sys/+/352310 to land.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-09-26 17:16:38 -06:00
Tobias Klauser
217ac1016b tun: make operateonfd.go build tags more specific
(*NativeTun).operateOnFd is only used on darwin and freebsd. Adjust the
build tags accordingly.

Signed-off-by: Tobias Klauser <tklauser@distanz.ch>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-09-23 09:54:01 -06:00
Tobias Klauser
eae5e0f3a3 tun: avoid leaking sock fd in CreateTUN error cases
At these points, the socket file descriptor is not yet wrapped in an
*os.File, so it needs to be closed explicitly on error.

Signed-off-by: Tobias Klauser <tklauser@distanz.ch>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-09-23 09:53:49 -06:00
Jason A. Donenfeld
2ef39d4754 global: add new go 1.17 build comments
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-09-05 16:00:43 +02:00
Jason A. Donenfeld
3957e9b9dd memmod: register exception handler tables
Otherwise recent WDK binaries fail on ARM64, where an exception handler
is used for trapping an illegal instruction when ARMv8.1 atomics are
being tested for functionality.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-08-05 14:56:48 +02:00
Jason A. Donenfeld
bad6caeb82 memmod: fix protected delayed load the right way
The reason this was failing before is that dloadsup.h's
DloadObtainSection was doing a linear search of sections to find which
header corresponds with the IMAGE_DELAYLOAD_DESCRIPTOR section, and we
were stupidly overwriting the VirtualSize field, so the linear search
wound up matching the .text section, which then it found to not be
marked writable and failed with FAST_FAIL_DLOAD_PROTECTION_FAILURE.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-07-29 01:27:40 +02:00
Jason A. Donenfeld
c89f5ca665 memmod: disable protected delayed load for now
Probably a bad idea, but we don't currently support it, and those huge
windows.NewCallback trampolines make juicer targets anyway.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-07-29 01:13:03 +02:00
Jason A. Donenfeld
15b24b6179 ipc: allow admins but require high integrity label
Might be more reasonable.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-06-24 17:01:02 +02:00
Jason A. Donenfeld
f9b48a961c device: zero out allowedip node pointers when removing
This should make it a bit easier for the garbage collector.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-06-04 16:33:28 +02:00
Jason A. Donenfeld
d0cf96114f device: limit allowedip fuzzer a to 4 times through
Trying this for every peer winds up being very slow and precludes it
from acceptable runtime in the CI, so reduce this to 4.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-06-03 18:22:50 +02:00
Jason A. Donenfeld
841756e328 device: simplify allowedips lookup signature
The inliner should handle this for us.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-06-03 16:29:43 +02:00
Jason A. Donenfeld
c382222eab device: remove nodes by peer in O(1) instead of O(n)
Now that we have parent pointers hooked up, we can simply go right to
the node and remove it in place, rather than having to recursively walk
the entire trie.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-06-03 16:29:43 +02:00
Jason A. Donenfeld
b41f4cc768 device: remove recursion from insertion and connect parent pointers
This makes the insertion algorithm a bit more efficient, while also now
taking on the additional task of connecting up parent pointers. This
will be handy in the following commit.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-06-03 15:08:42 +02:00
Jason A. Donenfeld
4a57024b94 device: reduce size of trie struct
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-06-03 13:51:03 +02:00
Josh Bleecher Snyder
64cb82f2b3 go.mod: bump golang.org/x/sys again
To pick up https://go-review.googlesource.com/c/sys/+/307129.

Signed-off-by: Josh Bleecher Snyder <josharian@gmail.com>
2021-05-25 16:34:54 +02:00
Jason A. Donenfeld
c27ff9b9f6 device: allow reducing queue constants on iOS
Heavier network extensions might require the wireguard-go component to
use less ram, so let users of this reduce these as needed.

At some point we'll put this behind a configuration method of sorts, but
for now, just expose the consts as vars.

Requested-by: Josh Bleecher Snyder <josh@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-22 01:00:51 +02:00
Jason A. Donenfeld
99e8b4ba60 tun: linux: account for interface removal from outside
On Linux we can run `ip link del wg0`, in which case the fd becomes
stale, and we should exit. Since this is an intentional action, don't
treat it as an error.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-20 18:26:01 +02:00
Jason A. Donenfeld
bd83f0ac99 conn: linux: protect read fds
The -1 protection was removed and the wrong error was returned, causing
us to read from a bogus fd. As well, remove the useless closures that
aren't doing anything, since this is all synchronized anyway.

Fixes: 10533c3 ("all: make conn.Bind.Open return a slice of receive functions")
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-20 18:09:55 +02:00
Jason A. Donenfeld
50d779833e rwcancel: use ordinary os.ErrClosed instead of custom error
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-20 17:56:36 +02:00
Jason A. Donenfeld
a9b377e9e1 rwcancel: use poll instead of select
Suggested-by: Lennart Poettering <lennart@poettering.net>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-20 17:42:34 +02:00
Jason A. Donenfeld
9087e444e6 device: optimize Peer.String even more
This reduces the allocation, branches, and amount of base64 encoding.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-18 17:43:53 +02:00
Josh Bleecher Snyder
25ad08a591 device: optimize Peer.String
Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-05-14 00:37:30 +02:00
Jason A. Donenfeld
5846b62283 conn: windows: set count=0 on retry
When retrying, if count is not 0, we forget to dequeue another request,
and so the ring fills up and errors out.

Reported-by: Sascha Dierberg <dierberg@dresearch-fe.de>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-11 16:47:17 +02:00
Jason A. Donenfeld
9844c74f67 main: replace crlf on windows in fmt test
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-10 22:23:32 +02:00
Jason A. Donenfeld
4e9e5dad09 main: check that code is formatted in unit test
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-10 17:48:26 +02:00
Jason A. Donenfeld
39e0b6dade tun: format
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-07 12:21:27 +02:00
Jason A. Donenfeld
7121927b87 device: add ID to repeated routines
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-07 12:21:21 +02:00
Jason A. Donenfeld
326aec10af device: remove unusual ... in messages
We dont use ... in any other present progressive messages except these.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-07 12:17:41 +02:00
Jason A. Donenfeld
efb8818550 device: avoid verbose log line during ordinary shutdown sequence
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-07 09:39:06 +02:00
Jason A. Donenfeld
69b39db0b4 tun: windows: set event before waiting
In 097af6e ("tun: windows: protect reads from closing") we made sure no
functions are running when End() is called, to avoid a UaF. But we still
need to kick that event somehow, so that Read() is allowed to exit, in
order to release the lock. So this commit calls SetEvent, while moving
the closing boolean to be atomic so it can be modified without locks,
and then moves to a WaitGroup for the RCU-like pattern.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-07 09:26:24 +02:00
Jason A. Donenfeld
db733ccd65 tun: windows: rearrange struct to avoid alignment trap on 32bit
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-07 09:19:00 +02:00
Jason A. Donenfeld
a7aec4449f tun: windows: check alignment in unit test
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-07 09:15:50 +02:00
Josh Bleecher Snyder
60a26371f4 device: log all errors received by RoutineReceiveIncoming
When debugging, it's useful to know why a receive func exited.

We were already logging that, but only in the "death spiral" case.
Move the logging up, to capture it always.
Reduce the verbosity, since it is not an error case any more.
Put the receive func name in the log line.

Signed-off-by: Josh Bleecher Snyder <josharian@gmail.com>
2021-05-06 11:22:13 +02:00
Jason A. Donenfeld
a544776d70 tun/netstack: update go mod and remove GSO argument
Reported-by: John Xiong <xiaoyang1258@yeah.net>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-06 11:07:26 +02:00
Jason A. Donenfeld
69a42a4eef tun: windows: send MTU update when forced MTU changes
Otherwise the padding doesn't get updated.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-05 11:42:45 +02:00
Jason A. Donenfeld
097af6e135 tun: windows: protect reads from closing
The code previously used the old errors channel for checking, rather
than the simpler boolean, which caused issues on shutdown, since the
errors channel was meaningless. However, looking at this exposed a more
basic problem: Close() and all the other functions that check the closed
boolean can race. So protect with a basic RW lock, to ensure that
Close() waits for all pending operations to complete.

Reported-by: Joshua Sjoding <joshua.sjoding@scjalliance.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-26 22:22:45 -04:00
Jason A. Donenfeld
8246d251ea conn: windows: do not error out when receiving UDP jumbogram
If we receive a large UDP packet, don't return an error to receive.go,
which then terminates the receive loop. Instead, simply retry.

Considering Winsock's general finickiness, we might consider other
places where an attacker on the wire can generate error conditions like
this.

Reported-by: Sascha Dierberg <sascha.dierberg@gmail.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-26 22:07:03 -04:00
Jason A. Donenfeld
c9db4b7aaa version: bump snapshot
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-24 13:07:27 -04:00
Jason A. Donenfeld
3625f8d284 tun: freebsd: avoid OOB writes
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-19 15:10:23 -06:00
Jason A. Donenfeld
0687dc06c8 tun: freebsd: become controlling process when reopening tun FD
When we pass the TUN FD to the child, we have to call TUNSIFPID;
otherwise when we close the device, we get a splat in dmesg.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-19 15:02:44 -06:00
Jason A. Donenfeld
71aefa374d tun: freebsd: restructure and cleanup
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-19 14:54:59 -06:00
Jason A. Donenfeld
3d3e30beb8 tun: freebsd: remove horrific hack for getting tunnel name
As of FreeBSD 12.1, there's TUNGIFNAME.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-19 12:03:16 -06:00
Jason A. Donenfeld
b0e5b19969 tun: freebsd: set IFF_MULTICAST for routing daemons
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-18 20:09:04 -06:00
Jason A. Donenfeld
3988821442 main: print kernel warning on OpenBSD and FreeBSD too
More kernels!

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-15 23:35:45 -06:00
Jason A. Donenfeld
c7cd2c9eab device: don't defer unlocking from loop
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-12 16:19:35 -06:00
Jason A. Donenfeld
54dbe2471f conn: reconstruct v4 vs v6 receive function based on symtab
This is kind of gross but it's better than the alternatives.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-12 15:35:32 -06:00
Kristupas Antanavičius
d2fd0c0cc0 device: allocate new buffer in receive death spiral
Note: this bug is "hidden" by avoiding "death spiral" code path by
6228659 ("device: handle broader range of errors in RoutineReceiveIncoming").

If the code reached "death spiral" mechanism, there would be multiple
double frees happening. This results in a deadlock on iOS, because the
pools are fixed size and goroutine might stop until somebody makes
space in the pool.

This was almost 100% repro on the new ARM Macbooks:

- Build with 'ios' tag for Mac. This will enable bounded pools.
- Somehow call device.IpcSet at least couple of times (update config)
- device.BindUpdate() would be triggered
- RoutineReceiveIncoming would enter "death spiral".
- RoutineReceiveIncoming would stall on double free (pool is already
  full)
- The stuck routine would deadlock 'device.closeBindLocked()' function
  on line 'netc.stopping.Wait()'

Signed-off-by: Kristupas Antanavičius <kristupas.antanavicius@nordsec.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-12 11:14:53 -06:00
Jason A. Donenfeld
5f6bbe4ae8 conn: windows: reset ring to starting position after free
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-09 18:09:41 -06:00
Jason A. Donenfeld
75526d6071 conn: windows: compare head and tail properly
By not comparing these with the modulo, the ring became nearly never
full, resulting in completion queue buffers filling up prematurely.

Reported-by: Joshua Sjoding <joshua.sjoding@scjalliance.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-09 14:26:08 -06:00
Jason A. Donenfeld
fbf97502cf winrio: test that IOCP-based RIO is supported
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-09 14:26:08 -06:00
Josh Bleecher Snyder
10533c3e73 all: make conn.Bind.Open return a slice of receive functions
Instead of hard-coding exactly two sources from which
to receive packets (an IPv4 source and an IPv6 source),
allow the conn.Bind to specify a set of sources.

Beneficial consequences:

* If there's no IPv6 support on a system,
  conn.Bind.Open can choose not to return a receive function for it,
  which is simpler than tracking that state in the bind.
  This simplification removes existing data races from both
  conn.StdNetBind and bindtest.ChannelBind.
* If there are more than two sources on a system,
  the conn.Bind no longer needs to add a separate muxing layer.

Signed-off-by: Josh Bleecher Snyder <josharian@gmail.com>
2021-04-02 11:07:08 -06:00
Jason A. Donenfeld
8ed83e0427 conn: winrio: pass key parameter into struct
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-02 10:36:41 -06:00
Josh Bleecher Snyder
6228659a91 device: handle broader range of errors in RoutineReceiveIncoming
RoutineReceiveIncoming exits immediately on net.ErrClosed,
but not on other errors. However, for errors that are known
to be permanent, such as syscall.EAFNOSUPPORT,
we may as well exit immediately instead of retrying.

This considerably speeds up the package device tests right now,
because the Bind sometimes (incorrectly) returns syscall.EAFNOSUPPORT
instead of net.ErrClosed.

Signed-off-by: Josh Bleecher Snyder <josharian@gmail.com>
2021-03-30 12:41:43 -07:00
Josh Bleecher Snyder
517f0703f5 conn: document retry loop in StdNetBind.Open
It's not obvious on a first read what the loop is doing.

Signed-off-by: Josh Bleecher Snyder <josharian@gmail.com>
2021-03-30 12:09:38 -07:00
Josh Bleecher Snyder
204140016a conn: use local ipvN vars in StdNetBind.Open
This makes it clearer that they are fresh on each attempt,
and avoids the bookkeeping required to clearing them on failure.

Also, remove an unnecessary err != nil.

Signed-off-by: Josh Bleecher Snyder <josharian@gmail.com>
2021-03-30 12:09:38 -07:00
Josh Bleecher Snyder
822f5a6d70 conn: unify code in StdNetBind.Send
The sending code is identical for ipv4 and ipv6;
select the conn, then use it.

Signed-off-by: Josh Bleecher Snyder <josharian@gmail.com>
2021-03-30 12:09:32 -07:00
Josh Bleecher Snyder
02e419ed8a device: rename unsafeCloseBind to closeBindLocked
And document a bit.
This name is more idiomatic.

Signed-off-by: Josh Bleecher Snyder <josharian@gmail.com>
2021-03-30 12:07:12 -07:00
67 changed files with 1026 additions and 2258 deletions

View File

@@ -10,7 +10,7 @@ MAKEFLAGS += --no-print-directory
generate-version-and-build: generate-version-and-build:
@export GIT_CEILING_DIRECTORIES="$(realpath $(CURDIR)/..)" && \ @export GIT_CEILING_DIRECTORIES="$(realpath $(CURDIR)/..)" && \
tag="$$(git describe --dirty 2>/dev/null)" && \ tag="$$(git describe --dirty 2>/dev/null)" && \
ver="$$(printf 'package main\nconst Version = "%s"\n' "$$tag")" && \ ver="$$(printf 'package main\n\nconst Version = "%s"\n' "$$tag")" && \
[ "$$(cat version.go 2>/dev/null)" != "$$ver" ] && \ [ "$$(cat version.go 2>/dev/null)" != "$$ver" ] && \
echo "$$ver" > version.go && \ echo "$$ver" > version.go && \
git update-index --assume-unchanged version.go || true git update-index --assume-unchanged version.go || true

View File

@@ -55,10 +55,11 @@ func (endpoint *LinuxSocketEndpoint) dst6() *unix.SockaddrInet6 {
// LinuxSocketBind uses sendmsg and recvmsg to implement a full bind with sticky sockets on Linux. // LinuxSocketBind uses sendmsg and recvmsg to implement a full bind with sticky sockets on Linux.
type LinuxSocketBind struct { type LinuxSocketBind struct {
sock4 int // mu guards sock4 and sock6 and the associated fds.
sock6 int // As long as someone holds mu (read or write), the associated fds are valid.
lastMark uint32 mu sync.RWMutex
closing sync.RWMutex sock4 int
sock6 int
} }
func NewLinuxSocketBind() Bind { return &LinuxSocketBind{sock4: -1, sock6: -1} } func NewLinuxSocketBind() Bind { return &LinuxSocketBind{sock4: -1, sock6: -1} }
@@ -102,54 +103,67 @@ func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) {
return nil, errors.New("invalid IP address") return nil, errors.New("invalid IP address")
} }
func (bind *LinuxSocketBind) Open(port uint16) (uint16, error) { func (bind *LinuxSocketBind) Open(port uint16) ([]ReceiveFunc, uint16, error) {
bind.mu.Lock()
defer bind.mu.Unlock()
var err error var err error
var newPort uint16 var newPort uint16
var tries int var tries int
if bind.sock4 != -1 || bind.sock6 != -1 { if bind.sock4 != -1 || bind.sock6 != -1 {
return 0, ErrBindAlreadyOpen return nil, 0, ErrBindAlreadyOpen
} }
originalPort := port originalPort := port
again: again:
port = originalPort port = originalPort
var sock4, sock6 int
// Attempt ipv6 bind, update port if successful. // Attempt ipv6 bind, update port if successful.
bind.sock6, newPort, err = create6(port) sock6, newPort, err = create6(port)
if err != nil { if err != nil {
if err != syscall.EAFNOSUPPORT { if !errors.Is(err, syscall.EAFNOSUPPORT) {
return 0, err return nil, 0, err
} }
} else { } else {
port = newPort port = newPort
} }
// Attempt ipv4 bind, update port if successful. // Attempt ipv4 bind, update port if successful.
bind.sock4, newPort, err = create4(port) sock4, newPort, err = create4(port)
if err != nil { if err != nil {
if originalPort == 0 && err == syscall.EADDRINUSE && tries < 100 { if originalPort == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
unix.Close(bind.sock6) unix.Close(sock6)
tries++ tries++
goto again goto again
} }
if err != syscall.EAFNOSUPPORT { if !errors.Is(err, syscall.EAFNOSUPPORT) {
unix.Close(bind.sock6) unix.Close(sock6)
return 0, err return nil, 0, err
} }
} else { } else {
port = newPort port = newPort
} }
if bind.sock4 == -1 && bind.sock6 == -1 { var fns []ReceiveFunc
return 0, syscall.EAFNOSUPPORT if sock4 != -1 {
bind.sock4 = sock4
fns = append(fns, bind.receiveIPv4)
} }
return port, nil if sock6 != -1 {
bind.sock6 = sock6
fns = append(fns, bind.receiveIPv6)
}
if len(fns) == 0 {
return nil, 0, syscall.EAFNOSUPPORT
}
return fns, port, nil
} }
func (bind *LinuxSocketBind) SetMark(value uint32) error { func (bind *LinuxSocketBind) SetMark(value uint32) error {
bind.closing.RLock() bind.mu.RLock()
defer bind.closing.RUnlock() defer bind.mu.RUnlock()
if bind.sock6 != -1 { if bind.sock6 != -1 {
err := unix.SetsockoptInt( err := unix.SetsockoptInt(
@@ -177,21 +191,24 @@ func (bind *LinuxSocketBind) SetMark(value uint32) error {
} }
} }
bind.lastMark = value
return nil return nil
} }
func (bind *LinuxSocketBind) Close() error { func (bind *LinuxSocketBind) Close() error {
var err1, err2 error // Take a readlock to shut down the sockets...
bind.closing.RLock() bind.mu.RLock()
if bind.sock6 != -1 { if bind.sock6 != -1 {
unix.Shutdown(bind.sock6, unix.SHUT_RDWR) unix.Shutdown(bind.sock6, unix.SHUT_RDWR)
} }
if bind.sock4 != -1 { if bind.sock4 != -1 {
unix.Shutdown(bind.sock4, unix.SHUT_RDWR) unix.Shutdown(bind.sock4, unix.SHUT_RDWR)
} }
bind.closing.RUnlock() bind.mu.RUnlock()
bind.closing.Lock() // ...and a write lock to close the fd.
// This ensures that no one else is using the fd.
bind.mu.Lock()
defer bind.mu.Unlock()
var err1, err2 error
if bind.sock6 != -1 { if bind.sock6 != -1 {
err1 = unix.Close(bind.sock6) err1 = unix.Close(bind.sock6)
bind.sock6 = -1 bind.sock6 = -1
@@ -200,7 +217,6 @@ func (bind *LinuxSocketBind) Close() error {
err2 = unix.Close(bind.sock4) err2 = unix.Close(bind.sock4)
bind.sock4 = -1 bind.sock4 = -1
} }
bind.closing.Unlock()
if err1 != nil { if err1 != nil {
return err1 return err1
@@ -208,46 +224,35 @@ func (bind *LinuxSocketBind) Close() error {
return err2 return err2
} }
func (bind *LinuxSocketBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { func (bind *LinuxSocketBind) receiveIPv4(buf []byte) (int, Endpoint, error) {
bind.closing.RLock() bind.mu.RLock()
defer bind.closing.RUnlock() defer bind.mu.RUnlock()
var end LinuxSocketEndpoint
if bind.sock6 == -1 {
return 0, nil, net.ErrClosed
}
n, err := receive6(
bind.sock6,
buff,
&end,
)
return n, &end, err
}
func (bind *LinuxSocketBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
bind.closing.RLock()
defer bind.closing.RUnlock()
var end LinuxSocketEndpoint
if bind.sock4 == -1 { if bind.sock4 == -1 {
return 0, nil, net.ErrClosed return 0, nil, net.ErrClosed
} }
n, err := receive4( var end LinuxSocketEndpoint
bind.sock4, n, err := receive4(bind.sock4, buf, &end)
buff, return n, &end, err
&end, }
)
func (bind *LinuxSocketBind) receiveIPv6(buf []byte) (int, Endpoint, error) {
bind.mu.RLock()
defer bind.mu.RUnlock()
if bind.sock6 == -1 {
return 0, nil, net.ErrClosed
}
var end LinuxSocketEndpoint
n, err := receive6(bind.sock6, buf, &end)
return n, &end, err return n, &end, err
} }
func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error { func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error {
bind.closing.RLock()
defer bind.closing.RUnlock()
nend, ok := end.(*LinuxSocketEndpoint) nend, ok := end.(*LinuxSocketEndpoint)
if !ok { if !ok {
return ErrWrongEndpointType return ErrWrongEndpointType
} }
bind.mu.RLock()
defer bind.mu.RUnlock()
if !nend.isV6 { if !nend.isV6 {
if bind.sock4 == -1 { if bind.sock4 == -1 {
return net.ErrClosed return net.ErrClosed

View File

@@ -8,6 +8,7 @@ package conn
import ( import (
"errors" "errors"
"net" "net"
"sync"
"syscall" "syscall"
) )
@@ -16,6 +17,7 @@ import (
// It uses the Go's net package to implement networking. // It uses the Go's net package to implement networking.
// See LinuxSocketBind for a proper implementation on the Linux platform. // See LinuxSocketBind for a proper implementation on the Linux platform.
type StdNetBind struct { type StdNetBind struct {
mu sync.Mutex // protects following fields
ipv4 *net.UDPConn ipv4 *net.UDPConn
ipv6 *net.UDPConn ipv6 *net.UDPConn
blackhole4 bool blackhole4 bool
@@ -81,44 +83,58 @@ func listenNet(network string, port int) (*net.UDPConn, int, error) {
return conn, uaddr.Port, nil return conn, uaddr.Port, nil
} }
func (bind *StdNetBind) Open(uport uint16) (uint16, error) { func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
bind.mu.Lock()
defer bind.mu.Unlock()
var err error var err error
var tries int var tries int
if bind.ipv4 != nil || bind.ipv6 != nil { if bind.ipv4 != nil || bind.ipv6 != nil {
return 0, ErrBindAlreadyOpen return nil, 0, ErrBindAlreadyOpen
} }
// Attempt to open ipv4 and ipv6 listeners on the same port.
// If uport is 0, we can retry on failure.
again: again:
port := int(uport) port := int(uport)
var ipv4, ipv6 *net.UDPConn
bind.ipv4, port, err = listenNet("udp4", port) ipv4, port, err = listenNet("udp4", port)
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
bind.ipv4 = nil return nil, 0, err
return 0, err
} }
bind.ipv6, port, err = listenNet("udp6", port) // Listen on the same port as we're using for ipv4.
if uport == 0 && err != nil && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { ipv6, port, err = listenNet("udp6", port)
bind.ipv4.Close() if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
bind.ipv4 = nil ipv4.Close()
bind.ipv6 = nil
tries++ tries++
goto again goto again
} }
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
bind.ipv4.Close() ipv4.Close()
bind.ipv4 = nil return nil, 0, err
bind.ipv6 = nil
return 0, err
} }
if bind.ipv4 == nil && bind.ipv6 == nil { var fns []ReceiveFunc
return 0, syscall.EAFNOSUPPORT if ipv4 != nil {
fns = append(fns, bind.makeReceiveIPv4(ipv4))
bind.ipv4 = ipv4
} }
return uint16(port), nil if ipv6 != nil {
fns = append(fns, bind.makeReceiveIPv6(ipv6))
bind.ipv6 = ipv6
}
if len(fns) == 0 {
return nil, 0, syscall.EAFNOSUPPORT
}
return fns, uint16(port), nil
} }
func (bind *StdNetBind) Close() error { func (bind *StdNetBind) Close() error {
bind.mu.Lock()
defer bind.mu.Unlock()
var err1, err2 error var err1, err2 error
if bind.ipv4 != nil { if bind.ipv4 != nil {
err1 = bind.ipv4.Close() err1 = bind.ipv4.Close()
@@ -136,23 +152,21 @@ func (bind *StdNetBind) Close() error {
return err2 return err2
} }
func (bind *StdNetBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { func (*StdNetBind) makeReceiveIPv4(conn *net.UDPConn) ReceiveFunc {
if bind.ipv4 == nil { return func(buff []byte) (int, Endpoint, error) {
return 0, nil, syscall.EAFNOSUPPORT n, endpoint, err := conn.ReadFromUDP(buff)
if endpoint != nil {
endpoint.IP = endpoint.IP.To4()
}
return n, (*StdNetEndpoint)(endpoint), err
} }
n, endpoint, err := bind.ipv4.ReadFromUDP(buff)
if endpoint != nil {
endpoint.IP = endpoint.IP.To4()
}
return n, (*StdNetEndpoint)(endpoint), err
} }
func (bind *StdNetBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { func (*StdNetBind) makeReceiveIPv6(conn *net.UDPConn) ReceiveFunc {
if bind.ipv6 == nil { return func(buff []byte) (int, Endpoint, error) {
return 0, nil, syscall.EAFNOSUPPORT n, endpoint, err := conn.ReadFromUDP(buff)
return n, (*StdNetEndpoint)(endpoint), err
} }
n, endpoint, err := bind.ipv6.ReadFromUDP(buff)
return n, (*StdNetEndpoint)(endpoint), err
} }
func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error { func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
@@ -161,22 +175,22 @@ func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
if !ok { if !ok {
return ErrWrongEndpointType return ErrWrongEndpointType
} }
if nend.IP.To4() != nil {
if bind.ipv4 == nil { bind.mu.Lock()
return syscall.EAFNOSUPPORT blackhole := bind.blackhole4
} conn := bind.ipv4
if bind.blackhole4 { if nend.IP.To4() == nil {
return nil blackhole = bind.blackhole6
} conn = bind.ipv6
_, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend))
} else {
if bind.ipv6 == nil {
return syscall.EAFNOSUPPORT
}
if bind.blackhole6 {
return nil
}
_, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend))
} }
bind.mu.Unlock()
if blackhole {
return nil
}
if conn == nil {
return syscall.EAFNOSUPPORT
}
_, err = conn.WriteToUDP(buff, (*net.UDPAddr)(nend))
return err return err
} }

View File

@@ -47,7 +47,7 @@ func (rb *ringBuffer) Push() *ringPacket {
} }
ret := (*ringPacket)(unsafe.Pointer(rb.packets + (uintptr(rb.tail%packetsPerRing) * unsafe.Sizeof(ringPacket{})))) ret := (*ringPacket)(unsafe.Pointer(rb.packets + (uintptr(rb.tail%packetsPerRing) * unsafe.Sizeof(ringPacket{}))))
rb.tail += 1 rb.tail += 1
if rb.tail == rb.head { if rb.tail%packetsPerRing == rb.head%packetsPerRing {
rb.isFull = true rb.isFull = true
} }
return ret return ret
@@ -121,10 +121,8 @@ func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) {
if (addrinfo.Family != windows.AF_INET && addrinfo.Family != windows.AF_INET6) || addrinfo.Addrlen > unsafe.Sizeof(WinRingEndpoint{}) { if (addrinfo.Family != windows.AF_INET && addrinfo.Family != windows.AF_INET6) || addrinfo.Addrlen > unsafe.Sizeof(WinRingEndpoint{}) {
return nil, windows.ERROR_INVALID_ADDRESS return nil, windows.ERROR_INVALID_ADDRESS
} }
var src []byte
var dst [unsafe.Sizeof(WinRingEndpoint{})]byte var dst [unsafe.Sizeof(WinRingEndpoint{})]byte
unsafeSlice(unsafe.Pointer(&src), unsafe.Pointer(addrinfo.Addr), int(addrinfo.Addrlen)) copy(dst[:], unsafe.Slice((*byte)(unsafe.Pointer(addrinfo.Addr)), addrinfo.Addrlen))
copy(dst[:], src)
return (*WinRingEndpoint)(unsafe.Pointer(&dst[0])), nil return (*WinRingEndpoint)(unsafe.Pointer(&dst[0])), nil
} }
@@ -197,6 +195,9 @@ func (ring *ringBuffer) CloseAndZero() {
windows.VirtualFree(ring.packets, 0, windows.MEM_RELEASE) windows.VirtualFree(ring.packets, 0, windows.MEM_RELEASE)
ring.packets = 0 ring.packets = 0
} }
ring.head = 0
ring.tail = 0
ring.isFull = false
} }
func (bind *afWinRingBind) CloseAndZero() { func (bind *afWinRingBind) CloseAndZero() {
@@ -266,7 +267,7 @@ func (bind *afWinRingBind) Open(family int32, sa windows.Sockaddr) (windows.Sock
return sa, nil return sa, nil
} }
func (bind *WinRingBind) Open(port uint16) (selectedPort uint16, err error) { func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort uint16, err error) {
bind.mu.Lock() bind.mu.Lock()
defer bind.mu.Unlock() defer bind.mu.Unlock()
defer func() { defer func() {
@@ -275,30 +276,30 @@ func (bind *WinRingBind) Open(port uint16) (selectedPort uint16, err error) {
} }
}() }()
if atomic.LoadUint32(&bind.isOpen) != 0 { if atomic.LoadUint32(&bind.isOpen) != 0 {
return 0, ErrBindAlreadyOpen return nil, 0, ErrBindAlreadyOpen
} }
var sa windows.Sockaddr var sa windows.Sockaddr
sa, err = bind.v4.Open(windows.AF_INET, &windows.SockaddrInet4{Port: int(port)}) sa, err = bind.v4.Open(windows.AF_INET, &windows.SockaddrInet4{Port: int(port)})
if err != nil { if err != nil {
return 0, err return nil, 0, err
} }
sa, err = bind.v6.Open(windows.AF_INET6, &windows.SockaddrInet6{Port: sa.(*windows.SockaddrInet4).Port}) sa, err = bind.v6.Open(windows.AF_INET6, &windows.SockaddrInet6{Port: sa.(*windows.SockaddrInet4).Port})
if err != nil { if err != nil {
return 0, err return nil, 0, err
} }
selectedPort = uint16(sa.(*windows.SockaddrInet6).Port) selectedPort = uint16(sa.(*windows.SockaddrInet6).Port)
for i := 0; i < packetsPerRing; i++ { for i := 0; i < packetsPerRing; i++ {
err = bind.v4.InsertReceiveRequest() err = bind.v4.InsertReceiveRequest()
if err != nil { if err != nil {
return 0, err return nil, 0, err
} }
err = bind.v6.InsertReceiveRequest() err = bind.v6.InsertReceiveRequest()
if err != nil { if err != nil {
return 0, err return nil, 0, err
} }
} }
atomic.StoreUint32(&bind.isOpen, 1) atomic.StoreUint32(&bind.isOpen, 1)
return return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err
} }
func (bind *WinRingBind) Close() error { func (bind *WinRingBind) Close() error {
@@ -349,8 +350,12 @@ func (bind *afWinRingBind) Receive(buf []byte, isOpen *uint32) (int, Endpoint, e
} }
bind.rx.mu.Lock() bind.rx.mu.Lock()
defer bind.rx.mu.Unlock() defer bind.rx.mu.Unlock()
var err error
var count uint32 var count uint32
var results [1]winrio.Result var results [1]winrio.Result
retry:
count = 0
for tries := 0; count == 0 && tries < receiveSpins; tries++ { for tries := 0; count == 0 && tries < receiveSpins; tries++ {
if tries > 0 { if tries > 0 {
if atomic.LoadUint32(isOpen) != 1 { if atomic.LoadUint32(isOpen) != 1 {
@@ -361,7 +366,7 @@ func (bind *afWinRingBind) Receive(buf []byte, isOpen *uint32) (int, Endpoint, e
count = winrio.DequeueCompletion(bind.rx.cq, results[:]) count = winrio.DequeueCompletion(bind.rx.cq, results[:])
} }
if count == 0 { if count == 0 {
err := winrio.Notify(bind.rx.cq) err = winrio.Notify(bind.rx.cq)
if err != nil { if err != nil {
return 0, nil, err return 0, nil, err
} }
@@ -382,10 +387,19 @@ func (bind *afWinRingBind) Receive(buf []byte, isOpen *uint32) (int, Endpoint, e
} }
} }
bind.rx.Return(1) bind.rx.Return(1)
err := bind.InsertReceiveRequest() err = bind.InsertReceiveRequest()
if err != nil { if err != nil {
return 0, nil, err return 0, nil, err
} }
// We limit the MTU well below the 65k max for practicality, but this means a remote host can still send us
// huge packets. Just try again when this happens. The infinite loop this could cause is still limited to
// attacker bandwidth, just like the rest of the receive path.
if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE {
if atomic.LoadUint32(isOpen) != 1 {
return 0, nil, net.ErrClosed
}
goto retry
}
if results[0].Status != 0 { if results[0].Status != 0 {
return 0, nil, windows.Errno(results[0].Status) return 0, nil, windows.Errno(results[0].Status)
} }
@@ -395,13 +409,13 @@ func (bind *afWinRingBind) Receive(buf []byte, isOpen *uint32) (int, Endpoint, e
return n, &ep, nil return n, &ep, nil
} }
func (bind *WinRingBind) ReceiveIPv4(buf []byte) (int, Endpoint, error) { func (bind *WinRingBind) receiveIPv4(buf []byte) (int, Endpoint, error) {
bind.mu.RLock() bind.mu.RLock()
defer bind.mu.RUnlock() defer bind.mu.RUnlock()
return bind.v4.Receive(buf, &bind.isOpen) return bind.v4.Receive(buf, &bind.isOpen)
} }
func (bind *WinRingBind) ReceiveIPv6(buf []byte) (int, Endpoint, error) { func (bind *WinRingBind) receiveIPv6(buf []byte) (int, Endpoint, error) {
bind.mu.RLock() bind.mu.RLock()
defer bind.mu.RUnlock() defer bind.mu.RUnlock()
return bind.v6.Receive(buf, &bind.isOpen) return bind.v6.Receive(buf, &bind.isOpen)
@@ -482,6 +496,8 @@ func (bind *WinRingBind) Send(buf []byte, endpoint Endpoint) error {
} }
func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
bind.mu.Lock()
defer bind.mu.Unlock()
sysconn, err := bind.ipv4.SyscallConn() sysconn, err := bind.ipv4.SyscallConn()
if err != nil { if err != nil {
return err return err
@@ -500,6 +516,8 @@ func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole
} }
func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
bind.mu.Lock()
defer bind.mu.Unlock()
sysconn, err := bind.ipv6.SyscallConn() sysconn, err := bind.ipv6.SyscallConn()
if err != nil { if err != nil {
return err return err
@@ -561,21 +579,3 @@ func bindSocketToInterface6(handle windows.Handle, interfaceIndex uint32) error
const IPV6_UNICAST_IF = 31 const IPV6_UNICAST_IF = 31
return windows.SetsockoptInt(handle, windows.IPPROTO_IPV6, IPV6_UNICAST_IF, int(interfaceIndex)) return windows.SetsockoptInt(handle, windows.IPPROTO_IPV6, IPV6_UNICAST_IF, int(interfaceIndex))
} }
// unsafeSlice updates the slice slicePtr to be a slice
// referencing the provided data with its length & capacity set to
// lenCap.
//
// TODO: when Go 1.16 or Go 1.17 is the minimum supported version,
// update callers to use unsafe.Slice instead of this.
func unsafeSlice(slicePtr, data unsafe.Pointer, lenCap int) {
type sliceHeader struct {
Data unsafe.Pointer
Len int
Cap int
}
h := (*sliceHeader)(slicePtr)
h.Data = data
h.Len = lenCap
h.Cap = lenCap
}

View File

@@ -65,12 +65,14 @@ func (c ChannelEndpoint) DstIP() net.IP { return net.IPv4(127, 0, 0, 1) }
func (c ChannelEndpoint) SrcIP() net.IP { return nil } func (c ChannelEndpoint) SrcIP() net.IP { return nil }
func (c *ChannelBind) Open(port uint16) (actualPort uint16, err error) { func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
c.closeSignal = make(chan bool) c.closeSignal = make(chan bool)
fns = append(fns, c.makeReceiveFunc(*c.rx4))
fns = append(fns, c.makeReceiveFunc(*c.rx6))
if rand.Uint32()&1 == 0 { if rand.Uint32()&1 == 0 {
return uint16(c.source4), nil return fns, uint16(c.source4), nil
} else { } else {
return uint16(c.source6), nil return fns, uint16(c.source6), nil
} }
} }
@@ -87,21 +89,14 @@ func (c *ChannelBind) Close() error {
func (c *ChannelBind) SetMark(mark uint32) error { return nil } func (c *ChannelBind) SetMark(mark uint32) error { return nil }
func (c *ChannelBind) ReceiveIPv6(b []byte) (n int, ep conn.Endpoint, err error) { func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc {
select { return func(b []byte) (n int, ep conn.Endpoint, err error) {
case <-c.closeSignal: select {
return 0, nil, net.ErrClosed case <-c.closeSignal:
case rx := <-*c.rx6: return 0, nil, net.ErrClosed
return copy(b, rx), c.target6, nil case rx := <-ch:
} return copy(b, rx), c.target6, nil
} }
func (c *ChannelBind) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, err error) {
select {
case <-c.closeSignal:
return 0, nil, net.ErrClosed
case rx := <-*c.rx4:
return copy(b, rx), c.target4, nil
} }
} }

View File

@@ -8,10 +8,18 @@ package conn
import ( import (
"errors" "errors"
"fmt"
"net" "net"
"reflect"
"runtime"
"strings" "strings"
) )
// A ReceiveFunc receives a single inbound packet from the network.
// It writes the data into b. n is the length of the packet.
// ep is the remote endpoint.
type ReceiveFunc func(b []byte) (n int, ep Endpoint, err error)
// A Bind listens on a port for both IPv6 and IPv4 UDP traffic. // A Bind listens on a port for both IPv6 and IPv4 UDP traffic.
// //
// A Bind interface may also be a PeekLookAtSocketFd or BindSocketToInterface, // A Bind interface may also be a PeekLookAtSocketFd or BindSocketToInterface,
@@ -19,23 +27,17 @@ import (
type Bind interface { type Bind interface {
// Open puts the Bind into a listening state on a given port and reports the actual // Open puts the Bind into a listening state on a given port and reports the actual
// port that it bound to. Passing zero results in a random selection. // port that it bound to. Passing zero results in a random selection.
Open(port uint16) (actualPort uint16, err error) // fns is the set of functions that will be called to receive packets.
Open(port uint16) (fns []ReceiveFunc, actualPort uint16, err error)
// Close closes the Bind listener. // Close closes the Bind listener.
// All fns returned by Open must return net.ErrClosed after a call to Close.
Close() error Close() error
// SetMark sets the mark for each packet sent through this Bind. // SetMark sets the mark for each packet sent through this Bind.
// This mark is passed to the kernel as the socket option SO_MARK. // This mark is passed to the kernel as the socket option SO_MARK.
SetMark(mark uint32) error SetMark(mark uint32) error
// ReceiveIPv6 reads an IPv6 UDP packet into b. It reports the number of bytes read,
// n, the packet source address ep, and any error.
ReceiveIPv6(b []byte) (n int, ep Endpoint, err error)
// ReceiveIPv4 reads an IPv4 UDP packet into b. It reports the number of bytes read,
// n, the packet source address ep, and any error.
ReceiveIPv4(b []byte) (n int, ep Endpoint, err error)
// Send writes a packet b to address ep. // Send writes a packet b to address ep.
Send(b []byte, ep Endpoint) error Send(b []byte, ep Endpoint) error
@@ -70,6 +72,54 @@ type Endpoint interface {
SrcIP() net.IP SrcIP() net.IP
} }
var (
ErrBindAlreadyOpen = errors.New("bind is already open")
ErrWrongEndpointType = errors.New("endpoint type does not correspond with bind type")
)
func (fn ReceiveFunc) PrettyName() string {
name := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()
// 0. cheese/taco.beansIPv6.func12.func21218-fm
name = strings.TrimSuffix(name, "-fm")
// 1. cheese/taco.beansIPv6.func12.func21218
if idx := strings.LastIndexByte(name, '/'); idx != -1 {
name = name[idx+1:]
// 2. taco.beansIPv6.func12.func21218
}
for {
var idx int
for idx = len(name) - 1; idx >= 0; idx-- {
if name[idx] < '0' || name[idx] > '9' {
break
}
}
if idx == len(name)-1 {
break
}
const dotFunc = ".func"
if !strings.HasSuffix(name[:idx+1], dotFunc) {
break
}
name = name[:idx+1-len(dotFunc)]
// 3. taco.beansIPv6.func12
// 4. taco.beansIPv6
}
if idx := strings.LastIndexByte(name, '.'); idx != -1 {
name = name[idx+1:]
// 5. beansIPv6
}
if name == "" {
return fmt.Sprintf("%p", fn)
}
if strings.HasSuffix(name, "IPv4") {
return "v4"
}
if strings.HasSuffix(name, "IPv6") {
return "v6"
}
return name
}
func parseEndpoint(s string) (*net.UDPAddr, error) { func parseEndpoint(s string) (*net.UDPAddr, error) {
// ensure that the host is an IP address // ensure that the host is an IP address
@@ -99,8 +149,3 @@ func parseEndpoint(s string) (*net.UDPAddr, error) {
} }
return addr, err return addr, err
} }
var (
ErrBindAlreadyOpen = errors.New("bind is already open")
ErrWrongEndpointType = errors.New("endpoint type does not correspond with bind type")
)

View File

@@ -1,4 +1,4 @@
// +build !linux,!windows //go:build !linux && !windows
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *

View File

@@ -1,4 +1,4 @@
// +build !linux,!openbsd,!freebsd //go:build !linux && !openbsd && !freebsd
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *

View File

@@ -1,4 +1,4 @@
// +build linux openbsd freebsd //go:build linux || openbsd || freebsd
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *

View File

@@ -118,9 +118,17 @@ func Initialize() bool {
if err != nil { if err != nil {
return return
} }
// While we should be able to stop here, after getting the function pointers, some anti-virus actually causes // While we should be able to stop here, after getting the function pointers, some anti-virus actually causes
// failures in RIOCreateRequestQueue, so keep going to be certain this is supported. // failures in RIOCreateRequestQueue, so keep going to be certain this is supported.
cq, err = CreatePolledCompletionQueue(2) var iocp windows.Handle
iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
if err != nil {
return
}
defer windows.CloseHandle(iocp)
var overlapped windows.Overlapped
cq, err = CreateIOCPCompletionQueue(2, iocp, 0, &overlapped)
if err != nil { if err != nil {
return return
} }
@@ -161,6 +169,7 @@ func CreateIOCPCompletionQueue(queueSize uint32, iocp windows.Handle, key uintpt
notificationCompletion := &iocpNotificationCompletion{ notificationCompletion := &iocpNotificationCompletion{
completionType: iocpCompletion, completionType: iocpCompletion,
iocp: iocp, iocp: iocp,
key: key,
overlapped: overlapped, overlapped: overlapped,
} }
ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0) ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0)

View File

@@ -14,14 +14,20 @@ import (
"unsafe" "unsafe"
) )
type parentIndirection struct {
parentBit **trieEntry
parentBitType uint8
}
type trieEntry struct { type trieEntry struct {
child [2]*trieEntry peer *Peer
peer *Peer child [2]*trieEntry
bits net.IP parent parentIndirection
cidr uint cidr uint8
bit_at_byte uint bitAtByte uint8
bit_at_shift uint bitAtShift uint8
perPeerElem *list.Element bits net.IP
perPeerElem *list.Element
} }
func isLittleEndian() bool { func isLittleEndian() bool {
@@ -45,24 +51,24 @@ func swapU64(i uint64) uint64 {
return bits.ReverseBytes64(i) return bits.ReverseBytes64(i)
} }
func commonBits(ip1 net.IP, ip2 net.IP) uint { func commonBits(ip1 net.IP, ip2 net.IP) uint8 {
size := len(ip1) size := len(ip1)
if size == net.IPv4len { if size == net.IPv4len {
a := (*uint32)(unsafe.Pointer(&ip1[0])) a := (*uint32)(unsafe.Pointer(&ip1[0]))
b := (*uint32)(unsafe.Pointer(&ip2[0])) b := (*uint32)(unsafe.Pointer(&ip2[0]))
x := *a ^ *b x := *a ^ *b
return uint(bits.LeadingZeros32(swapU32(x))) return uint8(bits.LeadingZeros32(swapU32(x)))
} else if size == net.IPv6len { } else if size == net.IPv6len {
a := (*uint64)(unsafe.Pointer(&ip1[0])) a := (*uint64)(unsafe.Pointer(&ip1[0]))
b := (*uint64)(unsafe.Pointer(&ip2[0])) b := (*uint64)(unsafe.Pointer(&ip2[0]))
x := *a ^ *b x := *a ^ *b
if x != 0 { if x != 0 {
return uint(bits.LeadingZeros64(swapU64(x))) return uint8(bits.LeadingZeros64(swapU64(x)))
} }
a = (*uint64)(unsafe.Pointer(&ip1[8])) a = (*uint64)(unsafe.Pointer(&ip1[8]))
b = (*uint64)(unsafe.Pointer(&ip2[8])) b = (*uint64)(unsafe.Pointer(&ip2[8]))
x = *a ^ *b x = *a ^ *b
return 64 + uint(bits.LeadingZeros64(swapU64(x))) return 64 + uint8(bits.LeadingZeros64(swapU64(x)))
} else { } else {
panic("Wrong size bit string") panic("Wrong size bit string")
} }
@@ -79,32 +85,8 @@ func (node *trieEntry) removeFromPeerEntries() {
} }
} }
func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {
if node == nil {
return node
}
// walk recursively
node.child[0] = node.child[0].removeByPeer(p)
node.child[1] = node.child[1].removeByPeer(p)
if node.peer != p {
return node
}
// remove peer & merge
node.removeFromPeerEntries()
node.peer = nil
if node.child[0] == nil {
return node.child[1]
}
return node.child[0]
}
func (node *trieEntry) choose(ip net.IP) byte { func (node *trieEntry) choose(ip net.IP) byte {
return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1 return (ip[node.bitAtByte] >> node.bitAtShift) & 1
} }
func (node *trieEntry) maskSelf() { func (node *trieEntry) maskSelf() {
@@ -114,86 +96,125 @@ func (node *trieEntry) maskSelf() {
} }
} }
func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry { func (node *trieEntry) zeroizePointers() {
// Make the garbage collector's life slightly easier
node.peer = nil
node.child[0] = nil
node.child[1] = nil
node.parent.parentBit = nil
}
// at leaf func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry, exact bool) {
for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
parent = node
if parent.cidr == cidr {
exact = true
return
}
bit := node.choose(ip)
node = node.child[bit]
}
return
}
if node == nil { func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) {
if *trie.parentBit == nil {
node := &trieEntry{ node := &trieEntry{
bits: ip, peer: peer,
peer: peer, parent: trie,
cidr: cidr, bits: ip,
bit_at_byte: cidr / 8, cidr: cidr,
bit_at_shift: 7 - (cidr % 8), bitAtByte: cidr / 8,
bitAtShift: 7 - (cidr % 8),
} }
node.maskSelf() node.maskSelf()
node.addToPeerEntries() node.addToPeerEntries()
return node *trie.parentBit = node
return
} }
node, exact := (*trie.parentBit).nodePlacement(ip, cidr)
// traverse deeper if exact {
node.removeFromPeerEntries()
common := commonBits(node.bits, ip) node.peer = peer
if node.cidr <= cidr && common >= node.cidr { node.addToPeerEntries()
if node.cidr == cidr { return
node.removeFromPeerEntries()
node.peer = peer
node.addToPeerEntries()
return node
}
bit := node.choose(ip)
node.child[bit] = node.child[bit].insert(ip, cidr, peer)
return node
} }
// split node
newNode := &trieEntry{ newNode := &trieEntry{
bits: ip, peer: peer,
peer: peer, bits: ip,
cidr: cidr, cidr: cidr,
bit_at_byte: cidr / 8, bitAtByte: cidr / 8,
bit_at_shift: 7 - (cidr % 8), bitAtShift: 7 - (cidr % 8),
} }
newNode.maskSelf() newNode.maskSelf()
newNode.addToPeerEntries() newNode.addToPeerEntries()
cidr = min(cidr, common) var down *trieEntry
if node == nil {
// check for shorter prefix down = *trie.parentBit
} else {
bit := node.choose(ip)
down = node.child[bit]
if down == nil {
newNode.parent = parentIndirection{&node.child[bit], bit}
node.child[bit] = newNode
return
}
}
common := commonBits(down.bits, ip)
if common < cidr {
cidr = common
}
parent := node
if newNode.cidr == cidr { if newNode.cidr == cidr {
bit := newNode.choose(node.bits) bit := newNode.choose(down.bits)
newNode.child[bit] = node down.parent = parentIndirection{&newNode.child[bit], bit}
return newNode newNode.child[bit] = down
if parent == nil {
newNode.parent = trie
*trie.parentBit = newNode
} else {
bit := parent.choose(newNode.bits)
newNode.parent = parentIndirection{&parent.child[bit], bit}
parent.child[bit] = newNode
}
return
} }
// create new parent for node & newNode node = &trieEntry{
bits: append([]byte{}, newNode.bits...),
parent := &trieEntry{ cidr: cidr,
bits: append([]byte{}, ip...), bitAtByte: cidr / 8,
peer: nil, bitAtShift: 7 - (cidr % 8),
cidr: cidr,
bit_at_byte: cidr / 8,
bit_at_shift: 7 - (cidr % 8),
} }
parent.maskSelf() node.maskSelf()
bit := parent.choose(ip) bit := node.choose(down.bits)
parent.child[bit] = newNode down.parent = parentIndirection{&node.child[bit], bit}
parent.child[bit^1] = node node.child[bit] = down
bit = node.choose(newNode.bits)
return parent newNode.parent = parentIndirection{&node.child[bit], bit}
node.child[bit] = newNode
if parent == nil {
node.parent = trie
*trie.parentBit = node
} else {
bit := parent.choose(node.bits)
node.parent = parentIndirection{&parent.child[bit], bit}
parent.child[bit] = node
}
} }
func (node *trieEntry) lookup(ip net.IP) *Peer { func (node *trieEntry) lookup(ip net.IP) *Peer {
var found *Peer var found *Peer
size := uint(len(ip)) size := uint8(len(ip))
for node != nil && commonBits(node.bits, ip) >= node.cidr { for node != nil && commonBits(node.bits, ip) >= node.cidr {
if node.peer != nil { if node.peer != nil {
found = node.peer found = node.peer
} }
if node.bit_at_byte == size { if node.bitAtByte == size {
break break
} }
bit := node.choose(ip) bit := node.choose(ip)
@@ -208,7 +229,7 @@ type AllowedIPs struct {
mutex sync.RWMutex mutex sync.RWMutex
} }
func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint) bool) { func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint8) bool) {
table.mutex.RLock() table.mutex.RLock()
defer table.mutex.RUnlock() defer table.mutex.RUnlock()
@@ -224,32 +245,67 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
table.mutex.Lock() table.mutex.Lock()
defer table.mutex.Unlock() defer table.mutex.Unlock()
table.IPv4 = table.IPv4.removeByPeer(peer) var next *list.Element
table.IPv6 = table.IPv6.removeByPeer(peer) for elem := peer.trieEntries.Front(); elem != nil; elem = next {
next = elem.Next()
node := elem.Value.(*trieEntry)
node.removeFromPeerEntries()
node.peer = nil
if node.child[0] != nil && node.child[1] != nil {
continue
}
bit := 0
if node.child[0] == nil {
bit = 1
}
child := node.child[bit]
if child != nil {
child.parent = node.parent
}
*node.parent.parentBit = child
if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
node.zeroizePointers()
continue
}
parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
if parent.peer != nil {
node.zeroizePointers()
continue
}
child = parent.child[node.parent.parentBitType^1]
if child != nil {
child.parent = parent.parent
}
*parent.parent.parentBit = child
node.zeroizePointers()
parent.zeroizePointers()
}
} }
func (table *AllowedIPs) Insert(ip net.IP, cidr uint, peer *Peer) { func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) {
table.mutex.Lock() table.mutex.Lock()
defer table.mutex.Unlock() defer table.mutex.Unlock()
switch len(ip) { switch len(ip) {
case net.IPv6len: case net.IPv6len:
table.IPv6 = table.IPv6.insert(ip, cidr, peer) parentIndirection{&table.IPv6, 2}.insert(ip, cidr, peer)
case net.IPv4len: case net.IPv4len:
table.IPv4 = table.IPv4.insert(ip, cidr, peer) parentIndirection{&table.IPv4, 2}.insert(ip, cidr, peer)
default: default:
panic(errors.New("inserting unknown address type")) panic(errors.New("inserting unknown address type"))
} }
} }
func (table *AllowedIPs) LookupIPv4(address []byte) *Peer { func (table *AllowedIPs) Lookup(address []byte) *Peer {
table.mutex.RLock() table.mutex.RLock()
defer table.mutex.RUnlock() defer table.mutex.RUnlock()
return table.IPv4.lookup(address) switch len(address) {
} case net.IPv6len:
return table.IPv6.lookup(address)
func (table *AllowedIPs) LookupIPv6(address []byte) *Peer { case net.IPv4len:
table.mutex.RLock() return table.IPv4.lookup(address)
defer table.mutex.RUnlock() default:
return table.IPv6.lookup(address) panic(errors.New("looking up unknown address type"))
}
} }

View File

@@ -7,19 +7,21 @@ package device
import ( import (
"math/rand" "math/rand"
"net"
"sort" "sort"
"testing" "testing"
) )
const ( const (
NumberOfPeers = 100 NumberOfPeers = 100
NumberOfAddresses = 250 NumberOfPeerRemovals = 4
NumberOfTests = 10000 NumberOfAddresses = 250
NumberOfTests = 10000
) )
type SlowNode struct { type SlowNode struct {
peer *Peer peer *Peer
cidr uint cidr uint8
bits []byte bits []byte
} }
@@ -37,7 +39,7 @@ func (r SlowRouter) Swap(i, j int) {
r[i], r[j] = r[j], r[i] r[i], r[j] = r[j], r[i]
} }
func (r SlowRouter) Insert(addr []byte, cidr uint, peer *Peer) SlowRouter { func (r SlowRouter) Insert(addr []byte, cidr uint8, peer *Peer) SlowRouter {
for _, t := range r { for _, t := range r {
if t.cidr == cidr && commonBits(t.bits, addr) >= cidr { if t.cidr == cidr && commonBits(t.bits, addr) >= cidr {
t.peer = peer t.peer = peer
@@ -64,68 +66,75 @@ func (r SlowRouter) Lookup(addr []byte) *Peer {
return nil return nil
} }
func TestTrieRandomIPv4(t *testing.T) { func (r SlowRouter) RemoveByPeer(peer *Peer) SlowRouter {
var trie *trieEntry n := 0
var slow SlowRouter for _, x := range r {
if x.peer != peer {
r[n] = x
n++
}
}
return r[:n]
}
func TestTrieRandom(t *testing.T) {
var slow4, slow6 SlowRouter
var peers []*Peer var peers []*Peer
var allowedIPs AllowedIPs
rand.Seed(1) rand.Seed(1)
const AddressLength = 4
for n := 0; n < NumberOfPeers; n++ { for n := 0; n < NumberOfPeers; n++ {
peers = append(peers, &Peer{}) peers = append(peers, &Peer{})
} }
for n := 0; n < NumberOfAddresses; n++ { for n := 0; n < NumberOfAddresses; n++ {
var addr [AddressLength]byte var addr4 [4]byte
rand.Read(addr[:]) rand.Read(addr4[:])
cidr := uint(rand.Uint32() % (AddressLength * 8)) cidr := uint8(rand.Intn(32) + 1)
index := rand.Int() % NumberOfPeers index := rand.Intn(NumberOfPeers)
trie = trie.insert(addr[:], cidr, peers[index]) allowedIPs.Insert(addr4[:], cidr, peers[index])
slow = slow.Insert(addr[:], cidr, peers[index]) slow4 = slow4.Insert(addr4[:], cidr, peers[index])
var addr6 [16]byte
rand.Read(addr6[:])
cidr = uint8(rand.Intn(128) + 1)
index = rand.Intn(NumberOfPeers)
allowedIPs.Insert(addr6[:], cidr, peers[index])
slow6 = slow6.Insert(addr6[:], cidr, peers[index])
} }
for n := 0; n < NumberOfTests; n++ { var p int
var addr [AddressLength]byte for p = 0; ; p++ {
rand.Read(addr[:]) for n := 0; n < NumberOfTests; n++ {
peer1 := slow.Lookup(addr[:]) var addr4 [4]byte
peer2 := trie.lookup(addr[:]) rand.Read(addr4[:])
if peer1 != peer2 { peer1 := slow4.Lookup(addr4[:])
t.Error("Trie did not match naive implementation, for:", addr) peer2 := allowedIPs.Lookup(addr4[:])
} if peer1 != peer2 {
} t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr4[:]), peer1, peer2)
} }
func TestTrieRandomIPv6(t *testing.T) { var addr6 [16]byte
var trie *trieEntry rand.Read(addr6[:])
var slow SlowRouter peer1 = slow6.Lookup(addr6[:])
var peers []*Peer peer2 = allowedIPs.Lookup(addr6[:])
if peer1 != peer2 {
rand.Seed(1) t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr6[:]), peer1, peer2)
}
const AddressLength = 16
for n := 0; n < NumberOfPeers; n++ {
peers = append(peers, &Peer{})
}
for n := 0; n < NumberOfAddresses; n++ {
var addr [AddressLength]byte
rand.Read(addr[:])
cidr := uint(rand.Uint32() % (AddressLength * 8))
index := rand.Int() % NumberOfPeers
trie = trie.insert(addr[:], cidr, peers[index])
slow = slow.Insert(addr[:], cidr, peers[index])
}
for n := 0; n < NumberOfTests; n++ {
var addr [AddressLength]byte
rand.Read(addr[:])
peer1 := slow.Lookup(addr[:])
peer2 := trie.lookup(addr[:])
if peer1 != peer2 {
t.Error("Trie did not match naive implementation, for:", addr)
} }
if p >= len(peers) || p >= NumberOfPeerRemovals {
break
}
allowedIPs.RemoveByPeer(peers[p])
slow4 = slow4.RemoveByPeer(peers[p])
slow6 = slow6.RemoveByPeer(peers[p])
}
for ; p < len(peers); p++ {
allowedIPs.RemoveByPeer(peers[p])
}
if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
t.Error("Failed to remove all nodes from trie by peer")
} }
} }

View File

@@ -11,13 +11,10 @@ import (
"testing" "testing"
) )
/* Todo: More comprehensive
*/
type testPairCommonBits struct { type testPairCommonBits struct {
s1 []byte s1 []byte
s2 []byte s2 []byte
match uint match uint8
} }
func TestCommonBits(t *testing.T) { func TestCommonBits(t *testing.T) {
@@ -45,6 +42,7 @@ func TestCommonBits(t *testing.T) {
func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) { func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) {
var trie *trieEntry var trie *trieEntry
var peers []*Peer var peers []*Peer
root := parentIndirection{&trie, 2}
rand.Seed(1) rand.Seed(1)
@@ -57,9 +55,9 @@ func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *test
for n := 0; n < addressNumber; n++ { for n := 0; n < addressNumber; n++ {
var addr [AddressLength]byte var addr [AddressLength]byte
rand.Read(addr[:]) rand.Read(addr[:])
cidr := uint(rand.Uint32() % (AddressLength * 8)) cidr := uint8(rand.Uint32() % (AddressLength * 8))
index := rand.Int() % peerNumber index := rand.Int() % peerNumber
trie = trie.insert(addr[:], cidr, peers[index]) root.insert(addr[:], cidr, peers[index])
} }
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
@@ -97,21 +95,21 @@ func TestTrieIPv4(t *testing.T) {
g := &Peer{} g := &Peer{}
h := &Peer{} h := &Peer{}
var trie *trieEntry var allowedIPs AllowedIPs
insert := func(peer *Peer, a, b, c, d byte, cidr uint) { insert := func(peer *Peer, a, b, c, d byte, cidr uint8) {
trie = trie.insert([]byte{a, b, c, d}, cidr, peer) allowedIPs.Insert([]byte{a, b, c, d}, cidr, peer)
} }
assertEQ := func(peer *Peer, a, b, c, d byte) { assertEQ := func(peer *Peer, a, b, c, d byte) {
p := trie.lookup([]byte{a, b, c, d}) p := allowedIPs.Lookup([]byte{a, b, c, d})
if p != peer { if p != peer {
t.Error("Assert EQ failed") t.Error("Assert EQ failed")
} }
} }
assertNEQ := func(peer *Peer, a, b, c, d byte) { assertNEQ := func(peer *Peer, a, b, c, d byte) {
p := trie.lookup([]byte{a, b, c, d}) p := allowedIPs.Lookup([]byte{a, b, c, d})
if p == peer { if p == peer {
t.Error("Assert NEQ failed") t.Error("Assert NEQ failed")
} }
@@ -153,7 +151,7 @@ func TestTrieIPv4(t *testing.T) {
assertEQ(a, 192, 0, 0, 0) assertEQ(a, 192, 0, 0, 0)
assertEQ(a, 255, 0, 0, 0) assertEQ(a, 255, 0, 0, 0)
trie = trie.removeByPeer(a) allowedIPs.RemoveByPeer(a)
assertNEQ(a, 1, 0, 0, 0) assertNEQ(a, 1, 0, 0, 0)
assertNEQ(a, 64, 0, 0, 0) assertNEQ(a, 64, 0, 0, 0)
@@ -161,12 +159,21 @@ func TestTrieIPv4(t *testing.T) {
assertNEQ(a, 192, 0, 0, 0) assertNEQ(a, 192, 0, 0, 0)
assertNEQ(a, 255, 0, 0, 0) assertNEQ(a, 255, 0, 0, 0)
trie = nil allowedIPs.RemoveByPeer(a)
allowedIPs.RemoveByPeer(b)
allowedIPs.RemoveByPeer(c)
allowedIPs.RemoveByPeer(d)
allowedIPs.RemoveByPeer(e)
allowedIPs.RemoveByPeer(g)
allowedIPs.RemoveByPeer(h)
if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
t.Error("Expected removing all the peers to empty trie, but it did not")
}
insert(a, 192, 168, 0, 0, 16) insert(a, 192, 168, 0, 0, 16)
insert(a, 192, 168, 0, 0, 24) insert(a, 192, 168, 0, 0, 24)
trie = trie.removeByPeer(a) allowedIPs.RemoveByPeer(a)
assertNEQ(a, 192, 168, 0, 1) assertNEQ(a, 192, 168, 0, 1)
} }
@@ -184,7 +191,7 @@ func TestTrieIPv6(t *testing.T) {
g := &Peer{} g := &Peer{}
h := &Peer{} h := &Peer{}
var trie *trieEntry var allowedIPs AllowedIPs
expand := func(a uint32) []byte { expand := func(a uint32) []byte {
var out [4]byte var out [4]byte
@@ -195,13 +202,13 @@ func TestTrieIPv6(t *testing.T) {
return out[:] return out[:]
} }
insert := func(peer *Peer, a, b, c, d uint32, cidr uint) { insert := func(peer *Peer, a, b, c, d uint32, cidr uint8) {
var addr []byte var addr []byte
addr = append(addr, expand(a)...) addr = append(addr, expand(a)...)
addr = append(addr, expand(b)...) addr = append(addr, expand(b)...)
addr = append(addr, expand(c)...) addr = append(addr, expand(c)...)
addr = append(addr, expand(d)...) addr = append(addr, expand(d)...)
trie = trie.insert(addr, cidr, peer) allowedIPs.Insert(addr, cidr, peer)
} }
assertEQ := func(peer *Peer, a, b, c, d uint32) { assertEQ := func(peer *Peer, a, b, c, d uint32) {
@@ -210,7 +217,7 @@ func TestTrieIPv6(t *testing.T) {
addr = append(addr, expand(b)...) addr = append(addr, expand(b)...)
addr = append(addr, expand(c)...) addr = append(addr, expand(c)...)
addr = append(addr, expand(d)...) addr = append(addr, expand(d)...)
p := trie.lookup(addr) p := allowedIPs.Lookup(addr)
if p != peer { if p != peer {
t.Error("Assert EQ failed") t.Error("Assert EQ failed")
} }

View File

@@ -35,7 +35,6 @@ const (
/* Implementation constants */ /* Implementation constants */
const ( const (
UnderLoadQueueSize = QueueHandshakeSize / 8
UnderLoadAfterTime = time.Second // how long does the device remain under load after detected UnderLoadAfterTime = time.Second // how long does the device remain under load after detected
MaxPeers = 1 << 16 // maximum number of configured peers MaxPeers = 1 << 16 // maximum number of configured peers
) )

View File

@@ -11,9 +11,6 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/ratelimiter" "golang.zx2c4.com/wireguard/ratelimiter"
"golang.zx2c4.com/wireguard/rwcancel" "golang.zx2c4.com/wireguard/rwcancel"
@@ -213,7 +210,7 @@ func (device *Device) Down() error {
func (device *Device) IsUnderLoad() bool { func (device *Device) IsUnderLoad() bool {
// check if currently under load // check if currently under load
now := time.Now() now := time.Now()
underLoad := len(device.queue.handshake.c) >= UnderLoadQueueSize underLoad := len(device.queue.handshake.c) >= QueueHandshakeSize/8
if underLoad { if underLoad {
atomic.StoreInt64(&device.rate.underLoadUntil, now.Add(UnderLoadAfterTime).UnixNano()) atomic.StoreInt64(&device.rate.underLoadUntil, now.Add(UnderLoadAfterTime).UnixNano())
return true return true
@@ -307,9 +304,9 @@ func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
device.state.stopping.Wait() device.state.stopping.Wait()
device.queue.encryption.wg.Add(cpus) // One for each RoutineHandshake device.queue.encryption.wg.Add(cpus) // One for each RoutineHandshake
for i := 0; i < cpus; i++ { for i := 0; i < cpus; i++ {
go device.RoutineEncryption() go device.RoutineEncryption(i + 1)
go device.RoutineDecryption() go device.RoutineDecryption(i + 1)
go device.RoutineHandshake() go device.RoutineHandshake(i + 1)
} }
device.state.stopping.Add(1) // RoutineReadFromTUN device.state.stopping.Add(1) // RoutineReadFromTUN
@@ -400,7 +397,9 @@ func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
device.peers.RUnlock() device.peers.RUnlock()
} }
func unsafeCloseBind(device *Device) error { // closeBindLocked closes the device's net.bind.
// The caller must hold the net mutex.
func closeBindLocked(device *Device) error {
var err error var err error
netc := &device.net netc := &device.net
if netc.netlinkCancel != nil { if netc.netlinkCancel != nil {
@@ -455,7 +454,7 @@ func (device *Device) BindUpdate() error {
defer device.net.Unlock() defer device.net.Unlock()
// close existing sockets // close existing sockets
if err := unsafeCloseBind(device); err != nil { if err := closeBindLocked(device); err != nil {
return err return err
} }
@@ -466,8 +465,9 @@ func (device *Device) BindUpdate() error {
// bind to new port // bind to new port
var err error var err error
var recvFns []conn.ReceiveFunc
netc := &device.net netc := &device.net
netc.port, err = netc.bind.Open(netc.port) recvFns, netc.port, err = netc.bind.Open(netc.port)
if err != nil { if err != nil {
netc.port = 0 netc.port = 0
return err return err
@@ -499,11 +499,12 @@ func (device *Device) BindUpdate() error {
device.peers.RUnlock() device.peers.RUnlock()
// start receiving routines // start receiving routines
device.net.stopping.Add(2) device.net.stopping.Add(len(recvFns))
device.queue.decryption.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
device.queue.handshake.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
go device.RoutineReceiveIncoming(ipv4.Version, netc.bind) for _, fn := range recvFns {
go device.RoutineReceiveIncoming(ipv6.Version, netc.bind) go device.RoutineReceiveIncoming(fn)
}
device.log.Verbosef("UDP bind has been updated") device.log.Verbosef("UDP bind has been updated")
return nil return nil
@@ -511,7 +512,7 @@ func (device *Device) BindUpdate() error {
func (device *Device) BindClose() error { func (device *Device) BindClose() error {
device.net.Lock() device.net.Lock()
err := unsafeCloseBind(device) err := closeBindLocked(device)
device.net.Unlock() device.net.Unlock()
return err return err
} }

View File

@@ -39,10 +39,3 @@ func (a *AtomicBool) Set(val bool) {
} }
atomic.StoreInt32(&a.int32, flag) atomic.StoreInt32(&a.int32, flag)
} }
func min(a, b uint) uint {
if a > b {
return b
}
return a
}

View File

@@ -9,8 +9,8 @@ func (device *Device) DisableSomeRoamingForBrokenMobileSemantics() {
device.peers.RLock() device.peers.RLock()
for _, peer := range device.peers.keyMap { for _, peer := range device.peers.keyMap {
peer.Lock() peer.Lock()
defer peer.Unlock()
peer.disableRoaming = peer.endpoint != nil peer.disableRoaming = peer.endpoint != nil
peer.Unlock()
} }
device.peers.RUnlock() device.peers.RUnlock()
} }

View File

@@ -7,9 +7,7 @@ package device
import ( import (
"container/list" "container/list"
"encoding/base64"
"errors" "errors"
"fmt"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -144,12 +142,29 @@ func (peer *Peer) SendBuffer(buffer []byte) error {
} }
func (peer *Peer) String() string { func (peer *Peer) String() string {
base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]) // The awful goo that follows is identical to:
abbreviatedKey := "invalid" //
if len(base64Key) == 44 { // base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:])
abbreviatedKey = base64Key[0:4] + "…" + base64Key[39:43] // abbreviatedKey := base64Key[0:4] + "…" + base64Key[39:43]
// return fmt.Sprintf("peer(%s)", abbreviatedKey)
//
// except that it is considerably more efficient.
src := peer.handshake.remoteStatic
b64 := func(input byte) byte {
return input + 'A' + byte(((25-int(input))>>8)&6) - byte(((51-int(input))>>8)&75) - byte(((61-int(input))>>8)&15) + byte(((62-int(input))>>8)&3)
} }
return fmt.Sprintf("peer(%s)", abbreviatedKey) b := []byte("peer(____…____)")
const first = len("peer(")
const second = len("peer(____…")
b[first+0] = b64((src[0] >> 2) & 63)
b[first+1] = b64(((src[0] << 4) | (src[1] >> 4)) & 63)
b[first+2] = b64(((src[1] << 2) | (src[2] >> 6)) & 63)
b[first+3] = b64(src[2] & 63)
b[second+0] = b64(src[29] & 63)
b[second+1] = b64((src[30] >> 2) & 63)
b[second+2] = b64(((src[30] << 4) | (src[31] >> 4)) & 63)
b[second+3] = b64((src[31] << 2) & 63)
return string(b)
} }
func (peer *Peer) Start() { func (peer *Peer) Start() {
@@ -167,7 +182,7 @@ func (peer *Peer) Start() {
} }
device := peer.device device := peer.device
device.log.Verbosef("%v - Starting...", peer) device.log.Verbosef("%v - Starting", peer)
// reset routine state // reset routine state
peer.stopping.Wait() peer.stopping.Wait()
@@ -243,7 +258,7 @@ func (peer *Peer) Stop() {
return return
} }
peer.device.log.Verbosef("%v - Stopping...", peer) peer.device.log.Verbosef("%v - Stopping", peer)
peer.timersStop() peer.timersStop()
// Signal that RoutineSequentialSender and RoutineSequentialReceiver should exit. // Signal that RoutineSequentialSender and RoutineSequentialReceiver should exit.

View File

@@ -1,4 +1,4 @@
// +build !android,!ios,!windows //go:build !android && !ios && !windows
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *

View File

@@ -1,4 +1,4 @@
// +build ios //go:build ios
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
@@ -7,13 +7,15 @@
package device package device
/* Fit within memory limits for iOS's Network Extension API, which has stricter requirements */ // Fit within memory limits for iOS's Network Extension API, which has stricter requirements.
// These are vars instead of consts, because heavier network extensions might want to reduce
const ( // them further.
QueueStagedSize = 128 var (
QueueOutboundSize = 1024 QueueStagedSize = 128
QueueInboundSize = 1024 QueueOutboundSize = 1024
QueueHandshakeSize = 1024 QueueInboundSize = 1024
MaxSegmentSize = 1700 QueueHandshakeSize = 1024
PreallocatedBuffersPerPool = 1024 PreallocatedBuffersPerPool uint32 = 1024
) )
const MaxSegmentSize = 1700

View File

@@ -1,4 +1,4 @@
//+build !race //go:build !race
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *

View File

@@ -1,4 +1,4 @@
//+build race //go:build race
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *

View File

@@ -68,15 +68,16 @@ func (peer *Peer) keepKeyFreshReceiving() {
* Every time the bind is updated a new routine is started for * Every time the bind is updated a new routine is started for
* IPv4 and IPv6 (separately) * IPv4 and IPv6 (separately)
*/ */
func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) { func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
recvName := recv.PrettyName()
defer func() { defer func() {
device.log.Verbosef("Routine: receive incoming IPv%d - stopped", IP) device.log.Verbosef("Routine: receive incoming %s - stopped", recvName)
device.queue.decryption.wg.Done() device.queue.decryption.wg.Done()
device.queue.handshake.wg.Done() device.queue.handshake.wg.Done()
device.net.stopping.Done() device.net.stopping.Done()
}() }()
device.log.Verbosef("Routine: receive incoming IPv%d - started", IP) device.log.Verbosef("Routine: receive incoming %s - started", recvName)
// receive datagrams until conn is closed // receive datagrams until conn is closed
@@ -90,24 +91,21 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) {
) )
for { for {
switch IP { size, endpoint, err = recv(buffer[:])
case ipv4.Version:
size, endpoint, err = bind.ReceiveIPv4(buffer[:])
case ipv6.Version:
size, endpoint, err = bind.ReceiveIPv6(buffer[:])
default:
panic("invalid IP version")
}
if err != nil { if err != nil {
device.PutMessageBuffer(buffer) device.PutMessageBuffer(buffer)
if errors.Is(err, net.ErrClosed) { if errors.Is(err, net.ErrClosed) {
return return
} }
device.log.Errorf("Failed to receive packet: %v", err) device.log.Verbosef("Failed to receive %s packet: %v", recvName, err)
if neterr, ok := err.(net.Error); ok && !neterr.Temporary() {
return
}
if deathSpiral < 10 { if deathSpiral < 10 {
deathSpiral++ deathSpiral++
time.Sleep(time.Second / 3) time.Sleep(time.Second / 3)
buffer = device.GetMessageBuffer()
continue continue
} }
return return
@@ -205,11 +203,11 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) {
} }
} }
func (device *Device) RoutineDecryption() { func (device *Device) RoutineDecryption(id int) {
var nonce [chacha20poly1305.NonceSize]byte var nonce [chacha20poly1305.NonceSize]byte
defer device.log.Verbosef("Routine: decryption worker - stopped") defer device.log.Verbosef("Routine: decryption worker %d - stopped", id)
device.log.Verbosef("Routine: decryption worker - started") device.log.Verbosef("Routine: decryption worker %d - started", id)
for elem := range device.queue.decryption.c { for elem := range device.queue.decryption.c {
// split message into fields // split message into fields
@@ -236,12 +234,12 @@ func (device *Device) RoutineDecryption() {
/* Handles incoming packets related to handshake /* Handles incoming packets related to handshake
*/ */
func (device *Device) RoutineHandshake() { func (device *Device) RoutineHandshake(id int) {
defer func() { defer func() {
device.log.Verbosef("Routine: handshake worker - stopped") device.log.Verbosef("Routine: handshake worker %d - stopped", id)
device.queue.encryption.wg.Done() device.queue.encryption.wg.Done()
}() }()
device.log.Verbosef("Routine: handshake worker - started") device.log.Verbosef("Routine: handshake worker %d - started", id)
for elem := range device.queue.handshake.c { for elem := range device.queue.handshake.c {
@@ -449,7 +447,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
} }
elem.packet = elem.packet[:length] elem.packet = elem.packet[:length]
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
if device.allowedips.LookupIPv4(src) != peer { if device.allowedips.Lookup(src) != peer {
device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer) device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
goto skip goto skip
} }
@@ -466,7 +464,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
} }
elem.packet = elem.packet[:length] elem.packet = elem.packet[:length]
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
if device.allowedips.LookupIPv6(src) != peer { if device.allowedips.Lookup(src) != peer {
device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer) device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
goto skip goto skip
} }

View File

@@ -8,7 +8,9 @@ package device
import ( import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"errors"
"net" "net"
"os"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -227,7 +229,9 @@ func (device *Device) RoutineReadFromTUN() {
if err != nil { if err != nil {
if !device.isClosed() { if !device.isClosed() {
device.log.Errorf("Failed to read packet from TUN device: %v", err) if !errors.Is(err, os.ErrClosed) {
device.log.Errorf("Failed to read packet from TUN device: %v", err)
}
go device.Close() go device.Close()
} }
device.PutMessageBuffer(elem.buffer) device.PutMessageBuffer(elem.buffer)
@@ -250,14 +254,14 @@ func (device *Device) RoutineReadFromTUN() {
continue continue
} }
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
peer = device.allowedips.LookupIPv4(dst) peer = device.allowedips.Lookup(dst)
case ipv6.Version: case ipv6.Version:
if len(elem.packet) < ipv6.HeaderLen { if len(elem.packet) < ipv6.HeaderLen {
continue continue
} }
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
peer = device.allowedips.LookupIPv6(dst) peer = device.allowedips.Lookup(dst)
default: default:
device.log.Verbosef("Received packet with unknown IP version") device.log.Verbosef("Received packet with unknown IP version")
@@ -362,12 +366,12 @@ func calculatePaddingSize(packetSize, mtu int) int {
* *
* Obs. One instance per core * Obs. One instance per core
*/ */
func (device *Device) RoutineEncryption() { func (device *Device) RoutineEncryption(id int) {
var paddingZeros [PaddingMultiple]byte var paddingZeros [PaddingMultiple]byte
var nonce [chacha20poly1305.NonceSize]byte var nonce [chacha20poly1305.NonceSize]byte
defer device.log.Verbosef("Routine: encryption worker - stopped") defer device.log.Verbosef("Routine: encryption worker %d - stopped", id)
device.log.Verbosef("Routine: encryption worker - started") device.log.Verbosef("Routine: encryption worker %d - started", id)
for elem := range device.queue.encryption.c { for elem := range device.queue.encryption.c {
// populate header fields // populate header fields

View File

@@ -1,4 +1,4 @@
// +build !linux //go:build !linux
package device package device

View File

@@ -121,7 +121,7 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes)) sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes))
sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval)) sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval))
device.allowedips.EntriesForPeer(peer, func(ip net.IP, cidr uint) bool { device.allowedips.EntriesForPeer(peer, func(ip net.IP, cidr uint8) bool {
sendf("allowed_ip=%s/%d", ip.String(), cidr) sendf("allowed_ip=%s/%d", ip.String(), cidr)
return true return true
}) })
@@ -379,7 +379,7 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
return nil return nil
} }
ones, _ := network.Mask.Size() ones, _ := network.Mask.Size()
device.allowedips.Insert(network.IP, uint(ones), peer.Peer) device.allowedips.Insert(network.IP, uint8(ones), peer.Peer)
case "protocol_version": case "protocol_version":
if value != "1" { if value != "1" {

51
format_test.go Normal file
View File

@@ -0,0 +1,51 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2021 WireGuard LLC. All Rights Reserved.
*/
package main
import (
"bytes"
"go/format"
"io/fs"
"os"
"path/filepath"
"runtime"
"sync"
"testing"
)
func TestFormatting(t *testing.T) {
var wg sync.WaitGroup
filepath.WalkDir(".", func(path string, d fs.DirEntry, err error) error {
if err != nil {
t.Errorf("unable to walk %s: %v", path, err)
return nil
}
if d.IsDir() || filepath.Ext(path) != ".go" {
return nil
}
wg.Add(1)
go func(path string) {
defer wg.Done()
src, err := os.ReadFile(path)
if err != nil {
t.Errorf("unable to read %s: %v", path, err)
return
}
if runtime.GOOS == "windows" {
src = bytes.ReplaceAll(src, []byte{'\r', '\n'}, []byte{'\n'})
}
formatted, err := format.Source(src)
if err != nil {
t.Errorf("unable to format %s: %v", path, err)
return
}
if !bytes.Equal(src, formatted) {
t.Errorf("unformatted code: %s", path)
}
}(path)
return nil
})
wg.Wait()
}

8
go.mod
View File

@@ -1,9 +1,9 @@
module golang.zx2c4.com/wireguard module golang.zx2c4.com/wireguard
go 1.16 go 1.17
require ( require (
golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83 golang.org/x/crypto v0.0.0-20210921155107-089bfa567519
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 golang.org/x/net v0.0.0-20210927181540-4e4d966f7476
golang.org/x/sys v0.0.0-20210309040221-94ec62e08169 golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6
) )

20
go.sum
View File

@@ -1,16 +1,14 @@
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 h1:7I4JAnoQBe7ZtJcBaYHi5UtiO8tQHbUSXxL+pnGRANg=
golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83 h1:/ZScEX8SfEmUGRHs0gxpqteO5nfNW6axyZbBdw9A12g= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 h1:qWPm9rbaAMKs8Bq/9LRpbMqxWRVUAQwMI9fVrssnTfw=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/net v0.0.0-20210927181540-4e4d966f7476 h1:s5hu7bTnLKswvidgtqc4GwsW83m9LZu8UAqzmWOZtI4=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/net v0.0.0-20210927181540-4e4d966f7476/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210309040221-94ec62e08169 h1:fpeMGRM6A+XFcw4RPCO8s8hH7ppgrGR22pSIjwM7YUI= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210309040221-94ec62e08169/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6 h1:foEbQz/B0Oz6YIqu/69kfXPYeFQAuuMYFkjaqXzl5Wo=
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

View File

@@ -1,4 +1,4 @@
// +build darwin freebsd openbsd //go:build darwin || freebsd || openbsd
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *

View File

@@ -1,4 +1,4 @@
// +build linux darwin freebsd openbsd //go:build linux || darwin || freebsd || openbsd
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *

View File

@@ -54,8 +54,7 @@ var UAPISecurityDescriptor *windows.SECURITY_DESCRIPTOR
func init() { func init() {
var err error var err error
/* SDDL_DEVOBJ_SYS_ALL from the WDK */ UAPISecurityDescriptor, err = windows.SecurityDescriptorFromString("O:SYD:P(A;;GA;;;SY)(A;;GA;;;BA)S:(ML;;NWNRNX;;;HI)")
UAPISecurityDescriptor, err = windows.SecurityDescriptorFromString("O:SYD:P(A;;GA;;;SY)")
if err != nil { if err != nil {
panic(err) panic(err)
} }

View File

@@ -1,4 +1,4 @@
// +build windows //go:build windows
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *

View File

@@ -1,4 +1,4 @@
// +build windows //go:build windows
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *

View File

@@ -1,4 +1,4 @@
// +build windows //go:build windows
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *

31
main.go
View File

@@ -1,4 +1,4 @@
// +build !windows //go:build !windows
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
@@ -33,25 +33,28 @@ const (
) )
func printUsage() { func printUsage() {
fmt.Printf("usage:\n") fmt.Printf("Usage: %s [-f/--foreground] INTERFACE-NAME\n", os.Args[0])
fmt.Printf("%s [-f/--foreground] INTERFACE-NAME\n", os.Args[0])
} }
func warning() { func warning() {
if runtime.GOOS != "linux" || os.Getenv(ENV_WG_PROCESS_FOREGROUND) == "1" { switch runtime.GOOS {
case "linux", "freebsd", "openbsd":
if os.Getenv(ENV_WG_PROCESS_FOREGROUND) == "1" {
return
}
default:
return return
} }
fmt.Fprintln(os.Stderr, "┌───────────────────────────────────────────────────┐") fmt.Fprintln(os.Stderr, "┌──────────────────────────────────────────────────────┐")
fmt.Fprintln(os.Stderr, "│ │") fmt.Fprintln(os.Stderr, "│ │")
fmt.Fprintln(os.Stderr, "│ Running this software on Linux is unnecessary, │") fmt.Fprintln(os.Stderr, "│ Running wireguard-go is not required because this │")
fmt.Fprintln(os.Stderr, "│ because the Linux kernel has built-in first │") fmt.Fprintln(os.Stderr, "│ kernel has first class support for WireGuard. For │")
fmt.Fprintln(os.Stderr, "│ class support for WireGuard, which will be │") fmt.Fprintln(os.Stderr, "│ information on installing the kernel module, │")
fmt.Fprintln(os.Stderr, "│ faster, slicker, and better integrated. For │") fmt.Fprintln(os.Stderr, "│ please visit: │")
fmt.Fprintln(os.Stderr, "│ information on installing the kernel module, │") fmt.Fprintln(os.Stderr, "│ https://www.wireguard.com/install/ │")
fmt.Fprintln(os.Stderr, "│ please visit: <https://wireguard.com/install>. │") fmt.Fprintln(os.Stderr, "│ │")
fmt.Fprintln(os.Stderr, "│ │") fmt.Fprintln(os.Stderr, "└──────────────────────────────────────────────────────┘")
fmt.Fprintln(os.Stderr, "└───────────────────────────────────────────────────┘")
} }
func main() { func main() {

View File

@@ -1,24 +0,0 @@
// +build !windows
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package rwcancel
import "golang.org/x/sys/unix"
type fdSet struct {
unix.FdSet
}
func (fdset *fdSet) set(i int) {
bits := 32 << (^uint(0) >> 63)
fdset.Bits[i/bits] |= 1 << uint(i%bits)
}
func (fdset *fdSet) check(i int) bool {
bits := 32 << (^uint(0) >> 63)
return (fdset.Bits[i/bits] & (1 << uint(i%bits))) != 0
}

View File

@@ -1,4 +1,4 @@
// +build !windows //go:build !windows
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
@@ -17,13 +17,6 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
func max(a, b int) int {
if a > b {
return a
}
return b
}
type RWCancel struct { type RWCancel struct {
fd int fd int
closingReader *os.File closingReader *os.File
@@ -50,13 +43,12 @@ func RetryAfterError(err error) bool {
} }
func (rw *RWCancel) ReadyRead() bool { func (rw *RWCancel) ReadyRead() bool {
closeFd := int(rw.closingReader.Fd()) closeFd := int32(rw.closingReader.Fd())
fdset := fdSet{}
fdset.set(rw.fd) pollFds := []unix.PollFd{{Fd: int32(rw.fd), Events: unix.POLLIN}, {Fd: closeFd, Events: unix.POLLIN}}
fdset.set(closeFd)
var err error var err error
for { for {
err = unixSelect(max(rw.fd, closeFd)+1, &fdset.FdSet, nil, nil, nil) _, err = unix.Poll(pollFds, -1)
if err == nil || !RetryAfterError(err) { if err == nil || !RetryAfterError(err) {
break break
} }
@@ -64,20 +56,18 @@ func (rw *RWCancel) ReadyRead() bool {
if err != nil { if err != nil {
return false return false
} }
if fdset.check(closeFd) { if pollFds[1].Revents != 0 {
return false return false
} }
return fdset.check(rw.fd) return pollFds[0].Revents != 0
} }
func (rw *RWCancel) ReadyWrite() bool { func (rw *RWCancel) ReadyWrite() bool {
closeFd := int(rw.closingReader.Fd()) closeFd := int32(rw.closingReader.Fd())
fdset := fdSet{} pollFds := []unix.PollFd{{Fd: int32(rw.fd), Events: unix.POLLOUT}, {Fd: closeFd, Events: unix.POLLOUT}}
fdset.set(rw.fd)
fdset.set(closeFd)
var err error var err error
for { for {
err = unixSelect(max(rw.fd, closeFd)+1, nil, &fdset.FdSet, nil, nil) _, err = unix.Poll(pollFds, -1)
if err == nil || !RetryAfterError(err) { if err == nil || !RetryAfterError(err) {
break break
} }
@@ -85,10 +75,11 @@ func (rw *RWCancel) ReadyWrite() bool {
if err != nil { if err != nil {
return false return false
} }
if fdset.check(closeFd) {
if pollFds[1].Revents != 0 {
return false return false
} }
return fdset.check(rw.fd) return pollFds[0].Revents != 0
} }
func (rw *RWCancel) Read(p []byte) (n int, err error) { func (rw *RWCancel) Read(p []byte) (n int, err error) {
@@ -98,7 +89,7 @@ func (rw *RWCancel) Read(p []byte) (n int, err error) {
return n, err return n, err
} }
if !rw.ReadyRead() { if !rw.ReadyRead() {
return 0, errors.New("fd closed") return 0, os.ErrClosed
} }
} }
} }
@@ -110,7 +101,7 @@ func (rw *RWCancel) Write(p []byte) (n int, err error) {
return n, err return n, err
} }
if !rw.ReadyWrite() { if !rw.ReadyWrite() {
return 0, errors.New("fd closed") return 0, os.ErrClosed
} }
} }
} }

View File

@@ -1,15 +0,0 @@
// +build !linux,!windows
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package rwcancel
import "golang.org/x/sys/unix"
func unixSelect(nfd int, r *unix.FdSet, w *unix.FdSet, e *unix.FdSet, timeout *unix.Timeval) error {
_, err := unix.Select(nfd, r, w, e, timeout)
return err
}

View File

@@ -1,13 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package rwcancel
import "golang.org/x/sys/unix"
func unixSelect(nfd int, r *unix.FdSet, w *unix.FdSet, e *unix.FdSet, timeout *unix.Timeval) (err error) {
_, err = unix.Select(nfd, r, w, e, timeout)
return
}

View File

@@ -0,0 +1,67 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package tun
import (
"reflect"
"testing"
"unsafe"
)
func checkAlignment(t *testing.T, name string, offset uintptr) {
t.Helper()
if offset%8 != 0 {
t.Errorf("offset of %q within struct is %d bytes, which does not align to 64-bit word boundaries (missing %d bytes). Atomic operations will crash on 32-bit systems.", name, offset, 8-(offset%8))
}
}
// TestRateJugglerAlignment checks that atomically-accessed fields are
// aligned to 64-bit boundaries, as required by the atomic package.
//
// Unfortunately, violating this rule on 32-bit platforms results in a
// hard segfault at runtime.
func TestRateJugglerAlignment(t *testing.T) {
var r rateJuggler
typ := reflect.TypeOf(&r).Elem()
t.Logf("Peer type size: %d, with fields:", typ.Size())
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)",
field.Name,
field.Offset,
field.Type.Size(),
field.Type.Align(),
)
}
checkAlignment(t, "rateJuggler.current", unsafe.Offsetof(r.current))
checkAlignment(t, "rateJuggler.nextByteCount", unsafe.Offsetof(r.nextByteCount))
checkAlignment(t, "rateJuggler.nextStartTime", unsafe.Offsetof(r.nextStartTime))
}
// TestNativeTunAlignment checks that atomically-accessed fields are
// aligned to 64-bit boundaries, as required by the atomic package.
//
// Unfortunately, violating this rule on 32-bit platforms results in a
// hard segfault at runtime.
func TestNativeTunAlignment(t *testing.T) {
var tun NativeTun
typ := reflect.TypeOf(&tun).Elem()
t.Logf("Peer type size: %d, with fields:", typ.Size())
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)",
field.Name,
field.Offset,
field.Type.Size(),
field.Type.Align(),
)
}
checkAlignment(t, "NativeTun.rate", unsafe.Offsetof(tun.rate))
}

View File

@@ -1,4 +1,4 @@
// +build ignore //go:build ignore
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *

View File

@@ -1,4 +1,4 @@
// +build ignore //go:build ignore
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *

View File

@@ -3,9 +3,9 @@ module golang.zx2c4.com/wireguard/tun/netstack
go 1.16 go 1.16
require ( require (
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6
golang.org/x/sys v0.0.0-20210305230114-8fe3ee5dd75b // indirect golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7 // indirect
golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba // indirect golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba // indirect
golang.zx2c4.com/wireguard v0.0.0-20210306154438-593658d9755b golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22
gvisor.dev/gvisor v0.0.0-20210306005637-a64c3a1b5a9f gvisor.dev/gvisor v0.0.0-20210506004418-fbfeba3024f0
) )

View File

@@ -112,7 +112,7 @@ github.com/go-openapi/jsonreference v0.0.0-20160704190145-13c6e3589ad9/go.mod h1
github.com/go-openapi/spec v0.0.0-20160808142527-6aced65f8501/go.mod h1:J8+jY1nAiCcj+friV/PDoE1/3eeccG9LYBs0tYvLOWc= github.com/go-openapi/spec v0.0.0-20160808142527-6aced65f8501/go.mod h1:J8+jY1nAiCcj+friV/PDoE1/3eeccG9LYBs0tYvLOWc=
github.com/go-openapi/swag v0.0.0-20160704191624-1d0bd113de87/go.mod h1:DXUve3Dpr1UfpPtxFw+EFuQ41HhCWZfha5jSVRG7C7I= github.com/go-openapi/swag v0.0.0-20160704191624-1d0bd113de87/go.mod h1:DXUve3Dpr1UfpPtxFw+EFuQ41HhCWZfha5jSVRG7C7I=
github.com/godbus/dbus/v5 v5.0.3/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/godbus/dbus/v5 v5.0.3/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU= github.com/gofrs/flock v0.8.0/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU=
github.com/gogo/googleapis v1.4.0/go.mod h1:5YRNX2z1oM5gXdAkurHa942MDgEJyk02w4OecKY87+c= github.com/gogo/googleapis v1.4.0/go.mod h1:5YRNX2z1oM5gXdAkurHa942MDgEJyk02w4OecKY87+c=
github.com/gogo/protobuf v1.2.2-0.20190723190241-65acae22fc9d/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= github.com/gogo/protobuf v1.2.2-0.20190723190241-65acae22fc9d/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o=
github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o=
@@ -344,9 +344,9 @@ golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81R
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20201031054903-ff519b6c9102/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201031054903-ff519b6c9102/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210224082022-3d97a244fca7/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 h1:qWPm9rbaAMKs8Bq/9LRpbMqxWRVUAQwMI9fVrssnTfw=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6 h1:0PC75Fz/kyMGhL0e1QnypqK2kQMqKt9csD1GnMJR+Zk=
golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
@@ -409,9 +409,10 @@ golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201201145000-ef89a241ccb3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201201145000-ef89a241ccb3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210225014209-683adc9d29d7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210309040221-94ec62e08169/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210305230114-8fe3ee5dd75b h1:ggRgirZABFolTmi3sn6Ivd9SipZwLedQ5wR0aAKnFxU= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210305230114-8fe3ee5dd75b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7 h1:iGu644GcxtEcrInvDsQRCwJjtCIOlT2V7IRt6ah2Whw=
golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@@ -421,6 +422,7 @@ golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
@@ -477,8 +479,8 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.zx2c4.com/wireguard v0.0.0-20210306154438-593658d9755b h1:DaVhJ+qTYKBJ4onaAbkkMnAW7xS8BD6NuG7lbloDZ0Y= golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22 h1:ytS28bw9HtZVDRMDxviC6ryCJuccw+zXhh04u2IRWJw=
golang.zx2c4.com/wireguard v0.0.0-20210306154438-593658d9755b/go.mod h1:39ZQQ95hUxDxT7opsWy/rtfgvXXc8s30qfZ02df69Fo= golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22/go.mod h1:a057zjmoc00UN7gVkaJt2sXVK523kMJcogDTEvPIasg=
google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE=
google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M= google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M=
google.golang.org/api v0.8.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg= google.golang.org/api v0.8.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg=
@@ -576,8 +578,8 @@ gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw= gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw=
gvisor.dev/gvisor v0.0.0-20210306005637-a64c3a1b5a9f h1:pROeAG1eh1j7CxnDx6Lft4cPoXtNLhjePTlgrssJSM4= gvisor.dev/gvisor v0.0.0-20210506004418-fbfeba3024f0 h1:Ny53d6x76f7EgvYe7pi8SQcIzdRM69y1ckQ8rRtT6ww=
gvisor.dev/gvisor v0.0.0-20210306005637-a64c3a1b5a9f/go.mod h1:TMDExUbPoyZXybG8K1+vm5ElIxI5qXUM5fwK7vI8YIg= gvisor.dev/gvisor v0.0.0-20210506004418-fbfeba3024f0/go.mod h1:ucHEMlckp+S/YzKEpwwAyGBhAh807Wxq/8Erc6gFxCE=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=

View File

@@ -74,12 +74,12 @@ func (*endpoint) LinkAddress() tcpip.LinkAddress {
func (*endpoint) Wait() {} func (*endpoint) Wait() {}
func (e *endpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { func (e *endpoint) WritePacket(_ stack.RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
e.incomingPacket <- buffer.NewVectorisedView(pkt.Size(), pkt.Views()) e.incomingPacket <- buffer.NewVectorisedView(pkt.Size(), pkt.Views())
return nil return nil
} }
func (e *endpoint) WritePackets(stack.RouteInfo, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) { func (e *endpoint) WritePackets(stack.RouteInfo, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
panic("not implemented") panic("not implemented")
} }

View File

@@ -1,4 +1,4 @@
// +build !windows //go:build darwin || freebsd
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *

View File

@@ -108,7 +108,6 @@ func CreateTUN(name string, mtu int) (Device, error) {
} }
fd, err := unix.Socket(unix.AF_SYSTEM, unix.SOCK_DGRAM, 2) fd, err := unix.Socket(unix.AF_SYSTEM, unix.SOCK_DGRAM, 2)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -117,6 +116,7 @@ func CreateTUN(name string, mtu int) (Device, error) {
copy(ctlInfo.Name[:], []byte(utunControlName)) copy(ctlInfo.Name[:], []byte(utunControlName))
err = unix.IoctlCtlInfo(fd, ctlInfo) err = unix.IoctlCtlInfo(fd, ctlInfo)
if err != nil { if err != nil {
unix.Close(fd)
return nil, fmt.Errorf("IoctlGetCtlInfo: %w", err) return nil, fmt.Errorf("IoctlGetCtlInfo: %w", err)
} }
@@ -127,11 +127,13 @@ func CreateTUN(name string, mtu int) (Device, error) {
err = unix.Connect(fd, sc) err = unix.Connect(fd, sc)
if err != nil { if err != nil {
unix.Close(fd)
return nil, err return nil, err
} }
err = syscall.SetNonblock(fd, true) err = unix.SetNonblock(fd, true)
if err != nil { if err != nil {
unix.Close(fd)
return nil, err return nil, err
} }
tun, err := CreateTUNFromFile(os.NewFile(uintptr(fd), ""), mtu) tun, err := CreateTUNFromFile(os.NewFile(uintptr(fd), ""), mtu)

View File

@@ -6,62 +6,52 @@
package tun package tun
import ( import (
"bytes"
"errors" "errors"
"fmt" "fmt"
"io"
"net" "net"
"os" "os"
"sync" "sync"
"syscall" "syscall"
"unsafe" "unsafe"
"golang.org/x/net/ipv6"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
// _TUNSIFHEAD, value derived from sys/net/{if_tun,ioccom}.h
// const _TUNSIFHEAD = ((0x80000000) | (((4) & ((1 << 13) - 1) ) << 16) | (uint32(byte('t')) << 8) | (96))
const ( const (
_TUNSIFHEAD = 0x80047460 _TUNSIFHEAD = 0x80047460
_TUNSIFMODE = 0x8004745e _TUNSIFMODE = 0x8004745e
_TUNGIFNAME = 0x4020745d
_TUNSIFPID = 0x2000745f _TUNSIFPID = 0x2000745f
_SIOCGIFINFO_IN6 = 0xc048696c
_SIOCSIFINFO_IN6 = 0xc048696d
_ND6_IFF_AUTO_LINKLOCAL = 0x20
_ND6_IFF_NO_DAD = 0x100
) )
// TODO: move into x/sys/unix // Iface requests with just the name
const ( type ifreqName struct {
SIOCGIFINFO_IN6 = 0xc048696c Name [unix.IFNAMSIZ]byte
SIOCSIFINFO_IN6 = 0xc048696d _ [16]byte
ND6_IFF_AUTO_LINKLOCAL = 0x20 }
ND6_IFF_NO_DAD = 0x100
)
// Iface status string max len // Iface requests with a pointer
const _IFSTATMAX = 800 type ifreqPtr struct {
const SIZEOF_UINTPTR = 4 << (^uintptr(0) >> 32 & 1)
// structure for iface requests with a pointer
type ifreq_ptr struct {
Name [unix.IFNAMSIZ]byte Name [unix.IFNAMSIZ]byte
Data uintptr Data uintptr
Pad0 [16 - SIZEOF_UINTPTR]byte _ [16 - unsafe.Sizeof(uintptr(0))]byte
} }
// Structure for iface mtu get/set ioctls // Iface requests with MTU
type ifreq_mtu struct { type ifreqMtu struct {
Name [unix.IFNAMSIZ]byte Name [unix.IFNAMSIZ]byte
MTU uint32 MTU uint32
Pad0 [12]byte _ [12]byte
} }
// Structure for interface status request ioctl // ND6 flag manipulation
type ifstat struct { type nd6Req struct {
IfsName [unix.IFNAMSIZ]byte
Ascii [_IFSTATMAX]byte
}
// Structures for nd6 flag manipulation
type in6_ndireq struct {
Name [unix.IFNAMSIZ]byte Name [unix.IFNAMSIZ]byte
Linkmtu uint32 Linkmtu uint32
Maxmtu uint32 Maxmtu uint32
@@ -99,7 +89,7 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) {
retry: retry:
n, err := unix.Read(tun.routeSocket, data) n, err := unix.Read(tun.routeSocket, data)
if err != nil { if err != nil {
if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR { if errors.Is(err, syscall.EINTR) {
goto retry goto retry
} }
tun.errors <- err tun.errors <- err
@@ -143,91 +133,17 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) {
} }
func tunName(fd uintptr) (string, error) { func tunName(fd uintptr) (string, error) {
//Terrible hack to make up for freebsd not having a TUNGIFNAME var ifreq ifreqName
_, _, err := unix.Syscall(unix.SYS_IOCTL, fd, _TUNGIFNAME, uintptr(unsafe.Pointer(&ifreq)))
//First, make sure the tun pid matches this proc's pid if err != 0 {
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
uintptr(fd),
uintptr(_TUNSIFPID),
uintptr(0),
)
if errno != 0 {
return "", fmt.Errorf("failed to set tun device PID: %s", errno.Error())
}
// Open iface control socket
confd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
0,
)
if err != nil {
return "", err return "", err
} }
return unix.ByteSliceToString(ifreq.Name[:]), nil
defer unix.Close(confd)
procPid := os.Getpid()
//Try to find interface with matching PID
for i := 1; ; i++ {
iface, _ := net.InterfaceByIndex(i)
if err != nil || iface == nil {
break
}
// Structs for getting data in and out of SIOCGIFSTATUS ioctl
var ifstatus ifstat
copy(ifstatus.IfsName[:], iface.Name)
// Make the syscall to get the status string
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
uintptr(confd),
uintptr(unix.SIOCGIFSTATUS),
uintptr(unsafe.Pointer(&ifstatus)),
)
if errno != 0 {
continue
}
nullStr := ifstatus.Ascii[:]
i := bytes.IndexByte(nullStr, 0)
if i < 1 {
continue
}
statStr := string(nullStr[:i])
var pidNum int = 0
// Finally get the owning PID
// Format string taken from sys/net/if_tun.c
_, err := fmt.Sscanf(statStr, "\tOpened by PID %d\n", &pidNum)
if err != nil {
continue
}
if pidNum == procPid {
return iface.Name, nil
}
}
return "", nil
} }
// Destroy a named system interface // Destroy a named system interface
func tunDestroy(name string) error { func tunDestroy(name string) error {
// Open control socket. fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, 0)
var fd int
fd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
0,
)
if err != nil { if err != nil {
return err return err
} }
@@ -235,14 +151,9 @@ func tunDestroy(name string) error {
var ifr [32]byte var ifr [32]byte
copy(ifr[:], name) copy(ifr[:], name)
_, _, errno := unix.Syscall( _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(unix.SIOCIFDESTROY), uintptr(unsafe.Pointer(&ifr[0])))
unix.SYS_IOCTL,
uintptr(fd),
uintptr(unix.SIOCIFDESTROY),
uintptr(unsafe.Pointer(&ifr[0])),
)
if errno != 0 { if errno != 0 {
return fmt.Errorf("failed to destroy interface %s: %s", name, errno.Error()) return fmt.Errorf("failed to destroy interface %s: %w", name, errno)
} }
return nil return nil
@@ -278,104 +189,68 @@ func CreateTUN(name string, mtu int) (Device, error) {
ifheadmode := 1 ifheadmode := 1
var errno syscall.Errno var errno syscall.Errno
tun.operateOnFd(func(fd uintptr) { tun.operateOnFd(func(fd uintptr) {
_, _, errno = unix.Syscall( _, _, errno = unix.Syscall(unix.SYS_IOCTL, fd, _TUNSIFHEAD, uintptr(unsafe.Pointer(&ifheadmode)))
unix.SYS_IOCTL,
fd,
uintptr(_TUNSIFHEAD),
uintptr(unsafe.Pointer(&ifheadmode)),
)
}) })
if errno != 0 { if errno != 0 {
tunFile.Close() tunFile.Close()
tunDestroy(assignedName) tunDestroy(assignedName)
return nil, fmt.Errorf("Unable to put into IFHEAD mode: %w", errno) return nil, fmt.Errorf("unable to put into IFHEAD mode: %w", errno)
} }
// Get out of PPP mode. // Get out of PTP mode.
ifflags := syscall.IFF_BROADCAST ifflags := syscall.IFF_BROADCAST | syscall.IFF_MULTICAST
tun.operateOnFd(func(fd uintptr) { tun.operateOnFd(func(fd uintptr) {
_, _, errno = unix.Syscall( _, _, errno = unix.Syscall(unix.SYS_IOCTL, fd, uintptr(_TUNSIFMODE), uintptr(unsafe.Pointer(&ifflags)))
unix.SYS_IOCTL,
fd,
uintptr(_TUNSIFMODE),
uintptr(unsafe.Pointer(&ifflags)),
)
}) })
if errno != 0 { if errno != 0 {
tunFile.Close() tunFile.Close()
tunDestroy(assignedName) tunDestroy(assignedName)
return nil, fmt.Errorf("Unable to put into IFF_BROADCAST mode: %w", errno) return nil, fmt.Errorf("unable to put into IFF_BROADCAST mode: %w", errno)
} }
// Open control sockets // Disable link-local v6, not just because WireGuard doesn't do that anyway, but
confd, err := unix.Socket( // also because there are serious races with attaching and detaching LLv6 addresses
unix.AF_INET, // in relation to interface lifetime within the FreeBSD kernel.
unix.SOCK_DGRAM, confd6, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, 0)
0,
)
if err != nil {
tunFile.Close()
tunDestroy(assignedName)
return nil, err
}
defer unix.Close(confd)
confd6, err := unix.Socket(
unix.AF_INET6,
unix.SOCK_DGRAM,
0,
)
if err != nil { if err != nil {
tunFile.Close() tunFile.Close()
tunDestroy(assignedName) tunDestroy(assignedName)
return nil, err return nil, err
} }
defer unix.Close(confd6) defer unix.Close(confd6)
var ndireq nd6Req
// Disable link-local v6, not just because WireGuard doesn't do that anyway, but
// also because there are serious races with attaching and detaching LLv6 addresses
// in relation to interface lifetime within the FreeBSD kernel.
var ndireq in6_ndireq
copy(ndireq.Name[:], assignedName) copy(ndireq.Name[:], assignedName)
_, _, errno = unix.Syscall( _, _, errno = unix.Syscall(unix.SYS_IOCTL, uintptr(confd6), uintptr(_SIOCGIFINFO_IN6), uintptr(unsafe.Pointer(&ndireq)))
unix.SYS_IOCTL,
uintptr(confd6),
uintptr(SIOCGIFINFO_IN6),
uintptr(unsafe.Pointer(&ndireq)),
)
if errno != 0 { if errno != 0 {
tunFile.Close() tunFile.Close()
tunDestroy(assignedName) tunDestroy(assignedName)
return nil, fmt.Errorf("Unable to get nd6 flags for %s: %w", assignedName, errno) return nil, fmt.Errorf("unable to get nd6 flags for %s: %w", assignedName, errno)
} }
ndireq.Flags = ndireq.Flags &^ ND6_IFF_AUTO_LINKLOCAL ndireq.Flags = ndireq.Flags &^ _ND6_IFF_AUTO_LINKLOCAL
ndireq.Flags = ndireq.Flags | ND6_IFF_NO_DAD ndireq.Flags = ndireq.Flags | _ND6_IFF_NO_DAD
_, _, errno = unix.Syscall( _, _, errno = unix.Syscall(unix.SYS_IOCTL, uintptr(confd6), uintptr(_SIOCSIFINFO_IN6), uintptr(unsafe.Pointer(&ndireq)))
unix.SYS_IOCTL,
uintptr(confd6),
uintptr(SIOCSIFINFO_IN6),
uintptr(unsafe.Pointer(&ndireq)),
)
if errno != 0 { if errno != 0 {
tunFile.Close() tunFile.Close()
tunDestroy(assignedName) tunDestroy(assignedName)
return nil, fmt.Errorf("Unable to set nd6 flags for %s: %w", assignedName, errno) return nil, fmt.Errorf("unable to set nd6 flags for %s: %w", assignedName, errno)
} }
if name != "" { if name != "" {
// Rename the interface confd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, 0)
if err != nil {
tunFile.Close()
tunDestroy(assignedName)
return nil, err
}
defer unix.Close(confd)
var newnp [unix.IFNAMSIZ]byte var newnp [unix.IFNAMSIZ]byte
copy(newnp[:], name) copy(newnp[:], name)
var ifr ifreq_ptr var ifr ifreqPtr
copy(ifr.Name[:], assignedName) copy(ifr.Name[:], assignedName)
ifr.Data = uintptr(unsafe.Pointer(&newnp[0])) ifr.Data = uintptr(unsafe.Pointer(&newnp[0]))
_, _, errno = unix.Syscall( _, _, errno = unix.Syscall(unix.SYS_IOCTL, uintptr(confd), uintptr(unix.SIOCSIFNAME), uintptr(unsafe.Pointer(&ifr)))
unix.SYS_IOCTL,
uintptr(confd),
uintptr(unix.SIOCSIFNAME),
uintptr(unsafe.Pointer(&ifr)),
)
if errno != 0 { if errno != 0 {
tunFile.Close() tunFile.Close()
tunDestroy(assignedName) tunDestroy(assignedName)
@@ -387,13 +262,21 @@ func CreateTUN(name string, mtu int) (Device, error) {
} }
func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
tun := &NativeTun{ tun := &NativeTun{
tunFile: file, tunFile: file,
events: make(chan Event, 10), events: make(chan Event, 10),
errors: make(chan error, 1), errors: make(chan error, 1),
} }
var errno syscall.Errno
tun.operateOnFd(func(fd uintptr) {
_, _, errno = unix.Syscall(unix.SYS_IOCTL, fd, _TUNSIFPID, uintptr(0))
})
if errno != 0 {
tun.tunFile.Close()
return nil, fmt.Errorf("unable to become controlling TUN process: %w", errno)
}
name, err := tun.Name() name, err := tun.Name()
if err != nil { if err != nil {
tun.tunFile.Close() tun.tunFile.Close()
@@ -464,27 +347,26 @@ func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
} }
} }
func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { func (tun *NativeTun) Write(buf []byte, offset int) (int, error) {
if offset < 4 {
// reserve space for header return 0, io.ErrShortBuffer
buff = buff[offset-4:]
// add packet information header
buff[0] = 0x00
buff[1] = 0x00
buff[2] = 0x00
if buff[4]>>4 == ipv6.Version {
buff[3] = unix.AF_INET6
} else {
buff[3] = unix.AF_INET
} }
buf = buf[offset-4:]
// write if len(buf) < 5 {
return 0, io.ErrShortBuffer
return tun.tunFile.Write(buff) }
buf[0] = 0x00
buf[1] = 0x00
buf[2] = 0x00
switch buf[4] >> 4 {
case 4:
buf[3] = unix.AF_INET
case 6:
buf[3] = unix.AF_INET6
default:
return 0, unix.EAFNOSUPPORT
}
return tun.tunFile.Write(buf)
} }
func (tun *NativeTun) Flush() error { func (tun *NativeTun) Flush() error {
@@ -515,70 +397,34 @@ func (tun *NativeTun) Close() error {
} }
func (tun *NativeTun) setMTU(n int) error { func (tun *NativeTun) setMTU(n int) error {
// open datagram socket fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, 0)
var fd int
fd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
0,
)
if err != nil { if err != nil {
return err return err
} }
defer unix.Close(fd) defer unix.Close(fd)
// do ioctl call var ifr ifreqMtu
var ifr ifreq_mtu
copy(ifr.Name[:], tun.name) copy(ifr.Name[:], tun.name)
ifr.MTU = uint32(n) ifr.MTU = uint32(n)
_, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(unix.SIOCSIFMTU), uintptr(unsafe.Pointer(&ifr)))
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
uintptr(fd),
uintptr(unix.SIOCSIFMTU),
uintptr(unsafe.Pointer(&ifr)),
)
if errno != 0 { if errno != 0 {
return fmt.Errorf("failed to set MTU on %s", tun.name) return fmt.Errorf("failed to set MTU on %s: %w", tun.name, errno)
} }
return nil return nil
} }
func (tun *NativeTun) MTU() (int, error) { func (tun *NativeTun) MTU() (int, error) {
// open datagram socket fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, 0)
fd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
0,
)
if err != nil { if err != nil {
return 0, err return 0, err
} }
defer unix.Close(fd) defer unix.Close(fd)
// do ioctl call var ifr ifreqMtu
var ifr ifreq_mtu
copy(ifr.Name[:], tun.name) copy(ifr.Name[:], tun.name)
_, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(unix.SIOCGIFMTU), uintptr(unsafe.Pointer(&ifr)))
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
uintptr(fd),
uintptr(unix.SIOCGIFMTU),
uintptr(unsafe.Pointer(&ifr)),
)
if errno != 0 { if errno != 0 {
return 0, fmt.Errorf("failed to get MTU on %s", tun.name) return 0, fmt.Errorf("failed to get MTU on %s: %w", tun.name, errno)
} }
return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil
} }

View File

@@ -10,6 +10,7 @@ package tun
import ( import (
"bytes" "bytes"
"errors"
"fmt" "fmt"
"os" "os"
"sync" "sync"
@@ -329,32 +330,30 @@ func (tun *NativeTun) nameSlow() (string, error) {
return string(name), nil return string(name), nil
} }
func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { func (tun *NativeTun) Write(buf []byte, offset int) (int, error) {
if tun.nopi { if tun.nopi {
buff = buff[offset:] buf = buf[offset:]
} else { } else {
// reserve space for header // reserve space for header
buf = buf[offset-4:]
buff = buff[offset-4:]
// add packet information header // add packet information header
buf[0] = 0x00
buff[0] = 0x00 buf[1] = 0x00
buff[1] = 0x00 if buf[4]>>4 == ipv6.Version {
buf[2] = 0x86
if buff[4]>>4 == ipv6.Version { buf[3] = 0xdd
buff[2] = 0x86
buff[3] = 0xdd
} else { } else {
buff[2] = 0x08 buf[2] = 0x08
buff[3] = 0x00 buf[3] = 0x00
} }
} }
// write n, err := tun.tunFile.Write(buf)
if errors.Is(err, syscall.EBADFD) {
return tun.tunFile.Write(buff) err = os.ErrClosed
}
return n, err
} }
func (tun *NativeTun) Flush() error { func (tun *NativeTun) Flush() error {
@@ -362,22 +361,26 @@ func (tun *NativeTun) Flush() error {
return nil return nil
} }
func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { func (tun *NativeTun) Read(buf []byte, offset int) (n int, err error) {
select { select {
case err := <-tun.errors: case err = <-tun.errors:
return 0, err
default: default:
if tun.nopi { if tun.nopi {
return tun.tunFile.Read(buff[offset:]) n, err = tun.tunFile.Read(buf[offset:])
} else { } else {
buff := buff[offset-4:] buff := buf[offset-4:]
n, err := tun.tunFile.Read(buff[:]) n, err = tun.tunFile.Read(buff[:])
if n < 4 { if errors.Is(err, syscall.EBADFD) {
return 0, err err = os.ErrClosed
}
if n < 4 {
n = 0
} else {
n -= 4
} }
return n - 4, err
} }
} }
return
} }
func (tun *NativeTun) Events() chan Event { func (tun *NativeTun) Events() chan Event {
@@ -416,6 +419,7 @@ func CreateTUN(name string, mtu int) (Device, error) {
var flags uint16 = unix.IFF_TUN // | unix.IFF_NO_PI (disabled for TUN status hack) var flags uint16 = unix.IFF_TUN // | unix.IFF_NO_PI (disabled for TUN status hack)
nameBytes := []byte(name) nameBytes := []byte(name)
if len(nameBytes) >= unix.IFNAMSIZ { if len(nameBytes) >= unix.IFNAMSIZ {
unix.Close(nfd)
return nil, fmt.Errorf("interface name too long: %w", unix.ENAMETOOLONG) return nil, fmt.Errorf("interface name too long: %w", unix.ENAMETOOLONG)
} }
copy(ifr[:], nameBytes) copy(ifr[:], nameBytes)
@@ -428,17 +432,19 @@ func CreateTUN(name string, mtu int) (Device, error) {
uintptr(unsafe.Pointer(&ifr[0])), uintptr(unsafe.Pointer(&ifr[0])),
) )
if errno != 0 { if errno != 0 {
unix.Close(nfd)
return nil, errno return nil, errno
} }
err = unix.SetNonblock(nfd, true) err = unix.SetNonblock(nfd, true)
if err != nil {
unix.Close(nfd)
return nil, err
}
// Note that the above -- open,ioctl,nonblock -- must happen prior to handing it to netpoll as below this line. // Note that the above -- open,ioctl,nonblock -- must happen prior to handing it to netpoll as below this line.
fd := os.NewFile(uintptr(nfd), cloneDevicePath) fd := os.NewFile(uintptr(nfd), cloneDevicePath)
if err != nil {
return nil, err
}
return CreateTUNFromFile(fd, mtu) return CreateTUNFromFile(fd, mtu)
} }

View File

@@ -8,7 +8,6 @@ package tun
import ( import (
"errors" "errors"
"fmt" "fmt"
"log"
"os" "os"
"sync" "sync"
"sync/atomic" "sync/atomic"
@@ -35,18 +34,19 @@ type rateJuggler struct {
type NativeTun struct { type NativeTun struct {
wt *wintun.Adapter wt *wintun.Adapter
name string
handle windows.Handle handle windows.Handle
close bool
events chan Event
errors chan error
forcedMTU int
rate rateJuggler rate rateJuggler
session wintun.Session session wintun.Session
readWait windows.Handle readWait windows.Handle
events chan Event
running sync.WaitGroup
closeOnce sync.Once closeOnce sync.Once
close int32
forcedMTU int
} }
var WintunPool, _ = wintun.MakePool("WireGuard") var WintunTunnelType = "WireGuard"
var WintunStaticRequestedGUID *windows.GUID var WintunStaticRequestedGUID *windows.GUID
//go:linkname procyield runtime.procyield //go:linkname procyield runtime.procyield
@@ -68,25 +68,10 @@ func CreateTUN(ifname string, mtu int) (Device, error) {
// a requested GUID. Should a Wintun interface with the same name exist, it is reused. // a requested GUID. Should a Wintun interface with the same name exist, it is reused.
// //
func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) { func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) {
var err error wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID)
var wt *wintun.Adapter
// Does an interface with this name already exist?
wt, err = WintunPool.OpenAdapter(ifname)
if err == nil {
// If so, we delete it, in case it has weird residual configuration.
_, err = wt.Delete(true)
if err != nil {
return nil, fmt.Errorf("Error deleting already existing interface: %w", err)
}
}
wt, rebootRequired, err := WintunPool.CreateAdapter(ifname, requestedGUID)
if err != nil { if err != nil {
return nil, fmt.Errorf("Error creating interface: %w", err) return nil, fmt.Errorf("Error creating interface: %w", err)
} }
if rebootRequired {
log.Println("Windows indicated a reboot is required.")
}
forcedMTU := 1420 forcedMTU := 1420
if mtu > 0 { if mtu > 0 {
@@ -95,15 +80,15 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu
tun := &NativeTun{ tun := &NativeTun{
wt: wt, wt: wt,
name: ifname,
handle: windows.InvalidHandle, handle: windows.InvalidHandle,
events: make(chan Event, 10), events: make(chan Event, 10),
errors: make(chan error, 1),
forcedMTU: forcedMTU, forcedMTU: forcedMTU,
} }
tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB
if err != nil { if err != nil {
tun.wt.Delete(false) tun.wt.Close()
close(tun.events) close(tun.events)
return nil, fmt.Errorf("Error starting session: %w", err) return nil, fmt.Errorf("Error starting session: %w", err)
} }
@@ -112,7 +97,7 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu
} }
func (tun *NativeTun) Name() (string, error) { func (tun *NativeTun) Name() (string, error) {
return tun.wt.Name() return tun.name, nil
} }
func (tun *NativeTun) File() *os.File { func (tun *NativeTun) File() *os.File {
@@ -126,10 +111,12 @@ func (tun *NativeTun) Events() chan Event {
func (tun *NativeTun) Close() error { func (tun *NativeTun) Close() error {
var err error var err error
tun.closeOnce.Do(func() { tun.closeOnce.Do(func() {
tun.close = true atomic.StoreInt32(&tun.close, 1)
windows.SetEvent(tun.readWait)
tun.running.Wait()
tun.session.End() tun.session.End()
if tun.wt != nil { if tun.wt != nil {
_, err = tun.wt.Delete(false) tun.wt.Close()
} }
close(tun.events) close(tun.events)
}) })
@@ -142,22 +129,26 @@ func (tun *NativeTun) MTU() (int, error) {
// TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes. // TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes.
func (tun *NativeTun) ForceMTU(mtu int) { func (tun *NativeTun) ForceMTU(mtu int) {
update := tun.forcedMTU != mtu
tun.forcedMTU = mtu tun.forcedMTU = mtu
if update {
tun.events <- EventMTUUpdate
}
} }
// Note: Read() and Write() assume the caller comes only from a single thread; there's no locking. // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking.
func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
tun.running.Add(1)
defer tun.running.Done()
retry: retry:
select { if atomic.LoadInt32(&tun.close) == 1 {
case err := <-tun.errors: return 0, os.ErrClosed
return 0, err
default:
} }
start := nanotime() start := nanotime()
shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2 shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2
for { for {
if tun.close { if atomic.LoadInt32(&tun.close) == 1 {
return 0, os.ErrClosed return 0, os.ErrClosed
} }
packet, err := tun.session.ReceivePacket() packet, err := tun.session.ReceivePacket()
@@ -189,7 +180,9 @@ func (tun *NativeTun) Flush() error {
} }
func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
if tun.close { tun.running.Add(1)
defer tun.running.Done()
if atomic.LoadInt32(&tun.close) == 1 {
return 0, os.ErrClosed return 0, os.ErrClosed
} }
@@ -213,6 +206,11 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
// LUID returns Windows interface instance ID. // LUID returns Windows interface instance ID.
func (tun *NativeTun) LUID() uint64 { func (tun *NativeTun) LUID() uint64 {
tun.running.Add(1)
defer tun.running.Done()
if atomic.LoadInt32(&tun.close) == 1 {
return 0
}
return tun.wt.LUID() return tun.wt.LUID()
} }

View File

@@ -1,54 +0,0 @@
// +build !load_wintun_from_rsrc
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package wintun
import (
"fmt"
"sync"
"sync/atomic"
"unsafe"
"golang.org/x/sys/windows"
)
type lazyDLL struct {
Name string
mu sync.Mutex
module windows.Handle
onLoad func(d *lazyDLL)
}
func (d *lazyDLL) Load() error {
if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&d.module))) != nil {
return nil
}
d.mu.Lock()
defer d.mu.Unlock()
if d.module != 0 {
return nil
}
const (
LOAD_LIBRARY_SEARCH_APPLICATION_DIR = 0x00000200
LOAD_LIBRARY_SEARCH_SYSTEM32 = 0x00000800
)
module, err := windows.LoadLibraryEx(d.Name, 0, LOAD_LIBRARY_SEARCH_APPLICATION_DIR|LOAD_LIBRARY_SEARCH_SYSTEM32)
if err != nil {
return fmt.Errorf("Unable to load library: %w", err)
}
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&d.module)), unsafe.Pointer(module))
if d.onLoad != nil {
d.onLoad(d)
}
return nil
}
func (p *lazyProc) nameToAddr() (uintptr, error) {
return windows.GetProcAddress(p.dll.module, p.Name)
}

View File

@@ -1,61 +0,0 @@
// +build load_wintun_from_rsrc
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package wintun
import (
"fmt"
"sync"
"sync/atomic"
"unsafe"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/tun/wintun/memmod"
)
type lazyDLL struct {
Name string
mu sync.Mutex
module *memmod.Module
onLoad func(d *lazyDLL)
}
func (d *lazyDLL) Load() error {
if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&d.module))) != nil {
return nil
}
d.mu.Lock()
defer d.mu.Unlock()
if d.module != nil {
return nil
}
const ourModule windows.Handle = 0
resInfo, err := windows.FindResource(ourModule, d.Name, windows.RT_RCDATA)
if err != nil {
return fmt.Errorf("Unable to find \"%v\" RCDATA resource: %w", d.Name, err)
}
data, err := windows.LoadResourceData(ourModule, resInfo)
if err != nil {
return fmt.Errorf("Unable to load resource: %w", err)
}
module, err := memmod.LoadLibrary(data)
if err != nil {
return fmt.Errorf("Unable to load library: %w", err)
}
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&d.module)), unsafe.Pointer(module))
if d.onLoad != nil {
d.onLoad(d)
}
return nil
}
func (p *lazyProc) nameToAddr() (uintptr, error) {
return p.dll.module.ProcAddressByName(p.Name)
}

View File

@@ -10,6 +10,8 @@ import (
"sync" "sync"
"sync/atomic" "sync/atomic"
"unsafe" "unsafe"
"golang.org/x/sys/windows"
) )
func newLazyDLL(name string, onLoad func(d *lazyDLL)) *lazyDLL { func newLazyDLL(name string, onLoad func(d *lazyDLL)) *lazyDLL {
@@ -57,3 +59,40 @@ func (p *lazyProc) Addr() uintptr {
} }
return p.addr return p.addr
} }
type lazyDLL struct {
Name string
mu sync.Mutex
module windows.Handle
onLoad func(d *lazyDLL)
}
func (d *lazyDLL) Load() error {
if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&d.module))) != nil {
return nil
}
d.mu.Lock()
defer d.mu.Unlock()
if d.module != 0 {
return nil
}
const (
LOAD_LIBRARY_SEARCH_APPLICATION_DIR = 0x00000200
LOAD_LIBRARY_SEARCH_SYSTEM32 = 0x00000800
)
module, err := windows.LoadLibraryEx(d.Name, 0, LOAD_LIBRARY_SEARCH_APPLICATION_DIR|LOAD_LIBRARY_SEARCH_SYSTEM32)
if err != nil {
return fmt.Errorf("Unable to load library: %w", err)
}
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&d.module)), unsafe.Pointer(module))
if d.onLoad != nil {
d.onLoad(d)
}
return nil
}
func (p *lazyProc) nameToAddr() (uintptr, error) {
return windows.GetProcAddress(p.dll.module, p.Name)
}

View File

@@ -1,620 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package memmod
import (
"errors"
"fmt"
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
type addressList struct {
next *addressList
address uintptr
}
func (head *addressList) free() {
for node := head; node != nil; node = node.next {
windows.VirtualFree(node.address, 0, windows.MEM_RELEASE)
}
}
type Module struct {
headers *IMAGE_NT_HEADERS
codeBase uintptr
modules []windows.Handle
initialized bool
isDLL bool
isRelocated bool
nameExports map[string]uint16
entry uintptr
blockedMemory *addressList
}
func (module *Module) headerDirectory(idx int) *IMAGE_DATA_DIRECTORY {
return &module.headers.OptionalHeader.DataDirectory[idx]
}
func (module *Module) copySections(address uintptr, size uintptr, old_headers *IMAGE_NT_HEADERS) error {
sections := module.headers.Sections()
for i := range sections {
if sections[i].SizeOfRawData == 0 {
// Section doesn't contain data in the dll itself, but may define uninitialized data.
sectionSize := old_headers.OptionalHeader.SectionAlignment
if sectionSize == 0 {
continue
}
dest, err := windows.VirtualAlloc(module.codeBase+uintptr(sections[i].VirtualAddress),
uintptr(sectionSize),
windows.MEM_COMMIT,
windows.PAGE_READWRITE)
if err != nil {
return fmt.Errorf("Error allocating section: %w", err)
}
// Always use position from file to support alignments smaller than page size (allocation above will align to page size).
dest = module.codeBase + uintptr(sections[i].VirtualAddress)
// NOTE: On 64bit systems we truncate to 32bit here but expand again later when "PhysicalAddress" is used.
sections[i].SetPhysicalAddress((uint32)(dest & 0xffffffff))
var dst []byte
unsafeSlice(unsafe.Pointer(&dst), a2p(dest), int(sectionSize))
for j := range dst {
dst[j] = 0
}
continue
}
if size < uintptr(sections[i].PointerToRawData+sections[i].SizeOfRawData) {
return errors.New("Incomplete section")
}
// Commit memory block and copy data from dll.
dest, err := windows.VirtualAlloc(module.codeBase+uintptr(sections[i].VirtualAddress),
uintptr(sections[i].SizeOfRawData),
windows.MEM_COMMIT,
windows.PAGE_READWRITE)
if err != nil {
return fmt.Errorf("Error allocating memory block: %w", err)
}
// Always use position from file to support alignments smaller than page size (allocation above will align to page size).
memcpy(
module.codeBase+uintptr(sections[i].VirtualAddress),
address+uintptr(sections[i].PointerToRawData),
uintptr(sections[i].SizeOfRawData))
// NOTE: On 64bit systems we truncate to 32bit here but expand again later when "PhysicalAddress" is used.
sections[i].SetPhysicalAddress((uint32)(dest & 0xffffffff))
}
return nil
}
func (module *Module) realSectionSize(section *IMAGE_SECTION_HEADER) uintptr {
size := section.SizeOfRawData
if size != 0 {
return uintptr(size)
}
if (section.Characteristics & IMAGE_SCN_CNT_INITIALIZED_DATA) != 0 {
return uintptr(module.headers.OptionalHeader.SizeOfInitializedData)
}
if (section.Characteristics & IMAGE_SCN_CNT_UNINITIALIZED_DATA) != 0 {
return uintptr(module.headers.OptionalHeader.SizeOfUninitializedData)
}
return 0
}
type sectionFinalizeData struct {
address uintptr
alignedAddress uintptr
size uintptr
characteristics uint32
last bool
}
func (module *Module) finalizeSection(sectionData *sectionFinalizeData) error {
if sectionData.size == 0 {
return nil
}
if (sectionData.characteristics & IMAGE_SCN_MEM_DISCARDABLE) != 0 {
// Section is not needed any more and can safely be freed.
if sectionData.address == sectionData.alignedAddress &&
(sectionData.last ||
(sectionData.size%uintptr(module.headers.OptionalHeader.SectionAlignment)) == 0) {
// Only allowed to decommit whole pages.
windows.VirtualFree(sectionData.address, sectionData.size, windows.MEM_DECOMMIT)
}
return nil
}
// determine protection flags based on characteristics
var ProtectionFlags = [8]uint32{
windows.PAGE_NOACCESS, // not writeable, not readable, not executable
windows.PAGE_EXECUTE, // not writeable, not readable, executable
windows.PAGE_READONLY, // not writeable, readable, not executable
windows.PAGE_EXECUTE_READ, // not writeable, readable, executable
windows.PAGE_WRITECOPY, // writeable, not readable, not executable
windows.PAGE_EXECUTE_WRITECOPY, // writeable, not readable, executable
windows.PAGE_READWRITE, // writeable, readable, not executable
windows.PAGE_EXECUTE_READWRITE, // writeable, readable, executable
}
protect := ProtectionFlags[sectionData.characteristics>>29]
if (sectionData.characteristics & IMAGE_SCN_MEM_NOT_CACHED) != 0 {
protect |= windows.PAGE_NOCACHE
}
// Change memory access flags.
var oldProtect uint32
err := windows.VirtualProtect(sectionData.address, sectionData.size, protect, &oldProtect)
if err != nil {
return fmt.Errorf("Error protecting memory page: %w", err)
}
return nil
}
func (module *Module) finalizeSections() error {
sections := module.headers.Sections()
imageOffset := module.headers.OptionalHeader.imageOffset()
sectionData := sectionFinalizeData{}
sectionData.address = uintptr(sections[0].PhysicalAddress()) | imageOffset
sectionData.alignedAddress = alignDown(sectionData.address, uintptr(module.headers.OptionalHeader.SectionAlignment))
sectionData.size = module.realSectionSize(&sections[0])
sectionData.characteristics = sections[0].Characteristics
// Loop through all sections and change access flags.
for i := uint16(1); i < module.headers.FileHeader.NumberOfSections; i++ {
sectionAddress := uintptr(sections[i].PhysicalAddress()) | imageOffset
alignedAddress := alignDown(sectionAddress, uintptr(module.headers.OptionalHeader.SectionAlignment))
sectionSize := module.realSectionSize(&sections[i])
// Combine access flags of all sections that share a page.
// TODO: We currently share flags of a trailing large section with the page of a first small section. This should be optimized.
if sectionData.alignedAddress == alignedAddress || sectionData.address+sectionData.size > alignedAddress {
// Section shares page with previous.
if (sections[i].Characteristics&IMAGE_SCN_MEM_DISCARDABLE) == 0 || (sectionData.characteristics&IMAGE_SCN_MEM_DISCARDABLE) == 0 {
sectionData.characteristics = (sectionData.characteristics | sections[i].Characteristics) &^ IMAGE_SCN_MEM_DISCARDABLE
} else {
sectionData.characteristics |= sections[i].Characteristics
}
sectionData.size = sectionAddress + sectionSize - sectionData.address
continue
}
err := module.finalizeSection(&sectionData)
if err != nil {
return fmt.Errorf("Error finalizing section: %w", err)
}
sectionData.address = sectionAddress
sectionData.alignedAddress = alignedAddress
sectionData.size = sectionSize
sectionData.characteristics = sections[i].Characteristics
}
sectionData.last = true
err := module.finalizeSection(&sectionData)
if err != nil {
return fmt.Errorf("Error finalizing section: %w", err)
}
return nil
}
func (module *Module) executeTLS() {
directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_TLS)
if directory.VirtualAddress == 0 {
return
}
tls := (*IMAGE_TLS_DIRECTORY)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
callback := tls.AddressOfCallbacks
if callback != 0 {
for {
f := *(*uintptr)(a2p(callback))
if f == 0 {
break
}
syscall.Syscall(f, 3, module.codeBase, uintptr(DLL_PROCESS_ATTACH), uintptr(0))
callback += unsafe.Sizeof(f)
}
}
}
func (module *Module) performBaseRelocation(delta uintptr) (relocated bool, err error) {
directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_BASERELOC)
if directory.Size == 0 {
return delta == 0, nil
}
relocationHdr := (*IMAGE_BASE_RELOCATION)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
for relocationHdr.VirtualAddress > 0 {
dest := module.codeBase + uintptr(relocationHdr.VirtualAddress)
var relInfos []uint16
unsafeSlice(
unsafe.Pointer(&relInfos),
a2p(uintptr(unsafe.Pointer(relocationHdr))+unsafe.Sizeof(*relocationHdr)),
int((uintptr(relocationHdr.SizeOfBlock)-unsafe.Sizeof(*relocationHdr))/unsafe.Sizeof(relInfos[0])))
for _, relInfo := range relInfos {
// The upper 4 bits define the type of relocation.
relType := relInfo >> 12
// The lower 12 bits define the offset.
relOffset := uintptr(relInfo & 0xfff)
switch relType {
case IMAGE_REL_BASED_ABSOLUTE:
// Skip relocation.
case IMAGE_REL_BASED_LOW:
*(*uint16)(a2p(dest + relOffset)) += uint16(delta & 0xffff)
break
case IMAGE_REL_BASED_HIGH:
*(*uint16)(a2p(dest + relOffset)) += uint16(uint32(delta) >> 16)
break
case IMAGE_REL_BASED_HIGHLOW:
*(*uint32)(a2p(dest + relOffset)) += uint32(delta)
case IMAGE_REL_BASED_DIR64:
*(*uint64)(a2p(dest + relOffset)) += uint64(delta)
case IMAGE_REL_BASED_THUMB_MOV32:
inst := *(*uint32)(a2p(dest + relOffset))
imm16 := ((inst << 1) & 0x0800) + ((inst << 12) & 0xf000) +
((inst >> 20) & 0x0700) + ((inst >> 16) & 0x00ff)
if (inst & 0x8000fbf0) != 0x0000f240 {
return false, fmt.Errorf("Wrong Thumb2 instruction %08x, expected MOVW", inst)
}
imm16 += uint32(delta) & 0xffff
hiDelta := (uint32(delta&0xffff0000) >> 16) + ((imm16 & 0xffff0000) >> 16)
*(*uint32)(a2p(dest + relOffset)) = (inst & 0x8f00fbf0) + ((imm16 >> 1) & 0x0400) +
((imm16 >> 12) & 0x000f) +
((imm16 << 20) & 0x70000000) +
((imm16 << 16) & 0xff0000)
if hiDelta != 0 {
inst = *(*uint32)(a2p(dest + relOffset + 4))
imm16 = ((inst << 1) & 0x0800) + ((inst << 12) & 0xf000) +
((inst >> 20) & 0x0700) + ((inst >> 16) & 0x00ff)
if (inst & 0x8000fbf0) != 0x0000f2c0 {
return false, fmt.Errorf("Wrong Thumb2 instruction %08x, expected MOVT", inst)
}
imm16 += hiDelta
if imm16 > 0xffff {
return false, fmt.Errorf("Resulting immediate value won't fit: %08x", imm16)
}
*(*uint32)(a2p(dest + relOffset + 4)) = (inst & 0x8f00fbf0) +
((imm16 >> 1) & 0x0400) +
((imm16 >> 12) & 0x000f) +
((imm16 << 20) & 0x70000000) +
((imm16 << 16) & 0xff0000)
}
default:
return false, fmt.Errorf("Unsupported relocation: %v", relType)
}
}
// Advance to next relocation block.
relocationHdr = (*IMAGE_BASE_RELOCATION)(a2p(uintptr(unsafe.Pointer(relocationHdr)) + uintptr(relocationHdr.SizeOfBlock)))
}
return true, nil
}
func (module *Module) buildImportTable() error {
directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_IMPORT)
if directory.Size == 0 {
return nil
}
module.modules = make([]windows.Handle, 0, 16)
importDesc := (*IMAGE_IMPORT_DESCRIPTOR)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
for importDesc.Name != 0 {
handle, err := windows.LoadLibraryEx(windows.BytePtrToString((*byte)(a2p(module.codeBase+uintptr(importDesc.Name)))), 0, windows.LOAD_LIBRARY_SEARCH_SYSTEM32)
if err != nil {
return fmt.Errorf("Error loading module: %w", err)
}
var thunkRef, funcRef *uintptr
if importDesc.OriginalFirstThunk() != 0 {
thunkRef = (*uintptr)(a2p(module.codeBase + uintptr(importDesc.OriginalFirstThunk())))
funcRef = (*uintptr)(a2p(module.codeBase + uintptr(importDesc.FirstThunk)))
} else {
// No hint table.
thunkRef = (*uintptr)(a2p(module.codeBase + uintptr(importDesc.FirstThunk)))
funcRef = (*uintptr)(a2p(module.codeBase + uintptr(importDesc.FirstThunk)))
}
for *thunkRef != 0 {
if IMAGE_SNAP_BY_ORDINAL(*thunkRef) {
*funcRef, err = windows.GetProcAddressByOrdinal(handle, IMAGE_ORDINAL(*thunkRef))
} else {
thunkData := (*IMAGE_IMPORT_BY_NAME)(a2p(module.codeBase + *thunkRef))
*funcRef, err = windows.GetProcAddress(handle, windows.BytePtrToString(&thunkData.Name[0]))
}
if err != nil {
windows.FreeLibrary(handle)
return fmt.Errorf("Error getting function address: %w", err)
}
thunkRef = (*uintptr)(a2p(uintptr(unsafe.Pointer(thunkRef)) + unsafe.Sizeof(*thunkRef)))
funcRef = (*uintptr)(a2p(uintptr(unsafe.Pointer(funcRef)) + unsafe.Sizeof(*funcRef)))
}
module.modules = append(module.modules, handle)
importDesc = (*IMAGE_IMPORT_DESCRIPTOR)(a2p(uintptr(unsafe.Pointer(importDesc)) + unsafe.Sizeof(*importDesc)))
}
return nil
}
func (module *Module) buildNameExports() error {
directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_EXPORT)
if directory.Size == 0 {
return errors.New("No export table found")
}
exports := (*IMAGE_EXPORT_DIRECTORY)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
if exports.NumberOfNames == 0 || exports.NumberOfFunctions == 0 {
return errors.New("No functions exported")
}
if exports.NumberOfNames == 0 {
return errors.New("No functions exported by name")
}
var nameRefs []uint32
unsafeSlice(unsafe.Pointer(&nameRefs), a2p(module.codeBase+uintptr(exports.AddressOfNames)), int(exports.NumberOfNames))
var ordinals []uint16
unsafeSlice(unsafe.Pointer(&ordinals), a2p(module.codeBase+uintptr(exports.AddressOfNameOrdinals)), int(exports.NumberOfNames))
module.nameExports = make(map[string]uint16)
for i := range nameRefs {
nameArray := windows.BytePtrToString((*byte)(a2p(module.codeBase + uintptr(nameRefs[i]))))
module.nameExports[nameArray] = ordinals[i]
}
return nil
}
// LoadLibrary loads module image to memory.
func LoadLibrary(data []byte) (module *Module, err error) {
addr := uintptr(unsafe.Pointer(&data[0]))
size := uintptr(len(data))
if size < unsafe.Sizeof(IMAGE_DOS_HEADER{}) {
return nil, errors.New("Incomplete IMAGE_DOS_HEADER")
}
dosHeader := (*IMAGE_DOS_HEADER)(a2p(addr))
if dosHeader.E_magic != IMAGE_DOS_SIGNATURE {
return nil, fmt.Errorf("Not an MS-DOS binary (provided: %x, expected: %x)", dosHeader.E_magic, IMAGE_DOS_SIGNATURE)
}
if (size < uintptr(dosHeader.E_lfanew)+unsafe.Sizeof(IMAGE_NT_HEADERS{})) {
return nil, errors.New("Incomplete IMAGE_NT_HEADERS")
}
oldHeader := (*IMAGE_NT_HEADERS)(a2p(addr + uintptr(dosHeader.E_lfanew)))
if oldHeader.Signature != IMAGE_NT_SIGNATURE {
return nil, fmt.Errorf("Not an NT binary (provided: %x, expected: %x)", oldHeader.Signature, IMAGE_NT_SIGNATURE)
}
if oldHeader.FileHeader.Machine != imageFileProcess {
return nil, fmt.Errorf("Foreign platform (provided: %x, expected: %x)", oldHeader.FileHeader.Machine, imageFileProcess)
}
if (oldHeader.OptionalHeader.SectionAlignment & 1) != 0 {
return nil, errors.New("Unaligned section")
}
lastSectionEnd := uintptr(0)
sections := oldHeader.Sections()
optionalSectionSize := oldHeader.OptionalHeader.SectionAlignment
for i := range sections {
var endOfSection uintptr
if sections[i].SizeOfRawData == 0 {
// Section without data in the DLL
endOfSection = uintptr(sections[i].VirtualAddress) + uintptr(optionalSectionSize)
} else {
endOfSection = uintptr(sections[i].VirtualAddress) + uintptr(sections[i].SizeOfRawData)
}
if endOfSection > lastSectionEnd {
lastSectionEnd = endOfSection
}
}
alignedImageSize := alignUp(uintptr(oldHeader.OptionalHeader.SizeOfImage), uintptr(oldHeader.OptionalHeader.SectionAlignment))
if alignedImageSize != alignUp(lastSectionEnd, uintptr(oldHeader.OptionalHeader.SectionAlignment)) {
return nil, errors.New("Section is not page-aligned")
}
module = &Module{isDLL: (oldHeader.FileHeader.Characteristics & IMAGE_FILE_DLL) != 0}
defer func() {
if err != nil {
module.Free()
module = nil
}
}()
// Reserve memory for image of library.
// TODO: Is it correct to commit the complete memory region at once? Calling DllEntry raises an exception if we don't.
module.codeBase, err = windows.VirtualAlloc(oldHeader.OptionalHeader.ImageBase,
alignedImageSize,
windows.MEM_RESERVE|windows.MEM_COMMIT,
windows.PAGE_READWRITE)
if err != nil {
// Try to allocate memory at arbitrary position.
module.codeBase, err = windows.VirtualAlloc(0,
alignedImageSize,
windows.MEM_RESERVE|windows.MEM_COMMIT,
windows.PAGE_READWRITE)
if err != nil {
err = fmt.Errorf("Error allocating code: %w", err)
return
}
}
err = module.check4GBBoundaries(alignedImageSize)
if err != nil {
err = fmt.Errorf("Error reallocating code: %w", err)
return
}
if size < uintptr(oldHeader.OptionalHeader.SizeOfHeaders) {
err = errors.New("Incomplete headers")
return
}
// Commit memory for headers.
headers, err := windows.VirtualAlloc(module.codeBase,
uintptr(oldHeader.OptionalHeader.SizeOfHeaders),
windows.MEM_COMMIT,
windows.PAGE_READWRITE)
if err != nil {
err = fmt.Errorf("Error allocating headers: %w", err)
return
}
// Copy PE header to code.
memcpy(headers, addr, uintptr(oldHeader.OptionalHeader.SizeOfHeaders))
module.headers = (*IMAGE_NT_HEADERS)(a2p(headers + uintptr(dosHeader.E_lfanew)))
// Update position.
module.headers.OptionalHeader.ImageBase = module.codeBase
// Copy sections from DLL file block to new memory location.
err = module.copySections(addr, size, oldHeader)
if err != nil {
err = fmt.Errorf("Error copying sections: %w", err)
return
}
// Adjust base address of imported data.
locationDelta := module.headers.OptionalHeader.ImageBase - oldHeader.OptionalHeader.ImageBase
if locationDelta != 0 {
module.isRelocated, err = module.performBaseRelocation(locationDelta)
if err != nil {
err = fmt.Errorf("Error relocating module: %w", err)
return
}
} else {
module.isRelocated = true
}
// Load required dlls and adjust function table of imports.
err = module.buildImportTable()
if err != nil {
err = fmt.Errorf("Error building import table: %w", err)
return
}
// Mark memory pages depending on section headers and release sections that are marked as "discardable".
err = module.finalizeSections()
if err != nil {
err = fmt.Errorf("Error finalizing sections: %w", err)
return
}
// TLS callbacks are executed BEFORE the main loading.
module.executeTLS()
// Get entry point of loaded module.
if module.headers.OptionalHeader.AddressOfEntryPoint != 0 {
module.entry = module.codeBase + uintptr(module.headers.OptionalHeader.AddressOfEntryPoint)
if module.isDLL {
// Notify library about attaching to process.
r0, _, _ := syscall.Syscall(module.entry, 3, module.codeBase, uintptr(DLL_PROCESS_ATTACH), 0)
successful := r0 != 0
if !successful {
err = windows.ERROR_DLL_INIT_FAILED
return
}
module.initialized = true
}
}
module.buildNameExports()
return
}
// Free releases module resources and unloads it.
func (module *Module) Free() {
if module.initialized {
// Notify library about detaching from process.
syscall.Syscall(module.entry, 3, module.codeBase, uintptr(DLL_PROCESS_DETACH), 0)
module.initialized = false
}
if module.modules != nil {
// Free previously opened libraries.
for _, handle := range module.modules {
windows.FreeLibrary(handle)
}
module.modules = nil
}
if module.codeBase != 0 {
windows.VirtualFree(module.codeBase, 0, windows.MEM_RELEASE)
module.codeBase = 0
}
if module.blockedMemory != nil {
module.blockedMemory.free()
module.blockedMemory = nil
}
}
// ProcAddressByName returns function address by exported name.
func (module *Module) ProcAddressByName(name string) (uintptr, error) {
directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_EXPORT)
if directory.Size == 0 {
return 0, errors.New("No export table found")
}
exports := (*IMAGE_EXPORT_DIRECTORY)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
if module.nameExports == nil {
return 0, errors.New("No functions exported by name")
}
if idx, ok := module.nameExports[name]; ok {
if uint32(idx) > exports.NumberOfFunctions {
return 0, errors.New("Ordinal number too high")
}
// AddressOfFunctions contains the RVAs to the "real" functions.
return module.codeBase + uintptr(*(*uint32)(a2p(module.codeBase + uintptr(exports.AddressOfFunctions) + uintptr(idx)*4))), nil
}
return 0, errors.New("Function not found by name")
}
// ProcAddressByOrdinal returns function address by exported ordinal.
func (module *Module) ProcAddressByOrdinal(ordinal uint16) (uintptr, error) {
directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_EXPORT)
if directory.Size == 0 {
return 0, errors.New("No export table found")
}
exports := (*IMAGE_EXPORT_DIRECTORY)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
if uint32(ordinal) < exports.Base {
return 0, errors.New("Ordinal number too low")
}
idx := ordinal - uint16(exports.Base)
if uint32(idx) > exports.NumberOfFunctions {
return 0, errors.New("Ordinal number too high")
}
// AddressOfFunctions contains the RVAs to the "real" functions.
return module.codeBase + uintptr(*(*uint32)(a2p(module.codeBase + uintptr(exports.AddressOfFunctions) + uintptr(idx)*4))), nil
}
func alignDown(value, alignment uintptr) uintptr {
return value & ^(alignment - 1)
}
func alignUp(value, alignment uintptr) uintptr {
return (value + alignment - 1) & ^(alignment - 1)
}
func a2p(addr uintptr) unsafe.Pointer {
return unsafe.Pointer(addr)
}
func memcpy(dst, src, size uintptr) {
var d, s []byte
unsafeSlice(unsafe.Pointer(&d), a2p(dst), int(size))
unsafeSlice(unsafe.Pointer(&s), a2p(src), int(size))
copy(d, s)
}
// unsafeSlice updates the slice slicePtr to be a slice
// referencing the provided data with its length & capacity set to
// lenCap.
//
// TODO: when Go 1.16 or Go 1.17 is the minimum supported version,
// update callers to use unsafe.Slice instead of this.
func unsafeSlice(slicePtr, data unsafe.Pointer, lenCap int) {
type sliceHeader struct {
Data unsafe.Pointer
Len int
Cap int
}
h := (*sliceHeader)(slicePtr)
h.Data = data
h.Len = lenCap
h.Cap = lenCap
}

View File

@@ -1,16 +0,0 @@
// +build windows,386 windows,arm
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package memmod
func (opthdr *IMAGE_OPTIONAL_HEADER) imageOffset() uintptr {
return 0
}
func (module *Module) check4GBBoundaries(alignedImageSize uintptr) (err error) {
return
}

View File

@@ -1,8 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package memmod
const imageFileProcess = IMAGE_FILE_MACHINE_I386

View File

@@ -1,36 +0,0 @@
// +build windows,amd64 windows,arm64
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package memmod
import (
"fmt"
"golang.org/x/sys/windows"
)
func (opthdr *IMAGE_OPTIONAL_HEADER) imageOffset() uintptr {
return uintptr(opthdr.ImageBase & 0xffffffff00000000)
}
func (module *Module) check4GBBoundaries(alignedImageSize uintptr) (err error) {
for (module.codeBase >> 32) < ((module.codeBase + alignedImageSize) >> 32) {
node := &addressList{
next: module.blockedMemory,
address: module.codeBase,
}
module.blockedMemory = node
module.codeBase, err = windows.VirtualAlloc(0,
alignedImageSize,
windows.MEM_RESERVE|windows.MEM_COMMIT,
windows.PAGE_READWRITE)
if err != nil {
return fmt.Errorf("Error allocating memory block: %w", err)
}
}
return
}

View File

@@ -1,8 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package memmod
const imageFileProcess = IMAGE_FILE_MACHINE_AMD64

View File

@@ -1,8 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package memmod
const imageFileProcess = IMAGE_FILE_MACHINE_ARMNT

View File

@@ -1,8 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package memmod
const imageFileProcess = IMAGE_FILE_MACHINE_ARM64

View File

@@ -1,339 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package memmod
import "unsafe"
const (
IMAGE_DOS_SIGNATURE = 0x5A4D // MZ
IMAGE_OS2_SIGNATURE = 0x454E // NE
IMAGE_OS2_SIGNATURE_LE = 0x454C // LE
IMAGE_VXD_SIGNATURE = 0x454C // LE
IMAGE_NT_SIGNATURE = 0x00004550 // PE00
)
// DOS .EXE header
type IMAGE_DOS_HEADER struct {
E_magic uint16 // Magic number
E_cblp uint16 // Bytes on last page of file
E_cp uint16 // Pages in file
E_crlc uint16 // Relocations
E_cparhdr uint16 // Size of header in paragraphs
E_minalloc uint16 // Minimum extra paragraphs needed
E_maxalloc uint16 // Maximum extra paragraphs needed
E_ss uint16 // Initial (relative) SS value
E_sp uint16 // Initial SP value
E_csum uint16 // Checksum
E_ip uint16 // Initial IP value
E_cs uint16 // Initial (relative) CS value
E_lfarlc uint16 // File address of relocation table
E_ovno uint16 // Overlay number
E_res [4]uint16 // Reserved words
E_oemid uint16 // OEM identifier (for e_oeminfo)
E_oeminfo uint16 // OEM information; e_oemid specific
E_res2 [10]uint16 // Reserved words
E_lfanew int32 // File address of new exe header
}
// File header format
type IMAGE_FILE_HEADER struct {
Machine uint16
NumberOfSections uint16
TimeDateStamp uint32
PointerToSymbolTable uint32
NumberOfSymbols uint32
SizeOfOptionalHeader uint16
Characteristics uint16
}
const (
IMAGE_SIZEOF_FILE_HEADER = 20
IMAGE_FILE_RELOCS_STRIPPED = 0x0001 // Relocation info stripped from file.
IMAGE_FILE_EXECUTABLE_IMAGE = 0x0002 // File is executable (i.e. no unresolved external references).
IMAGE_FILE_LINE_NUMS_STRIPPED = 0x0004 // Line nunbers stripped from file.
IMAGE_FILE_LOCAL_SYMS_STRIPPED = 0x0008 // Local symbols stripped from file.
IMAGE_FILE_AGGRESIVE_WS_TRIM = 0x0010 // Aggressively trim working set
IMAGE_FILE_LARGE_ADDRESS_AWARE = 0x0020 // App can handle >2gb addresses
IMAGE_FILE_BYTES_REVERSED_LO = 0x0080 // Bytes of machine word are reversed.
IMAGE_FILE_32BIT_MACHINE = 0x0100 // 32 bit word machine.
IMAGE_FILE_DEBUG_STRIPPED = 0x0200 // Debugging info stripped from file in .DBG file
IMAGE_FILE_REMOVABLE_RUN_FROM_SWAP = 0x0400 // If Image is on removable media, copy and run from the swap file.
IMAGE_FILE_NET_RUN_FROM_SWAP = 0x0800 // If Image is on Net, copy and run from the swap file.
IMAGE_FILE_SYSTEM = 0x1000 // System File.
IMAGE_FILE_DLL = 0x2000 // File is a DLL.
IMAGE_FILE_UP_SYSTEM_ONLY = 0x4000 // File should only be run on a UP machine
IMAGE_FILE_BYTES_REVERSED_HI = 0x8000 // Bytes of machine word are reversed.
IMAGE_FILE_MACHINE_UNKNOWN = 0
IMAGE_FILE_MACHINE_TARGET_HOST = 0x0001 // Useful for indicating we want to interact with the host and not a WoW guest.
IMAGE_FILE_MACHINE_I386 = 0x014c // Intel 386.
IMAGE_FILE_MACHINE_R3000 = 0x0162 // MIPS little-endian, 0x160 big-endian
IMAGE_FILE_MACHINE_R4000 = 0x0166 // MIPS little-endian
IMAGE_FILE_MACHINE_R10000 = 0x0168 // MIPS little-endian
IMAGE_FILE_MACHINE_WCEMIPSV2 = 0x0169 // MIPS little-endian WCE v2
IMAGE_FILE_MACHINE_ALPHA = 0x0184 // Alpha_AXP
IMAGE_FILE_MACHINE_SH3 = 0x01a2 // SH3 little-endian
IMAGE_FILE_MACHINE_SH3DSP = 0x01a3
IMAGE_FILE_MACHINE_SH3E = 0x01a4 // SH3E little-endian
IMAGE_FILE_MACHINE_SH4 = 0x01a6 // SH4 little-endian
IMAGE_FILE_MACHINE_SH5 = 0x01a8 // SH5
IMAGE_FILE_MACHINE_ARM = 0x01c0 // ARM Little-Endian
IMAGE_FILE_MACHINE_THUMB = 0x01c2 // ARM Thumb/Thumb-2 Little-Endian
IMAGE_FILE_MACHINE_ARMNT = 0x01c4 // ARM Thumb-2 Little-Endian
IMAGE_FILE_MACHINE_AM33 = 0x01d3
IMAGE_FILE_MACHINE_POWERPC = 0x01F0 // IBM PowerPC Little-Endian
IMAGE_FILE_MACHINE_POWERPCFP = 0x01f1
IMAGE_FILE_MACHINE_IA64 = 0x0200 // Intel 64
IMAGE_FILE_MACHINE_MIPS16 = 0x0266 // MIPS
IMAGE_FILE_MACHINE_ALPHA64 = 0x0284 // ALPHA64
IMAGE_FILE_MACHINE_MIPSFPU = 0x0366 // MIPS
IMAGE_FILE_MACHINE_MIPSFPU16 = 0x0466 // MIPS
IMAGE_FILE_MACHINE_AXP64 = IMAGE_FILE_MACHINE_ALPHA64
IMAGE_FILE_MACHINE_TRICORE = 0x0520 // Infineon
IMAGE_FILE_MACHINE_CEF = 0x0CEF
IMAGE_FILE_MACHINE_EBC = 0x0EBC // EFI Byte Code
IMAGE_FILE_MACHINE_AMD64 = 0x8664 // AMD64 (K8)
IMAGE_FILE_MACHINE_M32R = 0x9041 // M32R little-endian
IMAGE_FILE_MACHINE_ARM64 = 0xAA64 // ARM64 Little-Endian
IMAGE_FILE_MACHINE_CEE = 0xC0EE
)
// Directory format
type IMAGE_DATA_DIRECTORY struct {
VirtualAddress uint32
Size uint32
}
const IMAGE_NUMBEROF_DIRECTORY_ENTRIES = 16
type IMAGE_NT_HEADERS struct {
Signature uint32
FileHeader IMAGE_FILE_HEADER
OptionalHeader IMAGE_OPTIONAL_HEADER
}
func (ntheader *IMAGE_NT_HEADERS) Sections() []IMAGE_SECTION_HEADER {
return (*[0xffff]IMAGE_SECTION_HEADER)(unsafe.Pointer(
(uintptr)(unsafe.Pointer(ntheader)) +
unsafe.Offsetof(ntheader.OptionalHeader) +
uintptr(ntheader.FileHeader.SizeOfOptionalHeader)))[:ntheader.FileHeader.NumberOfSections]
}
const (
IMAGE_DIRECTORY_ENTRY_EXPORT = 0 // Export Directory
IMAGE_DIRECTORY_ENTRY_IMPORT = 1 // Import Directory
IMAGE_DIRECTORY_ENTRY_RESOURCE = 2 // Resource Directory
IMAGE_DIRECTORY_ENTRY_EXCEPTION = 3 // Exception Directory
IMAGE_DIRECTORY_ENTRY_SECURITY = 4 // Security Directory
IMAGE_DIRECTORY_ENTRY_BASERELOC = 5 // Base Relocation Table
IMAGE_DIRECTORY_ENTRY_DEBUG = 6 // Debug Directory
IMAGE_DIRECTORY_ENTRY_COPYRIGHT = 7 // (X86 usage)
IMAGE_DIRECTORY_ENTRY_ARCHITECTURE = 7 // Architecture Specific Data
IMAGE_DIRECTORY_ENTRY_GLOBALPTR = 8 // RVA of GP
IMAGE_DIRECTORY_ENTRY_TLS = 9 // TLS Directory
IMAGE_DIRECTORY_ENTRY_LOAD_CONFIG = 10 // Load Configuration Directory
IMAGE_DIRECTORY_ENTRY_BOUND_IMPORT = 11 // Bound Import Directory in headers
IMAGE_DIRECTORY_ENTRY_IAT = 12 // Import Address Table
IMAGE_DIRECTORY_ENTRY_DELAY_IMPORT = 13 // Delay Load Import Descriptors
IMAGE_DIRECTORY_ENTRY_COM_DESCRIPTOR = 14 // COM Runtime descriptor
)
const IMAGE_SIZEOF_SHORT_NAME = 8
// Section header format
type IMAGE_SECTION_HEADER struct {
Name [IMAGE_SIZEOF_SHORT_NAME]byte
physicalAddressOrVirtualSize uint32
VirtualAddress uint32
SizeOfRawData uint32
PointerToRawData uint32
PointerToRelocations uint32
PointerToLinenumbers uint32
NumberOfRelocations uint16
NumberOfLinenumbers uint16
Characteristics uint32
}
func (ishdr *IMAGE_SECTION_HEADER) PhysicalAddress() uint32 {
return ishdr.physicalAddressOrVirtualSize
}
func (ishdr *IMAGE_SECTION_HEADER) SetPhysicalAddress(addr uint32) {
ishdr.physicalAddressOrVirtualSize = addr
}
func (ishdr *IMAGE_SECTION_HEADER) VirtualSize() uint32 {
return ishdr.physicalAddressOrVirtualSize
}
func (ishdr *IMAGE_SECTION_HEADER) SetVirtualSize(addr uint32) {
ishdr.physicalAddressOrVirtualSize = addr
}
const (
// Section characteristics.
IMAGE_SCN_TYPE_REG = 0x00000000 // Reserved.
IMAGE_SCN_TYPE_DSECT = 0x00000001 // Reserved.
IMAGE_SCN_TYPE_NOLOAD = 0x00000002 // Reserved.
IMAGE_SCN_TYPE_GROUP = 0x00000004 // Reserved.
IMAGE_SCN_TYPE_NO_PAD = 0x00000008 // Reserved.
IMAGE_SCN_TYPE_COPY = 0x00000010 // Reserved.
IMAGE_SCN_CNT_CODE = 0x00000020 // Section contains code.
IMAGE_SCN_CNT_INITIALIZED_DATA = 0x00000040 // Section contains initialized data.
IMAGE_SCN_CNT_UNINITIALIZED_DATA = 0x00000080 // Section contains uninitialized data.
IMAGE_SCN_LNK_OTHER = 0x00000100 // Reserved.
IMAGE_SCN_LNK_INFO = 0x00000200 // Section contains comments or some other type of information.
IMAGE_SCN_TYPE_OVER = 0x00000400 // Reserved.
IMAGE_SCN_LNK_REMOVE = 0x00000800 // Section contents will not become part of image.
IMAGE_SCN_LNK_COMDAT = 0x00001000 // Section contents comdat.
IMAGE_SCN_MEM_PROTECTED = 0x00004000 // Obsolete.
IMAGE_SCN_NO_DEFER_SPEC_EXC = 0x00004000 // Reset speculative exceptions handling bits in the TLB entries for this section.
IMAGE_SCN_GPREL = 0x00008000 // Section content can be accessed relative to GP
IMAGE_SCN_MEM_FARDATA = 0x00008000
IMAGE_SCN_MEM_SYSHEAP = 0x00010000 // Obsolete.
IMAGE_SCN_MEM_PURGEABLE = 0x00020000
IMAGE_SCN_MEM_16BIT = 0x00020000
IMAGE_SCN_MEM_LOCKED = 0x00040000
IMAGE_SCN_MEM_PRELOAD = 0x00080000
IMAGE_SCN_ALIGN_1BYTES = 0x00100000 //
IMAGE_SCN_ALIGN_2BYTES = 0x00200000 //
IMAGE_SCN_ALIGN_4BYTES = 0x00300000 //
IMAGE_SCN_ALIGN_8BYTES = 0x00400000 //
IMAGE_SCN_ALIGN_16BYTES = 0x00500000 // Default alignment if no others are specified.
IMAGE_SCN_ALIGN_32BYTES = 0x00600000 //
IMAGE_SCN_ALIGN_64BYTES = 0x00700000 //
IMAGE_SCN_ALIGN_128BYTES = 0x00800000 //
IMAGE_SCN_ALIGN_256BYTES = 0x00900000 //
IMAGE_SCN_ALIGN_512BYTES = 0x00A00000 //
IMAGE_SCN_ALIGN_1024BYTES = 0x00B00000 //
IMAGE_SCN_ALIGN_2048BYTES = 0x00C00000 //
IMAGE_SCN_ALIGN_4096BYTES = 0x00D00000 //
IMAGE_SCN_ALIGN_8192BYTES = 0x00E00000 //
IMAGE_SCN_ALIGN_MASK = 0x00F00000
IMAGE_SCN_LNK_NRELOC_OVFL = 0x01000000 // Section contains extended relocations.
IMAGE_SCN_MEM_DISCARDABLE = 0x02000000 // Section can be discarded.
IMAGE_SCN_MEM_NOT_CACHED = 0x04000000 // Section is not cachable.
IMAGE_SCN_MEM_NOT_PAGED = 0x08000000 // Section is not pageable.
IMAGE_SCN_MEM_SHARED = 0x10000000 // Section is shareable.
IMAGE_SCN_MEM_EXECUTE = 0x20000000 // Section is executable.
IMAGE_SCN_MEM_READ = 0x40000000 // Section is readable.
IMAGE_SCN_MEM_WRITE = 0x80000000 // Section is writeable.
// TLS Characteristic Flags
IMAGE_SCN_SCALE_INDEX = 0x00000001 // Tls index is scaled.
)
// Based relocation format
type IMAGE_BASE_RELOCATION struct {
VirtualAddress uint32
SizeOfBlock uint32
}
const (
IMAGE_REL_BASED_ABSOLUTE = 0
IMAGE_REL_BASED_HIGH = 1
IMAGE_REL_BASED_LOW = 2
IMAGE_REL_BASED_HIGHLOW = 3
IMAGE_REL_BASED_HIGHADJ = 4
IMAGE_REL_BASED_MACHINE_SPECIFIC_5 = 5
IMAGE_REL_BASED_RESERVED = 6
IMAGE_REL_BASED_MACHINE_SPECIFIC_7 = 7
IMAGE_REL_BASED_MACHINE_SPECIFIC_8 = 8
IMAGE_REL_BASED_MACHINE_SPECIFIC_9 = 9
IMAGE_REL_BASED_DIR64 = 10
IMAGE_REL_BASED_IA64_IMM64 = 9
IMAGE_REL_BASED_MIPS_JMPADDR = 5
IMAGE_REL_BASED_MIPS_JMPADDR16 = 9
IMAGE_REL_BASED_ARM_MOV32 = 5
IMAGE_REL_BASED_THUMB_MOV32 = 7
)
// Export Format
type IMAGE_EXPORT_DIRECTORY struct {
Characteristics uint32
TimeDateStamp uint32
MajorVersion uint16
MinorVersion uint16
Name uint32
Base uint32
NumberOfFunctions uint32
NumberOfNames uint32
AddressOfFunctions uint32 // RVA from base of image
AddressOfNames uint32 // RVA from base of image
AddressOfNameOrdinals uint32 // RVA from base of image
}
type IMAGE_IMPORT_BY_NAME struct {
Hint uint16
Name [1]byte
}
func IMAGE_ORDINAL(ordinal uintptr) uintptr {
return ordinal & 0xffff
}
func IMAGE_SNAP_BY_ORDINAL(ordinal uintptr) bool {
return (ordinal & IMAGE_ORDINAL_FLAG) != 0
}
// Thread Local Storage
type IMAGE_TLS_DIRECTORY struct {
StartAddressOfRawData uintptr
EndAddressOfRawData uintptr
AddressOfIndex uintptr // PDWORD
AddressOfCallbacks uintptr // PIMAGE_TLS_CALLBACK *;
SizeOfZeroFill uint32
Characteristics uint32
}
type IMAGE_IMPORT_DESCRIPTOR struct {
characteristicsOrOriginalFirstThunk uint32 // 0 for terminating null import descriptor
// RVA to original unbound IAT (PIMAGE_THUNK_DATA)
TimeDateStamp uint32 // 0 if not bound,
// -1 if bound, and real date\time stamp
// in IMAGE_DIRECTORY_ENTRY_BOUND_IMPORT (new BIND)
// O.W. date/time stamp of DLL bound to (Old BIND)
ForwarderChain uint32 // -1 if no forwarders
Name uint32
FirstThunk uint32 // RVA to IAT (if bound this IAT has actual addresses)
}
func (imgimpdesc *IMAGE_IMPORT_DESCRIPTOR) Characteristics() uint32 {
return imgimpdesc.characteristicsOrOriginalFirstThunk
}
func (imgimpdesc *IMAGE_IMPORT_DESCRIPTOR) OriginalFirstThunk() uint32 {
return imgimpdesc.characteristicsOrOriginalFirstThunk
}
const (
DLL_PROCESS_ATTACH = 1
DLL_THREAD_ATTACH = 2
DLL_THREAD_DETACH = 3
DLL_PROCESS_DETACH = 0
)
type SYSTEM_INFO struct {
ProcessorArchitecture uint16
Reserved uint16
PageSize uint32
MinimumApplicationAddress uintptr
MaximumApplicationAddress uintptr
ActiveProcessorMask uintptr
NumberOfProcessors uint32
ProcessorType uint32
AllocationGranularity uint32
ProcessorLevel uint16
ProcessorRevision uint16
}

View File

@@ -1,45 +0,0 @@
// +build windows,386 windows,arm
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package memmod
// Optional header format
type IMAGE_OPTIONAL_HEADER struct {
Magic uint16
MajorLinkerVersion uint8
MinorLinkerVersion uint8
SizeOfCode uint32
SizeOfInitializedData uint32
SizeOfUninitializedData uint32
AddressOfEntryPoint uint32
BaseOfCode uint32
BaseOfData uint32
ImageBase uintptr
SectionAlignment uint32
FileAlignment uint32
MajorOperatingSystemVersion uint16
MinorOperatingSystemVersion uint16
MajorImageVersion uint16
MinorImageVersion uint16
MajorSubsystemVersion uint16
MinorSubsystemVersion uint16
Win32VersionValue uint32
SizeOfImage uint32
SizeOfHeaders uint32
CheckSum uint32
Subsystem uint16
DllCharacteristics uint16
SizeOfStackReserve uintptr
SizeOfStackCommit uintptr
SizeOfHeapReserve uintptr
SizeOfHeapCommit uintptr
LoaderFlags uint32
NumberOfRvaAndSizes uint32
DataDirectory [IMAGE_NUMBEROF_DIRECTORY_ENTRIES]IMAGE_DATA_DIRECTORY
}
const IMAGE_ORDINAL_FLAG uintptr = 0x80000000

View File

@@ -1,44 +0,0 @@
// +build windows,amd64 windows,arm64
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package memmod
// Optional header format
type IMAGE_OPTIONAL_HEADER struct {
Magic uint16
MajorLinkerVersion uint8
MinorLinkerVersion uint8
SizeOfCode uint32
SizeOfInitializedData uint32
SizeOfUninitializedData uint32
AddressOfEntryPoint uint32
BaseOfCode uint32
ImageBase uintptr
SectionAlignment uint32
FileAlignment uint32
MajorOperatingSystemVersion uint16
MinorOperatingSystemVersion uint16
MajorImageVersion uint16
MinorImageVersion uint16
MajorSubsystemVersion uint16
MinorSubsystemVersion uint16
Win32VersionValue uint32
SizeOfImage uint32
SizeOfHeaders uint32
CheckSum uint32
Subsystem uint16
DllCharacteristics uint16
SizeOfStackReserve uintptr
SizeOfStackCommit uintptr
SizeOfHeapReserve uintptr
SizeOfHeapCommit uintptr
LoaderFlags uint32
NumberOfRvaAndSizes uint32
DataDirectory [IMAGE_NUMBEROF_DIRECTORY_ENTRIES]IMAGE_DATA_DIRECTORY
}
const IMAGE_ORDINAL_FLAG uintptr = 0x8000000000000000

View File

@@ -67,7 +67,7 @@ func (session Session) ReceivePacket() (packet []byte, err error) {
err = e1 err = e1
return return
} }
unsafeSlice(unsafe.Pointer(&packet), unsafe.Pointer(r0), int(packetSize)) packet = unsafe.Slice((*byte)(unsafe.Pointer(r0)), packetSize)
return return
} }
@@ -81,28 +81,10 @@ func (session Session) AllocateSendPacket(packetSize int) (packet []byte, err er
err = e1 err = e1
return return
} }
unsafeSlice(unsafe.Pointer(&packet), unsafe.Pointer(r0), int(packetSize)) packet = unsafe.Slice((*byte)(unsafe.Pointer(r0)), packetSize)
return return
} }
func (session Session) SendPacket(packet []byte) { func (session Session) SendPacket(packet []byte) {
syscall.Syscall(procWintunSendPacket.Addr(), 2, session.handle, uintptr(unsafe.Pointer(&packet[0])), 0) syscall.Syscall(procWintunSendPacket.Addr(), 2, session.handle, uintptr(unsafe.Pointer(&packet[0])), 0)
} }
// unsafeSlice updates the slice slicePtr to be a slice
// referencing the provided data with its length & capacity set to
// lenCap.
//
// TODO: when Go 1.16 or Go 1.17 is the minimum supported version,
// update callers to use unsafe.Slice instead of this.
func unsafeSlice(slicePtr, data unsafe.Pointer, lenCap int) {
type sliceHeader struct {
Data unsafe.Pointer
Len int
Cap int
}
h := (*sliceHeader)(slicePtr)
h.Data = data
h.Len = lenCap
h.Cap = lenCap
}

View File

@@ -6,7 +6,6 @@
package wintun package wintun
import ( import (
"errors"
"log" "log"
"runtime" "runtime"
"syscall" "syscall"
@@ -23,175 +22,107 @@ const (
logErr logErr
) )
const ( const AdapterNameMax = 128
PoolNameMax = 256
AdapterNameMax = 128
)
type Pool [PoolNameMax]uint16
type Adapter struct { type Adapter struct {
handle uintptr handle uintptr
} }
var ( var (
modwintun = newLazyDLL("wintun.dll", setupLogger) modwintun = newLazyDLL("wintun.dll", setupLogger)
procWintunCreateAdapter = modwintun.NewProc("WintunCreateAdapter") procWintunCreateAdapter = modwintun.NewProc("WintunCreateAdapter")
procWintunDeleteAdapter = modwintun.NewProc("WintunDeleteAdapter")
procWintunDeletePoolDriver = modwintun.NewProc("WintunDeletePoolDriver")
procWintunEnumAdapters = modwintun.NewProc("WintunEnumAdapters")
procWintunFreeAdapter = modwintun.NewProc("WintunFreeAdapter")
procWintunOpenAdapter = modwintun.NewProc("WintunOpenAdapter") procWintunOpenAdapter = modwintun.NewProc("WintunOpenAdapter")
procWintunCloseAdapter = modwintun.NewProc("WintunCloseAdapter")
procWintunDeleteDriver = modwintun.NewProc("WintunDeleteDriver")
procWintunGetAdapterLUID = modwintun.NewProc("WintunGetAdapterLUID") procWintunGetAdapterLUID = modwintun.NewProc("WintunGetAdapterLUID")
procWintunGetAdapterName = modwintun.NewProc("WintunGetAdapterName")
procWintunGetRunningDriverVersion = modwintun.NewProc("WintunGetRunningDriverVersion") procWintunGetRunningDriverVersion = modwintun.NewProc("WintunGetRunningDriverVersion")
procWintunSetAdapterName = modwintun.NewProc("WintunSetAdapterName")
) )
type TimestampedWriter interface {
WriteWithTimestamp(p []byte, ts int64) (n int, err error)
}
func logMessage(level loggerLevel, timestamp uint64, msg *uint16) int {
if tw, ok := log.Default().Writer().(TimestampedWriter); ok {
tw.WriteWithTimestamp([]byte(log.Default().Prefix()+windows.UTF16PtrToString(msg)), (int64(timestamp)-116444736000000000)*100)
} else {
log.Println(windows.UTF16PtrToString(msg))
}
return 0
}
func setupLogger(dll *lazyDLL) { func setupLogger(dll *lazyDLL) {
syscall.Syscall(dll.NewProc("WintunSetLogger").Addr(), 1, windows.NewCallback(func(level loggerLevel, msg *uint16) int { var callback uintptr
log.Println("[Wintun]", windows.UTF16PtrToString(msg)) if runtime.GOARCH == "386" || runtime.GOARCH == "arm" {
return 0 callback = windows.NewCallback(func(level loggerLevel, timestampLow, timestampHigh uint32, msg *uint16) int {
}), 0, 0) return logMessage(level, uint64(timestampHigh)<<32|uint64(timestampLow), msg)
})
} else if runtime.GOARCH == "amd64" || runtime.GOARCH == "arm64" {
callback = windows.NewCallback(logMessage)
}
syscall.Syscall(dll.NewProc("WintunSetLogger").Addr(), 1, callback, 0, 0)
} }
func MakePool(poolName string) (pool *Pool, err error) { func closeAdapter(wintun *Adapter) {
poolName16, err := windows.UTF16FromString(poolName) syscall.Syscall(procWintunCloseAdapter.Addr(), 1, wintun.handle, 0, 0)
}
// CreateAdapter creates a Wintun adapter. name is the cosmetic name of the adapter.
// tunnelType represents the type of adapter and should be "Wintun". requestedGUID is
// the GUID of the created network adapter, which then influences NLA generation
// deterministically. If it is set to nil, the GUID is chosen by the system at random,
// and hence a new NLA entry is created for each new adapter.
func CreateAdapter(name string, tunnelType string, requestedGUID *windows.GUID) (wintun *Adapter, err error) {
var name16 *uint16
name16, err = windows.UTF16PtrFromString(name)
if err != nil { if err != nil {
return return
} }
if len(poolName16) > PoolNameMax { var tunnelType16 *uint16
err = errors.New("Pool name too long") tunnelType16, err = windows.UTF16PtrFromString(tunnelType)
if err != nil {
return return
} }
pool = &Pool{} r0, _, e1 := syscall.Syscall(procWintunCreateAdapter.Addr(), 3, uintptr(unsafe.Pointer(name16)), uintptr(unsafe.Pointer(tunnelType16)), uintptr(unsafe.Pointer(requestedGUID)))
copy(pool[:], poolName16)
return
}
func (pool *Pool) String() string {
return windows.UTF16ToString(pool[:])
}
func freeAdapter(wintun *Adapter) {
syscall.Syscall(procWintunFreeAdapter.Addr(), 1, uintptr(wintun.handle), 0, 0)
}
// OpenAdapter finds a Wintun adapter by its name. This function returns the adapter if found, or
// windows.ERROR_FILE_NOT_FOUND otherwise. If the adapter is found but not a Wintun-class or a
// member of the pool, this function returns windows.ERROR_ALREADY_EXISTS. The adapter must be
// released after use.
func (pool *Pool) OpenAdapter(ifname string) (wintun *Adapter, err error) {
ifname16, err := windows.UTF16PtrFromString(ifname)
if err != nil {
return nil, err
}
r0, _, e1 := syscall.Syscall(procWintunOpenAdapter.Addr(), 2, uintptr(unsafe.Pointer(pool)), uintptr(unsafe.Pointer(ifname16)), 0)
if r0 == 0 { if r0 == 0 {
err = e1 err = e1
return return
} }
wintun = &Adapter{r0} wintun = &Adapter{handle: r0}
runtime.SetFinalizer(wintun, freeAdapter) runtime.SetFinalizer(wintun, closeAdapter)
return return
} }
// CreateAdapter creates a Wintun adapter. ifname is the requested name of the adapter, while // OpenAdapter opens an existing Wintun adapter by name.
// requestedGUID is the GUID of the created network adapter, which then influences NLA generation func OpenAdapter(name string) (wintun *Adapter, err error) {
// deterministically. If it is set to nil, the GUID is chosen by the system at random, and hence a var name16 *uint16
// new NLA entry is created for each new adapter. It is called "requested" GUID because the API it name16, err = windows.UTF16PtrFromString(name)
// uses is completely undocumented, and so there could be minor interesting complications with its
// usage. This function returns the network adapter ID and a flag if reboot is required.
func (pool *Pool) CreateAdapter(ifname string, requestedGUID *windows.GUID) (wintun *Adapter, rebootRequired bool, err error) {
var ifname16 *uint16
ifname16, err = windows.UTF16PtrFromString(ifname)
if err != nil { if err != nil {
return return
} }
var _p0 uint32 r0, _, e1 := syscall.Syscall(procWintunOpenAdapter.Addr(), 1, uintptr(unsafe.Pointer(name16)), 0, 0)
r0, _, e1 := syscall.Syscall6(procWintunCreateAdapter.Addr(), 4, uintptr(unsafe.Pointer(pool)), uintptr(unsafe.Pointer(ifname16)), uintptr(unsafe.Pointer(requestedGUID)), uintptr(unsafe.Pointer(&_p0)), 0, 0)
rebootRequired = _p0 != 0
if r0 == 0 { if r0 == 0 {
err = e1 err = e1
return return
} }
wintun = &Adapter{r0} wintun = &Adapter{handle: r0}
runtime.SetFinalizer(wintun, freeAdapter) runtime.SetFinalizer(wintun, closeAdapter)
return return
} }
// Delete deletes a Wintun adapter. This function succeeds if the adapter was not found. It returns // Close closes a Wintun adapter.
// a bool indicating whether a reboot is required. func (wintun *Adapter) Close() (err error) {
func (wintun *Adapter) Delete(forceCloseSessions bool) (rebootRequired bool, err error) { runtime.SetFinalizer(wintun, nil)
var _p0 uint32 r1, _, e1 := syscall.Syscall(procWintunCloseAdapter.Addr(), 1, wintun.handle, 0, 0)
if forceCloseSessions {
_p0 = 1
}
var _p1 uint32
r1, _, e1 := syscall.Syscall(procWintunDeleteAdapter.Addr(), 3, uintptr(wintun.handle), uintptr(_p0), uintptr(unsafe.Pointer(&_p1)))
rebootRequired = _p1 != 0
if r1 == 0 { if r1 == 0 {
err = e1 err = e1
} }
return return
} }
// DeleteMatchingAdapters deletes all Wintun adapters, which match // Uninstall removes the driver from the system if no drivers are currently in use.
// given criteria, and returns which ones it deleted, whether a reboot func Uninstall() (err error) {
// is required after, and which errors occurred during the process. r1, _, e1 := syscall.Syscall(procWintunDeleteDriver.Addr(), 0, 0, 0, 0)
func (pool *Pool) DeleteMatchingAdapters(matches func(adapter *Adapter) bool, forceCloseSessions bool) (rebootRequired bool, errors []error) {
cb := func(handle uintptr, _ uintptr) int {
adapter := &Adapter{handle}
if !matches(adapter) {
return 1
}
rebootRequired2, err := adapter.Delete(forceCloseSessions)
if err != nil {
errors = append(errors, err)
return 1
}
rebootRequired = rebootRequired || rebootRequired2
return 1
}
r1, _, e1 := syscall.Syscall(procWintunEnumAdapters.Addr(), 3, uintptr(unsafe.Pointer(pool)), uintptr(windows.NewCallback(cb)), 0)
if r1 == 0 {
errors = append(errors, e1)
}
return
}
// Name returns the name of the Wintun adapter.
func (wintun *Adapter) Name() (ifname string, err error) {
var ifname16 [AdapterNameMax]uint16
r1, _, e1 := syscall.Syscall(procWintunGetAdapterName.Addr(), 2, uintptr(wintun.handle), uintptr(unsafe.Pointer(&ifname16[0])), 0)
if r1 == 0 {
err = e1
return
}
ifname = windows.UTF16ToString(ifname16[:])
return
}
// DeleteDriver deletes all Wintun adapters in a pool and if there are no more adapters in any other
// pools, also removes Wintun from the driver store, usually called by uninstallers.
func (pool *Pool) DeleteDriver() (rebootRequired bool, err error) {
var _p0 uint32
r1, _, e1 := syscall.Syscall(procWintunDeletePoolDriver.Addr(), 2, uintptr(unsafe.Pointer(pool)), uintptr(unsafe.Pointer(&_p0)), 0)
rebootRequired = _p0 != 0
if r1 == 0 {
err = e1
}
return
}
// SetName sets name of the Wintun adapter.
func (wintun *Adapter) SetName(ifname string) (err error) {
ifname16, err := windows.UTF16FromString(ifname)
if err != nil {
return err
}
r1, _, e1 := syscall.Syscall(procWintunSetAdapterName.Addr(), 2, uintptr(wintun.handle), uintptr(unsafe.Pointer(&ifname16[0])), 0)
if r1 == 0 { if r1 == 0 {
err = e1 err = e1
} }

View File

@@ -1,3 +1,3 @@
package main package main
const Version = "0.0.20210323" const Version = "0.0.20211016"