diff --git a/packages/libp2p-daemon-client/src/index.ts b/packages/libp2p-daemon-client/src/index.ts index 200a0447..bf0963be 100644 --- a/packages/libp2p-daemon-client/src/index.ts +++ b/packages/libp2p-daemon-client/src/index.ts @@ -1,6 +1,6 @@ import { type PSMessage, Request, Response, StreamInfo } from '@libp2p/daemon-protocol' import { StreamHandler } from '@libp2p/daemon-protocol/stream-handler' -import { passThroughUpgrader } from '@libp2p/daemon-protocol/upgrader' +import { PassThroughUpgrader } from '@libp2p/daemon-protocol/upgrader' import { InvalidParametersError, isPeerId } from '@libp2p/interface' import { defaultLogger, logger } from '@libp2p/logger' import { peerIdFromMultihash } from '@libp2p/peer-id' @@ -10,7 +10,7 @@ import { pbStream, type ProtobufStream } from 'it-protobuf-stream' import * as Digest from 'multiformats/hashes/digest' import { DHT } from './dht.js' import { Pubsub } from './pubsub.js' -import type { Stream, PeerId, MultiaddrConnection, PeerInfo, Transport } from '@libp2p/interface' +import type { Stream, PeerId, MultiaddrConnection, PeerInfo, Transport, Listener } from '@libp2p/interface' import type { Multiaddr } from '@multiformats/multiaddr' import type { CID } from 'multiformats/cid' @@ -49,7 +49,7 @@ class Client implements DaemonClient { // @ts-expect-error because we use a passthrough upgrader, // this is actually a MultiaddrConnection and not a Connection return this.tcp.dial(this.multiaddr, { - upgrader: passThroughUpgrader + upgrader: new PassThroughUpgrader() }) } @@ -196,43 +196,9 @@ class Client implements DaemonClient { // open a tcp port, pipe any data from it to the handler function const listener = this.tcp.createListener({ - upgrader: passThroughUpgrader, - // @ts-expect-error because we are using a passthrough upgrader, this is a MultiaddrConnection - handler: (connection: MultiaddrConnection) => { - Promise.resolve() - .then(async () => { - const sh = new StreamHandler({ - stream: connection - }) - const message = await sh.read() - - if (message == null) { - throw new OperationFailedError('Could not read open stream response') - } - - const response = StreamInfo.decode(message) - - if (response.proto !== protocol) { - throw new OperationFailedError('Incorrect protocol') - } - - // @ts-expect-error because we are using a passthrough upgrader, this is a MultiaddrConnection - await handler(sh.rest()) - }) - .catch(err => { - connection.abort(err) - }) - .finally(() => { - connection.close() - .catch(err => { - log.error(err) - }) - listener.close() - .catch(err => { - log.error(err) - }) - }) - } + upgrader: new PassThroughUpgrader((maConn) => { + this.onConnection(protocol, listener, maConn) + }) }) await listener.listen(multiaddr('/ip4/127.0.0.1/tcp/0')) const address = listener.getAddrs()[0] @@ -257,6 +223,42 @@ class Client implements DaemonClient { throw new OperationFailedError(response.error?.msg ?? 'Register stream handler failed') } } + + private onConnection (protocol: string, listener: Listener, connection: MultiaddrConnection): void { + Promise.resolve() + .then(async () => { + const sh = new StreamHandler({ + stream: connection + }) + const message = await sh.read() + + if (message == null) { + throw new OperationFailedError('Could not read open stream response') + } + + const response = StreamInfo.decode(message) + + if (response.proto !== protocol) { + throw new OperationFailedError('Incorrect protocol') + } + + // @ts-expect-error because we are using a passthrough upgrader, this is a MultiaddrConnection + await handler(sh.rest()) + }) + .catch(err => { + connection.abort(err) + }) + .finally(() => { + connection.close() + .catch(err => { + log.error(err) + }) + listener.close() + .catch(err => { + log.error(err) + }) + }) + } } export interface IdentifyResult { diff --git a/packages/libp2p-daemon-protocol/src/upgrader.ts b/packages/libp2p-daemon-protocol/src/upgrader.ts index f04ef03b..b0396d3b 100644 --- a/packages/libp2p-daemon-protocol/src/upgrader.ts +++ b/packages/libp2p-daemon-protocol/src/upgrader.ts @@ -1,8 +1,24 @@ -import type { Upgrader } from '@libp2p/interface' +import type { Connection, MultiaddrConnection, Upgrader } from '@libp2p/interface' -export const passThroughUpgrader: Upgrader = { - // @ts-expect-error should return a connection - upgradeInbound: async maConn => maConn, - // @ts-expect-error should return a connection - upgradeOutbound: async maConn => maConn +interface OnConnection { + (conn: MultiaddrConnection): void +} + +export class PassThroughUpgrader implements Upgrader { + private readonly onConnection?: OnConnection + + constructor (handler?: OnConnection) { + this.onConnection = handler + } + + async upgradeInbound (maConn: MultiaddrConnection): Promise { + this.onConnection?.(maConn) + // @ts-expect-error should return a connection + return maConn + } + + async upgradeOutbound (maConn: MultiaddrConnection): Promise { + // @ts-expect-error should return a connection + return maConn + } } diff --git a/packages/libp2p-daemon-server/src/index.ts b/packages/libp2p-daemon-server/src/index.ts index 177afb8c..c5b6d2fc 100644 --- a/packages/libp2p-daemon-server/src/index.ts +++ b/packages/libp2p-daemon-server/src/index.ts @@ -8,7 +8,7 @@ import { PSRequest, StreamInfo } from '@libp2p/daemon-protocol' -import { passThroughUpgrader } from '@libp2p/daemon-protocol/upgrader' +import { PassThroughUpgrader } from '@libp2p/daemon-protocol/upgrader' import { defaultLogger, logger } from '@libp2p/logger' import { peerIdFromMultihash } from '@libp2p/peer-id' import { tcp } from '@libp2p/tcp' @@ -63,9 +63,7 @@ export class Server implements Libp2pServer { logger: defaultLogger() }) this.listener = this.tcp.createListener({ - // @ts-expect-error connection may actually be a maconn? - handler: this.handleConnection.bind(this), - upgrader: passThroughUpgrader + upgrader: new PassThroughUpgrader(this.handleConnection.bind(this)) }) this._onExit = this._onExit.bind(this) @@ -150,7 +148,7 @@ export class Server implements Libp2pServer { // @ts-expect-error because we use a passthrough upgrader, // this is actually a MultiaddrConnection and not a Connection conn = await this.tcp.dial(addr, { - upgrader: passThroughUpgrader + upgrader: new PassThroughUpgrader() }) const message = StreamInfo.encode({