Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import { describe, expect, it, vi } from 'vitest';
import type { SshClientProxy } from '@main/core/ssh/lifecycle/ssh-client-proxy';
import { PortForwardService } from './port-forward-service';

function fakeProxy(): Pick<SshClientProxy, 'client' | 'isConnected'> {
return {
isConnected: true,
get client() {
return {} as SshClientProxy['client'];
},
};
}

describe('PortForwardService', () => {
it('deduplicates opens by id and closes the tunnel once', async () => {
const close = vi.fn();
const service = new PortForwardService({
openTunnel: vi.fn(async () => ({ localPort: 6100, close })),
});

const first = await service.open({
id: 'forward-1',
projectId: 'project-1',
workspaceId: 'workspace-1',
connectionId: 'ssh-1',
proxy: fakeProxy(),
remotePort: 5173,
});
const second = await service.open({
id: 'forward-1',
projectId: 'project-1',
workspaceId: 'workspace-1',
connectionId: 'ssh-1',
proxy: fakeProxy(),
remotePort: 5173,
});

expect(second).toEqual(first);

await service.stop('forward-1');
await service.stop('forward-1');

expect(close).toHaveBeenCalledTimes(1);
});

it('stops only tunnels owned by the requested workspace', async () => {
const closeFirst = vi.fn();
const closeSecond = vi.fn();
const service = new PortForwardService({
openTunnel: vi
.fn()
.mockResolvedValueOnce({ localPort: 6100, close: closeFirst })
.mockResolvedValueOnce({ localPort: 6101, close: closeSecond }),
});

await service.open({
id: 'forward-1',
projectId: 'project-1',
workspaceId: 'workspace-1',
connectionId: 'ssh-1',
proxy: fakeProxy(),
remotePort: 5173,
});
await service.open({
id: 'forward-2',
projectId: 'project-1',
workspaceId: 'workspace-2',
connectionId: 'ssh-1',
proxy: fakeProxy(),
remotePort: 5174,
});

await service.stopForWorkspace('project-1', 'workspace-1');

expect(closeFirst).toHaveBeenCalledTimes(1);
expect(closeSecond).not.toHaveBeenCalled();
});
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import type { SshClientProxy } from '@main/core/ssh/lifecycle/ssh-client-proxy';
import {
openPortForwardTunnel,
type OpenPortForwardTunnelOptions,
type PortForwardTunnel,
} from './port-forward-tunnel';

export type OpenPortForwardRequest = {
id: string;
projectId: string;
workspaceId: string;
connectionId: string;
proxy: Pick<SshClientProxy, 'client' | 'isConnected'>;
remotePort: number;
preferredLocalPort?: number;
};

export type PortForwardRecord = {
id: string;
projectId: string;
workspaceId: string;
connectionId: string;
remotePort: number;
localPort: number;
};

type PortForwardEntry = PortForwardRecord & {
tunnel: PortForwardTunnel;
};

export type PortForwardConnectionErrorHandler = (id: string, error: Error) => void;

export class PortForwardService {
private readonly tunnels = new Map<string, PortForwardEntry>();
private readonly openTunnel: (
request: OpenPortForwardTunnelOptions
) => Promise<PortForwardTunnel>;
private readonly onTunnelClosed?: (id: string) => void;
private readonly connectionErrorHandlers = new Set<PortForwardConnectionErrorHandler>();

constructor(
options: {
openTunnel?: (request: OpenPortForwardTunnelOptions) => Promise<PortForwardTunnel>;
onTunnelClosed?: (id: string) => void;
onConnectionError?: PortForwardConnectionErrorHandler;
} = {}
) {
this.openTunnel = options.openTunnel ?? openPortForwardTunnel;
this.onTunnelClosed = options.onTunnelClosed;
if (options.onConnectionError) {
this.connectionErrorHandlers.add(options.onConnectionError);
}
}

onConnectionError(handler: PortForwardConnectionErrorHandler): () => void {
this.connectionErrorHandlers.add(handler);
return () => this.connectionErrorHandlers.delete(handler);
}

async open(request: OpenPortForwardRequest): Promise<PortForwardRecord> {
const existing = this.tunnels.get(request.id);
if (existing) return toRecord(existing);

const tunnel = await this.openTunnel({
proxy: request.proxy,
remotePort: request.remotePort,
preferredLocalPort: request.preferredLocalPort,
onConnectionError: (error) => this.emitConnectionError(request.id, error),
});
const entry: PortForwardEntry = {
id: request.id,
projectId: request.projectId,
workspaceId: request.workspaceId,
connectionId: request.connectionId,
remotePort: request.remotePort,
localPort: tunnel.localPort,
tunnel,
};
this.tunnels.set(request.id, entry);
return toRecord(entry);
}

async stop(id: string): Promise<void> {
const entry = this.tunnels.get(id);
if (!entry) return;
this.tunnels.delete(id);
await entry.tunnel.close();
this.onTunnelClosed?.(id);
}

async stopForWorkspace(projectId: string, workspaceId: string): Promise<void> {
const ids = Array.from(this.tunnels.values())
.filter((entry) => entry.projectId === projectId && entry.workspaceId === workspaceId)
.map((entry) => entry.id);
await Promise.all(ids.map((id) => this.stop(id)));
}

async stopForProject(projectId: string): Promise<void> {
const ids = Array.from(this.tunnels.values())
.filter((entry) => entry.projectId === projectId)
.map((entry) => entry.id);
await Promise.all(ids.map((id) => this.stop(id)));
}

private emitConnectionError(id: string, error: Error): void {
for (const handler of this.connectionErrorHandlers) {
handler(id, error);
}
}
}

function toRecord(entry: PortForwardEntry): PortForwardRecord {
return {
id: entry.id,
projectId: entry.projectId,
workspaceId: entry.workspaceId,
connectionId: entry.connectionId,
remotePort: entry.remotePort,
localPort: entry.localPort,
};
}
Loading
Loading