Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions invokeai/app/invocations/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,115 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
)


PIL_RESAMPLING_MODES = Literal[
"nearest",
"box",
"bilinear",
"hamming",
"bicubic",
"lanczos",
]


PIL_RESAMPLING_MAP = {
"nearest": Image.Resampling.NEAREST,
"box": Image.Resampling.BOX,
"bilinear": Image.Resampling.BILINEAR,
"hamming": Image.Resampling.HAMMING,
"bicubic": Image.Resampling.BICUBIC,
"lanczos": Image.Resampling.LANCZOS,
}


class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
"""Resizes an image to specific dimensions"""

# fmt: off
type: Literal["img_resize"] = "img_resize"

# Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to resize")
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
# fmt: on

def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)

resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]

resize_image = image.resize(
(self.width, self.height),
resample=resample_mode,
)

image_dto = context.services.images.create(
image=resize_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)

return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
)


class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
"""Scales an image by a factor"""

# fmt: off
type: Literal["img_scale"] = "img_scale"

# Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to scale")
scale_factor: float = Field(gt=0, description="The factor by which to scale the image")
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
# fmt: on

def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)

resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
width = int(image.width * self.scale_factor)
height = int(image.height * self.scale_factor)

resize_image = image.resize(
(width, height),
resample=resample_mode,
)

image_dto = context.services.images.create(
image=resize_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)

return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
)


class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
"""Linear interpolation of all pixels of an image"""

Expand Down
6 changes: 4 additions & 2 deletions invokeai/frontend/web/public/locales/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,9 @@
"noImagesInGallery": "No Images In Gallery",
"deleteImage": "Delete Image",
"deleteImageBin": "Deleted images will be sent to your operating system's Bin.",
"deleteImagePermanent": "Deleted images cannot be restored."
"deleteImagePermanent": "Deleted images cannot be restored.",
"images": "Images",
"assets": "Assets"
},
"hotkeys": {
"keyboardShortcuts": "Keyboard Shortcuts",
Expand Down Expand Up @@ -524,7 +526,7 @@
},
"settings": {
"models": "Models",
"displayInProgress": "Display In-Progress Images",
"displayInProgress": "Display Progress Images",
"saveSteps": "Save images every n steps",
"confirmOnDelete": "Confirm On Delete",
"displayHelpIcons": "Display Help Icons",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@ export const actionsDenylist = [
'canvas/setBoundingBoxDimensions',
'canvas/setIsDrawing',
'canvas/addPointToCurrentLine',
'socket/generatorProgress',
'socket/socketGeneratorProgress',
'socket/appSocketGeneratorProgress',
];
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ import { addCanvasSavedToGalleryListener } from './listeners/canvasSavedToGaller
import { addCanvasDownloadedAsImageListener } from './listeners/canvasDownloadedAsImage';
import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard';
import { addCanvasMergedListener } from './listeners/canvasMerged';
import { addGeneratorProgressListener } from './listeners/socketio/generatorProgress';
import { addGraphExecutionStateCompleteListener } from './listeners/socketio/graphExecutionStateComplete';
import { addInvocationCompleteListener } from './listeners/socketio/invocationComplete';
import { addInvocationErrorListener } from './listeners/socketio/invocationError';
import { addInvocationStartedListener } from './listeners/socketio/invocationStarted';
import { addSocketConnectedListener } from './listeners/socketio/socketConnected';
import { addSocketDisconnectedListener } from './listeners/socketio/socketDisconnected';
import { addSocketSubscribedListener } from './listeners/socketio/socketSubscribed';
import { addSocketUnsubscribedListener } from './listeners/socketio/socketUnsubscribed';
import { addGeneratorProgressEventListener as addGeneratorProgressListener } from './listeners/socketio/socketGeneratorProgress';
import { addGraphExecutionStateCompleteEventListener as addGraphExecutionStateCompleteListener } from './listeners/socketio/socketGraphExecutionStateComplete';
import { addInvocationCompleteEventListener as addInvocationCompleteListener } from './listeners/socketio/socketInvocationComplete';
import { addInvocationErrorEventListener as addInvocationErrorListener } from './listeners/socketio/socketInvocationError';
import { addInvocationStartedEventListener as addInvocationStartedListener } from './listeners/socketio/socketInvocationStarted';
import { addSocketConnectedEventListener as addSocketConnectedListener } from './listeners/socketio/socketConnected';
import { addSocketDisconnectedEventListener as addSocketDisconnectedListener } from './listeners/socketio/socketDisconnected';
import { addSocketSubscribedEventListener as addSocketSubscribedListener } from './listeners/socketio/socketSubscribed';
import { addSocketUnsubscribedEventListener as addSocketUnsubscribedListener } from './listeners/socketio/socketUnsubscribed';
import { addSessionReadyToInvokeListener } from './listeners/sessionReadyToInvoke';
import {
addImageMetadataReceivedFulfilledListener,
Expand Down Expand Up @@ -68,6 +68,8 @@ import {
addReceivedPageOfImagesRejectedListener,
} from './listeners/receivedPageOfImages';
import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSaved';
import { addCommitStagingAreaImageListener } from './listeners/addCommitStagingAreaImageListener';
import { addImageCategoriesChangedListener } from './listeners/imageCategoriesChanged';

export const listenerMiddleware = createListenerMiddleware();

Expand Down Expand Up @@ -125,8 +127,21 @@ addCanvasDownloadedAsImageListener();
addCanvasCopiedToClipboardListener();
addCanvasMergedListener();
addStagingAreaImageSavedListener();

// socketio
addCommitStagingAreaImageListener();

/**
* Socket.IO Events - these handle SIO events directly and pass on internal application actions.
* We don't handle SIO events in slices via `extraReducers` because some of these events shouldn't
* actually be handled at all.
*
* For example, we don't want to respond to progress events for canceled sessions. To avoid
* duplicating the logic to determine if an event should be responded to, we handle all of that
* "is this session canceled?" logic in these listeners.
*
* The `socketGeneratorProgress` listener will then only dispatch the `appSocketGeneratorProgress`
* action if it should be handled by the rest of the application. It is this `appSocketGeneratorProgress`
* action that is handled by reducers in slices.
*/
addGeneratorProgressListener();
addGraphExecutionStateCompleteListener();
addInvocationCompleteListener();
Expand All @@ -152,6 +167,9 @@ addSessionCanceledPendingListener();
addSessionCanceledFulfilledListener();
addSessionCanceledRejectedListener();

// Images
// Fetching images
addReceivedPageOfImagesFulfilledListener();
addReceivedPageOfImagesRejectedListener();

// Gallery
addImageCategoriesChangedListener();
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import { startAppListening } from '..';
import { log } from 'app/logging/useLogger';
import { commitStagingAreaImage } from 'features/canvas/store/canvasSlice';
import { sessionCanceled } from 'services/thunks/session';

const moduleLog = log.child({ namespace: 'canvas' });

export const addCommitStagingAreaImageListener = () => {
startAppListening({
actionCreator: commitStagingAreaImage,
effect: async (action, { dispatch, getState }) => {
const state = getState();
const { sessionId, isProcessing } = state.system;
const canvasSessionId = action.payload;

if (!isProcessing) {
// Only need to cancel if we are processing
return;
}

if (!canvasSessionId) {
moduleLog.debug('No canvas session, skipping cancel');
return;
}

if (canvasSessionId !== sessionId) {
moduleLog.debug(
{
data: {
canvasSessionId,
sessionId,
},
},
'Canvas session does not match global session, skipping cancel'
);
return;
}

dispatch(sessionCanceled({ sessionId }));
},
});
};
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ export const addCanvasSavedToGalleryListener = () => {
effect: async (action, { dispatch, getState, take }) => {
const state = getState();

const blob = await getBaseLayerBlob(state);
const blob = await getBaseLayerBlob(state, true);

if (!blob) {
moduleLog.error('Problem getting base layer blob');
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import { log } from 'app/logging/useLogger';
import { startAppListening } from '..';
import { receivedPageOfImages } from 'services/thunks/image';
import {
imageCategoriesChanged,
selectFilteredImagesAsArray,
} from 'features/gallery/store/imagesSlice';

const moduleLog = log.child({ namespace: 'gallery' });

export const addImageCategoriesChangedListener = () => {
startAppListening({
actionCreator: imageCategoriesChanged,
effect: (action, { getState, dispatch }) => {
const filteredImagesCount = selectFilteredImagesAsArray(
getState()
).length;

if (!filteredImagesCount) {
dispatch(receivedPageOfImages());
}
},
});
};
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ export const addReceivedPageOfImagesFulfilledListener = () => {
effect: (action, { getState, dispatch }) => {
const page = action.payload;
moduleLog.debug(
{ data: { page } },
{ data: { payload: action.payload } },
`Received ${page.items.length} images`
);
},
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import { startAppListening } from '../..';
import { log } from 'app/logging/useLogger';
import { socketConnected } from 'services/events/actions';
import { appSocketConnected, socketConnected } from 'services/events/actions';
import { receivedPageOfImages } from 'services/thunks/image';
import { receivedModels } from 'services/thunks/model';
import { receivedOpenAPISchema } from 'services/thunks/schema';

const moduleLog = log.child({ namespace: 'socketio' });

export const addSocketConnectedListener = () => {
export const addSocketConnectedEventListener = () => {
startAppListening({
actionCreator: socketConnected,
effect: (action, { dispatch, getState }) => {
Expand All @@ -30,6 +30,9 @@ export const addSocketConnectedListener = () => {
if (!nodes.schema && !disabledTabs.includes('nodes')) {
dispatch(receivedOpenAPISchema());
}

// pass along the socket event as an application action
dispatch(appSocketConnected(action.payload));
},
});
};
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
import { startAppListening } from '../..';
import { log } from 'app/logging/useLogger';
import { socketDisconnected } from 'services/events/actions';
import {
socketDisconnected,
appSocketDisconnected,
} from 'services/events/actions';

const moduleLog = log.child({ namespace: 'socketio' });

export const addSocketDisconnectedListener = () => {
export const addSocketDisconnectedEventListener = () => {
startAppListening({
actionCreator: socketDisconnected,
effect: (action, { dispatch, getState }) => {
moduleLog.debug(action.payload, 'Disconnected');
// pass along the socket event as an application action
dispatch(appSocketDisconnected(action.payload));
},
});
};
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import { startAppListening } from '../..';
import { log } from 'app/logging/useLogger';
import { generatorProgress } from 'services/events/actions';
import {
appSocketGeneratorProgress,
socketGeneratorProgress,
} from 'services/events/actions';

const moduleLog = log.child({ namespace: 'socketio' });

export const addGeneratorProgressListener = () => {
export const addGeneratorProgressEventListener = () => {
startAppListening({
actionCreator: generatorProgress,
actionCreator: socketGeneratorProgress,
effect: (action, { dispatch, getState }) => {
if (
getState().system.canceledSession ===
Expand All @@ -23,6 +26,9 @@ export const addGeneratorProgressListener = () => {
action.payload,
`Generator progress (${action.payload.data.node.type})`
);

// pass along the socket event as an application action
dispatch(appSocketGeneratorProgress(action.payload));
},
});
};
Loading