Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow any valid node with no incoming connections and with side effects to run automatically #2944

Merged
merged 8 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@
"editor.formatOnSave": true,
},
"[typescript]": {
"editor.defaultFormatter": "esbenp.prettier-vscode",
"editor.defaultFormatter": "dbaeumer.vscode-eslint",
joeyballentine marked this conversation as resolved.
Show resolved Hide resolved
"editor.formatOnSave": true,
},
"[javascript]": {
Expand Down
1 change: 1 addition & 0 deletions backend/src/packages/chaiNNer_ncnn/ncnn/io/load_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
see_also=[
"chainner:ncnn:load_models",
],
side_effects=True,
)
def load_model_node(
param_path: Path, bin_path: Path
Expand Down
1 change: 1 addition & 0 deletions backend/src/packages/chaiNNer_onnx/onnx/io/load_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
see_also=[
"chainner:onnx:load_models",
],
side_effects=True,
)
def load_model_node(path: Path) -> tuple[OnnxModel, Path, str]:
assert os.path.exists(path), f"Model file at location {path} does not exist"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def parse_ckpt_state_dict(checkpoint: dict):
see_also=[
"chainner:pytorch:load_models",
],
side_effects=True,
)
def load_model_node(
context: NodeContext, path: Path
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def _for_ext(ext: str | Iterable[str], decoder: _Decoder) -> _Decoder:
DirectoryOutput("Directory", of_input=0),
FileNameOutput("Name", of_input=0),
],
side_effects=True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joeyballentine Making these nodes side-effect nodes is major change. It completely messes up our current system of "only nodes with side effects run" and is going to cause these nodes to be executed unnecessarily.

Why did you make this change in the first place?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It actually doesn't mess that up in practice, at least not that I found in testing. I think it's because they already run.

Anyway, I did that because they technically do have side effects. I explained this. We've even talked about similar things before: when I gave every image output a preview, you said they all had side effects. When we talked about giving every node the ability to toggle an image preview inline in the output, you said we'd need a way to toggle side effects on those nodes.

Why is that suddenly not the case now?

)
def load_image_node(path: Path) -> tuple[np.ndarray, Path, str]:
logger.debug(f"Reading image from path: {path}")
Expand Down
4 changes: 0 additions & 4 deletions src/common/util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -237,10 +237,6 @@ export const topologicalSort = <T>(
return result.reverse();
};

export const isStartingNode = (schema: NodeSchema) => {
return !schema.inputs.some((i) => i.hasHandle) && schema.outputs.length > 0;
};

export const isEndingNode = (schema: NodeSchema) => {
return !schema.outputs.some((i) => i.hasHandle) && schema.inputs.length > 0;
};
Expand Down
22 changes: 6 additions & 16 deletions src/renderer/components/node/Node.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,7 @@ import { useReactFlow } from 'reactflow';
import { useContext, useContextSelector } from 'use-context-selector';
import { Input, NodeData } from '../../../common/common-types';
import { DisabledStatus } from '../../../common/nodes/disabled';
import {
EMPTY_ARRAY,
getInputValue,
isStartingNode,
parseSourceHandle,
} from '../../../common/util';
import { EMPTY_ARRAY, getInputValue, parseSourceHandle } from '../../../common/util';
import { Validity } from '../../../common/Validity';
import { AlertBoxContext } from '../../contexts/AlertBoxContext';
import { BackendContext } from '../../contexts/BackendContext';
Expand All @@ -30,6 +25,7 @@ import {
} from '../../contexts/ExecutionContext';
import { GlobalContext, GlobalVolatileContext } from '../../contexts/GlobalNodeState';
import { getCategoryAccentColor, getTypeAccentColors } from '../../helpers/accentColors';

import { getSingleFileWithExtension } from '../../helpers/dataTransfer';
import { NodeState, useNodeStateFromData } from '../../helpers/nodeState';
import { NO_DISABLED, UseDisabled, useDisabled } from '../../hooks/useDisabled';
Expand Down Expand Up @@ -256,15 +252,9 @@ const NodeInner = memo(({ data, selected }: NodeProps) => {
}
};

const startingNode = isStartingNode(schema);
const isNewIterator = schema.kind === 'generator';
const hasStaticValueInput = schema.inputs.some((i) => i.kind === 'static');
const reload = useRunNode(
data,
validity.isValid && startingNode && !isNewIterator && !hasStaticValueInput
);
const { reload, isLive } = useRunNode(data, validity.isValid);
const filesToWatch = useMemo(() => {
if (!startingNode) return EMPTY_ARRAY;
if (!isLive) return EMPTY_ARRAY;

const files: string[] = [];
for (const input of schema.inputs) {
Expand All @@ -278,15 +268,15 @@ const NodeInner = memo(({ data, selected }: NodeProps) => {

if (files.length === 0) return EMPTY_ARRAY;
return files;
}, [startingNode, data.inputData, schema]);
}, [isLive, data.inputData, schema]);
useWatchFiles(filesToWatch, reload);

const disabled = useDisabled(data);
const passthrough = usePassthrough(data);
const menu = useNodeMenu(data, {
disabled,
passthrough,
reload: startingNode ? reload : undefined,
reload: isLive ? reload : undefined,
});

const toggleCollapse = useCallback(() => {
Expand Down
8 changes: 5 additions & 3 deletions src/renderer/components/node/NodeOutputs.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ import { OutputId, OutputKind, Size } from '../../../common/common-types';
import { log } from '../../../common/log';
import { getChainnerScope } from '../../../common/types/chainner-scope';
import { ExpressionJson, fromJson } from '../../../common/types/json';
import { isStartingNode } from '../../../common/util';
import { BackendContext } from '../../contexts/BackendContext';
import { GlobalContext, GlobalVolatileContext } from '../../contexts/GlobalNodeState';
import { NodeState } from '../../helpers/nodeState';
import { useAutomaticFeatures } from '../../hooks/useAutomaticFeatures';
import { useIsCollapsedNode } from '../../hooks/useIsCollapsedNode';
import { GenericOutput } from '../outputs/GenericOutput';
import { LargeImageOutput } from '../outputs/LargeImageOutput';
Expand Down Expand Up @@ -81,14 +81,16 @@ export const NodeOutputs = memo(({ nodeState, animated }: NodeOutputProps) => {

const currentTypes = stale ? undefined : outputDataEntry?.types;

const { isAutomatic } = useAutomaticFeatures(id, schemaId);

useEffect(() => {
if (isStartingNode(schema)) {
if (isAutomatic) {
for (const output of schema.outputs) {
const type = evalExpression(currentTypes?.[output.id]);
setManualOutputType(id, output.id, type);
}
}
}, [id, currentTypes, schema, setManualOutputType]);
}, [id, currentTypes, schema, setManualOutputType, isAutomatic]);

const isCollapsed = useIsCollapsedNode();
if (isCollapsed) {
Expand Down
32 changes: 32 additions & 0 deletions src/renderer/hooks/useAutomaticFeatures.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import { getIncomers, useReactFlow } from 'reactflow';
import { useContext } from 'use-context-selector';
import { EdgeData, NodeData, SchemaId } from '../../common/common-types';
import { BackendContext } from '../contexts/BackendContext';

/**
* Determines whether a node should use automatic ahead-of-time features, such as individually running the node or determining certain type features automatically.
*/
export const useAutomaticFeatures = (id: string, schemaId: SchemaId) => {
const { schemata } = useContext(BackendContext);
const schema = schemata.get(schemaId);

const { getEdges, getNodes, getNode } = useReactFlow<NodeData, EdgeData>();
const thisNode = getNode(id);

// A node should not use automatic features if it has incoming connections
const hasIncomingConnections =
thisNode && getIncomers(thisNode, getNodes(), getEdges()).length > 0;

// If the node is a generator, it should not use automatic features
const isGenerator = schema.kind === 'generator';
// Same if it has any static input values
const hasStaticValueInput = schema.inputs.some((i) => i.kind === 'static');
// We should only use automatic features if the node has side effects
const { hasSideEffects } = schema;

return {
isAutomatic:
hasSideEffects && !hasIncomingConnections && !isGenerator && !hasStaticValueInput,
hasIncomingConnections,
};
};
17 changes: 14 additions & 3 deletions src/renderer/hooks/useRunNode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { AlertBoxContext } from '../contexts/AlertBoxContext';
import { BackendContext } from '../contexts/BackendContext';
import { GlobalContext } from '../contexts/GlobalNodeState';
import { useAsyncEffect } from './useAsyncEffect';
import { useAutomaticFeatures } from './useAutomaticFeatures';
import { useSettings } from './useSettings';

/**
Expand All @@ -16,8 +17,8 @@ import { useSettings } from './useSettings';
*/
export const useRunNode = (
{ inputData, id, schemaId }: NodeData,
shouldRun: boolean
): (() => void) => {
isValid: boolean
): { reload: () => void; isLive: boolean } => {
const { sendToast } = useContext(AlertBoxContext);
const { addIndividuallyRunning, removeIndividuallyRunning } = useContext(GlobalContext);
const { schemata, backend } = useContext(BackendContext);
Expand All @@ -39,6 +40,10 @@ export const useRunNode = (
[reloadCounter, inputs]
);
const lastInputHash = useRef<string>();

const { isAutomatic, hasIncomingConnections } = useAutomaticFeatures(id, schemaId);
const shouldRun = isValid && isAutomatic;

useAsyncEffect(
() => async (token) => {
if (inputHash === lastInputHash.current) {
Expand Down Expand Up @@ -85,5 +90,11 @@ export const useRunNode = (
};
}, [backend, id]);

return reload;
useEffect(() => {
if (hasIncomingConnections && didEverRun.current) {
backend.clearNodeCacheIndividual(id).catch(log.error);
}
}, [backend, hasIncomingConnections, id]);

return { reload, isLive: shouldRun };
};
Loading