Skip to content

Commit

Permalink
Merge pull request #391 from Syquel/bugfix/390_token
Browse files Browse the repository at this point in the history
#390 Restrict usage of mvnd daemons to the current user by utilizing a token check
  • Loading branch information
gnodet authored Apr 28, 2021
2 parents 76eab72 + 784264c commit bbbd3a0
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.nio.ByteBuffer;
import java.nio.channels.SocketChannel;
import java.nio.file.Files;
import java.nio.file.Path;
Expand Down Expand Up @@ -415,7 +416,7 @@ private DaemonClientConnection connectToDaemon(DaemonInfo daemon,
throws DaemonException.ConnectException {
LOGGER.debug("Connecting to Daemon");
try {
DaemonConnection connection = connect(daemon.getAddress());
DaemonConnection connection = connect(daemon.getAddress(), daemon.getToken());
return new DaemonClientConnection(connection, daemon, staleAddressDetector, newDaemon, parameters);
} catch (DaemonException.ConnectException e) {
staleAddressDetector.maybeStaleAddress(e);
Expand Down Expand Up @@ -444,7 +445,7 @@ public boolean maybeStaleAddress(Exception failure) {
}
}

public DaemonConnection connect(int port) throws DaemonException.ConnectException {
public DaemonConnection connect(int port, byte[] token) throws DaemonException.ConnectException {
InetSocketAddress address = new InetSocketAddress(InetAddress.getLoopbackAddress(), port);
try {
LOGGER.debug("Trying to connect to address {}.", address);
Expand All @@ -456,6 +457,13 @@ public DaemonConnection connect(int port) throws DaemonException.ConnectExceptio
throw new DaemonException.ConnectException(String.format("Socket connected to itself on %s.", address));
}
LOGGER.debug("Connected to address {}.", socket.getRemoteSocketAddress());

ByteBuffer tokenBuffer = ByteBuffer.wrap(token);
do {
socketChannel.write(tokenBuffer);
} while (tokenBuffer.remaining() > 0);
LOGGER.debug("Exchanged token successfully");

return new DaemonConnection(socketChannel);
} catch (DaemonException.ConnectException e) {
throw e;
Expand Down
12 changes: 10 additions & 2 deletions common/src/main/java/org/mvndaemon/mvnd/common/DaemonInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,26 +26,30 @@
*/
public class DaemonInfo {

public static final int TOKEN_SIZE = 16;

private final String id;
private final String javaHome;
private final String mvndHome;
private final int pid;
private final int address;
private final byte[] token;
private final String locale;
private final List<String> options;
private final DaemonState state;
private final long lastIdle;
private final long lastBusy;

public DaemonInfo(String id, String javaHome, String mavenHome,
int pid, int address,
int pid, int address, byte[] token,
String locale, List<String> options,
DaemonState state, long lastIdle, long lastBusy) {
this.id = id;
this.javaHome = javaHome;
this.mvndHome = mavenHome;
this.pid = pid;
this.address = address;
this.token = token;
this.locale = locale;
this.options = options;
this.state = state;
Expand Down Expand Up @@ -73,6 +77,10 @@ public int getAddress() {
return address;
}

public byte[] getToken() {
return token;
}

public String getLocale() {
return locale;
}
Expand Down Expand Up @@ -106,7 +114,7 @@ public DaemonInfo withState(DaemonState state) {
lb = lastBusy;
}
return new DaemonInfo(id, javaHome, mvndHome, pid, address,
locale, options, state, li, lb);
token, locale, options, state, li, lb);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,10 @@ private void doUpdate(Runnable updater) {
String mavenHome = readString();
int pid = buffer.getInt();
int address = buffer.getInt();

byte[] token = new byte[DaemonInfo.TOKEN_SIZE];
buffer.get(token);

String locale = readString();
List<String> opts = new ArrayList<>();
int nbOpts = buffer.getInt();
Expand All @@ -190,8 +194,8 @@ private void doUpdate(Runnable updater) {
DaemonState state = DaemonState.values()[buffer.get()];
long lastIdle = buffer.getLong();
long lastBusy = buffer.getLong();
DaemonInfo di = new DaemonInfo(daemonId, javaHome, mavenHome, pid, address, locale, opts, state,
lastIdle, lastBusy);
DaemonInfo di = new DaemonInfo(daemonId, javaHome, mavenHome, pid, address, token, locale,
opts, state, lastIdle, lastBusy);
infosMap.putIfAbsent(di.getId(), di);
}
stopEvents.clear();
Expand All @@ -216,6 +220,7 @@ private void doUpdate(Runnable updater) {
writeString(di.getMvndHome());
buffer.putInt(di.getPid());
buffer.putInt(di.getAddress());
buffer.put(di.getToken());
writeString(di.getLocale());
buffer.putInt(di.getOptions().size());
for (String opt : di.getOptions()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public void testReadWrite() throws IOException {
byte[] token = new byte[16];
new Random().nextBytes(token);
reg1.store(new DaemonInfo("12345678", "/java/home/",
"/data/reg/", 0x12345678, 7502,
"/data/reg/", 0x12345678, 7502, token,
Locale.getDefault().toLanguageTag(), Arrays.asList("-Xmx"),
DaemonState.Idle, System.currentTimeMillis(), System.currentTimeMillis()));

Expand Down
32 changes: 32 additions & 0 deletions daemon/src/main/java/org/mvndaemon/mvnd/daemon/Server.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@
import java.lang.reflect.Field;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.security.MessageDigest;
import java.security.SecureRandom;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
Expand Down Expand Up @@ -116,6 +119,10 @@ public Server() throws IOException {
strategy = DaemonExpiration.master();
memoryStatus = new DaemonMemoryStatus(executor);

SecureRandom secureRandom = new SecureRandom();
byte[] token = new byte[DaemonInfo.TOKEN_SIZE];
secureRandom.nextBytes(token);

List<String> opts = new ArrayList<>();
Arrays.stream(Environment.values())
.filter(Environment::isDiscriminating)
Expand All @@ -130,6 +137,7 @@ public Server() throws IOException {
Environment.MVND_HOME.asString(),
DaemonRegistry.getProcessId(),
socket.socket().getLocalPort(),
token,
Locale.getDefault().toLanguageTag(),
opts,
Busy, cur, cur);
Expand Down Expand Up @@ -224,6 +232,13 @@ private void accept() {

private void client(SocketChannel socket) {
LOGGER.info("Client connected");
if (!checkToken(socket)) {
LOGGER.error("Received invalid token, dropping connection");
updateState(DaemonState.Idle);

return;
}

try (DaemonConnection connection = new DaemonConnection(socket)) {
LOGGER.info("Waiting for request");
SynchronousQueue<Message> request = new SynchronousQueue<>();
Expand All @@ -246,6 +261,23 @@ private void client(SocketChannel socket) {
}
}

private boolean checkToken(SocketChannel socket) {
byte[] token = new byte[info.getToken().length];
ByteBuffer tokenBuffer = ByteBuffer.wrap(token);

try {
do {
if (socket.read(tokenBuffer) == -1) {
break;
}
} while (tokenBuffer.remaining() > 0);
} catch (final IOException e) {
LOGGER.debug("Discarding EOFException: {}", e.toString(), e);
}

return MessageDigest.isEqual(info.getToken(), token);
}

private void expirationCheck() {
if (expirationLock.tryLock()) {
try {
Expand Down

0 comments on commit bbbd3a0

Please sign in to comment.