Skip to content

Commit

Permalink
fix: replaced connection map locking with LockBox
Browse files Browse the repository at this point in the history
  • Loading branch information
tegefaulkes committed Oct 5, 2022
1 parent b47c7c0 commit d12169b
Showing 1 changed file with 91 additions and 159 deletions.
250 changes: 91 additions & 159 deletions src/network/Proxy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ import type { PromiseCancellable } from '@matrixai/async-cancellable';
import type {
Host,
Port,
Address,
ConnectionInfo,
TLSConfig,
ConnectionEstablishedCallback,
Expand All @@ -16,8 +15,7 @@ import type { ContextTimed } from '../contexts/types';
import http from 'http';
import UTP from 'utp-native';
import Logger from '@matrixai/logger';
import { Lock } from '@matrixai/async-locks';
import { withF } from '@matrixai/resources';
import { Lock, LockBox } from '@matrixai/async-locks';
import { StartStop, ready } from '@matrixai/async-init/dist/StartStop';
import { Timer } from '@matrixai/timer';
import ConnectionReverse from './ConnectionReverse';
Expand Down Expand Up @@ -48,13 +46,12 @@ class Proxy {
protected server: http.Server;
protected utpSocket: UTP;
protected tlsConfig: TLSConfig;
// TODO: replace connection lock maps with `LockBox`
protected connectionLocksForward: Map<Address, Lock> = new Map();
protected connectionLocksForward: LockBox<Lock> = new LockBox();
protected connectionsForward: ConnectionsForward = {
proxy: new Map(),
client: new Map(),
};
protected connectionLocksReverse: Map<Address, Lock> = new Map();
protected connectionLocksReverse: LockBox<Lock> = new LockBox();
protected connectionsReverse: ConnectionsReverse = {
proxy: new Map(),
reverse: new Map(),
Expand Down Expand Up @@ -333,22 +330,8 @@ class Proxy {
@context ctx: ContextTimed,
): Promise<void> {
const proxyAddress = networkUtils.buildAddress(proxyHost, proxyPort);
let lock = this.connectionLocksForward.get(proxyAddress);
if (lock == null) {
lock = new Lock();
this.connectionLocksForward.set(proxyAddress, lock);
}
await withF([lock.lock()], async () => {
try {
await this.establishConnectionForward(
nodeId,
proxyHost,
proxyPort,
ctx,
);
} finally {
this.connectionLocksForward.delete(proxyAddress);
}
await this.connectionLocksForward.withF([proxyAddress, Lock], async () => {
await this.establishConnectionForward(nodeId, proxyHost, proxyPort, ctx);
});
}

Expand All @@ -358,21 +341,12 @@ class Proxy {
proxyPort: Port,
): Promise<void> {
const proxyAddress = networkUtils.buildAddress(proxyHost, proxyPort);
let lock = this.connectionLocksForward.get(proxyAddress);
if (lock == null) {
lock = new Lock();
this.connectionLocksForward.set(proxyAddress, lock);
}
await withF([lock.lock()], async () => {
try {
const conn = this.connectionsForward.proxy.get(proxyAddress);
if (conn == null) {
return;
}
await conn.stop();
} finally {
this.connectionLocksForward.delete(proxyAddress);
await this.connectionLocksForward.withF([proxyAddress, Lock], async () => {
const conn = this.connectionsForward.proxy.get(proxyAddress);
if (conn == null) {
return;
}
await conn.stop();
});
}

Expand Down Expand Up @@ -433,92 +407,77 @@ class Proxy {
clientSocket.destroy(new networkErrors.ErrorProxyConnectAuth());
return;
}
let lock = this.connectionLocksForward.get(proxyAddress as Address);
if (lock == null) {
lock = new Lock();
this.connectionLocksForward.set(proxyAddress as Address, lock);
}
await withF([lock.lock()], async () => {
await this.connectionLocksForward.withF([proxyAddress, Lock], async () => {
const timer = new Timer({ delay: this.connConnectTime });
try {
try {
await this.connectForward(
nodeId,
proxyHost,
proxyPort,
clientSocket,
{ timer },
);
} catch (e) {
if (e instanceof networkErrors.ErrorProxyConnectInvalidUrl) {
if (!clientSocket.destroyed) {
await clientSocketEnd('HTTP/1.1 400 Bad Request\r\n' + '\r\n');
clientSocket.destroy(e);
}
return;
}
if (e instanceof networkErrors.ErrorConnectionStartTimeout) {
if (!clientSocket.destroyed) {
await clientSocketEnd(
'HTTP/1.1 504 Gateway Timeout\r\n' + '\r\n',
);
clientSocket.destroy(e);
}
return;
await this.connectForward(nodeId, proxyHost, proxyPort, clientSocket, {
timer,
});
} catch (e) {
if (e instanceof networkErrors.ErrorProxyConnectInvalidUrl) {
if (!clientSocket.destroyed) {
await clientSocketEnd('HTTP/1.1 400 Bad Request\r\n' + '\r\n');
clientSocket.destroy(e);
}
if (e instanceof networkErrors.ErrorConnectionStart) {
if (!clientSocket.destroyed) {
await clientSocketEnd('HTTP/1.1 502 Bad Gateway\r\n' + '\r\n');
clientSocket.destroy(e);
}
return;
return;
}
if (e instanceof networkErrors.ErrorConnectionStartTimeout) {
if (!clientSocket.destroyed) {
await clientSocketEnd('HTTP/1.1 504 Gateway Timeout\r\n' + '\r\n');
clientSocket.destroy(e);
}
if (e instanceof networkErrors.ErrorCertChain) {
if (!clientSocket.destroyed) {
await clientSocketEnd(
'HTTP/1.1 526 Invalid SSL Certificate\r\n' + '\r\n',
);
clientSocket.destroy(e);
}
return;
return;
}
if (e instanceof networkErrors.ErrorConnectionStart) {
if (!clientSocket.destroyed) {
await clientSocketEnd('HTTP/1.1 502 Bad Gateway\r\n' + '\r\n');
clientSocket.destroy(e);
}
if (e instanceof networkErrors.ErrorConnectionTimeout) {
if (!clientSocket.destroyed) {
await clientSocketEnd(
'HTTP/1.1 524 A Timeout Occurred\r\n' + '\r\n',
);
clientSocket.destroy(e);
}
return;
return;
}
if (e instanceof networkErrors.ErrorCertChain) {
if (!clientSocket.destroyed) {
await clientSocketEnd(
'HTTP/1.1 526 Invalid SSL Certificate\r\n' + '\r\n',
);
clientSocket.destroy(e);
}
if (e instanceof networkErrors.ErrorConnection) {
if (!clientSocket.destroyed) {
await clientSocketEnd(
'HTTP/1.1 500 Internal Server Error\r\n' + '\r\n',
);
clientSocket.destroy(e);
}
return;
return;
}
if (e instanceof networkErrors.ErrorConnectionTimeout) {
if (!clientSocket.destroyed) {
await clientSocketEnd(
'HTTP/1.1 524 A Timeout Occurred\r\n' + '\r\n',
);
clientSocket.destroy(e);
}
return;
}
if (e instanceof networkErrors.ErrorConnection) {
if (!clientSocket.destroyed) {
await clientSocketEnd(
'HTTP/1.1 500 Internal Server Error\r\n' + '\r\n',
);
clientSocket.destroy(e);
}
return;
} finally {
timer.cancel();
}
// After composing, switch off this error handler
clientSocket.off('error', handleConnectError);
await clientSocketWrite(
'HTTP/1.1 200 Connection Established\r\n' + '\r\n',
);
this.logger.info(`Handled CONNECT to ${proxyAddress}`);
if (!clientSocket.destroyed) {
await clientSocketEnd(
'HTTP/1.1 500 Internal Server Error\r\n' + '\r\n',
);
clientSocket.destroy(e);
}
return;
} finally {
this.connectionLocksForward.delete(proxyAddress as Address);
timer.cancel();
}
// After composing, switch off this error handler
clientSocket.off('error', handleConnectError);
await clientSocketWrite(
'HTTP/1.1 200 Connection Established\r\n' + '\r\n',
);
this.logger.info(`Handled CONNECT to ${proxyAddress}`);
});
};

Expand Down Expand Up @@ -623,17 +582,8 @@ class Proxy {
@context ctx: ContextTimed,
): Promise<void> {
const proxyAddress = networkUtils.buildAddress(proxyHost, proxyPort);
let lock = this.connectionLocksReverse.get(proxyAddress);
if (lock == null) {
lock = new Lock();
this.connectionLocksReverse.set(proxyAddress, lock);
}
await withF([lock.lock()], async () => {
try {
await this.establishConnectionReverse(proxyHost, proxyPort, ctx);
} finally {
this.connectionLocksReverse.delete(proxyAddress);
}
await this.connectionLocksReverse.withF([proxyAddress, Lock], async () => {
await this.establishConnectionReverse(proxyHost, proxyPort, ctx);
});
}

Expand All @@ -643,21 +593,12 @@ class Proxy {
proxyPort: Port,
): Promise<void> {
const proxyAddress = networkUtils.buildAddress(proxyHost, proxyPort);
let lock = this.connectionLocksReverse.get(proxyAddress);
if (lock == null) {
lock = new Lock();
this.connectionLocksReverse.set(proxyAddress, lock);
}
await withF([lock.lock()], async () => {
try {
const conn = this.connectionsReverse.proxy.get(proxyAddress);
if (conn == null) {
return;
}
await conn.stop();
} finally {
this.connectionLocksReverse.delete(proxyAddress);
await this.connectionLocksReverse.withF([proxyAddress, Lock], async () => {
const conn = this.connectionsReverse.proxy.get(proxyAddress);
if (conn == null) {
return;
}
await conn.stop();
});
}

Expand All @@ -668,39 +609,30 @@ class Proxy {
utpConn.remoteAddress,
utpConn.remotePort,
);
let lock = this.connectionLocksReverse.get(proxyAddress);
if (lock == null) {
lock = new Lock();
this.connectionLocksReverse.set(proxyAddress, lock);
}
await withF([lock.lock()], async () => {
await this.connectionLocksReverse.withF([proxyAddress, Lock], async () => {
this.logger.info(`Handling connection from ${proxyAddress}`);
const timer = new Timer({ delay: this.connConnectTime });
try {
this.logger.info(`Handling connection from ${proxyAddress}`);
const timer = new Timer({ delay: this.connConnectTime });
try {
await this.connectReverse(
utpConn.remoteAddress,
utpConn.remotePort,
utpConn,
{ timer },
);
} catch (e) {
if (!(e instanceof networkErrors.ErrorNetwork)) {
throw e;
}
if (!utpConn.destroyed) {
utpConn.destroy();
}
this.logger.warn(
`Failed connection from ${proxyAddress} - ${e.toString()}`,
);
} finally {
timer.cancel();
await this.connectReverse(
utpConn.remoteAddress,
utpConn.remotePort,
utpConn,
{ timer },
);
} catch (e) {
if (!(e instanceof networkErrors.ErrorNetwork)) {
throw e;
}
if (!utpConn.destroyed) {
utpConn.destroy();
}
this.logger.info(`Handled connection from ${proxyAddress}`);
this.logger.warn(
`Failed connection from ${proxyAddress} - ${e.toString()}`,
);
} finally {
this.connectionLocksReverse.delete(proxyAddress);
timer.cancel();
}
this.logger.info(`Handled connection from ${proxyAddress}`);
});
};

Expand Down

0 comments on commit d12169b

Please sign in to comment.