Skip to content

Commit

Permalink
refactor pageRPC for better typing
Browse files Browse the repository at this point in the history
  • Loading branch information
mondaychen committed Nov 21, 2023
1 parent 4464369 commit d5d227b
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 92 deletions.
22 changes: 10 additions & 12 deletions src/helpers/domActions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@ export class DomActions {
}

private async getTaxySelector(originalId: number) {
const uniqueId = await callRPCWithTab(this.tabId, {
type: 'getUniqueElementSelectorId',
payload: [originalId],
});
const uniqueId = await callRPCWithTab(
this.tabId,
'getUniqueElementSelectorId',
[originalId]
);
return `[${TAXY_ELEMENT_SELECTOR}="${uniqueId}"]`;
}

Expand Down Expand Up @@ -75,10 +76,7 @@ export class DomActions {
y: number,
clickCount = 1
): Promise<void> {
callRPCWithTab(this.tabId, {
type: 'ripple',
payload: [x, y],
});
callRPCWithTab(this.tabId, 'ripple', [x, y]);
await this.sendCommand('Input.dispatchMouseEvent', {
type: 'mousePressed',
x,
Expand Down Expand Up @@ -221,10 +219,10 @@ export class DomActions {
}

public async attachFile(payload: { data: string; selector?: string }) {
return callRPCWithTab(this.tabId, {
type: 'attachFile',
payload: [payload.data, payload.selector || 'input[type="file"]'],
});
return callRPCWithTab(this.tabId, 'attachFile', [
payload.data,
payload.selector || 'input[type="file"]',
]);
}

public async scrollUp() {
Expand Down
41 changes: 29 additions & 12 deletions src/helpers/pageRPC.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,31 @@
import { sleep } from './utils';
import type { RPCDefinition } from '../pages/Content';
import type { RPCMethods } from '../pages/Content';

// Call these functions to execute code in the content script

export const callRPC = async (
message: RPCDefinition['Message'],
function sendMessage<K extends keyof RPCMethods>(
tabId: number,
method: K,
payload: Parameters<RPCMethods[K]>
): Promise<ReturnType<RPCMethods[K]>> {
// Send a message to the other world
// Ensure that the method and arguments are correct according to RpcMethods
return new Promise((resolve, reject) => {
chrome.tabs.sendMessage(tabId, { method, payload }, (response) => {
if (chrome.runtime.lastError) {
reject(chrome.runtime.lastError);
} else {
resolve(response);
}
});
});
}

export const callRPC = async <K extends keyof RPCMethods>(
method: K,
payload: Parameters<RPCMethods[K]>,
maxTries = 1
): Promise<RPCDefinition['ReturnType']> => {
): Promise<ReturnType<RPCMethods[K]>> => {
let queryOptions = { active: true, currentWindow: true };
let activeTab = (await chrome.tabs.query(queryOptions))[0];

Expand All @@ -17,21 +36,19 @@ export const callRPC = async (
}

if (!activeTab?.id) throw new Error('No active tab found');
return callRPCWithTab(activeTab.id, message, maxTries);
return callRPCWithTab(activeTab.id, method, payload, maxTries);
};

export const callRPCWithTab = async (
export const callRPCWithTab = async <K extends keyof RPCMethods>(
tabId: number,
message: RPCDefinition['Message'],
method: K,
payload: Parameters<RPCMethods[K]>,
maxTries = 1
): Promise<RPCDefinition['ReturnType']> => {
): Promise<ReturnType<RPCMethods[K]>> => {
let err: any;
for (let i = 0; i < maxTries; i++) {
try {
const response = await chrome.tabs.sendMessage(tabId, {
type: message.type,
payload: message.payload || [],
});
const response = await sendMessage(tabId, method, payload);
return response;
} catch (e) {
if (i === maxTries - 1) {
Expand Down
8 changes: 1 addition & 7 deletions src/helpers/simplifyDom.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,7 @@ import { callRPC } from './pageRPC';
import { truthyFilter } from './utils';

export async function getSimplifiedDom() {
const fullDom = await callRPC(
{
type: 'getAnnotatedDOM',
payload: [],
},
3
);
const fullDom = await callRPC('getAnnotatedDOM', [], 3);
if (!fullDom || typeof fullDom !== 'string') return null;

const dom = new DOMParser().parseFromString(fullDom, 'text/html');
Expand Down
22 changes: 8 additions & 14 deletions src/pages/Background/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,21 +64,15 @@ async function findActiveTab() {

async function takeScreenshot(tabId: number): Promise<string | null> {
await attachDebugger(tabId);
await callRPCWithTab(tabId, {
type: 'drawLabels',
payload: [],
});
await callRPCWithTab(tabId, 'drawLabels', []);
const screenshotData = (await chrome.debugger.sendCommand(
{ tabId: tabId },
'Page.captureScreenshot',
{
format: 'png', // or 'jpeg'
}
)) as any;
await callRPCWithTab(tabId, {
type: 'removeLabels',
payload: [],
});
await callRPCWithTab(tabId, 'removeLabels', []);
return screenshotData.data;
}

Expand Down Expand Up @@ -123,12 +117,12 @@ chrome.runtime.onMessage.addListener(async (request, sender) => {
await domActions.waitTillElementRendered(
`document.querySelector('${FINAL_MESSAGE_SELECTOR}')`
);
const message = await callRPCWithTab(chatGPTTabId, {
type: 'getDataFromRenderedMarkdown',
payload: [FINAL_MESSAGE_SELECTOR],
});
// TODO: make this more robust
if (message && typeof message === 'object' && message.codeBlocks) {
const message = await callRPCWithTab(
chatGPTTabId,
'getDataFromRenderedMarkdown',
[FINAL_MESSAGE_SELECTOR]
);
if (message && message.codeBlocks) {
const codeBlock = message.codeBlocks[0] || '{}';
try {
const action = JSON.parse(codeBlock);
Expand Down
50 changes: 19 additions & 31 deletions src/pages/Content/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import { drawLabels, removeLabels } from './drawLabels';
import ripple from './ripple';
import { getDataFromRenderedMarkdown } from './reverseMarkdown';

const rpcMethods = {
export const rpcMethods = {
getAnnotatedDOM,
getUniqueElementSelectorId,
ripple,
Expand All @@ -20,43 +20,31 @@ const rpcMethods = {
} as const;

export type RPCMethods = typeof rpcMethods;
export type MethodName = keyof RPCMethods;
export type Payload<T extends MethodName> = Parameters<RPCMethods[T]>;
// export type MethodRT<T extends MethodName> = ReturnType<RPCMethods[T]>;
export type RPCDefinition = {
type MethodName = keyof RPCMethods;

type RPCMessage = {
[K in MethodName]: {
ReturnType: ReturnType<RPCMethods[K]>;
Message: {
type: K;
payload: Parameters<RPCMethods[K]>;
};
method: K;
payload: Parameters<RPCMethods[K]>;
};
}[MethodName];

const isKnownMethodName = (type: string) => {
return type in rpcMethods;
};

// This function should run in the content script
const watchForRPCRequests = () => {
chrome.runtime.onMessage.addListener(
(
message: RPCDefinition['Message'],
sender,
sendResponse
): true | undefined => {
if (!isKnownMethodName(message.type)) {
return;
}
// @ts-expect-error - we know that the payload type is valid
const resp = rpcMethods[message.type](...message.payload);
if (resp instanceof Promise) {
resp.then((resolvedResp) => {
sendResponse(resolvedResp);
});
return true;
} else {
sendResponse(resp);
(message: RPCMessage, sender, sendResponse): true | undefined => {
const { method, payload } = message;
if (method in rpcMethods) {
// @ts-expect-error - we know this is valid (see pageRPC)
const resp = rpcMethods[method as keyof RPCMethods](...payload);
if (resp instanceof Promise) {
resp.then((resolvedResp) => {
sendResponse(resolvedResp);
});
return true;
} else {
sendResponse(resp);
}
}
}
);
Expand Down
20 changes: 4 additions & 16 deletions src/state/currentTask.ts
Original file line number Diff line number Diff line change
Expand Up @@ -128,19 +128,13 @@ export const createCurrentTaskSlice: MyStateCreator<CurrentTaskSlice> = (
useAppState.getState().settings.selectedModel ===
'gpt-4-vision-preview'
) {
await callRPCWithTab(tabId, {
type: 'drawLabels',
payload: [],
});
await callRPCWithTab(tabId, 'drawLabels', []);
const imgData = await chrome.tabs.captureVisibleTab({
format: 'jpeg',
quality: 85,
});
if (wasStopped()) break;
await callRPCWithTab(tabId, {
type: 'removeLabels',
payload: [],
});
await callRPCWithTab(tabId, 'removeLabels', []);
query = await determineNextActionWithVision(
instructions,
previousActions.filter(
Expand Down Expand Up @@ -257,15 +251,9 @@ export const createCurrentTaskSlice: MyStateCreator<CurrentTaskSlice> = (
},
prepareLabels: async () => {
const tabId = get().currentTask.tabId;
await callRPCWithTab(tabId, {
type: 'drawLabels',
payload: [],
});
await callRPCWithTab(tabId, 'drawLabels', []);
await sleep(800);
await callRPCWithTab(tabId, {
type: 'removeLabels',
payload: [],
});
await callRPCWithTab(tabId, 'removeLabels', []);
},
performActionString: async (actionString: string) => {
const action = parseResponse(actionString);
Expand Down

0 comments on commit d5d227b

Please sign in to comment.