diff --git a/LICENSE.txt b/LICENSE.txt
index 145775438..211d32e75 100644
--- a/LICENSE.txt
+++ b/LICENSE.txt
@@ -1,663 +1,663 @@
- GNU AFFERO GENERAL PUBLIC LICENSE
- Version 3, 19 November 2007
-
- Copyright (c) 2023 AUTOMATIC1111
-
- Copyright (C) 2007 Free Software Foundation, Inc.
- Everyone is permitted to copy and distribute verbatim copies
- of this license document, but changing it is not allowed.
-
- Preamble
-
- The GNU Affero General Public License is a free, copyleft license for
-software and other kinds of works, specifically designed to ensure
-cooperation with the community in the case of network server software.
-
- The licenses for most software and other practical works are designed
-to take away your freedom to share and change the works. By contrast,
-our General Public Licenses are intended to guarantee your freedom to
-share and change all versions of a program--to make sure it remains free
-software for all its users.
-
- When we speak of free software, we are referring to freedom, not
-price. Our General Public Licenses are designed to make sure that you
-have the freedom to distribute copies of free software (and charge for
-them if you wish), that you receive source code or can get it if you
-want it, that you can change the software or use pieces of it in new
-free programs, and that you know you can do these things.
-
- Developers that use our General Public Licenses protect your rights
-with two steps: (1) assert copyright on the software, and (2) offer
-you this License which gives you legal permission to copy, distribute
-and/or modify the software.
-
- A secondary benefit of defending all users' freedom is that
-improvements made in alternate versions of the program, if they
-receive widespread use, become available for other developers to
-incorporate. Many developers of free software are heartened and
-encouraged by the resulting cooperation. However, in the case of
-software used on network servers, this result may fail to come about.
-The GNU General Public License permits making a modified version and
-letting the public access it on a server without ever releasing its
-source code to the public.
-
- The GNU Affero General Public License is designed specifically to
-ensure that, in such cases, the modified source code becomes available
-to the community. It requires the operator of a network server to
-provide the source code of the modified version running there to the
-users of that server. Therefore, public use of a modified version, on
-a publicly accessible server, gives the public access to the source
-code of the modified version.
-
- An older license, called the Affero General Public License and
-published by Affero, was designed to accomplish similar goals. This is
-a different license, not a version of the Affero GPL, but Affero has
-released a new version of the Affero GPL which permits relicensing under
-this license.
-
- The precise terms and conditions for copying, distribution and
-modification follow.
-
- TERMS AND CONDITIONS
-
- 0. Definitions.
-
- "This License" refers to version 3 of the GNU Affero General Public License.
-
- "Copyright" also means copyright-like laws that apply to other kinds of
-works, such as semiconductor masks.
-
- "The Program" refers to any copyrightable work licensed under this
-License. Each licensee is addressed as "you". "Licensees" and
-"recipients" may be individuals or organizations.
-
- To "modify" a work means to copy from or adapt all or part of the work
-in a fashion requiring copyright permission, other than the making of an
-exact copy. The resulting work is called a "modified version" of the
-earlier work or a work "based on" the earlier work.
-
- A "covered work" means either the unmodified Program or a work based
-on the Program.
-
- To "propagate" a work means to do anything with it that, without
-permission, would make you directly or secondarily liable for
-infringement under applicable copyright law, except executing it on a
-computer or modifying a private copy. Propagation includes copying,
-distribution (with or without modification), making available to the
-public, and in some countries other activities as well.
-
- To "convey" a work means any kind of propagation that enables other
-parties to make or receive copies. Mere interaction with a user through
-a computer network, with no transfer of a copy, is not conveying.
-
- An interactive user interface displays "Appropriate Legal Notices"
-to the extent that it includes a convenient and prominently visible
-feature that (1) displays an appropriate copyright notice, and (2)
-tells the user that there is no warranty for the work (except to the
-extent that warranties are provided), that licensees may convey the
-work under this License, and how to view a copy of this License. If
-the interface presents a list of user commands or options, such as a
-menu, a prominent item in the list meets this criterion.
-
- 1. Source Code.
-
- The "source code" for a work means the preferred form of the work
-for making modifications to it. "Object code" means any non-source
-form of a work.
-
- A "Standard Interface" means an interface that either is an official
-standard defined by a recognized standards body, or, in the case of
-interfaces specified for a particular programming language, one that
-is widely used among developers working in that language.
-
- The "System Libraries" of an executable work include anything, other
-than the work as a whole, that (a) is included in the normal form of
-packaging a Major Component, but which is not part of that Major
-Component, and (b) serves only to enable use of the work with that
-Major Component, or to implement a Standard Interface for which an
-implementation is available to the public in source code form. A
-"Major Component", in this context, means a major essential component
-(kernel, window system, and so on) of the specific operating system
-(if any) on which the executable work runs, or a compiler used to
-produce the work, or an object code interpreter used to run it.
-
- The "Corresponding Source" for a work in object code form means all
-the source code needed to generate, install, and (for an executable
-work) run the object code and to modify the work, including scripts to
-control those activities. However, it does not include the work's
-System Libraries, or general-purpose tools or generally available free
-programs which are used unmodified in performing those activities but
-which are not part of the work. For example, Corresponding Source
-includes interface definition files associated with source files for
-the work, and the source code for shared libraries and dynamically
-linked subprograms that the work is specifically designed to require,
-such as by intimate data communication or control flow between those
-subprograms and other parts of the work.
-
- The Corresponding Source need not include anything that users
-can regenerate automatically from other parts of the Corresponding
-Source.
-
- The Corresponding Source for a work in source code form is that
-same work.
-
- 2. Basic Permissions.
-
- All rights granted under this License are granted for the term of
-copyright on the Program, and are irrevocable provided the stated
-conditions are met. This License explicitly affirms your unlimited
-permission to run the unmodified Program. The output from running a
-covered work is covered by this License only if the output, given its
-content, constitutes a covered work. This License acknowledges your
-rights of fair use or other equivalent, as provided by copyright law.
-
- You may make, run and propagate covered works that you do not
-convey, without conditions so long as your license otherwise remains
-in force. You may convey covered works to others for the sole purpose
-of having them make modifications exclusively for you, or provide you
-with facilities for running those works, provided that you comply with
-the terms of this License in conveying all material for which you do
-not control copyright. Those thus making or running the covered works
-for you must do so exclusively on your behalf, under your direction
-and control, on terms that prohibit them from making any copies of
-your copyrighted material outside their relationship with you.
-
- Conveying under any other circumstances is permitted solely under
-the conditions stated below. Sublicensing is not allowed; section 10
-makes it unnecessary.
-
- 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
-
- No covered work shall be deemed part of an effective technological
-measure under any applicable law fulfilling obligations under article
-11 of the WIPO copyright treaty adopted on 20 December 1996, or
-similar laws prohibiting or restricting circumvention of such
-measures.
-
- When you convey a covered work, you waive any legal power to forbid
-circumvention of technological measures to the extent such circumvention
-is effected by exercising rights under this License with respect to
-the covered work, and you disclaim any intention to limit operation or
-modification of the work as a means of enforcing, against the work's
-users, your or third parties' legal rights to forbid circumvention of
-technological measures.
-
- 4. Conveying Verbatim Copies.
-
- You may convey verbatim copies of the Program's source code as you
-receive it, in any medium, provided that you conspicuously and
-appropriately publish on each copy an appropriate copyright notice;
-keep intact all notices stating that this License and any
-non-permissive terms added in accord with section 7 apply to the code;
-keep intact all notices of the absence of any warranty; and give all
-recipients a copy of this License along with the Program.
-
- You may charge any price or no price for each copy that you convey,
-and you may offer support or warranty protection for a fee.
-
- 5. Conveying Modified Source Versions.
-
- You may convey a work based on the Program, or the modifications to
-produce it from the Program, in the form of source code under the
-terms of section 4, provided that you also meet all of these conditions:
-
- a) The work must carry prominent notices stating that you modified
- it, and giving a relevant date.
-
- b) The work must carry prominent notices stating that it is
- released under this License and any conditions added under section
- 7. This requirement modifies the requirement in section 4 to
- "keep intact all notices".
-
- c) You must license the entire work, as a whole, under this
- License to anyone who comes into possession of a copy. This
- License will therefore apply, along with any applicable section 7
- additional terms, to the whole of the work, and all its parts,
- regardless of how they are packaged. This License gives no
- permission to license the work in any other way, but it does not
- invalidate such permission if you have separately received it.
-
- d) If the work has interactive user interfaces, each must display
- Appropriate Legal Notices; however, if the Program has interactive
- interfaces that do not display Appropriate Legal Notices, your
- work need not make them do so.
-
- A compilation of a covered work with other separate and independent
-works, which are not by their nature extensions of the covered work,
-and which are not combined with it such as to form a larger program,
-in or on a volume of a storage or distribution medium, is called an
-"aggregate" if the compilation and its resulting copyright are not
-used to limit the access or legal rights of the compilation's users
-beyond what the individual works permit. Inclusion of a covered work
-in an aggregate does not cause this License to apply to the other
-parts of the aggregate.
-
- 6. Conveying Non-Source Forms.
-
- You may convey a covered work in object code form under the terms
-of sections 4 and 5, provided that you also convey the
-machine-readable Corresponding Source under the terms of this License,
-in one of these ways:
-
- a) Convey the object code in, or embodied in, a physical product
- (including a physical distribution medium), accompanied by the
- Corresponding Source fixed on a durable physical medium
- customarily used for software interchange.
-
- b) Convey the object code in, or embodied in, a physical product
- (including a physical distribution medium), accompanied by a
- written offer, valid for at least three years and valid for as
- long as you offer spare parts or customer support for that product
- model, to give anyone who possesses the object code either (1) a
- copy of the Corresponding Source for all the software in the
- product that is covered by this License, on a durable physical
- medium customarily used for software interchange, for a price no
- more than your reasonable cost of physically performing this
- conveying of source, or (2) access to copy the
- Corresponding Source from a network server at no charge.
-
- c) Convey individual copies of the object code with a copy of the
- written offer to provide the Corresponding Source. This
- alternative is allowed only occasionally and noncommercially, and
- only if you received the object code with such an offer, in accord
- with subsection 6b.
-
- d) Convey the object code by offering access from a designated
- place (gratis or for a charge), and offer equivalent access to the
- Corresponding Source in the same way through the same place at no
- further charge. You need not require recipients to copy the
- Corresponding Source along with the object code. If the place to
- copy the object code is a network server, the Corresponding Source
- may be on a different server (operated by you or a third party)
- that supports equivalent copying facilities, provided you maintain
- clear directions next to the object code saying where to find the
- Corresponding Source. Regardless of what server hosts the
- Corresponding Source, you remain obligated to ensure that it is
- available for as long as needed to satisfy these requirements.
-
- e) Convey the object code using peer-to-peer transmission, provided
- you inform other peers where the object code and Corresponding
- Source of the work are being offered to the general public at no
- charge under subsection 6d.
-
- A separable portion of the object code, whose source code is excluded
-from the Corresponding Source as a System Library, need not be
-included in conveying the object code work.
-
- A "User Product" is either (1) a "consumer product", which means any
-tangible personal property which is normally used for personal, family,
-or household purposes, or (2) anything designed or sold for incorporation
-into a dwelling. In determining whether a product is a consumer product,
-doubtful cases shall be resolved in favor of coverage. For a particular
-product received by a particular user, "normally used" refers to a
-typical or common use of that class of product, regardless of the status
-of the particular user or of the way in which the particular user
-actually uses, or expects or is expected to use, the product. A product
-is a consumer product regardless of whether the product has substantial
-commercial, industrial or non-consumer uses, unless such uses represent
-the only significant mode of use of the product.
-
- "Installation Information" for a User Product means any methods,
-procedures, authorization keys, or other information required to install
-and execute modified versions of a covered work in that User Product from
-a modified version of its Corresponding Source. The information must
-suffice to ensure that the continued functioning of the modified object
-code is in no case prevented or interfered with solely because
-modification has been made.
-
- If you convey an object code work under this section in, or with, or
-specifically for use in, a User Product, and the conveying occurs as
-part of a transaction in which the right of possession and use of the
-User Product is transferred to the recipient in perpetuity or for a
-fixed term (regardless of how the transaction is characterized), the
-Corresponding Source conveyed under this section must be accompanied
-by the Installation Information. But this requirement does not apply
-if neither you nor any third party retains the ability to install
-modified object code on the User Product (for example, the work has
-been installed in ROM).
-
- The requirement to provide Installation Information does not include a
-requirement to continue to provide support service, warranty, or updates
-for a work that has been modified or installed by the recipient, or for
-the User Product in which it has been modified or installed. Access to a
-network may be denied when the modification itself materially and
-adversely affects the operation of the network or violates the rules and
-protocols for communication across the network.
-
- Corresponding Source conveyed, and Installation Information provided,
-in accord with this section must be in a format that is publicly
-documented (and with an implementation available to the public in
-source code form), and must require no special password or key for
-unpacking, reading or copying.
-
- 7. Additional Terms.
-
- "Additional permissions" are terms that supplement the terms of this
-License by making exceptions from one or more of its conditions.
-Additional permissions that are applicable to the entire Program shall
-be treated as though they were included in this License, to the extent
-that they are valid under applicable law. If additional permissions
-apply only to part of the Program, that part may be used separately
-under those permissions, but the entire Program remains governed by
-this License without regard to the additional permissions.
-
- When you convey a copy of a covered work, you may at your option
-remove any additional permissions from that copy, or from any part of
-it. (Additional permissions may be written to require their own
-removal in certain cases when you modify the work.) You may place
-additional permissions on material, added by you to a covered work,
-for which you have or can give appropriate copyright permission.
-
- Notwithstanding any other provision of this License, for material you
-add to a covered work, you may (if authorized by the copyright holders of
-that material) supplement the terms of this License with terms:
-
- a) Disclaiming warranty or limiting liability differently from the
- terms of sections 15 and 16 of this License; or
-
- b) Requiring preservation of specified reasonable legal notices or
- author attributions in that material or in the Appropriate Legal
- Notices displayed by works containing it; or
-
- c) Prohibiting misrepresentation of the origin of that material, or
- requiring that modified versions of such material be marked in
- reasonable ways as different from the original version; or
-
- d) Limiting the use for publicity purposes of names of licensors or
- authors of the material; or
-
- e) Declining to grant rights under trademark law for use of some
- trade names, trademarks, or service marks; or
-
- f) Requiring indemnification of licensors and authors of that
- material by anyone who conveys the material (or modified versions of
- it) with contractual assumptions of liability to the recipient, for
- any liability that these contractual assumptions directly impose on
- those licensors and authors.
-
- All other non-permissive additional terms are considered "further
-restrictions" within the meaning of section 10. If the Program as you
-received it, or any part of it, contains a notice stating that it is
-governed by this License along with a term that is a further
-restriction, you may remove that term. If a license document contains
-a further restriction but permits relicensing or conveying under this
-License, you may add to a covered work material governed by the terms
-of that license document, provided that the further restriction does
-not survive such relicensing or conveying.
-
- If you add terms to a covered work in accord with this section, you
-must place, in the relevant source files, a statement of the
-additional terms that apply to those files, or a notice indicating
-where to find the applicable terms.
-
- Additional terms, permissive or non-permissive, may be stated in the
-form of a separately written license, or stated as exceptions;
-the above requirements apply either way.
-
- 8. Termination.
-
- You may not propagate or modify a covered work except as expressly
-provided under this License. Any attempt otherwise to propagate or
-modify it is void, and will automatically terminate your rights under
-this License (including any patent licenses granted under the third
-paragraph of section 11).
-
- However, if you cease all violation of this License, then your
-license from a particular copyright holder is reinstated (a)
-provisionally, unless and until the copyright holder explicitly and
-finally terminates your license, and (b) permanently, if the copyright
-holder fails to notify you of the violation by some reasonable means
-prior to 60 days after the cessation.
-
- Moreover, your license from a particular copyright holder is
-reinstated permanently if the copyright holder notifies you of the
-violation by some reasonable means, this is the first time you have
-received notice of violation of this License (for any work) from that
-copyright holder, and you cure the violation prior to 30 days after
-your receipt of the notice.
-
- Termination of your rights under this section does not terminate the
-licenses of parties who have received copies or rights from you under
-this License. If your rights have been terminated and not permanently
-reinstated, you do not qualify to receive new licenses for the same
-material under section 10.
-
- 9. Acceptance Not Required for Having Copies.
-
- You are not required to accept this License in order to receive or
-run a copy of the Program. Ancillary propagation of a covered work
-occurring solely as a consequence of using peer-to-peer transmission
-to receive a copy likewise does not require acceptance. However,
-nothing other than this License grants you permission to propagate or
-modify any covered work. These actions infringe copyright if you do
-not accept this License. Therefore, by modifying or propagating a
-covered work, you indicate your acceptance of this License to do so.
-
- 10. Automatic Licensing of Downstream Recipients.
-
- Each time you convey a covered work, the recipient automatically
-receives a license from the original licensors, to run, modify and
-propagate that work, subject to this License. You are not responsible
-for enforcing compliance by third parties with this License.
-
- An "entity transaction" is a transaction transferring control of an
-organization, or substantially all assets of one, or subdividing an
-organization, or merging organizations. If propagation of a covered
-work results from an entity transaction, each party to that
-transaction who receives a copy of the work also receives whatever
-licenses to the work the party's predecessor in interest had or could
-give under the previous paragraph, plus a right to possession of the
-Corresponding Source of the work from the predecessor in interest, if
-the predecessor has it or can get it with reasonable efforts.
-
- You may not impose any further restrictions on the exercise of the
-rights granted or affirmed under this License. For example, you may
-not impose a license fee, royalty, or other charge for exercise of
-rights granted under this License, and you may not initiate litigation
-(including a cross-claim or counterclaim in a lawsuit) alleging that
-any patent claim is infringed by making, using, selling, offering for
-sale, or importing the Program or any portion of it.
-
- 11. Patents.
-
- A "contributor" is a copyright holder who authorizes use under this
-License of the Program or a work on which the Program is based. The
-work thus licensed is called the contributor's "contributor version".
-
- A contributor's "essential patent claims" are all patent claims
-owned or controlled by the contributor, whether already acquired or
-hereafter acquired, that would be infringed by some manner, permitted
-by this License, of making, using, or selling its contributor version,
-but do not include claims that would be infringed only as a
-consequence of further modification of the contributor version. For
-purposes of this definition, "control" includes the right to grant
-patent sublicenses in a manner consistent with the requirements of
-this License.
-
- Each contributor grants you a non-exclusive, worldwide, royalty-free
-patent license under the contributor's essential patent claims, to
-make, use, sell, offer for sale, import and otherwise run, modify and
-propagate the contents of its contributor version.
-
- In the following three paragraphs, a "patent license" is any express
-agreement or commitment, however denominated, not to enforce a patent
-(such as an express permission to practice a patent or covenant not to
-sue for patent infringement). To "grant" such a patent license to a
-party means to make such an agreement or commitment not to enforce a
-patent against the party.
-
- If you convey a covered work, knowingly relying on a patent license,
-and the Corresponding Source of the work is not available for anyone
-to copy, free of charge and under the terms of this License, through a
-publicly available network server or other readily accessible means,
-then you must either (1) cause the Corresponding Source to be so
-available, or (2) arrange to deprive yourself of the benefit of the
-patent license for this particular work, or (3) arrange, in a manner
-consistent with the requirements of this License, to extend the patent
-license to downstream recipients. "Knowingly relying" means you have
-actual knowledge that, but for the patent license, your conveying the
-covered work in a country, or your recipient's use of the covered work
-in a country, would infringe one or more identifiable patents in that
-country that you have reason to believe are valid.
-
- If, pursuant to or in connection with a single transaction or
-arrangement, you convey, or propagate by procuring conveyance of, a
-covered work, and grant a patent license to some of the parties
-receiving the covered work authorizing them to use, propagate, modify
-or convey a specific copy of the covered work, then the patent license
-you grant is automatically extended to all recipients of the covered
-work and works based on it.
-
- A patent license is "discriminatory" if it does not include within
-the scope of its coverage, prohibits the exercise of, or is
-conditioned on the non-exercise of one or more of the rights that are
-specifically granted under this License. You may not convey a covered
-work if you are a party to an arrangement with a third party that is
-in the business of distributing software, under which you make payment
-to the third party based on the extent of your activity of conveying
-the work, and under which the third party grants, to any of the
-parties who would receive the covered work from you, a discriminatory
-patent license (a) in connection with copies of the covered work
-conveyed by you (or copies made from those copies), or (b) primarily
-for and in connection with specific products or compilations that
-contain the covered work, unless you entered into that arrangement,
-or that patent license was granted, prior to 28 March 2007.
-
- Nothing in this License shall be construed as excluding or limiting
-any implied license or other defenses to infringement that may
-otherwise be available to you under applicable patent law.
-
- 12. No Surrender of Others' Freedom.
-
- If conditions are imposed on you (whether by court order, agreement or
-otherwise) that contradict the conditions of this License, they do not
-excuse you from the conditions of this License. If you cannot convey a
-covered work so as to satisfy simultaneously your obligations under this
-License and any other pertinent obligations, then as a consequence you may
-not convey it at all. For example, if you agree to terms that obligate you
-to collect a royalty for further conveying from those to whom you convey
-the Program, the only way you could satisfy both those terms and this
-License would be to refrain entirely from conveying the Program.
-
- 13. Remote Network Interaction; Use with the GNU General Public License.
-
- Notwithstanding any other provision of this License, if you modify the
-Program, your modified version must prominently offer all users
-interacting with it remotely through a computer network (if your version
-supports such interaction) an opportunity to receive the Corresponding
-Source of your version by providing access to the Corresponding Source
-from a network server at no charge, through some standard or customary
-means of facilitating copying of software. This Corresponding Source
-shall include the Corresponding Source for any work covered by version 3
-of the GNU General Public License that is incorporated pursuant to the
-following paragraph.
-
- Notwithstanding any other provision of this License, you have
-permission to link or combine any covered work with a work licensed
-under version 3 of the GNU General Public License into a single
-combined work, and to convey the resulting work. The terms of this
-License will continue to apply to the part which is the covered work,
-but the work with which it is combined will remain governed by version
-3 of the GNU General Public License.
-
- 14. Revised Versions of this License.
-
- The Free Software Foundation may publish revised and/or new versions of
-the GNU Affero General Public License from time to time. Such new versions
-will be similar in spirit to the present version, but may differ in detail to
-address new problems or concerns.
-
- Each version is given a distinguishing version number. If the
-Program specifies that a certain numbered version of the GNU Affero General
-Public License "or any later version" applies to it, you have the
-option of following the terms and conditions either of that numbered
-version or of any later version published by the Free Software
-Foundation. If the Program does not specify a version number of the
-GNU Affero General Public License, you may choose any version ever published
-by the Free Software Foundation.
-
- If the Program specifies that a proxy can decide which future
-versions of the GNU Affero General Public License can be used, that proxy's
-public statement of acceptance of a version permanently authorizes you
-to choose that version for the Program.
-
- Later license versions may give you additional or different
-permissions. However, no additional obligations are imposed on any
-author or copyright holder as a result of your choosing to follow a
-later version.
-
- 15. Disclaimer of Warranty.
-
- THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
-APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
-HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
-OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
-THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
-PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
-IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
-ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
-
- 16. Limitation of Liability.
-
- IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
-WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
-THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
-GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
-USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
-DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
-PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
-EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
-SUCH DAMAGES.
-
- 17. Interpretation of Sections 15 and 16.
-
- If the disclaimer of warranty and limitation of liability provided
-above cannot be given local legal effect according to their terms,
-reviewing courts shall apply local law that most closely approximates
-an absolute waiver of all civil liability in connection with the
-Program, unless a warranty or assumption of liability accompanies a
-copy of the Program in return for a fee.
-
- END OF TERMS AND CONDITIONS
-
- How to Apply These Terms to Your New Programs
-
- If you develop a new program, and you want it to be of the greatest
-possible use to the public, the best way to achieve this is to make it
-free software which everyone can redistribute and change under these terms.
-
- To do so, attach the following notices to the program. It is safest
-to attach them to the start of each source file to most effectively
-state the exclusion of warranty; and each file should have at least
-the "copyright" line and a pointer to where the full notice is found.
-
-
- Copyright (C)
-
- This program is free software: you can redistribute it and/or modify
- it under the terms of the GNU Affero General Public License as published by
- the Free Software Foundation, either version 3 of the License, or
- (at your option) any later version.
-
- This program is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU Affero General Public License for more details.
-
- You should have received a copy of the GNU Affero General Public License
- along with this program. If not, see .
-
-Also add information on how to contact you by electronic and paper mail.
-
- If your software can interact with users remotely through a computer
-network, you should also make sure that it provides a way for users to
-get its source. For example, if your program is a web application, its
-interface could display a "Source" link that leads users to an archive
-of the code. There are many ways you could offer source, and different
-solutions will be better for different programs; see section 13 for the
-specific requirements.
-
- You should also get your employer (if you work as a programmer) or school,
-if any, to sign a "copyright disclaimer" for the program, if necessary.
-For more information on this, and how to apply and follow the GNU AGPL, see
- .
+ GNU AFFERO GENERAL PUBLIC LICENSE
+ Version 3, 19 November 2007
+
+ Copyright (c) 2023 AUTOMATIC1111
+
+ Copyright (C) 2007 Free Software Foundation, Inc.
+ Everyone is permitted to copy and distribute verbatim copies
+ of this license document, but changing it is not allowed.
+
+ Preamble
+
+ The GNU Affero General Public License is a free, copyleft license for
+software and other kinds of works, specifically designed to ensure
+cooperation with the community in the case of network server software.
+
+ The licenses for most software and other practical works are designed
+to take away your freedom to share and change the works. By contrast,
+our General Public Licenses are intended to guarantee your freedom to
+share and change all versions of a program--to make sure it remains free
+software for all its users.
+
+ When we speak of free software, we are referring to freedom, not
+price. Our General Public Licenses are designed to make sure that you
+have the freedom to distribute copies of free software (and charge for
+them if you wish), that you receive source code or can get it if you
+want it, that you can change the software or use pieces of it in new
+free programs, and that you know you can do these things.
+
+ Developers that use our General Public Licenses protect your rights
+with two steps: (1) assert copyright on the software, and (2) offer
+you this License which gives you legal permission to copy, distribute
+and/or modify the software.
+
+ A secondary benefit of defending all users' freedom is that
+improvements made in alternate versions of the program, if they
+receive widespread use, become available for other developers to
+incorporate. Many developers of free software are heartened and
+encouraged by the resulting cooperation. However, in the case of
+software used on network servers, this result may fail to come about.
+The GNU General Public License permits making a modified version and
+letting the public access it on a server without ever releasing its
+source code to the public.
+
+ The GNU Affero General Public License is designed specifically to
+ensure that, in such cases, the modified source code becomes available
+to the community. It requires the operator of a network server to
+provide the source code of the modified version running there to the
+users of that server. Therefore, public use of a modified version, on
+a publicly accessible server, gives the public access to the source
+code of the modified version.
+
+ An older license, called the Affero General Public License and
+published by Affero, was designed to accomplish similar goals. This is
+a different license, not a version of the Affero GPL, but Affero has
+released a new version of the Affero GPL which permits relicensing under
+this license.
+
+ The precise terms and conditions for copying, distribution and
+modification follow.
+
+ TERMS AND CONDITIONS
+
+ 0. Definitions.
+
+ "This License" refers to version 3 of the GNU Affero General Public License.
+
+ "Copyright" also means copyright-like laws that apply to other kinds of
+works, such as semiconductor masks.
+
+ "The Program" refers to any copyrightable work licensed under this
+License. Each licensee is addressed as "you". "Licensees" and
+"recipients" may be individuals or organizations.
+
+ To "modify" a work means to copy from or adapt all or part of the work
+in a fashion requiring copyright permission, other than the making of an
+exact copy. The resulting work is called a "modified version" of the
+earlier work or a work "based on" the earlier work.
+
+ A "covered work" means either the unmodified Program or a work based
+on the Program.
+
+ To "propagate" a work means to do anything with it that, without
+permission, would make you directly or secondarily liable for
+infringement under applicable copyright law, except executing it on a
+computer or modifying a private copy. Propagation includes copying,
+distribution (with or without modification), making available to the
+public, and in some countries other activities as well.
+
+ To "convey" a work means any kind of propagation that enables other
+parties to make or receive copies. Mere interaction with a user through
+a computer network, with no transfer of a copy, is not conveying.
+
+ An interactive user interface displays "Appropriate Legal Notices"
+to the extent that it includes a convenient and prominently visible
+feature that (1) displays an appropriate copyright notice, and (2)
+tells the user that there is no warranty for the work (except to the
+extent that warranties are provided), that licensees may convey the
+work under this License, and how to view a copy of this License. If
+the interface presents a list of user commands or options, such as a
+menu, a prominent item in the list meets this criterion.
+
+ 1. Source Code.
+
+ The "source code" for a work means the preferred form of the work
+for making modifications to it. "Object code" means any non-source
+form of a work.
+
+ A "Standard Interface" means an interface that either is an official
+standard defined by a recognized standards body, or, in the case of
+interfaces specified for a particular programming language, one that
+is widely used among developers working in that language.
+
+ The "System Libraries" of an executable work include anything, other
+than the work as a whole, that (a) is included in the normal form of
+packaging a Major Component, but which is not part of that Major
+Component, and (b) serves only to enable use of the work with that
+Major Component, or to implement a Standard Interface for which an
+implementation is available to the public in source code form. A
+"Major Component", in this context, means a major essential component
+(kernel, window system, and so on) of the specific operating system
+(if any) on which the executable work runs, or a compiler used to
+produce the work, or an object code interpreter used to run it.
+
+ The "Corresponding Source" for a work in object code form means all
+the source code needed to generate, install, and (for an executable
+work) run the object code and to modify the work, including scripts to
+control those activities. However, it does not include the work's
+System Libraries, or general-purpose tools or generally available free
+programs which are used unmodified in performing those activities but
+which are not part of the work. For example, Corresponding Source
+includes interface definition files associated with source files for
+the work, and the source code for shared libraries and dynamically
+linked subprograms that the work is specifically designed to require,
+such as by intimate data communication or control flow between those
+subprograms and other parts of the work.
+
+ The Corresponding Source need not include anything that users
+can regenerate automatically from other parts of the Corresponding
+Source.
+
+ The Corresponding Source for a work in source code form is that
+same work.
+
+ 2. Basic Permissions.
+
+ All rights granted under this License are granted for the term of
+copyright on the Program, and are irrevocable provided the stated
+conditions are met. This License explicitly affirms your unlimited
+permission to run the unmodified Program. The output from running a
+covered work is covered by this License only if the output, given its
+content, constitutes a covered work. This License acknowledges your
+rights of fair use or other equivalent, as provided by copyright law.
+
+ You may make, run and propagate covered works that you do not
+convey, without conditions so long as your license otherwise remains
+in force. You may convey covered works to others for the sole purpose
+of having them make modifications exclusively for you, or provide you
+with facilities for running those works, provided that you comply with
+the terms of this License in conveying all material for which you do
+not control copyright. Those thus making or running the covered works
+for you must do so exclusively on your behalf, under your direction
+and control, on terms that prohibit them from making any copies of
+your copyrighted material outside their relationship with you.
+
+ Conveying under any other circumstances is permitted solely under
+the conditions stated below. Sublicensing is not allowed; section 10
+makes it unnecessary.
+
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
+
+ No covered work shall be deemed part of an effective technological
+measure under any applicable law fulfilling obligations under article
+11 of the WIPO copyright treaty adopted on 20 December 1996, or
+similar laws prohibiting or restricting circumvention of such
+measures.
+
+ When you convey a covered work, you waive any legal power to forbid
+circumvention of technological measures to the extent such circumvention
+is effected by exercising rights under this License with respect to
+the covered work, and you disclaim any intention to limit operation or
+modification of the work as a means of enforcing, against the work's
+users, your or third parties' legal rights to forbid circumvention of
+technological measures.
+
+ 4. Conveying Verbatim Copies.
+
+ You may convey verbatim copies of the Program's source code as you
+receive it, in any medium, provided that you conspicuously and
+appropriately publish on each copy an appropriate copyright notice;
+keep intact all notices stating that this License and any
+non-permissive terms added in accord with section 7 apply to the code;
+keep intact all notices of the absence of any warranty; and give all
+recipients a copy of this License along with the Program.
+
+ You may charge any price or no price for each copy that you convey,
+and you may offer support or warranty protection for a fee.
+
+ 5. Conveying Modified Source Versions.
+
+ You may convey a work based on the Program, or the modifications to
+produce it from the Program, in the form of source code under the
+terms of section 4, provided that you also meet all of these conditions:
+
+ a) The work must carry prominent notices stating that you modified
+ it, and giving a relevant date.
+
+ b) The work must carry prominent notices stating that it is
+ released under this License and any conditions added under section
+ 7. This requirement modifies the requirement in section 4 to
+ "keep intact all notices".
+
+ c) You must license the entire work, as a whole, under this
+ License to anyone who comes into possession of a copy. This
+ License will therefore apply, along with any applicable section 7
+ additional terms, to the whole of the work, and all its parts,
+ regardless of how they are packaged. This License gives no
+ permission to license the work in any other way, but it does not
+ invalidate such permission if you have separately received it.
+
+ d) If the work has interactive user interfaces, each must display
+ Appropriate Legal Notices; however, if the Program has interactive
+ interfaces that do not display Appropriate Legal Notices, your
+ work need not make them do so.
+
+ A compilation of a covered work with other separate and independent
+works, which are not by their nature extensions of the covered work,
+and which are not combined with it such as to form a larger program,
+in or on a volume of a storage or distribution medium, is called an
+"aggregate" if the compilation and its resulting copyright are not
+used to limit the access or legal rights of the compilation's users
+beyond what the individual works permit. Inclusion of a covered work
+in an aggregate does not cause this License to apply to the other
+parts of the aggregate.
+
+ 6. Conveying Non-Source Forms.
+
+ You may convey a covered work in object code form under the terms
+of sections 4 and 5, provided that you also convey the
+machine-readable Corresponding Source under the terms of this License,
+in one of these ways:
+
+ a) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by the
+ Corresponding Source fixed on a durable physical medium
+ customarily used for software interchange.
+
+ b) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by a
+ written offer, valid for at least three years and valid for as
+ long as you offer spare parts or customer support for that product
+ model, to give anyone who possesses the object code either (1) a
+ copy of the Corresponding Source for all the software in the
+ product that is covered by this License, on a durable physical
+ medium customarily used for software interchange, for a price no
+ more than your reasonable cost of physically performing this
+ conveying of source, or (2) access to copy the
+ Corresponding Source from a network server at no charge.
+
+ c) Convey individual copies of the object code with a copy of the
+ written offer to provide the Corresponding Source. This
+ alternative is allowed only occasionally and noncommercially, and
+ only if you received the object code with such an offer, in accord
+ with subsection 6b.
+
+ d) Convey the object code by offering access from a designated
+ place (gratis or for a charge), and offer equivalent access to the
+ Corresponding Source in the same way through the same place at no
+ further charge. You need not require recipients to copy the
+ Corresponding Source along with the object code. If the place to
+ copy the object code is a network server, the Corresponding Source
+ may be on a different server (operated by you or a third party)
+ that supports equivalent copying facilities, provided you maintain
+ clear directions next to the object code saying where to find the
+ Corresponding Source. Regardless of what server hosts the
+ Corresponding Source, you remain obligated to ensure that it is
+ available for as long as needed to satisfy these requirements.
+
+ e) Convey the object code using peer-to-peer transmission, provided
+ you inform other peers where the object code and Corresponding
+ Source of the work are being offered to the general public at no
+ charge under subsection 6d.
+
+ A separable portion of the object code, whose source code is excluded
+from the Corresponding Source as a System Library, need not be
+included in conveying the object code work.
+
+ A "User Product" is either (1) a "consumer product", which means any
+tangible personal property which is normally used for personal, family,
+or household purposes, or (2) anything designed or sold for incorporation
+into a dwelling. In determining whether a product is a consumer product,
+doubtful cases shall be resolved in favor of coverage. For a particular
+product received by a particular user, "normally used" refers to a
+typical or common use of that class of product, regardless of the status
+of the particular user or of the way in which the particular user
+actually uses, or expects or is expected to use, the product. A product
+is a consumer product regardless of whether the product has substantial
+commercial, industrial or non-consumer uses, unless such uses represent
+the only significant mode of use of the product.
+
+ "Installation Information" for a User Product means any methods,
+procedures, authorization keys, or other information required to install
+and execute modified versions of a covered work in that User Product from
+a modified version of its Corresponding Source. The information must
+suffice to ensure that the continued functioning of the modified object
+code is in no case prevented or interfered with solely because
+modification has been made.
+
+ If you convey an object code work under this section in, or with, or
+specifically for use in, a User Product, and the conveying occurs as
+part of a transaction in which the right of possession and use of the
+User Product is transferred to the recipient in perpetuity or for a
+fixed term (regardless of how the transaction is characterized), the
+Corresponding Source conveyed under this section must be accompanied
+by the Installation Information. But this requirement does not apply
+if neither you nor any third party retains the ability to install
+modified object code on the User Product (for example, the work has
+been installed in ROM).
+
+ The requirement to provide Installation Information does not include a
+requirement to continue to provide support service, warranty, or updates
+for a work that has been modified or installed by the recipient, or for
+the User Product in which it has been modified or installed. Access to a
+network may be denied when the modification itself materially and
+adversely affects the operation of the network or violates the rules and
+protocols for communication across the network.
+
+ Corresponding Source conveyed, and Installation Information provided,
+in accord with this section must be in a format that is publicly
+documented (and with an implementation available to the public in
+source code form), and must require no special password or key for
+unpacking, reading or copying.
+
+ 7. Additional Terms.
+
+ "Additional permissions" are terms that supplement the terms of this
+License by making exceptions from one or more of its conditions.
+Additional permissions that are applicable to the entire Program shall
+be treated as though they were included in this License, to the extent
+that they are valid under applicable law. If additional permissions
+apply only to part of the Program, that part may be used separately
+under those permissions, but the entire Program remains governed by
+this License without regard to the additional permissions.
+
+ When you convey a copy of a covered work, you may at your option
+remove any additional permissions from that copy, or from any part of
+it. (Additional permissions may be written to require their own
+removal in certain cases when you modify the work.) You may place
+additional permissions on material, added by you to a covered work,
+for which you have or can give appropriate copyright permission.
+
+ Notwithstanding any other provision of this License, for material you
+add to a covered work, you may (if authorized by the copyright holders of
+that material) supplement the terms of this License with terms:
+
+ a) Disclaiming warranty or limiting liability differently from the
+ terms of sections 15 and 16 of this License; or
+
+ b) Requiring preservation of specified reasonable legal notices or
+ author attributions in that material or in the Appropriate Legal
+ Notices displayed by works containing it; or
+
+ c) Prohibiting misrepresentation of the origin of that material, or
+ requiring that modified versions of such material be marked in
+ reasonable ways as different from the original version; or
+
+ d) Limiting the use for publicity purposes of names of licensors or
+ authors of the material; or
+
+ e) Declining to grant rights under trademark law for use of some
+ trade names, trademarks, or service marks; or
+
+ f) Requiring indemnification of licensors and authors of that
+ material by anyone who conveys the material (or modified versions of
+ it) with contractual assumptions of liability to the recipient, for
+ any liability that these contractual assumptions directly impose on
+ those licensors and authors.
+
+ All other non-permissive additional terms are considered "further
+restrictions" within the meaning of section 10. If the Program as you
+received it, or any part of it, contains a notice stating that it is
+governed by this License along with a term that is a further
+restriction, you may remove that term. If a license document contains
+a further restriction but permits relicensing or conveying under this
+License, you may add to a covered work material governed by the terms
+of that license document, provided that the further restriction does
+not survive such relicensing or conveying.
+
+ If you add terms to a covered work in accord with this section, you
+must place, in the relevant source files, a statement of the
+additional terms that apply to those files, or a notice indicating
+where to find the applicable terms.
+
+ Additional terms, permissive or non-permissive, may be stated in the
+form of a separately written license, or stated as exceptions;
+the above requirements apply either way.
+
+ 8. Termination.
+
+ You may not propagate or modify a covered work except as expressly
+provided under this License. Any attempt otherwise to propagate or
+modify it is void, and will automatically terminate your rights under
+this License (including any patent licenses granted under the third
+paragraph of section 11).
+
+ However, if you cease all violation of this License, then your
+license from a particular copyright holder is reinstated (a)
+provisionally, unless and until the copyright holder explicitly and
+finally terminates your license, and (b) permanently, if the copyright
+holder fails to notify you of the violation by some reasonable means
+prior to 60 days after the cessation.
+
+ Moreover, your license from a particular copyright holder is
+reinstated permanently if the copyright holder notifies you of the
+violation by some reasonable means, this is the first time you have
+received notice of violation of this License (for any work) from that
+copyright holder, and you cure the violation prior to 30 days after
+your receipt of the notice.
+
+ Termination of your rights under this section does not terminate the
+licenses of parties who have received copies or rights from you under
+this License. If your rights have been terminated and not permanently
+reinstated, you do not qualify to receive new licenses for the same
+material under section 10.
+
+ 9. Acceptance Not Required for Having Copies.
+
+ You are not required to accept this License in order to receive or
+run a copy of the Program. Ancillary propagation of a covered work
+occurring solely as a consequence of using peer-to-peer transmission
+to receive a copy likewise does not require acceptance. However,
+nothing other than this License grants you permission to propagate or
+modify any covered work. These actions infringe copyright if you do
+not accept this License. Therefore, by modifying or propagating a
+covered work, you indicate your acceptance of this License to do so.
+
+ 10. Automatic Licensing of Downstream Recipients.
+
+ Each time you convey a covered work, the recipient automatically
+receives a license from the original licensors, to run, modify and
+propagate that work, subject to this License. You are not responsible
+for enforcing compliance by third parties with this License.
+
+ An "entity transaction" is a transaction transferring control of an
+organization, or substantially all assets of one, or subdividing an
+organization, or merging organizations. If propagation of a covered
+work results from an entity transaction, each party to that
+transaction who receives a copy of the work also receives whatever
+licenses to the work the party's predecessor in interest had or could
+give under the previous paragraph, plus a right to possession of the
+Corresponding Source of the work from the predecessor in interest, if
+the predecessor has it or can get it with reasonable efforts.
+
+ You may not impose any further restrictions on the exercise of the
+rights granted or affirmed under this License. For example, you may
+not impose a license fee, royalty, or other charge for exercise of
+rights granted under this License, and you may not initiate litigation
+(including a cross-claim or counterclaim in a lawsuit) alleging that
+any patent claim is infringed by making, using, selling, offering for
+sale, or importing the Program or any portion of it.
+
+ 11. Patents.
+
+ A "contributor" is a copyright holder who authorizes use under this
+License of the Program or a work on which the Program is based. The
+work thus licensed is called the contributor's "contributor version".
+
+ A contributor's "essential patent claims" are all patent claims
+owned or controlled by the contributor, whether already acquired or
+hereafter acquired, that would be infringed by some manner, permitted
+by this License, of making, using, or selling its contributor version,
+but do not include claims that would be infringed only as a
+consequence of further modification of the contributor version. For
+purposes of this definition, "control" includes the right to grant
+patent sublicenses in a manner consistent with the requirements of
+this License.
+
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
+patent license under the contributor's essential patent claims, to
+make, use, sell, offer for sale, import and otherwise run, modify and
+propagate the contents of its contributor version.
+
+ In the following three paragraphs, a "patent license" is any express
+agreement or commitment, however denominated, not to enforce a patent
+(such as an express permission to practice a patent or covenant not to
+sue for patent infringement). To "grant" such a patent license to a
+party means to make such an agreement or commitment not to enforce a
+patent against the party.
+
+ If you convey a covered work, knowingly relying on a patent license,
+and the Corresponding Source of the work is not available for anyone
+to copy, free of charge and under the terms of this License, through a
+publicly available network server or other readily accessible means,
+then you must either (1) cause the Corresponding Source to be so
+available, or (2) arrange to deprive yourself of the benefit of the
+patent license for this particular work, or (3) arrange, in a manner
+consistent with the requirements of this License, to extend the patent
+license to downstream recipients. "Knowingly relying" means you have
+actual knowledge that, but for the patent license, your conveying the
+covered work in a country, or your recipient's use of the covered work
+in a country, would infringe one or more identifiable patents in that
+country that you have reason to believe are valid.
+
+ If, pursuant to or in connection with a single transaction or
+arrangement, you convey, or propagate by procuring conveyance of, a
+covered work, and grant a patent license to some of the parties
+receiving the covered work authorizing them to use, propagate, modify
+or convey a specific copy of the covered work, then the patent license
+you grant is automatically extended to all recipients of the covered
+work and works based on it.
+
+ A patent license is "discriminatory" if it does not include within
+the scope of its coverage, prohibits the exercise of, or is
+conditioned on the non-exercise of one or more of the rights that are
+specifically granted under this License. You may not convey a covered
+work if you are a party to an arrangement with a third party that is
+in the business of distributing software, under which you make payment
+to the third party based on the extent of your activity of conveying
+the work, and under which the third party grants, to any of the
+parties who would receive the covered work from you, a discriminatory
+patent license (a) in connection with copies of the covered work
+conveyed by you (or copies made from those copies), or (b) primarily
+for and in connection with specific products or compilations that
+contain the covered work, unless you entered into that arrangement,
+or that patent license was granted, prior to 28 March 2007.
+
+ Nothing in this License shall be construed as excluding or limiting
+any implied license or other defenses to infringement that may
+otherwise be available to you under applicable patent law.
+
+ 12. No Surrender of Others' Freedom.
+
+ If conditions are imposed on you (whether by court order, agreement or
+otherwise) that contradict the conditions of this License, they do not
+excuse you from the conditions of this License. If you cannot convey a
+covered work so as to satisfy simultaneously your obligations under this
+License and any other pertinent obligations, then as a consequence you may
+not convey it at all. For example, if you agree to terms that obligate you
+to collect a royalty for further conveying from those to whom you convey
+the Program, the only way you could satisfy both those terms and this
+License would be to refrain entirely from conveying the Program.
+
+ 13. Remote Network Interaction; Use with the GNU General Public License.
+
+ Notwithstanding any other provision of this License, if you modify the
+Program, your modified version must prominently offer all users
+interacting with it remotely through a computer network (if your version
+supports such interaction) an opportunity to receive the Corresponding
+Source of your version by providing access to the Corresponding Source
+from a network server at no charge, through some standard or customary
+means of facilitating copying of software. This Corresponding Source
+shall include the Corresponding Source for any work covered by version 3
+of the GNU General Public License that is incorporated pursuant to the
+following paragraph.
+
+ Notwithstanding any other provision of this License, you have
+permission to link or combine any covered work with a work licensed
+under version 3 of the GNU General Public License into a single
+combined work, and to convey the resulting work. The terms of this
+License will continue to apply to the part which is the covered work,
+but the work with which it is combined will remain governed by version
+3 of the GNU General Public License.
+
+ 14. Revised Versions of this License.
+
+ The Free Software Foundation may publish revised and/or new versions of
+the GNU Affero General Public License from time to time. Such new versions
+will be similar in spirit to the present version, but may differ in detail to
+address new problems or concerns.
+
+ Each version is given a distinguishing version number. If the
+Program specifies that a certain numbered version of the GNU Affero General
+Public License "or any later version" applies to it, you have the
+option of following the terms and conditions either of that numbered
+version or of any later version published by the Free Software
+Foundation. If the Program does not specify a version number of the
+GNU Affero General Public License, you may choose any version ever published
+by the Free Software Foundation.
+
+ If the Program specifies that a proxy can decide which future
+versions of the GNU Affero General Public License can be used, that proxy's
+public statement of acceptance of a version permanently authorizes you
+to choose that version for the Program.
+
+ Later license versions may give you additional or different
+permissions. However, no additional obligations are imposed on any
+author or copyright holder as a result of your choosing to follow a
+later version.
+
+ 15. Disclaimer of Warranty.
+
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
+APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
+HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
+OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
+THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
+IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
+ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
+
+ 16. Limitation of Liability.
+
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
+WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
+THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
+GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
+USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
+DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
+PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
+EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
+SUCH DAMAGES.
+
+ 17. Interpretation of Sections 15 and 16.
+
+ If the disclaimer of warranty and limitation of liability provided
+above cannot be given local legal effect according to their terms,
+reviewing courts shall apply local law that most closely approximates
+an absolute waiver of all civil liability in connection with the
+Program, unless a warranty or assumption of liability accompanies a
+copy of the Program in return for a fee.
+
+ END OF TERMS AND CONDITIONS
+
+ How to Apply These Terms to Your New Programs
+
+ If you develop a new program, and you want it to be of the greatest
+possible use to the public, the best way to achieve this is to make it
+free software which everyone can redistribute and change under these terms.
+
+ To do so, attach the following notices to the program. It is safest
+to attach them to the start of each source file to most effectively
+state the exclusion of warranty; and each file should have at least
+the "copyright" line and a pointer to where the full notice is found.
+
+
+ Copyright (C)
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+
+Also add information on how to contact you by electronic and paper mail.
+
+ If your software can interact with users remotely through a computer
+network, you should also make sure that it provides a way for users to
+get its source. For example, if your program is a web application, its
+interface could display a "Source" link that leads users to an archive
+of the code. There are many ways you could offer source, and different
+solutions will be better for different programs; see section 13 for the
+specific requirements.
+
+ You should also get your employer (if you work as a programmer) or school,
+if any, to sign a "copyright disclaimer" for the program, if necessary.
+For more information on this, and how to apply and follow the GNU AGPL, see
+ .
diff --git a/README.md b/README.md
index df2a01493..1b3679de2 100644
--- a/README.md
+++ b/README.md
@@ -1,210 +1,210 @@
-
-
-# SD.Next
-
-**Stable Diffusion implementation with advanced features**
-
-[![Sponsors](https://img.shields.io/static/v1?label=Sponsor&message=%E2%9D%A4&logo=GitHub&color=%23fe8e86)](https://github.com/sponsors/vladmandic)
-![Last Commit](https://img.shields.io/github/last-commit/vladmandic/automatic?svg=true)
-![License](https://img.shields.io/github/license/vladmandic/automatic?svg=true)
-[![Discord](https://img.shields.io/discord/1101998836328697867?logo=Discord&svg=true)](https://discord.gg/VjvR2tabEX)
-
-[Wiki](https://github.com/vladmandic/automatic/wiki) | [Discord](https://discord.gg/VjvR2tabEX) | [Changelog](CHANGELOG.md)
-
-
-
-
-## Notable features
-
-All individual features are not listed here, instead check [ChangeLog](CHANGELOG.md) for full list of changes
-- Multiple backends!
- ▹ **Original | Diffusers**
-- Multiple diffusion models!
- ▹ **Stable Diffusion 1.5/2.1 | SD-XL | LCM | Segmind | Kandinsky | Pixart-α | Würstchen | aMUSEd | DeepFloyd IF | UniDiffusion | SD-Distilled | BLiP Diffusion | etc.**
-- Built-in Control for Text, Image, Batch and video processing!
- ▹ **ControlNet | ControlNet XS | Control LLLite | T2I Adapters | IP Adapters**
-- Multiplatform!
- ▹ **Windows | Linux | MacOS with CPU | nVidia | AMD | IntelArc | DirectML | OpenVINO | ONNX+Olive**
-- Platform specific autodetection and tuning performed on install
-- Optimized processing with latest `torch` developments with built-in support for `torch.compile` and multiple compile backends
-- Improved prompt parser
-- Enhanced *Lora*/*LoCon*/*Lyco* code supporting latest trends in training
-- Built-in queue management
-- Enterprise level logging and hardened API
-- Modern localization and hints engine
-- Broad compatibility with existing extensions ecosystem and new extensions manager
-- Built in installer with automatic updates and dependency management
-- Modernized UI with theme support and number of built-in themes *(dark and light)*
-
-
-
-![Screenshot-Dark](html/xmas-default.jpg)
-![Screenshot-Control](html/xmas-control.jpg)
-![Screenshot-Light](html/light-teal.jpg)
-
-
-
-## Backend support
-
-**SD.Next** supports two main backends: *Original* and *Diffusers*:
-
-- **Original**: Based on [LDM](https://github.com/Stability-AI/stablediffusion) reference implementation and significantly expanded on by [A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui)
- This is the default backend and it is fully compatible with all existing functionality and extensions
- Supports **SD 1.x** and **SD 2.x** models
- All other model types such as *SD-XL, LCM, PixArt, Segmind, Kandinsky, etc.* require backend **Diffusers**
-- **Diffusers**: Based on new [Huggingface Diffusers](https://huggingface.co/docs/diffusers/index) implementation
- Supports *original* SD models as well as *all* models listed below
- See [wiki article](https://github.com/vladmandic/automatic/wiki/Diffusers) for more information
-
-## Model support
-
-Additional models will be added as they become available and there is public interest in them
-
-- [RunwayML Stable Diffusion](https://github.com/Stability-AI/stablediffusion/) 1.x and 2.x *(all variants)*
-- [StabilityAI Stable Diffusion XL](https://github.com/Stability-AI/generative-models)
-- [StabilityAI Stable Video Diffusion](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid) Base and XT
-- [LCM: Latent Consistency Models](https://github.com/openai/consistency_models)
-- [aMUSEd 256](https://huggingface.co/amused/amused-256) 256 and 512
-- [Segmind Vega](https://huggingface.co/segmind/Segmind-Vega)
-- [Segmind SSD-1B](https://huggingface.co/segmind/SSD-1B)
-- [Kandinsky](https://github.com/ai-forever/Kandinsky-2) *2.1 and 2.2 and latest 3.0*
-- [PixArt-α XL 2](https://github.com/PixArt-alpha/PixArt-alpha) *Medium and Large*
-- [Warp Wuerstchen](https://huggingface.co/blog/wuertschen)
-- [Playground](https://huggingface.co/playgroundai/playground-v2-256px-base) *v1, v2 256, v2 512, v2 1024*
-- [Tsinghua UniDiffusion](https://github.com/thu-ml/unidiffuser)
-- [DeepFloyd IF](https://github.com/deep-floyd/IF) *Medium and Large*
-- [ModelScope T2V](https://huggingface.co/damo-vilab/text-to-video-ms-1.7b)
-- [Segmind SD Distilled](https://huggingface.co/blog/sd_distillation) *(all variants)*
-- [BLIP-Diffusion](https://dxli94.github.io/BLIP-Diffusion-website/)
-
-
-Also supported are modifiers such as:
-- **LCM** and **Turbo** (Adversarial Diffusion Distillation) networks
-- All **LoRA** types such as LoCon, LyCORIS, HADA, IA3, Lokr, OFT
-- **AnimateDiff** for SD 1.5
-- **IP-Adapters** for SD 1.5 and SD-XL
-
-> [!IMPORTANT]
-> - Loading any model other than standard SD 1.x / SD 2.x requires use of backend **Diffusers**
-> - Loading any other models using **Original** backend is not supported
-> - Loading manually download model `.safetensors` files is supported for SD 1.x / SD 2.x / SD-XL models only
-> - For all other model types, use backend **Diffusers** and use built in Model downloader or
- select model from Networks -> Models -> Reference list in which case it will be auto-downloaded and loaded
-
-## Platform support
-
-- *nVidia* GPUs using **CUDA** libraries on both *Windows and Linux*
-- *AMD* GPUs using **ROCm** libraries on *Linux*
- Support will be extended to *Windows* once AMD releases ROCm for Windows
-- *Intel Arc* GPUs using **OneAPI** with *IPEX XPU* libraries on both *Windows and Linux*
-- Any GPU compatible with *DirectX* on *Windows* using **DirectML** libraries
- This includes support for AMD GPUs that are not supported by native ROCm libraries
-- Any GPU or device compatible with **OpenVINO** libraries on both *Windows and Linux*
-- *Apple M1/M2* on *OSX* using built-in support in Torch with **MPS** optimizations
-- *ONNX/Olive* (experimental)
-
-## Install
-
-- [Step-by-step install guide](https://github.com/vladmandic/automatic/wiki/Installation)
-- [Advanced install notes](https://github.com/vladmandic/automatic/wiki/Advanced-Install)
-- [Common installation errors](https://github.com/vladmandic/automatic/discussions/1627)
-- [FAQ](https://github.com/vladmandic/automatic/discussions/1011)
-- If you can't run us locally, try our friends at [RunDuffusion!](https://rundiffusion.com?utm_source=github&utm_medium=referral&utm_campaign=SDNext)
-
-> [!TIP]
-> - Server can run without virtual environment,
- Recommended to use `VENV` to avoid library version conflicts with other applications
-> - **nVidia/CUDA** / **AMD/ROCm** / **Intel/OneAPI** are auto-detected if present and available,
- For any other use case such as **DirectML**, **ONNX/Olive**, **OpenVINO** specify required parameter explicitly
- or wrong packages may be installed as installer will assume CPU-only environment
-> - Full startup sequence is logged in `sdnext.log`,
- so if you encounter any issues, please check it first
-
-### Run
-
-Once SD.Next is installed, simply run `webui.ps1` or `webui.bat` (*Windows*) or `webui.sh` (*Linux or MacOS*)
-
-Below is partial list of all available parameters, run `webui --help` for the full list:
-
- Server options:
- --config CONFIG Use specific server configuration file, default: config.json
- --ui-config UI_CONFIG Use specific UI configuration file, default: ui-config.json
- --medvram Split model stages and keep only active part in VRAM, default: False
- --lowvram Split model components and keep only active part in VRAM, default: False
- --ckpt CKPT Path to model checkpoint to load immediately, default: None
- --vae VAE Path to VAE checkpoint to load immediately, default: None
- --data-dir DATA_DIR Base path where all user data is stored, default:
- --models-dir MODELS_DIR Base path where all models are stored, default: models
- --share Enable UI accessible through Gradio site, default: False
- --insecure Enable extensions tab regardless of other options, default: False
- --listen Launch web server using public IP address, default: False
- --auth AUTH Set access authentication like "user:pwd,user:pwd""
- --autolaunch Open the UI URL in the system's default browser upon launch
- --docs Mount Gradio docs at /docs, default: False
- --no-hashing Disable hashing of checkpoints, default: False
- --no-metadata Disable reading of metadata from models, default: False
- --no-download Disable download of default model, default: False
- --backend {original,diffusers} force model pipeline type
-
- Setup options:
- --debug Run installer with debug logging, default: False
- --reset Reset main repository to latest version, default: False
- --upgrade Upgrade main repository to latest version, default: False
- --requirements Force re-check of requirements, default: False
- --quick Run with startup sequence only, default: False
- --use-directml Use DirectML if no compatible GPU is detected, default: False
- --use-openvino Use Intel OpenVINO backend, default: False
- --use-ipex Force use Intel OneAPI XPU backend, default: False
- --use-cuda Force use nVidia CUDA backend, default: False
- --use-rocm Force use AMD ROCm backend, default: False
- --use-xformers Force use xFormers cross-optimization, default: False
- --skip-requirements Skips checking and installing requirements, default: False
- --skip-extensions Skips running individual extension installers, default: False
- --skip-git Skips running all GIT operations, default: False
- --skip-torch Skips running Torch checks, default: False
- --skip-all Skips running all checks, default: False
- --experimental Allow unsupported versions of libraries, default: False
- --reinstall Force reinstallation of all requirements, default: False
- --safe Run in safe mode with no user extensions
-
-
-## Notes
-
-### **Extensions**
-
-SD.Next comes with several extensions pre-installed:
-
-- [ControlNet](https://github.com/Mikubill/sd-webui-controlnet)
-- [Agent Scheduler](https://github.com/ArtVentureX/sd-webui-agent-scheduler)
-- [Image Browser](https://github.com/AlUlkesh/stable-diffusion-webui-images-browser)
-
-### **Collab**
-
-- We'd love to have additional maintainers with full admin rights. If you're interested, ping us!
-- In addition to general cross-platform code, desire is to have a lead for each of the main platforms.
-This should be fully cross-platform, but we'd really love to have additional contributors and/or maintainers to join and help lead the efforts on different platforms.
-
-## Credits
-
-- Main credit goes to [Automatic1111 WebUI](https://github.com/AUTOMATIC1111/stable-diffusion-webui)
-- Additional credits are listed in [Credits](https://github.com/AUTOMATIC1111/stable-diffusion-webui/#credits)
-- Licenses for modules are listed in [Licenses](html/licenses.html)
-
-### **Docs**
-
-If you're unsure how to use a feature, best place to start is [Wiki](https://github.com/vladmandic/automatic/wiki) and if its not there,
-check [ChangeLog](CHANGELOG.md) for when feature was first introduced as it will always have a short note on how to use it
-
-- [Wiki](https://github.com/vladmandic/automatic/wiki)
-- [ReadMe](README.md)
-- [ToDo](TODO.md)
-- [ChangeLog](CHANGELOG.md)
-- [CLI Tools](cli/README.md)
-
-### **Sponsors**
-
-
-
-
+
+
+# SD.Next
+
+**Stable Diffusion implementation with advanced features**
+
+[![Sponsors](https://img.shields.io/static/v1?label=Sponsor&message=%E2%9D%A4&logo=GitHub&color=%23fe8e86)](https://github.com/sponsors/vladmandic)
+![Last Commit](https://img.shields.io/github/last-commit/vladmandic/automatic?svg=true)
+![License](https://img.shields.io/github/license/vladmandic/automatic?svg=true)
+[![Discord](https://img.shields.io/discord/1101998836328697867?logo=Discord&svg=true)](https://discord.gg/VjvR2tabEX)
+
+[Wiki](https://github.com/vladmandic/automatic/wiki) | [Discord](https://discord.gg/VjvR2tabEX) | [Changelog](CHANGELOG.md)
+
+
+
+
+## Notable features
+
+All individual features are not listed here, instead check [ChangeLog](CHANGELOG.md) for full list of changes
+- Multiple backends!
+ ▹ **Original | Diffusers**
+- Multiple diffusion models!
+ ▹ **Stable Diffusion 1.5/2.1 | SD-XL | LCM | Segmind | Kandinsky | Pixart-α | Würstchen | aMUSEd | DeepFloyd IF | UniDiffusion | SD-Distilled | BLiP Diffusion | etc.**
+- Built-in Control for Text, Image, Batch and video processing!
+ ▹ **ControlNet | ControlNet XS | Control LLLite | T2I Adapters | IP Adapters**
+- Multiplatform!
+ ▹ **Windows | Linux | MacOS with CPU | nVidia | AMD | IntelArc | DirectML | OpenVINO | ONNX+Olive**
+- Platform specific autodetection and tuning performed on install
+- Optimized processing with latest `torch` developments with built-in support for `torch.compile` and multiple compile backends
+- Improved prompt parser
+- Enhanced *Lora*/*LoCon*/*Lyco* code supporting latest trends in training
+- Built-in queue management
+- Enterprise level logging and hardened API
+- Modern localization and hints engine
+- Broad compatibility with existing extensions ecosystem and new extensions manager
+- Built in installer with automatic updates and dependency management
+- Modernized UI with theme support and number of built-in themes *(dark and light)*
+
+
+
+![Screenshot-Dark](html/xmas-default.jpg)
+![Screenshot-Control](html/xmas-control.jpg)
+![Screenshot-Light](html/light-teal.jpg)
+
+
+
+## Backend support
+
+**SD.Next** supports two main backends: *Original* and *Diffusers*:
+
+- **Original**: Based on [LDM](https://github.com/Stability-AI/stablediffusion) reference implementation and significantly expanded on by [A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui)
+ This is the default backend and it is fully compatible with all existing functionality and extensions
+ Supports **SD 1.x** and **SD 2.x** models
+ All other model types such as *SD-XL, LCM, PixArt, Segmind, Kandinsky, etc.* require backend **Diffusers**
+- **Diffusers**: Based on new [Huggingface Diffusers](https://huggingface.co/docs/diffusers/index) implementation
+ Supports *original* SD models as well as *all* models listed below
+ See [wiki article](https://github.com/vladmandic/automatic/wiki/Diffusers) for more information
+
+## Model support
+
+Additional models will be added as they become available and there is public interest in them
+
+- [RunwayML Stable Diffusion](https://github.com/Stability-AI/stablediffusion/) 1.x and 2.x *(all variants)*
+- [StabilityAI Stable Diffusion XL](https://github.com/Stability-AI/generative-models)
+- [StabilityAI Stable Video Diffusion](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid) Base and XT
+- [LCM: Latent Consistency Models](https://github.com/openai/consistency_models)
+- [aMUSEd 256](https://huggingface.co/amused/amused-256) 256 and 512
+- [Segmind Vega](https://huggingface.co/segmind/Segmind-Vega)
+- [Segmind SSD-1B](https://huggingface.co/segmind/SSD-1B)
+- [Kandinsky](https://github.com/ai-forever/Kandinsky-2) *2.1 and 2.2 and latest 3.0*
+- [PixArt-α XL 2](https://github.com/PixArt-alpha/PixArt-alpha) *Medium and Large*
+- [Warp Wuerstchen](https://huggingface.co/blog/wuertschen)
+- [Playground](https://huggingface.co/playgroundai/playground-v2-256px-base) *v1, v2 256, v2 512, v2 1024*
+- [Tsinghua UniDiffusion](https://github.com/thu-ml/unidiffuser)
+- [DeepFloyd IF](https://github.com/deep-floyd/IF) *Medium and Large*
+- [ModelScope T2V](https://huggingface.co/damo-vilab/text-to-video-ms-1.7b)
+- [Segmind SD Distilled](https://huggingface.co/blog/sd_distillation) *(all variants)*
+- [BLIP-Diffusion](https://dxli94.github.io/BLIP-Diffusion-website/)
+
+
+Also supported are modifiers such as:
+- **LCM** and **Turbo** (Adversarial Diffusion Distillation) networks
+- All **LoRA** types such as LoCon, LyCORIS, HADA, IA3, Lokr, OFT
+- **AnimateDiff** for SD 1.5
+- **IP-Adapters** for SD 1.5 and SD-XL
+
+> [!IMPORTANT]
+> - Loading any model other than standard SD 1.x / SD 2.x requires use of backend **Diffusers**
+> - Loading any other models using **Original** backend is not supported
+> - Loading manually download model `.safetensors` files is supported for SD 1.x / SD 2.x / SD-XL models only
+> - For all other model types, use backend **Diffusers** and use built in Model downloader or
+ select model from Networks -> Models -> Reference list in which case it will be auto-downloaded and loaded
+
+## Platform support
+
+- *nVidia* GPUs using **CUDA** libraries on both *Windows and Linux*
+- *AMD* GPUs using **ROCm** libraries on *Linux*
+ Support will be extended to *Windows* once AMD releases ROCm for Windows
+- *Intel Arc* GPUs using **OneAPI** with *IPEX XPU* libraries on both *Windows and Linux*
+- Any GPU compatible with *DirectX* on *Windows* using **DirectML** libraries
+ This includes support for AMD GPUs that are not supported by native ROCm libraries
+- Any GPU or device compatible with **OpenVINO** libraries on both *Windows and Linux*
+- *Apple M1/M2* on *OSX* using built-in support in Torch with **MPS** optimizations
+- *ONNX/Olive* (experimental)
+
+## Install
+
+- [Step-by-step install guide](https://github.com/vladmandic/automatic/wiki/Installation)
+- [Advanced install notes](https://github.com/vladmandic/automatic/wiki/Advanced-Install)
+- [Common installation errors](https://github.com/vladmandic/automatic/discussions/1627)
+- [FAQ](https://github.com/vladmandic/automatic/discussions/1011)
+- If you can't run us locally, try our friends at [RunDuffusion!](https://rundiffusion.com?utm_source=github&utm_medium=referral&utm_campaign=SDNext)
+
+> [!TIP]
+> - Server can run without virtual environment,
+ Recommended to use `VENV` to avoid library version conflicts with other applications
+> - **nVidia/CUDA** / **AMD/ROCm** / **Intel/OneAPI** are auto-detected if present and available,
+ For any other use case such as **DirectML**, **ONNX/Olive**, **OpenVINO** specify required parameter explicitly
+ or wrong packages may be installed as installer will assume CPU-only environment
+> - Full startup sequence is logged in `sdnext.log`,
+ so if you encounter any issues, please check it first
+
+### Run
+
+Once SD.Next is installed, simply run `webui.ps1` or `webui.bat` (*Windows*) or `webui.sh` (*Linux or MacOS*)
+
+Below is partial list of all available parameters, run `webui --help` for the full list:
+
+ Server options:
+ --config CONFIG Use specific server configuration file, default: config.json
+ --ui-config UI_CONFIG Use specific UI configuration file, default: ui-config.json
+ --medvram Split model stages and keep only active part in VRAM, default: False
+ --lowvram Split model components and keep only active part in VRAM, default: False
+ --ckpt CKPT Path to model checkpoint to load immediately, default: None
+ --vae VAE Path to VAE checkpoint to load immediately, default: None
+ --data-dir DATA_DIR Base path where all user data is stored, default:
+ --models-dir MODELS_DIR Base path where all models are stored, default: models
+ --share Enable UI accessible through Gradio site, default: False
+ --insecure Enable extensions tab regardless of other options, default: False
+ --listen Launch web server using public IP address, default: False
+ --auth AUTH Set access authentication like "user:pwd,user:pwd""
+ --autolaunch Open the UI URL in the system's default browser upon launch
+ --docs Mount Gradio docs at /docs, default: False
+ --no-hashing Disable hashing of checkpoints, default: False
+ --no-metadata Disable reading of metadata from models, default: False
+ --no-download Disable download of default model, default: False
+ --backend {original,diffusers} force model pipeline type
+
+ Setup options:
+ --debug Run installer with debug logging, default: False
+ --reset Reset main repository to latest version, default: False
+ --upgrade Upgrade main repository to latest version, default: False
+ --requirements Force re-check of requirements, default: False
+ --quick Run with startup sequence only, default: False
+ --use-directml Use DirectML if no compatible GPU is detected, default: False
+ --use-openvino Use Intel OpenVINO backend, default: False
+ --use-ipex Force use Intel OneAPI XPU backend, default: False
+ --use-cuda Force use nVidia CUDA backend, default: False
+ --use-rocm Force use AMD ROCm backend, default: False
+ --use-xformers Force use xFormers cross-optimization, default: False
+ --skip-requirements Skips checking and installing requirements, default: False
+ --skip-extensions Skips running individual extension installers, default: False
+ --skip-git Skips running all GIT operations, default: False
+ --skip-torch Skips running Torch checks, default: False
+ --skip-all Skips running all checks, default: False
+ --experimental Allow unsupported versions of libraries, default: False
+ --reinstall Force reinstallation of all requirements, default: False
+ --safe Run in safe mode with no user extensions
+
+
+## Notes
+
+### **Extensions**
+
+SD.Next comes with several extensions pre-installed:
+
+- [ControlNet](https://github.com/Mikubill/sd-webui-controlnet)
+- [Agent Scheduler](https://github.com/ArtVentureX/sd-webui-agent-scheduler)
+- [Image Browser](https://github.com/AlUlkesh/stable-diffusion-webui-images-browser)
+
+### **Collab**
+
+- We'd love to have additional maintainers with full admin rights. If you're interested, ping us!
+- In addition to general cross-platform code, desire is to have a lead for each of the main platforms.
+This should be fully cross-platform, but we'd really love to have additional contributors and/or maintainers to join and help lead the efforts on different platforms.
+
+## Credits
+
+- Main credit goes to [Automatic1111 WebUI](https://github.com/AUTOMATIC1111/stable-diffusion-webui)
+- Additional credits are listed in [Credits](https://github.com/AUTOMATIC1111/stable-diffusion-webui/#credits)
+- Licenses for modules are listed in [Licenses](html/licenses.html)
+
+### **Docs**
+
+If you're unsure how to use a feature, best place to start is [Wiki](https://github.com/vladmandic/automatic/wiki) and if its not there,
+check [ChangeLog](CHANGELOG.md) for when feature was first introduced as it will always have a short note on how to use it
+
+- [Wiki](https://github.com/vladmandic/automatic/wiki)
+- [ReadMe](README.md)
+- [ToDo](TODO.md)
+- [ChangeLog](CHANGELOG.md)
+- [CLI Tools](cli/README.md)
+
+### **Sponsors**
+
+
+
+
diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py
index 6cdfd03a8..4ed4c4776 100644
--- a/extensions-builtin/Lora/extra_networks_lora.py
+++ b/extensions-builtin/Lora/extra_networks_lora.py
@@ -1,85 +1,85 @@
-import time
-import networks
-import lora_patches
-from modules import extra_networks, shared
-
-
-class ExtraNetworkLora(extra_networks.ExtraNetwork):
-
- def __init__(self):
- super().__init__('lora')
- self.active = False
- self.errors = {}
- networks.originals = lora_patches.LoraPatches()
-
- """mapping of network names to the number of errors the network had during operation"""
-
- def activate(self, p, params_list):
- t0 = time.time()
- additional = shared.opts.sd_lora
- self.errors.clear()
- if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional):
- p.all_prompts = [x + f"" for x in p.all_prompts]
- params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
- if len(params_list) > 0:
- self.active = True
- networks.originals.apply() # apply patches
- if networks.debug:
- shared.log.debug("LoRA activate")
- names = []
- te_multipliers = []
- unet_multipliers = []
- dyn_dims = []
- for params in params_list:
- assert params.items
- names.append(params.positional[0])
- te_multiplier = float(params.positional[1]) if len(params.positional) > 1 else 1.0
- te_multiplier = float(params.named.get("te", te_multiplier))
- unet_multiplier = [float(params.positional[2]) if len(params.positional) > 2 else te_multiplier] * 3
- unet_multiplier = [float(params.named.get("unet", unet_multiplier[0]))] * 3
- unet_multiplier[0] = float(params.named.get("in", unet_multiplier[0]))
- unet_multiplier[1] = float(params.named.get("mid", unet_multiplier[1]))
- unet_multiplier[2] = float(params.named.get("out", unet_multiplier[2]))
- dyn_dim = int(params.positional[3]) if len(params.positional) > 3 else None
- dyn_dim = int(params.named["dyn"]) if "dyn" in params.named else dyn_dim
- te_multipliers.append(te_multiplier)
- unet_multipliers.append(unet_multiplier)
- dyn_dims.append(dyn_dim)
- t1 = time.time()
- networks.load_networks(names, te_multipliers, unet_multipliers, dyn_dims)
- t2 = time.time()
- if shared.opts.lora_add_hashes_to_infotext:
- network_hashes = []
- for item in networks.loaded_networks:
- shorthash = item.network_on_disk.shorthash
- if not shorthash:
- continue
- alias = item.mentioned_name
- if not alias:
- continue
- alias = alias.replace(":", "").replace(",", "")
- network_hashes.append(f"{alias}: {shorthash}")
- if network_hashes:
- p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes)
- if len(names) > 0:
- shared.log.info(f'LoRA apply: {names} patch={t1-t0:.2f} load={t2-t1:.2f}')
- elif self.active:
- self.active = False
-
- def deactivate(self, p):
- if shared.backend == shared.Backend.DIFFUSERS and hasattr(shared.sd_model, "unload_lora_weights") and hasattr(shared.sd_model, "text_encoder"):
- if 'CLIP' in shared.sd_model.text_encoder.__class__.__name__ and not (shared.opts.cuda_compile and shared.opts.cuda_compile_backend == "openvino_fx"):
- if shared.opts.lora_fuse_diffusers:
- shared.sd_model.unfuse_lora()
- shared.sd_model.unload_lora_weights()
- if not self.active and getattr(networks, "originals", None ) is not None:
- networks.originals.undo() # remove patches
- if networks.debug:
- shared.log.debug("LoRA deactivate")
- if self.active and networks.debug:
- shared.log.debug(f"LoRA end: load={networks.timer['load']:.2f} apply={networks.timer['apply']:.2f} restore={networks.timer['restore']:.2f}")
- if self.errors:
- p.comment("Networks with errors: " + ", ".join(f"{k} ({v})" for k, v in self.errors.items()))
- for k, v in self.errors.items():
- shared.log.error(f'LoRA errors: file="{k}" errors={v}')
- self.errors.clear()
+import time
+import networks
+import lora_patches
+from modules import extra_networks, shared
+
+
+class ExtraNetworkLora(extra_networks.ExtraNetwork):
+
+ def __init__(self):
+ super().__init__('lora')
+ self.active = False
+ self.errors = {}
+ networks.originals = lora_patches.LoraPatches()
+
+ """mapping of network names to the number of errors the network had during operation"""
+
+ def activate(self, p, params_list):
+ t0 = time.time()
+ additional = shared.opts.sd_lora
+ self.errors.clear()
+ if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional):
+ p.all_prompts = [x + f"" for x in p.all_prompts]
+ params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
+ if len(params_list) > 0:
+ self.active = True
+ networks.originals.apply() # apply patches
+ if networks.debug:
+ shared.log.debug("LoRA activate")
+ names = []
+ te_multipliers = []
+ unet_multipliers = []
+ dyn_dims = []
+ for params in params_list:
+ assert params.items
+ names.append(params.positional[0])
+ te_multiplier = float(params.positional[1]) if len(params.positional) > 1 else 1.0
+ te_multiplier = float(params.named.get("te", te_multiplier))
+ unet_multiplier = [float(params.positional[2]) if len(params.positional) > 2 else te_multiplier] * 3
+ unet_multiplier = [float(params.named.get("unet", unet_multiplier[0]))] * 3
+ unet_multiplier[0] = float(params.named.get("in", unet_multiplier[0]))
+ unet_multiplier[1] = float(params.named.get("mid", unet_multiplier[1]))
+ unet_multiplier[2] = float(params.named.get("out", unet_multiplier[2]))
+ dyn_dim = int(params.positional[3]) if len(params.positional) > 3 else None
+ dyn_dim = int(params.named["dyn"]) if "dyn" in params.named else dyn_dim
+ te_multipliers.append(te_multiplier)
+ unet_multipliers.append(unet_multiplier)
+ dyn_dims.append(dyn_dim)
+ t1 = time.time()
+ networks.load_networks(names, te_multipliers, unet_multipliers, dyn_dims)
+ t2 = time.time()
+ if shared.opts.lora_add_hashes_to_infotext:
+ network_hashes = []
+ for item in networks.loaded_networks:
+ shorthash = item.network_on_disk.shorthash
+ if not shorthash:
+ continue
+ alias = item.mentioned_name
+ if not alias:
+ continue
+ alias = alias.replace(":", "").replace(",", "")
+ network_hashes.append(f"{alias}: {shorthash}")
+ if network_hashes:
+ p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes)
+ if len(names) > 0:
+ shared.log.info(f'LoRA apply: {names} patch={t1-t0:.2f} load={t2-t1:.2f}')
+ elif self.active:
+ self.active = False
+
+ def deactivate(self, p):
+ if shared.backend == shared.Backend.DIFFUSERS and hasattr(shared.sd_model, "unload_lora_weights") and hasattr(shared.sd_model, "text_encoder"):
+ if 'CLIP' in shared.sd_model.text_encoder.__class__.__name__ and not (shared.opts.cuda_compile and shared.opts.cuda_compile_backend == "openvino_fx"):
+ if shared.opts.lora_fuse_diffusers:
+ shared.sd_model.unfuse_lora()
+ shared.sd_model.unload_lora_weights()
+ if not self.active and getattr(networks, "originals", None ) is not None:
+ networks.originals.undo() # remove patches
+ if networks.debug:
+ shared.log.debug("LoRA deactivate")
+ if self.active and networks.debug:
+ shared.log.debug(f"LoRA end: load={networks.timer['load']:.2f} apply={networks.timer['apply']:.2f} restore={networks.timer['restore']:.2f}")
+ if self.errors:
+ p.comment("Networks with errors: " + ", ".join(f"{k} ({v})" for k, v in self.errors.items()))
+ for k, v in self.errors.items():
+ shared.log.error(f'LoRA errors: file="{k}" errors={v}')
+ self.errors.clear()
diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py
index 742be4c61..45cc5c9df 100644
--- a/extensions-builtin/Lora/lora.py
+++ b/extensions-builtin/Lora/lora.py
@@ -1,8 +1,8 @@
-import networks
-
-list_available_loras = networks.list_available_networks
-available_loras = networks.available_networks
-available_lora_aliases = networks.available_network_aliases
-available_lora_hash_lookup = networks.available_network_hash_lookup
-forbidden_lora_aliases = networks.forbidden_network_aliases
-loaded_loras = networks.loaded_networks
+import networks
+
+list_available_loras = networks.list_available_networks
+available_loras = networks.available_networks
+available_lora_aliases = networks.available_network_aliases
+available_lora_hash_lookup = networks.available_network_hash_lookup
+forbidden_lora_aliases = networks.forbidden_network_aliases
+loaded_loras = networks.loaded_networks
diff --git a/extensions-builtin/Lora/lora_patches.py b/extensions-builtin/Lora/lora_patches.py
index 9ea6a10c3..ca7c56e2c 100644
--- a/extensions-builtin/Lora/lora_patches.py
+++ b/extensions-builtin/Lora/lora_patches.py
@@ -1,52 +1,52 @@
-import torch
-import networks
-from modules import patches, shared
-
-
-class LoraPatches:
- def __init__(self):
- self.active = False
- self.Linear_forward = None
- self.Linear_load_state_dict = None
- self.Conv2d_forward = None
- self.Conv2d_load_state_dict = None
- self.GroupNorm_forward = None
- self.GroupNorm_load_state_dict = None
- self.LayerNorm_forward = None
- self.LayerNorm_load_state_dict = None
- self.MultiheadAttention_forward = None
- self.MultiheadAttention_load_state_dict = None
-
- def apply(self):
- if self.active or shared.opts.lora_force_diffusers:
- return
- self.Linear_forward = patches.patch(__name__, torch.nn.Linear, 'forward', networks.network_Linear_forward)
- self.Linear_load_state_dict = patches.patch(__name__, torch.nn.Linear, '_load_from_state_dict', networks.network_Linear_load_state_dict)
- self.Conv2d_forward = patches.patch(__name__, torch.nn.Conv2d, 'forward', networks.network_Conv2d_forward)
- self.Conv2d_load_state_dict = patches.patch(__name__, torch.nn.Conv2d, '_load_from_state_dict', networks.network_Conv2d_load_state_dict)
- self.GroupNorm_forward = patches.patch(__name__, torch.nn.GroupNorm, 'forward', networks.network_GroupNorm_forward)
- self.GroupNorm_load_state_dict = patches.patch(__name__, torch.nn.GroupNorm, '_load_from_state_dict', networks.network_GroupNorm_load_state_dict)
- self.LayerNorm_forward = patches.patch(__name__, torch.nn.LayerNorm, 'forward', networks.network_LayerNorm_forward)
- self.LayerNorm_load_state_dict = patches.patch(__name__, torch.nn.LayerNorm, '_load_from_state_dict', networks.network_LayerNorm_load_state_dict)
- self.MultiheadAttention_forward = patches.patch(__name__, torch.nn.MultiheadAttention, 'forward', networks.network_MultiheadAttention_forward)
- self.MultiheadAttention_load_state_dict = patches.patch(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict', networks.network_MultiheadAttention_load_state_dict)
- networks.timer['load'] = 0
- networks.timer['apply'] = 0
- networks.timer['restore'] = 0
- self.active = True
-
- def undo(self):
- if not self.active or shared.opts.lora_force_diffusers:
- return
- self.Linear_forward = patches.undo(__name__, torch.nn.Linear, 'forward') # pylint: disable=E1128
- self.Linear_load_state_dict = patches.undo(__name__, torch.nn.Linear, '_load_from_state_dict') # pylint: disable=E1128
- self.Conv2d_forward = patches.undo(__name__, torch.nn.Conv2d, 'forward') # pylint: disable=E1128
- self.Conv2d_load_state_dict = patches.undo(__name__, torch.nn.Conv2d, '_load_from_state_dict') # pylint: disable=E1128
- self.GroupNorm_forward = patches.undo(__name__, torch.nn.GroupNorm, 'forward') # pylint: disable=E1128
- self.GroupNorm_load_state_dict = patches.undo(__name__, torch.nn.GroupNorm, '_load_from_state_dict') # pylint: disable=E1128
- self.LayerNorm_forward = patches.undo(__name__, torch.nn.LayerNorm, 'forward') # pylint: disable=E1128
- self.LayerNorm_load_state_dict = patches.undo(__name__, torch.nn.LayerNorm, '_load_from_state_dict') # pylint: disable=E1128
- self.MultiheadAttention_forward = patches.undo(__name__, torch.nn.MultiheadAttention, 'forward') # pylint: disable=E1128
- self.MultiheadAttention_load_state_dict = patches.undo(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict') # pylint: disable=E1128
- patches.originals.pop(__name__, None)
- self.active = False
+import torch
+import networks
+from modules import patches, shared
+
+
+class LoraPatches:
+ def __init__(self):
+ self.active = False
+ self.Linear_forward = None
+ self.Linear_load_state_dict = None
+ self.Conv2d_forward = None
+ self.Conv2d_load_state_dict = None
+ self.GroupNorm_forward = None
+ self.GroupNorm_load_state_dict = None
+ self.LayerNorm_forward = None
+ self.LayerNorm_load_state_dict = None
+ self.MultiheadAttention_forward = None
+ self.MultiheadAttention_load_state_dict = None
+
+ def apply(self):
+ if self.active or shared.opts.lora_force_diffusers:
+ return
+ self.Linear_forward = patches.patch(__name__, torch.nn.Linear, 'forward', networks.network_Linear_forward)
+ self.Linear_load_state_dict = patches.patch(__name__, torch.nn.Linear, '_load_from_state_dict', networks.network_Linear_load_state_dict)
+ self.Conv2d_forward = patches.patch(__name__, torch.nn.Conv2d, 'forward', networks.network_Conv2d_forward)
+ self.Conv2d_load_state_dict = patches.patch(__name__, torch.nn.Conv2d, '_load_from_state_dict', networks.network_Conv2d_load_state_dict)
+ self.GroupNorm_forward = patches.patch(__name__, torch.nn.GroupNorm, 'forward', networks.network_GroupNorm_forward)
+ self.GroupNorm_load_state_dict = patches.patch(__name__, torch.nn.GroupNorm, '_load_from_state_dict', networks.network_GroupNorm_load_state_dict)
+ self.LayerNorm_forward = patches.patch(__name__, torch.nn.LayerNorm, 'forward', networks.network_LayerNorm_forward)
+ self.LayerNorm_load_state_dict = patches.patch(__name__, torch.nn.LayerNorm, '_load_from_state_dict', networks.network_LayerNorm_load_state_dict)
+ self.MultiheadAttention_forward = patches.patch(__name__, torch.nn.MultiheadAttention, 'forward', networks.network_MultiheadAttention_forward)
+ self.MultiheadAttention_load_state_dict = patches.patch(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict', networks.network_MultiheadAttention_load_state_dict)
+ networks.timer['load'] = 0
+ networks.timer['apply'] = 0
+ networks.timer['restore'] = 0
+ self.active = True
+
+ def undo(self):
+ if not self.active or shared.opts.lora_force_diffusers:
+ return
+ self.Linear_forward = patches.undo(__name__, torch.nn.Linear, 'forward') # pylint: disable=E1128
+ self.Linear_load_state_dict = patches.undo(__name__, torch.nn.Linear, '_load_from_state_dict') # pylint: disable=E1128
+ self.Conv2d_forward = patches.undo(__name__, torch.nn.Conv2d, 'forward') # pylint: disable=E1128
+ self.Conv2d_load_state_dict = patches.undo(__name__, torch.nn.Conv2d, '_load_from_state_dict') # pylint: disable=E1128
+ self.GroupNorm_forward = patches.undo(__name__, torch.nn.GroupNorm, 'forward') # pylint: disable=E1128
+ self.GroupNorm_load_state_dict = patches.undo(__name__, torch.nn.GroupNorm, '_load_from_state_dict') # pylint: disable=E1128
+ self.LayerNorm_forward = patches.undo(__name__, torch.nn.LayerNorm, 'forward') # pylint: disable=E1128
+ self.LayerNorm_load_state_dict = patches.undo(__name__, torch.nn.LayerNorm, '_load_from_state_dict') # pylint: disable=E1128
+ self.MultiheadAttention_forward = patches.undo(__name__, torch.nn.MultiheadAttention, 'forward') # pylint: disable=E1128
+ self.MultiheadAttention_load_state_dict = patches.undo(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict') # pylint: disable=E1128
+ patches.originals.pop(__name__, None)
+ self.active = False
diff --git a/extensions-builtin/Lora/lyco_helpers.py b/extensions-builtin/Lora/lyco_helpers.py
index 1679a0ce6..e0e1afa78 100644
--- a/extensions-builtin/Lora/lyco_helpers.py
+++ b/extensions-builtin/Lora/lyco_helpers.py
@@ -1,68 +1,68 @@
-import torch
-
-
-def make_weight_cp(t, wa, wb):
- temp = torch.einsum('i j k l, j r -> i r k l', t, wb)
- return torch.einsum('i j k l, i r -> r j k l', temp, wa)
-
-
-def rebuild_conventional(up, down, shape, dyn_dim=None):
- up = up.reshape(up.size(0), -1)
- down = down.reshape(down.size(0), -1)
- if dyn_dim is not None:
- up = up[:, :dyn_dim]
- down = down[:dyn_dim, :]
- return (up @ down).reshape(shape)
-
-
-def rebuild_cp_decomposition(up, down, mid):
- up = up.reshape(up.size(0), -1)
- down = down.reshape(down.size(0), -1)
- return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down)
-
-
-# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py
-def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
- '''
- return a tuple of two value of input dimension decomposed by the number closest to factor
- second value is higher or equal than first value.
-
- In LoRA with Kroneckor Product, first value is a value for weight scale.
- secon value is a value for weight.
-
- Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different.
-
- examples)
- factor
- -1 2 4 8 16 ...
- 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127
- 128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16
- 250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25
- 360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30
- 512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32
- 1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64
- '''
-
- if factor > 0 and (dimension % factor) == 0:
- m = factor
- n = dimension // factor
- if m > n:
- n, m = m, n
- return m, n
- if factor < 0:
- factor = dimension
- m, n = 1, dimension
- length = m + n
- while m length or new_m>factor:
- break
- else:
- m, n = new_m, new_n
- if m > n:
- n, m = m, n
- return m, n
-
+import torch
+
+
+def make_weight_cp(t, wa, wb):
+ temp = torch.einsum('i j k l, j r -> i r k l', t, wb)
+ return torch.einsum('i j k l, i r -> r j k l', temp, wa)
+
+
+def rebuild_conventional(up, down, shape, dyn_dim=None):
+ up = up.reshape(up.size(0), -1)
+ down = down.reshape(down.size(0), -1)
+ if dyn_dim is not None:
+ up = up[:, :dyn_dim]
+ down = down[:dyn_dim, :]
+ return (up @ down).reshape(shape)
+
+
+def rebuild_cp_decomposition(up, down, mid):
+ up = up.reshape(up.size(0), -1)
+ down = down.reshape(down.size(0), -1)
+ return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down)
+
+
+# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py
+def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
+ '''
+ return a tuple of two value of input dimension decomposed by the number closest to factor
+ second value is higher or equal than first value.
+
+ In LoRA with Kroneckor Product, first value is a value for weight scale.
+ secon value is a value for weight.
+
+ Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different.
+
+ examples)
+ factor
+ -1 2 4 8 16 ...
+ 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127
+ 128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16
+ 250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25
+ 360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30
+ 512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32
+ 1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64
+ '''
+
+ if factor > 0 and (dimension % factor) == 0:
+ m = factor
+ n = dimension // factor
+ if m > n:
+ n, m = m, n
+ return m, n
+ if factor < 0:
+ factor = dimension
+ m, n = 1, dimension
+ length = m + n
+ while m length or new_m>factor:
+ break
+ else:
+ m, n = new_m, new_n
+ if m > n:
+ n, m = m, n
+ return m, n
+
diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py
index ea22e9c3e..d0f3ebb7a 100644
--- a/extensions-builtin/Lora/network.py
+++ b/extensions-builtin/Lora/network.py
@@ -1,129 +1,129 @@
-from __future__ import annotations
-import os
-from collections import namedtuple
-import enum
-
-from modules import sd_models, hashes, shared
-
-NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
-
-metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
-
-
-class SdVersion(enum.Enum):
- Unknown = 1
- SD1 = 2
- SD2 = 3
- SDXL = 4
-
-
-class NetworkOnDisk:
- def __init__(self, name, filename):
- self.name = name
- self.filename = filename
- self.metadata = {}
- self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
-
- if self.is_safetensors:
- self.metadata = sd_models.read_metadata_from_safetensors(filename)
- if self.metadata:
- m = {}
- for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)):
- m[k] = v
- self.metadata = m
- self.alias = self.metadata.get('ss_output_name', self.name)
- self.hash = None
- self.shorthash = None
- self.set_hash(self.metadata.get('sshs_model_hash') or hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '')
- self.sd_version = self.detect_version()
-
- def detect_version(self):
- if str(self.metadata.get('ss_base_model_version', "")).startswith("sdxl_"):
- return SdVersion.SDXL
- elif str(self.metadata.get('ss_v2', "")) == "True":
- return SdVersion.SD2
- elif len(self.metadata):
- return SdVersion.SD1
- return SdVersion.Unknown
-
- def set_hash(self, v):
- self.hash = v
- self.shorthash = self.hash[0:12]
-
- def read_hash(self):
- if not self.hash:
- self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '')
-
- def get_alias(self):
- import networks
- return self.name if shared.opts.lora_preferred_name == "filename" or self.alias.lower() in networks.forbidden_network_aliases else self.alias
-
-
-class Network: # LoraModule
- def __init__(self, name, network_on_disk: NetworkOnDisk):
- self.name = name
- self.network_on_disk = network_on_disk
- self.te_multiplier = 1.0
- self.unet_multiplier = [1.0] * 3
- self.dyn_dim = None
- self.modules = {}
- self.mtime = None
- self.mentioned_name = None
- """the text that was used to add the network to prompt - can be either name or an alias"""
-
-
-class ModuleType:
- def create_module(self, net: Network, weights: NetworkWeights) -> Network | None: # pylint: disable=W0613
- return None
-
-
-class NetworkModule:
- def __init__(self, net: Network, weights: NetworkWeights):
- self.network = net
- self.network_key = weights.network_key
- self.sd_key = weights.sd_key
- self.sd_module = weights.sd_module
- if hasattr(self.sd_module, 'weight'):
- self.shape = self.sd_module.weight.shape
- self.dim = None
- self.bias = weights.w.get("bias")
- self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
- self.scale = weights.w["scale"].item() if "scale" in weights.w else None
-
- def multiplier(self):
- if 'transformer' in self.sd_key[:20]:
- return self.network.te_multiplier
- if "down_blocks" in self.sd_key:
- return self.network.unet_multiplier[0]
- if "mid_block" in self.sd_key:
- return self.network.unet_multiplier[1]
- if "up_blocks" in self.sd_key:
- return self.network.unet_multiplier[2]
- else:
- return self.network.unet_multiplier[0]
-
- def calc_scale(self):
- if self.scale is not None:
- return self.scale
- if self.dim is not None and self.alpha is not None:
- return self.alpha / self.dim
- return 1.0
-
- def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
- if self.bias is not None:
- updown = updown.reshape(self.bias.shape)
- updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
- updown = updown.reshape(output_shape)
- if len(output_shape) == 4:
- updown = updown.reshape(output_shape)
- if orig_weight.size().numel() == updown.size().numel():
- updown = updown.reshape(orig_weight.shape)
- if ex_bias is not None:
- ex_bias = ex_bias * self.multiplier()
- return updown * self.calc_scale() * self.multiplier(), ex_bias
-
- def calc_updown(self, target):
- raise NotImplementedError()
-
- def forward(self, x, y):
- raise NotImplementedError()
+from __future__ import annotations
+import os
+from collections import namedtuple
+import enum
+
+from modules import sd_models, hashes, shared
+
+NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
+
+metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
+
+
+class SdVersion(enum.Enum):
+ Unknown = 1
+ SD1 = 2
+ SD2 = 3
+ SDXL = 4
+
+
+class NetworkOnDisk:
+ def __init__(self, name, filename):
+ self.name = name
+ self.filename = filename
+ self.metadata = {}
+ self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
+
+ if self.is_safetensors:
+ self.metadata = sd_models.read_metadata_from_safetensors(filename)
+ if self.metadata:
+ m = {}
+ for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)):
+ m[k] = v
+ self.metadata = m
+ self.alias = self.metadata.get('ss_output_name', self.name)
+ self.hash = None
+ self.shorthash = None
+ self.set_hash(self.metadata.get('sshs_model_hash') or hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '')
+ self.sd_version = self.detect_version()
+
+ def detect_version(self):
+ if str(self.metadata.get('ss_base_model_version', "")).startswith("sdxl_"):
+ return SdVersion.SDXL
+ elif str(self.metadata.get('ss_v2', "")) == "True":
+ return SdVersion.SD2
+ elif len(self.metadata):
+ return SdVersion.SD1
+ return SdVersion.Unknown
+
+ def set_hash(self, v):
+ self.hash = v
+ self.shorthash = self.hash[0:12]
+
+ def read_hash(self):
+ if not self.hash:
+ self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '')
+
+ def get_alias(self):
+ import networks
+ return self.name if shared.opts.lora_preferred_name == "filename" or self.alias.lower() in networks.forbidden_network_aliases else self.alias
+
+
+class Network: # LoraModule
+ def __init__(self, name, network_on_disk: NetworkOnDisk):
+ self.name = name
+ self.network_on_disk = network_on_disk
+ self.te_multiplier = 1.0
+ self.unet_multiplier = [1.0] * 3
+ self.dyn_dim = None
+ self.modules = {}
+ self.mtime = None
+ self.mentioned_name = None
+ """the text that was used to add the network to prompt - can be either name or an alias"""
+
+
+class ModuleType:
+ def create_module(self, net: Network, weights: NetworkWeights) -> Network | None: # pylint: disable=W0613
+ return None
+
+
+class NetworkModule:
+ def __init__(self, net: Network, weights: NetworkWeights):
+ self.network = net
+ self.network_key = weights.network_key
+ self.sd_key = weights.sd_key
+ self.sd_module = weights.sd_module
+ if hasattr(self.sd_module, 'weight'):
+ self.shape = self.sd_module.weight.shape
+ self.dim = None
+ self.bias = weights.w.get("bias")
+ self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
+ self.scale = weights.w["scale"].item() if "scale" in weights.w else None
+
+ def multiplier(self):
+ if 'transformer' in self.sd_key[:20]:
+ return self.network.te_multiplier
+ if "down_blocks" in self.sd_key:
+ return self.network.unet_multiplier[0]
+ if "mid_block" in self.sd_key:
+ return self.network.unet_multiplier[1]
+ if "up_blocks" in self.sd_key:
+ return self.network.unet_multiplier[2]
+ else:
+ return self.network.unet_multiplier[0]
+
+ def calc_scale(self):
+ if self.scale is not None:
+ return self.scale
+ if self.dim is not None and self.alpha is not None:
+ return self.alpha / self.dim
+ return 1.0
+
+ def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
+ if self.bias is not None:
+ updown = updown.reshape(self.bias.shape)
+ updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
+ updown = updown.reshape(output_shape)
+ if len(output_shape) == 4:
+ updown = updown.reshape(output_shape)
+ if orig_weight.size().numel() == updown.size().numel():
+ updown = updown.reshape(orig_weight.shape)
+ if ex_bias is not None:
+ ex_bias = ex_bias * self.multiplier()
+ return updown * self.calc_scale() * self.multiplier(), ex_bias
+
+ def calc_updown(self, target):
+ raise NotImplementedError()
+
+ def forward(self, x, y):
+ raise NotImplementedError()
diff --git a/extensions-builtin/Lora/network_full.py b/extensions-builtin/Lora/network_full.py
index 233791712..ba9f2e359 100644
--- a/extensions-builtin/Lora/network_full.py
+++ b/extensions-builtin/Lora/network_full.py
@@ -1,27 +1,27 @@
-import network
-
-
-class ModuleTypeFull(network.ModuleType):
- def create_module(self, net: network.Network, weights: network.NetworkWeights):
- if all(x in weights.w for x in ["diff"]):
- return NetworkModuleFull(net, weights)
-
- return None
-
-
-class NetworkModuleFull(network.NetworkModule):
- def __init__(self, net: network.Network, weights: network.NetworkWeights):
- super().__init__(net, weights)
-
- self.weight = weights.w.get("diff")
- self.ex_bias = weights.w.get("diff_b")
-
- def calc_updown(self, target):
- output_shape = self.weight.shape
- updown = self.weight.to(target.device, dtype=target.dtype)
- if self.ex_bias is not None:
- ex_bias = self.ex_bias.to(target.device, dtype=target.dtype)
- else:
- ex_bias = None
-
- return self.finalize_updown(updown, target, output_shape, ex_bias)
+import network
+
+
+class ModuleTypeFull(network.ModuleType):
+ def create_module(self, net: network.Network, weights: network.NetworkWeights):
+ if all(x in weights.w for x in ["diff"]):
+ return NetworkModuleFull(net, weights)
+
+ return None
+
+
+class NetworkModuleFull(network.NetworkModule):
+ def __init__(self, net: network.Network, weights: network.NetworkWeights):
+ super().__init__(net, weights)
+
+ self.weight = weights.w.get("diff")
+ self.ex_bias = weights.w.get("diff_b")
+
+ def calc_updown(self, target):
+ output_shape = self.weight.shape
+ updown = self.weight.to(target.device, dtype=target.dtype)
+ if self.ex_bias is not None:
+ ex_bias = self.ex_bias.to(target.device, dtype=target.dtype)
+ else:
+ ex_bias = None
+
+ return self.finalize_updown(updown, target, output_shape, ex_bias)
diff --git a/extensions-builtin/Lora/network_hada.py b/extensions-builtin/Lora/network_hada.py
index 0feda761e..4e5924d04 100644
--- a/extensions-builtin/Lora/network_hada.py
+++ b/extensions-builtin/Lora/network_hada.py
@@ -1,46 +1,46 @@
-import lyco_helpers
-import network
-
-
-class ModuleTypeHada(network.ModuleType):
- def create_module(self, net: network.Network, weights: network.NetworkWeights):
- if all(x in weights.w for x in ["hada_w1_a", "hada_w1_b", "hada_w2_a", "hada_w2_b"]):
- return NetworkModuleHada(net, weights)
- return None
-
-
-class NetworkModuleHada(network.NetworkModule):
- def __init__(self, net: network.Network, weights: network.NetworkWeights):
- super().__init__(net, weights)
- if hasattr(self.sd_module, 'weight'):
- self.shape = self.sd_module.weight.shape
- self.w1a = weights.w["hada_w1_a"]
- self.w1b = weights.w["hada_w1_b"]
- self.dim = self.w1b.shape[0]
- self.w2a = weights.w["hada_w2_a"]
- self.w2b = weights.w["hada_w2_b"]
- self.t1 = weights.w.get("hada_t1")
- self.t2 = weights.w.get("hada_t2")
-
- def calc_updown(self, target):
- w1a = self.w1a.to(target.device, dtype=target.dtype)
- w1b = self.w1b.to(target.device, dtype=target.dtype)
- w2a = self.w2a.to(target.device, dtype=target.dtype)
- w2b = self.w2b.to(target.device, dtype=target.dtype)
- output_shape = [w1a.size(0), w1b.size(1)]
- if self.t1 is not None:
- output_shape = [w1a.size(1), w1b.size(1)]
- t1 = self.t1.to(target.device, dtype=target.dtype)
- updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b)
- output_shape += t1.shape[2:]
- else:
- if len(w1b.shape) == 4:
- output_shape += w1b.shape[2:]
- updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape)
- if self.t2 is not None:
- t2 = self.t2.to(target.device, dtype=target.dtype)
- updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
- else:
- updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape)
- updown = updown1 * updown2
- return self.finalize_updown(updown, target, output_shape)
+import lyco_helpers
+import network
+
+
+class ModuleTypeHada(network.ModuleType):
+ def create_module(self, net: network.Network, weights: network.NetworkWeights):
+ if all(x in weights.w for x in ["hada_w1_a", "hada_w1_b", "hada_w2_a", "hada_w2_b"]):
+ return NetworkModuleHada(net, weights)
+ return None
+
+
+class NetworkModuleHada(network.NetworkModule):
+ def __init__(self, net: network.Network, weights: network.NetworkWeights):
+ super().__init__(net, weights)
+ if hasattr(self.sd_module, 'weight'):
+ self.shape = self.sd_module.weight.shape
+ self.w1a = weights.w["hada_w1_a"]
+ self.w1b = weights.w["hada_w1_b"]
+ self.dim = self.w1b.shape[0]
+ self.w2a = weights.w["hada_w2_a"]
+ self.w2b = weights.w["hada_w2_b"]
+ self.t1 = weights.w.get("hada_t1")
+ self.t2 = weights.w.get("hada_t2")
+
+ def calc_updown(self, target):
+ w1a = self.w1a.to(target.device, dtype=target.dtype)
+ w1b = self.w1b.to(target.device, dtype=target.dtype)
+ w2a = self.w2a.to(target.device, dtype=target.dtype)
+ w2b = self.w2b.to(target.device, dtype=target.dtype)
+ output_shape = [w1a.size(0), w1b.size(1)]
+ if self.t1 is not None:
+ output_shape = [w1a.size(1), w1b.size(1)]
+ t1 = self.t1.to(target.device, dtype=target.dtype)
+ updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b)
+ output_shape += t1.shape[2:]
+ else:
+ if len(w1b.shape) == 4:
+ output_shape += w1b.shape[2:]
+ updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape)
+ if self.t2 is not None:
+ t2 = self.t2.to(target.device, dtype=target.dtype)
+ updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
+ else:
+ updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape)
+ updown = updown1 * updown2
+ return self.finalize_updown(updown, target, output_shape)
diff --git a/extensions-builtin/Lora/network_ia3.py b/extensions-builtin/Lora/network_ia3.py
index cb39df228..75316d97c 100644
--- a/extensions-builtin/Lora/network_ia3.py
+++ b/extensions-builtin/Lora/network_ia3.py
@@ -1,26 +1,26 @@
-import network
-
-
-class ModuleTypeIa3(network.ModuleType):
- def create_module(self, net: network.Network, weights: network.NetworkWeights):
- if all(x in weights.w for x in ["weight"]):
- return NetworkModuleIa3(net, weights)
-
- return None
-
-
-class NetworkModuleIa3(network.NetworkModule):
- def __init__(self, net: network.Network, weights: network.NetworkWeights):
- super().__init__(net, weights)
- self.w = weights.w["weight"]
- self.on_input = weights.w["on_input"].item()
-
- def calc_updown(self, target):
- w = self.w.to(target.device, dtype=target.dtype)
- output_shape = [w.size(0), target.size(1)]
- if self.on_input:
- output_shape.reverse()
- else:
- w = w.reshape(-1, 1)
- updown = target * w
- return self.finalize_updown(updown, target, output_shape)
+import network
+
+
+class ModuleTypeIa3(network.ModuleType):
+ def create_module(self, net: network.Network, weights: network.NetworkWeights):
+ if all(x in weights.w for x in ["weight"]):
+ return NetworkModuleIa3(net, weights)
+
+ return None
+
+
+class NetworkModuleIa3(network.NetworkModule):
+ def __init__(self, net: network.Network, weights: network.NetworkWeights):
+ super().__init__(net, weights)
+ self.w = weights.w["weight"]
+ self.on_input = weights.w["on_input"].item()
+
+ def calc_updown(self, target):
+ w = self.w.to(target.device, dtype=target.dtype)
+ output_shape = [w.size(0), target.size(1)]
+ if self.on_input:
+ output_shape.reverse()
+ else:
+ w = w.reshape(-1, 1)
+ updown = target * w
+ return self.finalize_updown(updown, target, output_shape)
diff --git a/extensions-builtin/Lora/network_lokr.py b/extensions-builtin/Lora/network_lokr.py
index 20387efee..cafdb968a 100644
--- a/extensions-builtin/Lora/network_lokr.py
+++ b/extensions-builtin/Lora/network_lokr.py
@@ -1,57 +1,57 @@
-import torch
-import lyco_helpers
-import network
-
-
-class ModuleTypeLokr(network.ModuleType):
- def create_module(self, net: network.Network, weights: network.NetworkWeights):
- has_1 = "lokr_w1" in weights.w or ("lokr_w1_a" in weights.w and "lokr_w1_b" in weights.w)
- has_2 = "lokr_w2" in weights.w or ("lokr_w2_a" in weights.w and "lokr_w2_b" in weights.w)
- if has_1 and has_2:
- return NetworkModuleLokr(net, weights)
- return None
-
-
-def make_kron(orig_shape, w1, w2):
- if len(w2.shape) == 4:
- w1 = w1.unsqueeze(2).unsqueeze(2)
- w2 = w2.contiguous()
- return torch.kron(w1, w2).reshape(orig_shape)
-
-
-class NetworkModuleLokr(network.NetworkModule):
- def __init__(self, net: network.Network, weights: network.NetworkWeights):
- super().__init__(net, weights)
- self.w1 = weights.w.get("lokr_w1")
- self.w1a = weights.w.get("lokr_w1_a")
- self.w1b = weights.w.get("lokr_w1_b")
- self.dim = self.w1b.shape[0] if self.w1b is not None else self.dim
- self.w2 = weights.w.get("lokr_w2")
- self.w2a = weights.w.get("lokr_w2_a")
- self.w2b = weights.w.get("lokr_w2_b")
- self.dim = self.w2b.shape[0] if self.w2b is not None else self.dim
- self.t2 = weights.w.get("lokr_t2")
-
- def calc_updown(self, target):
- if self.w1 is not None:
- w1 = self.w1.to(target.device, dtype=target.dtype)
- else:
- w1a = self.w1a.to(target.device, dtype=target.dtype)
- w1b = self.w1b.to(target.device, dtype=target.dtype)
- w1 = w1a @ w1b
- if self.w2 is not None:
- w2 = self.w2.to(target.device, dtype=target.dtype)
- elif self.t2 is None:
- w2a = self.w2a.to(target.device, dtype=target.dtype)
- w2b = self.w2b.to(target.device, dtype=target.dtype)
- w2 = w2a @ w2b
- else:
- t2 = self.t2.to(target.device, dtype=target.dtype)
- w2a = self.w2a.to(target.device, dtype=target.dtype)
- w2b = self.w2b.to(target.device, dtype=target.dtype)
- w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
- output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)]
- if len(target.shape) == 4:
- output_shape = target.shape
- updown = make_kron(output_shape, w1, w2)
- return self.finalize_updown(updown, target, output_shape)
+import torch
+import lyco_helpers
+import network
+
+
+class ModuleTypeLokr(network.ModuleType):
+ def create_module(self, net: network.Network, weights: network.NetworkWeights):
+ has_1 = "lokr_w1" in weights.w or ("lokr_w1_a" in weights.w and "lokr_w1_b" in weights.w)
+ has_2 = "lokr_w2" in weights.w or ("lokr_w2_a" in weights.w and "lokr_w2_b" in weights.w)
+ if has_1 and has_2:
+ return NetworkModuleLokr(net, weights)
+ return None
+
+
+def make_kron(orig_shape, w1, w2):
+ if len(w2.shape) == 4:
+ w1 = w1.unsqueeze(2).unsqueeze(2)
+ w2 = w2.contiguous()
+ return torch.kron(w1, w2).reshape(orig_shape)
+
+
+class NetworkModuleLokr(network.NetworkModule):
+ def __init__(self, net: network.Network, weights: network.NetworkWeights):
+ super().__init__(net, weights)
+ self.w1 = weights.w.get("lokr_w1")
+ self.w1a = weights.w.get("lokr_w1_a")
+ self.w1b = weights.w.get("lokr_w1_b")
+ self.dim = self.w1b.shape[0] if self.w1b is not None else self.dim
+ self.w2 = weights.w.get("lokr_w2")
+ self.w2a = weights.w.get("lokr_w2_a")
+ self.w2b = weights.w.get("lokr_w2_b")
+ self.dim = self.w2b.shape[0] if self.w2b is not None else self.dim
+ self.t2 = weights.w.get("lokr_t2")
+
+ def calc_updown(self, target):
+ if self.w1 is not None:
+ w1 = self.w1.to(target.device, dtype=target.dtype)
+ else:
+ w1a = self.w1a.to(target.device, dtype=target.dtype)
+ w1b = self.w1b.to(target.device, dtype=target.dtype)
+ w1 = w1a @ w1b
+ if self.w2 is not None:
+ w2 = self.w2.to(target.device, dtype=target.dtype)
+ elif self.t2 is None:
+ w2a = self.w2a.to(target.device, dtype=target.dtype)
+ w2b = self.w2b.to(target.device, dtype=target.dtype)
+ w2 = w2a @ w2b
+ else:
+ t2 = self.t2.to(target.device, dtype=target.dtype)
+ w2a = self.w2a.to(target.device, dtype=target.dtype)
+ w2b = self.w2b.to(target.device, dtype=target.dtype)
+ w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
+ output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)]
+ if len(target.shape) == 4:
+ output_shape = target.shape
+ updown = make_kron(output_shape, w1, w2)
+ return self.finalize_updown(updown, target, output_shape)
diff --git a/extensions-builtin/Lora/network_lora.py b/extensions-builtin/Lora/network_lora.py
index 8c2c4c8a5..8f31388d4 100644
--- a/extensions-builtin/Lora/network_lora.py
+++ b/extensions-builtin/Lora/network_lora.py
@@ -1,75 +1,75 @@
-import torch
-
-import diffusers.models.lora as diffusers_lora
-import lyco_helpers
-import network
-from modules import devices
-
-
-class ModuleTypeLora(network.ModuleType):
- def create_module(self, net: network.Network, weights: network.NetworkWeights):
- if all(x in weights.w for x in ["lora_up.weight", "lora_down.weight"]):
- return NetworkModuleLora(net, weights)
- return None
-
-
-class NetworkModuleLora(network.NetworkModule):
- def __init__(self, net: network.Network, weights: network.NetworkWeights):
- super().__init__(net, weights)
- self.up_model = self.create_module(weights.w, "lora_up.weight")
- self.down_model = self.create_module(weights.w, "lora_down.weight")
- self.mid_model = self.create_module(weights.w, "lora_mid.weight", none_ok=True)
- self.dim = weights.w["lora_down.weight"].shape[0]
-
- def create_module(self, weights, key, none_ok=False):
- weight = weights.get(key)
- if weight is None and none_ok:
- return None
- is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention, diffusers_lora.LoRACompatibleLinear]
- is_conv = type(self.sd_module) in [torch.nn.Conv2d, diffusers_lora.LoRACompatibleConv]
- if is_linear:
- weight = weight.reshape(weight.shape[0], -1)
- module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
- elif is_conv and key == "lora_down.weight" or key == "dyn_up":
- if len(weight.shape) == 2:
- weight = weight.reshape(weight.shape[0], -1, 1, 1)
- if weight.shape[2] != 1 or weight.shape[3] != 1:
- module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)
- else:
- module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
- elif is_conv and key == "lora_mid.weight":
- module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)
- elif is_conv and key == "lora_up.weight" or key == "dyn_down":
- module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
- else:
- raise AssertionError(f'Lora layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}')
- with torch.no_grad():
- if weight.shape != module.weight.shape:
- weight = weight.reshape(module.weight.shape)
- module.weight.copy_(weight)
- module.to(device=devices.cpu, dtype=devices.dtype)
- module.weight.requires_grad_(False)
- return module
-
- def calc_updown(self, target): # pylint: disable=W0237
- up = self.up_model.weight.to(target.device, dtype=target.dtype)
- down = self.down_model.weight.to(target.device, dtype=target.dtype)
- output_shape = [up.size(0), down.size(1)]
- if self.mid_model is not None:
- # cp-decomposition
- mid = self.mid_model.weight.to(target.device, dtype=target.dtype)
- updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid)
- output_shape += mid.shape[2:]
- else:
- if len(down.shape) == 4:
- output_shape += down.shape[2:]
- updown = lyco_helpers.rebuild_conventional(up, down, output_shape, self.network.dyn_dim)
- return self.finalize_updown(updown, target, output_shape)
-
- def forward(self, x, y):
- self.up_model.to(device=devices.device)
- self.down_model.to(device=devices.device)
- if hasattr(y, "scale"):
- return y(scale=1) + self.up_model(self.down_model(x)) * self.multiplier() * self.calc_scale()
-
- return y + self.up_model(self.down_model(x)) * self.multiplier() * self.calc_scale()
+import torch
+
+import diffusers.models.lora as diffusers_lora
+import lyco_helpers
+import network
+from modules import devices
+
+
+class ModuleTypeLora(network.ModuleType):
+ def create_module(self, net: network.Network, weights: network.NetworkWeights):
+ if all(x in weights.w for x in ["lora_up.weight", "lora_down.weight"]):
+ return NetworkModuleLora(net, weights)
+ return None
+
+
+class NetworkModuleLora(network.NetworkModule):
+ def __init__(self, net: network.Network, weights: network.NetworkWeights):
+ super().__init__(net, weights)
+ self.up_model = self.create_module(weights.w, "lora_up.weight")
+ self.down_model = self.create_module(weights.w, "lora_down.weight")
+ self.mid_model = self.create_module(weights.w, "lora_mid.weight", none_ok=True)
+ self.dim = weights.w["lora_down.weight"].shape[0]
+
+ def create_module(self, weights, key, none_ok=False):
+ weight = weights.get(key)
+ if weight is None and none_ok:
+ return None
+ is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention, diffusers_lora.LoRACompatibleLinear]
+ is_conv = type(self.sd_module) in [torch.nn.Conv2d, diffusers_lora.LoRACompatibleConv]
+ if is_linear:
+ weight = weight.reshape(weight.shape[0], -1)
+ module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
+ elif is_conv and key == "lora_down.weight" or key == "dyn_up":
+ if len(weight.shape) == 2:
+ weight = weight.reshape(weight.shape[0], -1, 1, 1)
+ if weight.shape[2] != 1 or weight.shape[3] != 1:
+ module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)
+ else:
+ module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
+ elif is_conv and key == "lora_mid.weight":
+ module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)
+ elif is_conv and key == "lora_up.weight" or key == "dyn_down":
+ module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
+ else:
+ raise AssertionError(f'Lora layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}')
+ with torch.no_grad():
+ if weight.shape != module.weight.shape:
+ weight = weight.reshape(module.weight.shape)
+ module.weight.copy_(weight)
+ module.to(device=devices.cpu, dtype=devices.dtype)
+ module.weight.requires_grad_(False)
+ return module
+
+ def calc_updown(self, target): # pylint: disable=W0237
+ up = self.up_model.weight.to(target.device, dtype=target.dtype)
+ down = self.down_model.weight.to(target.device, dtype=target.dtype)
+ output_shape = [up.size(0), down.size(1)]
+ if self.mid_model is not None:
+ # cp-decomposition
+ mid = self.mid_model.weight.to(target.device, dtype=target.dtype)
+ updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid)
+ output_shape += mid.shape[2:]
+ else:
+ if len(down.shape) == 4:
+ output_shape += down.shape[2:]
+ updown = lyco_helpers.rebuild_conventional(up, down, output_shape, self.network.dyn_dim)
+ return self.finalize_updown(updown, target, output_shape)
+
+ def forward(self, x, y):
+ self.up_model.to(device=devices.device)
+ self.down_model.to(device=devices.device)
+ if hasattr(y, "scale"):
+ return y(scale=1) + self.up_model(self.down_model(x)) * self.multiplier() * self.calc_scale()
+
+ return y + self.up_model(self.down_model(x)) * self.multiplier() * self.calc_scale()
diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py
index 5518e2527..c61bc5355 100644
--- a/extensions-builtin/Lora/networks.py
+++ b/extensions-builtin/Lora/networks.py
@@ -1,495 +1,495 @@
-from typing import Union, List
-import os
-import re
-import time
-import concurrent
-import lora_patches
-import network
-import network_lora
-import network_hada
-import network_ia3
-import network_oft
-import network_lokr
-import network_full
-import network_norm
-import network_glora
-import lora_convert
-import torch
-import diffusers.models.lora
-from modules import shared, devices, sd_models, sd_models_compile, errors, scripts
-
-
-debug = os.environ.get('SD_LORA_DEBUG', None) is not None
-originals: lora_patches.LoraPatches = None
-extra_network_lora = None
-available_networks = {}
-available_network_aliases = {}
-loaded_networks: List[network.Network] = []
-timer = { 'load': 0, 'apply': 0, 'restore': 0 }
-# networks_in_memory = {}
-lora_cache = {}
-available_network_hash_lookup = {}
-forbidden_network_aliases = {}
-re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
-module_types = [
- network_lora.ModuleTypeLora(),
- network_hada.ModuleTypeHada(),
- network_ia3.ModuleTypeIa3(),
- network_oft.ModuleTypeOFT(),
- network_lokr.ModuleTypeLokr(),
- network_full.ModuleTypeFull(),
- network_norm.ModuleTypeNorm(),
- network_glora.ModuleTypeGLora(),
-]
-convert_diffusers_name_to_compvis = lora_convert.convert_diffusers_name_to_compvis # supermerger compatibility item
-
-
-def assign_network_names_to_compvis_modules(sd_model):
- network_layer_mapping = {}
- if shared.backend == shared.Backend.DIFFUSERS:
- if not hasattr(shared.sd_model, 'text_encoder') or not hasattr(shared.sd_model, 'unet'):
- return
- for name, module in shared.sd_model.text_encoder.named_modules():
- prefix = "lora_te1_" if shared.sd_model_type == "sdxl" else "lora_te_"
- network_name = prefix + name.replace(".", "_")
- network_layer_mapping[network_name] = module
- module.network_layer_name = network_name
- if shared.sd_model_type == "sdxl":
- for name, module in shared.sd_model.text_encoder_2.named_modules():
- network_name = "lora_te2_" + name.replace(".", "_")
- network_layer_mapping[network_name] = module
- module.network_layer_name = network_name
- for name, module in shared.sd_model.unet.named_modules():
- network_name = "lora_unet_" + name.replace(".", "_")
- network_layer_mapping[network_name] = module
- module.network_layer_name = network_name
- else:
- if not hasattr(shared.sd_model, 'cond_stage_model'):
- return
- for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
- network_name = name.replace(".", "_")
- network_layer_mapping[network_name] = module
- module.network_layer_name = network_name
- for name, module in shared.sd_model.model.named_modules():
- network_name = name.replace(".", "_")
- network_layer_mapping[network_name] = module
- module.network_layer_name = network_name
- sd_model.network_layer_mapping = network_layer_mapping
-
-
-def load_diffusers(name, network_on_disk, lora_scale=1.0) -> network.Network:
- t0 = time.time()
- cached = lora_cache.get(name, None)
- # if debug:
- shared.log.debug(f'LoRA load: name="{name}" file="{network_on_disk.filename}" type=diffusers {"cached" if cached else ""} fuse={shared.opts.lora_fuse_diffusers}')
- if cached is not None:
- return cached
- if shared.backend != shared.Backend.DIFFUSERS:
- return None
- shared.sd_model.load_lora_weights(network_on_disk.filename)
- if shared.opts.lora_fuse_diffusers:
- shared.sd_model.fuse_lora(lora_scale=lora_scale)
- net = network.Network(name, network_on_disk)
- net.mtime = os.path.getmtime(network_on_disk.filename)
- lora_cache[name] = net
- t1 = time.time()
- timer['load'] += t1 - t0
- return net
-
-
-def load_network(name, network_on_disk) -> network.Network:
- t0 = time.time()
- cached = lora_cache.get(name, None)
- if debug:
- shared.log.debug(f'LoRA load: name="{name}" file="{network_on_disk.filename}" type=lora {"cached" if cached else ""}')
- if cached is not None:
- return cached
- net = network.Network(name, network_on_disk)
- net.mtime = os.path.getmtime(network_on_disk.filename)
- sd = sd_models.read_state_dict(network_on_disk.filename)
- assign_network_names_to_compvis_modules(shared.sd_model) # this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0
- keys_failed_to_match = {}
- matched_networks = {}
- convert = lora_convert.KeyConvert()
- for key_network, weight in sd.items():
- parts = key_network.split('.')
- if len(parts) > 5: # messy handler for diffusers peft lora
- key_network_without_network_parts = '_'.join(parts[:-2])
- if not key_network_without_network_parts.startswith('lora_'):
- key_network_without_network_parts = 'lora_' + key_network_without_network_parts
- network_part = '.'.join(parts[-2:]).replace('lora_A', 'lora_down').replace('lora_B', 'lora_up')
- else:
- key_network_without_network_parts, network_part = key_network.split(".", 1)
- # if debug:
- # shared.log.debug(f'LoRA load: name="{name}" full={key_network} network={network_part} key={key_network_without_network_parts}')
- key, sd_module = convert(key_network_without_network_parts)
- if sd_module is None:
- keys_failed_to_match[key_network] = key
- continue
- if key not in matched_networks:
- matched_networks[key] = network.NetworkWeights(network_key=key_network, sd_key=key, w={}, sd_module=sd_module)
- matched_networks[key].w[network_part] = weight
- for key, weights in matched_networks.items():
- net_module = None
- for nettype in module_types:
- net_module = nettype.create_module(net, weights)
- if net_module is not None:
- break
- if net_module is None:
- shared.log.error(f'LoRA unhandled: name={name} key={key} weights={weights.w.keys()}')
- else:
- net.modules[key] = net_module
- if len(keys_failed_to_match) > 0:
- shared.log.warning(f"LoRA file={network_on_disk.filename} unmatched={len(keys_failed_to_match)} matched={len(matched_networks)}")
- if debug:
- shared.log.debug(f"LoRA file={network_on_disk.filename} unmatched={keys_failed_to_match}")
- elif debug:
- shared.log.debug(f"LoRA file={network_on_disk.filename} unmatched={len(keys_failed_to_match)} matched={len(matched_networks)}")
- lora_cache[name] = net
- t1 = time.time()
- timer['load'] += t1 - t0
- return net
-
-
-def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
- networks_on_disk = [available_network_aliases.get(name, None) for name in names]
- if any(x is None for x in networks_on_disk):
- list_available_networks()
- networks_on_disk = [available_network_aliases.get(name, None) for name in names]
- failed_to_load_networks = []
-
- recompile_model = False
- if ((shared.opts.cuda_compile and shared.opts.cuda_compile_backend == "openvino_fx") or
- shared.opts.nncf_compress_weights or shared.opts.nncf_compress_text_encoder_weights):
- if len(names) == len(shared.compiled_model_state.lora_model):
- for i, name in enumerate(names):
- if shared.compiled_model_state.lora_model[i] != f"{name}:{te_multipliers[i] if te_multipliers else 1.0}":
- recompile_model = True
- shared.compiled_model_state.lora_model = []
- break
- if not recompile_model:
- if len(loaded_networks) > 0 and debug:
- shared.log.debug('OpenVINO: Skipping LoRa loading')
- return
- else:
- recompile_model = True
- shared.compiled_model_state.lora_model = []
- if recompile_model:
- backup_cuda_compile = shared.opts.cuda_compile
- backup_nncf_compress_weights = shared.opts.nncf_compress_weights
- backup_nncf_compress_text_encoder_weights = shared.opts.nncf_compress_text_encoder_weights
- shared.compiled_model_state.lora_compile = True
- sd_models.unload_model_weights(op='model')
- shared.opts.cuda_compile = False
- shared.opts.nncf_compress_weights = False
- shared.opts.nncf_compress_text_encoder_weights = False
- sd_models.reload_model_weights(op='model')
- shared.opts.cuda_compile = backup_cuda_compile
- shared.opts.nncf_compress_weights = backup_nncf_compress_weights
- shared.opts.nncf_compress_text_encoder_weights = backup_nncf_compress_text_encoder_weights
-
- loaded_networks.clear()
- for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)):
- net = None
- if network_on_disk is not None:
- if debug:
- shared.log.debug(f'LoRA load start: name="{name}" file="{network_on_disk.filename}"')
- try:
- if recompile_model:
- shared.compiled_model_state.lora_model.append(f"{name}:{te_multipliers[i] if te_multipliers else 1.0}")
- if shared.backend == shared.Backend.DIFFUSERS and shared.opts.lora_force_diffusers: # OpenVINO only works with Diffusers LoRa loading.
- # or getattr(network_on_disk, 'shorthash', '').lower() == 'aaebf6360f7d' # sd15-lcm
- # or getattr(network_on_disk, 'shorthash', '').lower() == '3d18b05e4f56' # sdxl-lcm
- # or getattr(network_on_disk, 'shorthash', '').lower() == '813ea5fb1c67' # turbo sdxl-turbo
- net = load_diffusers(name, network_on_disk, lora_scale=te_multipliers[i] if te_multipliers else 1.0)
- else:
- net = load_network(name, network_on_disk)
- except Exception as e:
- shared.log.error(f"LoRA load failed: file={network_on_disk.filename} {e}")
- if debug:
- errors.display(e, f"LoRA load failed file={network_on_disk.filename}")
- continue
- net.mentioned_name = name
- network_on_disk.read_hash()
- if net is None:
- failed_to_load_networks.append(name)
- shared.log.error(f"LoRA unknown type: network={name}")
- continue
- net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0
- net.unet_multiplier = unet_multipliers[i] if unet_multipliers else 1.0
- net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0
- loaded_networks.append(net)
-
- while len(lora_cache) > shared.opts.lora_in_memory_limit:
- name = next(iter(lora_cache))
- lora_cache.pop(name, None)
- if len(loaded_networks) > 0 and debug:
- shared.log.debug(f'LoRA loaded={len(loaded_networks)} cache={list(lora_cache)}')
- devices.torch_gc()
-
- if recompile_model:
- shared.log.info("LoRA recompiling model")
- sd_models_compile.compile_diffusers(shared.sd_model)
-
-
-def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv]):
- t0 = time.time()
- weights_backup = getattr(self, "network_weights_backup", None)
- bias_backup = getattr(self, "network_bias_backup", None)
- if weights_backup is None and bias_backup is None:
- return
- # if debug:
- # shared.log.debug('LoRA restore weights')
- if weights_backup is not None:
- if isinstance(self, torch.nn.MultiheadAttention):
- self.in_proj_weight.copy_(weights_backup[0])
- self.out_proj.weight.copy_(weights_backup[1])
- else:
- self.weight.copy_(weights_backup)
- if bias_backup is not None:
- if isinstance(self, torch.nn.MultiheadAttention):
- self.out_proj.bias.copy_(bias_backup)
- else:
- self.bias.copy_(bias_backup)
- else:
- if isinstance(self, torch.nn.MultiheadAttention):
- self.out_proj.bias = None
- else:
- self.bias = None
- t1 = time.time()
- timer['restore'] += t1 - t0
-
-
-def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv]):
- """
- Applies the currently selected set of networks to the weights of torch layer self.
- If weights already have this particular set of networks applied, does nothing.
- If not, restores orginal weights from backup and alters weights according to networks.
- """
- network_layer_name = getattr(self, 'network_layer_name', None)
- if network_layer_name is None:
- return
- t0 = time.time()
- current_names = getattr(self, "network_current_names", ())
- wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks)
- weights_backup = getattr(self, "network_weights_backup", None)
- if weights_backup is None and wanted_names != (): # pylint: disable=C1803
- if current_names != ():
- raise RuntimeError("no backup weights found and current weights are not unchanged")
- if isinstance(self, torch.nn.MultiheadAttention):
- weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True))
- else:
- weights_backup = self.weight.to(devices.cpu, copy=True)
- self.network_weights_backup = weights_backup
- bias_backup = getattr(self, "network_bias_backup", None)
- if bias_backup is None:
- if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None:
- bias_backup = self.out_proj.bias.to(devices.cpu, copy=True)
- elif getattr(self, 'bias', None) is not None:
- bias_backup = self.bias.to(devices.cpu, copy=True)
- else:
- bias_backup = None
- self.network_bias_backup = bias_backup
-
- if current_names != wanted_names:
- network_restore_weights_from_backup(self)
- for net in loaded_networks:
- # default workflow where module is known and has weights
- module = net.modules.get(network_layer_name, None)
- if module is not None and hasattr(self, 'weight'):
- try:
- with devices.inference_context():
- updown, ex_bias = module.calc_updown(self.weight)
- if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
- # inpainting model. zero pad updown to make channel[1] 4 to 9
- updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) # pylint: disable=not-callable
- self.weight += updown
- if ex_bias is not None and hasattr(self, 'bias'):
- if self.bias is None:
- self.bias = torch.nn.Parameter(ex_bias)
- else:
- self.bias += ex_bias
- except RuntimeError as e:
- extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
- if debug:
- module_name = net.modules.get(network_layer_name, None)
- shared.log.error(f"LoRA apply weight name={net.name} module={module_name} layer={network_layer_name} {e}")
- errors.display(e, 'LoRA apply weight')
- raise RuntimeError('LoRA apply weight') from e
- continue
- # alternative workflow looking at _*_proj layers
- module_q = net.modules.get(network_layer_name + "_q_proj", None)
- module_k = net.modules.get(network_layer_name + "_k_proj", None)
- module_v = net.modules.get(network_layer_name + "_v_proj", None)
- module_out = net.modules.get(network_layer_name + "_out_proj", None)
- if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
- try:
- with devices.inference_context():
- updown_q, _ = module_q.calc_updown(self.in_proj_weight)
- updown_k, _ = module_k.calc_updown(self.in_proj_weight)
- updown_v, _ = module_v.calc_updown(self.in_proj_weight)
- updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
- updown_out, ex_bias = module_out.calc_updown(self.out_proj.weight)
- self.in_proj_weight += updown_qkv
- self.out_proj.weight += updown_out
- if ex_bias is not None:
- if self.out_proj.bias is None:
- self.out_proj.bias = torch.nn.Parameter(ex_bias)
- else:
- self.out_proj.bias += ex_bias
- except RuntimeError as e:
- if debug:
- shared.log.debug(f"LoRA network={net.name} layer={network_layer_name} {e}")
- extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
- continue
- if module is None:
- continue
- shared.log.warning(f"LoRA network={net.name} layer={network_layer_name} unsupported operation")
- extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
- self.network_current_names = wanted_names
- t1 = time.time()
- timer['apply'] += t1 - t0
-
-
-def network_forward(module, input, original_forward): # pylint: disable=W0622
- """
- Old way of applying Lora by executing operations during layer's forward.
- Stacking many loras this way results in big performance degradation.
- """
- if len(loaded_networks) == 0:
- return original_forward(module, input)
- input = devices.cond_cast_unet(input)
- network_restore_weights_from_backup(module)
- network_reset_cached_weight(module)
- y = original_forward(module, input)
- network_layer_name = getattr(module, 'network_layer_name', None)
- for lora in loaded_networks:
- module = lora.modules.get(network_layer_name, None)
- if module is None:
- continue
- y = module.forward(input, y)
- return y
-
-
-def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
- self.network_current_names = ()
- self.network_weights_backup = None
-
-
-def network_Linear_forward(self, input): # pylint: disable=W0622
- if shared.opts.lora_functional:
- return network_forward(self, input, originals.Linear_forward)
- network_apply_weights(self)
- return originals.Linear_forward(self, input)
-
-
-def network_Linear_load_state_dict(self, *args, **kwargs):
- network_reset_cached_weight(self)
- return originals.Linear_load_state_dict(self, *args, **kwargs)
-
-
-def network_Conv2d_forward(self, input): # pylint: disable=W0622
- if shared.opts.lora_functional:
- return network_forward(self, input, originals.Conv2d_forward)
- network_apply_weights(self)
- return originals.Conv2d_forward(self, input)
-
-
-def network_Conv2d_load_state_dict(self, *args, **kwargs):
- network_reset_cached_weight(self)
- return originals.Conv2d_load_state_dict(self, *args, **kwargs)
-
-
-def network_GroupNorm_forward(self, input): # pylint: disable=W0622
- if shared.opts.lora_functional:
- return network_forward(self, input, originals.GroupNorm_forward)
- network_apply_weights(self)
- return originals.GroupNorm_forward(self, input)
-
-
-def network_GroupNorm_load_state_dict(self, *args, **kwargs):
- network_reset_cached_weight(self)
- return originals.GroupNorm_load_state_dict(self, *args, **kwargs)
-
-
-def network_LayerNorm_forward(self, input): # pylint: disable=W0622
- if shared.opts.lora_functional:
- return network_forward(self, input, originals.LayerNorm_forward)
- network_apply_weights(self)
- return originals.LayerNorm_forward(self, input)
-
-
-def network_LayerNorm_load_state_dict(self, *args, **kwargs):
- network_reset_cached_weight(self)
- return originals.LayerNorm_load_state_dict(self, *args, **kwargs)
-
-
-def network_MultiheadAttention_forward(self, *args, **kwargs):
- network_apply_weights(self)
- return originals.MultiheadAttention_forward(self, *args, **kwargs)
-
-
-def network_MultiheadAttention_load_state_dict(self, *args, **kwargs):
- network_reset_cached_weight(self)
- return originals.MultiheadAttention_load_state_dict(self, *args, **kwargs)
-
-
-def list_available_networks():
- available_networks.clear()
- available_network_aliases.clear()
- forbidden_network_aliases.clear()
- available_network_hash_lookup.clear()
- forbidden_network_aliases.update({"none": 1, "Addams": 1})
- os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
- candidates = []
- if os.path.exists(shared.cmd_opts.lora_dir):
- candidates += list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
- else:
- shared.log.warning('LoRA directory not found: path="{shared.cmd_opts.lora_dir}"')
- if os.path.exists(shared.cmd_opts.lyco_dir):
- candidates += list(shared.walk_files(shared.cmd_opts.lyco_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
-
- def add_network(filename):
- if os.path.isdir(filename):
- return
- name = os.path.splitext(os.path.basename(filename))[0]
- try:
- entry = network.NetworkOnDisk(name, filename)
- available_networks[entry.name] = entry
- if entry.alias in available_network_aliases:
- forbidden_network_aliases[entry.alias.lower()] = 1
- available_network_aliases[entry.name] = entry
- available_network_aliases[entry.alias] = entry
- if entry.shorthash:
- available_network_hash_lookup[entry.shorthash] = entry
- except OSError as e: # should catch FileNotFoundError and PermissionError etc.
- shared.log.error(f"Failed to load network {name} from {filename} {e}")
-
- with concurrent.futures.ThreadPoolExecutor(max_workers=shared.max_workers) as executor:
- for fn in candidates:
- executor.submit(add_network, fn)
-
-
-def infotext_pasted(infotext, params): # pylint: disable=W0613
- if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]:
- return # if the other extension is active, it will handle those fields, no need to do anything
- added = []
- for k in params:
- if not k.startswith("AddNet Model "):
- continue
- num = k[13:]
- if params.get("AddNet Module " + num) != "LoRA":
- continue
- name = params.get("AddNet Model " + num)
- if name is None:
- continue
- m = re_network_name.match(name)
- if m:
- name = m.group(1)
- multiplier = params.get("AddNet Weight A " + num, "1.0")
- added.append(f"")
- if added:
- params["Prompt"] += "\n" + "".join(added)
-
-
-list_available_networks()
+from typing import Union, List
+import os
+import re
+import time
+import concurrent
+import lora_patches
+import network
+import network_lora
+import network_hada
+import network_ia3
+import network_oft
+import network_lokr
+import network_full
+import network_norm
+import network_glora
+import lora_convert
+import torch
+import diffusers.models.lora
+from modules import shared, devices, sd_models, sd_models_compile, errors, scripts
+
+
+debug = os.environ.get('SD_LORA_DEBUG', None) is not None
+originals: lora_patches.LoraPatches = None
+extra_network_lora = None
+available_networks = {}
+available_network_aliases = {}
+loaded_networks: List[network.Network] = []
+timer = { 'load': 0, 'apply': 0, 'restore': 0 }
+# networks_in_memory = {}
+lora_cache = {}
+available_network_hash_lookup = {}
+forbidden_network_aliases = {}
+re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
+module_types = [
+ network_lora.ModuleTypeLora(),
+ network_hada.ModuleTypeHada(),
+ network_ia3.ModuleTypeIa3(),
+ network_oft.ModuleTypeOFT(),
+ network_lokr.ModuleTypeLokr(),
+ network_full.ModuleTypeFull(),
+ network_norm.ModuleTypeNorm(),
+ network_glora.ModuleTypeGLora(),
+]
+convert_diffusers_name_to_compvis = lora_convert.convert_diffusers_name_to_compvis # supermerger compatibility item
+
+
+def assign_network_names_to_compvis_modules(sd_model):
+ network_layer_mapping = {}
+ if shared.backend == shared.Backend.DIFFUSERS:
+ if not hasattr(shared.sd_model, 'text_encoder') or not hasattr(shared.sd_model, 'unet'):
+ return
+ for name, module in shared.sd_model.text_encoder.named_modules():
+ prefix = "lora_te1_" if shared.sd_model_type == "sdxl" else "lora_te_"
+ network_name = prefix + name.replace(".", "_")
+ network_layer_mapping[network_name] = module
+ module.network_layer_name = network_name
+ if shared.sd_model_type == "sdxl":
+ for name, module in shared.sd_model.text_encoder_2.named_modules():
+ network_name = "lora_te2_" + name.replace(".", "_")
+ network_layer_mapping[network_name] = module
+ module.network_layer_name = network_name
+ for name, module in shared.sd_model.unet.named_modules():
+ network_name = "lora_unet_" + name.replace(".", "_")
+ network_layer_mapping[network_name] = module
+ module.network_layer_name = network_name
+ else:
+ if not hasattr(shared.sd_model, 'cond_stage_model'):
+ return
+ for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
+ network_name = name.replace(".", "_")
+ network_layer_mapping[network_name] = module
+ module.network_layer_name = network_name
+ for name, module in shared.sd_model.model.named_modules():
+ network_name = name.replace(".", "_")
+ network_layer_mapping[network_name] = module
+ module.network_layer_name = network_name
+ sd_model.network_layer_mapping = network_layer_mapping
+
+
+def load_diffusers(name, network_on_disk, lora_scale=1.0) -> network.Network:
+ t0 = time.time()
+ cached = lora_cache.get(name, None)
+ # if debug:
+ shared.log.debug(f'LoRA load: name="{name}" file="{network_on_disk.filename}" type=diffusers {"cached" if cached else ""} fuse={shared.opts.lora_fuse_diffusers}')
+ if cached is not None:
+ return cached
+ if shared.backend != shared.Backend.DIFFUSERS:
+ return None
+ shared.sd_model.load_lora_weights(network_on_disk.filename)
+ if shared.opts.lora_fuse_diffusers:
+ shared.sd_model.fuse_lora(lora_scale=lora_scale)
+ net = network.Network(name, network_on_disk)
+ net.mtime = os.path.getmtime(network_on_disk.filename)
+ lora_cache[name] = net
+ t1 = time.time()
+ timer['load'] += t1 - t0
+ return net
+
+
+def load_network(name, network_on_disk) -> network.Network:
+ t0 = time.time()
+ cached = lora_cache.get(name, None)
+ if debug:
+ shared.log.debug(f'LoRA load: name="{name}" file="{network_on_disk.filename}" type=lora {"cached" if cached else ""}')
+ if cached is not None:
+ return cached
+ net = network.Network(name, network_on_disk)
+ net.mtime = os.path.getmtime(network_on_disk.filename)
+ sd = sd_models.read_state_dict(network_on_disk.filename)
+ assign_network_names_to_compvis_modules(shared.sd_model) # this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0
+ keys_failed_to_match = {}
+ matched_networks = {}
+ convert = lora_convert.KeyConvert()
+ for key_network, weight in sd.items():
+ parts = key_network.split('.')
+ if len(parts) > 5: # messy handler for diffusers peft lora
+ key_network_without_network_parts = '_'.join(parts[:-2])
+ if not key_network_without_network_parts.startswith('lora_'):
+ key_network_without_network_parts = 'lora_' + key_network_without_network_parts
+ network_part = '.'.join(parts[-2:]).replace('lora_A', 'lora_down').replace('lora_B', 'lora_up')
+ else:
+ key_network_without_network_parts, network_part = key_network.split(".", 1)
+ # if debug:
+ # shared.log.debug(f'LoRA load: name="{name}" full={key_network} network={network_part} key={key_network_without_network_parts}')
+ key, sd_module = convert(key_network_without_network_parts)
+ if sd_module is None:
+ keys_failed_to_match[key_network] = key
+ continue
+ if key not in matched_networks:
+ matched_networks[key] = network.NetworkWeights(network_key=key_network, sd_key=key, w={}, sd_module=sd_module)
+ matched_networks[key].w[network_part] = weight
+ for key, weights in matched_networks.items():
+ net_module = None
+ for nettype in module_types:
+ net_module = nettype.create_module(net, weights)
+ if net_module is not None:
+ break
+ if net_module is None:
+ shared.log.error(f'LoRA unhandled: name={name} key={key} weights={weights.w.keys()}')
+ else:
+ net.modules[key] = net_module
+ if len(keys_failed_to_match) > 0:
+ shared.log.warning(f"LoRA file={network_on_disk.filename} unmatched={len(keys_failed_to_match)} matched={len(matched_networks)}")
+ if debug:
+ shared.log.debug(f"LoRA file={network_on_disk.filename} unmatched={keys_failed_to_match}")
+ elif debug:
+ shared.log.debug(f"LoRA file={network_on_disk.filename} unmatched={len(keys_failed_to_match)} matched={len(matched_networks)}")
+ lora_cache[name] = net
+ t1 = time.time()
+ timer['load'] += t1 - t0
+ return net
+
+
+def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
+ networks_on_disk = [available_network_aliases.get(name, None) for name in names]
+ if any(x is None for x in networks_on_disk):
+ list_available_networks()
+ networks_on_disk = [available_network_aliases.get(name, None) for name in names]
+ failed_to_load_networks = []
+
+ recompile_model = False
+ if ((shared.opts.cuda_compile and shared.opts.cuda_compile_backend == "openvino_fx") or
+ shared.opts.nncf_compress_weights or shared.opts.nncf_compress_text_encoder_weights):
+ if len(names) == len(shared.compiled_model_state.lora_model):
+ for i, name in enumerate(names):
+ if shared.compiled_model_state.lora_model[i] != f"{name}:{te_multipliers[i] if te_multipliers else 1.0}":
+ recompile_model = True
+ shared.compiled_model_state.lora_model = []
+ break
+ if not recompile_model:
+ if len(loaded_networks) > 0 and debug:
+ shared.log.debug('OpenVINO: Skipping LoRa loading')
+ return
+ else:
+ recompile_model = True
+ shared.compiled_model_state.lora_model = []
+ if recompile_model:
+ backup_cuda_compile = shared.opts.cuda_compile
+ backup_nncf_compress_weights = shared.opts.nncf_compress_weights
+ backup_nncf_compress_text_encoder_weights = shared.opts.nncf_compress_text_encoder_weights
+ shared.compiled_model_state.lora_compile = True
+ sd_models.unload_model_weights(op='model')
+ shared.opts.cuda_compile = False
+ shared.opts.nncf_compress_weights = False
+ shared.opts.nncf_compress_text_encoder_weights = False
+ sd_models.reload_model_weights(op='model')
+ shared.opts.cuda_compile = backup_cuda_compile
+ shared.opts.nncf_compress_weights = backup_nncf_compress_weights
+ shared.opts.nncf_compress_text_encoder_weights = backup_nncf_compress_text_encoder_weights
+
+ loaded_networks.clear()
+ for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)):
+ net = None
+ if network_on_disk is not None:
+ if debug:
+ shared.log.debug(f'LoRA load start: name="{name}" file="{network_on_disk.filename}"')
+ try:
+ if recompile_model:
+ shared.compiled_model_state.lora_model.append(f"{name}:{te_multipliers[i] if te_multipliers else 1.0}")
+ if shared.backend == shared.Backend.DIFFUSERS and shared.opts.lora_force_diffusers: # OpenVINO only works with Diffusers LoRa loading.
+ # or getattr(network_on_disk, 'shorthash', '').lower() == 'aaebf6360f7d' # sd15-lcm
+ # or getattr(network_on_disk, 'shorthash', '').lower() == '3d18b05e4f56' # sdxl-lcm
+ # or getattr(network_on_disk, 'shorthash', '').lower() == '813ea5fb1c67' # turbo sdxl-turbo
+ net = load_diffusers(name, network_on_disk, lora_scale=te_multipliers[i] if te_multipliers else 1.0)
+ else:
+ net = load_network(name, network_on_disk)
+ except Exception as e:
+ shared.log.error(f"LoRA load failed: file={network_on_disk.filename} {e}")
+ if debug:
+ errors.display(e, f"LoRA load failed file={network_on_disk.filename}")
+ continue
+ net.mentioned_name = name
+ network_on_disk.read_hash()
+ if net is None:
+ failed_to_load_networks.append(name)
+ shared.log.error(f"LoRA unknown type: network={name}")
+ continue
+ net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0
+ net.unet_multiplier = unet_multipliers[i] if unet_multipliers else 1.0
+ net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0
+ loaded_networks.append(net)
+
+ while len(lora_cache) > shared.opts.lora_in_memory_limit:
+ name = next(iter(lora_cache))
+ lora_cache.pop(name, None)
+ if len(loaded_networks) > 0 and debug:
+ shared.log.debug(f'LoRA loaded={len(loaded_networks)} cache={list(lora_cache)}')
+ devices.torch_gc()
+
+ if recompile_model:
+ shared.log.info("LoRA recompiling model")
+ sd_models_compile.compile_diffusers(shared.sd_model)
+
+
+def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv]):
+ t0 = time.time()
+ weights_backup = getattr(self, "network_weights_backup", None)
+ bias_backup = getattr(self, "network_bias_backup", None)
+ if weights_backup is None and bias_backup is None:
+ return
+ # if debug:
+ # shared.log.debug('LoRA restore weights')
+ if weights_backup is not None:
+ if isinstance(self, torch.nn.MultiheadAttention):
+ self.in_proj_weight.copy_(weights_backup[0])
+ self.out_proj.weight.copy_(weights_backup[1])
+ else:
+ self.weight.copy_(weights_backup)
+ if bias_backup is not None:
+ if isinstance(self, torch.nn.MultiheadAttention):
+ self.out_proj.bias.copy_(bias_backup)
+ else:
+ self.bias.copy_(bias_backup)
+ else:
+ if isinstance(self, torch.nn.MultiheadAttention):
+ self.out_proj.bias = None
+ else:
+ self.bias = None
+ t1 = time.time()
+ timer['restore'] += t1 - t0
+
+
+def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv]):
+ """
+ Applies the currently selected set of networks to the weights of torch layer self.
+ If weights already have this particular set of networks applied, does nothing.
+ If not, restores orginal weights from backup and alters weights according to networks.
+ """
+ network_layer_name = getattr(self, 'network_layer_name', None)
+ if network_layer_name is None:
+ return
+ t0 = time.time()
+ current_names = getattr(self, "network_current_names", ())
+ wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks)
+ weights_backup = getattr(self, "network_weights_backup", None)
+ if weights_backup is None and wanted_names != (): # pylint: disable=C1803
+ if current_names != ():
+ raise RuntimeError("no backup weights found and current weights are not unchanged")
+ if isinstance(self, torch.nn.MultiheadAttention):
+ weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True))
+ else:
+ weights_backup = self.weight.to(devices.cpu, copy=True)
+ self.network_weights_backup = weights_backup
+ bias_backup = getattr(self, "network_bias_backup", None)
+ if bias_backup is None:
+ if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None:
+ bias_backup = self.out_proj.bias.to(devices.cpu, copy=True)
+ elif getattr(self, 'bias', None) is not None:
+ bias_backup = self.bias.to(devices.cpu, copy=True)
+ else:
+ bias_backup = None
+ self.network_bias_backup = bias_backup
+
+ if current_names != wanted_names:
+ network_restore_weights_from_backup(self)
+ for net in loaded_networks:
+ # default workflow where module is known and has weights
+ module = net.modules.get(network_layer_name, None)
+ if module is not None and hasattr(self, 'weight'):
+ try:
+ with devices.inference_context():
+ updown, ex_bias = module.calc_updown(self.weight)
+ if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
+ # inpainting model. zero pad updown to make channel[1] 4 to 9
+ updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) # pylint: disable=not-callable
+ self.weight += updown
+ if ex_bias is not None and hasattr(self, 'bias'):
+ if self.bias is None:
+ self.bias = torch.nn.Parameter(ex_bias)
+ else:
+ self.bias += ex_bias
+ except RuntimeError as e:
+ extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
+ if debug:
+ module_name = net.modules.get(network_layer_name, None)
+ shared.log.error(f"LoRA apply weight name={net.name} module={module_name} layer={network_layer_name} {e}")
+ errors.display(e, 'LoRA apply weight')
+ raise RuntimeError('LoRA apply weight') from e
+ continue
+ # alternative workflow looking at _*_proj layers
+ module_q = net.modules.get(network_layer_name + "_q_proj", None)
+ module_k = net.modules.get(network_layer_name + "_k_proj", None)
+ module_v = net.modules.get(network_layer_name + "_v_proj", None)
+ module_out = net.modules.get(network_layer_name + "_out_proj", None)
+ if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
+ try:
+ with devices.inference_context():
+ updown_q, _ = module_q.calc_updown(self.in_proj_weight)
+ updown_k, _ = module_k.calc_updown(self.in_proj_weight)
+ updown_v, _ = module_v.calc_updown(self.in_proj_weight)
+ updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
+ updown_out, ex_bias = module_out.calc_updown(self.out_proj.weight)
+ self.in_proj_weight += updown_qkv
+ self.out_proj.weight += updown_out
+ if ex_bias is not None:
+ if self.out_proj.bias is None:
+ self.out_proj.bias = torch.nn.Parameter(ex_bias)
+ else:
+ self.out_proj.bias += ex_bias
+ except RuntimeError as e:
+ if debug:
+ shared.log.debug(f"LoRA network={net.name} layer={network_layer_name} {e}")
+ extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
+ continue
+ if module is None:
+ continue
+ shared.log.warning(f"LoRA network={net.name} layer={network_layer_name} unsupported operation")
+ extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
+ self.network_current_names = wanted_names
+ t1 = time.time()
+ timer['apply'] += t1 - t0
+
+
+def network_forward(module, input, original_forward): # pylint: disable=W0622
+ """
+ Old way of applying Lora by executing operations during layer's forward.
+ Stacking many loras this way results in big performance degradation.
+ """
+ if len(loaded_networks) == 0:
+ return original_forward(module, input)
+ input = devices.cond_cast_unet(input)
+ network_restore_weights_from_backup(module)
+ network_reset_cached_weight(module)
+ y = original_forward(module, input)
+ network_layer_name = getattr(module, 'network_layer_name', None)
+ for lora in loaded_networks:
+ module = lora.modules.get(network_layer_name, None)
+ if module is None:
+ continue
+ y = module.forward(input, y)
+ return y
+
+
+def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
+ self.network_current_names = ()
+ self.network_weights_backup = None
+
+
+def network_Linear_forward(self, input): # pylint: disable=W0622
+ if shared.opts.lora_functional:
+ return network_forward(self, input, originals.Linear_forward)
+ network_apply_weights(self)
+ return originals.Linear_forward(self, input)
+
+
+def network_Linear_load_state_dict(self, *args, **kwargs):
+ network_reset_cached_weight(self)
+ return originals.Linear_load_state_dict(self, *args, **kwargs)
+
+
+def network_Conv2d_forward(self, input): # pylint: disable=W0622
+ if shared.opts.lora_functional:
+ return network_forward(self, input, originals.Conv2d_forward)
+ network_apply_weights(self)
+ return originals.Conv2d_forward(self, input)
+
+
+def network_Conv2d_load_state_dict(self, *args, **kwargs):
+ network_reset_cached_weight(self)
+ return originals.Conv2d_load_state_dict(self, *args, **kwargs)
+
+
+def network_GroupNorm_forward(self, input): # pylint: disable=W0622
+ if shared.opts.lora_functional:
+ return network_forward(self, input, originals.GroupNorm_forward)
+ network_apply_weights(self)
+ return originals.GroupNorm_forward(self, input)
+
+
+def network_GroupNorm_load_state_dict(self, *args, **kwargs):
+ network_reset_cached_weight(self)
+ return originals.GroupNorm_load_state_dict(self, *args, **kwargs)
+
+
+def network_LayerNorm_forward(self, input): # pylint: disable=W0622
+ if shared.opts.lora_functional:
+ return network_forward(self, input, originals.LayerNorm_forward)
+ network_apply_weights(self)
+ return originals.LayerNorm_forward(self, input)
+
+
+def network_LayerNorm_load_state_dict(self, *args, **kwargs):
+ network_reset_cached_weight(self)
+ return originals.LayerNorm_load_state_dict(self, *args, **kwargs)
+
+
+def network_MultiheadAttention_forward(self, *args, **kwargs):
+ network_apply_weights(self)
+ return originals.MultiheadAttention_forward(self, *args, **kwargs)
+
+
+def network_MultiheadAttention_load_state_dict(self, *args, **kwargs):
+ network_reset_cached_weight(self)
+ return originals.MultiheadAttention_load_state_dict(self, *args, **kwargs)
+
+
+def list_available_networks():
+ available_networks.clear()
+ available_network_aliases.clear()
+ forbidden_network_aliases.clear()
+ available_network_hash_lookup.clear()
+ forbidden_network_aliases.update({"none": 1, "Addams": 1})
+ os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
+ candidates = []
+ if os.path.exists(shared.cmd_opts.lora_dir):
+ candidates += list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
+ else:
+ shared.log.warning('LoRA directory not found: path="{shared.cmd_opts.lora_dir}"')
+ if os.path.exists(shared.cmd_opts.lyco_dir):
+ candidates += list(shared.walk_files(shared.cmd_opts.lyco_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
+
+ def add_network(filename):
+ if os.path.isdir(filename):
+ return
+ name = os.path.splitext(os.path.basename(filename))[0]
+ try:
+ entry = network.NetworkOnDisk(name, filename)
+ available_networks[entry.name] = entry
+ if entry.alias in available_network_aliases:
+ forbidden_network_aliases[entry.alias.lower()] = 1
+ available_network_aliases[entry.name] = entry
+ available_network_aliases[entry.alias] = entry
+ if entry.shorthash:
+ available_network_hash_lookup[entry.shorthash] = entry
+ except OSError as e: # should catch FileNotFoundError and PermissionError etc.
+ shared.log.error(f"Failed to load network {name} from {filename} {e}")
+
+ with concurrent.futures.ThreadPoolExecutor(max_workers=shared.max_workers) as executor:
+ for fn in candidates:
+ executor.submit(add_network, fn)
+
+
+def infotext_pasted(infotext, params): # pylint: disable=W0613
+ if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]:
+ return # if the other extension is active, it will handle those fields, no need to do anything
+ added = []
+ for k in params:
+ if not k.startswith("AddNet Model "):
+ continue
+ num = k[13:]
+ if params.get("AddNet Module " + num) != "LoRA":
+ continue
+ name = params.get("AddNet Model " + num)
+ if name is None:
+ continue
+ m = re_network_name.match(name)
+ if m:
+ name = m.group(1)
+ multiplier = params.get("AddNet Weight A " + num, "1.0")
+ added.append(f"")
+ if added:
+ params["Prompt"] += "\n" + "".join(added)
+
+
+list_available_networks()
diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py
index 8130d2b5f..4bda68638 100644
--- a/extensions-builtin/Lora/scripts/lora_script.py
+++ b/extensions-builtin/Lora/scripts/lora_script.py
@@ -1,62 +1,62 @@
-import re
-import networks
-import lora # noqa:F401 # pylint: disable=unused-import
-from network import NetworkOnDisk
-from ui_extra_networks_lora import ExtraNetworksPageLora
-from extra_networks_lora import ExtraNetworkLora
-from modules import script_callbacks, ui_extra_networks, extra_networks
-
-
-re_lora = re.compile(""),
- "local_preview": f"{path}.{shared.opts.samples_format}",
- "metadata": json.dumps(l.metadata, indent=4) if l.metadata else None,
- "mtime": os.path.getmtime(l.filename),
- "size": os.path.getsize(l.filename),
- }
- info = self.find_info(l.filename)
-
- tags = {}
- possible_tags = l.metadata.get('ss_tag_frequency', {}) if l.metadata is not None else {} # tags from model metedata
- if isinstance(possible_tags, str):
- possible_tags = {}
- for k, v in possible_tags.items():
- words = k.split('_', 1) if '_' in k else [v, k]
- words = [str(w).replace('.json', '') for w in words]
- if words[0] == '{}':
- words[0] = 0
- tag = ' '.join(words[1:])
- tags[tag] = words[0]
- versions = info.get('modelVersions', []) # trigger words from info json
- for v in versions:
- possible_tags = v.get('trainedWords', [])
- if isinstance(possible_tags, list):
- for tag in possible_tags:
- if tag not in tags:
- tags[tag] = 0
- search = {}
- possible_tags = info.get('tags', []) # tags from info json
- if not isinstance(possible_tags, list):
- possible_tags = [v for v in possible_tags.values()]
- for v in possible_tags:
- search[v] = 0
- if len(list(tags)) == 0:
- tags = search
-
- bad_chars = [';', ':', '<', ">", "*", '?', '\'', '\"']
- clean_tags = {}
- for k, v in tags.items():
- tag = ''.join(i for i in k if not i in bad_chars)
- clean_tags[tag] = v
-
- item["info"] = info
- item["description"] = self.find_description(l.filename, info) # use existing info instead of double-read
- item["tags"] = clean_tags
- item["search_term"] = f'{self.search_terms_from_path(l.filename)} {" ".join(tags.keys())} {" ".join(search.keys())}'
-
- return item
- except Exception as e:
- shared.log.debug(f"Extra networks error: type=lora file={name} {e}")
- return None
-
- def list_items(self):
- with concurrent.futures.ThreadPoolExecutor(max_workers=shared.max_workers) as executor:
- future_items = {executor.submit(self.create_item, net): net for net in networks.available_networks}
- for future in concurrent.futures.as_completed(future_items):
- item = future.result()
- if item is not None:
- yield item
-
- def allowed_directories_for_previews(self):
- return [shared.cmd_opts.lora_dir, shared.cmd_opts.lyco_dir]
+import os
+import json
+import concurrent
+import network
+import networks
+from modules import shared, ui_extra_networks
+
+
+class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
+ def __init__(self):
+ super().__init__('Lora')
+ self.list_time = 0
+
+ def refresh(self):
+ networks.list_available_networks()
+
+ def create_item(self, name):
+ l = networks.available_networks.get(name)
+ try:
+ path, _ext = os.path.splitext(l.filename)
+ name = os.path.splitext(os.path.relpath(l.filename, shared.cmd_opts.lora_dir))[0]
+ if shared.backend == shared.Backend.ORIGINAL:
+ if l.sd_version == network.SdVersion.SDXL:
+ return None
+ elif shared.backend == shared.Backend.DIFFUSERS:
+ if shared.sd_model_type == 'none': # return all when model is not loaded
+ pass
+ elif shared.sd_model_type == 'sdxl':
+ if l.sd_version == network.SdVersion.SD1 or l.sd_version == network.SdVersion.SD2:
+ return None
+ elif shared.sd_model_type == 'sd':
+ if l.sd_version == network.SdVersion.SDXL:
+ return None
+
+ item = {
+ "type": 'Lora',
+ "name": name,
+ "filename": l.filename,
+ "hash": l.shorthash,
+ "preview": self.find_preview(l.filename),
+ "prompt": json.dumps(f" "),
+ "local_preview": f"{path}.{shared.opts.samples_format}",
+ "metadata": json.dumps(l.metadata, indent=4) if l.metadata else None,
+ "mtime": os.path.getmtime(l.filename),
+ "size": os.path.getsize(l.filename),
+ }
+ info = self.find_info(l.filename)
+
+ tags = {}
+ possible_tags = l.metadata.get('ss_tag_frequency', {}) if l.metadata is not None else {} # tags from model metedata
+ if isinstance(possible_tags, str):
+ possible_tags = {}
+ for k, v in possible_tags.items():
+ words = k.split('_', 1) if '_' in k else [v, k]
+ words = [str(w).replace('.json', '') for w in words]
+ if words[0] == '{}':
+ words[0] = 0
+ tag = ' '.join(words[1:])
+ tags[tag] = words[0]
+ versions = info.get('modelVersions', []) # trigger words from info json
+ for v in versions:
+ possible_tags = v.get('trainedWords', [])
+ if isinstance(possible_tags, list):
+ for tag in possible_tags:
+ if tag not in tags:
+ tags[tag] = 0
+ search = {}
+ possible_tags = info.get('tags', []) # tags from info json
+ if not isinstance(possible_tags, list):
+ possible_tags = [v for v in possible_tags.values()]
+ for v in possible_tags:
+ search[v] = 0
+ if len(list(tags)) == 0:
+ tags = search
+
+ bad_chars = [';', ':', '<', ">", "*", '?', '\'', '\"']
+ clean_tags = {}
+ for k, v in tags.items():
+ tag = ''.join(i for i in k if not i in bad_chars)
+ clean_tags[tag] = v
+
+ item["info"] = info
+ item["description"] = self.find_description(l.filename, info) # use existing info instead of double-read
+ item["tags"] = clean_tags
+ item["search_term"] = f'{self.search_terms_from_path(l.filename)} {" ".join(tags.keys())} {" ".join(search.keys())}'
+
+ return item
+ except Exception as e:
+ shared.log.debug(f"Extra networks error: type=lora file={name} {e}")
+ return None
+
+ def list_items(self):
+ with concurrent.futures.ThreadPoolExecutor(max_workers=shared.max_workers) as executor:
+ future_items = {executor.submit(self.create_item, net): net for net in networks.available_networks}
+ for future in concurrent.futures.as_completed(future_items):
+ item = future.result()
+ if item is not None:
+ yield item
+
+ def allowed_directories_for_previews(self):
+ return [shared.cmd_opts.lora_dir, shared.cmd_opts.lyco_dir]
diff --git a/extensions-builtin/sd-webui-controlnet b/extensions-builtin/sd-webui-controlnet
index bb9483d46..f8f43a809 160000
--- a/extensions-builtin/sd-webui-controlnet
+++ b/extensions-builtin/sd-webui-controlnet
@@ -1 +1 @@
-Subproject commit bb9483d46f5a932fd35e8b4d04a3fdcc02dd9ff1
+Subproject commit f8f43a809fd0ce0ccb36d1abbe56fae3b8e18b60
diff --git a/html/licenses.html b/html/licenses.html
index 2ad803052..66b9c0bc6 100644
--- a/html/licenses.html
+++ b/html/licenses.html
@@ -1,690 +1,690 @@
-
-
-
-Parts of CodeFormer code had to be copied to be compatible with GFPGAN.
-
-S-Lab License 1.0
-
-Copyright 2022 S-Lab
-
-Redistribution and use for non-commercial purpose in source and
-binary forms, with or without modification, are permitted provided
-that the following conditions are met:
-
-1. Redistributions of source code must retain the above copyright
- notice, this list of conditions and the following disclaimer.
-
-2. Redistributions in binary form must reproduce the above copyright
- notice, this list of conditions and the following disclaimer in
- the documentation and/or other materials provided with the
- distribution.
-
-3. Neither the name of the copyright holder nor the names of its
- contributors may be used to endorse or promote products derived
- from this software without specific prior written permission.
-
-THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-In the event that redistribution and/or use for commercial purpose in
-source or binary forms, with or without modification is required,
-please contact the contributor(s) of the work.
-
-
-
-
-Code for architecture and reading models copied.
-
-MIT License
-
-Copyright (c) 2021 victorca25
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
-
-
-
-Some code is copied to support ESRGAN models.
-
-BSD 3-Clause License
-
-Copyright (c) 2021, Xintao Wang
-All rights reserved.
-
-Redistribution and use in source and binary forms, with or without
-modification, are permitted provided that the following conditions are met:
-
-1. Redistributions of source code must retain the above copyright notice, this
- list of conditions and the following disclaimer.
-
-2. Redistributions in binary form must reproduce the above copyright notice,
- this list of conditions and the following disclaimer in the documentation
- and/or other materials provided with the distribution.
-
-3. Neither the name of the copyright holder nor the names of its
- contributors may be used to endorse or promote products derived from
- this software without specific prior written permission.
-
-THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
-DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
-FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
-DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
-SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
-CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
-OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-
-
-Some code for compatibility with OSX is taken from lstein's repository.
-
-MIT License
-
-Copyright (c) 2022 InvokeAI Team
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
-
-
-
-Code added by contirubtors, most likely copied from this repository.
-
-MIT License
-
-Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
-
-
-
-Some small amounts of code borrowed and reworked.
-
-MIT License
-
-Copyright (c) 2022 pharmapsychotic
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
-
-
-
-Code added by contributors, most likely copied from this repository.
-
-
- Apache License
- Version 2.0, January 2004
- http://www.apache.org/licenses/
-
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
-
- 1. Definitions.
-
- "License" shall mean the terms and conditions for use, reproduction,
- and distribution as defined by Sections 1 through 9 of this document.
-
- "Licensor" shall mean the copyright owner or entity authorized by
- the copyright owner that is granting the License.
-
- "Legal Entity" shall mean the union of the acting entity and all
- other entities that control, are controlled by, or are under common
- control with that entity. For the purposes of this definition,
- "control" means (i) the power, direct or indirect, to cause the
- direction or management of such entity, whether by contract or
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
- outstanding shares, or (iii) beneficial ownership of such entity.
-
- "You" (or "Your") shall mean an individual or Legal Entity
- exercising permissions granted by this License.
-
- "Source" form shall mean the preferred form for making modifications,
- including but not limited to software source code, documentation
- source, and configuration files.
-
- "Object" form shall mean any form resulting from mechanical
- transformation or translation of a Source form, including but
- not limited to compiled object code, generated documentation,
- and conversions to other media types.
-
- "Work" shall mean the work of authorship, whether in Source or
- Object form, made available under the License, as indicated by a
- copyright notice that is included in or attached to the work
- (an example is provided in the Appendix below).
-
- "Derivative Works" shall mean any work, whether in Source or Object
- form, that is based on (or derived from) the Work and for which the
- editorial revisions, annotations, elaborations, or other modifications
- represent, as a whole, an original work of authorship. For the purposes
- of this License, Derivative Works shall not include works that remain
- separable from, or merely link (or bind by name) to the interfaces of,
- the Work and Derivative Works thereof.
-
- "Contribution" shall mean any work of authorship, including
- the original version of the Work and any modifications or additions
- to that Work or Derivative Works thereof, that is intentionally
- submitted to Licensor for inclusion in the Work by the copyright owner
- or by an individual or Legal Entity authorized to submit on behalf of
- the copyright owner. For the purposes of this definition, "submitted"
- means any form of electronic, verbal, or written communication sent
- to the Licensor or its representatives, including but not limited to
- communication on electronic mailing lists, source code control systems,
- and issue tracking systems that are managed by, or on behalf of, the
- Licensor for the purpose of discussing and improving the Work, but
- excluding communication that is conspicuously marked or otherwise
- designated in writing by the copyright owner as "Not a Contribution."
-
- "Contributor" shall mean Licensor and any individual or Legal Entity
- on behalf of whom a Contribution has been received by Licensor and
- subsequently incorporated within the Work.
-
- 2. Grant of Copyright License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- copyright license to reproduce, prepare Derivative Works of,
- publicly display, publicly perform, sublicense, and distribute the
- Work and such Derivative Works in Source or Object form.
-
- 3. Grant of Patent License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- (except as stated in this section) patent license to make, have made,
- use, offer to sell, sell, import, and otherwise transfer the Work,
- where such license applies only to those patent claims licensable
- by such Contributor that are necessarily infringed by their
- Contribution(s) alone or by combination of their Contribution(s)
- with the Work to which such Contribution(s) was submitted. If You
- institute patent litigation against any entity (including a
- cross-claim or counterclaim in a lawsuit) alleging that the Work
- or a Contribution incorporated within the Work constitutes direct
- or contributory patent infringement, then any patent licenses
- granted to You under this License for that Work shall terminate
- as of the date such litigation is filed.
-
- 4. Redistribution. You may reproduce and distribute copies of the
- Work or Derivative Works thereof in any medium, with or without
- modifications, and in Source or Object form, provided that You
- meet the following conditions:
-
- (a) You must give any other recipients of the Work or
- Derivative Works a copy of this License; and
-
- (b) You must cause any modified files to carry prominent notices
- stating that You changed the files; and
-
- (c) You must retain, in the Source form of any Derivative Works
- that You distribute, all copyright, patent, trademark, and
- attribution notices from the Source form of the Work,
- excluding those notices that do not pertain to any part of
- the Derivative Works; and
-
- (d) If the Work includes a "NOTICE" text file as part of its
- distribution, then any Derivative Works that You distribute must
- include a readable copy of the attribution notices contained
- within such NOTICE file, excluding those notices that do not
- pertain to any part of the Derivative Works, in at least one
- of the following places: within a NOTICE text file distributed
- as part of the Derivative Works; within the Source form or
- documentation, if provided along with the Derivative Works; or,
- within a display generated by the Derivative Works, if and
- wherever such third-party notices normally appear. The contents
- of the NOTICE file are for informational purposes only and
- do not modify the License. You may add Your own attribution
- notices within Derivative Works that You distribute, alongside
- or as an addendum to the NOTICE text from the Work, provided
- that such additional attribution notices cannot be construed
- as modifying the License.
-
- You may add Your own copyright statement to Your modifications and
- may provide additional or different license terms and conditions
- for use, reproduction, or distribution of Your modifications, or
- for any such Derivative Works as a whole, provided Your use,
- reproduction, and distribution of the Work otherwise complies with
- the conditions stated in this License.
-
- 5. Submission of Contributions. Unless You explicitly state otherwise,
- any Contribution intentionally submitted for inclusion in the Work
- by You to the Licensor shall be under the terms and conditions of
- this License, without any additional terms or conditions.
- Notwithstanding the above, nothing herein shall supersede or modify
- the terms of any separate license agreement you may have executed
- with Licensor regarding such Contributions.
-
- 6. Trademarks. This License does not grant permission to use the trade
- names, trademarks, service marks, or product names of the Licensor,
- except as required for reasonable and customary use in describing the
- origin of the Work and reproducing the content of the NOTICE file.
-
- 7. Disclaimer of Warranty. Unless required by applicable law or
- agreed to in writing, Licensor provides the Work (and each
- Contributor provides its Contributions) on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
- implied, including, without limitation, any warranties or conditions
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
- PARTICULAR PURPOSE. You are solely responsible for determining the
- appropriateness of using or redistributing the Work and assume any
- risks associated with Your exercise of permissions under this License.
-
- 8. Limitation of Liability. In no event and under no legal theory,
- whether in tort (including negligence), contract, or otherwise,
- unless required by applicable law (such as deliberate and grossly
- negligent acts) or agreed to in writing, shall any Contributor be
- liable to You for damages, including any direct, indirect, special,
- incidental, or consequential damages of any character arising as a
- result of this License or out of the use or inability to use the
- Work (including but not limited to damages for loss of goodwill,
- work stoppage, computer failure or malfunction, or any and all
- other commercial damages or losses), even if such Contributor
- has been advised of the possibility of such damages.
-
- 9. Accepting Warranty or Additional Liability. While redistributing
- the Work or Derivative Works thereof, You may choose to offer,
- and charge a fee for, acceptance of support, warranty, indemnity,
- or other liability obligations and/or rights consistent with this
- License. However, in accepting such obligations, You may act only
- on Your own behalf and on Your sole responsibility, not on behalf
- of any other Contributor, and only if You agree to indemnify,
- defend, and hold each Contributor harmless for any liability
- incurred by, or claims asserted against, such Contributor by reason
- of your accepting any such warranty or additional liability.
-
- END OF TERMS AND CONDITIONS
-
- APPENDIX: How to apply the Apache License to your work.
-
- To apply the Apache License to your work, attach the following
- boilerplate notice, with the fields enclosed by brackets "[]"
- replaced with your own identifying information. (Don't include
- the brackets!) The text should be enclosed in the appropriate
- comment syntax for the file format. We also recommend that a
- file or class name and description of purpose be included on the
- same "printed page" as the copyright notice for easier
- identification within third-party archives.
-
- Copyright [2021] [SwinIR Authors]
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
-
-
-
-The sub-quadratic cross attention optimization uses modified code from the Memory Efficient Attention package that Alex Birch optimized for 3D tensors. This license is updated to reflect that.
-
-MIT License
-
-Copyright (c) 2023 Alex Birch
-Copyright (c) 2023 Amin Rezaei
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
-
-
-
-Some small amounts of code borrowed and reworked.
-
- Copyright 2023 The HuggingFace Team. All rights reserved.
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
-
- Apache License
- Version 2.0, January 2004
- http://www.apache.org/licenses/
-
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
-
- 1. Definitions.
-
- "License" shall mean the terms and conditions for use, reproduction,
- and distribution as defined by Sections 1 through 9 of this document.
-
- "Licensor" shall mean the copyright owner or entity authorized by
- the copyright owner that is granting the License.
-
- "Legal Entity" shall mean the union of the acting entity and all
- other entities that control, are controlled by, or are under common
- control with that entity. For the purposes of this definition,
- "control" means (i) the power, direct or indirect, to cause the
- direction or management of such entity, whether by contract or
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
- outstanding shares, or (iii) beneficial ownership of such entity.
-
- "You" (or "Your") shall mean an individual or Legal Entity
- exercising permissions granted by this License.
-
- "Source" form shall mean the preferred form for making modifications,
- including but not limited to software source code, documentation
- source, and configuration files.
-
- "Object" form shall mean any form resulting from mechanical
- transformation or translation of a Source form, including but
- not limited to compiled object code, generated documentation,
- and conversions to other media types.
-
- "Work" shall mean the work of authorship, whether in Source or
- Object form, made available under the License, as indicated by a
- copyright notice that is included in or attached to the work
- (an example is provided in the Appendix below).
-
- "Derivative Works" shall mean any work, whether in Source or Object
- form, that is based on (or derived from) the Work and for which the
- editorial revisions, annotations, elaborations, or other modifications
- represent, as a whole, an original work of authorship. For the purposes
- of this License, Derivative Works shall not include works that remain
- separable from, or merely link (or bind by name) to the interfaces of,
- the Work and Derivative Works thereof.
-
- "Contribution" shall mean any work of authorship, including
- the original version of the Work and any modifications or additions
- to that Work or Derivative Works thereof, that is intentionally
- submitted to Licensor for inclusion in the Work by the copyright owner
- or by an individual or Legal Entity authorized to submit on behalf of
- the copyright owner. For the purposes of this definition, "submitted"
- means any form of electronic, verbal, or written communication sent
- to the Licensor or its representatives, including but not limited to
- communication on electronic mailing lists, source code control systems,
- and issue tracking systems that are managed by, or on behalf of, the
- Licensor for the purpose of discussing and improving the Work, but
- excluding communication that is conspicuously marked or otherwise
- designated in writing by the copyright owner as "Not a Contribution."
-
- "Contributor" shall mean Licensor and any individual or Legal Entity
- on behalf of whom a Contribution has been received by Licensor and
- subsequently incorporated within the Work.
-
- 2. Grant of Copyright License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- copyright license to reproduce, prepare Derivative Works of,
- publicly display, publicly perform, sublicense, and distribute the
- Work and such Derivative Works in Source or Object form.
-
- 3. Grant of Patent License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- (except as stated in this section) patent license to make, have made,
- use, offer to sell, sell, import, and otherwise transfer the Work,
- where such license applies only to those patent claims licensable
- by such Contributor that are necessarily infringed by their
- Contribution(s) alone or by combination of their Contribution(s)
- with the Work to which such Contribution(s) was submitted. If You
- institute patent litigation against any entity (including a
- cross-claim or counterclaim in a lawsuit) alleging that the Work
- or a Contribution incorporated within the Work constitutes direct
- or contributory patent infringement, then any patent licenses
- granted to You under this License for that Work shall terminate
- as of the date such litigation is filed.
-
- 4. Redistribution. You may reproduce and distribute copies of the
- Work or Derivative Works thereof in any medium, with or without
- modifications, and in Source or Object form, provided that You
- meet the following conditions:
-
- (a) You must give any other recipients of the Work or
- Derivative Works a copy of this License; and
-
- (b) You must cause any modified files to carry prominent notices
- stating that You changed the files; and
-
- (c) You must retain, in the Source form of any Derivative Works
- that You distribute, all copyright, patent, trademark, and
- attribution notices from the Source form of the Work,
- excluding those notices that do not pertain to any part of
- the Derivative Works; and
-
- (d) If the Work includes a "NOTICE" text file as part of its
- distribution, then any Derivative Works that You distribute must
- include a readable copy of the attribution notices contained
- within such NOTICE file, excluding those notices that do not
- pertain to any part of the Derivative Works, in at least one
- of the following places: within a NOTICE text file distributed
- as part of the Derivative Works; within the Source form or
- documentation, if provided along with the Derivative Works; or,
- within a display generated by the Derivative Works, if and
- wherever such third-party notices normally appear. The contents
- of the NOTICE file are for informational purposes only and
- do not modify the License. You may add Your own attribution
- notices within Derivative Works that You distribute, alongside
- or as an addendum to the NOTICE text from the Work, provided
- that such additional attribution notices cannot be construed
- as modifying the License.
-
- You may add Your own copyright statement to Your modifications and
- may provide additional or different license terms and conditions
- for use, reproduction, or distribution of Your modifications, or
- for any such Derivative Works as a whole, provided Your use,
- reproduction, and distribution of the Work otherwise complies with
- the conditions stated in this License.
-
- 5. Submission of Contributions. Unless You explicitly state otherwise,
- any Contribution intentionally submitted for inclusion in the Work
- by You to the Licensor shall be under the terms and conditions of
- this License, without any additional terms or conditions.
- Notwithstanding the above, nothing herein shall supersede or modify
- the terms of any separate license agreement you may have executed
- with Licensor regarding such Contributions.
-
- 6. Trademarks. This License does not grant permission to use the trade
- names, trademarks, service marks, or product names of the Licensor,
- except as required for reasonable and customary use in describing the
- origin of the Work and reproducing the content of the NOTICE file.
-
- 7. Disclaimer of Warranty. Unless required by applicable law or
- agreed to in writing, Licensor provides the Work (and each
- Contributor provides its Contributions) on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
- implied, including, without limitation, any warranties or conditions
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
- PARTICULAR PURPOSE. You are solely responsible for determining the
- appropriateness of using or redistributing the Work and assume any
- risks associated with Your exercise of permissions under this License.
-
- 8. Limitation of Liability. In no event and under no legal theory,
- whether in tort (including negligence), contract, or otherwise,
- unless required by applicable law (such as deliberate and grossly
- negligent acts) or agreed to in writing, shall any Contributor be
- liable to You for damages, including any direct, indirect, special,
- incidental, or consequential damages of any character arising as a
- result of this License or out of the use or inability to use the
- Work (including but not limited to damages for loss of goodwill,
- work stoppage, computer failure or malfunction, or any and all
- other commercial damages or losses), even if such Contributor
- has been advised of the possibility of such damages.
-
- 9. Accepting Warranty or Additional Liability. While redistributing
- the Work or Derivative Works thereof, You may choose to offer,
- and charge a fee for, acceptance of support, warranty, indemnity,
- or other liability obligations and/or rights consistent with this
- License. However, in accepting such obligations, You may act only
- on Your own behalf and on Your sole responsibility, not on behalf
- of any other Contributor, and only if You agree to indemnify,
- defend, and hold each Contributor harmless for any liability
- incurred by, or claims asserted against, such Contributor by reason
- of your accepting any such warranty or additional liability.
-
- END OF TERMS AND CONDITIONS
-
- APPENDIX: How to apply the Apache License to your work.
-
- To apply the Apache License to your work, attach the following
- boilerplate notice, with the fields enclosed by brackets "[]"
- replaced with your own identifying information. (Don't include
- the brackets!) The text should be enclosed in the appropriate
- comment syntax for the file format. We also recommend that a
- file or class name and description of purpose be included on the
- same "printed page" as the copyright notice for easier
- identification within third-party archives.
-
- Copyright [yyyy] [name of copyright owner]
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
-
-
-
-The MPS workaround for nn.Linear on macOS 13.2.X is based on the MPS workaround for nn.Linear created by danieldk for Curated transformers
-
-The MIT License (MIT)
-
-Copyright (C) 2021 ExplosionAI GmbH
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in
-all copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-THE SOFTWARE.
-
-
-
-Tiny AutoEncoder for Stable Diffusion option for live previews
-
-MIT License
-
-Copyright (c) 2023 Ollin Boer Bohan
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
-
+
+
+
+Parts of CodeFormer code had to be copied to be compatible with GFPGAN.
+
+S-Lab License 1.0
+
+Copyright 2022 S-Lab
+
+Redistribution and use for non-commercial purpose in source and
+binary forms, with or without modification, are permitted provided
+that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+2. Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in
+ the documentation and/or other materials provided with the
+ distribution.
+
+3. Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived
+ from this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+In the event that redistribution and/or use for commercial purpose in
+source or binary forms, with or without modification is required,
+please contact the contributor(s) of the work.
+
+
+
+
+Code for architecture and reading models copied.
+
+MIT License
+
+Copyright (c) 2021 victorca25
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+
+
+Some code is copied to support ESRGAN models.
+
+BSD 3-Clause License
+
+Copyright (c) 2021, Xintao Wang
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+2. Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+3. Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+
+
+Some code for compatibility with OSX is taken from lstein's repository.
+
+MIT License
+
+Copyright (c) 2022 InvokeAI Team
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+
+
+Code added by contirubtors, most likely copied from this repository.
+
+MIT License
+
+Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+
+
+Some small amounts of code borrowed and reworked.
+
+MIT License
+
+Copyright (c) 2022 pharmapsychotic
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+
+
+Code added by contributors, most likely copied from this repository.
+
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [2021] [SwinIR Authors]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+
+
+
+The sub-quadratic cross attention optimization uses modified code from the Memory Efficient Attention package that Alex Birch optimized for 3D tensors. This license is updated to reflect that.
+
+MIT License
+
+Copyright (c) 2023 Alex Birch
+Copyright (c) 2023 Amin Rezaei
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+
+
+Some small amounts of code borrowed and reworked.
+
+ Copyright 2023 The HuggingFace Team. All rights reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+
+
+
+The MPS workaround for nn.Linear on macOS 13.2.X is based on the MPS workaround for nn.Linear created by danieldk for Curated transformers
+
+The MIT License (MIT)
+
+Copyright (C) 2021 ExplosionAI GmbH
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
+
+
+
+Tiny AutoEncoder for Stable Diffusion option for live previews
+
+MIT License
+
+Copyright (c) 2023 Ollin Boer Bohan
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
diff --git a/javascript/emerald-paradise.css b/javascript/emerald-paradise.css
index db96216f1..1fe9a7424 100644
--- a/javascript/emerald-paradise.css
+++ b/javascript/emerald-paradise.css
@@ -1,297 +1,297 @@
-/* generic html tags */
-:root, .light, .dark {
- --font: 'system-ui', 'ui-sans-serif', 'system-ui', "Roboto", sans-serif;
- --font-mono: 'ui-monospace', 'Consolas', monospace;
- --font-size: 16px;
- --primary-100: #1e2223; /* bg color*/
- --primary-200: #242a2c; /* drop down menu/ prompt window fill*/
- --primary-300: #0a0c0e; /* black */
- --primary-400: #2a302c; /* small buttons*/
- --primary-500: #4b695d; /* main accent color green*/
- --primary-700: #273538; /* extension box fill*/
- --primary-800: #d15e84; /* pink(hover accent)*/
- --highlight-color: var(--primary-500);
- --inactive-color: var(--primary--800);
- --body-text-color: var(--neutral-100);
- --body-text-color-subdued: var(--neutral-300);
- --background-color: var(--primary-100);
- --background-fill-primary: var(--input-background-fill);
- --input-padding: 8px;
- --input-background-fill: var(--primary-200);
- --input-shadow: none;
- --button-secondary-text-color: white;
- --button-secondary-background-fill: var(--primary-400);
- --button-secondary-background-fill-hover: var(--primary-700);
- --block-title-text-color: var(--neutral-300);
- --radius-sm: 1px;
- --radius-lg: 6px;
- --spacing-md: 4px;
- --spacing-xxl: 8px;
- --line-sm: 1.2em;
- --line-md: 1.4em;
-}
-
-html { font-size: var(--font-size); }
-body, button, input, select, textarea { font-family: var(--font);}
-button { max-width: 400px; }
-img { background-color: var(--background-color); }
-input[type=range] { height: var(--line-sm); appearance: none; margin-top: 0; min-width: 160px; background-color: var(--background-color); width: 100%; background: transparent; }
-input[type=range]::-webkit-slider-runnable-track, input[type=range]::-moz-range-track { width: 100%; height: 6px; cursor: pointer; background: var(--primary-400); border-radius: var(--radius-lg); border: 0px solid #222222; }
-input[type=range]::-webkit-slider-thumb, input[type=range]::-moz-range-thumb { border: 0px solid #000000; height: var(--line-sm); width: 8px; border-radius: var(--radius-lg); background: white; cursor: pointer; appearance: none; margin-top: 0px; }
-input[type=range]::-moz-range-progress { background-color: var(--primary-500); height: 6px; border-radius: var(--radius-lg); }
-::-webkit-scrollbar-track { background: #333333; }
-::-webkit-scrollbar-thumb { background-color: var(--highlight-color); border-radius: var(--radius-lg); border-width: 0; box-shadow: 2px 2px 3px #111111; }
-div.form { border-width: 0; box-shadow: none; background: transparent; overflow: visible; margin-bottom: 6px; }
-div.compact { gap: 1em; }
-
-/* gradio style classes */
-fieldset .gr-block.gr-box, label.block span { padding: 0; margin-top: -4px; }
-.border-2 { border-width: 0; }
-.border-b-2 { border-bottom-width: 2px; border-color: var(--highlight-color) !important; padding-bottom: 2px; margin-bottom: 8px; }
-.bg-white { color: lightyellow; background-color: var(--inactive-color); }
-.gr-box { border-radius: var(--radius-sm) !important; background-color: #111111 !important; box-shadow: 2px 2px 3px #111111; border-width: 0; padding: 4px; margin: 12px 0px 12px 0px }
-.gr-button { font-weight: normal; box-shadow: 2px 2px 3px #111111; font-size: 0.8rem; min-width: 32px; min-height: 32px; padding: 3px; margin: 3px; }
-.gr-check-radio { background-color: var(--inactive-color); border-width: 0; border-radius: var(--radius-lg); box-shadow: 2px 2px 3px #111111; }
-.gr-check-radio:checked { background-color: var(--highlight-color); }
-.gr-compact { background-color: var(--background-color); }
-.gr-form { border-width: 0; }
-.gr-input { background-color: #333333 !important; padding: 4px; margin: 4px; }
-.gr-input-label { color: lightyellow; border-width: 0; background: transparent; padding: 2px !important; }
-.gr-panel { background-color: var(--background-color); }
-.eta-bar { display: none !important }
-svg.feather.feather-image, .feather .feather-image { display: none }
-.gap-2 { padding-top: 8px; }
-.gr-box > div > div > input.gr-text-input { right: 0; width: 4em; padding: 0; top: -12px; border: none; max-height: 20px; }
-.output-html { line-height: 1.2rem; overflow-x: hidden; }
-.output-html > div { margin-bottom: 8px; }
-.overflow-hidden .flex .flex-col .relative col .gap-4 { min-width: var(--left-column); max-width: var(--left-column); } /* this is a problematic one */
-.p-2 { padding: 0; }
-.px-4 { padding-lefT: 1rem; padding-right: 1rem; }
-.py-6 { padding-bottom: 0; }
-.tabs { background-color: var(--background-color); }
-.block.token-counter span { background-color: var(--input-background-fill) !important; box-shadow: 2px 2px 2px #111; border: none !important; font-size: 0.8rem; }
-.tab-nav { zoom: 110%; margin-top: 10px; margin-bottom: 10px; border-bottom: 2px solid var(--highlight-color) !important; padding-bottom: 2px; }
-div.tab-nav button.selected {background-color: var(--button-primary-background-fill);}
-#settings div.tab-nav button.selected {background-color: var(--background-color); color: var(--primary-800); font-weight: bold;}
-.label-wrap { background-color: #191919; /* extension tab color*/ padding: 16px 8px 8px 8px; border-radius: var(--radius-lg); padding-left: 8px !important; }
-.small-accordion .label-wrap { padding: 8px 0px 8px 0px; }
-.small-accordion .label-wrap .icon { margin-right: 1em; }
-.gradio-button.tool { border: none; box-shadow: none; border-radius: var(--radius-lg);}
-button.selected {background: var(--button-primary-background-fill);}
-.center.boundedheight.flex {background-color: var(--input-background-fill);}
-.compact {border-radius: var(--border-radius-lg);}
-#logMonitorData {background-color: var(--input-background-fill);}
-#tab_extensions table td, #tab_extensions table th, #tab_config table td, #tab_config table th { border: none; padding: 0.5em; background-color: var(--primary-200); }
-#tab_extensions table, #tab_config table { width: 96vw; }
-#tab_extensions table input[type=checkbox] {appearance: none; border-radius: 0px;}
-#tab_extensions button:hover { background-color: var(--button-secondary-background-fill-hover);}
-
-/* automatic style classes */
-.progressDiv { border-radius: var(--radius-sm) !important; position: fixed; top: 44px; right: 26px; max-width: 262px; height: 48px; z-index: 99; box-shadow: var(--button-shadow); }
-.progressDiv .progress { border-radius: var(--radius-lg) !important; background: var(--highlight-color); line-height: 3rem; height: 48px; }
-.gallery-item { box-shadow: none !important; }
-.performance { color: #888; }
-.extra-networks { border-left: 2px solid var(--highlight-color) !important; padding-left: 4px; }
-.image-buttons { gap: 10px !important; justify-content: center; }
-.image-buttons > button { max-width: 160px; }
-.tooltip { background: var(--primary-800); color: white; border: none; border-radius: var(--radius-lg) }
-#system_row > button, #settings_row > button, #config_row > button { max-width: 10em; }
-
-/* gradio elements overrides */
-#div.gradio-container { overflow-x: hidden; }
-#img2img_label_copy_to_img2img { font-weight: normal; }
-#txt2img_prompt, #txt2img_neg_prompt, #img2img_prompt, #img2img_neg_prompt { background-color: var(--background-color); box-shadow: 4px 4px 4px 0px #333333 !important; }
-#txt2img_prompt > label > textarea, #txt2img_neg_prompt > label > textarea, #img2img_prompt > label > textarea, #img2img_neg_prompt > label > textarea { font-size: 1.1rem; }
-#img2img_settings { min-width: calc(2 * var(--left-column)); max-width: calc(2 * var(--left-column)); background-color: #111111; padding-top: 16px; }
-#interrogate, #deepbooru { margin: 0 0px 10px 0px; max-width: 80px; max-height: 80px; font-weight: normal; font-size: 0.95em; }
-#quicksettings .gr-button-tool { font-size: 1.6rem; box-shadow: none; margin-top: -2px; height: 2.4em; }
-#quicksettings button {padding: 0 0.5em 0.1em 0.5em;}
-#open_folder_extras, #footer, #style_pos_col, #style_neg_col, #roll_col, #extras_upscaler_2, #extras_upscaler_2_visibility, #txt2img_seed_resize_from_w, #txt2img_seed_resize_from_h { display: none; }
-#save-animation { border-radius: var(--radius-sm) !important; margin-bottom: 16px; background-color: #111111; }
-#script_list { padding: 4px; margin-top: 16px; margin-bottom: 8px; }
-#settings > div.flex-wrap { width: 15em; }
-#txt2img_cfg_scale { min-width: 200px; }
-#txt2img_checkboxes, #img2img_checkboxes { background-color: transparent; }
-#txt2img_checkboxes, #img2img_checkboxes { margin-bottom: 0.2em; }
-#txt2img_actions_column, #img2img_actions_column { flex-flow: wrap; justify-content: space-between; }
-#txt2img_enqueue_wrapper, #img2img_enqueue_wrapper { min-width: unset; width: 48%; }
-#txt2img_generate_box, #img2img_generate_box { min-width: unset; width: 48%; }
-
-#extras_upscale { margin-top: 10px }
-#txt2img_progress_row > div { min-width: var(--left-column); max-width: var(--left-column); }
-#txt2img_settings { min-width: var(--left-column); max-width: var(--left-column); background-color: #111111; padding-top: 16px; }
-#pnginfo_html2_info { margin-top: -18px; background-color: var(--input-background-fill); padding: var(--input-padding) }
-#txt2img_tools, #img2img_tools { margin-top: -4px; margin-bottom: -4px; }
-#txt2img_styles_row, #img2img_styles_row { margin-top: -6px; z-index: 200; }
-
-/* based on gradio built-in dark theme */
-:root, .light, .dark {
- --body-background-fill: var(--background-color);
- --color-accent-soft: var(--neutral-700);
- --background-fill-secondary: none;
- --border-color-accent: var(--background-color);
- --border-color-primary: var(--background-color);
- --link-text-color-active: var(--primary-500);
- --link-text-color: var(--secondary-500);
- --link-text-color-hover: var(--secondary-400);
- --link-text-color-visited: var(--secondary-600);
- --shadow-spread: 1px;
- --block-background-fill: None;
- --block-border-color: var(--border-color-primary);
- --block_border_width: None;
- --block-info-text-color: var(--body-text-color-subdued);
- --block-label-background-fill: var(--background-fill-secondary);
- --block-label-border-color: var(--border-color-primary);
- --block_label_border_width: None;
- --block-label-text-color: var(--neutral-200);
- --block_shadow: None;
- --block_title_background_fill: None;
- --block_title_border_color: None;
- --block_title_border_width: None;
- --panel-background-fill: var(--background-fill-secondary);
- --panel-border-color: var(--border-color-primary);
- --panel_border_width: None;
- --checkbox-background-color: var(--primary-200);
- --checkbox-background-color-focus: var(--primary-700);
- --checkbox-background-color-hover: var(--primary-700);
- --checkbox-background-color-selected: var(--primary-500);
- --checkbox-border-color: transparent;
- --checkbox-border-color-focus: var(--primary-800);
- --checkbox-border-color-hover: var(--primary-800);
- --checkbox-border-color-selected: var(--primary-800);
- --checkbox-border-width: var(--input-border-width);
- --checkbox-label-background-fill: None;
- --checkbox-label-background-fill-hover: None;
- --checkbox-label-background-fill-selected: var(--checkbox-label-background-fill);
- --checkbox-label-border-color: var(--border-color-primary);
- --checkbox-label-border-color-hover: var(--checkbox-label-border-color);
- --checkbox-label-border-width: var(--input-border-width);
- --checkbox-label-text-color: var(--body-text-color);
- --checkbox-label-text-color-selected: var(--checkbox-label-text-color);
- --error-background-fill: var(--background-fill-primary);
- --error-border-color: var(--border-color-primary);
- --error-text-color: #f768b7; /*was ef4444*/
- --input-background-fill-focus: var(--secondary-600);
- --input-background-fill-hover: var(--input-background-fill);
- --input-border-color: var(--background-color);
- --input-border-color-focus: var(--primary-800);
- --input-placeholder-color: var(--neutral-500);
- --input-shadow-focus: None;
- --loader_color: None;
- --slider_color: None;
- --stat-background-fill: linear-gradient(to right, var(--primary-400), var(--primary-800));
- --table-border-color: var(--neutral-700);
- --table-even-background-fill: var(--primary-300);
- --table-odd-background-fill: var(--primary-200);
- --table-row-focus: var(--color-accent-soft);
- --button-border-width: var(--input-border-width);
- --button-cancel-background-fill: linear-gradient(to bottom right, #dc2626, #b91c1c);
- --button-cancel-background-fill-hover: linear-gradient(to bottom right, #dc2626, #dc2626);
- --button-cancel-border-color: #dc2626;
- --button-cancel-border-color-hover: var(--button-cancel-border-color);
- --button-cancel-text-color: white;
- --button-cancel-text-color-hover: var(--button-cancel-text-color);
- --button-primary-background-fill: var(--primary-500);
- --button-primary-background-fill-hover: var(--primary-800);
- --button-primary-border-color: var(--primary-500);
- --button-primary-border-color-hover: var(--button-primary-border-color);
- --button-primary-text-color: white;
- --button-primary-text-color-hover: var(--button-primary-text-color);
- --button-secondary-border-color: var(--neutral-600);
- --button-secondary-border-color-hover: var(--button-secondary-border-color);
- --button-secondary-text-color-hover: var(--button-secondary-text-color);
- --secondary-50: #eff6ff;
- --secondary-100: #dbeafe;
- --secondary-200: #bfdbfe;
- --secondary-300: #93c5fd;
- --secondary-400: #60a5fa;
- --secondary-500: #3b82f6;
- --secondary-600: #2563eb;
- --secondary-700: #1d4ed8;
- --secondary-800: #1e40af;
- --secondary-900: #1e3a8a;
- --secondary-950: #1d3660;
- --neutral-50: #f0f0f0; /* */
- --neutral-100: #e8e8e3;/* majority of text (neutral gray yellow) */
- --neutral-200: #d0d0d0;
- --neutral-300: #b3b5ac; /* top tab /sub text (light accent) */
- --neutral-400: #ffba85;/* tab title (bright orange) */
- --neutral-500: #48665b; /* prompt text (desat accent)*/
- --neutral-600: #373f39; /* tab outline color (accent color)*/
- --neutral-700: #2b373b; /* small settings tab accent */
- --neutral-800: #f379c2; /* bright pink accent */
- --neutral-900: #111827;
- --neutral-950: #0b0f19;
- --radius-xxs: 0;
- --radius-xs: 0;
- --radius-md: 0;
- --radius-xl: 0;
- --radius-xxl: 0;
- --body-text-size: var(--text-md);
- --body-text-weight: 400;
- --embed-radius: var(--radius-lg);
- --color-accent: var(--primary-500);
- --shadow-drop: 0;
- --shadow-drop-lg: 0 1px 3px 0 rgb(0 0 0 / 0.1), 0 1px 2px -1px rgb(0 0 0 / 0.1);
- --shadow-inset: rgba(0,0,0,0.05) 0px 2px 4px 0px inset;
- --block-border-width: 1px;
- --block-info-text-size: var(--text-sm);
- --block-info-text-weight: 400;
- --block-label-border-width: 1px;
- --block-label-margin: 0;
- --block-label-padding: var(--spacing-sm) var(--spacing-lg);
- --block-label-radius: calc(var(--radius-lg) - 1px) 0 calc(var(--radius-lg) - 1px) 0;
- --block-label-right-radius: 0 calc(var(--radius-lg) - 1px) 0 calc(var(--radius-lg) - 1px);
- --block-label-text-size: var(--text-sm);
- --block-label-text-weight: 400;
- --block-padding: var(--spacing-xl) calc(var(--spacing-xl) + 2px);
- --block-radius: var(--radius-lg);
- --block-shadow: var(--shadow-drop);
- --block-title-background-fill: none;
- --block-title-border-color: none;
- --block-title-border-width: 0;
- --block-title-padding: 0;
- --block-title-radius: none;
- --block-title-text-size: var(--text-md);
- --block-title-text-weight: 400;
- --container-radius: var(--radius-lg);
- --form-gap-width: 1px;
- --layout-gap: var(--spacing-xxl);
- --panel-border-width: 0;
- --section-header-text-size: var(--text-md);
- --section-header-text-weight: 400;
- --checkbox-border-radius: var(--radius-sm);
- --checkbox-label-gap: 2px;
- --checkbox-label-padding: var(--spacing-md);
- --checkbox-label-shadow: var(--shadow-drop);
- --checkbox-label-text-size: var(--text-md);
- --checkbox-label-text-weight: 400;
- --checkbox-check: url("data:image/svg+xml,%3csvg viewBox='0 0 16 16' fill='white' xmlns='http://www.w3.org/2000/svg'%3e%3cpath d='M12.207 4.793a1 1 0 010 1.414l-5 5a1 1 0 01-1.414 0l-2-2a1 1 0 011.414-1.414L6.5 9.086l4.293-4.293a1 1 0 011.414 0z'/%3e%3c/svg%3e");
- --radio-circle: url("data:image/svg+xml,%3csvg viewBox='0 0 16 16' fill='white' xmlns='http://www.w3.org/2000/svg'%3e%3ccircle cx='8' cy='8' r='3'/%3e%3c/svg%3e");
- --checkbox-shadow: var(--input-shadow);
- --error-border-width: 1px;
- --input-border-width: 1px;
- --input-radius: var(--radius-lg);
- --input-text-size: var(--text-md);
- --input-text-weight: 400;
- --loader-color: var(--color-accent);
- --prose-text-size: var(--text-md);
- --prose-text-weight: 400;
- --prose-header-text-weight: 600;
- --slider-color: ;
- --table-radius: var(--radius-lg);
- --button-large-padding: 2px 6px;
- --button-large-radius: var(--radius-lg);
- --button-large-text-size: var(--text-lg);
- --button-large-text-weight: 400;
- --button-shadow: none;
- --button-shadow-active: none;
- --button-shadow-hover: none;
- --button-small-padding: var(--spacing-sm) calc(2 * var(--spacing-sm));
- --button-small-radius: var(--radius-lg);
- --button-small-text-size: var(--text-md);
- --button-small-text-weight: 400;
- --button-transition: none;
- --size-9: 64px;
- --size-14: 64px;
-}
+/* generic html tags */
+:root, .light, .dark {
+ --font: 'system-ui', 'ui-sans-serif', 'system-ui', "Roboto", sans-serif;
+ --font-mono: 'ui-monospace', 'Consolas', monospace;
+ --font-size: 16px;
+ --primary-100: #1e2223; /* bg color*/
+ --primary-200: #242a2c; /* drop down menu/ prompt window fill*/
+ --primary-300: #0a0c0e; /* black */
+ --primary-400: #2a302c; /* small buttons*/
+ --primary-500: #4b695d; /* main accent color green*/
+ --primary-700: #273538; /* extension box fill*/
+ --primary-800: #d15e84; /* pink(hover accent)*/
+ --highlight-color: var(--primary-500);
+ --inactive-color: var(--primary--800);
+ --body-text-color: var(--neutral-100);
+ --body-text-color-subdued: var(--neutral-300);
+ --background-color: var(--primary-100);
+ --background-fill-primary: var(--input-background-fill);
+ --input-padding: 8px;
+ --input-background-fill: var(--primary-200);
+ --input-shadow: none;
+ --button-secondary-text-color: white;
+ --button-secondary-background-fill: var(--primary-400);
+ --button-secondary-background-fill-hover: var(--primary-700);
+ --block-title-text-color: var(--neutral-300);
+ --radius-sm: 1px;
+ --radius-lg: 6px;
+ --spacing-md: 4px;
+ --spacing-xxl: 8px;
+ --line-sm: 1.2em;
+ --line-md: 1.4em;
+}
+
+html { font-size: var(--font-size); }
+body, button, input, select, textarea { font-family: var(--font);}
+button { max-width: 400px; }
+img { background-color: var(--background-color); }
+input[type=range] { height: var(--line-sm); appearance: none; margin-top: 0; min-width: 160px; background-color: var(--background-color); width: 100%; background: transparent; }
+input[type=range]::-webkit-slider-runnable-track, input[type=range]::-moz-range-track { width: 100%; height: 6px; cursor: pointer; background: var(--primary-400); border-radius: var(--radius-lg); border: 0px solid #222222; }
+input[type=range]::-webkit-slider-thumb, input[type=range]::-moz-range-thumb { border: 0px solid #000000; height: var(--line-sm); width: 8px; border-radius: var(--radius-lg); background: white; cursor: pointer; appearance: none; margin-top: 0px; }
+input[type=range]::-moz-range-progress { background-color: var(--primary-500); height: 6px; border-radius: var(--radius-lg); }
+::-webkit-scrollbar-track { background: #333333; }
+::-webkit-scrollbar-thumb { background-color: var(--highlight-color); border-radius: var(--radius-lg); border-width: 0; box-shadow: 2px 2px 3px #111111; }
+div.form { border-width: 0; box-shadow: none; background: transparent; overflow: visible; margin-bottom: 6px; }
+div.compact { gap: 1em; }
+
+/* gradio style classes */
+fieldset .gr-block.gr-box, label.block span { padding: 0; margin-top: -4px; }
+.border-2 { border-width: 0; }
+.border-b-2 { border-bottom-width: 2px; border-color: var(--highlight-color) !important; padding-bottom: 2px; margin-bottom: 8px; }
+.bg-white { color: lightyellow; background-color: var(--inactive-color); }
+.gr-box { border-radius: var(--radius-sm) !important; background-color: #111111 !important; box-shadow: 2px 2px 3px #111111; border-width: 0; padding: 4px; margin: 12px 0px 12px 0px }
+.gr-button { font-weight: normal; box-shadow: 2px 2px 3px #111111; font-size: 0.8rem; min-width: 32px; min-height: 32px; padding: 3px; margin: 3px; }
+.gr-check-radio { background-color: var(--inactive-color); border-width: 0; border-radius: var(--radius-lg); box-shadow: 2px 2px 3px #111111; }
+.gr-check-radio:checked { background-color: var(--highlight-color); }
+.gr-compact { background-color: var(--background-color); }
+.gr-form { border-width: 0; }
+.gr-input { background-color: #333333 !important; padding: 4px; margin: 4px; }
+.gr-input-label { color: lightyellow; border-width: 0; background: transparent; padding: 2px !important; }
+.gr-panel { background-color: var(--background-color); }
+.eta-bar { display: none !important }
+svg.feather.feather-image, .feather .feather-image { display: none }
+.gap-2 { padding-top: 8px; }
+.gr-box > div > div > input.gr-text-input { right: 0; width: 4em; padding: 0; top: -12px; border: none; max-height: 20px; }
+.output-html { line-height: 1.2rem; overflow-x: hidden; }
+.output-html > div { margin-bottom: 8px; }
+.overflow-hidden .flex .flex-col .relative col .gap-4 { min-width: var(--left-column); max-width: var(--left-column); } /* this is a problematic one */
+.p-2 { padding: 0; }
+.px-4 { padding-lefT: 1rem; padding-right: 1rem; }
+.py-6 { padding-bottom: 0; }
+.tabs { background-color: var(--background-color); }
+.block.token-counter span { background-color: var(--input-background-fill) !important; box-shadow: 2px 2px 2px #111; border: none !important; font-size: 0.8rem; }
+.tab-nav { zoom: 110%; margin-top: 10px; margin-bottom: 10px; border-bottom: 2px solid var(--highlight-color) !important; padding-bottom: 2px; }
+div.tab-nav button.selected {background-color: var(--button-primary-background-fill);}
+#settings div.tab-nav button.selected {background-color: var(--background-color); color: var(--primary-800); font-weight: bold;}
+.label-wrap { background-color: #191919; /* extension tab color*/ padding: 16px 8px 8px 8px; border-radius: var(--radius-lg); padding-left: 8px !important; }
+.small-accordion .label-wrap { padding: 8px 0px 8px 0px; }
+.small-accordion .label-wrap .icon { margin-right: 1em; }
+.gradio-button.tool { border: none; box-shadow: none; border-radius: var(--radius-lg);}
+button.selected {background: var(--button-primary-background-fill);}
+.center.boundedheight.flex {background-color: var(--input-background-fill);}
+.compact {border-radius: var(--border-radius-lg);}
+#logMonitorData {background-color: var(--input-background-fill);}
+#tab_extensions table td, #tab_extensions table th, #tab_config table td, #tab_config table th { border: none; padding: 0.5em; background-color: var(--primary-200); }
+#tab_extensions table, #tab_config table { width: 96vw; }
+#tab_extensions table input[type=checkbox] {appearance: none; border-radius: 0px;}
+#tab_extensions button:hover { background-color: var(--button-secondary-background-fill-hover);}
+
+/* automatic style classes */
+.progressDiv { border-radius: var(--radius-sm) !important; position: fixed; top: 44px; right: 26px; max-width: 262px; height: 48px; z-index: 99; box-shadow: var(--button-shadow); }
+.progressDiv .progress { border-radius: var(--radius-lg) !important; background: var(--highlight-color); line-height: 3rem; height: 48px; }
+.gallery-item { box-shadow: none !important; }
+.performance { color: #888; }
+.extra-networks { border-left: 2px solid var(--highlight-color) !important; padding-left: 4px; }
+.image-buttons { gap: 10px !important; justify-content: center; }
+.image-buttons > button { max-width: 160px; }
+.tooltip { background: var(--primary-800); color: white; border: none; border-radius: var(--radius-lg) }
+#system_row > button, #settings_row > button, #config_row > button { max-width: 10em; }
+
+/* gradio elements overrides */
+#div.gradio-container { overflow-x: hidden; }
+#img2img_label_copy_to_img2img { font-weight: normal; }
+#txt2img_prompt, #txt2img_neg_prompt, #img2img_prompt, #img2img_neg_prompt { background-color: var(--background-color); box-shadow: 4px 4px 4px 0px #333333 !important; }
+#txt2img_prompt > label > textarea, #txt2img_neg_prompt > label > textarea, #img2img_prompt > label > textarea, #img2img_neg_prompt > label > textarea { font-size: 1.1rem; }
+#img2img_settings { min-width: calc(2 * var(--left-column)); max-width: calc(2 * var(--left-column)); background-color: #111111; padding-top: 16px; }
+#interrogate, #deepbooru { margin: 0 0px 10px 0px; max-width: 80px; max-height: 80px; font-weight: normal; font-size: 0.95em; }
+#quicksettings .gr-button-tool { font-size: 1.6rem; box-shadow: none; margin-top: -2px; height: 2.4em; }
+#quicksettings button {padding: 0 0.5em 0.1em 0.5em;}
+#open_folder_extras, #footer, #style_pos_col, #style_neg_col, #roll_col, #extras_upscaler_2, #extras_upscaler_2_visibility, #txt2img_seed_resize_from_w, #txt2img_seed_resize_from_h { display: none; }
+#save-animation { border-radius: var(--radius-sm) !important; margin-bottom: 16px; background-color: #111111; }
+#script_list { padding: 4px; margin-top: 16px; margin-bottom: 8px; }
+#settings > div.flex-wrap { width: 15em; }
+#txt2img_cfg_scale { min-width: 200px; }
+#txt2img_checkboxes, #img2img_checkboxes { background-color: transparent; }
+#txt2img_checkboxes, #img2img_checkboxes { margin-bottom: 0.2em; }
+#txt2img_actions_column, #img2img_actions_column { flex-flow: wrap; justify-content: space-between; }
+#txt2img_enqueue_wrapper, #img2img_enqueue_wrapper { min-width: unset; width: 48%; }
+#txt2img_generate_box, #img2img_generate_box { min-width: unset; width: 48%; }
+
+#extras_upscale { margin-top: 10px }
+#txt2img_progress_row > div { min-width: var(--left-column); max-width: var(--left-column); }
+#txt2img_settings { min-width: var(--left-column); max-width: var(--left-column); background-color: #111111; padding-top: 16px; }
+#pnginfo_html2_info { margin-top: -18px; background-color: var(--input-background-fill); padding: var(--input-padding) }
+#txt2img_tools, #img2img_tools { margin-top: -4px; margin-bottom: -4px; }
+#txt2img_styles_row, #img2img_styles_row { margin-top: -6px; z-index: 200; }
+
+/* based on gradio built-in dark theme */
+:root, .light, .dark {
+ --body-background-fill: var(--background-color);
+ --color-accent-soft: var(--neutral-700);
+ --background-fill-secondary: none;
+ --border-color-accent: var(--background-color);
+ --border-color-primary: var(--background-color);
+ --link-text-color-active: var(--primary-500);
+ --link-text-color: var(--secondary-500);
+ --link-text-color-hover: var(--secondary-400);
+ --link-text-color-visited: var(--secondary-600);
+ --shadow-spread: 1px;
+ --block-background-fill: None;
+ --block-border-color: var(--border-color-primary);
+ --block_border_width: None;
+ --block-info-text-color: var(--body-text-color-subdued);
+ --block-label-background-fill: var(--background-fill-secondary);
+ --block-label-border-color: var(--border-color-primary);
+ --block_label_border_width: None;
+ --block-label-text-color: var(--neutral-200);
+ --block_shadow: None;
+ --block_title_background_fill: None;
+ --block_title_border_color: None;
+ --block_title_border_width: None;
+ --panel-background-fill: var(--background-fill-secondary);
+ --panel-border-color: var(--border-color-primary);
+ --panel_border_width: None;
+ --checkbox-background-color: var(--primary-200);
+ --checkbox-background-color-focus: var(--primary-700);
+ --checkbox-background-color-hover: var(--primary-700);
+ --checkbox-background-color-selected: var(--primary-500);
+ --checkbox-border-color: transparent;
+ --checkbox-border-color-focus: var(--primary-800);
+ --checkbox-border-color-hover: var(--primary-800);
+ --checkbox-border-color-selected: var(--primary-800);
+ --checkbox-border-width: var(--input-border-width);
+ --checkbox-label-background-fill: None;
+ --checkbox-label-background-fill-hover: None;
+ --checkbox-label-background-fill-selected: var(--checkbox-label-background-fill);
+ --checkbox-label-border-color: var(--border-color-primary);
+ --checkbox-label-border-color-hover: var(--checkbox-label-border-color);
+ --checkbox-label-border-width: var(--input-border-width);
+ --checkbox-label-text-color: var(--body-text-color);
+ --checkbox-label-text-color-selected: var(--checkbox-label-text-color);
+ --error-background-fill: var(--background-fill-primary);
+ --error-border-color: var(--border-color-primary);
+ --error-text-color: #f768b7; /*was ef4444*/
+ --input-background-fill-focus: var(--secondary-600);
+ --input-background-fill-hover: var(--input-background-fill);
+ --input-border-color: var(--background-color);
+ --input-border-color-focus: var(--primary-800);
+ --input-placeholder-color: var(--neutral-500);
+ --input-shadow-focus: None;
+ --loader_color: None;
+ --slider_color: None;
+ --stat-background-fill: linear-gradient(to right, var(--primary-400), var(--primary-800));
+ --table-border-color: var(--neutral-700);
+ --table-even-background-fill: var(--primary-300);
+ --table-odd-background-fill: var(--primary-200);
+ --table-row-focus: var(--color-accent-soft);
+ --button-border-width: var(--input-border-width);
+ --button-cancel-background-fill: linear-gradient(to bottom right, #dc2626, #b91c1c);
+ --button-cancel-background-fill-hover: linear-gradient(to bottom right, #dc2626, #dc2626);
+ --button-cancel-border-color: #dc2626;
+ --button-cancel-border-color-hover: var(--button-cancel-border-color);
+ --button-cancel-text-color: white;
+ --button-cancel-text-color-hover: var(--button-cancel-text-color);
+ --button-primary-background-fill: var(--primary-500);
+ --button-primary-background-fill-hover: var(--primary-800);
+ --button-primary-border-color: var(--primary-500);
+ --button-primary-border-color-hover: var(--button-primary-border-color);
+ --button-primary-text-color: white;
+ --button-primary-text-color-hover: var(--button-primary-text-color);
+ --button-secondary-border-color: var(--neutral-600);
+ --button-secondary-border-color-hover: var(--button-secondary-border-color);
+ --button-secondary-text-color-hover: var(--button-secondary-text-color);
+ --secondary-50: #eff6ff;
+ --secondary-100: #dbeafe;
+ --secondary-200: #bfdbfe;
+ --secondary-300: #93c5fd;
+ --secondary-400: #60a5fa;
+ --secondary-500: #3b82f6;
+ --secondary-600: #2563eb;
+ --secondary-700: #1d4ed8;
+ --secondary-800: #1e40af;
+ --secondary-900: #1e3a8a;
+ --secondary-950: #1d3660;
+ --neutral-50: #f0f0f0; /* */
+ --neutral-100: #e8e8e3;/* majority of text (neutral gray yellow) */
+ --neutral-200: #d0d0d0;
+ --neutral-300: #b3b5ac; /* top tab /sub text (light accent) */
+ --neutral-400: #ffba85;/* tab title (bright orange) */
+ --neutral-500: #48665b; /* prompt text (desat accent)*/
+ --neutral-600: #373f39; /* tab outline color (accent color)*/
+ --neutral-700: #2b373b; /* small settings tab accent */
+ --neutral-800: #f379c2; /* bright pink accent */
+ --neutral-900: #111827;
+ --neutral-950: #0b0f19;
+ --radius-xxs: 0;
+ --radius-xs: 0;
+ --radius-md: 0;
+ --radius-xl: 0;
+ --radius-xxl: 0;
+ --body-text-size: var(--text-md);
+ --body-text-weight: 400;
+ --embed-radius: var(--radius-lg);
+ --color-accent: var(--primary-500);
+ --shadow-drop: 0;
+ --shadow-drop-lg: 0 1px 3px 0 rgb(0 0 0 / 0.1), 0 1px 2px -1px rgb(0 0 0 / 0.1);
+ --shadow-inset: rgba(0,0,0,0.05) 0px 2px 4px 0px inset;
+ --block-border-width: 1px;
+ --block-info-text-size: var(--text-sm);
+ --block-info-text-weight: 400;
+ --block-label-border-width: 1px;
+ --block-label-margin: 0;
+ --block-label-padding: var(--spacing-sm) var(--spacing-lg);
+ --block-label-radius: calc(var(--radius-lg) - 1px) 0 calc(var(--radius-lg) - 1px) 0;
+ --block-label-right-radius: 0 calc(var(--radius-lg) - 1px) 0 calc(var(--radius-lg) - 1px);
+ --block-label-text-size: var(--text-sm);
+ --block-label-text-weight: 400;
+ --block-padding: var(--spacing-xl) calc(var(--spacing-xl) + 2px);
+ --block-radius: var(--radius-lg);
+ --block-shadow: var(--shadow-drop);
+ --block-title-background-fill: none;
+ --block-title-border-color: none;
+ --block-title-border-width: 0;
+ --block-title-padding: 0;
+ --block-title-radius: none;
+ --block-title-text-size: var(--text-md);
+ --block-title-text-weight: 400;
+ --container-radius: var(--radius-lg);
+ --form-gap-width: 1px;
+ --layout-gap: var(--spacing-xxl);
+ --panel-border-width: 0;
+ --section-header-text-size: var(--text-md);
+ --section-header-text-weight: 400;
+ --checkbox-border-radius: var(--radius-sm);
+ --checkbox-label-gap: 2px;
+ --checkbox-label-padding: var(--spacing-md);
+ --checkbox-label-shadow: var(--shadow-drop);
+ --checkbox-label-text-size: var(--text-md);
+ --checkbox-label-text-weight: 400;
+ --checkbox-check: url("data:image/svg+xml,%3csvg viewBox='0 0 16 16' fill='white' xmlns='http://www.w3.org/2000/svg'%3e%3cpath d='M12.207 4.793a1 1 0 010 1.414l-5 5a1 1 0 01-1.414 0l-2-2a1 1 0 011.414-1.414L6.5 9.086l4.293-4.293a1 1 0 011.414 0z'/%3e%3c/svg%3e");
+ --radio-circle: url("data:image/svg+xml,%3csvg viewBox='0 0 16 16' fill='white' xmlns='http://www.w3.org/2000/svg'%3e%3ccircle cx='8' cy='8' r='3'/%3e%3c/svg%3e");
+ --checkbox-shadow: var(--input-shadow);
+ --error-border-width: 1px;
+ --input-border-width: 1px;
+ --input-radius: var(--radius-lg);
+ --input-text-size: var(--text-md);
+ --input-text-weight: 400;
+ --loader-color: var(--color-accent);
+ --prose-text-size: var(--text-md);
+ --prose-text-weight: 400;
+ --prose-header-text-weight: 600;
+ --slider-color: ;
+ --table-radius: var(--radius-lg);
+ --button-large-padding: 2px 6px;
+ --button-large-radius: var(--radius-lg);
+ --button-large-text-size: var(--text-lg);
+ --button-large-text-weight: 400;
+ --button-shadow: none;
+ --button-shadow-active: none;
+ --button-shadow-hover: none;
+ --button-small-padding: var(--spacing-sm) calc(2 * var(--spacing-sm));
+ --button-small-radius: var(--radius-lg);
+ --button-small-text-size: var(--text-md);
+ --button-small-text-weight: 400;
+ --button-transition: none;
+ --size-9: 64px;
+ --size-14: 64px;
+}
diff --git a/javascript/orchid-dreams.css b/javascript/orchid-dreams.css
index ef40d8e3a..05ff97a3b 100644
--- a/javascript/orchid-dreams.css
+++ b/javascript/orchid-dreams.css
@@ -1,297 +1,297 @@
-/* generic html tags */
-:root, .light, .dark {
- --font: 'system-ui', 'ui-sans-serif', 'system-ui', "Roboto", sans-serif;
- --font-mono: 'ui-monospace', 'Consolas', monospace;
- --font-size: 16px;
- --primary-100: #2a2a34; /* bg color*/
- --primary-200: #1f2028; /* drop down menu/ prompt*/
- --primary-300: #0a0c0e; /* black */
- --primary-400: #40435c; /* small buttons*/
- --primary-500: #4c48b5; /* main accent color purple*/
- --primary-700: #1f2028; /* darker hover accent*/
- --primary-800: #e95ee3; /* pink accent*/
- --highlight-color: var(--primary-500);
- --inactive-color: var(--primary--800);
- --body-text-color: var(--neutral-100);
- --body-text-color-subdued: var(--neutral-300);
- --background-color: var(--primary-100);
- --background-fill-primary: var(--input-background-fill);
- --input-padding: 8px;
- --input-background-fill: var(--primary-200);
- --input-shadow: none;
- --button-secondary-text-color: white;
- --button-secondary-background-fill: var(--primary-400);
- --button-secondary-background-fill-hover: var(--primary-700);
- --block-title-text-color: var(--neutral-300);
- --radius-sm: 1px;
- --radius-lg: 6px;
- --spacing-md: 4px;
- --spacing-xxl: 8px;
- --line-sm: 1.2em;
- --line-md: 1.4em;
-}
-
-html { font-size: var(--font-size); }
-body, button, input, select, textarea { font-family: var(--font);}
-button { max-width: 400px; }
-img { background-color: var(--background-color); }
-input[type=range] { height: var(--line-sm); appearance: none; margin-top: 0; min-width: 160px; background-color: var(--background-color); width: 100%; background: transparent; }
-input[type=range]::-webkit-slider-runnable-track, input[type=range]::-moz-range-track { width: 100%; height: 6px; cursor: pointer; background: var(--primary-400); border-radius: var(--radius-lg); border: 0px solid #222222; }
-input[type=range]::-webkit-slider-thumb, input[type=range]::-moz-range-thumb { border: 0px solid #000000; height: var(--line-sm); width: 8px; border-radius: var(--radius-lg); background: white; cursor: pointer; appearance: none; margin-top: 0px; }
-input[type=range]::-moz-range-progress { background-color: var(--primary-500); height: 6px; border-radius: var(--radius-lg); }
-::-webkit-scrollbar-track { background: #333333; }
-::-webkit-scrollbar-thumb { background-color: var(--highlight-color); border-radius: var(--radius-lg); border-width: 0; box-shadow: 2px 2px 3px #111111; }
-div.form { border-width: 0; box-shadow: none; background: transparent; overflow: visible; margin-bottom: 6px; }
-div.compact { gap: 1em; }
-
-/* gradio style classes */
-fieldset .gr-block.gr-box, label.block span { padding: 0; margin-top: -4px; }
-.border-2 { border-width: 0; }
-.border-b-2 { border-bottom-width: 2px; border-color: var(--highlight-color) !important; padding-bottom: 2px; margin-bottom: 8px; }
-.bg-white { color: lightyellow; background-color: var(--inactive-color); }
-.gr-box { border-radius: var(--radius-sm) !important; background-color: #111111 !important; box-shadow: 2px 2px 3px #111111; border-width: 0; padding: 4px; margin: 12px 0px 12px 0px }
-.gr-button { font-weight: normal; box-shadow: 2px 2px 3px #111111; font-size: 0.8rem; min-width: 32px; min-height: 32px; padding: 3px; margin: 3px; }
-.gr-check-radio { background-color: var(--inactive-color); border-width: 0; border-radius: var(--radius-lg); box-shadow: 2px 2px 3px #111111; }
-.gr-check-radio:checked { background-color: var(--highlight-color); }
-.gr-compact { background-color: var(--background-color); }
-.gr-form { border-width: 0; }
-.gr-input { background-color: #333333 !important; padding: 4px; margin: 4px; }
-.gr-input-label { color: lightyellow; border-width: 0; background: transparent; padding: 2px !important; }
-.gr-panel { background-color: var(--background-color); }
-.eta-bar { display: none !important }
-svg.feather.feather-image, .feather .feather-image { display: none }
-.gap-2 { padding-top: 8px; }
-.gr-box > div > div > input.gr-text-input { right: 0; width: 4em; padding: 0; top: -12px; border: none; max-height: 20px; }
-.output-html { line-height: 1.2rem; overflow-x: hidden; }
-.output-html > div { margin-bottom: 8px; }
-.overflow-hidden .flex .flex-col .relative col .gap-4 { min-width: var(--left-column); max-width: var(--left-column); } /* this is a problematic one */
-.p-2 { padding: 0; }
-.px-4 { padding-lefT: 1rem; padding-right: 1rem; }
-.py-6 { padding-bottom: 0; }
-.tabs { background-color: var(--background-color); }
-.block.token-counter span { background-color: var(--input-background-fill) !important; box-shadow: 2px 2px 2px #111; border: none !important; font-size: 0.8rem; }
-.tab-nav { zoom: 110%; margin-top: 10px; margin-bottom: 10px; border-bottom: 2px solid var(--highlight-color) !important; padding-bottom: 2px; }
-div.tab-nav button.selected {background-color: var(--button-primary-background-fill);}
-#settings div.tab-nav button.selected {background-color: var(--background-color); color: var(--primary-800); font-weight: bold;}
-.label-wrap { background-color: #18181e; /* extension tab color*/ padding: 16px 8px 8px 8px; border-radius: var(--radius-lg); padding-left: 8px !important; }
-.small-accordion .label-wrap { padding: 8px 0px 8px 0px; }
-.small-accordion .label-wrap .icon { margin-right: 1em; }
-.gradio-button.tool { border: none; box-shadow: none; border-radius: var(--radius-lg);}
-button.selected {background: var(--button-primary-background-fill);}
-.center.boundedheight.flex {background-color: var(--input-background-fill);}
-.compact {border-radius: var(--border-radius-lg);}
-#logMonitorData {background-color: var(--input-background-fill);}
-#tab_extensions table td, #tab_extensions table th, #tab_config table td, #tab_config table th { border: none; padding: 0.5em; background-color: var(--primary-200); }
-#tab_extensions table, #tab_config table { width: 96vw; }
-#tab_extensions table input[type=checkbox] {appearance: none; border-radius: 0px;}
-#tab_extensions button:hover { background-color: var(--button-secondary-background-fill-hover);}
-
-/* automatic style classes */
-.progressDiv { border-radius: var(--radius-sm) !important; position: fixed; top: 44px; right: 26px; max-width: 262px; height: 48px; z-index: 99; box-shadow: var(--button-shadow); }
-.progressDiv .progress { border-radius: var(--radius-lg) !important; background: var(--highlight-color); line-height: 3rem; height: 48px; }
-.gallery-item { box-shadow: none !important; }
-.performance { color: #888; }
-.extra-networks { border-left: 2px solid var(--highlight-color) !important; padding-left: 4px; }
-.image-buttons { gap: 10px !important; justify-content: center; }
-.image-buttons > button { max-width: 160px; }
-.tooltip { background: var(--primary-800); color: white; border: none; border-radius: var(--radius-lg) }
-#system_row > button, #settings_row > button, #config_row > button { max-width: 10em; }
-
-/* gradio elements overrides */
-#div.gradio-container { overflow-x: hidden; }
-#img2img_label_copy_to_img2img { font-weight: normal; }
-#txt2img_prompt, #txt2img_neg_prompt, #img2img_prompt, #img2img_neg_prompt { background-color: var(--background-color); box-shadow: 4px 4px 4px 0px #333333 !important; }
-#txt2img_prompt > label > textarea, #txt2img_neg_prompt > label > textarea, #img2img_prompt > label > textarea, #img2img_neg_prompt > label > textarea { font-size: 1.1rem; }
-#img2img_settings { min-width: calc(2 * var(--left-column)); max-width: calc(2 * var(--left-column)); background-color: #111111; padding-top: 16px; }
-#interrogate, #deepbooru { margin: 0 0px 10px 0px; max-width: 80px; max-height: 80px; font-weight: normal; font-size: 0.95em; }
-#quicksettings .gr-button-tool { font-size: 1.6rem; box-shadow: none; margin-top: -2px; height: 2.4em; }
-#quicksettings button {padding: 0 0.5em 0.1em 0.5em;}
-#open_folder_extras, #footer, #style_pos_col, #style_neg_col, #roll_col, #extras_upscaler_2, #extras_upscaler_2_visibility, #txt2img_seed_resize_from_w, #txt2img_seed_resize_from_h { display: none; }
-#save-animation { border-radius: var(--radius-sm) !important; margin-bottom: 16px; background-color: #111111; }
-#script_list { padding: 4px; margin-top: 16px; margin-bottom: 8px; }
-#settings > div.flex-wrap { width: 15em; }
-#txt2img_cfg_scale { min-width: 200px; }
-#txt2img_checkboxes, #img2img_checkboxes { background-color: transparent; }
-#txt2img_checkboxes, #img2img_checkboxes { margin-bottom: 0.2em; }
-#txt2img_actions_column, #img2img_actions_column { flex-flow: wrap; justify-content: space-between; }
-#txt2img_enqueue_wrapper, #img2img_enqueue_wrapper { min-width: unset; width: 48%; }
-#txt2img_generate_box, #img2img_generate_box { min-width: unset; width: 48%; }
-
-#extras_upscale { margin-top: 10px }
-#txt2img_progress_row > div { min-width: var(--left-column); max-width: var(--left-column); }
-#txt2img_settings { min-width: var(--left-column); max-width: var(--left-column); background-color: #111111; padding-top: 16px; }
-#pnginfo_html2_info { margin-top: -18px; background-color: var(--input-background-fill); padding: var(--input-padding) }
-#txt2img_tools, #img2img_tools { margin-top: -4px; margin-bottom: -4px; }
-#txt2img_styles_row, #img2img_styles_row { margin-top: -6px; z-index: 200; }
-
-/* based on gradio built-in dark theme */
-:root, .light, .dark {
- --body-background-fill: var(--background-color);
- --color-accent-soft: var(--neutral-700);
- --background-fill-secondary: none;
- --border-color-accent: var(--background-color);
- --border-color-primary: var(--background-color);
- --link-text-color-active: var(--primary-500);
- --link-text-color: var(--secondary-500);
- --link-text-color-hover: var(--secondary-400);
- --link-text-color-visited: var(--secondary-600);
- --shadow-spread: 1px;
- --block-background-fill: None;
- --block-border-color: var(--border-color-primary);
- --block_border_width: None;
- --block-info-text-color: var(--body-text-color-subdued);
- --block-label-background-fill: var(--background-fill-secondary);
- --block-label-border-color: var(--border-color-primary);
- --block_label_border_width: None;
- --block-label-text-color: var(--neutral-200);
- --block_shadow: None;
- --block_title_background_fill: None;
- --block_title_border_color: None;
- --block_title_border_width: None;
- --panel-background-fill: var(--background-fill-secondary);
- --panel-border-color: var(--border-color-primary);
- --panel_border_width: None;
- --checkbox-background-color: var(--primary-200);
- --checkbox-background-color-focus: var(--primary-400);
- --checkbox-background-color-hover: var(--primary-200);
- --checkbox-background-color-selected: var(--primary-400);
- --checkbox-border-color: transparent;
- --checkbox-border-color-focus: var(--primary-800);
- --checkbox-border-color-hover: var(--primary-800);
- --checkbox-border-color-selected: var(--primary-800);
- --checkbox-border-width: var(--input-border-width);
- --checkbox-label-background-fill: None;
- --checkbox-label-background-fill-hover: None;
- --checkbox-label-background-fill-selected: var(--checkbox-label-background-fill);
- --checkbox-label-border-color: var(--border-color-primary);
- --checkbox-label-border-color-hover: var(--checkbox-label-border-color);
- --checkbox-label-border-width: var(--input-border-width);
- --checkbox-label-text-color: var(--body-text-color);
- --checkbox-label-text-color-selected: var(--checkbox-label-text-color);
- --error-background-fill: var(--background-fill-primary);
- --error-border-color: var(--border-color-primary);
- --error-text-color: #f768b7; /*was ef4444*/
- --input-background-fill-focus: var(--secondary-600);
- --input-background-fill-hover: var(--input-background-fill);
- --input-border-color: var(--background-color);
- --input-border-color-focus: var(--primary-800);
- --input-placeholder-color: var(--neutral-500);
- --input-shadow-focus: None;
- --loader_color: None;
- --slider_color: None;
- --stat-background-fill: linear-gradient(to right, var(--primary-400), var(--primary-800));
- --table-border-color: var(--neutral-700);
- --table-even-background-fill: var(--primary-300);
- --table-odd-background-fill: var(--primary-200);
- --table-row-focus: var(--color-accent-soft);
- --button-border-width: var(--input-border-width);
- --button-cancel-background-fill: linear-gradient(to bottom right, #dc2626, #b91c1c);
- --button-cancel-background-fill-hover: linear-gradient(to bottom right, #dc2626, #dc2626);
- --button-cancel-border-color: #dc2626;
- --button-cancel-border-color-hover: var(--button-cancel-border-color);
- --button-cancel-text-color: white;
- --button-cancel-text-color-hover: var(--button-cancel-text-color);
- --button-primary-background-fill: var(--primary-500);
- --button-primary-background-fill-hover: var(--primary-800);
- --button-primary-border-color: var(--primary-500);
- --button-primary-border-color-hover: var(--button-primary-border-color);
- --button-primary-text-color: white;
- --button-primary-text-color-hover: var(--button-primary-text-color);
- --button-secondary-border-color: var(--neutral-600);
- --button-secondary-border-color-hover: var(--button-secondary-border-color);
- --button-secondary-text-color-hover: var(--button-secondary-text-color);
- --secondary-50: #eff6ff;
- --secondary-100: #dbeafe;
- --secondary-200: #bfdbfe;
- --secondary-300: #93c5fd;
- --secondary-400: #60a5fa;
- --secondary-500: #3b82f6;
- --secondary-600: #2563eb;
- --secondary-700: #1d4ed8;
- --secondary-800: #1e40af;
- --secondary-900: #1e3a8a;
- --secondary-950: #1d3660;
- --neutral-50: #f0f0f0; /* */
- --neutral-100: #ddd5e8;/* majority of text (neutral gray purple) */
- --neutral-200: #d0d0d0;
- --neutral-300: #bfbad6; /* top tab text (light accent) */
- --neutral-400: #ffba85;/* tab title (bright orange) */
- --neutral-500: #545b94; /* prompt text (desat accent)*/
- --neutral-600: #1f2028; /* tab outline color (accent color)*/
- --neutral-700: #20212c; /* unchanged settings tab accent (dark)*/
- --neutral-800: #e055dc; /* bright pink accent */
- --neutral-900: #111827;
- --neutral-950: #0b0f19;
- --radius-xxs: 0;
- --radius-xs: 0;
- --radius-md: 0;
- --radius-xl: 0;
- --radius-xxl: 0;
- --body-text-size: var(--text-md);
- --body-text-weight: 400;
- --embed-radius: var(--radius-lg);
- --color-accent: var(--primary-500);
- --shadow-drop: 0;
- --shadow-drop-lg: 0 1px 3px 0 rgb(0 0 0 / 0.1), 0 1px 2px -1px rgb(0 0 0 / 0.1);
- --shadow-inset: rgba(0,0,0,0.05) 0px 2px 4px 0px inset;
- --block-border-width: 1px;
- --block-info-text-size: var(--text-sm);
- --block-info-text-weight: 400;
- --block-label-border-width: 1px;
- --block-label-margin: 0;
- --block-label-padding: var(--spacing-sm) var(--spacing-lg);
- --block-label-radius: calc(var(--radius-lg) - 1px) 0 calc(var(--radius-lg) - 1px) 0;
- --block-label-right-radius: 0 calc(var(--radius-lg) - 1px) 0 calc(var(--radius-lg) - 1px);
- --block-label-text-size: var(--text-sm);
- --block-label-text-weight: 400;
- --block-padding: var(--spacing-xl) calc(var(--spacing-xl) + 2px);
- --block-radius: var(--radius-lg);
- --block-shadow: var(--shadow-drop);
- --block-title-background-fill: none;
- --block-title-border-color: none;
- --block-title-border-width: 0;
- --block-title-padding: 0;
- --block-title-radius: none;
- --block-title-text-size: var(--text-md);
- --block-title-text-weight: 400;
- --container-radius: var(--radius-lg);
- --form-gap-width: 1px;
- --layout-gap: var(--spacing-xxl);
- --panel-border-width: 0;
- --section-header-text-size: var(--text-md);
- --section-header-text-weight: 400;
- --checkbox-border-radius: var(--radius-sm);
- --checkbox-label-gap: 2px;
- --checkbox-label-padding: var(--spacing-md);
- --checkbox-label-shadow: var(--shadow-drop);
- --checkbox-label-text-size: var(--text-md);
- --checkbox-label-text-weight: 400;
- --checkbox-check: url("data:image/svg+xml,%3csvg viewBox='0 0 16 16' fill='white' xmlns='http://www.w3.org/2000/svg'%3e%3cpath d='M12.207 4.793a1 1 0 010 1.414l-5 5a1 1 0 01-1.414 0l-2-2a1 1 0 011.414-1.414L6.5 9.086l4.293-4.293a1 1 0 011.414 0z'/%3e%3c/svg%3e");
- --radio-circle: url("data:image/svg+xml,%3csvg viewBox='0 0 16 16' fill='white' xmlns='http://www.w3.org/2000/svg'%3e%3ccircle cx='8' cy='8' r='3'/%3e%3c/svg%3e");
- --checkbox-shadow: var(--input-shadow);
- --error-border-width: 1px;
- --input-border-width: 1px;
- --input-radius: var(--radius-lg);
- --input-text-size: var(--text-md);
- --input-text-weight: 400;
- --loader-color: var(--color-accent);
- --prose-text-size: var(--text-md);
- --prose-text-weight: 400;
- --prose-header-text-weight: 600;
- --slider-color: ;
- --table-radius: var(--radius-lg);
- --button-large-padding: 2px 6px;
- --button-large-radius: var(--radius-lg);
- --button-large-text-size: var(--text-lg);
- --button-large-text-weight: 400;
- --button-shadow: none;
- --button-shadow-active: none;
- --button-shadow-hover: none;
- --button-small-padding: var(--spacing-sm) calc(2 * var(--spacing-sm));
- --button-small-radius: var(--radius-lg);
- --button-small-text-size: var(--text-md);
- --button-small-text-weight: 400;
- --button-transition: none;
- --size-9: 64px;
- --size-14: 64px;
-}
+/* generic html tags */
+:root, .light, .dark {
+ --font: 'system-ui', 'ui-sans-serif', 'system-ui', "Roboto", sans-serif;
+ --font-mono: 'ui-monospace', 'Consolas', monospace;
+ --font-size: 16px;
+ --primary-100: #2a2a34; /* bg color*/
+ --primary-200: #1f2028; /* drop down menu/ prompt*/
+ --primary-300: #0a0c0e; /* black */
+ --primary-400: #40435c; /* small buttons*/
+ --primary-500: #4c48b5; /* main accent color purple*/
+ --primary-700: #1f2028; /* darker hover accent*/
+ --primary-800: #e95ee3; /* pink accent*/
+ --highlight-color: var(--primary-500);
+ --inactive-color: var(--primary--800);
+ --body-text-color: var(--neutral-100);
+ --body-text-color-subdued: var(--neutral-300);
+ --background-color: var(--primary-100);
+ --background-fill-primary: var(--input-background-fill);
+ --input-padding: 8px;
+ --input-background-fill: var(--primary-200);
+ --input-shadow: none;
+ --button-secondary-text-color: white;
+ --button-secondary-background-fill: var(--primary-400);
+ --button-secondary-background-fill-hover: var(--primary-700);
+ --block-title-text-color: var(--neutral-300);
+ --radius-sm: 1px;
+ --radius-lg: 6px;
+ --spacing-md: 4px;
+ --spacing-xxl: 8px;
+ --line-sm: 1.2em;
+ --line-md: 1.4em;
+}
+
+html { font-size: var(--font-size); }
+body, button, input, select, textarea { font-family: var(--font);}
+button { max-width: 400px; }
+img { background-color: var(--background-color); }
+input[type=range] { height: var(--line-sm); appearance: none; margin-top: 0; min-width: 160px; background-color: var(--background-color); width: 100%; background: transparent; }
+input[type=range]::-webkit-slider-runnable-track, input[type=range]::-moz-range-track { width: 100%; height: 6px; cursor: pointer; background: var(--primary-400); border-radius: var(--radius-lg); border: 0px solid #222222; }
+input[type=range]::-webkit-slider-thumb, input[type=range]::-moz-range-thumb { border: 0px solid #000000; height: var(--line-sm); width: 8px; border-radius: var(--radius-lg); background: white; cursor: pointer; appearance: none; margin-top: 0px; }
+input[type=range]::-moz-range-progress { background-color: var(--primary-500); height: 6px; border-radius: var(--radius-lg); }
+::-webkit-scrollbar-track { background: #333333; }
+::-webkit-scrollbar-thumb { background-color: var(--highlight-color); border-radius: var(--radius-lg); border-width: 0; box-shadow: 2px 2px 3px #111111; }
+div.form { border-width: 0; box-shadow: none; background: transparent; overflow: visible; margin-bottom: 6px; }
+div.compact { gap: 1em; }
+
+/* gradio style classes */
+fieldset .gr-block.gr-box, label.block span { padding: 0; margin-top: -4px; }
+.border-2 { border-width: 0; }
+.border-b-2 { border-bottom-width: 2px; border-color: var(--highlight-color) !important; padding-bottom: 2px; margin-bottom: 8px; }
+.bg-white { color: lightyellow; background-color: var(--inactive-color); }
+.gr-box { border-radius: var(--radius-sm) !important; background-color: #111111 !important; box-shadow: 2px 2px 3px #111111; border-width: 0; padding: 4px; margin: 12px 0px 12px 0px }
+.gr-button { font-weight: normal; box-shadow: 2px 2px 3px #111111; font-size: 0.8rem; min-width: 32px; min-height: 32px; padding: 3px; margin: 3px; }
+.gr-check-radio { background-color: var(--inactive-color); border-width: 0; border-radius: var(--radius-lg); box-shadow: 2px 2px 3px #111111; }
+.gr-check-radio:checked { background-color: var(--highlight-color); }
+.gr-compact { background-color: var(--background-color); }
+.gr-form { border-width: 0; }
+.gr-input { background-color: #333333 !important; padding: 4px; margin: 4px; }
+.gr-input-label { color: lightyellow; border-width: 0; background: transparent; padding: 2px !important; }
+.gr-panel { background-color: var(--background-color); }
+.eta-bar { display: none !important }
+svg.feather.feather-image, .feather .feather-image { display: none }
+.gap-2 { padding-top: 8px; }
+.gr-box > div > div > input.gr-text-input { right: 0; width: 4em; padding: 0; top: -12px; border: none; max-height: 20px; }
+.output-html { line-height: 1.2rem; overflow-x: hidden; }
+.output-html > div { margin-bottom: 8px; }
+.overflow-hidden .flex .flex-col .relative col .gap-4 { min-width: var(--left-column); max-width: var(--left-column); } /* this is a problematic one */
+.p-2 { padding: 0; }
+.px-4 { padding-lefT: 1rem; padding-right: 1rem; }
+.py-6 { padding-bottom: 0; }
+.tabs { background-color: var(--background-color); }
+.block.token-counter span { background-color: var(--input-background-fill) !important; box-shadow: 2px 2px 2px #111; border: none !important; font-size: 0.8rem; }
+.tab-nav { zoom: 110%; margin-top: 10px; margin-bottom: 10px; border-bottom: 2px solid var(--highlight-color) !important; padding-bottom: 2px; }
+div.tab-nav button.selected {background-color: var(--button-primary-background-fill);}
+#settings div.tab-nav button.selected {background-color: var(--background-color); color: var(--primary-800); font-weight: bold;}
+.label-wrap { background-color: #18181e; /* extension tab color*/ padding: 16px 8px 8px 8px; border-radius: var(--radius-lg); padding-left: 8px !important; }
+.small-accordion .label-wrap { padding: 8px 0px 8px 0px; }
+.small-accordion .label-wrap .icon { margin-right: 1em; }
+.gradio-button.tool { border: none; box-shadow: none; border-radius: var(--radius-lg);}
+button.selected {background: var(--button-primary-background-fill);}
+.center.boundedheight.flex {background-color: var(--input-background-fill);}
+.compact {border-radius: var(--border-radius-lg);}
+#logMonitorData {background-color: var(--input-background-fill);}
+#tab_extensions table td, #tab_extensions table th, #tab_config table td, #tab_config table th { border: none; padding: 0.5em; background-color: var(--primary-200); }
+#tab_extensions table, #tab_config table { width: 96vw; }
+#tab_extensions table input[type=checkbox] {appearance: none; border-radius: 0px;}
+#tab_extensions button:hover { background-color: var(--button-secondary-background-fill-hover);}
+
+/* automatic style classes */
+.progressDiv { border-radius: var(--radius-sm) !important; position: fixed; top: 44px; right: 26px; max-width: 262px; height: 48px; z-index: 99; box-shadow: var(--button-shadow); }
+.progressDiv .progress { border-radius: var(--radius-lg) !important; background: var(--highlight-color); line-height: 3rem; height: 48px; }
+.gallery-item { box-shadow: none !important; }
+.performance { color: #888; }
+.extra-networks { border-left: 2px solid var(--highlight-color) !important; padding-left: 4px; }
+.image-buttons { gap: 10px !important; justify-content: center; }
+.image-buttons > button { max-width: 160px; }
+.tooltip { background: var(--primary-800); color: white; border: none; border-radius: var(--radius-lg) }
+#system_row > button, #settings_row > button, #config_row > button { max-width: 10em; }
+
+/* gradio elements overrides */
+#div.gradio-container { overflow-x: hidden; }
+#img2img_label_copy_to_img2img { font-weight: normal; }
+#txt2img_prompt, #txt2img_neg_prompt, #img2img_prompt, #img2img_neg_prompt { background-color: var(--background-color); box-shadow: 4px 4px 4px 0px #333333 !important; }
+#txt2img_prompt > label > textarea, #txt2img_neg_prompt > label > textarea, #img2img_prompt > label > textarea, #img2img_neg_prompt > label > textarea { font-size: 1.1rem; }
+#img2img_settings { min-width: calc(2 * var(--left-column)); max-width: calc(2 * var(--left-column)); background-color: #111111; padding-top: 16px; }
+#interrogate, #deepbooru { margin: 0 0px 10px 0px; max-width: 80px; max-height: 80px; font-weight: normal; font-size: 0.95em; }
+#quicksettings .gr-button-tool { font-size: 1.6rem; box-shadow: none; margin-top: -2px; height: 2.4em; }
+#quicksettings button {padding: 0 0.5em 0.1em 0.5em;}
+#open_folder_extras, #footer, #style_pos_col, #style_neg_col, #roll_col, #extras_upscaler_2, #extras_upscaler_2_visibility, #txt2img_seed_resize_from_w, #txt2img_seed_resize_from_h { display: none; }
+#save-animation { border-radius: var(--radius-sm) !important; margin-bottom: 16px; background-color: #111111; }
+#script_list { padding: 4px; margin-top: 16px; margin-bottom: 8px; }
+#settings > div.flex-wrap { width: 15em; }
+#txt2img_cfg_scale { min-width: 200px; }
+#txt2img_checkboxes, #img2img_checkboxes { background-color: transparent; }
+#txt2img_checkboxes, #img2img_checkboxes { margin-bottom: 0.2em; }
+#txt2img_actions_column, #img2img_actions_column { flex-flow: wrap; justify-content: space-between; }
+#txt2img_enqueue_wrapper, #img2img_enqueue_wrapper { min-width: unset; width: 48%; }
+#txt2img_generate_box, #img2img_generate_box { min-width: unset; width: 48%; }
+
+#extras_upscale { margin-top: 10px }
+#txt2img_progress_row > div { min-width: var(--left-column); max-width: var(--left-column); }
+#txt2img_settings { min-width: var(--left-column); max-width: var(--left-column); background-color: #111111; padding-top: 16px; }
+#pnginfo_html2_info { margin-top: -18px; background-color: var(--input-background-fill); padding: var(--input-padding) }
+#txt2img_tools, #img2img_tools { margin-top: -4px; margin-bottom: -4px; }
+#txt2img_styles_row, #img2img_styles_row { margin-top: -6px; z-index: 200; }
+
+/* based on gradio built-in dark theme */
+:root, .light, .dark {
+ --body-background-fill: var(--background-color);
+ --color-accent-soft: var(--neutral-700);
+ --background-fill-secondary: none;
+ --border-color-accent: var(--background-color);
+ --border-color-primary: var(--background-color);
+ --link-text-color-active: var(--primary-500);
+ --link-text-color: var(--secondary-500);
+ --link-text-color-hover: var(--secondary-400);
+ --link-text-color-visited: var(--secondary-600);
+ --shadow-spread: 1px;
+ --block-background-fill: None;
+ --block-border-color: var(--border-color-primary);
+ --block_border_width: None;
+ --block-info-text-color: var(--body-text-color-subdued);
+ --block-label-background-fill: var(--background-fill-secondary);
+ --block-label-border-color: var(--border-color-primary);
+ --block_label_border_width: None;
+ --block-label-text-color: var(--neutral-200);
+ --block_shadow: None;
+ --block_title_background_fill: None;
+ --block_title_border_color: None;
+ --block_title_border_width: None;
+ --panel-background-fill: var(--background-fill-secondary);
+ --panel-border-color: var(--border-color-primary);
+ --panel_border_width: None;
+ --checkbox-background-color: var(--primary-200);
+ --checkbox-background-color-focus: var(--primary-400);
+ --checkbox-background-color-hover: var(--primary-200);
+ --checkbox-background-color-selected: var(--primary-400);
+ --checkbox-border-color: transparent;
+ --checkbox-border-color-focus: var(--primary-800);
+ --checkbox-border-color-hover: var(--primary-800);
+ --checkbox-border-color-selected: var(--primary-800);
+ --checkbox-border-width: var(--input-border-width);
+ --checkbox-label-background-fill: None;
+ --checkbox-label-background-fill-hover: None;
+ --checkbox-label-background-fill-selected: var(--checkbox-label-background-fill);
+ --checkbox-label-border-color: var(--border-color-primary);
+ --checkbox-label-border-color-hover: var(--checkbox-label-border-color);
+ --checkbox-label-border-width: var(--input-border-width);
+ --checkbox-label-text-color: var(--body-text-color);
+ --checkbox-label-text-color-selected: var(--checkbox-label-text-color);
+ --error-background-fill: var(--background-fill-primary);
+ --error-border-color: var(--border-color-primary);
+ --error-text-color: #f768b7; /*was ef4444*/
+ --input-background-fill-focus: var(--secondary-600);
+ --input-background-fill-hover: var(--input-background-fill);
+ --input-border-color: var(--background-color);
+ --input-border-color-focus: var(--primary-800);
+ --input-placeholder-color: var(--neutral-500);
+ --input-shadow-focus: None;
+ --loader_color: None;
+ --slider_color: None;
+ --stat-background-fill: linear-gradient(to right, var(--primary-400), var(--primary-800));
+ --table-border-color: var(--neutral-700);
+ --table-even-background-fill: var(--primary-300);
+ --table-odd-background-fill: var(--primary-200);
+ --table-row-focus: var(--color-accent-soft);
+ --button-border-width: var(--input-border-width);
+ --button-cancel-background-fill: linear-gradient(to bottom right, #dc2626, #b91c1c);
+ --button-cancel-background-fill-hover: linear-gradient(to bottom right, #dc2626, #dc2626);
+ --button-cancel-border-color: #dc2626;
+ --button-cancel-border-color-hover: var(--button-cancel-border-color);
+ --button-cancel-text-color: white;
+ --button-cancel-text-color-hover: var(--button-cancel-text-color);
+ --button-primary-background-fill: var(--primary-500);
+ --button-primary-background-fill-hover: var(--primary-800);
+ --button-primary-border-color: var(--primary-500);
+ --button-primary-border-color-hover: var(--button-primary-border-color);
+ --button-primary-text-color: white;
+ --button-primary-text-color-hover: var(--button-primary-text-color);
+ --button-secondary-border-color: var(--neutral-600);
+ --button-secondary-border-color-hover: var(--button-secondary-border-color);
+ --button-secondary-text-color-hover: var(--button-secondary-text-color);
+ --secondary-50: #eff6ff;
+ --secondary-100: #dbeafe;
+ --secondary-200: #bfdbfe;
+ --secondary-300: #93c5fd;
+ --secondary-400: #60a5fa;
+ --secondary-500: #3b82f6;
+ --secondary-600: #2563eb;
+ --secondary-700: #1d4ed8;
+ --secondary-800: #1e40af;
+ --secondary-900: #1e3a8a;
+ --secondary-950: #1d3660;
+ --neutral-50: #f0f0f0; /* */
+ --neutral-100: #ddd5e8;/* majority of text (neutral gray purple) */
+ --neutral-200: #d0d0d0;
+ --neutral-300: #bfbad6; /* top tab text (light accent) */
+ --neutral-400: #ffba85;/* tab title (bright orange) */
+ --neutral-500: #545b94; /* prompt text (desat accent)*/
+ --neutral-600: #1f2028; /* tab outline color (accent color)*/
+ --neutral-700: #20212c; /* unchanged settings tab accent (dark)*/
+ --neutral-800: #e055dc; /* bright pink accent */
+ --neutral-900: #111827;
+ --neutral-950: #0b0f19;
+ --radius-xxs: 0;
+ --radius-xs: 0;
+ --radius-md: 0;
+ --radius-xl: 0;
+ --radius-xxl: 0;
+ --body-text-size: var(--text-md);
+ --body-text-weight: 400;
+ --embed-radius: var(--radius-lg);
+ --color-accent: var(--primary-500);
+ --shadow-drop: 0;
+ --shadow-drop-lg: 0 1px 3px 0 rgb(0 0 0 / 0.1), 0 1px 2px -1px rgb(0 0 0 / 0.1);
+ --shadow-inset: rgba(0,0,0,0.05) 0px 2px 4px 0px inset;
+ --block-border-width: 1px;
+ --block-info-text-size: var(--text-sm);
+ --block-info-text-weight: 400;
+ --block-label-border-width: 1px;
+ --block-label-margin: 0;
+ --block-label-padding: var(--spacing-sm) var(--spacing-lg);
+ --block-label-radius: calc(var(--radius-lg) - 1px) 0 calc(var(--radius-lg) - 1px) 0;
+ --block-label-right-radius: 0 calc(var(--radius-lg) - 1px) 0 calc(var(--radius-lg) - 1px);
+ --block-label-text-size: var(--text-sm);
+ --block-label-text-weight: 400;
+ --block-padding: var(--spacing-xl) calc(var(--spacing-xl) + 2px);
+ --block-radius: var(--radius-lg);
+ --block-shadow: var(--shadow-drop);
+ --block-title-background-fill: none;
+ --block-title-border-color: none;
+ --block-title-border-width: 0;
+ --block-title-padding: 0;
+ --block-title-radius: none;
+ --block-title-text-size: var(--text-md);
+ --block-title-text-weight: 400;
+ --container-radius: var(--radius-lg);
+ --form-gap-width: 1px;
+ --layout-gap: var(--spacing-xxl);
+ --panel-border-width: 0;
+ --section-header-text-size: var(--text-md);
+ --section-header-text-weight: 400;
+ --checkbox-border-radius: var(--radius-sm);
+ --checkbox-label-gap: 2px;
+ --checkbox-label-padding: var(--spacing-md);
+ --checkbox-label-shadow: var(--shadow-drop);
+ --checkbox-label-text-size: var(--text-md);
+ --checkbox-label-text-weight: 400;
+ --checkbox-check: url("data:image/svg+xml,%3csvg viewBox='0 0 16 16' fill='white' xmlns='http://www.w3.org/2000/svg'%3e%3cpath d='M12.207 4.793a1 1 0 010 1.414l-5 5a1 1 0 01-1.414 0l-2-2a1 1 0 011.414-1.414L6.5 9.086l4.293-4.293a1 1 0 011.414 0z'/%3e%3c/svg%3e");
+ --radio-circle: url("data:image/svg+xml,%3csvg viewBox='0 0 16 16' fill='white' xmlns='http://www.w3.org/2000/svg'%3e%3ccircle cx='8' cy='8' r='3'/%3e%3c/svg%3e");
+ --checkbox-shadow: var(--input-shadow);
+ --error-border-width: 1px;
+ --input-border-width: 1px;
+ --input-radius: var(--radius-lg);
+ --input-text-size: var(--text-md);
+ --input-text-weight: 400;
+ --loader-color: var(--color-accent);
+ --prose-text-size: var(--text-md);
+ --prose-text-weight: 400;
+ --prose-header-text-weight: 600;
+ --slider-color: ;
+ --table-radius: var(--radius-lg);
+ --button-large-padding: 2px 6px;
+ --button-large-radius: var(--radius-lg);
+ --button-large-text-size: var(--text-lg);
+ --button-large-text-weight: 400;
+ --button-shadow: none;
+ --button-shadow-active: none;
+ --button-shadow-hover: none;
+ --button-small-padding: var(--spacing-sm) calc(2 * var(--spacing-sm));
+ --button-small-radius: var(--radius-lg);
+ --button-small-text-size: var(--text-md);
+ --button-small-text-weight: 400;
+ --button-transition: none;
+ --size-9: 64px;
+ --size-14: 64px;
+}
diff --git a/javascript/sdnext.css b/javascript/sdnext.css
index f710f742a..2ec0f35ad 100644
--- a/javascript/sdnext.css
+++ b/javascript/sdnext.css
@@ -1,343 +1,343 @@
-@font-face { font-family: 'NotoSans'; font-display: swap; font-style: normal; font-weight: 100; src: local('NotoSans'), url('notosans-nerdfont-regular.ttf') }
-:root { --left-column: 500px; }
-a { font-weight: bold; cursor: pointer; }
-h2 { margin-top: 1em !important; font-size: var(--text-xxl) !important; }
-footer { display: none; }
-table { overflow-x: auto !important; overflow-y: auto !important; }
-td { border-bottom: none !important; padding: 0.1em 0.5em !important; }
-tr { border-bottom: none !important; padding: 0.1em 0.5em !important; }
-textarea { overflow-y: auto !important; }
-span { font-size: var(--text-md) !important; }
-button { font-size: var(--text-lg) !important; }
-
-/* gradio elements */
-.block .padded:not(.gradio-accordion) { padding: 0 !important; margin-right: 0; min-width: 90px !important; }
-.compact { gap: 1em 0.2em; background: transparent !important; padding: 0 !important; }
-.flex-break { flex-basis: 100% !important; }
-.form { border-width: 0; box-shadow: none; background: transparent; overflow: visible; gap: 0.5em 1em; flex-grow: 1 !important; }
-.form-compact { margin-bottom: 0 !important; gap: 0.2em 1em !important; }
-.gap .compact{ padding: 0; gap: 0.2em 0; }
-.hidden { display: none; }
-.tabitem { padding: 0 !important; }
-
-.gradio-dropdown, .block.gradio-slider, .block.gradio-checkbox, .block.gradio-textbox, .block.gradio-radio, .block.gradio-checkboxgroup, .block.gradio-number, .block.gradio-colorpicker { border-width: 0 !important; box-shadow: none !important;}
-.gradio-accordion { padding-top: var(--spacing-md) !important; padding-right: 0 !important; padding-bottom: 0 !important; color: var(--body-text-color); }
-.gradio-accordion .label-wrap .icon { color: var(--button-primary-border-color); }
-.gradio-button { border-radius: var(--radius-lg) !important; }
-.gradio-button.secondary-down { background: var(--button-secondary-background-fill); color: var(--button-secondary-text-color); }
-.gradio-button.secondary-down, .gradio-button.secondary-down:hover { box-shadow: 1px 1px 1px rgba(0,0,0,0.25) inset, 0px 0px 3px rgba(0,0,0,0.15) inset; }
-.gradio-button.secondary-down:hover { background: var(--button-secondary-background-fill-hover); color: var(--button-secondary-text-color-hover); }
-.gradio-button.tool { max-width: min-content; min-width: min-content !important; align-self: end; font-size: 20px !important; color: var(--body-text-color) !important; margin-top: auto; margin-bottom: var(--spacing-md); align-self: center; }
-.gradio-checkbox { margin: 0.75em 1.5em 0 0; align-self: center; }
-.gradio-column { min-width: min(160px, 100%) !important; }
-.gradio-container { max-width: unset !important; padding: var(--block-label-padding) !important; }
-.gradio-container .prose a, .gradio-container .prose a:visited{ color: unset; text-decoration: none; }
-.gradio-dropdown { margin-right: var(--spacing-sm) !important; min-width:160px; max-width:fit-content }
-.gradio-dropdown ul.options { z-index: 1000; min-width: fit-content; max-height: 33vh !important; white-space: nowrap; }
-.gradio-dropdown ul.options li.item { padding: var(--spacing-xs); }
-.gradio-dropdown ul.options li.item:not(:has(.hide)) { background-color: var(--primary-500); }
-.gradio-dropdown .token { padding: var(--spacing-xs); }
-.gradio-dropdown span { margin-bottom: 0 !important; font-size: var(--text-sm); }
-.gradio-dropdown .reference { margin-bottom: var(--spacing-sm) !important; }
-.gradio-html { color: var(--body-text-color); }
-.gradio-html .min { min-height: 0; }
-.gradio-html div.wrap { height: 100%; }
-.gradio-number { min-width: unset !important; max-width: 5em !important; }
-.gradio-textbox { overflow: visible !important; }
-.gradio-radio { padding: 0 !important; width: max-content !important; }
-.gradio-slider { margin-right: var(--spacing-sm) !important; width: max-content !important }
-.gradio-slider input[type="number"] { width: 6em; font-size: var(--text-xs); height: 16px; text-align: right; }
-
-/* custom gradio elements */
-.accordion-compact { padding: 8px 0px 4px 0px !important; }
-.settings-accordion>div { flex-flow: wrap; }
-.small-accordion .form { min-width: var(--left-column) !important; max-width: max-content; }
-.small-accordion .label-wrap .icon { margin-right: 1.6em; margin-left: 0.6em; color: var(--button-primary-border-color); }
-.small-accordion .label-wrap { padding: 16px 0px 8px 0px; margin: 0; border-top: 2px solid var(--button-secondary-border-color); }
-.small-accordion { width: fit-content !important; min-width: fit-content !important; padding-left: 0 !important; }
-.extension-script { max-width: 48vw; }
-button.custom-button{ border-radius: var(--button-large-radius); padding: var(--button-large-padding); font-weight: var(--button-large-text-weight); border: var(--button-border-width) solid var(--button-secondary-border-color);
- background: var(--button-secondary-background-fill); color: var(--button-secondary-text-color); font-size: var(--text-lg);
- display: inline-flex; justify-content: center; align-items: center; transition: var(--button-transition); box-shadow: var(--button-shadow); text-align: center; }
-
-/* themes */
-.theme-preview { display: none; position: fixed; border: var(--spacing-sm) solid var(--neutral-600); box-shadow: 2px 2px 2px 2px var(--neutral-700); top: 0; bottom: 0; left: 0; right: 0; margin: auto; max-width: 75vw; z-index: 999; }
-
-/* txt2img/img2img specific */
-.block.token-counter{ position: absolute; display: inline-block; right: 1em; min-width: 0 !important; width: auto; z-index: 100; top: -0.5em; }
-.block.token-counter span{ background: var(--input-background-fill) !important; box-shadow: 0 0 0.0 0.3em rgba(192,192,192,0.15), inset 0 0 0.6em rgba(192,192,192,0.075); border: 2px solid rgba(192,192,192,0.4) !important; }
-.block.token-counter.error span{ box-shadow: 0 0 0.0 0.3em rgba(255,0,0,0.15), inset 0 0 0.6em rgba(255,0,0,0.075); border: 2px solid rgba(255,0,0,0.4) !important; }
-.block.token-counter div{ display: inline; }
-.block.token-counter span{ padding: 0.1em 0.75em; }
-.performance { font-size: var(--text-xs); color: #444; }
-.performance p { display: inline-block; color: var(--body-text-color-subdued) !important }
-.performance .time { margin-right: 0; }
-.thumbnails { background: var(--body-background-fill); }
-#control_gallery { height: 564px; }
-#control-result { padding: 0.5em; }
-#control-inputs { margin-top: 1em; }
-#txt2img_prompt_container, #img2img_prompt_container, #control_prompt_container { margin-right: var(--layout-gap) }
-#txt2img_footer, #img2img_footer, #control_footer { height: fit-content; display: none; }
-#txt2img_generate_box, #img2img_generate_box, #control_general_box { gap: 0.5em; flex-wrap: wrap-reverse; height: fit-content; }
-#txt2img_actions_column, #img2img_actions_column, #control_actions_column { gap: 0.3em; height: fit-content; }
-#txt2img_generate_box>button, #img2img_generate_box>button, #control_generate_box>button, #txt2img_enqueue, #img2img_enqueue { min-height: 42px; max-height: 42px; line-height: 1em; }
-#txt2img_generate_line2, #img2img_generate_line2, #txt2img_tools, #img2img_tools, #control_generate_line2, #control_tools { display: flex; }
-#txt2img_generate_line2>button, #img2img_generate_line2>button, #extras_generate_box>button, #control_generate_line2>button, #txt2img_tools>button, #img2img_tools>button, #control_tools>button { height: 2em; line-height: 0; font-size: var(--text-md);
- min-width: unset; display: block !important; }
-#txt2img_prompt, #txt2img_neg_prompt, #img2img_prompt, #img2img_neg_prompt, #control_prompt, #control_neg_prompt { display: contents; }
-#txt2img_generate_box, #img2img_generate_box { min-width: unset; width: 48%; }
-#control_generate_box { min-width: unset; width: 100%; }
-#txt2img_actions_column, #img2img_actions_column, #control_actions { flex-flow: wrap; justify-content: space-between; }
-#txt2img_enqueue_wrapper, #img2img_enqueue_wrapper, #control_enqueue_wrapper { min-width: unset !important; width: 48%; }
-.interrogate-clip { position: absolute; right: 3em; top: -2.7em; max-width: fit-content; }
-.interrogate-blip { position: absolute; right: 1em; top: -2.7em; max-width: fit-content; }
-.interrogate-col{ min-width: 0 !important; max-width: fit-content; margin-right: var(--spacing-xxl); }
-.interrogate-col>button{ flex: 1; width: 7em; max-height: 84px; }
-#sampler_selection_img2img { margin-top: 1em; }
-#txtimg_hr_finalres{ min-height: 0 !important; }
-#img2img_scale_resolution_preview.block{ display: flex; align-items: end; }
-#txtimg_hr_finalres .resolution, #img2img_scale_resolution_preview .resolution{ font-weight: bold; }
-div#extras_scale_to_tab div.form{ flex-direction: row; }
-#img2img_unused_scale_by_slider { visibility: hidden; width: 0.5em; max-width: 0.5em; min-width: 0.5em; }
-.inactive{ opacity: 0.5; }
-div#extras_scale_to_tab div.form{ flex-direction: row; }
-#mode_img2img .gradio-image>div.fixed-height, #mode_img2img .gradio-image>div.fixed-height img{ height: 480px !important; max-height: 480px !important; min-height: 480px !important; }
-#img2img_sketch, #img2maskimg, #inpaint_sketch { overflow: overlay !important; resize: auto; background: var(--panel-background-fill); z-index: 5; }
-.image-buttons button{ min-width: auto; }
-.infotext { overflow-wrap: break-word; line-height: 1.5em; }
-.infotext>p { padding-left: 1em; text-indent: -1em; white-space: pre-wrap; }
-.tooltip { display: block; position: fixed; top: 1em; right: 1em; padding: 0.5em; background: var(--input-background-fill); color: var(--body-text-color); border: 1pt solid var(--button-primary-border-color);
- width: 22em; min-height: 1.3em; font-size: var(--text-xs); transition: opacity 0.2s ease-in; pointer-events: none; opacity: 0; z-index: 999; }
-.tooltip-show { opacity: 0.9; }
-.toolbutton-selected { background: var(--background-fill-primary) !important; }
-
-/* settings */
-#si-sparkline-memo, #si-sparkline-load { background-color: #111; }
-#quicksettings { width: fit-content; }
-#quicksettings>button { padding: 0 1em 0 0; align-self: end; margin-bottom: var(--text-sm); }
-#settings { display: flex; gap: var(--layout-gap); }
-#settings div { border: none; gap: 0; margin: 0 0 var(--layout-gap) 0px; padding: 0; }
-#settings>div.tab-content { flex: 10 0 75%; display: grid; }
-#settings>div.tab-content>div { border: none; padding: 0; }
-#settings>div.tab-content>div>div>div>div>div { flex-direction: unset; }
-#settings>div.tab-nav { display: grid; grid-template-columns: repeat(auto-fill, .5em minmax(10em, 1fr)); flex: 1 0 auto; width: 12em; align-self: flex-start; gap: var(--spacing-xxl); }
-#settings>div.tab-nav button { display: block; border: none; text-align: left; white-space: initial; padding: 0; }
-#settings>div.tab-nav>#settings_show_all_pages { padding: var(--size-2) var(--size-4); }
-#settings .block.gradio-checkbox { margin: 0; width: auto; }
-#settings .dirtyable { gap: .5em; }
-#settings .dirtyable.hidden { display: none; }
-#settings .modification-indicator { height: 1.2em; border-radius: 1em !important; padding: 0; width: 0; margin-right: 0.5em; }
-#settings .modification-indicator:disabled { visibility: hidden; }
-#settings .modification-indicator.saved { background: var(--color-accent-soft); width: var(--spacing-sm); }
-#settings .modification-indicator.changed { background: var(--color-accent); width: var(--spacing-sm); }
-#settings .modification-indicator.changed.unsaved { background-image: linear-gradient(var(--color-accent) 25%, var(--color-accent-soft) 75%); width: var(--spacing-sm); }
-#settings_result { margin: 0 1.2em; }
-.licenses { display: block !important; }
-
-/* live preview */
-.progressDiv{ position: relative; height: 20px; background: #b4c0cc; margin-bottom: -3px; }
-.dark .progressDiv{ background: #424c5b; }
-.progressDiv .progress{ width: 0%; height: 20px; background: #0060df; color: white; font-weight: bold; line-height: 20px; padding: 0 8px 0 0; text-align: right; overflow: visible; white-space: nowrap; padding: 0 0.5em; }
-.livePreview { position: absolute; z-index: 50; background-color: transparent; width: -moz-available; width: -webkit-fill-available; }
-.livePreview img { position: absolute; object-fit: contain; width: 100%; height: 100%; }
-.dark .livePreview { background-color: rgb(17 24 39 / var(--tw-bg-opacity)); }
-.popup-metadata { color: white; background: #0000; display: inline-block; white-space: pre-wrap; font-size: var(--text-xxs); }
-.global-popup{ display: flex; position: fixed; z-index: 10001; left: 0; top: 0; width: 100%; height: 100%; overflow: auto; background-color: rgba(20, 20, 20, 0.95);}
-.global-popup-close:before { content: "×"; }
-.global-popup-close{ position: fixed; right: 0.5em; top: 0; cursor: pointer; color: white; font-size: 32pt; }
-.global-popup-inner{ display: inline-block; margin: auto; padding: 2em; }
-
-/* fullpage image viewer */
-#lightboxModal{ display: none; position: fixed; z-index: 1001; left: 0; top: 0; width: 100%; height: 100%; overflow: auto; background-color: rgba(20, 20, 20, 0.75); backdrop-filter: blur(6px);
- user-select: none; -webkit-user-select: none; flex-direction: row; }
-.modalControls { display: flex; justify-content: space-evenly; background-color: transparent; position: absolute; width: 99%; z-index: 1; }
-.modalControls:hover { background-color: #50505050; }
-.modalControls span { color: white; font-size: 2em; font-weight: bold; cursor: pointer; filter: grayscale(100%); }
-.modalControls span:hover, .modalControls span:focus { color: var(--highlight-color); filter: none; }
-.lightboxModalPreviewZone { display: flex; width: 100%; height: 100%; }
-.lightboxModalPreviewZone:focus-visible { outline: none; }
-.lightboxModalPreviewZone>img { display: block; margin: auto; width: auto; }
-.lightboxModalPreviewZone>img.modalImageFullscreen{ object-fit: contain; height: 100%; width: 100%; min-height: 0; background: transparent; }
-table.settings-value-table { background: white; border-collapse: collapse; margin: 1em; border: var(--spacing-sm) solid white; }
-table.settings-value-table td { padding: 0.4em; border: 1px solid #ccc; max-width: 36em; }
-.modalPrev, .modalNext { cursor: pointer; position: relative; z-index: 1; top: 0; width: auto; height: 100vh; line-height: 100vh; text-align: center; padding: 16px;
- margin-top: -50px; color: white; font-weight: bold; font-size: 20px; transition: 0.6s ease; user-select: none; -webkit-user-select: none; }
-.modalNext { right: 0; }
-.modalPrev:hover, .modalNext:hover { background-color: rgba(0, 0, 0, 0.8); }
-#imageARPreview { position: absolute; top: 0px; left: 0px; border: 2px solid red; background: rgba(255, 0, 0, 0.3); z-index: 900; pointer-events: none; display: none; }
-
-/* context menu (ie for the generate button) */
-#context-menu { z-index: 9999; position: absolute; display: block; padding: var(--spacing-md); border: 2px solid var(--highlight-color); background: var(--background-fill-primary); color: var(--body-text-color); }
-.context-menu-items { list-style: none; margin: 0; padding: 0; font-size: var(--text-sm); }
-.context-menu-items a { display: block; padding: var(--spacing-md); cursor: pointer; font-weight: normal; }
-.context-menu-items a:hover { background: var(--highlight-color) }
-
-/* extensions */
-#tab_extensions table, #tab_config table{ border-collapse: collapse; }
-#tab_extensions table td, #tab_extensions table th, #tab_config table td, #tab_config table th { border: 1px solid #ccc; padding: 0.25em 0.5em; }
-#tab_extensions table tr:hover, #tab_config table tr:hover { background-color: var(--neutral-500) !important; }
-#tab_extensions table input[type="checkbox"] { margin-right: 0.5em; appearance: checkbox; }
-#tab_extensions button{ max-width: 16em; }
-#tab_extensions input[disabled="disabled"]{ opacity: 0.5; }
-.extension-tag{ font-weight: bold; font-size: var(--text-sm); }
-.extension-button { font-size: var(--text-sm) !important; width: 6em; }
-#extensions .name{ font-size: var(--text-lg) }
-#extensions .type{ opacity: 0.5; font-size: var(--text-sm); text-align: center; }
-#extensions .version{ opacity: 0.7; }
-#extensions .info{ margin: 0; }
-#extensions .date{ opacity: 0.85; font-size: var(--text-sm); }
-
-/* extra networks */
-.extra-networks>div { margin: 0; border-bottom: none !important; gap: 0.3em 0; }
-.extra-networks .second-line { display: flex; width: -moz-available; width: -webkit-fill-available; gap: 0.3em; box-shadow: var(--input-shadow); }
-.extra-networks .search { flex: 1; }
-.extra-networks .description { flex: 3; }
-.extra-networks .tab-nav>button { margin-right: 0; height: 24px; padding: 2px 4px 2px 4px; }
-.extra-networks .buttons { position: absolute; right: 0; margin: -4px; background: var(--background-color); }
-.extra-networks .buttons>button { margin-left: -0.2em; height: 1.4em; color: var(--primary-300) !important; font-size: 20px !important; }
-.extra-networks .custom-button { width: 120px; width: 100%; background: none; justify-content: left; text-align: left; padding: 3px 3px 3px 12px; text-indent: -6px; box-shadow: none; line-break: auto; }
-.extra-networks .custom-button:hover { background: var(--button-primary-background-fill) }
-.extra-networks-tab { padding: 0 !important; }
-.extra-network-subdirs { background: var(--input-background-fill); overflow-x: hidden; overflow-y: auto; min-width: max(15%, 120px); padding-top: 0.5em; margin-top: -4px !important; }
-.extra-networks-page { display: flex }
-.extra-network-cards { display: flex; flex-wrap: wrap; overflow-y: auto; overflow-x: hidden; align-content: flex-start; width: -moz-available; width: -webkit-fill-available; }
-.extra-network-cards .card { height: fit-content; margin: 0 0 0.5em 0.5em; position: relative; scroll-snap-align: start; scroll-margin-top: 0; }
-.extra-network-cards .card .overlay { position: absolute; bottom: 0; padding: 0.2em; z-index: 10; width: 100%; background: none; }
-.extra-network-cards .card .overlay .name { font-size: var(--text-lg); font-weight: bold; text-shadow: 1px 1px black; color: white; overflow-wrap: break-word; }
-.extra-network-cards .card .preview { box-shadow: var(--button-shadow); min-height: 30px; }
-.extra-network-cards .card:hover .overlay { background: rgba(0, 0, 0, 0.40); }
-.extra-network-cards .card:hover .preview { box-shadow: none; filter: grayscale(100%); }
-.extra-network-cards .card:hover .overlay { background: rgba(0, 0, 0, 0.40); }
-.extra-network-cards .card .overlay .tags { display: none; overflow-wrap: break-word; }
-.extra-network-cards .card .overlay .tag { padding: 2px; margin: 2px; background: rgba(70, 70, 70, 0.60); font-size: var(--text-md); cursor: pointer; display: inline-block; }
-.extra-network-cards .card .actions>span { padding: 4px; font-size: 34px !important; }
-.extra-network-cards .card .actions>span:hover { color: var(--highlight-color); }
-.extra-network-cards .card:hover .actions { display: block; }
-.extra-network-cards .card:hover .overlay .tags { display: block; }
-.extra-network-cards .card .actions { font-size: 3em; display: none; text-align-last: right; cursor: pointer; font-variant: unicase; position: absolute; z-index: 80; right: 0; height: 0.7em; width: 100%; background: rgba(0, 0, 0, 0.40); }
-.extra-network-cards .card-list { display: flex; margin: 0.3em; padding: 0.3em; background: var(--input-background-fill); cursor: pointer; border-radius: var(--button-large-radius); }
-.extra-network-cards .card-list .tag { color: var(--primary-500); margin-left: 0.8em; }
-.extra-details-close { position: fixed; top: 0.2em; right: 0.2em; z-index: 99; background: var(--button-secondary-background-fill) !important; }
-#txt2img_description, #img2img_description, #control_description { max-height: 63px; overflow-y: auto !important; }
-#txt2img_description>label>textarea, #img2img_description>label>textarea, #control_description>label>textarea { font-size: var(--text-sm) }
-
-#txt2img_extra_details>div, #img2img_extra_details>div { overflow-y: auto; min-height: 40vh; max-height: 80vh; align-self: flex-start; }
-#txt2img_extra_details, #img2img_extra_details { position: fixed; bottom: 50%; left: 50%; transform: translate(-50%, 50%); padding: 0.8em; border: var(--block-border-width) solid var(--highlight-color) !important;
- z-index: 100; box-shadow: var(--button-shadow); }
-#txt2img_extra_details td:first-child, #img2img_extra_details td:first-child { font-weight: bold; vertical-align: top; }
-#txt2img_extra_details .gradio-image, #img2img_extra_details .gradio-image { max-height: 70vh; }
-
-
-/* specific elements */
-#modelmerger_interp_description { margin-top: 1em; margin-bottom: 1em; }
-#scripts_alwayson_txt2img, #scripts_alwayson_img2img { padding: 0 }
-#scripts_alwayson_txt2img>.label-wrap, #scripts_alwayson_img2img>.label-wrap { background: var(--input-background-fill); padding: 0; margin: 0; border-radius: var(--radius-lg); }
-#scripts_alwayson_txt2img>.label-wrap>span, #scripts_alwayson_img2img>.label-wrap>span { padding: var(--spacing-xxl); }
-#scripts_alwayson_txt2img div { max-width: var(--left-column); }
-#script_txt2img_agent_scheduler { display: none; }
-#refresh_tac_refreshTempFiles { display: none; }
-#train_tab { flex-flow: row-reverse; }
-#models_tab { flex-flow: row-reverse; }
-#swap_axes>button { min-width: 100px; font-size: var(--text-md); }
-#ui_defaults_review { margin: 1em; }
-
-/* extras */
-.extras { gap: 0.2em 1em !important }
-#extras_generate, #extras_interrupt, #extras_skip { display: block !important; position: relative; height: 36px; }
-#extras_upscale { margin-top: 10px }
-#pnginfo_html_info .gradio-html > div { margin: 0.5em; }
-
-/* log monitor */
-.log-monitor { display: none; justify-content: unset !important; overflow: hidden; padding: 0; margin-top: auto; font-family: monospace; font-size: var(--text-xs); }
-.log-monitor td, .log-monitor th { padding-left: 1em; }
-
-/* changelog */
-.md h2 { background-color: var(--background-fill-primary); padding: 0.5em; }
-.md ul { list-style-type: square !important; text-indent: 1em; margin-left: 4em; }
-.md li { list-style-position: outside !important; text-indent: 0; }
-.md p { margin-left: 2em; }
-
-/* custom component */
-.folder-selector textarea { height: 2em !important; padding: 6px !important; }
-.nvml { position: fixed; bottom: 10px; right: 10px; background: var(--background-fill-primary); border: 1px solid var(--button-primary-border-color); padding: 6px; color: var(--button-primary-text-color);
- font-size: var(--text-xxs); z-index: 50; font-family: monospace; display: none; }
-
-/* control */
-#control_input_type { max-width: 18em }
-#control_settings .small-accordion .form { min-width: 350px !important }
-.control-button { min-height: 42px; max-height: 42px; line-height: 1em; }
-.control-tabs > .tab-nav { margin-bottom: 0; margin-top: 0; }
-.control-unit { max-width: 1200px; padding: 0 !important; margin-top: -10px !important; }
-.control-unit > .label-wrap { margin-bottom: 0 !important; }
-.processor-settings { padding: 0 !important; max-width: 300px; }
-.processor-group>div { flex-flow: wrap;gap: 1em; }
-
-/* main info */
-.main-info { font-weight: var(--section-header-text-weight); color: var(--body-text-color-subdued); padding: 1em !important; margin-top: 2em !important; line-height: var(--line-lg) !important; }
-
-/* loader */
-.splash { position: fixed; top: 0; left: 0; width: 100vw; height: 100vh; z-index: 1000; display: block; text-align: center; }
-.motd { margin-top: 2em; color: var(--body-text-color-subdued); font-family: monospace; font-variant: all-petite-caps; }
-.splash-img { margin: 10% auto 0 auto; width: 512px; background-repeat: no-repeat; height: 512px; animation: color 10s infinite alternate; }
-.loading { color: white; position: absolute; top: 20%; left: 50%; transform: translateX(-50%); }
-.loader { width: 300px; height: 300px; border: var(--spacing-md) solid transparent; border-radius: 50%; border-top: var(--spacing-md) solid var(--primary-600); animation: spin 4s linear infinite; position: relative; }
-.loader::before, .loader::after { content: ""; position: absolute; top: 6px; bottom: 6px; left: 6px; right: 6px; border-radius: 50%; border: var(--spacing-md) solid transparent; }
-.loader::before { border-top-color: var(--primary-900); animation: 3s spin linear infinite; }
-.loader::after { border-top-color: var(--primary-300); animation: spin 1.5s linear infinite; }
-@keyframes move { from { background-position-x: 0, -40px; } to { background-position-x: 0, 40px; } }
-@keyframes spin { from { transform: rotate(0deg); } to { transform: rotate(360deg); } }
-@keyframes color { from { filter: hue-rotate(0deg) } to { filter: hue-rotate(360deg) } }
-
-:root, .light, .dark {
- --text-xxs: 9px;
- --text-xs: 10px;
- --text-sm: 12px;
- --text-md: 14px;
- --text-lg: 15px;
- --text-xl: 16px;
- --text-xxl: 17px;
- --spacing-xxs: 1px;
- --spacing-xs: 2px;
- --spacing-sm: 3px;
- --spacing-lg: 4px;
- --spacing-xl: 5px;
- --spacing-xxl: 6px;
-}
-
-@media (hover: none) and (pointer: coarse) { /* Apply different styles for devices with coarse pointers dependant on screen resolution */
- @media (max-width: 1024px) { /* Do not affect displays larger than 1024px wide. */
- @media (max-width: 399px) { /* Screens smaller than 400px wide */
- :root, .light, .dark { --left-column: 100%; }
- #txt2img_results, #img2img_results, #extras_results { min-width: calc(min(320px, 100%)) !important;} /* maintain single column for from image operations on larger mobile devices */
- #txt2img_footer p { text-wrap: wrap; }
- }
- @media (min-width: 400px) { /* Screens larger than 400px wide */
- :root, .light, .dark {--left-column: 50%;}
- #txt2img_results, #extras_results, #txt2im g_footer p {text-wrap: wrap; max-width: 100% !important; } /* maintain side by side split on larger mobile displays for from text */
- }
- #scripts_alwayson_txt2img div, #scripts_alwayson_img2img div { max-width: 100%; }
- #txt2img_prompt_container, #img2img_prompt_container, #control_prompt_container { resize:vertical !important; }
- #txt2img_generate_box, #txt2img_enqueue_wrapper { min-width: 100% !important;} /* make generate and enqueue buttons take up the entire width of their rows. */
- #img2img_toprow>div.gradio-column {flex-grow: 1 !important;} /*make interrogate buttons take up appropriate space. */
- #img2img_actions_column {display: flex; min-width: fit-content !important; flex-direction: row;justify-content: space-evenly; align-items: center;}
- #txt2img_generate_box, #img2img_generate_box, #txt2img_enqueue_wrapper,#img2img_enqueue_wrapper {display: flex;flex-direction: column;height: 4em !important;align-items: stretch;justify-content: space-evenly;}
- #img2img_interface, #img2img_results, #img2img_footer p {text-wrap: wrap; min-width: 100% !important; max-width: 100% !important;} /* maintain single column for from image operations on larger mobile devices */
- #img2img_sketch, #img2maskimg, #inpaint_sketch {display: flex; overflow: auto !important; resize: none !important; } /* fix inpaint image display being too large for mobile displays */
- #img2maskimg canvas { width: auto !important; max-height: 100% !important; height: auto !important; }
- #txt2img_sampler, #txt2img_batch, #txt2img_seed_group, #txt2img_advanced, #txt2img_second_pass, #img2img_sampling_group, #img2img_resize_group, #img2img_batch_group, #img2img_seed_group, #img2img_denoise_group, #img2img_advanced_group { width: 100% !important; } /* fix from text/image UI elements to prevent them from moving around within the UI */
- #img2img_resize_group .gradio-radio>div { display: flex; flex-direction: column; width: unset !important; }
- #inpaint_controls div {display:flex;flex-direction: row;}
- #inpaint_controls .gradio-radio>div { display: flex; flex-direction: column !important; }
- #models_tab { flex-direction: column-reverse !important; } /* move image preview/output on models page to bottom of page */
- #enqueue_keyboard_shortcut_modifiers, #enqueue_keyboard_shortcut_key div { max-width: 40% !important;} /* fix settings for agent scheduler */
- #settings { display: flex; flex-direction: row; flex-wrap: wrap; max-width: 100% !important; } /* adjust width of certain settings item to allow aligning as row, but not have it go off the screen */
- #settings div.tab-content>div>div>div { max-width: 80% !important;}
- #settings div .gradio-radio { width: unset !important; }
- #tab_extensions table { border-collapse: collapse; display: block; overflow-x:auto !important;} /* enable scrolling on extensions tab */
- ::-webkit-scrollbar { width: 25px !important; height:25px; } /* increase scrollbar size to make it finger friendly */
- .gradio-dropdown ul.options {max-height: 41vh !important; } /* adjust dropdown size to make them easier to select individual items on mobile. */
- .gradio-dropdown ul.options li.item {height: 40px !important; display: flex; align-items: center;}
- .gradio-slider input[type="number"] { width: 4em; font-size: var(--text-xs); height: 16px; text-align: center; } /* adjust slider input fields as they were too large for mobile devices. */
- #txt2img_settings .block .padded:not(.gradio-accordion) {padding: 0 !important;margin-right: 0; min-width: 100% !important; width:100% !important;}
- }
-}
+@font-face { font-family: 'NotoSans'; font-display: swap; font-style: normal; font-weight: 100; src: local('NotoSans'), url('notosans-nerdfont-regular.ttf') }
+:root { --left-column: 500px; }
+a { font-weight: bold; cursor: pointer; }
+h2 { margin-top: 1em !important; font-size: var(--text-xxl) !important; }
+footer { display: none; }
+table { overflow-x: auto !important; overflow-y: auto !important; }
+td { border-bottom: none !important; padding: 0.1em 0.5em !important; }
+tr { border-bottom: none !important; padding: 0.1em 0.5em !important; }
+textarea { overflow-y: auto !important; }
+span { font-size: var(--text-md) !important; }
+button { font-size: var(--text-lg) !important; }
+
+/* gradio elements */
+.block .padded:not(.gradio-accordion) { padding: 0 !important; margin-right: 0; min-width: 90px !important; }
+.compact { gap: 1em 0.2em; background: transparent !important; padding: 0 !important; }
+.flex-break { flex-basis: 100% !important; }
+.form { border-width: 0; box-shadow: none; background: transparent; overflow: visible; gap: 0.5em 1em; flex-grow: 1 !important; }
+.form-compact { margin-bottom: 0 !important; gap: 0.2em 1em !important; }
+.gap .compact{ padding: 0; gap: 0.2em 0; }
+.hidden { display: none; }
+.tabitem { padding: 0 !important; }
+
+.gradio-dropdown, .block.gradio-slider, .block.gradio-checkbox, .block.gradio-textbox, .block.gradio-radio, .block.gradio-checkboxgroup, .block.gradio-number, .block.gradio-colorpicker { border-width: 0 !important; box-shadow: none !important;}
+.gradio-accordion { padding-top: var(--spacing-md) !important; padding-right: 0 !important; padding-bottom: 0 !important; color: var(--body-text-color); }
+.gradio-accordion .label-wrap .icon { color: var(--button-primary-border-color); }
+.gradio-button { border-radius: var(--radius-lg) !important; }
+.gradio-button.secondary-down { background: var(--button-secondary-background-fill); color: var(--button-secondary-text-color); }
+.gradio-button.secondary-down, .gradio-button.secondary-down:hover { box-shadow: 1px 1px 1px rgba(0,0,0,0.25) inset, 0px 0px 3px rgba(0,0,0,0.15) inset; }
+.gradio-button.secondary-down:hover { background: var(--button-secondary-background-fill-hover); color: var(--button-secondary-text-color-hover); }
+.gradio-button.tool { max-width: min-content; min-width: min-content !important; align-self: end; font-size: 20px !important; color: var(--body-text-color) !important; margin-top: auto; margin-bottom: var(--spacing-md); align-self: center; }
+.gradio-checkbox { margin: 0.75em 1.5em 0 0; align-self: center; }
+.gradio-column { min-width: min(160px, 100%) !important; }
+.gradio-container { max-width: unset !important; padding: var(--block-label-padding) !important; }
+.gradio-container .prose a, .gradio-container .prose a:visited{ color: unset; text-decoration: none; }
+.gradio-dropdown { margin-right: var(--spacing-sm) !important; min-width:160px; max-width:fit-content }
+.gradio-dropdown ul.options { z-index: 1000; min-width: fit-content; max-height: 33vh !important; white-space: nowrap; }
+.gradio-dropdown ul.options li.item { padding: var(--spacing-xs); }
+.gradio-dropdown ul.options li.item:not(:has(.hide)) { background-color: var(--primary-500); }
+.gradio-dropdown .token { padding: var(--spacing-xs); }
+.gradio-dropdown span { margin-bottom: 0 !important; font-size: var(--text-sm); }
+.gradio-dropdown .reference { margin-bottom: var(--spacing-sm) !important; }
+.gradio-html { color: var(--body-text-color); }
+.gradio-html .min { min-height: 0; }
+.gradio-html div.wrap { height: 100%; }
+.gradio-number { min-width: unset !important; max-width: 5em !important; }
+.gradio-textbox { overflow: visible !important; }
+.gradio-radio { padding: 0 !important; width: max-content !important; }
+.gradio-slider { margin-right: var(--spacing-sm) !important; width: max-content !important }
+.gradio-slider input[type="number"] { width: 6em; font-size: var(--text-xs); height: 16px; text-align: right; }
+
+/* custom gradio elements */
+.accordion-compact { padding: 8px 0px 4px 0px !important; }
+.settings-accordion>div { flex-flow: wrap; }
+.small-accordion .form { min-width: var(--left-column) !important; max-width: max-content; }
+.small-accordion .label-wrap .icon { margin-right: 1.6em; margin-left: 0.6em; color: var(--button-primary-border-color); }
+.small-accordion .label-wrap { padding: 16px 0px 8px 0px; margin: 0; border-top: 2px solid var(--button-secondary-border-color); }
+.small-accordion { width: fit-content !important; min-width: fit-content !important; padding-left: 0 !important; }
+.extension-script { max-width: 48vw; }
+button.custom-button{ border-radius: var(--button-large-radius); padding: var(--button-large-padding); font-weight: var(--button-large-text-weight); border: var(--button-border-width) solid var(--button-secondary-border-color);
+ background: var(--button-secondary-background-fill); color: var(--button-secondary-text-color); font-size: var(--text-lg);
+ display: inline-flex; justify-content: center; align-items: center; transition: var(--button-transition); box-shadow: var(--button-shadow); text-align: center; }
+
+/* themes */
+.theme-preview { display: none; position: fixed; border: var(--spacing-sm) solid var(--neutral-600); box-shadow: 2px 2px 2px 2px var(--neutral-700); top: 0; bottom: 0; left: 0; right: 0; margin: auto; max-width: 75vw; z-index: 999; }
+
+/* txt2img/img2img specific */
+.block.token-counter{ position: absolute; display: inline-block; right: 1em; min-width: 0 !important; width: auto; z-index: 100; top: -0.5em; }
+.block.token-counter span{ background: var(--input-background-fill) !important; box-shadow: 0 0 0.0 0.3em rgba(192,192,192,0.15), inset 0 0 0.6em rgba(192,192,192,0.075); border: 2px solid rgba(192,192,192,0.4) !important; }
+.block.token-counter.error span{ box-shadow: 0 0 0.0 0.3em rgba(255,0,0,0.15), inset 0 0 0.6em rgba(255,0,0,0.075); border: 2px solid rgba(255,0,0,0.4) !important; }
+.block.token-counter div{ display: inline; }
+.block.token-counter span{ padding: 0.1em 0.75em; }
+.performance { font-size: var(--text-xs); color: #444; }
+.performance p { display: inline-block; color: var(--body-text-color-subdued) !important }
+.performance .time { margin-right: 0; }
+.thumbnails { background: var(--body-background-fill); }
+#control_gallery { height: 564px; }
+#control-result { padding: 0.5em; }
+#control-inputs { margin-top: 1em; }
+#txt2img_prompt_container, #img2img_prompt_container, #control_prompt_container { margin-right: var(--layout-gap) }
+#txt2img_footer, #img2img_footer, #control_footer { height: fit-content; display: none; }
+#txt2img_generate_box, #img2img_generate_box, #control_general_box { gap: 0.5em; flex-wrap: wrap-reverse; height: fit-content; }
+#txt2img_actions_column, #img2img_actions_column, #control_actions_column { gap: 0.3em; height: fit-content; }
+#txt2img_generate_box>button, #img2img_generate_box>button, #control_generate_box>button, #txt2img_enqueue, #img2img_enqueue { min-height: 42px; max-height: 42px; line-height: 1em; }
+#txt2img_generate_line2, #img2img_generate_line2, #txt2img_tools, #img2img_tools, #control_generate_line2, #control_tools { display: flex; }
+#txt2img_generate_line2>button, #img2img_generate_line2>button, #extras_generate_box>button, #control_generate_line2>button, #txt2img_tools>button, #img2img_tools>button, #control_tools>button { height: 2em; line-height: 0; font-size: var(--text-md);
+ min-width: unset; display: block !important; }
+#txt2img_prompt, #txt2img_neg_prompt, #img2img_prompt, #img2img_neg_prompt, #control_prompt, #control_neg_prompt { display: contents; }
+#txt2img_generate_box, #img2img_generate_box { min-width: unset; width: 48%; }
+#control_generate_box { min-width: unset; width: 100%; }
+#txt2img_actions_column, #img2img_actions_column, #control_actions { flex-flow: wrap; justify-content: space-between; }
+#txt2img_enqueue_wrapper, #img2img_enqueue_wrapper, #control_enqueue_wrapper { min-width: unset !important; width: 48%; }
+.interrogate-clip { position: absolute; right: 3em; top: -2.7em; max-width: fit-content; }
+.interrogate-blip { position: absolute; right: 1em; top: -2.7em; max-width: fit-content; }
+.interrogate-col{ min-width: 0 !important; max-width: fit-content; margin-right: var(--spacing-xxl); }
+.interrogate-col>button{ flex: 1; width: 7em; max-height: 84px; }
+#sampler_selection_img2img { margin-top: 1em; }
+#txtimg_hr_finalres{ min-height: 0 !important; }
+#img2img_scale_resolution_preview.block{ display: flex; align-items: end; }
+#txtimg_hr_finalres .resolution, #img2img_scale_resolution_preview .resolution{ font-weight: bold; }
+div#extras_scale_to_tab div.form{ flex-direction: row; }
+#img2img_unused_scale_by_slider { visibility: hidden; width: 0.5em; max-width: 0.5em; min-width: 0.5em; }
+.inactive{ opacity: 0.5; }
+div#extras_scale_to_tab div.form{ flex-direction: row; }
+#mode_img2img .gradio-image>div.fixed-height, #mode_img2img .gradio-image>div.fixed-height img{ height: 480px !important; max-height: 480px !important; min-height: 480px !important; }
+#img2img_sketch, #img2maskimg, #inpaint_sketch { overflow: overlay !important; resize: auto; background: var(--panel-background-fill); z-index: 5; }
+.image-buttons button{ min-width: auto; }
+.infotext { overflow-wrap: break-word; line-height: 1.5em; }
+.infotext>p { padding-left: 1em; text-indent: -1em; white-space: pre-wrap; }
+.tooltip { display: block; position: fixed; top: 1em; right: 1em; padding: 0.5em; background: var(--input-background-fill); color: var(--body-text-color); border: 1pt solid var(--button-primary-border-color);
+ width: 22em; min-height: 1.3em; font-size: var(--text-xs); transition: opacity 0.2s ease-in; pointer-events: none; opacity: 0; z-index: 999; }
+.tooltip-show { opacity: 0.9; }
+.toolbutton-selected { background: var(--background-fill-primary) !important; }
+
+/* settings */
+#si-sparkline-memo, #si-sparkline-load { background-color: #111; }
+#quicksettings { width: fit-content; }
+#quicksettings>button { padding: 0 1em 0 0; align-self: end; margin-bottom: var(--text-sm); }
+#settings { display: flex; gap: var(--layout-gap); }
+#settings div { border: none; gap: 0; margin: 0 0 var(--layout-gap) 0px; padding: 0; }
+#settings>div.tab-content { flex: 10 0 75%; display: grid; }
+#settings>div.tab-content>div { border: none; padding: 0; }
+#settings>div.tab-content>div>div>div>div>div { flex-direction: unset; }
+#settings>div.tab-nav { display: grid; grid-template-columns: repeat(auto-fill, .5em minmax(10em, 1fr)); flex: 1 0 auto; width: 12em; align-self: flex-start; gap: var(--spacing-xxl); }
+#settings>div.tab-nav button { display: block; border: none; text-align: left; white-space: initial; padding: 0; }
+#settings>div.tab-nav>#settings_show_all_pages { padding: var(--size-2) var(--size-4); }
+#settings .block.gradio-checkbox { margin: 0; width: auto; }
+#settings .dirtyable { gap: .5em; }
+#settings .dirtyable.hidden { display: none; }
+#settings .modification-indicator { height: 1.2em; border-radius: 1em !important; padding: 0; width: 0; margin-right: 0.5em; }
+#settings .modification-indicator:disabled { visibility: hidden; }
+#settings .modification-indicator.saved { background: var(--color-accent-soft); width: var(--spacing-sm); }
+#settings .modification-indicator.changed { background: var(--color-accent); width: var(--spacing-sm); }
+#settings .modification-indicator.changed.unsaved { background-image: linear-gradient(var(--color-accent) 25%, var(--color-accent-soft) 75%); width: var(--spacing-sm); }
+#settings_result { margin: 0 1.2em; }
+.licenses { display: block !important; }
+
+/* live preview */
+.progressDiv{ position: relative; height: 20px; background: #b4c0cc; margin-bottom: -3px; }
+.dark .progressDiv{ background: #424c5b; }
+.progressDiv .progress{ width: 0%; height: 20px; background: #0060df; color: white; font-weight: bold; line-height: 20px; padding: 0 8px 0 0; text-align: right; overflow: visible; white-space: nowrap; padding: 0 0.5em; }
+.livePreview { position: absolute; z-index: 50; background-color: transparent; width: -moz-available; width: -webkit-fill-available; }
+.livePreview img { position: absolute; object-fit: contain; width: 100%; height: 100%; }
+.dark .livePreview { background-color: rgb(17 24 39 / var(--tw-bg-opacity)); }
+.popup-metadata { color: white; background: #0000; display: inline-block; white-space: pre-wrap; font-size: var(--text-xxs); }
+.global-popup{ display: flex; position: fixed; z-index: 10001; left: 0; top: 0; width: 100%; height: 100%; overflow: auto; background-color: rgba(20, 20, 20, 0.95);}
+.global-popup-close:before { content: "×"; }
+.global-popup-close{ position: fixed; right: 0.5em; top: 0; cursor: pointer; color: white; font-size: 32pt; }
+.global-popup-inner{ display: inline-block; margin: auto; padding: 2em; }
+
+/* fullpage image viewer */
+#lightboxModal{ display: none; position: fixed; z-index: 1001; left: 0; top: 0; width: 100%; height: 100%; overflow: auto; background-color: rgba(20, 20, 20, 0.75); backdrop-filter: blur(6px);
+ user-select: none; -webkit-user-select: none; flex-direction: row; }
+.modalControls { display: flex; justify-content: space-evenly; background-color: transparent; position: absolute; width: 99%; z-index: 1; }
+.modalControls:hover { background-color: #50505050; }
+.modalControls span { color: white; font-size: 2em; font-weight: bold; cursor: pointer; filter: grayscale(100%); }
+.modalControls span:hover, .modalControls span:focus { color: var(--highlight-color); filter: none; }
+.lightboxModalPreviewZone { display: flex; width: 100%; height: 100%; }
+.lightboxModalPreviewZone:focus-visible { outline: none; }
+.lightboxModalPreviewZone>img { display: block; margin: auto; width: auto; }
+.lightboxModalPreviewZone>img.modalImageFullscreen{ object-fit: contain; height: 100%; width: 100%; min-height: 0; background: transparent; }
+table.settings-value-table { background: white; border-collapse: collapse; margin: 1em; border: var(--spacing-sm) solid white; }
+table.settings-value-table td { padding: 0.4em; border: 1px solid #ccc; max-width: 36em; }
+.modalPrev, .modalNext { cursor: pointer; position: relative; z-index: 1; top: 0; width: auto; height: 100vh; line-height: 100vh; text-align: center; padding: 16px;
+ margin-top: -50px; color: white; font-weight: bold; font-size: 20px; transition: 0.6s ease; user-select: none; -webkit-user-select: none; }
+.modalNext { right: 0; }
+.modalPrev:hover, .modalNext:hover { background-color: rgba(0, 0, 0, 0.8); }
+#imageARPreview { position: absolute; top: 0px; left: 0px; border: 2px solid red; background: rgba(255, 0, 0, 0.3); z-index: 900; pointer-events: none; display: none; }
+
+/* context menu (ie for the generate button) */
+#context-menu { z-index: 9999; position: absolute; display: block; padding: var(--spacing-md); border: 2px solid var(--highlight-color); background: var(--background-fill-primary); color: var(--body-text-color); }
+.context-menu-items { list-style: none; margin: 0; padding: 0; font-size: var(--text-sm); }
+.context-menu-items a { display: block; padding: var(--spacing-md); cursor: pointer; font-weight: normal; }
+.context-menu-items a:hover { background: var(--highlight-color) }
+
+/* extensions */
+#tab_extensions table, #tab_config table{ border-collapse: collapse; }
+#tab_extensions table td, #tab_extensions table th, #tab_config table td, #tab_config table th { border: 1px solid #ccc; padding: 0.25em 0.5em; }
+#tab_extensions table tr:hover, #tab_config table tr:hover { background-color: var(--neutral-500) !important; }
+#tab_extensions table input[type="checkbox"] { margin-right: 0.5em; appearance: checkbox; }
+#tab_extensions button{ max-width: 16em; }
+#tab_extensions input[disabled="disabled"]{ opacity: 0.5; }
+.extension-tag{ font-weight: bold; font-size: var(--text-sm); }
+.extension-button { font-size: var(--text-sm) !important; width: 6em; }
+#extensions .name{ font-size: var(--text-lg) }
+#extensions .type{ opacity: 0.5; font-size: var(--text-sm); text-align: center; }
+#extensions .version{ opacity: 0.7; }
+#extensions .info{ margin: 0; }
+#extensions .date{ opacity: 0.85; font-size: var(--text-sm); }
+
+/* extra networks */
+.extra-networks>div { margin: 0; border-bottom: none !important; gap: 0.3em 0; }
+.extra-networks .second-line { display: flex; width: -moz-available; width: -webkit-fill-available; gap: 0.3em; box-shadow: var(--input-shadow); }
+.extra-networks .search { flex: 1; }
+.extra-networks .description { flex: 3; }
+.extra-networks .tab-nav>button { margin-right: 0; height: 24px; padding: 2px 4px 2px 4px; }
+.extra-networks .buttons { position: absolute; right: 0; margin: -4px; background: var(--background-color); }
+.extra-networks .buttons>button { margin-left: -0.2em; height: 1.4em; color: var(--primary-300) !important; font-size: 20px !important; }
+.extra-networks .custom-button { width: 120px; width: 100%; background: none; justify-content: left; text-align: left; padding: 3px 3px 3px 12px; text-indent: -6px; box-shadow: none; line-break: auto; }
+.extra-networks .custom-button:hover { background: var(--button-primary-background-fill) }
+.extra-networks-tab { padding: 0 !important; }
+.extra-network-subdirs { background: var(--input-background-fill); overflow-x: hidden; overflow-y: auto; min-width: max(15%, 120px); padding-top: 0.5em; margin-top: -4px !important; }
+.extra-networks-page { display: flex }
+.extra-network-cards { display: flex; flex-wrap: wrap; overflow-y: auto; overflow-x: hidden; align-content: flex-start; width: -moz-available; width: -webkit-fill-available; }
+.extra-network-cards .card { height: fit-content; margin: 0 0 0.5em 0.5em; position: relative; scroll-snap-align: start; scroll-margin-top: 0; }
+.extra-network-cards .card .overlay { position: absolute; bottom: 0; padding: 0.2em; z-index: 10; width: 100%; background: none; }
+.extra-network-cards .card .overlay .name { font-size: var(--text-lg); font-weight: bold; text-shadow: 1px 1px black; color: white; overflow-wrap: break-word; }
+.extra-network-cards .card .preview { box-shadow: var(--button-shadow); min-height: 30px; }
+.extra-network-cards .card:hover .overlay { background: rgba(0, 0, 0, 0.40); }
+.extra-network-cards .card:hover .preview { box-shadow: none; filter: grayscale(100%); }
+.extra-network-cards .card:hover .overlay { background: rgba(0, 0, 0, 0.40); }
+.extra-network-cards .card .overlay .tags { display: none; overflow-wrap: break-word; }
+.extra-network-cards .card .overlay .tag { padding: 2px; margin: 2px; background: rgba(70, 70, 70, 0.60); font-size: var(--text-md); cursor: pointer; display: inline-block; }
+.extra-network-cards .card .actions>span { padding: 4px; font-size: 34px !important; }
+.extra-network-cards .card .actions>span:hover { color: var(--highlight-color); }
+.extra-network-cards .card:hover .actions { display: block; }
+.extra-network-cards .card:hover .overlay .tags { display: block; }
+.extra-network-cards .card .actions { font-size: 3em; display: none; text-align-last: right; cursor: pointer; font-variant: unicase; position: absolute; z-index: 80; right: 0; height: 0.7em; width: 100%; background: rgba(0, 0, 0, 0.40); }
+.extra-network-cards .card-list { display: flex; margin: 0.3em; padding: 0.3em; background: var(--input-background-fill); cursor: pointer; border-radius: var(--button-large-radius); }
+.extra-network-cards .card-list .tag { color: var(--primary-500); margin-left: 0.8em; }
+.extra-details-close { position: fixed; top: 0.2em; right: 0.2em; z-index: 99; background: var(--button-secondary-background-fill) !important; }
+#txt2img_description, #img2img_description, #control_description { max-height: 63px; overflow-y: auto !important; }
+#txt2img_description>label>textarea, #img2img_description>label>textarea, #control_description>label>textarea { font-size: var(--text-sm) }
+
+#txt2img_extra_details>div, #img2img_extra_details>div { overflow-y: auto; min-height: 40vh; max-height: 80vh; align-self: flex-start; }
+#txt2img_extra_details, #img2img_extra_details { position: fixed; bottom: 50%; left: 50%; transform: translate(-50%, 50%); padding: 0.8em; border: var(--block-border-width) solid var(--highlight-color) !important;
+ z-index: 100; box-shadow: var(--button-shadow); }
+#txt2img_extra_details td:first-child, #img2img_extra_details td:first-child { font-weight: bold; vertical-align: top; }
+#txt2img_extra_details .gradio-image, #img2img_extra_details .gradio-image { max-height: 70vh; }
+
+
+/* specific elements */
+#modelmerger_interp_description { margin-top: 1em; margin-bottom: 1em; }
+#scripts_alwayson_txt2img, #scripts_alwayson_img2img { padding: 0 }
+#scripts_alwayson_txt2img>.label-wrap, #scripts_alwayson_img2img>.label-wrap { background: var(--input-background-fill); padding: 0; margin: 0; border-radius: var(--radius-lg); }
+#scripts_alwayson_txt2img>.label-wrap>span, #scripts_alwayson_img2img>.label-wrap>span { padding: var(--spacing-xxl); }
+#scripts_alwayson_txt2img div { max-width: var(--left-column); }
+#script_txt2img_agent_scheduler { display: none; }
+#refresh_tac_refreshTempFiles { display: none; }
+#train_tab { flex-flow: row-reverse; }
+#models_tab { flex-flow: row-reverse; }
+#swap_axes>button { min-width: 100px; font-size: var(--text-md); }
+#ui_defaults_review { margin: 1em; }
+
+/* extras */
+.extras { gap: 0.2em 1em !important }
+#extras_generate, #extras_interrupt, #extras_skip { display: block !important; position: relative; height: 36px; }
+#extras_upscale { margin-top: 10px }
+#pnginfo_html_info .gradio-html > div { margin: 0.5em; }
+
+/* log monitor */
+.log-monitor { display: none; justify-content: unset !important; overflow: hidden; padding: 0; margin-top: auto; font-family: monospace; font-size: var(--text-xs); }
+.log-monitor td, .log-monitor th { padding-left: 1em; }
+
+/* changelog */
+.md h2 { background-color: var(--background-fill-primary); padding: 0.5em; }
+.md ul { list-style-type: square !important; text-indent: 1em; margin-left: 4em; }
+.md li { list-style-position: outside !important; text-indent: 0; }
+.md p { margin-left: 2em; }
+
+/* custom component */
+.folder-selector textarea { height: 2em !important; padding: 6px !important; }
+.nvml { position: fixed; bottom: 10px; right: 10px; background: var(--background-fill-primary); border: 1px solid var(--button-primary-border-color); padding: 6px; color: var(--button-primary-text-color);
+ font-size: var(--text-xxs); z-index: 50; font-family: monospace; display: none; }
+
+/* control */
+#control_input_type { max-width: 18em }
+#control_settings .small-accordion .form { min-width: 350px !important }
+.control-button { min-height: 42px; max-height: 42px; line-height: 1em; }
+.control-tabs > .tab-nav { margin-bottom: 0; margin-top: 0; }
+.control-unit { max-width: 1200px; padding: 0 !important; margin-top: -10px !important; }
+.control-unit > .label-wrap { margin-bottom: 0 !important; }
+.processor-settings { padding: 0 !important; max-width: 300px; }
+.processor-group>div { flex-flow: wrap;gap: 1em; }
+
+/* main info */
+.main-info { font-weight: var(--section-header-text-weight); color: var(--body-text-color-subdued); padding: 1em !important; margin-top: 2em !important; line-height: var(--line-lg) !important; }
+
+/* loader */
+.splash { position: fixed; top: 0; left: 0; width: 100vw; height: 100vh; z-index: 1000; display: block; text-align: center; }
+.motd { margin-top: 2em; color: var(--body-text-color-subdued); font-family: monospace; font-variant: all-petite-caps; }
+.splash-img { margin: 10% auto 0 auto; width: 512px; background-repeat: no-repeat; height: 512px; animation: color 10s infinite alternate; }
+.loading { color: white; position: absolute; top: 20%; left: 50%; transform: translateX(-50%); }
+.loader { width: 300px; height: 300px; border: var(--spacing-md) solid transparent; border-radius: 50%; border-top: var(--spacing-md) solid var(--primary-600); animation: spin 4s linear infinite; position: relative; }
+.loader::before, .loader::after { content: ""; position: absolute; top: 6px; bottom: 6px; left: 6px; right: 6px; border-radius: 50%; border: var(--spacing-md) solid transparent; }
+.loader::before { border-top-color: var(--primary-900); animation: 3s spin linear infinite; }
+.loader::after { border-top-color: var(--primary-300); animation: spin 1.5s linear infinite; }
+@keyframes move { from { background-position-x: 0, -40px; } to { background-position-x: 0, 40px; } }
+@keyframes spin { from { transform: rotate(0deg); } to { transform: rotate(360deg); } }
+@keyframes color { from { filter: hue-rotate(0deg) } to { filter: hue-rotate(360deg) } }
+
+:root, .light, .dark {
+ --text-xxs: 9px;
+ --text-xs: 10px;
+ --text-sm: 12px;
+ --text-md: 14px;
+ --text-lg: 15px;
+ --text-xl: 16px;
+ --text-xxl: 17px;
+ --spacing-xxs: 1px;
+ --spacing-xs: 2px;
+ --spacing-sm: 3px;
+ --spacing-lg: 4px;
+ --spacing-xl: 5px;
+ --spacing-xxl: 6px;
+}
+
+@media (hover: none) and (pointer: coarse) { /* Apply different styles for devices with coarse pointers dependant on screen resolution */
+ @media (max-width: 1024px) { /* Do not affect displays larger than 1024px wide. */
+ @media (max-width: 399px) { /* Screens smaller than 400px wide */
+ :root, .light, .dark { --left-column: 100%; }
+ #txt2img_results, #img2img_results, #extras_results { min-width: calc(min(320px, 100%)) !important;} /* maintain single column for from image operations on larger mobile devices */
+ #txt2img_footer p { text-wrap: wrap; }
+ }
+ @media (min-width: 400px) { /* Screens larger than 400px wide */
+ :root, .light, .dark {--left-column: 50%;}
+ #txt2img_results, #extras_results, #txt2im g_footer p {text-wrap: wrap; max-width: 100% !important; } /* maintain side by side split on larger mobile displays for from text */
+ }
+ #scripts_alwayson_txt2img div, #scripts_alwayson_img2img div { max-width: 100%; }
+ #txt2img_prompt_container, #img2img_prompt_container, #control_prompt_container { resize:vertical !important; }
+ #txt2img_generate_box, #txt2img_enqueue_wrapper { min-width: 100% !important;} /* make generate and enqueue buttons take up the entire width of their rows. */
+ #img2img_toprow>div.gradio-column {flex-grow: 1 !important;} /*make interrogate buttons take up appropriate space. */
+ #img2img_actions_column {display: flex; min-width: fit-content !important; flex-direction: row;justify-content: space-evenly; align-items: center;}
+ #txt2img_generate_box, #img2img_generate_box, #txt2img_enqueue_wrapper,#img2img_enqueue_wrapper {display: flex;flex-direction: column;height: 4em !important;align-items: stretch;justify-content: space-evenly;}
+ #img2img_interface, #img2img_results, #img2img_footer p {text-wrap: wrap; min-width: 100% !important; max-width: 100% !important;} /* maintain single column for from image operations on larger mobile devices */
+ #img2img_sketch, #img2maskimg, #inpaint_sketch {display: flex; overflow: auto !important; resize: none !important; } /* fix inpaint image display being too large for mobile displays */
+ #img2maskimg canvas { width: auto !important; max-height: 100% !important; height: auto !important; }
+ #txt2img_sampler, #txt2img_batch, #txt2img_seed_group, #txt2img_advanced, #txt2img_second_pass, #img2img_sampling_group, #img2img_resize_group, #img2img_batch_group, #img2img_seed_group, #img2img_denoise_group, #img2img_advanced_group { width: 100% !important; } /* fix from text/image UI elements to prevent them from moving around within the UI */
+ #img2img_resize_group .gradio-radio>div { display: flex; flex-direction: column; width: unset !important; }
+ #inpaint_controls div {display:flex;flex-direction: row;}
+ #inpaint_controls .gradio-radio>div { display: flex; flex-direction: column !important; }
+ #models_tab { flex-direction: column-reverse !important; } /* move image preview/output on models page to bottom of page */
+ #enqueue_keyboard_shortcut_modifiers, #enqueue_keyboard_shortcut_key div { max-width: 40% !important;} /* fix settings for agent scheduler */
+ #settings { display: flex; flex-direction: row; flex-wrap: wrap; max-width: 100% !important; } /* adjust width of certain settings item to allow aligning as row, but not have it go off the screen */
+ #settings div.tab-content>div>div>div { max-width: 80% !important;}
+ #settings div .gradio-radio { width: unset !important; }
+ #tab_extensions table { border-collapse: collapse; display: block; overflow-x:auto !important;} /* enable scrolling on extensions tab */
+ ::-webkit-scrollbar { width: 25px !important; height:25px; } /* increase scrollbar size to make it finger friendly */
+ .gradio-dropdown ul.options {max-height: 41vh !important; } /* adjust dropdown size to make them easier to select individual items on mobile. */
+ .gradio-dropdown ul.options li.item {height: 40px !important; display: flex; align-items: center;}
+ .gradio-slider input[type="number"] { width: 4em; font-size: var(--text-xs); height: 16px; text-align: center; } /* adjust slider input fields as they were too large for mobile devices. */
+ #txt2img_settings .block .padded:not(.gradio-accordion) {padding: 0 !important;margin-right: 0; min-width: 100% !important; width:100% !important;}
+ }
+}
diff --git a/javascript/timeless-beige.css b/javascript/timeless-beige.css
index d9a17bd00..d0142bdce 100644
--- a/javascript/timeless-beige.css
+++ b/javascript/timeless-beige.css
@@ -1,297 +1,297 @@
-/* generic html tags */
-:root, .light, .dark {
- --font: 'system-ui', 'ui-sans-serif', 'system-ui', "Roboto", sans-serif;
- --font-mono: 'ui-monospace', 'Consolas', monospace;
- --font-size: 16px;
- --primary-100: #212226; /* bg color*/
- --primary-200: #17181b; /* drop down menu/ prompt window fill*/
- --primary-300: #0a0c0e; /* black */
- --primary-400: #2f3034; /* small buttons*/
- --primary-500: #434242; /* main accent color retro beige*/
- --primary-700: #e75d5d; /* light blue gray*/
- --primary-800: #e75d5d; /* sat orange(hover accent)*/
- --highlight-color: var(--primary-500);
- --inactive-color: var(--primary--800);
- --body-text-color: var(--neutral-100);
- --body-text-color-subdued: var(--neutral-300);
- --background-color: var(--primary-100);
- --background-fill-primary: var(--input-background-fill);
- --input-padding: 8px;
- --input-background-fill: var(--primary-200);
- --input-shadow: none;
- --button-secondary-text-color: white;
- --button-secondary-background-fill: var(--primary-400);
- --button-secondary-background-fill-hover: var(--primary-700);
- --block-title-text-color: var(--neutral-300);
- --radius-sm: 1px;
- --radius-lg: 6px;
- --spacing-md: 4px;
- --spacing-xxl: 8px;
- --line-sm: 1.2em;
- --line-md: 1.4em;
-}
-
-html { font-size: var(--font-size); }
-body, button, input, select, textarea { font-family: var(--font);}
-button { max-width: 400px; }
-img { background-color: var(--background-color); }
-input[type=range] { height: var(--line-sm); appearance: none; margin-top: 0; min-width: 160px; background-color: var(--background-color); width: 100%; background: transparent; }
-input[type=range]::-webkit-slider-runnable-track, input[type=range]::-moz-range-track { width: 100%; height: 6px; cursor: pointer; background: var(--primary-400); border-radius: var(--radius-lg); border: 0px solid #222222; }
-input[type=range]::-webkit-slider-thumb, input[type=range]::-moz-range-thumb { border: 0px solid #000000; height: var(--line-sm); width: 8px; border-radius: var(--radius-lg); background: white; cursor: pointer; appearance: none; margin-top: 0px; }
-input[type=range]::-moz-range-progress { background-color: var(--primary-500); height: 6px; border-radius: var(--radius-lg); }
-::-webkit-scrollbar-track { background: #333333; }
-::-webkit-scrollbar-thumb { background-color: var(--highlight-color); border-radius: var(--radius-lg); border-width: 0; box-shadow: 2px 2px 3px #111111; }
-div.form { border-width: 0; box-shadow: none; background: transparent; overflow: visible; margin-bottom: 6px; }
-div.compact { gap: 1em; }
-
-/* gradio style classes */
-fieldset .gr-block.gr-box, label.block span { padding: 0; margin-top: -4px; }
-.border-2 { border-width: 0; }
-.border-b-2 { border-bottom-width: 2px; border-color: var(--highlight-color) !important; padding-bottom: 2px; margin-bottom: 8px; }
-.bg-white { color: lightyellow; background-color: var(--inactive-color); }
-.gr-box { border-radius: var(--radius-sm) !important; background-color: #111111 !important; box-shadow: 2px 2px 3px #111111; border-width: 0; padding: 4px; margin: 12px 0px 12px 0px }
-.gr-button { font-weight: normal; box-shadow: 2px 2px 3px #111111; font-size: 0.8rem; min-width: 32px; min-height: 32px; padding: 3px; margin: 3px; }
-.gr-check-radio { background-color: var(--inactive-color); border-width: 0; border-radius: var(--radius-lg); box-shadow: 2px 2px 3px #111111; }
-.gr-check-radio:checked { background-color: var(--highlight-color); }
-.gr-compact { background-color: var(--background-color); }
-.gr-form { border-width: 0; }
-.gr-input { background-color: #333333 !important; padding: 4px; margin: 4px; }
-.gr-input-label { color: lightyellow; border-width: 0; background: transparent; padding: 2px !important; }
-.gr-panel { background-color: var(--background-color); }
-.eta-bar { display: none !important }
-svg.feather.feather-image, .feather .feather-image { display: none }
-.gap-2 { padding-top: 8px; }
-.gr-box > div > div > input.gr-text-input { right: 0; width: 4em; padding: 0; top: -12px; border: none; max-height: 20px; }
-.output-html { line-height: 1.2rem; overflow-x: hidden; }
-.output-html > div { margin-bottom: 8px; }
-.overflow-hidden .flex .flex-col .relative col .gap-4 { min-width: var(--left-column); max-width: var(--left-column); } /* this is a problematic one */
-.p-2 { padding: 0; }
-.px-4 { padding-lefT: 1rem; padding-right: 1rem; }
-.py-6 { padding-bottom: 0; }
-.tabs { background-color: var(--background-color); }
-.block.token-counter span { background-color: var(--input-background-fill) !important; box-shadow: 2px 2px 2px #111; border: none !important; font-size: 0.8rem; }
-.tab-nav { zoom: 110%; margin-top: 10px; margin-bottom: 10px; border-bottom: 2px solid var(--highlight-color) !important; padding-bottom: 2px; }
-div.tab-nav button.selected {background-color: var(--button-primary-background-fill);}
-#settings div.tab-nav button.selected {background-color: var(--background-color); color: var(--primary-800); font-weight: bold;}
-.label-wrap { background-color: #292b30; /* extension tab color*/ padding: 16px 8px 8px 8px; border-radius: var(--radius-lg); padding-left: 8px !important; }
-.small-accordion .label-wrap { padding: 8px 0px 8px 0px; }
-.small-accordion .label-wrap .icon { margin-right: 1em; }
-.gradio-button.tool { border: none; box-shadow: none; border-radius: var(--radius-lg);}
-button.selected {background: var(--button-primary-background-fill);}
-.center.boundedheight.flex {background-color: var(--input-background-fill);}
-.compact {border-radius: var(--border-radius-lg);}
-#logMonitorData {background-color: var(--input-background-fill);}
-#tab_extensions table td, #tab_extensions table th, #tab_config table td, #tab_config table th { border: none; padding: 0.5em; background-color: var(--primary-200); }
-#tab_extensions table, #tab_config table { width: 96vw; }
-#tab_extensions table input[type=checkbox] {appearance: none; border-radius: 0px;}
-#tab_extensions button:hover { background-color: var(--button-secondary-background-fill-hover);}
-
-/* automatic style classes */
-.progressDiv { border-radius: var(--radius-sm) !important; position: fixed; top: 44px; right: 26px; max-width: 262px; height: 48px; z-index: 99; box-shadow: var(--button-shadow); }
-.progressDiv .progress { border-radius: var(--radius-lg) !important; background: var(--highlight-color); line-height: 3rem; height: 48px; }
-.gallery-item { box-shadow: none !important; }
-.performance { color: #888; }
-.extra-networks { border-left: 2px solid var(--highlight-color) !important; padding-left: 4px; }
-.image-buttons { gap: 10px !important; justify-content: center; }
-.image-buttons > button { max-width: 160px; }
-.tooltip { background: var(--primary-800); color: white; border: none; border-radius: var(--radius-lg) }
-#system_row > button, #settings_row > button, #config_row > button { max-width: 10em; }
-
-/* gradio elements overrides */
-#div.gradio-container { overflow-x: hidden; }
-#img2img_label_copy_to_img2img { font-weight: normal; }
-#txt2img_prompt, #txt2img_neg_prompt, #img2img_prompt, #img2img_neg_prompt { background-color: var(--background-color); box-shadow: 4px 4px 4px 0px #333333 !important; }
-#txt2img_prompt > label > textarea, #txt2img_neg_prompt > label > textarea, #img2img_prompt > label > textarea, #img2img_neg_prompt > label > textarea { font-size: 1.1rem; }
-#img2img_settings { min-width: calc(2 * var(--left-column)); max-width: calc(2 * var(--left-column)); background-color: #111111; padding-top: 16px; }
-#interrogate, #deepbooru { margin: 0 0px 10px 0px; max-width: 80px; max-height: 80px; font-weight: normal; font-size: 0.95em; }
-#quicksettings .gr-button-tool { font-size: 1.6rem; box-shadow: none; margin-top: -2px; height: 2.4em; }
-#quicksettings button {padding: 0 0.5em 0.1em 0.5em;}
-#open_folder_extras, #footer, #style_pos_col, #style_neg_col, #roll_col, #extras_upscaler_2, #extras_upscaler_2_visibility, #txt2img_seed_resize_from_w, #txt2img_seed_resize_from_h { display: none; }
-#save-animation { border-radius: var(--radius-sm) !important; margin-bottom: 16px; background-color: #111111; }
-#script_list { padding: 4px; margin-top: 16px; margin-bottom: 8px; }
-#settings > div.flex-wrap { width: 15em; }
-#txt2img_cfg_scale { min-width: 200px; }
-#txt2img_checkboxes, #img2img_checkboxes { background-color: transparent; }
-#txt2img_checkboxes, #img2img_checkboxes { margin-bottom: 0.2em; }
-#txt2img_actions_column, #img2img_actions_column { flex-flow: wrap; justify-content: space-between; }
-#txt2img_enqueue_wrapper, #img2img_enqueue_wrapper { min-width: unset; width: 48%; }
-#txt2img_generate_box, #img2img_generate_box { min-width: unset; width: 48%; }
-
-#extras_upscale { margin-top: 10px }
-#txt2img_progress_row > div { min-width: var(--left-column); max-width: var(--left-column); }
-#txt2img_settings { min-width: var(--left-column); max-width: var(--left-column); background-color: #111111; padding-top: 16px; }
-#pnginfo_html2_info { margin-top: -18px; background-color: var(--input-background-fill); padding: var(--input-padding) }
-#txt2img_tools, #img2img_tools { margin-top: -4px; margin-bottom: -4px; }
-#txt2img_styles_row, #img2img_styles_row { margin-top: -6px; z-index: 200; }
-
-/* based on gradio built-in dark theme */
-:root, .light, .dark {
- --body-background-fill: var(--background-color);
- --color-accent-soft: var(--neutral-700);
- --background-fill-secondary: none;
- --border-color-accent: var(--background-color);
- --border-color-primary: var(--background-color);
- --link-text-color-active: var(--primary-500);
- --link-text-color: var(--secondary-500);
- --link-text-color-hover: var(--secondary-400);
- --link-text-color-visited: var(--secondary-600);
- --shadow-spread: 1px;
- --block-background-fill: None;
- --block-border-color: var(--border-color-primary);
- --block_border_width: None;
- --block-info-text-color: var(--body-text-color-subdued);
- --block-label-background-fill: var(--background-fill-secondary);
- --block-label-border-color: var(--border-color-primary);
- --block_label_border_width: None;
- --block-label-text-color: var(--neutral-200);
- --block_shadow: None;
- --block_title_background_fill: None;
- --block_title_border_color: None;
- --block_title_border_width: None;
- --panel-background-fill: var(--background-fill-secondary);
- --panel-border-color: var(--border-color-primary);
- --panel_border_width: None;
- --checkbox-background-color: var(--primary-400);
- --checkbox-background-color-focus: var(--primary-700);
- --checkbox-background-color-hover: var(--primary-700);
- --checkbox-background-color-selected: var(--primary-500);
- --checkbox-border-color: transparent;
- --checkbox-border-color-focus: var(--primary-800);
- --checkbox-border-color-hover: var(--primary-800);
- --checkbox-border-color-selected: var(--primary-800);
- --checkbox-border-width: var(--input-border-width);
- --checkbox-label-background-fill: None;
- --checkbox-label-background-fill-hover: None;
- --checkbox-label-background-fill-selected: var(--checkbox-label-background-fill);
- --checkbox-label-border-color: var(--border-color-primary);
- --checkbox-label-border-color-hover: var(--checkbox-label-border-color);
- --checkbox-label-border-width: var(--input-border-width);
- --checkbox-label-text-color: var(--body-text-color);
- --checkbox-label-text-color-selected: var(--checkbox-label-text-color);
- --error-background-fill: var(--background-fill-primary);
- --error-border-color: var(--border-color-primary);
- --error-text-color: #f768b7; /*was ef4444*/
- --input-background-fill-focus: var(--secondary-600);
- --input-background-fill-hover: var(--input-background-fill);
- --input-border-color: var(--background-color);
- --input-border-color-focus: var(--primary-800);
- --input-placeholder-color: var(--neutral-500);
- --input-shadow-focus: None;
- --loader_color: None;
- --slider_color: None;
- --stat-background-fill: linear-gradient(to right, var(--primary-400), var(--primary-800));
- --table-border-color: var(--neutral-700);
- --table-even-background-fill: var(--primary-300);
- --table-odd-background-fill: var(--primary-200);
- --table-row-focus: var(--color-accent-soft);
- --button-border-width: var(--input-border-width);
- --button-cancel-background-fill: linear-gradient(to bottom right, #dc2626, #b91c1c);
- --button-cancel-background-fill-hover: linear-gradient(to bottom right, #dc2626, #dc2626);
- --button-cancel-border-color: #dc2626;
- --button-cancel-border-color-hover: var(--button-cancel-border-color);
- --button-cancel-text-color: white;
- --button-cancel-text-color-hover: var(--button-cancel-text-color);
- --button-primary-background-fill: var(--primary-500);
- --button-primary-background-fill-hover: var(--primary-800);
- --button-primary-border-color: var(--primary-500);
- --button-primary-border-color-hover: var(--button-primary-border-color);
- --button-primary-text-color: white;
- --button-primary-text-color-hover: var(--button-primary-text-color);
- --button-secondary-border-color: var(--neutral-600);
- --button-secondary-border-color-hover: var(--button-secondary-border-color);
- --button-secondary-text-color-hover: var(--button-secondary-text-color);
- --secondary-50: #eff6ff;
- --secondary-100: #dbeafe;
- --secondary-200: #bfdbfe;
- --secondary-300: #93c5fd;
- --secondary-400: #60a5fa;
- --secondary-500: #3b82f6;
- --secondary-600: #2563eb;
- --secondary-700: #1d4ed8;
- --secondary-800: #1e40af;
- --secondary-900: #1e3a8a;
- --secondary-950: #1d3660;
- --neutral-50: #f0f0f0; /* */
- --neutral-100: #e0dedc;/* majority of text (neutral gray yellow) */
- --neutral-200: #d0d0d0;
- --neutral-300: #9d9dab; /* top tab text (light accent) */
- --neutral-400: #ffba85;/* tab title (light beige) */
- --neutral-500: #484746; /* prompt text (desat accent)*/
- --neutral-600: #605a54; /* tab outline color (accent color)*/
- --neutral-700: #1b1c1e; /* small settings tab accent (dark)*/
- --neutral-800: #e75d5d; /* bright orange accent */
- --neutral-900: #111827;
- --neutral-950: #0b0f19;
- --radius-xxs: 0;
- --radius-xs: 0;
- --radius-md: 0;
- --radius-xl: 0;
- --radius-xxl: 0;
- --body-text-size: var(--text-md);
- --body-text-weight: 400;
- --embed-radius: var(--radius-lg);
- --color-accent: var(--primary-500);
- --shadow-drop: 0;
- --shadow-drop-lg: 0 1px 3px 0 rgb(0 0 0 / 0.1), 0 1px 2px -1px rgb(0 0 0 / 0.1);
- --shadow-inset: rgba(0,0,0,0.05) 0px 2px 4px 0px inset;
- --block-border-width: 1px;
- --block-info-text-size: var(--text-sm);
- --block-info-text-weight: 400;
- --block-label-border-width: 1px;
- --block-label-margin: 0;
- --block-label-padding: var(--spacing-sm) var(--spacing-lg);
- --block-label-radius: calc(var(--radius-lg) - 1px) 0 calc(var(--radius-lg) - 1px) 0;
- --block-label-right-radius: 0 calc(var(--radius-lg) - 1px) 0 calc(var(--radius-lg) - 1px);
- --block-label-text-size: var(--text-sm);
- --block-label-text-weight: 400;
- --block-padding: var(--spacing-xl) calc(var(--spacing-xl) + 2px);
- --block-radius: var(--radius-lg);
- --block-shadow: var(--shadow-drop);
- --block-title-background-fill: none;
- --block-title-border-color: none;
- --block-title-border-width: 0;
- --block-title-padding: 0;
- --block-title-radius: none;
- --block-title-text-size: var(--text-md);
- --block-title-text-weight: 400;
- --container-radius: var(--radius-lg);
- --form-gap-width: 1px;
- --layout-gap: var(--spacing-xxl);
- --panel-border-width: 0;
- --section-header-text-size: var(--text-md);
- --section-header-text-weight: 400;
- --checkbox-border-radius: var(--radius-sm);
- --checkbox-label-gap: 2px;
- --checkbox-label-padding: var(--spacing-md);
- --checkbox-label-shadow: var(--shadow-drop);
- --checkbox-label-text-size: var(--text-md);
- --checkbox-label-text-weight: 400;
- --checkbox-check: url("data:image/svg+xml,%3csvg viewBox='0 0 16 16' fill='white' xmlns='http://www.w3.org/2000/svg'%3e%3cpath d='M12.207 4.793a1 1 0 010 1.414l-5 5a1 1 0 01-1.414 0l-2-2a1 1 0 011.414-1.414L6.5 9.086l4.293-4.293a1 1 0 011.414 0z'/%3e%3c/svg%3e");
- --radio-circle: url("data:image/svg+xml,%3csvg viewBox='0 0 16 16' fill='white' xmlns='http://www.w3.org/2000/svg'%3e%3ccircle cx='8' cy='8' r='3'/%3e%3c/svg%3e");
- --checkbox-shadow: var(--input-shadow);
- --error-border-width: 1px;
- --input-border-width: 1px;
- --input-radius: var(--radius-lg);
- --input-text-size: var(--text-md);
- --input-text-weight: 400;
- --loader-color: var(--color-accent);
- --prose-text-size: var(--text-md);
- --prose-text-weight: 400;
- --prose-header-text-weight: 600;
- --slider-color: ;
- --table-radius: var(--radius-lg);
- --button-large-padding: 2px 6px;
- --button-large-radius: var(--radius-lg);
- --button-large-text-size: var(--text-lg);
- --button-large-text-weight: 400;
- --button-shadow: none;
- --button-shadow-active: none;
- --button-shadow-hover: none;
- --button-small-padding: var(--spacing-sm) calc(2 * var(--spacing-sm));
- --button-small-radius: var(--radius-lg);
- --button-small-text-size: var(--text-md);
- --button-small-text-weight: 400;
- --button-transition: none;
- --size-9: 64px;
- --size-14: 64px;
-}
+/* generic html tags */
+:root, .light, .dark {
+ --font: 'system-ui', 'ui-sans-serif', 'system-ui', "Roboto", sans-serif;
+ --font-mono: 'ui-monospace', 'Consolas', monospace;
+ --font-size: 16px;
+ --primary-100: #212226; /* bg color*/
+ --primary-200: #17181b; /* drop down menu/ prompt window fill*/
+ --primary-300: #0a0c0e; /* black */
+ --primary-400: #2f3034; /* small buttons*/
+ --primary-500: #434242; /* main accent color retro beige*/
+ --primary-700: #e75d5d; /* light blue gray*/
+ --primary-800: #e75d5d; /* sat orange(hover accent)*/
+ --highlight-color: var(--primary-500);
+ --inactive-color: var(--primary--800);
+ --body-text-color: var(--neutral-100);
+ --body-text-color-subdued: var(--neutral-300);
+ --background-color: var(--primary-100);
+ --background-fill-primary: var(--input-background-fill);
+ --input-padding: 8px;
+ --input-background-fill: var(--primary-200);
+ --input-shadow: none;
+ --button-secondary-text-color: white;
+ --button-secondary-background-fill: var(--primary-400);
+ --button-secondary-background-fill-hover: var(--primary-700);
+ --block-title-text-color: var(--neutral-300);
+ --radius-sm: 1px;
+ --radius-lg: 6px;
+ --spacing-md: 4px;
+ --spacing-xxl: 8px;
+ --line-sm: 1.2em;
+ --line-md: 1.4em;
+}
+
+html { font-size: var(--font-size); }
+body, button, input, select, textarea { font-family: var(--font);}
+button { max-width: 400px; }
+img { background-color: var(--background-color); }
+input[type=range] { height: var(--line-sm); appearance: none; margin-top: 0; min-width: 160px; background-color: var(--background-color); width: 100%; background: transparent; }
+input[type=range]::-webkit-slider-runnable-track, input[type=range]::-moz-range-track { width: 100%; height: 6px; cursor: pointer; background: var(--primary-400); border-radius: var(--radius-lg); border: 0px solid #222222; }
+input[type=range]::-webkit-slider-thumb, input[type=range]::-moz-range-thumb { border: 0px solid #000000; height: var(--line-sm); width: 8px; border-radius: var(--radius-lg); background: white; cursor: pointer; appearance: none; margin-top: 0px; }
+input[type=range]::-moz-range-progress { background-color: var(--primary-500); height: 6px; border-radius: var(--radius-lg); }
+::-webkit-scrollbar-track { background: #333333; }
+::-webkit-scrollbar-thumb { background-color: var(--highlight-color); border-radius: var(--radius-lg); border-width: 0; box-shadow: 2px 2px 3px #111111; }
+div.form { border-width: 0; box-shadow: none; background: transparent; overflow: visible; margin-bottom: 6px; }
+div.compact { gap: 1em; }
+
+/* gradio style classes */
+fieldset .gr-block.gr-box, label.block span { padding: 0; margin-top: -4px; }
+.border-2 { border-width: 0; }
+.border-b-2 { border-bottom-width: 2px; border-color: var(--highlight-color) !important; padding-bottom: 2px; margin-bottom: 8px; }
+.bg-white { color: lightyellow; background-color: var(--inactive-color); }
+.gr-box { border-radius: var(--radius-sm) !important; background-color: #111111 !important; box-shadow: 2px 2px 3px #111111; border-width: 0; padding: 4px; margin: 12px 0px 12px 0px }
+.gr-button { font-weight: normal; box-shadow: 2px 2px 3px #111111; font-size: 0.8rem; min-width: 32px; min-height: 32px; padding: 3px; margin: 3px; }
+.gr-check-radio { background-color: var(--inactive-color); border-width: 0; border-radius: var(--radius-lg); box-shadow: 2px 2px 3px #111111; }
+.gr-check-radio:checked { background-color: var(--highlight-color); }
+.gr-compact { background-color: var(--background-color); }
+.gr-form { border-width: 0; }
+.gr-input { background-color: #333333 !important; padding: 4px; margin: 4px; }
+.gr-input-label { color: lightyellow; border-width: 0; background: transparent; padding: 2px !important; }
+.gr-panel { background-color: var(--background-color); }
+.eta-bar { display: none !important }
+svg.feather.feather-image, .feather .feather-image { display: none }
+.gap-2 { padding-top: 8px; }
+.gr-box > div > div > input.gr-text-input { right: 0; width: 4em; padding: 0; top: -12px; border: none; max-height: 20px; }
+.output-html { line-height: 1.2rem; overflow-x: hidden; }
+.output-html > div { margin-bottom: 8px; }
+.overflow-hidden .flex .flex-col .relative col .gap-4 { min-width: var(--left-column); max-width: var(--left-column); } /* this is a problematic one */
+.p-2 { padding: 0; }
+.px-4 { padding-lefT: 1rem; padding-right: 1rem; }
+.py-6 { padding-bottom: 0; }
+.tabs { background-color: var(--background-color); }
+.block.token-counter span { background-color: var(--input-background-fill) !important; box-shadow: 2px 2px 2px #111; border: none !important; font-size: 0.8rem; }
+.tab-nav { zoom: 110%; margin-top: 10px; margin-bottom: 10px; border-bottom: 2px solid var(--highlight-color) !important; padding-bottom: 2px; }
+div.tab-nav button.selected {background-color: var(--button-primary-background-fill);}
+#settings div.tab-nav button.selected {background-color: var(--background-color); color: var(--primary-800); font-weight: bold;}
+.label-wrap { background-color: #292b30; /* extension tab color*/ padding: 16px 8px 8px 8px; border-radius: var(--radius-lg); padding-left: 8px !important; }
+.small-accordion .label-wrap { padding: 8px 0px 8px 0px; }
+.small-accordion .label-wrap .icon { margin-right: 1em; }
+.gradio-button.tool { border: none; box-shadow: none; border-radius: var(--radius-lg);}
+button.selected {background: var(--button-primary-background-fill);}
+.center.boundedheight.flex {background-color: var(--input-background-fill);}
+.compact {border-radius: var(--border-radius-lg);}
+#logMonitorData {background-color: var(--input-background-fill);}
+#tab_extensions table td, #tab_extensions table th, #tab_config table td, #tab_config table th { border: none; padding: 0.5em; background-color: var(--primary-200); }
+#tab_extensions table, #tab_config table { width: 96vw; }
+#tab_extensions table input[type=checkbox] {appearance: none; border-radius: 0px;}
+#tab_extensions button:hover { background-color: var(--button-secondary-background-fill-hover);}
+
+/* automatic style classes */
+.progressDiv { border-radius: var(--radius-sm) !important; position: fixed; top: 44px; right: 26px; max-width: 262px; height: 48px; z-index: 99; box-shadow: var(--button-shadow); }
+.progressDiv .progress { border-radius: var(--radius-lg) !important; background: var(--highlight-color); line-height: 3rem; height: 48px; }
+.gallery-item { box-shadow: none !important; }
+.performance { color: #888; }
+.extra-networks { border-left: 2px solid var(--highlight-color) !important; padding-left: 4px; }
+.image-buttons { gap: 10px !important; justify-content: center; }
+.image-buttons > button { max-width: 160px; }
+.tooltip { background: var(--primary-800); color: white; border: none; border-radius: var(--radius-lg) }
+#system_row > button, #settings_row > button, #config_row > button { max-width: 10em; }
+
+/* gradio elements overrides */
+#div.gradio-container { overflow-x: hidden; }
+#img2img_label_copy_to_img2img { font-weight: normal; }
+#txt2img_prompt, #txt2img_neg_prompt, #img2img_prompt, #img2img_neg_prompt { background-color: var(--background-color); box-shadow: 4px 4px 4px 0px #333333 !important; }
+#txt2img_prompt > label > textarea, #txt2img_neg_prompt > label > textarea, #img2img_prompt > label > textarea, #img2img_neg_prompt > label > textarea { font-size: 1.1rem; }
+#img2img_settings { min-width: calc(2 * var(--left-column)); max-width: calc(2 * var(--left-column)); background-color: #111111; padding-top: 16px; }
+#interrogate, #deepbooru { margin: 0 0px 10px 0px; max-width: 80px; max-height: 80px; font-weight: normal; font-size: 0.95em; }
+#quicksettings .gr-button-tool { font-size: 1.6rem; box-shadow: none; margin-top: -2px; height: 2.4em; }
+#quicksettings button {padding: 0 0.5em 0.1em 0.5em;}
+#open_folder_extras, #footer, #style_pos_col, #style_neg_col, #roll_col, #extras_upscaler_2, #extras_upscaler_2_visibility, #txt2img_seed_resize_from_w, #txt2img_seed_resize_from_h { display: none; }
+#save-animation { border-radius: var(--radius-sm) !important; margin-bottom: 16px; background-color: #111111; }
+#script_list { padding: 4px; margin-top: 16px; margin-bottom: 8px; }
+#settings > div.flex-wrap { width: 15em; }
+#txt2img_cfg_scale { min-width: 200px; }
+#txt2img_checkboxes, #img2img_checkboxes { background-color: transparent; }
+#txt2img_checkboxes, #img2img_checkboxes { margin-bottom: 0.2em; }
+#txt2img_actions_column, #img2img_actions_column { flex-flow: wrap; justify-content: space-between; }
+#txt2img_enqueue_wrapper, #img2img_enqueue_wrapper { min-width: unset; width: 48%; }
+#txt2img_generate_box, #img2img_generate_box { min-width: unset; width: 48%; }
+
+#extras_upscale { margin-top: 10px }
+#txt2img_progress_row > div { min-width: var(--left-column); max-width: var(--left-column); }
+#txt2img_settings { min-width: var(--left-column); max-width: var(--left-column); background-color: #111111; padding-top: 16px; }
+#pnginfo_html2_info { margin-top: -18px; background-color: var(--input-background-fill); padding: var(--input-padding) }
+#txt2img_tools, #img2img_tools { margin-top: -4px; margin-bottom: -4px; }
+#txt2img_styles_row, #img2img_styles_row { margin-top: -6px; z-index: 200; }
+
+/* based on gradio built-in dark theme */
+:root, .light, .dark {
+ --body-background-fill: var(--background-color);
+ --color-accent-soft: var(--neutral-700);
+ --background-fill-secondary: none;
+ --border-color-accent: var(--background-color);
+ --border-color-primary: var(--background-color);
+ --link-text-color-active: var(--primary-500);
+ --link-text-color: var(--secondary-500);
+ --link-text-color-hover: var(--secondary-400);
+ --link-text-color-visited: var(--secondary-600);
+ --shadow-spread: 1px;
+ --block-background-fill: None;
+ --block-border-color: var(--border-color-primary);
+ --block_border_width: None;
+ --block-info-text-color: var(--body-text-color-subdued);
+ --block-label-background-fill: var(--background-fill-secondary);
+ --block-label-border-color: var(--border-color-primary);
+ --block_label_border_width: None;
+ --block-label-text-color: var(--neutral-200);
+ --block_shadow: None;
+ --block_title_background_fill: None;
+ --block_title_border_color: None;
+ --block_title_border_width: None;
+ --panel-background-fill: var(--background-fill-secondary);
+ --panel-border-color: var(--border-color-primary);
+ --panel_border_width: None;
+ --checkbox-background-color: var(--primary-400);
+ --checkbox-background-color-focus: var(--primary-700);
+ --checkbox-background-color-hover: var(--primary-700);
+ --checkbox-background-color-selected: var(--primary-500);
+ --checkbox-border-color: transparent;
+ --checkbox-border-color-focus: var(--primary-800);
+ --checkbox-border-color-hover: var(--primary-800);
+ --checkbox-border-color-selected: var(--primary-800);
+ --checkbox-border-width: var(--input-border-width);
+ --checkbox-label-background-fill: None;
+ --checkbox-label-background-fill-hover: None;
+ --checkbox-label-background-fill-selected: var(--checkbox-label-background-fill);
+ --checkbox-label-border-color: var(--border-color-primary);
+ --checkbox-label-border-color-hover: var(--checkbox-label-border-color);
+ --checkbox-label-border-width: var(--input-border-width);
+ --checkbox-label-text-color: var(--body-text-color);
+ --checkbox-label-text-color-selected: var(--checkbox-label-text-color);
+ --error-background-fill: var(--background-fill-primary);
+ --error-border-color: var(--border-color-primary);
+ --error-text-color: #f768b7; /*was ef4444*/
+ --input-background-fill-focus: var(--secondary-600);
+ --input-background-fill-hover: var(--input-background-fill);
+ --input-border-color: var(--background-color);
+ --input-border-color-focus: var(--primary-800);
+ --input-placeholder-color: var(--neutral-500);
+ --input-shadow-focus: None;
+ --loader_color: None;
+ --slider_color: None;
+ --stat-background-fill: linear-gradient(to right, var(--primary-400), var(--primary-800));
+ --table-border-color: var(--neutral-700);
+ --table-even-background-fill: var(--primary-300);
+ --table-odd-background-fill: var(--primary-200);
+ --table-row-focus: var(--color-accent-soft);
+ --button-border-width: var(--input-border-width);
+ --button-cancel-background-fill: linear-gradient(to bottom right, #dc2626, #b91c1c);
+ --button-cancel-background-fill-hover: linear-gradient(to bottom right, #dc2626, #dc2626);
+ --button-cancel-border-color: #dc2626;
+ --button-cancel-border-color-hover: var(--button-cancel-border-color);
+ --button-cancel-text-color: white;
+ --button-cancel-text-color-hover: var(--button-cancel-text-color);
+ --button-primary-background-fill: var(--primary-500);
+ --button-primary-background-fill-hover: var(--primary-800);
+ --button-primary-border-color: var(--primary-500);
+ --button-primary-border-color-hover: var(--button-primary-border-color);
+ --button-primary-text-color: white;
+ --button-primary-text-color-hover: var(--button-primary-text-color);
+ --button-secondary-border-color: var(--neutral-600);
+ --button-secondary-border-color-hover: var(--button-secondary-border-color);
+ --button-secondary-text-color-hover: var(--button-secondary-text-color);
+ --secondary-50: #eff6ff;
+ --secondary-100: #dbeafe;
+ --secondary-200: #bfdbfe;
+ --secondary-300: #93c5fd;
+ --secondary-400: #60a5fa;
+ --secondary-500: #3b82f6;
+ --secondary-600: #2563eb;
+ --secondary-700: #1d4ed8;
+ --secondary-800: #1e40af;
+ --secondary-900: #1e3a8a;
+ --secondary-950: #1d3660;
+ --neutral-50: #f0f0f0; /* */
+ --neutral-100: #e0dedc;/* majority of text (neutral gray yellow) */
+ --neutral-200: #d0d0d0;
+ --neutral-300: #9d9dab; /* top tab text (light accent) */
+ --neutral-400: #ffba85;/* tab title (light beige) */
+ --neutral-500: #484746; /* prompt text (desat accent)*/
+ --neutral-600: #605a54; /* tab outline color (accent color)*/
+ --neutral-700: #1b1c1e; /* small settings tab accent (dark)*/
+ --neutral-800: #e75d5d; /* bright orange accent */
+ --neutral-900: #111827;
+ --neutral-950: #0b0f19;
+ --radius-xxs: 0;
+ --radius-xs: 0;
+ --radius-md: 0;
+ --radius-xl: 0;
+ --radius-xxl: 0;
+ --body-text-size: var(--text-md);
+ --body-text-weight: 400;
+ --embed-radius: var(--radius-lg);
+ --color-accent: var(--primary-500);
+ --shadow-drop: 0;
+ --shadow-drop-lg: 0 1px 3px 0 rgb(0 0 0 / 0.1), 0 1px 2px -1px rgb(0 0 0 / 0.1);
+ --shadow-inset: rgba(0,0,0,0.05) 0px 2px 4px 0px inset;
+ --block-border-width: 1px;
+ --block-info-text-size: var(--text-sm);
+ --block-info-text-weight: 400;
+ --block-label-border-width: 1px;
+ --block-label-margin: 0;
+ --block-label-padding: var(--spacing-sm) var(--spacing-lg);
+ --block-label-radius: calc(var(--radius-lg) - 1px) 0 calc(var(--radius-lg) - 1px) 0;
+ --block-label-right-radius: 0 calc(var(--radius-lg) - 1px) 0 calc(var(--radius-lg) - 1px);
+ --block-label-text-size: var(--text-sm);
+ --block-label-text-weight: 400;
+ --block-padding: var(--spacing-xl) calc(var(--spacing-xl) + 2px);
+ --block-radius: var(--radius-lg);
+ --block-shadow: var(--shadow-drop);
+ --block-title-background-fill: none;
+ --block-title-border-color: none;
+ --block-title-border-width: 0;
+ --block-title-padding: 0;
+ --block-title-radius: none;
+ --block-title-text-size: var(--text-md);
+ --block-title-text-weight: 400;
+ --container-radius: var(--radius-lg);
+ --form-gap-width: 1px;
+ --layout-gap: var(--spacing-xxl);
+ --panel-border-width: 0;
+ --section-header-text-size: var(--text-md);
+ --section-header-text-weight: 400;
+ --checkbox-border-radius: var(--radius-sm);
+ --checkbox-label-gap: 2px;
+ --checkbox-label-padding: var(--spacing-md);
+ --checkbox-label-shadow: var(--shadow-drop);
+ --checkbox-label-text-size: var(--text-md);
+ --checkbox-label-text-weight: 400;
+ --checkbox-check: url("data:image/svg+xml,%3csvg viewBox='0 0 16 16' fill='white' xmlns='http://www.w3.org/2000/svg'%3e%3cpath d='M12.207 4.793a1 1 0 010 1.414l-5 5a1 1 0 01-1.414 0l-2-2a1 1 0 011.414-1.414L6.5 9.086l4.293-4.293a1 1 0 011.414 0z'/%3e%3c/svg%3e");
+ --radio-circle: url("data:image/svg+xml,%3csvg viewBox='0 0 16 16' fill='white' xmlns='http://www.w3.org/2000/svg'%3e%3ccircle cx='8' cy='8' r='3'/%3e%3c/svg%3e");
+ --checkbox-shadow: var(--input-shadow);
+ --error-border-width: 1px;
+ --input-border-width: 1px;
+ --input-radius: var(--radius-lg);
+ --input-text-size: var(--text-md);
+ --input-text-weight: 400;
+ --loader-color: var(--color-accent);
+ --prose-text-size: var(--text-md);
+ --prose-text-weight: 400;
+ --prose-header-text-weight: 600;
+ --slider-color: ;
+ --table-radius: var(--radius-lg);
+ --button-large-padding: 2px 6px;
+ --button-large-radius: var(--radius-lg);
+ --button-large-text-size: var(--text-lg);
+ --button-large-text-weight: 400;
+ --button-shadow: none;
+ --button-shadow-active: none;
+ --button-shadow-hover: none;
+ --button-small-padding: var(--spacing-sm) calc(2 * var(--spacing-sm));
+ --button-small-radius: var(--radius-lg);
+ --button-small-text-size: var(--text-md);
+ --button-small-text-weight: 400;
+ --button-transition: none;
+ --size-9: 64px;
+ --size-14: 64px;
+}
diff --git a/modules/call_queue.py b/modules/call_queue.py
index 482a87e82..04dbd81d0 100644
--- a/modules/call_queue.py
+++ b/modules/call_queue.py
@@ -1,86 +1,86 @@
-import html
-import threading
-import time
-import cProfile
-from modules import shared, progress, errors
-
-queue_lock = threading.Lock()
-
-
-def wrap_queued_call(func):
- def f(*args, **kwargs):
- with queue_lock:
- res = func(*args, **kwargs)
- return res
- return f
-
-
-def wrap_gradio_gpu_call(func, extra_outputs=None):
- name = func.__name__
- def f(*args, **kwargs):
- # if the first argument is a string that says "task(...)", it is treated as a job id
- if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")":
- id_task = args[0]
- progress.add_task_to_queue(id_task)
- else:
- id_task = None
- with queue_lock:
- progress.start_task(id_task)
- res = [None, '', '', '']
- try:
- res = func(*args, **kwargs)
- progress.record_results(id_task, res)
- except Exception as e:
- shared.log.error(f"Exception: {e}")
- shared.log.error(f"Arguments: args={str(args)[:10240]} kwargs={str(kwargs)[:10240]}")
- errors.display(e, 'gradio call')
- res[-1] = f"{html.escape(str(e))}
"
- finally:
- progress.finish_task(id_task)
- return res
- return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True, name=name)
-
-
-def wrap_gradio_call(func, extra_outputs=None, add_stats=False, name=None):
- job_name = name if name is not None else func.__name__
- def f(*args, extra_outputs_array=extra_outputs, **kwargs):
- t = time.perf_counter()
- shared.mem_mon.reset()
- shared.state.begin(job_name)
- try:
- if shared.cmd_opts.profile:
- pr = cProfile.Profile()
- pr.enable()
- res = func(*args, **kwargs)
- if res is None:
- msg = "No result returned from function"
- shared.log.warning(msg)
- res = [None, '', '', f"{html.escape(msg)}
"]
- else:
- res = list(res)
- if shared.cmd_opts.profile:
- errors.profile(pr, 'Wrap')
- except Exception as e:
- errors.display(e, 'gradio call')
- if extra_outputs_array is None:
- extra_outputs_array = [None, '']
- res = extra_outputs_array + [f"{html.escape(type(e).__name__+': '+str(e))}
"]
- shared.state.end()
- if not add_stats:
- return tuple(res)
- elapsed = time.perf_counter() - t
- elapsed_m = int(elapsed // 60)
- elapsed_s = elapsed % 60
- elapsed_text = f"{elapsed_m}m {elapsed_s:.2f}s" if elapsed_m > 0 else f"{elapsed_s:.2f}s"
- vram_html = ''
- if not shared.mem_mon.disabled:
- vram = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.read().items()}
- if vram.get('active_peak', 0) > 0:
- vram_html = " | "
- vram_html += f"GPU active {max(vram['active_peak'], vram['reserved_peak'])} MB reserved {vram['reserved']} | used {vram['used']} MB free {vram['free']} MB total {vram['total']} MB"
- vram_html += f" | retries {vram['retries']} oom {vram['oom']}" if vram.get('retries', 0) > 0 or vram.get('oom', 0) > 0 else ''
- vram_html += "
"
- if isinstance(res, list):
- res[-1] += f""
- return tuple(res)
- return f
+import html
+import threading
+import time
+import cProfile
+from modules import shared, progress, errors
+
+queue_lock = threading.Lock()
+
+
+def wrap_queued_call(func):
+ def f(*args, **kwargs):
+ with queue_lock:
+ res = func(*args, **kwargs)
+ return res
+ return f
+
+
+def wrap_gradio_gpu_call(func, extra_outputs=None):
+ name = func.__name__
+ def f(*args, **kwargs):
+ # if the first argument is a string that says "task(...)", it is treated as a job id
+ if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")":
+ id_task = args[0]
+ progress.add_task_to_queue(id_task)
+ else:
+ id_task = None
+ with queue_lock:
+ progress.start_task(id_task)
+ res = [None, '', '', '']
+ try:
+ res = func(*args, **kwargs)
+ progress.record_results(id_task, res)
+ except Exception as e:
+ shared.log.error(f"Exception: {e}")
+ shared.log.error(f"Arguments: args={str(args)[:10240]} kwargs={str(kwargs)[:10240]}")
+ errors.display(e, 'gradio call')
+ res[-1] = f"{html.escape(str(e))}
"
+ finally:
+ progress.finish_task(id_task)
+ return res
+ return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True, name=name)
+
+
+def wrap_gradio_call(func, extra_outputs=None, add_stats=False, name=None):
+ job_name = name if name is not None else func.__name__
+ def f(*args, extra_outputs_array=extra_outputs, **kwargs):
+ t = time.perf_counter()
+ shared.mem_mon.reset()
+ shared.state.begin(job_name)
+ try:
+ if shared.cmd_opts.profile:
+ pr = cProfile.Profile()
+ pr.enable()
+ res = func(*args, **kwargs)
+ if res is None:
+ msg = "No result returned from function"
+ shared.log.warning(msg)
+ res = [None, '', '', f"{html.escape(msg)}
"]
+ else:
+ res = list(res)
+ if shared.cmd_opts.profile:
+ errors.profile(pr, 'Wrap')
+ except Exception as e:
+ errors.display(e, 'gradio call')
+ if extra_outputs_array is None:
+ extra_outputs_array = [None, '']
+ res = extra_outputs_array + [f"{html.escape(type(e).__name__+': '+str(e))}
"]
+ shared.state.end()
+ if not add_stats:
+ return tuple(res)
+ elapsed = time.perf_counter() - t
+ elapsed_m = int(elapsed // 60)
+ elapsed_s = elapsed % 60
+ elapsed_text = f"{elapsed_m}m {elapsed_s:.2f}s" if elapsed_m > 0 else f"{elapsed_s:.2f}s"
+ vram_html = ''
+ if not shared.mem_mon.disabled:
+ vram = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.read().items()}
+ if vram.get('active_peak', 0) > 0:
+ vram_html = " | "
+ vram_html += f"GPU active {max(vram['active_peak'], vram['reserved_peak'])} MB reserved {vram['reserved']} | used {vram['used']} MB free {vram['free']} MB total {vram['total']} MB"
+ vram_html += f" | retries {vram['retries']} oom {vram['oom']}" if vram.get('retries', 0) > 0 or vram.get('oom', 0) > 0 else ''
+ vram_html += "
"
+ if isinstance(res, list):
+ res[-1] += f""
+ return tuple(res)
+ return f
diff --git a/modules/cmd_args.py b/modules/cmd_args.py
index 40aeeebce..b86dfd154 100644
--- a/modules/cmd_args.py
+++ b/modules/cmd_args.py
@@ -1,127 +1,127 @@
-import os
-import argparse
-from modules.paths import data_path
-
-parser = argparse.ArgumentParser(description="SD.Next", conflict_handler='resolve', epilog='For other options see UI Settings page', prog='', add_help=True, formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=55, indent_increment=2, width=200))
-parser._optionals = parser.add_argument_group('Other options') # pylint: disable=protected-access
-group = parser.add_argument_group('Server options')
-
-# main server args
-group.add_argument("--config", type=str, default=os.environ.get("SD_CONFIG", os.path.join(data_path, 'config.json')), help="Use specific server configuration file, default: %(default)s")
-group.add_argument("--ui-config", type=str, default=os.environ.get("SD_UICONFIG", os.path.join(data_path, 'ui-config.json')), help="Use specific UI configuration file, default: %(default)s")
-group.add_argument("--medvram", default=os.environ.get("SD_MEDVRAM", False), action='store_true', help="Split model stages and keep only active part in VRAM, default: %(default)s")
-group.add_argument("--lowvram", default=os.environ.get("SD_LOWVRAM", False), action='store_true', help="Split model components and keep only active part in VRAM, default: %(default)s")
-group.add_argument("--ckpt", type=str, default=os.environ.get("SD_MODEL", None), help="Path to model checkpoint to load immediately, default: %(default)s")
-group.add_argument('--vae', type=str, default=os.environ.get("SD_VAE", None), help='Path to VAE checkpoint to load immediately, default: %(default)s')
-group.add_argument("--data-dir", type=str, default=os.environ.get("SD_DATADIR", ''), help="Base path where all user data is stored, default: %(default)s")
-group.add_argument("--models-dir", type=str, default=os.environ.get("SD_MODELSDIR", 'models'), help="Base path where all models are stored, default: %(default)s",)
-group.add_argument("--allow-code", default=os.environ.get("SD_ALLOWCODE", False), action='store_true', help="Allow custom script execution, default: %(default)s")
-group.add_argument("--share", default=os.environ.get("SD_SHARE", False), action='store_true', help="Enable UI accessible through Gradio site, default: %(default)s")
-group.add_argument("--insecure", default=os.environ.get("SD_INSECURE", False), action='store_true', help="Enable extensions tab regardless of other options, default: %(default)s")
-group.add_argument("--use-cpu", nargs='+', default=[], type=str.lower, help="Force use CPU for specified modules, default: %(default)s")
-group.add_argument("--listen", default=os.environ.get("SD_LISTEN", False), action='store_true', help="Launch web server using public IP address, default: %(default)s")
-group.add_argument("--port", type=int, default=os.environ.get("SD_PORT", 7860), help="Launch web server with given server port, default: %(default)s")
-group.add_argument("--freeze", default=os.environ.get("SD_FREEZE", False), action='store_true', help="Disable editing settings")
-group.add_argument("--auth", type=str, default=os.environ.get("SD_AUTH", None), help='Set access authentication like "user:pwd,user:pwd""')
-group.add_argument("--auth-file", type=str, default=os.environ.get("SD_AUTHFILE", None), help='Set access authentication using file, default: %(default)s')
-group.add_argument("--autolaunch", default=os.environ.get("SD_AUTOLAUNCH", False), action='store_true', help="Open the UI URL in the system's default browser upon launch")
-group.add_argument('--docs', default=os.environ.get("SD_DOCS", False), action='store_true', help = "Mount Gradio docs at /docs, default: %(default)s")
-group.add_argument('--api-only', default=os.environ.get("SD_APIONLY", False), action='store_true', help = "Run in API only mode without starting UI")
-group.add_argument("--api-log", default=os.environ.get("SD_APILOG", False), action='store_true', help="Enable logging of all API requests, default: %(default)s")
-group.add_argument("--device-id", type=str, default=os.environ.get("SD_DEVICEID", None), help="Select the default CUDA device to use, default: %(default)s")
-group.add_argument("--cors-origins", type=str, default=os.environ.get("SD_CORSORIGINS", None), help="Allowed CORS origins as comma-separated list, default: %(default)s")
-group.add_argument("--cors-regex", type=str, default=os.environ.get("SD_CORSREGEX", None), help="Allowed CORS origins as regular expression, default: %(default)s")
-group.add_argument("--tls-keyfile", type=str, default=os.environ.get("SD_TLSKEYFILE", None), help="Enable TLS and specify key file, default: %(default)s")
-group.add_argument("--tls-certfile", type=str, default=os.environ.get("SD_TLSCERTFILE", None), help="Enable TLS and specify cert file, default: %(default)s")
-group.add_argument("--tls-selfsign", action="store_true", default=os.environ.get("SD_TLSSELFSIGN", False), help="Enable TLS with self-signed certificates, default: %(default)s")
-group.add_argument("--server-name", type=str, default=os.environ.get("SD_SERVERNAME", None), help="Sets hostname of server, default: %(default)s")
-group.add_argument("--no-hashing", default=os.environ.get("SD_NOHASHING", False), action='store_true', help="Disable hashing of checkpoints, default: %(default)s")
-group.add_argument("--no-metadata", default=os.environ.get("SD_NOMETADATA", False), action='store_true', help="Disable reading of metadata from models, default: %(default)s")
-group.add_argument("--no-download", default=os.environ.get("SD_DOWNLOAD", False), action='store_true', help="Disable download of default model, default: %(default)s")
-group.add_argument("--profile", default=os.environ.get("SD_PROFILE", False), action='store_true', help="Run profiler, default: %(default)s")
-group.add_argument("--disable-queue", default=os.environ.get("SD_DISABLEQUEUE", False), action='store_true', help="Disable queues, default: %(default)s")
-group.add_argument('--debug', default=os.environ.get("SD_DEBUG", False), action='store_true', help = "Run installer with debug logging, default: %(default)s")
-group.add_argument('--use-directml', default=os.environ.get("SD_USEDIRECTML", False), action='store_true', help = "Use DirectML if no compatible GPU is detected, default: %(default)s")
-group.add_argument("--use-openvino", default=os.environ.get("SD_USEOPENVINO", False), action='store_true', help="Use Intel OpenVINO backend, default: %(default)s")
-group.add_argument("--use-ipex", default=os.environ.get("SD_USEIPX", False), action='store_true', help="Force use Intel OneAPI XPU backend, default: %(default)s")
-group.add_argument("--use-cuda", default=os.environ.get("SD_USECUDA", False), action='store_true', help="Force use nVidia CUDA backend, default: %(default)s")
-group.add_argument("--use-rocm", default=os.environ.get("SD_USEROCM", False), action='store_true', help="Force use AMD ROCm backend, default: %(default)s")
-group.add_argument('--subpath', type=str, default=os.environ.get("SD_SUBPATH", None), help='Customize the URL subpath for usage with reverse proxy')
-group.add_argument('--backend', type=str, default=os.environ.get("SD_BACKEND", None), choices=['original', 'diffusers'], required=False, help='force model pipeline type')
-
-
-# removed args are added here as hidden in fixed format for compatbility reasons
-group.add_argument("-f", action='store_true', help=argparse.SUPPRESS) # allows running as root; implemented outside of webui
-group.add_argument("--ui-settings-file", type=str, help=argparse.SUPPRESS, default=os.path.join(data_path, 'config.json'))
-group.add_argument("--ui-config-file", type=str, help=argparse.SUPPRESS, default=os.path.join(data_path, 'ui-config.json'))
-group.add_argument("--hide-ui-dir-config", action='store_true', help=argparse.SUPPRESS, default=False)
-group.add_argument("--theme", type=str, help=argparse.SUPPRESS, default=None)
-group.add_argument("--disable-console-progressbars", action='store_true', help=argparse.SUPPRESS, default=True)
-group.add_argument("--disable-safe-unpickle", action='store_true', help=argparse.SUPPRESS, default=True)
-group.add_argument("--lowram", action='store_true', help=argparse.SUPPRESS)
-group.add_argument("--disable-extension-access", default=False, action='store_true', help=argparse.SUPPRESS)
-group.add_argument("--api", help=argparse.SUPPRESS, default=True)
-group.add_argument("--api-auth", type=str, help=argparse.SUPPRESS, default=None)
-
-
-def compatibility_args(opts, args):
- # removed args that have been moved to opts are added here as hidden with default values as defined in opts
- group.add_argument("--ckpt-dir", type=str, help=argparse.SUPPRESS, default=opts.ckpt_dir)
- group.add_argument("--vae-dir", type=str, help=argparse.SUPPRESS, default=opts.vae_dir)
- group.add_argument("--embeddings-dir", type=str, help=argparse.SUPPRESS, default=opts.embeddings_dir)
- group.add_argument("--embeddings-templates-dir", type=str, help=argparse.SUPPRESS, default=opts.embeddings_templates_dir)
- group.add_argument("--hypernetwork-dir", type=str, help=argparse.SUPPRESS, default=opts.hypernetwork_dir)
- group.add_argument("--codeformer-models-path", type=str, help=argparse.SUPPRESS, default=opts.codeformer_models_path)
- group.add_argument("--gfpgan-models-path", type=str, help=argparse.SUPPRESS, default=opts.gfpgan_models_path)
- group.add_argument("--esrgan-models-path", type=str, help=argparse.SUPPRESS, default=opts.esrgan_models_path)
- group.add_argument("--bsrgan-models-path", type=str, help=argparse.SUPPRESS, default=opts.bsrgan_models_path)
- group.add_argument("--realesrgan-models-path", type=str, help=argparse.SUPPRESS, default=opts.realesrgan_models_path)
- group.add_argument("--scunet-models-path", help=argparse.SUPPRESS, default=opts.scunet_models_path)
- group.add_argument("--swinir-models-path", help=argparse.SUPPRESS, default=opts.swinir_models_path)
- group.add_argument("--ldsr-models-path", help=argparse.SUPPRESS, default=opts.ldsr_models_path)
- group.add_argument("--clip-models-path", type=str, help=argparse.SUPPRESS, default=opts.clip_models_path)
- group.add_argument("--opt-channelslast", help=argparse.SUPPRESS, action='store_true', default=opts.opt_channelslast)
- group.add_argument("--xformers", default=(opts.cross_attention_optimization == "xFormers"), action='store_true', help=argparse.SUPPRESS)
- group.add_argument("--disable-nan-check", help=argparse.SUPPRESS, action='store_true', default=opts.disable_nan_check)
- group.add_argument("--rollback-vae", help=argparse.SUPPRESS, default=opts.rollback_vae)
- group.add_argument("--no-half", help=argparse.SUPPRESS, action='store_true', default=opts.no_half)
- group.add_argument("--no-half-vae", help=argparse.SUPPRESS, action='store_true', default=opts.no_half_vae)
- group.add_argument("--precision", help=argparse.SUPPRESS, default=opts.precision)
- group.add_argument("--sub-quad-q-chunk-size", help=argparse.SUPPRESS, default=opts.sub_quad_q_chunk_size)
- group.add_argument("--sub-quad-kv-chunk-size", help=argparse.SUPPRESS, default=opts.sub_quad_kv_chunk_size)
- group.add_argument("--sub-quad-chunk-threshold", help=argparse.SUPPRESS, default=opts.sub_quad_chunk_threshold)
- group.add_argument("--lora-dir", help=argparse.SUPPRESS, default=opts.lora_dir)
- group.add_argument("--lyco-dir", help=argparse.SUPPRESS, default=opts.lyco_dir)
- group.add_argument("--embeddings-dir", help=argparse.SUPPRESS, default=opts.embeddings_dir)
- group.add_argument("--hypernetwork-dir", help=argparse.SUPPRESS, default=opts.hypernetwork_dir)
- group.add_argument("--lyco-patch-lora", help=argparse.SUPPRESS, action='store_true', default=False)
- group.add_argument("--lyco-debug", help=argparse.SUPPRESS, action='store_true', default=False)
- group.add_argument("--enable-console-prompts", help=argparse.SUPPRESS, action='store_true', default=False)
- group.add_argument("--safe", help=argparse.SUPPRESS, action='store_true', default=False)
- group.add_argument("--use-xformers", help=argparse.SUPPRESS, action='store_true', default=False)
-
- # removed opts are added here with fixed values for compatibility reasons
- opts.use_old_emphasis_implementation = False
- opts.use_old_karras_scheduler_sigmas = False
- opts.no_dpmpp_sde_batch_determinism = False
- opts.lora_apply_to_outputs = False
- opts.do_not_show_images = False
- opts.add_model_hash_to_info = True
- opts.add_model_name_to_info = True
- opts.js_modal_lightbox = True
- opts.js_modal_lightbox_initially_zoomed = True
- opts.show_progress_in_title = False
- opts.sd_vae_as_default = True
- opts.enable_emphasis = True
- opts.enable_batch_seeds = True
- # opts.multiple_tqdm = False
- opts.print_hypernet_extra = False
- opts.dimensions_and_batch_together = True
- opts.enable_pnginfo = True
- opts.data['clip_skip'] = 1
-
- opts.onchange("lora_dir", lambda: setattr(args, "lora_dir", opts.lora_dir))
- opts.onchange("lyco_dir", lambda: setattr(args, "lyco_dir", opts.lyco_dir))
-
- args = parser.parse_args()
- return args
+import os
+import argparse
+from modules.paths import data_path
+
+parser = argparse.ArgumentParser(description="SD.Next", conflict_handler='resolve', epilog='For other options see UI Settings page', prog='', add_help=True, formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=55, indent_increment=2, width=200))
+parser._optionals = parser.add_argument_group('Other options') # pylint: disable=protected-access
+group = parser.add_argument_group('Server options')
+
+# main server args
+group.add_argument("--config", type=str, default=os.environ.get("SD_CONFIG", os.path.join(data_path, 'config.json')), help="Use specific server configuration file, default: %(default)s")
+group.add_argument("--ui-config", type=str, default=os.environ.get("SD_UICONFIG", os.path.join(data_path, 'ui-config.json')), help="Use specific UI configuration file, default: %(default)s")
+group.add_argument("--medvram", default=os.environ.get("SD_MEDVRAM", False), action='store_true', help="Split model stages and keep only active part in VRAM, default: %(default)s")
+group.add_argument("--lowvram", default=os.environ.get("SD_LOWVRAM", False), action='store_true', help="Split model components and keep only active part in VRAM, default: %(default)s")
+group.add_argument("--ckpt", type=str, default=os.environ.get("SD_MODEL", None), help="Path to model checkpoint to load immediately, default: %(default)s")
+group.add_argument('--vae', type=str, default=os.environ.get("SD_VAE", None), help='Path to VAE checkpoint to load immediately, default: %(default)s')
+group.add_argument("--data-dir", type=str, default=os.environ.get("SD_DATADIR", ''), help="Base path where all user data is stored, default: %(default)s")
+group.add_argument("--models-dir", type=str, default=os.environ.get("SD_MODELSDIR", 'models'), help="Base path where all models are stored, default: %(default)s",)
+group.add_argument("--allow-code", default=os.environ.get("SD_ALLOWCODE", False), action='store_true', help="Allow custom script execution, default: %(default)s")
+group.add_argument("--share", default=os.environ.get("SD_SHARE", False), action='store_true', help="Enable UI accessible through Gradio site, default: %(default)s")
+group.add_argument("--insecure", default=os.environ.get("SD_INSECURE", False), action='store_true', help="Enable extensions tab regardless of other options, default: %(default)s")
+group.add_argument("--use-cpu", nargs='+', default=[], type=str.lower, help="Force use CPU for specified modules, default: %(default)s")
+group.add_argument("--listen", default=os.environ.get("SD_LISTEN", False), action='store_true', help="Launch web server using public IP address, default: %(default)s")
+group.add_argument("--port", type=int, default=os.environ.get("SD_PORT", 7860), help="Launch web server with given server port, default: %(default)s")
+group.add_argument("--freeze", default=os.environ.get("SD_FREEZE", False), action='store_true', help="Disable editing settings")
+group.add_argument("--auth", type=str, default=os.environ.get("SD_AUTH", None), help='Set access authentication like "user:pwd,user:pwd""')
+group.add_argument("--auth-file", type=str, default=os.environ.get("SD_AUTHFILE", None), help='Set access authentication using file, default: %(default)s')
+group.add_argument("--autolaunch", default=os.environ.get("SD_AUTOLAUNCH", False), action='store_true', help="Open the UI URL in the system's default browser upon launch")
+group.add_argument('--docs', default=os.environ.get("SD_DOCS", False), action='store_true', help = "Mount Gradio docs at /docs, default: %(default)s")
+group.add_argument('--api-only', default=os.environ.get("SD_APIONLY", False), action='store_true', help = "Run in API only mode without starting UI")
+group.add_argument("--api-log", default=os.environ.get("SD_APILOG", False), action='store_true', help="Enable logging of all API requests, default: %(default)s")
+group.add_argument("--device-id", type=str, default=os.environ.get("SD_DEVICEID", None), help="Select the default CUDA device to use, default: %(default)s")
+group.add_argument("--cors-origins", type=str, default=os.environ.get("SD_CORSORIGINS", None), help="Allowed CORS origins as comma-separated list, default: %(default)s")
+group.add_argument("--cors-regex", type=str, default=os.environ.get("SD_CORSREGEX", None), help="Allowed CORS origins as regular expression, default: %(default)s")
+group.add_argument("--tls-keyfile", type=str, default=os.environ.get("SD_TLSKEYFILE", None), help="Enable TLS and specify key file, default: %(default)s")
+group.add_argument("--tls-certfile", type=str, default=os.environ.get("SD_TLSCERTFILE", None), help="Enable TLS and specify cert file, default: %(default)s")
+group.add_argument("--tls-selfsign", action="store_true", default=os.environ.get("SD_TLSSELFSIGN", False), help="Enable TLS with self-signed certificates, default: %(default)s")
+group.add_argument("--server-name", type=str, default=os.environ.get("SD_SERVERNAME", None), help="Sets hostname of server, default: %(default)s")
+group.add_argument("--no-hashing", default=os.environ.get("SD_NOHASHING", False), action='store_true', help="Disable hashing of checkpoints, default: %(default)s")
+group.add_argument("--no-metadata", default=os.environ.get("SD_NOMETADATA", False), action='store_true', help="Disable reading of metadata from models, default: %(default)s")
+group.add_argument("--no-download", default=os.environ.get("SD_DOWNLOAD", False), action='store_true', help="Disable download of default model, default: %(default)s")
+group.add_argument("--profile", default=os.environ.get("SD_PROFILE", False), action='store_true', help="Run profiler, default: %(default)s")
+group.add_argument("--disable-queue", default=os.environ.get("SD_DISABLEQUEUE", False), action='store_true', help="Disable queues, default: %(default)s")
+group.add_argument('--debug', default=os.environ.get("SD_DEBUG", False), action='store_true', help = "Run installer with debug logging, default: %(default)s")
+group.add_argument('--use-directml', default=os.environ.get("SD_USEDIRECTML", False), action='store_true', help = "Use DirectML if no compatible GPU is detected, default: %(default)s")
+group.add_argument("--use-openvino", default=os.environ.get("SD_USEOPENVINO", False), action='store_true', help="Use Intel OpenVINO backend, default: %(default)s")
+group.add_argument("--use-ipex", default=os.environ.get("SD_USEIPX", False), action='store_true', help="Force use Intel OneAPI XPU backend, default: %(default)s")
+group.add_argument("--use-cuda", default=os.environ.get("SD_USECUDA", False), action='store_true', help="Force use nVidia CUDA backend, default: %(default)s")
+group.add_argument("--use-rocm", default=os.environ.get("SD_USEROCM", False), action='store_true', help="Force use AMD ROCm backend, default: %(default)s")
+group.add_argument('--subpath', type=str, default=os.environ.get("SD_SUBPATH", None), help='Customize the URL subpath for usage with reverse proxy')
+group.add_argument('--backend', type=str, default=os.environ.get("SD_BACKEND", None), choices=['original', 'diffusers'], required=False, help='force model pipeline type')
+
+
+# removed args are added here as hidden in fixed format for compatbility reasons
+group.add_argument("-f", action='store_true', help=argparse.SUPPRESS) # allows running as root; implemented outside of webui
+group.add_argument("--ui-settings-file", type=str, help=argparse.SUPPRESS, default=os.path.join(data_path, 'config.json'))
+group.add_argument("--ui-config-file", type=str, help=argparse.SUPPRESS, default=os.path.join(data_path, 'ui-config.json'))
+group.add_argument("--hide-ui-dir-config", action='store_true', help=argparse.SUPPRESS, default=False)
+group.add_argument("--theme", type=str, help=argparse.SUPPRESS, default=None)
+group.add_argument("--disable-console-progressbars", action='store_true', help=argparse.SUPPRESS, default=True)
+group.add_argument("--disable-safe-unpickle", action='store_true', help=argparse.SUPPRESS, default=True)
+group.add_argument("--lowram", action='store_true', help=argparse.SUPPRESS)
+group.add_argument("--disable-extension-access", default=False, action='store_true', help=argparse.SUPPRESS)
+group.add_argument("--api", help=argparse.SUPPRESS, default=True)
+group.add_argument("--api-auth", type=str, help=argparse.SUPPRESS, default=None)
+
+
+def compatibility_args(opts, args):
+ # removed args that have been moved to opts are added here as hidden with default values as defined in opts
+ group.add_argument("--ckpt-dir", type=str, help=argparse.SUPPRESS, default=opts.ckpt_dir)
+ group.add_argument("--vae-dir", type=str, help=argparse.SUPPRESS, default=opts.vae_dir)
+ group.add_argument("--embeddings-dir", type=str, help=argparse.SUPPRESS, default=opts.embeddings_dir)
+ group.add_argument("--embeddings-templates-dir", type=str, help=argparse.SUPPRESS, default=opts.embeddings_templates_dir)
+ group.add_argument("--hypernetwork-dir", type=str, help=argparse.SUPPRESS, default=opts.hypernetwork_dir)
+ group.add_argument("--codeformer-models-path", type=str, help=argparse.SUPPRESS, default=opts.codeformer_models_path)
+ group.add_argument("--gfpgan-models-path", type=str, help=argparse.SUPPRESS, default=opts.gfpgan_models_path)
+ group.add_argument("--esrgan-models-path", type=str, help=argparse.SUPPRESS, default=opts.esrgan_models_path)
+ group.add_argument("--bsrgan-models-path", type=str, help=argparse.SUPPRESS, default=opts.bsrgan_models_path)
+ group.add_argument("--realesrgan-models-path", type=str, help=argparse.SUPPRESS, default=opts.realesrgan_models_path)
+ group.add_argument("--scunet-models-path", help=argparse.SUPPRESS, default=opts.scunet_models_path)
+ group.add_argument("--swinir-models-path", help=argparse.SUPPRESS, default=opts.swinir_models_path)
+ group.add_argument("--ldsr-models-path", help=argparse.SUPPRESS, default=opts.ldsr_models_path)
+ group.add_argument("--clip-models-path", type=str, help=argparse.SUPPRESS, default=opts.clip_models_path)
+ group.add_argument("--opt-channelslast", help=argparse.SUPPRESS, action='store_true', default=opts.opt_channelslast)
+ group.add_argument("--xformers", default=(opts.cross_attention_optimization == "xFormers"), action='store_true', help=argparse.SUPPRESS)
+ group.add_argument("--disable-nan-check", help=argparse.SUPPRESS, action='store_true', default=opts.disable_nan_check)
+ group.add_argument("--rollback-vae", help=argparse.SUPPRESS, default=opts.rollback_vae)
+ group.add_argument("--no-half", help=argparse.SUPPRESS, action='store_true', default=opts.no_half)
+ group.add_argument("--no-half-vae", help=argparse.SUPPRESS, action='store_true', default=opts.no_half_vae)
+ group.add_argument("--precision", help=argparse.SUPPRESS, default=opts.precision)
+ group.add_argument("--sub-quad-q-chunk-size", help=argparse.SUPPRESS, default=opts.sub_quad_q_chunk_size)
+ group.add_argument("--sub-quad-kv-chunk-size", help=argparse.SUPPRESS, default=opts.sub_quad_kv_chunk_size)
+ group.add_argument("--sub-quad-chunk-threshold", help=argparse.SUPPRESS, default=opts.sub_quad_chunk_threshold)
+ group.add_argument("--lora-dir", help=argparse.SUPPRESS, default=opts.lora_dir)
+ group.add_argument("--lyco-dir", help=argparse.SUPPRESS, default=opts.lyco_dir)
+ group.add_argument("--embeddings-dir", help=argparse.SUPPRESS, default=opts.embeddings_dir)
+ group.add_argument("--hypernetwork-dir", help=argparse.SUPPRESS, default=opts.hypernetwork_dir)
+ group.add_argument("--lyco-patch-lora", help=argparse.SUPPRESS, action='store_true', default=False)
+ group.add_argument("--lyco-debug", help=argparse.SUPPRESS, action='store_true', default=False)
+ group.add_argument("--enable-console-prompts", help=argparse.SUPPRESS, action='store_true', default=False)
+ group.add_argument("--safe", help=argparse.SUPPRESS, action='store_true', default=False)
+ group.add_argument("--use-xformers", help=argparse.SUPPRESS, action='store_true', default=False)
+
+ # removed opts are added here with fixed values for compatibility reasons
+ opts.use_old_emphasis_implementation = False
+ opts.use_old_karras_scheduler_sigmas = False
+ opts.no_dpmpp_sde_batch_determinism = False
+ opts.lora_apply_to_outputs = False
+ opts.do_not_show_images = False
+ opts.add_model_hash_to_info = True
+ opts.add_model_name_to_info = True
+ opts.js_modal_lightbox = True
+ opts.js_modal_lightbox_initially_zoomed = True
+ opts.show_progress_in_title = False
+ opts.sd_vae_as_default = True
+ opts.enable_emphasis = True
+ opts.enable_batch_seeds = True
+ # opts.multiple_tqdm = False
+ opts.print_hypernet_extra = False
+ opts.dimensions_and_batch_together = True
+ opts.enable_pnginfo = True
+ opts.data['clip_skip'] = 1
+
+ opts.onchange("lora_dir", lambda: setattr(args, "lora_dir", opts.lora_dir))
+ opts.onchange("lyco_dir", lambda: setattr(args, "lyco_dir", opts.lyco_dir))
+
+ args = parser.parse_args()
+ return args
diff --git a/modules/deepbooru_model.py b/modules/deepbooru_model.py
index edeb81866..2963385c3 100644
--- a/modules/deepbooru_model.py
+++ b/modules/deepbooru_model.py
@@ -1,674 +1,674 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from modules import devices
-
-# see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more
-
-
-class DeepDanbooruModel(nn.Module):
- def __init__(self):
- super().__init__()
- self.tags = []
- self.n_Conv_0 = nn.Conv2d(kernel_size=(7, 7), in_channels=3, out_channels=64, stride=(2, 2))
- self.n_MaxPool_0 = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2))
- self.n_Conv_1 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
- self.n_Conv_2 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=64)
- self.n_Conv_3 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
- self.n_Conv_4 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
- self.n_Conv_5 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
- self.n_Conv_6 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
- self.n_Conv_7 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
- self.n_Conv_8 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
- self.n_Conv_9 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
- self.n_Conv_10 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
- self.n_Conv_11 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=512, stride=(2, 2))
- self.n_Conv_12 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=128)
- self.n_Conv_13 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128, stride=(2, 2))
- self.n_Conv_14 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
- self.n_Conv_15 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
- self.n_Conv_16 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
- self.n_Conv_17 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
- self.n_Conv_18 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
- self.n_Conv_19 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
- self.n_Conv_20 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
- self.n_Conv_21 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
- self.n_Conv_22 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
- self.n_Conv_23 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
- self.n_Conv_24 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
- self.n_Conv_25 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
- self.n_Conv_26 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
- self.n_Conv_27 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
- self.n_Conv_28 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
- self.n_Conv_29 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
- self.n_Conv_30 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
- self.n_Conv_31 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
- self.n_Conv_32 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
- self.n_Conv_33 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
- self.n_Conv_34 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
- self.n_Conv_35 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
- self.n_Conv_36 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=1024, stride=(2, 2))
- self.n_Conv_37 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=256)
- self.n_Conv_38 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
- self.n_Conv_39 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_40 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_41 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_42 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_43 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_44 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_45 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_46 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_47 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_48 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_49 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_50 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_51 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_52 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_53 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_54 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_55 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_56 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_57 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_58 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_59 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_60 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_61 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_62 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_63 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_64 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_65 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_66 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_67 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_68 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_69 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_70 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_71 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_72 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_73 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_74 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_75 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_76 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_77 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_78 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_79 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_80 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_81 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_82 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_83 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_84 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_85 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_86 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_87 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_88 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_89 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_90 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_91 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_92 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_93 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_94 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_95 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_96 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_97 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_98 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
- self.n_Conv_99 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_100 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=1024, stride=(2, 2))
- self.n_Conv_101 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_102 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_103 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_104 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_105 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_106 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_107 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_108 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_109 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_110 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_111 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_112 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_113 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_114 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_115 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_116 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_117 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_118 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_119 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_120 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_121 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_122 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_123 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_124 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_125 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_126 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_127 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_128 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_129 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_130 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_131 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_132 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_133 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_134 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_135 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_136 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_137 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_138 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_139 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_140 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_141 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_142 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_143 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_144 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_145 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_146 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_147 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_148 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_149 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_150 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_151 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_152 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_153 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_154 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_155 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_156 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_157 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_158 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=2048, stride=(2, 2))
- self.n_Conv_159 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=512)
- self.n_Conv_160 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512, stride=(2, 2))
- self.n_Conv_161 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
- self.n_Conv_162 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
- self.n_Conv_163 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
- self.n_Conv_164 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
- self.n_Conv_165 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
- self.n_Conv_166 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
- self.n_Conv_167 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
- self.n_Conv_168 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=4096, stride=(2, 2))
- self.n_Conv_169 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=1024)
- self.n_Conv_170 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024, stride=(2, 2))
- self.n_Conv_171 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
- self.n_Conv_172 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
- self.n_Conv_173 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
- self.n_Conv_174 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
- self.n_Conv_175 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
- self.n_Conv_176 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
- self.n_Conv_177 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
- self.n_Conv_178 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=9176, bias=False)
-
- def forward(self, *inputs):
- t_358, = inputs
- t_359 = t_358.permute(*[0, 3, 1, 2])
- t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0)
- t_360 = self.n_Conv_0(t_359_padded.to(self.n_Conv_0.bias.dtype) if devices.unet_needs_upcast else t_359_padded)
- t_361 = F.relu(t_360)
- t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf'))
- t_362 = self.n_MaxPool_0(t_361)
- t_363 = self.n_Conv_1(t_362)
- t_364 = self.n_Conv_2(t_362)
- t_365 = F.relu(t_364)
- t_365_padded = F.pad(t_365, [1, 1, 1, 1], value=0)
- t_366 = self.n_Conv_3(t_365_padded)
- t_367 = F.relu(t_366)
- t_368 = self.n_Conv_4(t_367)
- t_369 = torch.add(t_368, t_363)
- t_370 = F.relu(t_369)
- t_371 = self.n_Conv_5(t_370)
- t_372 = F.relu(t_371)
- t_372_padded = F.pad(t_372, [1, 1, 1, 1], value=0)
- t_373 = self.n_Conv_6(t_372_padded)
- t_374 = F.relu(t_373)
- t_375 = self.n_Conv_7(t_374)
- t_376 = torch.add(t_375, t_370)
- t_377 = F.relu(t_376)
- t_378 = self.n_Conv_8(t_377)
- t_379 = F.relu(t_378)
- t_379_padded = F.pad(t_379, [1, 1, 1, 1], value=0)
- t_380 = self.n_Conv_9(t_379_padded)
- t_381 = F.relu(t_380)
- t_382 = self.n_Conv_10(t_381)
- t_383 = torch.add(t_382, t_377)
- t_384 = F.relu(t_383)
- t_385 = self.n_Conv_11(t_384)
- t_386 = self.n_Conv_12(t_384)
- t_387 = F.relu(t_386)
- t_387_padded = F.pad(t_387, [0, 1, 0, 1], value=0)
- t_388 = self.n_Conv_13(t_387_padded)
- t_389 = F.relu(t_388)
- t_390 = self.n_Conv_14(t_389)
- t_391 = torch.add(t_390, t_385)
- t_392 = F.relu(t_391)
- t_393 = self.n_Conv_15(t_392)
- t_394 = F.relu(t_393)
- t_394_padded = F.pad(t_394, [1, 1, 1, 1], value=0)
- t_395 = self.n_Conv_16(t_394_padded)
- t_396 = F.relu(t_395)
- t_397 = self.n_Conv_17(t_396)
- t_398 = torch.add(t_397, t_392)
- t_399 = F.relu(t_398)
- t_400 = self.n_Conv_18(t_399)
- t_401 = F.relu(t_400)
- t_401_padded = F.pad(t_401, [1, 1, 1, 1], value=0)
- t_402 = self.n_Conv_19(t_401_padded)
- t_403 = F.relu(t_402)
- t_404 = self.n_Conv_20(t_403)
- t_405 = torch.add(t_404, t_399)
- t_406 = F.relu(t_405)
- t_407 = self.n_Conv_21(t_406)
- t_408 = F.relu(t_407)
- t_408_padded = F.pad(t_408, [1, 1, 1, 1], value=0)
- t_409 = self.n_Conv_22(t_408_padded)
- t_410 = F.relu(t_409)
- t_411 = self.n_Conv_23(t_410)
- t_412 = torch.add(t_411, t_406)
- t_413 = F.relu(t_412)
- t_414 = self.n_Conv_24(t_413)
- t_415 = F.relu(t_414)
- t_415_padded = F.pad(t_415, [1, 1, 1, 1], value=0)
- t_416 = self.n_Conv_25(t_415_padded)
- t_417 = F.relu(t_416)
- t_418 = self.n_Conv_26(t_417)
- t_419 = torch.add(t_418, t_413)
- t_420 = F.relu(t_419)
- t_421 = self.n_Conv_27(t_420)
- t_422 = F.relu(t_421)
- t_422_padded = F.pad(t_422, [1, 1, 1, 1], value=0)
- t_423 = self.n_Conv_28(t_422_padded)
- t_424 = F.relu(t_423)
- t_425 = self.n_Conv_29(t_424)
- t_426 = torch.add(t_425, t_420)
- t_427 = F.relu(t_426)
- t_428 = self.n_Conv_30(t_427)
- t_429 = F.relu(t_428)
- t_429_padded = F.pad(t_429, [1, 1, 1, 1], value=0)
- t_430 = self.n_Conv_31(t_429_padded)
- t_431 = F.relu(t_430)
- t_432 = self.n_Conv_32(t_431)
- t_433 = torch.add(t_432, t_427)
- t_434 = F.relu(t_433)
- t_435 = self.n_Conv_33(t_434)
- t_436 = F.relu(t_435)
- t_436_padded = F.pad(t_436, [1, 1, 1, 1], value=0)
- t_437 = self.n_Conv_34(t_436_padded)
- t_438 = F.relu(t_437)
- t_439 = self.n_Conv_35(t_438)
- t_440 = torch.add(t_439, t_434)
- t_441 = F.relu(t_440)
- t_442 = self.n_Conv_36(t_441)
- t_443 = self.n_Conv_37(t_441)
- t_444 = F.relu(t_443)
- t_444_padded = F.pad(t_444, [0, 1, 0, 1], value=0)
- t_445 = self.n_Conv_38(t_444_padded)
- t_446 = F.relu(t_445)
- t_447 = self.n_Conv_39(t_446)
- t_448 = torch.add(t_447, t_442)
- t_449 = F.relu(t_448)
- t_450 = self.n_Conv_40(t_449)
- t_451 = F.relu(t_450)
- t_451_padded = F.pad(t_451, [1, 1, 1, 1], value=0)
- t_452 = self.n_Conv_41(t_451_padded)
- t_453 = F.relu(t_452)
- t_454 = self.n_Conv_42(t_453)
- t_455 = torch.add(t_454, t_449)
- t_456 = F.relu(t_455)
- t_457 = self.n_Conv_43(t_456)
- t_458 = F.relu(t_457)
- t_458_padded = F.pad(t_458, [1, 1, 1, 1], value=0)
- t_459 = self.n_Conv_44(t_458_padded)
- t_460 = F.relu(t_459)
- t_461 = self.n_Conv_45(t_460)
- t_462 = torch.add(t_461, t_456)
- t_463 = F.relu(t_462)
- t_464 = self.n_Conv_46(t_463)
- t_465 = F.relu(t_464)
- t_465_padded = F.pad(t_465, [1, 1, 1, 1], value=0)
- t_466 = self.n_Conv_47(t_465_padded)
- t_467 = F.relu(t_466)
- t_468 = self.n_Conv_48(t_467)
- t_469 = torch.add(t_468, t_463)
- t_470 = F.relu(t_469)
- t_471 = self.n_Conv_49(t_470)
- t_472 = F.relu(t_471)
- t_472_padded = F.pad(t_472, [1, 1, 1, 1], value=0)
- t_473 = self.n_Conv_50(t_472_padded)
- t_474 = F.relu(t_473)
- t_475 = self.n_Conv_51(t_474)
- t_476 = torch.add(t_475, t_470)
- t_477 = F.relu(t_476)
- t_478 = self.n_Conv_52(t_477)
- t_479 = F.relu(t_478)
- t_479_padded = F.pad(t_479, [1, 1, 1, 1], value=0)
- t_480 = self.n_Conv_53(t_479_padded)
- t_481 = F.relu(t_480)
- t_482 = self.n_Conv_54(t_481)
- t_483 = torch.add(t_482, t_477)
- t_484 = F.relu(t_483)
- t_485 = self.n_Conv_55(t_484)
- t_486 = F.relu(t_485)
- t_486_padded = F.pad(t_486, [1, 1, 1, 1], value=0)
- t_487 = self.n_Conv_56(t_486_padded)
- t_488 = F.relu(t_487)
- t_489 = self.n_Conv_57(t_488)
- t_490 = torch.add(t_489, t_484)
- t_491 = F.relu(t_490)
- t_492 = self.n_Conv_58(t_491)
- t_493 = F.relu(t_492)
- t_493_padded = F.pad(t_493, [1, 1, 1, 1], value=0)
- t_494 = self.n_Conv_59(t_493_padded)
- t_495 = F.relu(t_494)
- t_496 = self.n_Conv_60(t_495)
- t_497 = torch.add(t_496, t_491)
- t_498 = F.relu(t_497)
- t_499 = self.n_Conv_61(t_498)
- t_500 = F.relu(t_499)
- t_500_padded = F.pad(t_500, [1, 1, 1, 1], value=0)
- t_501 = self.n_Conv_62(t_500_padded)
- t_502 = F.relu(t_501)
- t_503 = self.n_Conv_63(t_502)
- t_504 = torch.add(t_503, t_498)
- t_505 = F.relu(t_504)
- t_506 = self.n_Conv_64(t_505)
- t_507 = F.relu(t_506)
- t_507_padded = F.pad(t_507, [1, 1, 1, 1], value=0)
- t_508 = self.n_Conv_65(t_507_padded)
- t_509 = F.relu(t_508)
- t_510 = self.n_Conv_66(t_509)
- t_511 = torch.add(t_510, t_505)
- t_512 = F.relu(t_511)
- t_513 = self.n_Conv_67(t_512)
- t_514 = F.relu(t_513)
- t_514_padded = F.pad(t_514, [1, 1, 1, 1], value=0)
- t_515 = self.n_Conv_68(t_514_padded)
- t_516 = F.relu(t_515)
- t_517 = self.n_Conv_69(t_516)
- t_518 = torch.add(t_517, t_512)
- t_519 = F.relu(t_518)
- t_520 = self.n_Conv_70(t_519)
- t_521 = F.relu(t_520)
- t_521_padded = F.pad(t_521, [1, 1, 1, 1], value=0)
- t_522 = self.n_Conv_71(t_521_padded)
- t_523 = F.relu(t_522)
- t_524 = self.n_Conv_72(t_523)
- t_525 = torch.add(t_524, t_519)
- t_526 = F.relu(t_525)
- t_527 = self.n_Conv_73(t_526)
- t_528 = F.relu(t_527)
- t_528_padded = F.pad(t_528, [1, 1, 1, 1], value=0)
- t_529 = self.n_Conv_74(t_528_padded)
- t_530 = F.relu(t_529)
- t_531 = self.n_Conv_75(t_530)
- t_532 = torch.add(t_531, t_526)
- t_533 = F.relu(t_532)
- t_534 = self.n_Conv_76(t_533)
- t_535 = F.relu(t_534)
- t_535_padded = F.pad(t_535, [1, 1, 1, 1], value=0)
- t_536 = self.n_Conv_77(t_535_padded)
- t_537 = F.relu(t_536)
- t_538 = self.n_Conv_78(t_537)
- t_539 = torch.add(t_538, t_533)
- t_540 = F.relu(t_539)
- t_541 = self.n_Conv_79(t_540)
- t_542 = F.relu(t_541)
- t_542_padded = F.pad(t_542, [1, 1, 1, 1], value=0)
- t_543 = self.n_Conv_80(t_542_padded)
- t_544 = F.relu(t_543)
- t_545 = self.n_Conv_81(t_544)
- t_546 = torch.add(t_545, t_540)
- t_547 = F.relu(t_546)
- t_548 = self.n_Conv_82(t_547)
- t_549 = F.relu(t_548)
- t_549_padded = F.pad(t_549, [1, 1, 1, 1], value=0)
- t_550 = self.n_Conv_83(t_549_padded)
- t_551 = F.relu(t_550)
- t_552 = self.n_Conv_84(t_551)
- t_553 = torch.add(t_552, t_547)
- t_554 = F.relu(t_553)
- t_555 = self.n_Conv_85(t_554)
- t_556 = F.relu(t_555)
- t_556_padded = F.pad(t_556, [1, 1, 1, 1], value=0)
- t_557 = self.n_Conv_86(t_556_padded)
- t_558 = F.relu(t_557)
- t_559 = self.n_Conv_87(t_558)
- t_560 = torch.add(t_559, t_554)
- t_561 = F.relu(t_560)
- t_562 = self.n_Conv_88(t_561)
- t_563 = F.relu(t_562)
- t_563_padded = F.pad(t_563, [1, 1, 1, 1], value=0)
- t_564 = self.n_Conv_89(t_563_padded)
- t_565 = F.relu(t_564)
- t_566 = self.n_Conv_90(t_565)
- t_567 = torch.add(t_566, t_561)
- t_568 = F.relu(t_567)
- t_569 = self.n_Conv_91(t_568)
- t_570 = F.relu(t_569)
- t_570_padded = F.pad(t_570, [1, 1, 1, 1], value=0)
- t_571 = self.n_Conv_92(t_570_padded)
- t_572 = F.relu(t_571)
- t_573 = self.n_Conv_93(t_572)
- t_574 = torch.add(t_573, t_568)
- t_575 = F.relu(t_574)
- t_576 = self.n_Conv_94(t_575)
- t_577 = F.relu(t_576)
- t_577_padded = F.pad(t_577, [1, 1, 1, 1], value=0)
- t_578 = self.n_Conv_95(t_577_padded)
- t_579 = F.relu(t_578)
- t_580 = self.n_Conv_96(t_579)
- t_581 = torch.add(t_580, t_575)
- t_582 = F.relu(t_581)
- t_583 = self.n_Conv_97(t_582)
- t_584 = F.relu(t_583)
- t_584_padded = F.pad(t_584, [0, 1, 0, 1], value=0)
- t_585 = self.n_Conv_98(t_584_padded)
- t_586 = F.relu(t_585)
- t_587 = self.n_Conv_99(t_586)
- t_588 = self.n_Conv_100(t_582)
- t_589 = torch.add(t_587, t_588)
- t_590 = F.relu(t_589)
- t_591 = self.n_Conv_101(t_590)
- t_592 = F.relu(t_591)
- t_592_padded = F.pad(t_592, [1, 1, 1, 1], value=0)
- t_593 = self.n_Conv_102(t_592_padded)
- t_594 = F.relu(t_593)
- t_595 = self.n_Conv_103(t_594)
- t_596 = torch.add(t_595, t_590)
- t_597 = F.relu(t_596)
- t_598 = self.n_Conv_104(t_597)
- t_599 = F.relu(t_598)
- t_599_padded = F.pad(t_599, [1, 1, 1, 1], value=0)
- t_600 = self.n_Conv_105(t_599_padded)
- t_601 = F.relu(t_600)
- t_602 = self.n_Conv_106(t_601)
- t_603 = torch.add(t_602, t_597)
- t_604 = F.relu(t_603)
- t_605 = self.n_Conv_107(t_604)
- t_606 = F.relu(t_605)
- t_606_padded = F.pad(t_606, [1, 1, 1, 1], value=0)
- t_607 = self.n_Conv_108(t_606_padded)
- t_608 = F.relu(t_607)
- t_609 = self.n_Conv_109(t_608)
- t_610 = torch.add(t_609, t_604)
- t_611 = F.relu(t_610)
- t_612 = self.n_Conv_110(t_611)
- t_613 = F.relu(t_612)
- t_613_padded = F.pad(t_613, [1, 1, 1, 1], value=0)
- t_614 = self.n_Conv_111(t_613_padded)
- t_615 = F.relu(t_614)
- t_616 = self.n_Conv_112(t_615)
- t_617 = torch.add(t_616, t_611)
- t_618 = F.relu(t_617)
- t_619 = self.n_Conv_113(t_618)
- t_620 = F.relu(t_619)
- t_620_padded = F.pad(t_620, [1, 1, 1, 1], value=0)
- t_621 = self.n_Conv_114(t_620_padded)
- t_622 = F.relu(t_621)
- t_623 = self.n_Conv_115(t_622)
- t_624 = torch.add(t_623, t_618)
- t_625 = F.relu(t_624)
- t_626 = self.n_Conv_116(t_625)
- t_627 = F.relu(t_626)
- t_627_padded = F.pad(t_627, [1, 1, 1, 1], value=0)
- t_628 = self.n_Conv_117(t_627_padded)
- t_629 = F.relu(t_628)
- t_630 = self.n_Conv_118(t_629)
- t_631 = torch.add(t_630, t_625)
- t_632 = F.relu(t_631)
- t_633 = self.n_Conv_119(t_632)
- t_634 = F.relu(t_633)
- t_634_padded = F.pad(t_634, [1, 1, 1, 1], value=0)
- t_635 = self.n_Conv_120(t_634_padded)
- t_636 = F.relu(t_635)
- t_637 = self.n_Conv_121(t_636)
- t_638 = torch.add(t_637, t_632)
- t_639 = F.relu(t_638)
- t_640 = self.n_Conv_122(t_639)
- t_641 = F.relu(t_640)
- t_641_padded = F.pad(t_641, [1, 1, 1, 1], value=0)
- t_642 = self.n_Conv_123(t_641_padded)
- t_643 = F.relu(t_642)
- t_644 = self.n_Conv_124(t_643)
- t_645 = torch.add(t_644, t_639)
- t_646 = F.relu(t_645)
- t_647 = self.n_Conv_125(t_646)
- t_648 = F.relu(t_647)
- t_648_padded = F.pad(t_648, [1, 1, 1, 1], value=0)
- t_649 = self.n_Conv_126(t_648_padded)
- t_650 = F.relu(t_649)
- t_651 = self.n_Conv_127(t_650)
- t_652 = torch.add(t_651, t_646)
- t_653 = F.relu(t_652)
- t_654 = self.n_Conv_128(t_653)
- t_655 = F.relu(t_654)
- t_655_padded = F.pad(t_655, [1, 1, 1, 1], value=0)
- t_656 = self.n_Conv_129(t_655_padded)
- t_657 = F.relu(t_656)
- t_658 = self.n_Conv_130(t_657)
- t_659 = torch.add(t_658, t_653)
- t_660 = F.relu(t_659)
- t_661 = self.n_Conv_131(t_660)
- t_662 = F.relu(t_661)
- t_662_padded = F.pad(t_662, [1, 1, 1, 1], value=0)
- t_663 = self.n_Conv_132(t_662_padded)
- t_664 = F.relu(t_663)
- t_665 = self.n_Conv_133(t_664)
- t_666 = torch.add(t_665, t_660)
- t_667 = F.relu(t_666)
- t_668 = self.n_Conv_134(t_667)
- t_669 = F.relu(t_668)
- t_669_padded = F.pad(t_669, [1, 1, 1, 1], value=0)
- t_670 = self.n_Conv_135(t_669_padded)
- t_671 = F.relu(t_670)
- t_672 = self.n_Conv_136(t_671)
- t_673 = torch.add(t_672, t_667)
- t_674 = F.relu(t_673)
- t_675 = self.n_Conv_137(t_674)
- t_676 = F.relu(t_675)
- t_676_padded = F.pad(t_676, [1, 1, 1, 1], value=0)
- t_677 = self.n_Conv_138(t_676_padded)
- t_678 = F.relu(t_677)
- t_679 = self.n_Conv_139(t_678)
- t_680 = torch.add(t_679, t_674)
- t_681 = F.relu(t_680)
- t_682 = self.n_Conv_140(t_681)
- t_683 = F.relu(t_682)
- t_683_padded = F.pad(t_683, [1, 1, 1, 1], value=0)
- t_684 = self.n_Conv_141(t_683_padded)
- t_685 = F.relu(t_684)
- t_686 = self.n_Conv_142(t_685)
- t_687 = torch.add(t_686, t_681)
- t_688 = F.relu(t_687)
- t_689 = self.n_Conv_143(t_688)
- t_690 = F.relu(t_689)
- t_690_padded = F.pad(t_690, [1, 1, 1, 1], value=0)
- t_691 = self.n_Conv_144(t_690_padded)
- t_692 = F.relu(t_691)
- t_693 = self.n_Conv_145(t_692)
- t_694 = torch.add(t_693, t_688)
- t_695 = F.relu(t_694)
- t_696 = self.n_Conv_146(t_695)
- t_697 = F.relu(t_696)
- t_697_padded = F.pad(t_697, [1, 1, 1, 1], value=0)
- t_698 = self.n_Conv_147(t_697_padded)
- t_699 = F.relu(t_698)
- t_700 = self.n_Conv_148(t_699)
- t_701 = torch.add(t_700, t_695)
- t_702 = F.relu(t_701)
- t_703 = self.n_Conv_149(t_702)
- t_704 = F.relu(t_703)
- t_704_padded = F.pad(t_704, [1, 1, 1, 1], value=0)
- t_705 = self.n_Conv_150(t_704_padded)
- t_706 = F.relu(t_705)
- t_707 = self.n_Conv_151(t_706)
- t_708 = torch.add(t_707, t_702)
- t_709 = F.relu(t_708)
- t_710 = self.n_Conv_152(t_709)
- t_711 = F.relu(t_710)
- t_711_padded = F.pad(t_711, [1, 1, 1, 1], value=0)
- t_712 = self.n_Conv_153(t_711_padded)
- t_713 = F.relu(t_712)
- t_714 = self.n_Conv_154(t_713)
- t_715 = torch.add(t_714, t_709)
- t_716 = F.relu(t_715)
- t_717 = self.n_Conv_155(t_716)
- t_718 = F.relu(t_717)
- t_718_padded = F.pad(t_718, [1, 1, 1, 1], value=0)
- t_719 = self.n_Conv_156(t_718_padded)
- t_720 = F.relu(t_719)
- t_721 = self.n_Conv_157(t_720)
- t_722 = torch.add(t_721, t_716)
- t_723 = F.relu(t_722)
- t_724 = self.n_Conv_158(t_723)
- t_725 = self.n_Conv_159(t_723)
- t_726 = F.relu(t_725)
- t_726_padded = F.pad(t_726, [0, 1, 0, 1], value=0)
- t_727 = self.n_Conv_160(t_726_padded)
- t_728 = F.relu(t_727)
- t_729 = self.n_Conv_161(t_728)
- t_730 = torch.add(t_729, t_724)
- t_731 = F.relu(t_730)
- t_732 = self.n_Conv_162(t_731)
- t_733 = F.relu(t_732)
- t_733_padded = F.pad(t_733, [1, 1, 1, 1], value=0)
- t_734 = self.n_Conv_163(t_733_padded)
- t_735 = F.relu(t_734)
- t_736 = self.n_Conv_164(t_735)
- t_737 = torch.add(t_736, t_731)
- t_738 = F.relu(t_737)
- t_739 = self.n_Conv_165(t_738)
- t_740 = F.relu(t_739)
- t_740_padded = F.pad(t_740, [1, 1, 1, 1], value=0)
- t_741 = self.n_Conv_166(t_740_padded)
- t_742 = F.relu(t_741)
- t_743 = self.n_Conv_167(t_742)
- t_744 = torch.add(t_743, t_738)
- t_745 = F.relu(t_744)
- t_746 = self.n_Conv_168(t_745)
- t_747 = self.n_Conv_169(t_745)
- t_748 = F.relu(t_747)
- t_748_padded = F.pad(t_748, [0, 1, 0, 1], value=0)
- t_749 = self.n_Conv_170(t_748_padded)
- t_750 = F.relu(t_749)
- t_751 = self.n_Conv_171(t_750)
- t_752 = torch.add(t_751, t_746)
- t_753 = F.relu(t_752)
- t_754 = self.n_Conv_172(t_753)
- t_755 = F.relu(t_754)
- t_755_padded = F.pad(t_755, [1, 1, 1, 1], value=0)
- t_756 = self.n_Conv_173(t_755_padded)
- t_757 = F.relu(t_756)
- t_758 = self.n_Conv_174(t_757)
- t_759 = torch.add(t_758, t_753)
- t_760 = F.relu(t_759)
- t_761 = self.n_Conv_175(t_760)
- t_762 = F.relu(t_761)
- t_762_padded = F.pad(t_762, [1, 1, 1, 1], value=0)
- t_763 = self.n_Conv_176(t_762_padded)
- t_764 = F.relu(t_763)
- t_765 = self.n_Conv_177(t_764)
- t_766 = torch.add(t_765, t_760)
- t_767 = F.relu(t_766)
- t_768 = self.n_Conv_178(t_767)
- t_769 = F.avg_pool2d(t_768, kernel_size=t_768.shape[-2:])
- t_770 = torch.squeeze(t_769, 3)
- t_770 = torch.squeeze(t_770, 2)
- t_771 = torch.sigmoid(t_770)
- return t_771
-
- def load_state_dict(self, state_dict, **kwargs): # pylint: disable=arguments-differ,unused-argument
- self.tags = state_dict.get('tags', [])
- super(DeepDanbooruModel, self).load_state_dict({k: v for k, v in state_dict.items() if k != 'tags'}) # pylint: disable=R1725
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from modules import devices
+
+# see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more
+
+
+class DeepDanbooruModel(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.tags = []
+ self.n_Conv_0 = nn.Conv2d(kernel_size=(7, 7), in_channels=3, out_channels=64, stride=(2, 2))
+ self.n_MaxPool_0 = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2))
+ self.n_Conv_1 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
+ self.n_Conv_2 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=64)
+ self.n_Conv_3 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
+ self.n_Conv_4 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
+ self.n_Conv_5 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
+ self.n_Conv_6 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
+ self.n_Conv_7 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
+ self.n_Conv_8 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
+ self.n_Conv_9 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
+ self.n_Conv_10 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
+ self.n_Conv_11 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=512, stride=(2, 2))
+ self.n_Conv_12 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=128)
+ self.n_Conv_13 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128, stride=(2, 2))
+ self.n_Conv_14 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
+ self.n_Conv_15 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
+ self.n_Conv_16 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
+ self.n_Conv_17 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
+ self.n_Conv_18 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
+ self.n_Conv_19 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
+ self.n_Conv_20 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
+ self.n_Conv_21 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
+ self.n_Conv_22 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
+ self.n_Conv_23 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
+ self.n_Conv_24 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
+ self.n_Conv_25 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
+ self.n_Conv_26 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
+ self.n_Conv_27 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
+ self.n_Conv_28 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
+ self.n_Conv_29 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
+ self.n_Conv_30 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
+ self.n_Conv_31 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
+ self.n_Conv_32 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
+ self.n_Conv_33 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
+ self.n_Conv_34 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
+ self.n_Conv_35 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
+ self.n_Conv_36 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=1024, stride=(2, 2))
+ self.n_Conv_37 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=256)
+ self.n_Conv_38 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
+ self.n_Conv_39 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_40 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_41 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_42 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_43 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_44 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_45 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_46 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_47 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_48 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_49 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_50 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_51 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_52 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_53 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_54 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_55 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_56 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_57 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_58 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_59 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_60 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_61 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_62 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_63 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_64 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_65 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_66 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_67 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_68 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_69 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_70 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_71 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_72 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_73 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_74 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_75 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_76 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_77 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_78 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_79 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_80 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_81 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_82 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_83 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_84 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_85 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_86 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_87 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_88 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_89 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_90 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_91 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_92 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_93 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_94 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_95 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_96 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_97 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_98 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
+ self.n_Conv_99 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_100 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=1024, stride=(2, 2))
+ self.n_Conv_101 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_102 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_103 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_104 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_105 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_106 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_107 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_108 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_109 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_110 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_111 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_112 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_113 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_114 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_115 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_116 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_117 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_118 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_119 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_120 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_121 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_122 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_123 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_124 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_125 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_126 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_127 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_128 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_129 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_130 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_131 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_132 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_133 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_134 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_135 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_136 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_137 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_138 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_139 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_140 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_141 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_142 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_143 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_144 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_145 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_146 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_147 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_148 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_149 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_150 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_151 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_152 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_153 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_154 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_155 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_156 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_157 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_158 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=2048, stride=(2, 2))
+ self.n_Conv_159 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=512)
+ self.n_Conv_160 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512, stride=(2, 2))
+ self.n_Conv_161 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
+ self.n_Conv_162 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
+ self.n_Conv_163 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
+ self.n_Conv_164 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
+ self.n_Conv_165 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
+ self.n_Conv_166 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
+ self.n_Conv_167 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
+ self.n_Conv_168 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=4096, stride=(2, 2))
+ self.n_Conv_169 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=1024)
+ self.n_Conv_170 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024, stride=(2, 2))
+ self.n_Conv_171 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
+ self.n_Conv_172 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
+ self.n_Conv_173 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
+ self.n_Conv_174 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
+ self.n_Conv_175 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
+ self.n_Conv_176 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
+ self.n_Conv_177 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
+ self.n_Conv_178 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=9176, bias=False)
+
+ def forward(self, *inputs):
+ t_358, = inputs
+ t_359 = t_358.permute(*[0, 3, 1, 2])
+ t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0)
+ t_360 = self.n_Conv_0(t_359_padded.to(self.n_Conv_0.bias.dtype) if devices.unet_needs_upcast else t_359_padded)
+ t_361 = F.relu(t_360)
+ t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf'))
+ t_362 = self.n_MaxPool_0(t_361)
+ t_363 = self.n_Conv_1(t_362)
+ t_364 = self.n_Conv_2(t_362)
+ t_365 = F.relu(t_364)
+ t_365_padded = F.pad(t_365, [1, 1, 1, 1], value=0)
+ t_366 = self.n_Conv_3(t_365_padded)
+ t_367 = F.relu(t_366)
+ t_368 = self.n_Conv_4(t_367)
+ t_369 = torch.add(t_368, t_363)
+ t_370 = F.relu(t_369)
+ t_371 = self.n_Conv_5(t_370)
+ t_372 = F.relu(t_371)
+ t_372_padded = F.pad(t_372, [1, 1, 1, 1], value=0)
+ t_373 = self.n_Conv_6(t_372_padded)
+ t_374 = F.relu(t_373)
+ t_375 = self.n_Conv_7(t_374)
+ t_376 = torch.add(t_375, t_370)
+ t_377 = F.relu(t_376)
+ t_378 = self.n_Conv_8(t_377)
+ t_379 = F.relu(t_378)
+ t_379_padded = F.pad(t_379, [1, 1, 1, 1], value=0)
+ t_380 = self.n_Conv_9(t_379_padded)
+ t_381 = F.relu(t_380)
+ t_382 = self.n_Conv_10(t_381)
+ t_383 = torch.add(t_382, t_377)
+ t_384 = F.relu(t_383)
+ t_385 = self.n_Conv_11(t_384)
+ t_386 = self.n_Conv_12(t_384)
+ t_387 = F.relu(t_386)
+ t_387_padded = F.pad(t_387, [0, 1, 0, 1], value=0)
+ t_388 = self.n_Conv_13(t_387_padded)
+ t_389 = F.relu(t_388)
+ t_390 = self.n_Conv_14(t_389)
+ t_391 = torch.add(t_390, t_385)
+ t_392 = F.relu(t_391)
+ t_393 = self.n_Conv_15(t_392)
+ t_394 = F.relu(t_393)
+ t_394_padded = F.pad(t_394, [1, 1, 1, 1], value=0)
+ t_395 = self.n_Conv_16(t_394_padded)
+ t_396 = F.relu(t_395)
+ t_397 = self.n_Conv_17(t_396)
+ t_398 = torch.add(t_397, t_392)
+ t_399 = F.relu(t_398)
+ t_400 = self.n_Conv_18(t_399)
+ t_401 = F.relu(t_400)
+ t_401_padded = F.pad(t_401, [1, 1, 1, 1], value=0)
+ t_402 = self.n_Conv_19(t_401_padded)
+ t_403 = F.relu(t_402)
+ t_404 = self.n_Conv_20(t_403)
+ t_405 = torch.add(t_404, t_399)
+ t_406 = F.relu(t_405)
+ t_407 = self.n_Conv_21(t_406)
+ t_408 = F.relu(t_407)
+ t_408_padded = F.pad(t_408, [1, 1, 1, 1], value=0)
+ t_409 = self.n_Conv_22(t_408_padded)
+ t_410 = F.relu(t_409)
+ t_411 = self.n_Conv_23(t_410)
+ t_412 = torch.add(t_411, t_406)
+ t_413 = F.relu(t_412)
+ t_414 = self.n_Conv_24(t_413)
+ t_415 = F.relu(t_414)
+ t_415_padded = F.pad(t_415, [1, 1, 1, 1], value=0)
+ t_416 = self.n_Conv_25(t_415_padded)
+ t_417 = F.relu(t_416)
+ t_418 = self.n_Conv_26(t_417)
+ t_419 = torch.add(t_418, t_413)
+ t_420 = F.relu(t_419)
+ t_421 = self.n_Conv_27(t_420)
+ t_422 = F.relu(t_421)
+ t_422_padded = F.pad(t_422, [1, 1, 1, 1], value=0)
+ t_423 = self.n_Conv_28(t_422_padded)
+ t_424 = F.relu(t_423)
+ t_425 = self.n_Conv_29(t_424)
+ t_426 = torch.add(t_425, t_420)
+ t_427 = F.relu(t_426)
+ t_428 = self.n_Conv_30(t_427)
+ t_429 = F.relu(t_428)
+ t_429_padded = F.pad(t_429, [1, 1, 1, 1], value=0)
+ t_430 = self.n_Conv_31(t_429_padded)
+ t_431 = F.relu(t_430)
+ t_432 = self.n_Conv_32(t_431)
+ t_433 = torch.add(t_432, t_427)
+ t_434 = F.relu(t_433)
+ t_435 = self.n_Conv_33(t_434)
+ t_436 = F.relu(t_435)
+ t_436_padded = F.pad(t_436, [1, 1, 1, 1], value=0)
+ t_437 = self.n_Conv_34(t_436_padded)
+ t_438 = F.relu(t_437)
+ t_439 = self.n_Conv_35(t_438)
+ t_440 = torch.add(t_439, t_434)
+ t_441 = F.relu(t_440)
+ t_442 = self.n_Conv_36(t_441)
+ t_443 = self.n_Conv_37(t_441)
+ t_444 = F.relu(t_443)
+ t_444_padded = F.pad(t_444, [0, 1, 0, 1], value=0)
+ t_445 = self.n_Conv_38(t_444_padded)
+ t_446 = F.relu(t_445)
+ t_447 = self.n_Conv_39(t_446)
+ t_448 = torch.add(t_447, t_442)
+ t_449 = F.relu(t_448)
+ t_450 = self.n_Conv_40(t_449)
+ t_451 = F.relu(t_450)
+ t_451_padded = F.pad(t_451, [1, 1, 1, 1], value=0)
+ t_452 = self.n_Conv_41(t_451_padded)
+ t_453 = F.relu(t_452)
+ t_454 = self.n_Conv_42(t_453)
+ t_455 = torch.add(t_454, t_449)
+ t_456 = F.relu(t_455)
+ t_457 = self.n_Conv_43(t_456)
+ t_458 = F.relu(t_457)
+ t_458_padded = F.pad(t_458, [1, 1, 1, 1], value=0)
+ t_459 = self.n_Conv_44(t_458_padded)
+ t_460 = F.relu(t_459)
+ t_461 = self.n_Conv_45(t_460)
+ t_462 = torch.add(t_461, t_456)
+ t_463 = F.relu(t_462)
+ t_464 = self.n_Conv_46(t_463)
+ t_465 = F.relu(t_464)
+ t_465_padded = F.pad(t_465, [1, 1, 1, 1], value=0)
+ t_466 = self.n_Conv_47(t_465_padded)
+ t_467 = F.relu(t_466)
+ t_468 = self.n_Conv_48(t_467)
+ t_469 = torch.add(t_468, t_463)
+ t_470 = F.relu(t_469)
+ t_471 = self.n_Conv_49(t_470)
+ t_472 = F.relu(t_471)
+ t_472_padded = F.pad(t_472, [1, 1, 1, 1], value=0)
+ t_473 = self.n_Conv_50(t_472_padded)
+ t_474 = F.relu(t_473)
+ t_475 = self.n_Conv_51(t_474)
+ t_476 = torch.add(t_475, t_470)
+ t_477 = F.relu(t_476)
+ t_478 = self.n_Conv_52(t_477)
+ t_479 = F.relu(t_478)
+ t_479_padded = F.pad(t_479, [1, 1, 1, 1], value=0)
+ t_480 = self.n_Conv_53(t_479_padded)
+ t_481 = F.relu(t_480)
+ t_482 = self.n_Conv_54(t_481)
+ t_483 = torch.add(t_482, t_477)
+ t_484 = F.relu(t_483)
+ t_485 = self.n_Conv_55(t_484)
+ t_486 = F.relu(t_485)
+ t_486_padded = F.pad(t_486, [1, 1, 1, 1], value=0)
+ t_487 = self.n_Conv_56(t_486_padded)
+ t_488 = F.relu(t_487)
+ t_489 = self.n_Conv_57(t_488)
+ t_490 = torch.add(t_489, t_484)
+ t_491 = F.relu(t_490)
+ t_492 = self.n_Conv_58(t_491)
+ t_493 = F.relu(t_492)
+ t_493_padded = F.pad(t_493, [1, 1, 1, 1], value=0)
+ t_494 = self.n_Conv_59(t_493_padded)
+ t_495 = F.relu(t_494)
+ t_496 = self.n_Conv_60(t_495)
+ t_497 = torch.add(t_496, t_491)
+ t_498 = F.relu(t_497)
+ t_499 = self.n_Conv_61(t_498)
+ t_500 = F.relu(t_499)
+ t_500_padded = F.pad(t_500, [1, 1, 1, 1], value=0)
+ t_501 = self.n_Conv_62(t_500_padded)
+ t_502 = F.relu(t_501)
+ t_503 = self.n_Conv_63(t_502)
+ t_504 = torch.add(t_503, t_498)
+ t_505 = F.relu(t_504)
+ t_506 = self.n_Conv_64(t_505)
+ t_507 = F.relu(t_506)
+ t_507_padded = F.pad(t_507, [1, 1, 1, 1], value=0)
+ t_508 = self.n_Conv_65(t_507_padded)
+ t_509 = F.relu(t_508)
+ t_510 = self.n_Conv_66(t_509)
+ t_511 = torch.add(t_510, t_505)
+ t_512 = F.relu(t_511)
+ t_513 = self.n_Conv_67(t_512)
+ t_514 = F.relu(t_513)
+ t_514_padded = F.pad(t_514, [1, 1, 1, 1], value=0)
+ t_515 = self.n_Conv_68(t_514_padded)
+ t_516 = F.relu(t_515)
+ t_517 = self.n_Conv_69(t_516)
+ t_518 = torch.add(t_517, t_512)
+ t_519 = F.relu(t_518)
+ t_520 = self.n_Conv_70(t_519)
+ t_521 = F.relu(t_520)
+ t_521_padded = F.pad(t_521, [1, 1, 1, 1], value=0)
+ t_522 = self.n_Conv_71(t_521_padded)
+ t_523 = F.relu(t_522)
+ t_524 = self.n_Conv_72(t_523)
+ t_525 = torch.add(t_524, t_519)
+ t_526 = F.relu(t_525)
+ t_527 = self.n_Conv_73(t_526)
+ t_528 = F.relu(t_527)
+ t_528_padded = F.pad(t_528, [1, 1, 1, 1], value=0)
+ t_529 = self.n_Conv_74(t_528_padded)
+ t_530 = F.relu(t_529)
+ t_531 = self.n_Conv_75(t_530)
+ t_532 = torch.add(t_531, t_526)
+ t_533 = F.relu(t_532)
+ t_534 = self.n_Conv_76(t_533)
+ t_535 = F.relu(t_534)
+ t_535_padded = F.pad(t_535, [1, 1, 1, 1], value=0)
+ t_536 = self.n_Conv_77(t_535_padded)
+ t_537 = F.relu(t_536)
+ t_538 = self.n_Conv_78(t_537)
+ t_539 = torch.add(t_538, t_533)
+ t_540 = F.relu(t_539)
+ t_541 = self.n_Conv_79(t_540)
+ t_542 = F.relu(t_541)
+ t_542_padded = F.pad(t_542, [1, 1, 1, 1], value=0)
+ t_543 = self.n_Conv_80(t_542_padded)
+ t_544 = F.relu(t_543)
+ t_545 = self.n_Conv_81(t_544)
+ t_546 = torch.add(t_545, t_540)
+ t_547 = F.relu(t_546)
+ t_548 = self.n_Conv_82(t_547)
+ t_549 = F.relu(t_548)
+ t_549_padded = F.pad(t_549, [1, 1, 1, 1], value=0)
+ t_550 = self.n_Conv_83(t_549_padded)
+ t_551 = F.relu(t_550)
+ t_552 = self.n_Conv_84(t_551)
+ t_553 = torch.add(t_552, t_547)
+ t_554 = F.relu(t_553)
+ t_555 = self.n_Conv_85(t_554)
+ t_556 = F.relu(t_555)
+ t_556_padded = F.pad(t_556, [1, 1, 1, 1], value=0)
+ t_557 = self.n_Conv_86(t_556_padded)
+ t_558 = F.relu(t_557)
+ t_559 = self.n_Conv_87(t_558)
+ t_560 = torch.add(t_559, t_554)
+ t_561 = F.relu(t_560)
+ t_562 = self.n_Conv_88(t_561)
+ t_563 = F.relu(t_562)
+ t_563_padded = F.pad(t_563, [1, 1, 1, 1], value=0)
+ t_564 = self.n_Conv_89(t_563_padded)
+ t_565 = F.relu(t_564)
+ t_566 = self.n_Conv_90(t_565)
+ t_567 = torch.add(t_566, t_561)
+ t_568 = F.relu(t_567)
+ t_569 = self.n_Conv_91(t_568)
+ t_570 = F.relu(t_569)
+ t_570_padded = F.pad(t_570, [1, 1, 1, 1], value=0)
+ t_571 = self.n_Conv_92(t_570_padded)
+ t_572 = F.relu(t_571)
+ t_573 = self.n_Conv_93(t_572)
+ t_574 = torch.add(t_573, t_568)
+ t_575 = F.relu(t_574)
+ t_576 = self.n_Conv_94(t_575)
+ t_577 = F.relu(t_576)
+ t_577_padded = F.pad(t_577, [1, 1, 1, 1], value=0)
+ t_578 = self.n_Conv_95(t_577_padded)
+ t_579 = F.relu(t_578)
+ t_580 = self.n_Conv_96(t_579)
+ t_581 = torch.add(t_580, t_575)
+ t_582 = F.relu(t_581)
+ t_583 = self.n_Conv_97(t_582)
+ t_584 = F.relu(t_583)
+ t_584_padded = F.pad(t_584, [0, 1, 0, 1], value=0)
+ t_585 = self.n_Conv_98(t_584_padded)
+ t_586 = F.relu(t_585)
+ t_587 = self.n_Conv_99(t_586)
+ t_588 = self.n_Conv_100(t_582)
+ t_589 = torch.add(t_587, t_588)
+ t_590 = F.relu(t_589)
+ t_591 = self.n_Conv_101(t_590)
+ t_592 = F.relu(t_591)
+ t_592_padded = F.pad(t_592, [1, 1, 1, 1], value=0)
+ t_593 = self.n_Conv_102(t_592_padded)
+ t_594 = F.relu(t_593)
+ t_595 = self.n_Conv_103(t_594)
+ t_596 = torch.add(t_595, t_590)
+ t_597 = F.relu(t_596)
+ t_598 = self.n_Conv_104(t_597)
+ t_599 = F.relu(t_598)
+ t_599_padded = F.pad(t_599, [1, 1, 1, 1], value=0)
+ t_600 = self.n_Conv_105(t_599_padded)
+ t_601 = F.relu(t_600)
+ t_602 = self.n_Conv_106(t_601)
+ t_603 = torch.add(t_602, t_597)
+ t_604 = F.relu(t_603)
+ t_605 = self.n_Conv_107(t_604)
+ t_606 = F.relu(t_605)
+ t_606_padded = F.pad(t_606, [1, 1, 1, 1], value=0)
+ t_607 = self.n_Conv_108(t_606_padded)
+ t_608 = F.relu(t_607)
+ t_609 = self.n_Conv_109(t_608)
+ t_610 = torch.add(t_609, t_604)
+ t_611 = F.relu(t_610)
+ t_612 = self.n_Conv_110(t_611)
+ t_613 = F.relu(t_612)
+ t_613_padded = F.pad(t_613, [1, 1, 1, 1], value=0)
+ t_614 = self.n_Conv_111(t_613_padded)
+ t_615 = F.relu(t_614)
+ t_616 = self.n_Conv_112(t_615)
+ t_617 = torch.add(t_616, t_611)
+ t_618 = F.relu(t_617)
+ t_619 = self.n_Conv_113(t_618)
+ t_620 = F.relu(t_619)
+ t_620_padded = F.pad(t_620, [1, 1, 1, 1], value=0)
+ t_621 = self.n_Conv_114(t_620_padded)
+ t_622 = F.relu(t_621)
+ t_623 = self.n_Conv_115(t_622)
+ t_624 = torch.add(t_623, t_618)
+ t_625 = F.relu(t_624)
+ t_626 = self.n_Conv_116(t_625)
+ t_627 = F.relu(t_626)
+ t_627_padded = F.pad(t_627, [1, 1, 1, 1], value=0)
+ t_628 = self.n_Conv_117(t_627_padded)
+ t_629 = F.relu(t_628)
+ t_630 = self.n_Conv_118(t_629)
+ t_631 = torch.add(t_630, t_625)
+ t_632 = F.relu(t_631)
+ t_633 = self.n_Conv_119(t_632)
+ t_634 = F.relu(t_633)
+ t_634_padded = F.pad(t_634, [1, 1, 1, 1], value=0)
+ t_635 = self.n_Conv_120(t_634_padded)
+ t_636 = F.relu(t_635)
+ t_637 = self.n_Conv_121(t_636)
+ t_638 = torch.add(t_637, t_632)
+ t_639 = F.relu(t_638)
+ t_640 = self.n_Conv_122(t_639)
+ t_641 = F.relu(t_640)
+ t_641_padded = F.pad(t_641, [1, 1, 1, 1], value=0)
+ t_642 = self.n_Conv_123(t_641_padded)
+ t_643 = F.relu(t_642)
+ t_644 = self.n_Conv_124(t_643)
+ t_645 = torch.add(t_644, t_639)
+ t_646 = F.relu(t_645)
+ t_647 = self.n_Conv_125(t_646)
+ t_648 = F.relu(t_647)
+ t_648_padded = F.pad(t_648, [1, 1, 1, 1], value=0)
+ t_649 = self.n_Conv_126(t_648_padded)
+ t_650 = F.relu(t_649)
+ t_651 = self.n_Conv_127(t_650)
+ t_652 = torch.add(t_651, t_646)
+ t_653 = F.relu(t_652)
+ t_654 = self.n_Conv_128(t_653)
+ t_655 = F.relu(t_654)
+ t_655_padded = F.pad(t_655, [1, 1, 1, 1], value=0)
+ t_656 = self.n_Conv_129(t_655_padded)
+ t_657 = F.relu(t_656)
+ t_658 = self.n_Conv_130(t_657)
+ t_659 = torch.add(t_658, t_653)
+ t_660 = F.relu(t_659)
+ t_661 = self.n_Conv_131(t_660)
+ t_662 = F.relu(t_661)
+ t_662_padded = F.pad(t_662, [1, 1, 1, 1], value=0)
+ t_663 = self.n_Conv_132(t_662_padded)
+ t_664 = F.relu(t_663)
+ t_665 = self.n_Conv_133(t_664)
+ t_666 = torch.add(t_665, t_660)
+ t_667 = F.relu(t_666)
+ t_668 = self.n_Conv_134(t_667)
+ t_669 = F.relu(t_668)
+ t_669_padded = F.pad(t_669, [1, 1, 1, 1], value=0)
+ t_670 = self.n_Conv_135(t_669_padded)
+ t_671 = F.relu(t_670)
+ t_672 = self.n_Conv_136(t_671)
+ t_673 = torch.add(t_672, t_667)
+ t_674 = F.relu(t_673)
+ t_675 = self.n_Conv_137(t_674)
+ t_676 = F.relu(t_675)
+ t_676_padded = F.pad(t_676, [1, 1, 1, 1], value=0)
+ t_677 = self.n_Conv_138(t_676_padded)
+ t_678 = F.relu(t_677)
+ t_679 = self.n_Conv_139(t_678)
+ t_680 = torch.add(t_679, t_674)
+ t_681 = F.relu(t_680)
+ t_682 = self.n_Conv_140(t_681)
+ t_683 = F.relu(t_682)
+ t_683_padded = F.pad(t_683, [1, 1, 1, 1], value=0)
+ t_684 = self.n_Conv_141(t_683_padded)
+ t_685 = F.relu(t_684)
+ t_686 = self.n_Conv_142(t_685)
+ t_687 = torch.add(t_686, t_681)
+ t_688 = F.relu(t_687)
+ t_689 = self.n_Conv_143(t_688)
+ t_690 = F.relu(t_689)
+ t_690_padded = F.pad(t_690, [1, 1, 1, 1], value=0)
+ t_691 = self.n_Conv_144(t_690_padded)
+ t_692 = F.relu(t_691)
+ t_693 = self.n_Conv_145(t_692)
+ t_694 = torch.add(t_693, t_688)
+ t_695 = F.relu(t_694)
+ t_696 = self.n_Conv_146(t_695)
+ t_697 = F.relu(t_696)
+ t_697_padded = F.pad(t_697, [1, 1, 1, 1], value=0)
+ t_698 = self.n_Conv_147(t_697_padded)
+ t_699 = F.relu(t_698)
+ t_700 = self.n_Conv_148(t_699)
+ t_701 = torch.add(t_700, t_695)
+ t_702 = F.relu(t_701)
+ t_703 = self.n_Conv_149(t_702)
+ t_704 = F.relu(t_703)
+ t_704_padded = F.pad(t_704, [1, 1, 1, 1], value=0)
+ t_705 = self.n_Conv_150(t_704_padded)
+ t_706 = F.relu(t_705)
+ t_707 = self.n_Conv_151(t_706)
+ t_708 = torch.add(t_707, t_702)
+ t_709 = F.relu(t_708)
+ t_710 = self.n_Conv_152(t_709)
+ t_711 = F.relu(t_710)
+ t_711_padded = F.pad(t_711, [1, 1, 1, 1], value=0)
+ t_712 = self.n_Conv_153(t_711_padded)
+ t_713 = F.relu(t_712)
+ t_714 = self.n_Conv_154(t_713)
+ t_715 = torch.add(t_714, t_709)
+ t_716 = F.relu(t_715)
+ t_717 = self.n_Conv_155(t_716)
+ t_718 = F.relu(t_717)
+ t_718_padded = F.pad(t_718, [1, 1, 1, 1], value=0)
+ t_719 = self.n_Conv_156(t_718_padded)
+ t_720 = F.relu(t_719)
+ t_721 = self.n_Conv_157(t_720)
+ t_722 = torch.add(t_721, t_716)
+ t_723 = F.relu(t_722)
+ t_724 = self.n_Conv_158(t_723)
+ t_725 = self.n_Conv_159(t_723)
+ t_726 = F.relu(t_725)
+ t_726_padded = F.pad(t_726, [0, 1, 0, 1], value=0)
+ t_727 = self.n_Conv_160(t_726_padded)
+ t_728 = F.relu(t_727)
+ t_729 = self.n_Conv_161(t_728)
+ t_730 = torch.add(t_729, t_724)
+ t_731 = F.relu(t_730)
+ t_732 = self.n_Conv_162(t_731)
+ t_733 = F.relu(t_732)
+ t_733_padded = F.pad(t_733, [1, 1, 1, 1], value=0)
+ t_734 = self.n_Conv_163(t_733_padded)
+ t_735 = F.relu(t_734)
+ t_736 = self.n_Conv_164(t_735)
+ t_737 = torch.add(t_736, t_731)
+ t_738 = F.relu(t_737)
+ t_739 = self.n_Conv_165(t_738)
+ t_740 = F.relu(t_739)
+ t_740_padded = F.pad(t_740, [1, 1, 1, 1], value=0)
+ t_741 = self.n_Conv_166(t_740_padded)
+ t_742 = F.relu(t_741)
+ t_743 = self.n_Conv_167(t_742)
+ t_744 = torch.add(t_743, t_738)
+ t_745 = F.relu(t_744)
+ t_746 = self.n_Conv_168(t_745)
+ t_747 = self.n_Conv_169(t_745)
+ t_748 = F.relu(t_747)
+ t_748_padded = F.pad(t_748, [0, 1, 0, 1], value=0)
+ t_749 = self.n_Conv_170(t_748_padded)
+ t_750 = F.relu(t_749)
+ t_751 = self.n_Conv_171(t_750)
+ t_752 = torch.add(t_751, t_746)
+ t_753 = F.relu(t_752)
+ t_754 = self.n_Conv_172(t_753)
+ t_755 = F.relu(t_754)
+ t_755_padded = F.pad(t_755, [1, 1, 1, 1], value=0)
+ t_756 = self.n_Conv_173(t_755_padded)
+ t_757 = F.relu(t_756)
+ t_758 = self.n_Conv_174(t_757)
+ t_759 = torch.add(t_758, t_753)
+ t_760 = F.relu(t_759)
+ t_761 = self.n_Conv_175(t_760)
+ t_762 = F.relu(t_761)
+ t_762_padded = F.pad(t_762, [1, 1, 1, 1], value=0)
+ t_763 = self.n_Conv_176(t_762_padded)
+ t_764 = F.relu(t_763)
+ t_765 = self.n_Conv_177(t_764)
+ t_766 = torch.add(t_765, t_760)
+ t_767 = F.relu(t_766)
+ t_768 = self.n_Conv_178(t_767)
+ t_769 = F.avg_pool2d(t_768, kernel_size=t_768.shape[-2:])
+ t_770 = torch.squeeze(t_769, 3)
+ t_770 = torch.squeeze(t_770, 2)
+ t_771 = torch.sigmoid(t_770)
+ return t_771
+
+ def load_state_dict(self, state_dict, **kwargs): # pylint: disable=arguments-differ,unused-argument
+ self.tags = state_dict.get('tags', [])
+ super(DeepDanbooruModel, self).load_state_dict({k: v for k, v in state_dict.items() if k != 'tags'}) # pylint: disable=R1725
diff --git a/modules/errors.py b/modules/errors.py
index 122628bff..8cb2eb599 100644
--- a/modules/errors.py
+++ b/modules/errors.py
@@ -1,96 +1,96 @@
-import logging
-import warnings
-from rich.console import Console
-from rich.theme import Theme
-from rich.pretty import install as pretty_install
-from rich.traceback import install as traceback_install
-from installer import log as installer_log, setup_logging
-
-
-setup_logging()
-log = installer_log
-console = Console(log_time=True, log_time_format='%H:%M:%S-%f', theme=Theme({
- "traceback.border": "black",
- "traceback.border.syntax_error": "black",
- "inspect.value.border": "black",
-}))
-
-pretty_install(console=console)
-traceback_install(console=console, extra_lines=1, width=console.width, word_wrap=False, indent_guides=False)
-already_displayed = {}
-
-
-def install(suppress=[]): # noqa: B006
- warnings.filterwarnings("ignore", category=UserWarning)
- pretty_install(console=console)
- traceback_install(console=console, extra_lines=1, width=console.width, word_wrap=False, indent_guides=False, suppress=suppress)
- logging.basicConfig(level=logging.ERROR, format='%(asctime)s | %(levelname)s | %(pathname)s | %(message)s')
- # for handler in logging.getLogger().handlers:
- # handler.setLevel(logging.INFO)
-
-
-def print_error_explanation(message):
- lines = message.strip().split("\n")
- for line in lines:
- log.error(line)
-
-
-def display(e: Exception, task, suppress=[]): # noqa: B006
- log.error(f"{task or 'error'}: {type(e).__name__}")
- console.print_exception(show_locals=False, max_frames=10, extra_lines=1, suppress=suppress, theme="ansi_dark", word_wrap=False, width=min([console.width, 200]))
-
-
-def display_once(e: Exception, task):
- if task in already_displayed:
- return
- display(e, task)
- already_displayed[task] = 1
-
-
-def run(code, task):
- try:
- code()
- except Exception as e:
- display(e, task)
-
-
-def exception(suppress=[]): # noqa: B006
- console.print_exception(show_locals=False, max_frames=10, extra_lines=2, suppress=suppress, theme="ansi_dark", word_wrap=False, width=min([console.width, 200]))
-
-
-def profile(profiler, msg: str):
- profiler.disable()
- import io
- import pstats
- stream = io.StringIO() # pylint: disable=abstract-class-instantiated
- p = pstats.Stats(profiler, stream=stream)
- p.sort_stats(pstats.SortKey.CUMULATIVE)
- p.print_stats(100)
- # p.print_title()
- # p.print_call_heading(10, 'time')
- # p.print_callees(10)
- # p.print_callers(10)
- profiler = None
- lines = stream.getvalue().split('\n')
- lines = [x for x in lines if '{self.commit_hash[:8]}
{datetime.fromtimestamp(self.commit_date).strftime('%a %b%d %Y %H:%M')}
"
- except Exception as ex:
- shared.log.error(f"Extension: failed reading data from git repo={self.name}: {ex}")
- self.remote = None
-
- def list_files(self, subdir, extension):
- from modules import scripts
- dirpath = os.path.join(self.path, subdir)
- if not os.path.isdir(dirpath):
- return []
- res = []
- for filename in sorted(os.listdir(dirpath)):
- if not filename.endswith(".py") and not filename.endswith(".js") and not filename.endswith(".mjs"):
- continue
- priority = '50'
- if os.path.isfile(os.path.join(dirpath, "..", ".priority")):
- with open(os.path.join(dirpath, "..", ".priority"), "r", encoding="utf-8") as f:
- priority = str(f.read().strip())
- res.append(scripts.ScriptFile(self.path, filename, os.path.join(dirpath, filename), priority))
- if priority != '50':
- shared.log.debug(f'Extension priority override: {os.path.dirname(dirpath)}:{priority}')
- res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
- return res
-
- def check_updates(self):
- try:
- repo = git.Repo(self.path)
- except Exception:
- self.can_update = False
- return
- for fetch in repo.remote().fetch(dry_run=True):
- if fetch.flags != fetch.HEAD_UPTODATE:
- self.can_update = True
- self.status = "new commits"
- return
- try:
- origin = repo.rev_parse('origin')
- if repo.head.commit != origin:
- self.can_update = True
- self.status = "behind HEAD"
- return
- except Exception:
- self.can_update = False
- self.status = "unknown (remote error)"
- return
- self.can_update = False
- self.status = "latest"
-
- def git_fetch(self, commit='origin'):
- repo = git.Repo(self.path)
- # Fix: `error: Your local changes to the following files would be overwritten by merge`,
- # because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
- repo.git.fetch(all=True)
- repo.git.reset('origin', hard=True)
- repo.git.reset(commit, hard=True)
- self.have_info_from_repo = False
-
-
-def list_extensions():
- extensions.clear()
- if not os.path.isdir(extensions_dir):
- return
- if shared.opts.disable_all_extensions == "all" or shared.opts.disable_all_extensions == "user":
- shared.log.warning(f"Option set: Disable extensions: {shared.opts.disable_all_extensions}")
- extension_paths = []
- extension_names = []
- extension_folders = [extensions_builtin_dir] if shared.cmd_opts.safe else [extensions_builtin_dir, extensions_dir]
- for dirname in extension_folders:
- if not os.path.isdir(dirname):
- return
- for extension_dirname in sorted(os.listdir(dirname)):
- path = os.path.join(dirname, extension_dirname)
- if not os.path.isdir(path):
- continue
- if extension_dirname in extension_names:
- shared.log.info(f'Skipping conflicting extension: {path}')
- continue
- extension_names.append(extension_dirname)
- extension_paths.append((extension_dirname, path, dirname == extensions_builtin_dir))
- disabled_extensions = shared.opts.disabled_extensions + shared.temp_disable_extensions()
- for dirname, path, is_builtin in extension_paths:
- extension = Extension(name=dirname, path=path, enabled=dirname not in disabled_extensions, is_builtin=is_builtin)
- extensions.append(extension)
- shared.log.info(f'Disabled extensions: {[e.name for e in extensions if not e.enabled]}')
+import os
+from datetime import datetime
+import git
+from modules import shared, errors
+from modules.paths import extensions_dir, extensions_builtin_dir
+
+
+extensions = []
+
+
+if not os.path.exists(extensions_dir):
+ os.makedirs(extensions_dir)
+
+
+def active():
+ if shared.opts.disable_all_extensions == "all":
+ return []
+ elif shared.opts.disable_all_extensions == "user":
+ return [x for x in extensions if x.enabled and x.is_builtin]
+ else:
+ return [x for x in extensions if x.enabled]
+
+
+class Extension:
+ def __init__(self, name, path, enabled=True, is_builtin=False):
+ self.name = name
+ self.git_name = ''
+ self.path = path
+ self.enabled = enabled
+ self.status = ''
+ self.can_update = False
+ self.is_builtin = is_builtin
+ self.commit_hash = ''
+ self.commit_date = None
+ self.version = ''
+ self.description = ''
+ self.branch = None
+ self.remote = None
+ self.have_info_from_repo = False
+ self.mtime = 0
+ self.ctime = 0
+
+ def read_info(self, force=False):
+ if self.have_info_from_repo and not force:
+ return
+ self.have_info_from_repo = True
+ repo = None
+ self.mtime = datetime.fromtimestamp(os.path.getmtime(self.path)).isoformat() + 'Z'
+ self.ctime = datetime.fromtimestamp(os.path.getctime(self.path)).isoformat() + 'Z'
+ try:
+ if os.path.exists(os.path.join(self.path, ".git")):
+ repo = git.Repo(self.path)
+ except Exception as e:
+ errors.display(e, f'github info from {self.path}')
+ if repo is None or repo.bare:
+ self.remote = None
+ else:
+ try:
+ self.status = 'unknown'
+ if len(repo.remotes) == 0:
+ shared.log.debug(f"Extension: no remotes info repo={self.name}")
+ return
+ self.git_name = repo.remotes.origin.url.split('.git')[0].split('/')[-1]
+ self.description = repo.description
+ if self.description is None or self.description.startswith("Unnamed repository"):
+ self.description = "[No description]"
+ self.remote = next(repo.remote().urls, None)
+ head = repo.head.commit
+ self.commit_date = repo.head.commit.committed_date
+ try:
+ if repo.active_branch:
+ self.branch = repo.active_branch.name
+ except Exception:
+ pass
+ self.commit_hash = head.hexsha
+ self.version = f"{self.commit_hash[:8]}
{datetime.fromtimestamp(self.commit_date).strftime('%a %b%d %Y %H:%M')}
"
+ except Exception as ex:
+ shared.log.error(f"Extension: failed reading data from git repo={self.name}: {ex}")
+ self.remote = None
+
+ def list_files(self, subdir, extension):
+ from modules import scripts
+ dirpath = os.path.join(self.path, subdir)
+ if not os.path.isdir(dirpath):
+ return []
+ res = []
+ for filename in sorted(os.listdir(dirpath)):
+ if not filename.endswith(".py") and not filename.endswith(".js") and not filename.endswith(".mjs"):
+ continue
+ priority = '50'
+ if os.path.isfile(os.path.join(dirpath, "..", ".priority")):
+ with open(os.path.join(dirpath, "..", ".priority"), "r", encoding="utf-8") as f:
+ priority = str(f.read().strip())
+ res.append(scripts.ScriptFile(self.path, filename, os.path.join(dirpath, filename), priority))
+ if priority != '50':
+ shared.log.debug(f'Extension priority override: {os.path.dirname(dirpath)}:{priority}')
+ res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
+ return res
+
+ def check_updates(self):
+ try:
+ repo = git.Repo(self.path)
+ except Exception:
+ self.can_update = False
+ return
+ for fetch in repo.remote().fetch(dry_run=True):
+ if fetch.flags != fetch.HEAD_UPTODATE:
+ self.can_update = True
+ self.status = "new commits"
+ return
+ try:
+ origin = repo.rev_parse('origin')
+ if repo.head.commit != origin:
+ self.can_update = True
+ self.status = "behind HEAD"
+ return
+ except Exception:
+ self.can_update = False
+ self.status = "unknown (remote error)"
+ return
+ self.can_update = False
+ self.status = "latest"
+
+ def git_fetch(self, commit='origin'):
+ repo = git.Repo(self.path)
+ # Fix: `error: Your local changes to the following files would be overwritten by merge`,
+ # because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
+ repo.git.fetch(all=True)
+ repo.git.reset('origin', hard=True)
+ repo.git.reset(commit, hard=True)
+ self.have_info_from_repo = False
+
+
+def list_extensions():
+ extensions.clear()
+ if not os.path.isdir(extensions_dir):
+ return
+ if shared.opts.disable_all_extensions == "all" or shared.opts.disable_all_extensions == "user":
+ shared.log.warning(f"Option set: Disable extensions: {shared.opts.disable_all_extensions}")
+ extension_paths = []
+ extension_names = []
+ extension_folders = [extensions_builtin_dir] if shared.cmd_opts.safe else [extensions_builtin_dir, extensions_dir]
+ for dirname in extension_folders:
+ if not os.path.isdir(dirname):
+ return
+ for extension_dirname in sorted(os.listdir(dirname)):
+ path = os.path.join(dirname, extension_dirname)
+ if not os.path.isdir(path):
+ continue
+ if extension_dirname in extension_names:
+ shared.log.info(f'Skipping conflicting extension: {path}')
+ continue
+ extension_names.append(extension_dirname)
+ extension_paths.append((extension_dirname, path, dirname == extensions_builtin_dir))
+ disabled_extensions = shared.opts.disabled_extensions + shared.temp_disable_extensions()
+ for dirname, path, is_builtin in extension_paths:
+ extension = Extension(name=dirname, path=path, enabled=dirname not in disabled_extensions, is_builtin=is_builtin)
+ extensions.append(extension)
+ shared.log.info(f'Disabled extensions: {[e.name for e in extensions if not e.enabled]}')
diff --git a/modules/extra_networks.py b/modules/extra_networks.py
index d3ca87dd2..b74dfe0e3 100644
--- a/modules/extra_networks.py
+++ b/modules/extra_networks.py
@@ -1,137 +1,137 @@
-import re
-from collections import defaultdict
-
-from modules import errors
-
-extra_network_registry = {}
-
-
-def initialize():
- extra_network_registry.clear()
-
-
-def register_extra_network(extra_network):
- extra_network_registry[extra_network.name] = extra_network
-
-
-def register_default_extra_networks():
- from modules.extra_networks_hypernet import ExtraNetworkHypernet
- register_extra_network(ExtraNetworkHypernet())
- from modules.ui_extra_networks_styles import ExtraNetworkStyles
- register_extra_network(ExtraNetworkStyles())
-
-
-class ExtraNetworkParams:
- def __init__(self, items=None):
- self.items = items or []
- self.positional = []
- self.named = {}
- for item in self.items:
- parts = item.split('=', 2) if isinstance(item, str) else [item]
- if len(parts) == 2:
- self.named[parts[0]] = parts[1]
- else:
- self.positional.append(item)
-
-
-class ExtraNetwork:
- def __init__(self, name):
- self.name = name
-
- def activate(self, p, params_list):
- """
- Called by processing on every run. Whatever the extra network is meant to do should be activated here. Passes arguments related to this extra network in params_list. User passes arguments by specifying this in his prompt:
-
- Where name matches the name of this ExtraNetwork object, and arg1:arg2:arg3 are any natural number of text arguments separated by colon.
- Even if the user does not mention this ExtraNetwork in his prompt, the call will stil be made, with empty params_list - in this case, all effects of this extra networks should be disabled.
- Can be called multiple times before deactivate() - each new call should override the previous call completely.
- For example, if this ExtraNetwork's name is 'hypernet' and user's prompt is:
- > "1girl, "
- params_list will be:
- [
- ExtraNetworkParams(items=["agm", "1.1"]),
- ExtraNetworkParams(items=["ray"])
- ]
- """
- raise NotImplementedError
-
- def deactivate(self, p):
- """
- Called at the end of processing for housekeeping. No need to do anything here.
- """
- raise NotImplementedError
-
-
-def activate(p, extra_network_data):
- """call activate for extra networks in extra_network_data in specified order, then call activate for all remaining registered networks with an empty argument list"""
- if extra_network_data is None:
- return
- for extra_network_name, extra_network_args in extra_network_data.items():
- extra_network = extra_network_registry.get(extra_network_name, None)
- if extra_network is None:
- print(f"Skipping unknown extra network: {extra_network_name}")
- continue
- try:
- extra_network.activate(p, extra_network_args)
- except Exception as e:
- errors.display(e, f"activating extra network: name={extra_network_name} args:{extra_network_args}")
-
- for extra_network_name, extra_network in extra_network_registry.items():
- args = extra_network_data.get(extra_network_name, None)
- if args is not None:
- continue
- try:
- extra_network.activate(p, [])
- except Exception as e:
- errors.display(e, f"activating extra network: name={extra_network_name}")
-
-
-def deactivate(p, extra_network_data):
- """call deactivate for extra networks in extra_network_data in specified order, then call deactivate for all remaining registered networks"""
- if extra_network_data is None:
- return
- for extra_network_name in extra_network_data:
- extra_network = extra_network_registry.get(extra_network_name, None)
- if extra_network is None:
- continue
- try:
- extra_network.deactivate(p)
- except Exception as e:
- errors.display(e, f"deactivating extra network {extra_network_name}")
-
- for extra_network_name, extra_network in extra_network_registry.items():
- args = extra_network_data.get(extra_network_name, None)
- if args is not None:
- continue
- try:
- extra_network.deactivate(p)
- except Exception as e:
- errors.display(e, f"deactivating unmentioned extra network {extra_network_name}")
-
-
-re_extra_net = re.compile(r"<(\w+):([^>]+)>")
-
-
-def parse_prompt(prompt):
- res = defaultdict(list)
-
- def found(m):
- name = m.group(1)
- args = m.group(2)
- res[name].append(ExtraNetworkParams(items=args.split(":")))
- return ""
- prompt = re.sub(re_extra_net, found, prompt)
- return prompt, res
-
-
-def parse_prompts(prompts):
- res = []
- extra_data = None
-
- for prompt in prompts:
- updated_prompt, parsed_extra_data = parse_prompt(prompt)
- if extra_data is None:
- extra_data = parsed_extra_data
- res.append(updated_prompt)
-
- return res, extra_data
+import re
+from collections import defaultdict
+
+from modules import errors
+
+extra_network_registry = {}
+
+
+def initialize():
+ extra_network_registry.clear()
+
+
+def register_extra_network(extra_network):
+ extra_network_registry[extra_network.name] = extra_network
+
+
+def register_default_extra_networks():
+ from modules.extra_networks_hypernet import ExtraNetworkHypernet
+ register_extra_network(ExtraNetworkHypernet())
+ from modules.ui_extra_networks_styles import ExtraNetworkStyles
+ register_extra_network(ExtraNetworkStyles())
+
+
+class ExtraNetworkParams:
+ def __init__(self, items=None):
+ self.items = items or []
+ self.positional = []
+ self.named = {}
+ for item in self.items:
+ parts = item.split('=', 2) if isinstance(item, str) else [item]
+ if len(parts) == 2:
+ self.named[parts[0]] = parts[1]
+ else:
+ self.positional.append(item)
+
+
+class ExtraNetwork:
+ def __init__(self, name):
+ self.name = name
+
+ def activate(self, p, params_list):
+ """
+ Called by processing on every run. Whatever the extra network is meant to do should be activated here. Passes arguments related to this extra network in params_list. User passes arguments by specifying this in his prompt:
+
+ Where name matches the name of this ExtraNetwork object, and arg1:arg2:arg3 are any natural number of text arguments separated by colon.
+ Even if the user does not mention this ExtraNetwork in his prompt, the call will stil be made, with empty params_list - in this case, all effects of this extra networks should be disabled.
+ Can be called multiple times before deactivate() - each new call should override the previous call completely.
+ For example, if this ExtraNetwork's name is 'hypernet' and user's prompt is:
+ > "1girl, "
+ params_list will be:
+ [
+ ExtraNetworkParams(items=["agm", "1.1"]),
+ ExtraNetworkParams(items=["ray"])
+ ]
+ """
+ raise NotImplementedError
+
+ def deactivate(self, p):
+ """
+ Called at the end of processing for housekeeping. No need to do anything here.
+ """
+ raise NotImplementedError
+
+
+def activate(p, extra_network_data):
+ """call activate for extra networks in extra_network_data in specified order, then call activate for all remaining registered networks with an empty argument list"""
+ if extra_network_data is None:
+ return
+ for extra_network_name, extra_network_args in extra_network_data.items():
+ extra_network = extra_network_registry.get(extra_network_name, None)
+ if extra_network is None:
+ print(f"Skipping unknown extra network: {extra_network_name}")
+ continue
+ try:
+ extra_network.activate(p, extra_network_args)
+ except Exception as e:
+ errors.display(e, f"activating extra network: name={extra_network_name} args:{extra_network_args}")
+
+ for extra_network_name, extra_network in extra_network_registry.items():
+ args = extra_network_data.get(extra_network_name, None)
+ if args is not None:
+ continue
+ try:
+ extra_network.activate(p, [])
+ except Exception as e:
+ errors.display(e, f"activating extra network: name={extra_network_name}")
+
+
+def deactivate(p, extra_network_data):
+ """call deactivate for extra networks in extra_network_data in specified order, then call deactivate for all remaining registered networks"""
+ if extra_network_data is None:
+ return
+ for extra_network_name in extra_network_data:
+ extra_network = extra_network_registry.get(extra_network_name, None)
+ if extra_network is None:
+ continue
+ try:
+ extra_network.deactivate(p)
+ except Exception as e:
+ errors.display(e, f"deactivating extra network {extra_network_name}")
+
+ for extra_network_name, extra_network in extra_network_registry.items():
+ args = extra_network_data.get(extra_network_name, None)
+ if args is not None:
+ continue
+ try:
+ extra_network.deactivate(p)
+ except Exception as e:
+ errors.display(e, f"deactivating unmentioned extra network {extra_network_name}")
+
+
+re_extra_net = re.compile(r"<(\w+):([^>]+)>")
+
+
+def parse_prompt(prompt):
+ res = defaultdict(list)
+
+ def found(m):
+ name = m.group(1)
+ args = m.group(2)
+ res[name].append(ExtraNetworkParams(items=args.split(":")))
+ return ""
+ prompt = re.sub(re_extra_net, found, prompt)
+ return prompt, res
+
+
+def parse_prompts(prompts):
+ res = []
+ extra_data = None
+
+ for prompt in prompts:
+ updated_prompt, parsed_extra_data = parse_prompt(prompt)
+ if extra_data is None:
+ extra_data = parsed_extra_data
+ res.append(updated_prompt)
+
+ return res, extra_data
diff --git a/modules/extra_networks_hypernet.py b/modules/extra_networks_hypernet.py
index aa2a14efd..dce11b68a 100644
--- a/modules/extra_networks_hypernet.py
+++ b/modules/extra_networks_hypernet.py
@@ -1,28 +1,28 @@
-from modules import extra_networks, shared
-from modules.hypernetworks import hypernetwork
-
-
-class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
- def __init__(self):
- super().__init__('hypernet')
-
- def activate(self, p, params_list):
- additional = shared.opts.sd_hypernetwork
-
- if additional != "None" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
- hypernet_prompt_text = f""
- p.all_prompts = [f"{prompt}{hypernet_prompt_text}" for prompt in p.all_prompts]
- params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
-
- names = []
- multipliers = []
- for params in params_list:
- assert len(params.items) > 0
-
- names.append(params.items[0])
- multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
-
- hypernetwork.load_hypernetworks(names, multipliers)
-
- def deactivate(self, p):
- pass
+from modules import extra_networks, shared
+from modules.hypernetworks import hypernetwork
+
+
+class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
+ def __init__(self):
+ super().__init__('hypernet')
+
+ def activate(self, p, params_list):
+ additional = shared.opts.sd_hypernetwork
+
+ if additional != "None" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
+ hypernet_prompt_text = f""
+ p.all_prompts = [f"{prompt}{hypernet_prompt_text}" for prompt in p.all_prompts]
+ params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
+
+ names = []
+ multipliers = []
+ for params in params_list:
+ assert len(params.items) > 0
+
+ names.append(params.items[0])
+ multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
+
+ hypernetwork.load_hypernetworks(names, multipliers)
+
+ def deactivate(self, p):
+ pass
diff --git a/modules/extras.py b/modules/extras.py
index fa3e07378..8e24635ff 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -1,348 +1,348 @@
-import os
-import html
-import json
-import time
-import shutil
-
-import torch
-import tqdm
-import gradio as gr
-import safetensors.torch
-from modules.merging.merge import merge_models
-from modules.merging.merge_utils import TRIPLE_METHODS
-
-from modules import shared, images, sd_models, sd_vae, sd_models_config, devices
-
-
-def run_pnginfo(image):
- if image is None:
- return '', '', ''
- geninfo, items = images.read_info_from_image(image)
- items = {**{'parameters': geninfo}, **items}
- info = ''
- for key, text in items.items():
- if key != 'UserComment':
- info += f"{html.escape(str(key))} : {html.escape(str(text))}
"
- return '', geninfo, info
-
-
-def create_config(ckpt_result, config_source, a, b, c):
- def config(x):
- res = sd_models_config.find_checkpoint_config_near_filename(x) if x else None
- return res if res != shared.sd_default_config else None
-
- if config_source == 0:
- cfg = config(a) or config(b) or config(c)
- elif config_source == 1:
- cfg = config(b)
- elif config_source == 2:
- cfg = config(c)
- else:
- cfg = None
- if cfg is None:
- return
- filename, _ = os.path.splitext(ckpt_result)
- checkpoint_filename = filename + ".yaml"
- shared.log.info("Copying config: {cfg} -> {checkpoint_filename}")
- shutil.copyfile(cfg, checkpoint_filename)
-
-
-def to_half(tensor, enable):
- if enable and tensor.dtype == torch.float:
- return tensor.half()
- return tensor
-
-
-def run_modelmerger(id_task, **kwargs): # pylint: disable=unused-argument
- shared.state.begin('merge')
- t0 = time.time()
-
- def fail(message):
- shared.state.textinfo = message
- shared.state.end()
- return [*[gr.update() for _ in range(4)], message]
-
- kwargs["models"] = {
- "model_a": sd_models.get_closet_checkpoint_match(kwargs.get("primary_model_name", None)).filename,
- "model_b": sd_models.get_closet_checkpoint_match(kwargs.get("secondary_model_name", None)).filename,
- }
-
- if kwargs.get("primary_model_name", None) in [None, 'None']:
- return fail("Failed: Merging requires a primary model.")
- primary_model_info = sd_models.get_closet_checkpoint_match(kwargs.get("primary_model_name", None))
- if kwargs.get("secondary_model_name", None) in [None, 'None']:
- return fail("Failed: Merging requires a secondary model.")
- secondary_model_info = sd_models.get_closet_checkpoint_match(kwargs.get("secondary_model_name", None))
- if kwargs.get("tertiary_model_name", None) in [None, 'None'] and kwargs.get("merge_mode", None) in TRIPLE_METHODS:
- return fail(f"Failed: Interpolation method ({kwargs.get('merge_mode', None)}) requires a tertiary model.")
- tertiary_model_info = sd_models.get_closet_checkpoint_match(kwargs.get("tertiary_model_name", None)) if kwargs.get("merge_mode", None) in TRIPLE_METHODS else None
-
- del kwargs["primary_model_name"]
- del kwargs["secondary_model_name"]
- if kwargs.get("tertiary_model_name", None) is not None:
- kwargs["models"] |= {"model_c": sd_models.get_closet_checkpoint_match(kwargs.get("tertiary_model_name", None)).filename}
- del kwargs["tertiary_model_name"]
-
- if hasattr(kwargs, "alpha_base") and hasattr(kwargs, "alpha_in_blocks") and hasattr(kwargs, "alpha_mid_block") and hasattr(kwargs, "alpha_out_blocks"):
- try:
- alpha = [float(x) for x in
- [kwargs["alpha_base"]] + kwargs["alpha_in_blocks"].split(",") + [kwargs["alpha_mid_block"]] + kwargs["alpha_out_blocks"].split(",")]
- assert len(alpha) == 26 or len(alpha) == 20, "Alpha Block Weights are wrong length (26 or 20 for SDXL) falling back"
- kwargs["alpha"] = alpha
- except KeyError as ke:
- shared.log.warning(f"Merge: Malformed manual block weight: {ke}")
- elif hasattr(kwargs, "alpha_preset") or hasattr(kwargs, "alpha"):
- kwargs["alpha"] = kwargs.get("alpha_preset", kwargs["alpha"])
-
- kwargs.pop("alpha_base", None)
- kwargs.pop("alpha_in_blocks", None)
- kwargs.pop("alpha_mid_block", None)
- kwargs.pop("alpha_out_blocks", None)
- kwargs.pop("alpha_preset", None)
-
- if hasattr(kwargs, "beta_base") and hasattr(kwargs, "beta_in_blocks") and hasattr(kwargs, "beta_mid_block") and hasattr(kwargs, "beta_out_blocks"):
- try:
- beta = [float(x) for x in
- [kwargs["beta_base"]] + kwargs["beta_in_blocks"].split(",") + [kwargs["beta_mid_block"]] + kwargs["beta_out_blocks"].split(",")]
- assert len(beta) == 26 or len(beta) == 20, "Beta Block Weights are wrong length (26 or 20 for SDXL) falling back"
- kwargs["beta"] = beta
- except KeyError as ke:
- shared.log.warning(f"Merge: Malformed manual block weight: {ke}")
- elif hasattr(kwargs, "beta_preset") or hasattr(kwargs, "beta"):
- kwargs["beta"] = kwargs.get("beta_preset", kwargs["beta"])
-
- kwargs.pop("beta_base", None)
- kwargs.pop("beta_in_blocks", None)
- kwargs.pop("beta_mid_block", None)
- kwargs.pop("beta_out_blocks", None)
- kwargs.pop("beta_preset", None)
-
- if kwargs["device"] == "gpu":
- kwargs["device"] = devices.device
- elif kwargs["device"] == "shuffle":
- kwargs["device"] = torch.device("cpu")
- kwargs["work_device"] = devices.device
- else:
- kwargs["device"] = torch.device("cpu")
- if kwargs.pop("unload", False):
- sd_models.unload_model_weights()
-
- try:
- theta_0 = merge_models(**kwargs)
- except Exception as e:
- return fail(f"{e}")
-
- try:
- theta_0 = theta_0.to_dict() #TensorDict -> Dict if necessary
- except Exception:
- pass
-
- bake_in_vae_filename = sd_vae.vae_dict.get(kwargs.get("bake_in_vae", None), None)
- if bake_in_vae_filename is not None:
- shared.log.info(f"Merge VAE='{bake_in_vae_filename}'")
- shared.state.textinfo = 'Merge VAE'
- vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename)
- for key in vae_dict.keys():
- theta_0_key = 'first_stage_model.' + key
- if theta_0_key in theta_0:
- theta_0[theta_0_key] = to_half(vae_dict[key], kwargs.get("precision", "fp16") == "fp16")
- del vae_dict
-
- ckpt_dir = shared.opts.ckpt_dir or sd_models.model_path
- filename = kwargs.get("custom_name", "Unnamed_Merge")
- filename += "." + kwargs.get("checkpoint_format", None)
- output_modelname = os.path.join(ckpt_dir, filename)
- shared.state.textinfo = "merge saving"
- metadata = None
- if kwargs.get("save_metadata", False):
- metadata = {"format": "pt", "sd_merge_models": {}}
- merge_recipe = {
- "type": "SDNext", # indicate this model was merged with webui's built-in merger
- "primary_model_hash": primary_model_info.sha256,
- "secondary_model_hash": secondary_model_info.sha256 if secondary_model_info else None,
- "tertiary_model_hash": tertiary_model_info.sha256 if tertiary_model_info else None,
- "merge_mode": kwargs.get('merge_mode', None),
- "alpha": kwargs.get('alpha', None),
- "beta": kwargs.get('beta', None),
- "precision": kwargs.get('precision', None),
- "custom_name": kwargs.get("custom_name", "Unamed_Merge"),
- }
- metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
-
- def add_model_metadata(checkpoint_info):
- checkpoint_info.calculate_shorthash()
- metadata["sd_merge_models"][checkpoint_info.sha256] = {
- "name": checkpoint_info.name,
- "legacy_hash": checkpoint_info.hash,
- "sd_merge_recipe": checkpoint_info.metadata.get("sd_merge_recipe", None)
- }
- metadata["sd_merge_models"].update(checkpoint_info.metadata.get("sd_merge_models", {}))
-
- add_model_metadata(primary_model_info)
- if secondary_model_info:
- add_model_metadata(secondary_model_info)
- if tertiary_model_info:
- add_model_metadata(tertiary_model_info)
- metadata["sd_merge_models"] = json.dumps(metadata["sd_merge_models"])
-
- _, extension = os.path.splitext(output_modelname)
-
- if os.path.exists(output_modelname) and not kwargs.get("overwrite", False):
- return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], f"Model alredy exists: {output_modelname}"]
- if extension.lower() == ".safetensors":
- safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata)
- else:
- torch.save(theta_0, output_modelname)
-
- t1 = time.time()
- shared.log.info(f"Merge complete: saved='{output_modelname}' time={t1-t0:.2f}")
- sd_models.list_models()
- created_model = next((ckpt for ckpt in sd_models.checkpoints_list.values() if ckpt.name == filename), None)
- if created_model:
- created_model.calculate_shorthash()
- devices.torch_gc(force=True)
- shared.state.end()
- return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], f"Model saved to {output_modelname}"]
-
-
-def run_modelconvert(model, checkpoint_formats, precision, conv_type, custom_name, unet_conv, text_encoder_conv,
- vae_conv, others_conv, fix_clip):
- # position_ids in clip is int64. model_ema.num_updates is int32
- dtypes_to_fp16 = {torch.float32, torch.float64, torch.bfloat16}
- dtypes_to_bf16 = {torch.float32, torch.float64, torch.float16}
-
- def conv_fp16(t: torch.Tensor):
- return t.half() if t.dtype in dtypes_to_fp16 else t
-
- def conv_bf16(t: torch.Tensor):
- return t.bfloat16() if t.dtype in dtypes_to_bf16 else t
-
- def conv_full(t):
- return t
-
- _g_precision_func = {
- "full": conv_full,
- "fp32": conv_full,
- "fp16": conv_fp16,
- "bf16": conv_bf16,
- }
-
- def check_weight_type(k: str) -> str:
- if k.startswith("model.diffusion_model"):
- return "unet"
- elif k.startswith("first_stage_model"):
- return "vae"
- elif k.startswith("cond_stage_model"):
- return "clip"
- return "other"
-
- def load_model(path):
- if path.endswith(".safetensors"):
- m = safetensors.torch.load_file(path, device="cpu")
- else:
- m = torch.load(path, map_location="cpu")
- state_dict = m["state_dict"] if "state_dict" in m else m
- return state_dict
-
- def fix_model(model, fix_clip=False):
- # code from model-toolkit
- nai_keys = {
- 'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
- 'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
- 'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.'
- }
- for k in list(model.keys()):
- for r in nai_keys:
- if type(k) == str and k.startswith(r):
- new_key = k.replace(r, nai_keys[r])
- model[new_key] = model[k]
- del model[k]
- shared.log.warning(f"Model convert: fixed NovelAI error key: {k}")
- break
- if fix_clip:
- i = "cond_stage_model.transformer.text_model.embeddings.position_ids"
- if i in model:
- correct = torch.Tensor([list(range(77))]).to(torch.int64)
- now = model[i].to(torch.int64)
-
- broken = correct.ne(now)
- broken = [i for i in range(77) if broken[0][i]]
- model[i] = correct
- if len(broken) != 0:
- shared.log.warning(f"Model convert: fixed broken CLiP: {broken}")
-
- return model
-
- if model == "":
- return "Error: you must choose a model"
- if len(checkpoint_formats) == 0:
- return "Error: at least choose one model save format"
-
- extra_opt = {
- "unet": unet_conv,
- "clip": text_encoder_conv,
- "vae": vae_conv,
- "other": others_conv
- }
- shared.state.begin('convert')
- model_info = sd_models.checkpoints_list[model]
- shared.state.textinfo = f"Loading {model_info.filename}..."
- shared.log.info(f"Model convert loading: {model_info.filename}")
- state_dict = load_model(model_info.filename)
-
- ok = {} # {"state_dict": {}}
-
- conv_func = _g_precision_func[precision]
-
- def _hf(wk: str, t: torch.Tensor):
- if not isinstance(t, torch.Tensor):
- return
- w_t = check_weight_type(wk)
- conv_t = extra_opt[w_t]
- if conv_t == "convert":
- ok[wk] = conv_func(t)
- elif conv_t == "copy":
- ok[wk] = t
- elif conv_t == "delete":
- return
-
- shared.log.info("Model convert: running")
- if conv_type == "ema-only":
- for k in tqdm.tqdm(state_dict):
- ema_k = "___"
- try:
- ema_k = "model_ema." + k[6:].replace(".", "")
- except Exception:
- pass
- if ema_k in state_dict:
- _hf(k, state_dict[ema_k])
- elif not k.startswith("model_ema.") or k in ["model_ema.num_updates", "model_ema.decay"]:
- _hf(k, state_dict[k])
- elif conv_type == "no-ema":
- for k, v in tqdm.tqdm(state_dict.items()):
- if "model_ema." not in k:
- _hf(k, v)
- else:
- for k, v in tqdm.tqdm(state_dict.items()):
- _hf(k, v)
-
- ok = fix_model(ok, fix_clip=fix_clip)
- output = ""
- ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
- save_name = f"{model_info.model_name}-{precision}"
- if conv_type != "disabled":
- save_name += f"-{conv_type}"
- if custom_name != "":
- save_name = custom_name
- for fmt in checkpoint_formats:
- ext = ".safetensors" if fmt == "safetensors" else ".ckpt"
- _save_name = save_name + ext
- save_path = os.path.join(ckpt_dir, _save_name)
- shared.log.info(f"Model convert saving: {save_path}")
- if fmt == "safetensors":
- safetensors.torch.save_file(ok, save_path)
- else:
- torch.save({"state_dict": ok}, save_path)
- output += f"Checkpoint saved to {save_path} "
- shared.state.end()
- return output
+import os
+import html
+import json
+import time
+import shutil
+
+import torch
+import tqdm
+import gradio as gr
+import safetensors.torch
+from modules.merging.merge import merge_models
+from modules.merging.merge_utils import TRIPLE_METHODS
+
+from modules import shared, images, sd_models, sd_vae, sd_models_config, devices
+
+
+def run_pnginfo(image):
+ if image is None:
+ return '', '', ''
+ geninfo, items = images.read_info_from_image(image)
+ items = {**{'parameters': geninfo}, **items}
+ info = ''
+ for key, text in items.items():
+ if key != 'UserComment':
+ info += f"{html.escape(str(key))} : {html.escape(str(text))}
"
+ return '', geninfo, info
+
+
+def create_config(ckpt_result, config_source, a, b, c):
+ def config(x):
+ res = sd_models_config.find_checkpoint_config_near_filename(x) if x else None
+ return res if res != shared.sd_default_config else None
+
+ if config_source == 0:
+ cfg = config(a) or config(b) or config(c)
+ elif config_source == 1:
+ cfg = config(b)
+ elif config_source == 2:
+ cfg = config(c)
+ else:
+ cfg = None
+ if cfg is None:
+ return
+ filename, _ = os.path.splitext(ckpt_result)
+ checkpoint_filename = filename + ".yaml"
+ shared.log.info("Copying config: {cfg} -> {checkpoint_filename}")
+ shutil.copyfile(cfg, checkpoint_filename)
+
+
+def to_half(tensor, enable):
+ if enable and tensor.dtype == torch.float:
+ return tensor.half()
+ return tensor
+
+
+def run_modelmerger(id_task, **kwargs): # pylint: disable=unused-argument
+ shared.state.begin('merge')
+ t0 = time.time()
+
+ def fail(message):
+ shared.state.textinfo = message
+ shared.state.end()
+ return [*[gr.update() for _ in range(4)], message]
+
+ kwargs["models"] = {
+ "model_a": sd_models.get_closet_checkpoint_match(kwargs.get("primary_model_name", None)).filename,
+ "model_b": sd_models.get_closet_checkpoint_match(kwargs.get("secondary_model_name", None)).filename,
+ }
+
+ if kwargs.get("primary_model_name", None) in [None, 'None']:
+ return fail("Failed: Merging requires a primary model.")
+ primary_model_info = sd_models.get_closet_checkpoint_match(kwargs.get("primary_model_name", None))
+ if kwargs.get("secondary_model_name", None) in [None, 'None']:
+ return fail("Failed: Merging requires a secondary model.")
+ secondary_model_info = sd_models.get_closet_checkpoint_match(kwargs.get("secondary_model_name", None))
+ if kwargs.get("tertiary_model_name", None) in [None, 'None'] and kwargs.get("merge_mode", None) in TRIPLE_METHODS:
+ return fail(f"Failed: Interpolation method ({kwargs.get('merge_mode', None)}) requires a tertiary model.")
+ tertiary_model_info = sd_models.get_closet_checkpoint_match(kwargs.get("tertiary_model_name", None)) if kwargs.get("merge_mode", None) in TRIPLE_METHODS else None
+
+ del kwargs["primary_model_name"]
+ del kwargs["secondary_model_name"]
+ if kwargs.get("tertiary_model_name", None) is not None:
+ kwargs["models"] |= {"model_c": sd_models.get_closet_checkpoint_match(kwargs.get("tertiary_model_name", None)).filename}
+ del kwargs["tertiary_model_name"]
+
+ if hasattr(kwargs, "alpha_base") and hasattr(kwargs, "alpha_in_blocks") and hasattr(kwargs, "alpha_mid_block") and hasattr(kwargs, "alpha_out_blocks"):
+ try:
+ alpha = [float(x) for x in
+ [kwargs["alpha_base"]] + kwargs["alpha_in_blocks"].split(",") + [kwargs["alpha_mid_block"]] + kwargs["alpha_out_blocks"].split(",")]
+ assert len(alpha) == 26 or len(alpha) == 20, "Alpha Block Weights are wrong length (26 or 20 for SDXL) falling back"
+ kwargs["alpha"] = alpha
+ except KeyError as ke:
+ shared.log.warning(f"Merge: Malformed manual block weight: {ke}")
+ elif hasattr(kwargs, "alpha_preset") or hasattr(kwargs, "alpha"):
+ kwargs["alpha"] = kwargs.get("alpha_preset", kwargs["alpha"])
+
+ kwargs.pop("alpha_base", None)
+ kwargs.pop("alpha_in_blocks", None)
+ kwargs.pop("alpha_mid_block", None)
+ kwargs.pop("alpha_out_blocks", None)
+ kwargs.pop("alpha_preset", None)
+
+ if hasattr(kwargs, "beta_base") and hasattr(kwargs, "beta_in_blocks") and hasattr(kwargs, "beta_mid_block") and hasattr(kwargs, "beta_out_blocks"):
+ try:
+ beta = [float(x) for x in
+ [kwargs["beta_base"]] + kwargs["beta_in_blocks"].split(",") + [kwargs["beta_mid_block"]] + kwargs["beta_out_blocks"].split(",")]
+ assert len(beta) == 26 or len(beta) == 20, "Beta Block Weights are wrong length (26 or 20 for SDXL) falling back"
+ kwargs["beta"] = beta
+ except KeyError as ke:
+ shared.log.warning(f"Merge: Malformed manual block weight: {ke}")
+ elif hasattr(kwargs, "beta_preset") or hasattr(kwargs, "beta"):
+ kwargs["beta"] = kwargs.get("beta_preset", kwargs["beta"])
+
+ kwargs.pop("beta_base", None)
+ kwargs.pop("beta_in_blocks", None)
+ kwargs.pop("beta_mid_block", None)
+ kwargs.pop("beta_out_blocks", None)
+ kwargs.pop("beta_preset", None)
+
+ if kwargs["device"] == "gpu":
+ kwargs["device"] = devices.device
+ elif kwargs["device"] == "shuffle":
+ kwargs["device"] = torch.device("cpu")
+ kwargs["work_device"] = devices.device
+ else:
+ kwargs["device"] = torch.device("cpu")
+ if kwargs.pop("unload", False):
+ sd_models.unload_model_weights()
+
+ try:
+ theta_0 = merge_models(**kwargs)
+ except Exception as e:
+ return fail(f"{e}")
+
+ try:
+ theta_0 = theta_0.to_dict() #TensorDict -> Dict if necessary
+ except Exception:
+ pass
+
+ bake_in_vae_filename = sd_vae.vae_dict.get(kwargs.get("bake_in_vae", None), None)
+ if bake_in_vae_filename is not None:
+ shared.log.info(f"Merge VAE='{bake_in_vae_filename}'")
+ shared.state.textinfo = 'Merge VAE'
+ vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename)
+ for key in vae_dict.keys():
+ theta_0_key = 'first_stage_model.' + key
+ if theta_0_key in theta_0:
+ theta_0[theta_0_key] = to_half(vae_dict[key], kwargs.get("precision", "fp16") == "fp16")
+ del vae_dict
+
+ ckpt_dir = shared.opts.ckpt_dir or sd_models.model_path
+ filename = kwargs.get("custom_name", "Unnamed_Merge")
+ filename += "." + kwargs.get("checkpoint_format", None)
+ output_modelname = os.path.join(ckpt_dir, filename)
+ shared.state.textinfo = "merge saving"
+ metadata = None
+ if kwargs.get("save_metadata", False):
+ metadata = {"format": "pt", "sd_merge_models": {}}
+ merge_recipe = {
+ "type": "SDNext", # indicate this model was merged with webui's built-in merger
+ "primary_model_hash": primary_model_info.sha256,
+ "secondary_model_hash": secondary_model_info.sha256 if secondary_model_info else None,
+ "tertiary_model_hash": tertiary_model_info.sha256 if tertiary_model_info else None,
+ "merge_mode": kwargs.get('merge_mode', None),
+ "alpha": kwargs.get('alpha', None),
+ "beta": kwargs.get('beta', None),
+ "precision": kwargs.get('precision', None),
+ "custom_name": kwargs.get("custom_name", "Unamed_Merge"),
+ }
+ metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
+
+ def add_model_metadata(checkpoint_info):
+ checkpoint_info.calculate_shorthash()
+ metadata["sd_merge_models"][checkpoint_info.sha256] = {
+ "name": checkpoint_info.name,
+ "legacy_hash": checkpoint_info.hash,
+ "sd_merge_recipe": checkpoint_info.metadata.get("sd_merge_recipe", None)
+ }
+ metadata["sd_merge_models"].update(checkpoint_info.metadata.get("sd_merge_models", {}))
+
+ add_model_metadata(primary_model_info)
+ if secondary_model_info:
+ add_model_metadata(secondary_model_info)
+ if tertiary_model_info:
+ add_model_metadata(tertiary_model_info)
+ metadata["sd_merge_models"] = json.dumps(metadata["sd_merge_models"])
+
+ _, extension = os.path.splitext(output_modelname)
+
+ if os.path.exists(output_modelname) and not kwargs.get("overwrite", False):
+ return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], f"Model alredy exists: {output_modelname}"]
+ if extension.lower() == ".safetensors":
+ safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata)
+ else:
+ torch.save(theta_0, output_modelname)
+
+ t1 = time.time()
+ shared.log.info(f"Merge complete: saved='{output_modelname}' time={t1-t0:.2f}")
+ sd_models.list_models()
+ created_model = next((ckpt for ckpt in sd_models.checkpoints_list.values() if ckpt.name == filename), None)
+ if created_model:
+ created_model.calculate_shorthash()
+ devices.torch_gc(force=True)
+ shared.state.end()
+ return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], f"Model saved to {output_modelname}"]
+
+
+def run_modelconvert(model, checkpoint_formats, precision, conv_type, custom_name, unet_conv, text_encoder_conv,
+ vae_conv, others_conv, fix_clip):
+ # position_ids in clip is int64. model_ema.num_updates is int32
+ dtypes_to_fp16 = {torch.float32, torch.float64, torch.bfloat16}
+ dtypes_to_bf16 = {torch.float32, torch.float64, torch.float16}
+
+ def conv_fp16(t: torch.Tensor):
+ return t.half() if t.dtype in dtypes_to_fp16 else t
+
+ def conv_bf16(t: torch.Tensor):
+ return t.bfloat16() if t.dtype in dtypes_to_bf16 else t
+
+ def conv_full(t):
+ return t
+
+ _g_precision_func = {
+ "full": conv_full,
+ "fp32": conv_full,
+ "fp16": conv_fp16,
+ "bf16": conv_bf16,
+ }
+
+ def check_weight_type(k: str) -> str:
+ if k.startswith("model.diffusion_model"):
+ return "unet"
+ elif k.startswith("first_stage_model"):
+ return "vae"
+ elif k.startswith("cond_stage_model"):
+ return "clip"
+ return "other"
+
+ def load_model(path):
+ if path.endswith(".safetensors"):
+ m = safetensors.torch.load_file(path, device="cpu")
+ else:
+ m = torch.load(path, map_location="cpu")
+ state_dict = m["state_dict"] if "state_dict" in m else m
+ return state_dict
+
+ def fix_model(model, fix_clip=False):
+ # code from model-toolkit
+ nai_keys = {
+ 'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
+ 'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
+ 'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.'
+ }
+ for k in list(model.keys()):
+ for r in nai_keys:
+ if type(k) == str and k.startswith(r):
+ new_key = k.replace(r, nai_keys[r])
+ model[new_key] = model[k]
+ del model[k]
+ shared.log.warning(f"Model convert: fixed NovelAI error key: {k}")
+ break
+ if fix_clip:
+ i = "cond_stage_model.transformer.text_model.embeddings.position_ids"
+ if i in model:
+ correct = torch.Tensor([list(range(77))]).to(torch.int64)
+ now = model[i].to(torch.int64)
+
+ broken = correct.ne(now)
+ broken = [i for i in range(77) if broken[0][i]]
+ model[i] = correct
+ if len(broken) != 0:
+ shared.log.warning(f"Model convert: fixed broken CLiP: {broken}")
+
+ return model
+
+ if model == "":
+ return "Error: you must choose a model"
+ if len(checkpoint_formats) == 0:
+ return "Error: at least choose one model save format"
+
+ extra_opt = {
+ "unet": unet_conv,
+ "clip": text_encoder_conv,
+ "vae": vae_conv,
+ "other": others_conv
+ }
+ shared.state.begin('convert')
+ model_info = sd_models.checkpoints_list[model]
+ shared.state.textinfo = f"Loading {model_info.filename}..."
+ shared.log.info(f"Model convert loading: {model_info.filename}")
+ state_dict = load_model(model_info.filename)
+
+ ok = {} # {"state_dict": {}}
+
+ conv_func = _g_precision_func[precision]
+
+ def _hf(wk: str, t: torch.Tensor):
+ if not isinstance(t, torch.Tensor):
+ return
+ w_t = check_weight_type(wk)
+ conv_t = extra_opt[w_t]
+ if conv_t == "convert":
+ ok[wk] = conv_func(t)
+ elif conv_t == "copy":
+ ok[wk] = t
+ elif conv_t == "delete":
+ return
+
+ shared.log.info("Model convert: running")
+ if conv_type == "ema-only":
+ for k in tqdm.tqdm(state_dict):
+ ema_k = "___"
+ try:
+ ema_k = "model_ema." + k[6:].replace(".", "")
+ except Exception:
+ pass
+ if ema_k in state_dict:
+ _hf(k, state_dict[ema_k])
+ elif not k.startswith("model_ema.") or k in ["model_ema.num_updates", "model_ema.decay"]:
+ _hf(k, state_dict[k])
+ elif conv_type == "no-ema":
+ for k, v in tqdm.tqdm(state_dict.items()):
+ if "model_ema." not in k:
+ _hf(k, v)
+ else:
+ for k, v in tqdm.tqdm(state_dict.items()):
+ _hf(k, v)
+
+ ok = fix_model(ok, fix_clip=fix_clip)
+ output = ""
+ ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
+ save_name = f"{model_info.model_name}-{precision}"
+ if conv_type != "disabled":
+ save_name += f"-{conv_type}"
+ if custom_name != "":
+ save_name = custom_name
+ for fmt in checkpoint_formats:
+ ext = ".safetensors" if fmt == "safetensors" else ".ckpt"
+ _save_name = save_name + ext
+ save_path = os.path.join(ckpt_dir, _save_name)
+ shared.log.info(f"Model convert saving: {save_path}")
+ if fmt == "safetensors":
+ safetensors.torch.save_file(ok, save_path)
+ else:
+ torch.save({"state_dict": ok}, save_path)
+ output += f"Checkpoint saved to {save_path} "
+ shared.state.end()
+ return output
diff --git a/modules/face_restoration.py b/modules/face_restoration.py
index 55e1033c6..d7fc5d1e9 100644
--- a/modules/face_restoration.py
+++ b/modules/face_restoration.py
@@ -1,17 +1,17 @@
-from modules import shared
-
-
-class FaceRestoration:
- def name(self):
- return "None"
-
- def restore(self, np_image):
- return np_image
-
-
-def restore_faces(np_image):
- face_restorers = [x for x in shared.face_restorers if x.name() == shared.opts.face_restoration_model or shared.opts.face_restoration_model is None]
- if len(face_restorers) == 0:
- return np_image
- face_restorer = face_restorers[0]
- return face_restorer.restore(np_image)
+from modules import shared
+
+
+class FaceRestoration:
+ def name(self):
+ return "None"
+
+ def restore(self, np_image):
+ return np_image
+
+
+def restore_faces(np_image):
+ face_restorers = [x for x in shared.face_restorers if x.name() == shared.opts.face_restoration_model or shared.opts.face_restoration_model is None]
+ if len(face_restorers) == 0:
+ return np_image
+ face_restorer = face_restorers[0]
+ return face_restorer.restore(np_image)
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index 9c5e50063..6ee979a7c 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -1,383 +1,383 @@
-import base64
-import io
-import os
-import re
-import json
-from PIL import Image
-import gradio as gr
-from modules.paths import data_path
-from modules import shared, ui_tempdir, script_callbacks, images
-
-
-re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
-re_param = re.compile(re_param_code)
-re_imagesize = re.compile(r"^(\d+)x(\d+)$")
-re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$") # pylint: disable=anomalous-backslash-in-string
-type_of_gr_update = type(gr.update())
-paste_fields = {}
-registered_param_bindings = []
-debug = shared.log.trace if os.environ.get('SD_PASTE_DEBUG', None) is not None else lambda *args, **kwargs: None
-debug('Trace: PASTE')
-
-
-class ParamBinding:
- def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=None):
- self.paste_button = paste_button
- self.tabname = tabname
- self.source_text_component = source_text_component
- self.source_image_component = source_image_component
- self.source_tabname = source_tabname
- self.override_settings_component = override_settings_component
- self.paste_field_names = paste_field_names or []
- debug(f'ParamBinding: {vars(self)}')
-
-
-def reset():
- paste_fields.clear()
-
-
-def quote(text):
- if ',' not in str(text) and '\n' not in str(text) and ':' not in str(text):
- return text
- return json.dumps(text, ensure_ascii=False)
-
-
-def unquote(text):
- if len(text) == 0 or text[0] != '"' or text[-1] != '"':
- return text
- try:
- return json.loads(text)
- except Exception:
- return text
-
-
-def image_from_url_text(filedata):
- if filedata is None:
- return None
- if type(filedata) == list and len(filedata) > 0 and type(filedata[0]) == dict and filedata[0].get("is_file", False):
- filedata = filedata[0]
- if type(filedata) == dict and filedata.get("is_file", False):
- filename = filedata["name"]
- is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
- if is_in_right_dir:
- filename = filename.rsplit('?', 1)[0]
- if not os.path.exists(filename):
- shared.log.error(f'Image file not found: {filename}')
- image = Image.new('RGB', (512, 512))
- image.info['parameters'] = f'Image file not found: {filename}'
- return image
- image = Image.open(filename)
- geninfo, _items = images.read_info_from_image(image)
- image.info['parameters'] = geninfo
- return image
- else:
- shared.log.warning(f'File access denied: {filename}')
- return None
- if type(filedata) == list:
- if len(filedata) == 0:
- return None
- filedata = filedata[0]
- if type(filedata) == dict:
- shared.log.warning('Incorrect filedata received')
- return None
- if filedata.startswith("data:image/png;base64,"):
- filedata = filedata[len("data:image/png;base64,"):]
- if filedata.startswith("data:image/webp;base64,"):
- filedata = filedata[len("data:image/webp;base64,"):]
- if filedata.startswith("data:image/jpeg;base64,"):
- filedata = filedata[len("data:image/jpeg;base64,"):]
- filedata = base64.decodebytes(filedata.encode('utf-8'))
- image = Image.open(io.BytesIO(filedata))
- images.read_info_from_image(image)
- return image
-
-
-def add_paste_fields(tabname, init_img, fields, override_settings_component=None):
- paste_fields[tabname] = {"init_img": init_img, "fields": fields, "override_settings_component": override_settings_component}
- # backwards compatibility for existing extensions
- import modules.ui
- if tabname == 'txt2img':
- modules.ui.txt2img_paste_fields = fields
- elif tabname == 'img2img':
- modules.ui.img2img_paste_fields = fields
-
-
-def create_buttons(tabs_list):
- buttons = {}
- for tab in tabs_list:
- name = tab
- if name == 'txt2img':
- name = 'Text'
- elif name == 'img2img':
- name = 'Image'
- elif name == 'inpaint':
- name = 'Inpaint'
- elif name == 'extras':
- name = 'Process'
- elif name == 'control':
- name = 'Control'
- buttons[tab] = gr.Button(f"➠ {name}", elem_id=f"{tab}_tab")
- return buttons
-
-
-def bind_buttons(buttons, send_image, send_generate_info):
- """old function for backwards compatibility; do not use this, use register_paste_params_button"""
- for tabname, button in buttons.items():
- source_text_component = send_generate_info if isinstance(send_generate_info, gr.components.Component) else None
- source_tabname = send_generate_info if isinstance(send_generate_info, str) else None
- bindings = ParamBinding(paste_button=button, tabname=tabname, source_text_component=source_text_component, source_image_component=send_image, source_tabname=source_tabname)
- register_paste_params_button(bindings)
-
-
-def register_paste_params_button(binding: ParamBinding):
- registered_param_bindings.append(binding)
-
-
-def connect_paste_params_buttons():
- binding: ParamBinding
- for binding in registered_param_bindings:
- if binding.tabname not in paste_fields:
- debug(f"Not not registered: tab={binding.tabname}")
- continue
- destination_image_component = paste_fields[binding.tabname]["init_img"]
- fields = paste_fields[binding.tabname]["fields"]
- override_settings_component = binding.override_settings_component or paste_fields[binding.tabname]["override_settings_component"]
- destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
- destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)
-
- if binding.source_image_component and destination_image_component:
- if isinstance(binding.source_image_component, gr.Gallery):
- func = send_image_and_dimensions if destination_width_component else image_from_url_text
- jsfunc = "extract_image_from_gallery"
- else:
- func = send_image_and_dimensions if destination_width_component else lambda x: x
- jsfunc = None
- binding.paste_button.click(
- fn=func,
- _js=jsfunc,
- inputs=[binding.source_image_component],
- outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
- show_progress=False,
- )
- if binding.source_text_component is not None and fields is not None:
- connect_paste(binding.paste_button, fields, binding.source_text_component, override_settings_component, binding.tabname)
- if binding.source_tabname is not None and fields is not None:
- paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else []) + binding.paste_field_names
- binding.paste_button.click(
- fn=lambda *x: x,
- inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
- outputs=[field for field, name in fields if name in paste_field_names],
- )
- binding.paste_button.click(
- fn=None,
- _js=f"switch_to_{binding.tabname}",
- inputs=[],
- outputs=[],
- show_progress=False,
- )
-
-
-def send_image_and_dimensions(x):
- img = x if isinstance(x, Image.Image) else image_from_url_text(x)
- if shared.opts.send_size and isinstance(img, Image.Image):
- w = img.width
- h = img.height
- else:
- w = gr.update()
- h = gr.update()
- return img, w, h
-
-
-def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
- """Determines the config parameter name to use for the hypernet based on the parameters in the infotext.
- Example: an infotext provides "Hypernet: ke-ta" and "Hypernet hash: 1234abcd". For the "Hypernet" config
- parameter this means there should be an entry that looks like "ke-ta-10000(1234abcd)" to set it to.
- If the infotext has no hash, then a hypernet with the same name will be selected instead.
- """
- hypernet_name = hypernet_name.lower()
- if hypernet_hash is not None:
- # Try to match the hash in the name
- for hypernet_key in shared.hypernetworks.keys():
- result = re_hypernet_hash.search(hypernet_key)
- if result is not None and result[1] == hypernet_hash:
- return hypernet_key
- else:
- # Fall back to a hypernet with the same name
- for hypernet_key in shared.hypernetworks.keys():
- if hypernet_key.lower().startswith(hypernet_name):
- return hypernet_key
-
- return None
-
-
-def parse_generation_parameters(x: str):
- res = {}
- if x is None:
- return res
- remaining = x.replace('\n', ' ').strip()
- if len(remaining) == 0:
- return res
- remaining = x[7:] if x.startswith('Prompt: ') else x
- remaining = x[11:] if x.startswith('parameters: ') else x
- if 'Steps: ' in remaining and 'Negative prompt: ' not in remaining:
- remaining = remaining.replace('Steps: ', 'Negative prompt: Steps: ')
- prompt, remaining = remaining.strip().split('Negative prompt: ', maxsplit=1) if 'Negative prompt: ' in remaining else (remaining, '')
- res["Prompt"] = prompt.strip()
- negative, remaining = remaining.strip().split('Steps: ', maxsplit=1) if 'Steps: ' in remaining else (remaining, None)
- res["Negative prompt"] = negative.strip()
- if remaining is None:
- return res
- remaining = f'Steps: {remaining}'
- for k, v in re_param.findall(remaining.strip()):
- try:
- if v[0] == '"' and v[-1] == '"':
- v = unquote(v)
- m = re_imagesize.match(v)
- if m is not None:
- res[f"{k}-1"] = m.group(1)
- res[f"{k}-2"] = m.group(2)
- else:
- res[k] = v
- except Exception:
- pass
- if res.get('VAE', None) == 'TAESD':
- res["Full quality"] = False
- debug(f"Parse prompt: {res}")
- return res
-
-
-settings_map = {}
-
-
-infotext_to_setting_name_mapping = [
- ('Backend', 'sd_backend'),
- ('Model hash', 'sd_model_checkpoint'),
- ('Refiner', 'sd_model_refiner'),
- ('VAE', 'sd_vae'),
- ('Parser', 'prompt_attention'),
- ('Color correction', 'img2img_color_correction'),
- # Samplers
- ('Sampler Eta', 'scheduler_eta'),
- ('Sampler ENSD', 'eta_noise_seed_delta'),
- ('Sampler order', 'schedulers_solver_order'),
- # Samplers diffusers
- ('Sampler beta schedule', 'schedulers_beta_schedule'),
- ('Sampler beta start', 'schedulers_beta_start'),
- ('Sampler beta end', 'schedulers_beta_end'),
- ('Sampler DPM solver', 'schedulers_dpm_solver'),
- # Samplers original
- ('Sampler brownian', 'schedulers_brownian_noise'),
- ('Sampler discard', 'schedulers_discard_penultimate'),
- ('Sampler dyn threshold', 'schedulers_use_thresholding'),
- ('Sampler karras', 'schedulers_use_karras'),
- ('Sampler low order', 'schedulers_use_loworder'),
- ('Sampler quantization', 'enable_quantization'),
- ('Sampler sigma', 'schedulers_sigma'),
- ('Sampler sigma min', 's_min'),
- ('Sampler sigma max', 's_max'),
- ('Sampler sigma churn', 's_churn'),
- ('Sampler sigma uncond', 's_min_uncond'),
- ('Sampler sigma noise', 's_noise'),
- ('Sampler sigma tmin', 's_tmin'),
- ('Sampler ENSM', 'initial_noise_multiplier'), # img2img only
- ('UniPC skip type', 'uni_pc_skip_type'),
- ('UniPC variant', 'uni_pc_variant'),
- # Token Merging
- ('Mask weight', 'inpainting_mask_weight'),
- ('Token merging ratio', 'token_merging_ratio'),
- ('ToMe', 'token_merging_ratio'),
- ('ToMe hires', 'token_merging_ratio_hr'),
- ('ToMe img2img', 'token_merging_ratio_img2img'),
-]
-
-
-def create_override_settings_dict(text_pairs):
- res = {}
- params = {}
- for pair in text_pairs:
- k, v = pair.split(":", maxsplit=1)
- params[k] = v.strip()
- for param_name, setting_name in infotext_to_setting_name_mapping:
- value = params.get(param_name, None)
- if value is None:
- continue
- res[setting_name] = shared.opts.cast_value(setting_name, value)
- return res
-
-
-def connect_paste(button, local_paste_fields, input_comp, override_settings_component, tabname):
-
- def paste_func(prompt):
- if prompt is None or len(prompt.strip()) == 0 and not shared.cmd_opts.hide_ui_dir_config:
- filename = os.path.join(data_path, "params.txt")
- if os.path.exists(filename):
- with open(filename, "r", encoding="utf8") as file:
- prompt = file.read()
- shared.log.debug(f'Paste prompt: type="params" prompt="{prompt}"')
- else:
- prompt = ''
- else:
- shared.log.debug(f'Paste prompt: type="current" prompt="{prompt}"')
- params = parse_generation_parameters(prompt)
- script_callbacks.infotext_pasted_callback(prompt, params)
- res = []
- applied = {}
- for output, key in local_paste_fields:
- if callable(key):
- v = key(params)
- else:
- v = params.get(key, None)
- if v is None:
- res.append(gr.update())
- elif isinstance(v, type_of_gr_update):
- res.append(v)
- applied[key] = v
- else:
- try:
- valtype = type(output.value)
- if valtype == bool and v == "False":
- val = False
- else:
- val = valtype(v)
- res.append(gr.update(value=val))
- applied[key] = val
- except Exception:
- res.append(gr.update())
- debug(f"Parse apply: {applied}")
- return res
-
- if override_settings_component is not None:
- def paste_settings(params):
- vals = {}
- for param_name, setting_name in infotext_to_setting_name_mapping:
- v = params.get(param_name, None)
- if v is None:
- continue
- if shared.opts.disable_weights_auto_swap:
- if setting_name == "sd_model_checkpoint" or setting_name == 'sd_model_refiner' or setting_name == 'sd_backend' or setting_name == 'sd_vae':
- continue
- v = shared.opts.cast_value(setting_name, v)
- current_value = getattr(shared.opts, setting_name, None)
- if v == current_value:
- continue
- if type(current_value) == str and v == os.path.splitext(current_value)[0]:
- continue
- vals[param_name] = v
- vals_pairs = [f"{k}: {v}" for k, v in vals.items()]
- shared.log.debug(f'Settings overrides: {vals_pairs}')
- return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=len(vals_pairs) > 0)
- local_paste_fields = local_paste_fields + [(override_settings_component, paste_settings)]
-
- button.click(
- fn=paste_func,
- inputs=[input_comp],
- outputs=[x[0] for x in local_paste_fields],
- show_progress=False,
- )
- button.click(
- fn=None,
- _js=f"recalculate_prompts_{tabname}",
- inputs=[],
- outputs=[],
- show_progress=False,
- )
+import base64
+import io
+import os
+import re
+import json
+from PIL import Image
+import gradio as gr
+from modules.paths import data_path
+from modules import shared, ui_tempdir, script_callbacks, images
+
+
+re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
+re_param = re.compile(re_param_code)
+re_imagesize = re.compile(r"^(\d+)x(\d+)$")
+re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$") # pylint: disable=anomalous-backslash-in-string
+type_of_gr_update = type(gr.update())
+paste_fields = {}
+registered_param_bindings = []
+debug = shared.log.trace if os.environ.get('SD_PASTE_DEBUG', None) is not None else lambda *args, **kwargs: None
+debug('Trace: PASTE')
+
+
+class ParamBinding:
+ def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=None):
+ self.paste_button = paste_button
+ self.tabname = tabname
+ self.source_text_component = source_text_component
+ self.source_image_component = source_image_component
+ self.source_tabname = source_tabname
+ self.override_settings_component = override_settings_component
+ self.paste_field_names = paste_field_names or []
+ debug(f'ParamBinding: {vars(self)}')
+
+
+def reset():
+ paste_fields.clear()
+
+
+def quote(text):
+ if ',' not in str(text) and '\n' not in str(text) and ':' not in str(text):
+ return text
+ return json.dumps(text, ensure_ascii=False)
+
+
+def unquote(text):
+ if len(text) == 0 or text[0] != '"' or text[-1] != '"':
+ return text
+ try:
+ return json.loads(text)
+ except Exception:
+ return text
+
+
+def image_from_url_text(filedata):
+ if filedata is None:
+ return None
+ if type(filedata) == list and len(filedata) > 0 and type(filedata[0]) == dict and filedata[0].get("is_file", False):
+ filedata = filedata[0]
+ if type(filedata) == dict and filedata.get("is_file", False):
+ filename = filedata["name"]
+ is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
+ if is_in_right_dir:
+ filename = filename.rsplit('?', 1)[0]
+ if not os.path.exists(filename):
+ shared.log.error(f'Image file not found: {filename}')
+ image = Image.new('RGB', (512, 512))
+ image.info['parameters'] = f'Image file not found: {filename}'
+ return image
+ image = Image.open(filename)
+ geninfo, _items = images.read_info_from_image(image)
+ image.info['parameters'] = geninfo
+ return image
+ else:
+ shared.log.warning(f'File access denied: {filename}')
+ return None
+ if type(filedata) == list:
+ if len(filedata) == 0:
+ return None
+ filedata = filedata[0]
+ if type(filedata) == dict:
+ shared.log.warning('Incorrect filedata received')
+ return None
+ if filedata.startswith("data:image/png;base64,"):
+ filedata = filedata[len("data:image/png;base64,"):]
+ if filedata.startswith("data:image/webp;base64,"):
+ filedata = filedata[len("data:image/webp;base64,"):]
+ if filedata.startswith("data:image/jpeg;base64,"):
+ filedata = filedata[len("data:image/jpeg;base64,"):]
+ filedata = base64.decodebytes(filedata.encode('utf-8'))
+ image = Image.open(io.BytesIO(filedata))
+ images.read_info_from_image(image)
+ return image
+
+
+def add_paste_fields(tabname, init_img, fields, override_settings_component=None):
+ paste_fields[tabname] = {"init_img": init_img, "fields": fields, "override_settings_component": override_settings_component}
+ # backwards compatibility for existing extensions
+ import modules.ui
+ if tabname == 'txt2img':
+ modules.ui.txt2img_paste_fields = fields
+ elif tabname == 'img2img':
+ modules.ui.img2img_paste_fields = fields
+
+
+def create_buttons(tabs_list):
+ buttons = {}
+ for tab in tabs_list:
+ name = tab
+ if name == 'txt2img':
+ name = 'Text'
+ elif name == 'img2img':
+ name = 'Image'
+ elif name == 'inpaint':
+ name = 'Inpaint'
+ elif name == 'extras':
+ name = 'Process'
+ elif name == 'control':
+ name = 'Control'
+ buttons[tab] = gr.Button(f"➠ {name}", elem_id=f"{tab}_tab")
+ return buttons
+
+
+def bind_buttons(buttons, send_image, send_generate_info):
+ """old function for backwards compatibility; do not use this, use register_paste_params_button"""
+ for tabname, button in buttons.items():
+ source_text_component = send_generate_info if isinstance(send_generate_info, gr.components.Component) else None
+ source_tabname = send_generate_info if isinstance(send_generate_info, str) else None
+ bindings = ParamBinding(paste_button=button, tabname=tabname, source_text_component=source_text_component, source_image_component=send_image, source_tabname=source_tabname)
+ register_paste_params_button(bindings)
+
+
+def register_paste_params_button(binding: ParamBinding):
+ registered_param_bindings.append(binding)
+
+
+def connect_paste_params_buttons():
+ binding: ParamBinding
+ for binding in registered_param_bindings:
+ if binding.tabname not in paste_fields:
+ debug(f"Not not registered: tab={binding.tabname}")
+ continue
+ destination_image_component = paste_fields[binding.tabname]["init_img"]
+ fields = paste_fields[binding.tabname]["fields"]
+ override_settings_component = binding.override_settings_component or paste_fields[binding.tabname]["override_settings_component"]
+ destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
+ destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)
+
+ if binding.source_image_component and destination_image_component:
+ if isinstance(binding.source_image_component, gr.Gallery):
+ func = send_image_and_dimensions if destination_width_component else image_from_url_text
+ jsfunc = "extract_image_from_gallery"
+ else:
+ func = send_image_and_dimensions if destination_width_component else lambda x: x
+ jsfunc = None
+ binding.paste_button.click(
+ fn=func,
+ _js=jsfunc,
+ inputs=[binding.source_image_component],
+ outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
+ show_progress=False,
+ )
+ if binding.source_text_component is not None and fields is not None:
+ connect_paste(binding.paste_button, fields, binding.source_text_component, override_settings_component, binding.tabname)
+ if binding.source_tabname is not None and fields is not None:
+ paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else []) + binding.paste_field_names
+ binding.paste_button.click(
+ fn=lambda *x: x,
+ inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
+ outputs=[field for field, name in fields if name in paste_field_names],
+ )
+ binding.paste_button.click(
+ fn=None,
+ _js=f"switch_to_{binding.tabname}",
+ inputs=[],
+ outputs=[],
+ show_progress=False,
+ )
+
+
+def send_image_and_dimensions(x):
+ img = x if isinstance(x, Image.Image) else image_from_url_text(x)
+ if shared.opts.send_size and isinstance(img, Image.Image):
+ w = img.width
+ h = img.height
+ else:
+ w = gr.update()
+ h = gr.update()
+ return img, w, h
+
+
+def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
+ """Determines the config parameter name to use for the hypernet based on the parameters in the infotext.
+ Example: an infotext provides "Hypernet: ke-ta" and "Hypernet hash: 1234abcd". For the "Hypernet" config
+ parameter this means there should be an entry that looks like "ke-ta-10000(1234abcd)" to set it to.
+ If the infotext has no hash, then a hypernet with the same name will be selected instead.
+ """
+ hypernet_name = hypernet_name.lower()
+ if hypernet_hash is not None:
+ # Try to match the hash in the name
+ for hypernet_key in shared.hypernetworks.keys():
+ result = re_hypernet_hash.search(hypernet_key)
+ if result is not None and result[1] == hypernet_hash:
+ return hypernet_key
+ else:
+ # Fall back to a hypernet with the same name
+ for hypernet_key in shared.hypernetworks.keys():
+ if hypernet_key.lower().startswith(hypernet_name):
+ return hypernet_key
+
+ return None
+
+
+def parse_generation_parameters(x: str):
+ res = {}
+ if x is None:
+ return res
+ remaining = x.replace('\n', ' ').strip()
+ if len(remaining) == 0:
+ return res
+ remaining = x[7:] if x.startswith('Prompt: ') else x
+ remaining = x[11:] if x.startswith('parameters: ') else x
+ if 'Steps: ' in remaining and 'Negative prompt: ' not in remaining:
+ remaining = remaining.replace('Steps: ', 'Negative prompt: Steps: ')
+ prompt, remaining = remaining.strip().split('Negative prompt: ', maxsplit=1) if 'Negative prompt: ' in remaining else (remaining, '')
+ res["Prompt"] = prompt.strip()
+ negative, remaining = remaining.strip().split('Steps: ', maxsplit=1) if 'Steps: ' in remaining else (remaining, None)
+ res["Negative prompt"] = negative.strip()
+ if remaining is None:
+ return res
+ remaining = f'Steps: {remaining}'
+ for k, v in re_param.findall(remaining.strip()):
+ try:
+ if v[0] == '"' and v[-1] == '"':
+ v = unquote(v)
+ m = re_imagesize.match(v)
+ if m is not None:
+ res[f"{k}-1"] = m.group(1)
+ res[f"{k}-2"] = m.group(2)
+ else:
+ res[k] = v
+ except Exception:
+ pass
+ if res.get('VAE', None) == 'TAESD':
+ res["Full quality"] = False
+ debug(f"Parse prompt: {res}")
+ return res
+
+
+settings_map = {}
+
+
+infotext_to_setting_name_mapping = [
+ ('Backend', 'sd_backend'),
+ ('Model hash', 'sd_model_checkpoint'),
+ ('Refiner', 'sd_model_refiner'),
+ ('VAE', 'sd_vae'),
+ ('Parser', 'prompt_attention'),
+ ('Color correction', 'img2img_color_correction'),
+ # Samplers
+ ('Sampler Eta', 'scheduler_eta'),
+ ('Sampler ENSD', 'eta_noise_seed_delta'),
+ ('Sampler order', 'schedulers_solver_order'),
+ # Samplers diffusers
+ ('Sampler beta schedule', 'schedulers_beta_schedule'),
+ ('Sampler beta start', 'schedulers_beta_start'),
+ ('Sampler beta end', 'schedulers_beta_end'),
+ ('Sampler DPM solver', 'schedulers_dpm_solver'),
+ # Samplers original
+ ('Sampler brownian', 'schedulers_brownian_noise'),
+ ('Sampler discard', 'schedulers_discard_penultimate'),
+ ('Sampler dyn threshold', 'schedulers_use_thresholding'),
+ ('Sampler karras', 'schedulers_use_karras'),
+ ('Sampler low order', 'schedulers_use_loworder'),
+ ('Sampler quantization', 'enable_quantization'),
+ ('Sampler sigma', 'schedulers_sigma'),
+ ('Sampler sigma min', 's_min'),
+ ('Sampler sigma max', 's_max'),
+ ('Sampler sigma churn', 's_churn'),
+ ('Sampler sigma uncond', 's_min_uncond'),
+ ('Sampler sigma noise', 's_noise'),
+ ('Sampler sigma tmin', 's_tmin'),
+ ('Sampler ENSM', 'initial_noise_multiplier'), # img2img only
+ ('UniPC skip type', 'uni_pc_skip_type'),
+ ('UniPC variant', 'uni_pc_variant'),
+ # Token Merging
+ ('Mask weight', 'inpainting_mask_weight'),
+ ('Token merging ratio', 'token_merging_ratio'),
+ ('ToMe', 'token_merging_ratio'),
+ ('ToMe hires', 'token_merging_ratio_hr'),
+ ('ToMe img2img', 'token_merging_ratio_img2img'),
+]
+
+
+def create_override_settings_dict(text_pairs):
+ res = {}
+ params = {}
+ for pair in text_pairs:
+ k, v = pair.split(":", maxsplit=1)
+ params[k] = v.strip()
+ for param_name, setting_name in infotext_to_setting_name_mapping:
+ value = params.get(param_name, None)
+ if value is None:
+ continue
+ res[setting_name] = shared.opts.cast_value(setting_name, value)
+ return res
+
+
+def connect_paste(button, local_paste_fields, input_comp, override_settings_component, tabname):
+
+ def paste_func(prompt):
+ if prompt is None or len(prompt.strip()) == 0 and not shared.cmd_opts.hide_ui_dir_config:
+ filename = os.path.join(data_path, "params.txt")
+ if os.path.exists(filename):
+ with open(filename, "r", encoding="utf8") as file:
+ prompt = file.read()
+ shared.log.debug(f'Paste prompt: type="params" prompt="{prompt}"')
+ else:
+ prompt = ''
+ else:
+ shared.log.debug(f'Paste prompt: type="current" prompt="{prompt}"')
+ params = parse_generation_parameters(prompt)
+ script_callbacks.infotext_pasted_callback(prompt, params)
+ res = []
+ applied = {}
+ for output, key in local_paste_fields:
+ if callable(key):
+ v = key(params)
+ else:
+ v = params.get(key, None)
+ if v is None:
+ res.append(gr.update())
+ elif isinstance(v, type_of_gr_update):
+ res.append(v)
+ applied[key] = v
+ else:
+ try:
+ valtype = type(output.value)
+ if valtype == bool and v == "False":
+ val = False
+ else:
+ val = valtype(v)
+ res.append(gr.update(value=val))
+ applied[key] = val
+ except Exception:
+ res.append(gr.update())
+ debug(f"Parse apply: {applied}")
+ return res
+
+ if override_settings_component is not None:
+ def paste_settings(params):
+ vals = {}
+ for param_name, setting_name in infotext_to_setting_name_mapping:
+ v = params.get(param_name, None)
+ if v is None:
+ continue
+ if shared.opts.disable_weights_auto_swap:
+ if setting_name == "sd_model_checkpoint" or setting_name == 'sd_model_refiner' or setting_name == 'sd_backend' or setting_name == 'sd_vae':
+ continue
+ v = shared.opts.cast_value(setting_name, v)
+ current_value = getattr(shared.opts, setting_name, None)
+ if v == current_value:
+ continue
+ if type(current_value) == str and v == os.path.splitext(current_value)[0]:
+ continue
+ vals[param_name] = v
+ vals_pairs = [f"{k}: {v}" for k, v in vals.items()]
+ shared.log.debug(f'Settings overrides: {vals_pairs}')
+ return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=len(vals_pairs) > 0)
+ local_paste_fields = local_paste_fields + [(override_settings_component, paste_settings)]
+
+ button.click(
+ fn=paste_func,
+ inputs=[input_comp],
+ outputs=[x[0] for x in local_paste_fields],
+ show_progress=False,
+ )
+ button.click(
+ fn=None,
+ _js=f"recalculate_prompts_{tabname}",
+ inputs=[],
+ outputs=[],
+ show_progress=False,
+ )
diff --git a/modules/hashes.py b/modules/hashes.py
index 892368d3c..607140fe6 100644
--- a/modules/hashes.py
+++ b/modules/hashes.py
@@ -1,107 +1,107 @@
-import copy
-import hashlib
-import os.path
-from rich import progress, errors
-from modules import shared
-from modules.paths import data_path
-
-cache_filename = os.path.join(data_path, "cache.json")
-cache_data = None
-progress_ok = True
-
-def dump_cache():
- shared.writefile(cache_data, cache_filename)
-
-
-def cache(subsection):
- global cache_data # pylint: disable=global-statement
- if cache_data is None:
- cache_data = {} if not os.path.isfile(cache_filename) else shared.readfile(cache_filename, lock=True)
- s = cache_data.get(subsection, {})
- cache_data[subsection] = s
- return s
-
-
-def calculate_sha256(filename, quiet=False):
- global progress_ok # pylint: disable=global-statement
- hash_sha256 = hashlib.sha256()
- blksize = 1024 * 1024
- if not quiet:
- if progress_ok:
- try:
- with progress.open(filename, 'rb', description=f'[cyan]Calculating hash: [yellow]{filename}', auto_refresh=True, console=shared.console) as f:
- for chunk in iter(lambda: f.read(blksize), b""):
- hash_sha256.update(chunk)
- except errors.LiveError:
- shared.log.warning('Hash: attempting to use function in a thread')
- progress_ok = False
- if not progress_ok:
- with open(filename, 'rb') as f:
- for chunk in iter(lambda: f.read(blksize), b""):
- hash_sha256.update(chunk)
- else:
- with open(filename, 'rb') as f:
- for chunk in iter(lambda: f.read(blksize), b""):
- hash_sha256.update(chunk)
- return hash_sha256.hexdigest()
-
-
-def sha256_from_cache(filename, title, use_addnet_hash=False):
- hashes = cache("hashes-addnet") if use_addnet_hash else cache("hashes")
- if title not in hashes:
- return None
- cached_sha256 = hashes[title].get("sha256", None)
- cached_mtime = hashes[title].get("mtime", 0)
- ondisk_mtime = os.path.getmtime(filename) if os.path.isfile(filename) else 0
- if ondisk_mtime > cached_mtime or cached_sha256 is None:
- return None
- return cached_sha256
-
-
-def sha256(filename, title, use_addnet_hash=False):
- global progress_ok # pylint: disable=global-statement
- hashes = cache("hashes-addnet") if use_addnet_hash else cache("hashes")
- sha256_value = sha256_from_cache(filename, title, use_addnet_hash)
- if sha256_value is not None:
- return sha256_value
- if shared.cmd_opts.no_hashing:
- return None
- if not os.path.isfile(filename):
- return None
- orig_state = copy.deepcopy(shared.state)
- shared.state.begin("hash")
- if use_addnet_hash:
- if progress_ok:
- try:
- with progress.open(filename, 'rb', description=f'[cyan]Calculating hash: [yellow]{filename}', auto_refresh=True, console=shared.console) as f:
- sha256_value = addnet_hash_safetensors(f)
- except errors.LiveError:
- shared.log.warning('Hash: attempting to use function in a thread')
- progress_ok = False
- if not progress_ok:
- with open(filename, 'rb') as f:
- sha256_value = addnet_hash_safetensors(f)
- else:
- sha256_value = calculate_sha256(filename)
- hashes[title] = {
- "mtime": os.path.getmtime(filename),
- "sha256": sha256_value
- }
- shared.state.end()
- shared.state = orig_state
- dump_cache()
- return sha256_value
-
-
-def addnet_hash_safetensors(b):
- """kohya-ss hash for safetensors from https://github.com/kohya-ss/sd-scripts/blob/main/library/train_util.py"""
- hash_sha256 = hashlib.sha256()
- blksize = 1024 * 1024
- b.seek(0)
- header = b.read(8)
- n = int.from_bytes(header, "little")
- offset = n + 8
- b.seek(offset)
- for chunk in iter(lambda: b.read(blksize), b""):
- hash_sha256.update(chunk)
- return hash_sha256.hexdigest()
+import copy
+import hashlib
+import os.path
+from rich import progress, errors
+from modules import shared
+from modules.paths import data_path
+
+cache_filename = os.path.join(data_path, "cache.json")
+cache_data = None
+progress_ok = True
+
+def dump_cache():
+ shared.writefile(cache_data, cache_filename)
+
+
+def cache(subsection):
+ global cache_data # pylint: disable=global-statement
+ if cache_data is None:
+ cache_data = {} if not os.path.isfile(cache_filename) else shared.readfile(cache_filename, lock=True)
+ s = cache_data.get(subsection, {})
+ cache_data[subsection] = s
+ return s
+
+
+def calculate_sha256(filename, quiet=False):
+ global progress_ok # pylint: disable=global-statement
+ hash_sha256 = hashlib.sha256()
+ blksize = 1024 * 1024
+ if not quiet:
+ if progress_ok:
+ try:
+ with progress.open(filename, 'rb', description=f'[cyan]Calculating hash: [yellow]{filename}', auto_refresh=True, console=shared.console) as f:
+ for chunk in iter(lambda: f.read(blksize), b""):
+ hash_sha256.update(chunk)
+ except errors.LiveError:
+ shared.log.warning('Hash: attempting to use function in a thread')
+ progress_ok = False
+ if not progress_ok:
+ with open(filename, 'rb') as f:
+ for chunk in iter(lambda: f.read(blksize), b""):
+ hash_sha256.update(chunk)
+ else:
+ with open(filename, 'rb') as f:
+ for chunk in iter(lambda: f.read(blksize), b""):
+ hash_sha256.update(chunk)
+ return hash_sha256.hexdigest()
+
+
+def sha256_from_cache(filename, title, use_addnet_hash=False):
+ hashes = cache("hashes-addnet") if use_addnet_hash else cache("hashes")
+ if title not in hashes:
+ return None
+ cached_sha256 = hashes[title].get("sha256", None)
+ cached_mtime = hashes[title].get("mtime", 0)
+ ondisk_mtime = os.path.getmtime(filename) if os.path.isfile(filename) else 0
+ if ondisk_mtime > cached_mtime or cached_sha256 is None:
+ return None
+ return cached_sha256
+
+
+def sha256(filename, title, use_addnet_hash=False):
+ global progress_ok # pylint: disable=global-statement
+ hashes = cache("hashes-addnet") if use_addnet_hash else cache("hashes")
+ sha256_value = sha256_from_cache(filename, title, use_addnet_hash)
+ if sha256_value is not None:
+ return sha256_value
+ if shared.cmd_opts.no_hashing:
+ return None
+ if not os.path.isfile(filename):
+ return None
+ orig_state = copy.deepcopy(shared.state)
+ shared.state.begin("hash")
+ if use_addnet_hash:
+ if progress_ok:
+ try:
+ with progress.open(filename, 'rb', description=f'[cyan]Calculating hash: [yellow]{filename}', auto_refresh=True, console=shared.console) as f:
+ sha256_value = addnet_hash_safetensors(f)
+ except errors.LiveError:
+ shared.log.warning('Hash: attempting to use function in a thread')
+ progress_ok = False
+ if not progress_ok:
+ with open(filename, 'rb') as f:
+ sha256_value = addnet_hash_safetensors(f)
+ else:
+ sha256_value = calculate_sha256(filename)
+ hashes[title] = {
+ "mtime": os.path.getmtime(filename),
+ "sha256": sha256_value
+ }
+ shared.state.end()
+ shared.state = orig_state
+ dump_cache()
+ return sha256_value
+
+
+def addnet_hash_safetensors(b):
+ """kohya-ss hash for safetensors from https://github.com/kohya-ss/sd-scripts/blob/main/library/train_util.py"""
+ hash_sha256 = hashlib.sha256()
+ blksize = 1024 * 1024
+ b.seek(0)
+ header = b.read(8)
+ n = int.from_bytes(header, "little")
+ offset = n + 8
+ b.seek(offset)
+ for chunk in iter(lambda: b.read(blksize), b""):
+ hash_sha256.update(chunk)
+ return hash_sha256.hexdigest()
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 54f82c36a..5afa572be 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -1,755 +1,755 @@
-import datetime
-import html
-import os
-from collections import deque
-import inspect
-from statistics import stdev, mean
-from rich import progress
-import tqdm
-import torch
-from torch import einsum
-from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
-from einops import rearrange, repeat
-from ldm.util import default
-from modules import devices, processing, sd_models, shared, hashes, errors
-import modules.textual_inversion.dataset
-from modules.textual_inversion import textual_inversion, ti_logging
-from modules.textual_inversion.learn_schedule import LearnRateScheduler
-
-
-optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
-
-class HypernetworkModule(torch.nn.Module):
- activation_dict = {
- "linear": torch.nn.Identity,
- "relu": torch.nn.ReLU,
- "leakyrelu": torch.nn.LeakyReLU,
- "elu": torch.nn.ELU,
- "swish": torch.nn.Hardswish,
- "tanh": torch.nn.Tanh,
- "sigmoid": torch.nn.Sigmoid,
- }
- activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
-
- def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
- add_layer_norm=False, activate_output=False, dropout_structure=None):
- super().__init__()
- self.multiplier = 1.0
- assert layer_structure is not None, "layer_structure must not be None"
- assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
- assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
- linears = []
- for i in range(len(layer_structure) - 1):
- # Add a fully-connected layer
- linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
- # Add an activation func except last layer
- if activation_func == "linear" or activation_func is None or (i >= len(layer_structure) - 2 and not activate_output):
- pass
- elif activation_func in self.activation_dict:
- linears.append(self.activation_dict[activation_func]())
- else:
- raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
- # Add layer normalization
- if add_layer_norm:
- linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
- # Everything should be now parsed into dropout structure, and applied here.
- # Since we only have dropouts after layers, dropout structure should start with 0 and end with 0.
- if dropout_structure is not None and dropout_structure[i+1] > 0:
- assert 0 < dropout_structure[i+1] < 1, "Dropout probability should be 0 or float between 0 and 1!"
- linears.append(torch.nn.Dropout(p=dropout_structure[i+1]))
- # Code explanation : [1, 2, 1] -> dropout is missing when last_layer_dropout is false. [1, 2, 2, 1] -> [0, 0.3, 0, 0], when its True, [0, 0.3, 0.3, 0].
- self.linear = torch.nn.Sequential(*linears)
- if state_dict is not None:
- self.fix_old_state_dict(state_dict)
- self.load_state_dict(state_dict)
- else:
- for layer in self.linear:
- if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
- w, b = layer.weight.data, layer.bias.data
- if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm:
- normal_(w, mean=0.0, std=0.01)
- normal_(b, mean=0.0, std=0)
- elif weight_init == 'XavierUniform':
- xavier_uniform_(w)
- zeros_(b)
- elif weight_init == 'XavierNormal':
- xavier_normal_(w)
- zeros_(b)
- elif weight_init == 'KaimingUniform':
- kaiming_uniform_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
- zeros_(b)
- elif weight_init == 'KaimingNormal':
- kaiming_normal_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
- zeros_(b)
- else:
- raise KeyError(f"Key {weight_init} is not defined as initialization!")
- self.to(devices.device)
-
- def fix_old_state_dict(self, state_dict):
- changes = {
- 'linear1.bias': 'linear.0.bias',
- 'linear1.weight': 'linear.0.weight',
- 'linear2.bias': 'linear.1.bias',
- 'linear2.weight': 'linear.1.weight',
- }
- for fr, to in changes.items():
- x = state_dict.get(fr, None)
- if x is None:
- continue
- del state_dict[fr]
- state_dict[to] = x
-
- def forward(self, x):
- return x + self.linear(x) * (self.multiplier if not self.training else 1)
-
- def trainables(self):
- layer_structure = []
- for layer in self.linear:
- if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
- layer_structure += [layer.weight, layer.bias]
- return layer_structure
-
-
-#param layer_structure : sequence used for length, use_dropout : controlling boolean, last_layer_dropout : for compatibility check.
-def parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout):
- if layer_structure is None:
- layer_structure = [1, 2, 1]
- if not use_dropout:
- return [0] * len(layer_structure)
- dropout_values = [0]
- dropout_values.extend([0.3] * (len(layer_structure) - 3))
- if last_layer_dropout:
- dropout_values.append(0.3)
- else:
- dropout_values.append(0)
- dropout_values.append(0)
- return dropout_values
-
-
-class Hypernetwork:
- filename = None
- name = None
-
- def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False, **kwargs):
- self.filename = None
- self.name = name
- self.layers = {}
- self.step = 0
- self.sd_checkpoint = None
- self.sd_checkpoint_name = None
- self.layer_structure = layer_structure
- self.activation_func = activation_func
- self.weight_init = weight_init
- self.add_layer_norm = add_layer_norm
- self.use_dropout = use_dropout
- self.activate_output = activate_output
- self.last_layer_dropout = kwargs.get('last_layer_dropout', True)
- self.dropout_structure = kwargs.get('dropout_structure', None)
- if self.dropout_structure is None:
- self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
- self.optimizer_name = None
- self.optimizer_state_dict = None
- self.optional_info = None
- for size in enable_sizes or []:
- self.layers[size] = (
- HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
- self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),
- HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
- self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),
- )
- self.eval()
-
- def weights(self):
- res = []
- for layers in self.layers.values():
- for layer in layers:
- res += layer.parameters()
- return res
-
- def train(self, mode=True):
- for layers in self.layers.values():
- for layer in layers:
- layer.train(mode=mode)
- for param in layer.parameters():
- param.requires_grad = mode
-
- def to(self, device):
- for layers in self.layers.values():
- for layer in layers:
- layer.to(device)
-
- return self
-
- def set_multiplier(self, multiplier):
- for layers in self.layers.values():
- for layer in layers:
- layer.multiplier = multiplier
-
- return self
-
- def eval(self):
- for layers in self.layers.values():
- for layer in layers:
- layer.eval()
- for param in layer.parameters():
- param.requires_grad = False
-
- def save(self, filename):
- state_dict = {}
- optimizer_saved_dict = {}
- for k, v in self.layers.items():
- state_dict[k] = (v[0].state_dict(), v[1].state_dict())
- state_dict['step'] = self.step
- state_dict['name'] = self.name
- state_dict['layer_structure'] = self.layer_structure
- state_dict['activation_func'] = self.activation_func
- state_dict['is_layer_norm'] = self.add_layer_norm
- state_dict['weight_initialization'] = self.weight_init
- state_dict['sd_checkpoint'] = self.sd_checkpoint
- state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
- state_dict['activate_output'] = self.activate_output
- state_dict['use_dropout'] = self.use_dropout
- state_dict['dropout_structure'] = self.dropout_structure
- state_dict['last_layer_dropout'] = (self.dropout_structure[-2] != 0) if self.dropout_structure is not None else self.last_layer_dropout
- state_dict['optional_info'] = self.optional_info if self.optional_info else None
- if self.optimizer_name is not None:
- optimizer_saved_dict['optimizer_name'] = self.optimizer_name
- torch.save(state_dict, filename)
- if shared.opts.save_optimizer_state and self.optimizer_state_dict:
- optimizer_saved_dict['hash'] = self.shorthash()
- optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
- torch.save(optimizer_saved_dict, f"{filename}.optim")
-
- def load(self, filename):
- self.filename = filename if os.path.exists(filename) else os.path.join(shared.opts.hypernetwork_dir, filename)
- if self.name is None:
- self.name = os.path.splitext(os.path.basename(self.filename))[0]
- with progress.open(self.filename, 'rb', description=f'Load hypernetwork: [cyan]{self.filename}', auto_refresh=True, console=shared.console) as f:
- state_dict = torch.load(f, map_location='cpu')
- self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
- self.optional_info = state_dict.get('optional_info', None)
- self.activation_func = state_dict.get('activation_func', None)
- self.weight_init = state_dict.get('weight_initialization', 'Normal')
- self.add_layer_norm = state_dict.get('is_layer_norm', False)
- self.dropout_structure = state_dict.get('dropout_structure', None)
- self.use_dropout = True if self.dropout_structure is not None and any(self.dropout_structure) else state_dict.get('use_dropout', False)
- self.activate_output = state_dict.get('activate_output', True)
- self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
- # Dropout structure should have same length as layer structure, Every digits should be in [0,1), and last digit must be 0.
- if self.dropout_structure is None:
- self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
- if shared.opts.print_hypernet_extra:
- if self.optional_info is not None:
- print(f" INFO:\n {self.optional_info}\n")
- print(f" Layer structure: {self.layer_structure}")
- print(f" Activation function: {self.activation_func}")
- print(f" Weight initialization: {self.weight_init}")
- print(f" Layer norm: {self.add_layer_norm}")
- print(f" Dropout usage: {self.use_dropout}" )
- print(f" Activate last layer: {self.activate_output}")
- print(f" Dropout structure: {self.dropout_structure}")
- optimizer_saved_dict = torch.load(self.filename + '.optim', map_location='cpu') if os.path.exists(self.filename + '.optim') else {}
- if self.shorthash() == optimizer_saved_dict.get('hash', None):
- self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
- else:
- self.optimizer_state_dict = None
- if self.optimizer_state_dict:
- self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
- if shared.opts.print_hypernet_extra:
- print("Load existing optimizer from checkpoint")
- print(f"Optimizer name is {self.optimizer_name}")
- else:
- self.optimizer_name = "AdamW"
- if shared.opts.print_hypernet_extra:
- print("No saved optimizer exists in checkpoint")
- for size, sd in state_dict.items():
- if type(size) == int:
- self.layers[size] = (
- HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
- self.add_layer_norm, self.activate_output, self.dropout_structure),
- HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
- self.add_layer_norm, self.activate_output, self.dropout_structure),
- )
- self.name = state_dict.get('name', self.name)
- self.step = state_dict.get('step', 0)
- self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
- self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
- self.eval()
-
- def shorthash(self):
- sha256 = hashes.sha256(self.filename, f'hypernet/{self.name}')
- return sha256[0:10] if sha256 else None
-
-
-def list_hypernetworks(path):
- res = {}
- def list_folder(folder):
- for filename in os.listdir(folder):
- fn = os.path.join(folder, filename)
- if os.path.isfile(fn) and fn.lower().endswith(".pt"):
- name = os.path.splitext(os.path.basename(fn))[0]
- res[name] = fn
- elif os.path.isdir(fn) and not fn.startswith('.'):
- list_folder(fn)
-
- list_folder(path)
- return res
-
-
-def load_hypernetwork(name):
- path = shared.hypernetworks.get(name, None)
- if path is None:
- return None
- hypernetwork = Hypernetwork()
- try:
- hypernetwork.load(path)
- except Exception as e:
- errors.display(e, f'hypernetwork load: {path}')
- return None
- return hypernetwork
-
-
-def load_hypernetworks(names, multipliers=None):
- already_loaded = {}
- for hypernetwork in shared.loaded_hypernetworks:
- if hypernetwork.name in names:
- already_loaded[hypernetwork.name] = hypernetwork
- shared.loaded_hypernetworks.clear()
- for i, name in enumerate(names):
- hypernetwork = already_loaded.get(name, None)
- if hypernetwork is None:
- hypernetwork = load_hypernetwork(name)
- if hypernetwork is None:
- continue
- hypernetwork.set_multiplier(multipliers[i] if multipliers else 1.0)
- shared.loaded_hypernetworks.append(hypernetwork)
-
-
-def find_closest_hypernetwork_name(search: str):
- if not search:
- return None
- search = search.lower()
- applicable = [name for name in shared.hypernetworks if search in name.lower()]
- if not applicable:
- return None
- applicable = sorted(applicable, key=lambda name: len(name))
- return applicable[0]
-
-
-def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
- hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
- if hypernetwork_layers is None:
- return context_k, context_v
- if layer is not None:
- layer.hyper_k = hypernetwork_layers[0]
- layer.hyper_v = hypernetwork_layers[1]
- context_k = devices.cond_cast_unet(hypernetwork_layers[0](devices.cond_cast_float(context_k)))
- context_v = devices.cond_cast_unet(hypernetwork_layers[1](devices.cond_cast_float(context_v)))
- return context_k, context_v
-
-
-def apply_hypernetworks(hypernetworks, context, layer=None):
- context_k = context
- context_v = context
- for hypernetwork in hypernetworks:
- context_k, context_v = apply_single_hypernetwork(hypernetwork, context_k, context_v, layer)
- return context_k, context_v
-
-
-def attention_CrossAttention_forward(self, x, context=None, mask=None):
- h = self.heads
- q = self.to_q(x)
- context = default(context, x)
- context_k, context_v = apply_hypernetworks(shared.loaded_hypernetworks, context, self)
- k = self.to_k(context_k)
- v = self.to_v(context_v)
- q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
- sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
- if mask is not None:
- mask = rearrange(mask, 'b ... -> b (...)')
- max_neg_value = -torch.finfo(sim.dtype).max
- mask = repeat(mask, 'b j -> (b h) () j', h=h)
- sim.masked_fill_(~mask, max_neg_value)
- # attention, what we cannot get enough of
- attn = sim.softmax(dim=-1)
- out = einsum('b i j, b j d -> b i d', attn, v)
- out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
- return self.to_out(out)
-
-
-def stack_conds(conds):
- if len(conds) == 1:
- return torch.stack(conds)
- # same as in reconstruct_multicond_batch
- token_count = max([x.shape[0] for x in conds])
- for i in range(len(conds)):
- if conds[i].shape[0] != token_count:
- last_vector = conds[i][-1:]
- last_vector_repeated = last_vector.repeat([token_count - conds[i].shape[0], 1])
- conds[i] = torch.vstack([conds[i], last_vector_repeated])
- return torch.stack(conds)
-
-
-def statistics(data):
- if len(data) < 2:
- std = 0
- else:
- std = stdev(data)
- total_information = f"loss:{mean(data):.3f}" + "\u00B1" + f"({std/ (len(data) ** 0.5):.3f})"
- recent_data = data[-32:]
- if len(recent_data) < 2:
- std = 0
- else:
- std = stdev(recent_data)
- recent_information = f"recent 32 loss:{mean(recent_data):.3f}" + "\u00B1" + f"({std / (len(recent_data) ** 0.5):.3f})"
- return total_information, recent_information
-
-
-def report_statistics(loss_info:dict):
- keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
- for key in keys:
- try:
- print("Loss statistics for file " + key)
- info, recent = statistics(list(loss_info[key]))
- print(info)
- print(recent)
- except Exception as e:
- print(e)
-
-
-def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
- # Remove illegal characters from name.
- name = "".join( x for x in name if (x.isalnum() or x in "._- "))
- assert name, "Name cannot be empty!"
- fn = os.path.join(shared.opts.hypernetwork_dir, f"{name}.pt")
- if not overwrite_old:
- assert not os.path.exists(fn), f"file {fn} already exists"
- if type(layer_structure) == str:
- layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
- if use_dropout and dropout_structure and type(dropout_structure) == str:
- dropout_structure = [float(x.strip()) for x in dropout_structure.split(",")]
- else:
- dropout_structure = [0] * len(layer_structure)
- hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
- name=name,
- enable_sizes=[int(x) for x in enable_sizes],
- layer_structure=layer_structure,
- activation_func=activation_func,
- weight_init=weight_init,
- add_layer_norm=add_layer_norm,
- use_dropout=use_dropout,
- dropout_structure=dropout_structure
- )
- hypernet.save(fn)
- shared.reload_hypernetworks()
- return name
-
-
-def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): # pylint: disable=unused-argument
- # images allows training previews to have infotext. Importing it at the top causes a circular import problem.
- from modules import images, sd_hijack_checkpoint
-
- save_hypernetwork_every = save_hypernetwork_every or 0
- create_image_every = create_image_every or 0
- template_file = textual_inversion.textual_inversion_templates.get(template_filename, None)
- textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_hypernetwork_every, create_image_every, name="hypernetwork")
- template_file = template_file.path
-
- path = shared.hypernetworks.get(hypernetwork_name, None)
- hypernetwork = Hypernetwork()
- hypernetwork.load(path)
- shared.loaded_hypernetworks = [hypernetwork]
-
- shared.state.job = "train"
- shared.state.textinfo = "Initializing hypernetwork training..."
- shared.state.job_count = steps
-
- hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
- filename = os.path.join(shared.opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
-
- log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
- unload = shared.opts.unload_models_when_training
-
- if save_hypernetwork_every > 0:
- hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
- os.makedirs(hypernetwork_dir, exist_ok=True)
- else:
- hypernetwork_dir = None
-
- if create_image_every > 0:
- images_dir = os.path.join(log_directory, "images")
- os.makedirs(images_dir, exist_ok=True)
- else:
- images_dir = None
-
- checkpoint = sd_models.select_checkpoint()
-
- initial_step = hypernetwork.step or 0
- if initial_step >= steps:
- shared.state.textinfo = "Model has already been trained beyond specified max steps"
- return hypernetwork, filename
-
- scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
-
- clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None
- if clip_grad:
- clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
-
- if shared.opts.training_enable_tensorboard:
- tensorboard_writer = textual_inversion.tensorboard_setup(log_directory)
-
- # dataset loading may take a while, so input validations and early returns should be done before this
- shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
-
- pin_memory = shared.opts.pin_memory
-
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize, use_weight=use_weight)
-
- if shared.opts.save_training_settings_to_txt:
- saved_params = dict(
- model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds),
- **{field: getattr(hypernetwork, field) for field in ['layer_structure', 'activation_func', 'weight_init', 'add_layer_norm', 'use_dropout', ]}
- )
- ti_logging.save_settings_to_file(log_directory, {**saved_params, **locals()})
-
- latent_sampling_method = ds.latent_sampling_method
-
- dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
-
- old_parallel_processing_allowed = shared.parallel_processing_allowed
-
- if unload:
- shared.parallel_processing_allowed = False
- shared.sd_model.cond_stage_model.to(devices.cpu)
- shared.sd_model.first_stage_model.to(devices.cpu)
-
- weights = hypernetwork.weights()
- hypernetwork.train()
-
- # Here we use optimizer from saved HN, or we can specify as UI option.
- if hypernetwork.optimizer_name in optimizer_dict:
- optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
- optimizer_name = hypernetwork.optimizer_name
- else:
- print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
- optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
- optimizer_name = 'AdamW'
-
- if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
- try:
- optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
- except RuntimeError as e:
- print("Cannot resume from saved optimizer!")
- print(e)
-
- scaler = torch.cuda.amp.GradScaler()
-
- batch_size = ds.batch_size
- gradient_step = ds.gradient_step
- # n steps = batch_size * gradient_step * n image processed
- steps_per_epoch = len(ds) // batch_size // gradient_step
- max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
- loss_step = 0
- _loss_step = 0 #internal
- # size = len(ds.indexes)
- # loss_dict = defaultdict(lambda : deque(maxlen = 1024))
- loss_logging = deque(maxlen=len(ds) * 3) # this should be configurable parameter, this is 3 * epoch(dataset size)
- # losses = torch.zeros((size,))
- # previous_mean_losses = [0]
- # previous_mean_loss = 0
- # print("Mean loss of {} elements".format(size))
-
- _steps_without_grad = 0
-
- last_saved_file = ""
- last_saved_image = ""
- forced_filename = ""
-
- pbar = tqdm.tqdm(total=steps - initial_step)
- try:
- sd_hijack_checkpoint.add()
-
- for _i in range((steps-initial_step) * gradient_step):
- if scheduler.finished:
- break
- if shared.state.interrupted:
- break
- for j, batch in enumerate(dl):
- # works as a drop_last=True for gradient accumulation
- if j == max_steps_per_epoch:
- break
- scheduler.apply(optimizer, hypernetwork.step)
- if scheduler.finished:
- break
- if shared.state.interrupted:
- break
-
- if clip_grad:
- clip_grad_sched.step(hypernetwork.step)
-
- with devices.autocast():
- x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
- if use_weight:
- w = batch.weight.to(devices.device, non_blocking=pin_memory)
- if tag_drop_out != 0 or shuffle_tags:
- shared.sd_model.cond_stage_model.to(devices.device)
- c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)
- shared.sd_model.cond_stage_model.to(devices.cpu)
- else:
- c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
- if use_weight:
- loss = shared.sd_model.weighted_forward(x, c, w)[0] / gradient_step
- del w
- else:
- loss = shared.sd_model.forward(x, c)[0] / gradient_step
- del x
- del c
- _loss_step += loss.item()
-
- scaler.scale(loss).backward()
- # go back until we reach gradient accumulation steps
- if (j + 1) % gradient_step != 0:
- continue
- loss_logging.append(_loss_step)
- if clip_grad:
- clip_grad(weights, clip_grad_sched.learn_rate)
-
- scaler.step(optimizer)
- scaler.update()
- hypernetwork.step += 1
- pbar.update()
- optimizer.zero_grad(set_to_none=True)
- loss_step = _loss_step
- _loss_step = 0
- steps_done = hypernetwork.step + 1
- epoch_num = hypernetwork.step // steps_per_epoch
- epoch_step = hypernetwork.step % steps_per_epoch
-
- description = f"Training hypernetwork [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}"
- pbar.set_description(description)
- if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
- # Before saving, change name to match current checkpoint.
- hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
- last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
- hypernetwork.optimizer_name = optimizer_name
- if shared.opts.save_optimizer_state:
- hypernetwork.optimizer_state_dict = optimizer.state_dict()
- save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
- hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
-
-
-
- if shared.opts.training_enable_tensorboard:
- epoch_num = hypernetwork.step // len(ds)
- epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1
- mean_loss = sum(loss_logging) / len(loss_logging)
- textual_inversion.tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, learn_rate=scheduler.learn_rate, epoch_num=epoch_num)
-
- textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
- "loss": f"{loss_step:.7f}",
- "learn_rate": scheduler.learn_rate
- })
-
- if images_dir is not None and steps_done % create_image_every == 0:
- forced_filename = f'{hypernetwork_name}-{steps_done}'
- last_saved_image = os.path.join(images_dir, forced_filename)
- hypernetwork.eval()
- rng_state = torch.get_rng_state()
- cuda_rng_state = None
- cuda_rng_state = torch.cuda.get_rng_state_all()
- shared.sd_model.cond_stage_model.to(devices.device)
- shared.sd_model.first_stage_model.to(devices.device)
-
- p = processing.StableDiffusionProcessingTxt2Img(
- sd_model=shared.sd_model,
- do_not_save_grid=True,
- do_not_save_samples=True,
- )
-
- p.disable_extra_networks = True
-
- if preview_from_txt2img:
- p.prompt = preview_prompt
- p.negative_prompt = preview_negative_prompt
- p.steps = preview_steps
- p.sampler_name = processing.get_sampler_name(preview_sampler_index)
- p.cfg_scale = preview_cfg_scale
- p.seed = preview_seed
- p.width = preview_width
- p.height = preview_height
- else:
- p.prompt = batch.cond_text[0]
- p.steps = 20
- p.width = training_width
- p.height = training_height
-
- preview_text = p.prompt
-
- processed = processing.process_images(p)
- image = processed.images[0] if len(processed.images) > 0 else None
-
- if unload:
- shared.sd_model.cond_stage_model.to(devices.cpu)
- shared.sd_model.first_stage_model.to(devices.cpu)
- torch.set_rng_state(rng_state)
- torch.cuda.set_rng_state_all(cuda_rng_state)
- hypernetwork.train()
- if image is not None:
- shared.state.assign_current_image(image)
- if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
- textual_inversion.tensorboard_add_image(tensorboard_writer,
- f"Validation at epoch {epoch_num}", image,
- hypernetwork.step)
- last_saved_image, _last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
- last_saved_image += f", prompt: {preview_text}"
-
- shared.state.job_no = hypernetwork.step
-
- shared.state.textinfo = f"""
-
-Loss: {loss_step:.7f}
-Step: {steps_done}
-Last prompt: {html.escape(batch.cond_text[0])}
-Last saved hypernetwork: {html.escape(last_saved_file)}
-Last saved image: {html.escape(last_saved_image)}
-
-"""
- except Exception as e:
- errors.display(e, 'hypernetwork train')
- finally:
- pbar.leave = False
- pbar.close()
- hypernetwork.eval()
- #report_statistics(loss_dict)
- sd_hijack_checkpoint.remove()
-
-
-
- filename = os.path.join(shared.opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
- hypernetwork.optimizer_name = optimizer_name
- if shared.opts.save_optimizer_state:
- hypernetwork.optimizer_state_dict = optimizer.state_dict()
- save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
-
- del optimizer
- hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
- shared.sd_model.cond_stage_model.to(devices.device)
- shared.sd_model.first_stage_model.to(devices.device)
- shared.parallel_processing_allowed = old_parallel_processing_allowed
-
- return hypernetwork, filename
-
-def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
- old_hypernetwork_name = hypernetwork.name
- old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
- old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None
- try:
- hypernetwork.sd_checkpoint = checkpoint.shorthash
- hypernetwork.sd_checkpoint_name = checkpoint.model_name
- hypernetwork.name = hypernetwork_name
- hypernetwork.save(filename)
- except Exception:
- hypernetwork.sd_checkpoint = old_sd_checkpoint
- hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name
- hypernetwork.name = old_hypernetwork_name
- raise
+import datetime
+import html
+import os
+from collections import deque
+import inspect
+from statistics import stdev, mean
+from rich import progress
+import tqdm
+import torch
+from torch import einsum
+from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
+from einops import rearrange, repeat
+from ldm.util import default
+from modules import devices, processing, sd_models, shared, hashes, errors
+import modules.textual_inversion.dataset
+from modules.textual_inversion import textual_inversion, ti_logging
+from modules.textual_inversion.learn_schedule import LearnRateScheduler
+
+
+optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
+
+class HypernetworkModule(torch.nn.Module):
+ activation_dict = {
+ "linear": torch.nn.Identity,
+ "relu": torch.nn.ReLU,
+ "leakyrelu": torch.nn.LeakyReLU,
+ "elu": torch.nn.ELU,
+ "swish": torch.nn.Hardswish,
+ "tanh": torch.nn.Tanh,
+ "sigmoid": torch.nn.Sigmoid,
+ }
+ activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
+
+ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
+ add_layer_norm=False, activate_output=False, dropout_structure=None):
+ super().__init__()
+ self.multiplier = 1.0
+ assert layer_structure is not None, "layer_structure must not be None"
+ assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
+ assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
+ linears = []
+ for i in range(len(layer_structure) - 1):
+ # Add a fully-connected layer
+ linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
+ # Add an activation func except last layer
+ if activation_func == "linear" or activation_func is None or (i >= len(layer_structure) - 2 and not activate_output):
+ pass
+ elif activation_func in self.activation_dict:
+ linears.append(self.activation_dict[activation_func]())
+ else:
+ raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
+ # Add layer normalization
+ if add_layer_norm:
+ linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
+ # Everything should be now parsed into dropout structure, and applied here.
+ # Since we only have dropouts after layers, dropout structure should start with 0 and end with 0.
+ if dropout_structure is not None and dropout_structure[i+1] > 0:
+ assert 0 < dropout_structure[i+1] < 1, "Dropout probability should be 0 or float between 0 and 1!"
+ linears.append(torch.nn.Dropout(p=dropout_structure[i+1]))
+ # Code explanation : [1, 2, 1] -> dropout is missing when last_layer_dropout is false. [1, 2, 2, 1] -> [0, 0.3, 0, 0], when its True, [0, 0.3, 0.3, 0].
+ self.linear = torch.nn.Sequential(*linears)
+ if state_dict is not None:
+ self.fix_old_state_dict(state_dict)
+ self.load_state_dict(state_dict)
+ else:
+ for layer in self.linear:
+ if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
+ w, b = layer.weight.data, layer.bias.data
+ if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm:
+ normal_(w, mean=0.0, std=0.01)
+ normal_(b, mean=0.0, std=0)
+ elif weight_init == 'XavierUniform':
+ xavier_uniform_(w)
+ zeros_(b)
+ elif weight_init == 'XavierNormal':
+ xavier_normal_(w)
+ zeros_(b)
+ elif weight_init == 'KaimingUniform':
+ kaiming_uniform_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
+ zeros_(b)
+ elif weight_init == 'KaimingNormal':
+ kaiming_normal_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
+ zeros_(b)
+ else:
+ raise KeyError(f"Key {weight_init} is not defined as initialization!")
+ self.to(devices.device)
+
+ def fix_old_state_dict(self, state_dict):
+ changes = {
+ 'linear1.bias': 'linear.0.bias',
+ 'linear1.weight': 'linear.0.weight',
+ 'linear2.bias': 'linear.1.bias',
+ 'linear2.weight': 'linear.1.weight',
+ }
+ for fr, to in changes.items():
+ x = state_dict.get(fr, None)
+ if x is None:
+ continue
+ del state_dict[fr]
+ state_dict[to] = x
+
+ def forward(self, x):
+ return x + self.linear(x) * (self.multiplier if not self.training else 1)
+
+ def trainables(self):
+ layer_structure = []
+ for layer in self.linear:
+ if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
+ layer_structure += [layer.weight, layer.bias]
+ return layer_structure
+
+
+#param layer_structure : sequence used for length, use_dropout : controlling boolean, last_layer_dropout : for compatibility check.
+def parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout):
+ if layer_structure is None:
+ layer_structure = [1, 2, 1]
+ if not use_dropout:
+ return [0] * len(layer_structure)
+ dropout_values = [0]
+ dropout_values.extend([0.3] * (len(layer_structure) - 3))
+ if last_layer_dropout:
+ dropout_values.append(0.3)
+ else:
+ dropout_values.append(0)
+ dropout_values.append(0)
+ return dropout_values
+
+
+class Hypernetwork:
+ filename = None
+ name = None
+
+ def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False, **kwargs):
+ self.filename = None
+ self.name = name
+ self.layers = {}
+ self.step = 0
+ self.sd_checkpoint = None
+ self.sd_checkpoint_name = None
+ self.layer_structure = layer_structure
+ self.activation_func = activation_func
+ self.weight_init = weight_init
+ self.add_layer_norm = add_layer_norm
+ self.use_dropout = use_dropout
+ self.activate_output = activate_output
+ self.last_layer_dropout = kwargs.get('last_layer_dropout', True)
+ self.dropout_structure = kwargs.get('dropout_structure', None)
+ if self.dropout_structure is None:
+ self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
+ self.optimizer_name = None
+ self.optimizer_state_dict = None
+ self.optional_info = None
+ for size in enable_sizes or []:
+ self.layers[size] = (
+ HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
+ self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),
+ HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
+ self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),
+ )
+ self.eval()
+
+ def weights(self):
+ res = []
+ for layers in self.layers.values():
+ for layer in layers:
+ res += layer.parameters()
+ return res
+
+ def train(self, mode=True):
+ for layers in self.layers.values():
+ for layer in layers:
+ layer.train(mode=mode)
+ for param in layer.parameters():
+ param.requires_grad = mode
+
+ def to(self, device):
+ for layers in self.layers.values():
+ for layer in layers:
+ layer.to(device)
+
+ return self
+
+ def set_multiplier(self, multiplier):
+ for layers in self.layers.values():
+ for layer in layers:
+ layer.multiplier = multiplier
+
+ return self
+
+ def eval(self):
+ for layers in self.layers.values():
+ for layer in layers:
+ layer.eval()
+ for param in layer.parameters():
+ param.requires_grad = False
+
+ def save(self, filename):
+ state_dict = {}
+ optimizer_saved_dict = {}
+ for k, v in self.layers.items():
+ state_dict[k] = (v[0].state_dict(), v[1].state_dict())
+ state_dict['step'] = self.step
+ state_dict['name'] = self.name
+ state_dict['layer_structure'] = self.layer_structure
+ state_dict['activation_func'] = self.activation_func
+ state_dict['is_layer_norm'] = self.add_layer_norm
+ state_dict['weight_initialization'] = self.weight_init
+ state_dict['sd_checkpoint'] = self.sd_checkpoint
+ state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
+ state_dict['activate_output'] = self.activate_output
+ state_dict['use_dropout'] = self.use_dropout
+ state_dict['dropout_structure'] = self.dropout_structure
+ state_dict['last_layer_dropout'] = (self.dropout_structure[-2] != 0) if self.dropout_structure is not None else self.last_layer_dropout
+ state_dict['optional_info'] = self.optional_info if self.optional_info else None
+ if self.optimizer_name is not None:
+ optimizer_saved_dict['optimizer_name'] = self.optimizer_name
+ torch.save(state_dict, filename)
+ if shared.opts.save_optimizer_state and self.optimizer_state_dict:
+ optimizer_saved_dict['hash'] = self.shorthash()
+ optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
+ torch.save(optimizer_saved_dict, f"{filename}.optim")
+
+ def load(self, filename):
+ self.filename = filename if os.path.exists(filename) else os.path.join(shared.opts.hypernetwork_dir, filename)
+ if self.name is None:
+ self.name = os.path.splitext(os.path.basename(self.filename))[0]
+ with progress.open(self.filename, 'rb', description=f'Load hypernetwork: [cyan]{self.filename}', auto_refresh=True, console=shared.console) as f:
+ state_dict = torch.load(f, map_location='cpu')
+ self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
+ self.optional_info = state_dict.get('optional_info', None)
+ self.activation_func = state_dict.get('activation_func', None)
+ self.weight_init = state_dict.get('weight_initialization', 'Normal')
+ self.add_layer_norm = state_dict.get('is_layer_norm', False)
+ self.dropout_structure = state_dict.get('dropout_structure', None)
+ self.use_dropout = True if self.dropout_structure is not None and any(self.dropout_structure) else state_dict.get('use_dropout', False)
+ self.activate_output = state_dict.get('activate_output', True)
+ self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
+ # Dropout structure should have same length as layer structure, Every digits should be in [0,1), and last digit must be 0.
+ if self.dropout_structure is None:
+ self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
+ if shared.opts.print_hypernet_extra:
+ if self.optional_info is not None:
+ print(f" INFO:\n {self.optional_info}\n")
+ print(f" Layer structure: {self.layer_structure}")
+ print(f" Activation function: {self.activation_func}")
+ print(f" Weight initialization: {self.weight_init}")
+ print(f" Layer norm: {self.add_layer_norm}")
+ print(f" Dropout usage: {self.use_dropout}" )
+ print(f" Activate last layer: {self.activate_output}")
+ print(f" Dropout structure: {self.dropout_structure}")
+ optimizer_saved_dict = torch.load(self.filename + '.optim', map_location='cpu') if os.path.exists(self.filename + '.optim') else {}
+ if self.shorthash() == optimizer_saved_dict.get('hash', None):
+ self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
+ else:
+ self.optimizer_state_dict = None
+ if self.optimizer_state_dict:
+ self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
+ if shared.opts.print_hypernet_extra:
+ print("Load existing optimizer from checkpoint")
+ print(f"Optimizer name is {self.optimizer_name}")
+ else:
+ self.optimizer_name = "AdamW"
+ if shared.opts.print_hypernet_extra:
+ print("No saved optimizer exists in checkpoint")
+ for size, sd in state_dict.items():
+ if type(size) == int:
+ self.layers[size] = (
+ HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
+ self.add_layer_norm, self.activate_output, self.dropout_structure),
+ HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
+ self.add_layer_norm, self.activate_output, self.dropout_structure),
+ )
+ self.name = state_dict.get('name', self.name)
+ self.step = state_dict.get('step', 0)
+ self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
+ self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
+ self.eval()
+
+ def shorthash(self):
+ sha256 = hashes.sha256(self.filename, f'hypernet/{self.name}')
+ return sha256[0:10] if sha256 else None
+
+
+def list_hypernetworks(path):
+ res = {}
+ def list_folder(folder):
+ for filename in os.listdir(folder):
+ fn = os.path.join(folder, filename)
+ if os.path.isfile(fn) and fn.lower().endswith(".pt"):
+ name = os.path.splitext(os.path.basename(fn))[0]
+ res[name] = fn
+ elif os.path.isdir(fn) and not fn.startswith('.'):
+ list_folder(fn)
+
+ list_folder(path)
+ return res
+
+
+def load_hypernetwork(name):
+ path = shared.hypernetworks.get(name, None)
+ if path is None:
+ return None
+ hypernetwork = Hypernetwork()
+ try:
+ hypernetwork.load(path)
+ except Exception as e:
+ errors.display(e, f'hypernetwork load: {path}')
+ return None
+ return hypernetwork
+
+
+def load_hypernetworks(names, multipliers=None):
+ already_loaded = {}
+ for hypernetwork in shared.loaded_hypernetworks:
+ if hypernetwork.name in names:
+ already_loaded[hypernetwork.name] = hypernetwork
+ shared.loaded_hypernetworks.clear()
+ for i, name in enumerate(names):
+ hypernetwork = already_loaded.get(name, None)
+ if hypernetwork is None:
+ hypernetwork = load_hypernetwork(name)
+ if hypernetwork is None:
+ continue
+ hypernetwork.set_multiplier(multipliers[i] if multipliers else 1.0)
+ shared.loaded_hypernetworks.append(hypernetwork)
+
+
+def find_closest_hypernetwork_name(search: str):
+ if not search:
+ return None
+ search = search.lower()
+ applicable = [name for name in shared.hypernetworks if search in name.lower()]
+ if not applicable:
+ return None
+ applicable = sorted(applicable, key=lambda name: len(name))
+ return applicable[0]
+
+
+def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
+ hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
+ if hypernetwork_layers is None:
+ return context_k, context_v
+ if layer is not None:
+ layer.hyper_k = hypernetwork_layers[0]
+ layer.hyper_v = hypernetwork_layers[1]
+ context_k = devices.cond_cast_unet(hypernetwork_layers[0](devices.cond_cast_float(context_k)))
+ context_v = devices.cond_cast_unet(hypernetwork_layers[1](devices.cond_cast_float(context_v)))
+ return context_k, context_v
+
+
+def apply_hypernetworks(hypernetworks, context, layer=None):
+ context_k = context
+ context_v = context
+ for hypernetwork in hypernetworks:
+ context_k, context_v = apply_single_hypernetwork(hypernetwork, context_k, context_v, layer)
+ return context_k, context_v
+
+
+def attention_CrossAttention_forward(self, x, context=None, mask=None):
+ h = self.heads
+ q = self.to_q(x)
+ context = default(context, x)
+ context_k, context_v = apply_hypernetworks(shared.loaded_hypernetworks, context, self)
+ k = self.to_k(context_k)
+ v = self.to_v(context_v)
+ q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+ if mask is not None:
+ mask = rearrange(mask, 'b ... -> b (...)')
+ max_neg_value = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
+ sim.masked_fill_(~mask, max_neg_value)
+ # attention, what we cannot get enough of
+ attn = sim.softmax(dim=-1)
+ out = einsum('b i j, b j d -> b i d', attn, v)
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
+ return self.to_out(out)
+
+
+def stack_conds(conds):
+ if len(conds) == 1:
+ return torch.stack(conds)
+ # same as in reconstruct_multicond_batch
+ token_count = max([x.shape[0] for x in conds])
+ for i in range(len(conds)):
+ if conds[i].shape[0] != token_count:
+ last_vector = conds[i][-1:]
+ last_vector_repeated = last_vector.repeat([token_count - conds[i].shape[0], 1])
+ conds[i] = torch.vstack([conds[i], last_vector_repeated])
+ return torch.stack(conds)
+
+
+def statistics(data):
+ if len(data) < 2:
+ std = 0
+ else:
+ std = stdev(data)
+ total_information = f"loss:{mean(data):.3f}" + "\u00B1" + f"({std/ (len(data) ** 0.5):.3f})"
+ recent_data = data[-32:]
+ if len(recent_data) < 2:
+ std = 0
+ else:
+ std = stdev(recent_data)
+ recent_information = f"recent 32 loss:{mean(recent_data):.3f}" + "\u00B1" + f"({std / (len(recent_data) ** 0.5):.3f})"
+ return total_information, recent_information
+
+
+def report_statistics(loss_info:dict):
+ keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
+ for key in keys:
+ try:
+ print("Loss statistics for file " + key)
+ info, recent = statistics(list(loss_info[key]))
+ print(info)
+ print(recent)
+ except Exception as e:
+ print(e)
+
+
+def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
+ # Remove illegal characters from name.
+ name = "".join( x for x in name if (x.isalnum() or x in "._- "))
+ assert name, "Name cannot be empty!"
+ fn = os.path.join(shared.opts.hypernetwork_dir, f"{name}.pt")
+ if not overwrite_old:
+ assert not os.path.exists(fn), f"file {fn} already exists"
+ if type(layer_structure) == str:
+ layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
+ if use_dropout and dropout_structure and type(dropout_structure) == str:
+ dropout_structure = [float(x.strip()) for x in dropout_structure.split(",")]
+ else:
+ dropout_structure = [0] * len(layer_structure)
+ hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
+ name=name,
+ enable_sizes=[int(x) for x in enable_sizes],
+ layer_structure=layer_structure,
+ activation_func=activation_func,
+ weight_init=weight_init,
+ add_layer_norm=add_layer_norm,
+ use_dropout=use_dropout,
+ dropout_structure=dropout_structure
+ )
+ hypernet.save(fn)
+ shared.reload_hypernetworks()
+ return name
+
+
+def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): # pylint: disable=unused-argument
+ # images allows training previews to have infotext. Importing it at the top causes a circular import problem.
+ from modules import images, sd_hijack_checkpoint
+
+ save_hypernetwork_every = save_hypernetwork_every or 0
+ create_image_every = create_image_every or 0
+ template_file = textual_inversion.textual_inversion_templates.get(template_filename, None)
+ textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_hypernetwork_every, create_image_every, name="hypernetwork")
+ template_file = template_file.path
+
+ path = shared.hypernetworks.get(hypernetwork_name, None)
+ hypernetwork = Hypernetwork()
+ hypernetwork.load(path)
+ shared.loaded_hypernetworks = [hypernetwork]
+
+ shared.state.job = "train"
+ shared.state.textinfo = "Initializing hypernetwork training..."
+ shared.state.job_count = steps
+
+ hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
+ filename = os.path.join(shared.opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
+
+ log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
+ unload = shared.opts.unload_models_when_training
+
+ if save_hypernetwork_every > 0:
+ hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
+ os.makedirs(hypernetwork_dir, exist_ok=True)
+ else:
+ hypernetwork_dir = None
+
+ if create_image_every > 0:
+ images_dir = os.path.join(log_directory, "images")
+ os.makedirs(images_dir, exist_ok=True)
+ else:
+ images_dir = None
+
+ checkpoint = sd_models.select_checkpoint()
+
+ initial_step = hypernetwork.step or 0
+ if initial_step >= steps:
+ shared.state.textinfo = "Model has already been trained beyond specified max steps"
+ return hypernetwork, filename
+
+ scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
+
+ clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None
+ if clip_grad:
+ clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
+
+ if shared.opts.training_enable_tensorboard:
+ tensorboard_writer = textual_inversion.tensorboard_setup(log_directory)
+
+ # dataset loading may take a while, so input validations and early returns should be done before this
+ shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
+
+ pin_memory = shared.opts.pin_memory
+
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize, use_weight=use_weight)
+
+ if shared.opts.save_training_settings_to_txt:
+ saved_params = dict(
+ model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds),
+ **{field: getattr(hypernetwork, field) for field in ['layer_structure', 'activation_func', 'weight_init', 'add_layer_norm', 'use_dropout', ]}
+ )
+ ti_logging.save_settings_to_file(log_directory, {**saved_params, **locals()})
+
+ latent_sampling_method = ds.latent_sampling_method
+
+ dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
+
+ old_parallel_processing_allowed = shared.parallel_processing_allowed
+
+ if unload:
+ shared.parallel_processing_allowed = False
+ shared.sd_model.cond_stage_model.to(devices.cpu)
+ shared.sd_model.first_stage_model.to(devices.cpu)
+
+ weights = hypernetwork.weights()
+ hypernetwork.train()
+
+ # Here we use optimizer from saved HN, or we can specify as UI option.
+ if hypernetwork.optimizer_name in optimizer_dict:
+ optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
+ optimizer_name = hypernetwork.optimizer_name
+ else:
+ print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
+ optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
+ optimizer_name = 'AdamW'
+
+ if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
+ try:
+ optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
+ except RuntimeError as e:
+ print("Cannot resume from saved optimizer!")
+ print(e)
+
+ scaler = torch.cuda.amp.GradScaler()
+
+ batch_size = ds.batch_size
+ gradient_step = ds.gradient_step
+ # n steps = batch_size * gradient_step * n image processed
+ steps_per_epoch = len(ds) // batch_size // gradient_step
+ max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
+ loss_step = 0
+ _loss_step = 0 #internal
+ # size = len(ds.indexes)
+ # loss_dict = defaultdict(lambda : deque(maxlen = 1024))
+ loss_logging = deque(maxlen=len(ds) * 3) # this should be configurable parameter, this is 3 * epoch(dataset size)
+ # losses = torch.zeros((size,))
+ # previous_mean_losses = [0]
+ # previous_mean_loss = 0
+ # print("Mean loss of {} elements".format(size))
+
+ _steps_without_grad = 0
+
+ last_saved_file = ""
+ last_saved_image = ""
+ forced_filename = ""
+
+ pbar = tqdm.tqdm(total=steps - initial_step)
+ try:
+ sd_hijack_checkpoint.add()
+
+ for _i in range((steps-initial_step) * gradient_step):
+ if scheduler.finished:
+ break
+ if shared.state.interrupted:
+ break
+ for j, batch in enumerate(dl):
+ # works as a drop_last=True for gradient accumulation
+ if j == max_steps_per_epoch:
+ break
+ scheduler.apply(optimizer, hypernetwork.step)
+ if scheduler.finished:
+ break
+ if shared.state.interrupted:
+ break
+
+ if clip_grad:
+ clip_grad_sched.step(hypernetwork.step)
+
+ with devices.autocast():
+ x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
+ if use_weight:
+ w = batch.weight.to(devices.device, non_blocking=pin_memory)
+ if tag_drop_out != 0 or shuffle_tags:
+ shared.sd_model.cond_stage_model.to(devices.device)
+ c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)
+ shared.sd_model.cond_stage_model.to(devices.cpu)
+ else:
+ c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
+ if use_weight:
+ loss = shared.sd_model.weighted_forward(x, c, w)[0] / gradient_step
+ del w
+ else:
+ loss = shared.sd_model.forward(x, c)[0] / gradient_step
+ del x
+ del c
+ _loss_step += loss.item()
+
+ scaler.scale(loss).backward()
+ # go back until we reach gradient accumulation steps
+ if (j + 1) % gradient_step != 0:
+ continue
+ loss_logging.append(_loss_step)
+ if clip_grad:
+ clip_grad(weights, clip_grad_sched.learn_rate)
+
+ scaler.step(optimizer)
+ scaler.update()
+ hypernetwork.step += 1
+ pbar.update()
+ optimizer.zero_grad(set_to_none=True)
+ loss_step = _loss_step
+ _loss_step = 0
+ steps_done = hypernetwork.step + 1
+ epoch_num = hypernetwork.step // steps_per_epoch
+ epoch_step = hypernetwork.step % steps_per_epoch
+
+ description = f"Training hypernetwork [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}"
+ pbar.set_description(description)
+ if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
+ # Before saving, change name to match current checkpoint.
+ hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
+ last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
+ hypernetwork.optimizer_name = optimizer_name
+ if shared.opts.save_optimizer_state:
+ hypernetwork.optimizer_state_dict = optimizer.state_dict()
+ save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
+ hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
+
+
+
+ if shared.opts.training_enable_tensorboard:
+ epoch_num = hypernetwork.step // len(ds)
+ epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1
+ mean_loss = sum(loss_logging) / len(loss_logging)
+ textual_inversion.tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, learn_rate=scheduler.learn_rate, epoch_num=epoch_num)
+
+ textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
+ "loss": f"{loss_step:.7f}",
+ "learn_rate": scheduler.learn_rate
+ })
+
+ if images_dir is not None and steps_done % create_image_every == 0:
+ forced_filename = f'{hypernetwork_name}-{steps_done}'
+ last_saved_image = os.path.join(images_dir, forced_filename)
+ hypernetwork.eval()
+ rng_state = torch.get_rng_state()
+ cuda_rng_state = None
+ cuda_rng_state = torch.cuda.get_rng_state_all()
+ shared.sd_model.cond_stage_model.to(devices.device)
+ shared.sd_model.first_stage_model.to(devices.device)
+
+ p = processing.StableDiffusionProcessingTxt2Img(
+ sd_model=shared.sd_model,
+ do_not_save_grid=True,
+ do_not_save_samples=True,
+ )
+
+ p.disable_extra_networks = True
+
+ if preview_from_txt2img:
+ p.prompt = preview_prompt
+ p.negative_prompt = preview_negative_prompt
+ p.steps = preview_steps
+ p.sampler_name = processing.get_sampler_name(preview_sampler_index)
+ p.cfg_scale = preview_cfg_scale
+ p.seed = preview_seed
+ p.width = preview_width
+ p.height = preview_height
+ else:
+ p.prompt = batch.cond_text[0]
+ p.steps = 20
+ p.width = training_width
+ p.height = training_height
+
+ preview_text = p.prompt
+
+ processed = processing.process_images(p)
+ image = processed.images[0] if len(processed.images) > 0 else None
+
+ if unload:
+ shared.sd_model.cond_stage_model.to(devices.cpu)
+ shared.sd_model.first_stage_model.to(devices.cpu)
+ torch.set_rng_state(rng_state)
+ torch.cuda.set_rng_state_all(cuda_rng_state)
+ hypernetwork.train()
+ if image is not None:
+ shared.state.assign_current_image(image)
+ if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
+ textual_inversion.tensorboard_add_image(tensorboard_writer,
+ f"Validation at epoch {epoch_num}", image,
+ hypernetwork.step)
+ last_saved_image, _last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
+ last_saved_image += f", prompt: {preview_text}"
+
+ shared.state.job_no = hypernetwork.step
+
+ shared.state.textinfo = f"""
+
+Loss: {loss_step:.7f}
+Step: {steps_done}
+Last prompt: {html.escape(batch.cond_text[0])}
+Last saved hypernetwork: {html.escape(last_saved_file)}
+Last saved image: {html.escape(last_saved_image)}
+
+"""
+ except Exception as e:
+ errors.display(e, 'hypernetwork train')
+ finally:
+ pbar.leave = False
+ pbar.close()
+ hypernetwork.eval()
+ #report_statistics(loss_dict)
+ sd_hijack_checkpoint.remove()
+
+
+
+ filename = os.path.join(shared.opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
+ hypernetwork.optimizer_name = optimizer_name
+ if shared.opts.save_optimizer_state:
+ hypernetwork.optimizer_state_dict = optimizer.state_dict()
+ save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
+
+ del optimizer
+ hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
+ shared.sd_model.cond_stage_model.to(devices.device)
+ shared.sd_model.first_stage_model.to(devices.device)
+ shared.parallel_processing_allowed = old_parallel_processing_allowed
+
+ return hypernetwork, filename
+
+def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
+ old_hypernetwork_name = hypernetwork.name
+ old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
+ old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None
+ try:
+ hypernetwork.sd_checkpoint = checkpoint.shorthash
+ hypernetwork.sd_checkpoint_name = checkpoint.model_name
+ hypernetwork.name = hypernetwork_name
+ hypernetwork.save(filename)
+ except Exception:
+ hypernetwork.sd_checkpoint = old_sd_checkpoint
+ hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name
+ hypernetwork.name = old_hypernetwork_name
+ raise
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py
index c8b6cd5c3..42ca9b5b6 100644
--- a/modules/hypernetworks/ui.py
+++ b/modules/hypernetworks/ui.py
@@ -1,31 +1,31 @@
-import html
-import gradio as gr
-import modules.hypernetworks.hypernetwork
-from modules import devices, sd_hijack, shared
-
-not_available = ["hardswish", "multiheadattention"]
-keys = [x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available]
-
-
-def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
- filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
- return gr.Dropdown.update(choices=sorted(shared.hypernetworks)), f"Created: {filename}", ""
-
-
-def train_hypernetwork(*args):
- shared.loaded_hypernetworks = []
- assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
- try:
- sd_hijack.undo_optimizations()
- hypernetwork, filename = modules.hypernetworks.hypernetwork.train_hypernetwork(*args)
- res = f"""
-Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps.
-Hypernetwork saved to {html.escape(filename)}
-"""
- return res, ""
- except Exception as e:
- raise RuntimeError("Hypernetwork error") from e
- finally:
- shared.sd_model.cond_stage_model.to(devices.device)
- shared.sd_model.first_stage_model.to(devices.device)
- sd_hijack.apply_optimizations()
+import html
+import gradio as gr
+import modules.hypernetworks.hypernetwork
+from modules import devices, sd_hijack, shared
+
+not_available = ["hardswish", "multiheadattention"]
+keys = [x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available]
+
+
+def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
+ filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
+ return gr.Dropdown.update(choices=sorted(shared.hypernetworks)), f"Created: {filename}", ""
+
+
+def train_hypernetwork(*args):
+ shared.loaded_hypernetworks = []
+ assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
+ try:
+ sd_hijack.undo_optimizations()
+ hypernetwork, filename = modules.hypernetworks.hypernetwork.train_hypernetwork(*args)
+ res = f"""
+Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps.
+Hypernetwork saved to {html.escape(filename)}
+"""
+ return res, ""
+ except Exception as e:
+ raise RuntimeError("Hypernetwork error") from e
+ finally:
+ shared.sd_model.cond_stage_model.to(devices.device)
+ shared.sd_model.first_stage_model.to(devices.device)
+ sd_hijack.apply_optimizations()
diff --git a/modules/images.py b/modules/images.py
index e117b6860..083b2e5ad 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -1,867 +1,867 @@
-import io
-import re
-import os
-import sys
-import math
-import json
-import uuid
-import queue
-import string
-import hashlib
-import datetime
-import threading
-from pathlib import Path
-from collections import namedtuple
-import numpy as np
-import piexif
-import piexif.helper
-from PIL import Image, ImageFont, ImageDraw, PngImagePlugin, ExifTags
-from modules import sd_samplers, shared, script_callbacks, errors, paths
-
-
-debug = errors.log.trace if os.environ.get('SD_PATH_DEBUG', None) is not None else lambda *args, **kwargs: None
-try:
- from pi_heif import register_heif_opener
- register_heif_opener()
-except Exception:
- pass
-
-
-def check_grid_size(imgs):
- mp = 0
- for img in imgs:
- mp += img.width * img.height
- mp = round(mp / 1000000)
- ok = mp <= shared.opts.img_max_size_mp
- if not ok:
- shared.log.warning(f'Maximum image size exceded: size={mp} maximum={shared.opts.img_max_size_mp} MPixels')
- return ok
-
-
-def image_grid(imgs, batch_size=1, rows=None):
- if rows is None:
- if shared.opts.n_rows > 0:
- rows = shared.opts.n_rows
- elif shared.opts.n_rows == 0:
- rows = batch_size
- else:
- rows = math.floor(math.sqrt(len(imgs)))
- while len(imgs) % rows != 0:
- rows -= 1
- if rows > len(imgs):
- rows = len(imgs)
- cols = math.ceil(len(imgs) / rows)
- params = script_callbacks.ImageGridLoopParams(imgs, cols, rows)
- script_callbacks.image_grid_callback(params)
- w, h = imgs[0].size
- grid = Image.new('RGB', size=(params.cols * w, params.rows * h), color=shared.opts.grid_background)
- for i, img in enumerate(params.imgs):
- grid.paste(img, box=(i % params.cols * w, i // params.cols * h))
- return grid
-
-
-Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])
-
-
-def split_grid(image, tile_w=512, tile_h=512, overlap=64):
- w = image.width
- h = image.height
- non_overlap_width = tile_w - overlap
- non_overlap_height = tile_h - overlap
- cols = math.ceil((w - overlap) / non_overlap_width)
- rows = math.ceil((h - overlap) / non_overlap_height)
- dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
- dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
- grid = Grid([], tile_w, tile_h, w, h, overlap)
- for row in range(rows):
- row_images = []
- y = int(row * dy)
- if y + tile_h >= h:
- y = h - tile_h
- for col in range(cols):
- x = int(col * dx)
- if x + tile_w >= w:
- x = w - tile_w
- tile = image.crop((x, y, x + tile_w, y + tile_h))
- row_images.append([x, tile_w, tile])
- grid.tiles.append([y, tile_h, row_images])
- return grid
-
-
-def combine_grid(grid):
- def make_mask_image(r):
- r = r * 255 / grid.overlap
- r = r.astype(np.uint8)
- return Image.fromarray(r, 'L')
-
- mask_w = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0))
- mask_h = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1))
- combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
- for y, h, row in grid.tiles:
- combined_row = Image.new("RGB", (grid.image_w, h))
- for x, w, tile in row:
- if x == 0:
- combined_row.paste(tile, (0, 0))
- continue
- combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w)
- combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0))
- if y == 0:
- combined_image.paste(combined_row, (0, 0))
- continue
- combined_image.paste(combined_row.crop((0, 0, combined_row.width, grid.overlap)), (0, y), mask=mask_h)
- combined_image.paste(combined_row.crop((0, grid.overlap, combined_row.width, h)), (0, y + grid.overlap))
- return combined_image
-
-
-class GridAnnotation:
- def __init__(self, text='', is_active=True):
- self.text = text
- self.is_active = is_active
- self.size = None
-
-
-def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0, title=None):
- def wrap(drawing, text, font, line_length):
- lines = ['']
- for word in text.split():
- line = f'{lines[-1]} {word}'.strip()
- if drawing.textlength(line, font=font) <= line_length:
- lines[-1] = line
- else:
- lines.append(word)
- return lines
-
- def get_font(fontsize):
- try:
- return ImageFont.truetype(shared.opts.font or 'javascript/roboto.ttf', fontsize)
- except Exception:
- return ImageFont.truetype('javascript/roboto.ttf', fontsize)
-
- def draw_texts(drawing: ImageDraw, draw_x, draw_y, lines, initial_fnt, initial_fontsize):
- for line in lines:
- font = initial_fnt
- fontsize = initial_fontsize
- while drawing.multiline_textbbox((0,0), text=line.text, font=font)[0] > line.allowed_width and fontsize > 0:
- fontsize -= 1
- font = get_font(fontsize)
- drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=font, fill=shared.opts.font_color if line.is_active else color_inactive, anchor="mm", align="center")
- if not line.is_active:
- drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2, draw_y + line.size[1] // 2), fill=color_inactive, width=4)
- draw_y += line.size[1] + line_spacing
-
- fontsize = (width + height) // 25
- line_spacing = fontsize // 2
- font = get_font(fontsize)
- color_inactive = (127, 127, 127)
- pad_left = 0 if sum([sum([len(line.text) for line in lines]) for lines in ver_texts]) == 0 else width * 3 // 4
- cols = im.width // width
- rows = im.height // height
- assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}'
- assert rows == len(ver_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}'
- calc_img = Image.new("RGB", (1, 1), shared.opts.grid_background)
- calc_d = ImageDraw.Draw(calc_img)
- title_texts = [title] if title else [[GridAnnotation()]]
- for texts, allowed_width in zip(hor_texts + ver_texts + title_texts, [width] * len(hor_texts) + [pad_left] * len(ver_texts) + [(width+margin)*cols]):
- items = [] + texts
- texts.clear()
- for line in items:
- wrapped = wrap(calc_d, line.text, font, allowed_width)
- texts += [GridAnnotation(x, line.is_active) for x in wrapped]
- for line in texts:
- bbox = calc_d.multiline_textbbox((0, 0), line.text, font=font)
- line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1])
- line.allowed_width = allowed_width
- hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
- ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_texts]
- pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2
- title_pad = 0
- if title:
- title_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in title_texts] # pylint: disable=unsubscriptable-object
- title_pad = 0 if sum(title_text_heights) == 0 else max(title_text_heights) + line_spacing * 2
- result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + title_pad + margin * (rows-1)), shared.opts.grid_background)
- for row in range(rows):
- for col in range(cols):
- cell = im.crop((width * col, height * row, width * (col+1), height * (row+1)))
- result.paste(cell, (pad_left + (width + margin) * col, pad_top + title_pad + (height + margin) * row))
- d = ImageDraw.Draw(result)
- if title:
- x = pad_left + ((width+margin)*cols) / 2
- y = title_pad / 2 - title_text_heights[0] / 2
- draw_texts(d, x, y, title_texts[0], font, fontsize)
- for col in range(cols):
- x = pad_left + (width + margin) * col + width / 2
- y = (pad_top / 2 - hor_text_heights[col] / 2) + title_pad
- draw_texts(d, x, y, hor_texts[col], font, fontsize)
- for row in range(rows):
- x = pad_left / 2
- y = (pad_top + (height + margin) * row + height / 2 - ver_text_heights[row] / 2) + title_pad
- draw_texts(d, x, y, ver_texts[row], font, fontsize)
- return result
-
-
-def draw_prompt_matrix(im, width, height, all_prompts, margin=0):
- prompts = all_prompts[1:]
- boundary = math.ceil(len(prompts) / 2)
- prompts_horiz = prompts[:boundary]
- prompts_vert = prompts[boundary:]
- hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))]
- ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))]
- return draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin)
-
-
-def resize_image(resize_mode, im, width, height, upscaler_name=None, output_type='image'):
- shared.log.debug(f'Image resize: mode={resize_mode} resolution={width}x{height} upscaler={upscaler_name} function={sys._getframe(1).f_code.co_name}') # pylint: disable=protected-access
- """
- Resizes an image with the specified resize_mode, width, and height.
- Args:
- resize_mode: The mode to use when resizing the image.
- 0: No resize
- 1: Resize the image to the specified width and height.
- 2: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
- 3: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
- im: The image to resize.
- width: The width to resize the image to.
- height: The height to resize the image to.
- upscaler_name: The name of the upscaler to use. If not provided, defaults to opts.upscaler_for_img2img.
- """
- upscaler_name = upscaler_name or shared.opts.upscaler_for_img2img
-
- def latent(im, w, h, upscaler):
- from modules.processing_vae import vae_encode, vae_decode
- import torch
- latents = vae_encode(im, shared.sd_model, full_quality=False) # TODO enable full VAE mode
- latents = torch.nn.functional.interpolate(latents, size=(h // 8, w // 8), mode=upscaler["mode"], antialias=upscaler["antialias"])
- im = vae_decode(latents, shared.sd_model, output_type='pil', full_quality=False)[0]
- return im
-
- def resize(im, w, h):
- if upscaler_name is None or upscaler_name == "None" or im.mode == 'L':
- return im.resize((w, h), resample=Image.Resampling.LANCZOS)
- scale = max(w / im.width, h / im.height)
- if scale > 1.0:
- upscalers = [x for x in shared.sd_upscalers if x.name == upscaler_name]
- if len(upscalers) > 0:
- upscaler = upscalers[0]
- im = upscaler.scaler.upscale(im, scale, upscaler.data_path)
- else:
- upscaler = shared.latent_upscale_modes.get(upscaler_name, None)
- if upscaler is not None:
- im = latent(im, w, h, upscaler)
- else:
- shared.log.warning(f"Could not find upscaler: {upscaler_name or ''} using fallback: {upscaler.name}")
- if im.width != w or im.height != h:
- im = im.resize((w, h), resample=Image.Resampling.LANCZOS)
- return im
-
- if resize_mode == 0 or (im.width == width and im.height == height):
- res = im.copy()
- elif resize_mode == 1:
- res = resize(im, width, height)
- elif resize_mode == 2:
- ratio = width / height
- src_ratio = im.width / im.height
- src_w = width if ratio > src_ratio else im.width * height // im.height
- src_h = height if ratio <= src_ratio else im.height * width // im.width
- resized = resize(im, src_w, src_h)
- res = Image.new(im.mode, (width, height))
- res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
- else:
- ratio = width / height
- src_ratio = im.width / im.height
- src_w = width if ratio < src_ratio else im.width * height // im.height
- src_h = height if ratio >= src_ratio else im.height * width // im.width
- resized = resize(im, src_w, src_h)
- res = Image.new(im.mode, (width, height))
- res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
- if ratio < src_ratio:
- fill_height = height // 2 - src_h // 2
- res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
- res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
- elif ratio > src_ratio:
- fill_width = width // 2 - src_w // 2
- res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
- res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
- if output_type == 'np':
- return np.array(res)
- return res
-
-
-re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
-re_pattern = re.compile(r"(.*?)(?:\[([^\[\]]+)\]|$)")
-re_pattern_arg = re.compile(r"(.*)<([^>]*)>$")
-re_attention = re.compile(r'[\(*\[*](\w+)(:\d+(\.\d+))?[\)*\]*]|')
-re_network = re.compile(r'\<\w+:(\w+)(:\d+(\.\d+))?\>|')
-re_brackets = re.compile(r'[\([{})\]]')
-
-NOTHING = object()
-
-
-class FilenameGenerator:
- replacements = {
- 'width': lambda self: self.image.width,
- 'height': lambda self: self.image.height,
- 'batch_number': lambda self: self.batch_number,
- 'iter_number': lambda self: self.iter_number,
- 'num': lambda self: NOTHING if self.p.n_iter == 1 and self.p.batch_size == 1 else self.p.iteration * self.p.batch_size + self.p.batch_index + 1,
- 'generation_number': lambda self: NOTHING if self.p.n_iter == 1 and self.p.batch_size == 1 else self.p.iteration * self.p.batch_size + self.p.batch_index + 1,
- 'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
- 'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime], [datetime]
- 'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt..]
- 'hash': lambda self: self.image_hash(),
- 'image_hash': lambda self: self.image_hash(),
- 'timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
- 'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
-
- 'model': lambda self: shared.sd_model.sd_checkpoint_info.title,
- 'model_shortname': lambda self: shared.sd_model.sd_checkpoint_info.model_name,
- 'model_name': lambda self: shared.sd_model.sd_checkpoint_info.model_name,
- 'model_hash': lambda self: shared.sd_model.sd_checkpoint_info.shorthash,
-
- 'prompt': lambda self: self.prompt_full(),
- 'prompt_no_styles': lambda self: self.prompt_no_style(),
- 'prompt_words': lambda self: self.prompt_words(),
- 'prompt_hash': lambda self: hashlib.sha256(self.prompt.encode()).hexdigest()[0:8],
-
- 'sampler': lambda self: self.p and self.p.sampler_name,
- 'seed': lambda self: self.seed and str(self.seed) or '',
- 'steps': lambda self: self.p and self.p.steps,
- 'styles': lambda self: self.p and ", ".join([style for style in self.p.styles if not style == "None"]) or "None",
- 'uuid': lambda self: str(uuid.uuid4()),
- }
- default_time_format = '%Y%m%d%H%M%S'
-
- def __init__(self, p, seed, prompt, image, grid=False):
- if p is None:
- debug('Filename generator init skip')
- else:
- debug(f'Filename generator init: {seed} {prompt}')
- self.p = p
- if seed is not None and seed > 0:
- self.seed = seed
- elif hasattr(p, 'all_seeds'):
- self.seed = p.all_seeds[0]
- else:
- self.seed = 0
- self.prompt = prompt
- self.image = image
- if not grid:
- self.batch_number = NOTHING if self.p is None or getattr(self.p, 'batch_size', 1) == 1 else (self.p.batch_index + 1 if hasattr(self.p, 'batch_index') else NOTHING)
- self.iter_number = NOTHING if self.p is None or getattr(self.p, 'n_iter', 1) == 1 else (self.p.iteration + 1 if hasattr(self.p, 'iteration') else NOTHING)
- else:
- self.batch_number = NOTHING
- self.iter_number = NOTHING
-
- def hasprompt(self, *args):
- lower = self.prompt.lower()
- if getattr(self, 'p', None) is None or getattr(self, 'prompt', None) is None:
- return None
- outres = ""
- for arg in args:
- if arg != "":
- division = arg.split("|")
- expected = division[0].lower()
- default = division[1] if len(division) > 1 else ""
- if lower.find(expected) >= 0:
- outres = f'{outres}{expected}'
- else:
- outres = outres if default == "" else f'{outres}{default}'
- return outres
-
- def image_hash(self):
- if getattr(self, 'image', None) is None:
- return None
- import base64
- from io import BytesIO
- buffered = BytesIO()
- self.image.save(buffered, format="JPEG")
- img_str = base64.b64encode(buffered.getvalue())
- shorthash = hashlib.sha256(img_str).hexdigest()[0:8]
- return shorthash
-
- def prompt_full(self):
- return self.prompt_sanitize(self.prompt)
-
- def prompt_words(self):
- if getattr(self, 'prompt', None) is None:
- return ''
- no_attention = re_attention.sub(r'\1', self.prompt)
- no_network = re_network.sub(r'\1', no_attention)
- no_brackets = re_brackets.sub('', no_network)
- words = [x for x in re_nonletters.split(no_brackets or "") if len(x) > 0]
- prompt = " ".join(words[0:shared.opts.directories_max_prompt_words])
- return self.prompt_sanitize(prompt)
-
- def prompt_no_style(self):
- if getattr(self, 'p', None) is None or getattr(self, 'prompt', None) is None:
- return None
- prompt_no_style = self.prompt
- for style in shared.prompt_styles.get_style_prompts(self.p.styles):
- if len(style) > 0:
- for part in style.split("{prompt}"):
- prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",")
- prompt_no_style = prompt_no_style.replace(style, "")
- return self.prompt_sanitize(prompt_no_style)
-
- def datetime(self, *args):
- import pytz
- time_datetime = datetime.datetime.now()
- time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format
- try:
- time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
- except pytz.exceptions.UnknownTimeZoneError:
- time_zone = None
- time_zone_time = time_datetime.astimezone(time_zone)
- try:
- formatted_time = time_zone_time.strftime(time_format)
- except (ValueError, TypeError):
- formatted_time = time_zone_time.strftime(self.default_time_format)
- return formatted_time
-
- def prompt_sanitize(self, prompt):
- invalid_chars = '#<>:\'"\\|?*\n\t\r'
- sanitized = prompt.translate({ ord(x): '_' for x in invalid_chars }).strip()
- debug(f'Prompt sanitize: input="{prompt}" output={sanitized}')
- return sanitized
-
- def sanitize(self, filename):
- invalid_chars = '\'"|?*\n\t\r' #
- invalid_folder = ':'
- invalid_files = ['CON', 'PRN', 'AUX', 'NUL', 'NULL', 'COM0', 'COM1', 'LPT0', 'LPT1']
- invalid_prefix = ', '
- invalid_suffix = '.,_ '
- fn, ext = os.path.splitext(filename)
- parts = Path(fn).parts
- newparts = []
- for i, part in enumerate(parts):
- part = part.translate({ ord(x): '_' for x in invalid_chars })
- if i > 0 or (len(part) >= 2 and part[1] != invalid_folder): # skip drive, otherwise remove
- part = part.translate({ ord(x): '_' for x in invalid_folder })
- part = part.lstrip(invalid_prefix).rstrip(invalid_suffix)
- if part in invalid_files: # reserved names
- [part := part.replace(word, '_') for word in invalid_files] # pylint: disable=expression-not-assigned
- newparts.append(part)
- fn = str(Path(*newparts))
- max_length = max(256 - len(ext), os.statvfs(__file__).f_namemax - 32 if hasattr(os, 'statvfs') else 256 - len(ext))
- while len(os.path.abspath(fn)) > max_length:
- fn = fn[:-1]
- fn += ext
- debug(f'Filename sanitize: input="{filename}" parts={parts} output="{fn}" ext={ext} max={max_length} len={len(fn)}')
- return fn
-
- def sequence(self, x, dirname, basename):
- if shared.opts.save_images_add_number or '[seq]' in x:
- if '[seq]' not in x:
- x = os.path.join(os.path.dirname(x), f"[seq]-{os.path.basename(x)}")
- basecount = get_next_sequence_number(dirname, basename)
- for i in range(9999):
- seq = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
- filename = x.replace('[seq]', seq)
- if not os.path.exists(filename):
- debug(f'Prompt sequence: input="{x}" seq={seq} output="{filename}"')
- x = filename
- break
- return x
-
- def apply(self, x):
- res = ''
- for m in re_pattern.finditer(x):
- text, pattern = m.groups()
- if pattern is None:
- res += text
- continue
- pattern_args = []
- while True:
- m = re_pattern_arg.match(pattern)
- if m is None:
- break
- pattern, arg = m.groups()
- pattern_args.insert(0, arg)
- fun = self.replacements.get(pattern.lower(), None)
- if fun is not None:
- try:
- debug(f'Filename apply: pattern={pattern.lower()} args={pattern_args}')
- replacement = fun(self, *pattern_args)
- except Exception as e:
- replacement = None
- shared.log.error(f'Filename apply pattern: {x} {e}')
- if replacement == NOTHING:
- continue
- if replacement is not None:
- res += text + str(replacement).replace('/', '-').replace('\\', '-')
- continue
- else:
- res += text + f'[{pattern}]' # reinsert unknown pattern
- return res
-
-
-def get_next_sequence_number(path, basename):
- """
- Determines and returns the next sequence number to use when saving an image in the specified directory.
- """
- result = -1
- if basename != '':
- basename = f"{basename}-"
- prefix_length = len(basename)
- if not os.path.isdir(path):
- return 0
- for p in os.listdir(path):
- if p.startswith(basename):
- parts = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
- try:
- result = max(int(parts[0]), result)
- except ValueError:
- pass
- return result + 1
-
-
-def atomically_save_image():
- Image.MAX_IMAGE_PIXELS = None # disable check in Pillow and rely on check below to allow large custom image sizes
- while True:
- image, filename, extension, params, exifinfo, filename_txt = save_queue.get()
- with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file:
- file.write(exifinfo)
- fn = filename + extension
- filename = filename.strip()
- if extension[0] != '.': # add dot if missing
- extension = '.' + extension
- try:
- image_format = Image.registered_extensions()[extension]
- except Exception:
- shared.log.warning(f'Unknown image format: {extension}')
- image_format = 'JPEG'
- if shared.opts.image_watermark_enabled:
- image = set_watermark(image, shared.opts.image_watermark)
- size = os.path.getsize(fn) if os.path.exists(fn) else 0
- shared.log.debug(f'Saving: image="{fn}" type={image_format} resolution={image.width}x{image.height} size={size}')
- # additional metadata saved in files
- if shared.opts.save_txt and len(exifinfo) > 0:
- try:
- with open(filename_txt, "w", encoding="utf8") as file:
- file.write(f"{exifinfo}\n")
- shared.log.debug(f'Saving: text="{filename_txt}" len={len(exifinfo)}')
- except Exception as e:
- shared.log.warning(f'Image description save failed: {filename_txt} {e}')
- # actual save
- exifinfo = (exifinfo or "") if shared.opts.image_metadata else ""
- if image_format == 'PNG':
- pnginfo_data = PngImagePlugin.PngInfo()
- for k, v in params.pnginfo.items():
- pnginfo_data.add_text(k, str(v))
- try:
- image.save(fn, format=image_format, compress_level=6, pnginfo=pnginfo_data if shared.opts.image_metadata else None)
- except Exception as e:
- shared.log.error(f'Image save failed: file="{fn}" {e}')
- elif image_format == 'JPEG':
- if image.mode == 'RGBA':
- shared.log.warning('Saving RGBA image as JPEG: Alpha channel will be lost')
- image = image.convert("RGB")
- elif image.mode == 'I;16':
- image = image.point(lambda p: p * 0.0038910505836576).convert("L")
- exif_bytes = piexif.dump({ "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(exifinfo, encoding="unicode") } })
- try:
- image.save(fn, format=image_format, optimize=True, quality=shared.opts.jpeg_quality, exif=exif_bytes)
- except Exception as e:
- shared.log.error(f'Image save failed: file="{fn}" {e}')
- elif image_format == 'WEBP':
- if image.mode == 'I;16':
- image = image.point(lambda p: p * 0.0038910505836576).convert("RGB")
- exif_bytes = piexif.dump({ "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(exifinfo, encoding="unicode") } })
- try:
- image.save(fn, format=image_format, quality=shared.opts.jpeg_quality, lossless=shared.opts.webp_lossless, exif=exif_bytes)
- except Exception as e:
- shared.log.error(f'Image save failed: file="{fn}" {e}')
- else:
- # shared.log.warning(f'Unrecognized image format: {extension} attempting save as {image_format}')
- try:
- image.save(fn, format=image_format, quality=shared.opts.jpeg_quality)
- except Exception as e:
- shared.log.error(f'Image save failed: file="{fn}" {e}')
- if shared.opts.save_log_fn != '' and len(exifinfo) > 0:
- fn = os.path.join(paths.data_path, shared.opts.save_log_fn)
- if not fn.endswith('.json'):
- fn += '.json'
- entries = shared.readfile(fn, silent=True)
- idx = len(list(entries))
- if idx == 0:
- entries = []
- entry = { 'id': idx, 'filename': filename, 'time': datetime.datetime.now().isoformat(), 'info': exifinfo }
- entries.append(entry)
- shared.writefile(entries, fn, mode='w', silent=True)
- shared.log.debug(f'Saving: json="{fn}" records={len(entries)}')
- save_queue.task_done()
-
-
-save_queue = queue.Queue()
-save_thread = threading.Thread(target=atomically_save_image, daemon=True)
-save_thread.start()
-
-
-def save_image(image, path, basename='', seed=None, prompt=None, extension=shared.opts.samples_format, info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix='', save_to_dirs=None): # pylint: disable=unused-argument
- debug(f'Save from function={sys._getframe(1).f_code.co_name}') # pylint: disable=protected-access
- if image is None:
- shared.log.warning('Image is none')
- return None, None
- if not check_grid_size([image]):
- return None, None
- if path is None or len(path) == 0: # set default path to avoid errors when functions are triggered manually or via api and param is not set
- path = shared.opts.outdir_save
- namegen = FilenameGenerator(p, seed, prompt, image, grid=grid)
- suffix = suffix if suffix is not None else ''
- basename = basename if basename is not None else ''
- if shared.opts.save_to_dirs:
- dirname = namegen.apply(shared.opts.directories_filename_pattern or "[prompt_words]")
- path = os.path.join(path, dirname)
- if forced_filename is None:
- if shared.opts.samples_filename_pattern and len(shared.opts.samples_filename_pattern) > 0:
- file_decoration = shared.opts.samples_filename_pattern
- else:
- file_decoration = "[seq]-[prompt_words]"
- file_decoration = namegen.apply(file_decoration)
- file_decoration += suffix if suffix is not None else ''
- filename = os.path.join(path, f"{file_decoration}.{extension}") if basename == '' else os.path.join(path, f"{basename}-{file_decoration}.{extension}")
- else:
- forced_filename += suffix if suffix is not None else ''
- filename = os.path.join(path, f"{forced_filename}.{extension}") if basename == '' else os.path.join(path, f"{basename}-{forced_filename}.{extension}")
- pnginfo = existing_info or {}
- if info is not None:
- pnginfo[pnginfo_section_name] = info
- params = script_callbacks.ImageSaveParams(image, p, filename, pnginfo)
- params.filename = namegen.sanitize(filename)
- dirname = os.path.dirname(params.filename)
- if dirname is not None and len(dirname) > 0:
- os.makedirs(dirname, exist_ok=True)
- params.filename = namegen.sequence(params.filename, dirname, basename)
- params.filename = namegen.sanitize(params.filename)
- # callbacks
- script_callbacks.before_image_saved_callback(params)
- exifinfo = params.pnginfo.get('UserComment', '')
- exifinfo = (exifinfo + ', ' if len(exifinfo) > 0 else '') + params.pnginfo.get(pnginfo_section_name, '')
- filename, extension = os.path.splitext(params.filename)
- filename_txt = f"{filename}.txt" if shared.opts.save_txt and len(exifinfo) > 0 else None
- save_queue.put((params.image, filename, extension, params, exifinfo, filename_txt)) # actual save is executed in a thread that polls data from queue
- save_queue.join()
- if not hasattr(params.image, 'already_saved_as'):
- debug(f'Image marked: "{params.filename}"')
- params.image.already_saved_as = params.filename
- script_callbacks.image_saved_callback(params)
- return params.filename, filename_txt
-
-
-def save_video_atomic(images, filename, video_type: str = 'none', duration: float = 2.0, loop: bool = False, interpolate: int = 0, scale: float = 1.0, pad: int = 1, change: float = 0.3):
- try:
- import cv2
- except Exception as e:
- shared.log.error(f'Save video: cv2: {e}')
- return
- os.makedirs(os.path.dirname(filename), exist_ok=True)
- if video_type.lower() == 'mp4':
- frames = images
- if interpolate > 0:
- try:
- import modules.rife
- frames = modules.rife.interpolate(images, count=interpolate, scale=scale, pad=pad, change=change)
- except Exception as e:
- shared.log.error(f'RIFE interpolation: {e}')
- errors.display(e, 'RIFE interpolation')
- video_frames = [np.array(frame) for frame in frames]
- fourcc = "mp4v"
- h, w, _c = video_frames[0].shape
- video_writer = cv2.VideoWriter(filename, fourcc=cv2.VideoWriter_fourcc(*fourcc), fps=len(frames)/duration, frameSize=(w, h))
- for i in range(len(video_frames)):
- img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR)
- video_writer.write(img)
- size = os.path.getsize(filename)
- shared.log.info(f'Save video: file="{filename}" frames={len(frames)} duration={duration} fourcc={fourcc} size={size}')
- if video_type.lower() == 'gif' or video_type.lower() == 'png':
- append = images.copy()
- image = append.pop(0)
- if loop:
- append += append[::-1]
- frames=len(append) + 1
- image.save(
- filename,
- save_all = True,
- append_images = append,
- optimize = False,
- duration = 1000.0 * duration / frames,
- loop = 0 if loop else 1,
- )
- size = os.path.getsize(filename)
- shared.log.info(f'Save video: file="{filename}" frames={len(append) + 1} duration={duration} loop={loop} size={size}')
-
-
-def save_video(p, images, filename = None, video_type: str = 'none', duration: float = 2.0, loop: bool = False, interpolate: int = 0, scale: float = 1.0, pad: int = 1, change: float = 0.3, sync: bool = False):
- if images is None or len(images) < 2 or video_type is None or video_type.lower() == 'none':
- return
- image = images[0]
- if p is not None:
- namegen = FilenameGenerator(p, seed=p.all_seeds[0], prompt=p.all_prompts[0], image=image)
- else:
- namegen = FilenameGenerator(None, seed=0, prompt='', image=image)
- if filename is None and p is not None:
- filename = namegen.apply(shared.opts.samples_filename_pattern if shared.opts.samples_filename_pattern and len(shared.opts.samples_filename_pattern) > 0 else "[seq]-[prompt_words]")
- filename = os.path.join(shared.opts.outdir_video, filename)
- filename = namegen.sequence(filename, shared.opts.outdir_video, '')
- else:
- if os.pathsep not in filename:
- filename = os.path.join(shared.opts.outdir_video, filename)
- if not filename.lower().endswith(video_type.lower()):
- filename += f'.{video_type.lower()}'
- filename = namegen.sanitize(filename)
- if not sync:
- threading.Thread(target=save_video_atomic, args=(images, filename, video_type, duration, loop, interpolate, scale, pad, change)).start()
- else:
- save_video_atomic(images, filename, video_type, duration, loop, interpolate, scale, pad, change)
- return filename
-
-
-def safe_decode_string(s: bytes):
- remove_prefix = lambda text, prefix: text[len(prefix):] if text.startswith(prefix) else text # pylint: disable=unnecessary-lambda-assignment
- for encoding in ['utf-8', 'utf-16', 'ascii', 'latin_1', 'cp1252', 'cp437']: # try different encodings
- try:
- s = remove_prefix(s, b'UNICODE')
- s = remove_prefix(s, b'ASCII')
- s = remove_prefix(s, b'\x00')
- val = s.decode(encoding, errors="strict")
- val = re.sub(r'[\x00-\x09]', '', val).strip() # remove remaining special characters
- if len(val) == 0: # remove empty strings
- val = None
- return val
- except Exception:
- pass
- return None
-
-
-def read_info_from_image(image: Image):
- items = image.info or {}
- geninfo = items.pop('parameters', None)
- if geninfo is None:
- geninfo = items.pop('UserComment', None)
- if geninfo is not None and len(geninfo) > 0:
- if 'UserComment' in geninfo:
- geninfo = geninfo['UserComment']
- items['UserComment'] = geninfo
-
- if "exif" in items:
- try:
- exif = piexif.load(items["exif"])
- except Exception as e:
- shared.log.error(f'Error loading EXIF data: {e}')
- exif = {}
- for _key, subkey in exif.items():
- if isinstance(subkey, dict):
- for key, val in subkey.items():
- if isinstance(val, bytes): # decode bytestring
- val = safe_decode_string(val)
- if isinstance(val, tuple) and isinstance(val[0], int) and isinstance(val[1], int) and val[1] > 0: # convert camera ratios
- val = round(val[0] / val[1], 2)
- if val is not None and key in ExifTags.TAGS: # add known tags
- if ExifTags.TAGS[key] == 'UserComment': # add geninfo from UserComment
- geninfo = val
- items['parameters'] = val
- else:
- items[ExifTags.TAGS[key]] = val
- elif val is not None and key in ExifTags.GPSTAGS:
- items[ExifTags.GPSTAGS[key]] = val
- wm = get_watermark(image)
- if wm != '':
- # geninfo += f' Watermark: {wm}'
- items['watermark'] = wm
-
- for key, val in items.items():
- if isinstance(val, bytes): # decode bytestring
- items[key] = safe_decode_string(val)
-
- for key in ['exif', 'ExifOffset', 'JpegIFOffset', 'JpegIFByteCount', 'ExifVersion', 'icc_profile', 'jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'adobe', 'photoshop', 'loop', 'duration', 'dpi']: # remove unwanted tags
- items.pop(key, None)
-
- if items.get("Software", None) == "NovelAI":
- try:
- json_info = json.loads(items["Comment"])
- sampler = sd_samplers.samplers_map.get(json_info["sampler"], "Euler a")
- geninfo = f"""{items["Description"]}
-Negative prompt: {json_info["uc"]}
-Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
- except Exception as e:
- errors.display(e, 'novelai image parser')
-
- try:
- items['width'] = image.width
- items['height'] = image.height
- items['mode'] = image.mode
- except Exception:
- pass
-
- return geninfo, items
-
-
-def image_data(data):
- import gradio as gr
- if data is None:
- return gr.update(), None
- err1 = None
- err2 = None
- try:
- image = Image.open(io.BytesIO(data))
- errors.log.debug(f'Decoded object: image={image}')
- textinfo, _ = read_info_from_image(image)
- return textinfo, None
- except Exception as e:
- err1 = e
- try:
- if len(data) > 1024 * 10:
- errors.log.warning(f'Error decoding object: data too long: {len(data)}')
- return gr.update(), None
- text = data.decode('utf8')
- errors.log.debug(f'Decoded object: size={len(text)}')
- return text, None
- except Exception as e:
- err2 = e
- errors.log.error(f'Error decoding object: {err1 or err2}')
- return gr.update(), None
-
-
-def flatten(img, bgcolor):
- """replaces transparency with bgcolor (example: "#ffffff"), returning an RGB mode image with no transparency"""
- if img.mode == "RGBA":
- background = Image.new('RGBA', img.size, bgcolor)
- background.paste(img, mask=img)
- img = background
- return img.convert('RGB')
-
-
-def set_watermark(image, watermark):
- from imwatermark import WatermarkEncoder
- wm_type = 'bytes'
- wm_method = 'dwtDctSvd'
- wm_length = 32
- length = wm_length // 8
- info = image.info
- data = np.asarray(image)
- encoder = WatermarkEncoder()
- text = f"{watermark:<{length}}"[:length]
- bytearr = text.encode(encoding='ascii', errors='ignore')
- try:
- encoder.set_watermark(wm_type, bytearr)
- encoded = encoder.encode(data, wm_method)
- image = Image.fromarray(encoded)
- image.info = info
- shared.log.debug(f'Set watermark: {watermark} method={wm_method} bits={wm_length}')
- except Exception as e:
- shared.log.warning(f'Set watermark error: {watermark} method={wm_method} bits={wm_length} {e}')
- return image
-
-
-def get_watermark(image):
- from imwatermark import WatermarkDecoder
- wm_type = 'bytes'
- wm_method = 'dwtDctSvd'
- wm_length = 32
- data = np.asarray(image)
- decoder = WatermarkDecoder(wm_type, wm_length)
- try:
- decoded = decoder.decode(data, wm_method)
- wm = decoded.decode(encoding='ascii', errors='ignore')
- except Exception:
- wm = ''
- return wm
+import io
+import re
+import os
+import sys
+import math
+import json
+import uuid
+import queue
+import string
+import hashlib
+import datetime
+import threading
+from pathlib import Path
+from collections import namedtuple
+import numpy as np
+import piexif
+import piexif.helper
+from PIL import Image, ImageFont, ImageDraw, PngImagePlugin, ExifTags
+from modules import sd_samplers, shared, script_callbacks, errors, paths
+
+
+debug = errors.log.trace if os.environ.get('SD_PATH_DEBUG', None) is not None else lambda *args, **kwargs: None
+try:
+ from pi_heif import register_heif_opener
+ register_heif_opener()
+except Exception:
+ pass
+
+
+def check_grid_size(imgs):
+ mp = 0
+ for img in imgs:
+ mp += img.width * img.height
+ mp = round(mp / 1000000)
+ ok = mp <= shared.opts.img_max_size_mp
+ if not ok:
+ shared.log.warning(f'Maximum image size exceded: size={mp} maximum={shared.opts.img_max_size_mp} MPixels')
+ return ok
+
+
+def image_grid(imgs, batch_size=1, rows=None):
+ if rows is None:
+ if shared.opts.n_rows > 0:
+ rows = shared.opts.n_rows
+ elif shared.opts.n_rows == 0:
+ rows = batch_size
+ else:
+ rows = math.floor(math.sqrt(len(imgs)))
+ while len(imgs) % rows != 0:
+ rows -= 1
+ if rows > len(imgs):
+ rows = len(imgs)
+ cols = math.ceil(len(imgs) / rows)
+ params = script_callbacks.ImageGridLoopParams(imgs, cols, rows)
+ script_callbacks.image_grid_callback(params)
+ w, h = imgs[0].size
+ grid = Image.new('RGB', size=(params.cols * w, params.rows * h), color=shared.opts.grid_background)
+ for i, img in enumerate(params.imgs):
+ grid.paste(img, box=(i % params.cols * w, i // params.cols * h))
+ return grid
+
+
+Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])
+
+
+def split_grid(image, tile_w=512, tile_h=512, overlap=64):
+ w = image.width
+ h = image.height
+ non_overlap_width = tile_w - overlap
+ non_overlap_height = tile_h - overlap
+ cols = math.ceil((w - overlap) / non_overlap_width)
+ rows = math.ceil((h - overlap) / non_overlap_height)
+ dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
+ dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
+ grid = Grid([], tile_w, tile_h, w, h, overlap)
+ for row in range(rows):
+ row_images = []
+ y = int(row * dy)
+ if y + tile_h >= h:
+ y = h - tile_h
+ for col in range(cols):
+ x = int(col * dx)
+ if x + tile_w >= w:
+ x = w - tile_w
+ tile = image.crop((x, y, x + tile_w, y + tile_h))
+ row_images.append([x, tile_w, tile])
+ grid.tiles.append([y, tile_h, row_images])
+ return grid
+
+
+def combine_grid(grid):
+ def make_mask_image(r):
+ r = r * 255 / grid.overlap
+ r = r.astype(np.uint8)
+ return Image.fromarray(r, 'L')
+
+ mask_w = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0))
+ mask_h = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1))
+ combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
+ for y, h, row in grid.tiles:
+ combined_row = Image.new("RGB", (grid.image_w, h))
+ for x, w, tile in row:
+ if x == 0:
+ combined_row.paste(tile, (0, 0))
+ continue
+ combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w)
+ combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0))
+ if y == 0:
+ combined_image.paste(combined_row, (0, 0))
+ continue
+ combined_image.paste(combined_row.crop((0, 0, combined_row.width, grid.overlap)), (0, y), mask=mask_h)
+ combined_image.paste(combined_row.crop((0, grid.overlap, combined_row.width, h)), (0, y + grid.overlap))
+ return combined_image
+
+
+class GridAnnotation:
+ def __init__(self, text='', is_active=True):
+ self.text = text
+ self.is_active = is_active
+ self.size = None
+
+
+def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0, title=None):
+ def wrap(drawing, text, font, line_length):
+ lines = ['']
+ for word in text.split():
+ line = f'{lines[-1]} {word}'.strip()
+ if drawing.textlength(line, font=font) <= line_length:
+ lines[-1] = line
+ else:
+ lines.append(word)
+ return lines
+
+ def get_font(fontsize):
+ try:
+ return ImageFont.truetype(shared.opts.font or 'javascript/roboto.ttf', fontsize)
+ except Exception:
+ return ImageFont.truetype('javascript/roboto.ttf', fontsize)
+
+ def draw_texts(drawing: ImageDraw, draw_x, draw_y, lines, initial_fnt, initial_fontsize):
+ for line in lines:
+ font = initial_fnt
+ fontsize = initial_fontsize
+ while drawing.multiline_textbbox((0,0), text=line.text, font=font)[0] > line.allowed_width and fontsize > 0:
+ fontsize -= 1
+ font = get_font(fontsize)
+ drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=font, fill=shared.opts.font_color if line.is_active else color_inactive, anchor="mm", align="center")
+ if not line.is_active:
+ drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2, draw_y + line.size[1] // 2), fill=color_inactive, width=4)
+ draw_y += line.size[1] + line_spacing
+
+ fontsize = (width + height) // 25
+ line_spacing = fontsize // 2
+ font = get_font(fontsize)
+ color_inactive = (127, 127, 127)
+ pad_left = 0 if sum([sum([len(line.text) for line in lines]) for lines in ver_texts]) == 0 else width * 3 // 4
+ cols = im.width // width
+ rows = im.height // height
+ assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}'
+ assert rows == len(ver_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}'
+ calc_img = Image.new("RGB", (1, 1), shared.opts.grid_background)
+ calc_d = ImageDraw.Draw(calc_img)
+ title_texts = [title] if title else [[GridAnnotation()]]
+ for texts, allowed_width in zip(hor_texts + ver_texts + title_texts, [width] * len(hor_texts) + [pad_left] * len(ver_texts) + [(width+margin)*cols]):
+ items = [] + texts
+ texts.clear()
+ for line in items:
+ wrapped = wrap(calc_d, line.text, font, allowed_width)
+ texts += [GridAnnotation(x, line.is_active) for x in wrapped]
+ for line in texts:
+ bbox = calc_d.multiline_textbbox((0, 0), line.text, font=font)
+ line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1])
+ line.allowed_width = allowed_width
+ hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
+ ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_texts]
+ pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2
+ title_pad = 0
+ if title:
+ title_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in title_texts] # pylint: disable=unsubscriptable-object
+ title_pad = 0 if sum(title_text_heights) == 0 else max(title_text_heights) + line_spacing * 2
+ result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + title_pad + margin * (rows-1)), shared.opts.grid_background)
+ for row in range(rows):
+ for col in range(cols):
+ cell = im.crop((width * col, height * row, width * (col+1), height * (row+1)))
+ result.paste(cell, (pad_left + (width + margin) * col, pad_top + title_pad + (height + margin) * row))
+ d = ImageDraw.Draw(result)
+ if title:
+ x = pad_left + ((width+margin)*cols) / 2
+ y = title_pad / 2 - title_text_heights[0] / 2
+ draw_texts(d, x, y, title_texts[0], font, fontsize)
+ for col in range(cols):
+ x = pad_left + (width + margin) * col + width / 2
+ y = (pad_top / 2 - hor_text_heights[col] / 2) + title_pad
+ draw_texts(d, x, y, hor_texts[col], font, fontsize)
+ for row in range(rows):
+ x = pad_left / 2
+ y = (pad_top + (height + margin) * row + height / 2 - ver_text_heights[row] / 2) + title_pad
+ draw_texts(d, x, y, ver_texts[row], font, fontsize)
+ return result
+
+
+def draw_prompt_matrix(im, width, height, all_prompts, margin=0):
+ prompts = all_prompts[1:]
+ boundary = math.ceil(len(prompts) / 2)
+ prompts_horiz = prompts[:boundary]
+ prompts_vert = prompts[boundary:]
+ hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))]
+ ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))]
+ return draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin)
+
+
+def resize_image(resize_mode, im, width, height, upscaler_name=None, output_type='image'):
+ shared.log.debug(f'Image resize: mode={resize_mode} resolution={width}x{height} upscaler={upscaler_name} function={sys._getframe(1).f_code.co_name}') # pylint: disable=protected-access
+ """
+ Resizes an image with the specified resize_mode, width, and height.
+ Args:
+ resize_mode: The mode to use when resizing the image.
+ 0: No resize
+ 1: Resize the image to the specified width and height.
+ 2: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
+ 3: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
+ im: The image to resize.
+ width: The width to resize the image to.
+ height: The height to resize the image to.
+ upscaler_name: The name of the upscaler to use. If not provided, defaults to opts.upscaler_for_img2img.
+ """
+ upscaler_name = upscaler_name or shared.opts.upscaler_for_img2img
+
+ def latent(im, w, h, upscaler):
+ from modules.processing_vae import vae_encode, vae_decode
+ import torch
+ latents = vae_encode(im, shared.sd_model, full_quality=False) # TODO enable full VAE mode
+ latents = torch.nn.functional.interpolate(latents, size=(h // 8, w // 8), mode=upscaler["mode"], antialias=upscaler["antialias"])
+ im = vae_decode(latents, shared.sd_model, output_type='pil', full_quality=False)[0]
+ return im
+
+ def resize(im, w, h):
+ if upscaler_name is None or upscaler_name == "None" or im.mode == 'L':
+ return im.resize((w, h), resample=Image.Resampling.LANCZOS)
+ scale = max(w / im.width, h / im.height)
+ if scale > 1.0:
+ upscalers = [x for x in shared.sd_upscalers if x.name == upscaler_name]
+ if len(upscalers) > 0:
+ upscaler = upscalers[0]
+ im = upscaler.scaler.upscale(im, scale, upscaler.data_path)
+ else:
+ upscaler = shared.latent_upscale_modes.get(upscaler_name, None)
+ if upscaler is not None:
+ im = latent(im, w, h, upscaler)
+ else:
+ shared.log.warning(f"Could not find upscaler: {upscaler_name or ''} using fallback: {upscaler.name}")
+ if im.width != w or im.height != h:
+ im = im.resize((w, h), resample=Image.Resampling.LANCZOS)
+ return im
+
+ if resize_mode == 0 or (im.width == width and im.height == height):
+ res = im.copy()
+ elif resize_mode == 1:
+ res = resize(im, width, height)
+ elif resize_mode == 2:
+ ratio = width / height
+ src_ratio = im.width / im.height
+ src_w = width if ratio > src_ratio else im.width * height // im.height
+ src_h = height if ratio <= src_ratio else im.height * width // im.width
+ resized = resize(im, src_w, src_h)
+ res = Image.new(im.mode, (width, height))
+ res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
+ else:
+ ratio = width / height
+ src_ratio = im.width / im.height
+ src_w = width if ratio < src_ratio else im.width * height // im.height
+ src_h = height if ratio >= src_ratio else im.height * width // im.width
+ resized = resize(im, src_w, src_h)
+ res = Image.new(im.mode, (width, height))
+ res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
+ if ratio < src_ratio:
+ fill_height = height // 2 - src_h // 2
+ res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
+ res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
+ elif ratio > src_ratio:
+ fill_width = width // 2 - src_w // 2
+ res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
+ res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
+ if output_type == 'np':
+ return np.array(res)
+ return res
+
+
+re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
+re_pattern = re.compile(r"(.*?)(?:\[([^\[\]]+)\]|$)")
+re_pattern_arg = re.compile(r"(.*)<([^>]*)>$")
+re_attention = re.compile(r'[\(*\[*](\w+)(:\d+(\.\d+))?[\)*\]*]|')
+re_network = re.compile(r'\<\w+:(\w+)(:\d+(\.\d+))?\>|')
+re_brackets = re.compile(r'[\([{})\]]')
+
+NOTHING = object()
+
+
+class FilenameGenerator:
+ replacements = {
+ 'width': lambda self: self.image.width,
+ 'height': lambda self: self.image.height,
+ 'batch_number': lambda self: self.batch_number,
+ 'iter_number': lambda self: self.iter_number,
+ 'num': lambda self: NOTHING if self.p.n_iter == 1 and self.p.batch_size == 1 else self.p.iteration * self.p.batch_size + self.p.batch_index + 1,
+ 'generation_number': lambda self: NOTHING if self.p.n_iter == 1 and self.p.batch_size == 1 else self.p.iteration * self.p.batch_size + self.p.batch_index + 1,
+ 'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
+ 'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime], [datetime]
+ 'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt..]
+ 'hash': lambda self: self.image_hash(),
+ 'image_hash': lambda self: self.image_hash(),
+ 'timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
+ 'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
+
+ 'model': lambda self: shared.sd_model.sd_checkpoint_info.title,
+ 'model_shortname': lambda self: shared.sd_model.sd_checkpoint_info.model_name,
+ 'model_name': lambda self: shared.sd_model.sd_checkpoint_info.model_name,
+ 'model_hash': lambda self: shared.sd_model.sd_checkpoint_info.shorthash,
+
+ 'prompt': lambda self: self.prompt_full(),
+ 'prompt_no_styles': lambda self: self.prompt_no_style(),
+ 'prompt_words': lambda self: self.prompt_words(),
+ 'prompt_hash': lambda self: hashlib.sha256(self.prompt.encode()).hexdigest()[0:8],
+
+ 'sampler': lambda self: self.p and self.p.sampler_name,
+ 'seed': lambda self: self.seed and str(self.seed) or '',
+ 'steps': lambda self: self.p and self.p.steps,
+ 'styles': lambda self: self.p and ", ".join([style for style in self.p.styles if not style == "None"]) or "None",
+ 'uuid': lambda self: str(uuid.uuid4()),
+ }
+ default_time_format = '%Y%m%d%H%M%S'
+
+ def __init__(self, p, seed, prompt, image, grid=False):
+ if p is None:
+ debug('Filename generator init skip')
+ else:
+ debug(f'Filename generator init: {seed} {prompt}')
+ self.p = p
+ if seed is not None and seed > 0:
+ self.seed = seed
+ elif hasattr(p, 'all_seeds'):
+ self.seed = p.all_seeds[0]
+ else:
+ self.seed = 0
+ self.prompt = prompt
+ self.image = image
+ if not grid:
+ self.batch_number = NOTHING if self.p is None or getattr(self.p, 'batch_size', 1) == 1 else (self.p.batch_index + 1 if hasattr(self.p, 'batch_index') else NOTHING)
+ self.iter_number = NOTHING if self.p is None or getattr(self.p, 'n_iter', 1) == 1 else (self.p.iteration + 1 if hasattr(self.p, 'iteration') else NOTHING)
+ else:
+ self.batch_number = NOTHING
+ self.iter_number = NOTHING
+
+ def hasprompt(self, *args):
+ lower = self.prompt.lower()
+ if getattr(self, 'p', None) is None or getattr(self, 'prompt', None) is None:
+ return None
+ outres = ""
+ for arg in args:
+ if arg != "":
+ division = arg.split("|")
+ expected = division[0].lower()
+ default = division[1] if len(division) > 1 else ""
+ if lower.find(expected) >= 0:
+ outres = f'{outres}{expected}'
+ else:
+ outres = outres if default == "" else f'{outres}{default}'
+ return outres
+
+ def image_hash(self):
+ if getattr(self, 'image', None) is None:
+ return None
+ import base64
+ from io import BytesIO
+ buffered = BytesIO()
+ self.image.save(buffered, format="JPEG")
+ img_str = base64.b64encode(buffered.getvalue())
+ shorthash = hashlib.sha256(img_str).hexdigest()[0:8]
+ return shorthash
+
+ def prompt_full(self):
+ return self.prompt_sanitize(self.prompt)
+
+ def prompt_words(self):
+ if getattr(self, 'prompt', None) is None:
+ return ''
+ no_attention = re_attention.sub(r'\1', self.prompt)
+ no_network = re_network.sub(r'\1', no_attention)
+ no_brackets = re_brackets.sub('', no_network)
+ words = [x for x in re_nonletters.split(no_brackets or "") if len(x) > 0]
+ prompt = " ".join(words[0:shared.opts.directories_max_prompt_words])
+ return self.prompt_sanitize(prompt)
+
+ def prompt_no_style(self):
+ if getattr(self, 'p', None) is None or getattr(self, 'prompt', None) is None:
+ return None
+ prompt_no_style = self.prompt
+ for style in shared.prompt_styles.get_style_prompts(self.p.styles):
+ if len(style) > 0:
+ for part in style.split("{prompt}"):
+ prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",")
+ prompt_no_style = prompt_no_style.replace(style, "")
+ return self.prompt_sanitize(prompt_no_style)
+
+ def datetime(self, *args):
+ import pytz
+ time_datetime = datetime.datetime.now()
+ time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format
+ try:
+ time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
+ except pytz.exceptions.UnknownTimeZoneError:
+ time_zone = None
+ time_zone_time = time_datetime.astimezone(time_zone)
+ try:
+ formatted_time = time_zone_time.strftime(time_format)
+ except (ValueError, TypeError):
+ formatted_time = time_zone_time.strftime(self.default_time_format)
+ return formatted_time
+
+ def prompt_sanitize(self, prompt):
+ invalid_chars = '#<>:\'"\\|?*\n\t\r'
+ sanitized = prompt.translate({ ord(x): '_' for x in invalid_chars }).strip()
+ debug(f'Prompt sanitize: input="{prompt}" output={sanitized}')
+ return sanitized
+
+ def sanitize(self, filename):
+ invalid_chars = '\'"|?*\n\t\r' #
+ invalid_folder = ':'
+ invalid_files = ['CON', 'PRN', 'AUX', 'NUL', 'NULL', 'COM0', 'COM1', 'LPT0', 'LPT1']
+ invalid_prefix = ', '
+ invalid_suffix = '.,_ '
+ fn, ext = os.path.splitext(filename)
+ parts = Path(fn).parts
+ newparts = []
+ for i, part in enumerate(parts):
+ part = part.translate({ ord(x): '_' for x in invalid_chars })
+ if i > 0 or (len(part) >= 2 and part[1] != invalid_folder): # skip drive, otherwise remove
+ part = part.translate({ ord(x): '_' for x in invalid_folder })
+ part = part.lstrip(invalid_prefix).rstrip(invalid_suffix)
+ if part in invalid_files: # reserved names
+ [part := part.replace(word, '_') for word in invalid_files] # pylint: disable=expression-not-assigned
+ newparts.append(part)
+ fn = str(Path(*newparts))
+ max_length = max(256 - len(ext), os.statvfs(__file__).f_namemax - 32 if hasattr(os, 'statvfs') else 256 - len(ext))
+ while len(os.path.abspath(fn)) > max_length:
+ fn = fn[:-1]
+ fn += ext
+ debug(f'Filename sanitize: input="{filename}" parts={parts} output="{fn}" ext={ext} max={max_length} len={len(fn)}')
+ return fn
+
+ def sequence(self, x, dirname, basename):
+ if shared.opts.save_images_add_number or '[seq]' in x:
+ if '[seq]' not in x:
+ x = os.path.join(os.path.dirname(x), f"[seq]-{os.path.basename(x)}")
+ basecount = get_next_sequence_number(dirname, basename)
+ for i in range(9999):
+ seq = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
+ filename = x.replace('[seq]', seq)
+ if not os.path.exists(filename):
+ debug(f'Prompt sequence: input="{x}" seq={seq} output="{filename}"')
+ x = filename
+ break
+ return x
+
+ def apply(self, x):
+ res = ''
+ for m in re_pattern.finditer(x):
+ text, pattern = m.groups()
+ if pattern is None:
+ res += text
+ continue
+ pattern_args = []
+ while True:
+ m = re_pattern_arg.match(pattern)
+ if m is None:
+ break
+ pattern, arg = m.groups()
+ pattern_args.insert(0, arg)
+ fun = self.replacements.get(pattern.lower(), None)
+ if fun is not None:
+ try:
+ debug(f'Filename apply: pattern={pattern.lower()} args={pattern_args}')
+ replacement = fun(self, *pattern_args)
+ except Exception as e:
+ replacement = None
+ shared.log.error(f'Filename apply pattern: {x} {e}')
+ if replacement == NOTHING:
+ continue
+ if replacement is not None:
+ res += text + str(replacement).replace('/', '-').replace('\\', '-')
+ continue
+ else:
+ res += text + f'[{pattern}]' # reinsert unknown pattern
+ return res
+
+
+def get_next_sequence_number(path, basename):
+ """
+ Determines and returns the next sequence number to use when saving an image in the specified directory.
+ """
+ result = -1
+ if basename != '':
+ basename = f"{basename}-"
+ prefix_length = len(basename)
+ if not os.path.isdir(path):
+ return 0
+ for p in os.listdir(path):
+ if p.startswith(basename):
+ parts = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
+ try:
+ result = max(int(parts[0]), result)
+ except ValueError:
+ pass
+ return result + 1
+
+
+def atomically_save_image():
+ Image.MAX_IMAGE_PIXELS = None # disable check in Pillow and rely on check below to allow large custom image sizes
+ while True:
+ image, filename, extension, params, exifinfo, filename_txt = save_queue.get()
+ with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file:
+ file.write(exifinfo)
+ fn = filename + extension
+ filename = filename.strip()
+ if extension[0] != '.': # add dot if missing
+ extension = '.' + extension
+ try:
+ image_format = Image.registered_extensions()[extension]
+ except Exception:
+ shared.log.warning(f'Unknown image format: {extension}')
+ image_format = 'JPEG'
+ if shared.opts.image_watermark_enabled:
+ image = set_watermark(image, shared.opts.image_watermark)
+ size = os.path.getsize(fn) if os.path.exists(fn) else 0
+ shared.log.debug(f'Saving: image="{fn}" type={image_format} resolution={image.width}x{image.height} size={size}')
+ # additional metadata saved in files
+ if shared.opts.save_txt and len(exifinfo) > 0:
+ try:
+ with open(filename_txt, "w", encoding="utf8") as file:
+ file.write(f"{exifinfo}\n")
+ shared.log.debug(f'Saving: text="{filename_txt}" len={len(exifinfo)}')
+ except Exception as e:
+ shared.log.warning(f'Image description save failed: {filename_txt} {e}')
+ # actual save
+ exifinfo = (exifinfo or "") if shared.opts.image_metadata else ""
+ if image_format == 'PNG':
+ pnginfo_data = PngImagePlugin.PngInfo()
+ for k, v in params.pnginfo.items():
+ pnginfo_data.add_text(k, str(v))
+ try:
+ image.save(fn, format=image_format, compress_level=6, pnginfo=pnginfo_data if shared.opts.image_metadata else None)
+ except Exception as e:
+ shared.log.error(f'Image save failed: file="{fn}" {e}')
+ elif image_format == 'JPEG':
+ if image.mode == 'RGBA':
+ shared.log.warning('Saving RGBA image as JPEG: Alpha channel will be lost')
+ image = image.convert("RGB")
+ elif image.mode == 'I;16':
+ image = image.point(lambda p: p * 0.0038910505836576).convert("L")
+ exif_bytes = piexif.dump({ "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(exifinfo, encoding="unicode") } })
+ try:
+ image.save(fn, format=image_format, optimize=True, quality=shared.opts.jpeg_quality, exif=exif_bytes)
+ except Exception as e:
+ shared.log.error(f'Image save failed: file="{fn}" {e}')
+ elif image_format == 'WEBP':
+ if image.mode == 'I;16':
+ image = image.point(lambda p: p * 0.0038910505836576).convert("RGB")
+ exif_bytes = piexif.dump({ "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(exifinfo, encoding="unicode") } })
+ try:
+ image.save(fn, format=image_format, quality=shared.opts.jpeg_quality, lossless=shared.opts.webp_lossless, exif=exif_bytes)
+ except Exception as e:
+ shared.log.error(f'Image save failed: file="{fn}" {e}')
+ else:
+ # shared.log.warning(f'Unrecognized image format: {extension} attempting save as {image_format}')
+ try:
+ image.save(fn, format=image_format, quality=shared.opts.jpeg_quality)
+ except Exception as e:
+ shared.log.error(f'Image save failed: file="{fn}" {e}')
+ if shared.opts.save_log_fn != '' and len(exifinfo) > 0:
+ fn = os.path.join(paths.data_path, shared.opts.save_log_fn)
+ if not fn.endswith('.json'):
+ fn += '.json'
+ entries = shared.readfile(fn, silent=True)
+ idx = len(list(entries))
+ if idx == 0:
+ entries = []
+ entry = { 'id': idx, 'filename': filename, 'time': datetime.datetime.now().isoformat(), 'info': exifinfo }
+ entries.append(entry)
+ shared.writefile(entries, fn, mode='w', silent=True)
+ shared.log.debug(f'Saving: json="{fn}" records={len(entries)}')
+ save_queue.task_done()
+
+
+save_queue = queue.Queue()
+save_thread = threading.Thread(target=atomically_save_image, daemon=True)
+save_thread.start()
+
+
+def save_image(image, path, basename='', seed=None, prompt=None, extension=shared.opts.samples_format, info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix='', save_to_dirs=None): # pylint: disable=unused-argument
+ debug(f'Save from function={sys._getframe(1).f_code.co_name}') # pylint: disable=protected-access
+ if image is None:
+ shared.log.warning('Image is none')
+ return None, None
+ if not check_grid_size([image]):
+ return None, None
+ if path is None or len(path) == 0: # set default path to avoid errors when functions are triggered manually or via api and param is not set
+ path = shared.opts.outdir_save
+ namegen = FilenameGenerator(p, seed, prompt, image, grid=grid)
+ suffix = suffix if suffix is not None else ''
+ basename = basename if basename is not None else ''
+ if shared.opts.save_to_dirs:
+ dirname = namegen.apply(shared.opts.directories_filename_pattern or "[prompt_words]")
+ path = os.path.join(path, dirname)
+ if forced_filename is None:
+ if shared.opts.samples_filename_pattern and len(shared.opts.samples_filename_pattern) > 0:
+ file_decoration = shared.opts.samples_filename_pattern
+ else:
+ file_decoration = "[seq]-[prompt_words]"
+ file_decoration = namegen.apply(file_decoration)
+ file_decoration += suffix if suffix is not None else ''
+ filename = os.path.join(path, f"{file_decoration}.{extension}") if basename == '' else os.path.join(path, f"{basename}-{file_decoration}.{extension}")
+ else:
+ forced_filename += suffix if suffix is not None else ''
+ filename = os.path.join(path, f"{forced_filename}.{extension}") if basename == '' else os.path.join(path, f"{basename}-{forced_filename}.{extension}")
+ pnginfo = existing_info or {}
+ if info is not None:
+ pnginfo[pnginfo_section_name] = info
+ params = script_callbacks.ImageSaveParams(image, p, filename, pnginfo)
+ params.filename = namegen.sanitize(filename)
+ dirname = os.path.dirname(params.filename)
+ if dirname is not None and len(dirname) > 0:
+ os.makedirs(dirname, exist_ok=True)
+ params.filename = namegen.sequence(params.filename, dirname, basename)
+ params.filename = namegen.sanitize(params.filename)
+ # callbacks
+ script_callbacks.before_image_saved_callback(params)
+ exifinfo = params.pnginfo.get('UserComment', '')
+ exifinfo = (exifinfo + ', ' if len(exifinfo) > 0 else '') + params.pnginfo.get(pnginfo_section_name, '')
+ filename, extension = os.path.splitext(params.filename)
+ filename_txt = f"{filename}.txt" if shared.opts.save_txt and len(exifinfo) > 0 else None
+ save_queue.put((params.image, filename, extension, params, exifinfo, filename_txt)) # actual save is executed in a thread that polls data from queue
+ save_queue.join()
+ if not hasattr(params.image, 'already_saved_as'):
+ debug(f'Image marked: "{params.filename}"')
+ params.image.already_saved_as = params.filename
+ script_callbacks.image_saved_callback(params)
+ return params.filename, filename_txt
+
+
+def save_video_atomic(images, filename, video_type: str = 'none', duration: float = 2.0, loop: bool = False, interpolate: int = 0, scale: float = 1.0, pad: int = 1, change: float = 0.3):
+ try:
+ import cv2
+ except Exception as e:
+ shared.log.error(f'Save video: cv2: {e}')
+ return
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
+ if video_type.lower() == 'mp4':
+ frames = images
+ if interpolate > 0:
+ try:
+ import modules.rife
+ frames = modules.rife.interpolate(images, count=interpolate, scale=scale, pad=pad, change=change)
+ except Exception as e:
+ shared.log.error(f'RIFE interpolation: {e}')
+ errors.display(e, 'RIFE interpolation')
+ video_frames = [np.array(frame) for frame in frames]
+ fourcc = "mp4v"
+ h, w, _c = video_frames[0].shape
+ video_writer = cv2.VideoWriter(filename, fourcc=cv2.VideoWriter_fourcc(*fourcc), fps=len(frames)/duration, frameSize=(w, h))
+ for i in range(len(video_frames)):
+ img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR)
+ video_writer.write(img)
+ size = os.path.getsize(filename)
+ shared.log.info(f'Save video: file="{filename}" frames={len(frames)} duration={duration} fourcc={fourcc} size={size}')
+ if video_type.lower() == 'gif' or video_type.lower() == 'png':
+ append = images.copy()
+ image = append.pop(0)
+ if loop:
+ append += append[::-1]
+ frames=len(append) + 1
+ image.save(
+ filename,
+ save_all = True,
+ append_images = append,
+ optimize = False,
+ duration = 1000.0 * duration / frames,
+ loop = 0 if loop else 1,
+ )
+ size = os.path.getsize(filename)
+ shared.log.info(f'Save video: file="{filename}" frames={len(append) + 1} duration={duration} loop={loop} size={size}')
+
+
+def save_video(p, images, filename = None, video_type: str = 'none', duration: float = 2.0, loop: bool = False, interpolate: int = 0, scale: float = 1.0, pad: int = 1, change: float = 0.3, sync: bool = False):
+ if images is None or len(images) < 2 or video_type is None or video_type.lower() == 'none':
+ return
+ image = images[0]
+ if p is not None:
+ namegen = FilenameGenerator(p, seed=p.all_seeds[0], prompt=p.all_prompts[0], image=image)
+ else:
+ namegen = FilenameGenerator(None, seed=0, prompt='', image=image)
+ if filename is None and p is not None:
+ filename = namegen.apply(shared.opts.samples_filename_pattern if shared.opts.samples_filename_pattern and len(shared.opts.samples_filename_pattern) > 0 else "[seq]-[prompt_words]")
+ filename = os.path.join(shared.opts.outdir_video, filename)
+ filename = namegen.sequence(filename, shared.opts.outdir_video, '')
+ else:
+ if os.pathsep not in filename:
+ filename = os.path.join(shared.opts.outdir_video, filename)
+ if not filename.lower().endswith(video_type.lower()):
+ filename += f'.{video_type.lower()}'
+ filename = namegen.sanitize(filename)
+ if not sync:
+ threading.Thread(target=save_video_atomic, args=(images, filename, video_type, duration, loop, interpolate, scale, pad, change)).start()
+ else:
+ save_video_atomic(images, filename, video_type, duration, loop, interpolate, scale, pad, change)
+ return filename
+
+
+def safe_decode_string(s: bytes):
+ remove_prefix = lambda text, prefix: text[len(prefix):] if text.startswith(prefix) else text # pylint: disable=unnecessary-lambda-assignment
+ for encoding in ['utf-8', 'utf-16', 'ascii', 'latin_1', 'cp1252', 'cp437']: # try different encodings
+ try:
+ s = remove_prefix(s, b'UNICODE')
+ s = remove_prefix(s, b'ASCII')
+ s = remove_prefix(s, b'\x00')
+ val = s.decode(encoding, errors="strict")
+ val = re.sub(r'[\x00-\x09]', '', val).strip() # remove remaining special characters
+ if len(val) == 0: # remove empty strings
+ val = None
+ return val
+ except Exception:
+ pass
+ return None
+
+
+def read_info_from_image(image: Image):
+ items = image.info or {}
+ geninfo = items.pop('parameters', None)
+ if geninfo is None:
+ geninfo = items.pop('UserComment', None)
+ if geninfo is not None and len(geninfo) > 0:
+ if 'UserComment' in geninfo:
+ geninfo = geninfo['UserComment']
+ items['UserComment'] = geninfo
+
+ if "exif" in items:
+ try:
+ exif = piexif.load(items["exif"])
+ except Exception as e:
+ shared.log.error(f'Error loading EXIF data: {e}')
+ exif = {}
+ for _key, subkey in exif.items():
+ if isinstance(subkey, dict):
+ for key, val in subkey.items():
+ if isinstance(val, bytes): # decode bytestring
+ val = safe_decode_string(val)
+ if isinstance(val, tuple) and isinstance(val[0], int) and isinstance(val[1], int) and val[1] > 0: # convert camera ratios
+ val = round(val[0] / val[1], 2)
+ if val is not None and key in ExifTags.TAGS: # add known tags
+ if ExifTags.TAGS[key] == 'UserComment': # add geninfo from UserComment
+ geninfo = val
+ items['parameters'] = val
+ else:
+ items[ExifTags.TAGS[key]] = val
+ elif val is not None and key in ExifTags.GPSTAGS:
+ items[ExifTags.GPSTAGS[key]] = val
+ wm = get_watermark(image)
+ if wm != '':
+ # geninfo += f' Watermark: {wm}'
+ items['watermark'] = wm
+
+ for key, val in items.items():
+ if isinstance(val, bytes): # decode bytestring
+ items[key] = safe_decode_string(val)
+
+ for key in ['exif', 'ExifOffset', 'JpegIFOffset', 'JpegIFByteCount', 'ExifVersion', 'icc_profile', 'jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'adobe', 'photoshop', 'loop', 'duration', 'dpi']: # remove unwanted tags
+ items.pop(key, None)
+
+ if items.get("Software", None) == "NovelAI":
+ try:
+ json_info = json.loads(items["Comment"])
+ sampler = sd_samplers.samplers_map.get(json_info["sampler"], "Euler a")
+ geninfo = f"""{items["Description"]}
+Negative prompt: {json_info["uc"]}
+Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
+ except Exception as e:
+ errors.display(e, 'novelai image parser')
+
+ try:
+ items['width'] = image.width
+ items['height'] = image.height
+ items['mode'] = image.mode
+ except Exception:
+ pass
+
+ return geninfo, items
+
+
+def image_data(data):
+ import gradio as gr
+ if data is None:
+ return gr.update(), None
+ err1 = None
+ err2 = None
+ try:
+ image = Image.open(io.BytesIO(data))
+ errors.log.debug(f'Decoded object: image={image}')
+ textinfo, _ = read_info_from_image(image)
+ return textinfo, None
+ except Exception as e:
+ err1 = e
+ try:
+ if len(data) > 1024 * 10:
+ errors.log.warning(f'Error decoding object: data too long: {len(data)}')
+ return gr.update(), None
+ text = data.decode('utf8')
+ errors.log.debug(f'Decoded object: size={len(text)}')
+ return text, None
+ except Exception as e:
+ err2 = e
+ errors.log.error(f'Error decoding object: {err1 or err2}')
+ return gr.update(), None
+
+
+def flatten(img, bgcolor):
+ """replaces transparency with bgcolor (example: "#ffffff"), returning an RGB mode image with no transparency"""
+ if img.mode == "RGBA":
+ background = Image.new('RGBA', img.size, bgcolor)
+ background.paste(img, mask=img)
+ img = background
+ return img.convert('RGB')
+
+
+def set_watermark(image, watermark):
+ from imwatermark import WatermarkEncoder
+ wm_type = 'bytes'
+ wm_method = 'dwtDctSvd'
+ wm_length = 32
+ length = wm_length // 8
+ info = image.info
+ data = np.asarray(image)
+ encoder = WatermarkEncoder()
+ text = f"{watermark:<{length}}"[:length]
+ bytearr = text.encode(encoding='ascii', errors='ignore')
+ try:
+ encoder.set_watermark(wm_type, bytearr)
+ encoded = encoder.encode(data, wm_method)
+ image = Image.fromarray(encoded)
+ image.info = info
+ shared.log.debug(f'Set watermark: {watermark} method={wm_method} bits={wm_length}')
+ except Exception as e:
+ shared.log.warning(f'Set watermark error: {watermark} method={wm_method} bits={wm_length} {e}')
+ return image
+
+
+def get_watermark(image):
+ from imwatermark import WatermarkDecoder
+ wm_type = 'bytes'
+ wm_method = 'dwtDctSvd'
+ wm_length = 32
+ data = np.asarray(image)
+ decoder = WatermarkDecoder(wm_type, wm_length)
+ try:
+ decoded = decoder.decode(data, wm_method)
+ wm = decoded.decode(encoding='ascii', errors='ignore')
+ except Exception:
+ wm = ''
+ return wm
diff --git a/modules/img2img.py b/modules/img2img.py
index bb51119d3..ffe4aae33 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -1,264 +1,264 @@
-import os
-import itertools # SBM Batch frames
-import numpy as np
-from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
-import modules.scripts
-from modules import shared, processing, images
-from modules.generation_parameters_copypaste import create_override_settings_dict
-from modules.ui import plaintext_to_html
-from modules.memstats import memory_stats
-
-
-debug = shared.log.trace if os.environ.get('SD_PROCESS_DEBUG', None) is not None else lambda *args, **kwargs: None
-debug('Trace: PROCESS')
-
-
-def process_batch(p, input_files, input_dir, output_dir, inpaint_mask_dir, args):
- shared.log.debug(f'batch: {input_files}|{input_dir}|{output_dir}|{inpaint_mask_dir}')
- processing.fix_seed(p)
- if input_files is not None and len(input_files) > 0:
- image_files = [f.name for f in input_files]
- else:
- if not os.path.isdir(input_dir):
- shared.log.error(f"Process batch: directory not found: {input_dir}")
- return
- image_files = os.listdir(input_dir)
- is_inpaint_batch = False
- if inpaint_mask_dir:
- inpaint_masks = os.listdir(inpaint_mask_dir)
- is_inpaint_batch = len(inpaint_masks) > 0
- if is_inpaint_batch:
- shared.log.info(f"Process batch: inpaint batch masks={len(inpaint_masks)}")
- save_normally = output_dir == ''
- p.do_not_save_grid = True
- p.do_not_save_samples = not save_normally
- shared.state.job_count = len(image_files) * p.n_iter
- if shared.opts.batch_frame_mode: # SBM Frame mode is on, process each image in batch with same seed
- window_size = p.batch_size
- btcrept = 1
- p.seed = [p.seed] * window_size # SBM MONKEYPATCH: Need to change processing to support a fixed seed value.
- p.subseed = [p.subseed] * window_size # SBM MONKEYPATCH
- shared.log.info(f"Process batch: inputs={len(image_files)} parallel={window_size} outputs={p.n_iter} per input ")
- else: # SBM Frame mode is off, standard operation of repeating same images with sequential seed.
- window_size = 1
- btcrept = p.batch_size
- shared.log.info(f"Process batch: inputs={len(image_files)} outputs={p.n_iter * p.batch_size} per input")
- for i in range(0, len(image_files), window_size):
- if shared.state.skipped:
- shared.state.skipped = False
- if shared.state.interrupted:
- break
- batch_image_files = image_files[i:i+window_size]
- batch_images = []
- for image_file in batch_image_files:
- try:
- img = Image.open(image_file)
- if p.scale_by != 1:
- p.width = int(img.width * p.scale_by)
- p.height = int(img.height * p.scale_by)
- except UnidentifiedImageError as e:
- shared.log.error(f"Image error: {e}")
- continue
- img = ImageOps.exif_transpose(img)
- batch_images.append(img)
- batch_images = batch_images * btcrept # Standard mode sends the same image per batchsize.
- p.init_images = batch_images
-
- if is_inpaint_batch:
- # try to find corresponding mask for an image using simple filename matching
- batch_mask_images = []
- for image_file in batch_image_files:
- mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image_file))
- # if not found use first one ("same mask for all images" use-case)
- if mask_image_path not in inpaint_masks:
- mask_image_path = inpaint_masks[0]
- mask_image = Image.open(mask_image_path)
- batch_mask_images.append(mask_image)
- batch_mask_images = batch_mask_images * btcrept
- p.image_mask = batch_mask_images
-
- batch_image_files = batch_image_files * btcrept # List used for naming later.
-
- proc = modules.scripts.scripts_img2img.run(p, *args)
- if proc is None:
- proc = processing.process_images(p)
- for n, (image, image_file) in enumerate(itertools.zip_longest(proc.images,batch_image_files)):
- basename = ''
- if shared.opts.use_original_name_batch:
- forced_filename, ext = os.path.splitext(os.path.basename(image_file))
- else:
- forced_filename = None
- ext = shared.opts.samples_format
- if len(proc.images) > 1:
- basename = f'{n + i}' if shared.opts.batch_frame_mode else f'{n}'
- else:
- basename = ''
- if output_dir == '':
- output_dir = shared.opts.outdir_img2img_samples
- if not save_normally:
- os.makedirs(output_dir, exist_ok=True)
- geninfo, items = images.read_info_from_image(image)
- for k, v in items.items():
- image.info[k] = v
- images.save_image(image, path=output_dir, basename=basename, seed=None, prompt=None, extension=ext, info=geninfo, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=image.info, forced_filename=forced_filename)
- shared.log.debug(f'Processed: images={len(batch_image_files)} memory={memory_stats()} batch')
-
-
-def img2img(id_task: str, mode: int,
- prompt, negative_prompt, prompt_styles,
- init_img,
- sketch,
- init_img_with_mask,
- inpaint_color_sketch,
- inpaint_color_sketch_orig,
- init_img_inpaint,
- init_mask_inpaint,
- steps,
- sampler_index,
- mask_blur, mask_alpha,
- inpainting_fill,
- full_quality, restore_faces, tiling,
- n_iter, batch_size,
- cfg_scale, image_cfg_scale,
- diffusers_guidance_rescale,
- sag_scale,
- refiner_start,
- clip_skip,
- denoising_strength,
- seed, subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w,
- selected_scale_tab,
- height, width,
- scale_by,
- resize_mode, resize_name,
- inpaint_full_res, inpaint_full_res_padding, inpainting_mask_invert,
- img2img_batch_files, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir,
- hdr_clamp, hdr_boundary, hdr_threshold, hdr_center, hdr_channel_shift, hdr_full_shift, hdr_maximize, hdr_max_center, hdr_max_boundry,
- override_settings_texts,
- *args): # pylint: disable=unused-argument
-
- if shared.sd_model is None:
- shared.log.warning('Model not loaded')
- return [], '', '', 'Error: model not loaded'
-
- debug(f'img2img: id_task={id_task}|mode={mode}|prompt={prompt}|negative_prompt={negative_prompt}|prompt_styles={prompt_styles}|init_img={init_img}|sketch={sketch}|init_img_with_mask={init_img_with_mask}|inpaint_color_sketch={inpaint_color_sketch}|inpaint_color_sketch_orig={inpaint_color_sketch_orig}|init_img_inpaint={init_img_inpaint}|init_mask_inpaint={init_mask_inpaint}|steps={steps}|sampler_index={sampler_index}||mask_blur={mask_blur}|mask_alpha={mask_alpha}|inpainting_fill={inpainting_fill}|full_quality={full_quality}|restore_faces={restore_faces}|tiling={tiling}|n_iter={n_iter}|batch_size={batch_size}|cfg_scale={cfg_scale}|image_cfg_scale={image_cfg_scale}|clip_skip={clip_skip}|denoising_strength={denoising_strength}|seed={seed}|subseed{subseed}|subseed_strength={subseed_strength}|seed_resize_from_h={seed_resize_from_h}|seed_resize_from_w={seed_resize_from_w}|selected_scale_tab={selected_scale_tab}|height={height}|width={width}|scale_by={scale_by}|resize_mode={resize_mode}|resize_name={resize_name}|inpaint_full_res={inpaint_full_res}|inpaint_full_res_padding={inpaint_full_res_padding}|inpainting_mask_invert={inpainting_mask_invert}|img2img_batch_files={img2img_batch_files}|img2img_batch_input_dir={img2img_batch_input_dir}|img2img_batch_output_dir={img2img_batch_output_dir}|img2img_batch_inpaint_mask_dir={img2img_batch_inpaint_mask_dir}|override_settings_texts={override_settings_texts}')
-
- if mode == 5:
- if img2img_batch_files is None or len(img2img_batch_files) == 0:
- shared.log.debug('Init bactch images not set')
- elif init_img:
- shared.log.debug('Init image not set')
-
- if sampler_index is None:
- sampler_index = 0
-
- override_settings = create_override_settings_dict(override_settings_texts)
-
- if mode == 0: # img2img
- if init_img is None:
- return [], '', '', 'Error: init image not provided'
- image = init_img.convert("RGB")
- mask = None
- elif mode == 1: # img2img sketch
- if sketch is None:
- return [], '', '', 'Error: sketch image not provided'
- image = sketch.convert("RGB")
- mask = None
- elif mode == 2: # inpaint
- if init_img_with_mask is None:
- return [], '', '', 'Error: init image with mask not provided'
- image = init_img_with_mask["image"]
- mask = init_img_with_mask["mask"]
- alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
- mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
- image = image.convert("RGB")
- elif mode == 3: # inpaint sketch
- if inpaint_color_sketch is None:
- return [], '', '', 'Error: color sketch image not provided'
- image = inpaint_color_sketch
- orig = inpaint_color_sketch_orig or inpaint_color_sketch
- pred = np.any(np.array(image) != np.array(orig), axis=-1)
- mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
- mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
- blur = ImageFilter.GaussianBlur(mask_blur)
- image = Image.composite(image.filter(blur), orig, mask.filter(blur))
- image = image.convert("RGB")
- elif mode == 4: # inpaint upload mask
- if init_img_inpaint is None:
- return [], '', '', 'Error: inpaint image not provided'
- image = init_img_inpaint
- mask = init_mask_inpaint
- else:
- shared.log.error(f'Image processing unknown mode: {mode}')
- image = None
- mask = None
- if image is not None:
- image = ImageOps.exif_transpose(image)
- if selected_scale_tab == 1 and resize_mode != 0:
- width = int(image.width * scale_by)
- height = int(image.height * scale_by)
-
- p = processing.StableDiffusionProcessingImg2Img(
- sd_model=shared.sd_model,
- outpath_samples=shared.opts.outdir_samples or shared.opts.outdir_img2img_samples,
- outpath_grids=shared.opts.outdir_grids or shared.opts.outdir_img2img_grids,
- prompt=prompt,
- negative_prompt=negative_prompt,
- styles=prompt_styles,
- seed=seed,
- subseed=subseed,
- subseed_strength=subseed_strength,
- seed_resize_from_h=seed_resize_from_h,
- seed_resize_from_w=seed_resize_from_w,
- seed_enable_extras=True,
- sampler_name = processing.get_sampler_name(sampler_index, img=True),
- batch_size=batch_size,
- n_iter=n_iter,
- steps=steps,
- cfg_scale=cfg_scale,
- clip_skip=clip_skip,
- width=width,
- height=height,
- full_quality=full_quality,
- restore_faces=restore_faces,
- tiling=tiling,
- init_images=[image],
- mask=mask,
- mask_blur=mask_blur,
- inpainting_fill=inpainting_fill,
- resize_mode=resize_mode,
- resize_name=resize_name,
- denoising_strength=denoising_strength,
- image_cfg_scale=image_cfg_scale,
- diffusers_guidance_rescale=diffusers_guidance_rescale,
- sag_scale=sag_scale,
- refiner_start=refiner_start,
- inpaint_full_res=inpaint_full_res != 0,
- inpaint_full_res_padding=inpaint_full_res_padding,
- inpainting_mask_invert=inpainting_mask_invert,
- hdr_clamp=hdr_clamp, hdr_boundary=hdr_boundary, hdr_threshold=hdr_threshold,
- hdr_center=hdr_center, hdr_channel_shift=hdr_channel_shift, hdr_full_shift=hdr_full_shift,
- hdr_maximize=hdr_maximize, hdr_max_center=hdr_max_center, hdr_max_boundry=hdr_max_boundry,
- override_settings=override_settings,
- )
- if selected_scale_tab == 1 and resize_mode != 0:
- p.scale_by = scale_by
- p.scripts = modules.scripts.scripts_img2img
- p.script_args = args
- if mask:
- p.extra_generation_params["Mask blur"] = mask_blur
- p.extra_generation_params["Mask alpha"] = mask_alpha
- p.extra_generation_params["Mask invert"] = inpainting_mask_invert
- p.extra_generation_params["Mask content"] = inpainting_fill
- p.extra_generation_params["Mask area"] = inpaint_full_res
- p.extra_generation_params["Mask padding"] = inpaint_full_res_padding
- p.is_batch = mode == 5
- if p.is_batch:
- process_batch(p, img2img_batch_files, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args)
- processed = processing.Processed(p, [], p.seed, "")
- else:
- processed = modules.scripts.scripts_img2img.run(p, *args)
- if processed is None:
- processed = processing.process_images(p)
- p.close()
- generation_info_js = processed.js() if processed is not None else ''
- return processed.images, generation_info_js, processed.info, plaintext_to_html(processed.comments)
+import os
+import itertools # SBM Batch frames
+import numpy as np
+from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
+import modules.scripts
+from modules import shared, processing, images
+from modules.generation_parameters_copypaste import create_override_settings_dict
+from modules.ui import plaintext_to_html
+from modules.memstats import memory_stats
+
+
+debug = shared.log.trace if os.environ.get('SD_PROCESS_DEBUG', None) is not None else lambda *args, **kwargs: None
+debug('Trace: PROCESS')
+
+
+def process_batch(p, input_files, input_dir, output_dir, inpaint_mask_dir, args):
+ shared.log.debug(f'batch: {input_files}|{input_dir}|{output_dir}|{inpaint_mask_dir}')
+ processing.fix_seed(p)
+ if input_files is not None and len(input_files) > 0:
+ image_files = [f.name for f in input_files]
+ else:
+ if not os.path.isdir(input_dir):
+ shared.log.error(f"Process batch: directory not found: {input_dir}")
+ return
+ image_files = os.listdir(input_dir)
+ is_inpaint_batch = False
+ if inpaint_mask_dir:
+ inpaint_masks = os.listdir(inpaint_mask_dir)
+ is_inpaint_batch = len(inpaint_masks) > 0
+ if is_inpaint_batch:
+ shared.log.info(f"Process batch: inpaint batch masks={len(inpaint_masks)}")
+ save_normally = output_dir == ''
+ p.do_not_save_grid = True
+ p.do_not_save_samples = not save_normally
+ shared.state.job_count = len(image_files) * p.n_iter
+ if shared.opts.batch_frame_mode: # SBM Frame mode is on, process each image in batch with same seed
+ window_size = p.batch_size
+ btcrept = 1
+ p.seed = [p.seed] * window_size # SBM MONKEYPATCH: Need to change processing to support a fixed seed value.
+ p.subseed = [p.subseed] * window_size # SBM MONKEYPATCH
+ shared.log.info(f"Process batch: inputs={len(image_files)} parallel={window_size} outputs={p.n_iter} per input ")
+ else: # SBM Frame mode is off, standard operation of repeating same images with sequential seed.
+ window_size = 1
+ btcrept = p.batch_size
+ shared.log.info(f"Process batch: inputs={len(image_files)} outputs={p.n_iter * p.batch_size} per input")
+ for i in range(0, len(image_files), window_size):
+ if shared.state.skipped:
+ shared.state.skipped = False
+ if shared.state.interrupted:
+ break
+ batch_image_files = image_files[i:i+window_size]
+ batch_images = []
+ for image_file in batch_image_files:
+ try:
+ img = Image.open(image_file)
+ if p.scale_by != 1:
+ p.width = int(img.width * p.scale_by)
+ p.height = int(img.height * p.scale_by)
+ except UnidentifiedImageError as e:
+ shared.log.error(f"Image error: {e}")
+ continue
+ img = ImageOps.exif_transpose(img)
+ batch_images.append(img)
+ batch_images = batch_images * btcrept # Standard mode sends the same image per batchsize.
+ p.init_images = batch_images
+
+ if is_inpaint_batch:
+ # try to find corresponding mask for an image using simple filename matching
+ batch_mask_images = []
+ for image_file in batch_image_files:
+ mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image_file))
+ # if not found use first one ("same mask for all images" use-case)
+ if mask_image_path not in inpaint_masks:
+ mask_image_path = inpaint_masks[0]
+ mask_image = Image.open(mask_image_path)
+ batch_mask_images.append(mask_image)
+ batch_mask_images = batch_mask_images * btcrept
+ p.image_mask = batch_mask_images
+
+ batch_image_files = batch_image_files * btcrept # List used for naming later.
+
+ proc = modules.scripts.scripts_img2img.run(p, *args)
+ if proc is None:
+ proc = processing.process_images(p)
+ for n, (image, image_file) in enumerate(itertools.zip_longest(proc.images,batch_image_files)):
+ basename = ''
+ if shared.opts.use_original_name_batch:
+ forced_filename, ext = os.path.splitext(os.path.basename(image_file))
+ else:
+ forced_filename = None
+ ext = shared.opts.samples_format
+ if len(proc.images) > 1:
+ basename = f'{n + i}' if shared.opts.batch_frame_mode else f'{n}'
+ else:
+ basename = ''
+ if output_dir == '':
+ output_dir = shared.opts.outdir_img2img_samples
+ if not save_normally:
+ os.makedirs(output_dir, exist_ok=True)
+ geninfo, items = images.read_info_from_image(image)
+ for k, v in items.items():
+ image.info[k] = v
+ images.save_image(image, path=output_dir, basename=basename, seed=None, prompt=None, extension=ext, info=geninfo, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=image.info, forced_filename=forced_filename)
+ shared.log.debug(f'Processed: images={len(batch_image_files)} memory={memory_stats()} batch')
+
+
+def img2img(id_task: str, mode: int,
+ prompt, negative_prompt, prompt_styles,
+ init_img,
+ sketch,
+ init_img_with_mask,
+ inpaint_color_sketch,
+ inpaint_color_sketch_orig,
+ init_img_inpaint,
+ init_mask_inpaint,
+ steps,
+ sampler_index,
+ mask_blur, mask_alpha,
+ inpainting_fill,
+ full_quality, restore_faces, tiling,
+ n_iter, batch_size,
+ cfg_scale, image_cfg_scale,
+ diffusers_guidance_rescale,
+ sag_scale,
+ refiner_start,
+ clip_skip,
+ denoising_strength,
+ seed, subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w,
+ selected_scale_tab,
+ height, width,
+ scale_by,
+ resize_mode, resize_name,
+ inpaint_full_res, inpaint_full_res_padding, inpainting_mask_invert,
+ img2img_batch_files, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir,
+ hdr_clamp, hdr_boundary, hdr_threshold, hdr_center, hdr_channel_shift, hdr_full_shift, hdr_maximize, hdr_max_center, hdr_max_boundry,
+ override_settings_texts,
+ *args): # pylint: disable=unused-argument
+
+ if shared.sd_model is None:
+ shared.log.warning('Model not loaded')
+ return [], '', '', 'Error: model not loaded'
+
+ debug(f'img2img: id_task={id_task}|mode={mode}|prompt={prompt}|negative_prompt={negative_prompt}|prompt_styles={prompt_styles}|init_img={init_img}|sketch={sketch}|init_img_with_mask={init_img_with_mask}|inpaint_color_sketch={inpaint_color_sketch}|inpaint_color_sketch_orig={inpaint_color_sketch_orig}|init_img_inpaint={init_img_inpaint}|init_mask_inpaint={init_mask_inpaint}|steps={steps}|sampler_index={sampler_index}||mask_blur={mask_blur}|mask_alpha={mask_alpha}|inpainting_fill={inpainting_fill}|full_quality={full_quality}|restore_faces={restore_faces}|tiling={tiling}|n_iter={n_iter}|batch_size={batch_size}|cfg_scale={cfg_scale}|image_cfg_scale={image_cfg_scale}|clip_skip={clip_skip}|denoising_strength={denoising_strength}|seed={seed}|subseed{subseed}|subseed_strength={subseed_strength}|seed_resize_from_h={seed_resize_from_h}|seed_resize_from_w={seed_resize_from_w}|selected_scale_tab={selected_scale_tab}|height={height}|width={width}|scale_by={scale_by}|resize_mode={resize_mode}|resize_name={resize_name}|inpaint_full_res={inpaint_full_res}|inpaint_full_res_padding={inpaint_full_res_padding}|inpainting_mask_invert={inpainting_mask_invert}|img2img_batch_files={img2img_batch_files}|img2img_batch_input_dir={img2img_batch_input_dir}|img2img_batch_output_dir={img2img_batch_output_dir}|img2img_batch_inpaint_mask_dir={img2img_batch_inpaint_mask_dir}|override_settings_texts={override_settings_texts}')
+
+ if mode == 5:
+ if img2img_batch_files is None or len(img2img_batch_files) == 0:
+ shared.log.debug('Init bactch images not set')
+ elif init_img:
+ shared.log.debug('Init image not set')
+
+ if sampler_index is None:
+ sampler_index = 0
+
+ override_settings = create_override_settings_dict(override_settings_texts)
+
+ if mode == 0: # img2img
+ if init_img is None:
+ return [], '', '', 'Error: init image not provided'
+ image = init_img.convert("RGB")
+ mask = None
+ elif mode == 1: # img2img sketch
+ if sketch is None:
+ return [], '', '', 'Error: sketch image not provided'
+ image = sketch.convert("RGB")
+ mask = None
+ elif mode == 2: # inpaint
+ if init_img_with_mask is None:
+ return [], '', '', 'Error: init image with mask not provided'
+ image = init_img_with_mask["image"]
+ mask = init_img_with_mask["mask"]
+ alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
+ mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
+ image = image.convert("RGB")
+ elif mode == 3: # inpaint sketch
+ if inpaint_color_sketch is None:
+ return [], '', '', 'Error: color sketch image not provided'
+ image = inpaint_color_sketch
+ orig = inpaint_color_sketch_orig or inpaint_color_sketch
+ pred = np.any(np.array(image) != np.array(orig), axis=-1)
+ mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
+ mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
+ blur = ImageFilter.GaussianBlur(mask_blur)
+ image = Image.composite(image.filter(blur), orig, mask.filter(blur))
+ image = image.convert("RGB")
+ elif mode == 4: # inpaint upload mask
+ if init_img_inpaint is None:
+ return [], '', '', 'Error: inpaint image not provided'
+ image = init_img_inpaint
+ mask = init_mask_inpaint
+ else:
+ shared.log.error(f'Image processing unknown mode: {mode}')
+ image = None
+ mask = None
+ if image is not None:
+ image = ImageOps.exif_transpose(image)
+ if selected_scale_tab == 1 and resize_mode != 0:
+ width = int(image.width * scale_by)
+ height = int(image.height * scale_by)
+
+ p = processing.StableDiffusionProcessingImg2Img(
+ sd_model=shared.sd_model,
+ outpath_samples=shared.opts.outdir_samples or shared.opts.outdir_img2img_samples,
+ outpath_grids=shared.opts.outdir_grids or shared.opts.outdir_img2img_grids,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ styles=prompt_styles,
+ seed=seed,
+ subseed=subseed,
+ subseed_strength=subseed_strength,
+ seed_resize_from_h=seed_resize_from_h,
+ seed_resize_from_w=seed_resize_from_w,
+ seed_enable_extras=True,
+ sampler_name = processing.get_sampler_name(sampler_index, img=True),
+ batch_size=batch_size,
+ n_iter=n_iter,
+ steps=steps,
+ cfg_scale=cfg_scale,
+ clip_skip=clip_skip,
+ width=width,
+ height=height,
+ full_quality=full_quality,
+ restore_faces=restore_faces,
+ tiling=tiling,
+ init_images=[image],
+ mask=mask,
+ mask_blur=mask_blur,
+ inpainting_fill=inpainting_fill,
+ resize_mode=resize_mode,
+ resize_name=resize_name,
+ denoising_strength=denoising_strength,
+ image_cfg_scale=image_cfg_scale,
+ diffusers_guidance_rescale=diffusers_guidance_rescale,
+ sag_scale=sag_scale,
+ refiner_start=refiner_start,
+ inpaint_full_res=inpaint_full_res != 0,
+ inpaint_full_res_padding=inpaint_full_res_padding,
+ inpainting_mask_invert=inpainting_mask_invert,
+ hdr_clamp=hdr_clamp, hdr_boundary=hdr_boundary, hdr_threshold=hdr_threshold,
+ hdr_center=hdr_center, hdr_channel_shift=hdr_channel_shift, hdr_full_shift=hdr_full_shift,
+ hdr_maximize=hdr_maximize, hdr_max_center=hdr_max_center, hdr_max_boundry=hdr_max_boundry,
+ override_settings=override_settings,
+ )
+ if selected_scale_tab == 1 and resize_mode != 0:
+ p.scale_by = scale_by
+ p.scripts = modules.scripts.scripts_img2img
+ p.script_args = args
+ if mask:
+ p.extra_generation_params["Mask blur"] = mask_blur
+ p.extra_generation_params["Mask alpha"] = mask_alpha
+ p.extra_generation_params["Mask invert"] = inpainting_mask_invert
+ p.extra_generation_params["Mask content"] = inpainting_fill
+ p.extra_generation_params["Mask area"] = inpaint_full_res
+ p.extra_generation_params["Mask padding"] = inpaint_full_res_padding
+ p.is_batch = mode == 5
+ if p.is_batch:
+ process_batch(p, img2img_batch_files, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args)
+ processed = processing.Processed(p, [], p.seed, "")
+ else:
+ processed = modules.scripts.scripts_img2img.run(p, *args)
+ if processed is None:
+ processed = processing.process_images(p)
+ p.close()
+ generation_info_js = processed.js() if processed is not None else ''
+ return processed.images, generation_info_js, processed.info, plaintext_to_html(processed.comments)
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 5a7b4ca18..ba216a146 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -1,195 +1,195 @@
-import os
-import sys
-from collections import namedtuple
-from pathlib import Path
-import re
-import torch
-import torch.hub # pylint: disable=ungrouped-imports
-from PIL import Image
-from torchvision import transforms
-from torchvision.transforms.functional import InterpolationMode
-from modules import devices, paths, shared, lowvram, modelloader, errors
-
-
-blip_image_eval_size = 384
-clip_model_name = 'ViT-L/14'
-Category = namedtuple("Category", ["name", "topn", "items"])
-re_topn = re.compile(r"\.top(\d+)\.")
-
-
-def category_types():
- return [f.stem for f in Path(shared.interrogator.content_dir).glob('*.txt')]
-
-
-def download_default_clip_interrogate_categories(content_dir):
- shared.log.info("Downloading CLIP categories...")
- tmpdir = f"{content_dir}_tmp"
- cat_types = ["artists", "flavors", "mediums", "movements"]
- try:
- os.makedirs(tmpdir, exist_ok=True)
- for category_type in cat_types:
- torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt"))
- os.rename(tmpdir, content_dir)
- except Exception as e:
- errors.display(e, "downloading default CLIP interrogate categories")
- finally:
- if os.path.exists(tmpdir):
- os.removedirs(tmpdir)
-
-
-class InterrogateModels:
- blip_model = None
- clip_model = None
- clip_preprocess = None
- dtype = None
- running_on_cpu = None
-
- def __init__(self, content_dir):
- self.loaded_categories = None
- self.skip_categories = []
- self.content_dir = content_dir
- self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
-
- def categories(self):
- if not os.path.exists(self.content_dir):
- download_default_clip_interrogate_categories(self.content_dir)
- if self.loaded_categories is not None and self.skip_categories == shared.opts.interrogate_clip_skip_categories:
- return self.loaded_categories
- self.loaded_categories = []
-
- if os.path.exists(self.content_dir):
- self.skip_categories = shared.opts.interrogate_clip_skip_categories
- cat_types = []
- for filename in Path(self.content_dir).glob('*.txt'):
- cat_types.append(filename.stem)
- if filename.stem in self.skip_categories:
- continue
- m = re_topn.search(filename.stem)
- topn = 1 if m is None else int(m.group(1))
- with open(filename, "r", encoding="utf8") as file:
- lines = [x.strip() for x in file.readlines()]
- self.loaded_categories.append(Category(name=filename.stem, topn=topn, items=lines))
- return self.loaded_categories
-
- def create_fake_fairscale(self):
- class FakeFairscale:
- def checkpoint_wrapper(self):
- pass
- sys.modules["fairscale.nn.checkpoint.checkpoint_activations"] = FakeFairscale
-
- def load_blip_model(self):
- self.create_fake_fairscale()
- import models.blip # pylint: disable=no-name-in-module
- model_path = os.path.join(paths.models_path, "BLIP")
- download_name='model_base_caption_capfilt_large.pth',
- shared.log.debug(f'Model interrogate load: type=BLiP model={download_name} path={model_path}')
- files = modelloader.load_models(
- model_path=model_path,
- model_url='https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth',
- ext_filter=[".pth"],
- download_name=download_name,
- )
- blip_model = models.blip.blip_decoder(pretrained=files[0], image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json")) # pylint: disable=c-extension-no-member
- blip_model.eval()
-
- return blip_model
-
- def load_clip_model(self):
- shared.log.debug(f'Model interrogate load: type=CLiP model={clip_model_name} path={shared.opts.clip_models_path}')
- import clip
- if self.running_on_cpu:
- model, preprocess = clip.load(clip_model_name, device="cpu", download_root=shared.opts.clip_models_path)
- else:
- model, preprocess = clip.load(clip_model_name, download_root=shared.opts.clip_models_path)
- model.eval()
- model = model.to(devices.device_interrogate)
- return model, preprocess
-
- def load(self):
- if self.blip_model is None:
- self.blip_model = self.load_blip_model()
- if not shared.opts.no_half and not self.running_on_cpu:
- self.blip_model = self.blip_model.half()
- self.blip_model = self.blip_model.to(devices.device_interrogate)
- if self.clip_model is None:
- self.clip_model, self.clip_preprocess = self.load_clip_model()
- if not shared.opts.no_half and not self.running_on_cpu:
- self.clip_model = self.clip_model.half()
- self.clip_model = self.clip_model.to(devices.device_interrogate)
- self.dtype = next(self.clip_model.parameters()).dtype
-
- def send_clip_to_ram(self):
- if not shared.opts.interrogate_keep_models_in_memory:
- if self.clip_model is not None:
- self.clip_model = self.clip_model.to(devices.cpu)
-
- def send_blip_to_ram(self):
- if not shared.opts.interrogate_keep_models_in_memory:
- if self.blip_model is not None:
- self.blip_model = self.blip_model.to(devices.cpu)
-
- def unload(self):
- self.send_clip_to_ram()
- self.send_blip_to_ram()
- devices.torch_gc()
-
- def rank(self, image_features, text_array, top_count=1):
- import clip
- devices.torch_gc()
- if shared.opts.interrogate_clip_dict_limit != 0:
- text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
- top_count = min(top_count, len(text_array))
- text_tokens = clip.tokenize(list(text_array), truncate=True).to(devices.device_interrogate)
- text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
- text_features /= text_features.norm(dim=-1, keepdim=True)
- similarity = torch.zeros((1, len(text_array))).to(devices.device_interrogate)
- for i in range(image_features.shape[0]):
- similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
- similarity /= image_features.shape[0]
- top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
- return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]
-
- def generate_caption(self, pil_image):
- gpu_image = transforms.Compose([
- transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
- transforms.ToTensor(),
- transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
- ])(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
- with devices.inference_context():
- caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length)
- return caption[0]
-
- def interrogate(self, pil_image):
- res = ""
- shared.state.begin('interrogate')
- try:
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
- lowvram.send_everything_to_cpu()
- devices.torch_gc()
- self.load()
- if isinstance(pil_image, list):
- pil_image = pil_image[0]
- if isinstance(pil_image, dict) and 'name' in pil_image:
- pil_image = Image.open(pil_image['name'])
- pil_image = pil_image.convert("RGB")
- caption = self.generate_caption(pil_image)
- self.send_blip_to_ram()
- devices.torch_gc()
- res = caption
- clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
- with devices.inference_context(), devices.autocast():
- image_features = self.clip_model.encode_image(clip_image).type(self.dtype)
- image_features /= image_features.norm(dim=-1, keepdim=True)
- for _name, topn, items in self.categories():
- matches = self.rank(image_features, items, top_count=topn)
- for match, score in matches:
- if shared.opts.interrogate_return_ranks:
- res += f", ({match}:{score/100:.3f})"
- else:
- res += f", {match}"
- except Exception as e:
- errors.display(e, 'interrogate')
- res += ""
- self.unload()
- shared.state.end()
- return res
+import os
+import sys
+from collections import namedtuple
+from pathlib import Path
+import re
+import torch
+import torch.hub # pylint: disable=ungrouped-imports
+from PIL import Image
+from torchvision import transforms
+from torchvision.transforms.functional import InterpolationMode
+from modules import devices, paths, shared, lowvram, modelloader, errors
+
+
+blip_image_eval_size = 384
+clip_model_name = 'ViT-L/14'
+Category = namedtuple("Category", ["name", "topn", "items"])
+re_topn = re.compile(r"\.top(\d+)\.")
+
+
+def category_types():
+ return [f.stem for f in Path(shared.interrogator.content_dir).glob('*.txt')]
+
+
+def download_default_clip_interrogate_categories(content_dir):
+ shared.log.info("Downloading CLIP categories...")
+ tmpdir = f"{content_dir}_tmp"
+ cat_types = ["artists", "flavors", "mediums", "movements"]
+ try:
+ os.makedirs(tmpdir, exist_ok=True)
+ for category_type in cat_types:
+ torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt"))
+ os.rename(tmpdir, content_dir)
+ except Exception as e:
+ errors.display(e, "downloading default CLIP interrogate categories")
+ finally:
+ if os.path.exists(tmpdir):
+ os.removedirs(tmpdir)
+
+
+class InterrogateModels:
+ blip_model = None
+ clip_model = None
+ clip_preprocess = None
+ dtype = None
+ running_on_cpu = None
+
+ def __init__(self, content_dir):
+ self.loaded_categories = None
+ self.skip_categories = []
+ self.content_dir = content_dir
+ self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
+
+ def categories(self):
+ if not os.path.exists(self.content_dir):
+ download_default_clip_interrogate_categories(self.content_dir)
+ if self.loaded_categories is not None and self.skip_categories == shared.opts.interrogate_clip_skip_categories:
+ return self.loaded_categories
+ self.loaded_categories = []
+
+ if os.path.exists(self.content_dir):
+ self.skip_categories = shared.opts.interrogate_clip_skip_categories
+ cat_types = []
+ for filename in Path(self.content_dir).glob('*.txt'):
+ cat_types.append(filename.stem)
+ if filename.stem in self.skip_categories:
+ continue
+ m = re_topn.search(filename.stem)
+ topn = 1 if m is None else int(m.group(1))
+ with open(filename, "r", encoding="utf8") as file:
+ lines = [x.strip() for x in file.readlines()]
+ self.loaded_categories.append(Category(name=filename.stem, topn=topn, items=lines))
+ return self.loaded_categories
+
+ def create_fake_fairscale(self):
+ class FakeFairscale:
+ def checkpoint_wrapper(self):
+ pass
+ sys.modules["fairscale.nn.checkpoint.checkpoint_activations"] = FakeFairscale
+
+ def load_blip_model(self):
+ self.create_fake_fairscale()
+ import models.blip # pylint: disable=no-name-in-module
+ model_path = os.path.join(paths.models_path, "BLIP")
+ download_name='model_base_caption_capfilt_large.pth',
+ shared.log.debug(f'Model interrogate load: type=BLiP model={download_name} path={model_path}')
+ files = modelloader.load_models(
+ model_path=model_path,
+ model_url='https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth',
+ ext_filter=[".pth"],
+ download_name=download_name,
+ )
+ blip_model = models.blip.blip_decoder(pretrained=files[0], image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json")) # pylint: disable=c-extension-no-member
+ blip_model.eval()
+
+ return blip_model
+
+ def load_clip_model(self):
+ shared.log.debug(f'Model interrogate load: type=CLiP model={clip_model_name} path={shared.opts.clip_models_path}')
+ import clip
+ if self.running_on_cpu:
+ model, preprocess = clip.load(clip_model_name, device="cpu", download_root=shared.opts.clip_models_path)
+ else:
+ model, preprocess = clip.load(clip_model_name, download_root=shared.opts.clip_models_path)
+ model.eval()
+ model = model.to(devices.device_interrogate)
+ return model, preprocess
+
+ def load(self):
+ if self.blip_model is None:
+ self.blip_model = self.load_blip_model()
+ if not shared.opts.no_half and not self.running_on_cpu:
+ self.blip_model = self.blip_model.half()
+ self.blip_model = self.blip_model.to(devices.device_interrogate)
+ if self.clip_model is None:
+ self.clip_model, self.clip_preprocess = self.load_clip_model()
+ if not shared.opts.no_half and not self.running_on_cpu:
+ self.clip_model = self.clip_model.half()
+ self.clip_model = self.clip_model.to(devices.device_interrogate)
+ self.dtype = next(self.clip_model.parameters()).dtype
+
+ def send_clip_to_ram(self):
+ if not shared.opts.interrogate_keep_models_in_memory:
+ if self.clip_model is not None:
+ self.clip_model = self.clip_model.to(devices.cpu)
+
+ def send_blip_to_ram(self):
+ if not shared.opts.interrogate_keep_models_in_memory:
+ if self.blip_model is not None:
+ self.blip_model = self.blip_model.to(devices.cpu)
+
+ def unload(self):
+ self.send_clip_to_ram()
+ self.send_blip_to_ram()
+ devices.torch_gc()
+
+ def rank(self, image_features, text_array, top_count=1):
+ import clip
+ devices.torch_gc()
+ if shared.opts.interrogate_clip_dict_limit != 0:
+ text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
+ top_count = min(top_count, len(text_array))
+ text_tokens = clip.tokenize(list(text_array), truncate=True).to(devices.device_interrogate)
+ text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
+ text_features /= text_features.norm(dim=-1, keepdim=True)
+ similarity = torch.zeros((1, len(text_array))).to(devices.device_interrogate)
+ for i in range(image_features.shape[0]):
+ similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
+ similarity /= image_features.shape[0]
+ top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
+ return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]
+
+ def generate_caption(self, pil_image):
+ gpu_image = transforms.Compose([
+ transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
+ transforms.ToTensor(),
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
+ ])(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
+ with devices.inference_context():
+ caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length)
+ return caption[0]
+
+ def interrogate(self, pil_image):
+ res = ""
+ shared.state.begin('interrogate')
+ try:
+ if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+ lowvram.send_everything_to_cpu()
+ devices.torch_gc()
+ self.load()
+ if isinstance(pil_image, list):
+ pil_image = pil_image[0]
+ if isinstance(pil_image, dict) and 'name' in pil_image:
+ pil_image = Image.open(pil_image['name'])
+ pil_image = pil_image.convert("RGB")
+ caption = self.generate_caption(pil_image)
+ self.send_blip_to_ram()
+ devices.torch_gc()
+ res = caption
+ clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
+ with devices.inference_context(), devices.autocast():
+ image_features = self.clip_model.encode_image(clip_image).type(self.dtype)
+ image_features /= image_features.norm(dim=-1, keepdim=True)
+ for _name, topn, items in self.categories():
+ matches = self.rank(image_features, items, top_count=topn)
+ for match, score in matches:
+ if shared.opts.interrogate_return_ranks:
+ res += f", ({match}:{score/100:.3f})"
+ else:
+ res += f", {match}"
+ except Exception as e:
+ errors.display(e, 'interrogate')
+ res += ""
+ self.unload()
+ shared.state.end()
+ return res
diff --git a/modules/localization.py b/modules/localization.py
index d18d5137b..a385ccbd7 100644
--- a/modules/localization.py
+++ b/modules/localization.py
@@ -1,38 +1,38 @@
-import json
-import sys
-import modules.errors as errors
-
-
-localizations = {}
-
-
-def list_localizations(dirname): # pylint: disable=unused-argument
- localizations.clear()
- """
- for file in os.listdir(dirname):
- fn, ext = os.path.splitext(file)
- if ext.lower() != ".json":
- continue
-
- localizations[fn] = os.path.join(dirname, file)
-
- from modules import scripts
- for file in scripts.list_scripts("localizations", ".json"):
- fn, ext = os.path.splitext(file.filename)
- localizations[fn] = file.path
- """
- return localizations
-
-
-def localization_js(current_localization_name):
- fn = localizations.get(current_localization_name, None)
- data = {}
- if fn is not None:
- try:
- with open(fn, "r", encoding="utf8") as file:
- data = json.load(file)
- except Exception as e:
- print(f"Error loading localization from {fn}:", file=sys.stderr)
- errors.display(e, 'localization')
-
- return f"var localization = {json.dumps(data)}\n"
+import json
+import sys
+import modules.errors as errors
+
+
+localizations = {}
+
+
+def list_localizations(dirname): # pylint: disable=unused-argument
+ localizations.clear()
+ """
+ for file in os.listdir(dirname):
+ fn, ext = os.path.splitext(file)
+ if ext.lower() != ".json":
+ continue
+
+ localizations[fn] = os.path.join(dirname, file)
+
+ from modules import scripts
+ for file in scripts.list_scripts("localizations", ".json"):
+ fn, ext = os.path.splitext(file.filename)
+ localizations[fn] = file.path
+ """
+ return localizations
+
+
+def localization_js(current_localization_name):
+ fn = localizations.get(current_localization_name, None)
+ data = {}
+ if fn is not None:
+ try:
+ with open(fn, "r", encoding="utf8") as file:
+ data = json.load(file)
+ except Exception as e:
+ print(f"Error loading localization from {fn}:", file=sys.stderr)
+ errors.display(e, 'localization')
+
+ return f"var localization = {json.dumps(data)}\n"
diff --git a/modules/lowvram.py b/modules/lowvram.py
index 5afb7e7c8..143893ce6 100644
--- a/modules/lowvram.py
+++ b/modules/lowvram.py
@@ -1,98 +1,98 @@
-import torch
-from modules import devices
-
-module_in_gpu = None
-cpu = torch.device("cpu")
-
-
-def send_everything_to_cpu():
- global module_in_gpu # pylint: disable=global-statement
-
- if module_in_gpu is not None:
- module_in_gpu.to(cpu)
-
- module_in_gpu = None
-
-
-def setup_for_low_vram(sd_model, use_medvram):
- parents = {}
-
- def send_me_to_gpu(module, _):
- """send this module to GPU; send whatever tracked module was previous in GPU to CPU;
- we add this as forward_pre_hook to a lot of modules and this way all but one of them will
- be in CPU
- """
- global module_in_gpu # pylint: disable=global-statement
-
- module = parents.get(module, module)
-
- if module_in_gpu == module:
- return
-
- if module_in_gpu is not None:
- module_in_gpu.to(cpu)
-
- module.to(devices.device)
- module_in_gpu = module
-
- # see below for register_forward_pre_hook;
- # first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
- # useless here, and we just replace those methods
-
- first_stage_model = sd_model.first_stage_model
- first_stage_model_encode = sd_model.first_stage_model.encode
- first_stage_model_decode = sd_model.first_stage_model.decode
-
- def first_stage_model_encode_wrap(x):
- send_me_to_gpu(first_stage_model, None)
- return first_stage_model_encode(x)
-
- def first_stage_model_decode_wrap(z):
- send_me_to_gpu(first_stage_model, None)
- return first_stage_model_decode(z)
-
- # for SD1, cond_stage_model is CLIP and its NN is in the tranformer frield, but for SD2, it's open clip, and it's in model field
- if hasattr(sd_model, 'cond_stage_model') and hasattr(sd_model.cond_stage_model, 'model'):
- sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model
-
- # remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model and then
- # send the model to GPU. Then put modules back. the modules will be in CPU.
- stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), getattr(sd_model, 'embedder', None), sd_model.model
- sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = None, None, None, None, None
- sd_model.to(devices.device)
- sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = stored
-
- # register hooks for those the first three models
- sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
- sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
- sd_model.first_stage_model.encode = first_stage_model_encode_wrap
- sd_model.first_stage_model.decode = first_stage_model_decode_wrap
- if sd_model.depth_model:
- sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
- if sd_model.embedder:
- sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
- parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
-
- if hasattr(sd_model, 'cond_stage_model') and hasattr(sd_model.cond_stage_model, 'model'):
- sd_model.cond_stage_model.model = sd_model.cond_stage_model.transformer
- del sd_model.cond_stage_model.transformer
-
- if use_medvram:
- sd_model.model.register_forward_pre_hook(send_me_to_gpu)
- else:
- diff_model = sd_model.model.diffusion_model
-
- # the third remaining model is still too big for 4 GB, so we also do the same for its submodules
- # so that only one of them is in GPU at a time
- stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
- diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
- sd_model.model.to(devices.device)
- diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
-
- # install hooks for bits of third model
- diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)
- for block in diff_model.input_blocks:
- block.register_forward_pre_hook(send_me_to_gpu)
- diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
- for block in diff_model.output_blocks:
- block.register_forward_pre_hook(send_me_to_gpu)
+import torch
+from modules import devices
+
+module_in_gpu = None
+cpu = torch.device("cpu")
+
+
+def send_everything_to_cpu():
+ global module_in_gpu # pylint: disable=global-statement
+
+ if module_in_gpu is not None:
+ module_in_gpu.to(cpu)
+
+ module_in_gpu = None
+
+
+def setup_for_low_vram(sd_model, use_medvram):
+ parents = {}
+
+ def send_me_to_gpu(module, _):
+ """send this module to GPU; send whatever tracked module was previous in GPU to CPU;
+ we add this as forward_pre_hook to a lot of modules and this way all but one of them will
+ be in CPU
+ """
+ global module_in_gpu # pylint: disable=global-statement
+
+ module = parents.get(module, module)
+
+ if module_in_gpu == module:
+ return
+
+ if module_in_gpu is not None:
+ module_in_gpu.to(cpu)
+
+ module.to(devices.device)
+ module_in_gpu = module
+
+ # see below for register_forward_pre_hook;
+ # first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
+ # useless here, and we just replace those methods
+
+ first_stage_model = sd_model.first_stage_model
+ first_stage_model_encode = sd_model.first_stage_model.encode
+ first_stage_model_decode = sd_model.first_stage_model.decode
+
+ def first_stage_model_encode_wrap(x):
+ send_me_to_gpu(first_stage_model, None)
+ return first_stage_model_encode(x)
+
+ def first_stage_model_decode_wrap(z):
+ send_me_to_gpu(first_stage_model, None)
+ return first_stage_model_decode(z)
+
+ # for SD1, cond_stage_model is CLIP and its NN is in the tranformer frield, but for SD2, it's open clip, and it's in model field
+ if hasattr(sd_model, 'cond_stage_model') and hasattr(sd_model.cond_stage_model, 'model'):
+ sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model
+
+ # remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model and then
+ # send the model to GPU. Then put modules back. the modules will be in CPU.
+ stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), getattr(sd_model, 'embedder', None), sd_model.model
+ sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = None, None, None, None, None
+ sd_model.to(devices.device)
+ sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = stored
+
+ # register hooks for those the first three models
+ sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
+ sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
+ sd_model.first_stage_model.encode = first_stage_model_encode_wrap
+ sd_model.first_stage_model.decode = first_stage_model_decode_wrap
+ if sd_model.depth_model:
+ sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
+ if sd_model.embedder:
+ sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
+ parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
+
+ if hasattr(sd_model, 'cond_stage_model') and hasattr(sd_model.cond_stage_model, 'model'):
+ sd_model.cond_stage_model.model = sd_model.cond_stage_model.transformer
+ del sd_model.cond_stage_model.transformer
+
+ if use_medvram:
+ sd_model.model.register_forward_pre_hook(send_me_to_gpu)
+ else:
+ diff_model = sd_model.model.diffusion_model
+
+ # the third remaining model is still too big for 4 GB, so we also do the same for its submodules
+ # so that only one of them is in GPU at a time
+ stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
+ diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
+ sd_model.model.to(devices.device)
+ diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
+
+ # install hooks for bits of third model
+ diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)
+ for block in diff_model.input_blocks:
+ block.register_forward_pre_hook(send_me_to_gpu)
+ diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
+ for block in diff_model.output_blocks:
+ block.register_forward_pre_hook(send_me_to_gpu)
diff --git a/modules/masking.py b/modules/masking.py
index ce1de4c22..6c19276c8 100644
--- a/modules/masking.py
+++ b/modules/masking.py
@@ -1,85 +1,85 @@
-from PIL import Image, ImageFilter, ImageOps
-
-
-def get_crop_region(mask, pad=0):
- """finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle.
- For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)"""
- h, w = mask.shape
- crop_left = 0
- for i in range(w):
- if not (mask[:, i] == 0).all():
- break
- crop_left += 1
- crop_right = 0
- for i in reversed(range(w)):
- if not (mask[:, i] == 0).all():
- break
- crop_right += 1
- crop_top = 0
- for i in range(h):
- if not (mask[i] == 0).all():
- break
- crop_top += 1
- crop_bottom = 0
- for i in reversed(range(h)):
- if not (mask[i] == 0).all():
- break
- crop_bottom += 1
- return (
- int(max(crop_left-pad, 0)),
- int(max(crop_top-pad, 0)),
- int(min(w - crop_right + pad, w)),
- int(min(h - crop_bottom + pad, h))
- )
-
-
-def expand_crop_region(crop_region, processing_width, processing_height, image_width, image_height):
- """expands crop region get_crop_region() to match the ratio of the image the region will processed in; returns expanded region
- for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128."""
- x1, y1, x2, y2 = crop_region
- ratio_crop_region = (x2 - x1) / (y2 - y1)
- ratio_processing = processing_width / processing_height
-
- if ratio_crop_region > ratio_processing:
- desired_height = (x2 - x1) / ratio_processing
- desired_height_diff = int(desired_height - (y2-y1))
- y1 -= desired_height_diff//2
- y2 += desired_height_diff - desired_height_diff//2
- if y2 >= image_height:
- diff = y2 - image_height
- y2 -= diff
- y1 -= diff
- if y1 < 0:
- y2 -= y1
- y1 -= y1
- if y2 >= image_height:
- y2 = image_height
- else:
- desired_width = (y2 - y1) * ratio_processing
- desired_width_diff = int(desired_width - (x2-x1))
- x1 -= desired_width_diff//2
- x2 += desired_width_diff - desired_width_diff//2
- if x2 >= image_width:
- diff = x2 - image_width
- x2 -= diff
- x1 -= diff
- if x1 < 0:
- x2 -= x1
- x1 -= x1
- if x2 >= image_width:
- x2 = image_width
-
- return x1, y1, x2, y2
-
-
-def fill(image, mask):
- """fills masked regions with colors from image using blur. Not extremely effective."""
- image_mod = Image.new('RGBA', (image.width, image.height))
- image_masked = Image.new('RGBa', (image.width, image.height))
- image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert('L')))
- image_masked = image_masked.convert('RGBa')
- for radius, repeats in [(256, 1), (64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]:
- blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA')
- for _ in range(repeats):
- image_mod.alpha_composite(blurred)
- return image_mod.convert("RGB")
+from PIL import Image, ImageFilter, ImageOps
+
+
+def get_crop_region(mask, pad=0):
+ """finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle.
+ For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)"""
+ h, w = mask.shape
+ crop_left = 0
+ for i in range(w):
+ if not (mask[:, i] == 0).all():
+ break
+ crop_left += 1
+ crop_right = 0
+ for i in reversed(range(w)):
+ if not (mask[:, i] == 0).all():
+ break
+ crop_right += 1
+ crop_top = 0
+ for i in range(h):
+ if not (mask[i] == 0).all():
+ break
+ crop_top += 1
+ crop_bottom = 0
+ for i in reversed(range(h)):
+ if not (mask[i] == 0).all():
+ break
+ crop_bottom += 1
+ return (
+ int(max(crop_left-pad, 0)),
+ int(max(crop_top-pad, 0)),
+ int(min(w - crop_right + pad, w)),
+ int(min(h - crop_bottom + pad, h))
+ )
+
+
+def expand_crop_region(crop_region, processing_width, processing_height, image_width, image_height):
+ """expands crop region get_crop_region() to match the ratio of the image the region will processed in; returns expanded region
+ for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128."""
+ x1, y1, x2, y2 = crop_region
+ ratio_crop_region = (x2 - x1) / (y2 - y1)
+ ratio_processing = processing_width / processing_height
+
+ if ratio_crop_region > ratio_processing:
+ desired_height = (x2 - x1) / ratio_processing
+ desired_height_diff = int(desired_height - (y2-y1))
+ y1 -= desired_height_diff//2
+ y2 += desired_height_diff - desired_height_diff//2
+ if y2 >= image_height:
+ diff = y2 - image_height
+ y2 -= diff
+ y1 -= diff
+ if y1 < 0:
+ y2 -= y1
+ y1 -= y1
+ if y2 >= image_height:
+ y2 = image_height
+ else:
+ desired_width = (y2 - y1) * ratio_processing
+ desired_width_diff = int(desired_width - (x2-x1))
+ x1 -= desired_width_diff//2
+ x2 += desired_width_diff - desired_width_diff//2
+ if x2 >= image_width:
+ diff = x2 - image_width
+ x2 -= diff
+ x1 -= diff
+ if x1 < 0:
+ x2 -= x1
+ x1 -= x1
+ if x2 >= image_width:
+ x2 = image_width
+
+ return x1, y1, x2, y2
+
+
+def fill(image, mask):
+ """fills masked regions with colors from image using blur. Not extremely effective."""
+ image_mod = Image.new('RGBA', (image.width, image.height))
+ image_masked = Image.new('RGBa', (image.width, image.height))
+ image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert('L')))
+ image_masked = image_masked.convert('RGBa')
+ for radius, repeats in [(256, 1), (64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]:
+ blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA')
+ for _ in range(repeats):
+ image_mod.alpha_composite(blurred)
+ return image_mod.convert("RGB")
diff --git a/modules/patches.py b/modules/patches.py
index cff6bfd64..348759ab7 100644
--- a/modules/patches.py
+++ b/modules/patches.py
@@ -1,63 +1,63 @@
-from collections import defaultdict
-
-
-def patch(key, obj, field, replacement):
- """Replaces a function in a module or a class.
-
- Also stores the original function in this module, possible to be retrieved via original(key, obj, field).
- If the function is already replaced by this caller (key), an exception is raised -- use undo() before that.
-
- Arguments:
- key: identifying information for who is doing the replacement. You can use __name__.
- obj: the module or the class
- field: name of the function as a string
- replacement: the new function
-
- Returns:
- the original function
- """
-
- patch_key = (obj, field)
- if patch_key in originals[key]:
- raise RuntimeError(f"patch for {field} is already applied")
-
- original_func = getattr(obj, field)
- originals[key][patch_key] = original_func
-
- setattr(obj, field, replacement)
-
- return original_func
-
-
-def undo(key, obj, field):
- """Undoes the peplacement by the patch().
-
- If the function is not replaced, raises an exception.
-
- Arguments:
- key: identifying information for who is doing the replacement. You can use __name__.
- obj: the module or the class
- field: name of the function as a string
-
- Returns:
- Always None
- """
-
- patch_key = (obj, field)
-
- if patch_key not in originals[key]:
- raise RuntimeError(f"there is no patch for {field} to undo")
-
- original_func = originals[key].pop(patch_key)
- setattr(obj, field, original_func)
- return None
-
-
-def original(key, obj, field):
- """Returns the original function for the patch created by the patch() function"""
- patch_key = (obj, field)
-
- return originals[key].get(patch_key, None)
-
-
-originals = defaultdict(dict)
+from collections import defaultdict
+
+
+def patch(key, obj, field, replacement):
+ """Replaces a function in a module or a class.
+
+ Also stores the original function in this module, possible to be retrieved via original(key, obj, field).
+ If the function is already replaced by this caller (key), an exception is raised -- use undo() before that.
+
+ Arguments:
+ key: identifying information for who is doing the replacement. You can use __name__.
+ obj: the module or the class
+ field: name of the function as a string
+ replacement: the new function
+
+ Returns:
+ the original function
+ """
+
+ patch_key = (obj, field)
+ if patch_key in originals[key]:
+ raise RuntimeError(f"patch for {field} is already applied")
+
+ original_func = getattr(obj, field)
+ originals[key][patch_key] = original_func
+
+ setattr(obj, field, replacement)
+
+ return original_func
+
+
+def undo(key, obj, field):
+ """Undoes the peplacement by the patch().
+
+ If the function is not replaced, raises an exception.
+
+ Arguments:
+ key: identifying information for who is doing the replacement. You can use __name__.
+ obj: the module or the class
+ field: name of the function as a string
+
+ Returns:
+ Always None
+ """
+
+ patch_key = (obj, field)
+
+ if patch_key not in originals[key]:
+ raise RuntimeError(f"there is no patch for {field} to undo")
+
+ original_func = originals[key].pop(patch_key)
+ setattr(obj, field, original_func)
+ return None
+
+
+def original(key, obj, field):
+ """Returns the original function for the patch created by the patch() function"""
+ patch_key = (obj, field)
+
+ return originals[key].get(patch_key, None)
+
+
+originals = defaultdict(dict)
diff --git a/modules/paths.py b/modules/paths.py
index 3166005b6..7295f6575 100644
--- a/modules/paths.py
+++ b/modules/paths.py
@@ -1,130 +1,130 @@
-# this module must not have any dependencies as it is a very first import before webui even starts
-import os
-import sys
-import json
-import argparse
-from modules.errors import log
-
-
-# parse args, parse again after we have the data-dir and early-read the config file
-parser = argparse.ArgumentParser(add_help=False)
-parser.add_argument("--ckpt", type=str, default=os.environ.get("SD_MODEL", None), help="Path to model checkpoint to load immediately, default: %(default)s")
-parser.add_argument("--data-dir", type=str, default=os.environ.get("SD_DATADIR", ''), help="Base path where all user data is stored, default: %(default)s")
-parser.add_argument("--models-dir", type=str, default=os.environ.get("SD_MODELSDIR", None), help="Base path where all models are stored, default: %(default)s",)
-cli = parser.parse_known_args()[0]
-parser.add_argument("--config", type=str, default=os.environ.get("SD_CONFIG", os.path.join(cli.data_dir, 'config.json')), help="Use specific server configuration file, default: %(default)s")
-cli = parser.parse_known_args()[0]
-config_path = cli.config if os.path.isabs(cli.config) else os.path.join(cli.data_dir, cli.config)
-try:
- with open(config_path, 'r', encoding='utf8') as f:
- config = json.load(f)
-except Exception:
- config = {}
-
-modules_path = os.path.dirname(os.path.realpath(__file__))
-script_path = os.path.dirname(modules_path)
-data_path = cli.data_dir
-models_config = cli.models_dir or config.get('models_dir') or 'models'
-models_path = models_config if os.path.isabs(models_config) else os.path.join(data_path, models_config)
-extensions_dir = os.path.join(data_path, "extensions")
-extensions_builtin_dir = "extensions-builtin"
-sd_configs_path = os.path.join(script_path, "configs")
-sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
-sd_model_file = cli.ckpt or os.path.join(script_path, 'model.ckpt') # not used
-default_sd_model_file = sd_model_file # not used
-debug = log.trace if os.environ.get('SD_PATH_DEBUG', None) is not None else lambda *args, **kwargs: None
-debug('Trace: PATH')
-paths = {}
-
-if os.environ.get('SD_PATH_DEBUG', None) is not None:
- print(f'Paths: script-path="{script_path}" data-dir="{data_path}" models-dir="{models_path}" config="{config_path}"')
-
-
-def register_paths():
- log.debug('Register paths')
- sys.path.insert(0, script_path)
- sd_path = os.path.join(script_path, 'repositories')
- path_dirs = [
- (sd_path, 'ldm', 'ldm', []),
- (sd_path, 'taming', 'Taming Transformers', []),
- (os.path.join(sd_path, 'blip'), 'models/blip.py', 'BLIP', []),
- (os.path.join(sd_path, 'codeformer'), 'inference_codeformer.py', 'CodeFormer', []),
- (os.path.join(modules_path, 'k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
- ]
- for d, must_exist, what, _options in path_dirs:
- must_exist_path = os.path.abspath(os.path.join(script_path, d, must_exist))
- if not os.path.exists(must_exist_path):
- log.error(f'Required path not found: path={must_exist_path} item={what}')
- else:
- d = os.path.abspath(d)
- sys.path.append(d)
- paths[what] = d
-
-
-def create_path(folder):
- if folder is None or folder == '':
- return
- if os.path.exists(folder):
- return
- try:
- os.makedirs(folder, exist_ok=True)
- log.info(f'Create: folder="{folder}"')
- except Exception as e:
- log.error(f'Create failed: folder="{folder}" {e}')
-
-
-def create_paths(opts):
- def fix_path(folder):
- tgt = opts.data.get(folder, None) or opts.data_labels[folder].default
- if tgt is None or tgt == '':
- return tgt
- fix = tgt
- if not os.path.isabs(tgt) and len(data_path) > 0 and not tgt.startswith(data_path): # path is already relative to data_path
- fix = os.path.join(data_path, fix)
- if fix.startswith('..'):
- fix = os.path.abspath(fix)
- fix = fix if os.path.isabs(fix) else os.path.relpath(fix, script_path)
- opts.data[folder] = fix
- debug(f'Paths: folder="{folder}" original="{tgt}" target="{fix}"')
- return opts.data[folder]
-
- create_path(data_path)
- create_path(script_path)
- create_path(models_path)
- create_path(sd_configs_path)
- create_path(extensions_dir)
- create_path(extensions_builtin_dir)
- create_path(fix_path('temp_dir'))
- create_path(fix_path('ckpt_dir'))
- create_path(fix_path('diffusers_dir'))
- create_path(fix_path('vae_dir'))
- create_path(fix_path('lora_dir'))
- create_path(fix_path('embeddings_dir'))
- create_path(fix_path('hypernetwork_dir'))
- create_path(fix_path('outdir_samples'))
- create_path(fix_path('outdir_txt2img_samples'))
- create_path(fix_path('outdir_img2img_samples'))
- create_path(fix_path('outdir_control_samples'))
- create_path(fix_path('outdir_extras_samples'))
- create_path(fix_path('outdir_init_images'))
- create_path(fix_path('outdir_grids'))
- create_path(fix_path('outdir_txt2img_grids'))
- create_path(fix_path('outdir_img2img_grids'))
- create_path(fix_path('outdir_control_grids'))
- create_path(fix_path('outdir_save'))
- create_path(fix_path('outdir_video'))
- create_path(fix_path('styles_dir'))
-
-
-class Prioritize:
- def __init__(self, name):
- self.name = name
- self.path = None
-
- def __enter__(self):
- self.path = sys.path.copy()
- sys.path = [paths[self.name]] + sys.path
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- sys.path = self.path
- self.path = None
+# this module must not have any dependencies as it is a very first import before webui even starts
+import os
+import sys
+import json
+import argparse
+from modules.errors import log
+
+
+# parse args, parse again after we have the data-dir and early-read the config file
+parser = argparse.ArgumentParser(add_help=False)
+parser.add_argument("--ckpt", type=str, default=os.environ.get("SD_MODEL", None), help="Path to model checkpoint to load immediately, default: %(default)s")
+parser.add_argument("--data-dir", type=str, default=os.environ.get("SD_DATADIR", ''), help="Base path where all user data is stored, default: %(default)s")
+parser.add_argument("--models-dir", type=str, default=os.environ.get("SD_MODELSDIR", None), help="Base path where all models are stored, default: %(default)s",)
+cli = parser.parse_known_args()[0]
+parser.add_argument("--config", type=str, default=os.environ.get("SD_CONFIG", os.path.join(cli.data_dir, 'config.json')), help="Use specific server configuration file, default: %(default)s")
+cli = parser.parse_known_args()[0]
+config_path = cli.config if os.path.isabs(cli.config) else os.path.join(cli.data_dir, cli.config)
+try:
+ with open(config_path, 'r', encoding='utf8') as f:
+ config = json.load(f)
+except Exception:
+ config = {}
+
+modules_path = os.path.dirname(os.path.realpath(__file__))
+script_path = os.path.dirname(modules_path)
+data_path = cli.data_dir
+models_config = cli.models_dir or config.get('models_dir') or 'models'
+models_path = models_config if os.path.isabs(models_config) else os.path.join(data_path, models_config)
+extensions_dir = os.path.join(data_path, "extensions")
+extensions_builtin_dir = "extensions-builtin"
+sd_configs_path = os.path.join(script_path, "configs")
+sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
+sd_model_file = cli.ckpt or os.path.join(script_path, 'model.ckpt') # not used
+default_sd_model_file = sd_model_file # not used
+debug = log.trace if os.environ.get('SD_PATH_DEBUG', None) is not None else lambda *args, **kwargs: None
+debug('Trace: PATH')
+paths = {}
+
+if os.environ.get('SD_PATH_DEBUG', None) is not None:
+ print(f'Paths: script-path="{script_path}" data-dir="{data_path}" models-dir="{models_path}" config="{config_path}"')
+
+
+def register_paths():
+ log.debug('Register paths')
+ sys.path.insert(0, script_path)
+ sd_path = os.path.join(script_path, 'repositories')
+ path_dirs = [
+ (sd_path, 'ldm', 'ldm', []),
+ (sd_path, 'taming', 'Taming Transformers', []),
+ (os.path.join(sd_path, 'blip'), 'models/blip.py', 'BLIP', []),
+ (os.path.join(sd_path, 'codeformer'), 'inference_codeformer.py', 'CodeFormer', []),
+ (os.path.join(modules_path, 'k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
+ ]
+ for d, must_exist, what, _options in path_dirs:
+ must_exist_path = os.path.abspath(os.path.join(script_path, d, must_exist))
+ if not os.path.exists(must_exist_path):
+ log.error(f'Required path not found: path={must_exist_path} item={what}')
+ else:
+ d = os.path.abspath(d)
+ sys.path.append(d)
+ paths[what] = d
+
+
+def create_path(folder):
+ if folder is None or folder == '':
+ return
+ if os.path.exists(folder):
+ return
+ try:
+ os.makedirs(folder, exist_ok=True)
+ log.info(f'Create: folder="{folder}"')
+ except Exception as e:
+ log.error(f'Create failed: folder="{folder}" {e}')
+
+
+def create_paths(opts):
+ def fix_path(folder):
+ tgt = opts.data.get(folder, None) or opts.data_labels[folder].default
+ if tgt is None or tgt == '':
+ return tgt
+ fix = tgt
+ if not os.path.isabs(tgt) and len(data_path) > 0 and not tgt.startswith(data_path): # path is already relative to data_path
+ fix = os.path.join(data_path, fix)
+ if fix.startswith('..'):
+ fix = os.path.abspath(fix)
+ fix = fix if os.path.isabs(fix) else os.path.relpath(fix, script_path)
+ opts.data[folder] = fix
+ debug(f'Paths: folder="{folder}" original="{tgt}" target="{fix}"')
+ return opts.data[folder]
+
+ create_path(data_path)
+ create_path(script_path)
+ create_path(models_path)
+ create_path(sd_configs_path)
+ create_path(extensions_dir)
+ create_path(extensions_builtin_dir)
+ create_path(fix_path('temp_dir'))
+ create_path(fix_path('ckpt_dir'))
+ create_path(fix_path('diffusers_dir'))
+ create_path(fix_path('vae_dir'))
+ create_path(fix_path('lora_dir'))
+ create_path(fix_path('embeddings_dir'))
+ create_path(fix_path('hypernetwork_dir'))
+ create_path(fix_path('outdir_samples'))
+ create_path(fix_path('outdir_txt2img_samples'))
+ create_path(fix_path('outdir_img2img_samples'))
+ create_path(fix_path('outdir_control_samples'))
+ create_path(fix_path('outdir_extras_samples'))
+ create_path(fix_path('outdir_init_images'))
+ create_path(fix_path('outdir_grids'))
+ create_path(fix_path('outdir_txt2img_grids'))
+ create_path(fix_path('outdir_img2img_grids'))
+ create_path(fix_path('outdir_control_grids'))
+ create_path(fix_path('outdir_save'))
+ create_path(fix_path('outdir_video'))
+ create_path(fix_path('styles_dir'))
+
+
+class Prioritize:
+ def __init__(self, name):
+ self.name = name
+ self.path = None
+
+ def __enter__(self):
+ self.path = sys.path.copy()
+ sys.path = [paths[self.name]] + sys.path
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ sys.path = self.path
+ self.path = None
diff --git a/modules/paths_internal.py b/modules/paths_internal.py
index 49a055caa..3a408329d 100644
--- a/modules/paths_internal.py
+++ b/modules/paths_internal.py
@@ -1,30 +1,30 @@
-# no longer used, all paths are defined in paths.py
-
-from modules.paths import modules_path, script_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, data_path, models_path, extensions_dir, extensions_builtin_dir # pylint: disable=unused-import
-
-"""
-import argparse
-import os
-
-modules_path = os.path.dirname(os.path.realpath(__file__))
-script_path = os.path.dirname(modules_path)
-sd_configs_path = os.path.join(script_path, "configs")
-sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
-
-# Parse the --data-dir flag first so we can use it as a base for our other argument default values
-parser_pre = argparse.ArgumentParser(add_help=False)
-parser_pre.add_argument("--ckpt", type=str, default=os.environ.get("SD_MODEL", None), help="Path to model checkpoint to load immediately, default: %(default)s")
-parser_pre.add_argument("--data-dir", type=str, default=os.environ.get("SD_DATADIR", ''), help="Base path where all user data is stored, default: %(default)s")
-parser_pre.add_argument("--models-dir", type=str, default=os.environ.get("SD_MODELSDIR", 'models'), help="Base path where all models are stored, default: %(default)s",)
-cmd_opts_pre = parser_pre.parse_known_args()[0]
-
-# parser_pre.add_argument("--config", type=str, default=os.environ.get("SD_CONFIG", os.path.join(data_path, 'config.json')), help="Use specific server configuration file, default: %(default)s")
-
-data_path = cmd_opts_pre.data_dir
-models_path = cmd_opts_pre.models_dir if os.path.isabs(cmd_opts_pre.models_dir) else os.path.join(data_path, cmd_opts_pre.models_dir)
-extensions_dir = os.path.join(data_path, "extensions")
-extensions_builtin_dir = "extensions-builtin"
-
-sd_model_file = cmd_opts_pre.ckpt or os.path.join(script_path, 'model.ckpt') # not used
-default_sd_model_file = sd_model_file # not used
-"""
+# no longer used, all paths are defined in paths.py
+
+from modules.paths import modules_path, script_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, data_path, models_path, extensions_dir, extensions_builtin_dir # pylint: disable=unused-import
+
+"""
+import argparse
+import os
+
+modules_path = os.path.dirname(os.path.realpath(__file__))
+script_path = os.path.dirname(modules_path)
+sd_configs_path = os.path.join(script_path, "configs")
+sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
+
+# Parse the --data-dir flag first so we can use it as a base for our other argument default values
+parser_pre = argparse.ArgumentParser(add_help=False)
+parser_pre.add_argument("--ckpt", type=str, default=os.environ.get("SD_MODEL", None), help="Path to model checkpoint to load immediately, default: %(default)s")
+parser_pre.add_argument("--data-dir", type=str, default=os.environ.get("SD_DATADIR", ''), help="Base path where all user data is stored, default: %(default)s")
+parser_pre.add_argument("--models-dir", type=str, default=os.environ.get("SD_MODELSDIR", 'models'), help="Base path where all models are stored, default: %(default)s",)
+cmd_opts_pre = parser_pre.parse_known_args()[0]
+
+# parser_pre.add_argument("--config", type=str, default=os.environ.get("SD_CONFIG", os.path.join(data_path, 'config.json')), help="Use specific server configuration file, default: %(default)s")
+
+data_path = cmd_opts_pre.data_dir
+models_path = cmd_opts_pre.models_dir if os.path.isabs(cmd_opts_pre.models_dir) else os.path.join(data_path, cmd_opts_pre.models_dir)
+extensions_dir = os.path.join(data_path, "extensions")
+extensions_builtin_dir = "extensions-builtin"
+
+sd_model_file = cmd_opts_pre.ckpt or os.path.join(script_path, 'model.ckpt') # not used
+default_sd_model_file = sd_model_file # not used
+"""
diff --git a/modules/postprocess/codeformer_model.py b/modules/postprocess/codeformer_model.py
index fcf938081..59a990ce2 100644
--- a/modules/postprocess/codeformer_model.py
+++ b/modules/postprocess/codeformer_model.py
@@ -1,111 +1,111 @@
-import os
-import cv2
-import torch
-import modules.face_restoration
-from modules import shared, devices, modelloader, errors
-from modules.paths import models_path
-
-# codeformer people made a choice to include modified basicsr library to their project which makes
-# it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN.
-# I am making a choice to include some files from codeformer to work around this issue.
-model_dir = "Codeformer"
-model_path = os.path.join(models_path, model_dir)
-model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
-
-have_codeformer = False
-codeformer = None
-
-
-def setup_model(dirname):
- if not os.path.exists(model_path):
- os.makedirs(model_path)
- path = modules.paths.paths.get("CodeFormer", None)
- if path is None:
- return
-
- try:
- class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration):
- def name(self):
- return "CodeFormer"
-
- def __init__(self, dirname):
- self.net = None
- self.face_helper = None
- self.cmd_dir = dirname
-
- def create_models(self):
- from modules.postprocess.codeformer_arch import CodeFormer
- from facelib.utils.face_restoration_helper import FaceRestoreHelper
- from facelib.detection.retinaface import retinaface
- if self.net is not None and self.face_helper is not None:
- self.net.to(devices.device_codeformer)
- return self.net, self.face_helper
- model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth', ext_filter=['.pth'])
- if len(model_paths) != 0:
- ckpt_path = model_paths[0]
- else:
- shared.log.error(f"Model failed loading: type=CodeFormer model={model_path}")
- return None, None
- net = CodeFormer(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer)
- checkpoint = torch.load(ckpt_path)['params_ema']
- net.load_state_dict(checkpoint)
- net.eval()
- shared.log.info(f"Model loaded: type=CodeFormer model={ckpt_path}")
- if hasattr(retinaface, 'device'):
- retinaface.device = devices.device_codeformer
- face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer)
- self.net = net
- self.face_helper = face_helper
- return net, face_helper
-
- def send_model_to(self, device):
- self.net.to(device)
- self.face_helper.face_det.to(device) # pylint: disable=no-member
- self.face_helper.face_parse.to(device)
-
- def restore(self, np_image, w=None):
- from torchvision.transforms.functional import normalize
- from basicsr.utils import img2tensor, tensor2img
- np_image = np_image[:, :, ::-1]
- original_resolution = np_image.shape[0:2]
- self.create_models()
- if self.net is None or self.face_helper is None:
- return np_image
- self.send_model_to(devices.device_codeformer)
- self.face_helper.clean_all()
- self.face_helper.read_image(np_image)
- self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
- self.face_helper.align_warp_face()
- for cropped_face in self.face_helper.cropped_faces:
- cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
- normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
- cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
- try:
- with devices.inference_context():
- output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0] # pylint: disable=not-callable
- restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
- del output
- devices.torch_gc()
- except Exception as e:
- shared.log.error(f'CodeForomer error: {e}')
- restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
- restored_face = restored_face.astype('uint8')
- self.face_helper.add_restored_face(restored_face)
- self.face_helper.get_inverse_affine(None)
- restored_img = self.face_helper.paste_faces_to_input_image()
- restored_img = restored_img[:, :, ::-1]
- if original_resolution != restored_img.shape[0:2]:
- restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
- self.face_helper.clean_all()
- if shared.opts.face_restoration_unload:
- self.send_model_to(devices.cpu)
- return restored_img
-
- global have_codeformer # pylint: disable=global-statement
- have_codeformer = True
- global codeformer # pylint: disable=global-statement
- codeformer = FaceRestorerCodeFormer(dirname)
- shared.face_restorers.append(codeformer)
-
- except Exception as e:
- errors.display(e, 'codeformer')
+import os
+import cv2
+import torch
+import modules.face_restoration
+from modules import shared, devices, modelloader, errors
+from modules.paths import models_path
+
+# codeformer people made a choice to include modified basicsr library to their project which makes
+# it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN.
+# I am making a choice to include some files from codeformer to work around this issue.
+model_dir = "Codeformer"
+model_path = os.path.join(models_path, model_dir)
+model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
+
+have_codeformer = False
+codeformer = None
+
+
+def setup_model(dirname):
+ if not os.path.exists(model_path):
+ os.makedirs(model_path)
+ path = modules.paths.paths.get("CodeFormer", None)
+ if path is None:
+ return
+
+ try:
+ class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration):
+ def name(self):
+ return "CodeFormer"
+
+ def __init__(self, dirname):
+ self.net = None
+ self.face_helper = None
+ self.cmd_dir = dirname
+
+ def create_models(self):
+ from modules.postprocess.codeformer_arch import CodeFormer
+ from facelib.utils.face_restoration_helper import FaceRestoreHelper
+ from facelib.detection.retinaface import retinaface
+ if self.net is not None and self.face_helper is not None:
+ self.net.to(devices.device_codeformer)
+ return self.net, self.face_helper
+ model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth', ext_filter=['.pth'])
+ if len(model_paths) != 0:
+ ckpt_path = model_paths[0]
+ else:
+ shared.log.error(f"Model failed loading: type=CodeFormer model={model_path}")
+ return None, None
+ net = CodeFormer(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer)
+ checkpoint = torch.load(ckpt_path)['params_ema']
+ net.load_state_dict(checkpoint)
+ net.eval()
+ shared.log.info(f"Model loaded: type=CodeFormer model={ckpt_path}")
+ if hasattr(retinaface, 'device'):
+ retinaface.device = devices.device_codeformer
+ face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer)
+ self.net = net
+ self.face_helper = face_helper
+ return net, face_helper
+
+ def send_model_to(self, device):
+ self.net.to(device)
+ self.face_helper.face_det.to(device) # pylint: disable=no-member
+ self.face_helper.face_parse.to(device)
+
+ def restore(self, np_image, w=None):
+ from torchvision.transforms.functional import normalize
+ from basicsr.utils import img2tensor, tensor2img
+ np_image = np_image[:, :, ::-1]
+ original_resolution = np_image.shape[0:2]
+ self.create_models()
+ if self.net is None or self.face_helper is None:
+ return np_image
+ self.send_model_to(devices.device_codeformer)
+ self.face_helper.clean_all()
+ self.face_helper.read_image(np_image)
+ self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
+ self.face_helper.align_warp_face()
+ for cropped_face in self.face_helper.cropped_faces:
+ cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
+ normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
+ cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
+ try:
+ with devices.inference_context():
+ output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0] # pylint: disable=not-callable
+ restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
+ del output
+ devices.torch_gc()
+ except Exception as e:
+ shared.log.error(f'CodeForomer error: {e}')
+ restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
+ restored_face = restored_face.astype('uint8')
+ self.face_helper.add_restored_face(restored_face)
+ self.face_helper.get_inverse_affine(None)
+ restored_img = self.face_helper.paste_faces_to_input_image()
+ restored_img = restored_img[:, :, ::-1]
+ if original_resolution != restored_img.shape[0:2]:
+ restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
+ self.face_helper.clean_all()
+ if shared.opts.face_restoration_unload:
+ self.send_model_to(devices.cpu)
+ return restored_img
+
+ global have_codeformer # pylint: disable=global-statement
+ have_codeformer = True
+ global codeformer # pylint: disable=global-statement
+ codeformer = FaceRestorerCodeFormer(dirname)
+ shared.face_restorers.append(codeformer)
+
+ except Exception as e:
+ errors.display(e, 'codeformer')
diff --git a/modules/postprocess/esrgan_model.py b/modules/postprocess/esrgan_model.py
index 29a4cc8a2..c4d3e4ad4 100644
--- a/modules/postprocess/esrgan_model.py
+++ b/modules/postprocess/esrgan_model.py
@@ -1,217 +1,217 @@
-import numpy as np
-import torch
-from PIL import Image
-from rich.progress import Progress, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn, TimeElapsedColumn
-import modules.postprocess.esrgan_model_arch as arch
-from modules import images, devices
-from modules.upscaler import Upscaler, UpscalerData, compile_upscaler
-from modules.shared import opts, log, console
-
-
-def mod2normal(state_dict):
- # this code is copied from https://github.com/victorca25/iNNfer
- if 'conv_first.weight' in state_dict:
- crt_net = {}
- items = list(state_dict)
-
- crt_net['model.0.weight'] = state_dict['conv_first.weight']
- crt_net['model.0.bias'] = state_dict['conv_first.bias']
-
- for k in items.copy():
- if 'RDB' in k:
- ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
- if '.weight' in k:
- ori_k = ori_k.replace('.weight', '.0.weight')
- elif '.bias' in k:
- ori_k = ori_k.replace('.bias', '.0.bias')
- crt_net[ori_k] = state_dict[k]
- items.remove(k)
-
- crt_net['model.1.sub.23.weight'] = state_dict['trunk_conv.weight']
- crt_net['model.1.sub.23.bias'] = state_dict['trunk_conv.bias']
- crt_net['model.3.weight'] = state_dict['upconv1.weight']
- crt_net['model.3.bias'] = state_dict['upconv1.bias']
- crt_net['model.6.weight'] = state_dict['upconv2.weight']
- crt_net['model.6.bias'] = state_dict['upconv2.bias']
- crt_net['model.8.weight'] = state_dict['HRconv.weight']
- crt_net['model.8.bias'] = state_dict['HRconv.bias']
- crt_net['model.10.weight'] = state_dict['conv_last.weight']
- crt_net['model.10.bias'] = state_dict['conv_last.bias']
- state_dict = crt_net
- return state_dict
-
-
-def resrgan2normal(state_dict, nb=23):
- # this code is copied from https://github.com/victorca25/iNNfer
- if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
- re8x = 0
- crt_net = {}
- items = list(state_dict)
-
- crt_net['model.0.weight'] = state_dict['conv_first.weight']
- crt_net['model.0.bias'] = state_dict['conv_first.bias']
-
- for k in items.copy():
- if "rdb" in k:
- ori_k = k.replace('body.', 'model.1.sub.')
- ori_k = ori_k.replace('.rdb', '.RDB')
- if '.weight' in k:
- ori_k = ori_k.replace('.weight', '.0.weight')
- elif '.bias' in k:
- ori_k = ori_k.replace('.bias', '.0.bias')
- crt_net[ori_k] = state_dict[k]
- items.remove(k)
-
- crt_net[f'model.1.sub.{nb}.weight'] = state_dict['conv_body.weight']
- crt_net[f'model.1.sub.{nb}.bias'] = state_dict['conv_body.bias']
- crt_net['model.3.weight'] = state_dict['conv_up1.weight']
- crt_net['model.3.bias'] = state_dict['conv_up1.bias']
- crt_net['model.6.weight'] = state_dict['conv_up2.weight']
- crt_net['model.6.bias'] = state_dict['conv_up2.bias']
-
- if 'conv_up3.weight' in state_dict:
- # modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py
- re8x = 3
- crt_net['model.9.weight'] = state_dict['conv_up3.weight']
- crt_net['model.9.bias'] = state_dict['conv_up3.bias']
-
- crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight']
- crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias']
- crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight']
- crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias']
-
- state_dict = crt_net
- return state_dict
-
-
-def infer_params(state_dict):
- # this code is copied from https://github.com/victorca25/iNNfer
- scale2x = 0
- scalemin = 6
- n_uplayer = 0
- plus = False
-
- for block in list(state_dict):
- parts = block.split(".")
- n_parts = len(parts)
- if n_parts == 5 and parts[2] == "sub":
- nb = int(parts[3])
- elif n_parts == 3:
- part_num = int(parts[1])
- if (part_num > scalemin
- and parts[0] == "model"
- and parts[2] == "weight"):
- scale2x += 1
- if part_num > n_uplayer:
- n_uplayer = part_num
- out_nc = state_dict[block].shape[0]
- if not plus and "conv1x1" in block:
- plus = True
-
- nf = state_dict["model.0.weight"].shape[0]
- in_nc = state_dict["model.0.weight"].shape[1]
- # out_nc = out_nc
- scale = 2 ** scale2x
-
- return in_nc, out_nc, nf, nb, plus, scale
-
-
-class UpscalerESRGAN(Upscaler):
- def __init__(self, dirname):
- self.name = "ESRGAN"
- self.user_path = dirname
- super().__init__()
- self.scalers = self.find_scalers()
- self.models = {}
-
- def do_upscale(self, img, selected_model):
- model = self.load_model(selected_model)
- if model is None:
- return img
- model.to(devices.device_esrgan)
- img = esrgan_upscale(model, img)
- if opts.upscaler_unload and selected_model in self.models:
- del self.models[selected_model]
- log.debug(f"Upscaler unloaded: type={self.name} model={selected_model}")
- devices.torch_gc(force=True)
- return img
-
- def load_model(self, path: str):
- info: UpscalerData = self.find_model(path)
- if info is None:
- return
- if self.models.get(info.local_data_path, None) is not None:
- log.debug(f"Upscaler cached: type={self.name} model={info.local_data_path}")
- return self.models[info.local_data_path]
- state_dict = torch.load(info.local_data_path, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
- log.info(f"Upscaler loaded: type={self.name} model={info.local_data_path}")
-
- if "params_ema" in state_dict:
- state_dict = state_dict["params_ema"]
- elif "params" in state_dict:
- state_dict = state_dict["params"]
- num_conv = 16 if "realesr-animevideov3" in info.local_data_path else 32
- model = arch.SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=4, act_type='prelu')
- model.load_state_dict(state_dict)
- model.eval()
- model = compile_upscaler(model, name=self.name)
- self.models[info.local_data_path] = model
- return self.models[info.local_data_path]
-
- if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict:
- nb = 6 if "RealESRGAN_x4plus_anime_6B" in info.local_data_path else 23
- state_dict = resrgan2normal(state_dict, nb)
- elif "conv_first.weight" in state_dict:
- state_dict = mod2normal(state_dict)
- elif "model.0.weight" not in state_dict:
- raise TypeError("The file is not a recognized ESRGAN model.")
- in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict)
- model = arch.RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus)
- model.load_state_dict(state_dict)
- model.eval()
- model = compile_upscaler(model, name=self.name)
- self.models[info.local_data_path] = model
- return self.models[info.local_data_path]
-
-
-def upscale_without_tiling(model, img):
- img = np.array(img)
- img = img[:, :, ::-1]
- img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
- img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(devices.device_esrgan)
- with devices.inference_context():
- output = model(img)
- output = output.squeeze().float().cpu().clamp_(0, 1).detach().numpy()
- output = 255. * np.moveaxis(output, 0, 2)
- output = output.astype(np.uint8)
- output = output[:, :, ::-1]
- return Image.fromarray(output, 'RGB')
-
-
-def esrgan_upscale(model, img):
- if opts.upscaler_tile_size == 0:
- return upscale_without_tiling(model, img)
-
- grid = images.split_grid(img, opts.upscaler_tile_size, opts.upscaler_tile_size, opts.upscaler_tile_overlap)
- newtiles = []
- scale_factor = 1
-
- with Progress(TextColumn('[cyan]{task.description}'), BarColumn(), TaskProgressColumn(), TimeRemainingColumn(), TimeElapsedColumn(), console=console) as progress:
- total = 0
- for _y, _h, row in grid.tiles:
- total += len(row)
- task = progress.add_task(description="Upscaling", total=total)
- for y, h, row in grid.tiles:
- newrow = []
- for tiledata in row:
- x, w, tile = tiledata
- output = upscale_without_tiling(model, tile)
- scale_factor = output.width // tile.width
- newrow.append([x * scale_factor, w * scale_factor, output])
- progress.update(task, advance=1, description="Upscaling")
- newtiles.append([y * scale_factor, h * scale_factor, newrow])
-
- newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
- output = images.combine_grid(newgrid)
- return output
+import numpy as np
+import torch
+from PIL import Image
+from rich.progress import Progress, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn, TimeElapsedColumn
+import modules.postprocess.esrgan_model_arch as arch
+from modules import images, devices
+from modules.upscaler import Upscaler, UpscalerData, compile_upscaler
+from modules.shared import opts, log, console
+
+
+def mod2normal(state_dict):
+ # this code is copied from https://github.com/victorca25/iNNfer
+ if 'conv_first.weight' in state_dict:
+ crt_net = {}
+ items = list(state_dict)
+
+ crt_net['model.0.weight'] = state_dict['conv_first.weight']
+ crt_net['model.0.bias'] = state_dict['conv_first.bias']
+
+ for k in items.copy():
+ if 'RDB' in k:
+ ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
+ if '.weight' in k:
+ ori_k = ori_k.replace('.weight', '.0.weight')
+ elif '.bias' in k:
+ ori_k = ori_k.replace('.bias', '.0.bias')
+ crt_net[ori_k] = state_dict[k]
+ items.remove(k)
+
+ crt_net['model.1.sub.23.weight'] = state_dict['trunk_conv.weight']
+ crt_net['model.1.sub.23.bias'] = state_dict['trunk_conv.bias']
+ crt_net['model.3.weight'] = state_dict['upconv1.weight']
+ crt_net['model.3.bias'] = state_dict['upconv1.bias']
+ crt_net['model.6.weight'] = state_dict['upconv2.weight']
+ crt_net['model.6.bias'] = state_dict['upconv2.bias']
+ crt_net['model.8.weight'] = state_dict['HRconv.weight']
+ crt_net['model.8.bias'] = state_dict['HRconv.bias']
+ crt_net['model.10.weight'] = state_dict['conv_last.weight']
+ crt_net['model.10.bias'] = state_dict['conv_last.bias']
+ state_dict = crt_net
+ return state_dict
+
+
+def resrgan2normal(state_dict, nb=23):
+ # this code is copied from https://github.com/victorca25/iNNfer
+ if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
+ re8x = 0
+ crt_net = {}
+ items = list(state_dict)
+
+ crt_net['model.0.weight'] = state_dict['conv_first.weight']
+ crt_net['model.0.bias'] = state_dict['conv_first.bias']
+
+ for k in items.copy():
+ if "rdb" in k:
+ ori_k = k.replace('body.', 'model.1.sub.')
+ ori_k = ori_k.replace('.rdb', '.RDB')
+ if '.weight' in k:
+ ori_k = ori_k.replace('.weight', '.0.weight')
+ elif '.bias' in k:
+ ori_k = ori_k.replace('.bias', '.0.bias')
+ crt_net[ori_k] = state_dict[k]
+ items.remove(k)
+
+ crt_net[f'model.1.sub.{nb}.weight'] = state_dict['conv_body.weight']
+ crt_net[f'model.1.sub.{nb}.bias'] = state_dict['conv_body.bias']
+ crt_net['model.3.weight'] = state_dict['conv_up1.weight']
+ crt_net['model.3.bias'] = state_dict['conv_up1.bias']
+ crt_net['model.6.weight'] = state_dict['conv_up2.weight']
+ crt_net['model.6.bias'] = state_dict['conv_up2.bias']
+
+ if 'conv_up3.weight' in state_dict:
+ # modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py
+ re8x = 3
+ crt_net['model.9.weight'] = state_dict['conv_up3.weight']
+ crt_net['model.9.bias'] = state_dict['conv_up3.bias']
+
+ crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight']
+ crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias']
+ crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight']
+ crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias']
+
+ state_dict = crt_net
+ return state_dict
+
+
+def infer_params(state_dict):
+ # this code is copied from https://github.com/victorca25/iNNfer
+ scale2x = 0
+ scalemin = 6
+ n_uplayer = 0
+ plus = False
+
+ for block in list(state_dict):
+ parts = block.split(".")
+ n_parts = len(parts)
+ if n_parts == 5 and parts[2] == "sub":
+ nb = int(parts[3])
+ elif n_parts == 3:
+ part_num = int(parts[1])
+ if (part_num > scalemin
+ and parts[0] == "model"
+ and parts[2] == "weight"):
+ scale2x += 1
+ if part_num > n_uplayer:
+ n_uplayer = part_num
+ out_nc = state_dict[block].shape[0]
+ if not plus and "conv1x1" in block:
+ plus = True
+
+ nf = state_dict["model.0.weight"].shape[0]
+ in_nc = state_dict["model.0.weight"].shape[1]
+ # out_nc = out_nc
+ scale = 2 ** scale2x
+
+ return in_nc, out_nc, nf, nb, plus, scale
+
+
+class UpscalerESRGAN(Upscaler):
+ def __init__(self, dirname):
+ self.name = "ESRGAN"
+ self.user_path = dirname
+ super().__init__()
+ self.scalers = self.find_scalers()
+ self.models = {}
+
+ def do_upscale(self, img, selected_model):
+ model = self.load_model(selected_model)
+ if model is None:
+ return img
+ model.to(devices.device_esrgan)
+ img = esrgan_upscale(model, img)
+ if opts.upscaler_unload and selected_model in self.models:
+ del self.models[selected_model]
+ log.debug(f"Upscaler unloaded: type={self.name} model={selected_model}")
+ devices.torch_gc(force=True)
+ return img
+
+ def load_model(self, path: str):
+ info: UpscalerData = self.find_model(path)
+ if info is None:
+ return
+ if self.models.get(info.local_data_path, None) is not None:
+ log.debug(f"Upscaler cached: type={self.name} model={info.local_data_path}")
+ return self.models[info.local_data_path]
+ state_dict = torch.load(info.local_data_path, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
+ log.info(f"Upscaler loaded: type={self.name} model={info.local_data_path}")
+
+ if "params_ema" in state_dict:
+ state_dict = state_dict["params_ema"]
+ elif "params" in state_dict:
+ state_dict = state_dict["params"]
+ num_conv = 16 if "realesr-animevideov3" in info.local_data_path else 32
+ model = arch.SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=4, act_type='prelu')
+ model.load_state_dict(state_dict)
+ model.eval()
+ model = compile_upscaler(model, name=self.name)
+ self.models[info.local_data_path] = model
+ return self.models[info.local_data_path]
+
+ if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict:
+ nb = 6 if "RealESRGAN_x4plus_anime_6B" in info.local_data_path else 23
+ state_dict = resrgan2normal(state_dict, nb)
+ elif "conv_first.weight" in state_dict:
+ state_dict = mod2normal(state_dict)
+ elif "model.0.weight" not in state_dict:
+ raise TypeError("The file is not a recognized ESRGAN model.")
+ in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict)
+ model = arch.RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus)
+ model.load_state_dict(state_dict)
+ model.eval()
+ model = compile_upscaler(model, name=self.name)
+ self.models[info.local_data_path] = model
+ return self.models[info.local_data_path]
+
+
+def upscale_without_tiling(model, img):
+ img = np.array(img)
+ img = img[:, :, ::-1]
+ img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
+ img = torch.from_numpy(img).float()
+ img = img.unsqueeze(0).to(devices.device_esrgan)
+ with devices.inference_context():
+ output = model(img)
+ output = output.squeeze().float().cpu().clamp_(0, 1).detach().numpy()
+ output = 255. * np.moveaxis(output, 0, 2)
+ output = output.astype(np.uint8)
+ output = output[:, :, ::-1]
+ return Image.fromarray(output, 'RGB')
+
+
+def esrgan_upscale(model, img):
+ if opts.upscaler_tile_size == 0:
+ return upscale_without_tiling(model, img)
+
+ grid = images.split_grid(img, opts.upscaler_tile_size, opts.upscaler_tile_size, opts.upscaler_tile_overlap)
+ newtiles = []
+ scale_factor = 1
+
+ with Progress(TextColumn('[cyan]{task.description}'), BarColumn(), TaskProgressColumn(), TimeRemainingColumn(), TimeElapsedColumn(), console=console) as progress:
+ total = 0
+ for _y, _h, row in grid.tiles:
+ total += len(row)
+ task = progress.add_task(description="Upscaling", total=total)
+ for y, h, row in grid.tiles:
+ newrow = []
+ for tiledata in row:
+ x, w, tile = tiledata
+ output = upscale_without_tiling(model, tile)
+ scale_factor = output.width // tile.width
+ newrow.append([x * scale_factor, w * scale_factor, output])
+ progress.update(task, advance=1, description="Upscaling")
+ newtiles.append([y * scale_factor, h * scale_factor, newrow])
+
+ newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
+ output = images.combine_grid(newgrid)
+ return output
diff --git a/modules/postprocess/esrgan_model_arch.py b/modules/postprocess/esrgan_model_arch.py
index fb8ed9741..bf9f0ac6e 100644
--- a/modules/postprocess/esrgan_model_arch.py
+++ b/modules/postprocess/esrgan_model_arch.py
@@ -1,465 +1,465 @@
-# this file is adapted from https://github.com/victorca25/iNNfer
-
-import math
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-
-####################
-# RRDBNet Generator
-####################
-
-class RRDBNet(nn.Module):
- def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None,
- act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D',
- finalact=None, gaussian_noise=False, plus=False):
- super(RRDBNet, self).__init__()
- n_upscale = int(math.log(upscale, 2))
- if upscale == 3:
- n_upscale = 1
-
- self.resrgan_scale = 0
- if in_nc % 16 == 0:
- self.resrgan_scale = 1
- elif in_nc != 4 and in_nc % 4 == 0:
- self.resrgan_scale = 2
-
- fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
- rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
- norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype,
- gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)]
- LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype)
-
- if upsample_mode == 'upconv':
- upsample_block = upconv_block
- elif upsample_mode == 'pixelshuffle':
- upsample_block = pixelshuffle_block
- else:
- raise NotImplementedError(f'upsample mode [{upsample_mode}] is not found')
- if upscale == 3:
- upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
- else:
- upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)]
- HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype)
- HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
-
- outact = act(finalact) if finalact else None
-
- self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)),
- *upsampler, HR_conv0, HR_conv1, outact)
-
- def forward(self, x, outm=None):
- if self.resrgan_scale == 1:
- feat = pixel_unshuffle(x, scale=4)
- elif self.resrgan_scale == 2:
- feat = pixel_unshuffle(x, scale=2)
- else:
- feat = x
-
- return self.model(feat)
-
-
-class RRDB(nn.Module):
- """
- Residual in Residual Dense Block
- (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
- """
-
- def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
- norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
- spectral_norm=False, gaussian_noise=False, plus=False):
- super(RRDB, self).__init__()
- # This is for backwards compatibility with existing models
- if nr == 3:
- self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
- norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
- gaussian_noise=gaussian_noise, plus=plus)
- self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
- norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
- gaussian_noise=gaussian_noise, plus=plus)
- self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
- norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
- gaussian_noise=gaussian_noise, plus=plus)
- else:
- RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
- norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
- gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)]
- self.RDBs = nn.Sequential(*RDB_list)
-
- def forward(self, x):
- if hasattr(self, 'RDB1'):
- out = self.RDB1(x)
- out = self.RDB2(out)
- out = self.RDB3(out)
- else:
- out = self.RDBs(x)
- return out * 0.2 + x
-
-
-class ResidualDenseBlock_5C(nn.Module):
- """
- Residual Dense Block
- The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
- Modified options that can be used:
- - "Partial Convolution based Padding" arXiv:1811.11718
- - "Spectral normalization" arXiv:1802.05957
- - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
- {Rakotonirina} and A. {Rasoanaivo}
- """
-
- def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
- norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
- spectral_norm=False, gaussian_noise=False, plus=False):
- super(ResidualDenseBlock_5C, self).__init__()
-
- self.noise = GaussianNoise() if gaussian_noise else None
- self.conv1x1 = conv1x1(nf, gc) if plus else None
-
- self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
- norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
- spectral_norm=spectral_norm)
- self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
- norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
- spectral_norm=spectral_norm)
- self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
- norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
- spectral_norm=spectral_norm)
- self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
- norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
- spectral_norm=spectral_norm)
- if mode == 'CNA':
- last_act = None
- else:
- last_act = act_type
- self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type,
- norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype,
- spectral_norm=spectral_norm)
-
- def forward(self, x):
- x1 = self.conv1(x)
- x2 = self.conv2(torch.cat((x, x1), 1))
- if self.conv1x1:
- x2 = x2 + self.conv1x1(x)
- x3 = self.conv3(torch.cat((x, x1, x2), 1))
- x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
- if self.conv1x1:
- x4 = x4 + x2
- x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
- if self.noise:
- return self.noise(x5.mul(0.2) + x)
- else:
- return x5 * 0.2 + x
-
-
-####################
-# ESRGANplus
-####################
-
-class GaussianNoise(nn.Module):
- def __init__(self, sigma=0.1, is_relative_detach=False):
- super().__init__()
- self.sigma = sigma
- self.is_relative_detach = is_relative_detach
- self.noise = torch.tensor(0, dtype=torch.float)
-
- def forward(self, x):
- if self.training and self.sigma != 0:
- self.noise = self.noise.to(x.device)
- scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
- sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
- x = x + sampled_noise
- return x
-
-def conv1x1(in_planes, out_planes, stride=1):
- return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
-
-
-####################
-# SRVGGNetCompact
-####################
-
-class SRVGGNetCompact(nn.Module):
- """A compact VGG-style network structure for super-resolution.
- This class is copied from https://github.com/xinntao/Real-ESRGAN
- """
-
- def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
- super(SRVGGNetCompact, self).__init__()
- self.num_in_ch = num_in_ch
- self.num_out_ch = num_out_ch
- self.num_feat = num_feat
- self.num_conv = num_conv
- self.upscale = upscale
- self.act_type = act_type
-
- self.body = nn.ModuleList()
- # the first conv
- self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
- # the first activation
- if act_type == 'relu':
- activation = nn.ReLU(inplace=True)
- elif act_type == 'prelu':
- activation = nn.PReLU(num_parameters=num_feat)
- elif act_type == 'leakyrelu':
- activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
- self.body.append(activation)
-
- # the body structure
- for _ in range(num_conv):
- self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
- # activation
- if act_type == 'relu':
- activation = nn.ReLU(inplace=True)
- elif act_type == 'prelu':
- activation = nn.PReLU(num_parameters=num_feat)
- elif act_type == 'leakyrelu':
- activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
- self.body.append(activation)
-
- # the last conv
- self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
- # upsample
- self.upsampler = nn.PixelShuffle(upscale)
-
- def forward(self, x):
- out = x
- for i in range(0, len(self.body)):
- out = self.body[i](out)
-
- out = self.upsampler(out)
- # add the nearest upsampled image, so that the network learns the residual
- base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
- out += base
- return out
-
-
-####################
-# Upsampler
-####################
-
-class Upsample(nn.Module):
- r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
- The input data is assumed to be of the form
- `minibatch x channels x [optional depth] x [optional height] x width`.
- """
-
- def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None):
- super(Upsample, self).__init__()
- if isinstance(scale_factor, tuple):
- self.scale_factor = tuple(float(factor) for factor in scale_factor)
- else:
- self.scale_factor = float(scale_factor) if scale_factor else None
- self.mode = mode
- self.size = size
- self.align_corners = align_corners
-
- def forward(self, x):
- return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
-
- def extra_repr(self):
- if self.scale_factor is not None:
- info = f'scale_factor={self.scale_factor}'
- else:
- info = f'size={self.size}'
- info += f', mode={self.mode}'
- return info
-
-
-def pixel_unshuffle(x, scale):
- """ Pixel unshuffle.
- Args:
- x (Tensor): Input feature with shape (b, c, hh, hw).
- scale (int): Downsample ratio.
- Returns:
- Tensor: the pixel unshuffled feature.
- """
- b, c, hh, hw = x.size()
- out_channel = c * (scale**2)
- assert hh % scale == 0 and hw % scale == 0
- h = hh // scale
- w = hw // scale
- x_view = x.view(b, c, h, scale, w, scale)
- return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
-
-
-def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
- pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'):
- """
- Pixel shuffle layer
- (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
- Neural Network, CVPR17)
- """
- conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias,
- pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype)
- pixel_shuffle = nn.PixelShuffle(upscale_factor)
-
- n = norm(norm_type, out_nc) if norm_type else None
- a = act(act_type) if act_type else None
- return sequential(conv, pixel_shuffle, n, a)
-
-
-def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
- pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'):
- """ Upconv layer """
- upscale_factor = (1, upscale_factor, upscale_factor) if convtype == 'Conv3D' else upscale_factor
- upsample = Upsample(scale_factor=upscale_factor, mode=mode)
- conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias,
- pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype)
- return sequential(upsample, conv)
-
-
-
-
-
-
-
-
-####################
-# Basic blocks
-####################
-
-
-def make_layer(basic_block, num_basic_block, **kwarg):
- """Make layers by stacking the same blocks.
- Args:
- basic_block (nn.module): nn.module class for basic block. (block)
- num_basic_block (int): number of blocks. (n_layers)
- Returns:
- nn.Sequential: Stacked blocks in nn.Sequential.
- """
- layers = []
- for _ in range(num_basic_block):
- layers.append(basic_block(**kwarg))
- return nn.Sequential(*layers)
-
-
-def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
- """ activation helper """
- act_type = act_type.lower()
- if act_type == 'relu':
- layer = nn.ReLU(inplace)
- elif act_type in ('leakyrelu', 'lrelu'):
- layer = nn.LeakyReLU(neg_slope, inplace)
- elif act_type == 'prelu':
- layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
- elif act_type == 'tanh': # [-1, 1] range output
- layer = nn.Tanh()
- elif act_type == 'sigmoid': # [0, 1] range output
- layer = nn.Sigmoid()
- else:
- raise NotImplementedError(f'activation layer [{act_type}] is not found')
- return layer
-
-
-class Identity(nn.Module):
- def __init__(self, *kwargs):
- super(Identity, self).__init__()
-
- def forward(self, x, *kwargs):
- return x
-
-
-def norm(norm_type, nc):
- """ Return a normalization layer """
- norm_type = norm_type.lower()
- if norm_type == 'batch':
- layer = nn.BatchNorm2d(nc, affine=True)
- elif norm_type == 'instance':
- layer = nn.InstanceNorm2d(nc, affine=False)
- elif norm_type == 'none':
- def norm_layer(x): return Identity()
- else:
- raise NotImplementedError(f'normalization layer [{norm_type}] is not found')
- return layer
-
-
-def pad(pad_type, padding):
- """ padding layer helper """
- pad_type = pad_type.lower()
- if padding == 0:
- return None
- if pad_type == 'reflect':
- layer = nn.ReflectionPad2d(padding)
- elif pad_type == 'replicate':
- layer = nn.ReplicationPad2d(padding)
- elif pad_type == 'zero':
- layer = nn.ZeroPad2d(padding)
- else:
- raise NotImplementedError(f'padding layer [{pad_type}] is not implemented')
- return layer
-
-
-def get_valid_padding(kernel_size, dilation):
- kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
- padding = (kernel_size - 1) // 2
- return padding
-
-
-class ShortcutBlock(nn.Module):
- """ Elementwise sum the output of a submodule to its input """
- def __init__(self, submodule):
- super(ShortcutBlock, self).__init__()
- self.sub = submodule
-
- def forward(self, x):
- output = x + self.sub(x)
- return output
-
- def __repr__(self):
- return 'Identity + \n|' + self.sub.__repr__().replace('\n', '\n|')
-
-
-def sequential(*args):
- """ Flatten Sequential. It unwraps nn.Sequential. """
- if len(args) == 1:
- from collections import OrderedDict
- if isinstance(args[0], OrderedDict):
- raise NotImplementedError('sequential does not support OrderedDict input.')
- return args[0] # No sequential is needed.
- modules = []
- for module in args:
- if isinstance(module, nn.Sequential):
- for submodule in module.children():
- modules.append(submodule)
- elif isinstance(module, nn.Module):
- modules.append(module)
- return nn.Sequential(*modules)
-
-
-def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True,
- pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
- spectral_norm=False):
- """ Conv layer with padding, normalization, activation """
- assert mode in ['CNA', 'NAC', 'CNAC'], f'Wrong conv mode [{mode}]'
- padding = get_valid_padding(kernel_size, dilation)
- p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
- padding = padding if pad_type == 'zero' else 0
-
- if convtype=='PartialConv2D':
- from torchvision.ops import PartialConv2d
- c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
- dilation=dilation, bias=bias, groups=groups)
- elif convtype=='DeformConv2D':
- from torchvision.ops import DeformConv2d
- c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
- dilation=dilation, bias=bias, groups=groups)
- elif convtype=='Conv3D':
- c = nn.Conv3d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
- dilation=dilation, bias=bias, groups=groups)
- else:
- c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
- dilation=dilation, bias=bias, groups=groups)
-
- if spectral_norm:
- c = nn.utils.spectral_norm(c)
-
- a = act(act_type) if act_type else None
- if 'CNA' in mode:
- n = norm(norm_type, out_nc) if norm_type else None
- return sequential(p, c, n, a)
- elif mode == 'NAC':
- if norm_type is None and act_type is not None:
- a = act(act_type, inplace=False)
- n = norm(norm_type, in_nc) if norm_type else None
- return sequential(n, a, p, c)
+# this file is adapted from https://github.com/victorca25/iNNfer
+
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+####################
+# RRDBNet Generator
+####################
+
+class RRDBNet(nn.Module):
+ def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None,
+ act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D',
+ finalact=None, gaussian_noise=False, plus=False):
+ super(RRDBNet, self).__init__()
+ n_upscale = int(math.log(upscale, 2))
+ if upscale == 3:
+ n_upscale = 1
+
+ self.resrgan_scale = 0
+ if in_nc % 16 == 0:
+ self.resrgan_scale = 1
+ elif in_nc != 4 and in_nc % 4 == 0:
+ self.resrgan_scale = 2
+
+ fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
+ rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
+ norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype,
+ gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)]
+ LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype)
+
+ if upsample_mode == 'upconv':
+ upsample_block = upconv_block
+ elif upsample_mode == 'pixelshuffle':
+ upsample_block = pixelshuffle_block
+ else:
+ raise NotImplementedError(f'upsample mode [{upsample_mode}] is not found')
+ if upscale == 3:
+ upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
+ else:
+ upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)]
+ HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype)
+ HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
+
+ outact = act(finalact) if finalact else None
+
+ self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)),
+ *upsampler, HR_conv0, HR_conv1, outact)
+
+ def forward(self, x, outm=None):
+ if self.resrgan_scale == 1:
+ feat = pixel_unshuffle(x, scale=4)
+ elif self.resrgan_scale == 2:
+ feat = pixel_unshuffle(x, scale=2)
+ else:
+ feat = x
+
+ return self.model(feat)
+
+
+class RRDB(nn.Module):
+ """
+ Residual in Residual Dense Block
+ (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
+ """
+
+ def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
+ norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
+ spectral_norm=False, gaussian_noise=False, plus=False):
+ super(RRDB, self).__init__()
+ # This is for backwards compatibility with existing models
+ if nr == 3:
+ self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
+ norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
+ gaussian_noise=gaussian_noise, plus=plus)
+ self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
+ norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
+ gaussian_noise=gaussian_noise, plus=plus)
+ self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
+ norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
+ gaussian_noise=gaussian_noise, plus=plus)
+ else:
+ RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
+ norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
+ gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)]
+ self.RDBs = nn.Sequential(*RDB_list)
+
+ def forward(self, x):
+ if hasattr(self, 'RDB1'):
+ out = self.RDB1(x)
+ out = self.RDB2(out)
+ out = self.RDB3(out)
+ else:
+ out = self.RDBs(x)
+ return out * 0.2 + x
+
+
+class ResidualDenseBlock_5C(nn.Module):
+ """
+ Residual Dense Block
+ The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
+ Modified options that can be used:
+ - "Partial Convolution based Padding" arXiv:1811.11718
+ - "Spectral normalization" arXiv:1802.05957
+ - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
+ {Rakotonirina} and A. {Rasoanaivo}
+ """
+
+ def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
+ norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
+ spectral_norm=False, gaussian_noise=False, plus=False):
+ super(ResidualDenseBlock_5C, self).__init__()
+
+ self.noise = GaussianNoise() if gaussian_noise else None
+ self.conv1x1 = conv1x1(nf, gc) if plus else None
+
+ self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
+ norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
+ spectral_norm=spectral_norm)
+ self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
+ norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
+ spectral_norm=spectral_norm)
+ self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
+ norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
+ spectral_norm=spectral_norm)
+ self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
+ norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
+ spectral_norm=spectral_norm)
+ if mode == 'CNA':
+ last_act = None
+ else:
+ last_act = act_type
+ self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type,
+ norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype,
+ spectral_norm=spectral_norm)
+
+ def forward(self, x):
+ x1 = self.conv1(x)
+ x2 = self.conv2(torch.cat((x, x1), 1))
+ if self.conv1x1:
+ x2 = x2 + self.conv1x1(x)
+ x3 = self.conv3(torch.cat((x, x1, x2), 1))
+ x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
+ if self.conv1x1:
+ x4 = x4 + x2
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
+ if self.noise:
+ return self.noise(x5.mul(0.2) + x)
+ else:
+ return x5 * 0.2 + x
+
+
+####################
+# ESRGANplus
+####################
+
+class GaussianNoise(nn.Module):
+ def __init__(self, sigma=0.1, is_relative_detach=False):
+ super().__init__()
+ self.sigma = sigma
+ self.is_relative_detach = is_relative_detach
+ self.noise = torch.tensor(0, dtype=torch.float)
+
+ def forward(self, x):
+ if self.training and self.sigma != 0:
+ self.noise = self.noise.to(x.device)
+ scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
+ sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
+ x = x + sampled_noise
+ return x
+
+def conv1x1(in_planes, out_planes, stride=1):
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+####################
+# SRVGGNetCompact
+####################
+
+class SRVGGNetCompact(nn.Module):
+ """A compact VGG-style network structure for super-resolution.
+ This class is copied from https://github.com/xinntao/Real-ESRGAN
+ """
+
+ def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
+ super(SRVGGNetCompact, self).__init__()
+ self.num_in_ch = num_in_ch
+ self.num_out_ch = num_out_ch
+ self.num_feat = num_feat
+ self.num_conv = num_conv
+ self.upscale = upscale
+ self.act_type = act_type
+
+ self.body = nn.ModuleList()
+ # the first conv
+ self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
+ # the first activation
+ if act_type == 'relu':
+ activation = nn.ReLU(inplace=True)
+ elif act_type == 'prelu':
+ activation = nn.PReLU(num_parameters=num_feat)
+ elif act_type == 'leakyrelu':
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+ self.body.append(activation)
+
+ # the body structure
+ for _ in range(num_conv):
+ self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
+ # activation
+ if act_type == 'relu':
+ activation = nn.ReLU(inplace=True)
+ elif act_type == 'prelu':
+ activation = nn.PReLU(num_parameters=num_feat)
+ elif act_type == 'leakyrelu':
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+ self.body.append(activation)
+
+ # the last conv
+ self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
+ # upsample
+ self.upsampler = nn.PixelShuffle(upscale)
+
+ def forward(self, x):
+ out = x
+ for i in range(0, len(self.body)):
+ out = self.body[i](out)
+
+ out = self.upsampler(out)
+ # add the nearest upsampled image, so that the network learns the residual
+ base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
+ out += base
+ return out
+
+
+####################
+# Upsampler
+####################
+
+class Upsample(nn.Module):
+ r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
+ The input data is assumed to be of the form
+ `minibatch x channels x [optional depth] x [optional height] x width`.
+ """
+
+ def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None):
+ super(Upsample, self).__init__()
+ if isinstance(scale_factor, tuple):
+ self.scale_factor = tuple(float(factor) for factor in scale_factor)
+ else:
+ self.scale_factor = float(scale_factor) if scale_factor else None
+ self.mode = mode
+ self.size = size
+ self.align_corners = align_corners
+
+ def forward(self, x):
+ return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
+
+ def extra_repr(self):
+ if self.scale_factor is not None:
+ info = f'scale_factor={self.scale_factor}'
+ else:
+ info = f'size={self.size}'
+ info += f', mode={self.mode}'
+ return info
+
+
+def pixel_unshuffle(x, scale):
+ """ Pixel unshuffle.
+ Args:
+ x (Tensor): Input feature with shape (b, c, hh, hw).
+ scale (int): Downsample ratio.
+ Returns:
+ Tensor: the pixel unshuffled feature.
+ """
+ b, c, hh, hw = x.size()
+ out_channel = c * (scale**2)
+ assert hh % scale == 0 and hw % scale == 0
+ h = hh // scale
+ w = hw // scale
+ x_view = x.view(b, c, h, scale, w, scale)
+ return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
+
+
+def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
+ pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'):
+ """
+ Pixel shuffle layer
+ (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
+ Neural Network, CVPR17)
+ """
+ conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias,
+ pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype)
+ pixel_shuffle = nn.PixelShuffle(upscale_factor)
+
+ n = norm(norm_type, out_nc) if norm_type else None
+ a = act(act_type) if act_type else None
+ return sequential(conv, pixel_shuffle, n, a)
+
+
+def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
+ pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'):
+ """ Upconv layer """
+ upscale_factor = (1, upscale_factor, upscale_factor) if convtype == 'Conv3D' else upscale_factor
+ upsample = Upsample(scale_factor=upscale_factor, mode=mode)
+ conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias,
+ pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype)
+ return sequential(upsample, conv)
+
+
+
+
+
+
+
+
+####################
+# Basic blocks
+####################
+
+
+def make_layer(basic_block, num_basic_block, **kwarg):
+ """Make layers by stacking the same blocks.
+ Args:
+ basic_block (nn.module): nn.module class for basic block. (block)
+ num_basic_block (int): number of blocks. (n_layers)
+ Returns:
+ nn.Sequential: Stacked blocks in nn.Sequential.
+ """
+ layers = []
+ for _ in range(num_basic_block):
+ layers.append(basic_block(**kwarg))
+ return nn.Sequential(*layers)
+
+
+def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
+ """ activation helper """
+ act_type = act_type.lower()
+ if act_type == 'relu':
+ layer = nn.ReLU(inplace)
+ elif act_type in ('leakyrelu', 'lrelu'):
+ layer = nn.LeakyReLU(neg_slope, inplace)
+ elif act_type == 'prelu':
+ layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
+ elif act_type == 'tanh': # [-1, 1] range output
+ layer = nn.Tanh()
+ elif act_type == 'sigmoid': # [0, 1] range output
+ layer = nn.Sigmoid()
+ else:
+ raise NotImplementedError(f'activation layer [{act_type}] is not found')
+ return layer
+
+
+class Identity(nn.Module):
+ def __init__(self, *kwargs):
+ super(Identity, self).__init__()
+
+ def forward(self, x, *kwargs):
+ return x
+
+
+def norm(norm_type, nc):
+ """ Return a normalization layer """
+ norm_type = norm_type.lower()
+ if norm_type == 'batch':
+ layer = nn.BatchNorm2d(nc, affine=True)
+ elif norm_type == 'instance':
+ layer = nn.InstanceNorm2d(nc, affine=False)
+ elif norm_type == 'none':
+ def norm_layer(x): return Identity()
+ else:
+ raise NotImplementedError(f'normalization layer [{norm_type}] is not found')
+ return layer
+
+
+def pad(pad_type, padding):
+ """ padding layer helper """
+ pad_type = pad_type.lower()
+ if padding == 0:
+ return None
+ if pad_type == 'reflect':
+ layer = nn.ReflectionPad2d(padding)
+ elif pad_type == 'replicate':
+ layer = nn.ReplicationPad2d(padding)
+ elif pad_type == 'zero':
+ layer = nn.ZeroPad2d(padding)
+ else:
+ raise NotImplementedError(f'padding layer [{pad_type}] is not implemented')
+ return layer
+
+
+def get_valid_padding(kernel_size, dilation):
+ kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
+ padding = (kernel_size - 1) // 2
+ return padding
+
+
+class ShortcutBlock(nn.Module):
+ """ Elementwise sum the output of a submodule to its input """
+ def __init__(self, submodule):
+ super(ShortcutBlock, self).__init__()
+ self.sub = submodule
+
+ def forward(self, x):
+ output = x + self.sub(x)
+ return output
+
+ def __repr__(self):
+ return 'Identity + \n|' + self.sub.__repr__().replace('\n', '\n|')
+
+
+def sequential(*args):
+ """ Flatten Sequential. It unwraps nn.Sequential. """
+ if len(args) == 1:
+ from collections import OrderedDict
+ if isinstance(args[0], OrderedDict):
+ raise NotImplementedError('sequential does not support OrderedDict input.')
+ return args[0] # No sequential is needed.
+ modules = []
+ for module in args:
+ if isinstance(module, nn.Sequential):
+ for submodule in module.children():
+ modules.append(submodule)
+ elif isinstance(module, nn.Module):
+ modules.append(module)
+ return nn.Sequential(*modules)
+
+
+def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True,
+ pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
+ spectral_norm=False):
+ """ Conv layer with padding, normalization, activation """
+ assert mode in ['CNA', 'NAC', 'CNAC'], f'Wrong conv mode [{mode}]'
+ padding = get_valid_padding(kernel_size, dilation)
+ p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
+ padding = padding if pad_type == 'zero' else 0
+
+ if convtype=='PartialConv2D':
+ from torchvision.ops import PartialConv2d
+ c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
+ dilation=dilation, bias=bias, groups=groups)
+ elif convtype=='DeformConv2D':
+ from torchvision.ops import DeformConv2d
+ c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
+ dilation=dilation, bias=bias, groups=groups)
+ elif convtype=='Conv3D':
+ c = nn.Conv3d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
+ dilation=dilation, bias=bias, groups=groups)
+ else:
+ c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
+ dilation=dilation, bias=bias, groups=groups)
+
+ if spectral_norm:
+ c = nn.utils.spectral_norm(c)
+
+ a = act(act_type) if act_type else None
+ if 'CNA' in mode:
+ n = norm(norm_type, out_nc) if norm_type else None
+ return sequential(p, c, n, a)
+ elif mode == 'NAC':
+ if norm_type is None and act_type is not None:
+ a = act(act_type, inplace=False)
+ n = norm(norm_type, in_nc) if norm_type else None
+ return sequential(n, a, p, c)
diff --git a/modules/postprocess/gfpgan_model.py b/modules/postprocess/gfpgan_model.py
index 989af466f..b85c900e6 100644
--- a/modules/postprocess/gfpgan_model.py
+++ b/modules/postprocess/gfpgan_model.py
@@ -1,107 +1,107 @@
-import os
-
-import modules.face_restoration
-from modules import paths, shared, devices, modelloader, errors
-
-model_dir = "GFPGAN"
-user_path = None
-model_path = os.path.join(paths.models_path, model_dir)
-model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
-have_gfpgan = False
-loaded_gfpgan_model = None
-
-
-def gfpgann():
- import facexlib
- import gfpgan # pylint: disable=unused-import
- global loaded_gfpgan_model # pylint: disable=global-statement
- if loaded_gfpgan_model is not None:
- loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
- return loaded_gfpgan_model
- if gfpgan_constructor is None:
- return None
- models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
- if len(models) == 1 and "http" in models[0]:
- model_file = models[0]
- elif len(models) != 0:
- latest_file = max(models, key=os.path.getctime)
- model_file = latest_file
- else:
- shared.log.error(f"Model failed loading: type=GFPGAN model={model_file}")
- return None
- if hasattr(facexlib.detection.retinaface, 'device'):
- facexlib.detection.retinaface.device = devices.device_gfpgan
- model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
- loaded_gfpgan_model = model
- shared.log.info(f"Model loaded: type=GFPGAN model={model_file}")
- return model
-
-
-def send_model_to(model, device):
- model.gfpgan.to(device)
- model.face_helper.face_det.to(device)
- model.face_helper.face_parse.to(device)
-
-
-def gfpgan_fix_faces(np_image):
- model = gfpgann()
- if model is None:
- return np_image
-
- send_model_to(model, devices.device_gfpgan)
-
- np_image_bgr = np_image[:, :, ::-1]
- _cropped_faces, _restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
- np_image = gfpgan_output_bgr[:, :, ::-1]
-
- model.face_helper.clean_all()
-
- if shared.opts.face_restoration_unload:
- send_model_to(model, devices.cpu)
-
- return np_image
-
-
-gfpgan_constructor = None
-
-
-def setup_model(dirname):
- if not os.path.exists(model_path):
- os.makedirs(model_path)
- try:
- import gfpgan
- import facexlib
-
- global user_path # pylint: disable=global-statement
- global have_gfpgan # pylint: disable=global-statement
- global gfpgan_constructor # pylint: disable=global-statement
- load_file_from_url_orig = gfpgan.utils.load_file_from_url
- facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
- facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
-
- def my_load_file_from_url(**kwargs):
- return load_file_from_url_orig(**dict(kwargs, model_dir=model_path))
-
- def facex_load_file_from_url(**kwargs):
- return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None))
-
- def facex_load_file_from_url2(**kwargs):
- return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None))
-
- gfpgan.utils.load_file_from_url = my_load_file_from_url
- facexlib.detection.load_file_from_url = facex_load_file_from_url
- facexlib.parsing.load_file_from_url = facex_load_file_from_url2
- user_path = dirname
- have_gfpgan = True
- gfpgan_constructor = gfpgan.GFPGANer
-
- class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
- def name(self):
- return "GFPGAN"
-
- def restore(self, np_image):
- return gfpgan_fix_faces(np_image)
-
- shared.face_restorers.append(FaceRestorerGFPGAN())
- except Exception as e:
- errors.display(e, 'gfpgan')
+import os
+
+import modules.face_restoration
+from modules import paths, shared, devices, modelloader, errors
+
+model_dir = "GFPGAN"
+user_path = None
+model_path = os.path.join(paths.models_path, model_dir)
+model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
+have_gfpgan = False
+loaded_gfpgan_model = None
+
+
+def gfpgann():
+ import facexlib
+ import gfpgan # pylint: disable=unused-import
+ global loaded_gfpgan_model # pylint: disable=global-statement
+ if loaded_gfpgan_model is not None:
+ loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
+ return loaded_gfpgan_model
+ if gfpgan_constructor is None:
+ return None
+ models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
+ if len(models) == 1 and "http" in models[0]:
+ model_file = models[0]
+ elif len(models) != 0:
+ latest_file = max(models, key=os.path.getctime)
+ model_file = latest_file
+ else:
+ shared.log.error(f"Model failed loading: type=GFPGAN model={model_file}")
+ return None
+ if hasattr(facexlib.detection.retinaface, 'device'):
+ facexlib.detection.retinaface.device = devices.device_gfpgan
+ model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
+ loaded_gfpgan_model = model
+ shared.log.info(f"Model loaded: type=GFPGAN model={model_file}")
+ return model
+
+
+def send_model_to(model, device):
+ model.gfpgan.to(device)
+ model.face_helper.face_det.to(device)
+ model.face_helper.face_parse.to(device)
+
+
+def gfpgan_fix_faces(np_image):
+ model = gfpgann()
+ if model is None:
+ return np_image
+
+ send_model_to(model, devices.device_gfpgan)
+
+ np_image_bgr = np_image[:, :, ::-1]
+ _cropped_faces, _restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
+ np_image = gfpgan_output_bgr[:, :, ::-1]
+
+ model.face_helper.clean_all()
+
+ if shared.opts.face_restoration_unload:
+ send_model_to(model, devices.cpu)
+
+ return np_image
+
+
+gfpgan_constructor = None
+
+
+def setup_model(dirname):
+ if not os.path.exists(model_path):
+ os.makedirs(model_path)
+ try:
+ import gfpgan
+ import facexlib
+
+ global user_path # pylint: disable=global-statement
+ global have_gfpgan # pylint: disable=global-statement
+ global gfpgan_constructor # pylint: disable=global-statement
+ load_file_from_url_orig = gfpgan.utils.load_file_from_url
+ facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
+ facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
+
+ def my_load_file_from_url(**kwargs):
+ return load_file_from_url_orig(**dict(kwargs, model_dir=model_path))
+
+ def facex_load_file_from_url(**kwargs):
+ return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None))
+
+ def facex_load_file_from_url2(**kwargs):
+ return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None))
+
+ gfpgan.utils.load_file_from_url = my_load_file_from_url
+ facexlib.detection.load_file_from_url = facex_load_file_from_url
+ facexlib.parsing.load_file_from_url = facex_load_file_from_url2
+ user_path = dirname
+ have_gfpgan = True
+ gfpgan_constructor = gfpgan.GFPGANer
+
+ class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
+ def name(self):
+ return "GFPGAN"
+
+ def restore(self, np_image):
+ return gfpgan_fix_faces(np_image)
+
+ shared.face_restorers.append(FaceRestorerGFPGAN())
+ except Exception as e:
+ errors.display(e, 'gfpgan')
diff --git a/modules/postprocess/realesrgan_model.py b/modules/postprocess/realesrgan_model.py
index af6bb4705..490148c2b 100644
--- a/modules/postprocess/realesrgan_model.py
+++ b/modules/postprocess/realesrgan_model.py
@@ -1,70 +1,70 @@
-import os
-import numpy as np
-from PIL import Image
-from basicsr.archs.rrdbnet_arch import RRDBNet
-from modules.postprocess.realesrgan_model_arch import SRVGGNetCompact
-from modules.upscaler import Upscaler
-from modules.shared import opts, device, log
-from modules import devices
-
-class UpscalerRealESRGAN(Upscaler):
- def __init__(self, dirname):
- self.name = "RealESRGAN"
- self.user_path = dirname
- super().__init__()
- self.scalers = self.find_scalers()
- self.models = {}
- for scaler in self.scalers:
- if scaler.name == 'RealESRGAN 2x+':
- scaler.model = lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
- scaler.scale = 2
- elif scaler.name == 'RealESRGAN 4x+ Anime6B':
- scaler.model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
- elif scaler.name == 'RealESRGAN 4x General V3':
- scaler.model = lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
- elif scaler.name == 'RealESRGAN 4x General WDN V3':
- scaler.model = lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
- elif scaler.name == 'RealESRGAN AnimeVideo V3':
- scaler.model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
- elif scaler.name == 'RealESRGAN 4x+':
- scaler.model = lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
- else:
- log.error(f"Upscaler unrecognized model: type={self.name} model={scaler.name}")
-
- def load_model(self, path): # pylint: disable=unused-argument
- pass
-
- def do_upscale(self, img, selected_model):
- if not self.enable:
- return img
- try:
- from modules.postprocess.realesrgan_model_arch import RealESRGANer
- except Exception:
- log.error("Error importing Real-ESRGAN:")
- return img
- info = self.find_model(selected_model)
- if info is None or not os.path.exists(info.local_data_path):
- return img
- if self.models.get(info.local_data_path, None) is not None:
- log.debug(f"Upscaler cached: type={self.name} model={info.local_data_path}")
- upsampler=self.models[info.local_data_path]
- else:
- upsampler = RealESRGANer(
- name=info.name,
- scale=info.scale,
- model_path=info.local_data_path,
- model=info.model(),
- half=not opts.no_half and not opts.upcast_sampling,
- tile=opts.upscaler_tile_size,
- tile_pad=opts.upscaler_tile_overlap,
- device=device,
- )
- self.models[info.local_data_path] = upsampler
- upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0]
- if opts.upscaler_unload and info.local_data_path in self.models:
- del self.models[info.local_data_path]
- log.debug(f"Upscaler unloaded: type={self.name} model={selected_model}")
- devices.torch_gc(force=True)
-
- image = Image.fromarray(upsampled)
- return image
+import os
+import numpy as np
+from PIL import Image
+from basicsr.archs.rrdbnet_arch import RRDBNet
+from modules.postprocess.realesrgan_model_arch import SRVGGNetCompact
+from modules.upscaler import Upscaler
+from modules.shared import opts, device, log
+from modules import devices
+
+class UpscalerRealESRGAN(Upscaler):
+ def __init__(self, dirname):
+ self.name = "RealESRGAN"
+ self.user_path = dirname
+ super().__init__()
+ self.scalers = self.find_scalers()
+ self.models = {}
+ for scaler in self.scalers:
+ if scaler.name == 'RealESRGAN 2x+':
+ scaler.model = lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
+ scaler.scale = 2
+ elif scaler.name == 'RealESRGAN 4x+ Anime6B':
+ scaler.model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
+ elif scaler.name == 'RealESRGAN 4x General V3':
+ scaler.model = lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
+ elif scaler.name == 'RealESRGAN 4x General WDN V3':
+ scaler.model = lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
+ elif scaler.name == 'RealESRGAN AnimeVideo V3':
+ scaler.model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
+ elif scaler.name == 'RealESRGAN 4x+':
+ scaler.model = lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
+ else:
+ log.error(f"Upscaler unrecognized model: type={self.name} model={scaler.name}")
+
+ def load_model(self, path): # pylint: disable=unused-argument
+ pass
+
+ def do_upscale(self, img, selected_model):
+ if not self.enable:
+ return img
+ try:
+ from modules.postprocess.realesrgan_model_arch import RealESRGANer
+ except Exception:
+ log.error("Error importing Real-ESRGAN:")
+ return img
+ info = self.find_model(selected_model)
+ if info is None or not os.path.exists(info.local_data_path):
+ return img
+ if self.models.get(info.local_data_path, None) is not None:
+ log.debug(f"Upscaler cached: type={self.name} model={info.local_data_path}")
+ upsampler=self.models[info.local_data_path]
+ else:
+ upsampler = RealESRGANer(
+ name=info.name,
+ scale=info.scale,
+ model_path=info.local_data_path,
+ model=info.model(),
+ half=not opts.no_half and not opts.upcast_sampling,
+ tile=opts.upscaler_tile_size,
+ tile_pad=opts.upscaler_tile_overlap,
+ device=device,
+ )
+ self.models[info.local_data_path] = upsampler
+ upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0]
+ if opts.upscaler_unload and info.local_data_path in self.models:
+ del self.models[info.local_data_path]
+ log.debug(f"Upscaler unloaded: type={self.name} model={selected_model}")
+ devices.torch_gc(force=True)
+
+ image = Image.fromarray(upsampled)
+ return image
diff --git a/modules/postprocess/swinir_model_arch_v2.py b/modules/postprocess/swinir_model_arch_v2.py
index 19bcab441..71c246c3f 100644
--- a/modules/postprocess/swinir_model_arch_v2.py
+++ b/modules/postprocess/swinir_model_arch_v2.py
@@ -1,1015 +1,1015 @@
-# -----------------------------------------------------------------------------------
-# Swin2SR: Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration, https://arxiv.org/abs/
-# Written by Conde and Choi et al.
-# -----------------------------------------------------------------------------------
-
-import math
-import numpy as np
-import torch
-from torch import nn
-import torch.nn.functional as F
-import torch.utils.checkpoint as checkpoint
-from timm.models.layers import DropPath, to_2tuple, trunc_normal_
-
-
-class Mlp(nn.Module):
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- self.fc1 = nn.Linear(in_features, hidden_features)
- self.act = act_layer()
- self.fc2 = nn.Linear(hidden_features, out_features)
- self.drop = nn.Dropout(drop)
-
- def forward(self, x):
- x = self.fc1(x)
- x = self.act(x)
- x = self.drop(x)
- x = self.fc2(x)
- x = self.drop(x)
- return x
-
-
-def window_partition(x, window_size):
- """
- Args:
- x: (B, H, W, C)
- window_size (int): window size
- Returns:
- windows: (num_windows*B, window_size, window_size, C)
- """
- B, H, W, C = x.shape
- x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
- windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
- return windows
-
-
-def window_reverse(windows, window_size, H, W):
- """
- Args:
- windows: (num_windows*B, window_size, window_size, C)
- window_size (int): Window size
- H (int): Height of image
- W (int): Width of image
- Returns:
- x: (B, H, W, C)
- """
- B = int(windows.shape[0] / (H * W / window_size / window_size))
- x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
- return x
-
-class WindowAttention(nn.Module):
- r""" Window based multi-head self attention (W-MSA) module with relative position bias.
- It supports both of shifted and non-shifted window.
- Args:
- dim (int): Number of input channels.
- window_size (tuple[int]): The height and width of the window.
- num_heads (int): Number of attention heads.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
- proj_drop (float, optional): Dropout ratio of output. Default: 0.0
- pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
- """
-
- def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
- pretrained_window_size=(0, 0)):
-
- super().__init__()
- self.dim = dim
- self.window_size = window_size # Wh, Ww
- self.pretrained_window_size = pretrained_window_size
- self.num_heads = num_heads
-
- self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
-
- # mlp to generate continuous relative position bias
- self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
- nn.ReLU(inplace=True),
- nn.Linear(512, num_heads, bias=False))
-
- # get relative_coords_table
- relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
- relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
- relative_coords_table = torch.stack(
- torch.meshgrid([relative_coords_h,
- relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
- if pretrained_window_size[0] > 0:
- relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
- relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
- else:
- relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
- relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
- relative_coords_table *= 8 # normalize to -8, 8
- relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
- torch.abs(relative_coords_table) + 1.0) / np.log2(8)
-
- self.register_buffer("relative_coords_table", relative_coords_table)
-
- # get pair-wise relative position index for each token inside the window
- coords_h = torch.arange(self.window_size[0])
- coords_w = torch.arange(self.window_size[1])
- coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
- coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
- relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
- relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
- relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
- relative_coords[:, :, 1] += self.window_size[1] - 1
- relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
- relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
- self.register_buffer("relative_position_index", relative_position_index)
-
- self.qkv = nn.Linear(dim, dim * 3, bias=False)
- if qkv_bias:
- self.q_bias = nn.Parameter(torch.zeros(dim))
- self.v_bias = nn.Parameter(torch.zeros(dim))
- else:
- self.q_bias = None
- self.v_bias = None
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
- self.proj_drop = nn.Dropout(proj_drop)
- self.softmax = nn.Softmax(dim=-1)
-
- def forward(self, x, mask=None):
- """
- Args:
- x: input features with shape of (num_windows*B, N, C)
- mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
- """
- B_, N, C = x.shape
- qkv_bias = None
- if self.q_bias is not None:
- qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
- qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
- qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
- q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
-
- # cosine attention
- attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
- logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).to(self.logit_scale.device)).exp()
- attn = attn * logit_scale
-
- relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
- relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
- self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
- relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
- relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
- attn = attn + relative_position_bias.unsqueeze(0)
-
- if mask is not None:
- nW = mask.shape[0]
- attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
- attn = attn.view(-1, self.num_heads, N, N)
- attn = self.softmax(attn)
- else:
- attn = self.softmax(attn)
-
- attn = self.attn_drop(attn)
-
- x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
- def extra_repr(self) -> str:
- return f'dim={self.dim}, window_size={self.window_size}, pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'
-
- def flops(self, N):
- # calculate flops for 1 window with token length of N
- flops = 0
- # qkv = self.qkv(x)
- flops += N * self.dim * 3 * self.dim
- # attn = (q @ k.transpose(-2, -1))
- flops += self.num_heads * N * (self.dim // self.num_heads) * N
- # x = (attn @ v)
- flops += self.num_heads * N * N * (self.dim // self.num_heads)
- # x = self.proj(x)
- flops += N * self.dim * self.dim
- return flops
-
-class SwinTransformerBlock(nn.Module):
- r""" Swin Transformer Block.
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resulotion.
- num_heads (int): Number of attention heads.
- window_size (int): Window size.
- shift_size (int): Shift size for SW-MSA.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float, optional): Stochastic depth rate. Default: 0.0
- act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- pretrained_window_size (int): Window size in pre-training.
- """
-
- def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
- mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
- act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0):
- super().__init__()
- self.dim = dim
- self.input_resolution = input_resolution
- self.num_heads = num_heads
- self.window_size = window_size
- self.shift_size = shift_size
- self.mlp_ratio = mlp_ratio
- if min(self.input_resolution) <= self.window_size:
- # if window size is larger than input resolution, we don't partition windows
- self.shift_size = 0
- self.window_size = min(self.input_resolution)
- assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
-
- self.norm1 = norm_layer(dim)
- self.attn = WindowAttention(
- dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
- qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
- pretrained_window_size=to_2tuple(pretrained_window_size))
-
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- self.norm2 = norm_layer(dim)
- mlp_hidden_dim = int(dim * mlp_ratio)
- self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
-
- if self.shift_size > 0:
- attn_mask = self.calculate_mask(self.input_resolution)
- else:
- attn_mask = None
-
- self.register_buffer("attn_mask", attn_mask)
-
- def calculate_mask(self, x_size):
- # calculate attention mask for SW-MSA
- H, W = x_size
- img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
- h_slices = (slice(0, -self.window_size),
- slice(-self.window_size, -self.shift_size),
- slice(-self.shift_size, None))
- w_slices = (slice(0, -self.window_size),
- slice(-self.window_size, -self.shift_size),
- slice(-self.shift_size, None))
- cnt = 0
- for h in h_slices:
- for w in w_slices:
- img_mask[:, h, w, :] = cnt
- cnt += 1
-
- mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
- mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
- attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
- attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
-
- return attn_mask
-
- def forward(self, x, x_size):
- H, W = x_size
- B, L, C = x.shape
- #assert L == H * W, "input feature has wrong size"
-
- shortcut = x
- x = x.view(B, H, W, C)
-
- # cyclic shift
- if self.shift_size > 0:
- shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
- else:
- shifted_x = x
-
- # partition windows
- x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
- x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
-
- # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
- if self.input_resolution == x_size:
- attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
- else:
- attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
-
- # merge windows
- attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
- shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
-
- # reverse cyclic shift
- if self.shift_size > 0:
- x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
- else:
- x = shifted_x
- x = x.view(B, H * W, C)
- x = shortcut + self.drop_path(self.norm1(x))
-
- # FFN
- x = x + self.drop_path(self.norm2(self.mlp(x)))
-
- return x
-
- def extra_repr(self) -> str:
- return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
-
- def flops(self):
- flops = 0
- H, W = self.input_resolution
- # norm1
- flops += self.dim * H * W
- # W-MSA/SW-MSA
- nW = H * W / self.window_size / self.window_size
- flops += nW * self.attn.flops(self.window_size * self.window_size)
- # mlp
- flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
- # norm2
- flops += self.dim * H * W
- return flops
-
-class PatchMerging(nn.Module):
- r""" Patch Merging Layer.
- Args:
- input_resolution (tuple[int]): Resolution of input feature.
- dim (int): Number of input channels.
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- """
-
- def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
- super().__init__()
- self.input_resolution = input_resolution
- self.dim = dim
- self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
- self.norm = norm_layer(2 * dim)
-
- def forward(self, x):
- """
- x: B, H*W, C
- """
- H, W = self.input_resolution
- B, L, C = x.shape
- assert L == H * W, "input feature has wrong size"
- assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
-
- x = x.view(B, H, W, C)
-
- x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
- x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
- x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
- x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
- x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
- x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
-
- x = self.reduction(x)
- x = self.norm(x)
-
- return x
-
- def extra_repr(self) -> str:
- return f"input_resolution={self.input_resolution}, dim={self.dim}"
-
- def flops(self):
- H, W = self.input_resolution
- flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
- flops += H * W * self.dim // 2
- return flops
-
-class BasicLayer(nn.Module):
- """ A basic Swin Transformer layer for one stage.
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resolution.
- depth (int): Number of blocks.
- num_heads (int): Number of attention heads.
- window_size (int): Local window size.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
- pretrained_window_size (int): Local window size in pre-training.
- """
-
- def __init__(self, dim, input_resolution, depth, num_heads, window_size,
- mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
- drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
- pretrained_window_size=0):
-
- super().__init__()
- self.dim = dim
- self.input_resolution = input_resolution
- self.depth = depth
- self.use_checkpoint = use_checkpoint
-
- # build blocks
- self.blocks = nn.ModuleList([
- SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
- num_heads=num_heads, window_size=window_size,
- shift_size=0 if (i % 2 == 0) else window_size // 2,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
- drop=drop, attn_drop=attn_drop,
- drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
- norm_layer=norm_layer,
- pretrained_window_size=pretrained_window_size)
- for i in range(depth)])
-
- # patch merging layer
- if downsample is not None:
- self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
- else:
- self.downsample = None
-
- def forward(self, x, x_size):
- for blk in self.blocks:
- if self.use_checkpoint:
- x = checkpoint.checkpoint(blk, x, x_size)
- else:
- x = blk(x, x_size)
- if self.downsample is not None:
- x = self.downsample(x)
- return x
-
- def extra_repr(self) -> str:
- return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
-
- def flops(self):
- flops = 0
- for blk in self.blocks:
- flops += blk.flops()
- if self.downsample is not None:
- flops += self.downsample.flops()
- return flops
-
- def _init_respostnorm(self):
- for blk in self.blocks:
- nn.init.constant_(blk.norm1.bias, 0)
- nn.init.constant_(blk.norm1.weight, 0)
- nn.init.constant_(blk.norm2.bias, 0)
- nn.init.constant_(blk.norm2.weight, 0)
-
-class PatchEmbed(nn.Module):
- r""" Image to Patch Embedding
- Args:
- img_size (int): Image size. Default: 224.
- patch_size (int): Patch token size. Default: 4.
- in_chans (int): Number of input image channels. Default: 3.
- embed_dim (int): Number of linear projection output channels. Default: 96.
- norm_layer (nn.Module, optional): Normalization layer. Default: None
- """
-
- def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
- super().__init__()
- img_size = to_2tuple(img_size)
- patch_size = to_2tuple(patch_size)
- patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
- self.img_size = img_size
- self.patch_size = patch_size
- self.patches_resolution = patches_resolution
- self.num_patches = patches_resolution[0] * patches_resolution[1]
-
- self.in_chans = in_chans
- self.embed_dim = embed_dim
-
- self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
- if norm_layer is not None:
- self.norm = norm_layer(embed_dim)
- else:
- self.norm = None
-
- def forward(self, x):
- B, C, H, W = x.shape
- # FIXME look at relaxing size constraints
- # assert H == self.img_size[0] and W == self.img_size[1],
- # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
- x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
- if self.norm is not None:
- x = self.norm(x)
- return x
-
- def flops(self):
- Ho, Wo = self.patches_resolution
- flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
- if self.norm is not None:
- flops += Ho * Wo * self.embed_dim
- return flops
-
-class RSTB(nn.Module):
- """Residual Swin Transformer Block (RSTB).
-
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resolution.
- depth (int): Number of blocks.
- num_heads (int): Number of attention heads.
- window_size (int): Local window size.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
- img_size: Input image size.
- patch_size: Patch size.
- resi_connection: The convolutional block before residual connection.
- """
-
- def __init__(self, dim, input_resolution, depth, num_heads, window_size,
- mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
- drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
- img_size=224, patch_size=4, resi_connection='1conv'):
- super(RSTB, self).__init__()
-
- self.dim = dim
- self.input_resolution = input_resolution
-
- self.residual_group = BasicLayer(dim=dim,
- input_resolution=input_resolution,
- depth=depth,
- num_heads=num_heads,
- window_size=window_size,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
- drop=drop, attn_drop=attn_drop,
- drop_path=drop_path,
- norm_layer=norm_layer,
- downsample=downsample,
- use_checkpoint=use_checkpoint)
-
- if resi_connection == '1conv':
- self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
- elif resi_connection == '3conv':
- # to save parameters and memory
- self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(dim // 4, dim, 3, 1, 1))
-
- self.patch_embed = PatchEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim,
- norm_layer=None)
-
- self.patch_unembed = PatchUnEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim,
- norm_layer=None)
-
- def forward(self, x, x_size):
- return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
-
- def flops(self):
- flops = 0
- flops += self.residual_group.flops()
- H, W = self.input_resolution
- flops += H * W * self.dim * self.dim * 9
- flops += self.patch_embed.flops()
- flops += self.patch_unembed.flops()
-
- return flops
-
-class PatchUnEmbed(nn.Module):
- r""" Image to Patch Unembedding
-
- Args:
- img_size (int): Image size. Default: 224.
- patch_size (int): Patch token size. Default: 4.
- in_chans (int): Number of input image channels. Default: 3.
- embed_dim (int): Number of linear projection output channels. Default: 96.
- norm_layer (nn.Module, optional): Normalization layer. Default: None
- """
-
- def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
- super().__init__()
- img_size = to_2tuple(img_size)
- patch_size = to_2tuple(patch_size)
- patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
- self.img_size = img_size
- self.patch_size = patch_size
- self.patches_resolution = patches_resolution
- self.num_patches = patches_resolution[0] * patches_resolution[1]
-
- self.in_chans = in_chans
- self.embed_dim = embed_dim
-
- def forward(self, x, x_size):
- B, HW, C = x.shape
- x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
- return x
-
- def flops(self):
- flops = 0
- return flops
-
-
-class Upsample(nn.Sequential):
- """Upsample module.
-
- Args:
- scale (int): Scale factor. Supported scales: 2^n and 3.
- num_feat (int): Channel number of intermediate features.
- """
-
- def __init__(self, scale, num_feat):
- m = []
- if (scale & (scale - 1)) == 0: # scale = 2^n
- for _ in range(int(math.log(scale, 2))):
- m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
- m.append(nn.PixelShuffle(2))
- elif scale == 3:
- m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
- m.append(nn.PixelShuffle(3))
- else:
- raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
- super(Upsample, self).__init__(*m)
-
-class Upsample_hf(nn.Sequential):
- """Upsample module.
-
- Args:
- scale (int): Scale factor. Supported scales: 2^n and 3.
- num_feat (int): Channel number of intermediate features.
- """
-
- def __init__(self, scale, num_feat):
- m = []
- if (scale & (scale - 1)) == 0: # scale = 2^n
- for _ in range(int(math.log(scale, 2))):
- m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
- m.append(nn.PixelShuffle(2))
- elif scale == 3:
- m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
- m.append(nn.PixelShuffle(3))
- else:
- raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
- super(Upsample_hf, self).__init__(*m)
-
-
-class UpsampleOneStep(nn.Sequential):
- """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
- Used in lightweight SR to save parameters.
-
- Args:
- scale (int): Scale factor. Supported scales: 2^n and 3.
- num_feat (int): Channel number of intermediate features.
-
- """
-
- def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
- self.num_feat = num_feat
- self.input_resolution = input_resolution
- m = []
- m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
- m.append(nn.PixelShuffle(scale))
- super(UpsampleOneStep, self).__init__(*m)
-
- def flops(self):
- H, W = self.input_resolution
- flops = H * W * self.num_feat * 3 * 9
- return flops
-
-
-
-class Swin2SR(nn.Module):
- r""" Swin2SR
- A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`.
-
- Args:
- img_size (int | tuple(int)): Input image size. Default 64
- patch_size (int | tuple(int)): Patch size. Default: 1
- in_chans (int): Number of input image channels. Default: 3
- embed_dim (int): Patch embedding dimension. Default: 96
- depths (tuple(int)): Depth of each Swin Transformer layer.
- num_heads (tuple(int)): Number of attention heads in different layers.
- window_size (int): Window size. Default: 7
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
- qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
- drop_rate (float): Dropout rate. Default: 0
- attn_drop_rate (float): Attention dropout rate. Default: 0
- drop_path_rate (float): Stochastic depth rate. Default: 0.1
- norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
- ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
- patch_norm (bool): If True, add normalization after patch embedding. Default: True
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
- upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
- img_range: Image range. 1. or 255.
- upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
- resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
- """
-
- def __init__(self, img_size=64, patch_size=1, in_chans=3,
- embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
- window_size=7, mlp_ratio=4., qkv_bias=True,
- drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
- norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
- use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
- **kwargs):
- super(Swin2SR, self).__init__()
- num_in_ch = in_chans
- num_out_ch = in_chans
- num_feat = 64
- self.img_range = img_range
- if in_chans == 3:
- rgb_mean = (0.4488, 0.4371, 0.4040)
- self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
- else:
- self.mean = torch.zeros(1, 1, 1, 1)
- self.upscale = upscale
- self.upsampler = upsampler
- self.window_size = window_size
-
- #####################################################################################################
- ################################### 1, shallow feature extraction ###################################
- self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
-
- #####################################################################################################
- ################################### 2, deep feature extraction ######################################
- self.num_layers = len(depths)
- self.embed_dim = embed_dim
- self.ape = ape
- self.patch_norm = patch_norm
- self.num_features = embed_dim
- self.mlp_ratio = mlp_ratio
-
- # split image into non-overlapping patches
- self.patch_embed = PatchEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
- norm_layer=norm_layer if self.patch_norm else None)
- num_patches = self.patch_embed.num_patches
- patches_resolution = self.patch_embed.patches_resolution
- self.patches_resolution = patches_resolution
-
- # merge non-overlapping patches into image
- self.patch_unembed = PatchUnEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
- norm_layer=norm_layer if self.patch_norm else None)
-
- # absolute position embedding
- if self.ape:
- self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
- trunc_normal_(self.absolute_pos_embed, std=.02)
-
- self.pos_drop = nn.Dropout(p=drop_rate)
-
- # stochastic depth
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
-
- # build Residual Swin Transformer blocks (RSTB)
- self.layers = nn.ModuleList()
- for i_layer in range(self.num_layers):
- layer = RSTB(dim=embed_dim,
- input_resolution=(patches_resolution[0],
- patches_resolution[1]),
- depth=depths[i_layer],
- num_heads=num_heads[i_layer],
- window_size=window_size,
- mlp_ratio=self.mlp_ratio,
- qkv_bias=qkv_bias,
- drop=drop_rate, attn_drop=attn_drop_rate,
- drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
- norm_layer=norm_layer,
- downsample=None,
- use_checkpoint=use_checkpoint,
- img_size=img_size,
- patch_size=patch_size,
- resi_connection=resi_connection
-
- )
- self.layers.append(layer)
-
- if self.upsampler == 'pixelshuffle_hf':
- self.layers_hf = nn.ModuleList()
- for i_layer in range(self.num_layers):
- layer = RSTB(dim=embed_dim,
- input_resolution=(patches_resolution[0],
- patches_resolution[1]),
- depth=depths[i_layer],
- num_heads=num_heads[i_layer],
- window_size=window_size,
- mlp_ratio=self.mlp_ratio,
- qkv_bias=qkv_bias,
- drop=drop_rate, attn_drop=attn_drop_rate,
- drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
- norm_layer=norm_layer,
- downsample=None,
- use_checkpoint=use_checkpoint,
- img_size=img_size,
- patch_size=patch_size,
- resi_connection=resi_connection
-
- )
- self.layers_hf.append(layer)
-
- self.norm = norm_layer(self.num_features)
-
- # build the last conv layer in deep feature extraction
- if resi_connection == '1conv':
- self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
- elif resi_connection == '3conv':
- # to save parameters and memory
- self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
-
- #####################################################################################################
- ################################ 3, high quality image reconstruction ################################
- if self.upsampler == 'pixelshuffle':
- # for classical SR
- self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.upsample = Upsample(upscale, num_feat)
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
- elif self.upsampler == 'pixelshuffle_aux':
- self.conv_bicubic = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
- self.conv_before_upsample = nn.Sequential(
- nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
- self.conv_after_aux = nn.Sequential(
- nn.Conv2d(3, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.upsample = Upsample(upscale, num_feat)
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
-
- elif self.upsampler == 'pixelshuffle_hf':
- self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.upsample = Upsample(upscale, num_feat)
- self.upsample_hf = Upsample_hf(upscale, num_feat)
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
- self.conv_first_hf = nn.Sequential(nn.Conv2d(num_feat, embed_dim, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.conv_after_body_hf = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
- self.conv_before_upsample_hf = nn.Sequential(
- nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
-
- elif self.upsampler == 'pixelshuffledirect':
- # for lightweight SR (to save parameters)
- self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
- (patches_resolution[0], patches_resolution[1]))
- elif self.upsampler == 'nearest+conv':
- # for real-world SR (less artifacts)
- assert self.upscale == 4, 'only support x4 now.'
- self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
- self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
- self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
- else:
- # for image denoising and JPEG compression artifact reduction
- self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
-
- self.apply(self._init_weights)
-
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
-
- @torch.jit.ignore
- def no_weight_decay(self):
- return {'absolute_pos_embed'}
-
- @torch.jit.ignore
- def no_weight_decay_keywords(self):
- return {'relative_position_bias_table'}
-
- def check_image_size(self, x):
- _, _, h, w = x.size()
- mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
- mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
- x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
- return x
-
- def forward_features(self, x):
- x_size = (x.shape[2], x.shape[3])
- x = self.patch_embed(x)
- if self.ape:
- x = x + self.absolute_pos_embed
- x = self.pos_drop(x)
-
- for layer in self.layers:
- x = layer(x, x_size)
-
- x = self.norm(x) # B L C
- x = self.patch_unembed(x, x_size)
-
- return x
-
- def forward_features_hf(self, x):
- x_size = (x.shape[2], x.shape[3])
- x = self.patch_embed(x)
- if self.ape:
- x = x + self.absolute_pos_embed
- x = self.pos_drop(x)
-
- for layer in self.layers_hf:
- x = layer(x, x_size)
-
- x = self.norm(x) # B L C
- x = self.patch_unembed(x, x_size)
-
- return x
-
- def forward(self, x):
- H, W = x.shape[2:]
- x = self.check_image_size(x)
-
- self.mean = self.mean.type_as(x)
- x = (x - self.mean) * self.img_range
-
- if self.upsampler == 'pixelshuffle':
- # for classical SR
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.conv_before_upsample(x)
- x = self.conv_last(self.upsample(x))
- elif self.upsampler == 'pixelshuffle_aux':
- bicubic = F.interpolate(x, size=(H * self.upscale, W * self.upscale), mode='bicubic', align_corners=False)
- bicubic = self.conv_bicubic(bicubic)
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.conv_before_upsample(x)
- aux = self.conv_aux(x) # b, 3, LR_H, LR_W
- x = self.conv_after_aux(aux)
- x = self.upsample(x)[:, :, :H * self.upscale, :W * self.upscale] + bicubic[:, :, :H * self.upscale, :W * self.upscale]
- x = self.conv_last(x)
- aux = aux / self.img_range + self.mean
- elif self.upsampler == 'pixelshuffle_hf':
- # for classical SR with HF
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x_before = self.conv_before_upsample(x)
- x_out = self.conv_last(self.upsample(x_before))
-
- x_hf = self.conv_first_hf(x_before)
- x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf
- x_hf = self.conv_before_upsample_hf(x_hf)
- x_hf = self.conv_last_hf(self.upsample_hf(x_hf))
- x = x_out + x_hf
- x_hf = x_hf / self.img_range + self.mean
-
- elif self.upsampler == 'pixelshuffledirect':
- # for lightweight SR
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.upsample(x)
- elif self.upsampler == 'nearest+conv':
- # for real-world SR
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.conv_before_upsample(x)
- x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
- x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
- x = self.conv_last(self.lrelu(self.conv_hr(x)))
- else:
- # for image denoising and JPEG compression artifact reduction
- x_first = self.conv_first(x)
- res = self.conv_after_body(self.forward_features(x_first)) + x_first
- x = x + self.conv_last(res)
-
- x = x / self.img_range + self.mean
- if self.upsampler == "pixelshuffle_aux":
- return x[:, :, :H*self.upscale, :W*self.upscale], aux
-
- elif self.upsampler == "pixelshuffle_hf":
- x_out = x_out / self.img_range + self.mean
- return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale]
-
- else:
- return x[:, :, :H*self.upscale, :W*self.upscale]
-
- def flops(self):
- flops = 0
- H, W = self.patches_resolution
- flops += H * W * 3 * self.embed_dim * 9
- flops += self.patch_embed.flops()
- for layer in self.layers:
- flops += layer.flops()
- flops += H * W * 3 * self.embed_dim * self.embed_dim
- flops += self.upsample.flops()
- return flops
-
-
-if __name__ == '__main__':
- upscale = 4
- window_size = 8
- height = (1024 // upscale // window_size + 1) * window_size
- width = (720 // upscale // window_size + 1) * window_size
- model = Swin2SR(upscale=2, img_size=(height, width),
- window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
- embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
- print(model)
- print(height, width, model.flops() / 1e9)
-
- x = torch.randn((1, 3, height, width))
- x = model(x)
- print(x.shape)
+# -----------------------------------------------------------------------------------
+# Swin2SR: Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration, https://arxiv.org/abs/
+# Written by Conde and Choi et al.
+# -----------------------------------------------------------------------------------
+
+import math
+import numpy as np
+import torch
+from torch import nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+def window_partition(x, window_size):
+ """
+ Args:
+ x: (B, H, W, C)
+ window_size (int): window size
+ Returns:
+ windows: (num_windows*B, window_size, window_size, C)
+ """
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ return windows
+
+
+def window_reverse(windows, window_size, H, W):
+ """
+ Args:
+ windows: (num_windows*B, window_size, window_size, C)
+ window_size (int): Window size
+ H (int): Height of image
+ W (int): Width of image
+ Returns:
+ x: (B, H, W, C)
+ """
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+class WindowAttention(nn.Module):
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
+ It supports both of shifted and non-shifted window.
+ Args:
+ dim (int): Number of input channels.
+ window_size (tuple[int]): The height and width of the window.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+ pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
+ """
+
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
+ pretrained_window_size=(0, 0)):
+
+ super().__init__()
+ self.dim = dim
+ self.window_size = window_size # Wh, Ww
+ self.pretrained_window_size = pretrained_window_size
+ self.num_heads = num_heads
+
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
+
+ # mlp to generate continuous relative position bias
+ self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
+ nn.ReLU(inplace=True),
+ nn.Linear(512, num_heads, bias=False))
+
+ # get relative_coords_table
+ relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
+ relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
+ relative_coords_table = torch.stack(
+ torch.meshgrid([relative_coords_h,
+ relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
+ if pretrained_window_size[0] > 0:
+ relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
+ relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
+ else:
+ relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
+ relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
+ relative_coords_table *= 8 # normalize to -8, 8
+ relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
+ torch.abs(relative_coords_table) + 1.0) / np.log2(8)
+
+ self.register_buffer("relative_coords_table", relative_coords_table)
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(self.window_size[0])
+ coords_w = torch.arange(self.window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += self.window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=False)
+ if qkv_bias:
+ self.q_bias = nn.Parameter(torch.zeros(dim))
+ self.v_bias = nn.Parameter(torch.zeros(dim))
+ else:
+ self.q_bias = None
+ self.v_bias = None
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x, mask=None):
+ """
+ Args:
+ x: input features with shape of (num_windows*B, N, C)
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+ """
+ B_, N, C = x.shape
+ qkv_bias = None
+ if self.q_bias is not None:
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
+ qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ # cosine attention
+ attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
+ logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).to(self.logit_scale.device)).exp()
+ attn = attn * logit_scale
+
+ relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
+ relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if mask is not None:
+ nW = mask.shape[0]
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
+ attn = attn.view(-1, self.num_heads, N, N)
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+ def extra_repr(self) -> str:
+ return f'dim={self.dim}, window_size={self.window_size}, pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'
+
+ def flops(self, N):
+ # calculate flops for 1 window with token length of N
+ flops = 0
+ # qkv = self.qkv(x)
+ flops += N * self.dim * 3 * self.dim
+ # attn = (q @ k.transpose(-2, -1))
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
+ # x = (attn @ v)
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
+ # x = self.proj(x)
+ flops += N * self.dim * self.dim
+ return flops
+
+class SwinTransformerBlock(nn.Module):
+ r""" Swin Transformer Block.
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resulotion.
+ num_heads (int): Number of attention heads.
+ window_size (int): Window size.
+ shift_size (int): Shift size for SW-MSA.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ pretrained_window_size (int): Window size in pre-training.
+ """
+
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0):
+ super().__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.shift_size = shift_size
+ self.mlp_ratio = mlp_ratio
+ if min(self.input_resolution) <= self.window_size:
+ # if window size is larger than input resolution, we don't partition windows
+ self.shift_size = 0
+ self.window_size = min(self.input_resolution)
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+
+ self.norm1 = norm_layer(dim)
+ self.attn = WindowAttention(
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
+ qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
+ pretrained_window_size=to_2tuple(pretrained_window_size))
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ if self.shift_size > 0:
+ attn_mask = self.calculate_mask(self.input_resolution)
+ else:
+ attn_mask = None
+
+ self.register_buffer("attn_mask", attn_mask)
+
+ def calculate_mask(self, x_size):
+ # calculate attention mask for SW-MSA
+ H, W = x_size
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
+ h_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ w_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+
+ return attn_mask
+
+ def forward(self, x, x_size):
+ H, W = x_size
+ B, L, C = x.shape
+ #assert L == H * W, "input feature has wrong size"
+
+ shortcut = x
+ x = x.view(B, H, W, C)
+
+ # cyclic shift
+ if self.shift_size > 0:
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+ else:
+ shifted_x = x
+
+ # partition windows
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
+
+ # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
+ if self.input_resolution == x_size:
+ attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
+ else:
+ attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
+
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+ else:
+ x = shifted_x
+ x = x.view(B, H * W, C)
+ x = shortcut + self.drop_path(self.norm1(x))
+
+ # FFN
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
+
+ return x
+
+ def extra_repr(self) -> str:
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
+
+ def flops(self):
+ flops = 0
+ H, W = self.input_resolution
+ # norm1
+ flops += self.dim * H * W
+ # W-MSA/SW-MSA
+ nW = H * W / self.window_size / self.window_size
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
+ # mlp
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
+ # norm2
+ flops += self.dim * H * W
+ return flops
+
+class PatchMerging(nn.Module):
+ r""" Patch Merging Layer.
+ Args:
+ input_resolution (tuple[int]): Resolution of input feature.
+ dim (int): Number of input channels.
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.input_resolution = input_resolution
+ self.dim = dim
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+ self.norm = norm_layer(2 * dim)
+
+ def forward(self, x):
+ """
+ x: B, H*W, C
+ """
+ H, W = self.input_resolution
+ B, L, C = x.shape
+ assert L == H * W, "input feature has wrong size"
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
+
+ x = x.view(B, H, W, C)
+
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
+
+ x = self.reduction(x)
+ x = self.norm(x)
+
+ return x
+
+ def extra_repr(self) -> str:
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
+
+ def flops(self):
+ H, W = self.input_resolution
+ flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
+ flops += H * W * self.dim // 2
+ return flops
+
+class BasicLayer(nn.Module):
+ """ A basic Swin Transformer layer for one stage.
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resolution.
+ depth (int): Number of blocks.
+ num_heads (int): Number of attention heads.
+ window_size (int): Local window size.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ pretrained_window_size (int): Local window size in pre-training.
+ """
+
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
+ pretrained_window_size=0):
+
+ super().__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.depth = depth
+ self.use_checkpoint = use_checkpoint
+
+ # build blocks
+ self.blocks = nn.ModuleList([
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
+ num_heads=num_heads, window_size=window_size,
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop, attn_drop=attn_drop,
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+ norm_layer=norm_layer,
+ pretrained_window_size=pretrained_window_size)
+ for i in range(depth)])
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
+ else:
+ self.downsample = None
+
+ def forward(self, x, x_size):
+ for blk in self.blocks:
+ if self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x, x_size)
+ else:
+ x = blk(x, x_size)
+ if self.downsample is not None:
+ x = self.downsample(x)
+ return x
+
+ def extra_repr(self) -> str:
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
+
+ def flops(self):
+ flops = 0
+ for blk in self.blocks:
+ flops += blk.flops()
+ if self.downsample is not None:
+ flops += self.downsample.flops()
+ return flops
+
+ def _init_respostnorm(self):
+ for blk in self.blocks:
+ nn.init.constant_(blk.norm1.bias, 0)
+ nn.init.constant_(blk.norm1.weight, 0)
+ nn.init.constant_(blk.norm2.bias, 0)
+ nn.init.constant_(blk.norm2.weight, 0)
+
+class PatchEmbed(nn.Module):
+ r""" Image to Patch Embedding
+ Args:
+ img_size (int): Image size. Default: 224.
+ patch_size (int): Patch token size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
+ """
+
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.patches_resolution = patches_resolution
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+ if norm_layer is not None:
+ self.norm = norm_layer(embed_dim)
+ else:
+ self.norm = None
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ # FIXME look at relaxing size constraints
+ # assert H == self.img_size[0] and W == self.img_size[1],
+ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+ x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
+ if self.norm is not None:
+ x = self.norm(x)
+ return x
+
+ def flops(self):
+ Ho, Wo = self.patches_resolution
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
+
+class RSTB(nn.Module):
+ """Residual Swin Transformer Block (RSTB).
+
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resolution.
+ depth (int): Number of blocks.
+ num_heads (int): Number of attention heads.
+ window_size (int): Local window size.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ img_size: Input image size.
+ patch_size: Patch size.
+ resi_connection: The convolutional block before residual connection.
+ """
+
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
+ img_size=224, patch_size=4, resi_connection='1conv'):
+ super(RSTB, self).__init__()
+
+ self.dim = dim
+ self.input_resolution = input_resolution
+
+ self.residual_group = BasicLayer(dim=dim,
+ input_resolution=input_resolution,
+ depth=depth,
+ num_heads=num_heads,
+ window_size=window_size,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop, attn_drop=attn_drop,
+ drop_path=drop_path,
+ norm_layer=norm_layer,
+ downsample=downsample,
+ use_checkpoint=use_checkpoint)
+
+ if resi_connection == '1conv':
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
+ elif resi_connection == '3conv':
+ # to save parameters and memory
+ self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Conv2d(dim // 4, dim, 3, 1, 1))
+
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim,
+ norm_layer=None)
+
+ self.patch_unembed = PatchUnEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim,
+ norm_layer=None)
+
+ def forward(self, x, x_size):
+ return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
+
+ def flops(self):
+ flops = 0
+ flops += self.residual_group.flops()
+ H, W = self.input_resolution
+ flops += H * W * self.dim * self.dim * 9
+ flops += self.patch_embed.flops()
+ flops += self.patch_unembed.flops()
+
+ return flops
+
+class PatchUnEmbed(nn.Module):
+ r""" Image to Patch Unembedding
+
+ Args:
+ img_size (int): Image size. Default: 224.
+ patch_size (int): Patch token size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
+ """
+
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.patches_resolution = patches_resolution
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ def forward(self, x, x_size):
+ B, HW, C = x.shape
+ x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
+ return x
+
+ def flops(self):
+ flops = 0
+ return flops
+
+
+class Upsample(nn.Sequential):
+ """Upsample module.
+
+ Args:
+ scale (int): Scale factor. Supported scales: 2^n and 3.
+ num_feat (int): Channel number of intermediate features.
+ """
+
+ def __init__(self, scale, num_feat):
+ m = []
+ if (scale & (scale - 1)) == 0: # scale = 2^n
+ for _ in range(int(math.log(scale, 2))):
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
+ m.append(nn.PixelShuffle(2))
+ elif scale == 3:
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
+ m.append(nn.PixelShuffle(3))
+ else:
+ raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
+ super(Upsample, self).__init__(*m)
+
+class Upsample_hf(nn.Sequential):
+ """Upsample module.
+
+ Args:
+ scale (int): Scale factor. Supported scales: 2^n and 3.
+ num_feat (int): Channel number of intermediate features.
+ """
+
+ def __init__(self, scale, num_feat):
+ m = []
+ if (scale & (scale - 1)) == 0: # scale = 2^n
+ for _ in range(int(math.log(scale, 2))):
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
+ m.append(nn.PixelShuffle(2))
+ elif scale == 3:
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
+ m.append(nn.PixelShuffle(3))
+ else:
+ raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
+ super(Upsample_hf, self).__init__(*m)
+
+
+class UpsampleOneStep(nn.Sequential):
+ """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
+ Used in lightweight SR to save parameters.
+
+ Args:
+ scale (int): Scale factor. Supported scales: 2^n and 3.
+ num_feat (int): Channel number of intermediate features.
+
+ """
+
+ def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
+ self.num_feat = num_feat
+ self.input_resolution = input_resolution
+ m = []
+ m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
+ m.append(nn.PixelShuffle(scale))
+ super(UpsampleOneStep, self).__init__(*m)
+
+ def flops(self):
+ H, W = self.input_resolution
+ flops = H * W * self.num_feat * 3 * 9
+ return flops
+
+
+
+class Swin2SR(nn.Module):
+ r""" Swin2SR
+ A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`.
+
+ Args:
+ img_size (int | tuple(int)): Input image size. Default 64
+ patch_size (int | tuple(int)): Patch size. Default: 1
+ in_chans (int): Number of input image channels. Default: 3
+ embed_dim (int): Patch embedding dimension. Default: 96
+ depths (tuple(int)): Depth of each Swin Transformer layer.
+ num_heads (tuple(int)): Number of attention heads in different layers.
+ window_size (int): Window size. Default: 7
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
+ drop_rate (float): Dropout rate. Default: 0
+ attn_drop_rate (float): Attention dropout rate. Default: 0
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
+ upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
+ img_range: Image range. 1. or 255.
+ upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
+ resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
+ """
+
+ def __init__(self, img_size=64, patch_size=1, in_chans=3,
+ embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
+ window_size=7, mlp_ratio=4., qkv_bias=True,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
+ norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
+ use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
+ **kwargs):
+ super(Swin2SR, self).__init__()
+ num_in_ch = in_chans
+ num_out_ch = in_chans
+ num_feat = 64
+ self.img_range = img_range
+ if in_chans == 3:
+ rgb_mean = (0.4488, 0.4371, 0.4040)
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
+ else:
+ self.mean = torch.zeros(1, 1, 1, 1)
+ self.upscale = upscale
+ self.upsampler = upsampler
+ self.window_size = window_size
+
+ #####################################################################################################
+ ################################### 1, shallow feature extraction ###################################
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
+
+ #####################################################################################################
+ ################################### 2, deep feature extraction ######################################
+ self.num_layers = len(depths)
+ self.embed_dim = embed_dim
+ self.ape = ape
+ self.patch_norm = patch_norm
+ self.num_features = embed_dim
+ self.mlp_ratio = mlp_ratio
+
+ # split image into non-overlapping patches
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
+ norm_layer=norm_layer if self.patch_norm else None)
+ num_patches = self.patch_embed.num_patches
+ patches_resolution = self.patch_embed.patches_resolution
+ self.patches_resolution = patches_resolution
+
+ # merge non-overlapping patches into image
+ self.patch_unembed = PatchUnEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
+ norm_layer=norm_layer if self.patch_norm else None)
+
+ # absolute position embedding
+ if self.ape:
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
+ trunc_normal_(self.absolute_pos_embed, std=.02)
+
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ # stochastic depth
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
+
+ # build Residual Swin Transformer blocks (RSTB)
+ self.layers = nn.ModuleList()
+ for i_layer in range(self.num_layers):
+ layer = RSTB(dim=embed_dim,
+ input_resolution=(patches_resolution[0],
+ patches_resolution[1]),
+ depth=depths[i_layer],
+ num_heads=num_heads[i_layer],
+ window_size=window_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop_rate, attn_drop=attn_drop_rate,
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
+ norm_layer=norm_layer,
+ downsample=None,
+ use_checkpoint=use_checkpoint,
+ img_size=img_size,
+ patch_size=patch_size,
+ resi_connection=resi_connection
+
+ )
+ self.layers.append(layer)
+
+ if self.upsampler == 'pixelshuffle_hf':
+ self.layers_hf = nn.ModuleList()
+ for i_layer in range(self.num_layers):
+ layer = RSTB(dim=embed_dim,
+ input_resolution=(patches_resolution[0],
+ patches_resolution[1]),
+ depth=depths[i_layer],
+ num_heads=num_heads[i_layer],
+ window_size=window_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop_rate, attn_drop=attn_drop_rate,
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
+ norm_layer=norm_layer,
+ downsample=None,
+ use_checkpoint=use_checkpoint,
+ img_size=img_size,
+ patch_size=patch_size,
+ resi_connection=resi_connection
+
+ )
+ self.layers_hf.append(layer)
+
+ self.norm = norm_layer(self.num_features)
+
+ # build the last conv layer in deep feature extraction
+ if resi_connection == '1conv':
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
+ elif resi_connection == '3conv':
+ # to save parameters and memory
+ self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
+
+ #####################################################################################################
+ ################################ 3, high quality image reconstruction ################################
+ if self.upsampler == 'pixelshuffle':
+ # for classical SR
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
+ nn.LeakyReLU(inplace=True))
+ self.upsample = Upsample(upscale, num_feat)
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+ elif self.upsampler == 'pixelshuffle_aux':
+ self.conv_bicubic = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
+ self.conv_before_upsample = nn.Sequential(
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
+ nn.LeakyReLU(inplace=True))
+ self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+ self.conv_after_aux = nn.Sequential(
+ nn.Conv2d(3, num_feat, 3, 1, 1),
+ nn.LeakyReLU(inplace=True))
+ self.upsample = Upsample(upscale, num_feat)
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+
+ elif self.upsampler == 'pixelshuffle_hf':
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
+ nn.LeakyReLU(inplace=True))
+ self.upsample = Upsample(upscale, num_feat)
+ self.upsample_hf = Upsample_hf(upscale, num_feat)
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+ self.conv_first_hf = nn.Sequential(nn.Conv2d(num_feat, embed_dim, 3, 1, 1),
+ nn.LeakyReLU(inplace=True))
+ self.conv_after_body_hf = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
+ self.conv_before_upsample_hf = nn.Sequential(
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
+ nn.LeakyReLU(inplace=True))
+ self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+
+ elif self.upsampler == 'pixelshuffledirect':
+ # for lightweight SR (to save parameters)
+ self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
+ (patches_resolution[0], patches_resolution[1]))
+ elif self.upsampler == 'nearest+conv':
+ # for real-world SR (less artifacts)
+ assert self.upscale == 4, 'only support x4 now.'
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
+ nn.LeakyReLU(inplace=True))
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+ else:
+ # for image denoising and JPEG compression artifact reduction
+ self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'absolute_pos_embed'}
+
+ @torch.jit.ignore
+ def no_weight_decay_keywords(self):
+ return {'relative_position_bias_table'}
+
+ def check_image_size(self, x):
+ _, _, h, w = x.size()
+ mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
+ mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
+ return x
+
+ def forward_features(self, x):
+ x_size = (x.shape[2], x.shape[3])
+ x = self.patch_embed(x)
+ if self.ape:
+ x = x + self.absolute_pos_embed
+ x = self.pos_drop(x)
+
+ for layer in self.layers:
+ x = layer(x, x_size)
+
+ x = self.norm(x) # B L C
+ x = self.patch_unembed(x, x_size)
+
+ return x
+
+ def forward_features_hf(self, x):
+ x_size = (x.shape[2], x.shape[3])
+ x = self.patch_embed(x)
+ if self.ape:
+ x = x + self.absolute_pos_embed
+ x = self.pos_drop(x)
+
+ for layer in self.layers_hf:
+ x = layer(x, x_size)
+
+ x = self.norm(x) # B L C
+ x = self.patch_unembed(x, x_size)
+
+ return x
+
+ def forward(self, x):
+ H, W = x.shape[2:]
+ x = self.check_image_size(x)
+
+ self.mean = self.mean.type_as(x)
+ x = (x - self.mean) * self.img_range
+
+ if self.upsampler == 'pixelshuffle':
+ # for classical SR
+ x = self.conv_first(x)
+ x = self.conv_after_body(self.forward_features(x)) + x
+ x = self.conv_before_upsample(x)
+ x = self.conv_last(self.upsample(x))
+ elif self.upsampler == 'pixelshuffle_aux':
+ bicubic = F.interpolate(x, size=(H * self.upscale, W * self.upscale), mode='bicubic', align_corners=False)
+ bicubic = self.conv_bicubic(bicubic)
+ x = self.conv_first(x)
+ x = self.conv_after_body(self.forward_features(x)) + x
+ x = self.conv_before_upsample(x)
+ aux = self.conv_aux(x) # b, 3, LR_H, LR_W
+ x = self.conv_after_aux(aux)
+ x = self.upsample(x)[:, :, :H * self.upscale, :W * self.upscale] + bicubic[:, :, :H * self.upscale, :W * self.upscale]
+ x = self.conv_last(x)
+ aux = aux / self.img_range + self.mean
+ elif self.upsampler == 'pixelshuffle_hf':
+ # for classical SR with HF
+ x = self.conv_first(x)
+ x = self.conv_after_body(self.forward_features(x)) + x
+ x_before = self.conv_before_upsample(x)
+ x_out = self.conv_last(self.upsample(x_before))
+
+ x_hf = self.conv_first_hf(x_before)
+ x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf
+ x_hf = self.conv_before_upsample_hf(x_hf)
+ x_hf = self.conv_last_hf(self.upsample_hf(x_hf))
+ x = x_out + x_hf
+ x_hf = x_hf / self.img_range + self.mean
+
+ elif self.upsampler == 'pixelshuffledirect':
+ # for lightweight SR
+ x = self.conv_first(x)
+ x = self.conv_after_body(self.forward_features(x)) + x
+ x = self.upsample(x)
+ elif self.upsampler == 'nearest+conv':
+ # for real-world SR
+ x = self.conv_first(x)
+ x = self.conv_after_body(self.forward_features(x)) + x
+ x = self.conv_before_upsample(x)
+ x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
+ x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
+ x = self.conv_last(self.lrelu(self.conv_hr(x)))
+ else:
+ # for image denoising and JPEG compression artifact reduction
+ x_first = self.conv_first(x)
+ res = self.conv_after_body(self.forward_features(x_first)) + x_first
+ x = x + self.conv_last(res)
+
+ x = x / self.img_range + self.mean
+ if self.upsampler == "pixelshuffle_aux":
+ return x[:, :, :H*self.upscale, :W*self.upscale], aux
+
+ elif self.upsampler == "pixelshuffle_hf":
+ x_out = x_out / self.img_range + self.mean
+ return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale]
+
+ else:
+ return x[:, :, :H*self.upscale, :W*self.upscale]
+
+ def flops(self):
+ flops = 0
+ H, W = self.patches_resolution
+ flops += H * W * 3 * self.embed_dim * 9
+ flops += self.patch_embed.flops()
+ for layer in self.layers:
+ flops += layer.flops()
+ flops += H * W * 3 * self.embed_dim * self.embed_dim
+ flops += self.upsample.flops()
+ return flops
+
+
+if __name__ == '__main__':
+ upscale = 4
+ window_size = 8
+ height = (1024 // upscale // window_size + 1) * window_size
+ width = (720 // upscale // window_size + 1) * window_size
+ model = Swin2SR(upscale=2, img_size=(height, width),
+ window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
+ embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
+ print(model)
+ print(height, width, model.flops() / 1e9)
+
+ x = torch.randn((1, 3, height, width))
+ x = model(x)
+ print(x.shape)
diff --git a/modules/postprocessing.py b/modules/postprocessing.py
index dc2ae3d45..153128142 100644
--- a/modules/postprocessing.py
+++ b/modules/postprocessing.py
@@ -1,120 +1,120 @@
-import os
-import tempfile
-from typing import List
-
-from PIL import Image
-
-from modules import shared, images, devices, scripts, scripts_postprocessing, generation_parameters_copypaste
-from modules.shared import opts
-
-
-def run_postprocessing(extras_mode, image, image_folder: List[tempfile.NamedTemporaryFile], input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
- devices.torch_gc()
- shared.state.begin('extras')
- image_data = []
- image_names = []
- image_fullnames = []
- image_ext = []
- outputs = []
- params = {}
- if extras_mode == 1:
- for img in image_folder:
- if isinstance(img, Image.Image):
- image = img
- fn = ''
- ext = None
- else:
- try:
- image = Image.open(os.path.abspath(img.name))
- except Exception as e:
- shared.log.error(f'Failed to open image: file="{img.name}" {e}')
- continue
- fn, ext = os.path.splitext(img.orig_name)
- image_fullnames.append(img.name)
- image_data.append(image)
- image_names.append(fn)
- image_ext.append(ext)
- shared.log.debug(f'Process: mode=batch inputs={len(image_folder)} images={len(image_data)}')
- elif extras_mode == 2:
- assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
- assert input_dir, 'input directory not selected'
- image_list = os.listdir(input_dir)
- for filename in image_list:
- try:
- image = Image.open(filename)
- except Exception as e:
- shared.log.error(f'Failed to open image: file="{filename}" {e}')
- continue
- image_fullnames.append(filename)
- image_data.append(image)
- image_names.append(filename)
- image_ext.append(None)
- shared.log.debug(f'Process: mode=folder inputs={input_dir} files={len(image_list)} images={len(image_data)}')
- else:
- image_data.append(image)
- image_names.append(None)
- image_ext.append(None)
- if extras_mode == 2 and output_dir != '':
- outpath = output_dir
- else:
- outpath = opts.outdir_samples or opts.outdir_extras_samples
- processed_images = []
- for image, name, ext in zip(image_data, image_names, image_ext): # pylint: disable=redefined-argument-from-local
- shared.log.debug(f'Process: image={image} {args}')
- infotext = ''
- if shared.state.interrupted:
- shared.log.debug('Postprocess interrupted')
- break
- if image is None:
- continue
- shared.state.textinfo = name
- pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB"))
- scripts.scripts_postproc.run(pp, args)
- if opts.use_original_name_batch and name is not None:
- basename = os.path.splitext(os.path.basename(name))[0]
- else:
- basename = ''
- geninfo, items = images.read_info_from_image(image)
- params = generation_parameters_copypaste.parse_generation_parameters(geninfo)
- for k, v in items.items():
- pp.image.info[k] = v
- if 'parameters' in items:
- infotext = items['parameters'] + ', '
- infotext = infotext + ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None])
- pp.image.info["postprocessing"] = infotext
- processed_images.append(pp.image)
- if save_output:
- images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=ext or opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=pp.image.info, forced_filename=None)
- if extras_mode != 2 or show_extras_results:
- outputs.append(pp.image)
- image.close()
- scripts.scripts_postproc.postprocess(processed_images, args)
-
- devices.torch_gc()
- return outputs, infotext, params
-
-
-def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): #pylint: disable=unused-argument
- """old handler for API"""
-
- args = scripts.scripts_postproc.create_args_for_run({
- "Upscale": {
- "upscale_mode": resize_mode,
- "upscale_by": upscaling_resize,
- "upscale_to_width": upscaling_resize_w,
- "upscale_to_height": upscaling_resize_h,
- "upscale_crop": upscaling_crop,
- "upscaler_1_name": extras_upscaler_1,
- "upscaler_2_name": extras_upscaler_2,
- "upscaler_2_visibility": extras_upscaler_2_visibility,
- },
- "GFPGAN": {
- "gfpgan_visibility": gfpgan_visibility,
- },
- "CodeFormer": {
- "codeformer_visibility": codeformer_visibility,
- "codeformer_weight": codeformer_weight,
- },
- })
-
- return run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output=save_output)
+import os
+import tempfile
+from typing import List
+
+from PIL import Image
+
+from modules import shared, images, devices, scripts, scripts_postprocessing, generation_parameters_copypaste
+from modules.shared import opts
+
+
+def run_postprocessing(extras_mode, image, image_folder: List[tempfile.NamedTemporaryFile], input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
+ devices.torch_gc()
+ shared.state.begin('extras')
+ image_data = []
+ image_names = []
+ image_fullnames = []
+ image_ext = []
+ outputs = []
+ params = {}
+ if extras_mode == 1:
+ for img in image_folder:
+ if isinstance(img, Image.Image):
+ image = img
+ fn = ''
+ ext = None
+ else:
+ try:
+ image = Image.open(os.path.abspath(img.name))
+ except Exception as e:
+ shared.log.error(f'Failed to open image: file="{img.name}" {e}')
+ continue
+ fn, ext = os.path.splitext(img.orig_name)
+ image_fullnames.append(img.name)
+ image_data.append(image)
+ image_names.append(fn)
+ image_ext.append(ext)
+ shared.log.debug(f'Process: mode=batch inputs={len(image_folder)} images={len(image_data)}')
+ elif extras_mode == 2:
+ assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
+ assert input_dir, 'input directory not selected'
+ image_list = os.listdir(input_dir)
+ for filename in image_list:
+ try:
+ image = Image.open(filename)
+ except Exception as e:
+ shared.log.error(f'Failed to open image: file="{filename}" {e}')
+ continue
+ image_fullnames.append(filename)
+ image_data.append(image)
+ image_names.append(filename)
+ image_ext.append(None)
+ shared.log.debug(f'Process: mode=folder inputs={input_dir} files={len(image_list)} images={len(image_data)}')
+ else:
+ image_data.append(image)
+ image_names.append(None)
+ image_ext.append(None)
+ if extras_mode == 2 and output_dir != '':
+ outpath = output_dir
+ else:
+ outpath = opts.outdir_samples or opts.outdir_extras_samples
+ processed_images = []
+ for image, name, ext in zip(image_data, image_names, image_ext): # pylint: disable=redefined-argument-from-local
+ shared.log.debug(f'Process: image={image} {args}')
+ infotext = ''
+ if shared.state.interrupted:
+ shared.log.debug('Postprocess interrupted')
+ break
+ if image is None:
+ continue
+ shared.state.textinfo = name
+ pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB"))
+ scripts.scripts_postproc.run(pp, args)
+ if opts.use_original_name_batch and name is not None:
+ basename = os.path.splitext(os.path.basename(name))[0]
+ else:
+ basename = ''
+ geninfo, items = images.read_info_from_image(image)
+ params = generation_parameters_copypaste.parse_generation_parameters(geninfo)
+ for k, v in items.items():
+ pp.image.info[k] = v
+ if 'parameters' in items:
+ infotext = items['parameters'] + ', '
+ infotext = infotext + ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None])
+ pp.image.info["postprocessing"] = infotext
+ processed_images.append(pp.image)
+ if save_output:
+ images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=ext or opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=pp.image.info, forced_filename=None)
+ if extras_mode != 2 or show_extras_results:
+ outputs.append(pp.image)
+ image.close()
+ scripts.scripts_postproc.postprocess(processed_images, args)
+
+ devices.torch_gc()
+ return outputs, infotext, params
+
+
+def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): #pylint: disable=unused-argument
+ """old handler for API"""
+
+ args = scripts.scripts_postproc.create_args_for_run({
+ "Upscale": {
+ "upscale_mode": resize_mode,
+ "upscale_by": upscaling_resize,
+ "upscale_to_width": upscaling_resize_w,
+ "upscale_to_height": upscaling_resize_h,
+ "upscale_crop": upscaling_crop,
+ "upscaler_1_name": extras_upscaler_1,
+ "upscaler_2_name": extras_upscaler_2,
+ "upscaler_2_visibility": extras_upscaler_2_visibility,
+ },
+ "GFPGAN": {
+ "gfpgan_visibility": gfpgan_visibility,
+ },
+ "CodeFormer": {
+ "codeformer_visibility": codeformer_visibility,
+ "codeformer_weight": codeformer_weight,
+ },
+ })
+
+ return run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output=save_output)
diff --git a/modules/processing.py b/modules/processing.py
index 4804a37b1..c9eeb69e4 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -1,1406 +1,1406 @@
-import os
-import json
-import math
-import time
-import hashlib
-import random
-import warnings
-from contextlib import nullcontext
-from typing import Any, Dict, List
-from dataclasses import dataclass, field
-import torch
-import numpy as np
-import cv2
-from PIL import Image, ImageOps
-from skimage import exposure
-from einops import repeat, rearrange
-from blendmodes.blend import blendLayers, BlendType
-from installer import git_commit
-from modules import shared, devices, errors
-import modules.memstats
-import modules.lowvram
-import modules.masking
-import modules.paths
-import modules.scripts
-import modules.script_callbacks
-import modules.prompt_parser
-import modules.extra_networks
-import modules.face_restoration
-import modules.images as images
-import modules.styles
-import modules.sd_hijack_freeu
-import modules.sd_samplers
-import modules.sd_samplers_common
-import modules.sd_models
-import modules.sd_vae
-import modules.sd_vae_approx
-import modules.taesd.sd_vae_taesd
-import modules.generation_parameters_copypaste
-from modules.sd_hijack_hypertile import context_hypertile_vae, context_hypertile_unet, hypertile_set
-
-
-if shared.backend == shared.Backend.ORIGINAL:
- import modules.sd_hijack
-
-opt_C = 4
-opt_f = 8
-debug = shared.log.trace if os.environ.get('SD_PROCESS_DEBUG', None) is not None else lambda *args, **kwargs: None
-debug('Trace: PROCESS')
-
-
-def setup_color_correction(image):
- debug("Calibrating color correction")
- correction_target = cv2.cvtColor(np.asarray(image.copy()), cv2.COLOR_RGB2LAB)
- return correction_target
-
-
-def apply_color_correction(correction, original_image):
- shared.log.debug(f"Applying color correction: correction={correction} image={original_image}")
- image = Image.fromarray(cv2.cvtColor(exposure.match_histograms(cv2.cvtColor(np.asarray(original_image), cv2.COLOR_RGB2LAB), correction, channel_axis=2), cv2.COLOR_LAB2RGB).astype("uint8"))
- image = blendLayers(image, original_image, BlendType.LUMINOSITY)
- return image
-
-
-def apply_overlay(image: Image, paste_loc, index, overlays):
- debug(f'Apply overlay: image={image} loc={paste_loc} index={index} overlays={overlays}')
- if overlays is None or index >= len(overlays):
- return image
- overlay = overlays[index]
- if paste_loc is not None:
- x, y, w, h = paste_loc
- if image.width != w or image.height != h or x != 0 or y != 0:
- base_image = Image.new('RGBA', (overlay.width, overlay.height))
- image = images.resize_image(2, image, w, h)
- base_image.paste(image, (x, y))
- image = base_image
- image = image.convert('RGBA')
- image.alpha_composite(overlay)
- image = image.convert('RGB')
- return image
-
-
-def create_binary_mask(image):
- if image.mode == 'RGBA' and image.getextrema()[-1] != (255, 255):
- image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
- else:
- image = image.convert('L')
- return image
-
-
-def images_tensor_to_samples(image, approximation=None, model=None): # pylint: disable=unused-argument
- if model is None:
- model = shared.sd_model
- model.first_stage_model.to(devices.dtype_vae)
- image = image.to(shared.device, dtype=devices.dtype_vae)
- image = image * 2 - 1
- if len(image) > 1:
- x_latent = torch.stack([
- model.get_first_stage_encoding(model.encode_first_stage(torch.unsqueeze(img, 0)))[0]
- for img in image
- ])
- else:
- x_latent = model.get_first_stage_encoding(model.encode_first_stage(image))
- return x_latent
-
-
-def txt2img_image_conditioning(sd_model, x, width, height):
- if sd_model.model.conditioning_key in {'hybrid', 'concat'}: # Inpainting models
- # The "masked-image" in this case will just be all zeros since the entire image is masked.
- image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
- image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning))
- # Add the fake full 1s mask to the first dimension.
- image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) # pylint: disable=not-callable
- image_conditioning = image_conditioning.to(x.dtype)
- return image_conditioning
- elif sd_model.model.conditioning_key == "crossattn-adm": # UnCLIP models
- return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device)
- else:
- # Dummy zero conditioning if we're not using inpainting or unclip models.
- # Still takes up a bit of memory, but no encoder call.
- # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
- return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
-
-
-def get_sampler_name(sampler_index: int, img: bool = False) -> str:
- if len(modules.sd_samplers.samplers) > sampler_index:
- sampler_name = modules.sd_samplers.samplers[sampler_index].name
- else:
- sampler_name = "UniPC"
- shared.log.warning(f'Sampler not found: index={sampler_index} available={[s.name for s in modules.sd_samplers.samplers]} fallback={sampler_name}')
- if img and sampler_name == "PLMS":
- sampler_name = "UniPC"
- shared.log.warning(f'Sampler not compatible: name=PLMS fallback={sampler_name}')
- return sampler_name
-
-
-@dataclass(repr=False)
-class StableDiffusionProcessing:
- """
- The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
- """
- def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, hr_sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, image_cfg_scale: float = None, clip_skip: int = 1, width: int = 512, height: int = 512, full_quality: bool = True, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, diffusers_guidance_rescale: float = 0.7, sag_scale: float = 0.0, resize_mode: int = 0, resize_name: str = 'None', scale_by: float = 0, selected_scale_tab: int = 0, hdr_clamp: bool = False, hdr_boundary: float = 4.0, hdr_threshold: float = 3.5, hdr_center: bool = False, hdr_channel_shift: float = 0.8, hdr_full_shift: float = 0.8, hdr_maximize: bool = False, hdr_max_center: float = 0.6, hdr_max_boundry: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None): # pylint: disable=unused-argument
- self.outpath_samples: str = outpath_samples
- self.outpath_grids: str = outpath_grids
- self.prompt: str = prompt
- self.prompt_for_display: str = None
- self.negative_prompt: str = (negative_prompt or "")
- self.styles: list = styles or []
- self.seed: int = seed
- self.subseed: int = subseed
- self.subseed_strength: float = subseed_strength
- self.seed_resize_from_h: int = seed_resize_from_h
- self.seed_resize_from_w: int = seed_resize_from_w
- self.sampler_name: str = sampler_name
- self.hr_sampler_name: str = hr_sampler_name
- self.batch_size: int = batch_size
- self.n_iter: int = n_iter
- self.steps: int = steps
- self.hr_second_pass_steps = 0
- self.cfg_scale: float = cfg_scale
- self.scale_by: float = scale_by
- self.image_cfg_scale = image_cfg_scale
- self.diffusers_guidance_rescale = diffusers_guidance_rescale
- self.sag_scale = sag_scale
- if devices.backend == "ipex" and width == 1024 and height == 1024 and os.environ.get('DISABLE_IPEX_1024_WA', None) is None:
- width = 1080
- height = 1080
- self.width: int = width
- self.height: int = height
- self.full_quality: bool = full_quality
- self.restore_faces: bool = restore_faces
- self.tiling: bool = tiling
- self.do_not_save_samples: bool = do_not_save_samples
- self.do_not_save_grid: bool = do_not_save_grid
- self.extra_generation_params: dict = extra_generation_params or {}
- self.overlay_images = overlay_images
- self.eta = eta
- self.do_not_reload_embeddings = do_not_reload_embeddings
- self.paste_to = None
- self.color_corrections = None
- self.denoising_strength: float = denoising_strength
- self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
- self.override_settings_restore_afterwards = override_settings_restore_afterwards
- self.is_using_inpainting_conditioning = False
- self.disable_extra_networks = False
- self.token_merging_ratio = 0
- self.token_merging_ratio_hr = 0
- # self.scripts = modules.scripts.ScriptRunner() # set via property
- # self.script_args = script_args or [] # set via property
- self.per_script_args = {}
- self.all_prompts = None
- self.all_negative_prompts = None
- self.all_seeds = None
- self.all_subseeds = None
- self.clip_skip = clip_skip
- self.iteration = 0
- self.is_control = False
- self.is_hr_pass = False
- self.is_refiner_pass = False
- self.hr_force = False
- self.enable_hr = None
- self.hr_scale = None
- self.hr_upscaler = None
- self.hr_resize_x = 0
- self.hr_resize_y = 0
- self.hr_upscale_to_x = 0
- self.hr_upscale_to_y = 0
- self.truncate_x = 0
- self.truncate_y = 0
- self.applied_old_hires_behavior_to = None
- self.refiner_steps = 5
- self.refiner_start = 0
- self.refiner_prompt = ''
- self.refiner_negative = ''
- self.ops = []
- self.resize_mode: int = resize_mode
- self.resize_name: str = resize_name
- self.ddim_discretize = shared.opts.ddim_discretize
- self.s_min_uncond = shared.opts.s_min_uncond
- self.s_churn = shared.opts.s_churn
- self.s_noise = shared.opts.s_noise
- self.s_min = shared.opts.s_min
- self.s_max = shared.opts.s_max
- self.s_tmin = shared.opts.s_tmin
- self.s_tmax = float('inf') # not representable as a standard ui option
- shared.opts.data['clip_skip'] = clip_skip
- self.task_args = {}
- # a1111 compatibility items
- self.refiner_switch_at = 0
- self.hr_prompt = ''
- self.all_hr_prompts = []
- self.hr_negative_prompt = ''
- self.all_hr_negative_prompts = []
- self.comments = {}
- self.is_api = False
- self.scripts_value: modules.scripts.ScriptRunner = field(default=None, init=False)
- self.script_args_value: list = field(default=None, init=False)
- self.scripts_setup_complete: bool = field(default=False, init=False)
- # hdr
- self.hdr_clamp = hdr_clamp
- self.hdr_boundary = hdr_boundary
- self.hdr_threshold = hdr_threshold
- self.hdr_center = hdr_center
- self.hdr_channel_shift = hdr_channel_shift
- self.hdr_full_shift = hdr_full_shift
- self.hdr_maximize = hdr_maximize
- self.hdr_max_center = hdr_max_center
- self.hdr_max_boundry = hdr_max_boundry
- self.scheduled_prompt: bool = False
- self.prompt_embeds = []
- self.positive_pooleds = []
- self.negative_embeds = []
- self.negative_pooleds = []
-
-
- @property
- def sd_model(self):
- return shared.sd_model
-
- @property
- def scripts(self):
- return self.scripts_value
-
- @scripts.setter
- def scripts(self, value):
- self.scripts_value = value
- if self.scripts_value and self.script_args_value and not self.scripts_setup_complete:
- self.setup_scripts()
-
- @property
- def script_args(self):
- return self.script_args_value
-
- @script_args.setter
- def script_args(self, value):
- self.script_args_value = value
- if self.scripts_value and self.script_args_value and not self.scripts_setup_complete:
- self.setup_scripts()
-
- def setup_scripts(self):
- self.scripts_setup_complete = True
- self.scripts.setup_scrips(self, is_ui=not self.is_api)
-
- def comment(self, text):
- self.comments[text] = 1
-
- def txt2img_image_conditioning(self, x, width=None, height=None):
- self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
- return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height)
-
- def depth2img_image_conditioning(self, source_image):
- # Use the AddMiDaS helper to Format our source image to suit the MiDaS model
- from ldm.data.util import AddMiDaS
- transformer = AddMiDaS(model_type="dpt_hybrid")
- transformed = transformer({"jpg": rearrange(source_image[0], "c h w -> h w c")})
- midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
- midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
- conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
- conditioning = torch.nn.functional.interpolate(
- self.sd_model.depth_model(midas_in),
- size=conditioning_image.shape[2:],
- mode="bicubic",
- align_corners=False,
- )
- (depth_min, depth_max) = torch.aminmax(conditioning)
- conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1.
- return conditioning
-
- def edit_image_conditioning(self, source_image):
- conditioning_image = self.sd_model.encode_first_stage(source_image).mode()
- return conditioning_image
-
- def unclip_image_conditioning(self, source_image):
- c_adm = self.sd_model.embedder(source_image)
- if self.sd_model.noise_augmentor is not None:
- noise_level = 0
- c_adm, noise_level_emb = self.sd_model.noise_augmentor(c_adm, noise_level=repeat(torch.tensor([noise_level]).to(c_adm.device), '1 -> b', b=c_adm.shape[0]))
- c_adm = torch.cat((c_adm, noise_level_emb), 1)
- return c_adm
-
- def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None):
- self.is_using_inpainting_conditioning = True
- # Handle the different mask inputs
- if image_mask is not None:
- if torch.is_tensor(image_mask):
- conditioning_mask = image_mask
- else:
- conditioning_mask = np.array(image_mask.convert("L"))
- conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
- conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
- # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
- conditioning_mask = torch.round(conditioning_mask)
- else:
- conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
- # Create another latent image, this time with a masked version of the original input.
- # Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
- conditioning_mask = conditioning_mask.to(device=source_image.device, dtype=source_image.dtype)
- conditioning_image = torch.lerp(
- source_image,
- source_image * (1.0 - conditioning_mask),
- getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)
- )
- # Encode the new masked image using first stage of network.
- conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
- # Create the concatenated conditioning tensor to be fed to `c_concat`
- conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
- conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
- image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
- image_conditioning = image_conditioning.to(device=shared.device, dtype=source_image.dtype)
- return image_conditioning
-
- def diffusers_image_conditioning(self, _source_image, latent_image, _image_mask=None):
- # shared.log.warning('Diffusers not implemented: img2img_image_conditioning')
- return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
-
- def img2img_image_conditioning(self, source_image, latent_image, image_mask=None):
- from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
- source_image = devices.cond_cast_float(source_image)
- # HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
- # identify itself with a field common to all models. The conditioning_key is also hybrid.
- if shared.backend == shared.Backend.DIFFUSERS:
- return self.diffusers_image_conditioning(source_image, latent_image, image_mask)
- if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
- return self.depth2img_image_conditioning(source_image)
- if hasattr(self.sd_model, 'cond_stage_key') and self.sd_model.cond_stage_key == "edit":
- return self.edit_image_conditioning(source_image)
- if hasattr(self.sampler, 'conditioning_key') and self.sampler.conditioning_key in {'hybrid', 'concat'}:
- return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
- if hasattr(self.sampler, 'conditioning_key') and self.sampler.conditioning_key == "crossattn-adm":
- return self.unclip_image_conditioning(source_image)
- # Dummy zero conditioning if we're not using inpainting or depth model.
- return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
-
- def init(self, all_prompts, all_seeds, all_subseeds):
- pass
-
- def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
- raise NotImplementedError
-
- def close(self):
- self.sampler = None # pylint: disable=attribute-defined-outside-init
-
- def get_token_merging_ratio(self, for_hr=False):
- if for_hr:
- return self.token_merging_ratio_hr or shared.opts.token_merging_ratio_hr or self.token_merging_ratio or shared.opts.token_merging_ratio
- return self.token_merging_ratio or shared.opts.token_merging_ratio
-
-
-class Processed:
- def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
- self.images = images_list
- self.prompt = p.prompt
- self.negative_prompt = p.negative_prompt
- self.seed = seed
- self.subseed = subseed
- self.subseed_strength = p.subseed_strength
- self.info = info
- self.comments = comments
- self.width = p.width if hasattr(p, 'width') else (self.images[0].width if len(self.images) > 0 else 0)
- self.height = p.height if hasattr(p, 'height') else (self.images[0].height if len(self.images) > 0 else 0)
- self.sampler_name = p.sampler_name
- self.cfg_scale = p.cfg_scale
- self.image_cfg_scale = p.image_cfg_scale
- self.steps = p.steps
- self.batch_size = p.batch_size
- self.restore_faces = p.restore_faces
- self.face_restoration_model = shared.opts.face_restoration_model if p.restore_faces else None
- self.sd_model_hash = getattr(shared.sd_model, 'sd_model_hash', '')
- self.seed_resize_from_w = p.seed_resize_from_w
- self.seed_resize_from_h = p.seed_resize_from_h
- self.denoising_strength = p.denoising_strength
- self.extra_generation_params = p.extra_generation_params
- self.index_of_first_image = index_of_first_image
- self.styles = p.styles
- self.job_timestamp = shared.state.job_timestamp
- self.clip_skip = p.clip_skip
- self.eta = p.eta
- self.ddim_discretize = p.ddim_discretize
- self.s_churn = p.s_churn
- self.s_tmin = p.s_tmin
- self.s_tmax = p.s_tmax
- self.s_noise = p.s_noise
- self.s_min_uncond = p.s_min_uncond
- self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
- self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
- self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1
- self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
- self.is_using_inpainting_conditioning = p.is_using_inpainting_conditioning
- self.all_prompts = all_prompts or p.all_prompts or [self.prompt]
- self.all_negative_prompts = all_negative_prompts or p.all_negative_prompts or [self.negative_prompt]
- self.all_seeds = all_seeds or p.all_seeds or [self.seed]
- self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed]
- self.token_merging_ratio = p.token_merging_ratio
- self.token_merging_ratio_hr = p.token_merging_ratio_hr
- self.infotexts = infotexts or [info]
-
- def js(self):
- obj = {
- "prompt": self.all_prompts[0],
- "all_prompts": self.all_prompts,
- "negative_prompt": self.all_negative_prompts[0],
- "all_negative_prompts": self.all_negative_prompts,
- "seed": self.seed,
- "all_seeds": self.all_seeds,
- "subseed": self.subseed,
- "all_subseeds": self.all_subseeds,
- "subseed_strength": self.subseed_strength,
- "width": self.width,
- "height": self.height,
- "sampler_name": self.sampler_name,
- "cfg_scale": self.cfg_scale,
- "steps": self.steps,
- "batch_size": self.batch_size,
- "restore_faces": self.restore_faces,
- "face_restoration_model": self.face_restoration_model,
- "sd_model_hash": self.sd_model_hash,
- "seed_resize_from_w": self.seed_resize_from_w,
- "seed_resize_from_h": self.seed_resize_from_h,
- "denoising_strength": self.denoising_strength,
- "extra_generation_params": self.extra_generation_params,
- "index_of_first_image": self.index_of_first_image,
- "infotexts": self.infotexts,
- "styles": self.styles,
- "job_timestamp": self.job_timestamp,
- "clip_skip": self.clip_skip,
- # "is_using_inpainting_conditioning": self.is_using_inpainting_conditioning,
- }
- return json.dumps(obj)
-
- def infotext(self, p: StableDiffusionProcessing, index):
- return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
-
- def get_token_merging_ratio(self, for_hr=False):
- return self.token_merging_ratio_hr if for_hr else self.token_merging_ratio
-
-
-def slerp(val, low, high): # from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
- low_norm = low/torch.norm(low, dim=1, keepdim=True)
- high_norm = high/torch.norm(high, dim=1, keepdim=True)
- dot = (low_norm*high_norm).sum(1)
-
- if dot.mean() > 0.9995:
- return low * val + high * (1 - val)
-
- omega = torch.acos(dot)
- so = torch.sin(omega)
- res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
- return res
-
-
-def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
- eta_noise_seed_delta = shared.opts.eta_noise_seed_delta or 0
- xs = []
- # if we have multiple seeds, this means we are working with batch size>1; this then
- # enables the generation of additional tensors with noise that the sampler will use during its processing.
- # Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
- # produce the same images as with two batches [100], [101].
- if p is not None and p.sampler is not None and (len(seeds) > 1 and shared.opts.enable_batch_seeds or eta_noise_seed_delta > 0):
- sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
- else:
- sampler_noises = None
- for i, seed in enumerate(seeds):
- noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
- subnoise = None
- if subseeds is not None:
- subseed = 0 if i >= len(subseeds) else subseeds[i]
- subnoise = devices.randn(subseed, noise_shape)
- # randn results depend on device; gpu and cpu get different results for same seed;
- # the way I see it, it's better to do this on CPU, so that everyone gets same result;
- # but the original script had it like this, so I do not dare change it for now because
- # it will break everyone's seeds.
- noise = devices.randn(seed, noise_shape)
- if subnoise is not None:
- noise = slerp(subseed_strength, noise, subnoise)
- if noise_shape != shape:
- x = devices.randn(seed, shape)
- dx = (shape[2] - noise_shape[2]) // 2
- dy = (shape[1] - noise_shape[1]) // 2
- w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
- h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
- tx = 0 if dx < 0 else dx
- ty = 0 if dy < 0 else dy
- dx = max(-dx, 0)
- dy = max(-dy, 0)
- x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w]
- noise = x
- if sampler_noises is not None:
- cnt = p.sampler.number_of_needed_noises(p)
- if eta_noise_seed_delta > 0:
- torch.manual_seed(seed + eta_noise_seed_delta)
- for j in range(cnt):
- sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
- xs.append(noise)
- if sampler_noises is not None:
- p.sampler.sampler_noises = [torch.stack(n).to(shared.device) for n in sampler_noises]
- x = torch.stack(xs).to(shared.device)
- return x
-
-
-def decode_first_stage(model, x, full_quality=True):
- if not shared.opts.keep_incomplete and (shared.state.skipped or shared.state.interrupted):
- shared.log.debug(f'Decode VAE: skipped={shared.state.skipped} interrupted={shared.state.interrupted}')
- x_sample = torch.zeros((len(x), 3, x.shape[2] * 8, x.shape[3] * 8), dtype=devices.dtype_vae, device=devices.device)
- return x_sample
- prev_job = shared.state.job
- shared.state.job = 'vae'
- with devices.autocast(disable = x.dtype==devices.dtype_vae):
- try:
- if full_quality:
- if hasattr(model, 'decode_first_stage'):
- x_sample = model.decode_first_stage(x)
- elif hasattr(model, 'vae'):
- x_sample = model.vae(x)
- else:
- x_sample = x
- shared.log.error('Decode VAE unknown model')
- else:
- x_sample = torch.zeros((len(x), 3, x.shape[2] * 8, x.shape[3] * 8), dtype=devices.dtype_vae, device=devices.device)
- for i in range(len(x_sample)):
- x_sample[i] = modules.taesd.sd_vae_taesd.decode(x[i])
- except Exception as e:
- x_sample = x
- shared.log.error(f'Decode VAE: {e}')
- shared.state.job = prev_job
- return x_sample
-
-
-def get_fixed_seed(seed):
- if seed is None or seed == '' or seed == -1:
- return int(random.randrange(4294967294))
- return seed
-
-
-def fix_seed(p):
- p.seed = get_fixed_seed(p.seed)
- p.subseed = get_fixed_seed(p.subseed)
-
-
-def create_infotext(p: StableDiffusionProcessing, all_prompts=None, all_seeds=None, all_subseeds=None, comments=None, iteration=0, position_in_batch=0, index=None, all_negative_prompts=None):
- if not hasattr(shared.sd_model, 'sd_checkpoint_info'):
- return ''
- if index is None:
- index = position_in_batch + iteration * p.batch_size
- if all_prompts is None:
- all_prompts = p.all_prompts or [p.prompt]
- if all_negative_prompts is None:
- all_negative_prompts = p.all_negative_prompts or [p.negative_prompt]
- if all_seeds is None:
- all_seeds = p.all_seeds or [p.seed]
- if all_subseeds is None:
- all_subseeds = p.all_subseeds or [p.subseed]
- while len(all_prompts) <= index:
- all_prompts.append(all_prompts[-1])
- while len(all_seeds) <= index:
- all_seeds.append(all_seeds[-1])
- while len(all_subseeds) <= index:
- all_subseeds.append(all_subseeds[-1])
- while len(all_negative_prompts) <= index:
- all_negative_prompts.append(all_negative_prompts[-1])
- comment = ', '.join(comments) if comments is not None and type(comments) is list else None
- ops = list(set(p.ops))
- ops.reverse()
- args = {
- # basic
- "Steps": p.steps,
- "Seed": all_seeds[index],
- "Sampler": p.sampler_name,
- "CFG scale": p.cfg_scale,
- "Size": f"{p.width}x{p.height}" if hasattr(p, 'width') and hasattr(p, 'height') else None,
- "Batch": f'{p.n_iter}x{p.batch_size}' if p.n_iter > 1 or p.batch_size > 1 else None,
- "Index": f'{p.iteration + 1}x{index + 1}' if (p.n_iter > 1 or p.batch_size > 1) and index >= 0 else None,
- "Parser": shared.opts.prompt_attention,
- "Model": None if (not shared.opts.add_model_name_to_info) or (not shared.sd_model.sd_checkpoint_info.model_name) else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', ''),
- "Model hash": getattr(p, 'sd_model_hash', None if (not shared.opts.add_model_hash_to_info) or (not shared.sd_model.sd_model_hash) else shared.sd_model.sd_model_hash),
- "VAE": (None if not shared.opts.add_model_name_to_info or modules.sd_vae.loaded_vae_file is None else os.path.splitext(os.path.basename(modules.sd_vae.loaded_vae_file))[0]) if p.full_quality else 'TAESD',
- "Seed resize from": None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}",
- "Clip skip": p.clip_skip if p.clip_skip > 1 else None,
- "Prompt2": p.refiner_prompt if len(p.refiner_prompt) > 0 else None,
- "Negative2": p.refiner_negative if len(p.refiner_negative) > 0 else None,
- "Styles": "; ".join(p.styles) if p.styles is not None and len(p.styles) > 0 else None,
- "Tiling": p.tiling if p.tiling else None,
- # sdnext
- "Backend": 'Diffusers' if shared.backend == shared.Backend.DIFFUSERS else 'Original',
- "App": 'SD.Next',
- "Version": git_commit,
- "Comment": comment,
- "Operations": '; '.join(ops).replace('"', '') if len(p.ops) > 0 else 'none',
- }
- if 'txt2img' in p.ops:
- pass
- if shared.backend == shared.Backend.ORIGINAL:
- args["Variation seed"] = None if p.subseed_strength == 0 else all_subseeds[index],
- args["Variation strength"] = None if p.subseed_strength == 0 else p.subseed_strength,
- if 'hires' in p.ops or 'upscale' in p.ops:
- args["Second pass"] = p.enable_hr
- args["Hires force"] = p.hr_force
- args["Hires steps"] = p.hr_second_pass_steps
- args["Hires upscaler"] = p.hr_upscaler
- args["Hires upscale"] = p.hr_scale
- args["Hires resize"] = f"{p.hr_resize_x}x{p.hr_resize_y}"
- args["Hires size"] = f"{p.hr_upscale_to_x}x{p.hr_upscale_to_y}"
- args["Denoising strength"] = p.denoising_strength
- args["Hires sampler"] = p.hr_sampler_name
- args["Image CFG scale"] = p.image_cfg_scale
- args["CFG rescale"] = p.diffusers_guidance_rescale
- if 'refine' in p.ops:
- args["Second pass"] = p.enable_hr
- args["Refiner"] = None if (not shared.opts.add_model_name_to_info) or (not shared.sd_refiner) or (not shared.sd_refiner.sd_checkpoint_info.model_name) else shared.sd_refiner.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')
- args['Image CFG scale'] = p.image_cfg_scale
- args['Refiner steps'] = p.refiner_steps
- args['Refiner start'] = p.refiner_start
- args["Hires steps"] = p.hr_second_pass_steps
- args["Hires sampler"] = p.hr_sampler_name
- args["CFG rescale"] = p.diffusers_guidance_rescale
- if 'img2img' in p.ops or 'inpaint' in p.ops:
- args["Init image size"] = f"{getattr(p, 'init_img_width', 0)}x{getattr(p, 'init_img_height', 0)}"
- args["Init image hash"] = getattr(p, 'init_img_hash', None)
- args["Mask weight"] = getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None
- args['Resize scale'] = getattr(p, 'scale_by', None)
- args["Mask blur"] = p.mask_blur if getattr(p, 'mask', None) is not None and getattr(p, 'mask_blur', 0) > 0 else None
- args["Denoising strength"] = getattr(p, 'denoising_strength', None)
- if args["Size"] is None:
- args["Size"] = args["Init image size"]
- # lookup by index
- if getattr(p, 'resize_mode', None) is not None:
- args['Resize mode'] = shared.resize_modes[p.resize_mode]
- if 'face' in p.ops:
- args["Face restoration"] = shared.opts.face_restoration_model
- if 'color' in p.ops:
- args["Color correction"] = True
- # embeddings
- if hasattr(modules.sd_hijack.model_hijack, 'embedding_db') and len(modules.sd_hijack.model_hijack.embedding_db.embeddings_used) > 0: # this is for original hijaacked models only, diffusers are handled separately
- args["Embeddings"] = ', '.join(modules.sd_hijack.model_hijack.embedding_db.embeddings_used)
- # samplers
- args["Sampler ENSD"] = shared.opts.eta_noise_seed_delta if shared.opts.eta_noise_seed_delta != 0 and modules.sd_samplers_common.is_sampler_using_eta_noise_seed_delta(p) else None
- args["Sampler ENSM"] = p.initial_noise_multiplier if getattr(p, 'initial_noise_multiplier', 1.0) != 1.0 else None
- args['Sampler order'] = shared.opts.schedulers_solver_order if shared.opts.schedulers_solver_order != shared.opts.data_labels.get('schedulers_solver_order').default else None
- if shared.backend == shared.Backend.DIFFUSERS:
- args['Sampler beta schedule'] = shared.opts.schedulers_beta_schedule if shared.opts.schedulers_beta_schedule != shared.opts.data_labels.get('schedulers_beta_schedule').default else None
- args['Sampler beta start'] = shared.opts.schedulers_beta_start if shared.opts.schedulers_beta_start != shared.opts.data_labels.get('schedulers_beta_start').default else None
- args['Sampler beta end'] = shared.opts.schedulers_beta_end if shared.opts.schedulers_beta_end != shared.opts.data_labels.get('schedulers_beta_end').default else None
- args['Sampler DPM solver'] = shared.opts.schedulers_dpm_solver if shared.opts.schedulers_dpm_solver != shared.opts.data_labels.get('schedulers_dpm_solver').default else None
- if shared.backend == shared.Backend.ORIGINAL:
- args['Sampler brownian'] = shared.opts.schedulers_brownian_noise if shared.opts.schedulers_brownian_noise != shared.opts.data_labels.get('schedulers_brownian_noise').default else None
- args['Sampler discard'] = shared.opts.schedulers_discard_penultimate if shared.opts.schedulers_discard_penultimate != shared.opts.data_labels.get('schedulers_discard_penultimate').default else None
- args['Sampler dyn threshold'] = shared.opts.schedulers_use_thresholding if shared.opts.schedulers_use_thresholding != shared.opts.data_labels.get('schedulers_use_thresholding').default else None
- args['Sampler karras'] = shared.opts.schedulers_use_karras if shared.opts.schedulers_use_karras != shared.opts.data_labels.get('schedulers_use_karras').default else None
- args['Sampler low order'] = shared.opts.schedulers_use_loworder if shared.opts.schedulers_use_loworder != shared.opts.data_labels.get('schedulers_use_loworder').default else None
- args['Sampler quantization'] = shared.opts.enable_quantization if shared.opts.enable_quantization != shared.opts.data_labels.get('enable_quantization').default else None
- args['Sampler sigma'] = shared.opts.schedulers_sigma if shared.opts.schedulers_sigma != shared.opts.data_labels.get('schedulers_sigma').default else None
- args['Sampler sigma min'] = shared.opts.s_min if shared.opts.s_min != shared.opts.data_labels.get('s_min').default else None
- args['Sampler sigma max'] = shared.opts.s_max if shared.opts.s_max != shared.opts.data_labels.get('s_max').default else None
- args['Sampler sigma churn'] = shared.opts.s_churn if shared.opts.s_churn != shared.opts.data_labels.get('s_churn').default else None
- args['Sampler sigma uncond'] = shared.opts.s_churn if shared.opts.s_churn != shared.opts.data_labels.get('s_churn').default else None
- args['Sampler sigma noise'] = shared.opts.s_noise if shared.opts.s_noise != shared.opts.data_labels.get('s_noise').default else None
- args['Sampler sigma tmin'] = shared.opts.s_tmin if shared.opts.s_tmin != shared.opts.data_labels.get('s_tmin').default else None
- # tome
- token_merging_ratio = p.get_token_merging_ratio()
- token_merging_ratio_hr = p.get_token_merging_ratio(for_hr=True) if p.enable_hr else None
- args['ToMe'] = token_merging_ratio if token_merging_ratio != 0 else None
- args['ToMe hires'] = token_merging_ratio_hr if token_merging_ratio_hr != 0 else None
-
- args.update(p.extra_generation_params)
- params_text = ", ".join([k if k == v else f'{k}: {modules.generation_parameters_copypaste.quote(v)}' for k, v in args.items() if v is not None])
- negative_prompt_text = f"\nNegative prompt: {all_negative_prompts[index]}" if all_negative_prompts[index] else ""
- infotext = f"{all_prompts[index]}{negative_prompt_text}\n{params_text}".strip()
- return infotext
-
-
-def process_images(p: StableDiffusionProcessing) -> Processed:
- debug(f'Process images: {vars(p)}')
- if not hasattr(p.sd_model, 'sd_checkpoint_info'):
- return None
- if p.scripts is not None and isinstance(p.scripts, modules.scripts.ScriptRunner):
- p.scripts.before_process(p)
- stored_opts = {}
- for k, v in p.override_settings.copy().items():
- if shared.opts.data.get(k, None) is None and shared.opts.data_labels.get(k, None) is None:
- continue
- orig = shared.opts.data.get(k, None) or shared.opts.data_labels[k].default
- if orig == v or (type(orig) == str and os.path.splitext(orig)[0] == v):
- p.override_settings.pop(k, None)
- for k in p.override_settings.keys():
- stored_opts[k] = shared.opts.data.get(k, None) or shared.opts.data_labels[k].default
- res = None
- try:
- # if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
- if p.override_settings.get('sd_model_checkpoint', None) is not None and modules.sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None:
- shared.log.warning(f"Override not found: checkpoint={p.override_settings.get('sd_model_checkpoint', None)}")
- p.override_settings.pop('sd_model_checkpoint', None)
- modules.sd_models.reload_model_weights()
- if p.override_settings.get('sd_model_refiner', None) is not None and modules.sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_refiner')) is None:
- shared.log.warning(f"Override not found: refiner={p.override_settings.get('sd_model_refiner', None)}")
- p.override_settings.pop('sd_model_refiner', None)
- modules.sd_models.reload_model_weights()
- if p.override_settings.get('sd_vae', None) is not None:
- if p.override_settings.get('sd_vae', None) == 'TAESD':
- p.full_quality = False
- p.override_settings.pop('sd_vae', None)
- if p.override_settings.get('Hires upscaler', None) is not None:
- p.enable_hr = True
- if len(p.override_settings.keys()) > 0:
- shared.log.debug(f'Override: {p.override_settings}')
- for k, v in p.override_settings.items():
- setattr(shared.opts, k, v)
- if k == 'sd_model_checkpoint':
- modules.sd_models.reload_model_weights()
- if k == 'sd_vae':
- modules.sd_vae.reload_vae_weights()
-
- shared.prompt_styles.apply_styles_to_extra(p)
- if not shared.opts.cuda_compile:
- modules.sd_models.apply_token_merging(p.sd_model, p.get_token_merging_ratio())
- modules.sd_hijack_freeu.apply_freeu(p, shared.backend == shared.Backend.ORIGINAL)
-
- modules.script_callbacks.before_process_callback(p)
-
- if shared.cmd_opts.profile:
- import cProfile
- profile_python = cProfile.Profile()
- profile_python.enable()
- with context_hypertile_vae(p), context_hypertile_unet(p):
- import torch.profiler # pylint: disable=redefined-outer-name
- activities=[torch.profiler.ProfilerActivity.CPU]
- if torch.cuda.is_available():
- activities.append(torch.profiler.ProfilerActivity.CUDA)
- shared.log.debug(f'Torch profile: activities={activities}')
- if shared.profiler is None:
- shared.profiler = torch.profiler.profile(activities=activities, profile_memory=True, with_modules=True)
- shared.profiler.start()
- shared.profiler.step()
- res = process_images_inner(p)
- errors.profile_torch(shared.profiler, 'Process')
- errors.profile(profile_python, 'Process')
- else:
- with context_hypertile_vae(p), context_hypertile_unet(p):
- res = process_images_inner(p)
-
- finally:
- if not shared.opts.cuda_compile:
- modules.sd_models.apply_token_merging(p.sd_model, 0)
- modules.script_callbacks.after_process_callback(p)
- if p.override_settings_restore_afterwards: # restore opts to original state
- for k, v in stored_opts.items():
- setattr(shared.opts, k, v)
- if k == 'sd_model_checkpoint':
- modules.sd_models.reload_model_weights()
- if k == 'sd_model_refiner':
- modules.sd_models.reload_model_weights()
- if k == 'sd_vae':
- modules.sd_vae.reload_vae_weights()
- return res
-
-
-def validate_sample(tensor):
- if not isinstance(tensor, np.ndarray) and not isinstance(tensor, torch.Tensor):
- return tensor
- if tensor.dtype == torch.bfloat16: # numpy does not support bf16
- tensor = tensor.to(torch.float16)
- if isinstance(tensor, torch.Tensor) and hasattr(tensor, 'detach'):
- sample = tensor.detach().cpu().numpy()
- elif isinstance(tensor, np.ndarray):
- sample = tensor
- else:
- shared.log.warning(f'Unknown sample type: {type(tensor)}')
- sample = 255.0 * np.moveaxis(sample, 0, 2) if shared.backend == shared.Backend.ORIGINAL else 255.0 * sample
- with warnings.catch_warnings(record=True) as w:
- cast = sample.astype(np.uint8)
- if len(w) > 0:
- nans = np.isnan(sample).sum()
- shared.log.error(f'Failed to validate samples: sample={sample.shape} invalid={nans}')
- cast = np.nan_to_num(sample)
- minimum, maximum, mean = np.min(cast), np.max(cast), np.mean(cast)
- cast = cast.astype(np.uint8)
- shared.log.warning(f'Attempted to correct samples: min={minimum:.2f} max={maximum:.2f} mean={mean:.2f}')
- return cast
-
-
-def process_init(p: StableDiffusionProcessing):
- seed = get_fixed_seed(p.seed)
- subseed = get_fixed_seed(p.subseed)
- if type(p.prompt) == list:
- p.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, p.styles) for x in p.prompt]
- else:
- p.all_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_styles_to_prompt(p.prompt, p.styles)]
- if type(p.negative_prompt) == list:
- p.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, p.styles) for x in p.negative_prompt]
- else:
- p.all_negative_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)]
- if type(seed) == list:
- p.all_seeds = seed
- else:
- p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))]
- if type(subseed) == list:
- p.all_subseeds = subseed
- else:
- p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
-
-
-def process_images_inner(p: StableDiffusionProcessing) -> Processed:
- """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
-
- if type(p.prompt) == list:
- assert len(p.prompt) > 0
- else:
- assert p.prompt is not None
-
- if shared.backend == shared.Backend.ORIGINAL:
- modules.sd_hijack.model_hijack.apply_circular(p.tiling)
- modules.sd_hijack.model_hijack.clear_comments()
- comments = {}
- infotexts = []
- output_images = []
- cached_uc = [None, None]
- cached_c = [None, None]
-
- process_init(p)
- if os.path.exists(shared.opts.embeddings_dir) and not p.do_not_reload_embeddings and shared.backend == shared.Backend.ORIGINAL:
- modules.sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=False)
- if p.scripts is not None and isinstance(p.scripts, modules.scripts.ScriptRunner):
- p.scripts.process(p)
-
-
- def get_conds_with_caching(function, required_prompts, steps, cache):
- if cache[0] is not None and (required_prompts, steps) == cache[0]:
- return cache[1]
- with devices.autocast():
- cache[1] = function(shared.sd_model, required_prompts, steps)
- cache[0] = (required_prompts, steps)
- return cache[1]
-
- def infotext(_inxex=0): # dummy function overriden if there are iterations
- return ''
-
- ema_scope_context = p.sd_model.ema_scope if shared.backend == shared.Backend.ORIGINAL else nullcontext
- shared.state.job_count = p.n_iter
- with devices.inference_context(), ema_scope_context():
- t0 = time.time()
- with devices.autocast():
- p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
- extra_network_data = None
- debug(f'Processing inner: args={vars(p)}')
- for n in range(p.n_iter):
- p.iteration = n
- if shared.state.skipped:
- shared.log.debug(f'Process skipped: {n}/{p.n_iter}')
- shared.state.skipped = False
- continue
- if shared.state.interrupted:
- shared.log.debug(f'Process interrupted: {n}/{p.n_iter}')
- break
- p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
- p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
- p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
- p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
- if p.scripts is not None and isinstance(p.scripts, modules.scripts.ScriptRunner):
- p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
- if len(p.prompts) == 0:
- break
- p.prompts, extra_network_data = modules.extra_networks.parse_prompts(p.prompts)
- if not p.disable_extra_networks:
- with devices.autocast():
- modules.extra_networks.activate(p, extra_network_data)
- if p.scripts is not None and isinstance(p.scripts, modules.scripts.ScriptRunner):
- p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
- step_multiplier = 1
- sampler_config = modules.sd_samplers.find_sampler_config(p.sampler_name)
- step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
-
- if shared.backend == shared.Backend.ORIGINAL:
- uc = get_conds_with_caching(modules.prompt_parser.get_learned_conditioning, p.negative_prompts, p.steps * step_multiplier, cached_uc)
- c = get_conds_with_caching(modules.prompt_parser.get_multicond_learned_conditioning, p.prompts, p.steps * step_multiplier, cached_c)
- if len(modules.sd_hijack.model_hijack.comments) > 0:
- for comment in modules.sd_hijack.model_hijack.comments:
- comments[comment] = 1
- with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
- samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
- x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae), p.full_quality)[0].cpu() for i in range(samples_ddim.size(0))]
- try:
- for x in x_samples_ddim:
- devices.test_for_nans(x, "vae")
- except devices.NansException as e:
- if not shared.opts.no_half and not shared.opts.no_half_vae and shared.cmd_opts.rollback_vae:
- shared.log.warning('Tensor with all NaNs was produced in VAE')
- devices.dtype_vae = torch.bfloat16
- vae_file, vae_source = modules.sd_vae.resolve_vae(p.sd_model.sd_model_checkpoint)
- modules.sd_vae.load_vae(p.sd_model, vae_file, vae_source)
- x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae), p.full_quality)[0].cpu() for i in range(samples_ddim.size(0))]
- for x in x_samples_ddim:
- devices.test_for_nans(x, "vae")
- else:
- raise e
- x_samples_ddim = torch.stack(x_samples_ddim).float()
- x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
- del samples_ddim
-
- elif shared.backend == shared.Backend.DIFFUSERS:
- from modules.processing_diffusers import process_diffusers
- x_samples_ddim = process_diffusers(p)
- else:
- raise ValueError(f"Unknown backend {shared.backend}")
-
- if not shared.opts.keep_incomplete and shared.state.interrupted:
- x_samples_ddim = []
-
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram and shared.backend == shared.Backend.ORIGINAL:
- modules.lowvram.send_everything_to_cpu()
- devices.torch_gc()
- if p.scripts is not None and isinstance(p.scripts, modules.scripts.ScriptRunner):
- p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
- if p.scripts is not None and isinstance(p.scripts, modules.scripts.ScriptRunner):
- p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
- p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
- batch_params = modules.scripts.PostprocessBatchListArgs(list(x_samples_ddim))
- p.scripts.postprocess_batch_list(p, batch_params, batch_number=n)
- x_samples_ddim = batch_params.images
-
- def infotext(index): # pylint: disable=function-redefined # noqa: F811
- return create_infotext(p, p.prompts, p.seeds, p.subseeds, index=index, all_negative_prompts=p.negative_prompts)
-
- for i, x_sample in enumerate(x_samples_ddim):
- p.batch_index = i
- if type(x_sample) == Image.Image:
- image = x_sample
- x_sample = np.array(x_sample)
- else:
- x_sample = validate_sample(x_sample)
- image = Image.fromarray(x_sample)
- if p.restore_faces:
- if shared.opts.save and not p.do_not_save_samples and shared.opts.save_images_before_face_restoration:
- orig = p.restore_faces
- p.restore_faces = False
- info = infotext(i)
- p.restore_faces = orig
- images.save_image(Image.fromarray(x_sample), path=p.outpath_samples, basename="", seed=p.seeds[i], prompt=p.prompts[i], extension=shared.opts.samples_format, info=info, p=p, suffix="-before-face-restore")
- p.ops.append('face')
- x_sample = modules.face_restoration.restore_faces(x_sample)
- image = Image.fromarray(x_sample)
- if p.scripts is not None and isinstance(p.scripts, modules.scripts.ScriptRunner):
- pp = modules.scripts.PostprocessImageArgs(image)
- p.scripts.postprocess_image(p, pp)
- image = pp.image
- if p.color_corrections is not None and i < len(p.color_corrections):
- if shared.opts.save and not p.do_not_save_samples and shared.opts.save_images_before_color_correction:
- orig = p.color_corrections
- p.color_corrections = None
- info = infotext(i)
- p.color_corrections = orig
- image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
- images.save_image(image_without_cc, path=p.outpath_samples, basename="", seed=p.seeds[i], prompt=p.prompts[i], extension=shared.opts.samples_format, info=info, p=p, suffix="-before-color-correct")
- p.ops.append('color')
- image = apply_color_correction(p.color_corrections[i], image)
- image = apply_overlay(image, p.paste_to, i, p.overlay_images)
- text = infotext(i)
- infotexts.append(text)
- image.info["parameters"] = text
- output_images.append(image)
- if shared.opts.samples_save and not p.do_not_save_samples:
- images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], shared.opts.samples_format, info=text, p=p) # main save image
- if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([shared.opts.save_mask, shared.opts.save_mask_composite, shared.opts.return_mask, shared.opts.return_mask_composite]):
- image_mask = p.mask_for_overlay.convert('RGB')
- image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(3, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
- if shared.opts.save_mask:
- images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], shared.opts.samples_format, info=text, p=p, suffix="-mask")
- if shared.opts.save_mask_composite:
- images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], shared.opts.samples_format, info=text, p=p, suffix="-mask-composite")
- if shared.opts.return_mask:
- output_images.append(image_mask)
- if shared.opts.return_mask_composite:
- output_images.append(image_mask_composite)
- del x_samples_ddim
- devices.torch_gc()
-
- t1 = time.time()
- shared.log.info(f'Processed: images={len(output_images)} time={t1 - t0:.2f} its={(p.steps * len(output_images)) / (t1 - t0):.2f} memory={modules.memstats.memory_stats()}')
-
- p.color_corrections = None
- index_of_first_image = 0
- if (shared.opts.return_grid or shared.opts.grid_save) and not p.do_not_save_grid and len(output_images) > 1:
- if images.check_grid_size(output_images):
- grid = images.image_grid(output_images, p.batch_size)
- if shared.opts.return_grid:
- text = infotext(-1)
- infotexts.insert(0, text)
- grid.info["parameters"] = text
- output_images.insert(0, grid)
- index_of_first_image = 1
- if shared.opts.grid_save:
- images.save_image(grid, p.outpath_grids, "", p.all_seeds[0], p.all_prompts[0], shared.opts.grid_format, info=infotext(-1), p=p, grid=True, suffix="-grid") # main save grid
-
- if not p.disable_extra_networks:
- modules.extra_networks.deactivate(p, extra_network_data)
-
- res = Processed(
- p,
- images_list=output_images,
- seed=p.all_seeds[0],
- info=infotext(0),
- comments="\n".join(comments),
- subseed=p.all_subseeds[0],
- index_of_first_image=index_of_first_image,
- infotexts=infotexts,
- )
- if p.scripts is not None and isinstance(p.scripts, modules.scripts.ScriptRunner) and not (shared.state.interrupted or shared.state.skipped):
- p.scripts.postprocess(p, res)
- return res
-
-
-def old_hires_fix_first_pass_dimensions(width, height):
- """old algorithm for auto-calculating first pass size"""
- desired_pixel_count = 512 * 512
- actual_pixel_count = width * height
- scale = math.sqrt(desired_pixel_count / actual_pixel_count)
- width = math.ceil(scale * width / 64) * 64
- height = math.ceil(scale * height / 64) * 64
- return width, height
-
-
-class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
-
- def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_force: bool = False, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, refiner_steps: int = 5, refiner_start: float = 0, refiner_prompt: str = '', refiner_negative: str = '', **kwargs):
-
- super().__init__(**kwargs)
- if devices.backend == "ipex" and os.environ.get('DISABLE_IPEX_1024_WA', None) is None:
- width_curse = bool(hr_resize_x == 1024 and self.height * (hr_resize_x / self.width) == 1024)
- height_curse = bool(hr_resize_y == 1024 and self.width * (hr_resize_y / self.height) == 1024)
- if (width_curse != height_curse) or (height_curse and width_curse):
- if width_curse:
- hr_resize_x = 1080
- if height_curse:
- hr_resize_y = 1080
- if self.width * hr_scale == 1024 and self.height * hr_scale == 1024:
- hr_scale = 1080 / self.width
- if firstphase_width * hr_scale == 1024 and firstphase_height * hr_scale == 1024:
- hr_scale = 1080 / firstphase_width
- self.enable_hr = enable_hr
- self.denoising_strength = denoising_strength
- self.hr_scale = hr_scale
- self.hr_upscaler = hr_upscaler
- self.hr_force = hr_force
- self.hr_second_pass_steps = hr_second_pass_steps
- self.hr_resize_x = hr_resize_x
- self.hr_resize_y = hr_resize_y
- self.hr_upscale_to_x = hr_resize_x
- self.hr_upscale_to_y = hr_resize_y
- if firstphase_width != 0 or firstphase_height != 0:
- self.hr_upscale_to_x = self.width
- self.hr_upscale_to_y = self.height
- self.width = firstphase_width
- self.height = firstphase_height
- self.truncate_x = 0
- self.truncate_y = 0
- self.applied_old_hires_behavior_to = None
- self.refiner_steps = refiner_steps
- self.refiner_start = refiner_start
- self.refiner_prompt = refiner_prompt
- self.refiner_negative = refiner_negative
- self.sampler = None
- self.scripts = None
- self.script_args = []
-
- def init(self, all_prompts, all_seeds, all_subseeds):
- if shared.backend == shared.Backend.DIFFUSERS:
- shared.sd_model = modules.sd_models.set_diffuser_pipe(self.sd_model, modules.sd_models.DiffusersTaskType.TEXT_2_IMAGE)
- self.width = self.width or 512
- self.height = self.height or 512
-
- def init_hr(self):
- if self.hr_resize_x == 0 and self.hr_resize_y == 0:
- self.hr_upscale_to_x = int(self.width * self.hr_scale)
- self.hr_upscale_to_y = int(self.height * self.hr_scale)
- else:
- if self.hr_resize_y == 0:
- self.hr_upscale_to_x = self.hr_resize_x
- self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
- elif self.hr_resize_x == 0:
- self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
- self.hr_upscale_to_y = self.hr_resize_y
- else:
- target_w = self.hr_resize_x
- target_h = self.hr_resize_y
- src_ratio = self.width / self.height
- dst_ratio = self.hr_resize_x / self.hr_resize_y
- if src_ratio < dst_ratio:
- self.hr_upscale_to_x = self.hr_resize_x
- self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
- else:
- self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
- self.hr_upscale_to_y = self.hr_resize_y
- self.truncate_x = (self.hr_upscale_to_x - target_w) // 8
- self.truncate_y = (self.hr_upscale_to_y - target_h) // 8
- # special case: the user has chosen to do nothing
- if (self.hr_upscale_to_x == self.width and self.hr_upscale_to_y == self.height) or self.hr_upscaler is None or self.hr_upscaler == 'None':
- self.is_hr_pass = False
- return
- self.is_hr_pass = True
- hypertile_set(self, hr=True)
- shared.state.job_count = 2 * self.n_iter
- shared.log.debug(f'Init hires: upscaler="{self.hr_upscaler}" sampler="{self.hr_sampler_name}" resize={self.hr_resize_x}x{self.hr_resize_y} upscale={self.hr_upscale_to_x}x{self.hr_upscale_to_y}')
-
- def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
-
- latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "None")
- if latent_scale_mode is not None:
- self.hr_force = False # no need to force anything
- if self.enable_hr and (latent_scale_mode is None or self.hr_force):
- if len([x for x in shared.sd_upscalers if x.name == self.hr_upscaler]) == 0:
- shared.log.warning(f"Cannot find upscaler for hires: {self.hr_upscaler}")
- self.enable_hr = False
-
- self.ops.append('txt2img')
- hypertile_set(self)
- self.sampler = modules.sd_samplers.create_sampler(self.sampler_name, self.sd_model)
- if hasattr(self.sampler, "initialize"):
- self.sampler.initialize(self)
- x = create_random_tensors([4, self.height // 8, self.width // 8], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
- samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
- shared.state.nextjob()
- if not self.enable_hr or shared.state.interrupted or shared.state.skipped:
- return samples
-
- self.init_hr()
- if self.is_hr_pass:
- prev_job = shared.state.job
- target_width = self.hr_upscale_to_x
- target_height = self.hr_upscale_to_y
- decoded_samples = None
- if shared.opts.save and shared.opts.save_images_before_highres_fix and not self.do_not_save_samples:
- decoded_samples = decode_first_stage(self.sd_model, samples.to(dtype=devices.dtype_vae), self.full_quality)
- decoded_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
- for i, x_sample in enumerate(decoded_samples):
- x_sample = validate_sample(x_sample)
- image = Image.fromarray(x_sample)
- bak_extra_generation_params, bak_restore_faces = self.extra_generation_params, self.restore_faces
- self.extra_generation_params = {}
- self.restore_faces = False
- info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=i)
- self.extra_generation_params, self.restore_faces = bak_extra_generation_params, bak_restore_faces
- images.save_image(image, self.outpath_samples, "", seeds[i], prompts[i], shared.opts.samples_format, info=info, suffix="-before-hires")
- if latent_scale_mode is None or self.hr_force: # non-latent upscaling
- shared.state.job = 'upscale'
- if decoded_samples is None:
- decoded_samples = decode_first_stage(self.sd_model, samples.to(dtype=devices.dtype_vae), self.full_quality)
- decoded_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
- batch_images = []
- for _i, x_sample in enumerate(decoded_samples):
- x_sample = validate_sample(x_sample)
- image = Image.fromarray(x_sample)
- image = images.resize_image(1, image, target_width, target_height, upscaler_name=self.hr_upscaler)
- image = np.array(image).astype(np.float32) / 255.0
- image = np.moveaxis(image, 2, 0)
- batch_images.append(image)
- resized_samples = torch.from_numpy(np.array(batch_images))
- resized_samples = resized_samples.to(device=shared.device, dtype=devices.dtype_vae)
- resized_samples = 2.0 * resized_samples - 1.0
- if shared.opts.sd_vae_sliced_encode and len(decoded_samples) > 1:
- samples = torch.stack([self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(torch.unsqueeze(resized_sample, 0)))[0] for resized_sample in resized_samples])
- else:
- # TODO add TEASD support
- samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(resized_samples))
- image_conditioning = self.img2img_image_conditioning(resized_samples, samples)
- else:
- samples = torch.nn.functional.interpolate(samples, size=(target_height // 8, target_width // 8), mode=latent_scale_mode["mode"], antialias=latent_scale_mode["antialias"])
- if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
- image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples.to(dtype=devices.dtype_vae), self.full_quality), samples)
- else:
- image_conditioning = self.txt2img_image_conditioning(samples.to(dtype=devices.dtype_vae))
- if self.hr_sampler_name == "PLMS":
- self.hr_sampler_name = 'UniPC'
- if self.hr_force or latent_scale_mode is not None:
- shared.state.job = 'hires'
- if self.denoising_strength > 0:
- self.ops.append('hires')
- devices.torch_gc() # GC now before running the next img2img to prevent running out of memory
- self.sampler = modules.sd_samplers.create_sampler(self.hr_sampler_name or self.sampler_name, self.sd_model)
- if hasattr(self.sampler, "initialize"):
- self.sampler.initialize(self)
- samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
- noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
- modules.sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
- hypertile_set(self, hr=True)
- samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
- modules.sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
- else:
- self.ops.append('upscale')
- x = None
- self.is_hr_pass = False
- shared.state.job = prev_job
- shared.state.nextjob()
-
- return samples
-
-
-class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
-
- def __init__(self, init_images: list = None, resize_mode: int = 0, resize_name: str = 'None', denoising_strength: float = 0.3, image_cfg_scale: float = None, mask: Any = None, mask_blur: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, scale_by: float = 1, refiner_steps: int = 5, refiner_start: float = 0, refiner_prompt: str = '', refiner_negative: str = '', **kwargs):
- super().__init__(**kwargs)
- self.init_images = init_images
- self.resize_mode: int = resize_mode
- self.resize_name: str = resize_name
- self.denoising_strength: float = denoising_strength
- self.image_cfg_scale: float = image_cfg_scale
- self.init_latent = None
- self.image_mask = mask
- self.latent_mask = None
- self.mask_for_overlay = None
- self.mask_blur_x = mask_blur # a1111 compatibility item
- self.mask_blur_y = mask_blur # a1111 compatibility item
- self.mask_blur = mask_blur
- self.inpainting_fill = inpainting_fill
- self.inpaint_full_res = inpaint_full_res
- self.inpaint_full_res_padding = inpaint_full_res_padding
- self.inpainting_mask_invert = inpainting_mask_invert
- self.initial_noise_multiplier = shared.opts.initial_noise_multiplier if initial_noise_multiplier is None else initial_noise_multiplier
- self.mask = None
- self.nmask = None
- self.image_conditioning = None
- self.refiner_steps = refiner_steps
- self.refiner_start = refiner_start
- self.refiner_prompt = refiner_prompt
- self.refiner_negative = refiner_negative
- self.enable_hr = None
- self.is_batch = False
- self.scale_by = scale_by
- self.sampler = None
- self.scripts = None
- self.script_args = []
-
- def init(self, all_prompts, all_seeds, all_subseeds):
- if shared.backend == shared.Backend.DIFFUSERS and self.image_mask is not None and not self.is_control:
- shared.sd_model = modules.sd_models.set_diffuser_pipe(self.sd_model, modules.sd_models.DiffusersTaskType.INPAINTING)
- elif shared.backend == shared.Backend.DIFFUSERS and self.image_mask is None and not self.is_control:
- shared.sd_model = modules.sd_models.set_diffuser_pipe(self.sd_model, modules.sd_models.DiffusersTaskType.IMAGE_2_IMAGE)
-
- if self.sampler_name == "PLMS":
- self.sampler_name = 'UniPC'
- self.sampler = modules.sd_samplers.create_sampler(self.sampler_name, self.sd_model)
- if hasattr(self.sampler, "initialize"):
- self.sampler.initialize(self)
-
- if self.image_mask is not None:
- self.ops.append('inpaint')
- else:
- self.ops.append('img2img')
- crop_region = None
-
- if self.image_mask is not None:
- if type(self.image_mask) == list:
- self.image_mask = self.image_mask[0]
- self.image_mask = create_binary_mask(self.image_mask)
- if self.inpainting_mask_invert:
- self.image_mask = ImageOps.invert(self.image_mask)
- if self.mask_blur > 0:
- np_mask = np.array(self.image_mask)
- kernel_size = 2 * int(2.5 * self.mask_blur + 0.5) + 1
- np_mask = cv2.GaussianBlur(np_mask, (kernel_size, 1), self.mask_blur)
- np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur)
- self.image_mask = Image.fromarray(np_mask)
- if self.inpaint_full_res:
- self.mask_for_overlay = self.image_mask
- mask = self.image_mask.convert('L')
- crop_region = modules.masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
- crop_region = modules.masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
- x1, y1, x2, y2 = crop_region
- mask = mask.crop(crop_region)
- self.image_mask = images.resize_image(2, mask, self.width, self.height)
- self.paste_to = (x1, y1, x2-x1, y2-y1)
- else:
- self.image_mask = images.resize_image(self.resize_mode, self.image_mask, self.width, self.height)
- np_mask = np.array(self.image_mask)
- np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
- self.mask_for_overlay = Image.fromarray(np_mask)
- self.overlay_images = []
-
- latent_mask = self.latent_mask if self.latent_mask is not None else self.image_mask
-
- add_color_corrections = shared.opts.img2img_color_correction and self.color_corrections is None
- if add_color_corrections:
- self.color_corrections = []
- processed = []
- if getattr(self, 'init_images', None) is None:
- return
- if not isinstance(self.init_images, list):
- self.init_images = [self.init_images]
- for img in self.init_images:
- if img is None:
- shared.log.warning(f"Skipping empty image: images={self.init_images}")
- continue
- self.init_img_hash = hashlib.sha256(img.tobytes()).hexdigest()[0:8] # pylint: disable=attribute-defined-outside-init
- self.init_img_width = img.width # pylint: disable=attribute-defined-outside-init
- self.init_img_height = img.height # pylint: disable=attribute-defined-outside-init
- if shared.opts.save_init_img:
- images.save_image(img, path=shared.opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, suffix="-init-image")
- image = images.flatten(img, shared.opts.img2img_background_color)
- if crop_region is None and self.resize_mode != 4 and self.resize_mode > 0:
- if image.width != self.width or image.height != self.height:
- image = images.resize_image(self.resize_mode, image, self.width, self.height, self.resize_name)
- self.width = image.width
- self.height = image.height
- if self.image_mask is not None:
- try:
- image_masked = Image.new('RGBa', (image.width, image.height))
- image_to_paste = image.convert("RGBA").convert("RGBa")
- image_to_mask = ImageOps.invert(self.mask_for_overlay.convert('L')) if self.mask_for_overlay is not None else None
- image_to_mask = image_to_mask.resize((image.width, image.height), Image.Resampling.BILINEAR) if image_to_mask is not None else None
- image_masked.paste(image_to_paste, mask=image_to_mask)
- self.overlay_images.append(image_masked.convert('RGBA'))
- except Exception as e:
- shared.log.error(f"Failed to apply mask to image: {e}")
- if crop_region is not None: # crop_region is not None if we are doing inpaint full res
- image = image.crop(crop_region)
- if image.width != self.width or image.height != self.height:
- image = images.resize_image(3, image, self.width, self.height, self.resize_name)
- if self.image_mask is not None and self.inpainting_fill != 1:
- image = modules.masking.fill(image, latent_mask)
- if add_color_corrections:
- self.color_corrections.append(setup_color_correction(image))
- processed.append(image)
- self.init_images = processed
- self.batch_size = len(self.init_images)
- if self.overlay_images is not None:
- self.overlay_images = self.overlay_images * self.batch_size
- if self.color_corrections is not None and len(self.color_corrections) == 1:
- self.color_corrections = self.color_corrections * self.batch_size
- if shared.backend == shared.Backend.DIFFUSERS:
- return # we've already set self.init_images and self.mask and we dont need any more processing
-
- self.init_images = [np.moveaxis((np.array(image).astype(np.float32) / 255.0), 2, 0) for image in self.init_images]
- if len(self.init_images) == 1:
- batch_images = np.expand_dims(self.init_images[0], axis=0).repeat(self.batch_size, axis=0)
- elif len(self.init_images) <= self.batch_size:
- batch_images = np.array(self.init_images)
- image = torch.from_numpy(batch_images)
- image = 2. * image - 1.
- image = image.to(device=shared.device, dtype=devices.dtype_vae)
- self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
- if self.resize_mode == 4:
- self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // 8, self.width // 8), mode="bilinear")
- if self.image_mask is not None:
- init_mask = latent_mask
- latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
- latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
- latmask = latmask[0]
- latmask = np.tile(latmask[None], (4, 1, 1))
- latmask = np.around(latmask)
- self.mask = torch.asarray(1.0 - latmask).to(device=shared.device, dtype=self.sd_model.dtype)
- self.nmask = torch.asarray(latmask).to(device=shared.device, dtype=self.sd_model.dtype)
- if self.inpainting_fill == 2:
- self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
- elif self.inpainting_fill == 3:
- self.init_latent = self.init_latent * self.mask
- self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, self.image_mask)
-
- def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
- hypertile_set(self)
- x = create_random_tensors([4, self.height // 8, self.width // 8], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
- x *= self.initial_noise_multiplier
- samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
- if self.mask is not None:
- samples = samples * self.nmask + self.init_latent * self.mask
- del x
- devices.torch_gc()
- shared.state.nextjob()
- return samples
-
- def get_token_merging_ratio(self, for_hr=False):
- return self.token_merging_ratio or ("token_merging_ratio" in self.override_settings and shared.opts.token_merging_ratio) or shared.opts.token_merging_ratio_img2img or shared.opts.token_merging_ratio
+import os
+import json
+import math
+import time
+import hashlib
+import random
+import warnings
+from contextlib import nullcontext
+from typing import Any, Dict, List
+from dataclasses import dataclass, field
+import torch
+import numpy as np
+import cv2
+from PIL import Image, ImageOps
+from skimage import exposure
+from einops import repeat, rearrange
+from blendmodes.blend import blendLayers, BlendType
+from installer import git_commit
+from modules import shared, devices, errors
+import modules.memstats
+import modules.lowvram
+import modules.masking
+import modules.paths
+import modules.scripts
+import modules.script_callbacks
+import modules.prompt_parser
+import modules.extra_networks
+import modules.face_restoration
+import modules.images as images
+import modules.styles
+import modules.sd_hijack_freeu
+import modules.sd_samplers
+import modules.sd_samplers_common
+import modules.sd_models
+import modules.sd_vae
+import modules.sd_vae_approx
+import modules.taesd.sd_vae_taesd
+import modules.generation_parameters_copypaste
+from modules.sd_hijack_hypertile import context_hypertile_vae, context_hypertile_unet, hypertile_set
+
+
+if shared.backend == shared.Backend.ORIGINAL:
+ import modules.sd_hijack
+
+opt_C = 4
+opt_f = 8
+debug = shared.log.trace if os.environ.get('SD_PROCESS_DEBUG', None) is not None else lambda *args, **kwargs: None
+debug('Trace: PROCESS')
+
+
+def setup_color_correction(image):
+ debug("Calibrating color correction")
+ correction_target = cv2.cvtColor(np.asarray(image.copy()), cv2.COLOR_RGB2LAB)
+ return correction_target
+
+
+def apply_color_correction(correction, original_image):
+ shared.log.debug(f"Applying color correction: correction={correction} image={original_image}")
+ image = Image.fromarray(cv2.cvtColor(exposure.match_histograms(cv2.cvtColor(np.asarray(original_image), cv2.COLOR_RGB2LAB), correction, channel_axis=2), cv2.COLOR_LAB2RGB).astype("uint8"))
+ image = blendLayers(image, original_image, BlendType.LUMINOSITY)
+ return image
+
+
+def apply_overlay(image: Image, paste_loc, index, overlays):
+ debug(f'Apply overlay: image={image} loc={paste_loc} index={index} overlays={overlays}')
+ if overlays is None or index >= len(overlays):
+ return image
+ overlay = overlays[index]
+ if paste_loc is not None:
+ x, y, w, h = paste_loc
+ if image.width != w or image.height != h or x != 0 or y != 0:
+ base_image = Image.new('RGBA', (overlay.width, overlay.height))
+ image = images.resize_image(2, image, w, h)
+ base_image.paste(image, (x, y))
+ image = base_image
+ image = image.convert('RGBA')
+ image.alpha_composite(overlay)
+ image = image.convert('RGB')
+ return image
+
+
+def create_binary_mask(image):
+ if image.mode == 'RGBA' and image.getextrema()[-1] != (255, 255):
+ image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
+ else:
+ image = image.convert('L')
+ return image
+
+
+def images_tensor_to_samples(image, approximation=None, model=None): # pylint: disable=unused-argument
+ if model is None:
+ model = shared.sd_model
+ model.first_stage_model.to(devices.dtype_vae)
+ image = image.to(shared.device, dtype=devices.dtype_vae)
+ image = image * 2 - 1
+ if len(image) > 1:
+ x_latent = torch.stack([
+ model.get_first_stage_encoding(model.encode_first_stage(torch.unsqueeze(img, 0)))[0]
+ for img in image
+ ])
+ else:
+ x_latent = model.get_first_stage_encoding(model.encode_first_stage(image))
+ return x_latent
+
+
+def txt2img_image_conditioning(sd_model, x, width, height):
+ if sd_model.model.conditioning_key in {'hybrid', 'concat'}: # Inpainting models
+ # The "masked-image" in this case will just be all zeros since the entire image is masked.
+ image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
+ image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning))
+ # Add the fake full 1s mask to the first dimension.
+ image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) # pylint: disable=not-callable
+ image_conditioning = image_conditioning.to(x.dtype)
+ return image_conditioning
+ elif sd_model.model.conditioning_key == "crossattn-adm": # UnCLIP models
+ return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device)
+ else:
+ # Dummy zero conditioning if we're not using inpainting or unclip models.
+ # Still takes up a bit of memory, but no encoder call.
+ # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
+ return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
+
+
+def get_sampler_name(sampler_index: int, img: bool = False) -> str:
+ if len(modules.sd_samplers.samplers) > sampler_index:
+ sampler_name = modules.sd_samplers.samplers[sampler_index].name
+ else:
+ sampler_name = "UniPC"
+ shared.log.warning(f'Sampler not found: index={sampler_index} available={[s.name for s in modules.sd_samplers.samplers]} fallback={sampler_name}')
+ if img and sampler_name == "PLMS":
+ sampler_name = "UniPC"
+ shared.log.warning(f'Sampler not compatible: name=PLMS fallback={sampler_name}')
+ return sampler_name
+
+
+@dataclass(repr=False)
+class StableDiffusionProcessing:
+ """
+ The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
+ """
+ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, hr_sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, image_cfg_scale: float = None, clip_skip: int = 1, width: int = 512, height: int = 512, full_quality: bool = True, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, diffusers_guidance_rescale: float = 0.7, sag_scale: float = 0.0, resize_mode: int = 0, resize_name: str = 'None', scale_by: float = 0, selected_scale_tab: int = 0, hdr_clamp: bool = False, hdr_boundary: float = 4.0, hdr_threshold: float = 3.5, hdr_center: bool = False, hdr_channel_shift: float = 0.8, hdr_full_shift: float = 0.8, hdr_maximize: bool = False, hdr_max_center: float = 0.6, hdr_max_boundry: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None): # pylint: disable=unused-argument
+ self.outpath_samples: str = outpath_samples
+ self.outpath_grids: str = outpath_grids
+ self.prompt: str = prompt
+ self.prompt_for_display: str = None
+ self.negative_prompt: str = (negative_prompt or "")
+ self.styles: list = styles or []
+ self.seed: int = seed
+ self.subseed: int = subseed
+ self.subseed_strength: float = subseed_strength
+ self.seed_resize_from_h: int = seed_resize_from_h
+ self.seed_resize_from_w: int = seed_resize_from_w
+ self.sampler_name: str = sampler_name
+ self.hr_sampler_name: str = hr_sampler_name
+ self.batch_size: int = batch_size
+ self.n_iter: int = n_iter
+ self.steps: int = steps
+ self.hr_second_pass_steps = 0
+ self.cfg_scale: float = cfg_scale
+ self.scale_by: float = scale_by
+ self.image_cfg_scale = image_cfg_scale
+ self.diffusers_guidance_rescale = diffusers_guidance_rescale
+ self.sag_scale = sag_scale
+ if devices.backend == "ipex" and width == 1024 and height == 1024 and os.environ.get('DISABLE_IPEX_1024_WA', None) is None:
+ width = 1080
+ height = 1080
+ self.width: int = width
+ self.height: int = height
+ self.full_quality: bool = full_quality
+ self.restore_faces: bool = restore_faces
+ self.tiling: bool = tiling
+ self.do_not_save_samples: bool = do_not_save_samples
+ self.do_not_save_grid: bool = do_not_save_grid
+ self.extra_generation_params: dict = extra_generation_params or {}
+ self.overlay_images = overlay_images
+ self.eta = eta
+ self.do_not_reload_embeddings = do_not_reload_embeddings
+ self.paste_to = None
+ self.color_corrections = None
+ self.denoising_strength: float = denoising_strength
+ self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
+ self.override_settings_restore_afterwards = override_settings_restore_afterwards
+ self.is_using_inpainting_conditioning = False
+ self.disable_extra_networks = False
+ self.token_merging_ratio = 0
+ self.token_merging_ratio_hr = 0
+ # self.scripts = modules.scripts.ScriptRunner() # set via property
+ # self.script_args = script_args or [] # set via property
+ self.per_script_args = {}
+ self.all_prompts = None
+ self.all_negative_prompts = None
+ self.all_seeds = None
+ self.all_subseeds = None
+ self.clip_skip = clip_skip
+ self.iteration = 0
+ self.is_control = False
+ self.is_hr_pass = False
+ self.is_refiner_pass = False
+ self.hr_force = False
+ self.enable_hr = None
+ self.hr_scale = None
+ self.hr_upscaler = None
+ self.hr_resize_x = 0
+ self.hr_resize_y = 0
+ self.hr_upscale_to_x = 0
+ self.hr_upscale_to_y = 0
+ self.truncate_x = 0
+ self.truncate_y = 0
+ self.applied_old_hires_behavior_to = None
+ self.refiner_steps = 5
+ self.refiner_start = 0
+ self.refiner_prompt = ''
+ self.refiner_negative = ''
+ self.ops = []
+ self.resize_mode: int = resize_mode
+ self.resize_name: str = resize_name
+ self.ddim_discretize = shared.opts.ddim_discretize
+ self.s_min_uncond = shared.opts.s_min_uncond
+ self.s_churn = shared.opts.s_churn
+ self.s_noise = shared.opts.s_noise
+ self.s_min = shared.opts.s_min
+ self.s_max = shared.opts.s_max
+ self.s_tmin = shared.opts.s_tmin
+ self.s_tmax = float('inf') # not representable as a standard ui option
+ shared.opts.data['clip_skip'] = clip_skip
+ self.task_args = {}
+ # a1111 compatibility items
+ self.refiner_switch_at = 0
+ self.hr_prompt = ''
+ self.all_hr_prompts = []
+ self.hr_negative_prompt = ''
+ self.all_hr_negative_prompts = []
+ self.comments = {}
+ self.is_api = False
+ self.scripts_value: modules.scripts.ScriptRunner = field(default=None, init=False)
+ self.script_args_value: list = field(default=None, init=False)
+ self.scripts_setup_complete: bool = field(default=False, init=False)
+ # hdr
+ self.hdr_clamp = hdr_clamp
+ self.hdr_boundary = hdr_boundary
+ self.hdr_threshold = hdr_threshold
+ self.hdr_center = hdr_center
+ self.hdr_channel_shift = hdr_channel_shift
+ self.hdr_full_shift = hdr_full_shift
+ self.hdr_maximize = hdr_maximize
+ self.hdr_max_center = hdr_max_center
+ self.hdr_max_boundry = hdr_max_boundry
+ self.scheduled_prompt: bool = False
+ self.prompt_embeds = []
+ self.positive_pooleds = []
+ self.negative_embeds = []
+ self.negative_pooleds = []
+
+
+ @property
+ def sd_model(self):
+ return shared.sd_model
+
+ @property
+ def scripts(self):
+ return self.scripts_value
+
+ @scripts.setter
+ def scripts(self, value):
+ self.scripts_value = value
+ if self.scripts_value and self.script_args_value and not self.scripts_setup_complete:
+ self.setup_scripts()
+
+ @property
+ def script_args(self):
+ return self.script_args_value
+
+ @script_args.setter
+ def script_args(self, value):
+ self.script_args_value = value
+ if self.scripts_value and self.script_args_value and not self.scripts_setup_complete:
+ self.setup_scripts()
+
+ def setup_scripts(self):
+ self.scripts_setup_complete = True
+ self.scripts.setup_scrips(self, is_ui=not self.is_api)
+
+ def comment(self, text):
+ self.comments[text] = 1
+
+ def txt2img_image_conditioning(self, x, width=None, height=None):
+ self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
+ return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height)
+
+ def depth2img_image_conditioning(self, source_image):
+ # Use the AddMiDaS helper to Format our source image to suit the MiDaS model
+ from ldm.data.util import AddMiDaS
+ transformer = AddMiDaS(model_type="dpt_hybrid")
+ transformed = transformer({"jpg": rearrange(source_image[0], "c h w -> h w c")})
+ midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
+ midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
+ conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
+ conditioning = torch.nn.functional.interpolate(
+ self.sd_model.depth_model(midas_in),
+ size=conditioning_image.shape[2:],
+ mode="bicubic",
+ align_corners=False,
+ )
+ (depth_min, depth_max) = torch.aminmax(conditioning)
+ conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1.
+ return conditioning
+
+ def edit_image_conditioning(self, source_image):
+ conditioning_image = self.sd_model.encode_first_stage(source_image).mode()
+ return conditioning_image
+
+ def unclip_image_conditioning(self, source_image):
+ c_adm = self.sd_model.embedder(source_image)
+ if self.sd_model.noise_augmentor is not None:
+ noise_level = 0
+ c_adm, noise_level_emb = self.sd_model.noise_augmentor(c_adm, noise_level=repeat(torch.tensor([noise_level]).to(c_adm.device), '1 -> b', b=c_adm.shape[0]))
+ c_adm = torch.cat((c_adm, noise_level_emb), 1)
+ return c_adm
+
+ def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None):
+ self.is_using_inpainting_conditioning = True
+ # Handle the different mask inputs
+ if image_mask is not None:
+ if torch.is_tensor(image_mask):
+ conditioning_mask = image_mask
+ else:
+ conditioning_mask = np.array(image_mask.convert("L"))
+ conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
+ conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
+ # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
+ conditioning_mask = torch.round(conditioning_mask)
+ else:
+ conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
+ # Create another latent image, this time with a masked version of the original input.
+ # Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
+ conditioning_mask = conditioning_mask.to(device=source_image.device, dtype=source_image.dtype)
+ conditioning_image = torch.lerp(
+ source_image,
+ source_image * (1.0 - conditioning_mask),
+ getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)
+ )
+ # Encode the new masked image using first stage of network.
+ conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
+ # Create the concatenated conditioning tensor to be fed to `c_concat`
+ conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
+ conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
+ image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
+ image_conditioning = image_conditioning.to(device=shared.device, dtype=source_image.dtype)
+ return image_conditioning
+
+ def diffusers_image_conditioning(self, _source_image, latent_image, _image_mask=None):
+ # shared.log.warning('Diffusers not implemented: img2img_image_conditioning')
+ return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
+
+ def img2img_image_conditioning(self, source_image, latent_image, image_mask=None):
+ from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
+ source_image = devices.cond_cast_float(source_image)
+ # HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
+ # identify itself with a field common to all models. The conditioning_key is also hybrid.
+ if shared.backend == shared.Backend.DIFFUSERS:
+ return self.diffusers_image_conditioning(source_image, latent_image, image_mask)
+ if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
+ return self.depth2img_image_conditioning(source_image)
+ if hasattr(self.sd_model, 'cond_stage_key') and self.sd_model.cond_stage_key == "edit":
+ return self.edit_image_conditioning(source_image)
+ if hasattr(self.sampler, 'conditioning_key') and self.sampler.conditioning_key in {'hybrid', 'concat'}:
+ return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
+ if hasattr(self.sampler, 'conditioning_key') and self.sampler.conditioning_key == "crossattn-adm":
+ return self.unclip_image_conditioning(source_image)
+ # Dummy zero conditioning if we're not using inpainting or depth model.
+ return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
+
+ def init(self, all_prompts, all_seeds, all_subseeds):
+ pass
+
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
+ raise NotImplementedError
+
+ def close(self):
+ self.sampler = None # pylint: disable=attribute-defined-outside-init
+
+ def get_token_merging_ratio(self, for_hr=False):
+ if for_hr:
+ return self.token_merging_ratio_hr or shared.opts.token_merging_ratio_hr or self.token_merging_ratio or shared.opts.token_merging_ratio
+ return self.token_merging_ratio or shared.opts.token_merging_ratio
+
+
+class Processed:
+ def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
+ self.images = images_list
+ self.prompt = p.prompt
+ self.negative_prompt = p.negative_prompt
+ self.seed = seed
+ self.subseed = subseed
+ self.subseed_strength = p.subseed_strength
+ self.info = info
+ self.comments = comments
+ self.width = p.width if hasattr(p, 'width') else (self.images[0].width if len(self.images) > 0 else 0)
+ self.height = p.height if hasattr(p, 'height') else (self.images[0].height if len(self.images) > 0 else 0)
+ self.sampler_name = p.sampler_name
+ self.cfg_scale = p.cfg_scale
+ self.image_cfg_scale = p.image_cfg_scale
+ self.steps = p.steps
+ self.batch_size = p.batch_size
+ self.restore_faces = p.restore_faces
+ self.face_restoration_model = shared.opts.face_restoration_model if p.restore_faces else None
+ self.sd_model_hash = getattr(shared.sd_model, 'sd_model_hash', '')
+ self.seed_resize_from_w = p.seed_resize_from_w
+ self.seed_resize_from_h = p.seed_resize_from_h
+ self.denoising_strength = p.denoising_strength
+ self.extra_generation_params = p.extra_generation_params
+ self.index_of_first_image = index_of_first_image
+ self.styles = p.styles
+ self.job_timestamp = shared.state.job_timestamp
+ self.clip_skip = p.clip_skip
+ self.eta = p.eta
+ self.ddim_discretize = p.ddim_discretize
+ self.s_churn = p.s_churn
+ self.s_tmin = p.s_tmin
+ self.s_tmax = p.s_tmax
+ self.s_noise = p.s_noise
+ self.s_min_uncond = p.s_min_uncond
+ self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
+ self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
+ self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1
+ self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
+ self.is_using_inpainting_conditioning = p.is_using_inpainting_conditioning
+ self.all_prompts = all_prompts or p.all_prompts or [self.prompt]
+ self.all_negative_prompts = all_negative_prompts or p.all_negative_prompts or [self.negative_prompt]
+ self.all_seeds = all_seeds or p.all_seeds or [self.seed]
+ self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed]
+ self.token_merging_ratio = p.token_merging_ratio
+ self.token_merging_ratio_hr = p.token_merging_ratio_hr
+ self.infotexts = infotexts or [info]
+
+ def js(self):
+ obj = {
+ "prompt": self.all_prompts[0],
+ "all_prompts": self.all_prompts,
+ "negative_prompt": self.all_negative_prompts[0],
+ "all_negative_prompts": self.all_negative_prompts,
+ "seed": self.seed,
+ "all_seeds": self.all_seeds,
+ "subseed": self.subseed,
+ "all_subseeds": self.all_subseeds,
+ "subseed_strength": self.subseed_strength,
+ "width": self.width,
+ "height": self.height,
+ "sampler_name": self.sampler_name,
+ "cfg_scale": self.cfg_scale,
+ "steps": self.steps,
+ "batch_size": self.batch_size,
+ "restore_faces": self.restore_faces,
+ "face_restoration_model": self.face_restoration_model,
+ "sd_model_hash": self.sd_model_hash,
+ "seed_resize_from_w": self.seed_resize_from_w,
+ "seed_resize_from_h": self.seed_resize_from_h,
+ "denoising_strength": self.denoising_strength,
+ "extra_generation_params": self.extra_generation_params,
+ "index_of_first_image": self.index_of_first_image,
+ "infotexts": self.infotexts,
+ "styles": self.styles,
+ "job_timestamp": self.job_timestamp,
+ "clip_skip": self.clip_skip,
+ # "is_using_inpainting_conditioning": self.is_using_inpainting_conditioning,
+ }
+ return json.dumps(obj)
+
+ def infotext(self, p: StableDiffusionProcessing, index):
+ return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
+
+ def get_token_merging_ratio(self, for_hr=False):
+ return self.token_merging_ratio_hr if for_hr else self.token_merging_ratio
+
+
+def slerp(val, low, high): # from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
+ low_norm = low/torch.norm(low, dim=1, keepdim=True)
+ high_norm = high/torch.norm(high, dim=1, keepdim=True)
+ dot = (low_norm*high_norm).sum(1)
+
+ if dot.mean() > 0.9995:
+ return low * val + high * (1 - val)
+
+ omega = torch.acos(dot)
+ so = torch.sin(omega)
+ res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
+ return res
+
+
+def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
+ eta_noise_seed_delta = shared.opts.eta_noise_seed_delta or 0
+ xs = []
+ # if we have multiple seeds, this means we are working with batch size>1; this then
+ # enables the generation of additional tensors with noise that the sampler will use during its processing.
+ # Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
+ # produce the same images as with two batches [100], [101].
+ if p is not None and p.sampler is not None and (len(seeds) > 1 and shared.opts.enable_batch_seeds or eta_noise_seed_delta > 0):
+ sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
+ else:
+ sampler_noises = None
+ for i, seed in enumerate(seeds):
+ noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
+ subnoise = None
+ if subseeds is not None:
+ subseed = 0 if i >= len(subseeds) else subseeds[i]
+ subnoise = devices.randn(subseed, noise_shape)
+ # randn results depend on device; gpu and cpu get different results for same seed;
+ # the way I see it, it's better to do this on CPU, so that everyone gets same result;
+ # but the original script had it like this, so I do not dare change it for now because
+ # it will break everyone's seeds.
+ noise = devices.randn(seed, noise_shape)
+ if subnoise is not None:
+ noise = slerp(subseed_strength, noise, subnoise)
+ if noise_shape != shape:
+ x = devices.randn(seed, shape)
+ dx = (shape[2] - noise_shape[2]) // 2
+ dy = (shape[1] - noise_shape[1]) // 2
+ w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
+ h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
+ tx = 0 if dx < 0 else dx
+ ty = 0 if dy < 0 else dy
+ dx = max(-dx, 0)
+ dy = max(-dy, 0)
+ x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w]
+ noise = x
+ if sampler_noises is not None:
+ cnt = p.sampler.number_of_needed_noises(p)
+ if eta_noise_seed_delta > 0:
+ torch.manual_seed(seed + eta_noise_seed_delta)
+ for j in range(cnt):
+ sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
+ xs.append(noise)
+ if sampler_noises is not None:
+ p.sampler.sampler_noises = [torch.stack(n).to(shared.device) for n in sampler_noises]
+ x = torch.stack(xs).to(shared.device)
+ return x
+
+
+def decode_first_stage(model, x, full_quality=True):
+ if not shared.opts.keep_incomplete and (shared.state.skipped or shared.state.interrupted):
+ shared.log.debug(f'Decode VAE: skipped={shared.state.skipped} interrupted={shared.state.interrupted}')
+ x_sample = torch.zeros((len(x), 3, x.shape[2] * 8, x.shape[3] * 8), dtype=devices.dtype_vae, device=devices.device)
+ return x_sample
+ prev_job = shared.state.job
+ shared.state.job = 'vae'
+ with devices.autocast(disable = x.dtype==devices.dtype_vae):
+ try:
+ if full_quality:
+ if hasattr(model, 'decode_first_stage'):
+ x_sample = model.decode_first_stage(x)
+ elif hasattr(model, 'vae'):
+ x_sample = model.vae(x)
+ else:
+ x_sample = x
+ shared.log.error('Decode VAE unknown model')
+ else:
+ x_sample = torch.zeros((len(x), 3, x.shape[2] * 8, x.shape[3] * 8), dtype=devices.dtype_vae, device=devices.device)
+ for i in range(len(x_sample)):
+ x_sample[i] = modules.taesd.sd_vae_taesd.decode(x[i])
+ except Exception as e:
+ x_sample = x
+ shared.log.error(f'Decode VAE: {e}')
+ shared.state.job = prev_job
+ return x_sample
+
+
+def get_fixed_seed(seed):
+ if seed is None or seed == '' or seed == -1:
+ return int(random.randrange(4294967294))
+ return seed
+
+
+def fix_seed(p):
+ p.seed = get_fixed_seed(p.seed)
+ p.subseed = get_fixed_seed(p.subseed)
+
+
+def create_infotext(p: StableDiffusionProcessing, all_prompts=None, all_seeds=None, all_subseeds=None, comments=None, iteration=0, position_in_batch=0, index=None, all_negative_prompts=None):
+ if not hasattr(shared.sd_model, 'sd_checkpoint_info'):
+ return ''
+ if index is None:
+ index = position_in_batch + iteration * p.batch_size
+ if all_prompts is None:
+ all_prompts = p.all_prompts or [p.prompt]
+ if all_negative_prompts is None:
+ all_negative_prompts = p.all_negative_prompts or [p.negative_prompt]
+ if all_seeds is None:
+ all_seeds = p.all_seeds or [p.seed]
+ if all_subseeds is None:
+ all_subseeds = p.all_subseeds or [p.subseed]
+ while len(all_prompts) <= index:
+ all_prompts.append(all_prompts[-1])
+ while len(all_seeds) <= index:
+ all_seeds.append(all_seeds[-1])
+ while len(all_subseeds) <= index:
+ all_subseeds.append(all_subseeds[-1])
+ while len(all_negative_prompts) <= index:
+ all_negative_prompts.append(all_negative_prompts[-1])
+ comment = ', '.join(comments) if comments is not None and type(comments) is list else None
+ ops = list(set(p.ops))
+ ops.reverse()
+ args = {
+ # basic
+ "Steps": p.steps,
+ "Seed": all_seeds[index],
+ "Sampler": p.sampler_name,
+ "CFG scale": p.cfg_scale,
+ "Size": f"{p.width}x{p.height}" if hasattr(p, 'width') and hasattr(p, 'height') else None,
+ "Batch": f'{p.n_iter}x{p.batch_size}' if p.n_iter > 1 or p.batch_size > 1 else None,
+ "Index": f'{p.iteration + 1}x{index + 1}' if (p.n_iter > 1 or p.batch_size > 1) and index >= 0 else None,
+ "Parser": shared.opts.prompt_attention,
+ "Model": None if (not shared.opts.add_model_name_to_info) or (not shared.sd_model.sd_checkpoint_info.model_name) else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', ''),
+ "Model hash": getattr(p, 'sd_model_hash', None if (not shared.opts.add_model_hash_to_info) or (not shared.sd_model.sd_model_hash) else shared.sd_model.sd_model_hash),
+ "VAE": (None if not shared.opts.add_model_name_to_info or modules.sd_vae.loaded_vae_file is None else os.path.splitext(os.path.basename(modules.sd_vae.loaded_vae_file))[0]) if p.full_quality else 'TAESD',
+ "Seed resize from": None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}",
+ "Clip skip": p.clip_skip if p.clip_skip > 1 else None,
+ "Prompt2": p.refiner_prompt if len(p.refiner_prompt) > 0 else None,
+ "Negative2": p.refiner_negative if len(p.refiner_negative) > 0 else None,
+ "Styles": "; ".join(p.styles) if p.styles is not None and len(p.styles) > 0 else None,
+ "Tiling": p.tiling if p.tiling else None,
+ # sdnext
+ "Backend": 'Diffusers' if shared.backend == shared.Backend.DIFFUSERS else 'Original',
+ "App": 'SD.Next',
+ "Version": git_commit,
+ "Comment": comment,
+ "Operations": '; '.join(ops).replace('"', '') if len(p.ops) > 0 else 'none',
+ }
+ if 'txt2img' in p.ops:
+ pass
+ if shared.backend == shared.Backend.ORIGINAL:
+ args["Variation seed"] = None if p.subseed_strength == 0 else all_subseeds[index],
+ args["Variation strength"] = None if p.subseed_strength == 0 else p.subseed_strength,
+ if 'hires' in p.ops or 'upscale' in p.ops:
+ args["Second pass"] = p.enable_hr
+ args["Hires force"] = p.hr_force
+ args["Hires steps"] = p.hr_second_pass_steps
+ args["Hires upscaler"] = p.hr_upscaler
+ args["Hires upscale"] = p.hr_scale
+ args["Hires resize"] = f"{p.hr_resize_x}x{p.hr_resize_y}"
+ args["Hires size"] = f"{p.hr_upscale_to_x}x{p.hr_upscale_to_y}"
+ args["Denoising strength"] = p.denoising_strength
+ args["Hires sampler"] = p.hr_sampler_name
+ args["Image CFG scale"] = p.image_cfg_scale
+ args["CFG rescale"] = p.diffusers_guidance_rescale
+ if 'refine' in p.ops:
+ args["Second pass"] = p.enable_hr
+ args["Refiner"] = None if (not shared.opts.add_model_name_to_info) or (not shared.sd_refiner) or (not shared.sd_refiner.sd_checkpoint_info.model_name) else shared.sd_refiner.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')
+ args['Image CFG scale'] = p.image_cfg_scale
+ args['Refiner steps'] = p.refiner_steps
+ args['Refiner start'] = p.refiner_start
+ args["Hires steps"] = p.hr_second_pass_steps
+ args["Hires sampler"] = p.hr_sampler_name
+ args["CFG rescale"] = p.diffusers_guidance_rescale
+ if 'img2img' in p.ops or 'inpaint' in p.ops:
+ args["Init image size"] = f"{getattr(p, 'init_img_width', 0)}x{getattr(p, 'init_img_height', 0)}"
+ args["Init image hash"] = getattr(p, 'init_img_hash', None)
+ args["Mask weight"] = getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None
+ args['Resize scale'] = getattr(p, 'scale_by', None)
+ args["Mask blur"] = p.mask_blur if getattr(p, 'mask', None) is not None and getattr(p, 'mask_blur', 0) > 0 else None
+ args["Denoising strength"] = getattr(p, 'denoising_strength', None)
+ if args["Size"] is None:
+ args["Size"] = args["Init image size"]
+ # lookup by index
+ if getattr(p, 'resize_mode', None) is not None:
+ args['Resize mode'] = shared.resize_modes[p.resize_mode]
+ if 'face' in p.ops:
+ args["Face restoration"] = shared.opts.face_restoration_model
+ if 'color' in p.ops:
+ args["Color correction"] = True
+ # embeddings
+ if hasattr(modules.sd_hijack.model_hijack, 'embedding_db') and len(modules.sd_hijack.model_hijack.embedding_db.embeddings_used) > 0: # this is for original hijaacked models only, diffusers are handled separately
+ args["Embeddings"] = ', '.join(modules.sd_hijack.model_hijack.embedding_db.embeddings_used)
+ # samplers
+ args["Sampler ENSD"] = shared.opts.eta_noise_seed_delta if shared.opts.eta_noise_seed_delta != 0 and modules.sd_samplers_common.is_sampler_using_eta_noise_seed_delta(p) else None
+ args["Sampler ENSM"] = p.initial_noise_multiplier if getattr(p, 'initial_noise_multiplier', 1.0) != 1.0 else None
+ args['Sampler order'] = shared.opts.schedulers_solver_order if shared.opts.schedulers_solver_order != shared.opts.data_labels.get('schedulers_solver_order').default else None
+ if shared.backend == shared.Backend.DIFFUSERS:
+ args['Sampler beta schedule'] = shared.opts.schedulers_beta_schedule if shared.opts.schedulers_beta_schedule != shared.opts.data_labels.get('schedulers_beta_schedule').default else None
+ args['Sampler beta start'] = shared.opts.schedulers_beta_start if shared.opts.schedulers_beta_start != shared.opts.data_labels.get('schedulers_beta_start').default else None
+ args['Sampler beta end'] = shared.opts.schedulers_beta_end if shared.opts.schedulers_beta_end != shared.opts.data_labels.get('schedulers_beta_end').default else None
+ args['Sampler DPM solver'] = shared.opts.schedulers_dpm_solver if shared.opts.schedulers_dpm_solver != shared.opts.data_labels.get('schedulers_dpm_solver').default else None
+ if shared.backend == shared.Backend.ORIGINAL:
+ args['Sampler brownian'] = shared.opts.schedulers_brownian_noise if shared.opts.schedulers_brownian_noise != shared.opts.data_labels.get('schedulers_brownian_noise').default else None
+ args['Sampler discard'] = shared.opts.schedulers_discard_penultimate if shared.opts.schedulers_discard_penultimate != shared.opts.data_labels.get('schedulers_discard_penultimate').default else None
+ args['Sampler dyn threshold'] = shared.opts.schedulers_use_thresholding if shared.opts.schedulers_use_thresholding != shared.opts.data_labels.get('schedulers_use_thresholding').default else None
+ args['Sampler karras'] = shared.opts.schedulers_use_karras if shared.opts.schedulers_use_karras != shared.opts.data_labels.get('schedulers_use_karras').default else None
+ args['Sampler low order'] = shared.opts.schedulers_use_loworder if shared.opts.schedulers_use_loworder != shared.opts.data_labels.get('schedulers_use_loworder').default else None
+ args['Sampler quantization'] = shared.opts.enable_quantization if shared.opts.enable_quantization != shared.opts.data_labels.get('enable_quantization').default else None
+ args['Sampler sigma'] = shared.opts.schedulers_sigma if shared.opts.schedulers_sigma != shared.opts.data_labels.get('schedulers_sigma').default else None
+ args['Sampler sigma min'] = shared.opts.s_min if shared.opts.s_min != shared.opts.data_labels.get('s_min').default else None
+ args['Sampler sigma max'] = shared.opts.s_max if shared.opts.s_max != shared.opts.data_labels.get('s_max').default else None
+ args['Sampler sigma churn'] = shared.opts.s_churn if shared.opts.s_churn != shared.opts.data_labels.get('s_churn').default else None
+ args['Sampler sigma uncond'] = shared.opts.s_churn if shared.opts.s_churn != shared.opts.data_labels.get('s_churn').default else None
+ args['Sampler sigma noise'] = shared.opts.s_noise if shared.opts.s_noise != shared.opts.data_labels.get('s_noise').default else None
+ args['Sampler sigma tmin'] = shared.opts.s_tmin if shared.opts.s_tmin != shared.opts.data_labels.get('s_tmin').default else None
+ # tome
+ token_merging_ratio = p.get_token_merging_ratio()
+ token_merging_ratio_hr = p.get_token_merging_ratio(for_hr=True) if p.enable_hr else None
+ args['ToMe'] = token_merging_ratio if token_merging_ratio != 0 else None
+ args['ToMe hires'] = token_merging_ratio_hr if token_merging_ratio_hr != 0 else None
+
+ args.update(p.extra_generation_params)
+ params_text = ", ".join([k if k == v else f'{k}: {modules.generation_parameters_copypaste.quote(v)}' for k, v in args.items() if v is not None])
+ negative_prompt_text = f"\nNegative prompt: {all_negative_prompts[index]}" if all_negative_prompts[index] else ""
+ infotext = f"{all_prompts[index]}{negative_prompt_text}\n{params_text}".strip()
+ return infotext
+
+
+def process_images(p: StableDiffusionProcessing) -> Processed:
+ debug(f'Process images: {vars(p)}')
+ if not hasattr(p.sd_model, 'sd_checkpoint_info'):
+ return None
+ if p.scripts is not None and isinstance(p.scripts, modules.scripts.ScriptRunner):
+ p.scripts.before_process(p)
+ stored_opts = {}
+ for k, v in p.override_settings.copy().items():
+ if shared.opts.data.get(k, None) is None and shared.opts.data_labels.get(k, None) is None:
+ continue
+ orig = shared.opts.data.get(k, None) or shared.opts.data_labels[k].default
+ if orig == v or (type(orig) == str and os.path.splitext(orig)[0] == v):
+ p.override_settings.pop(k, None)
+ for k in p.override_settings.keys():
+ stored_opts[k] = shared.opts.data.get(k, None) or shared.opts.data_labels[k].default
+ res = None
+ try:
+ # if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
+ if p.override_settings.get('sd_model_checkpoint', None) is not None and modules.sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None:
+ shared.log.warning(f"Override not found: checkpoint={p.override_settings.get('sd_model_checkpoint', None)}")
+ p.override_settings.pop('sd_model_checkpoint', None)
+ modules.sd_models.reload_model_weights()
+ if p.override_settings.get('sd_model_refiner', None) is not None and modules.sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_refiner')) is None:
+ shared.log.warning(f"Override not found: refiner={p.override_settings.get('sd_model_refiner', None)}")
+ p.override_settings.pop('sd_model_refiner', None)
+ modules.sd_models.reload_model_weights()
+ if p.override_settings.get('sd_vae', None) is not None:
+ if p.override_settings.get('sd_vae', None) == 'TAESD':
+ p.full_quality = False
+ p.override_settings.pop('sd_vae', None)
+ if p.override_settings.get('Hires upscaler', None) is not None:
+ p.enable_hr = True
+ if len(p.override_settings.keys()) > 0:
+ shared.log.debug(f'Override: {p.override_settings}')
+ for k, v in p.override_settings.items():
+ setattr(shared.opts, k, v)
+ if k == 'sd_model_checkpoint':
+ modules.sd_models.reload_model_weights()
+ if k == 'sd_vae':
+ modules.sd_vae.reload_vae_weights()
+
+ shared.prompt_styles.apply_styles_to_extra(p)
+ if not shared.opts.cuda_compile:
+ modules.sd_models.apply_token_merging(p.sd_model, p.get_token_merging_ratio())
+ modules.sd_hijack_freeu.apply_freeu(p, shared.backend == shared.Backend.ORIGINAL)
+
+ modules.script_callbacks.before_process_callback(p)
+
+ if shared.cmd_opts.profile:
+ import cProfile
+ profile_python = cProfile.Profile()
+ profile_python.enable()
+ with context_hypertile_vae(p), context_hypertile_unet(p):
+ import torch.profiler # pylint: disable=redefined-outer-name
+ activities=[torch.profiler.ProfilerActivity.CPU]
+ if torch.cuda.is_available():
+ activities.append(torch.profiler.ProfilerActivity.CUDA)
+ shared.log.debug(f'Torch profile: activities={activities}')
+ if shared.profiler is None:
+ shared.profiler = torch.profiler.profile(activities=activities, profile_memory=True, with_modules=True)
+ shared.profiler.start()
+ shared.profiler.step()
+ res = process_images_inner(p)
+ errors.profile_torch(shared.profiler, 'Process')
+ errors.profile(profile_python, 'Process')
+ else:
+ with context_hypertile_vae(p), context_hypertile_unet(p):
+ res = process_images_inner(p)
+
+ finally:
+ if not shared.opts.cuda_compile:
+ modules.sd_models.apply_token_merging(p.sd_model, 0)
+ modules.script_callbacks.after_process_callback(p)
+ if p.override_settings_restore_afterwards: # restore opts to original state
+ for k, v in stored_opts.items():
+ setattr(shared.opts, k, v)
+ if k == 'sd_model_checkpoint':
+ modules.sd_models.reload_model_weights()
+ if k == 'sd_model_refiner':
+ modules.sd_models.reload_model_weights()
+ if k == 'sd_vae':
+ modules.sd_vae.reload_vae_weights()
+ return res
+
+
+def validate_sample(tensor):
+ if not isinstance(tensor, np.ndarray) and not isinstance(tensor, torch.Tensor):
+ return tensor
+ if tensor.dtype == torch.bfloat16: # numpy does not support bf16
+ tensor = tensor.to(torch.float16)
+ if isinstance(tensor, torch.Tensor) and hasattr(tensor, 'detach'):
+ sample = tensor.detach().cpu().numpy()
+ elif isinstance(tensor, np.ndarray):
+ sample = tensor
+ else:
+ shared.log.warning(f'Unknown sample type: {type(tensor)}')
+ sample = 255.0 * np.moveaxis(sample, 0, 2) if shared.backend == shared.Backend.ORIGINAL else 255.0 * sample
+ with warnings.catch_warnings(record=True) as w:
+ cast = sample.astype(np.uint8)
+ if len(w) > 0:
+ nans = np.isnan(sample).sum()
+ shared.log.error(f'Failed to validate samples: sample={sample.shape} invalid={nans}')
+ cast = np.nan_to_num(sample)
+ minimum, maximum, mean = np.min(cast), np.max(cast), np.mean(cast)
+ cast = cast.astype(np.uint8)
+ shared.log.warning(f'Attempted to correct samples: min={minimum:.2f} max={maximum:.2f} mean={mean:.2f}')
+ return cast
+
+
+def process_init(p: StableDiffusionProcessing):
+ seed = get_fixed_seed(p.seed)
+ subseed = get_fixed_seed(p.subseed)
+ if type(p.prompt) == list:
+ p.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, p.styles) for x in p.prompt]
+ else:
+ p.all_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_styles_to_prompt(p.prompt, p.styles)]
+ if type(p.negative_prompt) == list:
+ p.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, p.styles) for x in p.negative_prompt]
+ else:
+ p.all_negative_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)]
+ if type(seed) == list:
+ p.all_seeds = seed
+ else:
+ p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))]
+ if type(subseed) == list:
+ p.all_subseeds = subseed
+ else:
+ p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
+
+
+def process_images_inner(p: StableDiffusionProcessing) -> Processed:
+ """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
+
+ if type(p.prompt) == list:
+ assert len(p.prompt) > 0
+ else:
+ assert p.prompt is not None
+
+ if shared.backend == shared.Backend.ORIGINAL:
+ modules.sd_hijack.model_hijack.apply_circular(p.tiling)
+ modules.sd_hijack.model_hijack.clear_comments()
+ comments = {}
+ infotexts = []
+ output_images = []
+ cached_uc = [None, None]
+ cached_c = [None, None]
+
+ process_init(p)
+ if os.path.exists(shared.opts.embeddings_dir) and not p.do_not_reload_embeddings and shared.backend == shared.Backend.ORIGINAL:
+ modules.sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=False)
+ if p.scripts is not None and isinstance(p.scripts, modules.scripts.ScriptRunner):
+ p.scripts.process(p)
+
+
+ def get_conds_with_caching(function, required_prompts, steps, cache):
+ if cache[0] is not None and (required_prompts, steps) == cache[0]:
+ return cache[1]
+ with devices.autocast():
+ cache[1] = function(shared.sd_model, required_prompts, steps)
+ cache[0] = (required_prompts, steps)
+ return cache[1]
+
+ def infotext(_inxex=0): # dummy function overriden if there are iterations
+ return ''
+
+ ema_scope_context = p.sd_model.ema_scope if shared.backend == shared.Backend.ORIGINAL else nullcontext
+ shared.state.job_count = p.n_iter
+ with devices.inference_context(), ema_scope_context():
+ t0 = time.time()
+ with devices.autocast():
+ p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
+ extra_network_data = None
+ debug(f'Processing inner: args={vars(p)}')
+ for n in range(p.n_iter):
+ p.iteration = n
+ if shared.state.skipped:
+ shared.log.debug(f'Process skipped: {n}/{p.n_iter}')
+ shared.state.skipped = False
+ continue
+ if shared.state.interrupted:
+ shared.log.debug(f'Process interrupted: {n}/{p.n_iter}')
+ break
+ p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
+ p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
+ p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
+ p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
+ if p.scripts is not None and isinstance(p.scripts, modules.scripts.ScriptRunner):
+ p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
+ if len(p.prompts) == 0:
+ break
+ p.prompts, extra_network_data = modules.extra_networks.parse_prompts(p.prompts)
+ if not p.disable_extra_networks:
+ with devices.autocast():
+ modules.extra_networks.activate(p, extra_network_data)
+ if p.scripts is not None and isinstance(p.scripts, modules.scripts.ScriptRunner):
+ p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
+ step_multiplier = 1
+ sampler_config = modules.sd_samplers.find_sampler_config(p.sampler_name)
+ step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
+
+ if shared.backend == shared.Backend.ORIGINAL:
+ uc = get_conds_with_caching(modules.prompt_parser.get_learned_conditioning, p.negative_prompts, p.steps * step_multiplier, cached_uc)
+ c = get_conds_with_caching(modules.prompt_parser.get_multicond_learned_conditioning, p.prompts, p.steps * step_multiplier, cached_c)
+ if len(modules.sd_hijack.model_hijack.comments) > 0:
+ for comment in modules.sd_hijack.model_hijack.comments:
+ comments[comment] = 1
+ with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
+ samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
+ x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae), p.full_quality)[0].cpu() for i in range(samples_ddim.size(0))]
+ try:
+ for x in x_samples_ddim:
+ devices.test_for_nans(x, "vae")
+ except devices.NansException as e:
+ if not shared.opts.no_half and not shared.opts.no_half_vae and shared.cmd_opts.rollback_vae:
+ shared.log.warning('Tensor with all NaNs was produced in VAE')
+ devices.dtype_vae = torch.bfloat16
+ vae_file, vae_source = modules.sd_vae.resolve_vae(p.sd_model.sd_model_checkpoint)
+ modules.sd_vae.load_vae(p.sd_model, vae_file, vae_source)
+ x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae), p.full_quality)[0].cpu() for i in range(samples_ddim.size(0))]
+ for x in x_samples_ddim:
+ devices.test_for_nans(x, "vae")
+ else:
+ raise e
+ x_samples_ddim = torch.stack(x_samples_ddim).float()
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
+ del samples_ddim
+
+ elif shared.backend == shared.Backend.DIFFUSERS:
+ from modules.processing_diffusers import process_diffusers
+ x_samples_ddim = process_diffusers(p)
+ else:
+ raise ValueError(f"Unknown backend {shared.backend}")
+
+ if not shared.opts.keep_incomplete and shared.state.interrupted:
+ x_samples_ddim = []
+
+ if shared.cmd_opts.lowvram or shared.cmd_opts.medvram and shared.backend == shared.Backend.ORIGINAL:
+ modules.lowvram.send_everything_to_cpu()
+ devices.torch_gc()
+ if p.scripts is not None and isinstance(p.scripts, modules.scripts.ScriptRunner):
+ p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
+ if p.scripts is not None and isinstance(p.scripts, modules.scripts.ScriptRunner):
+ p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
+ p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
+ batch_params = modules.scripts.PostprocessBatchListArgs(list(x_samples_ddim))
+ p.scripts.postprocess_batch_list(p, batch_params, batch_number=n)
+ x_samples_ddim = batch_params.images
+
+ def infotext(index): # pylint: disable=function-redefined # noqa: F811
+ return create_infotext(p, p.prompts, p.seeds, p.subseeds, index=index, all_negative_prompts=p.negative_prompts)
+
+ for i, x_sample in enumerate(x_samples_ddim):
+ p.batch_index = i
+ if type(x_sample) == Image.Image:
+ image = x_sample
+ x_sample = np.array(x_sample)
+ else:
+ x_sample = validate_sample(x_sample)
+ image = Image.fromarray(x_sample)
+ if p.restore_faces:
+ if shared.opts.save and not p.do_not_save_samples and shared.opts.save_images_before_face_restoration:
+ orig = p.restore_faces
+ p.restore_faces = False
+ info = infotext(i)
+ p.restore_faces = orig
+ images.save_image(Image.fromarray(x_sample), path=p.outpath_samples, basename="", seed=p.seeds[i], prompt=p.prompts[i], extension=shared.opts.samples_format, info=info, p=p, suffix="-before-face-restore")
+ p.ops.append('face')
+ x_sample = modules.face_restoration.restore_faces(x_sample)
+ image = Image.fromarray(x_sample)
+ if p.scripts is not None and isinstance(p.scripts, modules.scripts.ScriptRunner):
+ pp = modules.scripts.PostprocessImageArgs(image)
+ p.scripts.postprocess_image(p, pp)
+ image = pp.image
+ if p.color_corrections is not None and i < len(p.color_corrections):
+ if shared.opts.save and not p.do_not_save_samples and shared.opts.save_images_before_color_correction:
+ orig = p.color_corrections
+ p.color_corrections = None
+ info = infotext(i)
+ p.color_corrections = orig
+ image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
+ images.save_image(image_without_cc, path=p.outpath_samples, basename="", seed=p.seeds[i], prompt=p.prompts[i], extension=shared.opts.samples_format, info=info, p=p, suffix="-before-color-correct")
+ p.ops.append('color')
+ image = apply_color_correction(p.color_corrections[i], image)
+ image = apply_overlay(image, p.paste_to, i, p.overlay_images)
+ text = infotext(i)
+ infotexts.append(text)
+ image.info["parameters"] = text
+ output_images.append(image)
+ if shared.opts.samples_save and not p.do_not_save_samples:
+ images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], shared.opts.samples_format, info=text, p=p) # main save image
+ if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([shared.opts.save_mask, shared.opts.save_mask_composite, shared.opts.return_mask, shared.opts.return_mask_composite]):
+ image_mask = p.mask_for_overlay.convert('RGB')
+ image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(3, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
+ if shared.opts.save_mask:
+ images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], shared.opts.samples_format, info=text, p=p, suffix="-mask")
+ if shared.opts.save_mask_composite:
+ images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], shared.opts.samples_format, info=text, p=p, suffix="-mask-composite")
+ if shared.opts.return_mask:
+ output_images.append(image_mask)
+ if shared.opts.return_mask_composite:
+ output_images.append(image_mask_composite)
+ del x_samples_ddim
+ devices.torch_gc()
+
+ t1 = time.time()
+ shared.log.info(f'Processed: images={len(output_images)} time={t1 - t0:.2f} its={(p.steps * len(output_images)) / (t1 - t0):.2f} memory={modules.memstats.memory_stats()}')
+
+ p.color_corrections = None
+ index_of_first_image = 0
+ if (shared.opts.return_grid or shared.opts.grid_save) and not p.do_not_save_grid and len(output_images) > 1:
+ if images.check_grid_size(output_images):
+ grid = images.image_grid(output_images, p.batch_size)
+ if shared.opts.return_grid:
+ text = infotext(-1)
+ infotexts.insert(0, text)
+ grid.info["parameters"] = text
+ output_images.insert(0, grid)
+ index_of_first_image = 1
+ if shared.opts.grid_save:
+ images.save_image(grid, p.outpath_grids, "", p.all_seeds[0], p.all_prompts[0], shared.opts.grid_format, info=infotext(-1), p=p, grid=True, suffix="-grid") # main save grid
+
+ if not p.disable_extra_networks:
+ modules.extra_networks.deactivate(p, extra_network_data)
+
+ res = Processed(
+ p,
+ images_list=output_images,
+ seed=p.all_seeds[0],
+ info=infotext(0),
+ comments="\n".join(comments),
+ subseed=p.all_subseeds[0],
+ index_of_first_image=index_of_first_image,
+ infotexts=infotexts,
+ )
+ if p.scripts is not None and isinstance(p.scripts, modules.scripts.ScriptRunner) and not (shared.state.interrupted or shared.state.skipped):
+ p.scripts.postprocess(p, res)
+ return res
+
+
+def old_hires_fix_first_pass_dimensions(width, height):
+ """old algorithm for auto-calculating first pass size"""
+ desired_pixel_count = 512 * 512
+ actual_pixel_count = width * height
+ scale = math.sqrt(desired_pixel_count / actual_pixel_count)
+ width = math.ceil(scale * width / 64) * 64
+ height = math.ceil(scale * height / 64) * 64
+ return width, height
+
+
+class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
+
+ def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_force: bool = False, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, refiner_steps: int = 5, refiner_start: float = 0, refiner_prompt: str = '', refiner_negative: str = '', **kwargs):
+
+ super().__init__(**kwargs)
+ if devices.backend == "ipex" and os.environ.get('DISABLE_IPEX_1024_WA', None) is None:
+ width_curse = bool(hr_resize_x == 1024 and self.height * (hr_resize_x / self.width) == 1024)
+ height_curse = bool(hr_resize_y == 1024 and self.width * (hr_resize_y / self.height) == 1024)
+ if (width_curse != height_curse) or (height_curse and width_curse):
+ if width_curse:
+ hr_resize_x = 1080
+ if height_curse:
+ hr_resize_y = 1080
+ if self.width * hr_scale == 1024 and self.height * hr_scale == 1024:
+ hr_scale = 1080 / self.width
+ if firstphase_width * hr_scale == 1024 and firstphase_height * hr_scale == 1024:
+ hr_scale = 1080 / firstphase_width
+ self.enable_hr = enable_hr
+ self.denoising_strength = denoising_strength
+ self.hr_scale = hr_scale
+ self.hr_upscaler = hr_upscaler
+ self.hr_force = hr_force
+ self.hr_second_pass_steps = hr_second_pass_steps
+ self.hr_resize_x = hr_resize_x
+ self.hr_resize_y = hr_resize_y
+ self.hr_upscale_to_x = hr_resize_x
+ self.hr_upscale_to_y = hr_resize_y
+ if firstphase_width != 0 or firstphase_height != 0:
+ self.hr_upscale_to_x = self.width
+ self.hr_upscale_to_y = self.height
+ self.width = firstphase_width
+ self.height = firstphase_height
+ self.truncate_x = 0
+ self.truncate_y = 0
+ self.applied_old_hires_behavior_to = None
+ self.refiner_steps = refiner_steps
+ self.refiner_start = refiner_start
+ self.refiner_prompt = refiner_prompt
+ self.refiner_negative = refiner_negative
+ self.sampler = None
+ self.scripts = None
+ self.script_args = []
+
+ def init(self, all_prompts, all_seeds, all_subseeds):
+ if shared.backend == shared.Backend.DIFFUSERS:
+ shared.sd_model = modules.sd_models.set_diffuser_pipe(self.sd_model, modules.sd_models.DiffusersTaskType.TEXT_2_IMAGE)
+ self.width = self.width or 512
+ self.height = self.height or 512
+
+ def init_hr(self):
+ if self.hr_resize_x == 0 and self.hr_resize_y == 0:
+ self.hr_upscale_to_x = int(self.width * self.hr_scale)
+ self.hr_upscale_to_y = int(self.height * self.hr_scale)
+ else:
+ if self.hr_resize_y == 0:
+ self.hr_upscale_to_x = self.hr_resize_x
+ self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
+ elif self.hr_resize_x == 0:
+ self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
+ self.hr_upscale_to_y = self.hr_resize_y
+ else:
+ target_w = self.hr_resize_x
+ target_h = self.hr_resize_y
+ src_ratio = self.width / self.height
+ dst_ratio = self.hr_resize_x / self.hr_resize_y
+ if src_ratio < dst_ratio:
+ self.hr_upscale_to_x = self.hr_resize_x
+ self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
+ else:
+ self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
+ self.hr_upscale_to_y = self.hr_resize_y
+ self.truncate_x = (self.hr_upscale_to_x - target_w) // 8
+ self.truncate_y = (self.hr_upscale_to_y - target_h) // 8
+ # special case: the user has chosen to do nothing
+ if (self.hr_upscale_to_x == self.width and self.hr_upscale_to_y == self.height) or self.hr_upscaler is None or self.hr_upscaler == 'None':
+ self.is_hr_pass = False
+ return
+ self.is_hr_pass = True
+ hypertile_set(self, hr=True)
+ shared.state.job_count = 2 * self.n_iter
+ shared.log.debug(f'Init hires: upscaler="{self.hr_upscaler}" sampler="{self.hr_sampler_name}" resize={self.hr_resize_x}x{self.hr_resize_y} upscale={self.hr_upscale_to_x}x{self.hr_upscale_to_y}')
+
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
+
+ latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "None")
+ if latent_scale_mode is not None:
+ self.hr_force = False # no need to force anything
+ if self.enable_hr and (latent_scale_mode is None or self.hr_force):
+ if len([x for x in shared.sd_upscalers if x.name == self.hr_upscaler]) == 0:
+ shared.log.warning(f"Cannot find upscaler for hires: {self.hr_upscaler}")
+ self.enable_hr = False
+
+ self.ops.append('txt2img')
+ hypertile_set(self)
+ self.sampler = modules.sd_samplers.create_sampler(self.sampler_name, self.sd_model)
+ if hasattr(self.sampler, "initialize"):
+ self.sampler.initialize(self)
+ x = create_random_tensors([4, self.height // 8, self.width // 8], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
+ samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
+ shared.state.nextjob()
+ if not self.enable_hr or shared.state.interrupted or shared.state.skipped:
+ return samples
+
+ self.init_hr()
+ if self.is_hr_pass:
+ prev_job = shared.state.job
+ target_width = self.hr_upscale_to_x
+ target_height = self.hr_upscale_to_y
+ decoded_samples = None
+ if shared.opts.save and shared.opts.save_images_before_highres_fix and not self.do_not_save_samples:
+ decoded_samples = decode_first_stage(self.sd_model, samples.to(dtype=devices.dtype_vae), self.full_quality)
+ decoded_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
+ for i, x_sample in enumerate(decoded_samples):
+ x_sample = validate_sample(x_sample)
+ image = Image.fromarray(x_sample)
+ bak_extra_generation_params, bak_restore_faces = self.extra_generation_params, self.restore_faces
+ self.extra_generation_params = {}
+ self.restore_faces = False
+ info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=i)
+ self.extra_generation_params, self.restore_faces = bak_extra_generation_params, bak_restore_faces
+ images.save_image(image, self.outpath_samples, "", seeds[i], prompts[i], shared.opts.samples_format, info=info, suffix="-before-hires")
+ if latent_scale_mode is None or self.hr_force: # non-latent upscaling
+ shared.state.job = 'upscale'
+ if decoded_samples is None:
+ decoded_samples = decode_first_stage(self.sd_model, samples.to(dtype=devices.dtype_vae), self.full_quality)
+ decoded_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
+ batch_images = []
+ for _i, x_sample in enumerate(decoded_samples):
+ x_sample = validate_sample(x_sample)
+ image = Image.fromarray(x_sample)
+ image = images.resize_image(1, image, target_width, target_height, upscaler_name=self.hr_upscaler)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = np.moveaxis(image, 2, 0)
+ batch_images.append(image)
+ resized_samples = torch.from_numpy(np.array(batch_images))
+ resized_samples = resized_samples.to(device=shared.device, dtype=devices.dtype_vae)
+ resized_samples = 2.0 * resized_samples - 1.0
+ if shared.opts.sd_vae_sliced_encode and len(decoded_samples) > 1:
+ samples = torch.stack([self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(torch.unsqueeze(resized_sample, 0)))[0] for resized_sample in resized_samples])
+ else:
+ # TODO add TEASD support
+ samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(resized_samples))
+ image_conditioning = self.img2img_image_conditioning(resized_samples, samples)
+ else:
+ samples = torch.nn.functional.interpolate(samples, size=(target_height // 8, target_width // 8), mode=latent_scale_mode["mode"], antialias=latent_scale_mode["antialias"])
+ if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
+ image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples.to(dtype=devices.dtype_vae), self.full_quality), samples)
+ else:
+ image_conditioning = self.txt2img_image_conditioning(samples.to(dtype=devices.dtype_vae))
+ if self.hr_sampler_name == "PLMS":
+ self.hr_sampler_name = 'UniPC'
+ if self.hr_force or latent_scale_mode is not None:
+ shared.state.job = 'hires'
+ if self.denoising_strength > 0:
+ self.ops.append('hires')
+ devices.torch_gc() # GC now before running the next img2img to prevent running out of memory
+ self.sampler = modules.sd_samplers.create_sampler(self.hr_sampler_name or self.sampler_name, self.sd_model)
+ if hasattr(self.sampler, "initialize"):
+ self.sampler.initialize(self)
+ samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
+ noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
+ modules.sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
+ hypertile_set(self, hr=True)
+ samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
+ modules.sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
+ else:
+ self.ops.append('upscale')
+ x = None
+ self.is_hr_pass = False
+ shared.state.job = prev_job
+ shared.state.nextjob()
+
+ return samples
+
+
+class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
+
+ def __init__(self, init_images: list = None, resize_mode: int = 0, resize_name: str = 'None', denoising_strength: float = 0.3, image_cfg_scale: float = None, mask: Any = None, mask_blur: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, scale_by: float = 1, refiner_steps: int = 5, refiner_start: float = 0, refiner_prompt: str = '', refiner_negative: str = '', **kwargs):
+ super().__init__(**kwargs)
+ self.init_images = init_images
+ self.resize_mode: int = resize_mode
+ self.resize_name: str = resize_name
+ self.denoising_strength: float = denoising_strength
+ self.image_cfg_scale: float = image_cfg_scale
+ self.init_latent = None
+ self.image_mask = mask
+ self.latent_mask = None
+ self.mask_for_overlay = None
+ self.mask_blur_x = mask_blur # a1111 compatibility item
+ self.mask_blur_y = mask_blur # a1111 compatibility item
+ self.mask_blur = mask_blur
+ self.inpainting_fill = inpainting_fill
+ self.inpaint_full_res = inpaint_full_res
+ self.inpaint_full_res_padding = inpaint_full_res_padding
+ self.inpainting_mask_invert = inpainting_mask_invert
+ self.initial_noise_multiplier = shared.opts.initial_noise_multiplier if initial_noise_multiplier is None else initial_noise_multiplier
+ self.mask = None
+ self.nmask = None
+ self.image_conditioning = None
+ self.refiner_steps = refiner_steps
+ self.refiner_start = refiner_start
+ self.refiner_prompt = refiner_prompt
+ self.refiner_negative = refiner_negative
+ self.enable_hr = None
+ self.is_batch = False
+ self.scale_by = scale_by
+ self.sampler = None
+ self.scripts = None
+ self.script_args = []
+
+ def init(self, all_prompts, all_seeds, all_subseeds):
+ if shared.backend == shared.Backend.DIFFUSERS and self.image_mask is not None and not self.is_control:
+ shared.sd_model = modules.sd_models.set_diffuser_pipe(self.sd_model, modules.sd_models.DiffusersTaskType.INPAINTING)
+ elif shared.backend == shared.Backend.DIFFUSERS and self.image_mask is None and not self.is_control:
+ shared.sd_model = modules.sd_models.set_diffuser_pipe(self.sd_model, modules.sd_models.DiffusersTaskType.IMAGE_2_IMAGE)
+
+ if self.sampler_name == "PLMS":
+ self.sampler_name = 'UniPC'
+ self.sampler = modules.sd_samplers.create_sampler(self.sampler_name, self.sd_model)
+ if hasattr(self.sampler, "initialize"):
+ self.sampler.initialize(self)
+
+ if self.image_mask is not None:
+ self.ops.append('inpaint')
+ else:
+ self.ops.append('img2img')
+ crop_region = None
+
+ if self.image_mask is not None:
+ if type(self.image_mask) == list:
+ self.image_mask = self.image_mask[0]
+ self.image_mask = create_binary_mask(self.image_mask)
+ if self.inpainting_mask_invert:
+ self.image_mask = ImageOps.invert(self.image_mask)
+ if self.mask_blur > 0:
+ np_mask = np.array(self.image_mask)
+ kernel_size = 2 * int(2.5 * self.mask_blur + 0.5) + 1
+ np_mask = cv2.GaussianBlur(np_mask, (kernel_size, 1), self.mask_blur)
+ np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur)
+ self.image_mask = Image.fromarray(np_mask)
+ if self.inpaint_full_res:
+ self.mask_for_overlay = self.image_mask
+ mask = self.image_mask.convert('L')
+ crop_region = modules.masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
+ crop_region = modules.masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
+ x1, y1, x2, y2 = crop_region
+ mask = mask.crop(crop_region)
+ self.image_mask = images.resize_image(2, mask, self.width, self.height)
+ self.paste_to = (x1, y1, x2-x1, y2-y1)
+ else:
+ self.image_mask = images.resize_image(self.resize_mode, self.image_mask, self.width, self.height)
+ np_mask = np.array(self.image_mask)
+ np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
+ self.mask_for_overlay = Image.fromarray(np_mask)
+ self.overlay_images = []
+
+ latent_mask = self.latent_mask if self.latent_mask is not None else self.image_mask
+
+ add_color_corrections = shared.opts.img2img_color_correction and self.color_corrections is None
+ if add_color_corrections:
+ self.color_corrections = []
+ processed = []
+ if getattr(self, 'init_images', None) is None:
+ return
+ if not isinstance(self.init_images, list):
+ self.init_images = [self.init_images]
+ for img in self.init_images:
+ if img is None:
+ shared.log.warning(f"Skipping empty image: images={self.init_images}")
+ continue
+ self.init_img_hash = hashlib.sha256(img.tobytes()).hexdigest()[0:8] # pylint: disable=attribute-defined-outside-init
+ self.init_img_width = img.width # pylint: disable=attribute-defined-outside-init
+ self.init_img_height = img.height # pylint: disable=attribute-defined-outside-init
+ if shared.opts.save_init_img:
+ images.save_image(img, path=shared.opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, suffix="-init-image")
+ image = images.flatten(img, shared.opts.img2img_background_color)
+ if crop_region is None and self.resize_mode != 4 and self.resize_mode > 0:
+ if image.width != self.width or image.height != self.height:
+ image = images.resize_image(self.resize_mode, image, self.width, self.height, self.resize_name)
+ self.width = image.width
+ self.height = image.height
+ if self.image_mask is not None:
+ try:
+ image_masked = Image.new('RGBa', (image.width, image.height))
+ image_to_paste = image.convert("RGBA").convert("RGBa")
+ image_to_mask = ImageOps.invert(self.mask_for_overlay.convert('L')) if self.mask_for_overlay is not None else None
+ image_to_mask = image_to_mask.resize((image.width, image.height), Image.Resampling.BILINEAR) if image_to_mask is not None else None
+ image_masked.paste(image_to_paste, mask=image_to_mask)
+ self.overlay_images.append(image_masked.convert('RGBA'))
+ except Exception as e:
+ shared.log.error(f"Failed to apply mask to image: {e}")
+ if crop_region is not None: # crop_region is not None if we are doing inpaint full res
+ image = image.crop(crop_region)
+ if image.width != self.width or image.height != self.height:
+ image = images.resize_image(3, image, self.width, self.height, self.resize_name)
+ if self.image_mask is not None and self.inpainting_fill != 1:
+ image = modules.masking.fill(image, latent_mask)
+ if add_color_corrections:
+ self.color_corrections.append(setup_color_correction(image))
+ processed.append(image)
+ self.init_images = processed
+ self.batch_size = len(self.init_images)
+ if self.overlay_images is not None:
+ self.overlay_images = self.overlay_images * self.batch_size
+ if self.color_corrections is not None and len(self.color_corrections) == 1:
+ self.color_corrections = self.color_corrections * self.batch_size
+ if shared.backend == shared.Backend.DIFFUSERS:
+ return # we've already set self.init_images and self.mask and we dont need any more processing
+
+ self.init_images = [np.moveaxis((np.array(image).astype(np.float32) / 255.0), 2, 0) for image in self.init_images]
+ if len(self.init_images) == 1:
+ batch_images = np.expand_dims(self.init_images[0], axis=0).repeat(self.batch_size, axis=0)
+ elif len(self.init_images) <= self.batch_size:
+ batch_images = np.array(self.init_images)
+ image = torch.from_numpy(batch_images)
+ image = 2. * image - 1.
+ image = image.to(device=shared.device, dtype=devices.dtype_vae)
+ self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
+ if self.resize_mode == 4:
+ self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // 8, self.width // 8), mode="bilinear")
+ if self.image_mask is not None:
+ init_mask = latent_mask
+ latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
+ latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
+ latmask = latmask[0]
+ latmask = np.tile(latmask[None], (4, 1, 1))
+ latmask = np.around(latmask)
+ self.mask = torch.asarray(1.0 - latmask).to(device=shared.device, dtype=self.sd_model.dtype)
+ self.nmask = torch.asarray(latmask).to(device=shared.device, dtype=self.sd_model.dtype)
+ if self.inpainting_fill == 2:
+ self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
+ elif self.inpainting_fill == 3:
+ self.init_latent = self.init_latent * self.mask
+ self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, self.image_mask)
+
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
+ hypertile_set(self)
+ x = create_random_tensors([4, self.height // 8, self.width // 8], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
+ x *= self.initial_noise_multiplier
+ samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
+ if self.mask is not None:
+ samples = samples * self.nmask + self.init_latent * self.mask
+ del x
+ devices.torch_gc()
+ shared.state.nextjob()
+ return samples
+
+ def get_token_merging_ratio(self, for_hr=False):
+ return self.token_merging_ratio or ("token_merging_ratio" in self.override_settings and shared.opts.token_merging_ratio) or shared.opts.token_merging_ratio_img2img or shared.opts.token_merging_ratio
diff --git a/modules/progress.py b/modules/progress.py
index 7da590a7c..cc6704d99 100644
--- a/modules/progress.py
+++ b/modules/progress.py
@@ -1,93 +1,93 @@
-import base64
-import io
-import time
-from pydantic import BaseModel, Field # pylint: disable=no-name-in-module
-import modules.shared as shared
-
-
-current_task = None
-pending_tasks = {}
-finished_tasks = []
-recorded_results = []
-recorded_results_limit = 2
-
-
-def start_task(id_task):
- global current_task # pylint: disable=global-statement
- current_task = id_task
- pending_tasks.pop(id_task, None)
-
-
-def record_results(id_task, res):
- recorded_results.append((id_task, res))
- if len(recorded_results) > recorded_results_limit:
- recorded_results.pop(0)
-
-
-def finish_task(id_task):
- global current_task # pylint: disable=global-statement
- if current_task == id_task:
- current_task = None
- finished_tasks.append(id_task)
- if len(finished_tasks) > 16:
- finished_tasks.pop(0)
-
-
-def add_task_to_queue(id_job):
- pending_tasks[id_job] = time.time()
-
-
-class ProgressRequest(BaseModel):
- id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for")
- id_live_preview: int = Field(default=-1, title="Live preview image ID", description="id of last received last preview image")
-
-
-class InternalProgressResponse(BaseModel):
- job: str = Field(default=None, title="Job name", description="Internal job name")
- active: bool = Field(title="Whether the task is being worked on right now")
- queued: bool = Field(title="Whether the task is in queue")
- paused: bool = Field(title="Whether the task is paused")
- completed: bool = Field(title="Whether the task has already finished")
- progress: float = Field(default=None, title="Progress", description="The progress with a range of 0 to 1")
- eta: float = Field(default=None, title="ETA in secs")
- live_preview: str = Field(default=None, title="Live preview image", description="Current live preview; a data: uri")
- id_live_preview: int = Field(default=None, title="Live preview image ID", description="Send this together with next request to prevent receiving same image")
- textinfo: str = Field(default=None, title="Info text", description="Info text used by WebUI.")
-
-
-def progressapi(req: ProgressRequest):
- active = req.id_task == current_task
- queued = req.id_task in pending_tasks
- completed = req.id_task in finished_tasks
- paused = shared.state.paused
- if not active:
- return InternalProgressResponse(job=shared.state.job, active=active, queued=queued, paused=paused, completed=completed, id_live_preview=-1, textinfo="Queued..." if queued else "Waiting...")
- if shared.state.job_no > shared.state.job_count:
- shared.state.job_count = shared.state.job_no
- batch_x = max(shared.state.job_no, 0)
- batch_y = max(shared.state.job_count, 1)
- step_x = max(shared.state.sampling_step, 0)
- step_y = max(shared.state.sampling_steps, 1)
- current = step_y * batch_x + step_x
- total = step_y * batch_y
- progress = min(1, abs(current / total) if total > 0 else 0)
- elapsed = time.time() - shared.state.time_start
- predicted = elapsed / progress if progress > 0 else None
- eta = predicted - elapsed if predicted is not None else None
- # shared.log.debug(f'Progress: step={step_x}:{step_y} batch={batch_x}:{batch_y} current={current} total={total} progress={progress} elapsed={elapsed} eta={eta}')
-
- id_live_preview = req.id_live_preview
- live_preview = None
- shared.state.set_current_image()
- if shared.opts.live_previews_enable and (shared.state.id_live_preview != req.id_live_preview) and (shared.state.current_image is not None):
- buffered = io.BytesIO()
- shared.state.current_image.save(buffered, format='jpeg')
- live_preview = f'data:image/jpeg;base64,{base64.b64encode(buffered.getvalue()).decode("ascii")}'
- id_live_preview = shared.state.id_live_preview
-
- res = InternalProgressResponse(job=shared.state.job, active=active, queued=queued, paused=paused, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo)
- return res
-
-
-def setup_progress_api(app):
- return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=InternalProgressResponse)
+import base64
+import io
+import time
+from pydantic import BaseModel, Field # pylint: disable=no-name-in-module
+import modules.shared as shared
+
+
+current_task = None
+pending_tasks = {}
+finished_tasks = []
+recorded_results = []
+recorded_results_limit = 2
+
+
+def start_task(id_task):
+ global current_task # pylint: disable=global-statement
+ current_task = id_task
+ pending_tasks.pop(id_task, None)
+
+
+def record_results(id_task, res):
+ recorded_results.append((id_task, res))
+ if len(recorded_results) > recorded_results_limit:
+ recorded_results.pop(0)
+
+
+def finish_task(id_task):
+ global current_task # pylint: disable=global-statement
+ if current_task == id_task:
+ current_task = None
+ finished_tasks.append(id_task)
+ if len(finished_tasks) > 16:
+ finished_tasks.pop(0)
+
+
+def add_task_to_queue(id_job):
+ pending_tasks[id_job] = time.time()
+
+
+class ProgressRequest(BaseModel):
+ id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for")
+ id_live_preview: int = Field(default=-1, title="Live preview image ID", description="id of last received last preview image")
+
+
+class InternalProgressResponse(BaseModel):
+ job: str = Field(default=None, title="Job name", description="Internal job name")
+ active: bool = Field(title="Whether the task is being worked on right now")
+ queued: bool = Field(title="Whether the task is in queue")
+ paused: bool = Field(title="Whether the task is paused")
+ completed: bool = Field(title="Whether the task has already finished")
+ progress: float = Field(default=None, title="Progress", description="The progress with a range of 0 to 1")
+ eta: float = Field(default=None, title="ETA in secs")
+ live_preview: str = Field(default=None, title="Live preview image", description="Current live preview; a data: uri")
+ id_live_preview: int = Field(default=None, title="Live preview image ID", description="Send this together with next request to prevent receiving same image")
+ textinfo: str = Field(default=None, title="Info text", description="Info text used by WebUI.")
+
+
+def progressapi(req: ProgressRequest):
+ active = req.id_task == current_task
+ queued = req.id_task in pending_tasks
+ completed = req.id_task in finished_tasks
+ paused = shared.state.paused
+ if not active:
+ return InternalProgressResponse(job=shared.state.job, active=active, queued=queued, paused=paused, completed=completed, id_live_preview=-1, textinfo="Queued..." if queued else "Waiting...")
+ if shared.state.job_no > shared.state.job_count:
+ shared.state.job_count = shared.state.job_no
+ batch_x = max(shared.state.job_no, 0)
+ batch_y = max(shared.state.job_count, 1)
+ step_x = max(shared.state.sampling_step, 0)
+ step_y = max(shared.state.sampling_steps, 1)
+ current = step_y * batch_x + step_x
+ total = step_y * batch_y
+ progress = min(1, abs(current / total) if total > 0 else 0)
+ elapsed = time.time() - shared.state.time_start
+ predicted = elapsed / progress if progress > 0 else None
+ eta = predicted - elapsed if predicted is not None else None
+ # shared.log.debug(f'Progress: step={step_x}:{step_y} batch={batch_x}:{batch_y} current={current} total={total} progress={progress} elapsed={elapsed} eta={eta}')
+
+ id_live_preview = req.id_live_preview
+ live_preview = None
+ shared.state.set_current_image()
+ if shared.opts.live_previews_enable and (shared.state.id_live_preview != req.id_live_preview) and (shared.state.current_image is not None):
+ buffered = io.BytesIO()
+ shared.state.current_image.save(buffered, format='jpeg')
+ live_preview = f'data:image/jpeg;base64,{base64.b64encode(buffered.getvalue()).decode("ascii")}'
+ id_live_preview = shared.state.id_live_preview
+
+ res = InternalProgressResponse(job=shared.state.job, active=active, queued=queued, paused=paused, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo)
+ return res
+
+
+def setup_progress_api(app):
+ return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=InternalProgressResponse)
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py
index f3c6b30c6..10cf91f3e 100644
--- a/modules/prompt_parser.py
+++ b/modules/prompt_parser.py
@@ -1,401 +1,401 @@
-# pylint: disable=anomalous-backslash-in-string
-
-"""
-import os
-import sys
-from rich import print
-sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
-"""
-
-import os
-import re
-from collections import namedtuple
-from typing import List
-import lark
-import torch
-from compel import Compel
-from modules.shared import opts, log, backend, Backend
-
-# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
-# will be represented with prompt_schedule like this (assuming steps=100):
-# [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']
-# [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy']
-# [60, 'fantasy landscape with a lake and an oak in foreground in background masterful']
-# [75, 'fantasy landscape with a lake and an oak in background masterful']
-# [100, 'fantasy landscape with a lake and a christmas tree in background masterful']
-
-round_bracket_multiplier = 1.1
-square_bracket_multiplier = 1.0 / 1.1
-re_AND = re.compile(r"\bAND\b")
-# re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")
-re_weight = re.compile(r"^((?:\s|.)*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")
-ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
-schedule_parser = lark.Lark(r"""
-!start: (prompt | /[][():]/+)*
-prompt: (emphasized | scheduled | alternate | plain | WHITESPACE)*
-!emphasized: "(" prompt ")"
- | "(" prompt ":" prompt ")"
- | "[" prompt "]"
-scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]"
-alternate: "[" prompt ("|" prompt)+ "]"
-WHITESPACE: /\s+/
-plain: /([^\\\[\]():|]|\\.)+/
-%import common.SIGNED_NUMBER -> NUMBER
-""")
-re_clean = re.compile(r"^\W+", re.S)
-re_whitespace = re.compile(r"\s+", re.S)
-re_break = re.compile(r"\s*\bBREAK\b|##\s*", re.S)
-re_attention_v2 = re.compile(r"""
-\(|\[|\\\(|\\\[|\\|\\\\|
-:([+-]?[.\d]+)|
-\)|\]|\\\)|\\\]|
-[^\(\)\[\]:]+|
-:
-""", re.X)
-re_attention_v1 = re.compile(r"""
-\\\(|
-\\\)|
-\\\[|
-\\]|
-\\\\|
-\\|
-\(|
-\[|
-:([+-]?[.\d]+)\)|
-\)|
-]|
-[^\\()\[\]:]+|
-:
-""", re.X)
-
-
-debug_output = os.environ.get('SD_PROMPT_DEBUG', None)
-debug = log.trace if debug_output is not None else lambda *args, **kwargs: None
-debug('Trace: PROMPT')
-
-
-def get_learned_conditioning_prompt_schedules(prompts, steps):
- """
- >>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0]
- >>> g("test")
- [[10, 'test']]
- >>> g("a [b:3]")
- [[3, 'a '], [10, 'a b']]
- >>> g("a [b: 3]")
- [[3, 'a '], [10, 'a b']]
- >>> g("a [[[b]]:2]")
- [[2, 'a '], [10, 'a [[b]]']]
- >>> g("[(a:2):3]")
- [[3, ''], [10, '(a:2)']]
- >>> g("a [b : c : 1] d")
- [[1, 'a b d'], [10, 'a c d']]
- >>> g("a[b:[c:d:2]:1]e")
- [[1, 'abe'], [2, 'ace'], [10, 'ade']]
- >>> g("a [unbalanced")
- [[10, 'a [unbalanced']]
- >>> g("a [b:.5] c")
- [[5, 'a c'], [10, 'a b c']]
- >>> g("a [{b|d{:.5] c") # not handling this right now
- [[5, 'a c'], [10, 'a {b|d{ c']]
- >>> g("((a][:b:c [d:3]")
- [[3, '((a][:b:c '], [10, '((a][:b:c d']]
- >>> g("[a|(b:1.1)]")
- [[1, 'a'], [2, '(b:1.1)'], [3, 'a'], [4, '(b:1.1)'], [5, 'a'], [6, '(b:1.1)'], [7, 'a'], [8, '(b:1.1)'], [9, 'a'], [10, '(b:1.1)']]
- """
-
- def collect_steps(steps, tree):
- res = [steps]
- class CollectSteps(lark.Visitor):
- def scheduled(self, tree):
- tree.children[-1] = float(tree.children[-1])
- if tree.children[-1] < 1:
- tree.children[-1] *= steps
- tree.children[-1] = min(steps, int(tree.children[-1]))
- res.append(tree.children[-1])
- def alternate(self, tree): # pylint: disable=unused-argument
- res.extend(range(1, steps+1))
- CollectSteps().visit(tree)
- return sorted(set(res))
-
- def at_step(step, tree):
- class AtStep(lark.Transformer):
- def scheduled(self, args):
- before, after, _, when = args
- yield before or () if step <= when else after
- def alternate(self, args):
- yield next(args[(step - 1)%len(args)]) # pylint: disable=stop-iteration-return
- def start(self, args):
- def flatten(x):
- if type(x) == str:
- yield x
- else:
- for gen in x:
- yield from flatten(gen)
- return ''.join(flatten(args))
- def plain(self, args):
- yield args[0].value
- def __default__(self, data, children, meta):
- for child in children:
- yield child
- return AtStep().transform(tree)
-
- def get_schedule(prompt):
- try:
- tree = schedule_parser.parse(prompt)
- except lark.exceptions.LarkError:
- return [[steps, prompt]]
- return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)]
-
- promptdict = {prompt: get_schedule(prompt) for prompt in set(prompts)}
- return [promptdict[prompt] for prompt in prompts]
-
-
-def get_learned_conditioning(model, prompts, steps):
- """converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
- and the sampling step at which this condition is to be replaced by the next one.
- Input:
- (model, ['a red crown', 'a [blue:green:5] jeweled crown'], 20)
- Output:
- [
- [ ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0523, ..., -0.4901, -0.3066, 0.0674], ..., [ 0.3317, -0.5102, -0.4066, ..., 0.4119, -0.7647, -1.0160]], device='cuda:0')) ],
- [ ScheduledPromptConditioning(end_at_step=5, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.0192, 0.3867, -0.4644, ..., 0.1135, -0.3696, -0.4625]], device='cuda:0')),
- ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.7352, -0.4356, -0.7888, ..., 0.6994, -0.4312, -1.2593]], device='cuda:0')),
- ]
- ]
- """
- res = []
- prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps)
- cache = {}
- for prompt, prompt_schedule in zip(prompts, prompt_schedules):
- debug(f'Prompt schedule: {prompt_schedule}')
- cached = cache.get(prompt, None)
- if cached is not None:
- res.append(cached)
- continue
- texts = [x[1] for x in prompt_schedule]
- conds = model.get_learned_conditioning(texts)
- cond_schedule = []
- for i, (end_at_step, _text) in enumerate(prompt_schedule):
- cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
- cache[prompt] = cond_schedule
- res.append(cond_schedule)
- return res
-
-
-def get_multicond_prompt_list(prompts):
- res_indexes = []
- prompt_flat_list = []
- prompt_indexes = {}
- for prompt in prompts:
- subprompts = re_AND.split(prompt)
- indexes = []
- for subprompt in subprompts:
- match = re_weight.search(subprompt)
- text, weight = match.groups() if match is not None else (subprompt, 1.0)
- weight = float(weight) if weight is not None else 1.0
- index = prompt_indexes.get(text, None)
- if index is None:
- index = len(prompt_flat_list)
- prompt_flat_list.append(text)
- prompt_indexes[text] = index
- indexes.append((index, weight))
- res_indexes.append(indexes)
- return res_indexes, prompt_flat_list, prompt_indexes
-
-
-class ComposableScheduledPromptConditioning:
- def __init__(self, schedules, weight=1.0):
- self.schedules: List[ScheduledPromptConditioning] = schedules
- self.weight: float = weight
-
-
-class MulticondLearnedConditioning:
- def __init__(self, shape, batch):
- self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
- self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
-
-
-def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
- """same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
- For each prompt, the list is obtained by splitting the prompt using the AND separator.
- https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/
- """
- res_indexes, prompt_flat_list, _prompt_indexes = get_multicond_prompt_list(prompts)
- learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps)
- res = []
- for indexes in res_indexes:
- res.append([ComposableScheduledPromptConditioning(learned_conditioning[i], weight) for i, weight in indexes])
- return MulticondLearnedConditioning(shape=(len(prompts),), batch=res)
-
-
-def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step):
- param = c[0][0].cond
- res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
- for i, cond_schedule in enumerate(c):
- target_index = 0
- for current, (end_at, _cond) in enumerate(cond_schedule):
- if current_step <= end_at:
- target_index = current
- break
- res[i] = cond_schedule[target_index].cond
- return res
-
-
-def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
- param = c.batch[0][0].schedules[0].cond
- tensors = []
- conds_list = []
- for composable_prompts in c.batch:
- conds_for_batch = []
- for composable_prompt in composable_prompts:
- target_index = 0
- for current, entry in enumerate(composable_prompt.schedules):
- if current_step <= entry.end_at_step:
- target_index = current
- break
- conds_for_batch.append((len(tensors), composable_prompt.weight))
- tensors.append(composable_prompt.schedules[target_index].cond)
- conds_list.append(conds_for_batch)
- # if prompts have wildly different lengths above the limit we'll get tensors fo different shapes and won't be able to torch.stack them. So this fixes that.
- token_count = max([x.shape[0] for x in tensors])
- for i in range(len(tensors)):
- if tensors[i].shape[0] != token_count:
- last_vector = tensors[i][-1:]
- last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
- tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
- return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype)
-
-
-def parse_prompt_attention(text):
- """
- Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
- Accepted tokens are:
- (abc) - increases attention to abc by a multiplier of 1.1
- (abc:3.12) - increases attention to abc by a multiplier of 3.12
- [abc] - decreases attention to abc by a multiplier of 1.1
- \( - literal character '('
- \[ - literal character '['
- \) - literal character ')'
- \] - literal character ']'
- \\ - literal character '\'
- anything else - just text
- >>> parse_prompt_attention('normal text')
- [['normal text', 1.0]]
- >>> parse_prompt_attention('an (important) word')
- [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
- >>> parse_prompt_attention('(unbalanced')
- [['unbalanced', 1.1]]
- >>> parse_prompt_attention('\(literal\]')
- [['(literal]', 1.0]]
- >>> parse_prompt_attention('(unnecessary)(parens)')
- [['unnecessaryparens', 1.1]]
- >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
- [['a ', 1.0],
- ['house', 1.5730000000000004],
- [' ', 1.1],
- ['on', 1.0],
- [' a ', 1.1],
- ['hill', 0.55],
- [', sun, ', 1.1],
- ['sky', 1.4641000000000006],
- ['.', 1.1]]
- """
- res = []
- round_brackets = []
- square_brackets = []
- if opts.prompt_attention == 'Fixed attention':
- res = [[text, 1.0]]
- debug(f'Prompt: parser={opts.prompt_attention} {res}')
- return res
- elif opts.prompt_attention == 'Compel parser':
- conjunction = Compel.parse_prompt_string(text)
- if conjunction is None or conjunction.prompts is None or conjunction.prompts is None or len(conjunction.prompts[0].children) == 0:
- return [["", 1.0]]
- res = []
- for frag in conjunction.prompts[0].children:
- res.append([frag.text, frag.weight])
- debug(f'Prompt: parser={opts.prompt_attention} {res}')
- return res
- elif opts.prompt_attention == 'A1111 parser':
- re_attention = re_attention_v1
- whitespace = ''
- else:
- re_attention = re_attention_v1
- if backend == Backend.DIFFUSERS:
- text = text.replace('\n', ' BREAK ')
- else:
- text = text.replace('\n', ' ')
- whitespace = ' '
-
- def multiply_range(start_position, multiplier):
- try:
- for p in range(start_position, len(res)):
- res[p][1] *= multiplier
- except Exception as e:
- log(f'Prompt parser: {e}')
-
- for m in re_attention.finditer(text):
- try:
- section = m.group(0)
- weight = m.group(1)
- if section.startswith('\\'):
- res.append([section[1:], 1.0])
- elif section == '(':
- round_brackets.append(len(res))
- elif section == '[':
- square_brackets.append(len(res))
- elif weight is not None and len(round_brackets) > 0:
- multiply_range(round_brackets.pop(), float(weight))
- elif section == ')' and len(round_brackets) > 0:
- multiply_range(round_brackets.pop(), round_bracket_multiplier)
- elif section == ']' and len(square_brackets) > 0:
- multiply_range(square_brackets.pop(), square_bracket_multiplier)
- else:
- parts = re.split(re_break, section)
- for i, part in enumerate(parts):
- if i > 0:
- res.append(["BREAK", -1])
- if opts.prompt_attention == 'Full parser':
- part = re_clean.sub("", part)
- part = re_whitespace.sub(" ", part).strip()
- if len(part) == 0:
- continue
- res.append([part, 1.0])
- except Exception as e:
- log.error(f'Prompt parser: section="{text[m.start():m.end()]}" position={m.start()}:{m.end()} text="{text}" error={e}')
- for pos in round_brackets:
- multiply_range(pos, round_bracket_multiplier)
- for pos in square_brackets:
- multiply_range(pos, square_bracket_multiplier)
- if len(res) == 0:
- res = [["", 1.0]]
- # merge runs of identical weights
- i = 0
- while i + 1 < len(res):
- if res[i][1] == res[i + 1][1]:
- res[i][0] += whitespace + res[i + 1][0]
- res.pop(i + 1)
- else:
- i += 1
- debug(f'Prompt: parser={opts.prompt_attention} {res}')
- return res
-
-if __name__ == "__main__":
- input_text = '[black] [[grey]] (white) ((gray)) ((orange:1.1) yellow) ((purple) and [dark] red:1.1) [mouse:0.2] [(cat:1.1):0.5]'
- print(f'Prompt: {input_text}')
- all_schedules = get_learned_conditioning_prompt_schedules([input_text], 100)[0]
- print('Schedules', all_schedules)
- for schedule in all_schedules:
- print('Schedule', schedule[0])
- opts.data['prompt_attention'] = 'Fixed attention'
- output_list = parse_prompt_attention(schedule[1])
- print(' Fixed:', output_list)
- opts.data['prompt_attention'] = 'Compel parser'
- output_list = parse_prompt_attention(schedule[1])
- print(' Compel:', output_list)
- opts.data['prompt_attention'] = 'A1111 parser'
- output_list = parse_prompt_attention(schedule[1])
- print(' A1111:', output_list)
- opts.data['prompt_attention'] = 'Full parser'
- output_list = parse_prompt_attention(schedule[1])
- print(' Full :', output_list)
+# pylint: disable=anomalous-backslash-in-string
+
+"""
+import os
+import sys
+from rich import print
+sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
+"""
+
+import os
+import re
+from collections import namedtuple
+from typing import List
+import lark
+import torch
+from compel import Compel
+from modules.shared import opts, log, backend, Backend
+
+# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
+# will be represented with prompt_schedule like this (assuming steps=100):
+# [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']
+# [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy']
+# [60, 'fantasy landscape with a lake and an oak in foreground in background masterful']
+# [75, 'fantasy landscape with a lake and an oak in background masterful']
+# [100, 'fantasy landscape with a lake and a christmas tree in background masterful']
+
+round_bracket_multiplier = 1.1
+square_bracket_multiplier = 1.0 / 1.1
+re_AND = re.compile(r"\bAND\b")
+# re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")
+re_weight = re.compile(r"^((?:\s|.)*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")
+ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
+schedule_parser = lark.Lark(r"""
+!start: (prompt | /[][():]/+)*
+prompt: (emphasized | scheduled | alternate | plain | WHITESPACE)*
+!emphasized: "(" prompt ")"
+ | "(" prompt ":" prompt ")"
+ | "[" prompt "]"
+scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]"
+alternate: "[" prompt ("|" prompt)+ "]"
+WHITESPACE: /\s+/
+plain: /([^\\\[\]():|]|\\.)+/
+%import common.SIGNED_NUMBER -> NUMBER
+""")
+re_clean = re.compile(r"^\W+", re.S)
+re_whitespace = re.compile(r"\s+", re.S)
+re_break = re.compile(r"\s*\bBREAK\b|##\s*", re.S)
+re_attention_v2 = re.compile(r"""
+\(|\[|\\\(|\\\[|\\|\\\\|
+:([+-]?[.\d]+)|
+\)|\]|\\\)|\\\]|
+[^\(\)\[\]:]+|
+:
+""", re.X)
+re_attention_v1 = re.compile(r"""
+\\\(|
+\\\)|
+\\\[|
+\\]|
+\\\\|
+\\|
+\(|
+\[|
+:([+-]?[.\d]+)\)|
+\)|
+]|
+[^\\()\[\]:]+|
+:
+""", re.X)
+
+
+debug_output = os.environ.get('SD_PROMPT_DEBUG', None)
+debug = log.trace if debug_output is not None else lambda *args, **kwargs: None
+debug('Trace: PROMPT')
+
+
+def get_learned_conditioning_prompt_schedules(prompts, steps):
+ """
+ >>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0]
+ >>> g("test")
+ [[10, 'test']]
+ >>> g("a [b:3]")
+ [[3, 'a '], [10, 'a b']]
+ >>> g("a [b: 3]")
+ [[3, 'a '], [10, 'a b']]
+ >>> g("a [[[b]]:2]")
+ [[2, 'a '], [10, 'a [[b]]']]
+ >>> g("[(a:2):3]")
+ [[3, ''], [10, '(a:2)']]
+ >>> g("a [b : c : 1] d")
+ [[1, 'a b d'], [10, 'a c d']]
+ >>> g("a[b:[c:d:2]:1]e")
+ [[1, 'abe'], [2, 'ace'], [10, 'ade']]
+ >>> g("a [unbalanced")
+ [[10, 'a [unbalanced']]
+ >>> g("a [b:.5] c")
+ [[5, 'a c'], [10, 'a b c']]
+ >>> g("a [{b|d{:.5] c") # not handling this right now
+ [[5, 'a c'], [10, 'a {b|d{ c']]
+ >>> g("((a][:b:c [d:3]")
+ [[3, '((a][:b:c '], [10, '((a][:b:c d']]
+ >>> g("[a|(b:1.1)]")
+ [[1, 'a'], [2, '(b:1.1)'], [3, 'a'], [4, '(b:1.1)'], [5, 'a'], [6, '(b:1.1)'], [7, 'a'], [8, '(b:1.1)'], [9, 'a'], [10, '(b:1.1)']]
+ """
+
+ def collect_steps(steps, tree):
+ res = [steps]
+ class CollectSteps(lark.Visitor):
+ def scheduled(self, tree):
+ tree.children[-1] = float(tree.children[-1])
+ if tree.children[-1] < 1:
+ tree.children[-1] *= steps
+ tree.children[-1] = min(steps, int(tree.children[-1]))
+ res.append(tree.children[-1])
+ def alternate(self, tree): # pylint: disable=unused-argument
+ res.extend(range(1, steps+1))
+ CollectSteps().visit(tree)
+ return sorted(set(res))
+
+ def at_step(step, tree):
+ class AtStep(lark.Transformer):
+ def scheduled(self, args):
+ before, after, _, when = args
+ yield before or () if step <= when else after
+ def alternate(self, args):
+ yield next(args[(step - 1)%len(args)]) # pylint: disable=stop-iteration-return
+ def start(self, args):
+ def flatten(x):
+ if type(x) == str:
+ yield x
+ else:
+ for gen in x:
+ yield from flatten(gen)
+ return ''.join(flatten(args))
+ def plain(self, args):
+ yield args[0].value
+ def __default__(self, data, children, meta):
+ for child in children:
+ yield child
+ return AtStep().transform(tree)
+
+ def get_schedule(prompt):
+ try:
+ tree = schedule_parser.parse(prompt)
+ except lark.exceptions.LarkError:
+ return [[steps, prompt]]
+ return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)]
+
+ promptdict = {prompt: get_schedule(prompt) for prompt in set(prompts)}
+ return [promptdict[prompt] for prompt in prompts]
+
+
+def get_learned_conditioning(model, prompts, steps):
+ """converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
+ and the sampling step at which this condition is to be replaced by the next one.
+ Input:
+ (model, ['a red crown', 'a [blue:green:5] jeweled crown'], 20)
+ Output:
+ [
+ [ ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0523, ..., -0.4901, -0.3066, 0.0674], ..., [ 0.3317, -0.5102, -0.4066, ..., 0.4119, -0.7647, -1.0160]], device='cuda:0')) ],
+ [ ScheduledPromptConditioning(end_at_step=5, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.0192, 0.3867, -0.4644, ..., 0.1135, -0.3696, -0.4625]], device='cuda:0')),
+ ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.7352, -0.4356, -0.7888, ..., 0.6994, -0.4312, -1.2593]], device='cuda:0')),
+ ]
+ ]
+ """
+ res = []
+ prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps)
+ cache = {}
+ for prompt, prompt_schedule in zip(prompts, prompt_schedules):
+ debug(f'Prompt schedule: {prompt_schedule}')
+ cached = cache.get(prompt, None)
+ if cached is not None:
+ res.append(cached)
+ continue
+ texts = [x[1] for x in prompt_schedule]
+ conds = model.get_learned_conditioning(texts)
+ cond_schedule = []
+ for i, (end_at_step, _text) in enumerate(prompt_schedule):
+ cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
+ cache[prompt] = cond_schedule
+ res.append(cond_schedule)
+ return res
+
+
+def get_multicond_prompt_list(prompts):
+ res_indexes = []
+ prompt_flat_list = []
+ prompt_indexes = {}
+ for prompt in prompts:
+ subprompts = re_AND.split(prompt)
+ indexes = []
+ for subprompt in subprompts:
+ match = re_weight.search(subprompt)
+ text, weight = match.groups() if match is not None else (subprompt, 1.0)
+ weight = float(weight) if weight is not None else 1.0
+ index = prompt_indexes.get(text, None)
+ if index is None:
+ index = len(prompt_flat_list)
+ prompt_flat_list.append(text)
+ prompt_indexes[text] = index
+ indexes.append((index, weight))
+ res_indexes.append(indexes)
+ return res_indexes, prompt_flat_list, prompt_indexes
+
+
+class ComposableScheduledPromptConditioning:
+ def __init__(self, schedules, weight=1.0):
+ self.schedules: List[ScheduledPromptConditioning] = schedules
+ self.weight: float = weight
+
+
+class MulticondLearnedConditioning:
+ def __init__(self, shape, batch):
+ self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
+ self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
+
+
+def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
+ """same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
+ For each prompt, the list is obtained by splitting the prompt using the AND separator.
+ https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/
+ """
+ res_indexes, prompt_flat_list, _prompt_indexes = get_multicond_prompt_list(prompts)
+ learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps)
+ res = []
+ for indexes in res_indexes:
+ res.append([ComposableScheduledPromptConditioning(learned_conditioning[i], weight) for i, weight in indexes])
+ return MulticondLearnedConditioning(shape=(len(prompts),), batch=res)
+
+
+def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step):
+ param = c[0][0].cond
+ res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
+ for i, cond_schedule in enumerate(c):
+ target_index = 0
+ for current, (end_at, _cond) in enumerate(cond_schedule):
+ if current_step <= end_at:
+ target_index = current
+ break
+ res[i] = cond_schedule[target_index].cond
+ return res
+
+
+def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
+ param = c.batch[0][0].schedules[0].cond
+ tensors = []
+ conds_list = []
+ for composable_prompts in c.batch:
+ conds_for_batch = []
+ for composable_prompt in composable_prompts:
+ target_index = 0
+ for current, entry in enumerate(composable_prompt.schedules):
+ if current_step <= entry.end_at_step:
+ target_index = current
+ break
+ conds_for_batch.append((len(tensors), composable_prompt.weight))
+ tensors.append(composable_prompt.schedules[target_index].cond)
+ conds_list.append(conds_for_batch)
+ # if prompts have wildly different lengths above the limit we'll get tensors fo different shapes and won't be able to torch.stack them. So this fixes that.
+ token_count = max([x.shape[0] for x in tensors])
+ for i in range(len(tensors)):
+ if tensors[i].shape[0] != token_count:
+ last_vector = tensors[i][-1:]
+ last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
+ tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
+ return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype)
+
+
+def parse_prompt_attention(text):
+ """
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
+ Accepted tokens are:
+ (abc) - increases attention to abc by a multiplier of 1.1
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
+ [abc] - decreases attention to abc by a multiplier of 1.1
+ \( - literal character '('
+ \[ - literal character '['
+ \) - literal character ')'
+ \] - literal character ']'
+ \\ - literal character '\'
+ anything else - just text
+ >>> parse_prompt_attention('normal text')
+ [['normal text', 1.0]]
+ >>> parse_prompt_attention('an (important) word')
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
+ >>> parse_prompt_attention('(unbalanced')
+ [['unbalanced', 1.1]]
+ >>> parse_prompt_attention('\(literal\]')
+ [['(literal]', 1.0]]
+ >>> parse_prompt_attention('(unnecessary)(parens)')
+ [['unnecessaryparens', 1.1]]
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
+ [['a ', 1.0],
+ ['house', 1.5730000000000004],
+ [' ', 1.1],
+ ['on', 1.0],
+ [' a ', 1.1],
+ ['hill', 0.55],
+ [', sun, ', 1.1],
+ ['sky', 1.4641000000000006],
+ ['.', 1.1]]
+ """
+ res = []
+ round_brackets = []
+ square_brackets = []
+ if opts.prompt_attention == 'Fixed attention':
+ res = [[text, 1.0]]
+ debug(f'Prompt: parser={opts.prompt_attention} {res}')
+ return res
+ elif opts.prompt_attention == 'Compel parser':
+ conjunction = Compel.parse_prompt_string(text)
+ if conjunction is None or conjunction.prompts is None or conjunction.prompts is None or len(conjunction.prompts[0].children) == 0:
+ return [["", 1.0]]
+ res = []
+ for frag in conjunction.prompts[0].children:
+ res.append([frag.text, frag.weight])
+ debug(f'Prompt: parser={opts.prompt_attention} {res}')
+ return res
+ elif opts.prompt_attention == 'A1111 parser':
+ re_attention = re_attention_v1
+ whitespace = ''
+ else:
+ re_attention = re_attention_v1
+ if backend == Backend.DIFFUSERS:
+ text = text.replace('\n', ' BREAK ')
+ else:
+ text = text.replace('\n', ' ')
+ whitespace = ' '
+
+ def multiply_range(start_position, multiplier):
+ try:
+ for p in range(start_position, len(res)):
+ res[p][1] *= multiplier
+ except Exception as e:
+ log(f'Prompt parser: {e}')
+
+ for m in re_attention.finditer(text):
+ try:
+ section = m.group(0)
+ weight = m.group(1)
+ if section.startswith('\\'):
+ res.append([section[1:], 1.0])
+ elif section == '(':
+ round_brackets.append(len(res))
+ elif section == '[':
+ square_brackets.append(len(res))
+ elif weight is not None and len(round_brackets) > 0:
+ multiply_range(round_brackets.pop(), float(weight))
+ elif section == ')' and len(round_brackets) > 0:
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
+ elif section == ']' and len(square_brackets) > 0:
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
+ else:
+ parts = re.split(re_break, section)
+ for i, part in enumerate(parts):
+ if i > 0:
+ res.append(["BREAK", -1])
+ if opts.prompt_attention == 'Full parser':
+ part = re_clean.sub("", part)
+ part = re_whitespace.sub(" ", part).strip()
+ if len(part) == 0:
+ continue
+ res.append([part, 1.0])
+ except Exception as e:
+ log.error(f'Prompt parser: section="{text[m.start():m.end()]}" position={m.start()}:{m.end()} text="{text}" error={e}')
+ for pos in round_brackets:
+ multiply_range(pos, round_bracket_multiplier)
+ for pos in square_brackets:
+ multiply_range(pos, square_bracket_multiplier)
+ if len(res) == 0:
+ res = [["", 1.0]]
+ # merge runs of identical weights
+ i = 0
+ while i + 1 < len(res):
+ if res[i][1] == res[i + 1][1]:
+ res[i][0] += whitespace + res[i + 1][0]
+ res.pop(i + 1)
+ else:
+ i += 1
+ debug(f'Prompt: parser={opts.prompt_attention} {res}')
+ return res
+
+if __name__ == "__main__":
+ input_text = '[black] [[grey]] (white) ((gray)) ((orange:1.1) yellow) ((purple) and [dark] red:1.1) [mouse:0.2] [(cat:1.1):0.5]'
+ print(f'Prompt: {input_text}')
+ all_schedules = get_learned_conditioning_prompt_schedules([input_text], 100)[0]
+ print('Schedules', all_schedules)
+ for schedule in all_schedules:
+ print('Schedule', schedule[0])
+ opts.data['prompt_attention'] = 'Fixed attention'
+ output_list = parse_prompt_attention(schedule[1])
+ print(' Fixed:', output_list)
+ opts.data['prompt_attention'] = 'Compel parser'
+ output_list = parse_prompt_attention(schedule[1])
+ print(' Compel:', output_list)
+ opts.data['prompt_attention'] = 'A1111 parser'
+ output_list = parse_prompt_attention(schedule[1])
+ print(' A1111:', output_list)
+ opts.data['prompt_attention'] = 'Full parser'
+ output_list = parse_prompt_attention(schedule[1])
+ print(' Full :', output_list)
diff --git a/modules/prompt_parser_diffusers.py b/modules/prompt_parser_diffusers.py
index 6ec9df66c..816a476ae 100644
--- a/modules/prompt_parser_diffusers.py
+++ b/modules/prompt_parser_diffusers.py
@@ -1,223 +1,223 @@
-import os
-import time
-import typing
-import torch
-from compel import ReturnedEmbeddingsType
-from compel.embeddings_provider import BaseTextualInversionManager, EmbeddingsProvider
-from modules import shared, prompt_parser, devices
-
-debug = shared.log.trace if os.environ.get('SD_PROMPT_DEBUG', None) is not None else lambda *args, **kwargs: None
-debug('Trace: PROMPT')
-CLIP_SKIP_MAPPING = {
- None: ReturnedEmbeddingsType.LAST_HIDDEN_STATES_NORMALIZED,
- 1: ReturnedEmbeddingsType.LAST_HIDDEN_STATES_NORMALIZED,
- 2: ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NORMALIZED,
-}
-
-
-# from https://github.com/damian0815/compel/blob/main/src/compel/diffusers_textual_inversion_manager.py
-class DiffusersTextualInversionManager(BaseTextualInversionManager):
- def __init__(self, pipe, tokenizer):
- self.pipe = pipe
- self.tokenizer = tokenizer
- if hasattr(self.pipe, 'embedding_db'):
- self.pipe.embedding_db.embeddings_used.clear()
-
- # from https://github.com/huggingface/diffusers/blob/705c592ea98ba4e288d837b9cba2767623c78603/src/diffusers/loaders.py#L599
- def maybe_convert_prompt(self, prompt: typing.Union[str, typing.List[str]], tokenizer="PreTrainedTokenizer"):
- prompts = [prompt] if not isinstance(prompt, typing.List) else prompt
- prompts = [self._maybe_convert_prompt(p, tokenizer) for p in prompts]
- if not isinstance(prompt, typing.List):
- return prompts[0]
- return prompts
-
- def _maybe_convert_prompt(self, prompt: str, tokenizer="PreTrainedTokenizer"):
- tokens = tokenizer.tokenize(prompt)
- unique_tokens = set(tokens)
- for token in unique_tokens:
- if token in tokenizer.added_tokens_encoder:
- if hasattr(self.pipe, 'embedding_db'):
- self.pipe.embedding_db.embeddings_used.append(token)
- replacement = token
- i = 1
- while f"{token}_{i}" in tokenizer.added_tokens_encoder:
- replacement += f" {token}_{i}"
- i += 1
- prompt = prompt.replace(token, replacement)
- if hasattr(self.pipe, 'embedding_db'):
- self.pipe.embedding_db.embeddings_used = list(set(self.pipe.embedding_db.embeddings_used))
- debug(f'Prompt: convert={prompt}')
- return prompt
-
- def expand_textual_inversion_token_ids_if_necessary(self, token_ids: typing.List[int]) -> typing.List[int]:
- if len(token_ids) == 0:
- return token_ids
- prompt = self.pipe.tokenizer.decode(token_ids)
- prompt = self.maybe_convert_prompt(prompt, self.pipe.tokenizer)
- debug(f'Prompt: expand={prompt}')
- return self.pipe.tokenizer.encode(prompt, add_special_tokens=False)
-
-
-def get_prompt_schedule(p, prompt, steps): # pylint: disable=unused-argument
- t0 = time.time()
- temp = []
- schedule = prompt_parser.get_learned_conditioning_prompt_schedules([prompt], steps)[0]
- if all(x == schedule[0] for x in schedule):
- return [prompt], False
- for chunk in schedule:
- for s in range(steps):
- if len(temp) < s + 1 <= chunk[0]:
- temp.append(chunk[1])
- debug(f'Prompt: schedule={temp} time={time.time() - t0}')
- return temp, len(schedule) > 1
-
-
-def encode_prompts(pipe, p, prompts: list, negative_prompts: list, steps: int, step: int = 1, clip_skip: typing.Optional[int] = None): # pylint: disable=unused-argument
- if 'StableDiffusion' not in pipe.__class__.__name__ and 'DemoFusion':
- shared.log.warning(f"Prompt parser not supported: {pipe.__class__.__name__}")
- return None, None, None, None
- else:
- t0 = time.time()
- positive_schedule, scheduled = get_prompt_schedule(p, prompts[0], steps)
- negative_schedule, neg_scheduled = get_prompt_schedule(p, negative_prompts[0], steps)
- p.scheduled_prompt = scheduled or neg_scheduled
-
- p.prompt_embeds = []
- p.positive_pooleds = []
- p.negative_embeds = []
- p.negative_pooleds = []
-
- cache = {}
- for i in range(max(len(positive_schedule), len(negative_schedule))):
- cached = cache.get(positive_schedule[i % len(positive_schedule)] + negative_schedule[i % len(negative_schedule)], None)
- if cached is not None:
- prompt_embed, positive_pooled, negative_embed, negative_pooled = cached
- else:
- prompt_embed, positive_pooled, negative_embed, negative_pooled = get_weighted_text_embeddings(pipe,
- positive_schedule[i % len(positive_schedule)],
- negative_schedule[i % len(negative_schedule)],
- clip_skip)
- if prompt_embed is not None:
- p.prompt_embeds.append(torch.cat([prompt_embed] * len(prompts), dim=0))
- if negative_embed is not None:
- p.negative_embeds.append(torch.cat([negative_embed] * len(negative_prompts), dim=0))
- if positive_pooled is not None and shared.sd_model_type == "sdxl":
- p.positive_pooleds.append(torch.cat([positive_pooled] * len(prompts), dim=0))
- if negative_pooled is not None and shared.sd_model_type == "sdxl":
- p.negative_pooleds.append(torch.cat([negative_pooled] * len(negative_prompts), dim=0))
- debug(f"Prompt Parser: Elapsed Time {time.time() - t0}")
- return
-
-
-def get_prompts_with_weights(prompt: str):
- manager = DiffusersTextualInversionManager(shared.sd_model, shared.sd_model.tokenizer or shared.sd_model.tokenizer_2)
- prompt = manager.maybe_convert_prompt(prompt, shared.sd_model.tokenizer or shared.sd_model.tokenizer_2)
- texts_and_weights = prompt_parser.parse_prompt_attention(prompt)
- texts = [t for t, w in texts_and_weights]
- text_weights = [w for t, w in texts_and_weights]
- debug(f'Prompt: weights={texts_and_weights}')
- return texts, text_weights
-
-
-def prepare_embedding_providers(pipe, clip_skip):
- device = pipe.device if str(pipe.device) != 'meta' else devices.device
- embeddings_providers = []
- if 'XL' in pipe.__class__.__name__:
- embedding_type = ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED
- else:
- if clip_skip > 2:
- shared.log.warning(f"Prompt parser unsupported: clip_skip={clip_skip}")
- clip_skip = 2
- embedding_type = CLIP_SKIP_MAPPING[clip_skip]
- if getattr(pipe, "tokenizer", None) is not None and getattr(pipe, "text_encoder", None) is not None:
- embedding = EmbeddingsProvider(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder, truncate=False, returned_embeddings_type=embedding_type, device=device)
- embeddings_providers.append(embedding)
- if getattr(pipe, "tokenizer_2", None) is not None and getattr(pipe, "text_encoder_2", None) is not None:
- embedding = EmbeddingsProvider(tokenizer=pipe.tokenizer_2, text_encoder=pipe.text_encoder_2, truncate=False, returned_embeddings_type=embedding_type, device=device)
- embeddings_providers.append(embedding)
- return embeddings_providers
-
-
-def pad_to_same_length(pipe, embeds):
- device = pipe.device if str(pipe.device) != 'meta' else devices.device
- try: # SDXL
- empty_embed = pipe.encode_prompt("")
- except Exception: # SD1.5
- empty_embed = pipe.encode_prompt("", device, 1, False)
- empty_batched = torch.cat([empty_embed[0].to(embeds[0].device)] * embeds[0].shape[0])
- max_token_count = max([embed.shape[1] for embed in embeds])
- for i, embed in enumerate(embeds):
- while embed.shape[1] < max_token_count:
- embed = torch.cat([embed, empty_batched], dim=1)
- embeds[i] = embed
- return embeds
-
-
-def get_weighted_text_embeddings(pipe, prompt: str = "", neg_prompt: str = "", clip_skip: int = None):
- device = pipe.device if str(pipe.device) != 'meta' else devices.device
- prompt_2 = prompt.split("TE2:")[-1]
- neg_prompt_2 = neg_prompt.split("TE2:")[-1]
- prompt = prompt.split("TE2:")[0]
- neg_prompt = neg_prompt.split("TE2:")[0]
-
- ps = [get_prompts_with_weights(p) for p in [prompt, prompt_2]]
- positives = [t for t, w in ps]
- positive_weights = [w for t, w in ps]
- ns = [get_prompts_with_weights(p) for p in [neg_prompt, neg_prompt_2]]
- negatives = [t for t, w in ns]
- negative_weights = [w for t, w in ns]
- if getattr(pipe, "tokenizer_2", None) is not None and getattr(pipe, "tokenizer", None) is None:
- positives.pop(0)
- positive_weights.pop(0)
- negatives.pop(0)
- negative_weights.pop(0)
-
- embedding_providers = prepare_embedding_providers(pipe, clip_skip)
- prompt_embeds = []
- negative_prompt_embeds = []
- pooled_prompt_embeds = None
- negative_pooled_prompt_embeds = None
- for i in range(len(embedding_providers)):
- # add BREAK keyword that splits the prompt into multiple fragments
- text = positives[i]
- weights = positive_weights[i]
- text.append('BREAK')
- weights.append(-1)
- provider_embed = []
- while 'BREAK' in text:
- pos = text.index('BREAK')
- debug(f'Prompt: section="{text[:pos]}" len={len(text[:pos])} weights={weights[:pos]}')
- if len(text[:pos]) > 0:
- embed, ptokens = embedding_providers[i].get_embeddings_for_weighted_prompt_fragments(text_batch=[text[:pos]], fragment_weights_batch=[weights[:pos]], device=device, should_return_tokens=True)
- provider_embed.append(embed)
- text = text[pos + 1:]
- weights = weights[pos + 1:]
- prompt_embeds.append(torch.cat(provider_embed, dim=1))
- # negative prompt has no keywords
- embed, ntokens = embedding_providers[i].get_embeddings_for_weighted_prompt_fragments(text_batch=[negatives[i]], fragment_weights_batch=[negative_weights[i]], device=device, should_return_tokens=True)
- negative_prompt_embeds.append(embed)
-
- if prompt_embeds[-1].shape[-1] > 768:
- if shared.opts.diffusers_pooled == "weighted":
- pooled_prompt_embeds = prompt_embeds[-1][
- torch.arange(prompt_embeds[-1].shape[0], device=device),
- (ptokens.to(dtype=torch.int, device=device) == 49407)
- .int()
- .argmax(dim=-1),
- ]
- negative_pooled_prompt_embeds = negative_prompt_embeds[-1][
- torch.arange(negative_prompt_embeds[-1].shape[0], device=device),
- (ntokens.to(dtype=torch.int, device=device) == 49407)
- .int()
- .argmax(dim=-1),
- ]
- else:
- pooled_prompt_embeds = embedding_providers[-1].get_pooled_embeddings(texts=[prompt_2], device=device) if prompt_embeds[-1].shape[-1] > 768 else None
- negative_pooled_prompt_embeds = embedding_providers[-1].get_pooled_embeddings(texts=[neg_prompt_2], device=device) if negative_prompt_embeds[-1].shape[-1] > 768 else None
-
- prompt_embeds = torch.cat(prompt_embeds, dim=-1) if len(prompt_embeds) > 1 else prompt_embeds[0]
- negative_prompt_embeds = torch.cat(negative_prompt_embeds, dim=-1) if len(negative_prompt_embeds) > 1 else negative_prompt_embeds[0]
- debug(f'Prompt: shape={prompt_embeds.shape} negative={negative_prompt_embeds.shape}')
- if prompt_embeds.shape[1] != negative_prompt_embeds.shape[1]:
- [prompt_embeds, negative_prompt_embeds] = pad_to_same_length(pipe, [prompt_embeds, negative_prompt_embeds])
- return prompt_embeds, pooled_prompt_embeds, negative_prompt_embeds, negative_pooled_prompt_embeds
+import os
+import time
+import typing
+import torch
+from compel import ReturnedEmbeddingsType
+from compel.embeddings_provider import BaseTextualInversionManager, EmbeddingsProvider
+from modules import shared, prompt_parser, devices
+
+debug = shared.log.trace if os.environ.get('SD_PROMPT_DEBUG', None) is not None else lambda *args, **kwargs: None
+debug('Trace: PROMPT')
+CLIP_SKIP_MAPPING = {
+ None: ReturnedEmbeddingsType.LAST_HIDDEN_STATES_NORMALIZED,
+ 1: ReturnedEmbeddingsType.LAST_HIDDEN_STATES_NORMALIZED,
+ 2: ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NORMALIZED,
+}
+
+
+# from https://github.com/damian0815/compel/blob/main/src/compel/diffusers_textual_inversion_manager.py
+class DiffusersTextualInversionManager(BaseTextualInversionManager):
+ def __init__(self, pipe, tokenizer):
+ self.pipe = pipe
+ self.tokenizer = tokenizer
+ if hasattr(self.pipe, 'embedding_db'):
+ self.pipe.embedding_db.embeddings_used.clear()
+
+ # from https://github.com/huggingface/diffusers/blob/705c592ea98ba4e288d837b9cba2767623c78603/src/diffusers/loaders.py#L599
+ def maybe_convert_prompt(self, prompt: typing.Union[str, typing.List[str]], tokenizer="PreTrainedTokenizer"):
+ prompts = [prompt] if not isinstance(prompt, typing.List) else prompt
+ prompts = [self._maybe_convert_prompt(p, tokenizer) for p in prompts]
+ if not isinstance(prompt, typing.List):
+ return prompts[0]
+ return prompts
+
+ def _maybe_convert_prompt(self, prompt: str, tokenizer="PreTrainedTokenizer"):
+ tokens = tokenizer.tokenize(prompt)
+ unique_tokens = set(tokens)
+ for token in unique_tokens:
+ if token in tokenizer.added_tokens_encoder:
+ if hasattr(self.pipe, 'embedding_db'):
+ self.pipe.embedding_db.embeddings_used.append(token)
+ replacement = token
+ i = 1
+ while f"{token}_{i}" in tokenizer.added_tokens_encoder:
+ replacement += f" {token}_{i}"
+ i += 1
+ prompt = prompt.replace(token, replacement)
+ if hasattr(self.pipe, 'embedding_db'):
+ self.pipe.embedding_db.embeddings_used = list(set(self.pipe.embedding_db.embeddings_used))
+ debug(f'Prompt: convert={prompt}')
+ return prompt
+
+ def expand_textual_inversion_token_ids_if_necessary(self, token_ids: typing.List[int]) -> typing.List[int]:
+ if len(token_ids) == 0:
+ return token_ids
+ prompt = self.pipe.tokenizer.decode(token_ids)
+ prompt = self.maybe_convert_prompt(prompt, self.pipe.tokenizer)
+ debug(f'Prompt: expand={prompt}')
+ return self.pipe.tokenizer.encode(prompt, add_special_tokens=False)
+
+
+def get_prompt_schedule(p, prompt, steps): # pylint: disable=unused-argument
+ t0 = time.time()
+ temp = []
+ schedule = prompt_parser.get_learned_conditioning_prompt_schedules([prompt], steps)[0]
+ if all(x == schedule[0] for x in schedule):
+ return [prompt], False
+ for chunk in schedule:
+ for s in range(steps):
+ if len(temp) < s + 1 <= chunk[0]:
+ temp.append(chunk[1])
+ debug(f'Prompt: schedule={temp} time={time.time() - t0}')
+ return temp, len(schedule) > 1
+
+
+def encode_prompts(pipe, p, prompts: list, negative_prompts: list, steps: int, step: int = 1, clip_skip: typing.Optional[int] = None): # pylint: disable=unused-argument
+ if 'StableDiffusion' not in pipe.__class__.__name__ and 'DemoFusion':
+ shared.log.warning(f"Prompt parser not supported: {pipe.__class__.__name__}")
+ return None, None, None, None
+ else:
+ t0 = time.time()
+ positive_schedule, scheduled = get_prompt_schedule(p, prompts[0], steps)
+ negative_schedule, neg_scheduled = get_prompt_schedule(p, negative_prompts[0], steps)
+ p.scheduled_prompt = scheduled or neg_scheduled
+
+ p.prompt_embeds = []
+ p.positive_pooleds = []
+ p.negative_embeds = []
+ p.negative_pooleds = []
+
+ cache = {}
+ for i in range(max(len(positive_schedule), len(negative_schedule))):
+ cached = cache.get(positive_schedule[i % len(positive_schedule)] + negative_schedule[i % len(negative_schedule)], None)
+ if cached is not None:
+ prompt_embed, positive_pooled, negative_embed, negative_pooled = cached
+ else:
+ prompt_embed, positive_pooled, negative_embed, negative_pooled = get_weighted_text_embeddings(pipe,
+ positive_schedule[i % len(positive_schedule)],
+ negative_schedule[i % len(negative_schedule)],
+ clip_skip)
+ if prompt_embed is not None:
+ p.prompt_embeds.append(torch.cat([prompt_embed] * len(prompts), dim=0))
+ if negative_embed is not None:
+ p.negative_embeds.append(torch.cat([negative_embed] * len(negative_prompts), dim=0))
+ if positive_pooled is not None and shared.sd_model_type == "sdxl":
+ p.positive_pooleds.append(torch.cat([positive_pooled] * len(prompts), dim=0))
+ if negative_pooled is not None and shared.sd_model_type == "sdxl":
+ p.negative_pooleds.append(torch.cat([negative_pooled] * len(negative_prompts), dim=0))
+ debug(f"Prompt Parser: Elapsed Time {time.time() - t0}")
+ return
+
+
+def get_prompts_with_weights(prompt: str):
+ manager = DiffusersTextualInversionManager(shared.sd_model, shared.sd_model.tokenizer or shared.sd_model.tokenizer_2)
+ prompt = manager.maybe_convert_prompt(prompt, shared.sd_model.tokenizer or shared.sd_model.tokenizer_2)
+ texts_and_weights = prompt_parser.parse_prompt_attention(prompt)
+ texts = [t for t, w in texts_and_weights]
+ text_weights = [w for t, w in texts_and_weights]
+ debug(f'Prompt: weights={texts_and_weights}')
+ return texts, text_weights
+
+
+def prepare_embedding_providers(pipe, clip_skip):
+ device = pipe.device if str(pipe.device) != 'meta' else devices.device
+ embeddings_providers = []
+ if 'XL' in pipe.__class__.__name__:
+ embedding_type = ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED
+ else:
+ if clip_skip > 2:
+ shared.log.warning(f"Prompt parser unsupported: clip_skip={clip_skip}")
+ clip_skip = 2
+ embedding_type = CLIP_SKIP_MAPPING[clip_skip]
+ if getattr(pipe, "tokenizer", None) is not None and getattr(pipe, "text_encoder", None) is not None:
+ embedding = EmbeddingsProvider(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder, truncate=False, returned_embeddings_type=embedding_type, device=device)
+ embeddings_providers.append(embedding)
+ if getattr(pipe, "tokenizer_2", None) is not None and getattr(pipe, "text_encoder_2", None) is not None:
+ embedding = EmbeddingsProvider(tokenizer=pipe.tokenizer_2, text_encoder=pipe.text_encoder_2, truncate=False, returned_embeddings_type=embedding_type, device=device)
+ embeddings_providers.append(embedding)
+ return embeddings_providers
+
+
+def pad_to_same_length(pipe, embeds):
+ device = pipe.device if str(pipe.device) != 'meta' else devices.device
+ try: # SDXL
+ empty_embed = pipe.encode_prompt("")
+ except Exception: # SD1.5
+ empty_embed = pipe.encode_prompt("", device, 1, False)
+ empty_batched = torch.cat([empty_embed[0].to(embeds[0].device)] * embeds[0].shape[0])
+ max_token_count = max([embed.shape[1] for embed in embeds])
+ for i, embed in enumerate(embeds):
+ while embed.shape[1] < max_token_count:
+ embed = torch.cat([embed, empty_batched], dim=1)
+ embeds[i] = embed
+ return embeds
+
+
+def get_weighted_text_embeddings(pipe, prompt: str = "", neg_prompt: str = "", clip_skip: int = None):
+ device = pipe.device if str(pipe.device) != 'meta' else devices.device
+ prompt_2 = prompt.split("TE2:")[-1]
+ neg_prompt_2 = neg_prompt.split("TE2:")[-1]
+ prompt = prompt.split("TE2:")[0]
+ neg_prompt = neg_prompt.split("TE2:")[0]
+
+ ps = [get_prompts_with_weights(p) for p in [prompt, prompt_2]]
+ positives = [t for t, w in ps]
+ positive_weights = [w for t, w in ps]
+ ns = [get_prompts_with_weights(p) for p in [neg_prompt, neg_prompt_2]]
+ negatives = [t for t, w in ns]
+ negative_weights = [w for t, w in ns]
+ if getattr(pipe, "tokenizer_2", None) is not None and getattr(pipe, "tokenizer", None) is None:
+ positives.pop(0)
+ positive_weights.pop(0)
+ negatives.pop(0)
+ negative_weights.pop(0)
+
+ embedding_providers = prepare_embedding_providers(pipe, clip_skip)
+ prompt_embeds = []
+ negative_prompt_embeds = []
+ pooled_prompt_embeds = None
+ negative_pooled_prompt_embeds = None
+ for i in range(len(embedding_providers)):
+ # add BREAK keyword that splits the prompt into multiple fragments
+ text = positives[i]
+ weights = positive_weights[i]
+ text.append('BREAK')
+ weights.append(-1)
+ provider_embed = []
+ while 'BREAK' in text:
+ pos = text.index('BREAK')
+ debug(f'Prompt: section="{text[:pos]}" len={len(text[:pos])} weights={weights[:pos]}')
+ if len(text[:pos]) > 0:
+ embed, ptokens = embedding_providers[i].get_embeddings_for_weighted_prompt_fragments(text_batch=[text[:pos]], fragment_weights_batch=[weights[:pos]], device=device, should_return_tokens=True)
+ provider_embed.append(embed)
+ text = text[pos + 1:]
+ weights = weights[pos + 1:]
+ prompt_embeds.append(torch.cat(provider_embed, dim=1))
+ # negative prompt has no keywords
+ embed, ntokens = embedding_providers[i].get_embeddings_for_weighted_prompt_fragments(text_batch=[negatives[i]], fragment_weights_batch=[negative_weights[i]], device=device, should_return_tokens=True)
+ negative_prompt_embeds.append(embed)
+
+ if prompt_embeds[-1].shape[-1] > 768:
+ if shared.opts.diffusers_pooled == "weighted":
+ pooled_prompt_embeds = prompt_embeds[-1][
+ torch.arange(prompt_embeds[-1].shape[0], device=device),
+ (ptokens.to(dtype=torch.int, device=device) == 49407)
+ .int()
+ .argmax(dim=-1),
+ ]
+ negative_pooled_prompt_embeds = negative_prompt_embeds[-1][
+ torch.arange(negative_prompt_embeds[-1].shape[0], device=device),
+ (ntokens.to(dtype=torch.int, device=device) == 49407)
+ .int()
+ .argmax(dim=-1),
+ ]
+ else:
+ pooled_prompt_embeds = embedding_providers[-1].get_pooled_embeddings(texts=[prompt_2], device=device) if prompt_embeds[-1].shape[-1] > 768 else None
+ negative_pooled_prompt_embeds = embedding_providers[-1].get_pooled_embeddings(texts=[neg_prompt_2], device=device) if negative_prompt_embeds[-1].shape[-1] > 768 else None
+
+ prompt_embeds = torch.cat(prompt_embeds, dim=-1) if len(prompt_embeds) > 1 else prompt_embeds[0]
+ negative_prompt_embeds = torch.cat(negative_prompt_embeds, dim=-1) if len(negative_prompt_embeds) > 1 else negative_prompt_embeds[0]
+ debug(f'Prompt: shape={prompt_embeds.shape} negative={negative_prompt_embeds.shape}')
+ if prompt_embeds.shape[1] != negative_prompt_embeds.shape[1]:
+ [prompt_embeds, negative_prompt_embeds] = pad_to_same_length(pipe, [prompt_embeds, negative_prompt_embeds])
+ return prompt_embeds, pooled_prompt_embeds, negative_prompt_embeds, negative_pooled_prompt_embeds
diff --git a/modules/rife/model_ifnet.py b/modules/rife/model_ifnet.py
index ec667425b..eb6fbed07 100644
--- a/modules/rife/model_ifnet.py
+++ b/modules/rife/model_ifnet.py
@@ -1,134 +1,134 @@
-import os
-import sys
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-sys.path.append(os.path.dirname(__file__))
-from warplayer import warp # pylint: disable=wrong-import-position
-
-
-device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
-
-def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
- return nn.Sequential(
- nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
- padding=padding, dilation=dilation, bias=True),
- nn.LeakyReLU(0.2, True)
- )
-
-def conv_bn(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
- return nn.Sequential(
- nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
- padding=padding, dilation=dilation, bias=False),
- nn.BatchNorm2d(out_planes),
- nn.LeakyReLU(0.2, True)
- )
-
-class ResConv(nn.Module):
- def __init__(self, c, dilation=1):
- super(ResConv, self).__init__()
- self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1\
-)
- self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True)
- self.relu = nn.LeakyReLU(0.2, True)
-
- def forward(self, x):
- return self.relu(self.conv(x) * self.beta + x)
-
-class IFBlock(nn.Module):
- def __init__(self, in_planes, c=64):
- super(IFBlock, self).__init__()
- self.conv0 = nn.Sequential(
- conv(in_planes, c//2, 3, 2, 1),
- conv(c//2, c, 3, 2, 1),
- )
- self.convblock = nn.Sequential(
- ResConv(c),
- ResConv(c),
- ResConv(c),
- ResConv(c),
- ResConv(c),
- ResConv(c),
- ResConv(c),
- ResConv(c),
- )
- self.lastconv = nn.Sequential(
- nn.ConvTranspose2d(c, 4*6, 4, 2, 1),
- nn.PixelShuffle(2)
- )
-
- def forward(self, x, flow=None, scale=1):
- x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False)
- if flow is not None:
- flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False) * 1. / scale
- x = torch.cat((x, flow), 1)
- feat = self.conv0(x)
- feat = self.convblock(feat)
- tmp = self.lastconv(feat)
- tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear", align_corners=False)
- flow = tmp[:, :4] * scale
- mask = tmp[:, 4:5]
- return flow, mask
-
-class IFNet(nn.Module):
- def __init__(self):
- super(IFNet, self).__init__()
- self.block0 = IFBlock(7, c=192)
- self.block1 = IFBlock(8+4, c=128)
- self.block2 = IFBlock(8+4, c=96)
- self.block3 = IFBlock(8+4, c=64)
- # self.contextnet = Contextnet()
- # self.unet = Unet()
-
- def forward( self, x, timestep=0.5, scale_list=[8, 4, 2, 1], training=False, fastmode=True, ensemble=False): # pylint: disable=dangerous-default-value # noqa: B006
- if training is False:
- channel = x.shape[1] // 2
- img0 = x[:, :channel]
- img1 = x[:, channel:]
- if not torch.is_tensor(timestep):
- timestep = (x[:, :1].clone() * 0 + 1) * timestep
- else:
- timestep = timestep.repeat(1, 1, img0.shape[2], img0.shape[3])
- flow_list = []
- merged = []
- mask_list = []
- warped_img0 = img0
- warped_img1 = img1
- flow = None
- mask = None
- # loss_cons = 0
- block = [self.block0, self.block1, self.block2, self.block3]
- for i in range(4):
- if flow is None:
- flow, mask = block[i](torch.cat((img0[:, :3], img1[:, :3], timestep), 1), None, scale=scale_list[i])
- if ensemble:
- f1, m1 = block[i](torch.cat((img1[:, :3], img0[:, :3], 1-timestep), 1), None, scale=scale_list[i])
- flow = (flow + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2
- mask = (mask + (-m1)) / 2
- else:
- f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], timestep, mask), 1), flow, scale=scale_list[i])
- if ensemble:
- f1, m1 = block[i](torch.cat((warped_img1[:, :3], warped_img0[:, :3], 1-timestep, -mask), 1), torch.cat((flow[:, 2:4], flow[:, :2]), 1), scale=scale_list[i]) # pylint: disable=invalid-unary-operand-type
- f0 = (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2
- m0 = (m0 + (-m1)) / 2
- flow = flow + f0
- mask = mask + m0
- mask_list.append(mask)
- flow_list.append(flow)
- warped_img0 = warp(img0, flow[:, :2])
- warped_img1 = warp(img1, flow[:, 2:4])
- merged.append((warped_img0, warped_img1))
- mask_list[3] = torch.sigmoid(mask_list[3])
- merged[3] = merged[3][0] * mask_list[3] + merged[3][1] * (1 - mask_list[3])
- if not fastmode:
- print('contextnet is removed')
- '''
- c0 = self.contextnet(img0, flow[:, :2])
- c1 = self.contextnet(img1, flow[:, 2:4])
- tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
- res = tmp[:, :3] * 2 - 1
- merged[3] = torch.clamp(merged[3] + res, 0, 1)
- '''
- return flow_list, mask_list[3], merged
+import os
+import sys
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+sys.path.append(os.path.dirname(__file__))
+from warplayer import warp # pylint: disable=wrong-import-position
+
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+
+def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
+ return nn.Sequential(
+ nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
+ padding=padding, dilation=dilation, bias=True),
+ nn.LeakyReLU(0.2, True)
+ )
+
+def conv_bn(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
+ return nn.Sequential(
+ nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
+ padding=padding, dilation=dilation, bias=False),
+ nn.BatchNorm2d(out_planes),
+ nn.LeakyReLU(0.2, True)
+ )
+
+class ResConv(nn.Module):
+ def __init__(self, c, dilation=1):
+ super(ResConv, self).__init__()
+ self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1\
+)
+ self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True)
+ self.relu = nn.LeakyReLU(0.2, True)
+
+ def forward(self, x):
+ return self.relu(self.conv(x) * self.beta + x)
+
+class IFBlock(nn.Module):
+ def __init__(self, in_planes, c=64):
+ super(IFBlock, self).__init__()
+ self.conv0 = nn.Sequential(
+ conv(in_planes, c//2, 3, 2, 1),
+ conv(c//2, c, 3, 2, 1),
+ )
+ self.convblock = nn.Sequential(
+ ResConv(c),
+ ResConv(c),
+ ResConv(c),
+ ResConv(c),
+ ResConv(c),
+ ResConv(c),
+ ResConv(c),
+ ResConv(c),
+ )
+ self.lastconv = nn.Sequential(
+ nn.ConvTranspose2d(c, 4*6, 4, 2, 1),
+ nn.PixelShuffle(2)
+ )
+
+ def forward(self, x, flow=None, scale=1):
+ x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False)
+ if flow is not None:
+ flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False) * 1. / scale
+ x = torch.cat((x, flow), 1)
+ feat = self.conv0(x)
+ feat = self.convblock(feat)
+ tmp = self.lastconv(feat)
+ tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear", align_corners=False)
+ flow = tmp[:, :4] * scale
+ mask = tmp[:, 4:5]
+ return flow, mask
+
+class IFNet(nn.Module):
+ def __init__(self):
+ super(IFNet, self).__init__()
+ self.block0 = IFBlock(7, c=192)
+ self.block1 = IFBlock(8+4, c=128)
+ self.block2 = IFBlock(8+4, c=96)
+ self.block3 = IFBlock(8+4, c=64)
+ # self.contextnet = Contextnet()
+ # self.unet = Unet()
+
+ def forward( self, x, timestep=0.5, scale_list=[8, 4, 2, 1], training=False, fastmode=True, ensemble=False): # pylint: disable=dangerous-default-value # noqa: B006
+ if training is False:
+ channel = x.shape[1] // 2
+ img0 = x[:, :channel]
+ img1 = x[:, channel:]
+ if not torch.is_tensor(timestep):
+ timestep = (x[:, :1].clone() * 0 + 1) * timestep
+ else:
+ timestep = timestep.repeat(1, 1, img0.shape[2], img0.shape[3])
+ flow_list = []
+ merged = []
+ mask_list = []
+ warped_img0 = img0
+ warped_img1 = img1
+ flow = None
+ mask = None
+ # loss_cons = 0
+ block = [self.block0, self.block1, self.block2, self.block3]
+ for i in range(4):
+ if flow is None:
+ flow, mask = block[i](torch.cat((img0[:, :3], img1[:, :3], timestep), 1), None, scale=scale_list[i])
+ if ensemble:
+ f1, m1 = block[i](torch.cat((img1[:, :3], img0[:, :3], 1-timestep), 1), None, scale=scale_list[i])
+ flow = (flow + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2
+ mask = (mask + (-m1)) / 2
+ else:
+ f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], timestep, mask), 1), flow, scale=scale_list[i])
+ if ensemble:
+ f1, m1 = block[i](torch.cat((warped_img1[:, :3], warped_img0[:, :3], 1-timestep, -mask), 1), torch.cat((flow[:, 2:4], flow[:, :2]), 1), scale=scale_list[i]) # pylint: disable=invalid-unary-operand-type
+ f0 = (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2
+ m0 = (m0 + (-m1)) / 2
+ flow = flow + f0
+ mask = mask + m0
+ mask_list.append(mask)
+ flow_list.append(flow)
+ warped_img0 = warp(img0, flow[:, :2])
+ warped_img1 = warp(img1, flow[:, 2:4])
+ merged.append((warped_img0, warped_img1))
+ mask_list[3] = torch.sigmoid(mask_list[3])
+ merged[3] = merged[3][0] * mask_list[3] + merged[3][1] * (1 - mask_list[3])
+ if not fastmode:
+ print('contextnet is removed')
+ '''
+ c0 = self.contextnet(img0, flow[:, :2])
+ c1 = self.contextnet(img1, flow[:, 2:4])
+ tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
+ res = tmp[:, :3] * 2 - 1
+ merged[3] = torch.clamp(merged[3] + res, 0, 1)
+ '''
+ return flow_list, mask_list[3], merged
diff --git a/modules/rife/refine.py b/modules/rife/refine.py
index a66488345..5d77582cc 100644
--- a/modules/rife/refine.py
+++ b/modules/rife/refine.py
@@ -1,90 +1,90 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from modules.rife.warplayer import warp
-
-
-c = 16
-
-
-def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
- return nn.Sequential(
- nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True),
- nn.LeakyReLU(0.2, True)
- )
-
-
-def conv_woact(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
- return nn.Sequential(
- nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True),
- )
-
-
-def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): # pylint: disable=unused-argument
- return nn.Sequential(
- torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1, bias=True),
- nn.LeakyReLU(0.2, True)
- )
-
-
-class Conv2(nn.Module):
- def __init__(self, in_planes, out_planes, stride=2):
- super(Conv2, self).__init__()
- self.conv1 = conv(in_planes, out_planes, 3, stride, 1)
- self.conv2 = conv(out_planes, out_planes, 3, 1, 1)
-
- def forward(self, x):
- x = self.conv1(x)
- x = self.conv2(x)
- return x
-
-
-class Contextnet(nn.Module):
- def __init__(self):
- super(Contextnet, self).__init__()
- self.conv1 = Conv2(3, c)
- self.conv2 = Conv2(c, 2*c)
- self.conv3 = Conv2(2*c, 4*c)
- self.conv4 = Conv2(4*c, 8*c)
-
- def forward(self, x, flow):
- x = self.conv1(x)
- flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
- f1 = warp(x, flow)
- x = self.conv2(x)
- flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
- f2 = warp(x, flow)
- x = self.conv3(x)
- flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
- f3 = warp(x, flow)
- x = self.conv4(x)
- flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
- f4 = warp(x, flow)
- return [f1, f2, f3, f4]
-
-
-class Unet(nn.Module):
- def __init__(self):
- super(Unet, self).__init__()
- self.down0 = Conv2(17, 2*c)
- self.down1 = Conv2(4*c, 4*c)
- self.down2 = Conv2(8*c, 8*c)
- self.down3 = Conv2(16*c, 16*c)
- self.up0 = deconv(32*c, 8*c)
- self.up1 = deconv(16*c, 4*c)
- self.up2 = deconv(8*c, 2*c)
- self.up3 = deconv(4*c, c)
- self.conv = nn.Conv2d(c, 3, 3, 1, 1)
-
- def forward(self, img0, img1, warped_img0, warped_img1, mask, flow, c0, c1):
- s0 = self.down0(
- torch.cat((img0, img1, warped_img0, warped_img1, mask, flow), 1))
- s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1))
- s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1))
- s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1))
- x = self.up0(torch.cat((s3, c0[3], c1[3]), 1))
- x = self.up1(torch.cat((x, s2), 1))
- x = self.up2(torch.cat((x, s1), 1))
- x = self.up3(torch.cat((x, s0), 1))
- x = self.conv(x)
- return torch.sigmoid(x)
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from modules.rife.warplayer import warp
+
+
+c = 16
+
+
+def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
+ return nn.Sequential(
+ nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True),
+ nn.LeakyReLU(0.2, True)
+ )
+
+
+def conv_woact(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
+ return nn.Sequential(
+ nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True),
+ )
+
+
+def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): # pylint: disable=unused-argument
+ return nn.Sequential(
+ torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1, bias=True),
+ nn.LeakyReLU(0.2, True)
+ )
+
+
+class Conv2(nn.Module):
+ def __init__(self, in_planes, out_planes, stride=2):
+ super(Conv2, self).__init__()
+ self.conv1 = conv(in_planes, out_planes, 3, stride, 1)
+ self.conv2 = conv(out_planes, out_planes, 3, 1, 1)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.conv2(x)
+ return x
+
+
+class Contextnet(nn.Module):
+ def __init__(self):
+ super(Contextnet, self).__init__()
+ self.conv1 = Conv2(3, c)
+ self.conv2 = Conv2(c, 2*c)
+ self.conv3 = Conv2(2*c, 4*c)
+ self.conv4 = Conv2(4*c, 8*c)
+
+ def forward(self, x, flow):
+ x = self.conv1(x)
+ flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
+ f1 = warp(x, flow)
+ x = self.conv2(x)
+ flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
+ f2 = warp(x, flow)
+ x = self.conv3(x)
+ flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
+ f3 = warp(x, flow)
+ x = self.conv4(x)
+ flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
+ f4 = warp(x, flow)
+ return [f1, f2, f3, f4]
+
+
+class Unet(nn.Module):
+ def __init__(self):
+ super(Unet, self).__init__()
+ self.down0 = Conv2(17, 2*c)
+ self.down1 = Conv2(4*c, 4*c)
+ self.down2 = Conv2(8*c, 8*c)
+ self.down3 = Conv2(16*c, 16*c)
+ self.up0 = deconv(32*c, 8*c)
+ self.up1 = deconv(16*c, 4*c)
+ self.up2 = deconv(8*c, 2*c)
+ self.up3 = deconv(4*c, c)
+ self.conv = nn.Conv2d(c, 3, 3, 1, 1)
+
+ def forward(self, img0, img1, warped_img0, warped_img1, mask, flow, c0, c1):
+ s0 = self.down0(
+ torch.cat((img0, img1, warped_img0, warped_img1, mask, flow), 1))
+ s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1))
+ s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1))
+ s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1))
+ x = self.up0(torch.cat((s3, c0[3], c1[3]), 1))
+ x = self.up1(torch.cat((x, s2), 1))
+ x = self.up2(torch.cat((x, s1), 1))
+ x = self.up3(torch.cat((x, s0), 1))
+ x = self.conv(x)
+ return torch.sigmoid(x)
diff --git a/modules/safe.py b/modules/safe.py
index b3b98897e..21b8c08b2 100644
--- a/modules/safe.py
+++ b/modules/safe.py
@@ -1,179 +1,179 @@
-# this code is adapted from the script contributed by anon from /h/
-
-import pickle
-import collections
-import zipfile
-import re
-
-import torch
-import numpy as np
-import _codecs
-
-# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
-TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage # pylint: disable=protected-access
-
-
-def encode(*args):
- out = _codecs.encode(*args)
- return out
-
-
-class RestrictedUnpickler(pickle.Unpickler):
- extra_handler = None
-
- def persistent_load(self, saved_id):
- assert saved_id[0] == 'storage'
- try:
- return TypedStorage(_internal=True)
- except TypeError:
- return TypedStorage() # PyTorch before 2.0 does not have the _internal argument
-
- def find_class(self, module, name):
- if self.extra_handler is not None:
- res = self.extra_handler(module, name)
- if res is not None:
- return res
-
- if module == 'collections' and name == 'OrderedDict':
- return getattr(collections, name)
- if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
- return getattr(torch._utils, name) # pylint: disable=protected-access
- if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32', 'BFloat16Storage']:
- return getattr(torch, name)
- if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
- return getattr(torch.nn.modules.container, name)
- if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']:
- return getattr(np.core.multiarray, name)
- if module == 'numpy' and name in ['dtype', 'ndarray']:
- return getattr(np, name)
- if module == '_codecs' and name == 'encode':
- return encode
- if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
- import pytorch_lightning.callbacks
- return pytorch_lightning.callbacks.model_checkpoint
- if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
- import pytorch_lightning.callbacks.model_checkpoint
- return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
- if module == "__builtin__" and name == 'set':
- return set
-
- # Forbid everything else.
- raise Exception(f"global '{module}/{name}' is forbidden") # pylint: disable=broad-exception-raised
-
-
-# Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/'
-allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$")
-data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
-
-def check_zip_filenames(filename, names):
- for name in names:
- if allowed_zip_names_re.match(name):
- continue
-
- raise Exception(f"bad file inside {filename}: {name}") # pylint: disable=broad-exception-raised
-
-
-def check_pt(filename, extra_handler):
- try:
-
- # new pytorch format is a zip file
- with zipfile.ZipFile(filename) as z:
- check_zip_filenames(filename, z.namelist())
-
- # find filename of data.pkl in zip file: '/data.pkl'
- data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
- if len(data_pkl_filenames) == 0:
- raise Exception(f"data.pkl not found in {filename}") # pylint: disable=broad-exception-raised
- if len(data_pkl_filenames) > 1:
- raise Exception(f"Multiple data.pkl found in {filename}") # pylint: disable=broad-exception-raised
- with z.open(data_pkl_filenames[0]) as file:
- unpickler = RestrictedUnpickler(file)
- unpickler.extra_handler = extra_handler
- unpickler.load()
-
- except zipfile.BadZipfile:
-
- # if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
- with open(filename, "rb") as file:
- unpickler = RestrictedUnpickler(file)
- unpickler.extra_handler = extra_handler
- for _i in range(5):
- unpickler.load()
-
-
-def load(filename, *args, **kwargs):
- return load_with_extra(filename, *args, extra_handler=global_extra_handler, **kwargs)
-
-
-def load_with_extra(filename, extra_handler=None, *args, **kwargs): # pylint: disable=keyword-arg-before-vararg
- """
- this function is intended to be used by extensions that want to load models with
- some extra classes in them that the usual unpickler would find suspicious.
-
- Use the extra_handler argument to specify a function that takes module and field name as text,
- and returns that field's value:
-
- ```python
- def extra(module, name):
- if module == 'collections' and name == 'OrderedDict':
- return collections.OrderedDict
-
- return None
-
- safe.load_with_extra('model.pt', extra_handler=extra)
- ```
-
- The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
- definitely unsafe.
- """
-
- from modules import shared, errors
-
- try:
- if not shared.cmd_opts.disable_safe_unpickle:
- check_pt(filename, extra_handler)
- except Exception as e:
- errors.display(e, f'verifying pickled file {filename}')
- return None
-
- return unsafe_torch_load(filename, *args, **kwargs)
-
-
-class Extra:
- """
- A class for temporarily setting the global handler for when you can't explicitly call load_with_extra
- (because it's not your code making the torch.load call). The intended use is like this:
-
-```
-import torch
-from modules import safe
-
-def handler(module, name):
- if module == 'torch' and name in ['float64', 'float16']:
- return getattr(torch, name)
-
- return None
-
-with safe.Extra(handler):
- x = torch.load('model.pt')
-```
- """
-
- def __init__(self, handler):
- self.handler = handler
-
- def __enter__(self):
- global global_extra_handler # pylint: disable=global-statement
-
- assert global_extra_handler is None, 'already inside an Extra() block'
- global_extra_handler = self.handler
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- global global_extra_handler # pylint: disable=global-statement
-
- global_extra_handler = None
-
-
-unsafe_torch_load = torch.load
-torch.load = load
-global_extra_handler = None
+# this code is adapted from the script contributed by anon from /h/
+
+import pickle
+import collections
+import zipfile
+import re
+
+import torch
+import numpy as np
+import _codecs
+
+# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
+TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage # pylint: disable=protected-access
+
+
+def encode(*args):
+ out = _codecs.encode(*args)
+ return out
+
+
+class RestrictedUnpickler(pickle.Unpickler):
+ extra_handler = None
+
+ def persistent_load(self, saved_id):
+ assert saved_id[0] == 'storage'
+ try:
+ return TypedStorage(_internal=True)
+ except TypeError:
+ return TypedStorage() # PyTorch before 2.0 does not have the _internal argument
+
+ def find_class(self, module, name):
+ if self.extra_handler is not None:
+ res = self.extra_handler(module, name)
+ if res is not None:
+ return res
+
+ if module == 'collections' and name == 'OrderedDict':
+ return getattr(collections, name)
+ if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
+ return getattr(torch._utils, name) # pylint: disable=protected-access
+ if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32', 'BFloat16Storage']:
+ return getattr(torch, name)
+ if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
+ return getattr(torch.nn.modules.container, name)
+ if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']:
+ return getattr(np.core.multiarray, name)
+ if module == 'numpy' and name in ['dtype', 'ndarray']:
+ return getattr(np, name)
+ if module == '_codecs' and name == 'encode':
+ return encode
+ if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
+ import pytorch_lightning.callbacks
+ return pytorch_lightning.callbacks.model_checkpoint
+ if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
+ import pytorch_lightning.callbacks.model_checkpoint
+ return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
+ if module == "__builtin__" and name == 'set':
+ return set
+
+ # Forbid everything else.
+ raise Exception(f"global '{module}/{name}' is forbidden") # pylint: disable=broad-exception-raised
+
+
+# Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/'
+allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$")
+data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
+
+def check_zip_filenames(filename, names):
+ for name in names:
+ if allowed_zip_names_re.match(name):
+ continue
+
+ raise Exception(f"bad file inside {filename}: {name}") # pylint: disable=broad-exception-raised
+
+
+def check_pt(filename, extra_handler):
+ try:
+
+ # new pytorch format is a zip file
+ with zipfile.ZipFile(filename) as z:
+ check_zip_filenames(filename, z.namelist())
+
+ # find filename of data.pkl in zip file: '/data.pkl'
+ data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
+ if len(data_pkl_filenames) == 0:
+ raise Exception(f"data.pkl not found in {filename}") # pylint: disable=broad-exception-raised
+ if len(data_pkl_filenames) > 1:
+ raise Exception(f"Multiple data.pkl found in {filename}") # pylint: disable=broad-exception-raised
+ with z.open(data_pkl_filenames[0]) as file:
+ unpickler = RestrictedUnpickler(file)
+ unpickler.extra_handler = extra_handler
+ unpickler.load()
+
+ except zipfile.BadZipfile:
+
+ # if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
+ with open(filename, "rb") as file:
+ unpickler = RestrictedUnpickler(file)
+ unpickler.extra_handler = extra_handler
+ for _i in range(5):
+ unpickler.load()
+
+
+def load(filename, *args, **kwargs):
+ return load_with_extra(filename, *args, extra_handler=global_extra_handler, **kwargs)
+
+
+def load_with_extra(filename, extra_handler=None, *args, **kwargs): # pylint: disable=keyword-arg-before-vararg
+ """
+ this function is intended to be used by extensions that want to load models with
+ some extra classes in them that the usual unpickler would find suspicious.
+
+ Use the extra_handler argument to specify a function that takes module and field name as text,
+ and returns that field's value:
+
+ ```python
+ def extra(module, name):
+ if module == 'collections' and name == 'OrderedDict':
+ return collections.OrderedDict
+
+ return None
+
+ safe.load_with_extra('model.pt', extra_handler=extra)
+ ```
+
+ The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
+ definitely unsafe.
+ """
+
+ from modules import shared, errors
+
+ try:
+ if not shared.cmd_opts.disable_safe_unpickle:
+ check_pt(filename, extra_handler)
+ except Exception as e:
+ errors.display(e, f'verifying pickled file {filename}')
+ return None
+
+ return unsafe_torch_load(filename, *args, **kwargs)
+
+
+class Extra:
+ """
+ A class for temporarily setting the global handler for when you can't explicitly call load_with_extra
+ (because it's not your code making the torch.load call). The intended use is like this:
+
+```
+import torch
+from modules import safe
+
+def handler(module, name):
+ if module == 'torch' and name in ['float64', 'float16']:
+ return getattr(torch, name)
+
+ return None
+
+with safe.Extra(handler):
+ x = torch.load('model.pt')
+```
+ """
+
+ def __init__(self, handler):
+ self.handler = handler
+
+ def __enter__(self):
+ global global_extra_handler # pylint: disable=global-statement
+
+ assert global_extra_handler is None, 'already inside an Extra() block'
+ global_extra_handler = self.handler
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ global global_extra_handler # pylint: disable=global-statement
+
+ global_extra_handler = None
+
+
+unsafe_torch_load = torch.load
+torch.load = load
+global_extra_handler = None
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 6454db2f2..9b57c69c8 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -1,507 +1,507 @@
-import os
-import sys
-import time
-from collections import namedtuple
-from typing import Optional, Dict, Any
-from fastapi import FastAPI
-from gradio import Blocks
-import modules.errors as errors
-
-
-def report_exception(e, c, job):
- errors.display(e, f'executing callback: {c.script} {job}')
-
-
-class ImageSaveParams:
- def __init__(self, image, p, filename, pnginfo):
- self.image = image
- """the PIL image itself"""
-
- self.p = p
- """p object with processing parameters; either StableDiffusionProcessing or an object with same fields"""
-
- self.filename = filename
- """name of file that the image would be saved to"""
-
- self.pnginfo = pnginfo
- """dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'"""
-
-
-class CFGDenoiserParams:
- def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond):
- self.x = x
- """Latent image representation in the process of being denoised"""
-
- self.image_cond = image_cond
- """Conditioning image"""
-
- self.sigma = sigma
- """Current sigma noise step value"""
-
- self.sampling_step = sampling_step
- """Current Sampling step number"""
-
- self.total_sampling_steps = total_sampling_steps
- """Total number of sampling steps planned"""
-
- self.text_cond = text_cond
- """ Encoder hidden states of text conditioning from prompt"""
-
- self.text_uncond = text_uncond
- """ Encoder hidden states of text conditioning from negative prompt"""
-
-
-class CFGDenoisedParams:
- def __init__(self, x, sampling_step, total_sampling_steps, inner_model):
- self.x = x
- """Latent image representation in the process of being denoised"""
-
- self.sampling_step = sampling_step
- """Current Sampling step number"""
-
- self.total_sampling_steps = total_sampling_steps
- """Total number of sampling steps planned"""
-
- self.inner_model = inner_model
- """Inner model reference used for denoising"""
-
-
-class AfterCFGCallbackParams:
- def __init__(self, x, sampling_step, total_sampling_steps):
- self.x = x
- """Latent image representation in the process of being denoised"""
-
- self.sampling_step = sampling_step
- """Current Sampling step number"""
-
- self.total_sampling_steps = total_sampling_steps
- """Total number of sampling steps planned"""
-
-
-class UiTrainTabParams:
- def __init__(self, txt2img_preview_params):
- self.txt2img_preview_params = txt2img_preview_params
-
-
-class ImageGridLoopParams:
- def __init__(self, imgs, cols, rows):
- self.imgs = imgs
- self.cols = cols
- self.rows = rows
-
-
-ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
-callback_map = dict(
- callbacks_app_started=[],
- callbacks_before_process=[],
- callbacks_after_process=[],
- callbacks_model_loaded=[],
- callbacks_ui_tabs=[],
- callbacks_ui_train_tabs=[],
- callbacks_ui_settings=[],
- callbacks_before_image_saved=[],
- callbacks_image_saved=[],
- callbacks_image_save_btn=[],
- callbacks_cfg_denoiser=[],
- callbacks_cfg_denoised=[],
- callbacks_cfg_after_cfg=[],
- callbacks_before_component=[],
- callbacks_after_component=[],
- callbacks_image_grid=[],
- callbacks_infotext_pasted=[],
- callbacks_script_unloaded=[],
- callbacks_before_ui=[],
- callbacks_on_reload=[],
-)
-
-timers = {}
-def timer(t0: float, script, callback: str):
- t1 = time.time()
- k = f'{os.path.basename(script)}:{callback}'
- if k not in timers:
- timers[k] = 0
- timers[k] += t1 - t0
-
-
-def print_timers():
- for k, v in timers.items():
- if v > 0.05:
- errors.log.debug(f'Script: time={v:.2f} {k}')
-
-
-def clear_callbacks():
- for callback_list in callback_map.values():
- callback_list.clear()
-
-
-def app_started_callback(demo: Optional[Blocks], app: FastAPI):
- for c in callback_map['callbacks_app_started']:
- try:
- t0 = time.time()
- c.callback(demo, app)
- timer(t0, c.script, 'app_started')
- except Exception as e:
- report_exception(e, c, 'app_started_callback')
-
-
-def before_process_callback(p):
- for c in callback_map['callbacks_before_process']:
- try:
- t0 = time.time()
- c.callback(p)
- timer(t0, c.script, 'before_process')
- except Exception as e:
- report_exception(e, c, 'before_process_callback')
-
-
-def after_process_callback(p):
- for c in callback_map['callbacks_after_process']:
- try:
- t0 = time.time()
- c.callback(p)
- timer(t0, c.script, 'after_process')
- except Exception as e:
- report_exception(e, c, 'after_process_callback')
-
-
-def app_reload_callback():
- for c in callback_map['callbacks_on_reload']:
- try:
- t0 = time.time()
- c.callback()
- timer(t0, c.script, 'on_reload')
- except Exception as e:
- report_exception(e, c, 'callbacks_on_reload')
-
-
-def model_loaded_callback(sd_model):
- for c in callback_map['callbacks_model_loaded']:
- try:
- t0 = time.time()
- c.callback(sd_model)
- timer(t0, c.script, 'model_loaded')
- except Exception as e:
- report_exception(e, c, 'model_loaded_callback')
-
-
-def ui_tabs_callback():
- res = []
- for c in callback_map['callbacks_ui_tabs']:
- try:
- t0 = time.time()
- res += c.callback() or []
- timer(t0, c.script, 'ui_tabs')
- except Exception as e:
- report_exception(e, c, 'ui_tabs_callback')
- return res
-
-
-def ui_train_tabs_callback(params: UiTrainTabParams):
- for c in callback_map['callbacks_ui_train_tabs']:
- try:
- t0 = time.time()
- c.callback(params)
- timer(t0, c.script, 'ui_train_tabs')
- except Exception as e:
- report_exception(e, c, 'callbacks_ui_train_tabs')
-
-
-def ui_settings_callback():
- for c in callback_map['callbacks_ui_settings']:
- try:
- t0 = time.time()
- c.callback()
- timer(t0, c.script, 'ui_settings')
- except Exception as e:
- report_exception(e, c, 'ui_settings_callback')
-
-
-def before_image_saved_callback(params: ImageSaveParams):
- for c in callback_map['callbacks_before_image_saved']:
- try:
- t0 = time.time()
- c.callback(params)
- timer(t0, c.script, 'before_image_saved')
- except Exception as e:
- report_exception(e, c, 'before_image_saved_callback')
-
-
-def image_saved_callback(params: ImageSaveParams):
- for c in callback_map['callbacks_image_saved']:
- try:
- t0 = time.time()
- c.callback(params)
- timer(t0, c.script, 'image_saved')
- except Exception as e:
- report_exception(e, c, 'image_saved_callback')
-
-
-def image_save_btn_callback(filename: str):
- for c in callback_map['callbacks_image_save_btn']:
- try:
- t0 = time.time()
- c.callback(filename)
- timer(t0, c.script, 'image_save_btn')
- except Exception as e:
- report_exception(e, c, 'image_save_btn_callback')
-
-
-def cfg_denoiser_callback(params: CFGDenoiserParams):
- for c in callback_map['callbacks_cfg_denoiser']:
- try:
- t0 = time.time()
- c.callback(params)
- timer(t0, c.script, 'cfg_denoiser')
- except Exception as e:
- report_exception(e, c, 'cfg_denoiser_callback')
-
-
-def cfg_denoised_callback(params: CFGDenoisedParams):
- for c in callback_map['callbacks_cfg_denoised']:
- try:
- t0 = time.time()
- c.callback(params)
- timer(t0, c.script, 'cfg_denoised')
- except Exception as e:
- report_exception(e, c, 'cfg_denoised_callback')
-
-
-def cfg_after_cfg_callback(params: AfterCFGCallbackParams):
- for c in callback_map['callbacks_cfg_after_cfg']:
- try:
- t0 = time.time()
- c.callback(params)
- timer(t0, c.script, 'cfg_after_cfg')
- except Exception as e:
- report_exception(e, c, 'cfg_after_cfg_callback')
-
-
-def before_component_callback(component, **kwargs):
- for c in callback_map['callbacks_before_component']:
- try:
- t0 = time.time()
- c.callback(component, **kwargs)
- timer(t0, c.script, 'before_component')
- except Exception as e:
- report_exception(e, c, 'before_component_callback')
-
-
-def after_component_callback(component, **kwargs):
- for c in callback_map['callbacks_after_component']:
- try:
- t0 = time.time()
- c.callback(component, **kwargs)
- timer(t0, c.script, 'after_component')
- except Exception as e:
- report_exception(e, c, 'after_component_callback')
-
-
-def image_grid_callback(params: ImageGridLoopParams):
- for c in callback_map['callbacks_image_grid']:
- try:
- t0 = time.time()
- c.callback(params)
- timer(t0, c.script, 'image_grid')
- except Exception as e:
- report_exception(e, c, 'image_grid')
-
-
-def infotext_pasted_callback(infotext: str, params: Dict[str, Any]):
- for c in callback_map['callbacks_infotext_pasted']:
- try:
- t0 = time.time()
- c.callback(infotext, params)
- timer(t0, c.script, 'infotext_pasted')
- except Exception as e:
- report_exception(e, c, 'infotext_pasted')
-
-
-def script_unloaded_callback():
- for c in reversed(callback_map['callbacks_script_unloaded']):
- try:
- t0 = time.time()
- c.callback()
- timer(t0, c.script, 'script_unloaded')
- except Exception as e:
- report_exception(e, c, 'script_unloaded')
-
-
-def before_ui_callback():
- for c in reversed(callback_map['callbacks_before_ui']):
- try:
- t0 = time.time()
- c.callback()
- timer(t0, c.script, 'before_ui')
- except Exception as e:
- report_exception(e, c, 'before_ui')
-
-
-def add_callback(callbacks, fun):
- # stack = [x for x in inspect.stack(0) if x.filename != __file__]
- # filename = stack[0].filename if len(stack) > 0 else 'unknown file'
- filename = sys._getframe().f_back.f_back.f_code.co_filename # pylint: disable=protected-access
- callbacks.append(ScriptCallback(filename, fun))
-
-
-def remove_current_script_callbacks():
- # stack = [x for x in inspect.stack() if x.filename != __file__]
- # filename = stack[0].filename if len(stack) > 0 else 'unknown file'
- # if filename == 'unknown file':
- # return
- filename = sys._getframe().f_back.f_back.f_code.co_filename # pylint: disable=protected-access
- for callback_list in callback_map.values():
- for callback_to_remove in [cb for cb in callback_list if cb.script == filename]:
- callback_list.remove(callback_to_remove)
-
-
-def remove_callbacks_for_function(callback_func):
- for callback_list in callback_map.values():
- for callback_to_remove in [cb for cb in callback_list if cb.callback == callback_func]:
- callback_list.remove(callback_to_remove)
-
-
-def on_app_started(callback):
- """register a function to be called when the webui started, the gradio `Block` component and
- fastapi `FastAPI` object are passed as the arguments"""
- add_callback(callback_map['callbacks_app_started'], callback)
-
-
-def on_before_process(callback):
- """register a function to be called just before processing starts"""
- add_callback(callback_map['callbacks_before_process'], callback)
-
-
-def on_after_process(callback):
- """register a function to be called just after processing ends"""
- add_callback(callback_map['callbacks_after_process'], callback)
-
-
-def on_before_reload(callback):
- """register a function to be called just before the server reloads."""
- add_callback(callback_map['callbacks_on_reload'], callback)
-
-
-def on_model_loaded(callback):
- """register a function to be called when the stable diffusion model is created; the model is
- passed as an argument; this function is also called when the script is reloaded. """
- add_callback(callback_map['callbacks_model_loaded'], callback)
-
-
-def on_ui_tabs(callback):
- """register a function to be called when the UI is creating new tabs.
- The function must either return a None, which means no new tabs to be added, or a list, where
- each element is a tuple:
- (gradio_component, title, elem_id)
-
- gradio_component is a gradio component to be used for contents of the tab (usually gr.Blocks)
- title is tab text displayed to user in the UI
- elem_id is HTML id for the tab
- """
- add_callback(callback_map['callbacks_ui_tabs'], callback)
-
-
-def on_ui_train_tabs(callback):
- """register a function to be called when the UI is creating new tabs for the train tab.
- Create your new tabs with gr.Tab.
- """
- add_callback(callback_map['callbacks_ui_train_tabs'], callback)
-
-
-def on_ui_settings(callback):
- """register a function to be called before UI settings are populated; add your settings
- by using shared.opts.add_option(shared.OptionInfo(...)) """
- add_callback(callback_map['callbacks_ui_settings'], callback)
-
-
-def on_before_image_saved(callback):
- """register a function to be called before an image is saved to a file.
- The callback is called with one argument:
- - params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object.
- """
- add_callback(callback_map['callbacks_before_image_saved'], callback)
-
-
-def on_image_saved(callback):
- """register a function to be called after an image is saved to a file.
- The callback is called with one argument:
- - params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
- """
- add_callback(callback_map['callbacks_image_saved'], callback)
-
-
-def on_image_save_btn(callback):
- """register a function to be called after an image save button is pressed.
- The callback is called with one argument:
- - params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
- """
- add_callback(callback_map['callbacks_image_save_btn'], callback)
-
-
-def on_cfg_denoiser(callback):
- """register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
- The callback is called with one argument:
- - params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
- """
- add_callback(callback_map['callbacks_cfg_denoiser'], callback)
-
-
-def on_cfg_denoised(callback):
- """register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
- The callback is called with one argument:
- - params: CFGDenoisedParams - parameters to be passed to the inner model and sampling state details.
- """
- add_callback(callback_map['callbacks_cfg_denoised'], callback)
-
-
-def on_cfg_after_cfg(callback):
- """register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations are completed.
- The callback is called with one argument:
- - params: AfterCFGCallbackParams - parameters to be passed to the script for post-processing after cfg calculation.
- """
- add_callback(callback_map['callbacks_cfg_after_cfg'], callback)
-
-
-def on_before_component(callback):
- """register a function to be called before a component is created.
- The callback is called with arguments:
- - component - gradio component that is about to be created.
- - **kwargs - args to gradio.components.IOComponent.__init__ function
-
- Use elem_id/label fields of kwargs to figure out which component it is.
- This can be useful to inject your own components somewhere in the middle of vanilla UI.
- """
- add_callback(callback_map['callbacks_before_component'], callback)
-
-
-def on_after_component(callback):
- """register a function to be called after a component is created. See on_before_component for more."""
- add_callback(callback_map['callbacks_after_component'], callback)
-
-
-def on_image_grid(callback):
- """register a function to be called before making an image grid.
- The callback is called with one argument:
- - params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified.
- """
- add_callback(callback_map['callbacks_image_grid'], callback)
-
-
-def on_infotext_pasted(callback):
- """register a function to be called before applying an infotext.
- The callback is called with two arguments:
- - infotext: str - raw infotext.
- - result: Dict[str, any] - parsed infotext parameters.
- """
- add_callback(callback_map['callbacks_infotext_pasted'], callback)
-
-
-def on_script_unloaded(callback):
- """register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that
- the script did should be reverted here"""
-
- add_callback(callback_map['callbacks_script_unloaded'], callback)
-
-
-def on_before_ui(callback):
- """register a function to be called before the UI is created."""
- add_callback(callback_map['callbacks_before_ui'], callback)
+import os
+import sys
+import time
+from collections import namedtuple
+from typing import Optional, Dict, Any
+from fastapi import FastAPI
+from gradio import Blocks
+import modules.errors as errors
+
+
+def report_exception(e, c, job):
+ errors.display(e, f'executing callback: {c.script} {job}')
+
+
+class ImageSaveParams:
+ def __init__(self, image, p, filename, pnginfo):
+ self.image = image
+ """the PIL image itself"""
+
+ self.p = p
+ """p object with processing parameters; either StableDiffusionProcessing or an object with same fields"""
+
+ self.filename = filename
+ """name of file that the image would be saved to"""
+
+ self.pnginfo = pnginfo
+ """dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'"""
+
+
+class CFGDenoiserParams:
+ def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond):
+ self.x = x
+ """Latent image representation in the process of being denoised"""
+
+ self.image_cond = image_cond
+ """Conditioning image"""
+
+ self.sigma = sigma
+ """Current sigma noise step value"""
+
+ self.sampling_step = sampling_step
+ """Current Sampling step number"""
+
+ self.total_sampling_steps = total_sampling_steps
+ """Total number of sampling steps planned"""
+
+ self.text_cond = text_cond
+ """ Encoder hidden states of text conditioning from prompt"""
+
+ self.text_uncond = text_uncond
+ """ Encoder hidden states of text conditioning from negative prompt"""
+
+
+class CFGDenoisedParams:
+ def __init__(self, x, sampling_step, total_sampling_steps, inner_model):
+ self.x = x
+ """Latent image representation in the process of being denoised"""
+
+ self.sampling_step = sampling_step
+ """Current Sampling step number"""
+
+ self.total_sampling_steps = total_sampling_steps
+ """Total number of sampling steps planned"""
+
+ self.inner_model = inner_model
+ """Inner model reference used for denoising"""
+
+
+class AfterCFGCallbackParams:
+ def __init__(self, x, sampling_step, total_sampling_steps):
+ self.x = x
+ """Latent image representation in the process of being denoised"""
+
+ self.sampling_step = sampling_step
+ """Current Sampling step number"""
+
+ self.total_sampling_steps = total_sampling_steps
+ """Total number of sampling steps planned"""
+
+
+class UiTrainTabParams:
+ def __init__(self, txt2img_preview_params):
+ self.txt2img_preview_params = txt2img_preview_params
+
+
+class ImageGridLoopParams:
+ def __init__(self, imgs, cols, rows):
+ self.imgs = imgs
+ self.cols = cols
+ self.rows = rows
+
+
+ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
+callback_map = dict(
+ callbacks_app_started=[],
+ callbacks_before_process=[],
+ callbacks_after_process=[],
+ callbacks_model_loaded=[],
+ callbacks_ui_tabs=[],
+ callbacks_ui_train_tabs=[],
+ callbacks_ui_settings=[],
+ callbacks_before_image_saved=[],
+ callbacks_image_saved=[],
+ callbacks_image_save_btn=[],
+ callbacks_cfg_denoiser=[],
+ callbacks_cfg_denoised=[],
+ callbacks_cfg_after_cfg=[],
+ callbacks_before_component=[],
+ callbacks_after_component=[],
+ callbacks_image_grid=[],
+ callbacks_infotext_pasted=[],
+ callbacks_script_unloaded=[],
+ callbacks_before_ui=[],
+ callbacks_on_reload=[],
+)
+
+timers = {}
+def timer(t0: float, script, callback: str):
+ t1 = time.time()
+ k = f'{os.path.basename(script)}:{callback}'
+ if k not in timers:
+ timers[k] = 0
+ timers[k] += t1 - t0
+
+
+def print_timers():
+ for k, v in timers.items():
+ if v > 0.05:
+ errors.log.debug(f'Script: time={v:.2f} {k}')
+
+
+def clear_callbacks():
+ for callback_list in callback_map.values():
+ callback_list.clear()
+
+
+def app_started_callback(demo: Optional[Blocks], app: FastAPI):
+ for c in callback_map['callbacks_app_started']:
+ try:
+ t0 = time.time()
+ c.callback(demo, app)
+ timer(t0, c.script, 'app_started')
+ except Exception as e:
+ report_exception(e, c, 'app_started_callback')
+
+
+def before_process_callback(p):
+ for c in callback_map['callbacks_before_process']:
+ try:
+ t0 = time.time()
+ c.callback(p)
+ timer(t0, c.script, 'before_process')
+ except Exception as e:
+ report_exception(e, c, 'before_process_callback')
+
+
+def after_process_callback(p):
+ for c in callback_map['callbacks_after_process']:
+ try:
+ t0 = time.time()
+ c.callback(p)
+ timer(t0, c.script, 'after_process')
+ except Exception as e:
+ report_exception(e, c, 'after_process_callback')
+
+
+def app_reload_callback():
+ for c in callback_map['callbacks_on_reload']:
+ try:
+ t0 = time.time()
+ c.callback()
+ timer(t0, c.script, 'on_reload')
+ except Exception as e:
+ report_exception(e, c, 'callbacks_on_reload')
+
+
+def model_loaded_callback(sd_model):
+ for c in callback_map['callbacks_model_loaded']:
+ try:
+ t0 = time.time()
+ c.callback(sd_model)
+ timer(t0, c.script, 'model_loaded')
+ except Exception as e:
+ report_exception(e, c, 'model_loaded_callback')
+
+
+def ui_tabs_callback():
+ res = []
+ for c in callback_map['callbacks_ui_tabs']:
+ try:
+ t0 = time.time()
+ res += c.callback() or []
+ timer(t0, c.script, 'ui_tabs')
+ except Exception as e:
+ report_exception(e, c, 'ui_tabs_callback')
+ return res
+
+
+def ui_train_tabs_callback(params: UiTrainTabParams):
+ for c in callback_map['callbacks_ui_train_tabs']:
+ try:
+ t0 = time.time()
+ c.callback(params)
+ timer(t0, c.script, 'ui_train_tabs')
+ except Exception as e:
+ report_exception(e, c, 'callbacks_ui_train_tabs')
+
+
+def ui_settings_callback():
+ for c in callback_map['callbacks_ui_settings']:
+ try:
+ t0 = time.time()
+ c.callback()
+ timer(t0, c.script, 'ui_settings')
+ except Exception as e:
+ report_exception(e, c, 'ui_settings_callback')
+
+
+def before_image_saved_callback(params: ImageSaveParams):
+ for c in callback_map['callbacks_before_image_saved']:
+ try:
+ t0 = time.time()
+ c.callback(params)
+ timer(t0, c.script, 'before_image_saved')
+ except Exception as e:
+ report_exception(e, c, 'before_image_saved_callback')
+
+
+def image_saved_callback(params: ImageSaveParams):
+ for c in callback_map['callbacks_image_saved']:
+ try:
+ t0 = time.time()
+ c.callback(params)
+ timer(t0, c.script, 'image_saved')
+ except Exception as e:
+ report_exception(e, c, 'image_saved_callback')
+
+
+def image_save_btn_callback(filename: str):
+ for c in callback_map['callbacks_image_save_btn']:
+ try:
+ t0 = time.time()
+ c.callback(filename)
+ timer(t0, c.script, 'image_save_btn')
+ except Exception as e:
+ report_exception(e, c, 'image_save_btn_callback')
+
+
+def cfg_denoiser_callback(params: CFGDenoiserParams):
+ for c in callback_map['callbacks_cfg_denoiser']:
+ try:
+ t0 = time.time()
+ c.callback(params)
+ timer(t0, c.script, 'cfg_denoiser')
+ except Exception as e:
+ report_exception(e, c, 'cfg_denoiser_callback')
+
+
+def cfg_denoised_callback(params: CFGDenoisedParams):
+ for c in callback_map['callbacks_cfg_denoised']:
+ try:
+ t0 = time.time()
+ c.callback(params)
+ timer(t0, c.script, 'cfg_denoised')
+ except Exception as e:
+ report_exception(e, c, 'cfg_denoised_callback')
+
+
+def cfg_after_cfg_callback(params: AfterCFGCallbackParams):
+ for c in callback_map['callbacks_cfg_after_cfg']:
+ try:
+ t0 = time.time()
+ c.callback(params)
+ timer(t0, c.script, 'cfg_after_cfg')
+ except Exception as e:
+ report_exception(e, c, 'cfg_after_cfg_callback')
+
+
+def before_component_callback(component, **kwargs):
+ for c in callback_map['callbacks_before_component']:
+ try:
+ t0 = time.time()
+ c.callback(component, **kwargs)
+ timer(t0, c.script, 'before_component')
+ except Exception as e:
+ report_exception(e, c, 'before_component_callback')
+
+
+def after_component_callback(component, **kwargs):
+ for c in callback_map['callbacks_after_component']:
+ try:
+ t0 = time.time()
+ c.callback(component, **kwargs)
+ timer(t0, c.script, 'after_component')
+ except Exception as e:
+ report_exception(e, c, 'after_component_callback')
+
+
+def image_grid_callback(params: ImageGridLoopParams):
+ for c in callback_map['callbacks_image_grid']:
+ try:
+ t0 = time.time()
+ c.callback(params)
+ timer(t0, c.script, 'image_grid')
+ except Exception as e:
+ report_exception(e, c, 'image_grid')
+
+
+def infotext_pasted_callback(infotext: str, params: Dict[str, Any]):
+ for c in callback_map['callbacks_infotext_pasted']:
+ try:
+ t0 = time.time()
+ c.callback(infotext, params)
+ timer(t0, c.script, 'infotext_pasted')
+ except Exception as e:
+ report_exception(e, c, 'infotext_pasted')
+
+
+def script_unloaded_callback():
+ for c in reversed(callback_map['callbacks_script_unloaded']):
+ try:
+ t0 = time.time()
+ c.callback()
+ timer(t0, c.script, 'script_unloaded')
+ except Exception as e:
+ report_exception(e, c, 'script_unloaded')
+
+
+def before_ui_callback():
+ for c in reversed(callback_map['callbacks_before_ui']):
+ try:
+ t0 = time.time()
+ c.callback()
+ timer(t0, c.script, 'before_ui')
+ except Exception as e:
+ report_exception(e, c, 'before_ui')
+
+
+def add_callback(callbacks, fun):
+ # stack = [x for x in inspect.stack(0) if x.filename != __file__]
+ # filename = stack[0].filename if len(stack) > 0 else 'unknown file'
+ filename = sys._getframe().f_back.f_back.f_code.co_filename # pylint: disable=protected-access
+ callbacks.append(ScriptCallback(filename, fun))
+
+
+def remove_current_script_callbacks():
+ # stack = [x for x in inspect.stack() if x.filename != __file__]
+ # filename = stack[0].filename if len(stack) > 0 else 'unknown file'
+ # if filename == 'unknown file':
+ # return
+ filename = sys._getframe().f_back.f_back.f_code.co_filename # pylint: disable=protected-access
+ for callback_list in callback_map.values():
+ for callback_to_remove in [cb for cb in callback_list if cb.script == filename]:
+ callback_list.remove(callback_to_remove)
+
+
+def remove_callbacks_for_function(callback_func):
+ for callback_list in callback_map.values():
+ for callback_to_remove in [cb for cb in callback_list if cb.callback == callback_func]:
+ callback_list.remove(callback_to_remove)
+
+
+def on_app_started(callback):
+ """register a function to be called when the webui started, the gradio `Block` component and
+ fastapi `FastAPI` object are passed as the arguments"""
+ add_callback(callback_map['callbacks_app_started'], callback)
+
+
+def on_before_process(callback):
+ """register a function to be called just before processing starts"""
+ add_callback(callback_map['callbacks_before_process'], callback)
+
+
+def on_after_process(callback):
+ """register a function to be called just after processing ends"""
+ add_callback(callback_map['callbacks_after_process'], callback)
+
+
+def on_before_reload(callback):
+ """register a function to be called just before the server reloads."""
+ add_callback(callback_map['callbacks_on_reload'], callback)
+
+
+def on_model_loaded(callback):
+ """register a function to be called when the stable diffusion model is created; the model is
+ passed as an argument; this function is also called when the script is reloaded. """
+ add_callback(callback_map['callbacks_model_loaded'], callback)
+
+
+def on_ui_tabs(callback):
+ """register a function to be called when the UI is creating new tabs.
+ The function must either return a None, which means no new tabs to be added, or a list, where
+ each element is a tuple:
+ (gradio_component, title, elem_id)
+
+ gradio_component is a gradio component to be used for contents of the tab (usually gr.Blocks)
+ title is tab text displayed to user in the UI
+ elem_id is HTML id for the tab
+ """
+ add_callback(callback_map['callbacks_ui_tabs'], callback)
+
+
+def on_ui_train_tabs(callback):
+ """register a function to be called when the UI is creating new tabs for the train tab.
+ Create your new tabs with gr.Tab.
+ """
+ add_callback(callback_map['callbacks_ui_train_tabs'], callback)
+
+
+def on_ui_settings(callback):
+ """register a function to be called before UI settings are populated; add your settings
+ by using shared.opts.add_option(shared.OptionInfo(...)) """
+ add_callback(callback_map['callbacks_ui_settings'], callback)
+
+
+def on_before_image_saved(callback):
+ """register a function to be called before an image is saved to a file.
+ The callback is called with one argument:
+ - params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object.
+ """
+ add_callback(callback_map['callbacks_before_image_saved'], callback)
+
+
+def on_image_saved(callback):
+ """register a function to be called after an image is saved to a file.
+ The callback is called with one argument:
+ - params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
+ """
+ add_callback(callback_map['callbacks_image_saved'], callback)
+
+
+def on_image_save_btn(callback):
+ """register a function to be called after an image save button is pressed.
+ The callback is called with one argument:
+ - params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
+ """
+ add_callback(callback_map['callbacks_image_save_btn'], callback)
+
+
+def on_cfg_denoiser(callback):
+ """register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
+ The callback is called with one argument:
+ - params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
+ """
+ add_callback(callback_map['callbacks_cfg_denoiser'], callback)
+
+
+def on_cfg_denoised(callback):
+ """register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
+ The callback is called with one argument:
+ - params: CFGDenoisedParams - parameters to be passed to the inner model and sampling state details.
+ """
+ add_callback(callback_map['callbacks_cfg_denoised'], callback)
+
+
+def on_cfg_after_cfg(callback):
+ """register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations are completed.
+ The callback is called with one argument:
+ - params: AfterCFGCallbackParams - parameters to be passed to the script for post-processing after cfg calculation.
+ """
+ add_callback(callback_map['callbacks_cfg_after_cfg'], callback)
+
+
+def on_before_component(callback):
+ """register a function to be called before a component is created.
+ The callback is called with arguments:
+ - component - gradio component that is about to be created.
+ - **kwargs - args to gradio.components.IOComponent.__init__ function
+
+ Use elem_id/label fields of kwargs to figure out which component it is.
+ This can be useful to inject your own components somewhere in the middle of vanilla UI.
+ """
+ add_callback(callback_map['callbacks_before_component'], callback)
+
+
+def on_after_component(callback):
+ """register a function to be called after a component is created. See on_before_component for more."""
+ add_callback(callback_map['callbacks_after_component'], callback)
+
+
+def on_image_grid(callback):
+ """register a function to be called before making an image grid.
+ The callback is called with one argument:
+ - params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified.
+ """
+ add_callback(callback_map['callbacks_image_grid'], callback)
+
+
+def on_infotext_pasted(callback):
+ """register a function to be called before applying an infotext.
+ The callback is called with two arguments:
+ - infotext: str - raw infotext.
+ - result: Dict[str, any] - parsed infotext parameters.
+ """
+ add_callback(callback_map['callbacks_infotext_pasted'], callback)
+
+
+def on_script_unloaded(callback):
+ """register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that
+ the script did should be reverted here"""
+
+ add_callback(callback_map['callbacks_script_unloaded'], callback)
+
+
+def on_before_ui(callback):
+ """register a function to be called before the UI is created."""
+ add_callback(callback_map['callbacks_before_ui'], callback)
diff --git a/modules/script_loading.py b/modules/script_loading.py
index 37971eaf1..cfb0eb097 100644
--- a/modules/script_loading.py
+++ b/modules/script_loading.py
@@ -1,56 +1,56 @@
-import io
-import os
-import contextlib
-import importlib.util
-import modules.errors as errors
-from installer import setup_logging, args
-
-
-preloaded = []
-debug = os.environ.get('SD_SCRIPT_DEBUG', None)
-
-
-def load_module(path):
- module_spec = importlib.util.spec_from_file_location(os.path.basename(path), path)
- module = importlib.util.module_from_spec(module_spec)
- if args.profile:
- import cProfile
- pr = cProfile.Profile()
- pr.enable()
- try:
- if '/sd-extension-' in path: # safe extensions without stdout intercept
- module_spec.loader.exec_module(module)
- else:
- if debug:
- module_spec.loader.exec_module(module)
- stdout = io.StringIO()
- else:
- with contextlib.redirect_stdout(io.StringIO()) as stdout:
- module_spec.loader.exec_module(module)
- setup_logging() # reset since scripts can hijaack logging
- for line in stdout.getvalue().splitlines():
- if len(line) > 0:
- errors.log.info(f"Extension: script='{os.path.relpath(path)}' {line.strip()}")
- except Exception as e:
- errors.display(e, f'Module load: {path}')
- if args.profile:
- errors.profile(pr, f'Scripts: {path}')
- return module
-
-
-def preload_extensions(extensions_dir, parser):
- if not os.path.isdir(extensions_dir):
- return
- for dirname in sorted(os.listdir(extensions_dir)):
- if dirname in preloaded:
- continue
- preloaded.append(dirname)
- preload_script = os.path.join(extensions_dir, dirname, "preload.py")
- if not os.path.isfile(preload_script):
- continue
- try:
- module = load_module(preload_script)
- if hasattr(module, 'preload'):
- module.preload(parser)
- except Exception as e:
- errors.display(e, f'Extension preload: {preload_script}')
+import io
+import os
+import contextlib
+import importlib.util
+import modules.errors as errors
+from installer import setup_logging, args
+
+
+preloaded = []
+debug = os.environ.get('SD_SCRIPT_DEBUG', None)
+
+
+def load_module(path):
+ module_spec = importlib.util.spec_from_file_location(os.path.basename(path), path)
+ module = importlib.util.module_from_spec(module_spec)
+ if args.profile:
+ import cProfile
+ pr = cProfile.Profile()
+ pr.enable()
+ try:
+ if '/sd-extension-' in path: # safe extensions without stdout intercept
+ module_spec.loader.exec_module(module)
+ else:
+ if debug:
+ module_spec.loader.exec_module(module)
+ stdout = io.StringIO()
+ else:
+ with contextlib.redirect_stdout(io.StringIO()) as stdout:
+ module_spec.loader.exec_module(module)
+ setup_logging() # reset since scripts can hijaack logging
+ for line in stdout.getvalue().splitlines():
+ if len(line) > 0:
+ errors.log.info(f"Extension: script='{os.path.relpath(path)}' {line.strip()}")
+ except Exception as e:
+ errors.display(e, f'Module load: {path}')
+ if args.profile:
+ errors.profile(pr, f'Scripts: {path}')
+ return module
+
+
+def preload_extensions(extensions_dir, parser):
+ if not os.path.isdir(extensions_dir):
+ return
+ for dirname in sorted(os.listdir(extensions_dir)):
+ if dirname in preloaded:
+ continue
+ preloaded.append(dirname)
+ preload_script = os.path.join(extensions_dir, dirname, "preload.py")
+ if not os.path.isfile(preload_script):
+ continue
+ try:
+ module = load_module(preload_script)
+ if hasattr(module, 'preload'):
+ module.preload(parser)
+ except Exception as e:
+ errors.display(e, f'Extension preload: {preload_script}')
diff --git a/modules/scripts.py b/modules/scripts.py
index d6d920cf1..0f308a13f 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -1,683 +1,683 @@
-import os
-import re
-import sys
-import time
-from collections import namedtuple
-import gradio as gr
-from modules import paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer
-from installer import log
-
-
-AlwaysVisible = object()
-time_component = {}
-time_setup = {}
-debug = log.trace if os.environ.get('SD_SCRIPT_DEBUG', None) is not None else lambda *args, **kwargs: None
-
-
-class PostprocessImageArgs:
- def __init__(self, image):
- self.image = image
-
-
-class PostprocessBatchListArgs:
- def __init__(self, images):
- self.images = images
-
-
-class Script:
- parent = None
- name = None
- filename = None
- args_from = None
- args_to = None
- alwayson = False
- is_txt2img = False
- is_img2img = False
- api_info = None
- group = None
- infotext_fields = None
- paste_field_names = None
- section = None
-
- def title(self):
- """this function should return the title of the script. This is what will be displayed in the dropdown menu."""
- raise NotImplementedError
-
- def ui(self, is_img2img):
- """this function should create gradio UI elements. See https://gradio.app/docs/#components
- The return value should be an array of all components that are used in processing.
- Values of those returned components will be passed to run() and process() functions.
- """
- pass # pylint: disable=unnecessary-pass
-
- def show(self, is_img2img): # pylint: disable=unused-argument
- """
- is_img2img is True if this function is called for the img2img interface, and Fasle otherwise
- This function should return:
- - False if the script should not be shown in UI at all
- - True if the script should be shown in UI if it's selected in the scripts dropdown
- - script.AlwaysVisible if the script should be shown in UI at all times
- """
- return True
-
- def run(self, p, *args):
- """
- This function is called if the script has been selected in the script dropdown.
- It must do all processing and return the Processed object with results, same as
- one returned by processing.process_images.
- Usually the processing is done by calling the processing.process_images function.
- args contains all values returned by components from ui()
- """
- pass # pylint: disable=unnecessary-pass
-
- def setup(self, p, *args):
- """For AlwaysVisible scripts, this function is called when the processing object is set up, before any processing starts.
- args contains all values returned by components from ui().
- """
- pass # pylint: disable=unnecessary-pass
-
- def before_process(self, p, *args):
- """
- This function is called very early during processing begins for AlwaysVisible scripts.
- You can modify the processing object (p) here, inject hooks, etc.
- args contains all values returned by components from ui()
- """
- pass # pylint: disable=unnecessary-pass
-
- def process(self, p, *args):
- """
- This function is called before processing begins for AlwaysVisible scripts.
- You can modify the processing object (p) here, inject hooks, etc.
- args contains all values returned by components from ui()
- """
- pass # pylint: disable=unnecessary-pass
-
- def before_process_batch(self, p, *args, **kwargs):
- """
- Called before extra networks are parsed from the prompt, so you can add
- new extra network keywords to the prompt with this callback.
- **kwargs will have those items:
- - batch_number - index of current batch, from 0 to number of batches-1
- - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
- - seeds - list of seeds for current batch
- - subseeds - list of subseeds for current batch
- """
- pass # pylint: disable=unnecessary-pass
-
- def process_batch(self, p, *args, **kwargs):
- """
- Same as process(), but called for every batch.
- **kwargs will have those items:
- - batch_number - index of current batch, from 0 to number of batches-1
- - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
- - seeds - list of seeds for current batch
- - subseeds - list of subseeds for current batch
- """
- pass # pylint: disable=unnecessary-pass
-
- def postprocess_batch(self, p, *args, **kwargs):
- """
- Same as process_batch(), but called for every batch after it has been generated.
- **kwargs will have same items as process_batch, and also:
- - batch_number - index of current batch, from 0 to number of batches-1
- - images - torch tensor with all generated images, with values ranging from 0 to 1;
- """
- pass # pylint: disable=unnecessary-pass
-
- def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
- """
- Called for every image after it has been generated.
- """
- pass # pylint: disable=unnecessary-pass
-
- def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, *args, **kwargs):
- """
- Same as postprocess_batch(), but receives batch images as a list of 3D tensors instead of a 4D tensor.
- This is useful when you want to update the entire batch instead of individual images.
- You can modify the postprocessing object (pp) to update the images in the batch, remove images, add images, etc.
- If the number of images is different from the batch size when returning,
- then the script has the responsibility to also update the following attributes in the processing object (p):
- - p.prompts
- - p.negative_prompts
- - p.seeds
- - p.subseeds
- **kwargs will have same items as process_batch, and also:
- - batch_number - index of current batch, from 0 to number of batches-1
- """
- pass # pylint: disable=unnecessary-pass
-
- def postprocess(self, p, processed, *args):
- """
- This function is called after processing ends for AlwaysVisible scripts.
- args contains all values returned by components from ui()
- """
- pass # pylint: disable=unnecessary-pass
-
- def before_component(self, component, **kwargs):
- """
- Called before a component is created.
- Use elem_id/label fields of kwargs to figure out which component it is.
- This can be useful to inject your own components somewhere in the middle of vanilla UI.
- You can return created components in the ui() function to add them to the list of arguments for your processing functions
- """
- pass # pylint: disable=unnecessary-pass
-
- def after_component(self, component, **kwargs):
- """
- Called after a component is created. Same as above.
- """
- pass # pylint: disable=unnecessary-pass
-
- def describe(self):
- """unused"""
- return ""
-
- def elem_id(self, item_id):
- """helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""
- title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))
- return f'script_{self.parent}_{title}_{item_id}'
-
-
-current_basedir = paths.script_path
-
-
-def basedir():
- """returns the base directory for the current script. For scripts in the main scripts directory,
- this is the main directory (where webui.py resides), and for scripts in extensions directory
- (ie extensions/aesthetic/script/aesthetic.py), this is extension's directory (extensions/aesthetic)
- """
- return current_basedir
-
-
-ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path", "priority"])
-scripts_data = []
-postprocessing_scripts_data = []
-ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
-
-
-def list_scripts(scriptdirname, extension):
- tmp_list = []
- base = os.path.join(paths.script_path, scriptdirname)
- if os.path.exists(base):
- for filename in sorted(os.listdir(base)):
- tmp_list.append(ScriptFile(paths.script_path, filename, os.path.join(base, filename), '50'))
- for ext in extensions.active():
- tmp_list += ext.list_files(scriptdirname, extension)
- priority_list = []
- for script in tmp_list:
- if os.path.splitext(script.path)[1].lower() == extension and os.path.isfile(script.path):
- if script.basedir == paths.script_path:
- priority = '0'
- elif script.basedir.startswith(os.path.join(paths.script_path, 'scripts')):
- priority = '1'
- elif script.basedir.startswith(os.path.join(paths.script_path, 'extensions-builtin')):
- priority = '2'
- elif script.basedir.startswith(os.path.join(paths.script_path, 'extensions')):
- priority = '3'
- else:
- priority = '9'
- if os.path.isfile(os.path.join(base, "..", ".priority")):
- with open(os.path.join(base, "..", ".priority"), "r", encoding="utf-8") as f:
- priority = priority + str(f.read().strip())
- log.debug(f'Script priority override: ${script.name}:{priority}')
- else:
- priority = priority + script.priority
- priority_list.append(ScriptFile(script.basedir, script.filename, script.path, priority))
- debug(f'Adding script: {script.basedir} {script.filename} {script.path} {priority}')
- priority_sort = sorted(priority_list, key=lambda item: item.priority + item.path.lower(), reverse=False)
- return priority_sort
-
-
-def list_files_with_name(filename):
- res = []
- dirs = [paths.script_path] + [ext.path for ext in extensions.active()]
- for dirpath in dirs:
- if not os.path.isdir(dirpath):
- continue
- path = os.path.join(dirpath, filename)
- if os.path.isfile(path):
- res.append(path)
- return res
-
-
-def load_scripts():
- t = timer.Timer()
- t0 = time.time()
- global current_basedir # pylint: disable=global-statement
- scripts_data.clear()
- postprocessing_scripts_data.clear()
- script_callbacks.clear_callbacks()
- scripts_list = list_scripts("scripts", ".py")
- syspath = sys.path
-
- def register_scripts_from_module(module, scriptfile):
- for script_class in module.__dict__.values():
- if type(script_class) != type:
- continue
- debug(f'Registering script: {scriptfile.path}')
- if issubclass(script_class, Script):
- scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
- elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
- postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
-
- for scriptfile in scripts_list:
- try:
- if scriptfile.basedir != paths.script_path:
- sys.path = [scriptfile.basedir] + sys.path
- current_basedir = scriptfile.basedir
- script_module = script_loading.load_module(scriptfile.path)
- register_scripts_from_module(script_module, scriptfile)
- except Exception as e:
- errors.display(e, f'Load script: {scriptfile.filename}')
- finally:
- current_basedir = paths.script_path
- t.record(os.path.basename(scriptfile.basedir) if scriptfile.basedir != paths.script_path else scriptfile.filename)
- sys.path = syspath
- global scripts_txt2img, scripts_img2img, scripts_control, scripts_postproc # pylint: disable=global-statement
- scripts_txt2img = ScriptRunner()
- scripts_img2img = ScriptRunner()
- scripts_control = ScriptRunner()
- scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
- return t, time.time()-t0
-
-
-def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
- try:
- res = func(*args, **kwargs)
- return res
- except Exception as e:
- errors.display(e, f'Calling script: {filename}/{funcname}')
- return default
-
-
-class ScriptSummary:
- def __init__(self, op):
- self.start = time.time()
- self.update = time.time()
- self.op = op
- self.time = {}
-
- def record(self, script):
- self.update = time.time()
- self.time[script] = round(time.time() - self.update, 2)
-
- def report(self):
- total = sum(self.time.values())
- if total == 0:
- return
- scripts = [f'{k}:{v}' for k, v in self.time.items() if v > 0]
- log.debug(f'Script: op={self.op} total={total} scripts={scripts}')
-
-
-class ScriptRunner:
- def __init__(self):
- self.scripts = []
- self.selectable_scripts = []
- self.alwayson_scripts = []
- self.titles = []
- self.infotext_fields = []
- self.paste_field_names = []
- self.script_load_ctr = 0
- self.is_img2img = False
- self.inputs = [None]
-
- def initialize_scripts(self, is_img2img):
- from modules import scripts_auto_postprocessing
-
- self.scripts.clear()
- self.selectable_scripts.clear()
- self.alwayson_scripts.clear()
- self.titles.clear()
- self.infotext_fields.clear()
- self.paste_field_names.clear()
- self.script_load_ctr = 0
- self.is_img2img = is_img2img
- self.scripts.clear()
- self.alwayson_scripts.clear()
- self.selectable_scripts.clear()
- auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
-
- for script_class, path, _basedir, _script_module in auto_processing_scripts + scripts_data:
- try:
- script = script_class()
- script.filename = path
- script.is_txt2img = not is_img2img
- script.is_img2img = is_img2img
- visibility = script.show(script.is_img2img)
- if visibility == AlwaysVisible:
- self.scripts.append(script)
- self.alwayson_scripts.append(script)
- script.alwayson = True
- elif visibility:
- self.scripts.append(script)
- self.selectable_scripts.append(script)
- except Exception as e:
- log.error(f'Script initialize: {path} {e}')
-
- """
- def create_script_ui(self, script):
- import modules.api.models as api_models
- script.args_from = len(self.inputs)
- script.args_to = len(self.inputs)
- controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
- if controls is None:
- return
- script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower()
- api_args = []
- for control in controls:
- if not isinstance(control, gr.components.IOComponent):
- log.error(f'Invalid script control: "{script.filename}" control={control}')
- continue
- control.custom_script_source = os.path.basename(script.filename)
- arg_info = api_models.ScriptArg(label=control.label or "")
- for field in ("value", "minimum", "maximum", "step", "choices"):
- v = getattr(control, field, None)
- if v is not None:
- setattr(arg_info, field, v)
- api_args.append(arg_info)
- script.api_info = api_models.ScriptInfo(name=script.name, is_img2img=script.is_img2img, is_alwayson=script.alwayson, args=api_args)
- if script.infotext_fields is not None:
- self.infotext_fields += script.infotext_fields
- if script.paste_field_names is not None:
- self.paste_field_names += script.paste_field_names
- self.inputs += controls
- script.args_to = len(self.inputs)
-
- def setup_ui_for_section(self, section, scriptlist=None):
- if scriptlist is None:
- scriptlist = self.alwayson_scripts
- for script in scriptlist:
- if script.alwayson and script.section != section:
- continue
- if script.create_group:
- with gr.Group(visible=script.alwayson) as group:
- self.create_script_ui(script)
- script.group = group
- else:
- self.create_script_ui(script)
- """
-
- def prepare_ui(self):
- self.inputs = [None]
-
- def setup_ui(self, parent='unknown', accordion=True):
- import modules.api.models as api_models
- self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
- inputs = []
- inputs_alwayson = [True]
-
- def create_script_ui(script: Script, inputs, inputs_alwayson):
- script.parent = parent
- script.args_from = len(inputs)
- script.args_to = len(inputs)
- controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
- if controls is None:
- return
- script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower()
- api_args = []
- for control in controls:
- debug(f'Script control: parent={script.parent} script="{script.name}" label="{control.label}" type={control} id={control.elem_id}')
- if not isinstance(control, gr.components.IOComponent):
- log.error(f'Invalid script control: "{script.filename}" control={control}')
- continue
- control.custom_script_source = os.path.basename(script.filename)
- arg_info = api_models.ScriptArg(label=control.label or "")
- for field in ("value", "minimum", "maximum", "step", "choices"):
- v = getattr(control, field, None)
- if v is not None:
- setattr(arg_info, field, v)
- api_args.append(arg_info)
-
- script.api_info = api_models.ScriptInfo(
- name=script.name,
- is_img2img=script.is_img2img,
- is_alwayson=script.alwayson,
- args=api_args,
- )
- if script.infotext_fields is not None:
- self.infotext_fields += script.infotext_fields
- if script.paste_field_names is not None:
- self.paste_field_names += script.paste_field_names
- inputs += controls
- inputs_alwayson += [script.alwayson for _ in controls]
- script.args_to = len(inputs)
-
- dropdown = gr.Dropdown(label="Script", elem_id=f'{parent}_script_list', choices=["None"] + self.titles, value="None", type="index")
- inputs.insert(0, dropdown)
- for script in self.selectable_scripts:
- with gr.Group(visible=False) as group:
- t0 = time.time()
- create_script_ui(script, inputs, inputs_alwayson)
- time_setup[script.title()] = time_setup.get(script.title(), 0) + (time.time()-t0)
- script.group = group
-
- def select_script(script_index):
- selected_script = self.selectable_scripts[script_index - 1] if script_index > 0 else None
- return [gr.update(visible=selected_script == s) for s in self.selectable_scripts]
-
- def init_field(title):
- if title == 'None': # called when an initial value is set from ui-config.json to show script's UI components
- return
- script_index = self.titles.index(title)
- self.selectable_scripts[script_index].group.visible = True
-
- dropdown.init_field = init_field
- dropdown.change(fn=select_script, inputs=[dropdown], outputs=[script.group for script in self.selectable_scripts])
-
- def onload_script_visibility(params):
- title = params.get('Script', None)
- if title:
- title_index = self.titles.index(title)
- visibility = title_index == self.script_load_ctr
- self.script_load_ctr = (self.script_load_ctr + 1) % len(self.titles)
- return gr.update(visible=visibility)
- else:
- return gr.update(visible=False)
-
- with gr.Accordion(label="Extensions", elem_id=f'{parent}_script_alwayson') if accordion else gr.Group():
- for script in self.alwayson_scripts:
- t0 = time.time()
- with gr.Group(elem_id=f'{parent}_script_{script.title().lower().replace(" ", "_")}', elem_classes=['extension-script']) as group:
- create_script_ui(script, inputs, inputs_alwayson)
- script.group = group
- time_setup[script.title()] = time_setup.get(script.title(), 0) + (time.time()-t0)
-
- self.infotext_fields.append( (dropdown, lambda x: gr.update(value=x.get('Script', 'None'))) )
- self.infotext_fields.extend( [(script.group, onload_script_visibility) for script in self.selectable_scripts] )
- return inputs
-
- def run(self, p, *args):
- s = ScriptSummary('run')
- script_index = args[0]
- if script_index == 0:
- return None
- script = self.selectable_scripts[script_index-1]
- if script is None:
- return None
- parsed = p.per_script_args.get(script.title(), args[script.args_from:script.args_to])
- processed = script.run(p, *parsed)
- s.record(script.title())
- s.report()
- return processed
-
- def before_process(self, p, **kwargs):
- s = ScriptSummary('before-process')
- for script in self.alwayson_scripts:
- try:
- script_args = p.script_args[script.args_from:script.args_to]
- script.before_process(p, *script_args, **kwargs)
- except Exception as e:
- errors.display(e, f"Error running before process: {script.filename}")
- s.record(script.title())
- s.report()
-
- def process(self, p, **kwargs):
- s = ScriptSummary('process')
- for script in self.alwayson_scripts:
- try:
- args = p.per_script_args.get(script.title(), p.script_args[script.args_from:script.args_to])
- script.process(p, *args, **kwargs)
- except Exception as e:
- errors.display(e, f'Running script process: {script.filename}')
- s.record(script.title())
- s.report()
-
- def before_process_batch(self, p, **kwargs):
- s = ScriptSummary('before-process-batch')
- for script in self.alwayson_scripts:
- try:
- args = p.per_script_args.get(script.title(), p.script_args[script.args_from:script.args_to])
- script.before_process_batch(p, *args, **kwargs)
- except Exception as e:
- errors.display(e, f'Running script before process batch: {script.filename}')
- s.record(script.title())
- s.report()
-
- def process_batch(self, p, **kwargs):
- s = ScriptSummary('process-batch')
- for script in self.alwayson_scripts:
- try:
- args = p.per_script_args.get(script.title(), p.script_args[script.args_from:script.args_to])
- script.process_batch(p, *args, **kwargs)
- except Exception as e:
- errors.display(e, f'Running script process batch: {script.filename}')
- s.record(script.title())
- s.report()
-
- def postprocess(self, p, processed):
- s = ScriptSummary('postprocess')
- for script in self.alwayson_scripts:
- try:
- args = p.per_script_args.get(script.title(), p.script_args[script.args_from:script.args_to])
- script.postprocess(p, processed, *args)
- except Exception as e:
- errors.display(e, f'Running script postprocess: {script.filename}')
- s.record(script.title())
- s.report()
-
- def postprocess_batch(self, p, images, **kwargs):
- s = ScriptSummary('postprocess-batch')
- for script in self.alwayson_scripts:
- try:
- args = p.per_script_args.get(script.title(), p.script_args[script.args_from:script.args_to])
- script.postprocess_batch(p, *args, images=images, **kwargs)
- except Exception as e:
- errors.display(e, f'Running script before postprocess batch: {script.filename}')
- s.record(script.title())
- s.report()
-
- def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, **kwargs):
- s = ScriptSummary('postprocess-batch-list')
- for script in self.alwayson_scripts:
- try:
- args = p.per_script_args.get(script.title(), p.script_args[script.args_from:script.args_to])
- script.postprocess_batch_list(p, pp, *args, **kwargs)
- except Exception as e:
- errors.display(e, f'Running script before postprocess batch list: {script.filename}')
- s.record(script.title())
- s.report()
-
- def postprocess_image(self, p, pp: PostprocessImageArgs):
- s = ScriptSummary('postprocess-image')
- for script in self.alwayson_scripts:
- try:
- args = p.per_script_args.get(script.title(), p.script_args[script.args_from:script.args_to])
- script.postprocess_image(p, pp, *args)
- except Exception as e:
- errors.display(e, f'Running script postprocess image: {script.filename}')
- s.record(script.title())
- s.report()
-
- def before_component(self, component, **kwargs):
- s = ScriptSummary('before-component')
- for script in self.scripts:
- try:
- script.before_component(component, **kwargs)
- except Exception as e:
- errors.display(e, f'Running script before component: {script.filename}')
- s.record(script.title())
- s.report()
-
- def after_component(self, component, **kwargs):
- s = ScriptSummary('after-component')
- for script in self.scripts:
- try:
- script.after_component(component, **kwargs)
- except Exception as e:
- errors.display(e, f'Running script after component: {script.filename}')
- s.record(script.title())
- s.report()
-
- def reload_sources(self, cache):
- s = ScriptSummary('reload-sources')
- for si, script in list(enumerate(self.scripts)):
- args_from = script.args_from
- args_to = script.args_to
- filename = script.filename
- module = cache.get(filename, None)
- if module is None:
- module = script_loading.load_module(script.filename)
- cache[filename] = module
- for script_class in module.__dict__.values():
- if type(script_class) == type and issubclass(script_class, Script):
- self.scripts[si] = script_class()
- self.scripts[si].filename = filename
- self.scripts[si].args_from = args_from
- self.scripts[si].args_to = args_to
- s.record(script.title())
- s.report()
-
-
-scripts_txt2img: ScriptRunner = None
-scripts_img2img: ScriptRunner = None
-scripts_control: ScriptRunner = None
-scripts_current: ScriptRunner = None
-scripts_postproc: scripts_postprocessing.ScriptPostprocessingRunner = None
-reload_scripts = load_scripts # compatibility alias
-
-
-def reload_script_body_only():
- cache = {}
- scripts_txt2img.reload_sources(cache)
- scripts_img2img.reload_sources(cache)
- scripts_control.reload_sources(cache)
-
-
-def add_classes_to_gradio_component(comp):
- """
- this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
- """
- elem_classes = []
- if hasattr(comp, "elem_classes"):
- elem_classes = comp.elem_classes
- if elem_classes is None:
- elem_classes = []
- comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])]
- if getattr(comp, 'multiselect', False):
- comp.elem_classes.append('multiselect')
-
-
-def IOComponent_init(self, *args, **kwargs):
- if scripts_current is not None:
- scripts_current.before_component(self, **kwargs)
- script_callbacks.before_component_callback(self, **kwargs)
- res = original_IOComponent_init(self, *args, **kwargs) # pylint: disable=assignment-from-no-return
- add_classes_to_gradio_component(self)
- script_callbacks.after_component_callback(self, **kwargs)
- if scripts_current is not None:
- scripts_current.after_component(self, **kwargs)
- return res
-
-
-original_IOComponent_init = gr.components.IOComponent.__init__
-gr.components.IOComponent.__init__ = IOComponent_init
-
-
-def BlockContext_init(self, *args, **kwargs):
- res = original_BlockContext_init(self, *args, **kwargs) # pylint: disable=assignment-from-no-return
- add_classes_to_gradio_component(self)
- return res
-
-
-original_BlockContext_init = gr.blocks.BlockContext.__init__
-gr.blocks.BlockContext.__init__ = BlockContext_init
+import os
+import re
+import sys
+import time
+from collections import namedtuple
+import gradio as gr
+from modules import paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer
+from installer import log
+
+
+AlwaysVisible = object()
+time_component = {}
+time_setup = {}
+debug = log.trace if os.environ.get('SD_SCRIPT_DEBUG', None) is not None else lambda *args, **kwargs: None
+
+
+class PostprocessImageArgs:
+ def __init__(self, image):
+ self.image = image
+
+
+class PostprocessBatchListArgs:
+ def __init__(self, images):
+ self.images = images
+
+
+class Script:
+ parent = None
+ name = None
+ filename = None
+ args_from = None
+ args_to = None
+ alwayson = False
+ is_txt2img = False
+ is_img2img = False
+ api_info = None
+ group = None
+ infotext_fields = None
+ paste_field_names = None
+ section = None
+
+ def title(self):
+ """this function should return the title of the script. This is what will be displayed in the dropdown menu."""
+ raise NotImplementedError
+
+ def ui(self, is_img2img):
+ """this function should create gradio UI elements. See https://gradio.app/docs/#components
+ The return value should be an array of all components that are used in processing.
+ Values of those returned components will be passed to run() and process() functions.
+ """
+ pass # pylint: disable=unnecessary-pass
+
+ def show(self, is_img2img): # pylint: disable=unused-argument
+ """
+ is_img2img is True if this function is called for the img2img interface, and Fasle otherwise
+ This function should return:
+ - False if the script should not be shown in UI at all
+ - True if the script should be shown in UI if it's selected in the scripts dropdown
+ - script.AlwaysVisible if the script should be shown in UI at all times
+ """
+ return True
+
+ def run(self, p, *args):
+ """
+ This function is called if the script has been selected in the script dropdown.
+ It must do all processing and return the Processed object with results, same as
+ one returned by processing.process_images.
+ Usually the processing is done by calling the processing.process_images function.
+ args contains all values returned by components from ui()
+ """
+ pass # pylint: disable=unnecessary-pass
+
+ def setup(self, p, *args):
+ """For AlwaysVisible scripts, this function is called when the processing object is set up, before any processing starts.
+ args contains all values returned by components from ui().
+ """
+ pass # pylint: disable=unnecessary-pass
+
+ def before_process(self, p, *args):
+ """
+ This function is called very early during processing begins for AlwaysVisible scripts.
+ You can modify the processing object (p) here, inject hooks, etc.
+ args contains all values returned by components from ui()
+ """
+ pass # pylint: disable=unnecessary-pass
+
+ def process(self, p, *args):
+ """
+ This function is called before processing begins for AlwaysVisible scripts.
+ You can modify the processing object (p) here, inject hooks, etc.
+ args contains all values returned by components from ui()
+ """
+ pass # pylint: disable=unnecessary-pass
+
+ def before_process_batch(self, p, *args, **kwargs):
+ """
+ Called before extra networks are parsed from the prompt, so you can add
+ new extra network keywords to the prompt with this callback.
+ **kwargs will have those items:
+ - batch_number - index of current batch, from 0 to number of batches-1
+ - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
+ - seeds - list of seeds for current batch
+ - subseeds - list of subseeds for current batch
+ """
+ pass # pylint: disable=unnecessary-pass
+
+ def process_batch(self, p, *args, **kwargs):
+ """
+ Same as process(), but called for every batch.
+ **kwargs will have those items:
+ - batch_number - index of current batch, from 0 to number of batches-1
+ - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
+ - seeds - list of seeds for current batch
+ - subseeds - list of subseeds for current batch
+ """
+ pass # pylint: disable=unnecessary-pass
+
+ def postprocess_batch(self, p, *args, **kwargs):
+ """
+ Same as process_batch(), but called for every batch after it has been generated.
+ **kwargs will have same items as process_batch, and also:
+ - batch_number - index of current batch, from 0 to number of batches-1
+ - images - torch tensor with all generated images, with values ranging from 0 to 1;
+ """
+ pass # pylint: disable=unnecessary-pass
+
+ def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
+ """
+ Called for every image after it has been generated.
+ """
+ pass # pylint: disable=unnecessary-pass
+
+ def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, *args, **kwargs):
+ """
+ Same as postprocess_batch(), but receives batch images as a list of 3D tensors instead of a 4D tensor.
+ This is useful when you want to update the entire batch instead of individual images.
+ You can modify the postprocessing object (pp) to update the images in the batch, remove images, add images, etc.
+ If the number of images is different from the batch size when returning,
+ then the script has the responsibility to also update the following attributes in the processing object (p):
+ - p.prompts
+ - p.negative_prompts
+ - p.seeds
+ - p.subseeds
+ **kwargs will have same items as process_batch, and also:
+ - batch_number - index of current batch, from 0 to number of batches-1
+ """
+ pass # pylint: disable=unnecessary-pass
+
+ def postprocess(self, p, processed, *args):
+ """
+ This function is called after processing ends for AlwaysVisible scripts.
+ args contains all values returned by components from ui()
+ """
+ pass # pylint: disable=unnecessary-pass
+
+ def before_component(self, component, **kwargs):
+ """
+ Called before a component is created.
+ Use elem_id/label fields of kwargs to figure out which component it is.
+ This can be useful to inject your own components somewhere in the middle of vanilla UI.
+ You can return created components in the ui() function to add them to the list of arguments for your processing functions
+ """
+ pass # pylint: disable=unnecessary-pass
+
+ def after_component(self, component, **kwargs):
+ """
+ Called after a component is created. Same as above.
+ """
+ pass # pylint: disable=unnecessary-pass
+
+ def describe(self):
+ """unused"""
+ return ""
+
+ def elem_id(self, item_id):
+ """helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""
+ title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))
+ return f'script_{self.parent}_{title}_{item_id}'
+
+
+current_basedir = paths.script_path
+
+
+def basedir():
+ """returns the base directory for the current script. For scripts in the main scripts directory,
+ this is the main directory (where webui.py resides), and for scripts in extensions directory
+ (ie extensions/aesthetic/script/aesthetic.py), this is extension's directory (extensions/aesthetic)
+ """
+ return current_basedir
+
+
+ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path", "priority"])
+scripts_data = []
+postprocessing_scripts_data = []
+ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
+
+
+def list_scripts(scriptdirname, extension):
+ tmp_list = []
+ base = os.path.join(paths.script_path, scriptdirname)
+ if os.path.exists(base):
+ for filename in sorted(os.listdir(base)):
+ tmp_list.append(ScriptFile(paths.script_path, filename, os.path.join(base, filename), '50'))
+ for ext in extensions.active():
+ tmp_list += ext.list_files(scriptdirname, extension)
+ priority_list = []
+ for script in tmp_list:
+ if os.path.splitext(script.path)[1].lower() == extension and os.path.isfile(script.path):
+ if script.basedir == paths.script_path:
+ priority = '0'
+ elif script.basedir.startswith(os.path.join(paths.script_path, 'scripts')):
+ priority = '1'
+ elif script.basedir.startswith(os.path.join(paths.script_path, 'extensions-builtin')):
+ priority = '2'
+ elif script.basedir.startswith(os.path.join(paths.script_path, 'extensions')):
+ priority = '3'
+ else:
+ priority = '9'
+ if os.path.isfile(os.path.join(base, "..", ".priority")):
+ with open(os.path.join(base, "..", ".priority"), "r", encoding="utf-8") as f:
+ priority = priority + str(f.read().strip())
+ log.debug(f'Script priority override: ${script.name}:{priority}')
+ else:
+ priority = priority + script.priority
+ priority_list.append(ScriptFile(script.basedir, script.filename, script.path, priority))
+ debug(f'Adding script: {script.basedir} {script.filename} {script.path} {priority}')
+ priority_sort = sorted(priority_list, key=lambda item: item.priority + item.path.lower(), reverse=False)
+ return priority_sort
+
+
+def list_files_with_name(filename):
+ res = []
+ dirs = [paths.script_path] + [ext.path for ext in extensions.active()]
+ for dirpath in dirs:
+ if not os.path.isdir(dirpath):
+ continue
+ path = os.path.join(dirpath, filename)
+ if os.path.isfile(path):
+ res.append(path)
+ return res
+
+
+def load_scripts():
+ t = timer.Timer()
+ t0 = time.time()
+ global current_basedir # pylint: disable=global-statement
+ scripts_data.clear()
+ postprocessing_scripts_data.clear()
+ script_callbacks.clear_callbacks()
+ scripts_list = list_scripts("scripts", ".py")
+ syspath = sys.path
+
+ def register_scripts_from_module(module, scriptfile):
+ for script_class in module.__dict__.values():
+ if type(script_class) != type:
+ continue
+ debug(f'Registering script: {scriptfile.path}')
+ if issubclass(script_class, Script):
+ scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
+ elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
+ postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
+
+ for scriptfile in scripts_list:
+ try:
+ if scriptfile.basedir != paths.script_path:
+ sys.path = [scriptfile.basedir] + sys.path
+ current_basedir = scriptfile.basedir
+ script_module = script_loading.load_module(scriptfile.path)
+ register_scripts_from_module(script_module, scriptfile)
+ except Exception as e:
+ errors.display(e, f'Load script: {scriptfile.filename}')
+ finally:
+ current_basedir = paths.script_path
+ t.record(os.path.basename(scriptfile.basedir) if scriptfile.basedir != paths.script_path else scriptfile.filename)
+ sys.path = syspath
+ global scripts_txt2img, scripts_img2img, scripts_control, scripts_postproc # pylint: disable=global-statement
+ scripts_txt2img = ScriptRunner()
+ scripts_img2img = ScriptRunner()
+ scripts_control = ScriptRunner()
+ scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
+ return t, time.time()-t0
+
+
+def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
+ try:
+ res = func(*args, **kwargs)
+ return res
+ except Exception as e:
+ errors.display(e, f'Calling script: {filename}/{funcname}')
+ return default
+
+
+class ScriptSummary:
+ def __init__(self, op):
+ self.start = time.time()
+ self.update = time.time()
+ self.op = op
+ self.time = {}
+
+ def record(self, script):
+ self.update = time.time()
+ self.time[script] = round(time.time() - self.update, 2)
+
+ def report(self):
+ total = sum(self.time.values())
+ if total == 0:
+ return
+ scripts = [f'{k}:{v}' for k, v in self.time.items() if v > 0]
+ log.debug(f'Script: op={self.op} total={total} scripts={scripts}')
+
+
+class ScriptRunner:
+ def __init__(self):
+ self.scripts = []
+ self.selectable_scripts = []
+ self.alwayson_scripts = []
+ self.titles = []
+ self.infotext_fields = []
+ self.paste_field_names = []
+ self.script_load_ctr = 0
+ self.is_img2img = False
+ self.inputs = [None]
+
+ def initialize_scripts(self, is_img2img):
+ from modules import scripts_auto_postprocessing
+
+ self.scripts.clear()
+ self.selectable_scripts.clear()
+ self.alwayson_scripts.clear()
+ self.titles.clear()
+ self.infotext_fields.clear()
+ self.paste_field_names.clear()
+ self.script_load_ctr = 0
+ self.is_img2img = is_img2img
+ self.scripts.clear()
+ self.alwayson_scripts.clear()
+ self.selectable_scripts.clear()
+ auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
+
+ for script_class, path, _basedir, _script_module in auto_processing_scripts + scripts_data:
+ try:
+ script = script_class()
+ script.filename = path
+ script.is_txt2img = not is_img2img
+ script.is_img2img = is_img2img
+ visibility = script.show(script.is_img2img)
+ if visibility == AlwaysVisible:
+ self.scripts.append(script)
+ self.alwayson_scripts.append(script)
+ script.alwayson = True
+ elif visibility:
+ self.scripts.append(script)
+ self.selectable_scripts.append(script)
+ except Exception as e:
+ log.error(f'Script initialize: {path} {e}')
+
+ """
+ def create_script_ui(self, script):
+ import modules.api.models as api_models
+ script.args_from = len(self.inputs)
+ script.args_to = len(self.inputs)
+ controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
+ if controls is None:
+ return
+ script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower()
+ api_args = []
+ for control in controls:
+ if not isinstance(control, gr.components.IOComponent):
+ log.error(f'Invalid script control: "{script.filename}" control={control}')
+ continue
+ control.custom_script_source = os.path.basename(script.filename)
+ arg_info = api_models.ScriptArg(label=control.label or "")
+ for field in ("value", "minimum", "maximum", "step", "choices"):
+ v = getattr(control, field, None)
+ if v is not None:
+ setattr(arg_info, field, v)
+ api_args.append(arg_info)
+ script.api_info = api_models.ScriptInfo(name=script.name, is_img2img=script.is_img2img, is_alwayson=script.alwayson, args=api_args)
+ if script.infotext_fields is not None:
+ self.infotext_fields += script.infotext_fields
+ if script.paste_field_names is not None:
+ self.paste_field_names += script.paste_field_names
+ self.inputs += controls
+ script.args_to = len(self.inputs)
+
+ def setup_ui_for_section(self, section, scriptlist=None):
+ if scriptlist is None:
+ scriptlist = self.alwayson_scripts
+ for script in scriptlist:
+ if script.alwayson and script.section != section:
+ continue
+ if script.create_group:
+ with gr.Group(visible=script.alwayson) as group:
+ self.create_script_ui(script)
+ script.group = group
+ else:
+ self.create_script_ui(script)
+ """
+
+ def prepare_ui(self):
+ self.inputs = [None]
+
+ def setup_ui(self, parent='unknown', accordion=True):
+ import modules.api.models as api_models
+ self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
+ inputs = []
+ inputs_alwayson = [True]
+
+ def create_script_ui(script: Script, inputs, inputs_alwayson):
+ script.parent = parent
+ script.args_from = len(inputs)
+ script.args_to = len(inputs)
+ controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
+ if controls is None:
+ return
+ script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower()
+ api_args = []
+ for control in controls:
+ debug(f'Script control: parent={script.parent} script="{script.name}" label="{control.label}" type={control} id={control.elem_id}')
+ if not isinstance(control, gr.components.IOComponent):
+ log.error(f'Invalid script control: "{script.filename}" control={control}')
+ continue
+ control.custom_script_source = os.path.basename(script.filename)
+ arg_info = api_models.ScriptArg(label=control.label or "")
+ for field in ("value", "minimum", "maximum", "step", "choices"):
+ v = getattr(control, field, None)
+ if v is not None:
+ setattr(arg_info, field, v)
+ api_args.append(arg_info)
+
+ script.api_info = api_models.ScriptInfo(
+ name=script.name,
+ is_img2img=script.is_img2img,
+ is_alwayson=script.alwayson,
+ args=api_args,
+ )
+ if script.infotext_fields is not None:
+ self.infotext_fields += script.infotext_fields
+ if script.paste_field_names is not None:
+ self.paste_field_names += script.paste_field_names
+ inputs += controls
+ inputs_alwayson += [script.alwayson for _ in controls]
+ script.args_to = len(inputs)
+
+ dropdown = gr.Dropdown(label="Script", elem_id=f'{parent}_script_list', choices=["None"] + self.titles, value="None", type="index")
+ inputs.insert(0, dropdown)
+ for script in self.selectable_scripts:
+ with gr.Group(visible=False) as group:
+ t0 = time.time()
+ create_script_ui(script, inputs, inputs_alwayson)
+ time_setup[script.title()] = time_setup.get(script.title(), 0) + (time.time()-t0)
+ script.group = group
+
+ def select_script(script_index):
+ selected_script = self.selectable_scripts[script_index - 1] if script_index > 0 else None
+ return [gr.update(visible=selected_script == s) for s in self.selectable_scripts]
+
+ def init_field(title):
+ if title == 'None': # called when an initial value is set from ui-config.json to show script's UI components
+ return
+ script_index = self.titles.index(title)
+ self.selectable_scripts[script_index].group.visible = True
+
+ dropdown.init_field = init_field
+ dropdown.change(fn=select_script, inputs=[dropdown], outputs=[script.group for script in self.selectable_scripts])
+
+ def onload_script_visibility(params):
+ title = params.get('Script', None)
+ if title:
+ title_index = self.titles.index(title)
+ visibility = title_index == self.script_load_ctr
+ self.script_load_ctr = (self.script_load_ctr + 1) % len(self.titles)
+ return gr.update(visible=visibility)
+ else:
+ return gr.update(visible=False)
+
+ with gr.Accordion(label="Extensions", elem_id=f'{parent}_script_alwayson') if accordion else gr.Group():
+ for script in self.alwayson_scripts:
+ t0 = time.time()
+ with gr.Group(elem_id=f'{parent}_script_{script.title().lower().replace(" ", "_")}', elem_classes=['extension-script']) as group:
+ create_script_ui(script, inputs, inputs_alwayson)
+ script.group = group
+ time_setup[script.title()] = time_setup.get(script.title(), 0) + (time.time()-t0)
+
+ self.infotext_fields.append( (dropdown, lambda x: gr.update(value=x.get('Script', 'None'))) )
+ self.infotext_fields.extend( [(script.group, onload_script_visibility) for script in self.selectable_scripts] )
+ return inputs
+
+ def run(self, p, *args):
+ s = ScriptSummary('run')
+ script_index = args[0]
+ if script_index == 0:
+ return None
+ script = self.selectable_scripts[script_index-1]
+ if script is None:
+ return None
+ parsed = p.per_script_args.get(script.title(), args[script.args_from:script.args_to])
+ processed = script.run(p, *parsed)
+ s.record(script.title())
+ s.report()
+ return processed
+
+ def before_process(self, p, **kwargs):
+ s = ScriptSummary('before-process')
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.before_process(p, *script_args, **kwargs)
+ except Exception as e:
+ errors.display(e, f"Error running before process: {script.filename}")
+ s.record(script.title())
+ s.report()
+
+ def process(self, p, **kwargs):
+ s = ScriptSummary('process')
+ for script in self.alwayson_scripts:
+ try:
+ args = p.per_script_args.get(script.title(), p.script_args[script.args_from:script.args_to])
+ script.process(p, *args, **kwargs)
+ except Exception as e:
+ errors.display(e, f'Running script process: {script.filename}')
+ s.record(script.title())
+ s.report()
+
+ def before_process_batch(self, p, **kwargs):
+ s = ScriptSummary('before-process-batch')
+ for script in self.alwayson_scripts:
+ try:
+ args = p.per_script_args.get(script.title(), p.script_args[script.args_from:script.args_to])
+ script.before_process_batch(p, *args, **kwargs)
+ except Exception as e:
+ errors.display(e, f'Running script before process batch: {script.filename}')
+ s.record(script.title())
+ s.report()
+
+ def process_batch(self, p, **kwargs):
+ s = ScriptSummary('process-batch')
+ for script in self.alwayson_scripts:
+ try:
+ args = p.per_script_args.get(script.title(), p.script_args[script.args_from:script.args_to])
+ script.process_batch(p, *args, **kwargs)
+ except Exception as e:
+ errors.display(e, f'Running script process batch: {script.filename}')
+ s.record(script.title())
+ s.report()
+
+ def postprocess(self, p, processed):
+ s = ScriptSummary('postprocess')
+ for script in self.alwayson_scripts:
+ try:
+ args = p.per_script_args.get(script.title(), p.script_args[script.args_from:script.args_to])
+ script.postprocess(p, processed, *args)
+ except Exception as e:
+ errors.display(e, f'Running script postprocess: {script.filename}')
+ s.record(script.title())
+ s.report()
+
+ def postprocess_batch(self, p, images, **kwargs):
+ s = ScriptSummary('postprocess-batch')
+ for script in self.alwayson_scripts:
+ try:
+ args = p.per_script_args.get(script.title(), p.script_args[script.args_from:script.args_to])
+ script.postprocess_batch(p, *args, images=images, **kwargs)
+ except Exception as e:
+ errors.display(e, f'Running script before postprocess batch: {script.filename}')
+ s.record(script.title())
+ s.report()
+
+ def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, **kwargs):
+ s = ScriptSummary('postprocess-batch-list')
+ for script in self.alwayson_scripts:
+ try:
+ args = p.per_script_args.get(script.title(), p.script_args[script.args_from:script.args_to])
+ script.postprocess_batch_list(p, pp, *args, **kwargs)
+ except Exception as e:
+ errors.display(e, f'Running script before postprocess batch list: {script.filename}')
+ s.record(script.title())
+ s.report()
+
+ def postprocess_image(self, p, pp: PostprocessImageArgs):
+ s = ScriptSummary('postprocess-image')
+ for script in self.alwayson_scripts:
+ try:
+ args = p.per_script_args.get(script.title(), p.script_args[script.args_from:script.args_to])
+ script.postprocess_image(p, pp, *args)
+ except Exception as e:
+ errors.display(e, f'Running script postprocess image: {script.filename}')
+ s.record(script.title())
+ s.report()
+
+ def before_component(self, component, **kwargs):
+ s = ScriptSummary('before-component')
+ for script in self.scripts:
+ try:
+ script.before_component(component, **kwargs)
+ except Exception as e:
+ errors.display(e, f'Running script before component: {script.filename}')
+ s.record(script.title())
+ s.report()
+
+ def after_component(self, component, **kwargs):
+ s = ScriptSummary('after-component')
+ for script in self.scripts:
+ try:
+ script.after_component(component, **kwargs)
+ except Exception as e:
+ errors.display(e, f'Running script after component: {script.filename}')
+ s.record(script.title())
+ s.report()
+
+ def reload_sources(self, cache):
+ s = ScriptSummary('reload-sources')
+ for si, script in list(enumerate(self.scripts)):
+ args_from = script.args_from
+ args_to = script.args_to
+ filename = script.filename
+ module = cache.get(filename, None)
+ if module is None:
+ module = script_loading.load_module(script.filename)
+ cache[filename] = module
+ for script_class in module.__dict__.values():
+ if type(script_class) == type and issubclass(script_class, Script):
+ self.scripts[si] = script_class()
+ self.scripts[si].filename = filename
+ self.scripts[si].args_from = args_from
+ self.scripts[si].args_to = args_to
+ s.record(script.title())
+ s.report()
+
+
+scripts_txt2img: ScriptRunner = None
+scripts_img2img: ScriptRunner = None
+scripts_control: ScriptRunner = None
+scripts_current: ScriptRunner = None
+scripts_postproc: scripts_postprocessing.ScriptPostprocessingRunner = None
+reload_scripts = load_scripts # compatibility alias
+
+
+def reload_script_body_only():
+ cache = {}
+ scripts_txt2img.reload_sources(cache)
+ scripts_img2img.reload_sources(cache)
+ scripts_control.reload_sources(cache)
+
+
+def add_classes_to_gradio_component(comp):
+ """
+ this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
+ """
+ elem_classes = []
+ if hasattr(comp, "elem_classes"):
+ elem_classes = comp.elem_classes
+ if elem_classes is None:
+ elem_classes = []
+ comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])]
+ if getattr(comp, 'multiselect', False):
+ comp.elem_classes.append('multiselect')
+
+
+def IOComponent_init(self, *args, **kwargs):
+ if scripts_current is not None:
+ scripts_current.before_component(self, **kwargs)
+ script_callbacks.before_component_callback(self, **kwargs)
+ res = original_IOComponent_init(self, *args, **kwargs) # pylint: disable=assignment-from-no-return
+ add_classes_to_gradio_component(self)
+ script_callbacks.after_component_callback(self, **kwargs)
+ if scripts_current is not None:
+ scripts_current.after_component(self, **kwargs)
+ return res
+
+
+original_IOComponent_init = gr.components.IOComponent.__init__
+gr.components.IOComponent.__init__ = IOComponent_init
+
+
+def BlockContext_init(self, *args, **kwargs):
+ res = original_BlockContext_init(self, *args, **kwargs) # pylint: disable=assignment-from-no-return
+ add_classes_to_gradio_component(self)
+ return res
+
+
+original_BlockContext_init = gr.blocks.BlockContext.__init__
+gr.blocks.BlockContext.__init__ = BlockContext_init
diff --git a/modules/scripts_auto_postprocessing.py b/modules/scripts_auto_postprocessing.py
index dbb04ec9a..51d948a21 100644
--- a/modules/scripts_auto_postprocessing.py
+++ b/modules/scripts_auto_postprocessing.py
@@ -1,36 +1,36 @@
-from modules import scripts, scripts_postprocessing, shared
-
-
-class ScriptPostprocessingForMainUI(scripts.Script):
- def __init__(self, script_postproc):
- self.script: scripts_postprocessing.ScriptPostprocessing = script_postproc
- self.postprocessing_controls = None
-
- def title(self):
- return self.script.name
-
- def show(self, is_img2img): # pylint: disable=unused-argument
- return scripts.AlwaysVisible
-
- def ui(self, is_img2img): # pylint: disable=unused-argument
- self.postprocessing_controls = self.script.ui()
- return self.postprocessing_controls.values()
-
- def postprocess_image(self, p, script_pp, *args): # pylint: disable=arguments-differ
- args_dict = dict(zip(self.postprocessing_controls, args))
- pp = scripts_postprocessing.PostprocessedImage(script_pp.image)
- pp.info = {}
- self.script.process(pp, **args_dict)
- p.extra_generation_params.update(pp.info)
- script_pp.image = pp.image
-
-
-def create_auto_preprocessing_script_data():
- res = []
- for name in shared.opts.postprocessing_enable_in_main_ui:
- script = next(iter([x for x in scripts.postprocessing_scripts_data if x.script_class.name == name]), None)
- if script is None:
- continue
- constructor = lambda s=script: ScriptPostprocessingForMainUI(s.script_class()) # pylint: disable=unnecessary-lambda-assignment
- res.append(scripts.ScriptClassData(script_class=constructor, path=script.path, basedir=script.basedir, module=script.module))
- return res
+from modules import scripts, scripts_postprocessing, shared
+
+
+class ScriptPostprocessingForMainUI(scripts.Script):
+ def __init__(self, script_postproc):
+ self.script: scripts_postprocessing.ScriptPostprocessing = script_postproc
+ self.postprocessing_controls = None
+
+ def title(self):
+ return self.script.name
+
+ def show(self, is_img2img): # pylint: disable=unused-argument
+ return scripts.AlwaysVisible
+
+ def ui(self, is_img2img): # pylint: disable=unused-argument
+ self.postprocessing_controls = self.script.ui()
+ return self.postprocessing_controls.values()
+
+ def postprocess_image(self, p, script_pp, *args): # pylint: disable=arguments-differ
+ args_dict = dict(zip(self.postprocessing_controls, args))
+ pp = scripts_postprocessing.PostprocessedImage(script_pp.image)
+ pp.info = {}
+ self.script.process(pp, **args_dict)
+ p.extra_generation_params.update(pp.info)
+ script_pp.image = pp.image
+
+
+def create_auto_preprocessing_script_data():
+ res = []
+ for name in shared.opts.postprocessing_enable_in_main_ui:
+ script = next(iter([x for x in scripts.postprocessing_scripts_data if x.script_class.name == name]), None)
+ if script is None:
+ continue
+ constructor = lambda s=script: ScriptPostprocessingForMainUI(s.script_class()) # pylint: disable=unnecessary-lambda-assignment
+ res.append(scripts.ScriptClassData(script_class=constructor, path=script.path, basedir=script.basedir, module=script.module))
+ return res
diff --git a/modules/scripts_postprocessing.py b/modules/scripts_postprocessing.py
index f4825da17..3bb92b4b5 100644
--- a/modules/scripts_postprocessing.py
+++ b/modules/scripts_postprocessing.py
@@ -1,134 +1,134 @@
-import os
-import gradio as gr
-from modules import errors, shared
-
-
-class PostprocessedImage:
- def __init__(self, image):
- self.image = image
- self.info = {}
-
-
-class ScriptPostprocessing:
- filename = None
- controls = None
- args_from = None
- args_to = None
- order = 1000 # scripts will be ordred by this value in postprocessing UI
- name = None # this function should return the title of the script
- group = None # A gr.Group component that has all script's UI inside it
-
- def ui(self):
- """
- This function should create gradio UI elements. See https://gradio.app/docs/#components
- The return value should be a dictionary that maps parameter names to components used in processing.
- Values of those components will be passed to process() function.
- """
- pass # pylint: disable=unnecessary-pass
-
- def process(self, pp: PostprocessedImage, **args):
- """
- This function is called to postprocess the image.
- args contains a dictionary with all values returned by components from ui()
- """
- pass # pylint: disable=unnecessary-pass
-
- def image_changed(self):
- pass
-
-
-def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
- try:
- res = func(*args, **kwargs)
- return res
- except Exception as e:
- errors.display(e, f"calling {filename}/{funcname}")
-
- return default
-
-
-class ScriptPostprocessingRunner:
- def __init__(self):
- self.scripts = None
- self.ui_created = False
-
- def initialize_scripts(self, scripts_data):
- self.scripts = []
- for script_class, path, _basedir, _script_module in scripts_data:
- script: ScriptPostprocessing = script_class()
- script.filename = path
- if script.name == "Simple Upscale":
- continue
- self.scripts.append(script)
-
- def create_script_ui(self, script, inputs):
- script.args_from = len(inputs)
- script.args_to = len(inputs)
- script.controls = wrap_call(script.ui, script.filename, "ui")
- for control in script.controls.values():
- control.custom_script_source = os.path.basename(script.filename)
- inputs += list(script.controls.values())
- script.args_to = len(inputs)
-
- def scripts_in_preferred_order(self):
- if self.scripts is None:
- import modules.scripts
- self.initialize_scripts(modules.scripts.postprocessing_scripts_data)
- scripts_order = shared.opts.postprocessing_operation_order
-
- def script_score(name):
- for i, possible_match in enumerate(scripts_order):
- if possible_match == name:
- return i
- return len(self.scripts)
-
- script_scores = {script.name: (script_score(script.name), script.order, script.name, original_index) for original_index, script in enumerate(self.scripts)}
- return sorted(self.scripts, key=lambda x: script_scores[x.name])
-
- def setup_ui(self):
- inputs = []
- for script in self.scripts_in_preferred_order():
- with gr.Accordion(label=script.name, open=False, elem_classes=['postprocess']) as group:
- self.create_script_ui(script, inputs)
- script.group = group
- self.ui_created = True
- return inputs
-
- def run(self, pp: PostprocessedImage, args):
- for script in self.scripts_in_preferred_order():
- shared.state.job = script.name
- script_args = args[script.args_from:script.args_to]
- process_args = {}
- for (name, _component), value in zip(script.controls.items(), script_args):
- process_args[name] = value
- shared.log.debug(f'Process: script={script.name} args={process_args}')
- script.process(pp, **process_args)
-
- def create_args_for_run(self, scripts_args):
- if not self.ui_created:
- with gr.Blocks(analytics_enabled=False):
- self.setup_ui()
- scripts = self.scripts_in_preferred_order()
- args = [None] * max([x.args_to for x in scripts])
- for script in scripts:
- script_args_dict = scripts_args.get(script.name, None)
- if script_args_dict is not None:
- for i, name in enumerate(script.controls):
- args[script.args_from + i] = script_args_dict.get(name, None)
- return args
-
- def image_changed(self):
- for script in self.scripts_in_preferred_order():
- script.image_changed()
-
- def postprocess(self, filenames, args):
- for script in self.scripts_in_preferred_order():
- if not hasattr(script, 'postprocess'):
- continue
- shared.state.job = script.name
- script_args = args[script.args_from:script.args_to]
- process_args = {}
- for (name, _component), value in zip(script.controls.items(), script_args):
- process_args[name] = value
- shared.log.debug(f'Postprocess: script={script.name} args={process_args}')
- script.postprocess(filenames, **process_args)
+import os
+import gradio as gr
+from modules import errors, shared
+
+
+class PostprocessedImage:
+ def __init__(self, image):
+ self.image = image
+ self.info = {}
+
+
+class ScriptPostprocessing:
+ filename = None
+ controls = None
+ args_from = None
+ args_to = None
+ order = 1000 # scripts will be ordred by this value in postprocessing UI
+ name = None # this function should return the title of the script
+ group = None # A gr.Group component that has all script's UI inside it
+
+ def ui(self):
+ """
+ This function should create gradio UI elements. See https://gradio.app/docs/#components
+ The return value should be a dictionary that maps parameter names to components used in processing.
+ Values of those components will be passed to process() function.
+ """
+ pass # pylint: disable=unnecessary-pass
+
+ def process(self, pp: PostprocessedImage, **args):
+ """
+ This function is called to postprocess the image.
+ args contains a dictionary with all values returned by components from ui()
+ """
+ pass # pylint: disable=unnecessary-pass
+
+ def image_changed(self):
+ pass
+
+
+def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
+ try:
+ res = func(*args, **kwargs)
+ return res
+ except Exception as e:
+ errors.display(e, f"calling {filename}/{funcname}")
+
+ return default
+
+
+class ScriptPostprocessingRunner:
+ def __init__(self):
+ self.scripts = None
+ self.ui_created = False
+
+ def initialize_scripts(self, scripts_data):
+ self.scripts = []
+ for script_class, path, _basedir, _script_module in scripts_data:
+ script: ScriptPostprocessing = script_class()
+ script.filename = path
+ if script.name == "Simple Upscale":
+ continue
+ self.scripts.append(script)
+
+ def create_script_ui(self, script, inputs):
+ script.args_from = len(inputs)
+ script.args_to = len(inputs)
+ script.controls = wrap_call(script.ui, script.filename, "ui")
+ for control in script.controls.values():
+ control.custom_script_source = os.path.basename(script.filename)
+ inputs += list(script.controls.values())
+ script.args_to = len(inputs)
+
+ def scripts_in_preferred_order(self):
+ if self.scripts is None:
+ import modules.scripts
+ self.initialize_scripts(modules.scripts.postprocessing_scripts_data)
+ scripts_order = shared.opts.postprocessing_operation_order
+
+ def script_score(name):
+ for i, possible_match in enumerate(scripts_order):
+ if possible_match == name:
+ return i
+ return len(self.scripts)
+
+ script_scores = {script.name: (script_score(script.name), script.order, script.name, original_index) for original_index, script in enumerate(self.scripts)}
+ return sorted(self.scripts, key=lambda x: script_scores[x.name])
+
+ def setup_ui(self):
+ inputs = []
+ for script in self.scripts_in_preferred_order():
+ with gr.Accordion(label=script.name, open=False, elem_classes=['postprocess']) as group:
+ self.create_script_ui(script, inputs)
+ script.group = group
+ self.ui_created = True
+ return inputs
+
+ def run(self, pp: PostprocessedImage, args):
+ for script in self.scripts_in_preferred_order():
+ shared.state.job = script.name
+ script_args = args[script.args_from:script.args_to]
+ process_args = {}
+ for (name, _component), value in zip(script.controls.items(), script_args):
+ process_args[name] = value
+ shared.log.debug(f'Process: script={script.name} args={process_args}')
+ script.process(pp, **process_args)
+
+ def create_args_for_run(self, scripts_args):
+ if not self.ui_created:
+ with gr.Blocks(analytics_enabled=False):
+ self.setup_ui()
+ scripts = self.scripts_in_preferred_order()
+ args = [None] * max([x.args_to for x in scripts])
+ for script in scripts:
+ script_args_dict = scripts_args.get(script.name, None)
+ if script_args_dict is not None:
+ for i, name in enumerate(script.controls):
+ args[script.args_from + i] = script_args_dict.get(name, None)
+ return args
+
+ def image_changed(self):
+ for script in self.scripts_in_preferred_order():
+ script.image_changed()
+
+ def postprocess(self, filenames, args):
+ for script in self.scripts_in_preferred_order():
+ if not hasattr(script, 'postprocess'):
+ continue
+ shared.state.job = script.name
+ script_args = args[script.args_from:script.args_to]
+ process_args = {}
+ for (name, _component), value in zip(script.controls.items(), script_args):
+ process_args[name] = value
+ shared.log.debug(f'Postprocess: script={script.name} args={process_args}')
+ script.postprocess(filenames, **process_args)
diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py
index c30525c30..e9ac1be92 100644
--- a/modules/sd_disable_initialization.py
+++ b/modules/sd_disable_initialization.py
@@ -1,92 +1,92 @@
-import ldm.modules.encoders.modules
-import open_clip
-import torch
-import transformers.utils.hub
-
-
-class DisableInitialization:
- """
- When an object of this class enters a `with` block, it starts:
- - preventing torch's layer initialization functions from working
- - changes CLIP and OpenCLIP to not download model weights
- - changes CLIP to not make requests to check if there is a new version of a file you already have
-
- When it leaves the block, it reverts everything to how it was before.
-
- Use it like this:
- ```
- with DisableInitialization():
- do_things()
- ```
- """
-
- def __init__(self, disable_clip=True):
- self.replaced = []
- self.disable_clip = disable_clip
-
- def replace(self, obj, field, func):
- original = getattr(obj, field, None)
- if original is None:
- return None
-
- self.replaced.append((obj, field, original))
- setattr(obj, field, func)
-
- return original
-
- def __enter__(self):
- def do_nothing(*args, **kwargs): # pylint: disable=unused-argument
- pass
-
- def create_model_and_transforms_without_pretrained(*args, pretrained=None, **kwargs): # pylint: disable=unused-argument
- return self.create_model_and_transforms(*args, pretrained=None, **kwargs)
-
- def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
- res = self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs)
- res.name_or_path = pretrained_model_name_or_path
- return res
-
- def transformers_modeling_utils_load_pretrained_model(*args, **kwargs):
- args = args[0:3] + ('/', ) + args[4:] # resolved_archive_file; must set it to something to prevent what seems to be a bug
- return self.transformers_modeling_utils_load_pretrained_model(*args, **kwargs)
-
- def transformers_utils_hub_get_file_from_cache(original, url, *args, **kwargs):
-
- # this file is always 404, prevent making request
- if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json' or url == 'openai/clip-vit-large-patch14' and args[0] == 'added_tokens.json':
- return None
-
- try:
- res = original(url, *args, local_files_only=True, **kwargs)
- if res is None:
- res = original(url, *args, local_files_only=False, **kwargs)
- return res
- except Exception:
- return original(url, *args, local_files_only=False, **kwargs)
-
- def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs): # pylint: disable=unused-argument
- return transformers_utils_hub_get_file_from_cache(self.transformers_utils_hub_get_from_cache, url, *args, **kwargs)
-
- def transformers_tokenization_utils_base_cached_file(url, *args, local_files_only=False, **kwargs): # pylint: disable=unused-argument
- return transformers_utils_hub_get_file_from_cache(self.transformers_tokenization_utils_base_cached_file, url, *args, **kwargs)
-
- def transformers_configuration_utils_cached_file(url, *args, local_files_only=False, **kwargs): # pylint: disable=unused-argument
- return transformers_utils_hub_get_file_from_cache(self.transformers_configuration_utils_cached_file, url, *args, **kwargs)
-
- self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing)
- self.replace(torch.nn.init, '_no_grad_normal_', do_nothing)
- self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing)
-
- if self.disable_clip:
- self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained) # pylint: disable=attribute-defined-outside-init
- self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained) # pylint: disable=attribute-defined-outside-init
- self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model) # pylint: disable=attribute-defined-outside-init
- self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file) # pylint: disable=attribute-defined-outside-init
- self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file) # pylint: disable=attribute-defined-outside-init
- self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache) # pylint: disable=attribute-defined-outside-init
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- for obj, field, original in self.replaced:
- setattr(obj, field, original)
-
- self.replaced.clear()
+import ldm.modules.encoders.modules
+import open_clip
+import torch
+import transformers.utils.hub
+
+
+class DisableInitialization:
+ """
+ When an object of this class enters a `with` block, it starts:
+ - preventing torch's layer initialization functions from working
+ - changes CLIP and OpenCLIP to not download model weights
+ - changes CLIP to not make requests to check if there is a new version of a file you already have
+
+ When it leaves the block, it reverts everything to how it was before.
+
+ Use it like this:
+ ```
+ with DisableInitialization():
+ do_things()
+ ```
+ """
+
+ def __init__(self, disable_clip=True):
+ self.replaced = []
+ self.disable_clip = disable_clip
+
+ def replace(self, obj, field, func):
+ original = getattr(obj, field, None)
+ if original is None:
+ return None
+
+ self.replaced.append((obj, field, original))
+ setattr(obj, field, func)
+
+ return original
+
+ def __enter__(self):
+ def do_nothing(*args, **kwargs): # pylint: disable=unused-argument
+ pass
+
+ def create_model_and_transforms_without_pretrained(*args, pretrained=None, **kwargs): # pylint: disable=unused-argument
+ return self.create_model_and_transforms(*args, pretrained=None, **kwargs)
+
+ def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
+ res = self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs)
+ res.name_or_path = pretrained_model_name_or_path
+ return res
+
+ def transformers_modeling_utils_load_pretrained_model(*args, **kwargs):
+ args = args[0:3] + ('/', ) + args[4:] # resolved_archive_file; must set it to something to prevent what seems to be a bug
+ return self.transformers_modeling_utils_load_pretrained_model(*args, **kwargs)
+
+ def transformers_utils_hub_get_file_from_cache(original, url, *args, **kwargs):
+
+ # this file is always 404, prevent making request
+ if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json' or url == 'openai/clip-vit-large-patch14' and args[0] == 'added_tokens.json':
+ return None
+
+ try:
+ res = original(url, *args, local_files_only=True, **kwargs)
+ if res is None:
+ res = original(url, *args, local_files_only=False, **kwargs)
+ return res
+ except Exception:
+ return original(url, *args, local_files_only=False, **kwargs)
+
+ def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs): # pylint: disable=unused-argument
+ return transformers_utils_hub_get_file_from_cache(self.transformers_utils_hub_get_from_cache, url, *args, **kwargs)
+
+ def transformers_tokenization_utils_base_cached_file(url, *args, local_files_only=False, **kwargs): # pylint: disable=unused-argument
+ return transformers_utils_hub_get_file_from_cache(self.transformers_tokenization_utils_base_cached_file, url, *args, **kwargs)
+
+ def transformers_configuration_utils_cached_file(url, *args, local_files_only=False, **kwargs): # pylint: disable=unused-argument
+ return transformers_utils_hub_get_file_from_cache(self.transformers_configuration_utils_cached_file, url, *args, **kwargs)
+
+ self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing)
+ self.replace(torch.nn.init, '_no_grad_normal_', do_nothing)
+ self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing)
+
+ if self.disable_clip:
+ self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained) # pylint: disable=attribute-defined-outside-init
+ self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained) # pylint: disable=attribute-defined-outside-init
+ self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model) # pylint: disable=attribute-defined-outside-init
+ self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file) # pylint: disable=attribute-defined-outside-init
+ self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file) # pylint: disable=attribute-defined-outside-init
+ self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache) # pylint: disable=attribute-defined-outside-init
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ for obj, field, original in self.replaced:
+ setattr(obj, field, original)
+
+ self.replaced.clear()
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 816c35d76..1dd8b41ea 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -1,313 +1,313 @@
-from types import MethodType, SimpleNamespace
-import io
-import contextlib
-import torch
-from torch.nn.functional import silu
-
-from modules import shared
-shared.log.debug('Importing LDM')
-stdout = io.StringIO()
-with contextlib.redirect_stdout(stdout):
- import ldm.modules.attention
- import ldm.modules.distributions.distributions
- import ldm.modules.diffusionmodules.model
- import ldm.modules.diffusionmodules.openaimodel
- import ldm.models.diffusion.ddim
- import ldm.models.diffusion.plms
- import ldm.modules.encoders.modules
-
-import modules.textual_inversion.textual_inversion
-from modules import devices, sd_hijack_optimizations
-from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
-from modules.hypernetworks import hypernetwork
-
-attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
-diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
-diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
-
-# new memory efficient cross attention blocks do not support hypernets and we already
-# have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
-ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention
-ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
-
-# silence new console spam from SD2
-ldm.modules.attention.print = lambda *args: None
-ldm.modules.diffusionmodules.model.print = lambda *args: None
-
-current_optimizer = SimpleNamespace(**{ "name": "none" })
-
-def apply_optimizations():
- undo_optimizations()
- ldm.modules.diffusionmodules.model.nonlinearity = silu
- ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
- optimization_method = None
- can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention)
- if devices.device == torch.device("cpu"):
- if shared.opts.cross_attention_optimization == "Scaled-Dot-Product":
- shared.log.warning("Cross-attention: Scaled dot product is not available on CPU")
- can_use_sdp = False
- if shared.opts.cross_attention_optimization == "xFormers":
- shared.log.warning("Cross-attention: xFormers is not available on CPU")
- shared.xformers_available = False
-
- shared.log.info(f"Cross-attention: optimization={shared.opts.cross_attention_optimization} options={shared.opts.cross_attention_options}")
- if shared.opts.cross_attention_optimization == "Disabled":
- optimization_method = 'none'
- if can_use_sdp and shared.opts.cross_attention_optimization == "Scaled-Dot-Product" and 'SDP disable memory attention' in shared.opts.cross_attention_options:
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_no_mem_attention_forward
- ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_no_mem_attnblock_forward
- optimization_method = 'sdp-no-mem'
- elif can_use_sdp and shared.opts.cross_attention_optimization == "Scaled-Dot-Product":
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_attention_forward
- ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_attnblock_forward
- optimization_method = 'sdp'
- if shared.xformers_available and shared.opts.cross_attention_optimization == "xFormers":
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
- ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
- optimization_method = 'xformers'
- if shared.opts.cross_attention_optimization == "Sub-quadratic":
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
- ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sub_quad_attnblock_forward
- optimization_method = 'sub-quadratic'
- if shared.opts.cross_attention_optimization == "Split attention":
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
- optimization_method = 'v1'
- if shared.opts.cross_attention_optimization == "InvokeAI's":
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
- optimization_method = 'invokeai'
- if shared.opts.cross_attention_optimization == "Doggettx's":
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
- ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
- optimization_method = 'doggettx'
- current_optimizer.name = optimization_method
- return optimization_method
-
-
-def undo_optimizations():
- ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
- ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
- ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
-
-
-def fix_checkpoint():
- """checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
- checkpoints to be added when not training (there's a warning)"""
- pass # pylint: disable=unnecessary-pass
-
-
-def weighted_loss(sd_model, pred, target, mean=True):
- #Calculate the weight normally, but ignore the mean
- loss = sd_model._old_get_loss(pred, target, mean=False) # pylint: disable=protected-access
-
- #Check if we have weights available
- weight = getattr(sd_model, '_custom_loss_weight', None)
- if weight is not None:
- loss *= weight
-
- #Return the loss, as mean if specified
- return loss.mean() if mean else loss
-
-def weighted_forward(sd_model, x, c, w, *args, **kwargs):
- try:
- #Temporarily append weights to a place accessible during loss calc
- sd_model._custom_loss_weight = w # pylint: disable=protected-access
-
- #Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
- #Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
- if not hasattr(sd_model, '_old_get_loss'):
- sd_model._old_get_loss = sd_model.get_loss # pylint: disable=protected-access
- sd_model.get_loss = MethodType(weighted_loss, sd_model)
-
- #Run the standard forward function, but with the patched 'get_loss'
- return sd_model.forward(x, c, *args, **kwargs)
- finally:
- try:
- #Delete temporary weights if appended
- del sd_model._custom_loss_weight
- except AttributeError:
- pass
-
- #If we have an old loss function, reset the loss function to the original one
- if hasattr(sd_model, '_old_get_loss'):
- sd_model.get_loss = sd_model._old_get_loss # pylint: disable=protected-access
- del sd_model._old_get_loss
-
-def apply_weighted_forward(sd_model):
- #Add new function 'weighted_forward' that can be called to calc weighted loss
- sd_model.weighted_forward = MethodType(weighted_forward, sd_model)
-
-def undo_weighted_forward(sd_model):
- try:
- del sd_model.weighted_forward
- except AttributeError:
- pass
-
-
-class StableDiffusionModelHijack:
- fixes = None
- comments = []
- layers = None
- circular_enabled = False
- clip = None
- optimization_method = None
-
- embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
-
- def __init__(self):
- self.embedding_db.add_embedding_dir(shared.opts.embeddings_dir)
-
- def hijack(self, m):
- if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
- model_embeddings = m.cond_stage_model.roberta.embeddings
- model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
- m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
-
- elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
- model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
- model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
- m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
-
- elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:
- m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
- m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
-
- apply_weighted_forward(m)
- if m.cond_stage_key == "edit":
- sd_hijack_unet.hijack_ddpm_edit()
-
- if shared.opts.ipex_optimize and shared.backend == shared.Backend.ORIGINAL:
- try:
- import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
- m.model.training = False
- m.model = ipex.optimize(m.model, dtype=devices.dtype_unet, inplace=True, weights_prepack=False) # pylint: disable=attribute-defined-outside-init
- shared.log.info("Applied IPEX Optimize.")
- except Exception as err:
- shared.log.warning(f"IPEX Optimize not supported: {err}")
-
- if (shared.opts.cuda_compile or shared.opts.cuda_compile_vae or shared.opts.cuda_compile_upscaler) and shared.opts.cuda_compile_backend != 'none' and shared.backend == shared.Backend.ORIGINAL:
- try:
- import logging
- shared.log.info(f"Compiling pipeline={m.model.__class__.__name__} mode={shared.opts.cuda_compile_backend}")
- import torch._dynamo # pylint: disable=unused-import,redefined-outer-name
- log_level = logging.WARNING if shared.opts.cuda_compile_verbose else logging.CRITICAL # pylint: disable=protected-access
- if hasattr(torch, '_logging'):
- torch._logging.set_logs(dynamo=log_level, aot=log_level, inductor=log_level) # pylint: disable=protected-access
- torch._dynamo.config.verbose = shared.opts.cuda_compile_verbose # pylint: disable=protected-access
- torch._dynamo.config.suppress_errors = shared.opts.cuda_compile_errors # pylint: disable=protected-access
- torch.backends.cudnn.benchmark = True
- if shared.opts.cuda_compile_backend == 'hidet':
- import hidet # pylint: disable=import-error
- hidet.torch.dynamo_config.use_tensor_core(True)
- hidet.torch.dynamo_config.search_space(2)
- m.model = torch.compile(m.model, mode=shared.opts.cuda_compile_mode, backend=shared.opts.cuda_compile_backend, fullgraph=shared.opts.cuda_compile_fullgraph, dynamic=False)
- shared.log.info("Model complilation done.")
- except Exception as err:
- shared.log.warning(f"Model compile not supported: {err}")
- finally:
- from installer import setup_logging
- setup_logging()
-
- self.optimization_method = apply_optimizations()
- self.clip = m.cond_stage_model
-
- def flatten(el):
- flattened = [flatten(children) for children in el.children()]
- res = [el]
- for c in flattened:
- res += c
- return res
-
- self.layers = flatten(m)
-
- def undo_hijack(self, m):
- if not hasattr(m, 'cond_stage_model'):
- return # not ldm model
- if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
- m.cond_stage_model = m.cond_stage_model.wrapped
- elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
- m.cond_stage_model = m.cond_stage_model.wrapped
- model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
- if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
- model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
- elif type(m.cond_stage_model) == sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords:
- m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped
- m.cond_stage_model = m.cond_stage_model.wrapped
- undo_optimizations()
- undo_weighted_forward(m)
- self.apply_circular(False)
- self.layers = None
- self.clip = None
-
- def apply_circular(self, enable):
- if self.circular_enabled == enable:
- return
- self.circular_enabled = enable
- for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
- layer.padding_mode = 'circular' if enable else 'zeros'
-
- def clear_comments(self):
- self.comments = []
-
- def get_prompt_lengths(self, text):
- if self.clip is None:
- return 0, 0
- _, token_count = self.clip.process_texts([text])
- return token_count, self.clip.get_target_prompt_token_count(token_count)
-
-
-class EmbeddingsWithFixes(torch.nn.Module):
- def __init__(self, wrapped, embeddings):
- super().__init__()
- self.wrapped = wrapped
- self.embeddings = embeddings
-
- def forward(self, input_ids):
- batch_fixes = self.embeddings.fixes
- self.embeddings.fixes = None
-
- inputs_embeds = self.wrapped(input_ids)
-
- if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
- return inputs_embeds
-
- vecs = []
- for fixes, tensor in zip(batch_fixes, inputs_embeds):
- for offset, embedding in fixes:
- emb = devices.cond_cast_unet(embedding.vec)
- emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
- tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
-
- vecs.append(tensor)
-
- return torch.stack(vecs)
-
-
-def add_circular_option_to_conv_2d():
- conv2d_constructor = torch.nn.Conv2d.__init__
-
- def conv2d_constructor_circular(self, *args, **kwargs):
- return conv2d_constructor(self, *args, padding_mode='circular', **kwargs)
-
- torch.nn.Conv2d.__init__ = conv2d_constructor_circular
-
-
-model_hijack = StableDiffusionModelHijack()
-
-
-def register_buffer(self, name, attr):
- """
- Fix register buffer bug for Mac OS.
- """
-
- if type(attr) == torch.Tensor:
- if attr.device != devices.device:
- attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))
-
- setattr(self, name, attr)
-
-
-ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
-ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer
-
-# Ensure samping from Guassian for DDPM follows types
-ldm.modules.distributions.distributions.DiagonalGaussianDistribution.sample = lambda self: self.mean.to(self.parameters.dtype) + self.std.to(self.parameters.dtype) * torch.randn(self.mean.shape, dtype=self.parameters.dtype).to(device=self.parameters.device)
+from types import MethodType, SimpleNamespace
+import io
+import contextlib
+import torch
+from torch.nn.functional import silu
+
+from modules import shared
+shared.log.debug('Importing LDM')
+stdout = io.StringIO()
+with contextlib.redirect_stdout(stdout):
+ import ldm.modules.attention
+ import ldm.modules.distributions.distributions
+ import ldm.modules.diffusionmodules.model
+ import ldm.modules.diffusionmodules.openaimodel
+ import ldm.models.diffusion.ddim
+ import ldm.models.diffusion.plms
+ import ldm.modules.encoders.modules
+
+import modules.textual_inversion.textual_inversion
+from modules import devices, sd_hijack_optimizations
+from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
+from modules.hypernetworks import hypernetwork
+
+attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
+diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
+diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
+
+# new memory efficient cross attention blocks do not support hypernets and we already
+# have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
+ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention
+ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
+
+# silence new console spam from SD2
+ldm.modules.attention.print = lambda *args: None
+ldm.modules.diffusionmodules.model.print = lambda *args: None
+
+current_optimizer = SimpleNamespace(**{ "name": "none" })
+
+def apply_optimizations():
+ undo_optimizations()
+ ldm.modules.diffusionmodules.model.nonlinearity = silu
+ ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
+ optimization_method = None
+ can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention)
+ if devices.device == torch.device("cpu"):
+ if shared.opts.cross_attention_optimization == "Scaled-Dot-Product":
+ shared.log.warning("Cross-attention: Scaled dot product is not available on CPU")
+ can_use_sdp = False
+ if shared.opts.cross_attention_optimization == "xFormers":
+ shared.log.warning("Cross-attention: xFormers is not available on CPU")
+ shared.xformers_available = False
+
+ shared.log.info(f"Cross-attention: optimization={shared.opts.cross_attention_optimization} options={shared.opts.cross_attention_options}")
+ if shared.opts.cross_attention_optimization == "Disabled":
+ optimization_method = 'none'
+ if can_use_sdp and shared.opts.cross_attention_optimization == "Scaled-Dot-Product" and 'SDP disable memory attention' in shared.opts.cross_attention_options:
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_no_mem_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_no_mem_attnblock_forward
+ optimization_method = 'sdp-no-mem'
+ elif can_use_sdp and shared.opts.cross_attention_optimization == "Scaled-Dot-Product":
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_attnblock_forward
+ optimization_method = 'sdp'
+ if shared.xformers_available and shared.opts.cross_attention_optimization == "xFormers":
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
+ optimization_method = 'xformers'
+ if shared.opts.cross_attention_optimization == "Sub-quadratic":
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sub_quad_attnblock_forward
+ optimization_method = 'sub-quadratic'
+ if shared.opts.cross_attention_optimization == "Split attention":
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
+ optimization_method = 'v1'
+ if shared.opts.cross_attention_optimization == "InvokeAI's":
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
+ optimization_method = 'invokeai'
+ if shared.opts.cross_attention_optimization == "Doggettx's":
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
+ optimization_method = 'doggettx'
+ current_optimizer.name = optimization_method
+ return optimization_method
+
+
+def undo_optimizations():
+ ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
+ ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
+
+
+def fix_checkpoint():
+ """checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
+ checkpoints to be added when not training (there's a warning)"""
+ pass # pylint: disable=unnecessary-pass
+
+
+def weighted_loss(sd_model, pred, target, mean=True):
+ #Calculate the weight normally, but ignore the mean
+ loss = sd_model._old_get_loss(pred, target, mean=False) # pylint: disable=protected-access
+
+ #Check if we have weights available
+ weight = getattr(sd_model, '_custom_loss_weight', None)
+ if weight is not None:
+ loss *= weight
+
+ #Return the loss, as mean if specified
+ return loss.mean() if mean else loss
+
+def weighted_forward(sd_model, x, c, w, *args, **kwargs):
+ try:
+ #Temporarily append weights to a place accessible during loss calc
+ sd_model._custom_loss_weight = w # pylint: disable=protected-access
+
+ #Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
+ #Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
+ if not hasattr(sd_model, '_old_get_loss'):
+ sd_model._old_get_loss = sd_model.get_loss # pylint: disable=protected-access
+ sd_model.get_loss = MethodType(weighted_loss, sd_model)
+
+ #Run the standard forward function, but with the patched 'get_loss'
+ return sd_model.forward(x, c, *args, **kwargs)
+ finally:
+ try:
+ #Delete temporary weights if appended
+ del sd_model._custom_loss_weight
+ except AttributeError:
+ pass
+
+ #If we have an old loss function, reset the loss function to the original one
+ if hasattr(sd_model, '_old_get_loss'):
+ sd_model.get_loss = sd_model._old_get_loss # pylint: disable=protected-access
+ del sd_model._old_get_loss
+
+def apply_weighted_forward(sd_model):
+ #Add new function 'weighted_forward' that can be called to calc weighted loss
+ sd_model.weighted_forward = MethodType(weighted_forward, sd_model)
+
+def undo_weighted_forward(sd_model):
+ try:
+ del sd_model.weighted_forward
+ except AttributeError:
+ pass
+
+
+class StableDiffusionModelHijack:
+ fixes = None
+ comments = []
+ layers = None
+ circular_enabled = False
+ clip = None
+ optimization_method = None
+
+ embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
+
+ def __init__(self):
+ self.embedding_db.add_embedding_dir(shared.opts.embeddings_dir)
+
+ def hijack(self, m):
+ if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
+ model_embeddings = m.cond_stage_model.roberta.embeddings
+ model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
+ m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
+
+ elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
+ model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
+ model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
+ m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
+
+ elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:
+ m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
+ m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
+
+ apply_weighted_forward(m)
+ if m.cond_stage_key == "edit":
+ sd_hijack_unet.hijack_ddpm_edit()
+
+ if shared.opts.ipex_optimize and shared.backend == shared.Backend.ORIGINAL:
+ try:
+ import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
+ m.model.training = False
+ m.model = ipex.optimize(m.model, dtype=devices.dtype_unet, inplace=True, weights_prepack=False) # pylint: disable=attribute-defined-outside-init
+ shared.log.info("Applied IPEX Optimize.")
+ except Exception as err:
+ shared.log.warning(f"IPEX Optimize not supported: {err}")
+
+ if (shared.opts.cuda_compile or shared.opts.cuda_compile_vae or shared.opts.cuda_compile_upscaler) and shared.opts.cuda_compile_backend != 'none' and shared.backend == shared.Backend.ORIGINAL:
+ try:
+ import logging
+ shared.log.info(f"Compiling pipeline={m.model.__class__.__name__} mode={shared.opts.cuda_compile_backend}")
+ import torch._dynamo # pylint: disable=unused-import,redefined-outer-name
+ log_level = logging.WARNING if shared.opts.cuda_compile_verbose else logging.CRITICAL # pylint: disable=protected-access
+ if hasattr(torch, '_logging'):
+ torch._logging.set_logs(dynamo=log_level, aot=log_level, inductor=log_level) # pylint: disable=protected-access
+ torch._dynamo.config.verbose = shared.opts.cuda_compile_verbose # pylint: disable=protected-access
+ torch._dynamo.config.suppress_errors = shared.opts.cuda_compile_errors # pylint: disable=protected-access
+ torch.backends.cudnn.benchmark = True
+ if shared.opts.cuda_compile_backend == 'hidet':
+ import hidet # pylint: disable=import-error
+ hidet.torch.dynamo_config.use_tensor_core(True)
+ hidet.torch.dynamo_config.search_space(2)
+ m.model = torch.compile(m.model, mode=shared.opts.cuda_compile_mode, backend=shared.opts.cuda_compile_backend, fullgraph=shared.opts.cuda_compile_fullgraph, dynamic=False)
+ shared.log.info("Model complilation done.")
+ except Exception as err:
+ shared.log.warning(f"Model compile not supported: {err}")
+ finally:
+ from installer import setup_logging
+ setup_logging()
+
+ self.optimization_method = apply_optimizations()
+ self.clip = m.cond_stage_model
+
+ def flatten(el):
+ flattened = [flatten(children) for children in el.children()]
+ res = [el]
+ for c in flattened:
+ res += c
+ return res
+
+ self.layers = flatten(m)
+
+ def undo_hijack(self, m):
+ if not hasattr(m, 'cond_stage_model'):
+ return # not ldm model
+ if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
+ m.cond_stage_model = m.cond_stage_model.wrapped
+ elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
+ m.cond_stage_model = m.cond_stage_model.wrapped
+ model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
+ if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
+ model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
+ elif type(m.cond_stage_model) == sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords:
+ m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped
+ m.cond_stage_model = m.cond_stage_model.wrapped
+ undo_optimizations()
+ undo_weighted_forward(m)
+ self.apply_circular(False)
+ self.layers = None
+ self.clip = None
+
+ def apply_circular(self, enable):
+ if self.circular_enabled == enable:
+ return
+ self.circular_enabled = enable
+ for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
+ layer.padding_mode = 'circular' if enable else 'zeros'
+
+ def clear_comments(self):
+ self.comments = []
+
+ def get_prompt_lengths(self, text):
+ if self.clip is None:
+ return 0, 0
+ _, token_count = self.clip.process_texts([text])
+ return token_count, self.clip.get_target_prompt_token_count(token_count)
+
+
+class EmbeddingsWithFixes(torch.nn.Module):
+ def __init__(self, wrapped, embeddings):
+ super().__init__()
+ self.wrapped = wrapped
+ self.embeddings = embeddings
+
+ def forward(self, input_ids):
+ batch_fixes = self.embeddings.fixes
+ self.embeddings.fixes = None
+
+ inputs_embeds = self.wrapped(input_ids)
+
+ if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
+ return inputs_embeds
+
+ vecs = []
+ for fixes, tensor in zip(batch_fixes, inputs_embeds):
+ for offset, embedding in fixes:
+ emb = devices.cond_cast_unet(embedding.vec)
+ emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
+ tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
+
+ vecs.append(tensor)
+
+ return torch.stack(vecs)
+
+
+def add_circular_option_to_conv_2d():
+ conv2d_constructor = torch.nn.Conv2d.__init__
+
+ def conv2d_constructor_circular(self, *args, **kwargs):
+ return conv2d_constructor(self, *args, padding_mode='circular', **kwargs)
+
+ torch.nn.Conv2d.__init__ = conv2d_constructor_circular
+
+
+model_hijack = StableDiffusionModelHijack()
+
+
+def register_buffer(self, name, attr):
+ """
+ Fix register buffer bug for Mac OS.
+ """
+
+ if type(attr) == torch.Tensor:
+ if attr.device != devices.device:
+ attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))
+
+ setattr(self, name, attr)
+
+
+ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
+ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer
+
+# Ensure samping from Guassian for DDPM follows types
+ldm.modules.distributions.distributions.DiagonalGaussianDistribution.sample = lambda self: self.mean.to(self.parameters.dtype) + self.std.to(self.parameters.dtype) * torch.randn(self.mean.shape, dtype=self.parameters.dtype).to(device=self.parameters.device)
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py
index 14510a689..c8fe52885 100644
--- a/modules/sd_hijack_clip.py
+++ b/modules/sd_hijack_clip.py
@@ -1,256 +1,256 @@
-import math
-from collections import namedtuple
-import torch
-from modules import prompt_parser, devices, sd_hijack
-from modules.shared import opts
-
-
-class PromptChunk:
- """
- This object contains token ids, weight (multipliers:1.4) and textual inversion embedding info for a chunk of prompt.
- If a prompt is short, it is represented by one PromptChunk, otherwise, multiple are necessary.
- Each PromptChunk contains an exact amount of tokens - 77, which includes one for start and end token,
- so just 75 tokens from prompt.
- """
- def __init__(self):
- self.tokens = []
- self.multipliers = []
- self.fixes = []
-
-
-PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
-"""An object of this type is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt
-chunk. Thos objects are found in PromptChunk.fixes and, are placed into FrozenCLIPEmbedderWithCustomWordsBase.hijack.fixes, and finally
-are applied by sd_hijack.EmbeddingsWithFixes's forward function."""
-
-
-class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
- """A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to
- have unlimited prompt length and assign weights to tokens in prompt.
- """
- def __init__(self, wrapped, hijack):
- super().__init__()
- self.wrapped = wrapped
- """Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation,
- depending on model."""
- self.hijack: sd_hijack.StableDiffusionModelHijack = hijack
- self.chunk_length = 75
-
- def empty_chunk(self):
- """creates an empty PromptChunk and returns it"""
- chunk = PromptChunk()
- chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1)
- chunk.multipliers = [1.0] * (self.chunk_length + 2)
- return chunk
-
- def get_target_prompt_token_count(self, token_count):
- """returns the maximum number of tokens a prompt of a known length can have before it requires one more PromptChunk to be represented"""
- return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length
-
- def tokenize(self, texts):
- """Converts a batch of texts into a batch of token ids"""
- raise NotImplementedError
-
- def encode_with_transformers(self, tokens):
- """
- converts a batch of token ids (in python lists) into a single tensor with numeric respresentation of those tokens;
- All python lists with tokens are assumed to have same length, usually 77.
- if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on
- model - can be 768 and 1024.
- Among other things, this call will read self.hijack.fixes, apply it to its inputs, and clear it (setting it to None).
- """
- raise NotImplementedError
-
- def encode_embedding_init_text(self, init_text, nvpt):
- """Converts text into a tensor with this text's tokens' embeddings. Note that those are embeddings before they are passed through
- transformers. nvpt is used as a maximum length in tokens. If text produces less teokens than nvpt, only this many is returned."""
- raise NotImplementedError
-
- def tokenize_line(self, line):
- """
- this transforms a single prompt into a list of PromptChunk objects - as many as needed to
- represent the prompt.
- Returns the list and the total number of tokens in the prompt.
- """
- parsed = prompt_parser.parse_prompt_attention(line)
- tokenized = self.tokenize([text for text, _ in parsed])
- chunks = []
- chunk = PromptChunk()
- token_count = 0
- last_comma = -1
-
- def next_chunk(is_last=False):
- """puts current chunk into the list of results and produces the next one - empty;
- if is_last is true, tokens tokens at the end won't add to token_count"""
- nonlocal token_count
- nonlocal last_comma
- nonlocal chunk
- if is_last:
- token_count += len(chunk.tokens)
- else:
- token_count += self.chunk_length
- to_add = self.chunk_length - len(chunk.tokens)
- if to_add > 0:
- chunk.tokens += [self.id_end] * to_add
- chunk.multipliers += [1.0] * to_add
- chunk.tokens = [self.id_start] + chunk.tokens + [self.id_end]
- chunk.multipliers = [1.0] + chunk.multipliers + [1.0]
- last_comma = -1
- chunks.append(chunk)
- chunk = PromptChunk()
-
- for tokens, (text, weight) in zip(tokenized, parsed):
- if text == 'BREAK' and weight == -1:
- next_chunk()
- continue
- position = 0
- while position < len(tokens):
- token = tokens[position]
- if token == self.comma_token:
- last_comma = len(chunk.tokens)
- # this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack
- # is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next.
- elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack:
- break_location = last_comma + 1
- reloc_tokens = chunk.tokens[break_location:]
- reloc_mults = chunk.multipliers[break_location:]
- chunk.tokens = chunk.tokens[:break_location]
- chunk.multipliers = chunk.multipliers[:break_location]
- next_chunk()
- chunk.tokens = reloc_tokens
- chunk.multipliers = reloc_mults
- if len(chunk.tokens) == self.chunk_length:
- next_chunk()
- embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, position)
- if embedding is None:
- chunk.tokens.append(token)
- chunk.multipliers.append(weight)
- position += 1
- continue
- emb_len = int(embedding.vec.shape[0])
- if len(chunk.tokens) + emb_len > self.chunk_length:
- next_chunk()
- chunk.fixes.append(PromptChunkFix(len(chunk.tokens), embedding))
- chunk.tokens += [0] * emb_len
- chunk.multipliers += [weight] * emb_len
- position += embedding_length_in_tokens
- if len(chunk.tokens) > 0 or len(chunks) == 0:
- next_chunk(is_last=True)
- return chunks, token_count
-
- def process_texts(self, texts):
- """
- Accepts a list of texts and calls tokenize_line() on each, with cache. Returns the list of results and maximum
- length, in tokens, of all texts.
- """
- token_count = 0
- cache = {}
- batch_chunks = []
- for line in texts:
- if line in cache:
- chunks = cache[line]
- else:
- chunks, current_token_count = self.tokenize_line(line)
- token_count = max(current_token_count, token_count)
- cache[line] = chunks
- batch_chunks.append(chunks)
- return batch_chunks, token_count
-
- def forward(self, texts):
- """
- Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.
- Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will
- be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024.
- An example shape returned by this function can be: (2, 77, 768).
- Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
- is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
- """
- batch_chunks, _token_count = self.process_texts(texts)
- used_embeddings = {}
- chunk_count = max([len(x) for x in batch_chunks])
- zs = []
- for i in range(chunk_count):
- batch_chunk = [chunks[i] if i < len(chunks) else self.empty_chunk() for chunks in batch_chunks]
- tokens = [x.tokens for x in batch_chunk]
- multipliers = [x.multipliers for x in batch_chunk]
- self.hijack.fixes = [x.fixes for x in batch_chunk]
- for fixes in self.hijack.fixes:
- for _position, embedding in fixes:
- used_embeddings[embedding.name] = embedding
- z = self.process_tokens(tokens, multipliers)
- zs.append(z)
- self.hijack.embedding_db.embeddings_used = list(used_embeddings)
- return torch.hstack(zs)
-
- def process_tokens(self, remade_batch_tokens, batch_multipliers):
- """
- sends one single prompt chunk to be encoded by transformers neural network.
- remade_batch_tokens is a batch of tokens - a list, where every element is a list of tokens; usually
- there are exactly 77 tokens in the list. batch_multipliers is the same but for multipliers instead of tokens.
- Multipliers are used to give more or less weight to the outputs of transformers network. Each multiplier
- corresponds to one token.
- """
- tokens = torch.asarray(remade_batch_tokens).to(devices.device)
- # this is for SD2: SD1 uses the same token for padding and end of text, while SD2 uses different ones.
- if self.id_end != self.id_pad:
- for batch_pos in range(len(remade_batch_tokens)):
- index = remade_batch_tokens[batch_pos].index(self.id_end)
- tokens[batch_pos, index+1:tokens.shape[1]] = self.id_pad
-
- z = self.encode_with_transformers(tokens)
- # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
- batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
- if opts.prompt_mean_norm:
- original_mean = z.mean()
- z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
- new_mean = z.mean()
- z = z * (original_mean / new_mean)
- else:
- z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
- return z
-
-
-class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
- def __init__(self, wrapped, hijack):
- super().__init__(wrapped, hijack)
- self.tokenizer = wrapped.tokenizer
- vocab = self.tokenizer.get_vocab()
- self.comma_token = vocab.get(',', None)
- self.token_mults = {}
- tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k]
- for text, ident in tokens_with_parens:
- mult = 1.0
- for c in text:
- if c == '[':
- mult /= 1.1
- if c == ']':
- mult *= 1.1
- if c == '(':
- mult *= 1.1
- if c == ')':
- mult /= 1.1
- if mult != 1.0:
- self.token_mults[ident] = mult
- self.id_start = self.wrapped.tokenizer.bos_token_id
- self.id_end = self.wrapped.tokenizer.eos_token_id
- self.id_pad = self.id_end
-
- def tokenize(self, texts):
- tokenized = self.wrapped.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
- return tokenized
-
- def encode_with_transformers(self, tokens):
- clip_skip = opts.data['clip_skip'] or 1
- outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-clip_skip)
- if clip_skip > 1:
- z = outputs.hidden_states[-clip_skip]
- z = self.wrapped.transformer.text_model.final_layer_norm(z)
- else:
- z = outputs.last_hidden_state
- return z
-
- def encode_embedding_init_text(self, init_text, nvpt):
- embedding_layer = self.wrapped.transformer.text_model.embeddings
- ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
- embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0)
- return embedded
+import math
+from collections import namedtuple
+import torch
+from modules import prompt_parser, devices, sd_hijack
+from modules.shared import opts
+
+
+class PromptChunk:
+ """
+ This object contains token ids, weight (multipliers:1.4) and textual inversion embedding info for a chunk of prompt.
+ If a prompt is short, it is represented by one PromptChunk, otherwise, multiple are necessary.
+ Each PromptChunk contains an exact amount of tokens - 77, which includes one for start and end token,
+ so just 75 tokens from prompt.
+ """
+ def __init__(self):
+ self.tokens = []
+ self.multipliers = []
+ self.fixes = []
+
+
+PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
+"""An object of this type is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt
+chunk. Thos objects are found in PromptChunk.fixes and, are placed into FrozenCLIPEmbedderWithCustomWordsBase.hijack.fixes, and finally
+are applied by sd_hijack.EmbeddingsWithFixes's forward function."""
+
+
+class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
+ """A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to
+ have unlimited prompt length and assign weights to tokens in prompt.
+ """
+ def __init__(self, wrapped, hijack):
+ super().__init__()
+ self.wrapped = wrapped
+ """Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation,
+ depending on model."""
+ self.hijack: sd_hijack.StableDiffusionModelHijack = hijack
+ self.chunk_length = 75
+
+ def empty_chunk(self):
+ """creates an empty PromptChunk and returns it"""
+ chunk = PromptChunk()
+ chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1)
+ chunk.multipliers = [1.0] * (self.chunk_length + 2)
+ return chunk
+
+ def get_target_prompt_token_count(self, token_count):
+ """returns the maximum number of tokens a prompt of a known length can have before it requires one more PromptChunk to be represented"""
+ return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length
+
+ def tokenize(self, texts):
+ """Converts a batch of texts into a batch of token ids"""
+ raise NotImplementedError
+
+ def encode_with_transformers(self, tokens):
+ """
+ converts a batch of token ids (in python lists) into a single tensor with numeric respresentation of those tokens;
+ All python lists with tokens are assumed to have same length, usually 77.
+ if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on
+ model - can be 768 and 1024.
+ Among other things, this call will read self.hijack.fixes, apply it to its inputs, and clear it (setting it to None).
+ """
+ raise NotImplementedError
+
+ def encode_embedding_init_text(self, init_text, nvpt):
+ """Converts text into a tensor with this text's tokens' embeddings. Note that those are embeddings before they are passed through
+ transformers. nvpt is used as a maximum length in tokens. If text produces less teokens than nvpt, only this many is returned."""
+ raise NotImplementedError
+
+ def tokenize_line(self, line):
+ """
+ this transforms a single prompt into a list of PromptChunk objects - as many as needed to
+ represent the prompt.
+ Returns the list and the total number of tokens in the prompt.
+ """
+ parsed = prompt_parser.parse_prompt_attention(line)
+ tokenized = self.tokenize([text for text, _ in parsed])
+ chunks = []
+ chunk = PromptChunk()
+ token_count = 0
+ last_comma = -1
+
+ def next_chunk(is_last=False):
+ """puts current chunk into the list of results and produces the next one - empty;
+ if is_last is true, tokens tokens at the end won't add to token_count"""
+ nonlocal token_count
+ nonlocal last_comma
+ nonlocal chunk
+ if is_last:
+ token_count += len(chunk.tokens)
+ else:
+ token_count += self.chunk_length
+ to_add = self.chunk_length - len(chunk.tokens)
+ if to_add > 0:
+ chunk.tokens += [self.id_end] * to_add
+ chunk.multipliers += [1.0] * to_add
+ chunk.tokens = [self.id_start] + chunk.tokens + [self.id_end]
+ chunk.multipliers = [1.0] + chunk.multipliers + [1.0]
+ last_comma = -1
+ chunks.append(chunk)
+ chunk = PromptChunk()
+
+ for tokens, (text, weight) in zip(tokenized, parsed):
+ if text == 'BREAK' and weight == -1:
+ next_chunk()
+ continue
+ position = 0
+ while position < len(tokens):
+ token = tokens[position]
+ if token == self.comma_token:
+ last_comma = len(chunk.tokens)
+ # this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack
+ # is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next.
+ elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack:
+ break_location = last_comma + 1
+ reloc_tokens = chunk.tokens[break_location:]
+ reloc_mults = chunk.multipliers[break_location:]
+ chunk.tokens = chunk.tokens[:break_location]
+ chunk.multipliers = chunk.multipliers[:break_location]
+ next_chunk()
+ chunk.tokens = reloc_tokens
+ chunk.multipliers = reloc_mults
+ if len(chunk.tokens) == self.chunk_length:
+ next_chunk()
+ embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, position)
+ if embedding is None:
+ chunk.tokens.append(token)
+ chunk.multipliers.append(weight)
+ position += 1
+ continue
+ emb_len = int(embedding.vec.shape[0])
+ if len(chunk.tokens) + emb_len > self.chunk_length:
+ next_chunk()
+ chunk.fixes.append(PromptChunkFix(len(chunk.tokens), embedding))
+ chunk.tokens += [0] * emb_len
+ chunk.multipliers += [weight] * emb_len
+ position += embedding_length_in_tokens
+ if len(chunk.tokens) > 0 or len(chunks) == 0:
+ next_chunk(is_last=True)
+ return chunks, token_count
+
+ def process_texts(self, texts):
+ """
+ Accepts a list of texts and calls tokenize_line() on each, with cache. Returns the list of results and maximum
+ length, in tokens, of all texts.
+ """
+ token_count = 0
+ cache = {}
+ batch_chunks = []
+ for line in texts:
+ if line in cache:
+ chunks = cache[line]
+ else:
+ chunks, current_token_count = self.tokenize_line(line)
+ token_count = max(current_token_count, token_count)
+ cache[line] = chunks
+ batch_chunks.append(chunks)
+ return batch_chunks, token_count
+
+ def forward(self, texts):
+ """
+ Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.
+ Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will
+ be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024.
+ An example shape returned by this function can be: (2, 77, 768).
+ Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
+ is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
+ """
+ batch_chunks, _token_count = self.process_texts(texts)
+ used_embeddings = {}
+ chunk_count = max([len(x) for x in batch_chunks])
+ zs = []
+ for i in range(chunk_count):
+ batch_chunk = [chunks[i] if i < len(chunks) else self.empty_chunk() for chunks in batch_chunks]
+ tokens = [x.tokens for x in batch_chunk]
+ multipliers = [x.multipliers for x in batch_chunk]
+ self.hijack.fixes = [x.fixes for x in batch_chunk]
+ for fixes in self.hijack.fixes:
+ for _position, embedding in fixes:
+ used_embeddings[embedding.name] = embedding
+ z = self.process_tokens(tokens, multipliers)
+ zs.append(z)
+ self.hijack.embedding_db.embeddings_used = list(used_embeddings)
+ return torch.hstack(zs)
+
+ def process_tokens(self, remade_batch_tokens, batch_multipliers):
+ """
+ sends one single prompt chunk to be encoded by transformers neural network.
+ remade_batch_tokens is a batch of tokens - a list, where every element is a list of tokens; usually
+ there are exactly 77 tokens in the list. batch_multipliers is the same but for multipliers instead of tokens.
+ Multipliers are used to give more or less weight to the outputs of transformers network. Each multiplier
+ corresponds to one token.
+ """
+ tokens = torch.asarray(remade_batch_tokens).to(devices.device)
+ # this is for SD2: SD1 uses the same token for padding and end of text, while SD2 uses different ones.
+ if self.id_end != self.id_pad:
+ for batch_pos in range(len(remade_batch_tokens)):
+ index = remade_batch_tokens[batch_pos].index(self.id_end)
+ tokens[batch_pos, index+1:tokens.shape[1]] = self.id_pad
+
+ z = self.encode_with_transformers(tokens)
+ # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
+ batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
+ if opts.prompt_mean_norm:
+ original_mean = z.mean()
+ z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
+ new_mean = z.mean()
+ z = z * (original_mean / new_mean)
+ else:
+ z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
+ return z
+
+
+class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
+ def __init__(self, wrapped, hijack):
+ super().__init__(wrapped, hijack)
+ self.tokenizer = wrapped.tokenizer
+ vocab = self.tokenizer.get_vocab()
+ self.comma_token = vocab.get(',', None)
+ self.token_mults = {}
+ tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k]
+ for text, ident in tokens_with_parens:
+ mult = 1.0
+ for c in text:
+ if c == '[':
+ mult /= 1.1
+ if c == ']':
+ mult *= 1.1
+ if c == '(':
+ mult *= 1.1
+ if c == ')':
+ mult /= 1.1
+ if mult != 1.0:
+ self.token_mults[ident] = mult
+ self.id_start = self.wrapped.tokenizer.bos_token_id
+ self.id_end = self.wrapped.tokenizer.eos_token_id
+ self.id_pad = self.id_end
+
+ def tokenize(self, texts):
+ tokenized = self.wrapped.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
+ return tokenized
+
+ def encode_with_transformers(self, tokens):
+ clip_skip = opts.data['clip_skip'] or 1
+ outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-clip_skip)
+ if clip_skip > 1:
+ z = outputs.hidden_states[-clip_skip]
+ z = self.wrapped.transformer.text_model.final_layer_norm(z)
+ else:
+ z = outputs.last_hidden_state
+ return z
+
+ def encode_embedding_init_text(self, init_text, nvpt):
+ embedding_layer = self.wrapped.transformer.text_model.embeddings
+ ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
+ embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0)
+ return embedded
diff --git a/modules/sd_hijack_clip_old.py b/modules/sd_hijack_clip_old.py
index 21af997e9..ba6c38754 100644
--- a/modules/sd_hijack_clip_old.py
+++ b/modules/sd_hijack_clip_old.py
@@ -1,82 +1,82 @@
-from modules import sd_hijack_clip
-from modules import shared
-
-
-def process_text_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
- id_start = self.id_start
- id_end = self.id_end
- maxlen = self.wrapped.max_length # you get to stay at 77
- used_custom_terms = []
- remade_batch_tokens = []
- hijack_comments = []
- hijack_fixes = []
- token_count = 0
-
- cache = {}
- batch_tokens = self.tokenize(texts)
- batch_multipliers = []
- for tokens in batch_tokens:
- tuple_tokens = tuple(tokens)
-
- if tuple_tokens in cache:
- remade_tokens, fixes, multipliers = cache[tuple_tokens]
- else:
- fixes = []
- remade_tokens = []
- multipliers = []
- mult = 1.0
-
- i = 0
- while i < len(tokens):
- token = tokens[i]
-
- embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
-
- mult_change = self.token_mults.get(token) if shared.opts.enable_emphasis else None
- if mult_change is not None:
- mult *= mult_change
- i += 1
- elif embedding is None:
- remade_tokens.append(token)
- multipliers.append(mult)
- i += 1
- else:
- emb_len = int(embedding.vec.shape[0])
- fixes.append((len(remade_tokens), embedding))
- remade_tokens += [0] * emb_len
- multipliers += [mult] * emb_len
- used_custom_terms.append((embedding.name, embedding.checksum()))
- i += embedding_length_in_tokens
-
- if len(remade_tokens) > maxlen - 2:
- vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
- ovf = remade_tokens[maxlen - 2:]
- overflowing_words = [vocab.get(int(x), "") for x in ovf]
- overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
- hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
-
- token_count = len(remade_tokens)
- remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
- remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
- cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
-
- multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
- multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
-
- remade_batch_tokens.append(remade_tokens)
- hijack_fixes.append(fixes)
- batch_multipliers.append(multipliers)
- return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
-
-
-def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
- batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, _token_count = process_text_old(self, texts)
-
- self.hijack.comments += hijack_comments
-
- if len(used_custom_terms) > 0:
- embedding_names = ", ".join(f"{word} [{checksum}]" for word, checksum in used_custom_terms)
- self.hijack.comments.append(f"Used embeddings: {embedding_names}")
-
- self.hijack.fixes = hijack_fixes
- return self.process_tokens(remade_batch_tokens, batch_multipliers)
+from modules import sd_hijack_clip
+from modules import shared
+
+
+def process_text_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
+ id_start = self.id_start
+ id_end = self.id_end
+ maxlen = self.wrapped.max_length # you get to stay at 77
+ used_custom_terms = []
+ remade_batch_tokens = []
+ hijack_comments = []
+ hijack_fixes = []
+ token_count = 0
+
+ cache = {}
+ batch_tokens = self.tokenize(texts)
+ batch_multipliers = []
+ for tokens in batch_tokens:
+ tuple_tokens = tuple(tokens)
+
+ if tuple_tokens in cache:
+ remade_tokens, fixes, multipliers = cache[tuple_tokens]
+ else:
+ fixes = []
+ remade_tokens = []
+ multipliers = []
+ mult = 1.0
+
+ i = 0
+ while i < len(tokens):
+ token = tokens[i]
+
+ embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
+
+ mult_change = self.token_mults.get(token) if shared.opts.enable_emphasis else None
+ if mult_change is not None:
+ mult *= mult_change
+ i += 1
+ elif embedding is None:
+ remade_tokens.append(token)
+ multipliers.append(mult)
+ i += 1
+ else:
+ emb_len = int(embedding.vec.shape[0])
+ fixes.append((len(remade_tokens), embedding))
+ remade_tokens += [0] * emb_len
+ multipliers += [mult] * emb_len
+ used_custom_terms.append((embedding.name, embedding.checksum()))
+ i += embedding_length_in_tokens
+
+ if len(remade_tokens) > maxlen - 2:
+ vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
+ ovf = remade_tokens[maxlen - 2:]
+ overflowing_words = [vocab.get(int(x), "") for x in ovf]
+ overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
+ hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
+
+ token_count = len(remade_tokens)
+ remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
+ remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
+ cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
+
+ multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
+ multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
+
+ remade_batch_tokens.append(remade_tokens)
+ hijack_fixes.append(fixes)
+ batch_multipliers.append(multipliers)
+ return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
+
+
+def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
+ batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, _token_count = process_text_old(self, texts)
+
+ self.hijack.comments += hijack_comments
+
+ if len(used_custom_terms) > 0:
+ embedding_names = ", ".join(f"{word} [{checksum}]" for word, checksum in used_custom_terms)
+ self.hijack.comments.append(f"Used embeddings: {embedding_names}")
+
+ self.hijack.fixes = hijack_fixes
+ return self.process_tokens(remade_batch_tokens, batch_multipliers)
diff --git a/modules/sd_hijack_open_clip.py b/modules/sd_hijack_open_clip.py
index 0ff673490..3c06ed54d 100644
--- a/modules/sd_hijack_open_clip.py
+++ b/modules/sd_hijack_open_clip.py
@@ -1,29 +1,29 @@
-import open_clip.tokenizer
-import torch
-
-from modules import sd_hijack_clip, devices
-
-tokenizer = open_clip.tokenizer._tokenizer # pylint: disable=protected-access
-
-
-class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase):
- def __init__(self, wrapped, hijack):
- super().__init__(wrapped, hijack)
- self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ','][0]
- self.id_start = tokenizer.encoder[""]
- self.id_end = tokenizer.encoder[""]
- self.id_pad = 0
-
- def tokenize(self, texts):
- tokenized = [tokenizer.encode(text) for text in texts]
- return tokenized
-
- def encode_with_transformers(self, tokens):
- z = self.wrapped.encode_with_transformer(tokens)
- return z
-
- def encode_embedding_init_text(self, init_text, nvpt): # pylint: disable=unused-argument
- ids = tokenizer.encode(init_text)
- ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
- embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)
- return embedded
+import open_clip.tokenizer
+import torch
+
+from modules import sd_hijack_clip, devices
+
+tokenizer = open_clip.tokenizer._tokenizer # pylint: disable=protected-access
+
+
+class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase):
+ def __init__(self, wrapped, hijack):
+ super().__init__(wrapped, hijack)
+ self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ','][0]
+ self.id_start = tokenizer.encoder[""]
+ self.id_end = tokenizer.encoder[""]
+ self.id_pad = 0
+
+ def tokenize(self, texts):
+ tokenized = [tokenizer.encode(text) for text in texts]
+ return tokenized
+
+ def encode_with_transformers(self, tokens):
+ z = self.wrapped.encode_with_transformer(tokens)
+ return z
+
+ def encode_embedding_init_text(self, init_text, nvpt): # pylint: disable=unused-argument
+ ids = tokenizer.encode(init_text)
+ ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
+ embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)
+ return embedded
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index 9a3d0d237..60a0d3f22 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -1,520 +1,520 @@
-from __future__ import annotations
-import sys
-import math
-import psutil
-
-import torch
-from torch import einsum
-
-from ldm.util import default
-from einops import rearrange
-
-from modules import shared, errors, devices
-from modules.hypernetworks import hypernetwork
-
-from .sub_quadratic_attention import efficient_dot_product_attention # pylint: disable=relative-beyond-top-level
-
-
-if shared.opts.cross_attention_optimization == "xFormers":
- try:
- import xformers.ops # pylint: disable=import-error
- shared.xformers_available = True
- except Exception:
- pass
-else:
- if sys.modules.get("xformers", None) is not None:
- shared.log.debug('Unloading xFormers')
- sys.modules["xformers"] = None
- sys.modules["xformers.ops"] = None
-
-
-def get_available_vram():
- if shared.device.type == 'cuda' or shared.device.type == 'xpu':
- try:
- stats = torch.cuda.memory_stats(shared.device)
- mem_active = stats['active_bytes.all.current']
- mem_reserved = stats['reserved_bytes.all.current']
- mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
- mem_free_torch = mem_reserved - mem_active
- mem_free_total = mem_free_cuda + mem_free_torch
- except Exception:
- mem_free_total = 1024 * 1024 * 1024
- return mem_free_total
- elif shared.device.type == 'privateuseone':
- return torch.dml.mem_get_info(shared.device)[0]
- else:
- return psutil.virtual_memory().available
-
-
-# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
-def split_cross_attention_forward_v1(self, x, context=None, mask=None): # pylint: disable=unused-argument
- h = self.heads
-
- q_in = self.to_q(x)
- context = default(context, x)
-
- context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
- k_in = self.to_k(context_k)
- v_in = self.to_v(context_v)
- del context, context_k, context_v, x
-
- q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
- del q_in, k_in, v_in
-
- dtype = q.dtype
- if shared.opts.upcast_attn:
- q, k, v = q.float(), k.float(), v.float()
-
- with devices.without_autocast(disable=not shared.opts.upcast_attn):
- r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
- for i in range(0, q.shape[0], 2):
- end = i + 2
- s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
- s1 *= self.scale
- s2 = s1.softmax(dim=-1)
- del s1
- r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
- del s2
- del q, k, v
-
- r1 = r1.to(dtype)
-
- r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
- del r1
-
- return self.to_out(r2)
-
-
-# taken from https://github.com/Doggettx/stable-diffusion and modified
-def split_cross_attention_forward(self, x, context=None, mask=None): # pylint: disable=unused-argument
- h = self.heads
- q_in = self.to_q(x)
- context = default(context, x)
-
- context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
- k_in = self.to_k(context_k)
- v_in = self.to_v(context_v)
-
- dtype = q_in.dtype
- if shared.opts.upcast_attn:
- q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float()
-
- with devices.without_autocast(disable=not shared.opts.upcast_attn):
- k_in = k_in * self.scale
- del context, x
- q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
- del q_in, k_in, v_in
- r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
- mem_free_total = get_available_vram()
- gb = 1024 ** 3
- tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
- modifier = 3 if q.element_size() == 2 else 2.5
- mem_required = tensor_size * modifier
- steps = 1
- if mem_required > mem_free_total:
- steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
- if steps > 64:
- max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
- raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
- f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
- slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
- for i in range(0, q.shape[1], slice_size):
- end = i + slice_size
- s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
- s2 = s1.softmax(dim=-1, dtype=q.dtype)
- del s1
- r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
- del s2
- del q, k, v
- r1 = r1.to(dtype)
- r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
- del r1
-
- return self.to_out(r2)
-
-
-# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
-mem_total_gb = psutil.virtual_memory().total // (1 << 30)
-
-def einsum_op_compvis(q, k, v):
- s = einsum('b i d, b j d -> b i j', q, k)
- s = s.softmax(dim=-1, dtype=s.dtype)
- return einsum('b i j, b j d -> b i d', s, v)
-
-def einsum_op_slice_0(q, k, v, slice_size):
- r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
- for i in range(0, q.shape[0], slice_size):
- end = i + slice_size
- r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end])
- return r
-
-def einsum_op_slice_1(q, k, v, slice_size):
- r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
- for i in range(0, q.shape[1], slice_size):
- end = i + slice_size
- r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v)
- return r
-
-def einsum_op_mps_v1(q, k, v):
- if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096
- return einsum_op_compvis(q, k, v)
- else:
- slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
- if slice_size % 4096 == 0:
- slice_size -= 1
- return einsum_op_slice_1(q, k, v, slice_size)
-
-def einsum_op_mps_v2(q, k, v):
- if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:
- return einsum_op_compvis(q, k, v)
- else:
- return einsum_op_slice_0(q, k, v, 1)
-
-def einsum_op_tensor_mem(q, k, v, max_tensor_mb):
- size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
- if size_mb <= max_tensor_mb:
- return einsum_op_compvis(q, k, v)
- div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
- if div <= q.shape[0]:
- return einsum_op_slice_0(q, k, v, q.shape[0] // div)
- return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
-
-def einsum_op_cuda(q, k, v):
- try:
- stats = torch.cuda.memory_stats(q.device)
- mem_active = stats['active_bytes.all.current']
- mem_reserved = stats['reserved_bytes.all.current']
- mem_free_cuda, _ = torch.cuda.mem_get_info(q.device)
- mem_free_torch = mem_reserved - mem_active
- mem_free_total = mem_free_cuda + mem_free_torch
- except Exception:
- mem_free_total = 1024 * 1024 * 1024
- # Divide factor of safety as there's copying and fragmentation
- return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
-
-def einsum_op_dml(q, k, v):
- mem_free, mem_total = torch.dml.mem_get_info(q.device)
- mem_active = mem_total - mem_free
- mem_reserved = mem_total * 0.7
- return einsum_op_tensor_mem(q, k, v, (mem_reserved - mem_active) if mem_reserved > mem_active else 1)
-
-def einsum_op(q, k, v):
- if q.device.type == 'cuda' or q.device.type == 'xpu':
- return einsum_op_cuda(q, k, v)
-
- if q.device.type == 'mps':
- if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18:
- return einsum_op_mps_v1(q, k, v)
- return einsum_op_mps_v2(q, k, v)
-
- if q.device.type == 'privateuseone':
- return einsum_op_dml(q, k, v)
-
- # Smaller slices are faster due to L2/L3/SLC caches.
- # Tested on i7 with 8MB L3 cache.
- return einsum_op_tensor_mem(q, k, v, 32)
-
-def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): # pylint: disable=unused-argument
- h = self.heads
-
- q = self.to_q(x)
- context = default(context, x)
-
- context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
- k = self.to_k(context_k)
- v = self.to_v(context_v)
- del context, context_k, context_v, x
-
- dtype = q.dtype
- if shared.opts.upcast_attn:
- q, k, v = q.float(), k.float(), v if v.device.type == 'mps' else v.float()
-
- with devices.without_autocast(disable=not shared.opts.upcast_attn):
- k = k * self.scale
- q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
- r = einsum_op(q, k, v)
- r = r.to(dtype)
- return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
-
-# -- End of code from https://github.com/invoke-ai/InvokeAI --
-
-
-# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
-# The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface
-def sub_quad_attention_forward(self, x, context=None, mask=None):
- assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."
-
- h = self.heads
-
- q = self.to_q(x)
- context = default(context, x)
-
- context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
- k = self.to_k(context_k)
- v = self.to_v(context_v)
- del context, context_k, context_v, x
-
- q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
- k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
- v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
-
- if q.device.type == 'mps':
- q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
-
- dtype = q.dtype
- if shared.opts.upcast_attn:
- q, k = q.float(), k.float()
-
- x = sub_quad_attention(q, k, v, q_chunk_size=shared.opts.sub_quad_q_chunk_size, kv_chunk_size=shared.opts.sub_quad_kv_chunk_size, chunk_threshold=shared.opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
-
- x = x.to(dtype)
-
- x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)
-
- out_proj, dropout = self.to_out
- x = out_proj(x)
- x = dropout(x)
-
- return x
-
-def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True):
- bytes_per_token = torch.finfo(q.dtype).bits//8
- batch_x_heads, q_tokens, _ = q.shape
- _, k_tokens, _ = k.shape
- qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
-
- if chunk_threshold is None:
- chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7)
- elif chunk_threshold == 0:
- chunk_threshold_bytes = None
- else:
- chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram())
-
- if kv_chunk_size_min is None and chunk_threshold_bytes is not None:
- kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2]))
- elif kv_chunk_size_min == 0:
- kv_chunk_size_min = None
-
- if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
- # the big matmul fits into our memory limit; do everything in 1 chunk,
- # i.e. send it down the unchunked fast-path
- kv_chunk_size = k_tokens
-
- with devices.without_autocast(disable=q.dtype == v.dtype):
- return efficient_dot_product_attention(
- q,
- k,
- v,
- query_chunk_size=q_chunk_size,
- kv_chunk_size=kv_chunk_size,
- kv_chunk_size_min = kv_chunk_size_min,
- use_checkpoint=use_checkpoint,
- )
-
-
-def get_xformers_flash_attention_op(q, k, v):
- if 'xFormers enable flash Attention' not in shared.opts.cross_attention_options:
- return None
-
- try:
- flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp # pylint: disable=used-before-assignment
- fw, _bw = flash_attention_op
- if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)):
- return flash_attention_op
- except Exception as e:
- errors.display_once(e, "enabling flash attention")
-
- return None
-
-
-def xformers_attention_forward(self, x, context=None, mask=None): # pylint: disable=unused-argument
- h = self.heads
- q_in = self.to_q(x)
- context = default(context, x)
-
- context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
- k_in = self.to_k(context_k)
- v_in = self.to_v(context_v)
-
- q, k, v = (rearrange(t, 'b n (h d) -> b n h d', h=h) for t in (q_in, k_in, v_in))
- del q_in, k_in, v_in
-
- dtype = q.dtype
- if shared.opts.upcast_attn:
- q, k, v = q.float(), k.float(), v.float()
-
- out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))
-
- out = out.to(dtype)
-
- out = rearrange(out, 'b n h d -> b n (h d)', h=h)
- return self.to_out(out)
-
-# Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py
-# The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface
-def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
- batch_size, sequence_length, inner_dim = x.shape
-
- if mask is not None:
- mask = self.prepare_attention_mask(mask, sequence_length, batch_size)
- mask = mask.view(batch_size, self.heads, -1, mask.shape[-1])
-
- h = self.heads
- q_in = self.to_q(x)
- context = default(context, x)
-
- context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
- k_in = self.to_k(context_k)
- v_in = self.to_v(context_v)
-
- head_dim = inner_dim // h
- q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
- k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
- v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
- del q_in, k_in, v_in
- dtype = q.dtype
- if shared.opts.upcast_attn:
- q, k, v = q.float(), k.float(), v.float()
-
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- hidden_states = torch.nn.functional.scaled_dot_product_attention(
- q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
- )
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, h * head_dim)
- hidden_states = hidden_states.to(dtype)
-
- # linear proj
- hidden_states = self.to_out[0](hidden_states)
- # dropout
- hidden_states = self.to_out[1](hidden_states)
- return hidden_states
-
-def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None):
- with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
- return scaled_dot_product_attention_forward(self, x, context, mask)
-
-def cross_attention_attnblock_forward(self, x):
- h_ = x
- h_ = self.norm(h_)
- q1 = self.q(h_)
- k1 = self.k(h_)
- v = self.v(h_)
-
- # compute attention
- b, c, h, w = q1.shape
-
- q2 = q1.reshape(b, c, h*w)
- del q1
-
- q = q2.permute(0, 2, 1) # b,hw,c
- del q2
-
- k = k1.reshape(b, c, h*w) # b,c,hw
- del k1
-
- h_ = torch.zeros_like(k, device=q.device)
-
- mem_free_total = get_available_vram()
-
- tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
- mem_required = tensor_size * 2.5
- steps = 1
-
- if mem_required > mem_free_total:
- steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
-
- slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
- for i in range(0, q.shape[1], slice_size):
- end = i + slice_size
-
- w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
- w2 = w1 * (int(c)**(-0.5))
- del w1
- w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
- del w2
-
- # attend to values
- v1 = v.reshape(b, c, h*w)
- w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
- del w3
-
- h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
- del v1, w4
-
- h2 = h_.reshape(b, c, h, w)
- del h_
-
- h3 = self.proj_out(h2)
- del h2
-
- h3 += x
-
- return h3
-
-def xformers_attnblock_forward(self, x):
- try:
- h_ = x
- h_ = self.norm(h_)
- q = self.q(h_)
- k = self.k(h_)
- v = self.v(h_)
- b, c, h, w = q.shape # pylint: disable=unused-variable
- q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
- dtype = q.dtype
- if shared.opts.upcast_attn:
- q, k = q.float(), k.float()
- q = q.contiguous()
- k = k.contiguous()
- v = v.contiguous()
- out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v))
- out = out.to(dtype)
- out = rearrange(out, 'b (h w) c -> b c h w', h=h)
- out = self.proj_out(out)
- return x + out
- except NotImplementedError:
- return cross_attention_attnblock_forward(self, x)
-
-def sdp_attnblock_forward(self, x):
- h_ = x
- h_ = self.norm(h_)
- q = self.q(h_)
- k = self.k(h_)
- v = self.v(h_)
- b, c, h, w = q.shape # pylint: disable=unused-variable
-
- # SDP optimization kenels are built for operations with multiple attention heads.
- # Four dimensional tensors are required for mem_efficient and flash attention to work.
- # We add an attention head dimension `a` to allow these kernels to be used.
- q, k, v = (rearrange(t, '(b a) c h w -> b a (h w) c', a=1) for t in (q, k, v))
- dtype = q.dtype
- if shared.opts.upcast_attn:
- q, k, v = q.float(), k.float(), v.float()
- q = q.contiguous()
- k = k.contiguous()
- v = v.contiguous()
- out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
- out = out.to(dtype)
- out = rearrange(out, 'b a (h w) c -> (b a) c h w', h=h) # remove the one attention head dimension `a`
- out = self.proj_out(out)
- return x + out
-
-def sdp_no_mem_attnblock_forward(self, x):
- with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
- return sdp_attnblock_forward(self, x)
-
-def sub_quad_attnblock_forward(self, x):
- h_ = x
- h_ = self.norm(h_)
- q = self.q(h_)
- k = self.k(h_)
- v = self.v(h_)
- b, c, h, w = q.shape # pylint: disable=unused-variable
- q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
- q = q.contiguous()
- k = k.contiguous()
- v = v.contiguous()
- out = sub_quad_attention(q, k, v, q_chunk_size=shared.opts.sub_quad_q_chunk_size, kv_chunk_size=shared.opts.sub_quad_kv_chunk_size, chunk_threshold=shared.opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
- out = rearrange(out, 'b (h w) c -> b c h w', h=h)
- out = self.proj_out(out)
- return x + out
+from __future__ import annotations
+import sys
+import math
+import psutil
+
+import torch
+from torch import einsum
+
+from ldm.util import default
+from einops import rearrange
+
+from modules import shared, errors, devices
+from modules.hypernetworks import hypernetwork
+
+from .sub_quadratic_attention import efficient_dot_product_attention # pylint: disable=relative-beyond-top-level
+
+
+if shared.opts.cross_attention_optimization == "xFormers":
+ try:
+ import xformers.ops # pylint: disable=import-error
+ shared.xformers_available = True
+ except Exception:
+ pass
+else:
+ if sys.modules.get("xformers", None) is not None:
+ shared.log.debug('Unloading xFormers')
+ sys.modules["xformers"] = None
+ sys.modules["xformers.ops"] = None
+
+
+def get_available_vram():
+ if shared.device.type == 'cuda' or shared.device.type == 'xpu':
+ try:
+ stats = torch.cuda.memory_stats(shared.device)
+ mem_active = stats['active_bytes.all.current']
+ mem_reserved = stats['reserved_bytes.all.current']
+ mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
+ mem_free_torch = mem_reserved - mem_active
+ mem_free_total = mem_free_cuda + mem_free_torch
+ except Exception:
+ mem_free_total = 1024 * 1024 * 1024
+ return mem_free_total
+ elif shared.device.type == 'privateuseone':
+ return torch.dml.mem_get_info(shared.device)[0]
+ else:
+ return psutil.virtual_memory().available
+
+
+# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
+def split_cross_attention_forward_v1(self, x, context=None, mask=None): # pylint: disable=unused-argument
+ h = self.heads
+
+ q_in = self.to_q(x)
+ context = default(context, x)
+
+ context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
+ k_in = self.to_k(context_k)
+ v_in = self.to_v(context_v)
+ del context, context_k, context_v, x
+
+ q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
+ del q_in, k_in, v_in
+
+ dtype = q.dtype
+ if shared.opts.upcast_attn:
+ q, k, v = q.float(), k.float(), v.float()
+
+ with devices.without_autocast(disable=not shared.opts.upcast_attn):
+ r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
+ for i in range(0, q.shape[0], 2):
+ end = i + 2
+ s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
+ s1 *= self.scale
+ s2 = s1.softmax(dim=-1)
+ del s1
+ r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
+ del s2
+ del q, k, v
+
+ r1 = r1.to(dtype)
+
+ r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
+ del r1
+
+ return self.to_out(r2)
+
+
+# taken from https://github.com/Doggettx/stable-diffusion and modified
+def split_cross_attention_forward(self, x, context=None, mask=None): # pylint: disable=unused-argument
+ h = self.heads
+ q_in = self.to_q(x)
+ context = default(context, x)
+
+ context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
+ k_in = self.to_k(context_k)
+ v_in = self.to_v(context_v)
+
+ dtype = q_in.dtype
+ if shared.opts.upcast_attn:
+ q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float()
+
+ with devices.without_autocast(disable=not shared.opts.upcast_attn):
+ k_in = k_in * self.scale
+ del context, x
+ q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
+ del q_in, k_in, v_in
+ r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
+ mem_free_total = get_available_vram()
+ gb = 1024 ** 3
+ tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
+ modifier = 3 if q.element_size() == 2 else 2.5
+ mem_required = tensor_size * modifier
+ steps = 1
+ if mem_required > mem_free_total:
+ steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
+ if steps > 64:
+ max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
+ raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
+ f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
+ slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
+ for i in range(0, q.shape[1], slice_size):
+ end = i + slice_size
+ s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
+ s2 = s1.softmax(dim=-1, dtype=q.dtype)
+ del s1
+ r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
+ del s2
+ del q, k, v
+ r1 = r1.to(dtype)
+ r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
+ del r1
+
+ return self.to_out(r2)
+
+
+# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
+mem_total_gb = psutil.virtual_memory().total // (1 << 30)
+
+def einsum_op_compvis(q, k, v):
+ s = einsum('b i d, b j d -> b i j', q, k)
+ s = s.softmax(dim=-1, dtype=s.dtype)
+ return einsum('b i j, b j d -> b i d', s, v)
+
+def einsum_op_slice_0(q, k, v, slice_size):
+ r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
+ for i in range(0, q.shape[0], slice_size):
+ end = i + slice_size
+ r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end])
+ return r
+
+def einsum_op_slice_1(q, k, v, slice_size):
+ r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
+ for i in range(0, q.shape[1], slice_size):
+ end = i + slice_size
+ r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v)
+ return r
+
+def einsum_op_mps_v1(q, k, v):
+ if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096
+ return einsum_op_compvis(q, k, v)
+ else:
+ slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
+ if slice_size % 4096 == 0:
+ slice_size -= 1
+ return einsum_op_slice_1(q, k, v, slice_size)
+
+def einsum_op_mps_v2(q, k, v):
+ if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:
+ return einsum_op_compvis(q, k, v)
+ else:
+ return einsum_op_slice_0(q, k, v, 1)
+
+def einsum_op_tensor_mem(q, k, v, max_tensor_mb):
+ size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
+ if size_mb <= max_tensor_mb:
+ return einsum_op_compvis(q, k, v)
+ div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
+ if div <= q.shape[0]:
+ return einsum_op_slice_0(q, k, v, q.shape[0] // div)
+ return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
+
+def einsum_op_cuda(q, k, v):
+ try:
+ stats = torch.cuda.memory_stats(q.device)
+ mem_active = stats['active_bytes.all.current']
+ mem_reserved = stats['reserved_bytes.all.current']
+ mem_free_cuda, _ = torch.cuda.mem_get_info(q.device)
+ mem_free_torch = mem_reserved - mem_active
+ mem_free_total = mem_free_cuda + mem_free_torch
+ except Exception:
+ mem_free_total = 1024 * 1024 * 1024
+ # Divide factor of safety as there's copying and fragmentation
+ return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
+
+def einsum_op_dml(q, k, v):
+ mem_free, mem_total = torch.dml.mem_get_info(q.device)
+ mem_active = mem_total - mem_free
+ mem_reserved = mem_total * 0.7
+ return einsum_op_tensor_mem(q, k, v, (mem_reserved - mem_active) if mem_reserved > mem_active else 1)
+
+def einsum_op(q, k, v):
+ if q.device.type == 'cuda' or q.device.type == 'xpu':
+ return einsum_op_cuda(q, k, v)
+
+ if q.device.type == 'mps':
+ if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18:
+ return einsum_op_mps_v1(q, k, v)
+ return einsum_op_mps_v2(q, k, v)
+
+ if q.device.type == 'privateuseone':
+ return einsum_op_dml(q, k, v)
+
+ # Smaller slices are faster due to L2/L3/SLC caches.
+ # Tested on i7 with 8MB L3 cache.
+ return einsum_op_tensor_mem(q, k, v, 32)
+
+def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): # pylint: disable=unused-argument
+ h = self.heads
+
+ q = self.to_q(x)
+ context = default(context, x)
+
+ context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
+ k = self.to_k(context_k)
+ v = self.to_v(context_v)
+ del context, context_k, context_v, x
+
+ dtype = q.dtype
+ if shared.opts.upcast_attn:
+ q, k, v = q.float(), k.float(), v if v.device.type == 'mps' else v.float()
+
+ with devices.without_autocast(disable=not shared.opts.upcast_attn):
+ k = k * self.scale
+ q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
+ r = einsum_op(q, k, v)
+ r = r.to(dtype)
+ return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
+
+# -- End of code from https://github.com/invoke-ai/InvokeAI --
+
+
+# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
+# The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface
+def sub_quad_attention_forward(self, x, context=None, mask=None):
+ assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."
+
+ h = self.heads
+
+ q = self.to_q(x)
+ context = default(context, x)
+
+ context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
+ k = self.to_k(context_k)
+ v = self.to_v(context_v)
+ del context, context_k, context_v, x
+
+ q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
+ k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
+ v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
+
+ if q.device.type == 'mps':
+ q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
+
+ dtype = q.dtype
+ if shared.opts.upcast_attn:
+ q, k = q.float(), k.float()
+
+ x = sub_quad_attention(q, k, v, q_chunk_size=shared.opts.sub_quad_q_chunk_size, kv_chunk_size=shared.opts.sub_quad_kv_chunk_size, chunk_threshold=shared.opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
+
+ x = x.to(dtype)
+
+ x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)
+
+ out_proj, dropout = self.to_out
+ x = out_proj(x)
+ x = dropout(x)
+
+ return x
+
+def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True):
+ bytes_per_token = torch.finfo(q.dtype).bits//8
+ batch_x_heads, q_tokens, _ = q.shape
+ _, k_tokens, _ = k.shape
+ qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
+
+ if chunk_threshold is None:
+ chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7)
+ elif chunk_threshold == 0:
+ chunk_threshold_bytes = None
+ else:
+ chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram())
+
+ if kv_chunk_size_min is None and chunk_threshold_bytes is not None:
+ kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2]))
+ elif kv_chunk_size_min == 0:
+ kv_chunk_size_min = None
+
+ if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
+ # the big matmul fits into our memory limit; do everything in 1 chunk,
+ # i.e. send it down the unchunked fast-path
+ kv_chunk_size = k_tokens
+
+ with devices.without_autocast(disable=q.dtype == v.dtype):
+ return efficient_dot_product_attention(
+ q,
+ k,
+ v,
+ query_chunk_size=q_chunk_size,
+ kv_chunk_size=kv_chunk_size,
+ kv_chunk_size_min = kv_chunk_size_min,
+ use_checkpoint=use_checkpoint,
+ )
+
+
+def get_xformers_flash_attention_op(q, k, v):
+ if 'xFormers enable flash Attention' not in shared.opts.cross_attention_options:
+ return None
+
+ try:
+ flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp # pylint: disable=used-before-assignment
+ fw, _bw = flash_attention_op
+ if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)):
+ return flash_attention_op
+ except Exception as e:
+ errors.display_once(e, "enabling flash attention")
+
+ return None
+
+
+def xformers_attention_forward(self, x, context=None, mask=None): # pylint: disable=unused-argument
+ h = self.heads
+ q_in = self.to_q(x)
+ context = default(context, x)
+
+ context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
+ k_in = self.to_k(context_k)
+ v_in = self.to_v(context_v)
+
+ q, k, v = (rearrange(t, 'b n (h d) -> b n h d', h=h) for t in (q_in, k_in, v_in))
+ del q_in, k_in, v_in
+
+ dtype = q.dtype
+ if shared.opts.upcast_attn:
+ q, k, v = q.float(), k.float(), v.float()
+
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))
+
+ out = out.to(dtype)
+
+ out = rearrange(out, 'b n h d -> b n (h d)', h=h)
+ return self.to_out(out)
+
+# Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py
+# The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface
+def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
+ batch_size, sequence_length, inner_dim = x.shape
+
+ if mask is not None:
+ mask = self.prepare_attention_mask(mask, sequence_length, batch_size)
+ mask = mask.view(batch_size, self.heads, -1, mask.shape[-1])
+
+ h = self.heads
+ q_in = self.to_q(x)
+ context = default(context, x)
+
+ context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
+ k_in = self.to_k(context_k)
+ v_in = self.to_v(context_v)
+
+ head_dim = inner_dim // h
+ q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
+ k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
+ v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
+ del q_in, k_in, v_in
+ dtype = q.dtype
+ if shared.opts.upcast_attn:
+ q, k, v = q.float(), k.float(), v.float()
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ hidden_states = torch.nn.functional.scaled_dot_product_attention(
+ q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, h * head_dim)
+ hidden_states = hidden_states.to(dtype)
+
+ # linear proj
+ hidden_states = self.to_out[0](hidden_states)
+ # dropout
+ hidden_states = self.to_out[1](hidden_states)
+ return hidden_states
+
+def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None):
+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
+ return scaled_dot_product_attention_forward(self, x, context, mask)
+
+def cross_attention_attnblock_forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q1 = self.q(h_)
+ k1 = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q1.shape
+
+ q2 = q1.reshape(b, c, h*w)
+ del q1
+
+ q = q2.permute(0, 2, 1) # b,hw,c
+ del q2
+
+ k = k1.reshape(b, c, h*w) # b,c,hw
+ del k1
+
+ h_ = torch.zeros_like(k, device=q.device)
+
+ mem_free_total = get_available_vram()
+
+ tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
+ mem_required = tensor_size * 2.5
+ steps = 1
+
+ if mem_required > mem_free_total:
+ steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
+
+ slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
+ for i in range(0, q.shape[1], slice_size):
+ end = i + slice_size
+
+ w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w2 = w1 * (int(c)**(-0.5))
+ del w1
+ w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
+ del w2
+
+ # attend to values
+ v1 = v.reshape(b, c, h*w)
+ w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
+ del w3
+
+ h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ del v1, w4
+
+ h2 = h_.reshape(b, c, h, w)
+ del h_
+
+ h3 = self.proj_out(h2)
+ del h2
+
+ h3 += x
+
+ return h3
+
+def xformers_attnblock_forward(self, x):
+ try:
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+ b, c, h, w = q.shape # pylint: disable=unused-variable
+ q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
+ dtype = q.dtype
+ if shared.opts.upcast_attn:
+ q, k = q.float(), k.float()
+ q = q.contiguous()
+ k = k.contiguous()
+ v = v.contiguous()
+ out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v))
+ out = out.to(dtype)
+ out = rearrange(out, 'b (h w) c -> b c h w', h=h)
+ out = self.proj_out(out)
+ return x + out
+ except NotImplementedError:
+ return cross_attention_attnblock_forward(self, x)
+
+def sdp_attnblock_forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+ b, c, h, w = q.shape # pylint: disable=unused-variable
+
+ # SDP optimization kenels are built for operations with multiple attention heads.
+ # Four dimensional tensors are required for mem_efficient and flash attention to work.
+ # We add an attention head dimension `a` to allow these kernels to be used.
+ q, k, v = (rearrange(t, '(b a) c h w -> b a (h w) c', a=1) for t in (q, k, v))
+ dtype = q.dtype
+ if shared.opts.upcast_attn:
+ q, k, v = q.float(), k.float(), v.float()
+ q = q.contiguous()
+ k = k.contiguous()
+ v = v.contiguous()
+ out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
+ out = out.to(dtype)
+ out = rearrange(out, 'b a (h w) c -> (b a) c h w', h=h) # remove the one attention head dimension `a`
+ out = self.proj_out(out)
+ return x + out
+
+def sdp_no_mem_attnblock_forward(self, x):
+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
+ return sdp_attnblock_forward(self, x)
+
+def sub_quad_attnblock_forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+ b, c, h, w = q.shape # pylint: disable=unused-variable
+ q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
+ q = q.contiguous()
+ k = k.contiguous()
+ v = v.contiguous()
+ out = sub_quad_attention(q, k, v, q_chunk_size=shared.opts.sub_quad_q_chunk_size, kv_chunk_size=shared.opts.sub_quad_kv_chunk_size, chunk_threshold=shared.opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
+ out = rearrange(out, 'b (h w) c -> b c h w', h=h)
+ out = self.proj_out(out)
+ return x + out
diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py
index ea8773147..d8d356071 100644
--- a/modules/sd_hijack_unet.py
+++ b/modules/sd_hijack_unet.py
@@ -1,79 +1,79 @@
-import torch
-from packaging import version
-
-from modules import devices
-from modules.sd_hijack_utils import CondFunc
-
-
-class TorchHijackForUnet:
- """
- This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
- this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64
- """
-
- def __getattr__(self, item):
- if item == 'cat':
- return self.cat
-
- if hasattr(torch, item):
- return getattr(torch, item)
-
- raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
-
- def cat(self, tensors, *args, **kwargs):
- if len(tensors) == 2:
- a, b = tensors
- if a.shape[-2:] != b.shape[-2:]:
- a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest")
-
- tensors = (a, b)
-
- return torch.cat(tensors, *args, **kwargs)
-
-
-th = TorchHijackForUnet()
-
-
-# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
-def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
-
- if isinstance(cond, dict):
- for y in cond.keys():
- cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
-
- with devices.autocast():
- return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
-
-
-class GELUHijack(torch.nn.GELU, torch.nn.Module):
- def __init__(self, *args, **kwargs): # pylint: disable=super-init-not-called
- torch.nn.GELU.__init__(self, *args, **kwargs)
- def forward(self, input): # pylint: disable=redefined-builtin
- if devices.unet_needs_upcast:
- return torch.nn.GELU.forward(self.float(), input.float()).to(devices.dtype_unet)
- else:
- return torch.nn.GELU.forward(self, input)
-
-
-ddpm_edit_hijack = None
-def hijack_ddpm_edit():
- global ddpm_edit_hijack # pylint: disable=global-statement
- if not ddpm_edit_hijack:
- CondFunc('modules.hijack.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
- CondFunc('modules.hijack.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
- ddpm_edit_hijack = CondFunc('modules.hijack.ddpm_edit.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
-
-
-unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast # pylint: disable=unnecessary-lambda-assignment
-CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
-CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
-if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available():
- CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
- CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
- CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
-
-first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16 # pylint: disable=unnecessary-lambda-assignment
-first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs) # pylint: disable=unnecessary-lambda-assignment
-CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
-CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
-CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)
+import torch
+from packaging import version
+
+from modules import devices
+from modules.sd_hijack_utils import CondFunc
+
+
+class TorchHijackForUnet:
+ """
+ This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
+ this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64
+ """
+
+ def __getattr__(self, item):
+ if item == 'cat':
+ return self.cat
+
+ if hasattr(torch, item):
+ return getattr(torch, item)
+
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
+
+ def cat(self, tensors, *args, **kwargs):
+ if len(tensors) == 2:
+ a, b = tensors
+ if a.shape[-2:] != b.shape[-2:]:
+ a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest")
+
+ tensors = (a, b)
+
+ return torch.cat(tensors, *args, **kwargs)
+
+
+th = TorchHijackForUnet()
+
+
+# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
+def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
+
+ if isinstance(cond, dict):
+ for y in cond.keys():
+ cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
+
+ with devices.autocast():
+ return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
+
+
+class GELUHijack(torch.nn.GELU, torch.nn.Module):
+ def __init__(self, *args, **kwargs): # pylint: disable=super-init-not-called
+ torch.nn.GELU.__init__(self, *args, **kwargs)
+ def forward(self, input): # pylint: disable=redefined-builtin
+ if devices.unet_needs_upcast:
+ return torch.nn.GELU.forward(self.float(), input.float()).to(devices.dtype_unet)
+ else:
+ return torch.nn.GELU.forward(self, input)
+
+
+ddpm_edit_hijack = None
+def hijack_ddpm_edit():
+ global ddpm_edit_hijack # pylint: disable=global-statement
+ if not ddpm_edit_hijack:
+ CondFunc('modules.hijack.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
+ CondFunc('modules.hijack.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
+ ddpm_edit_hijack = CondFunc('modules.hijack.ddpm_edit.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
+
+
+unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast # pylint: disable=unnecessary-lambda-assignment
+CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
+CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
+if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available():
+ CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
+ CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
+ CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
+
+first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16 # pylint: disable=unnecessary-lambda-assignment
+first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs) # pylint: disable=unnecessary-lambda-assignment
+CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
+CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
+CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)
diff --git a/modules/sd_hijack_utils.py b/modules/sd_hijack_utils.py
index d281be635..179ebc78e 100644
--- a/modules/sd_hijack_utils.py
+++ b/modules/sd_hijack_utils.py
@@ -1,28 +1,28 @@
-import importlib
-
-class CondFunc:
- def __new__(cls, orig_func, sub_func, cond_func):
- self = super(CondFunc, cls).__new__(cls)
- if isinstance(orig_func, str):
- func_path = orig_func.split('.')
- for i in range(len(func_path)-1, -1, -1):
- try:
- resolved_obj = importlib.import_module('.'.join(func_path[:i]))
- break
- except ImportError:
- pass
- for attr_name in func_path[i:-1]:
- resolved_obj = getattr(resolved_obj, attr_name)
- orig_func = getattr(resolved_obj, func_path[-1])
- setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
- self.__init__(orig_func, sub_func, cond_func)
- return lambda *args, **kwargs: self(*args, **kwargs)
- def __init__(self, orig_func, sub_func, cond_func):
- self.__orig_func = orig_func
- self.__sub_func = sub_func
- self.__cond_func = cond_func
- def __call__(self, *args, **kwargs):
- if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
- return self.__sub_func(self.__orig_func, *args, **kwargs)
- else:
- return self.__orig_func(*args, **kwargs)
+import importlib
+
+class CondFunc:
+ def __new__(cls, orig_func, sub_func, cond_func):
+ self = super(CondFunc, cls).__new__(cls)
+ if isinstance(orig_func, str):
+ func_path = orig_func.split('.')
+ for i in range(len(func_path)-1, -1, -1):
+ try:
+ resolved_obj = importlib.import_module('.'.join(func_path[:i]))
+ break
+ except ImportError:
+ pass
+ for attr_name in func_path[i:-1]:
+ resolved_obj = getattr(resolved_obj, attr_name)
+ orig_func = getattr(resolved_obj, func_path[-1])
+ setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
+ self.__init__(orig_func, sub_func, cond_func)
+ return lambda *args, **kwargs: self(*args, **kwargs)
+ def __init__(self, orig_func, sub_func, cond_func):
+ self.__orig_func = orig_func
+ self.__sub_func = sub_func
+ self.__cond_func = cond_func
+ def __call__(self, *args, **kwargs):
+ if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
+ return self.__sub_func(self.__orig_func, *args, **kwargs)
+ else:
+ return self.__orig_func(*args, **kwargs)
diff --git a/modules/sd_hijack_xlmr.py b/modules/sd_hijack_xlmr.py
index 28528329b..2aa646aea 100644
--- a/modules/sd_hijack_xlmr.py
+++ b/modules/sd_hijack_xlmr.py
@@ -1,32 +1,32 @@
-import torch
-
-from modules import sd_hijack_clip, devices
-
-
-class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords):
- def __init__(self, wrapped, hijack):
- super().__init__(wrapped, hijack)
-
- self.id_start = wrapped.config.bos_token_id
- self.id_end = wrapped.config.eos_token_id
- self.id_pad = wrapped.config.pad_token_id
-
- self.comma_token = self.tokenizer.get_vocab().get(',', None) # alt diffusion doesn't have bits for comma
-
- def encode_with_transformers(self, tokens):
- # there's no CLIP Skip here because all hidden layers have size of 1024 and the last one uses a
- # trained layer to transform those 1024 into 768 for unet; so you can't choose which transformer
- # layer to work with - you have to use the last
-
- attention_mask = (tokens != self.id_pad).to(device=tokens.device, dtype=torch.int64)
- features = self.wrapped(input_ids=tokens, attention_mask=attention_mask)
- z = features['projection_state']
-
- return z
-
- def encode_embedding_init_text(self, init_text, nvpt):
- embedding_layer = self.wrapped.roberta.embeddings
- ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
- embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
-
- return embedded
+import torch
+
+from modules import sd_hijack_clip, devices
+
+
+class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords):
+ def __init__(self, wrapped, hijack):
+ super().__init__(wrapped, hijack)
+
+ self.id_start = wrapped.config.bos_token_id
+ self.id_end = wrapped.config.eos_token_id
+ self.id_pad = wrapped.config.pad_token_id
+
+ self.comma_token = self.tokenizer.get_vocab().get(',', None) # alt diffusion doesn't have bits for comma
+
+ def encode_with_transformers(self, tokens):
+ # there's no CLIP Skip here because all hidden layers have size of 1024 and the last one uses a
+ # trained layer to transform those 1024 into 768 for unet; so you can't choose which transformer
+ # layer to work with - you have to use the last
+
+ attention_mask = (tokens != self.id_pad).to(device=tokens.device, dtype=torch.int64)
+ features = self.wrapped(input_ids=tokens, attention_mask=attention_mask)
+ z = features['projection_state']
+
+ return z
+
+ def encode_embedding_init_text(self, init_text, nvpt):
+ embedding_layer = self.wrapped.roberta.embeddings
+ ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
+ embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
+
+ return embedded
diff --git a/modules/sd_models.py b/modules/sd_models.py
index a9f2f5dc7..02d15b281 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -1,1344 +1,1344 @@
-import re
-import io
-import sys
-import json
-import time
-import copy
-import logging
-import contextlib
-import collections
-import os.path
-from os import mkdir
-from urllib import request
-from enum import Enum
-from rich import progress # pylint: disable=redefined-builtin
-import torch
-import safetensors.torch
-import diffusers
-from omegaconf import OmegaConf
-import tomesd
-from transformers import logging as transformers_logging
-from ldm.util import instantiate_from_config
-from modules import paths, shared, shared_items, shared_state, modelloader, devices, script_callbacks, sd_vae, errors, hashes, sd_models_config, sd_models_compile
-from modules.timer import Timer
-from modules.memstats import memory_stats
-from modules.paths import models_path, script_path
-from modules.modeldata import model_data
-
-
-transformers_logging.set_verbosity_error()
-model_dir = "Stable-diffusion"
-model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
-checkpoints_list = {}
-checkpoint_aliases = {}
-checkpoints_loaded = collections.OrderedDict()
-sd_metadata_file = os.path.join(paths.data_path, "metadata.json")
-sd_metadata = None
-sd_metadata_pending = 0
-sd_metadata_timer = 0
-
-
-class CheckpointInfo:
- def __init__(self, filename):
- self.name = None
- self.hash = None
- self.filename = filename
- self.type = ''
- relname = filename
- app_path = os.path.abspath(script_path)
-
- def rel(fn, path):
- try:
- return os.path.relpath(fn, path)
- except Exception:
- return fn
-
- if relname.startswith('..'):
- relname = os.path.abspath(relname)
- if relname.startswith(shared.opts.ckpt_dir):
- relname = rel(filename, shared.opts.ckpt_dir)
- elif relname.startswith(shared.opts.diffusers_dir):
- relname = rel(filename, shared.opts.diffusers_dir)
- elif relname.startswith(model_path):
- relname = rel(filename, model_path)
- elif relname.startswith(script_path):
- relname = rel(filename, script_path)
- elif relname.startswith(app_path):
- relname = rel(filename, app_path)
- else:
- relname = os.path.abspath(relname)
- relname, ext = os.path.splitext(relname)
- ext = ext.lower()[1:]
-
- if os.path.isfile(filename): # ckpt or safetensor
- self.name = relname
- self.filename = filename
- self.sha256 = hashes.sha256_from_cache(self.filename, f"checkpoint/{relname}")
- self.type = ext
- # self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
- else: # maybe a diffuser
- repo = [r for r in modelloader.diffuser_repos if filename == r['name']]
- if len(repo) == 0:
- self.name = relname
- self.filename = filename
- self.sha256 = None
- self.type = 'unknown'
- else:
- self.name = os.path.join(os.path.basename(shared.opts.diffusers_dir), repo[0]['name'])
- self.filename = repo[0]['path']
- self.sha256 = repo[0]['hash']
- self.type = 'diffusers'
-
- self.shorthash = self.sha256[0:10] if self.sha256 else None
- self.title = self.name if self.shorthash is None else f'{self.name} [{self.shorthash}]'
- self.path = self.filename
- self.model_name = os.path.basename(self.name)
- self.metadata = read_metadata_from_safetensors(filename)
- # shared.log.debug(f'Checkpoint: type={self.type} name={self.name} filename={self.filename} hash={self.shorthash} title={self.title}')
-
- def register(self):
- checkpoints_list[self.title] = self
- for i in [self.name, self.filename, self.shorthash, self.title]:
- if i is not None:
- checkpoint_aliases[i] = self
-
- def calculate_shorthash(self):
- self.sha256 = hashes.sha256(self.filename, f"checkpoint/{self.name}")
- if self.sha256 is None:
- return None
- self.shorthash = self.sha256[0:10]
- checkpoints_list.pop(self.title)
- self.title = f'{self.name} [{self.shorthash}]'
- self.register()
- return self.shorthash
-
-
-class NoWatermark:
- def apply_watermark(self, img):
- return img
-
-
-def setup_model():
- if not os.path.exists(model_path):
- os.makedirs(model_path, exist_ok=True)
- list_models()
- if shared.backend == shared.Backend.ORIGINAL:
- enable_midas_autodownload()
-
-
-def checkpoint_tiles(use_short=False): # pylint: disable=unused-argument
- def convert(name):
- return int(name) if name.isdigit() else name.lower()
- def alphanumeric_key(key):
- return [convert(c) for c in re.split('([0-9]+)', key)]
- return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
-
-
-def list_models():
- t0 = time.time()
- global checkpoints_list # pylint: disable=global-statement
- checkpoints_list.clear()
- checkpoint_aliases.clear()
- if shared.opts.sd_disable_ckpt or shared.backend == shared.Backend.DIFFUSERS:
- ext_filter = [".safetensors"]
- else:
- ext_filter = [".ckpt", ".safetensors"]
- model_list = modelloader.load_models(model_path=model_path, model_url=None, command_path=shared.opts.ckpt_dir, ext_filter=ext_filter, download_name=None, ext_blacklist=[".vae.ckpt", ".vae.safetensors"])
- if shared.backend == shared.Backend.DIFFUSERS:
- model_list += modelloader.load_diffusers_models(model_path=os.path.join(models_path, 'Diffusers'), command_path=shared.opts.diffusers_dir, clear=True)
- for filename in sorted(model_list, key=str.lower):
- checkpoint_info = CheckpointInfo(filename)
- if checkpoint_info.name is not None:
- checkpoint_info.register()
- if shared.cmd_opts.ckpt is not None:
- if not os.path.exists(shared.cmd_opts.ckpt) and shared.backend == shared.Backend.ORIGINAL:
- if shared.cmd_opts.ckpt.lower() != "none":
- shared.log.warning(f"Requested checkpoint not found: {shared.cmd_opts.ckpt}")
- else:
- checkpoint_info = CheckpointInfo(shared.cmd_opts.ckpt)
- if checkpoint_info.name is not None:
- checkpoint_info.register()
- shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title
- elif shared.cmd_opts.ckpt != shared.default_sd_model_file and shared.cmd_opts.ckpt is not None:
- shared.log.warning(f"Checkpoint not found: {shared.cmd_opts.ckpt}")
- shared.log.info(f'Available models: path="{shared.opts.ckpt_dir}" items={len(checkpoints_list)} time={time.time()-t0:.2f}')
-
- checkpoints_list = dict(sorted(checkpoints_list.items(), key=lambda cp: cp[1].filename))
- """
- if len(checkpoints_list) == 0:
- if not shared.cmd_opts.no_download:
- key = input('Download the default model? (y/N) ')
- if key.lower().startswith('y'):
- if shared.backend == shared.Backend.ORIGINAL:
- model_url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors"
- shared.opts.data['sd_model_checkpoint'] = "v1-5-pruned-emaonly.safetensors"
- model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=shared.opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], download_name="v1-5-pruned-emaonly.safetensors", ext_blacklist=[".vae.ckpt", ".vae.safetensors"])
- else:
- default_model_id = "runwayml/stable-diffusion-v1-5"
- modelloader.download_diffusers_model(default_model_id, shared.opts.diffusers_dir)
- model_list = modelloader.load_diffusers_models(model_path=os.path.join(models_path, 'Diffusers'), command_path=shared.opts.diffusers_dir)
-
- for filename in sorted(model_list, key=str.lower):
- checkpoint_info = CheckpointInfo(filename)
- if checkpoint_info.name is not None:
- checkpoint_info.register()
- """
-
-def update_model_hashes():
- txt = []
- lst = [ckpt for ckpt in checkpoints_list.values() if ckpt.hash is None]
- # shared.log.info(f'Models list: short hash missing for {len(lst)} out of {len(checkpoints_list)} models')
- for ckpt in lst:
- ckpt.hash = model_hash(ckpt.filename)
- # txt.append(f'Calculated short hash: {ckpt.title} {ckpt.hash}')
- # txt.append(f'Updated short hashes for {len(lst)} out of {len(checkpoints_list)} models')
- lst = [ckpt for ckpt in checkpoints_list.values() if ckpt.sha256 is None or ckpt.shorthash is None]
- shared.log.info(f'Models list: hash missing={len(lst)} total={len(checkpoints_list)}')
- for ckpt in lst:
- ckpt.sha256 = hashes.sha256(ckpt.filename, f"checkpoint/{ckpt.name}")
- ckpt.shorthash = ckpt.sha256[0:10] if ckpt.sha256 is not None else None
- if ckpt.sha256 is not None:
- txt.append(f'Calculated full hash: {ckpt.title} {ckpt.shorthash}')
- else:
- txt.append(f'Skipped hash calculation: {ckpt.title} ')
- txt.append(f'Updated hashes for {len(lst)} out of {len(checkpoints_list)} models')
- txt = ' '.join(txt)
- return txt
-
-
-def get_closet_checkpoint_match(search_string):
- checkpoint_info = checkpoint_aliases.get(search_string, None)
- if checkpoint_info is not None:
- return checkpoint_info
- found = sorted([info for info in checkpoints_list.values() if search_string in info.title], key=lambda x: len(x.title))
- if found:
- return found[0]
- found = sorted([info for info in checkpoints_list.values() if search_string.split(' ')[0] in info.title], key=lambda x: len(x.title))
- if found:
- return found[0]
- return None
-
-
-def model_hash(filename):
- """old hash that only looks at a small part of the file and is prone to collisions"""
- try:
- with open(filename, "rb") as file:
- import hashlib
- # t0 = time.time()
- m = hashlib.sha256()
- file.seek(0x100000)
- m.update(file.read(0x10000))
- shorthash = m.hexdigest()[0:8]
- # t1 = time.time()
- # shared.log.debug(f'Calculating short hash: {filename} hash={shorthash} time={(t1-t0):.2f}')
- return shorthash
- except FileNotFoundError:
- return 'NOFILE'
- except Exception:
- return 'NOHASH'
-
-
-def select_checkpoint(op='model'):
- if op == 'dict':
- model_checkpoint = shared.opts.sd_model_dict
- elif op == 'refiner':
- model_checkpoint = shared.opts.data.get('sd_model_refiner', None)
- else:
- model_checkpoint = shared.opts.sd_model_checkpoint
- if model_checkpoint is None or model_checkpoint == 'None':
- return None
- checkpoint_info = get_closet_checkpoint_match(model_checkpoint)
- if checkpoint_info is not None:
- shared.log.info(f'Select: {op}="{checkpoint_info.title if checkpoint_info is not None else None}"')
- return checkpoint_info
- if len(checkpoints_list) == 0 and not shared.cmd_opts.no_download:
- shared.log.warning("Cannot generate without a checkpoint")
- shared.log.info("Set system paths to use existing folders in a different location")
- shared.log.info("Or use --ckpt to force using existing checkpoint")
- return None
- checkpoint_info = next(iter(checkpoints_list.values()))
- if model_checkpoint is not None:
- if model_checkpoint != 'model.ckpt' and model_checkpoint != 'runwayml/stable-diffusion-v1-5':
- shared.log.warning(f"Selected checkpoint not found: {model_checkpoint}")
- else:
- shared.log.info("Selecting first available checkpoint")
- # shared.log.warning(f"Loading fallback checkpoint: {checkpoint_info.title}")
- shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title
- shared.log.info(f'Select: {op}="{checkpoint_info.title if checkpoint_info is not None else None}"')
- return checkpoint_info
-
-
-checkpoint_dict_replacements = {
- 'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
- 'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
- 'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
-}
-
-
-def transform_checkpoint_dict_key(k):
- for text, replacement in checkpoint_dict_replacements.items():
- if k.startswith(text):
- k = replacement + k[len(text):]
- return k
-
-
-def get_state_dict_from_checkpoint(pl_sd):
- pl_sd = pl_sd.pop("state_dict", pl_sd)
- pl_sd.pop("state_dict", None)
- sd = {}
- for k, v in pl_sd.items():
- new_key = transform_checkpoint_dict_key(k)
- if new_key is not None:
- sd[new_key] = v
- pl_sd.clear()
- pl_sd.update(sd)
- return pl_sd
-
-
-def write_metadata():
- global sd_metadata_pending # pylint: disable=global-statement
- if sd_metadata_pending == 0:
- shared.log.debug(f'Model metadata: file="{sd_metadata_file}" no changes')
- return
- shared.writefile(sd_metadata, sd_metadata_file)
- shared.log.info(f'Model metadata saved: file="{sd_metadata_file}" items={sd_metadata_pending} time={sd_metadata_timer:.2f}')
- sd_metadata_pending = 0
-
-
-def scrub_dict(dict_obj, keys):
- for key in list(dict_obj.keys()):
- if not isinstance(dict_obj, dict):
- continue
- if key in keys:
- dict_obj.pop(key, None)
- elif isinstance(dict_obj[key], dict):
- scrub_dict(dict_obj[key], keys)
- elif isinstance(dict_obj[key], list):
- for item in dict_obj[key]:
- scrub_dict(item, keys)
-
-
-def read_metadata_from_safetensors(filename):
- global sd_metadata # pylint: disable=global-statement
- if sd_metadata is None:
- if not os.path.isfile(sd_metadata_file):
- sd_metadata = {}
- else:
- sd_metadata = shared.readfile(sd_metadata_file, lock=True)
- res = sd_metadata.get(filename, None)
- if res is not None:
- return res
- if not filename.endswith(".safetensors"):
- return {}
- if shared.cmd_opts.no_metadata:
- return {}
- res = {}
- try:
- t0 = time.time()
- with open(filename, mode="rb") as file:
- metadata_len = file.read(8)
- metadata_len = int.from_bytes(metadata_len, "little")
- json_start = file.read(2)
- if metadata_len <= 2 or json_start not in (b'{"', b"{'"):
- shared.log.error(f"Not a valid safetensors file: {filename}")
- json_data = json_start + file.read(metadata_len-2)
- json_obj = json.loads(json_data)
- for k, v in json_obj.get("__metadata__", {}).items():
- if v.startswith("data:"):
- v = 'data'
- if k == 'format' and v == 'pt':
- continue
- large = True if len(v) > 2048 else False
- if large and k == 'ss_datasets':
- continue
- if large and k == 'workflow':
- continue
- if large and k == 'prompt':
- continue
- if large and k == 'ss_bucket_info':
- continue
- if v[0:1] == '{':
- try:
- v = json.loads(v)
- if large and k == 'ss_tag_frequency':
- v = { i: len(j) for i, j in v.items() }
- if large and k == 'sd_merge_models':
- scrub_dict(v, ['sd_merge_recipe'])
- except Exception:
- pass
- res[k] = v
- sd_metadata[filename] = res
- global sd_metadata_pending # pylint: disable=global-statement
- sd_metadata_pending += 1
- t1 = time.time()
- global sd_metadata_timer # pylint: disable=global-statement
- sd_metadata_timer += (t1 - t0)
- except Exception as e:
- shared.log.error(f"Error reading metadata from: {filename} {e}")
- return res
-
-
-def read_state_dict(checkpoint_file, map_location=None): # pylint: disable=unused-argument
- if not os.path.isfile(checkpoint_file):
- shared.log.error(f"Model is not a file: {checkpoint_file}")
- return None
- try:
- pl_sd = None
- with progress.open(checkpoint_file, 'rb', description=f'[cyan]Loading model: [yellow]{checkpoint_file}', auto_refresh=True, console=shared.console) as f:
- _, extension = os.path.splitext(checkpoint_file)
- if extension.lower() == ".ckpt" and shared.opts.sd_disable_ckpt:
- shared.log.warning(f"Checkpoint loading disabled: {checkpoint_file}")
- return None
- if shared.opts.stream_load:
- if extension.lower() == ".safetensors":
- # shared.log.debug('Model weights loading: type=safetensors mode=buffered')
- buffer = f.read()
- pl_sd = safetensors.torch.load(buffer)
- else:
- # shared.log.debug('Model weights loading: type=checkpoint mode=buffered')
- buffer = io.BytesIO(f.read())
- pl_sd = torch.load(buffer, map_location='cpu')
- else:
- if extension.lower() == ".safetensors":
- # shared.log.debug('Model weights loading: type=safetensors mode=mmap')
- pl_sd = safetensors.torch.load_file(checkpoint_file, device='cpu')
- else:
- # shared.log.debug('Model weights loading: type=checkpoint mode=direct')
- pl_sd = torch.load(f, map_location='cpu')
- sd = get_state_dict_from_checkpoint(pl_sd)
- del pl_sd
- except Exception as e:
- errors.display(e, f'Load model: {checkpoint_file}')
- sd = None
- return sd
-
-
-def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
- if not os.path.isfile(checkpoint_info.filename):
- return None
- if checkpoint_info in checkpoints_loaded:
- shared.log.info("Model weights loading: from cache")
- checkpoints_loaded.move_to_end(checkpoint_info, last=True) # FIFO -> LRU cache
- return checkpoints_loaded[checkpoint_info]
- res = read_state_dict(checkpoint_info.filename)
- if shared.opts.sd_checkpoint_cache > 0 and shared.backend == shared.Backend.ORIGINAL:
- # cache newly loaded model
- checkpoints_loaded[checkpoint_info] = res
- # clean up cache if limit is reached
- while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
- checkpoints_loaded.popitem(last=False)
- timer.record("load")
- return res
-
-
-def load_model_weights(model: torch.nn.Module, checkpoint_info: CheckpointInfo, state_dict, timer):
- _pipeline, _model_type = detect_pipeline(checkpoint_info.path, 'model')
- shared.log.debug(f'Model weights loading: {memory_stats()}')
- timer.record("hash")
- if model_data.sd_dict == 'None':
- shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
- if state_dict is None:
- state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
- try:
- model.load_state_dict(state_dict, strict=False)
- except Exception as e:
- shared.log.error(f'Error loading model weights: {checkpoint_info.filename}')
- shared.log.error(' '.join(str(e).splitlines()[:2]))
- return False
- del state_dict
- timer.record("apply")
- if shared.opts.opt_channelslast:
- model.to(memory_format=torch.channels_last)
- timer.record("channels")
- if not shared.opts.no_half:
- vae = model.first_stage_model
- depth_model = getattr(model, 'depth_model', None)
- # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
- if shared.opts.no_half_vae:
- model.first_stage_model = None
- # with --upcast-sampling, don't convert the depth model weights to float16
- if shared.opts.upcast_sampling and depth_model:
- model.depth_model = None
- model.half()
- model.first_stage_model = vae
- if depth_model:
- model.depth_model = depth_model
- if shared.opts.cuda_cast_unet:
- devices.dtype_unet = model.model.diffusion_model.dtype
- else:
- model.model.diffusion_model.to(devices.dtype_unet)
- model.first_stage_model.to(devices.dtype_vae)
- model.sd_model_hash = checkpoint_info.calculate_shorthash()
- model.sd_model_checkpoint = checkpoint_info.filename
- model.sd_checkpoint_info = checkpoint_info
- model.is_sdxl = False # a1111 compatibility item
- model.is_sd2 = hasattr(model.cond_stage_model, 'model') # a1111 compatibility item
- model.is_sd1 = not hasattr(model.cond_stage_model, 'model') # a1111 compatibility item
- model.logvar = model.logvar.to(devices.device) if hasattr(model, 'logvar') else None # fix for training
- shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
- sd_vae.delete_base_vae()
- sd_vae.clear_loaded_vae()
- vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
- sd_vae.load_vae(model, vae_file, vae_source)
- timer.record("vae")
- return True
-
-
-def enable_midas_autodownload():
- """
- Gives the ldm.modules.midas.api.load_model function automatic downloading.
-
- When the 512-depth-ema model, and other future models like it, is loaded,
- it calls midas.api.load_model to load the associated midas depth model.
- This function applies a wrapper to download the model to the correct
- location automatically.
- """
- import ldm.modules.midas.api
- midas_path = os.path.join(paths.models_path, 'midas')
- for k, v in ldm.modules.midas.api.ISL_PATHS.items():
- file_name = os.path.basename(v)
- ldm.modules.midas.api.ISL_PATHS[k] = os.path.join(midas_path, file_name)
- midas_urls = {
- "dpt_large": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt",
- "dpt_hybrid": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt",
- "midas_v21": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21-f6b98070.pt",
- "midas_v21_small": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21_small-70d6b9c8.pt",
- }
- ldm.modules.midas.api.load_model_inner = ldm.modules.midas.api.load_model
-
- def load_model_wrapper(model_type):
- path = ldm.modules.midas.api.ISL_PATHS[model_type]
- if not os.path.exists(path):
- if not os.path.exists(midas_path):
- mkdir(midas_path)
- shared.log.info(f"Downloading midas model weights for {model_type} to {path}")
- request.urlretrieve(midas_urls[model_type], path)
- shared.log.info(f"{model_type} downloaded")
- return ldm.modules.midas.api.load_model_inner(model_type)
-
- ldm.modules.midas.api.load_model = load_model_wrapper
-
-
-def repair_config(sd_config):
- if "use_ema" not in sd_config.model.params:
- sd_config.model.params.use_ema = False
- if shared.opts.no_half:
- sd_config.model.params.unet_config.params.use_fp16 = False
- elif shared.opts.upcast_sampling:
- sd_config.model.params.unet_config.params.use_fp16 = True if sys.platform != 'darwin' else False
- if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
- sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
- # For UnCLIP-L, override the hardcoded karlo directory
- if "noise_aug_config" in sd_config.model.params and "clip_stats_path" in sd_config.model.params.noise_aug_config.params:
- karlo_path = os.path.join(paths.models_path, 'karlo')
- sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path)
-
-
-sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
-sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
-
-
-def change_backend():
- shared.log.info(f'Backend changed: from={shared.backend} to={shared.opts.sd_backend}')
- shared.log.warning('Full server restart required to apply all changes')
- unload_model_weights()
- shared.backend = shared.Backend.ORIGINAL if shared.opts.sd_backend == 'original' else shared.Backend.DIFFUSERS
- checkpoints_loaded.clear()
- from modules.sd_samplers import list_samplers
- list_samplers(shared.backend)
- list_models()
- from modules.sd_vae import refresh_vae_list
- refresh_vae_list()
-
-
-def detect_pipeline(f: str, op: str = 'model', warning=True):
- if not f.endswith('.safetensors'):
- return None, None
- guess = shared.opts.diffusers_pipeline
- warn = shared.log.warning if warning else lambda *args, **kwargs: None
- if guess == 'Autodetect':
- try:
- # guess by size
- size = round(os.path.getsize(f) / 1024 / 1024)
- if size < 128:
- warn(f'Model size smaller than expected: {f} size={size} MB')
- elif (size >= 316 and size <= 324) or (size >= 156 and size <= 164): # 320 or 160
- warn(f'Model detected as VAE model, but attempting to load as model: {op}={f} size={size} MB')
- guess = 'VAE'
- elif size >= 5351 and size <= 5359: # 5353
- guess = 'Stable Diffusion' # SD v2
- elif size >= 5791 and size <= 5799: # 5795
- if shared.backend == shared.Backend.ORIGINAL:
- warn(f'Model detected as SD-XL refiner model, but attempting to load using backend=original: {op}={f} size={size} MB')
- if op == 'model':
- warn(f'Model detected as SD-XL refiner model, but attempting to load a base model: {op}={f} size={size} MB')
- guess = 'Stable Diffusion XL'
- elif (size >= 6611 and size <= 6619) or (size >= 6771 and size <= 6779): # 6617, HassakuXL is 6776
- if shared.backend == shared.Backend.ORIGINAL:
- warn(f'Model detected as SD-XL base model, but attempting to load using backend=original: {op}={f} size={size} MB')
- guess = 'Stable Diffusion XL'
- elif size >= 3361 and size <= 3369: # 3368
- if shared.backend == shared.Backend.ORIGINAL:
- warn(f'Model detected as SD upscale model, but attempting to load using backend=original: {op}={f} size={size} MB')
- guess = 'Stable Diffusion Upscale'
- elif size >= 4891 and size <= 4899: # 4897
- if shared.backend == shared.Backend.ORIGINAL:
- warn(f'Model detected as SD XL inpaint model, but attempting to load using backend=original: {op}={f} size={size} MB')
- guess = 'Stable Diffusion XL Inpaint'
- elif size >= 9791 and size <= 9799: # 9794
- if shared.backend == shared.Backend.ORIGINAL:
- warn(f'Model detected as SD XL instruct pix2pix model, but attempting to load using backend=original: {op}={f} size={size} MB')
- guess = 'Stable Diffusion XL Instruct'
- elif size > 3138 and size < 3142: #3140
- if shared.backend == shared.Backend.ORIGINAL:
- warn(f'Model detected as Segmind Vega model, but attempting to load using backend=original: {op}={f} size={size} MB')
- guess = 'Stable Diffusion XL'
- else:
- guess = 'Stable Diffusion'
- # guess by name
- """
- if 'LCM_' in f.upper() or 'LCM-' in f.upper() or '_LCM' in f.upper() or '-LCM' in f.upper():
- if shared.backend == shared.Backend.ORIGINAL:
- warn(f'Model detected as LCM model, but attempting to load using backend=original: {op}={f} size={size} MB')
- guess = 'Latent Consistency Model'
- """
- if 'PixArt' in f:
- if shared.backend == shared.Backend.ORIGINAL:
- warn(f'Model detected as PixArt Alpha model, but attempting to load using backend=original: {op}={f} size={size} MB')
- guess = 'PixArt Alpha'
- # switch for specific variant
- if guess == 'Stable Diffusion' and 'inpaint' in f.lower():
- guess = 'Stable Diffusion Inpaint'
- elif guess == 'Stable Diffusion' and 'instruct' in f.lower():
- guess = 'Stable Diffusion Instruct'
- if guess == 'Stable Diffusion XL' and 'inpaint' in f.lower():
- guess = 'Stable Diffusion XL Inpaint'
- elif guess == 'Stable Diffusion XL' and 'instruct' in f.lower():
- guess = 'Stable Diffusion XL Instruct'
- # get actual pipeline
- pipeline = shared_items.get_pipelines().get(guess, None)
- shared.log.info(f'Autodetect: {op}="{guess}" class={pipeline.__name__} file="{f}" size={size}MB')
- except Exception as e:
- shared.log.error(f'Error detecting diffusers pipeline: model={f} {e}')
- return None, None
- else:
- try:
- size = round(os.path.getsize(f) / 1024 / 1024)
- pipeline = shared_items.get_pipelines().get(guess, None)
- shared.log.info(f'Diffusers: {op}="{guess}" class={pipeline.__name__} file="{f}" size={size}MB')
- except Exception as e:
- shared.log.error(f'Error loading diffusers pipeline: model={f} {e}')
-
- if pipeline is None:
- shared.log.warning(f'Autodetect: pipeline not recognized: {guess}: {op}={f} size={size}')
- pipeline = diffusers.StableDiffusionPipeline
- return pipeline, guess
-
-
-def copy_diffuser_options(new_pipe, orig_pipe):
- new_pipe.sd_checkpoint_info = orig_pipe.sd_checkpoint_info
- new_pipe.sd_model_checkpoint = orig_pipe.sd_model_checkpoint
- new_pipe.embedding_db = getattr(orig_pipe, 'embedding_db', None)
- new_pipe.sd_model_hash = getattr(orig_pipe, 'sd_model_hash', None)
- new_pipe.has_accelerate = getattr(orig_pipe, 'has_accelerate', False)
- new_pipe.is_sdxl = getattr(orig_pipe, 'is_sdxl', False) # a1111 compatibility item
- new_pipe.is_sd2 = getattr(orig_pipe, 'is_sd2', False)
- new_pipe.is_sd1 = getattr(orig_pipe, 'is_sd1', True)
-
-
-def set_diffuser_options(sd_model, vae = None, op: str = 'model'):
- if sd_model is None:
- shared.log.warning(f'{op} is not loaded')
- return
- if (shared.opts.diffusers_model_cpu_offload or shared.cmd_opts.medvram) and (shared.opts.diffusers_seq_cpu_offload or shared.cmd_opts.lowvram):
- shared.log.warning(f'Setting {op}: Model CPU offload and Sequential CPU offload are not compatible')
- shared.log.debug(f'Setting {op}: disabling model CPU offload')
- shared.opts.diffusers_model_cpu_offload=False
- shared.cmd_opts.medvram=False
-
- if hasattr(sd_model, "watermark"):
- sd_model.watermark = NoWatermark()
- sd_model.has_accelerate = False
- if hasattr(sd_model, "vae"):
- if vae is not None:
- sd_model.vae = vae
- if shared.opts.diffusers_vae_upcast != 'default':
- if shared.opts.diffusers_vae_upcast == 'true':
- sd_model.vae.config.force_upcast = True
- else:
- sd_model.vae.config.force_upcast = False
- if shared.opts.no_half_vae:
- devices.dtype_vae = torch.float32
- sd_model.vae.to(devices.dtype_vae)
- shared.log.debug(f'Setting {op} VAE: name={sd_vae.loaded_vae_file} upcast={sd_model.vae.config.get("force_upcast", None)}')
- if hasattr(sd_model, "enable_model_cpu_offload"):
- if (shared.cmd_opts.medvram and devices.backend != "directml") or shared.opts.diffusers_model_cpu_offload:
- shared.log.debug(f'Setting {op}: enable model CPU offload')
- if shared.opts.diffusers_move_base or shared.opts.diffusers_move_unet or shared.opts.diffusers_move_refiner:
- shared.opts.diffusers_move_base = False
- shared.opts.diffusers_move_unet = False
- shared.opts.diffusers_move_refiner = False
- shared.log.warning(f'Disabling {op} "Move model to CPU" since "Model CPU offload" is enabled')
- sd_model.enable_model_cpu_offload()
- sd_model.has_accelerate = True
- if hasattr(sd_model, "enable_sequential_cpu_offload"):
- if shared.cmd_opts.lowvram or shared.opts.diffusers_seq_cpu_offload:
- shared.log.debug(f'Setting {op}: enable sequential CPU offload')
- if shared.opts.diffusers_move_base or shared.opts.diffusers_move_unet or shared.opts.diffusers_move_refiner:
- shared.opts.diffusers_move_base = False
- shared.opts.diffusers_move_unet = False
- shared.opts.diffusers_move_refiner = False
- shared.log.warning(f'Disabling {op} "Move model to CPU" since "Sequential CPU offload" is enabled')
- sd_model.enable_sequential_cpu_offload(device=devices.device)
- sd_model.has_accelerate = True
- if hasattr(sd_model, "enable_vae_slicing"):
- if shared.cmd_opts.lowvram or shared.opts.diffusers_vae_slicing:
- shared.log.debug(f'Setting {op}: enable VAE slicing')
- sd_model.enable_vae_slicing()
- else:
- sd_model.disable_vae_slicing()
- if hasattr(sd_model, "enable_vae_tiling"):
- if shared.cmd_opts.lowvram or shared.opts.diffusers_vae_tiling:
- shared.log.debug(f'Setting {op}: enable VAE tiling')
- sd_model.enable_vae_tiling()
- else:
- sd_model.disable_vae_tiling()
- if hasattr(sd_model, "enable_attention_slicing"):
- if shared.cmd_opts.lowvram or shared.opts.diffusers_attention_slicing:
- shared.log.debug(f'Setting {op}: enable attention slicing')
- sd_model.enable_attention_slicing()
- else:
- sd_model.disable_attention_slicing()
- if hasattr(sd_model, "vqvae"):
- sd_model.vqvae.to(torch.float32) # vqvae is producing nans in fp16
- if shared.opts.cross_attention_optimization == "xFormers" and hasattr(sd_model, 'enable_xformers_memory_efficient_attention'):
- sd_model.enable_xformers_memory_efficient_attention()
- if shared.opts.diffusers_fuse_projections and hasattr(sd_model, 'fuse_qkv_projections'):
- shared.log.debug(f'Setting {op}: enable fused projections')
- sd_model.fuse_qkv_projections()
- if shared.opts.diffusers_eval:
- if hasattr(sd_model, "unet") and hasattr(sd_model.unet, "requires_grad_"):
- sd_model.unet.requires_grad_(False)
- sd_model.unet.eval()
- if hasattr(sd_model, "vae") and hasattr(sd_model.vae, "requires_grad_"):
- sd_model.vae.requires_grad_(False)
- sd_model.vae.eval()
- if hasattr(sd_model, "text_encoder") and hasattr(sd_model.text_encoder, "requires_grad_"):
- sd_model.text_encoder.requires_grad_(False)
- sd_model.text_encoder.eval()
- if shared.opts.diffusers_quantization:
- sd_model = sd_models_compile.dynamic_quantization(sd_model)
-
- if shared.opts.opt_channelslast and hasattr(sd_model, 'unet'):
- shared.log.debug(f'Setting {op}: enable channels last')
- sd_model.unet.to(memory_format=torch.channels_last)
-
-
-def load_diffuser(checkpoint_info=None, already_loaded_state_dict=None, timer=None, op='model'): # pylint: disable=unused-argument
- import torch # pylint: disable=reimported,redefined-outer-name
- if shared.cmd_opts.profile:
- import cProfile
- pr = cProfile.Profile()
- pr.enable()
- if timer is None:
- timer = Timer()
- logging.getLogger("diffusers").setLevel(logging.ERROR)
- timer.record("diffusers")
- devices.set_cuda_params()
- diffusers_load_config = {
- "low_cpu_mem_usage": True,
- "torch_dtype": devices.dtype,
- "safety_checker": None,
- "requires_safety_checker": False,
- "load_safety_checker": False,
- "load_connected_pipeline": True,
- # TODO: use_safetensors cant enable for all checkpoints just yet
- }
- if shared.opts.diffusers_model_load_variant == 'default':
- if devices.dtype == torch.float16:
- diffusers_load_config['variant'] = 'fp16'
- elif shared.opts.diffusers_model_load_variant == 'fp32':
- pass
- else:
- diffusers_load_config['variant'] = shared.opts.diffusers_model_load_variant
-
- if shared.opts.diffusers_pipeline == 'Custom Diffusers Pipeline' and len(shared.opts.custom_diffusers_pipeline) > 0:
- shared.log.debug(f'Diffusers custom pipeline: {shared.opts.custom_diffusers_pipeline}')
- diffusers_load_config['custom_pipeline'] = shared.opts.custom_diffusers_pipeline
-
- # if 'LCM' in checkpoint_info.path:
- # diffusers_load_config['custom_pipeline'] = 'latent_consistency_txt2img'
-
- if shared.opts.data.get('sd_model_checkpoint', '') == 'model.ckpt' or shared.opts.data.get('sd_model_checkpoint', '') == '':
- shared.opts.data['sd_model_checkpoint'] = "runwayml/stable-diffusion-v1-5"
-
- if op == 'model' or op == 'dict':
- if (model_data.sd_model is not None) and (checkpoint_info is not None) and (checkpoint_info.hash == model_data.sd_model.sd_checkpoint_info.hash): # trying to load the same model
- return
- else:
- if (model_data.sd_refiner is not None) and (checkpoint_info is not None) and (checkpoint_info.hash == model_data.sd_refiner.sd_checkpoint_info.hash): # trying to load the same model
- return
-
- sd_model = None
-
- try:
- if shared.cmd_opts.ckpt is not None and os.path.isdir(shared.cmd_opts.ckpt) and model_data.initial: # initial load
- ckpt_basename = os.path.basename(shared.cmd_opts.ckpt)
- model_name = modelloader.find_diffuser(ckpt_basename)
- if model_name is not None:
- shared.log.info(f'Load model {op}: {model_name}')
- model_file = modelloader.download_diffusers_model(hub_id=model_name)
- try:
- shared.log.debug(f'Model load {op} config: {diffusers_load_config}')
- sd_model = diffusers.DiffusionPipeline.from_pretrained(model_file, **diffusers_load_config)
- except Exception as e:
- shared.log.error(f'Failed loading model: {model_file} {e}')
- list_models() # rescan for downloaded model
- checkpoint_info = CheckpointInfo(model_name)
-
- checkpoint_info = checkpoint_info or select_checkpoint(op=op)
- if checkpoint_info is None:
- unload_model_weights(op=op)
- return
-
- vae = None
- sd_vae.loaded_vae_file = None
- if op == 'model' or op == 'refiner':
- vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
- vae = sd_vae.load_vae_diffusers(checkpoint_info.path, vae_file, vae_source)
- if vae is not None:
- diffusers_load_config["vae"] = vae
-
- shared.log.debug(f'Diffusers loading: path="{checkpoint_info.path}"')
- if os.path.isdir(checkpoint_info.path):
- err1 = None
- err2 = None
- err3 = None
- try: # try autopipeline first, best choice but not all pipelines are available
- sd_model = diffusers.AutoPipelineForText2Image.from_pretrained(checkpoint_info.path, cache_dir=shared.opts.diffusers_dir, **diffusers_load_config)
- sd_model.model_type = sd_model.__class__.__name__
- except Exception as e:
- err1 = e
- # shared.log.error(f'AutoPipeline: {e}')
- try: # try diffusion pipeline next second-best choice, works for most non-linked pipelines
- if err1 is not None:
- sd_model = diffusers.DiffusionPipeline.from_pretrained(checkpoint_info.path, cache_dir=shared.opts.diffusers_dir, **diffusers_load_config)
- sd_model.model_type = sd_model.__class__.__name__
- except Exception as e:
- err2 = e
- # shared.log.error(f'DiffusionPipeline: {e}')
- try: # try basic pipeline next just in case
- if err2 is not None:
- sd_model = diffusers.StableDiffusionPipeline.from_pretrained(checkpoint_info.path, cache_dir=shared.opts.diffusers_dir, **diffusers_load_config)
- sd_model.model_type = sd_model.__class__.__name__
- except Exception as e:
- err3 = e # ignore last error
- shared.log.error(f'StableDiffusionPipeline: {e}')
- if err3 is not None:
- shared.log.error(f'Failed loading {op}: {checkpoint_info.path} auto={err1} diffusion={err2}')
- return
- elif os.path.isfile(checkpoint_info.path) and checkpoint_info.path.lower().endswith('.safetensors'):
- # diffusers_load_config["local_files_only"] = True
- diffusers_load_config["extract_ema"] = shared.opts.diffusers_extract_ema
- pipeline, model_type = detect_pipeline(checkpoint_info.path, op)
- if pipeline is None:
- shared.log.error(f'Diffusers {op} pipeline not initialized: {shared.opts.diffusers_pipeline}')
- return
- try:
- if model_type.startswith('Stable Diffusion'):
- diffusers_load_config['force_zeros_for_empty_prompt '] = shared.opts.diffusers_force_zeros
- diffusers_load_config['requires_aesthetics_score'] = shared.opts.diffusers_aesthetics_score
- if 'inpainting' in checkpoint_info.path.lower():
- diffusers_load_config['config_files'] = {
- 'v1': 'configs/v1-inpainting-inference.yaml',
- 'v2': 'configs/v2-inference-768-v.yaml',
- 'xl': 'configs/sd_xl_base.yaml',
- 'xl_refiner': 'configs/sd_xl_refiner.yaml',
- }
- else:
- diffusers_load_config['config_files'] = {
- 'v1': 'configs/v1-inference.yaml',
- 'v2': 'configs/v2-inference-768-v.yaml',
- 'xl': 'configs/sd_xl_base.yaml',
- 'xl_refiner': 'configs/sd_xl_refiner.yaml',
- }
- if hasattr(pipeline, 'from_single_file'):
- diffusers_load_config['use_safetensors'] = True
- sd_model = pipeline.from_single_file(checkpoint_info.path, **diffusers_load_config)
- if sd_model is not None and hasattr(sd_model, 'unet') and hasattr(sd_model.unet, 'config') and 'inpainting' in checkpoint_info.path.lower():
- shared.log.debug('Model patch: type=inpaint')
- sd_model.unet.config.in_channels = 9
- elif hasattr(pipeline, 'from_ckpt'):
- sd_model = pipeline.from_ckpt(checkpoint_info.path, **diffusers_load_config)
- else:
- shared.log.error(f'Diffusers {op} cannot load safetensor model: {checkpoint_info.path} {shared.opts.diffusers_pipeline}')
- return
- if sd_model is not None:
- diffusers_load_config.pop('vae', None)
- diffusers_load_config.pop('safety_checker', None)
- diffusers_load_config.pop('requires_safety_checker', None)
- diffusers_load_config.pop('load_safety_checker', None)
- diffusers_load_config.pop('config_files', None)
- diffusers_load_config.pop('local_files_only', None)
- shared.log.debug(f'Setting {op}: pipeline={sd_model.__class__.__name__} config={diffusers_load_config}') # pylint: disable=protected-access
- except Exception as e:
- shared.log.error(f'Diffusers failed loading: {op}={checkpoint_info.path} pipeline={shared.opts.diffusers_pipeline}/{sd_model.__class__.__name__} {e}')
- errors.display(e, f'loading {op}={checkpoint_info.path} pipeline={shared.opts.diffusers_pipeline}/{sd_model.__class__.__name__}')
- return
- else:
- shared.log.error(f'Diffusers cannot load: {op}={checkpoint_info.path}')
- return
-
- if "StableDiffusion" in sd_model.__class__.__name__:
- pass # scheduler is created on first use
- elif "Kandinsky" in sd_model.__class__.__name__:
- sd_model.scheduler.name = 'DDIM'
-
- set_diffuser_options(sd_model, vae, op)
-
- base_sent_to_cpu=False
- if (shared.opts.cuda_compile and shared.opts.cuda_compile_backend != 'none') or shared.opts.ipex_optimize:
- if op == 'refiner' and not getattr(sd_model, 'has_accelerate', False):
- gpu_vram = memory_stats().get('gpu', {})
- free_vram = gpu_vram.get('total', 0) - gpu_vram.get('used', 0)
- refiner_enough_vram = free_vram >= 7 if "StableDiffusionXL" in sd_model.__class__.__name__ else 3
- if not shared.opts.diffusers_move_base and refiner_enough_vram:
- sd_model.to(devices.device)
- base_sent_to_cpu=False
- else:
- if not refiner_enough_vram and not (shared.opts.diffusers_move_base and shared.opts.diffusers_move_refiner):
- shared.log.warning(f"Insufficient GPU memory, using system memory as fallback: free={free_vram} GB")
- if not shared.opts.shared.opts.diffusers_seq_cpu_offload and not shared.opts.diffusers_model_cpu_offload:
- shared.log.debug('Enabled moving base model to CPU')
- shared.log.debug('Enabled moving refiner model to CPU')
- shared.opts.diffusers_move_base=True
- shared.opts.diffusers_move_refiner=True
- shared.log.debug('Moving base model to CPU')
- if model_data.sd_model is not None:
- model_data.sd_model.to(devices.cpu)
- devices.torch_gc(force=True)
- sd_model.to(devices.device)
- base_sent_to_cpu=True
- elif not getattr(sd_model, 'has_accelerate', False):
- sd_model.to(devices.device)
-
- sd_models_compile.compile_diffusers(sd_model)
-
- if sd_model is None:
- shared.log.error('Diffuser model not loaded')
- return
- sd_model.sd_model_hash = checkpoint_info.calculate_shorthash() # pylint: disable=attribute-defined-outside-init
- sd_model.sd_checkpoint_info = checkpoint_info # pylint: disable=attribute-defined-outside-init
- sd_model.sd_model_checkpoint = checkpoint_info.filename # pylint: disable=attribute-defined-outside-init
- sd_model.is_sdxl = False # a1111 compatibility item
- sd_model.is_sd2 = hasattr(sd_model, 'cond_stage_model') and hasattr(sd_model.cond_stage_model, 'model') # a1111 compatibility item
- sd_model.is_sd1 = not sd_model.is_sd2 # a1111 compatibility item
- sd_model.logvar = sd_model.logvar.to(devices.device) if hasattr(sd_model, 'logvar') else None # fix for training
- shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
- if hasattr(sd_model, "set_progress_bar_config"):
- sd_model.set_progress_bar_config(bar_format='Progress {rate_fmt}{postfix} {bar} {percentage:3.0f}% {n_fmt}/{total_fmt} {elapsed} {remaining}', ncols=80, colour='#327fba')
- if op == 'refiner' and shared.opts.diffusers_move_refiner and not getattr(sd_model, 'has_accelerate', False):
- shared.log.debug('Moving refiner model to CPU')
- sd_model.to(devices.cpu)
- elif not getattr(sd_model, 'has_accelerate', False): # In offload modes, accelerate will move models around
- sd_model.to(devices.device)
- if op == 'refiner' and base_sent_to_cpu:
- shared.log.debug('Moving base model back to GPU')
- model_data.sd_model.to(devices.device)
- except Exception as e:
- shared.log.error("Failed to load diffusers model")
- errors.display(e, "loading Diffusers model")
-
- if sd_model is not None:
- from modules.textual_inversion import textual_inversion
- sd_model.embedding_db = textual_inversion.EmbeddingDatabase()
- if op == 'refiner':
- model_data.sd_refiner = sd_model
- else:
- model_data.sd_model = sd_model
- sd_model.embedding_db.add_embedding_dir(shared.opts.embeddings_dir)
- sd_model.embedding_db.load_textual_inversion_embeddings(force_reload=True)
-
- timer.record("load")
- devices.torch_gc(force=True)
- if shared.cmd_opts.profile:
- errors.profile(pr, 'Load')
- script_callbacks.model_loaded_callback(sd_model)
- shared.log.info(f"Load {op}: time={timer.summary()} native={get_native(sd_model)} {memory_stats()}")
-
-
-class DiffusersTaskType(Enum):
- TEXT_2_IMAGE = 1
- IMAGE_2_IMAGE = 2
- INPAINTING = 3
- INSTRUCT = 4
-
-
-def get_diffusers_task(pipe: diffusers.DiffusionPipeline) -> DiffusersTaskType:
- if pipe.__class__.__name__ == "StableDiffusionXLInstructPix2PixPipeline":
- return DiffusersTaskType.INSTRUCT
- elif pipe.__class__ in diffusers.pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING.values():
- return DiffusersTaskType.IMAGE_2_IMAGE
- elif pipe.__class__ in diffusers.pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING.values():
- return DiffusersTaskType.INPAINTING
- else:
- return DiffusersTaskType.TEXT_2_IMAGE
-
-
-def switch_diffuser_pipe(pipeline, cls):
- try:
- new_pipe = None
- if isinstance(pipeline, cls):
- return pipeline
- elif isinstance(pipeline, diffusers.StableDiffusionXLPipeline):
- new_pipe = cls(
- vae=pipeline.vae,
- text_encoder=pipeline.text_encoder,
- text_encoder_2=pipeline.text_encoder_2,
- tokenizer=pipeline.tokenizer,
- tokenizer_2=pipeline.tokenizer_2,
- unet=pipeline.unet,
- scheduler=pipeline.scheduler,
- feature_extractor=getattr(pipeline, 'feature_extractor', None),
- ).to(pipeline.device)
- elif isinstance(pipeline, diffusers.StableDiffusionPipeline):
- new_pipe = cls(
- vae=pipeline.vae,
- text_encoder=pipeline.text_encoder,
- tokenizer=pipeline.tokenizer,
- unet=pipeline.unet,
- scheduler=pipeline.scheduler,
- feature_extractor=getattr(pipeline, 'feature_extractor', None),
- requires_safety_checker=False,
- safety_checker=None,
- ).to(pipeline.device)
- else:
- shared.log.error(f'Pipeline switch error: {pipeline.__class__.__name__} unrecognized')
- return pipeline
- if new_pipe is not None:
- copy_diffuser_options(new_pipe, pipeline)
- shared.log.debug(f'Pipeline switch: from={pipeline.__class__.__name__} to={new_pipe.__class__.__name__}')
- return new_pipe
- else:
- shared.log.error(f'Pipeline switch error: from={pipeline.__class__.__name__} to={cls.__name__} empty pipeline')
- except Exception as e:
- shared.log.error(f'Pipeline switch error: from={pipeline.__class__.__name__} to={cls.__name__} {e}')
- return pipeline
-
-
-def set_diffuser_pipe(pipe, new_pipe_type):
- sd_checkpoint_info = getattr(pipe, "sd_checkpoint_info", None)
- sd_model_checkpoint = getattr(pipe, "sd_model_checkpoint", None)
- sd_model_hash = getattr(pipe, "sd_model_hash", None)
- has_accelerate = getattr(pipe, "has_accelerate", None)
- embedding_db = getattr(pipe, "embedding_db", None)
- image_encoder = getattr(pipe, "image_encoder", None)
- feature_extractor = getattr(pipe, "feature_extractor", None)
-
- # skip specific pipelines
- if pipe.__class__.__name__ == 'StableDiffusionReferencePipeline' or pipe.__class__.__name__ == 'StableDiffusionAdapterPipeline':
- return pipe
-
- try:
- if new_pipe_type == DiffusersTaskType.TEXT_2_IMAGE:
- new_pipe = diffusers.AutoPipelineForText2Image.from_pipe(pipe)
- elif new_pipe_type == DiffusersTaskType.IMAGE_2_IMAGE:
- new_pipe = diffusers.AutoPipelineForImage2Image.from_pipe(pipe)
- elif new_pipe_type == DiffusersTaskType.INPAINTING:
- new_pipe = diffusers.AutoPipelineForInpainting.from_pipe(pipe)
- except Exception as e: # pylint: disable=unused-variable
- shared.log.warning(f'Failed to change: type={new_pipe_type} pipeline={pipe.__class__.__name__} {e}')
- return pipe
-
- if pipe.__class__ == new_pipe.__class__:
- return pipe
- new_pipe.sd_checkpoint_info = sd_checkpoint_info
- new_pipe.sd_model_checkpoint = sd_model_checkpoint
- new_pipe.sd_model_hash = sd_model_hash
- new_pipe.has_accelerate = has_accelerate
- new_pipe.embedding_db = embedding_db
- new_pipe.image_encoder = image_encoder
- new_pipe.feature_extractor = feature_extractor
- new_pipe.is_sdxl = getattr(pipe, 'is_sdxl', False) # a1111 compatibility item
- new_pipe.is_sd2 = getattr(pipe, 'is_sd2', False)
- new_pipe.is_sd1 = getattr(pipe, 'is_sd1', True)
- shared.log.debug(f"Pipeline class change: original={pipe.__class__.__name__} target={new_pipe.__class__.__name__}")
- pipe = new_pipe
- return pipe
-
-
-def get_native(pipe: diffusers.DiffusionPipeline):
- if hasattr(pipe, "vae") and hasattr(pipe.vae.config, "sample_size"):
- # Stable Diffusion
- size = pipe.vae.config.sample_size
- elif hasattr(pipe, "movq") and hasattr(pipe.movq.config, "sample_size"):
- # Kandinsky
- size = pipe.movq.config.sample_size
- elif hasattr(pipe, "unet") and hasattr(pipe.unet.config, "sample_size"):
- size = pipe.unet.config.sample_size
- else:
- size = 0
- return size
-
-
-def load_model(checkpoint_info=None, already_loaded_state_dict=None, timer=None, op='model'):
- from modules import lowvram, sd_hijack
- checkpoint_info = checkpoint_info or select_checkpoint(op=op)
- if checkpoint_info is None:
- return
- if op == 'model' or op == 'dict':
- if model_data.sd_model is not None and (checkpoint_info.hash == model_data.sd_model.sd_checkpoint_info.hash): # trying to load the same model
- return
- else:
- if model_data.sd_refiner is not None and (checkpoint_info.hash == model_data.sd_refiner.sd_checkpoint_info.hash): # trying to load the same model
- return
- shared.log.debug(f'Load {op}: name={checkpoint_info.filename} dict={already_loaded_state_dict is not None}')
- if timer is None:
- timer = Timer()
- current_checkpoint_info = None
- if op == 'model' or op == 'dict':
- if model_data.sd_model is not None:
- sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
- current_checkpoint_info = model_data.sd_model.sd_checkpoint_info
- unload_model_weights(op=op)
- else:
- if model_data.sd_refiner is not None:
- sd_hijack.model_hijack.undo_hijack(model_data.sd_refiner)
- current_checkpoint_info = model_data.sd_refiner.sd_checkpoint_info
- unload_model_weights(op=op)
-
- if shared.backend == shared.Backend.ORIGINAL:
- from modules import sd_hijack_inpainting
- sd_hijack_inpainting.do_inpainting_hijack()
-
- devices.set_cuda_params()
- if already_loaded_state_dict is not None:
- state_dict = already_loaded_state_dict
- else:
- state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
- checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
- if state_dict is None or checkpoint_config is None:
- shared.log.error(f"Failed to load checkpooint: {checkpoint_info.filename}")
- if current_checkpoint_info is not None:
- shared.log.info(f"Restoring previous checkpoint: {current_checkpoint_info.filename}")
- load_model(current_checkpoint_info, None)
- return
- shared.log.debug(f'Model dict loaded: {memory_stats()}')
- sd_config = OmegaConf.load(checkpoint_config)
- repair_config(sd_config)
- timer.record("config")
- shared.log.debug(f'Model config loaded: {memory_stats()}')
- sd_model = None
- stdout = io.StringIO()
- if os.environ.get('SD_LDM_DEBUG', None) is not None:
- sd_model = instantiate_from_config(sd_config.model)
- else:
- with contextlib.redirect_stdout(stdout):
- """
- try:
- clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict
- with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd):
- sd_model = instantiate_from_config(sd_config.model)
- except Exception as e:
- shared.log.error(f'LDM: instantiate from config: {e}')
- sd_model = instantiate_from_config(sd_config.model)
- """
- sd_model = instantiate_from_config(sd_config.model)
- for line in stdout.getvalue().splitlines():
- if len(line) > 0:
- shared.log.info(f'LDM: {line.strip()}')
- shared.log.debug(f"Model created from config: {checkpoint_config}")
- sd_model.used_config = checkpoint_config
- sd_model.has_accelerate = False
- timer.record("create")
- ok = load_model_weights(sd_model, checkpoint_info, state_dict, timer)
- if not ok:
- model_data.sd_model = sd_model
- current_checkpoint_info = None
- unload_model_weights(op=op)
- shared.log.debug(f'Model weights unloaded: {memory_stats()} op={op}')
- if op == 'refiner':
- # shared.opts.data['sd_model_refiner'] = 'None'
- shared.opts.sd_model_refiner = 'None'
- return
- else:
- shared.log.debug(f'Model weights loaded: {memory_stats()}')
- timer.record("load")
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
- lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
- else:
- sd_model.to(devices.device)
- timer.record("move")
- shared.log.debug(f'Model weights moved: {memory_stats()}')
- sd_hijack.model_hijack.hijack(sd_model)
- timer.record("hijack")
- sd_model.eval()
- if op == 'refiner':
- model_data.sd_refiner = sd_model
- else:
- model_data.sd_model = sd_model
- sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
- timer.record("embeddings")
- script_callbacks.model_loaded_callback(sd_model)
- timer.record("callbacks")
- shared.log.info(f"Model loaded in {timer.summary()}")
- current_checkpoint_info = None
- devices.torch_gc(force=True)
- shared.log.info(f'Model load finished: {memory_stats()} cached={len(checkpoints_loaded.keys())}')
-
-
-def reload_model_weights(sd_model=None, info=None, reuse_dict=False, op='model'):
- load_dict = shared.opts.sd_model_dict != model_data.sd_dict
- from modules import lowvram, sd_hijack
- checkpoint_info = info or select_checkpoint(op=op) # are we selecting model or dictionary
- next_checkpoint_info = info or select_checkpoint(op='dict' if load_dict else 'model') if load_dict else None
- if checkpoint_info is None:
- unload_model_weights(op=op)
- return None
- orig_state = copy.deepcopy(shared.state)
- shared.state = shared_state.State()
- shared.state.begin('load')
- if load_dict:
- shared.log.debug(f'Model dict: existing={sd_model is not None} target={checkpoint_info.filename} info={info}')
- else:
- model_data.sd_dict = 'None'
- shared.log.debug(f'Load model weights: existing={sd_model is not None} target={checkpoint_info.filename} info={info}')
- if sd_model is None:
- sd_model = model_data.sd_model if op == 'model' or op == 'dict' else model_data.sd_refiner
- if sd_model is None: # previous model load failed
- current_checkpoint_info = None
- else:
- current_checkpoint_info = getattr(sd_model, 'sd_checkpoint_info', None)
- if current_checkpoint_info is not None and checkpoint_info is not None and current_checkpoint_info.filename == checkpoint_info.filename:
- return None
- if not getattr(sd_model, 'has_accelerate', False):
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
- lowvram.send_everything_to_cpu()
- else:
- sd_model.to(devices.cpu)
- if (reuse_dict or shared.opts.model_reuse_dict) and not getattr(sd_model, 'has_accelerate', False):
- shared.log.info('Reusing previous model dictionary')
- sd_hijack.model_hijack.undo_hijack(sd_model)
- else:
- unload_model_weights(op=op)
- sd_model = None
- timer = Timer()
- state_dict = get_checkpoint_state_dict(checkpoint_info, timer) if shared.backend == shared.Backend.ORIGINAL else None # TODO Revist after Diffusers enables state_dict loading
- checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
- timer.record("config")
- if sd_model is None or checkpoint_config != getattr(sd_model, 'used_config', None):
- sd_model = None
- if shared.backend == shared.Backend.ORIGINAL:
- load_model(checkpoint_info, already_loaded_state_dict=state_dict, timer=timer, op=op)
- model_data.sd_dict = shared.opts.sd_model_dict
- else:
- load_diffuser(checkpoint_info, already_loaded_state_dict=state_dict, timer=timer, op=op)
- if load_dict and next_checkpoint_info is not None:
- model_data.sd_dict = shared.opts.sd_model_dict
- shared.opts.data["sd_model_checkpoint"] = next_checkpoint_info.title
- reload_model_weights(reuse_dict=True) # ok we loaded dict now lets redo and load model on top of it
- shared.state.end()
- shared.state = orig_state
- # data['sd_model_checkpoint']
- if op == 'model' or op == 'dict':
- shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
- return model_data.sd_model
- else:
- shared.opts.data["sd_model_refiner"] = checkpoint_info.title
- return model_data.sd_refiner
-
- # fallback
- shared.log.info(f"Loading using fallback: {op} model={checkpoint_info.title}")
- try:
- load_model_weights(sd_model, checkpoint_info, state_dict, timer)
- except Exception:
- shared.log.error("Load model failed: restoring previous")
- load_model_weights(sd_model, current_checkpoint_info, None, timer)
- finally:
- sd_hijack.model_hijack.hijack(sd_model)
- timer.record("hijack")
- script_callbacks.model_loaded_callback(sd_model)
- timer.record("callbacks")
- if sd_model is not None and not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram and not getattr(sd_model, 'has_accelerate', False):
- sd_model.to(devices.device)
- timer.record("device")
- shared.state.end()
- shared.state = orig_state
- shared.log.info(f"Load: {op} time={timer.summary()}")
- return sd_model
-
-
-def convert_to_faketensors(tensor):
- fake_module = torch._subclasses.fake_tensor.FakeTensorMode(allow_non_fake_inputs=True) # pylint: disable=protected-access
- if hasattr(tensor, "weight"):
- tensor.weight = torch.nn.Parameter(fake_module.from_tensor(tensor.weight))
- return tensor
-
-
-def disable_offload(sd_model):
- from accelerate.hooks import remove_hook_from_module
- if not getattr(sd_model, 'has_accelerate', False):
- return
- for _name, model in sd_model.components.items():
- if not isinstance(model, torch.nn.Module):
- continue
- remove_hook_from_module(model, recurse=True)
-
-
-def unload_model_weights(op='model'):
- if shared.compiled_model_state is not None:
- shared.compiled_model_state.compiled_cache.clear()
- shared.compiled_model_state.partitioned_modules.clear()
- if op == 'model' or op == 'dict':
- if model_data.sd_model:
- if shared.backend == shared.Backend.ORIGINAL:
- from modules import sd_hijack
- model_data.sd_model.to(devices.cpu)
- sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
- elif not (shared.opts.cuda_compile and shared.opts.cuda_compile_backend == "openvino_fx"):
- disable_offload(model_data.sd_model)
- model_data.sd_model.to('meta')
- model_data.sd_model = None
- shared.log.debug(f'Unload weights {op}: {memory_stats()}')
- else:
- if model_data.sd_refiner:
- if shared.backend == shared.Backend.ORIGINAL:
- from modules import sd_hijack
- model_data.sd_model.to(devices.cpu)
- sd_hijack.model_hijack.undo_hijack(model_data.sd_refiner)
- else:
- disable_offload(model_data.sd_model)
- model_data.sd_refiner.to('meta')
- model_data.sd_refiner = None
- shared.log.debug(f'Unload weights {op}: {memory_stats()}')
- devices.torch_gc(force=True)
-
-
-def apply_token_merging(sd_model, token_merging_ratio=0):
- current_token_merging_ratio = getattr(sd_model, 'applied_token_merged_ratio', 0)
- if token_merging_ratio is None or current_token_merging_ratio is None or current_token_merging_ratio == token_merging_ratio:
- return
- try:
- if current_token_merging_ratio > 0:
- tomesd.remove_patch(sd_model)
- except Exception:
- pass
- if token_merging_ratio > 0:
- if shared.opts.hypertile_unet_enabled and not shared.cmd_opts.experimental:
- shared.log.warning('Token merging not supported with HyperTile for UNet')
- return
- try:
- tomesd.apply_patch(
- sd_model,
- ratio=token_merging_ratio,
- use_rand=False, # can cause issues with some samplers
- merge_attn=True,
- merge_crossattn=False,
- merge_mlp=False
- )
- shared.log.info(f'Applying token merging: ratio={token_merging_ratio}')
- sd_model.applied_token_merged_ratio = token_merging_ratio
- except Exception:
- shared.log.warning(f'Token merging not supported: pipeline={sd_model.__class__.__name__}')
- else:
- sd_model.applied_token_merged_ratio = 0
+import re
+import io
+import sys
+import json
+import time
+import copy
+import logging
+import contextlib
+import collections
+import os.path
+from os import mkdir
+from urllib import request
+from enum import Enum
+from rich import progress # pylint: disable=redefined-builtin
+import torch
+import safetensors.torch
+import diffusers
+from omegaconf import OmegaConf
+import tomesd
+from transformers import logging as transformers_logging
+from ldm.util import instantiate_from_config
+from modules import paths, shared, shared_items, shared_state, modelloader, devices, script_callbacks, sd_vae, errors, hashes, sd_models_config, sd_models_compile
+from modules.timer import Timer
+from modules.memstats import memory_stats
+from modules.paths import models_path, script_path
+from modules.modeldata import model_data
+
+
+transformers_logging.set_verbosity_error()
+model_dir = "Stable-diffusion"
+model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
+checkpoints_list = {}
+checkpoint_aliases = {}
+checkpoints_loaded = collections.OrderedDict()
+sd_metadata_file = os.path.join(paths.data_path, "metadata.json")
+sd_metadata = None
+sd_metadata_pending = 0
+sd_metadata_timer = 0
+
+
+class CheckpointInfo:
+ def __init__(self, filename):
+ self.name = None
+ self.hash = None
+ self.filename = filename
+ self.type = ''
+ relname = filename
+ app_path = os.path.abspath(script_path)
+
+ def rel(fn, path):
+ try:
+ return os.path.relpath(fn, path)
+ except Exception:
+ return fn
+
+ if relname.startswith('..'):
+ relname = os.path.abspath(relname)
+ if relname.startswith(shared.opts.ckpt_dir):
+ relname = rel(filename, shared.opts.ckpt_dir)
+ elif relname.startswith(shared.opts.diffusers_dir):
+ relname = rel(filename, shared.opts.diffusers_dir)
+ elif relname.startswith(model_path):
+ relname = rel(filename, model_path)
+ elif relname.startswith(script_path):
+ relname = rel(filename, script_path)
+ elif relname.startswith(app_path):
+ relname = rel(filename, app_path)
+ else:
+ relname = os.path.abspath(relname)
+ relname, ext = os.path.splitext(relname)
+ ext = ext.lower()[1:]
+
+ if os.path.isfile(filename): # ckpt or safetensor
+ self.name = relname
+ self.filename = filename
+ self.sha256 = hashes.sha256_from_cache(self.filename, f"checkpoint/{relname}")
+ self.type = ext
+ # self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
+ else: # maybe a diffuser
+ repo = [r for r in modelloader.diffuser_repos if filename == r['name']]
+ if len(repo) == 0:
+ self.name = relname
+ self.filename = filename
+ self.sha256 = None
+ self.type = 'unknown'
+ else:
+ self.name = os.path.join(os.path.basename(shared.opts.diffusers_dir), repo[0]['name'])
+ self.filename = repo[0]['path']
+ self.sha256 = repo[0]['hash']
+ self.type = 'diffusers'
+
+ self.shorthash = self.sha256[0:10] if self.sha256 else None
+ self.title = self.name if self.shorthash is None else f'{self.name} [{self.shorthash}]'
+ self.path = self.filename
+ self.model_name = os.path.basename(self.name)
+ self.metadata = read_metadata_from_safetensors(filename)
+ # shared.log.debug(f'Checkpoint: type={self.type} name={self.name} filename={self.filename} hash={self.shorthash} title={self.title}')
+
+ def register(self):
+ checkpoints_list[self.title] = self
+ for i in [self.name, self.filename, self.shorthash, self.title]:
+ if i is not None:
+ checkpoint_aliases[i] = self
+
+ def calculate_shorthash(self):
+ self.sha256 = hashes.sha256(self.filename, f"checkpoint/{self.name}")
+ if self.sha256 is None:
+ return None
+ self.shorthash = self.sha256[0:10]
+ checkpoints_list.pop(self.title)
+ self.title = f'{self.name} [{self.shorthash}]'
+ self.register()
+ return self.shorthash
+
+
+class NoWatermark:
+ def apply_watermark(self, img):
+ return img
+
+
+def setup_model():
+ if not os.path.exists(model_path):
+ os.makedirs(model_path, exist_ok=True)
+ list_models()
+ if shared.backend == shared.Backend.ORIGINAL:
+ enable_midas_autodownload()
+
+
+def checkpoint_tiles(use_short=False): # pylint: disable=unused-argument
+ def convert(name):
+ return int(name) if name.isdigit() else name.lower()
+ def alphanumeric_key(key):
+ return [convert(c) for c in re.split('([0-9]+)', key)]
+ return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
+
+
+def list_models():
+ t0 = time.time()
+ global checkpoints_list # pylint: disable=global-statement
+ checkpoints_list.clear()
+ checkpoint_aliases.clear()
+ if shared.opts.sd_disable_ckpt or shared.backend == shared.Backend.DIFFUSERS:
+ ext_filter = [".safetensors"]
+ else:
+ ext_filter = [".ckpt", ".safetensors"]
+ model_list = modelloader.load_models(model_path=model_path, model_url=None, command_path=shared.opts.ckpt_dir, ext_filter=ext_filter, download_name=None, ext_blacklist=[".vae.ckpt", ".vae.safetensors"])
+ if shared.backend == shared.Backend.DIFFUSERS:
+ model_list += modelloader.load_diffusers_models(model_path=os.path.join(models_path, 'Diffusers'), command_path=shared.opts.diffusers_dir, clear=True)
+ for filename in sorted(model_list, key=str.lower):
+ checkpoint_info = CheckpointInfo(filename)
+ if checkpoint_info.name is not None:
+ checkpoint_info.register()
+ if shared.cmd_opts.ckpt is not None:
+ if not os.path.exists(shared.cmd_opts.ckpt) and shared.backend == shared.Backend.ORIGINAL:
+ if shared.cmd_opts.ckpt.lower() != "none":
+ shared.log.warning(f"Requested checkpoint not found: {shared.cmd_opts.ckpt}")
+ else:
+ checkpoint_info = CheckpointInfo(shared.cmd_opts.ckpt)
+ if checkpoint_info.name is not None:
+ checkpoint_info.register()
+ shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title
+ elif shared.cmd_opts.ckpt != shared.default_sd_model_file and shared.cmd_opts.ckpt is not None:
+ shared.log.warning(f"Checkpoint not found: {shared.cmd_opts.ckpt}")
+ shared.log.info(f'Available models: path="{shared.opts.ckpt_dir}" items={len(checkpoints_list)} time={time.time()-t0:.2f}')
+
+ checkpoints_list = dict(sorted(checkpoints_list.items(), key=lambda cp: cp[1].filename))
+ """
+ if len(checkpoints_list) == 0:
+ if not shared.cmd_opts.no_download:
+ key = input('Download the default model? (y/N) ')
+ if key.lower().startswith('y'):
+ if shared.backend == shared.Backend.ORIGINAL:
+ model_url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors"
+ shared.opts.data['sd_model_checkpoint'] = "v1-5-pruned-emaonly.safetensors"
+ model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=shared.opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], download_name="v1-5-pruned-emaonly.safetensors", ext_blacklist=[".vae.ckpt", ".vae.safetensors"])
+ else:
+ default_model_id = "runwayml/stable-diffusion-v1-5"
+ modelloader.download_diffusers_model(default_model_id, shared.opts.diffusers_dir)
+ model_list = modelloader.load_diffusers_models(model_path=os.path.join(models_path, 'Diffusers'), command_path=shared.opts.diffusers_dir)
+
+ for filename in sorted(model_list, key=str.lower):
+ checkpoint_info = CheckpointInfo(filename)
+ if checkpoint_info.name is not None:
+ checkpoint_info.register()
+ """
+
+def update_model_hashes():
+ txt = []
+ lst = [ckpt for ckpt in checkpoints_list.values() if ckpt.hash is None]
+ # shared.log.info(f'Models list: short hash missing for {len(lst)} out of {len(checkpoints_list)} models')
+ for ckpt in lst:
+ ckpt.hash = model_hash(ckpt.filename)
+ # txt.append(f'Calculated short hash: {ckpt.title} {ckpt.hash}')
+ # txt.append(f'Updated short hashes for {len(lst)} out of {len(checkpoints_list)} models')
+ lst = [ckpt for ckpt in checkpoints_list.values() if ckpt.sha256 is None or ckpt.shorthash is None]
+ shared.log.info(f'Models list: hash missing={len(lst)} total={len(checkpoints_list)}')
+ for ckpt in lst:
+ ckpt.sha256 = hashes.sha256(ckpt.filename, f"checkpoint/{ckpt.name}")
+ ckpt.shorthash = ckpt.sha256[0:10] if ckpt.sha256 is not None else None
+ if ckpt.sha256 is not None:
+ txt.append(f'Calculated full hash: {ckpt.title} {ckpt.shorthash}')
+ else:
+ txt.append(f'Skipped hash calculation: {ckpt.title} ')
+ txt.append(f'Updated hashes for {len(lst)} out of {len(checkpoints_list)} models')
+ txt = ' '.join(txt)
+ return txt
+
+
+def get_closet_checkpoint_match(search_string):
+ checkpoint_info = checkpoint_aliases.get(search_string, None)
+ if checkpoint_info is not None:
+ return checkpoint_info
+ found = sorted([info for info in checkpoints_list.values() if search_string in info.title], key=lambda x: len(x.title))
+ if found:
+ return found[0]
+ found = sorted([info for info in checkpoints_list.values() if search_string.split(' ')[0] in info.title], key=lambda x: len(x.title))
+ if found:
+ return found[0]
+ return None
+
+
+def model_hash(filename):
+ """old hash that only looks at a small part of the file and is prone to collisions"""
+ try:
+ with open(filename, "rb") as file:
+ import hashlib
+ # t0 = time.time()
+ m = hashlib.sha256()
+ file.seek(0x100000)
+ m.update(file.read(0x10000))
+ shorthash = m.hexdigest()[0:8]
+ # t1 = time.time()
+ # shared.log.debug(f'Calculating short hash: {filename} hash={shorthash} time={(t1-t0):.2f}')
+ return shorthash
+ except FileNotFoundError:
+ return 'NOFILE'
+ except Exception:
+ return 'NOHASH'
+
+
+def select_checkpoint(op='model'):
+ if op == 'dict':
+ model_checkpoint = shared.opts.sd_model_dict
+ elif op == 'refiner':
+ model_checkpoint = shared.opts.data.get('sd_model_refiner', None)
+ else:
+ model_checkpoint = shared.opts.sd_model_checkpoint
+ if model_checkpoint is None or model_checkpoint == 'None':
+ return None
+ checkpoint_info = get_closet_checkpoint_match(model_checkpoint)
+ if checkpoint_info is not None:
+ shared.log.info(f'Select: {op}="{checkpoint_info.title if checkpoint_info is not None else None}"')
+ return checkpoint_info
+ if len(checkpoints_list) == 0 and not shared.cmd_opts.no_download:
+ shared.log.warning("Cannot generate without a checkpoint")
+ shared.log.info("Set system paths to use existing folders in a different location")
+ shared.log.info("Or use --ckpt to force using existing checkpoint")
+ return None
+ checkpoint_info = next(iter(checkpoints_list.values()))
+ if model_checkpoint is not None:
+ if model_checkpoint != 'model.ckpt' and model_checkpoint != 'runwayml/stable-diffusion-v1-5':
+ shared.log.warning(f"Selected checkpoint not found: {model_checkpoint}")
+ else:
+ shared.log.info("Selecting first available checkpoint")
+ # shared.log.warning(f"Loading fallback checkpoint: {checkpoint_info.title}")
+ shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title
+ shared.log.info(f'Select: {op}="{checkpoint_info.title if checkpoint_info is not None else None}"')
+ return checkpoint_info
+
+
+checkpoint_dict_replacements = {
+ 'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
+ 'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
+ 'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
+}
+
+
+def transform_checkpoint_dict_key(k):
+ for text, replacement in checkpoint_dict_replacements.items():
+ if k.startswith(text):
+ k = replacement + k[len(text):]
+ return k
+
+
+def get_state_dict_from_checkpoint(pl_sd):
+ pl_sd = pl_sd.pop("state_dict", pl_sd)
+ pl_sd.pop("state_dict", None)
+ sd = {}
+ for k, v in pl_sd.items():
+ new_key = transform_checkpoint_dict_key(k)
+ if new_key is not None:
+ sd[new_key] = v
+ pl_sd.clear()
+ pl_sd.update(sd)
+ return pl_sd
+
+
+def write_metadata():
+ global sd_metadata_pending # pylint: disable=global-statement
+ if sd_metadata_pending == 0:
+ shared.log.debug(f'Model metadata: file="{sd_metadata_file}" no changes')
+ return
+ shared.writefile(sd_metadata, sd_metadata_file)
+ shared.log.info(f'Model metadata saved: file="{sd_metadata_file}" items={sd_metadata_pending} time={sd_metadata_timer:.2f}')
+ sd_metadata_pending = 0
+
+
+def scrub_dict(dict_obj, keys):
+ for key in list(dict_obj.keys()):
+ if not isinstance(dict_obj, dict):
+ continue
+ if key in keys:
+ dict_obj.pop(key, None)
+ elif isinstance(dict_obj[key], dict):
+ scrub_dict(dict_obj[key], keys)
+ elif isinstance(dict_obj[key], list):
+ for item in dict_obj[key]:
+ scrub_dict(item, keys)
+
+
+def read_metadata_from_safetensors(filename):
+ global sd_metadata # pylint: disable=global-statement
+ if sd_metadata is None:
+ if not os.path.isfile(sd_metadata_file):
+ sd_metadata = {}
+ else:
+ sd_metadata = shared.readfile(sd_metadata_file, lock=True)
+ res = sd_metadata.get(filename, None)
+ if res is not None:
+ return res
+ if not filename.endswith(".safetensors"):
+ return {}
+ if shared.cmd_opts.no_metadata:
+ return {}
+ res = {}
+ try:
+ t0 = time.time()
+ with open(filename, mode="rb") as file:
+ metadata_len = file.read(8)
+ metadata_len = int.from_bytes(metadata_len, "little")
+ json_start = file.read(2)
+ if metadata_len <= 2 or json_start not in (b'{"', b"{'"):
+ shared.log.error(f"Not a valid safetensors file: {filename}")
+ json_data = json_start + file.read(metadata_len-2)
+ json_obj = json.loads(json_data)
+ for k, v in json_obj.get("__metadata__", {}).items():
+ if v.startswith("data:"):
+ v = 'data'
+ if k == 'format' and v == 'pt':
+ continue
+ large = True if len(v) > 2048 else False
+ if large and k == 'ss_datasets':
+ continue
+ if large and k == 'workflow':
+ continue
+ if large and k == 'prompt':
+ continue
+ if large and k == 'ss_bucket_info':
+ continue
+ if v[0:1] == '{':
+ try:
+ v = json.loads(v)
+ if large and k == 'ss_tag_frequency':
+ v = { i: len(j) for i, j in v.items() }
+ if large and k == 'sd_merge_models':
+ scrub_dict(v, ['sd_merge_recipe'])
+ except Exception:
+ pass
+ res[k] = v
+ sd_metadata[filename] = res
+ global sd_metadata_pending # pylint: disable=global-statement
+ sd_metadata_pending += 1
+ t1 = time.time()
+ global sd_metadata_timer # pylint: disable=global-statement
+ sd_metadata_timer += (t1 - t0)
+ except Exception as e:
+ shared.log.error(f"Error reading metadata from: {filename} {e}")
+ return res
+
+
+def read_state_dict(checkpoint_file, map_location=None): # pylint: disable=unused-argument
+ if not os.path.isfile(checkpoint_file):
+ shared.log.error(f"Model is not a file: {checkpoint_file}")
+ return None
+ try:
+ pl_sd = None
+ with progress.open(checkpoint_file, 'rb', description=f'[cyan]Loading model: [yellow]{checkpoint_file}', auto_refresh=True, console=shared.console) as f:
+ _, extension = os.path.splitext(checkpoint_file)
+ if extension.lower() == ".ckpt" and shared.opts.sd_disable_ckpt:
+ shared.log.warning(f"Checkpoint loading disabled: {checkpoint_file}")
+ return None
+ if shared.opts.stream_load:
+ if extension.lower() == ".safetensors":
+ # shared.log.debug('Model weights loading: type=safetensors mode=buffered')
+ buffer = f.read()
+ pl_sd = safetensors.torch.load(buffer)
+ else:
+ # shared.log.debug('Model weights loading: type=checkpoint mode=buffered')
+ buffer = io.BytesIO(f.read())
+ pl_sd = torch.load(buffer, map_location='cpu')
+ else:
+ if extension.lower() == ".safetensors":
+ # shared.log.debug('Model weights loading: type=safetensors mode=mmap')
+ pl_sd = safetensors.torch.load_file(checkpoint_file, device='cpu')
+ else:
+ # shared.log.debug('Model weights loading: type=checkpoint mode=direct')
+ pl_sd = torch.load(f, map_location='cpu')
+ sd = get_state_dict_from_checkpoint(pl_sd)
+ del pl_sd
+ except Exception as e:
+ errors.display(e, f'Load model: {checkpoint_file}')
+ sd = None
+ return sd
+
+
+def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
+ if not os.path.isfile(checkpoint_info.filename):
+ return None
+ if checkpoint_info in checkpoints_loaded:
+ shared.log.info("Model weights loading: from cache")
+ checkpoints_loaded.move_to_end(checkpoint_info, last=True) # FIFO -> LRU cache
+ return checkpoints_loaded[checkpoint_info]
+ res = read_state_dict(checkpoint_info.filename)
+ if shared.opts.sd_checkpoint_cache > 0 and shared.backend == shared.Backend.ORIGINAL:
+ # cache newly loaded model
+ checkpoints_loaded[checkpoint_info] = res
+ # clean up cache if limit is reached
+ while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
+ checkpoints_loaded.popitem(last=False)
+ timer.record("load")
+ return res
+
+
+def load_model_weights(model: torch.nn.Module, checkpoint_info: CheckpointInfo, state_dict, timer):
+ _pipeline, _model_type = detect_pipeline(checkpoint_info.path, 'model')
+ shared.log.debug(f'Model weights loading: {memory_stats()}')
+ timer.record("hash")
+ if model_data.sd_dict == 'None':
+ shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
+ if state_dict is None:
+ state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
+ try:
+ model.load_state_dict(state_dict, strict=False)
+ except Exception as e:
+ shared.log.error(f'Error loading model weights: {checkpoint_info.filename}')
+ shared.log.error(' '.join(str(e).splitlines()[:2]))
+ return False
+ del state_dict
+ timer.record("apply")
+ if shared.opts.opt_channelslast:
+ model.to(memory_format=torch.channels_last)
+ timer.record("channels")
+ if not shared.opts.no_half:
+ vae = model.first_stage_model
+ depth_model = getattr(model, 'depth_model', None)
+ # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
+ if shared.opts.no_half_vae:
+ model.first_stage_model = None
+ # with --upcast-sampling, don't convert the depth model weights to float16
+ if shared.opts.upcast_sampling and depth_model:
+ model.depth_model = None
+ model.half()
+ model.first_stage_model = vae
+ if depth_model:
+ model.depth_model = depth_model
+ if shared.opts.cuda_cast_unet:
+ devices.dtype_unet = model.model.diffusion_model.dtype
+ else:
+ model.model.diffusion_model.to(devices.dtype_unet)
+ model.first_stage_model.to(devices.dtype_vae)
+ model.sd_model_hash = checkpoint_info.calculate_shorthash()
+ model.sd_model_checkpoint = checkpoint_info.filename
+ model.sd_checkpoint_info = checkpoint_info
+ model.is_sdxl = False # a1111 compatibility item
+ model.is_sd2 = hasattr(model.cond_stage_model, 'model') # a1111 compatibility item
+ model.is_sd1 = not hasattr(model.cond_stage_model, 'model') # a1111 compatibility item
+ model.logvar = model.logvar.to(devices.device) if hasattr(model, 'logvar') else None # fix for training
+ shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
+ sd_vae.delete_base_vae()
+ sd_vae.clear_loaded_vae()
+ vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
+ sd_vae.load_vae(model, vae_file, vae_source)
+ timer.record("vae")
+ return True
+
+
+def enable_midas_autodownload():
+ """
+ Gives the ldm.modules.midas.api.load_model function automatic downloading.
+
+ When the 512-depth-ema model, and other future models like it, is loaded,
+ it calls midas.api.load_model to load the associated midas depth model.
+ This function applies a wrapper to download the model to the correct
+ location automatically.
+ """
+ import ldm.modules.midas.api
+ midas_path = os.path.join(paths.models_path, 'midas')
+ for k, v in ldm.modules.midas.api.ISL_PATHS.items():
+ file_name = os.path.basename(v)
+ ldm.modules.midas.api.ISL_PATHS[k] = os.path.join(midas_path, file_name)
+ midas_urls = {
+ "dpt_large": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt",
+ "dpt_hybrid": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt",
+ "midas_v21": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21-f6b98070.pt",
+ "midas_v21_small": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21_small-70d6b9c8.pt",
+ }
+ ldm.modules.midas.api.load_model_inner = ldm.modules.midas.api.load_model
+
+ def load_model_wrapper(model_type):
+ path = ldm.modules.midas.api.ISL_PATHS[model_type]
+ if not os.path.exists(path):
+ if not os.path.exists(midas_path):
+ mkdir(midas_path)
+ shared.log.info(f"Downloading midas model weights for {model_type} to {path}")
+ request.urlretrieve(midas_urls[model_type], path)
+ shared.log.info(f"{model_type} downloaded")
+ return ldm.modules.midas.api.load_model_inner(model_type)
+
+ ldm.modules.midas.api.load_model = load_model_wrapper
+
+
+def repair_config(sd_config):
+ if "use_ema" not in sd_config.model.params:
+ sd_config.model.params.use_ema = False
+ if shared.opts.no_half:
+ sd_config.model.params.unet_config.params.use_fp16 = False
+ elif shared.opts.upcast_sampling:
+ sd_config.model.params.unet_config.params.use_fp16 = True if sys.platform != 'darwin' else False
+ if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
+ sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
+ # For UnCLIP-L, override the hardcoded karlo directory
+ if "noise_aug_config" in sd_config.model.params and "clip_stats_path" in sd_config.model.params.noise_aug_config.params:
+ karlo_path = os.path.join(paths.models_path, 'karlo')
+ sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path)
+
+
+sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
+sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
+
+
+def change_backend():
+ shared.log.info(f'Backend changed: from={shared.backend} to={shared.opts.sd_backend}')
+ shared.log.warning('Full server restart required to apply all changes')
+ unload_model_weights()
+ shared.backend = shared.Backend.ORIGINAL if shared.opts.sd_backend == 'original' else shared.Backend.DIFFUSERS
+ checkpoints_loaded.clear()
+ from modules.sd_samplers import list_samplers
+ list_samplers(shared.backend)
+ list_models()
+ from modules.sd_vae import refresh_vae_list
+ refresh_vae_list()
+
+
+def detect_pipeline(f: str, op: str = 'model', warning=True):
+ if not f.endswith('.safetensors'):
+ return None, None
+ guess = shared.opts.diffusers_pipeline
+ warn = shared.log.warning if warning else lambda *args, **kwargs: None
+ if guess == 'Autodetect':
+ try:
+ # guess by size
+ size = round(os.path.getsize(f) / 1024 / 1024)
+ if size < 128:
+ warn(f'Model size smaller than expected: {f} size={size} MB')
+ elif (size >= 316 and size <= 324) or (size >= 156 and size <= 164): # 320 or 160
+ warn(f'Model detected as VAE model, but attempting to load as model: {op}={f} size={size} MB')
+ guess = 'VAE'
+ elif size >= 5351 and size <= 5359: # 5353
+ guess = 'Stable Diffusion' # SD v2
+ elif size >= 5791 and size <= 5799: # 5795
+ if shared.backend == shared.Backend.ORIGINAL:
+ warn(f'Model detected as SD-XL refiner model, but attempting to load using backend=original: {op}={f} size={size} MB')
+ if op == 'model':
+ warn(f'Model detected as SD-XL refiner model, but attempting to load a base model: {op}={f} size={size} MB')
+ guess = 'Stable Diffusion XL'
+ elif (size >= 6611 and size <= 6619) or (size >= 6771 and size <= 6779): # 6617, HassakuXL is 6776
+ if shared.backend == shared.Backend.ORIGINAL:
+ warn(f'Model detected as SD-XL base model, but attempting to load using backend=original: {op}={f} size={size} MB')
+ guess = 'Stable Diffusion XL'
+ elif size >= 3361 and size <= 3369: # 3368
+ if shared.backend == shared.Backend.ORIGINAL:
+ warn(f'Model detected as SD upscale model, but attempting to load using backend=original: {op}={f} size={size} MB')
+ guess = 'Stable Diffusion Upscale'
+ elif size >= 4891 and size <= 4899: # 4897
+ if shared.backend == shared.Backend.ORIGINAL:
+ warn(f'Model detected as SD XL inpaint model, but attempting to load using backend=original: {op}={f} size={size} MB')
+ guess = 'Stable Diffusion XL Inpaint'
+ elif size >= 9791 and size <= 9799: # 9794
+ if shared.backend == shared.Backend.ORIGINAL:
+ warn(f'Model detected as SD XL instruct pix2pix model, but attempting to load using backend=original: {op}={f} size={size} MB')
+ guess = 'Stable Diffusion XL Instruct'
+ elif size > 3138 and size < 3142: #3140
+ if shared.backend == shared.Backend.ORIGINAL:
+ warn(f'Model detected as Segmind Vega model, but attempting to load using backend=original: {op}={f} size={size} MB')
+ guess = 'Stable Diffusion XL'
+ else:
+ guess = 'Stable Diffusion'
+ # guess by name
+ """
+ if 'LCM_' in f.upper() or 'LCM-' in f.upper() or '_LCM' in f.upper() or '-LCM' in f.upper():
+ if shared.backend == shared.Backend.ORIGINAL:
+ warn(f'Model detected as LCM model, but attempting to load using backend=original: {op}={f} size={size} MB')
+ guess = 'Latent Consistency Model'
+ """
+ if 'PixArt' in f:
+ if shared.backend == shared.Backend.ORIGINAL:
+ warn(f'Model detected as PixArt Alpha model, but attempting to load using backend=original: {op}={f} size={size} MB')
+ guess = 'PixArt Alpha'
+ # switch for specific variant
+ if guess == 'Stable Diffusion' and 'inpaint' in f.lower():
+ guess = 'Stable Diffusion Inpaint'
+ elif guess == 'Stable Diffusion' and 'instruct' in f.lower():
+ guess = 'Stable Diffusion Instruct'
+ if guess == 'Stable Diffusion XL' and 'inpaint' in f.lower():
+ guess = 'Stable Diffusion XL Inpaint'
+ elif guess == 'Stable Diffusion XL' and 'instruct' in f.lower():
+ guess = 'Stable Diffusion XL Instruct'
+ # get actual pipeline
+ pipeline = shared_items.get_pipelines().get(guess, None)
+ shared.log.info(f'Autodetect: {op}="{guess}" class={pipeline.__name__} file="{f}" size={size}MB')
+ except Exception as e:
+ shared.log.error(f'Error detecting diffusers pipeline: model={f} {e}')
+ return None, None
+ else:
+ try:
+ size = round(os.path.getsize(f) / 1024 / 1024)
+ pipeline = shared_items.get_pipelines().get(guess, None)
+ shared.log.info(f'Diffusers: {op}="{guess}" class={pipeline.__name__} file="{f}" size={size}MB')
+ except Exception as e:
+ shared.log.error(f'Error loading diffusers pipeline: model={f} {e}')
+
+ if pipeline is None:
+ shared.log.warning(f'Autodetect: pipeline not recognized: {guess}: {op}={f} size={size}')
+ pipeline = diffusers.StableDiffusionPipeline
+ return pipeline, guess
+
+
+def copy_diffuser_options(new_pipe, orig_pipe):
+ new_pipe.sd_checkpoint_info = orig_pipe.sd_checkpoint_info
+ new_pipe.sd_model_checkpoint = orig_pipe.sd_model_checkpoint
+ new_pipe.embedding_db = getattr(orig_pipe, 'embedding_db', None)
+ new_pipe.sd_model_hash = getattr(orig_pipe, 'sd_model_hash', None)
+ new_pipe.has_accelerate = getattr(orig_pipe, 'has_accelerate', False)
+ new_pipe.is_sdxl = getattr(orig_pipe, 'is_sdxl', False) # a1111 compatibility item
+ new_pipe.is_sd2 = getattr(orig_pipe, 'is_sd2', False)
+ new_pipe.is_sd1 = getattr(orig_pipe, 'is_sd1', True)
+
+
+def set_diffuser_options(sd_model, vae = None, op: str = 'model'):
+ if sd_model is None:
+ shared.log.warning(f'{op} is not loaded')
+ return
+ if (shared.opts.diffusers_model_cpu_offload or shared.cmd_opts.medvram) and (shared.opts.diffusers_seq_cpu_offload or shared.cmd_opts.lowvram):
+ shared.log.warning(f'Setting {op}: Model CPU offload and Sequential CPU offload are not compatible')
+ shared.log.debug(f'Setting {op}: disabling model CPU offload')
+ shared.opts.diffusers_model_cpu_offload=False
+ shared.cmd_opts.medvram=False
+
+ if hasattr(sd_model, "watermark"):
+ sd_model.watermark = NoWatermark()
+ sd_model.has_accelerate = False
+ if hasattr(sd_model, "vae"):
+ if vae is not None:
+ sd_model.vae = vae
+ if shared.opts.diffusers_vae_upcast != 'default':
+ if shared.opts.diffusers_vae_upcast == 'true':
+ sd_model.vae.config.force_upcast = True
+ else:
+ sd_model.vae.config.force_upcast = False
+ if shared.opts.no_half_vae:
+ devices.dtype_vae = torch.float32
+ sd_model.vae.to(devices.dtype_vae)
+ shared.log.debug(f'Setting {op} VAE: name={sd_vae.loaded_vae_file} upcast={sd_model.vae.config.get("force_upcast", None)}')
+ if hasattr(sd_model, "enable_model_cpu_offload"):
+ if (shared.cmd_opts.medvram and devices.backend != "directml") or shared.opts.diffusers_model_cpu_offload:
+ shared.log.debug(f'Setting {op}: enable model CPU offload')
+ if shared.opts.diffusers_move_base or shared.opts.diffusers_move_unet or shared.opts.diffusers_move_refiner:
+ shared.opts.diffusers_move_base = False
+ shared.opts.diffusers_move_unet = False
+ shared.opts.diffusers_move_refiner = False
+ shared.log.warning(f'Disabling {op} "Move model to CPU" since "Model CPU offload" is enabled')
+ sd_model.enable_model_cpu_offload()
+ sd_model.has_accelerate = True
+ if hasattr(sd_model, "enable_sequential_cpu_offload"):
+ if shared.cmd_opts.lowvram or shared.opts.diffusers_seq_cpu_offload:
+ shared.log.debug(f'Setting {op}: enable sequential CPU offload')
+ if shared.opts.diffusers_move_base or shared.opts.diffusers_move_unet or shared.opts.diffusers_move_refiner:
+ shared.opts.diffusers_move_base = False
+ shared.opts.diffusers_move_unet = False
+ shared.opts.diffusers_move_refiner = False
+ shared.log.warning(f'Disabling {op} "Move model to CPU" since "Sequential CPU offload" is enabled')
+ sd_model.enable_sequential_cpu_offload(device=devices.device)
+ sd_model.has_accelerate = True
+ if hasattr(sd_model, "enable_vae_slicing"):
+ if shared.cmd_opts.lowvram or shared.opts.diffusers_vae_slicing:
+ shared.log.debug(f'Setting {op}: enable VAE slicing')
+ sd_model.enable_vae_slicing()
+ else:
+ sd_model.disable_vae_slicing()
+ if hasattr(sd_model, "enable_vae_tiling"):
+ if shared.cmd_opts.lowvram or shared.opts.diffusers_vae_tiling:
+ shared.log.debug(f'Setting {op}: enable VAE tiling')
+ sd_model.enable_vae_tiling()
+ else:
+ sd_model.disable_vae_tiling()
+ if hasattr(sd_model, "enable_attention_slicing"):
+ if shared.cmd_opts.lowvram or shared.opts.diffusers_attention_slicing:
+ shared.log.debug(f'Setting {op}: enable attention slicing')
+ sd_model.enable_attention_slicing()
+ else:
+ sd_model.disable_attention_slicing()
+ if hasattr(sd_model, "vqvae"):
+ sd_model.vqvae.to(torch.float32) # vqvae is producing nans in fp16
+ if shared.opts.cross_attention_optimization == "xFormers" and hasattr(sd_model, 'enable_xformers_memory_efficient_attention'):
+ sd_model.enable_xformers_memory_efficient_attention()
+ if shared.opts.diffusers_fuse_projections and hasattr(sd_model, 'fuse_qkv_projections'):
+ shared.log.debug(f'Setting {op}: enable fused projections')
+ sd_model.fuse_qkv_projections()
+ if shared.opts.diffusers_eval:
+ if hasattr(sd_model, "unet") and hasattr(sd_model.unet, "requires_grad_"):
+ sd_model.unet.requires_grad_(False)
+ sd_model.unet.eval()
+ if hasattr(sd_model, "vae") and hasattr(sd_model.vae, "requires_grad_"):
+ sd_model.vae.requires_grad_(False)
+ sd_model.vae.eval()
+ if hasattr(sd_model, "text_encoder") and hasattr(sd_model.text_encoder, "requires_grad_"):
+ sd_model.text_encoder.requires_grad_(False)
+ sd_model.text_encoder.eval()
+ if shared.opts.diffusers_quantization:
+ sd_model = sd_models_compile.dynamic_quantization(sd_model)
+
+ if shared.opts.opt_channelslast and hasattr(sd_model, 'unet'):
+ shared.log.debug(f'Setting {op}: enable channels last')
+ sd_model.unet.to(memory_format=torch.channels_last)
+
+
+def load_diffuser(checkpoint_info=None, already_loaded_state_dict=None, timer=None, op='model'): # pylint: disable=unused-argument
+ import torch # pylint: disable=reimported,redefined-outer-name
+ if shared.cmd_opts.profile:
+ import cProfile
+ pr = cProfile.Profile()
+ pr.enable()
+ if timer is None:
+ timer = Timer()
+ logging.getLogger("diffusers").setLevel(logging.ERROR)
+ timer.record("diffusers")
+ devices.set_cuda_params()
+ diffusers_load_config = {
+ "low_cpu_mem_usage": True,
+ "torch_dtype": devices.dtype,
+ "safety_checker": None,
+ "requires_safety_checker": False,
+ "load_safety_checker": False,
+ "load_connected_pipeline": True,
+ # TODO: use_safetensors cant enable for all checkpoints just yet
+ }
+ if shared.opts.diffusers_model_load_variant == 'default':
+ if devices.dtype == torch.float16:
+ diffusers_load_config['variant'] = 'fp16'
+ elif shared.opts.diffusers_model_load_variant == 'fp32':
+ pass
+ else:
+ diffusers_load_config['variant'] = shared.opts.diffusers_model_load_variant
+
+ if shared.opts.diffusers_pipeline == 'Custom Diffusers Pipeline' and len(shared.opts.custom_diffusers_pipeline) > 0:
+ shared.log.debug(f'Diffusers custom pipeline: {shared.opts.custom_diffusers_pipeline}')
+ diffusers_load_config['custom_pipeline'] = shared.opts.custom_diffusers_pipeline
+
+ # if 'LCM' in checkpoint_info.path:
+ # diffusers_load_config['custom_pipeline'] = 'latent_consistency_txt2img'
+
+ if shared.opts.data.get('sd_model_checkpoint', '') == 'model.ckpt' or shared.opts.data.get('sd_model_checkpoint', '') == '':
+ shared.opts.data['sd_model_checkpoint'] = "runwayml/stable-diffusion-v1-5"
+
+ if op == 'model' or op == 'dict':
+ if (model_data.sd_model is not None) and (checkpoint_info is not None) and (checkpoint_info.hash == model_data.sd_model.sd_checkpoint_info.hash): # trying to load the same model
+ return
+ else:
+ if (model_data.sd_refiner is not None) and (checkpoint_info is not None) and (checkpoint_info.hash == model_data.sd_refiner.sd_checkpoint_info.hash): # trying to load the same model
+ return
+
+ sd_model = None
+
+ try:
+ if shared.cmd_opts.ckpt is not None and os.path.isdir(shared.cmd_opts.ckpt) and model_data.initial: # initial load
+ ckpt_basename = os.path.basename(shared.cmd_opts.ckpt)
+ model_name = modelloader.find_diffuser(ckpt_basename)
+ if model_name is not None:
+ shared.log.info(f'Load model {op}: {model_name}')
+ model_file = modelloader.download_diffusers_model(hub_id=model_name)
+ try:
+ shared.log.debug(f'Model load {op} config: {diffusers_load_config}')
+ sd_model = diffusers.DiffusionPipeline.from_pretrained(model_file, **diffusers_load_config)
+ except Exception as e:
+ shared.log.error(f'Failed loading model: {model_file} {e}')
+ list_models() # rescan for downloaded model
+ checkpoint_info = CheckpointInfo(model_name)
+
+ checkpoint_info = checkpoint_info or select_checkpoint(op=op)
+ if checkpoint_info is None:
+ unload_model_weights(op=op)
+ return
+
+ vae = None
+ sd_vae.loaded_vae_file = None
+ if op == 'model' or op == 'refiner':
+ vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
+ vae = sd_vae.load_vae_diffusers(checkpoint_info.path, vae_file, vae_source)
+ if vae is not None:
+ diffusers_load_config["vae"] = vae
+
+ shared.log.debug(f'Diffusers loading: path="{checkpoint_info.path}"')
+ if os.path.isdir(checkpoint_info.path):
+ err1 = None
+ err2 = None
+ err3 = None
+ try: # try autopipeline first, best choice but not all pipelines are available
+ sd_model = diffusers.AutoPipelineForText2Image.from_pretrained(checkpoint_info.path, cache_dir=shared.opts.diffusers_dir, **diffusers_load_config)
+ sd_model.model_type = sd_model.__class__.__name__
+ except Exception as e:
+ err1 = e
+ # shared.log.error(f'AutoPipeline: {e}')
+ try: # try diffusion pipeline next second-best choice, works for most non-linked pipelines
+ if err1 is not None:
+ sd_model = diffusers.DiffusionPipeline.from_pretrained(checkpoint_info.path, cache_dir=shared.opts.diffusers_dir, **diffusers_load_config)
+ sd_model.model_type = sd_model.__class__.__name__
+ except Exception as e:
+ err2 = e
+ # shared.log.error(f'DiffusionPipeline: {e}')
+ try: # try basic pipeline next just in case
+ if err2 is not None:
+ sd_model = diffusers.StableDiffusionPipeline.from_pretrained(checkpoint_info.path, cache_dir=shared.opts.diffusers_dir, **diffusers_load_config)
+ sd_model.model_type = sd_model.__class__.__name__
+ except Exception as e:
+ err3 = e # ignore last error
+ shared.log.error(f'StableDiffusionPipeline: {e}')
+ if err3 is not None:
+ shared.log.error(f'Failed loading {op}: {checkpoint_info.path} auto={err1} diffusion={err2}')
+ return
+ elif os.path.isfile(checkpoint_info.path) and checkpoint_info.path.lower().endswith('.safetensors'):
+ # diffusers_load_config["local_files_only"] = True
+ diffusers_load_config["extract_ema"] = shared.opts.diffusers_extract_ema
+ pipeline, model_type = detect_pipeline(checkpoint_info.path, op)
+ if pipeline is None:
+ shared.log.error(f'Diffusers {op} pipeline not initialized: {shared.opts.diffusers_pipeline}')
+ return
+ try:
+ if model_type.startswith('Stable Diffusion'):
+ diffusers_load_config['force_zeros_for_empty_prompt '] = shared.opts.diffusers_force_zeros
+ diffusers_load_config['requires_aesthetics_score'] = shared.opts.diffusers_aesthetics_score
+ if 'inpainting' in checkpoint_info.path.lower():
+ diffusers_load_config['config_files'] = {
+ 'v1': 'configs/v1-inpainting-inference.yaml',
+ 'v2': 'configs/v2-inference-768-v.yaml',
+ 'xl': 'configs/sd_xl_base.yaml',
+ 'xl_refiner': 'configs/sd_xl_refiner.yaml',
+ }
+ else:
+ diffusers_load_config['config_files'] = {
+ 'v1': 'configs/v1-inference.yaml',
+ 'v2': 'configs/v2-inference-768-v.yaml',
+ 'xl': 'configs/sd_xl_base.yaml',
+ 'xl_refiner': 'configs/sd_xl_refiner.yaml',
+ }
+ if hasattr(pipeline, 'from_single_file'):
+ diffusers_load_config['use_safetensors'] = True
+ sd_model = pipeline.from_single_file(checkpoint_info.path, **diffusers_load_config)
+ if sd_model is not None and hasattr(sd_model, 'unet') and hasattr(sd_model.unet, 'config') and 'inpainting' in checkpoint_info.path.lower():
+ shared.log.debug('Model patch: type=inpaint')
+ sd_model.unet.config.in_channels = 9
+ elif hasattr(pipeline, 'from_ckpt'):
+ sd_model = pipeline.from_ckpt(checkpoint_info.path, **diffusers_load_config)
+ else:
+ shared.log.error(f'Diffusers {op} cannot load safetensor model: {checkpoint_info.path} {shared.opts.diffusers_pipeline}')
+ return
+ if sd_model is not None:
+ diffusers_load_config.pop('vae', None)
+ diffusers_load_config.pop('safety_checker', None)
+ diffusers_load_config.pop('requires_safety_checker', None)
+ diffusers_load_config.pop('load_safety_checker', None)
+ diffusers_load_config.pop('config_files', None)
+ diffusers_load_config.pop('local_files_only', None)
+ shared.log.debug(f'Setting {op}: pipeline={sd_model.__class__.__name__} config={diffusers_load_config}') # pylint: disable=protected-access
+ except Exception as e:
+ shared.log.error(f'Diffusers failed loading: {op}={checkpoint_info.path} pipeline={shared.opts.diffusers_pipeline}/{sd_model.__class__.__name__} {e}')
+ errors.display(e, f'loading {op}={checkpoint_info.path} pipeline={shared.opts.diffusers_pipeline}/{sd_model.__class__.__name__}')
+ return
+ else:
+ shared.log.error(f'Diffusers cannot load: {op}={checkpoint_info.path}')
+ return
+
+ if "StableDiffusion" in sd_model.__class__.__name__:
+ pass # scheduler is created on first use
+ elif "Kandinsky" in sd_model.__class__.__name__:
+ sd_model.scheduler.name = 'DDIM'
+
+ set_diffuser_options(sd_model, vae, op)
+
+ base_sent_to_cpu=False
+ if (shared.opts.cuda_compile and shared.opts.cuda_compile_backend != 'none') or shared.opts.ipex_optimize:
+ if op == 'refiner' and not getattr(sd_model, 'has_accelerate', False):
+ gpu_vram = memory_stats().get('gpu', {})
+ free_vram = gpu_vram.get('total', 0) - gpu_vram.get('used', 0)
+ refiner_enough_vram = free_vram >= 7 if "StableDiffusionXL" in sd_model.__class__.__name__ else 3
+ if not shared.opts.diffusers_move_base and refiner_enough_vram:
+ sd_model.to(devices.device)
+ base_sent_to_cpu=False
+ else:
+ if not refiner_enough_vram and not (shared.opts.diffusers_move_base and shared.opts.diffusers_move_refiner):
+ shared.log.warning(f"Insufficient GPU memory, using system memory as fallback: free={free_vram} GB")
+ if not shared.opts.shared.opts.diffusers_seq_cpu_offload and not shared.opts.diffusers_model_cpu_offload:
+ shared.log.debug('Enabled moving base model to CPU')
+ shared.log.debug('Enabled moving refiner model to CPU')
+ shared.opts.diffusers_move_base=True
+ shared.opts.diffusers_move_refiner=True
+ shared.log.debug('Moving base model to CPU')
+ if model_data.sd_model is not None:
+ model_data.sd_model.to(devices.cpu)
+ devices.torch_gc(force=True)
+ sd_model.to(devices.device)
+ base_sent_to_cpu=True
+ elif not getattr(sd_model, 'has_accelerate', False):
+ sd_model.to(devices.device)
+
+ sd_models_compile.compile_diffusers(sd_model)
+
+ if sd_model is None:
+ shared.log.error('Diffuser model not loaded')
+ return
+ sd_model.sd_model_hash = checkpoint_info.calculate_shorthash() # pylint: disable=attribute-defined-outside-init
+ sd_model.sd_checkpoint_info = checkpoint_info # pylint: disable=attribute-defined-outside-init
+ sd_model.sd_model_checkpoint = checkpoint_info.filename # pylint: disable=attribute-defined-outside-init
+ sd_model.is_sdxl = False # a1111 compatibility item
+ sd_model.is_sd2 = hasattr(sd_model, 'cond_stage_model') and hasattr(sd_model.cond_stage_model, 'model') # a1111 compatibility item
+ sd_model.is_sd1 = not sd_model.is_sd2 # a1111 compatibility item
+ sd_model.logvar = sd_model.logvar.to(devices.device) if hasattr(sd_model, 'logvar') else None # fix for training
+ shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
+ if hasattr(sd_model, "set_progress_bar_config"):
+ sd_model.set_progress_bar_config(bar_format='Progress {rate_fmt}{postfix} {bar} {percentage:3.0f}% {n_fmt}/{total_fmt} {elapsed} {remaining}', ncols=80, colour='#327fba')
+ if op == 'refiner' and shared.opts.diffusers_move_refiner and not getattr(sd_model, 'has_accelerate', False):
+ shared.log.debug('Moving refiner model to CPU')
+ sd_model.to(devices.cpu)
+ elif not getattr(sd_model, 'has_accelerate', False): # In offload modes, accelerate will move models around
+ sd_model.to(devices.device)
+ if op == 'refiner' and base_sent_to_cpu:
+ shared.log.debug('Moving base model back to GPU')
+ model_data.sd_model.to(devices.device)
+ except Exception as e:
+ shared.log.error("Failed to load diffusers model")
+ errors.display(e, "loading Diffusers model")
+
+ if sd_model is not None:
+ from modules.textual_inversion import textual_inversion
+ sd_model.embedding_db = textual_inversion.EmbeddingDatabase()
+ if op == 'refiner':
+ model_data.sd_refiner = sd_model
+ else:
+ model_data.sd_model = sd_model
+ sd_model.embedding_db.add_embedding_dir(shared.opts.embeddings_dir)
+ sd_model.embedding_db.load_textual_inversion_embeddings(force_reload=True)
+
+ timer.record("load")
+ devices.torch_gc(force=True)
+ if shared.cmd_opts.profile:
+ errors.profile(pr, 'Load')
+ script_callbacks.model_loaded_callback(sd_model)
+ shared.log.info(f"Load {op}: time={timer.summary()} native={get_native(sd_model)} {memory_stats()}")
+
+
+class DiffusersTaskType(Enum):
+ TEXT_2_IMAGE = 1
+ IMAGE_2_IMAGE = 2
+ INPAINTING = 3
+ INSTRUCT = 4
+
+
+def get_diffusers_task(pipe: diffusers.DiffusionPipeline) -> DiffusersTaskType:
+ if pipe.__class__.__name__ == "StableDiffusionXLInstructPix2PixPipeline":
+ return DiffusersTaskType.INSTRUCT
+ elif pipe.__class__ in diffusers.pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING.values():
+ return DiffusersTaskType.IMAGE_2_IMAGE
+ elif pipe.__class__ in diffusers.pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING.values():
+ return DiffusersTaskType.INPAINTING
+ else:
+ return DiffusersTaskType.TEXT_2_IMAGE
+
+
+def switch_diffuser_pipe(pipeline, cls):
+ try:
+ new_pipe = None
+ if isinstance(pipeline, cls):
+ return pipeline
+ elif isinstance(pipeline, diffusers.StableDiffusionXLPipeline):
+ new_pipe = cls(
+ vae=pipeline.vae,
+ text_encoder=pipeline.text_encoder,
+ text_encoder_2=pipeline.text_encoder_2,
+ tokenizer=pipeline.tokenizer,
+ tokenizer_2=pipeline.tokenizer_2,
+ unet=pipeline.unet,
+ scheduler=pipeline.scheduler,
+ feature_extractor=getattr(pipeline, 'feature_extractor', None),
+ ).to(pipeline.device)
+ elif isinstance(pipeline, diffusers.StableDiffusionPipeline):
+ new_pipe = cls(
+ vae=pipeline.vae,
+ text_encoder=pipeline.text_encoder,
+ tokenizer=pipeline.tokenizer,
+ unet=pipeline.unet,
+ scheduler=pipeline.scheduler,
+ feature_extractor=getattr(pipeline, 'feature_extractor', None),
+ requires_safety_checker=False,
+ safety_checker=None,
+ ).to(pipeline.device)
+ else:
+ shared.log.error(f'Pipeline switch error: {pipeline.__class__.__name__} unrecognized')
+ return pipeline
+ if new_pipe is not None:
+ copy_diffuser_options(new_pipe, pipeline)
+ shared.log.debug(f'Pipeline switch: from={pipeline.__class__.__name__} to={new_pipe.__class__.__name__}')
+ return new_pipe
+ else:
+ shared.log.error(f'Pipeline switch error: from={pipeline.__class__.__name__} to={cls.__name__} empty pipeline')
+ except Exception as e:
+ shared.log.error(f'Pipeline switch error: from={pipeline.__class__.__name__} to={cls.__name__} {e}')
+ return pipeline
+
+
+def set_diffuser_pipe(pipe, new_pipe_type):
+ sd_checkpoint_info = getattr(pipe, "sd_checkpoint_info", None)
+ sd_model_checkpoint = getattr(pipe, "sd_model_checkpoint", None)
+ sd_model_hash = getattr(pipe, "sd_model_hash", None)
+ has_accelerate = getattr(pipe, "has_accelerate", None)
+ embedding_db = getattr(pipe, "embedding_db", None)
+ image_encoder = getattr(pipe, "image_encoder", None)
+ feature_extractor = getattr(pipe, "feature_extractor", None)
+
+ # skip specific pipelines
+ if pipe.__class__.__name__ == 'StableDiffusionReferencePipeline' or pipe.__class__.__name__ == 'StableDiffusionAdapterPipeline':
+ return pipe
+
+ try:
+ if new_pipe_type == DiffusersTaskType.TEXT_2_IMAGE:
+ new_pipe = diffusers.AutoPipelineForText2Image.from_pipe(pipe)
+ elif new_pipe_type == DiffusersTaskType.IMAGE_2_IMAGE:
+ new_pipe = diffusers.AutoPipelineForImage2Image.from_pipe(pipe)
+ elif new_pipe_type == DiffusersTaskType.INPAINTING:
+ new_pipe = diffusers.AutoPipelineForInpainting.from_pipe(pipe)
+ except Exception as e: # pylint: disable=unused-variable
+ shared.log.warning(f'Failed to change: type={new_pipe_type} pipeline={pipe.__class__.__name__} {e}')
+ return pipe
+
+ if pipe.__class__ == new_pipe.__class__:
+ return pipe
+ new_pipe.sd_checkpoint_info = sd_checkpoint_info
+ new_pipe.sd_model_checkpoint = sd_model_checkpoint
+ new_pipe.sd_model_hash = sd_model_hash
+ new_pipe.has_accelerate = has_accelerate
+ new_pipe.embedding_db = embedding_db
+ new_pipe.image_encoder = image_encoder
+ new_pipe.feature_extractor = feature_extractor
+ new_pipe.is_sdxl = getattr(pipe, 'is_sdxl', False) # a1111 compatibility item
+ new_pipe.is_sd2 = getattr(pipe, 'is_sd2', False)
+ new_pipe.is_sd1 = getattr(pipe, 'is_sd1', True)
+ shared.log.debug(f"Pipeline class change: original={pipe.__class__.__name__} target={new_pipe.__class__.__name__}")
+ pipe = new_pipe
+ return pipe
+
+
+def get_native(pipe: diffusers.DiffusionPipeline):
+ if hasattr(pipe, "vae") and hasattr(pipe.vae.config, "sample_size"):
+ # Stable Diffusion
+ size = pipe.vae.config.sample_size
+ elif hasattr(pipe, "movq") and hasattr(pipe.movq.config, "sample_size"):
+ # Kandinsky
+ size = pipe.movq.config.sample_size
+ elif hasattr(pipe, "unet") and hasattr(pipe.unet.config, "sample_size"):
+ size = pipe.unet.config.sample_size
+ else:
+ size = 0
+ return size
+
+
+def load_model(checkpoint_info=None, already_loaded_state_dict=None, timer=None, op='model'):
+ from modules import lowvram, sd_hijack
+ checkpoint_info = checkpoint_info or select_checkpoint(op=op)
+ if checkpoint_info is None:
+ return
+ if op == 'model' or op == 'dict':
+ if model_data.sd_model is not None and (checkpoint_info.hash == model_data.sd_model.sd_checkpoint_info.hash): # trying to load the same model
+ return
+ else:
+ if model_data.sd_refiner is not None and (checkpoint_info.hash == model_data.sd_refiner.sd_checkpoint_info.hash): # trying to load the same model
+ return
+ shared.log.debug(f'Load {op}: name={checkpoint_info.filename} dict={already_loaded_state_dict is not None}')
+ if timer is None:
+ timer = Timer()
+ current_checkpoint_info = None
+ if op == 'model' or op == 'dict':
+ if model_data.sd_model is not None:
+ sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
+ current_checkpoint_info = model_data.sd_model.sd_checkpoint_info
+ unload_model_weights(op=op)
+ else:
+ if model_data.sd_refiner is not None:
+ sd_hijack.model_hijack.undo_hijack(model_data.sd_refiner)
+ current_checkpoint_info = model_data.sd_refiner.sd_checkpoint_info
+ unload_model_weights(op=op)
+
+ if shared.backend == shared.Backend.ORIGINAL:
+ from modules import sd_hijack_inpainting
+ sd_hijack_inpainting.do_inpainting_hijack()
+
+ devices.set_cuda_params()
+ if already_loaded_state_dict is not None:
+ state_dict = already_loaded_state_dict
+ else:
+ state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
+ checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
+ if state_dict is None or checkpoint_config is None:
+ shared.log.error(f"Failed to load checkpooint: {checkpoint_info.filename}")
+ if current_checkpoint_info is not None:
+ shared.log.info(f"Restoring previous checkpoint: {current_checkpoint_info.filename}")
+ load_model(current_checkpoint_info, None)
+ return
+ shared.log.debug(f'Model dict loaded: {memory_stats()}')
+ sd_config = OmegaConf.load(checkpoint_config)
+ repair_config(sd_config)
+ timer.record("config")
+ shared.log.debug(f'Model config loaded: {memory_stats()}')
+ sd_model = None
+ stdout = io.StringIO()
+ if os.environ.get('SD_LDM_DEBUG', None) is not None:
+ sd_model = instantiate_from_config(sd_config.model)
+ else:
+ with contextlib.redirect_stdout(stdout):
+ """
+ try:
+ clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict
+ with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd):
+ sd_model = instantiate_from_config(sd_config.model)
+ except Exception as e:
+ shared.log.error(f'LDM: instantiate from config: {e}')
+ sd_model = instantiate_from_config(sd_config.model)
+ """
+ sd_model = instantiate_from_config(sd_config.model)
+ for line in stdout.getvalue().splitlines():
+ if len(line) > 0:
+ shared.log.info(f'LDM: {line.strip()}')
+ shared.log.debug(f"Model created from config: {checkpoint_config}")
+ sd_model.used_config = checkpoint_config
+ sd_model.has_accelerate = False
+ timer.record("create")
+ ok = load_model_weights(sd_model, checkpoint_info, state_dict, timer)
+ if not ok:
+ model_data.sd_model = sd_model
+ current_checkpoint_info = None
+ unload_model_weights(op=op)
+ shared.log.debug(f'Model weights unloaded: {memory_stats()} op={op}')
+ if op == 'refiner':
+ # shared.opts.data['sd_model_refiner'] = 'None'
+ shared.opts.sd_model_refiner = 'None'
+ return
+ else:
+ shared.log.debug(f'Model weights loaded: {memory_stats()}')
+ timer.record("load")
+ if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+ lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
+ else:
+ sd_model.to(devices.device)
+ timer.record("move")
+ shared.log.debug(f'Model weights moved: {memory_stats()}')
+ sd_hijack.model_hijack.hijack(sd_model)
+ timer.record("hijack")
+ sd_model.eval()
+ if op == 'refiner':
+ model_data.sd_refiner = sd_model
+ else:
+ model_data.sd_model = sd_model
+ sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
+ timer.record("embeddings")
+ script_callbacks.model_loaded_callback(sd_model)
+ timer.record("callbacks")
+ shared.log.info(f"Model loaded in {timer.summary()}")
+ current_checkpoint_info = None
+ devices.torch_gc(force=True)
+ shared.log.info(f'Model load finished: {memory_stats()} cached={len(checkpoints_loaded.keys())}')
+
+
+def reload_model_weights(sd_model=None, info=None, reuse_dict=False, op='model'):
+ load_dict = shared.opts.sd_model_dict != model_data.sd_dict
+ from modules import lowvram, sd_hijack
+ checkpoint_info = info or select_checkpoint(op=op) # are we selecting model or dictionary
+ next_checkpoint_info = info or select_checkpoint(op='dict' if load_dict else 'model') if load_dict else None
+ if checkpoint_info is None:
+ unload_model_weights(op=op)
+ return None
+ orig_state = copy.deepcopy(shared.state)
+ shared.state = shared_state.State()
+ shared.state.begin('load')
+ if load_dict:
+ shared.log.debug(f'Model dict: existing={sd_model is not None} target={checkpoint_info.filename} info={info}')
+ else:
+ model_data.sd_dict = 'None'
+ shared.log.debug(f'Load model weights: existing={sd_model is not None} target={checkpoint_info.filename} info={info}')
+ if sd_model is None:
+ sd_model = model_data.sd_model if op == 'model' or op == 'dict' else model_data.sd_refiner
+ if sd_model is None: # previous model load failed
+ current_checkpoint_info = None
+ else:
+ current_checkpoint_info = getattr(sd_model, 'sd_checkpoint_info', None)
+ if current_checkpoint_info is not None and checkpoint_info is not None and current_checkpoint_info.filename == checkpoint_info.filename:
+ return None
+ if not getattr(sd_model, 'has_accelerate', False):
+ if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+ lowvram.send_everything_to_cpu()
+ else:
+ sd_model.to(devices.cpu)
+ if (reuse_dict or shared.opts.model_reuse_dict) and not getattr(sd_model, 'has_accelerate', False):
+ shared.log.info('Reusing previous model dictionary')
+ sd_hijack.model_hijack.undo_hijack(sd_model)
+ else:
+ unload_model_weights(op=op)
+ sd_model = None
+ timer = Timer()
+ state_dict = get_checkpoint_state_dict(checkpoint_info, timer) if shared.backend == shared.Backend.ORIGINAL else None # TODO Revist after Diffusers enables state_dict loading
+ checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
+ timer.record("config")
+ if sd_model is None or checkpoint_config != getattr(sd_model, 'used_config', None):
+ sd_model = None
+ if shared.backend == shared.Backend.ORIGINAL:
+ load_model(checkpoint_info, already_loaded_state_dict=state_dict, timer=timer, op=op)
+ model_data.sd_dict = shared.opts.sd_model_dict
+ else:
+ load_diffuser(checkpoint_info, already_loaded_state_dict=state_dict, timer=timer, op=op)
+ if load_dict and next_checkpoint_info is not None:
+ model_data.sd_dict = shared.opts.sd_model_dict
+ shared.opts.data["sd_model_checkpoint"] = next_checkpoint_info.title
+ reload_model_weights(reuse_dict=True) # ok we loaded dict now lets redo and load model on top of it
+ shared.state.end()
+ shared.state = orig_state
+ # data['sd_model_checkpoint']
+ if op == 'model' or op == 'dict':
+ shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
+ return model_data.sd_model
+ else:
+ shared.opts.data["sd_model_refiner"] = checkpoint_info.title
+ return model_data.sd_refiner
+
+ # fallback
+ shared.log.info(f"Loading using fallback: {op} model={checkpoint_info.title}")
+ try:
+ load_model_weights(sd_model, checkpoint_info, state_dict, timer)
+ except Exception:
+ shared.log.error("Load model failed: restoring previous")
+ load_model_weights(sd_model, current_checkpoint_info, None, timer)
+ finally:
+ sd_hijack.model_hijack.hijack(sd_model)
+ timer.record("hijack")
+ script_callbacks.model_loaded_callback(sd_model)
+ timer.record("callbacks")
+ if sd_model is not None and not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram and not getattr(sd_model, 'has_accelerate', False):
+ sd_model.to(devices.device)
+ timer.record("device")
+ shared.state.end()
+ shared.state = orig_state
+ shared.log.info(f"Load: {op} time={timer.summary()}")
+ return sd_model
+
+
+def convert_to_faketensors(tensor):
+ fake_module = torch._subclasses.fake_tensor.FakeTensorMode(allow_non_fake_inputs=True) # pylint: disable=protected-access
+ if hasattr(tensor, "weight"):
+ tensor.weight = torch.nn.Parameter(fake_module.from_tensor(tensor.weight))
+ return tensor
+
+
+def disable_offload(sd_model):
+ from accelerate.hooks import remove_hook_from_module
+ if not getattr(sd_model, 'has_accelerate', False):
+ return
+ for _name, model in sd_model.components.items():
+ if not isinstance(model, torch.nn.Module):
+ continue
+ remove_hook_from_module(model, recurse=True)
+
+
+def unload_model_weights(op='model'):
+ if shared.compiled_model_state is not None:
+ shared.compiled_model_state.compiled_cache.clear()
+ shared.compiled_model_state.partitioned_modules.clear()
+ if op == 'model' or op == 'dict':
+ if model_data.sd_model:
+ if shared.backend == shared.Backend.ORIGINAL:
+ from modules import sd_hijack
+ model_data.sd_model.to(devices.cpu)
+ sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
+ elif not (shared.opts.cuda_compile and shared.opts.cuda_compile_backend == "openvino_fx"):
+ disable_offload(model_data.sd_model)
+ model_data.sd_model.to('meta')
+ model_data.sd_model = None
+ shared.log.debug(f'Unload weights {op}: {memory_stats()}')
+ else:
+ if model_data.sd_refiner:
+ if shared.backend == shared.Backend.ORIGINAL:
+ from modules import sd_hijack
+ model_data.sd_model.to(devices.cpu)
+ sd_hijack.model_hijack.undo_hijack(model_data.sd_refiner)
+ else:
+ disable_offload(model_data.sd_model)
+ model_data.sd_refiner.to('meta')
+ model_data.sd_refiner = None
+ shared.log.debug(f'Unload weights {op}: {memory_stats()}')
+ devices.torch_gc(force=True)
+
+
+def apply_token_merging(sd_model, token_merging_ratio=0):
+ current_token_merging_ratio = getattr(sd_model, 'applied_token_merged_ratio', 0)
+ if token_merging_ratio is None or current_token_merging_ratio is None or current_token_merging_ratio == token_merging_ratio:
+ return
+ try:
+ if current_token_merging_ratio > 0:
+ tomesd.remove_patch(sd_model)
+ except Exception:
+ pass
+ if token_merging_ratio > 0:
+ if shared.opts.hypertile_unet_enabled and not shared.cmd_opts.experimental:
+ shared.log.warning('Token merging not supported with HyperTile for UNet')
+ return
+ try:
+ tomesd.apply_patch(
+ sd_model,
+ ratio=token_merging_ratio,
+ use_rand=False, # can cause issues with some samplers
+ merge_attn=True,
+ merge_crossattn=False,
+ merge_mlp=False
+ )
+ shared.log.info(f'Applying token merging: ratio={token_merging_ratio}')
+ sd_model.applied_token_merged_ratio = token_merging_ratio
+ except Exception:
+ shared.log.warning(f'Token merging not supported: pipeline={sd_model.__class__.__name__}')
+ else:
+ sd_model.applied_token_merged_ratio = 0
diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py
index d2bc0ac4c..50b8659d8 100644
--- a/modules/sd_models_config.py
+++ b/modules/sd_models_config.py
@@ -1,114 +1,114 @@
-import os
-
-import torch
-
-from modules import paths, devices
-
-sd_repo_configs_path = 'configs'
-config_default = paths.sd_default_config
-config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference-512-base.yaml")
-config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-768-v.yaml")
-config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
-config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
-config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
-config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
-config_inpainting = os.path.join(paths.sd_configs_path, "v1-inpainting-inference.yaml")
-config_instruct_pix2pix = os.path.join(paths.sd_configs_path, "instruct-pix2pix.yaml")
-config_alt_diffusion = os.path.join(paths.sd_configs_path, "alt-diffusion-inference.yaml")
-
-
-def is_using_v_parameterization_for_sd2(state_dict):
- """
- Detects whether unet in state_dict is using v-parameterization. Returns True if it is. You're welcome.
- """
- from modules import sd_disable_initialization
- import ldm.modules.diffusionmodules.openaimodel
-
- device = devices.cpu
- with sd_disable_initialization.DisableInitialization():
- unet = ldm.modules.diffusionmodules.openaimodel.UNetModel(
- use_checkpoint=True,
- use_fp16=False,
- image_size=32,
- in_channels=4,
- out_channels=4,
- model_channels=320,
- attention_resolutions=[4, 2, 1],
- num_res_blocks=2,
- channel_mult=[1, 2, 4, 4],
- num_head_channels=64,
- use_spatial_transformer=True,
- use_linear_in_transformer=True,
- transformer_depth=1,
- context_dim=1024,
- legacy=False
- )
- unet.eval()
-
- with devices.inference_context():
- unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k}
- unet.load_state_dict(unet_sd, strict=True)
- unet.to(device=device, dtype=torch.float)
-
- test_cond = torch.ones((1, 2, 1024), device=device) * 0.5
- x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5
-
- out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().item()
-
- return out < -1
-
-
-def guess_model_config_from_state_dict(sd, _filename):
- if sd is None:
- return None
- sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None)
- diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
- sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
-
- if sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
- return config_depth_model
- elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768:
- return config_unclip
- elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 1024:
- return config_unopenclip
-
- if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024:
- if diffusion_model_input.shape[1] == 9:
- return config_sd2_inpainting
- elif is_using_v_parameterization_for_sd2(sd):
- return config_sd2v
- else:
- return config_sd2
-
- if diffusion_model_input is not None:
- if diffusion_model_input.shape[1] == 9:
- return config_inpainting
- if diffusion_model_input.shape[1] == 8:
- return config_instruct_pix2pix
-
- if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
- return config_alt_diffusion
-
- return config_default
-
-
-def find_checkpoint_config(state_dict, info):
- if info is None:
- return guess_model_config_from_state_dict(state_dict, "")
-
- config = find_checkpoint_config_near_filename(info)
- if config is not None:
- return config
-
- return guess_model_config_from_state_dict(state_dict, info.filename)
-
-
-def find_checkpoint_config_near_filename(info):
- if info is None:
- return None
-
- config = f"{os.path.splitext(info.filename)[0]}.yaml"
- if os.path.exists(config):
- return config
-
- return None
+import os
+
+import torch
+
+from modules import paths, devices
+
+sd_repo_configs_path = 'configs'
+config_default = paths.sd_default_config
+config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference-512-base.yaml")
+config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-768-v.yaml")
+config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
+config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
+config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
+config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
+config_inpainting = os.path.join(paths.sd_configs_path, "v1-inpainting-inference.yaml")
+config_instruct_pix2pix = os.path.join(paths.sd_configs_path, "instruct-pix2pix.yaml")
+config_alt_diffusion = os.path.join(paths.sd_configs_path, "alt-diffusion-inference.yaml")
+
+
+def is_using_v_parameterization_for_sd2(state_dict):
+ """
+ Detects whether unet in state_dict is using v-parameterization. Returns True if it is. You're welcome.
+ """
+ from modules import sd_disable_initialization
+ import ldm.modules.diffusionmodules.openaimodel
+
+ device = devices.cpu
+ with sd_disable_initialization.DisableInitialization():
+ unet = ldm.modules.diffusionmodules.openaimodel.UNetModel(
+ use_checkpoint=True,
+ use_fp16=False,
+ image_size=32,
+ in_channels=4,
+ out_channels=4,
+ model_channels=320,
+ attention_resolutions=[4, 2, 1],
+ num_res_blocks=2,
+ channel_mult=[1, 2, 4, 4],
+ num_head_channels=64,
+ use_spatial_transformer=True,
+ use_linear_in_transformer=True,
+ transformer_depth=1,
+ context_dim=1024,
+ legacy=False
+ )
+ unet.eval()
+
+ with devices.inference_context():
+ unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k}
+ unet.load_state_dict(unet_sd, strict=True)
+ unet.to(device=device, dtype=torch.float)
+
+ test_cond = torch.ones((1, 2, 1024), device=device) * 0.5
+ x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5
+
+ out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().item()
+
+ return out < -1
+
+
+def guess_model_config_from_state_dict(sd, _filename):
+ if sd is None:
+ return None
+ sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None)
+ diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
+ sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
+
+ if sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
+ return config_depth_model
+ elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768:
+ return config_unclip
+ elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 1024:
+ return config_unopenclip
+
+ if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024:
+ if diffusion_model_input.shape[1] == 9:
+ return config_sd2_inpainting
+ elif is_using_v_parameterization_for_sd2(sd):
+ return config_sd2v
+ else:
+ return config_sd2
+
+ if diffusion_model_input is not None:
+ if diffusion_model_input.shape[1] == 9:
+ return config_inpainting
+ if diffusion_model_input.shape[1] == 8:
+ return config_instruct_pix2pix
+
+ if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
+ return config_alt_diffusion
+
+ return config_default
+
+
+def find_checkpoint_config(state_dict, info):
+ if info is None:
+ return guess_model_config_from_state_dict(state_dict, "")
+
+ config = find_checkpoint_config_near_filename(info)
+ if config is not None:
+ return config
+
+ return guess_model_config_from_state_dict(state_dict, info.filename)
+
+
+def find_checkpoint_config_near_filename(info):
+ if info is None:
+ return None
+
+ config = f"{os.path.splitext(info.filename)[0]}.yaml"
+ if os.path.exists(config):
+ return config
+
+ return None
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index e1cfd0bf6..806a32e00 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -1,84 +1,84 @@
-import os
-from modules import shared
-from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # pylint: disable=unused-import
-
-debug = shared.log.trace if os.environ.get('SD_SAMPLER_DEBUG', None) is not None else lambda *args, **kwargs: None
-debug('Trace: SAMPLER')
-all_samplers = []
-all_samplers = []
-all_samplers_map = {}
-samplers = all_samplers
-samplers_for_img2img = all_samplers
-samplers_map = {}
-
-
-def list_samplers(backend_name = shared.backend):
- global all_samplers # pylint: disable=global-statement
- global all_samplers_map # pylint: disable=global-statement
- global samplers # pylint: disable=global-statement
- global samplers_for_img2img # pylint: disable=global-statement
- global samplers_map # pylint: disable=global-statement
- if backend_name == shared.Backend.ORIGINAL:
- from modules import sd_samplers_compvis, sd_samplers_kdiffusion
- all_samplers = [*sd_samplers_compvis.samplers_data_compvis, *sd_samplers_kdiffusion.samplers_data_k_diffusion]
- else:
- from modules import sd_samplers_diffusers
- all_samplers = [*sd_samplers_diffusers.samplers_data_diffusers]
- all_samplers_map = {x.name: x for x in all_samplers}
- samplers = all_samplers
- samplers_for_img2img = all_samplers
- samplers_map = {}
- # shared.log.debug(f'Available samplers: {[x.name for x in all_samplers]}')
-
-
-def find_sampler_config(name):
- if name is not None and name != 'None':
- config = all_samplers_map.get(name, None)
- else:
- config = all_samplers[0]
- return config
-
-
-def visible_sampler_names():
- visible_samplers = [x for x in all_samplers if x.name in shared.opts.show_samplers] if len(shared.opts.show_samplers) > 0 else all_samplers
- return visible_samplers
-
-
-def create_sampler(name, model):
- if name == 'Default' and hasattr(model, 'scheduler'):
- config = {k: v for k, v in model.scheduler.config.items() if not k.startswith('_')}
- shared.log.debug(f'Sampler default {type(model.scheduler).__name__}: {config}')
- return model.scheduler
- config = find_sampler_config(name)
- if config is None:
- shared.log.error(f'Attempting to use unknown sampler: {name}')
- config = all_samplers[0]
- if shared.backend == shared.Backend.ORIGINAL:
- sampler = config.constructor(model)
- sampler.config = config
- sampler.initialize(p=None)
- sampler.name = name
- shared.log.debug(f'Sampler: sampler="{sampler.name}" config={sampler.config.options}')
- return sampler
- elif shared.backend == shared.Backend.DIFFUSERS:
- sampler = config.constructor(model)
- if not hasattr(model, 'scheduler_config'):
- model.scheduler_config = sampler.sampler.config.copy()
- model.scheduler = sampler.sampler
- shared.log.debug(f'Sampler: sampler="{sampler.name}" config={sampler.config}')
- return sampler.sampler
- else:
- return None
-
-
-def set_samplers():
- global samplers # pylint: disable=global-statement
- global samplers_for_img2img # pylint: disable=global-statement
- samplers = visible_sampler_names()
- # samplers_for_img2img = [x for x in samplers if x.name != "PLMS"]
- samplers_for_img2img = samplers
- samplers_map.clear()
- for sampler in all_samplers:
- samplers_map[sampler.name.lower()] = sampler.name
- for alias in sampler.aliases:
- samplers_map[alias.lower()] = sampler.name
+import os
+from modules import shared
+from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # pylint: disable=unused-import
+
+debug = shared.log.trace if os.environ.get('SD_SAMPLER_DEBUG', None) is not None else lambda *args, **kwargs: None
+debug('Trace: SAMPLER')
+all_samplers = []
+all_samplers = []
+all_samplers_map = {}
+samplers = all_samplers
+samplers_for_img2img = all_samplers
+samplers_map = {}
+
+
+def list_samplers(backend_name = shared.backend):
+ global all_samplers # pylint: disable=global-statement
+ global all_samplers_map # pylint: disable=global-statement
+ global samplers # pylint: disable=global-statement
+ global samplers_for_img2img # pylint: disable=global-statement
+ global samplers_map # pylint: disable=global-statement
+ if backend_name == shared.Backend.ORIGINAL:
+ from modules import sd_samplers_compvis, sd_samplers_kdiffusion
+ all_samplers = [*sd_samplers_compvis.samplers_data_compvis, *sd_samplers_kdiffusion.samplers_data_k_diffusion]
+ else:
+ from modules import sd_samplers_diffusers
+ all_samplers = [*sd_samplers_diffusers.samplers_data_diffusers]
+ all_samplers_map = {x.name: x for x in all_samplers}
+ samplers = all_samplers
+ samplers_for_img2img = all_samplers
+ samplers_map = {}
+ # shared.log.debug(f'Available samplers: {[x.name for x in all_samplers]}')
+
+
+def find_sampler_config(name):
+ if name is not None and name != 'None':
+ config = all_samplers_map.get(name, None)
+ else:
+ config = all_samplers[0]
+ return config
+
+
+def visible_sampler_names():
+ visible_samplers = [x for x in all_samplers if x.name in shared.opts.show_samplers] if len(shared.opts.show_samplers) > 0 else all_samplers
+ return visible_samplers
+
+
+def create_sampler(name, model):
+ if name == 'Default' and hasattr(model, 'scheduler'):
+ config = {k: v for k, v in model.scheduler.config.items() if not k.startswith('_')}
+ shared.log.debug(f'Sampler default {type(model.scheduler).__name__}: {config}')
+ return model.scheduler
+ config = find_sampler_config(name)
+ if config is None:
+ shared.log.error(f'Attempting to use unknown sampler: {name}')
+ config = all_samplers[0]
+ if shared.backend == shared.Backend.ORIGINAL:
+ sampler = config.constructor(model)
+ sampler.config = config
+ sampler.initialize(p=None)
+ sampler.name = name
+ shared.log.debug(f'Sampler: sampler="{sampler.name}" config={sampler.config.options}')
+ return sampler
+ elif shared.backend == shared.Backend.DIFFUSERS:
+ sampler = config.constructor(model)
+ if not hasattr(model, 'scheduler_config'):
+ model.scheduler_config = sampler.sampler.config.copy()
+ model.scheduler = sampler.sampler
+ shared.log.debug(f'Sampler: sampler="{sampler.name}" config={sampler.config}')
+ return sampler.sampler
+ else:
+ return None
+
+
+def set_samplers():
+ global samplers # pylint: disable=global-statement
+ global samplers_for_img2img # pylint: disable=global-statement
+ samplers = visible_sampler_names()
+ # samplers_for_img2img = [x for x in samplers if x.name != "PLMS"]
+ samplers_for_img2img = samplers
+ samplers_map.clear()
+ for sampler in all_samplers:
+ samplers_map[sampler.name.lower()] = sampler.name
+ for alias in sampler.aliases:
+ samplers_map[alias.lower()] = sampler.name
diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py
index 8de26b8f3..6685a4d38 100644
--- a/modules/sd_samplers_cfg_denoiser.py
+++ b/modules/sd_samplers_cfg_denoiser.py
@@ -1,233 +1,233 @@
-# TODO a1111 compatibility module
-# TODO cfg_denoiser implementation missing
-
-import torch
-from modules import prompt_parser, devices, sd_samplers_common
-
-from modules.shared import opts, state
-import modules.shared as shared
-from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
-from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
-from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
-
-
-def catenate_conds(conds):
- if not isinstance(conds[0], dict):
- return torch.cat(conds)
-
- return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()}
-
-
-def subscript_cond(cond, a, b):
- if not isinstance(cond, dict):
- return cond[a:b]
-
- return {key: vec[a:b] for key, vec in cond.items()}
-
-
-def pad_cond(tensor, repeats, empty):
- if not isinstance(tensor, dict):
- return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1)
-
- tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty)
- return tensor
-
-
-class CFGDenoiser(torch.nn.Module):
- """
- Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
- that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
- instead of one. Originally, the second prompt is just an empty string, but we use non-empty
- negative prompt.
- """
-
- def __init__(self, sampler):
- super().__init__()
- self.model_wrap = None
- self.mask = None
- self.nmask = None
- self.init_latent = None
- self.steps = None
- """number of steps as specified by user in UI"""
-
- self.total_steps = None
- """expected number of calls to denoiser calculated from self.steps and specifics of the selected sampler"""
-
- self.step = 0
- self.image_cfg_scale = None
- self.padded_cond_uncond = False
- self.sampler = sampler
- self.model_wrap = None
- self.p = None
- self.mask_before_denoising = False
-
- @property
- def inner_model(self):
- raise NotImplementedError
-
- def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
- denoised_uncond = x_out[-uncond.shape[0]:]
- denoised = torch.clone(denoised_uncond)
-
- for i, conds in enumerate(conds_list):
- for cond_index, weight in conds:
- denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
-
- return denoised
-
- def combine_denoised_for_edit_model(self, x_out, cond_scale):
- out_cond, out_img_cond, out_uncond = x_out.chunk(3)
- denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond)
-
- return denoised
-
- def get_pred_x0(self, x_in, x_out, sigma): # pylint: disable=unused-argument
- return x_out
-
- def update_inner_model(self):
- self.model_wrap = None
-
- c, uc = self.p.get_conds()
- self.sampler.sampler_extra_args['cond'] = c
- self.sampler.sampler_extra_args['uncond'] = uc
-
- def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
- if state.interrupted or state.skipped:
- raise sd_samplers_common.InterruptedException
-
- # TODO cfg_scale implementation missing
- # if sd_samplers_common.apply_refiner(self):
- # cond = self.sampler.sampler_extra_args['cond']
- # uncond = self.sampler.sampler_extra_args['uncond']
-
- # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
- # so is_edit_model is set to False to support AND composition.
- is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
-
- conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
- uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
-
- assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
-
- if self.mask_before_denoising and self.mask is not None:
- x = self.init_latent * self.mask + self.nmask * x
-
- batch_size = len(conds_list)
- repeats = [len(conds_list[i]) for i in range(batch_size)]
-
- if shared.sd_model.model.conditioning_key == "crossattn-adm":
- image_uncond = torch.zeros_like(image_cond)
- make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm} # pylint: disable=unnecessary-lambda-assignment
- else:
- image_uncond = image_cond
- if isinstance(uncond, dict):
- make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]} # pylint: disable=unnecessary-lambda-assignment
- else:
- make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]} # pylint: disable=unnecessary-lambda-assignment
-
- if not is_edit_model:
- x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
- sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
- image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond])
- else:
- x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
- sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
- image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])
-
- denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond)
- cfg_denoiser_callback(denoiser_params)
- x_in = denoiser_params.x
- image_cond_in = denoiser_params.image_cond
- sigma_in = denoiser_params.sigma
- tensor = denoiser_params.text_cond
- uncond = denoiser_params.text_uncond
- skip_uncond = False
-
- # alternating uncond allows for higher thresholds without the quality loss normally expected from raising it
- if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model:
- skip_uncond = True
- x_in = x_in[:-batch_size]
- sigma_in = sigma_in[:-batch_size]
-
- self.padded_cond_uncond = False
- if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
- empty = shared.sd_model.cond_stage_model_empty_prompt
- num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
-
- if num_repeats < 0:
- tensor = pad_cond(tensor, -num_repeats, empty)
- self.padded_cond_uncond = True
- elif num_repeats > 0:
- uncond = pad_cond(uncond, num_repeats, empty)
- self.padded_cond_uncond = True
-
- if tensor.shape[1] == uncond.shape[1] or skip_uncond:
- if is_edit_model:
- cond_in = catenate_conds([tensor, uncond, uncond])
- elif skip_uncond:
- cond_in = tensor
- else:
- cond_in = catenate_conds([tensor, uncond])
-
- if shared.opts.batch_cond_uncond:
- x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in))
- else:
- x_out = torch.zeros_like(x_in)
- for batch_offset in range(0, x_out.shape[0], batch_size):
- a = batch_offset
- b = a + batch_size
- x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b]))
- else:
- x_out = torch.zeros_like(x_in)
- batch_size = batch_size*2 if shared.opts.batch_cond_uncond else batch_size
- for batch_offset in range(0, tensor.shape[0], batch_size):
- a = batch_offset
- b = min(a + batch_size, tensor.shape[0])
-
- if not is_edit_model:
- c_crossattn = subscript_cond(tensor, a, b)
- else:
- c_crossattn = torch.cat([tensor[a:b]], uncond)
-
- x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
-
- if not skip_uncond:
- x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:]))
-
- denoised_image_indexes = [x[0][0] for x in conds_list]
- if skip_uncond:
- fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
- x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
-
- denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
- cfg_denoised_callback(denoised_params)
-
- devices.test_for_nans(x_out, "unet")
-
- if is_edit_model:
- denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
- elif skip_uncond:
- denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
- else:
- denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
-
- if not self.mask_before_denoising and self.mask is not None:
- denoised = self.init_latent * self.mask + self.nmask * denoised
-
- self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)
-
- if opts.live_preview_content == "Prompt":
- preview = self.sampler.last_latent
- elif opts.live_preview_content == "Negative prompt":
- preview = self.get_pred_x0(x_in[-uncond.shape[0]:], x_out[-uncond.shape[0]:], sigma)
- else:
- preview = self.get_pred_x0(torch.cat([x_in[i:i+1] for i in denoised_image_indexes]), torch.cat([denoised[i:i+1] for i in denoised_image_indexes]), sigma)
-
- sd_samplers_common.store_latent(preview)
-
- after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
- cfg_after_cfg_callback(after_cfg_callback_params)
- denoised = after_cfg_callback_params.x
-
- self.step += 1
- return denoised
+# TODO a1111 compatibility module
+# TODO cfg_denoiser implementation missing
+
+import torch
+from modules import prompt_parser, devices, sd_samplers_common
+
+from modules.shared import opts, state
+import modules.shared as shared
+from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
+from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
+from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
+
+
+def catenate_conds(conds):
+ if not isinstance(conds[0], dict):
+ return torch.cat(conds)
+
+ return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()}
+
+
+def subscript_cond(cond, a, b):
+ if not isinstance(cond, dict):
+ return cond[a:b]
+
+ return {key: vec[a:b] for key, vec in cond.items()}
+
+
+def pad_cond(tensor, repeats, empty):
+ if not isinstance(tensor, dict):
+ return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1)
+
+ tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty)
+ return tensor
+
+
+class CFGDenoiser(torch.nn.Module):
+ """
+ Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
+ that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
+ instead of one. Originally, the second prompt is just an empty string, but we use non-empty
+ negative prompt.
+ """
+
+ def __init__(self, sampler):
+ super().__init__()
+ self.model_wrap = None
+ self.mask = None
+ self.nmask = None
+ self.init_latent = None
+ self.steps = None
+ """number of steps as specified by user in UI"""
+
+ self.total_steps = None
+ """expected number of calls to denoiser calculated from self.steps and specifics of the selected sampler"""
+
+ self.step = 0
+ self.image_cfg_scale = None
+ self.padded_cond_uncond = False
+ self.sampler = sampler
+ self.model_wrap = None
+ self.p = None
+ self.mask_before_denoising = False
+
+ @property
+ def inner_model(self):
+ raise NotImplementedError
+
+ def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
+ denoised_uncond = x_out[-uncond.shape[0]:]
+ denoised = torch.clone(denoised_uncond)
+
+ for i, conds in enumerate(conds_list):
+ for cond_index, weight in conds:
+ denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
+
+ return denoised
+
+ def combine_denoised_for_edit_model(self, x_out, cond_scale):
+ out_cond, out_img_cond, out_uncond = x_out.chunk(3)
+ denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond)
+
+ return denoised
+
+ def get_pred_x0(self, x_in, x_out, sigma): # pylint: disable=unused-argument
+ return x_out
+
+ def update_inner_model(self):
+ self.model_wrap = None
+
+ c, uc = self.p.get_conds()
+ self.sampler.sampler_extra_args['cond'] = c
+ self.sampler.sampler_extra_args['uncond'] = uc
+
+ def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
+ if state.interrupted or state.skipped:
+ raise sd_samplers_common.InterruptedException
+
+ # TODO cfg_scale implementation missing
+ # if sd_samplers_common.apply_refiner(self):
+ # cond = self.sampler.sampler_extra_args['cond']
+ # uncond = self.sampler.sampler_extra_args['uncond']
+
+ # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
+ # so is_edit_model is set to False to support AND composition.
+ is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
+
+ conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
+ uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
+
+ assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
+
+ if self.mask_before_denoising and self.mask is not None:
+ x = self.init_latent * self.mask + self.nmask * x
+
+ batch_size = len(conds_list)
+ repeats = [len(conds_list[i]) for i in range(batch_size)]
+
+ if shared.sd_model.model.conditioning_key == "crossattn-adm":
+ image_uncond = torch.zeros_like(image_cond)
+ make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm} # pylint: disable=unnecessary-lambda-assignment
+ else:
+ image_uncond = image_cond
+ if isinstance(uncond, dict):
+ make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]} # pylint: disable=unnecessary-lambda-assignment
+ else:
+ make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]} # pylint: disable=unnecessary-lambda-assignment
+
+ if not is_edit_model:
+ x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
+ sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
+ image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond])
+ else:
+ x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
+ sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
+ image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])
+
+ denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond)
+ cfg_denoiser_callback(denoiser_params)
+ x_in = denoiser_params.x
+ image_cond_in = denoiser_params.image_cond
+ sigma_in = denoiser_params.sigma
+ tensor = denoiser_params.text_cond
+ uncond = denoiser_params.text_uncond
+ skip_uncond = False
+
+ # alternating uncond allows for higher thresholds without the quality loss normally expected from raising it
+ if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model:
+ skip_uncond = True
+ x_in = x_in[:-batch_size]
+ sigma_in = sigma_in[:-batch_size]
+
+ self.padded_cond_uncond = False
+ if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
+ empty = shared.sd_model.cond_stage_model_empty_prompt
+ num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
+
+ if num_repeats < 0:
+ tensor = pad_cond(tensor, -num_repeats, empty)
+ self.padded_cond_uncond = True
+ elif num_repeats > 0:
+ uncond = pad_cond(uncond, num_repeats, empty)
+ self.padded_cond_uncond = True
+
+ if tensor.shape[1] == uncond.shape[1] or skip_uncond:
+ if is_edit_model:
+ cond_in = catenate_conds([tensor, uncond, uncond])
+ elif skip_uncond:
+ cond_in = tensor
+ else:
+ cond_in = catenate_conds([tensor, uncond])
+
+ if shared.opts.batch_cond_uncond:
+ x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in))
+ else:
+ x_out = torch.zeros_like(x_in)
+ for batch_offset in range(0, x_out.shape[0], batch_size):
+ a = batch_offset
+ b = a + batch_size
+ x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b]))
+ else:
+ x_out = torch.zeros_like(x_in)
+ batch_size = batch_size*2 if shared.opts.batch_cond_uncond else batch_size
+ for batch_offset in range(0, tensor.shape[0], batch_size):
+ a = batch_offset
+ b = min(a + batch_size, tensor.shape[0])
+
+ if not is_edit_model:
+ c_crossattn = subscript_cond(tensor, a, b)
+ else:
+ c_crossattn = torch.cat([tensor[a:b]], uncond)
+
+ x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
+
+ if not skip_uncond:
+ x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:]))
+
+ denoised_image_indexes = [x[0][0] for x in conds_list]
+ if skip_uncond:
+ fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
+ x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
+
+ denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
+ cfg_denoised_callback(denoised_params)
+
+ devices.test_for_nans(x_out, "unet")
+
+ if is_edit_model:
+ denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
+ elif skip_uncond:
+ denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
+ else:
+ denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
+
+ if not self.mask_before_denoising and self.mask is not None:
+ denoised = self.init_latent * self.mask + self.nmask * denoised
+
+ self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)
+
+ if opts.live_preview_content == "Prompt":
+ preview = self.sampler.last_latent
+ elif opts.live_preview_content == "Negative prompt":
+ preview = self.get_pred_x0(x_in[-uncond.shape[0]:], x_out[-uncond.shape[0]:], sigma)
+ else:
+ preview = self.get_pred_x0(torch.cat([x_in[i:i+1] for i in denoised_image_indexes]), torch.cat([denoised[i:i+1] for i in denoised_image_indexes]), sigma)
+
+ sd_samplers_common.store_latent(preview)
+
+ after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
+ cfg_after_cfg_callback(after_cfg_callback_params)
+ denoised = after_cfg_callback_params.x
+
+ self.step += 1
+ return denoised
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py
index cf4f8b8fb..86c090324 100644
--- a/modules/sd_samplers_common.py
+++ b/modules/sd_samplers_common.py
@@ -1,128 +1,128 @@
-from collections import namedtuple
-import torch
-import torchvision.transforms as T
-from PIL import Image
-from modules import devices, processing, images, sd_vae_approx, sd_samplers, shared
-import modules.taesd.sd_vae_taesd as sd_vae_taesd
-
-
-SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
-approximation_indexes = { "Simple": 0, "Approximate": 1, "TAESD": 2, "Full VAE": 3 }
-warned = False
-
-
-def warn_once(message):
- global warned # pylint: disable=global-statement
- if not warned:
- shared.log.warning(message)
- warned = True
-
-
-def setup_img2img_steps(p, steps=None):
- if shared.opts.img2img_fix_steps or steps is not None:
- requested_steps = (steps or p.steps)
- steps = int(requested_steps / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
- t_enc = requested_steps - 1
- else:
- steps = p.steps
- t_enc = int(min(p.denoising_strength, 0.999) * steps)
-
- return steps, t_enc
-
-
-def single_sample_to_image(sample, approximation=None):
- if approximation is None:
- approximation = approximation_indexes.get(shared.opts.show_progress_type, None)
- if approximation is None:
- warn_once('Unknown decode type, please reset preview method')
- approximation = 0
-
- # normal sample is [4,64,64]
- if sample.dtype == torch.bfloat16:
- sample = sample.to(torch.float16)
- if len(sample.shape) > 4: # likely unknown video latent (e.g. svd)
- return Image.new(mode="RGB", size=(512, 512))
- if len(sample.shape) == 4 and sample.shape[0]: # likely animatediff latent
- sample = sample.permute(1, 0, 2, 3)[0]
- if approximation == 0: # Simple
- x_sample = sd_vae_approx.cheap_approximation(sample) * 0.5 + 0.5
- elif approximation == 1: # Approximate
- x_sample = sd_vae_approx.nn_approximation(sample) * 0.5 + 0.5
- if shared.sd_model_type == "sdxl":
- x_sample = x_sample[[2,1,0], :, :] # BGR to RGB
- elif approximation == 2: # TAESD
- x_sample = sd_vae_taesd.decode(sample)
- x_sample = (1.0 + x_sample) / 2.0 # preview requires smaller range
- elif approximation == 3: # Full VAE
- x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] * 0.5 + 0.5
- else:
- warn_once(f"Unknown latent decode type: {approximation}")
- return Image.new(mode="RGB", size=(512, 512))
-
- try:
- if x_sample.dtype == torch.bfloat16:
- x_sample.to(torch.float16)
- transform = T.ToPILImage()
- image = transform(x_sample)
- except Exception as e:
- warn_once(f'Live preview: {e}')
- image = Image.new(mode="RGB", size=(512, 512))
- return image
-
-
-def sample_to_image(samples, index=0, approximation=None):
- return single_sample_to_image(samples[index], approximation)
-
-
-def samples_to_image_grid(samples, approximation=None):
- return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
-
-
-def images_tensor_to_samples(image, approximation=None, model=None):
- '''image[0, 1] -> latent'''
- if approximation is None:
- approximation = approximation_indexes.get(shared.opts.show_progress_type, 0)
- if approximation == 2:
- image = image.to(devices.device, devices.dtype)
- x_latent = sd_vae_taesd.encode(image)
- else:
- if model is None:
- model = shared.sd_model
- model.first_stage_model.to(devices.dtype_vae)
- image = image.to(shared.device, dtype=devices.dtype_vae)
- image = image * 2 - 1
- if len(image) > 1:
- image_latents = [model.get_first_stage_encoding(model.encode_first_stage(torch.unsqueeze(img, 0)))[0] for img in image]
- x_latent = torch.stack(image_latents)
- else:
- x_latent = model.get_first_stage_encoding(model.encode_first_stage(image))
- return x_latent
-
-
-def store_latent(decoded):
- shared.state.current_latent = decoded
- if shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % shared.opts.show_progress_every_n_steps == 0:
- if not shared.parallel_processing_allowed:
- image = sample_to_image(decoded)
- shared.state.assign_current_image(image)
-
-
-def is_sampler_using_eta_noise_seed_delta(p):
- """returns whether sampler from config will use eta noise seed delta for image creation"""
- sampler_config = sd_samplers.find_sampler_config(p.sampler_name)
- eta = 0
- if hasattr(p, "eta"):
- eta = p.eta
- if not hasattr(p.sampler, "eta"):
- return False
- if eta is None and p.sampler is not None:
- eta = p.sampler.eta
- if eta is None and sampler_config is not None:
- eta = 0 if sampler_config.options.get("default_eta_is_0", False) else 1.0
- if eta == 0:
- return False
- return True
-
-
-class InterruptedException(BaseException):
- pass
+from collections import namedtuple
+import torch
+import torchvision.transforms as T
+from PIL import Image
+from modules import devices, processing, images, sd_vae_approx, sd_samplers, shared
+import modules.taesd.sd_vae_taesd as sd_vae_taesd
+
+
+SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
+approximation_indexes = { "Simple": 0, "Approximate": 1, "TAESD": 2, "Full VAE": 3 }
+warned = False
+
+
+def warn_once(message):
+ global warned # pylint: disable=global-statement
+ if not warned:
+ shared.log.warning(message)
+ warned = True
+
+
+def setup_img2img_steps(p, steps=None):
+ if shared.opts.img2img_fix_steps or steps is not None:
+ requested_steps = (steps or p.steps)
+ steps = int(requested_steps / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
+ t_enc = requested_steps - 1
+ else:
+ steps = p.steps
+ t_enc = int(min(p.denoising_strength, 0.999) * steps)
+
+ return steps, t_enc
+
+
+def single_sample_to_image(sample, approximation=None):
+ if approximation is None:
+ approximation = approximation_indexes.get(shared.opts.show_progress_type, None)
+ if approximation is None:
+ warn_once('Unknown decode type, please reset preview method')
+ approximation = 0
+
+ # normal sample is [4,64,64]
+ if sample.dtype == torch.bfloat16:
+ sample = sample.to(torch.float16)
+ if len(sample.shape) > 4: # likely unknown video latent (e.g. svd)
+ return Image.new(mode="RGB", size=(512, 512))
+ if len(sample.shape) == 4 and sample.shape[0]: # likely animatediff latent
+ sample = sample.permute(1, 0, 2, 3)[0]
+ if approximation == 0: # Simple
+ x_sample = sd_vae_approx.cheap_approximation(sample) * 0.5 + 0.5
+ elif approximation == 1: # Approximate
+ x_sample = sd_vae_approx.nn_approximation(sample) * 0.5 + 0.5
+ if shared.sd_model_type == "sdxl":
+ x_sample = x_sample[[2,1,0], :, :] # BGR to RGB
+ elif approximation == 2: # TAESD
+ x_sample = sd_vae_taesd.decode(sample)
+ x_sample = (1.0 + x_sample) / 2.0 # preview requires smaller range
+ elif approximation == 3: # Full VAE
+ x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] * 0.5 + 0.5
+ else:
+ warn_once(f"Unknown latent decode type: {approximation}")
+ return Image.new(mode="RGB", size=(512, 512))
+
+ try:
+ if x_sample.dtype == torch.bfloat16:
+ x_sample.to(torch.float16)
+ transform = T.ToPILImage()
+ image = transform(x_sample)
+ except Exception as e:
+ warn_once(f'Live preview: {e}')
+ image = Image.new(mode="RGB", size=(512, 512))
+ return image
+
+
+def sample_to_image(samples, index=0, approximation=None):
+ return single_sample_to_image(samples[index], approximation)
+
+
+def samples_to_image_grid(samples, approximation=None):
+ return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
+
+
+def images_tensor_to_samples(image, approximation=None, model=None):
+ '''image[0, 1] -> latent'''
+ if approximation is None:
+ approximation = approximation_indexes.get(shared.opts.show_progress_type, 0)
+ if approximation == 2:
+ image = image.to(devices.device, devices.dtype)
+ x_latent = sd_vae_taesd.encode(image)
+ else:
+ if model is None:
+ model = shared.sd_model
+ model.first_stage_model.to(devices.dtype_vae)
+ image = image.to(shared.device, dtype=devices.dtype_vae)
+ image = image * 2 - 1
+ if len(image) > 1:
+ image_latents = [model.get_first_stage_encoding(model.encode_first_stage(torch.unsqueeze(img, 0)))[0] for img in image]
+ x_latent = torch.stack(image_latents)
+ else:
+ x_latent = model.get_first_stage_encoding(model.encode_first_stage(image))
+ return x_latent
+
+
+def store_latent(decoded):
+ shared.state.current_latent = decoded
+ if shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % shared.opts.show_progress_every_n_steps == 0:
+ if not shared.parallel_processing_allowed:
+ image = sample_to_image(decoded)
+ shared.state.assign_current_image(image)
+
+
+def is_sampler_using_eta_noise_seed_delta(p):
+ """returns whether sampler from config will use eta noise seed delta for image creation"""
+ sampler_config = sd_samplers.find_sampler_config(p.sampler_name)
+ eta = 0
+ if hasattr(p, "eta"):
+ eta = p.eta
+ if not hasattr(p.sampler, "eta"):
+ return False
+ if eta is None and p.sampler is not None:
+ eta = p.sampler.eta
+ if eta is None and sampler_config is not None:
+ eta = 0 if sampler_config.options.get("default_eta_is_0", False) else 1.0
+ if eta == 0:
+ return False
+ return True
+
+
+class InterruptedException(BaseException):
+ pass
diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py
index c5ec87a1c..6e56dfd75 100644
--- a/modules/sd_samplers_compvis.py
+++ b/modules/sd_samplers_compvis.py
@@ -1,215 +1,215 @@
-import math
-import ldm.models.diffusion.ddim
-import ldm.models.diffusion.plms
-import numpy as np
-import torch
-from modules import sd_samplers_common, prompt_parser, shared
-import modules.unipc
-
-
-samplers_data_compvis = [
- sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.unipc.UniPCSampler, model), [], {}),
- sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {"default_eta_is_0": True}),
- sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
-]
-
-
-class VanillaStableDiffusionSampler:
- def __init__(self, constructor, sd_model):
- self.sampler = constructor(sd_model)
- self.is_ddim = hasattr(self.sampler, 'p_sample_ddim')
- self.is_plms = hasattr(self.sampler, 'p_sample_plms')
- self.is_unipc = isinstance(self.sampler, modules.unipc.UniPCSampler)
- self.orig_p_sample_ddim = None
- if self.is_plms:
- self.orig_p_sample_ddim = self.sampler.p_sample_plms
- elif self.is_ddim:
- self.orig_p_sample_ddim = self.sampler.p_sample_ddim
- self.mask = None
- self.nmask = None
- self.init_latent = None
- self.sampler_noises = None
- self.step = 0
- self.stop_at = None
- self.eta = None
- self.config = None
- self.last_latent = None
- self.conditioning_key = sd_model.model.conditioning_key
-
- def number_of_needed_noises(self, p): # pylint: disable=unused-argument
- return 0
-
- def launch_sampling(self, steps, func):
- shared.state.sampling_steps = steps
- shared.state.sampling_step = 0
- try:
- return func()
- except sd_samplers_common.InterruptedException:
- return self.last_latent
-
- def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
- x_dec, ts, cond, unconditional_conditioning = self.before_sample(x_dec, ts, cond, unconditional_conditioning)
- res = self.orig_p_sample_ddim(x_dec, cond, ts, *args, unconditional_conditioning=unconditional_conditioning, **kwargs)
- x_dec, ts, cond, unconditional_conditioning, res = self.after_sample(x_dec, ts, cond, unconditional_conditioning, res)
- return res
-
- def before_sample(self, x, ts, cond, unconditional_conditioning):
- if shared.state.interrupted or shared.state.skipped:
- raise sd_samplers_common.InterruptedException
- if shared.state.paused:
- shared.log.debug('Sampling paused')
- while shared.state.paused:
- if shared.state.interrupted or shared.state.skipped:
- raise sd_samplers_common.InterruptedException
- import time
- time.sleep(0.1)
-
- if self.stop_at is not None and self.step > self.stop_at:
- raise sd_samplers_common.InterruptedException
-
- # Have to unwrap the inpainting conditioning here to perform pre-processing
- image_conditioning = None
- uc_image_conditioning = None
- if isinstance(cond, dict):
- if self.conditioning_key == "crossattn-adm":
- image_conditioning = cond["c_adm"]
- uc_image_conditioning = unconditional_conditioning["c_adm"]
- else:
- image_conditioning = cond["c_concat"][0]
- cond = cond["c_crossattn"][0]
- unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
-
- conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
- unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
-
- assert all(len(conds) == 1 for conds in conds_list), 'composition via AND is not supported for DDIM/PLMS samplers'
- cond = tensor
-
- # for DDIM, shapes must match, we can't just process cond and uncond independently;
- # filling unconditional_conditioning with repeats of the last vector to match length is
- # not 100% correct but should work well enough
- if unconditional_conditioning.shape[1] < cond.shape[1]:
- last_vector = unconditional_conditioning[:, -1:]
- last_vector_repeated = last_vector.repeat([1, cond.shape[1] - unconditional_conditioning.shape[1], 1])
- unconditional_conditioning = torch.hstack([unconditional_conditioning, last_vector_repeated])
- elif unconditional_conditioning.shape[1] > cond.shape[1]:
- unconditional_conditioning = unconditional_conditioning[:, :cond.shape[1]]
-
- if self.mask is not None:
- encode_fn = self.sampler.model.q_sample
- if self.is_unipc:
- encode_fn = self.sampler.stochastic_encode
- img_orig = encode_fn(self.init_latent, ts)
- x = img_orig * self.mask + self.nmask * x
-
- # Wrap the image conditioning back up since the DDIM code can accept the dict directly.
- # Note that they need to be lists because it just concatenates them later.
- if image_conditioning is not None:
- if self.conditioning_key == "crossattn-adm":
- cond = {"c_adm": image_conditioning, "c_crossattn": [cond]}
- unconditional_conditioning = {"c_adm": uc_image_conditioning, "c_crossattn": [unconditional_conditioning]}
- else:
- cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
- unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
- return x, ts, cond, unconditional_conditioning
-
- def update_step(self, last_latent):
- self.last_latent = self.init_latent * self.mask + self.nmask * last_latent if self.mask is not None else last_latent
- sd_samplers_common.store_latent(self.last_latent)
- self.step += 1
- shared.state.sampling_step = self.step
-
- def after_sample(self, x, ts, cond, uncond, res):
- if not self.is_unipc:
- self.update_step(res[1])
- return x, ts, cond, uncond, res
-
- def unipc_after_update(self, x, model_x): # pylint: disable=unused-argument
- self.update_step(x)
-
- def initialize(self, p):
- if p is not None:
- if self.is_ddim:
- self.eta = p.eta if p.eta is not None else shared.opts.scheduler_eta
- else:
- self.eta = 0.0
- if self.eta != 0.0:
- p.extra_generation_params["Sampler Eta"] = self.eta
- if self.is_unipc:
- keys = [
- ('Solver order', 'schedulers_solver_order'),
- ('Sampler low order', 'schedulers_use_loworder'),
- ('UniPC variant', 'uni_pc_variant'),
- ('UniPC skip type', 'uni_pc_skip_type'),
- ]
- for name, key in keys:
- v = getattr(shared.opts, key)
- if v != shared.opts.get_default(key):
- p.extra_generation_params[name] = v
- for fieldname in ['p_sample_ddim', 'p_sample_plms']:
- if hasattr(self.sampler, fieldname):
- setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
- if self.is_unipc:
- self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r), lambda x, mx: self.unipc_after_update(x, mx))
-
- self.mask = p.mask if hasattr(p, 'mask') else None
- self.nmask = p.nmask if hasattr(p, 'nmask') else None
-
-
- def adjust_steps_if_invalid(self, p, num_steps):
- if ((self.config.name == 'DDIM') and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS') or (self.config.name == 'UniPC'):
- if self.config.name == 'UniPC' and num_steps < shared.opts.schedulers_solver_order:
- num_steps = shared.opts.schedulers_solver_order
- valid_step = 999 / (1000 // num_steps)
- if valid_step == math.floor(valid_step):
- return min(int(valid_step) + 1, num_steps)
-
- return num_steps
-
- def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
- steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
- steps = self.adjust_steps_if_invalid(p, steps)
- self.initialize(p)
-
- self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
- x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
-
- self.init_latent = x
- self.last_latent = x
- self.step = 0
-
- # Wrap the conditioning models with additional image conditioning for inpainting model
- if image_conditioning is not None:
- if self.conditioning_key == "crossattn-adm":
- conditioning = {"c_adm": image_conditioning, "c_crossattn": [conditioning]}
- unconditional_conditioning = {"c_adm": torch.zeros_like(image_conditioning), "c_crossattn": [unconditional_conditioning]}
- else:
- conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
- unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
-
- samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
-
- return samples
-
- def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
- self.initialize(p)
-
- self.init_latent = None
- self.last_latent = x
- self.step = 0
-
- steps = self.adjust_steps_if_invalid(p, steps or p.steps)
-
- # Wrap the conditioning models with additional image conditioning for inpainting model
- # dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
- if image_conditioning is not None:
- if self.conditioning_key == "crossattn-adm":
- conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_adm": image_conditioning}
- unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_adm": torch.zeros_like(image_conditioning)}
- else:
- conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
- unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
-
- samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
-
- return samples_ddim
+import math
+import ldm.models.diffusion.ddim
+import ldm.models.diffusion.plms
+import numpy as np
+import torch
+from modules import sd_samplers_common, prompt_parser, shared
+import modules.unipc
+
+
+samplers_data_compvis = [
+ sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.unipc.UniPCSampler, model), [], {}),
+ sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {"default_eta_is_0": True}),
+ sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
+]
+
+
+class VanillaStableDiffusionSampler:
+ def __init__(self, constructor, sd_model):
+ self.sampler = constructor(sd_model)
+ self.is_ddim = hasattr(self.sampler, 'p_sample_ddim')
+ self.is_plms = hasattr(self.sampler, 'p_sample_plms')
+ self.is_unipc = isinstance(self.sampler, modules.unipc.UniPCSampler)
+ self.orig_p_sample_ddim = None
+ if self.is_plms:
+ self.orig_p_sample_ddim = self.sampler.p_sample_plms
+ elif self.is_ddim:
+ self.orig_p_sample_ddim = self.sampler.p_sample_ddim
+ self.mask = None
+ self.nmask = None
+ self.init_latent = None
+ self.sampler_noises = None
+ self.step = 0
+ self.stop_at = None
+ self.eta = None
+ self.config = None
+ self.last_latent = None
+ self.conditioning_key = sd_model.model.conditioning_key
+
+ def number_of_needed_noises(self, p): # pylint: disable=unused-argument
+ return 0
+
+ def launch_sampling(self, steps, func):
+ shared.state.sampling_steps = steps
+ shared.state.sampling_step = 0
+ try:
+ return func()
+ except sd_samplers_common.InterruptedException:
+ return self.last_latent
+
+ def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
+ x_dec, ts, cond, unconditional_conditioning = self.before_sample(x_dec, ts, cond, unconditional_conditioning)
+ res = self.orig_p_sample_ddim(x_dec, cond, ts, *args, unconditional_conditioning=unconditional_conditioning, **kwargs)
+ x_dec, ts, cond, unconditional_conditioning, res = self.after_sample(x_dec, ts, cond, unconditional_conditioning, res)
+ return res
+
+ def before_sample(self, x, ts, cond, unconditional_conditioning):
+ if shared.state.interrupted or shared.state.skipped:
+ raise sd_samplers_common.InterruptedException
+ if shared.state.paused:
+ shared.log.debug('Sampling paused')
+ while shared.state.paused:
+ if shared.state.interrupted or shared.state.skipped:
+ raise sd_samplers_common.InterruptedException
+ import time
+ time.sleep(0.1)
+
+ if self.stop_at is not None and self.step > self.stop_at:
+ raise sd_samplers_common.InterruptedException
+
+ # Have to unwrap the inpainting conditioning here to perform pre-processing
+ image_conditioning = None
+ uc_image_conditioning = None
+ if isinstance(cond, dict):
+ if self.conditioning_key == "crossattn-adm":
+ image_conditioning = cond["c_adm"]
+ uc_image_conditioning = unconditional_conditioning["c_adm"]
+ else:
+ image_conditioning = cond["c_concat"][0]
+ cond = cond["c_crossattn"][0]
+ unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
+
+ conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
+ unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
+
+ assert all(len(conds) == 1 for conds in conds_list), 'composition via AND is not supported for DDIM/PLMS samplers'
+ cond = tensor
+
+ # for DDIM, shapes must match, we can't just process cond and uncond independently;
+ # filling unconditional_conditioning with repeats of the last vector to match length is
+ # not 100% correct but should work well enough
+ if unconditional_conditioning.shape[1] < cond.shape[1]:
+ last_vector = unconditional_conditioning[:, -1:]
+ last_vector_repeated = last_vector.repeat([1, cond.shape[1] - unconditional_conditioning.shape[1], 1])
+ unconditional_conditioning = torch.hstack([unconditional_conditioning, last_vector_repeated])
+ elif unconditional_conditioning.shape[1] > cond.shape[1]:
+ unconditional_conditioning = unconditional_conditioning[:, :cond.shape[1]]
+
+ if self.mask is not None:
+ encode_fn = self.sampler.model.q_sample
+ if self.is_unipc:
+ encode_fn = self.sampler.stochastic_encode
+ img_orig = encode_fn(self.init_latent, ts)
+ x = img_orig * self.mask + self.nmask * x
+
+ # Wrap the image conditioning back up since the DDIM code can accept the dict directly.
+ # Note that they need to be lists because it just concatenates them later.
+ if image_conditioning is not None:
+ if self.conditioning_key == "crossattn-adm":
+ cond = {"c_adm": image_conditioning, "c_crossattn": [cond]}
+ unconditional_conditioning = {"c_adm": uc_image_conditioning, "c_crossattn": [unconditional_conditioning]}
+ else:
+ cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
+ unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
+ return x, ts, cond, unconditional_conditioning
+
+ def update_step(self, last_latent):
+ self.last_latent = self.init_latent * self.mask + self.nmask * last_latent if self.mask is not None else last_latent
+ sd_samplers_common.store_latent(self.last_latent)
+ self.step += 1
+ shared.state.sampling_step = self.step
+
+ def after_sample(self, x, ts, cond, uncond, res):
+ if not self.is_unipc:
+ self.update_step(res[1])
+ return x, ts, cond, uncond, res
+
+ def unipc_after_update(self, x, model_x): # pylint: disable=unused-argument
+ self.update_step(x)
+
+ def initialize(self, p):
+ if p is not None:
+ if self.is_ddim:
+ self.eta = p.eta if p.eta is not None else shared.opts.scheduler_eta
+ else:
+ self.eta = 0.0
+ if self.eta != 0.0:
+ p.extra_generation_params["Sampler Eta"] = self.eta
+ if self.is_unipc:
+ keys = [
+ ('Solver order', 'schedulers_solver_order'),
+ ('Sampler low order', 'schedulers_use_loworder'),
+ ('UniPC variant', 'uni_pc_variant'),
+ ('UniPC skip type', 'uni_pc_skip_type'),
+ ]
+ for name, key in keys:
+ v = getattr(shared.opts, key)
+ if v != shared.opts.get_default(key):
+ p.extra_generation_params[name] = v
+ for fieldname in ['p_sample_ddim', 'p_sample_plms']:
+ if hasattr(self.sampler, fieldname):
+ setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
+ if self.is_unipc:
+ self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r), lambda x, mx: self.unipc_after_update(x, mx))
+
+ self.mask = p.mask if hasattr(p, 'mask') else None
+ self.nmask = p.nmask if hasattr(p, 'nmask') else None
+
+
+ def adjust_steps_if_invalid(self, p, num_steps):
+ if ((self.config.name == 'DDIM') and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS') or (self.config.name == 'UniPC'):
+ if self.config.name == 'UniPC' and num_steps < shared.opts.schedulers_solver_order:
+ num_steps = shared.opts.schedulers_solver_order
+ valid_step = 999 / (1000 // num_steps)
+ if valid_step == math.floor(valid_step):
+ return min(int(valid_step) + 1, num_steps)
+
+ return num_steps
+
+ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
+ steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
+ steps = self.adjust_steps_if_invalid(p, steps)
+ self.initialize(p)
+
+ self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
+ x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
+
+ self.init_latent = x
+ self.last_latent = x
+ self.step = 0
+
+ # Wrap the conditioning models with additional image conditioning for inpainting model
+ if image_conditioning is not None:
+ if self.conditioning_key == "crossattn-adm":
+ conditioning = {"c_adm": image_conditioning, "c_crossattn": [conditioning]}
+ unconditional_conditioning = {"c_adm": torch.zeros_like(image_conditioning), "c_crossattn": [unconditional_conditioning]}
+ else:
+ conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
+ unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
+
+ samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
+
+ return samples
+
+ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
+ self.initialize(p)
+
+ self.init_latent = None
+ self.last_latent = x
+ self.step = 0
+
+ steps = self.adjust_steps_if_invalid(p, steps or p.steps)
+
+ # Wrap the conditioning models with additional image conditioning for inpainting model
+ # dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
+ if image_conditioning is not None:
+ if self.conditioning_key == "crossattn-adm":
+ conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_adm": image_conditioning}
+ unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_adm": torch.zeros_like(image_conditioning)}
+ else:
+ conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
+ unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
+
+ samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
+
+ return samples_ddim
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index 3210edb09..399c05317 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -1,385 +1,385 @@
-import sys
-import time
-import inspect
-from collections import deque
-import torch
-from modules import prompt_parser
-from modules import devices
-from modules import sd_samplers_common
-import modules.shared as shared
-from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
-from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
-from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
-
-
-# deal with k-diffusion imports
-k_sampling = None
-try:
- import k_diffusion # pylint: disable=wrong-import-order
- k_sampling = k_diffusion.sampling
-except ImportError:
- pass
-try:
- if k_sampling is None:
- import importlib
- k_diffusion = importlib.import_module('modules.k-diffusion.k_diffusion')
- k_sampling = k_diffusion.sampling
-except Exception:
- pass
-if k_sampling is None:
- shared.log.info(f'Path search: {sys.path}')
- shared.log.error("Module not found: k-diffusion")
- sys.exit(1)
-
-
-samplers_k_diffusion = [
- ('Euler', 'sample_euler', ['k_euler'], {"scheduler": "default"}),
- ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"scheduler": "default", "brownian_noise": False}),
- ('Heun', 'sample_heun', ['k_heun'], {"scheduler": "default"}),
- ('LMS', 'sample_lms', ['k_lms'], {"scheduler": "default"}),
- ('DPM Adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"scheduler": "default", "brownian_noise": False}),
- ('DPM Fast', 'sample_dpm_fast', ['k_dpm_fast'], {"scheduler": "default", "brownian_noise": False}),
- ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True, "second_order": True, "scheduler": "default", "brownian_noise": False}),
- ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "second_order": True, "scheduler": "default", "brownian_noise": False}),
- ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {"scheduler": "default", "brownian_noise": False}),
- ('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde'], {'discard_next_to_last_sigma': True, "scheduler": "default", "brownian_noise": False}),
- ('DPM++ 2M SDE Heun', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun'], {"solver_type": "heun", "scheduler": "default", "brownian_noise": False}),
- ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {"second_order": True, "scheduler": "default", "brownian_noise": False}),
- ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "scheduler": "default", "brownian_noise": False}),
- ('DPM++ 3M SDE', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde'], {'discard_next_to_last_sigma': True, "scheduler": "default", "brownian_noise": False}),
-]
-
-samplers_data_k_diffusion = [
- sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options)
- for label, funcname, aliases, options in samplers_k_diffusion
- if hasattr(k_sampling, funcname)
-]
-
-sampler_extra_params = {
- 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
- 'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
- 'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
-}
-
-
-class CFGDenoiser(torch.nn.Module):
- """
- Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
- that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
- instead of one. Originally, the second prompt is just an empty string, but we use non-empty
- negative prompt.
- """
- def __init__(self, model):
- super().__init__()
- self.inner_model = model
- self.mask = None
- self.nmask = None
- self.init_latent = None
- self.step = 0
- self.image_cfg_scale = None
-
- def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
- denoised_uncond = x_out[-uncond.shape[0]:]
- denoised = torch.clone(denoised_uncond)
- for i, conds in enumerate(conds_list):
- for cond_index, weight in conds:
- denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
- return denoised
-
- def combine_denoised_for_edit_model(self, x_out, cond_scale):
- out_cond, out_img_cond, out_uncond = x_out.chunk(3)
- denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond)
- return denoised
-
- def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
- if shared.state.interrupted or shared.state.skipped:
- raise sd_samplers_common.InterruptedException
- if shared.state.paused:
- shared.log.debug('Sampling paused')
- while shared.state.paused:
- if shared.state.interrupted or shared.state.skipped:
- raise sd_samplers_common.InterruptedException
- time.sleep(0.1)
- # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
- # so is_edit_model is set to False to support AND composition.
- is_edit_model = (shared.sd_model is not None) and hasattr(shared.sd_model, 'cond_stage_key') and (shared.sd_model.cond_stage_key == "edit") and (self.image_cfg_scale is not None) and (self.image_cfg_scale != 1.0)
- conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
- uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
- assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
- batch_size = len(conds_list)
- repeats = [len(conds_list[i]) for i in range(batch_size)]
- if shared.sd_model.model.conditioning_key == "crossattn-adm":
- image_uncond = torch.zeros_like(image_cond)
- make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm} # pylint: disable=C3001
- else:
- image_uncond = image_cond
- make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]} # pylint: disable=C3001
- if not is_edit_model:
- x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
- sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
- image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond])
- else:
- x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
- sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
- image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])
- denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, shared.state.sampling_step, shared.state.sampling_steps, tensor, uncond)
- cfg_denoiser_callback(denoiser_params)
- x_in = denoiser_params.x
- image_cond_in = denoiser_params.image_cond
- sigma_in = denoiser_params.sigma
- tensor = denoiser_params.text_cond
- uncond = denoiser_params.text_uncond
- skip_uncond = False
- # alternating uncond allows for higher thresholds without the quality loss normally expected from raising it
- if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model:
- skip_uncond = True
- x_in = x_in[:-batch_size]
- sigma_in = sigma_in[:-batch_size]
-
- if tensor.shape[1] == uncond.shape[1] or skip_uncond:
- if is_edit_model:
- cond_in = torch.cat([tensor, uncond, uncond])
- elif skip_uncond:
- cond_in = tensor
- else:
- cond_in = torch.cat([tensor, uncond])
- """
- adjusted_cond_scale = cond_scale # Adjusted cond_scale for uncond
- last_uncond_steps = max(0, state.sampling_steps - 2) # Determine the last two steps before uncond stops
- if self.step >= last_uncond_steps: # Check if we're in the last two steps before uncond stops
- adjusted_cond_scale *= 1.5 # Apply uncond with 150% cond_scale
- else:
- if (self.step - last_uncond_steps) % 3 == 0: # Check if it's one of every three steps after uncond stops
- adjusted_cond_scale *= 1.5 # Apply uncond with 150% cond_scale
- """
- if shared.batch_cond_uncond:
- x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict([cond_in], image_cond_in))
- else:
- x_out = torch.zeros_like(x_in)
- for batch_offset in range(0, x_out.shape[0], batch_size):
- a = batch_offset
- b = a + batch_size
- x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict([cond_in[a:b]], image_cond_in[a:b]))
- else:
- x_out = torch.zeros_like(x_in)
- batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
- for batch_offset in range(0, tensor.shape[0], batch_size):
- a = batch_offset
- b = min(a + batch_size, tensor.shape[0])
- if not is_edit_model:
- c_crossattn = [tensor[a:b]]
- else:
- c_crossattn = torch.cat([tensor[a:b]], uncond)
- x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
- if not skip_uncond:
- x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:]))
- denoised_image_indexes = [x[0][0] for x in conds_list]
- if skip_uncond:
- fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
- x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
- denoised_params = CFGDenoisedParams(x_out, shared.state.sampling_step, shared.state.sampling_steps, self.inner_model)
- cfg_denoised_callback(denoised_params)
- devices.test_for_nans(x_out, "unet")
- if shared.opts.live_preview_content == "Prompt":
- sd_samplers_common.store_latent(torch.cat([x_out[i:i+1] for i in denoised_image_indexes]))
- elif shared.opts.live_preview_content == "Negative prompt":
- sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])
- if is_edit_model:
- denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
- elif skip_uncond:
- denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
- else:
- denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
- if self.mask is not None:
- if devices.backend == "directml":
- self.init_latent = self.init_latent.float()
- denoised = self.init_latent * self.mask + self.nmask * denoised
- self.init_latent = self.init_latent.half()
- else:
- denoised = self.init_latent * self.mask + self.nmask * denoised
- after_cfg_callback_params = AfterCFGCallbackParams(denoised, shared.state.sampling_step, shared.state.sampling_steps)
- cfg_after_cfg_callback(after_cfg_callback_params)
- denoised = after_cfg_callback_params.x
- self.step += 1
- return denoised
-
-
-class TorchHijack:
- def __init__(self, sampler_noises):
- # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
- # implementation.
- self.sampler_noises = deque(sampler_noises)
-
- def __getattr__(self, item):
- if item == 'randn_like':
- return self.randn_like
- if hasattr(torch, item):
- return getattr(torch, item)
- raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
-
- def randn_like(self, x):
- if self.sampler_noises:
- noise = self.sampler_noises.popleft()
- if noise.shape == x.shape:
- return noise
- if x.device.type == 'mps':
- return torch.randn_like(x, device=devices.cpu).to(x.device)
- else:
- return torch.randn_like(x)
-
-
-class KDiffusionSampler:
- def __init__(self, funcname, sd_model):
- denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
- self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
- self.funcname = funcname
- self.func = getattr(k_sampling, self.funcname)
- self.extra_params = sampler_extra_params.get(funcname, [])
- self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
- self.sampler_noises = None
- self.stop_at = None
- self.eta = None
- self.config = None # set by the function calling the constructor
- self.last_latent = None
- self.s_min_uncond = None
- self.conditioning_key = sd_model.model.conditioning_key
-
- def callback_state(self, d):
- step = d['i']
- latent = d["denoised"]
- if shared.opts.live_preview_content == "Combined":
- sd_samplers_common.store_latent(latent)
- self.last_latent = latent
- if self.stop_at is not None and step > self.stop_at:
- raise sd_samplers_common.InterruptedException
- shared.state.sampling_step = step
-
- def launch_sampling(self, steps, func):
- shared.state.sampling_steps = steps
- shared.state.sampling_step = 0
- try:
- return func()
- except sd_samplers_common.InterruptedException:
- return self.last_latent
-
- def number_of_needed_noises(self, p):
- return p.steps
-
- def initialize(self, p):
- if self.config.options.get('brownian_noise', None) is not None:
- self.config.options['brownian_noise'] = shared.opts.data.get('schedulers_brownian_noise', False)
- if self.config.options.get('scheduler', None) is not None:
- self.config.options['scheduler'] = shared.opts.data.get('schedulers_sigma', None)
- if p is None:
- return {}
- self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
- self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
- self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
- self.eta = p.eta if p.eta is not None else shared.opts.scheduler_eta
- self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)
- k_sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
- extra_params_kwargs = {}
- for param_name in self.extra_params:
- if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
- extra_params_kwargs[param_name] = getattr(p, param_name)
- if 'eta' in inspect.signature(self.func).parameters:
- if self.eta != 1.0:
- p.extra_generation_params["Sampler Eta"] = self.eta
- extra_params_kwargs['eta'] = self.eta
- return extra_params_kwargs
-
- def get_sigmas(self, p, steps): # pylint: disable=unused-argument
- discard_next_to_last_sigma = shared.opts.data.get('schedulers_discard_penultimate', True) if self.config.options.get('discard_next_to_last_sigma', None) is not None else False
- steps += 1 if discard_next_to_last_sigma else 0
- if self.config.options.get('scheduler', None) == 'default' or self.config.options.get('scheduler', None) is None:
- sigmas = self.model_wrap.get_sigmas(steps)
- elif self.config.options.get('scheduler', None) == 'karras':
- sigma_min = p.s_min if p.s_min > 0 else self.model_wrap.sigmas[0].item()
- sigma_max = p.s_max if p.s_max > 0 else self.model_wrap.sigmas[-1].item()
- sigmas = k_sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)
- elif self.config.options.get('scheduler', None) == 'exponential':
- sigma_min = p.s_min if p.s_min > 0 else self.model_wrap.sigmas[0].item()
- sigma_max = p.s_max if p.s_max > 0 else self.model_wrap.sigmas[-1].item()
- sigmas = k_sampling.get_sigmas_exponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)
- elif self.config.options.get('scheduler', None) == 'polyexponential':
- sigma_min = p.s_min if p.s_min > 0 else self.model_wrap.sigmas[0].item()
- sigma_max = p.s_max if p.s_max > 0 else self.model_wrap.sigmas[-1].item()
- sigmas = k_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)
- elif self.config.options.get('scheduler', None) == 'vp':
- sigmas = k_sampling.get_sigmas_vp(n=steps, device=shared.device)
- if discard_next_to_last_sigma:
- sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
- return sigmas
-
- def create_noise_sampler(self, x, sigmas, p):
- """For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
- if shared.opts.no_dpmpp_sde_batch_determinism:
- return None
- positive_sigmas = sigmas[sigmas > 0]
- if positive_sigmas.numel() > 0:
- sigma_min = positive_sigmas.min(dim=0)[0]
- else:
- sigma_min = 0
- sigma_max = sigmas.max()
- current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
- return k_sampling.BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
-
- def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
- steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
- sigmas = self.get_sigmas(p, steps)
- sigma_sched = sigmas[steps - t_enc - 1:]
- xi = x + noise * sigma_sched[0]
- extra_params_kwargs = self.initialize(p)
- parameters = inspect.signature(self.func).parameters
- if 'sigma_min' in parameters:
- ## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last
- extra_params_kwargs['sigma_min'] = sigma_sched[-2]
- if 'sigma_max' in parameters:
- extra_params_kwargs['sigma_max'] = sigma_sched[0]
- if 'n' in parameters:
- extra_params_kwargs['n'] = len(sigma_sched) - 1
- if 'sigma_sched' in parameters:
- extra_params_kwargs['sigma_sched'] = sigma_sched
- if 'sigmas' in parameters:
- extra_params_kwargs['sigmas'] = sigma_sched
- if self.config.options.get('brownian_noise', False) and 'noise_sampler' in parameters:
- noise_sampler = self.create_noise_sampler(x, sigmas, p)
- extra_params_kwargs['noise_sampler'] = noise_sampler
- self.model_wrap_cfg.init_latent = x
- self.last_latent = x
- extra_args = {
- 'cond': conditioning,
- 'image_cond': image_conditioning,
- 'uncond': unconditional_conditioning,
- 'cond_scale': p.cfg_scale,
- 's_min_uncond': self.s_min_uncond
- }
- samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
- samples = samples.type(devices.dtype)
- return samples
-
- def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
- steps = steps or p.steps
- sigmas = self.get_sigmas(p, steps)
- x = x * sigmas[0]
- extra_params_kwargs = self.initialize(p)
- parameters = inspect.signature(self.func).parameters
- if 'sigma_min' in parameters:
- extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
- extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
- if 'n' in parameters:
- extra_params_kwargs['n'] = steps
- else:
- extra_params_kwargs['sigmas'] = sigmas
- if self.config.options.get('brownian_noise', False) and 'noise_sampler' in parameters:
- noise_sampler = self.create_noise_sampler(x, sigmas, p)
- extra_params_kwargs['noise_sampler'] = noise_sampler
- self.last_latent = x
- samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
- 'cond': conditioning,
- 'image_cond': image_conditioning,
- 'uncond': unconditional_conditioning,
- 'cond_scale': p.cfg_scale,
- 's_min_uncond': self.s_min_uncond
- }, disable=False, callback=self.callback_state, **extra_params_kwargs))
- return samples
+import sys
+import time
+import inspect
+from collections import deque
+import torch
+from modules import prompt_parser
+from modules import devices
+from modules import sd_samplers_common
+import modules.shared as shared
+from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
+from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
+from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
+
+
+# deal with k-diffusion imports
+k_sampling = None
+try:
+ import k_diffusion # pylint: disable=wrong-import-order
+ k_sampling = k_diffusion.sampling
+except ImportError:
+ pass
+try:
+ if k_sampling is None:
+ import importlib
+ k_diffusion = importlib.import_module('modules.k-diffusion.k_diffusion')
+ k_sampling = k_diffusion.sampling
+except Exception:
+ pass
+if k_sampling is None:
+ shared.log.info(f'Path search: {sys.path}')
+ shared.log.error("Module not found: k-diffusion")
+ sys.exit(1)
+
+
+samplers_k_diffusion = [
+ ('Euler', 'sample_euler', ['k_euler'], {"scheduler": "default"}),
+ ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"scheduler": "default", "brownian_noise": False}),
+ ('Heun', 'sample_heun', ['k_heun'], {"scheduler": "default"}),
+ ('LMS', 'sample_lms', ['k_lms'], {"scheduler": "default"}),
+ ('DPM Adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"scheduler": "default", "brownian_noise": False}),
+ ('DPM Fast', 'sample_dpm_fast', ['k_dpm_fast'], {"scheduler": "default", "brownian_noise": False}),
+ ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True, "second_order": True, "scheduler": "default", "brownian_noise": False}),
+ ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "second_order": True, "scheduler": "default", "brownian_noise": False}),
+ ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {"scheduler": "default", "brownian_noise": False}),
+ ('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde'], {'discard_next_to_last_sigma': True, "scheduler": "default", "brownian_noise": False}),
+ ('DPM++ 2M SDE Heun', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun'], {"solver_type": "heun", "scheduler": "default", "brownian_noise": False}),
+ ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {"second_order": True, "scheduler": "default", "brownian_noise": False}),
+ ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "scheduler": "default", "brownian_noise": False}),
+ ('DPM++ 3M SDE', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde'], {'discard_next_to_last_sigma': True, "scheduler": "default", "brownian_noise": False}),
+]
+
+samplers_data_k_diffusion = [
+ sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options)
+ for label, funcname, aliases, options in samplers_k_diffusion
+ if hasattr(k_sampling, funcname)
+]
+
+sampler_extra_params = {
+ 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
+ 'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
+ 'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
+}
+
+
+class CFGDenoiser(torch.nn.Module):
+ """
+ Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
+ that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
+ instead of one. Originally, the second prompt is just an empty string, but we use non-empty
+ negative prompt.
+ """
+ def __init__(self, model):
+ super().__init__()
+ self.inner_model = model
+ self.mask = None
+ self.nmask = None
+ self.init_latent = None
+ self.step = 0
+ self.image_cfg_scale = None
+
+ def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
+ denoised_uncond = x_out[-uncond.shape[0]:]
+ denoised = torch.clone(denoised_uncond)
+ for i, conds in enumerate(conds_list):
+ for cond_index, weight in conds:
+ denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
+ return denoised
+
+ def combine_denoised_for_edit_model(self, x_out, cond_scale):
+ out_cond, out_img_cond, out_uncond = x_out.chunk(3)
+ denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond)
+ return denoised
+
+ def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
+ if shared.state.interrupted or shared.state.skipped:
+ raise sd_samplers_common.InterruptedException
+ if shared.state.paused:
+ shared.log.debug('Sampling paused')
+ while shared.state.paused:
+ if shared.state.interrupted or shared.state.skipped:
+ raise sd_samplers_common.InterruptedException
+ time.sleep(0.1)
+ # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
+ # so is_edit_model is set to False to support AND composition.
+ is_edit_model = (shared.sd_model is not None) and hasattr(shared.sd_model, 'cond_stage_key') and (shared.sd_model.cond_stage_key == "edit") and (self.image_cfg_scale is not None) and (self.image_cfg_scale != 1.0)
+ conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
+ uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
+ assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
+ batch_size = len(conds_list)
+ repeats = [len(conds_list[i]) for i in range(batch_size)]
+ if shared.sd_model.model.conditioning_key == "crossattn-adm":
+ image_uncond = torch.zeros_like(image_cond)
+ make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm} # pylint: disable=C3001
+ else:
+ image_uncond = image_cond
+ make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]} # pylint: disable=C3001
+ if not is_edit_model:
+ x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
+ sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
+ image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond])
+ else:
+ x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
+ sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
+ image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])
+ denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, shared.state.sampling_step, shared.state.sampling_steps, tensor, uncond)
+ cfg_denoiser_callback(denoiser_params)
+ x_in = denoiser_params.x
+ image_cond_in = denoiser_params.image_cond
+ sigma_in = denoiser_params.sigma
+ tensor = denoiser_params.text_cond
+ uncond = denoiser_params.text_uncond
+ skip_uncond = False
+ # alternating uncond allows for higher thresholds without the quality loss normally expected from raising it
+ if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model:
+ skip_uncond = True
+ x_in = x_in[:-batch_size]
+ sigma_in = sigma_in[:-batch_size]
+
+ if tensor.shape[1] == uncond.shape[1] or skip_uncond:
+ if is_edit_model:
+ cond_in = torch.cat([tensor, uncond, uncond])
+ elif skip_uncond:
+ cond_in = tensor
+ else:
+ cond_in = torch.cat([tensor, uncond])
+ """
+ adjusted_cond_scale = cond_scale # Adjusted cond_scale for uncond
+ last_uncond_steps = max(0, state.sampling_steps - 2) # Determine the last two steps before uncond stops
+ if self.step >= last_uncond_steps: # Check if we're in the last two steps before uncond stops
+ adjusted_cond_scale *= 1.5 # Apply uncond with 150% cond_scale
+ else:
+ if (self.step - last_uncond_steps) % 3 == 0: # Check if it's one of every three steps after uncond stops
+ adjusted_cond_scale *= 1.5 # Apply uncond with 150% cond_scale
+ """
+ if shared.batch_cond_uncond:
+ x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict([cond_in], image_cond_in))
+ else:
+ x_out = torch.zeros_like(x_in)
+ for batch_offset in range(0, x_out.shape[0], batch_size):
+ a = batch_offset
+ b = a + batch_size
+ x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict([cond_in[a:b]], image_cond_in[a:b]))
+ else:
+ x_out = torch.zeros_like(x_in)
+ batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
+ for batch_offset in range(0, tensor.shape[0], batch_size):
+ a = batch_offset
+ b = min(a + batch_size, tensor.shape[0])
+ if not is_edit_model:
+ c_crossattn = [tensor[a:b]]
+ else:
+ c_crossattn = torch.cat([tensor[a:b]], uncond)
+ x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
+ if not skip_uncond:
+ x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:]))
+ denoised_image_indexes = [x[0][0] for x in conds_list]
+ if skip_uncond:
+ fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
+ x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
+ denoised_params = CFGDenoisedParams(x_out, shared.state.sampling_step, shared.state.sampling_steps, self.inner_model)
+ cfg_denoised_callback(denoised_params)
+ devices.test_for_nans(x_out, "unet")
+ if shared.opts.live_preview_content == "Prompt":
+ sd_samplers_common.store_latent(torch.cat([x_out[i:i+1] for i in denoised_image_indexes]))
+ elif shared.opts.live_preview_content == "Negative prompt":
+ sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])
+ if is_edit_model:
+ denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
+ elif skip_uncond:
+ denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
+ else:
+ denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
+ if self.mask is not None:
+ if devices.backend == "directml":
+ self.init_latent = self.init_latent.float()
+ denoised = self.init_latent * self.mask + self.nmask * denoised
+ self.init_latent = self.init_latent.half()
+ else:
+ denoised = self.init_latent * self.mask + self.nmask * denoised
+ after_cfg_callback_params = AfterCFGCallbackParams(denoised, shared.state.sampling_step, shared.state.sampling_steps)
+ cfg_after_cfg_callback(after_cfg_callback_params)
+ denoised = after_cfg_callback_params.x
+ self.step += 1
+ return denoised
+
+
+class TorchHijack:
+ def __init__(self, sampler_noises):
+ # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
+ # implementation.
+ self.sampler_noises = deque(sampler_noises)
+
+ def __getattr__(self, item):
+ if item == 'randn_like':
+ return self.randn_like
+ if hasattr(torch, item):
+ return getattr(torch, item)
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
+
+ def randn_like(self, x):
+ if self.sampler_noises:
+ noise = self.sampler_noises.popleft()
+ if noise.shape == x.shape:
+ return noise
+ if x.device.type == 'mps':
+ return torch.randn_like(x, device=devices.cpu).to(x.device)
+ else:
+ return torch.randn_like(x)
+
+
+class KDiffusionSampler:
+ def __init__(self, funcname, sd_model):
+ denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
+ self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
+ self.funcname = funcname
+ self.func = getattr(k_sampling, self.funcname)
+ self.extra_params = sampler_extra_params.get(funcname, [])
+ self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
+ self.sampler_noises = None
+ self.stop_at = None
+ self.eta = None
+ self.config = None # set by the function calling the constructor
+ self.last_latent = None
+ self.s_min_uncond = None
+ self.conditioning_key = sd_model.model.conditioning_key
+
+ def callback_state(self, d):
+ step = d['i']
+ latent = d["denoised"]
+ if shared.opts.live_preview_content == "Combined":
+ sd_samplers_common.store_latent(latent)
+ self.last_latent = latent
+ if self.stop_at is not None and step > self.stop_at:
+ raise sd_samplers_common.InterruptedException
+ shared.state.sampling_step = step
+
+ def launch_sampling(self, steps, func):
+ shared.state.sampling_steps = steps
+ shared.state.sampling_step = 0
+ try:
+ return func()
+ except sd_samplers_common.InterruptedException:
+ return self.last_latent
+
+ def number_of_needed_noises(self, p):
+ return p.steps
+
+ def initialize(self, p):
+ if self.config.options.get('brownian_noise', None) is not None:
+ self.config.options['brownian_noise'] = shared.opts.data.get('schedulers_brownian_noise', False)
+ if self.config.options.get('scheduler', None) is not None:
+ self.config.options['scheduler'] = shared.opts.data.get('schedulers_sigma', None)
+ if p is None:
+ return {}
+ self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
+ self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
+ self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
+ self.eta = p.eta if p.eta is not None else shared.opts.scheduler_eta
+ self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)
+ k_sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
+ extra_params_kwargs = {}
+ for param_name in self.extra_params:
+ if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
+ extra_params_kwargs[param_name] = getattr(p, param_name)
+ if 'eta' in inspect.signature(self.func).parameters:
+ if self.eta != 1.0:
+ p.extra_generation_params["Sampler Eta"] = self.eta
+ extra_params_kwargs['eta'] = self.eta
+ return extra_params_kwargs
+
+ def get_sigmas(self, p, steps): # pylint: disable=unused-argument
+ discard_next_to_last_sigma = shared.opts.data.get('schedulers_discard_penultimate', True) if self.config.options.get('discard_next_to_last_sigma', None) is not None else False
+ steps += 1 if discard_next_to_last_sigma else 0
+ if self.config.options.get('scheduler', None) == 'default' or self.config.options.get('scheduler', None) is None:
+ sigmas = self.model_wrap.get_sigmas(steps)
+ elif self.config.options.get('scheduler', None) == 'karras':
+ sigma_min = p.s_min if p.s_min > 0 else self.model_wrap.sigmas[0].item()
+ sigma_max = p.s_max if p.s_max > 0 else self.model_wrap.sigmas[-1].item()
+ sigmas = k_sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)
+ elif self.config.options.get('scheduler', None) == 'exponential':
+ sigma_min = p.s_min if p.s_min > 0 else self.model_wrap.sigmas[0].item()
+ sigma_max = p.s_max if p.s_max > 0 else self.model_wrap.sigmas[-1].item()
+ sigmas = k_sampling.get_sigmas_exponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)
+ elif self.config.options.get('scheduler', None) == 'polyexponential':
+ sigma_min = p.s_min if p.s_min > 0 else self.model_wrap.sigmas[0].item()
+ sigma_max = p.s_max if p.s_max > 0 else self.model_wrap.sigmas[-1].item()
+ sigmas = k_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)
+ elif self.config.options.get('scheduler', None) == 'vp':
+ sigmas = k_sampling.get_sigmas_vp(n=steps, device=shared.device)
+ if discard_next_to_last_sigma:
+ sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
+ return sigmas
+
+ def create_noise_sampler(self, x, sigmas, p):
+ """For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
+ if shared.opts.no_dpmpp_sde_batch_determinism:
+ return None
+ positive_sigmas = sigmas[sigmas > 0]
+ if positive_sigmas.numel() > 0:
+ sigma_min = positive_sigmas.min(dim=0)[0]
+ else:
+ sigma_min = 0
+ sigma_max = sigmas.max()
+ current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
+ return k_sampling.BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
+
+ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
+ steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
+ sigmas = self.get_sigmas(p, steps)
+ sigma_sched = sigmas[steps - t_enc - 1:]
+ xi = x + noise * sigma_sched[0]
+ extra_params_kwargs = self.initialize(p)
+ parameters = inspect.signature(self.func).parameters
+ if 'sigma_min' in parameters:
+ ## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last
+ extra_params_kwargs['sigma_min'] = sigma_sched[-2]
+ if 'sigma_max' in parameters:
+ extra_params_kwargs['sigma_max'] = sigma_sched[0]
+ if 'n' in parameters:
+ extra_params_kwargs['n'] = len(sigma_sched) - 1
+ if 'sigma_sched' in parameters:
+ extra_params_kwargs['sigma_sched'] = sigma_sched
+ if 'sigmas' in parameters:
+ extra_params_kwargs['sigmas'] = sigma_sched
+ if self.config.options.get('brownian_noise', False) and 'noise_sampler' in parameters:
+ noise_sampler = self.create_noise_sampler(x, sigmas, p)
+ extra_params_kwargs['noise_sampler'] = noise_sampler
+ self.model_wrap_cfg.init_latent = x
+ self.last_latent = x
+ extra_args = {
+ 'cond': conditioning,
+ 'image_cond': image_conditioning,
+ 'uncond': unconditional_conditioning,
+ 'cond_scale': p.cfg_scale,
+ 's_min_uncond': self.s_min_uncond
+ }
+ samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
+ samples = samples.type(devices.dtype)
+ return samples
+
+ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
+ steps = steps or p.steps
+ sigmas = self.get_sigmas(p, steps)
+ x = x * sigmas[0]
+ extra_params_kwargs = self.initialize(p)
+ parameters = inspect.signature(self.func).parameters
+ if 'sigma_min' in parameters:
+ extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
+ extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
+ if 'n' in parameters:
+ extra_params_kwargs['n'] = steps
+ else:
+ extra_params_kwargs['sigmas'] = sigmas
+ if self.config.options.get('brownian_noise', False) and 'noise_sampler' in parameters:
+ noise_sampler = self.create_noise_sampler(x, sigmas, p)
+ extra_params_kwargs['noise_sampler'] = noise_sampler
+ self.last_latent = x
+ samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
+ 'cond': conditioning,
+ 'image_cond': image_conditioning,
+ 'uncond': unconditional_conditioning,
+ 'cond_scale': p.cfg_scale,
+ 's_min_uncond': self.s_min_uncond
+ }, disable=False, callback=self.callback_state, **extra_params_kwargs))
+ return samples
diff --git a/modules/sd_samplers_timesteps_impl.py b/modules/sd_samplers_timesteps_impl.py
index a15c76a4a..03716ee08 100644
--- a/modules/sd_samplers_timesteps_impl.py
+++ b/modules/sd_samplers_timesteps_impl.py
@@ -1,139 +1,139 @@
-# TODO a1111 compatibility module
-
-import torch
-import tqdm
-import k_diffusion.sampling
-import numpy as np
-
-from modules import shared
-from modules.unipc import uni_pc
-
-
-@torch.no_grad()
-def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
- alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
- alphas = alphas_cumprod[timesteps]
- alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32) # pylint: disable=not-callable
- sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
- sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
-
- extra_args = {} if extra_args is None else extra_args
- s_in = x.new_ones((x.shape[0]))
- s_x = x.new_ones((x.shape[0], 1, 1, 1))
- for i in tqdm.trange(len(timesteps) - 1, disable=disable):
- index = len(timesteps) - 1 - i
-
- e_t = model(x, timesteps[index].item() * s_in, **extra_args)
-
- a_t = alphas[index].item() * s_x
- a_prev = alphas_prev[index].item() * s_x
- sigma_t = sigmas[index].item() * s_x
- sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
-
- pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
- dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
- noise = sigma_t * k_diffusion.sampling.torch.randn_like(x)
- x = a_prev.sqrt() * pred_x0 + dir_xt + noise
-
- if callback is not None:
- callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})
-
- return x
-
-
-@torch.no_grad()
-def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
- alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
- alphas = alphas_cumprod[timesteps]
- alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32) # pylint: disable=not-callable
- sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
-
- extra_args = {} if extra_args is None else extra_args
- s_in = x.new_ones([x.shape[0]])
- s_x = x.new_ones((x.shape[0], 1, 1, 1))
- old_eps = []
-
- def get_x_prev_and_pred_x0(e_t, index):
- # select parameters corresponding to the currently considered timestep
- a_t = alphas[index].item() * s_x
- a_prev = alphas_prev[index].item() * s_x
- sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
-
- # current prediction for x_0
- pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
-
- # direction pointing to x_t
- dir_xt = (1. - a_prev).sqrt() * e_t
- x_prev = a_prev.sqrt() * pred_x0 + dir_xt
- return x_prev, pred_x0
-
- for i in tqdm.trange(len(timesteps) - 1, disable=disable):
- index = len(timesteps) - 1 - i
- ts = timesteps[index].item() * s_in
- t_next = timesteps[max(index - 1, 0)].item() * s_in
-
- e_t = model(x, ts, **extra_args)
-
- if len(old_eps) == 0:
- # Pseudo Improved Euler (2nd order)
- x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
- e_t_next = model(x_prev, t_next, **extra_args)
- e_t_prime = (e_t + e_t_next) / 2
- elif len(old_eps) == 1:
- # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
- e_t_prime = (3 * e_t - old_eps[-1]) / 2
- elif len(old_eps) == 2:
- # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
- e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
- else:
- # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
- e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
-
- x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
-
- old_eps.append(e_t)
- if len(old_eps) >= 4:
- old_eps.pop(0)
-
- x = x_prev
-
- if callback is not None:
- callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})
-
- return x
-
-
-class UniPCCFG(uni_pc.UniPC):
- def __init__(self, cfg_model, extra_args, callback, *args, **kwargs):
- super().__init__(None, *args, **kwargs)
-
- def after_update(x, model_x):
- callback({'x': x, 'i': self.index, 'sigma': 0, 'sigma_hat': 0, 'denoised': model_x})
- self.index += 1
-
- self.cfg_model = cfg_model
- self.extra_args = extra_args
- self.callback = callback
- self.index = 0
- self.after_update = after_update
-
- def get_model_input_time(self, t_continuous):
- return (t_continuous - 1. / self.noise_schedule.total_N) * 1000.
-
- def model(self, x, t):
- t_input = self.get_model_input_time(t)
-
- res = self.cfg_model(x, t_input, **self.extra_args)
-
- return res
-
-
-def unipc(model, x, timesteps, extra_args=None, callback=None, disable=None, is_img2img=False): # pylint: disable=unused-argument
- alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
-
- ns = uni_pc.NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
- t_start = timesteps[-1] / 1000 + 1 / 1000 if is_img2img else None # this is likely off by a bit - if someone wants to fix it please by all means
- unipc_sampler = UniPCCFG(model, extra_args, callback, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant)
- x = unipc_sampler.sample(x, steps=len(timesteps), t_start=t_start, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final)
-
- return x
+# TODO a1111 compatibility module
+
+import torch
+import tqdm
+import k_diffusion.sampling
+import numpy as np
+
+from modules import shared
+from modules.unipc import uni_pc
+
+
+@torch.no_grad()
+def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
+ alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
+ alphas = alphas_cumprod[timesteps]
+ alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32) # pylint: disable=not-callable
+ sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
+ sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
+
+ extra_args = {} if extra_args is None else extra_args
+ s_in = x.new_ones((x.shape[0]))
+ s_x = x.new_ones((x.shape[0], 1, 1, 1))
+ for i in tqdm.trange(len(timesteps) - 1, disable=disable):
+ index = len(timesteps) - 1 - i
+
+ e_t = model(x, timesteps[index].item() * s_in, **extra_args)
+
+ a_t = alphas[index].item() * s_x
+ a_prev = alphas_prev[index].item() * s_x
+ sigma_t = sigmas[index].item() * s_x
+ sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
+
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
+ noise = sigma_t * k_diffusion.sampling.torch.randn_like(x)
+ x = a_prev.sqrt() * pred_x0 + dir_xt + noise
+
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})
+
+ return x
+
+
+@torch.no_grad()
+def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
+ alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
+ alphas = alphas_cumprod[timesteps]
+ alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32) # pylint: disable=not-callable
+ sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
+
+ extra_args = {} if extra_args is None else extra_args
+ s_in = x.new_ones([x.shape[0]])
+ s_x = x.new_ones((x.shape[0], 1, 1, 1))
+ old_eps = []
+
+ def get_x_prev_and_pred_x0(e_t, index):
+ # select parameters corresponding to the currently considered timestep
+ a_t = alphas[index].item() * s_x
+ a_prev = alphas_prev[index].item() * s_x
+ sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
+
+ # current prediction for x_0
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+
+ # direction pointing to x_t
+ dir_xt = (1. - a_prev).sqrt() * e_t
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt
+ return x_prev, pred_x0
+
+ for i in tqdm.trange(len(timesteps) - 1, disable=disable):
+ index = len(timesteps) - 1 - i
+ ts = timesteps[index].item() * s_in
+ t_next = timesteps[max(index - 1, 0)].item() * s_in
+
+ e_t = model(x, ts, **extra_args)
+
+ if len(old_eps) == 0:
+ # Pseudo Improved Euler (2nd order)
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
+ e_t_next = model(x_prev, t_next, **extra_args)
+ e_t_prime = (e_t + e_t_next) / 2
+ elif len(old_eps) == 1:
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
+ elif len(old_eps) == 2:
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
+ else:
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
+
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
+
+ old_eps.append(e_t)
+ if len(old_eps) >= 4:
+ old_eps.pop(0)
+
+ x = x_prev
+
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})
+
+ return x
+
+
+class UniPCCFG(uni_pc.UniPC):
+ def __init__(self, cfg_model, extra_args, callback, *args, **kwargs):
+ super().__init__(None, *args, **kwargs)
+
+ def after_update(x, model_x):
+ callback({'x': x, 'i': self.index, 'sigma': 0, 'sigma_hat': 0, 'denoised': model_x})
+ self.index += 1
+
+ self.cfg_model = cfg_model
+ self.extra_args = extra_args
+ self.callback = callback
+ self.index = 0
+ self.after_update = after_update
+
+ def get_model_input_time(self, t_continuous):
+ return (t_continuous - 1. / self.noise_schedule.total_N) * 1000.
+
+ def model(self, x, t):
+ t_input = self.get_model_input_time(t)
+
+ res = self.cfg_model(x, t_input, **self.extra_args)
+
+ return res
+
+
+def unipc(model, x, timesteps, extra_args=None, callback=None, disable=None, is_img2img=False): # pylint: disable=unused-argument
+ alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
+
+ ns = uni_pc.NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
+ t_start = timesteps[-1] / 1000 + 1 / 1000 if is_img2img else None # this is likely off by a bit - if someone wants to fix it please by all means
+ unipc_sampler = UniPCCFG(model, extra_args, callback, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant)
+ x = unipc_sampler.sample(x, steps=len(timesteps), t_start=t_start, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final)
+
+ return x
diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py
index 14a67fca1..0b6db41bc 100644
--- a/modules/sd_vae_approx.py
+++ b/modules/sd_vae_approx.py
@@ -1,80 +1,80 @@
-import os
-import torch
-from torch import nn
-from modules import devices, paths, shared
-
-
-sd_vae_approx_model = None
-
-
-class VAEApprox(nn.Module):
- def __init__(self):
- super().__init__()
- self.conv1 = nn.Conv2d(4, 8, (7, 7))
- self.conv2 = nn.Conv2d(8, 16, (5, 5))
- self.conv3 = nn.Conv2d(16, 32, (3, 3))
- self.conv4 = nn.Conv2d(32, 64, (3, 3))
- self.conv5 = nn.Conv2d(64, 32, (3, 3))
- self.conv6 = nn.Conv2d(32, 16, (3, 3))
- self.conv7 = nn.Conv2d(16, 8, (3, 3))
- self.conv8 = nn.Conv2d(8, 3, (3, 3))
-
- def forward(self, x):
- extra = 11
- try:
- x = nn.functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2))
- x = nn.functional.pad(x, (extra, extra, extra, extra)) # pylint: disable=not-callable
- for layer in [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7, self.conv8, ]:
- x = layer(x)
- x = nn.functional.leaky_relu(x, 0.1)
- except Exception:
- pass
- return x
-
-
-def nn_approximation(sample): # Approximate NN
- global sd_vae_approx_model # pylint: disable=global-statement
- if sd_vae_approx_model is None:
- model_path = os.path.join(paths.models_path, "VAE-approx", "model.pt")
- sd_vae_approx_model = VAEApprox()
- if not os.path.exists(model_path):
- model_path = os.path.join(paths.script_path, "models", "VAE-approx", "model.pt")
- approx_weights = torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None)
- sd_vae_approx_model.load_state_dict(approx_weights)
- sd_vae_approx_model.eval()
- sd_vae_approx_model.to(devices.device, sample.dtype)
- shared.log.debug(f'Load VAE decode approximate: model="{model_path}"')
- try:
- in_sample = sample.to(devices.device).unsqueeze(0)
- x_sample = sd_vae_approx_model(in_sample)
- x_sample = x_sample[0]
- return x_sample
- except Exception as e:
- shared.log.error(f'Decode approximate: {e}')
- return sample
-
-
-def cheap_approximation(sample): # Approximate simple
- # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2
- if shared.sd_model_type == "sdxl":
- simple_weights = torch.tensor([
- [0.4543,-0.2868, 0.1566,-0.4748],
- [0.5008, 0.0952, 0.2155,-0.3268],
- [0.5294, 0.1625,-0.0624,-0.3793]
- ]).reshape(3, 4, 1, 1)
- simple_bias = torch.tensor([0.1375, 0.0144, -0.0675])
- else:
- simple_weights = torch.tensor([
- [0.298, 0.187,-0.158,-0.184],
- [0.207, 0.286, 0.189,-0.271],
- [0.208, 0.173, 0.264,-0.473],
- ]).reshape(3, 4, 1, 1)
- simple_bias = None
- try:
- weights = simple_weights.to(sample.device, sample.dtype)
- bias = simple_bias.to(sample.device, sample.dtype) if simple_bias is not None else None
- x_sample = nn.functional.conv2d(sample, weights, bias) # pylint: disable=not-callable
- return x_sample
- except Exception as e:
- shared.log.error(f'Decode simple: {e}')
- return sample
+import os
+import torch
+from torch import nn
+from modules import devices, paths, shared
+
+
+sd_vae_approx_model = None
+
+
+class VAEApprox(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.conv1 = nn.Conv2d(4, 8, (7, 7))
+ self.conv2 = nn.Conv2d(8, 16, (5, 5))
+ self.conv3 = nn.Conv2d(16, 32, (3, 3))
+ self.conv4 = nn.Conv2d(32, 64, (3, 3))
+ self.conv5 = nn.Conv2d(64, 32, (3, 3))
+ self.conv6 = nn.Conv2d(32, 16, (3, 3))
+ self.conv7 = nn.Conv2d(16, 8, (3, 3))
+ self.conv8 = nn.Conv2d(8, 3, (3, 3))
+
+ def forward(self, x):
+ extra = 11
+ try:
+ x = nn.functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2))
+ x = nn.functional.pad(x, (extra, extra, extra, extra)) # pylint: disable=not-callable
+ for layer in [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7, self.conv8, ]:
+ x = layer(x)
+ x = nn.functional.leaky_relu(x, 0.1)
+ except Exception:
+ pass
+ return x
+
+
+def nn_approximation(sample): # Approximate NN
+ global sd_vae_approx_model # pylint: disable=global-statement
+ if sd_vae_approx_model is None:
+ model_path = os.path.join(paths.models_path, "VAE-approx", "model.pt")
+ sd_vae_approx_model = VAEApprox()
+ if not os.path.exists(model_path):
+ model_path = os.path.join(paths.script_path, "models", "VAE-approx", "model.pt")
+ approx_weights = torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None)
+ sd_vae_approx_model.load_state_dict(approx_weights)
+ sd_vae_approx_model.eval()
+ sd_vae_approx_model.to(devices.device, sample.dtype)
+ shared.log.debug(f'Load VAE decode approximate: model="{model_path}"')
+ try:
+ in_sample = sample.to(devices.device).unsqueeze(0)
+ x_sample = sd_vae_approx_model(in_sample)
+ x_sample = x_sample[0]
+ return x_sample
+ except Exception as e:
+ shared.log.error(f'Decode approximate: {e}')
+ return sample
+
+
+def cheap_approximation(sample): # Approximate simple
+ # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2
+ if shared.sd_model_type == "sdxl":
+ simple_weights = torch.tensor([
+ [0.4543,-0.2868, 0.1566,-0.4748],
+ [0.5008, 0.0952, 0.2155,-0.3268],
+ [0.5294, 0.1625,-0.0624,-0.3793]
+ ]).reshape(3, 4, 1, 1)
+ simple_bias = torch.tensor([0.1375, 0.0144, -0.0675])
+ else:
+ simple_weights = torch.tensor([
+ [0.298, 0.187,-0.158,-0.184],
+ [0.207, 0.286, 0.189,-0.271],
+ [0.208, 0.173, 0.264,-0.473],
+ ]).reshape(3, 4, 1, 1)
+ simple_bias = None
+ try:
+ weights = simple_weights.to(sample.device, sample.dtype)
+ bias = simple_bias.to(sample.device, sample.dtype) if simple_bias is not None else None
+ x_sample = nn.functional.conv2d(sample, weights, bias) # pylint: disable=not-callable
+ return x_sample
+ except Exception as e:
+ shared.log.error(f'Decode simple: {e}')
+ return sample
diff --git a/modules/shared.py b/modules/shared.py
index 762f729f5..4c4396536 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -1,986 +1,986 @@
-import io
-import os
-import sys
-import time
-import json
-import contextlib
-from types import SimpleNamespace
-from urllib.parse import urlparse
-from enum import Enum
-import requests
-import gradio as gr
-import fasteners
-import orjson
-import diffusers
-from rich.console import Console
-from modules import errors, shared_items, shared_state, cmd_args, ui_components, theme
-from modules.paths import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # pylint: disable=W0611
-from modules.dml import memory_providers, default_memory_provider, directml_do_hijack
-import modules.interrogate
-import modules.memmon
-import modules.styles
-import modules.devices as devices # pylint: disable=R0402
-import modules.paths as paths
-from installer import print_dict
-from installer import log as central_logger # pylint: disable=E0611
-
-
-errors.install([gr])
-demo: gr.Blocks = None
-log = central_logger
-progress_print_out = sys.stdout
-parser = cmd_args.parser
-url = 'https://github.com/vladmandic/automatic'
-cmd_opts, _ = parser.parse_known_args()
-hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
-xformers_available = False
-clip_model = None
-interrogator = modules.interrogate.InterrogateModels("interrogate")
-sd_upscalers = []
-face_restorers = []
-tab_names = []
-extra_networks = []
-options_templates = {}
-hypernetworks = {}
-loaded_hypernetworks = []
-settings_components = None
-latent_upscale_default_mode = "None"
-latent_upscale_modes = {
- "Latent Nearest": {"mode": "nearest", "antialias": False},
- "Latent Nearest-exact": {"mode": "nearest-exact", "antialias": False},
- "Latent Area": {"mode": "area", "antialias": False},
- "Latent Bilinear": {"mode": "bilinear", "antialias": False},
- "Latent Bicubic": {"mode": "bicubic", "antialias": False},
- "Latent Bilinear antialias": {"mode": "bilinear", "antialias": True},
- "Latent Bicubic antialias": {"mode": "bicubic", "antialias": True},
- # "Latent Linear": {"mode": "linear", "antialias": False}, # not supported for latents with channels=4
- # "Latent Trilinear": {"mode": "trilinear", "antialias": False}, # not supported for latents with channels=4
-}
-restricted_opts = {
- "samples_filename_pattern",
- "directories_filename_pattern",
- "outdir_samples",
- "outdir_txt2img_samples",
- "outdir_img2img_samples",
- "outdir_extras_samples",
- "outdir_grids",
- "outdir_txt2img_grids",
- "outdir_save",
- "outdir_init_images"
-}
-resize_modes = ["None", "Fixed", "Crop", "Fill", "Latent"]
-compatibility_opts = ['clip_skip', 'uni_pc_lower_order_final', 'uni_pc_order']
-console = Console(log_time=True, log_time_format='%H:%M:%S-%f')
-dir_timestamps = {}
-dir_cache = {}
-
-
-class Backend(Enum):
- ORIGINAL = 1
- DIFFUSERS = 2
-
-
-state = shared_state.State()
-if not hasattr(cmd_opts, "use_openvino"):
- cmd_opts.use_openvino = False
-
-
-def readfile(filename, silent=False, lock=False):
- data = {}
- lock_file = None
- locked = False
- try:
- # if not os.path.exists(filename):
- # return {}
- t0 = time.time()
- if lock:
- lock_file = fasteners.InterProcessReaderWriterLock(f"{filename}.lock", logger=log)
- locked = lock_file.acquire_read_lock(blocking=True, timeout=3)
- with open(filename, "rb") as file:
- b = file.read()
- data = orjson.loads(b) # pylint: disable=no-member
- # if type(data) is str:
- # data = json.loads(data)
- t1 = time.time()
- if not silent:
- log.debug(f'Read: file="{filename}" json={len(data)} bytes={os.path.getsize(filename)} time={t1-t0:.3f}')
- except Exception as e:
- if not silent:
- log.error(f'Reading failed: {filename} {e}')
- finally:
- if lock_file is not None:
- lock_file.release_read_lock()
- if locked and os.path.exists(f"{filename}.lock"):
- os.remove(f"{filename}.lock")
- return data
-
-
-def writefile(data, filename, mode='w', silent=False, atomic=False):
- lock = None
- locked = False
- import tempfile
-
-
- def default(obj):
- log.error(f"Saving: {filename} not a valid object: {obj}")
- return str(obj)
-
- try:
- t0 = time.time()
- # skipkeys=True, ensure_ascii=True, check_circular=True, allow_nan=True
- if type(data) == dict:
- output = json.dumps(data, indent=2, default=default)
- elif type(data) == list:
- output = json.dumps(data, indent=2, default=default)
- elif isinstance(data, object):
- simple = {}
- for k in data.__dict__:
- if data.__dict__[k] is not None:
- simple[k] = data.__dict__[k]
- output = json.dumps(simple, indent=2, default=default)
- else:
- raise ValueError('not a valid object')
- lock = fasteners.InterProcessReaderWriterLock(f"{filename}.lock", logger=log)
- locked = lock.acquire_write_lock(blocking=True, timeout=3)
- if atomic:
- with tempfile.NamedTemporaryFile(mode=mode, encoding="utf8", delete=False, dir=os.path.dirname(filename)) as f:
- f.write(output)
- f.flush()
- os.fsync(f.fileno())
- os.replace(f.name, filename)
- else:
- with open(filename, mode=mode, encoding="utf8") as file:
- file.write(output)
- t1 = time.time()
- if not silent:
- log.debug(f'Save: file="{filename}" json={len(data)} bytes={len(output)} time={t1-t0:.3f}')
- except Exception as e:
- log.error(f'Saving failed: {filename} {e}')
- errors.display(e, 'Saving failed')
- finally:
- if lock is not None:
- lock.release_read_lock()
- if locked and os.path.exists(f"{filename}.lock"):
- os.remove(f"{filename}.lock")
-
-
-# early select backend
-default_backend = 'original'
-early_opts = readfile(cmd_opts.config, silent=True)
-early_backend = early_opts.get('sd_backend', default_backend)
-backend = Backend.DIFFUSERS if early_backend.lower() == 'diffusers' else Backend.ORIGINAL
-if cmd_opts.backend is not None: # override with args
- backend = Backend.DIFFUSERS if cmd_opts.backend.lower() == 'diffusers' else Backend.ORIGINAL
-if cmd_opts.use_openvino: # override for openvino
- backend = Backend.DIFFUSERS
-
-
-class OptionInfo:
- def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, folder=None, submit=None, comment_before='', comment_after=''):
- self.default = default
- self.label = label
- self.component = component
- self.component_args = component_args
- self.onchange = onchange
- self.section = section
- self.refresh = refresh
- self.folder = folder
- self.comment_before = comment_before # HTML text that will be added after label in UI
- self.comment_after = comment_after # HTML text that will be added before label in UI
- self.submit = submit
-
- def link(self, label, uri):
- self.comment_before += f"[{label} ]"
- return self
-
- def js(self, label, js_func):
- self.comment_before += f"[{label} ]"
- return self
-
- def info(self, info):
- self.comment_after += f"({info}) "
- return self
-
- def html(self, info):
- self.comment_after += f"{info} "
- return self
-
- def needs_restart(self):
- self.comment_after += " (requires restart) "
- return self
-
-
-def options_section(section_identifier, options_dict):
- for v in options_dict.values():
- v.section = section_identifier
- return options_dict
-
-
-def list_checkpoint_tiles():
- import modules.sd_models # pylint: disable=W0621
- return modules.sd_models.checkpoint_tiles()
-
-default_checkpoint = list_checkpoint_tiles()[0] if len(list_checkpoint_tiles()) > 0 else "model.ckpt"
-
-
-def is_url(string):
- parsed_url = urlparse(string)
- return all([parsed_url.scheme, parsed_url.netloc])
-
-
-def reload_hypernetworks():
- from modules.hypernetworks import hypernetwork
- global hypernetworks # pylint: disable=W0603
- hypernetworks = hypernetwork.list_hypernetworks(opts.hypernetwork_dir)
-
-
-def refresh_checkpoints():
- import modules.sd_models # pylint: disable=W0621
- return modules.sd_models.list_models()
-
-
-def refresh_vaes():
- import modules.sd_vae # pylint: disable=W0621
- modules.sd_vae.refresh_vae_list()
-
-
-def refresh_upscalers():
- import modules.modelloader # pylint: disable=W0621
- modules.modelloader.load_upscalers()
-
-
-def list_samplers():
- import modules.sd_samplers # pylint: disable=W0621
- modules.sd_samplers.set_samplers()
- return modules.sd_samplers.all_samplers
-
-
-def temp_disable_extensions():
- disable_safe = ['sd-webui-controlnet', 'multidiffusion-upscaler-for-automatic1111', 'a1111-sd-webui-lycoris', 'sd-webui-agent-scheduler', 'clip-interrogator-ext', 'stable-diffusion-webui-rembg', 'sd-extension-chainner', 'stable-diffusion-webui-images-browser']
- disable_diffusers = ['sd-webui-controlnet', 'multidiffusion-upscaler-for-automatic1111', 'a1111-sd-webui-lycoris', 'sd-webui-animatediff']
- disable_original = []
- disabled = []
- if cmd_opts.safe:
- for ext in disable_safe:
- if ext not in opts.disabled_extensions:
- disabled.append(ext)
- if backend == Backend.DIFFUSERS:
- for ext in disable_diffusers:
- if ext not in opts.disabled_extensions:
- disabled.append(ext)
- if backend == Backend.ORIGINAL:
- for ext in disable_original:
- if ext not in opts.disabled_extensions:
- disabled.append(ext)
- cmd_opts.controlnet_loglevel = 'WARNING'
- return disabled
-
-
-if devices.backend == "cpu":
- cross_attention_optimization_default = "Doggettx's"
-elif devices.backend == "mps":
- cross_attention_optimization_default = "Doggettx's"
-elif devices.backend == "ipex":
- cross_attention_optimization_default = "Scaled-Dot-Product"
-elif devices.backend == "directml":
- cross_attention_optimization_default = "Sub-quadratic"
-elif devices.backend == "rocm":
- cross_attention_optimization_default = "Sub-quadratic"
-else: # cuda
- cross_attention_optimization_default ="Scaled-Dot-Product"
-
-
-options_templates.update(options_section(('sd', "Execution & Models"), {
- "sd_backend": OptionInfo(default_backend, "Execution backend", gr.Radio, {"choices": ["original", "diffusers"] }),
- "sd_model_checkpoint": OptionInfo(default_checkpoint, "Base model", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
- "sd_model_refiner": OptionInfo('None', "Refiner model", gr.Dropdown, lambda: {"choices": ['None'] + list_checkpoint_tiles()}, refresh=refresh_checkpoints),
- "sd_vae": OptionInfo("Automatic", "VAE model", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list),
- "sd_checkpoint_autoload": OptionInfo(True, "Model autoload on server start"),
- "sd_model_dict": OptionInfo('None', "Use baseline data from a different model", gr.Dropdown, lambda: {"choices": ['None'] + list_checkpoint_tiles()}, refresh=refresh_checkpoints),
- "stream_load": OptionInfo(False, "Load models using stream loading method", gr.Checkbox, {"visible": backend == Backend.ORIGINAL }),
- "model_reuse_dict": OptionInfo(False, "When loading models attempt to reuse previous model dictionary", gr.Checkbox, {"visible": False}),
- "prompt_attention": OptionInfo("Full parser", "Prompt attention parser", gr.Radio, {"choices": ["Full parser", "Compel parser", "A1111 parser", "Fixed attention"] }),
- "prompt_mean_norm": OptionInfo(True, "Prompt attention mean normalization", gr.Checkbox, {"visible": backend == Backend.ORIGINAL }),
- "comma_padding_backtrack": OptionInfo(20, "Prompt padding for long prompts", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1, "visible": backend == Backend.ORIGINAL }),
- "sd_checkpoint_cache": OptionInfo(0, "Number of cached models", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1, "visible": backend == Backend.ORIGINAL }),
- "sd_vae_checkpoint_cache": OptionInfo(0, "Number of cached VAEs", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1, "visible": False}),
- "sd_disable_ckpt": OptionInfo(False, "Disallow usage of models in ckpt format"),
-}))
-
-options_templates.update(options_section(('cuda', "Compute Settings"), {
- "math_sep": OptionInfo("Execution precision ", "", gr.HTML),
- "precision": OptionInfo("Autocast", "Precision type", gr.Radio, {"choices": ["Autocast", "Full"]}),
- "cuda_dtype": OptionInfo("FP32" if sys.platform == "darwin" or cmd_opts.use_openvino else "BF16" if devices.backend == "ipex" else "FP16", "Device precision type", gr.Radio, {"choices": ["FP32", "FP16", "BF16"]}),
- "no_half": OptionInfo(False if not cmd_opts.use_openvino else True, "Use full precision for model (--no-half)", None, None, None),
- "no_half_vae": OptionInfo(False if not cmd_opts.use_openvino else True, "Use full precision for VAE (--no-half-vae)"),
- "upcast_sampling": OptionInfo(False if sys.platform != "darwin" else True, "Enable upcast sampling"),
- "upcast_attn": OptionInfo(False, "Enable upcast cross attention layer"),
- "cuda_cast_unet": OptionInfo(False, "Use fixed UNet precision"),
- "disable_nan_check": OptionInfo(True, "Disable NaN check in produced images/latent spaces", gr.Checkbox, {"visible": False}),
- "rollback_vae": OptionInfo(False, "Attempt VAE roll back when produced NaN values"),
-
- "cross_attention_sep": OptionInfo("Cross-attention ", "", gr.HTML),
- "cross_attention_optimization": OptionInfo(cross_attention_optimization_default, "Cross-attention optimization method", gr.Radio, lambda: {"choices": shared_items.list_crossattention() }),
- "cross_attention_options": OptionInfo([], "Cross-attention advanced options", gr.CheckboxGroup, {"choices": ['xFormers enable flash Attention', 'SDP disable memory attention']}),
- "sub_quad_sep": OptionInfo("Sub-quadratic options ", "", gr.HTML),
- "sub_quad_q_chunk_size": OptionInfo(512, "cross-attention query chunk size", gr.Slider, {"minimum": 16, "maximum": 8192, "step": 8}),
- "sub_quad_kv_chunk_size": OptionInfo(512, "cross-attention kv chunk size", gr.Slider, {"minimum": 0, "maximum": 8192, "step": 8}),
- "sub_quad_chunk_threshold": OptionInfo(80, "cross-attention chunking threshold", gr.Slider, {"minimum": 0, "maximum": 100, "step": 1}),
-
- "other_sep": OptionInfo("Execution precision ", "", gr.HTML),
- "opt_channelslast": OptionInfo(False, "Use channels last as torch memory format "),
- "cudnn_benchmark": OptionInfo(False, "Enable full-depth cuDNN benchmark feature"),
- "diffusers_fuse_projections": OptionInfo(False, "Enable fused projections"),
- "torch_gc_threshold": OptionInfo(80, "Memory usage threshold before running Torch GC", gr.Slider, {"minimum": 0, "maximum": 100, "step": 1}),
-
- "cuda_compile_sep": OptionInfo("Model Compile ", "", gr.HTML),
- "cuda_compile": OptionInfo(False if not cmd_opts.use_openvino else True, "Compile UNet"),
- "cuda_compile_vae": OptionInfo(False if not cmd_opts.use_openvino else True, "Compile VAE"),
- "cuda_compile_text_encoder": OptionInfo(False, "Compile Text Encoder"),
- "cuda_compile_upscaler": OptionInfo(False if not cmd_opts.use_openvino else True, "Compile Upscaler"),
- "cuda_compile_backend": OptionInfo("none" if not cmd_opts.use_openvino else "openvino_fx", "Model compile backend", gr.Radio, {"choices": ['none', 'inductor', 'cudagraphs', 'aot_ts_nvfuser', 'hidet', 'ipex', 'openvino_fx', 'stable-fast']}),
- "cuda_compile_mode": OptionInfo("default", "Model compile mode", gr.Radio, {"choices": ['default', 'reduce-overhead', 'max-autotune', 'max-autotune-no-cudagraphs']}),
- "cuda_compile_fullgraph": OptionInfo(False, "Model compile fullgraph"),
- "cuda_compile_precompile": OptionInfo(False, "Model compile precompile"),
- "cuda_compile_verbose": OptionInfo(False, "Model compile verbose mode"),
- "cuda_compile_errors": OptionInfo(True, "Model compile suppress errors"),
- "diffusers_quantization": OptionInfo(False, "Enable dynamic quantization with torchao"),
-
- "nncf_sep": OptionInfo("NNCF ", "", gr.HTML),
- "nncf_compress_weights": OptionInfo(False, "Compress Model weights with NNCF"),
- "nncf_compress_vae_weights": OptionInfo(False, "Compress VAE weights with NNCF"),
- "nncf_compress_text_encoder_weights": OptionInfo(False, "Compress Text Encoder weights with NNCF"),
-
- "directml_sep": OptionInfo("DirectML ", "", gr.HTML),
- "directml_memory_provider": OptionInfo(default_memory_provider, 'DirectML memory stats provider', gr.Radio, {"choices": memory_providers}),
- "directml_catch_nan": OptionInfo(False, "DirectML retry specific operation when NaN is produced if possible. (makes generation slower)"),
-
- "ipex_sep": OptionInfo("IPEX ", "", gr.HTML),
- "ipex_optimize": OptionInfo(False if not devices.backend == "ipex" else True, "Enable IPEX Optimize for Intel GPUs with UNet"),
- "ipex_optimize_vae": OptionInfo(False if not devices.backend == "ipex" else True, "Enable IPEX Optimize for Intel GPUs with VAE"),
- "ipex_optimize_text_encoder": OptionInfo(False if not devices.backend == "ipex" else True, "Enable IPEX Optimize for Intel GPUs with Text Encoder"),
- "ipex_optimize_upscaler": OptionInfo(False if not devices.backend == "ipex" else True, "Enable IPEX Optimize for Intel GPUs with Upscalers"),
-
- "openvino_sep": OptionInfo("OpenVINO ", "", gr.HTML),
- "openvino_disable_model_caching": OptionInfo(False, "OpenVINO disable model caching"),
- "openvino_hetero_gpu": OptionInfo(False, "OpenVINO use Hetero Device for single inference with multiple devices"),
- "openvino_remove_cpu_from_hetero": OptionInfo(False, "OpenVINO remove CPU from Hetero Device"),
- "openvino_remove_igpu_from_hetero": OptionInfo(False, "OpenVINO remove iGPU from Hetero Device"),
- "nncf_compress_weights_mode": OptionInfo("INT8", "OpenVINO compress mode for NNCF (CPU Only)", gr.Radio, {"choices": ['INT8', 'INT4_SYM', 'INT4_ASYM', 'NF4']}),
- "nncf_compress_weights_raito": OptionInfo(1.0, "OpenVINO compress ratio for NNCF with 4-bit modes", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
-}))
-
-options_templates.update(options_section(('advanced', "Inference Settings"), {
- "token_merging_sep": OptionInfo("Token merging ", "", gr.HTML),
- "token_merging_ratio": OptionInfo(0.0, "Token merging ratio (txt2img)", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}),
- "token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio (img2img)", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}),
- "token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio (hires)", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}),
-
- "freeu_sep": OptionInfo("FreeU ", "", gr.HTML),
- "freeu_enabled": OptionInfo(False, "FreeU enabled"),
- "freeu_b1": OptionInfo(1.2, "1st stage backbone factor", gr.Slider, {"minimum": 1.0, "maximum": 2.0, "step": 0.01}),
- "freeu_b2": OptionInfo(1.4, "2nd stage backbone factor", gr.Slider, {"minimum": 1.0, "maximum": 2.0, "step": 0.01}),
- "freeu_s1": OptionInfo(0.9, "1st stage skip factor", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
- "freeu_s2": OptionInfo(0.2, "2nd stage skip factor", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
-
- "hypertile_sep": OptionInfo("HyperTile ", "", gr.HTML),
- "hypertile_unet_enabled": OptionInfo(False, "HyperTile for UNet enabled"),
- "hypertile_unet_tile": OptionInfo(256, "HyperTile for UNet tile size", gr.Slider, {"minimum": 0, "maximum": 1024, "step": 8}),
- "hypertile_vae_enabled": OptionInfo(False, "HyperTile for VAE enabled", gr.Checkbox),
- "hypertile_vae_tile": OptionInfo(128, "HyperTile for VAE tile size", gr.Slider, {"minimum": 0, "maximum": 1024, "step": 8}),
-
- "inference_other_sep": OptionInfo("Other ", "", gr.HTML),
- "batch_frame_mode": OptionInfo(False, "Process multiple images in batch in parallel"),
- "inference_mode": OptionInfo("no-grad", "Torch inference mode", gr.Radio, {"choices": ["no-grad", "inference-mode", "none"]}),
- "sd_vae_sliced_encode": OptionInfo(False, "VAE sliced encode"),
-}))
-
-options_templates.update(options_section(('diffusers', "Diffusers Settings"), {
- "diffusers_pipeline": OptionInfo('Autodetect', 'Diffusers pipeline', gr.Dropdown, lambda: {"choices": list(shared_items.get_pipelines()) }),
- "diffusers_move_base": OptionInfo(True, "Move base model to CPU when using refiner"),
- "diffusers_move_unet": OptionInfo(True, "Move base model to CPU when using VAE"),
- "diffusers_move_refiner": OptionInfo(True, "Move refiner model to CPU when not in use"),
- "diffusers_extract_ema": OptionInfo(True, "Use model EMA weights when possible"),
- "diffusers_generator_device": OptionInfo("GPU", "Generator device", gr.Radio, {"choices": ["GPU", "CPU", "Unset"]}),
- "diffusers_model_cpu_offload": OptionInfo(False, "Enable model CPU offload (--medvram)"),
- "diffusers_seq_cpu_offload": OptionInfo(False, "Enable sequential CPU offload (--lowvram)"),
- "diffusers_vae_upcast": OptionInfo("default", "VAE upcasting", gr.Radio, {"choices": ['default', 'true', 'false']}),
- "diffusers_vae_slicing": OptionInfo(True, "Enable VAE slicing"),
- "diffusers_vae_tiling": OptionInfo(True if not cmd_opts.use_openvino else False, "Enable VAE tiling"),
- "diffusers_attention_slicing": OptionInfo(False, "Enable attention slicing"),
- "diffusers_model_load_variant": OptionInfo("default", "Diffusers model loading variant", gr.Radio, {"choices": ['default', 'fp32', 'fp16']}),
- "diffusers_vae_load_variant": OptionInfo("default", "Diffusers VAE loading variant", gr.Radio, {"choices": ['default', 'fp32', 'fp16']}),
- "custom_diffusers_pipeline": OptionInfo('', 'Load custom Diffusers pipeline'),
- "diffusers_eval": OptionInfo(True, "Force model eval"),
- "diffusers_force_zeros": OptionInfo(True, "Force zeros for prompts when empty"),
- "diffusers_aesthetics_score": OptionInfo(False, "Require aesthetics score"),
- "diffusers_pooled": OptionInfo("default", "Diffusers SDXL pooled embeds (experimental)", gr.Radio, {"choices": ['default', 'weighted']}),
-}))
-
-options_templates.update(options_section(('system-paths', "System Paths"), {
- "models_paths_sep_options": OptionInfo("Models paths ", "", gr.HTML),
- "models_dir": OptionInfo('models', "Base path where all models are stored", folder=True),
- "ckpt_dir": OptionInfo(os.path.join(paths.models_path, 'Stable-diffusion'), "Folder with stable diffusion models", folder=True),
- "diffusers_dir": OptionInfo(os.path.join(paths.models_path, 'Diffusers'), "Folder with Hugggingface models", folder=True),
- "hfcache_dir": OptionInfo(os.path.join(os.path.expanduser('~'), '.cache', 'huggingface', 'hub'), "Folder for Hugggingface cache", folder=True),
- "vae_dir": OptionInfo(os.path.join(paths.models_path, 'VAE'), "Folder with VAE files", folder=True),
- "sd_lora": OptionInfo("", "Add LoRA to prompt", gr.Textbox, {"visible": False}),
- "lora_dir": OptionInfo(os.path.join(paths.models_path, 'Lora'), "Folder with LoRA network(s)", folder=True),
- "lyco_dir": OptionInfo(os.path.join(paths.models_path, 'LyCORIS'), "Folder with LyCORIS network(s)", gr.Text, {"visible": False}),
- "styles_dir": OptionInfo(os.path.join(paths.data_path, 'styles.csv'), "File or Folder with user-defined styles", folder=True),
- "embeddings_dir": OptionInfo(os.path.join(paths.models_path, 'embeddings'), "Folder with textual inversion embeddings", folder=True),
- "hypernetwork_dir": OptionInfo(os.path.join(paths.models_path, 'hypernetworks'), "Folder with Hypernetwork models", folder=True),
- "control_dir": OptionInfo(os.path.join(paths.models_path, 'control'), "Folder with Control models", folder=True),
- "codeformer_models_path": OptionInfo(os.path.join(paths.models_path, 'Codeformer'), "Folder with codeformer models", folder=True),
- "gfpgan_models_path": OptionInfo(os.path.join(paths.models_path, 'GFPGAN'), "Folder with GFPGAN models", folder=True),
- "esrgan_models_path": OptionInfo(os.path.join(paths.models_path, 'ESRGAN'), "Folder with ESRGAN models", folder=True),
- "bsrgan_models_path": OptionInfo(os.path.join(paths.models_path, 'BSRGAN'), "Folder with BSRGAN models", folder=True),
- "realesrgan_models_path": OptionInfo(os.path.join(paths.models_path, 'RealESRGAN'), "Folder with RealESRGAN models", folder=True),
- "scunet_models_path": OptionInfo(os.path.join(paths.models_path, 'SCUNet'), "Folder with SCUNet models", folder=True),
- "swinir_models_path": OptionInfo(os.path.join(paths.models_path, 'SwinIR'), "Folder with SwinIR models", folder=True),
- "ldsr_models_path": OptionInfo(os.path.join(paths.models_path, 'LDSR'), "Folder with LDSR models", folder=True),
- "clip_models_path": OptionInfo(os.path.join(paths.models_path, 'CLIP'), "Folder with CLIP models", folder=True),
-
- "other_paths_sep_options": OptionInfo("Other paths ", "", gr.HTML),
- "openvino_cache_path": OptionInfo('cache', "Directory for OpenVINO cache", folder=True),
- "temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default", folder=True),
- "clean_temp_dir_at_start": OptionInfo(True, "Cleanup non-default temporary directory when starting webui"),
-}))
-
-options_templates.update(options_section(('saving-images', "Image Options"), {
- "keep_incomplete": OptionInfo(True, "Keep incomplete images"),
- "samples_save": OptionInfo(True, "Always save all generated images"),
- "samples_format": OptionInfo('jpg', 'File format for generated images', gr.Dropdown, {"choices": ["jpg", "png", "webp", "tiff", "jp2"]}),
- "jpeg_quality": OptionInfo(90, "Quality for saved images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
- "img_max_size_mp": OptionInfo(250, "Maximum image size (MP)", gr.Slider, {"minimum": 100, "maximum": 2000, "step": 1}),
- "webp_lossless": OptionInfo(False, "Use lossless compression for webp images"),
- "save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
- "samples_save_zip": OptionInfo(True, "Create zip archive when downloading multiple images"),
-
- "image_sep_metadata": OptionInfo("Metadata/Logging ", "", gr.HTML),
- "image_metadata": OptionInfo(True, "Include metadata in saved images"),
- "save_txt": OptionInfo(False, "Create info file for each every image"),
- "save_log_fn": OptionInfo("", "Create JSON log file for each saved image", component_args=hide_dirs),
- "image_watermark_enabled": OptionInfo(False, "Include watermark in saved images"),
- "image_watermark": OptionInfo('', "Image watermark string"),
- "image_sep_grid": OptionInfo("Grid Options ", "", gr.HTML),
- "grid_save": OptionInfo(True, "Always save all generated image grids"),
- "grid_format": OptionInfo('jpg', 'File format for grids', gr.Dropdown, {"choices": ["jpg", "png", "webp", "tiff", "jp2"]}),
- "n_rows": OptionInfo(-1, "Grid row count", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
- "grid_background": OptionInfo("#000000", "Grid background color", ui_components.FormColorPicker, {}),
- "font": OptionInfo("", "Font file"),
- "font_color": OptionInfo("#FFFFFF", "Font color", ui_components.FormColorPicker, {}),
-
- "save_sep_options": OptionInfo("Intermediate Image Saving ", "", gr.HTML),
- "save_init_img": OptionInfo(False, "Save copy of img2img init images"),
- "save_images_before_highres_fix": OptionInfo(False, "Save copy of image before applying hires"),
- "save_images_before_refiner": OptionInfo(False, "Save copy of image before running refiner"),
- "save_images_before_face_restoration": OptionInfo(False, "Save copy of image before doing face restoration"),
- "save_images_before_color_correction": OptionInfo(False, "Save copy of image before applying color correction"),
- "save_mask": OptionInfo(False, "Save copy of the inpainting greyscale mask"),
- "save_mask_composite": OptionInfo(False, "Save copy of inpainting masked composite"),
-}))
-
-options_templates.update(options_section(('saving-paths', "Image Naming & Paths"), {
- "saving_sep_images": OptionInfo("Save options ", "", gr.HTML),
- "save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs),
- "use_original_name_batch": OptionInfo(True, "Use original name during batch process"),
- "use_upscaler_name_as_suffix": OptionInfo(True, "Use upscaler as suffix", gr.Checkbox, {"visible": False}),
- "samples_filename_pattern": OptionInfo("[seq]-[model_name]-[prompt_words]", "Images filename pattern", component_args=hide_dirs),
- "directories_max_prompt_words": OptionInfo(8, "Max words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 99, "step": 1, **hide_dirs}),
- "use_save_to_dirs_for_ui": OptionInfo(False, "Save images to a subdirectory when using Save button", gr.Checkbox, {"visible": False}),
- "save_to_dirs": OptionInfo(False, "Save images to a subdirectory"),
- "directories_filename_pattern": OptionInfo("[date]", "Directory name pattern", component_args=hide_dirs),
-
- "outdir_sep_dirs": OptionInfo("Directories ", "", gr.HTML),
- "outdir_samples": OptionInfo("", "Output directory for images", component_args=hide_dirs, folder=True),
- "outdir_txt2img_samples": OptionInfo("outputs/text", 'Directory for text generate', component_args=hide_dirs, folder=True),
- "outdir_img2img_samples": OptionInfo("outputs/image", 'Directory for image generate', component_args=hide_dirs, folder=True),
- "outdir_control_samples": OptionInfo("outputs/control", 'Directory for control generate', component_args=hide_dirs, folder=True),
- "outdir_extras_samples": OptionInfo("outputs/extras", 'Directory for processed images', component_args=hide_dirs, folder=True),
- "outdir_save": OptionInfo("outputs/save", "Directory for manually saved images", component_args=hide_dirs, folder=True),
- "outdir_video": OptionInfo("outputs/video", "Directory for videos", component_args=hide_dirs, folder=True),
- "outdir_init_images": OptionInfo("outputs/init-images", "Directory for init images", component_args=hide_dirs, folder=True),
-
- "outdir_sep_grids": OptionInfo("Grids ", "", gr.HTML),
- "grid_extended_filename": OptionInfo(True, "Add extended info (seed, prompt) to filename when saving grid", gr.Checkbox, {"visible": False}),
- "grid_save_to_dirs": OptionInfo(False, "Save grids to a subdirectory", gr.Checkbox, {"visible": False}),
- "outdir_grids": OptionInfo("", "Output directory for grids", component_args=hide_dirs, folder=True),
- "outdir_txt2img_grids": OptionInfo("outputs/grids", 'Output directory for txt2img grids', component_args=hide_dirs, folder=True),
- "outdir_img2img_grids": OptionInfo("outputs/grids", 'Output directory for img2img grids', component_args=hide_dirs, folder=True),
- "outdir_control_grids": OptionInfo("outputs/grids", 'Output directory for control grids', component_args=hide_dirs, folder=True),
-}))
-
-options_templates.update(options_section(('ui', "User Interface"), {
- "motd": OptionInfo(True, "Show MOTD"),
- "gradio_theme": OptionInfo("black-teal", "UI theme", gr.Dropdown, lambda: {"choices": theme.list_themes()}, refresh=theme.refresh_themes),
- "theme_style": OptionInfo("Auto", "Theme mode", gr.Radio, {"choices": ["Auto", "Dark", "Light"]}),
- "font_size": OptionInfo(16, "Font size", gr.Slider, {"minimum": 8, "maximum": 32, "step": 1, "visible": True}),
- "tooltips": OptionInfo("UI Tooltips", "UI tooltips", gr.Radio, {"choices": ["None", "Browser default", "UI tooltips"], "visible": False}),
- "gallery_height": OptionInfo("", "Gallery height", gr.Textbox),
- "compact_view": OptionInfo(False, "Compact view"),
- "return_grid": OptionInfo(True, "Show grid in results"),
- "return_mask": OptionInfo(False, "Inpainting include greyscale mask in results"),
- "return_mask_composite": OptionInfo(False, "Inpainting include masked composite in results"),
- "disable_weights_auto_swap": OptionInfo(True, "Do not change selected model when reading generation parameters"),
- "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
- "send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
- "keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001, "visible": False}),
- "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing ", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001, "visible": False}),
- "keyedit_delimiters": OptionInfo(".,\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters", gr.Textbox, { "visible": False }), # pylint: disable=anomalous-backslash-in-string
- "quicksettings_list": OptionInfo(["sd_model_checkpoint"] if backend == Backend.ORIGINAL else ["sd_model_checkpoint", "sd_model_refiner"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}),
- "ui_scripts_reorder": OptionInfo("", "UI scripts order", gr.Textbox, { "visible": False }),
-}))
-
-options_templates.update(options_section(('live-preview', "Live Previews"), {
- "show_progressbar": OptionInfo(True, "Show progressbar", gr.Checkbox, {"visible": False}),
- "live_previews_enable": OptionInfo(True, "Show live previews of the created image", gr.Checkbox, {"visible": False}),
- "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid", gr.Checkbox, {"visible": False}),
- "notification_audio_enable": OptionInfo(False, "Play a sound when images are finished generating"),
- "notification_audio_path": OptionInfo("html/notification.mp3","Path to notification sound", component_args=hide_dirs, folder=True),
- "show_progress_every_n_steps": OptionInfo(1, "Live preview display period", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}),
- "show_progress_type": OptionInfo("Approximate", "Live preview method", gr.Radio, {"choices": ["Simple", "Approximate", "TAESD", "Full VAE"]}),
- "live_preview_content": OptionInfo("Combined", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"], "visible": False}),
- "live_preview_refresh_period": OptionInfo(500, "Progress update period", gr.Slider, {"minimum": 0, "maximum": 5000, "step": 25}),
- "logmonitor_show": OptionInfo(True, "Show log view"),
- "logmonitor_refresh_period": OptionInfo(5000, "Log view update period", gr.Slider, {"minimum": 0, "maximum": 30000, "step": 25}),
-}))
-
-options_templates.update(options_section(('sampler-params', "Sampler Settings"), {
- "show_samplers": OptionInfo([], "Show samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}),
- 'eta_noise_seed_delta': OptionInfo(0, "Noise seed delta (eta)", gr.Number, {"precision": 0}),
- "scheduler_eta": OptionInfo(1.0, "Noise multiplier (eta)", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
- "schedulers_solver_order": OptionInfo(2, "Solver order (where applicable)", gr.Slider, {"minimum": 1, "maximum": 5, "step": 1}),
-
- # managed from ui.py for backend original
- "schedulers_brownian_noise": OptionInfo(True, "Use Brownian noise", gr.Checkbox, {"visible": False}),
- "schedulers_discard_penultimate": OptionInfo(True, "Discard penultimate sigma", gr.Checkbox, {"visible": False}),
- "schedulers_sigma": OptionInfo("default", "Sigma algorithm", gr.Radio, {"choices": ['default', 'karras', 'exponential', 'polyexponential'], "visible": False}),
- "schedulers_use_karras": OptionInfo(True, "Use Karras sigmas", gr.Checkbox, {"visible": False}),
- "schedulers_use_thresholding": OptionInfo(False, "Use dynamic thresholding", gr.Checkbox, {"visible": False}),
- "schedulers_use_loworder": OptionInfo(True, "Use simplified solvers in final steps", gr.Checkbox, {"visible": False}),
- "schedulers_prediction_type": OptionInfo("default", "Override model prediction type", gr.Radio, {"choices": ['default', 'epsilon', 'sample', 'v_prediction']}),
-
- # managed from ui.py for backend diffusers
- "schedulers_sep_diffusers": OptionInfo("Diffusers specific config ", "", gr.HTML),
- "schedulers_dpm_solver": OptionInfo("sde-dpmsolver++", "DPM solver algorithm", gr.Radio, {"choices": ['dpmsolver', 'dpmsolver++', 'sde-dpmsolver', 'sde-dpmsolver++']}),
- "schedulers_beta_schedule": OptionInfo("default", "Beta schedule", gr.Radio, {"choices": ['default', 'linear', 'scaled_linear', 'squaredcos_cap_v2']}),
- 'schedulers_beta_start': OptionInfo(0, "Beta start", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.00001}),
- 'schedulers_beta_end': OptionInfo(0, "Beta end", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.00001}),
- 'schedulers_timesteps_range': OptionInfo(1000, "Timesteps range", gr.Slider, {"minimum": 250, "maximum": 4000, "step": 1}),
- "schedulers_rescale_betas": OptionInfo(False, "Rescale betas with zero terminal SNR", gr.Checkbox),
-
- # managed from ui.py for backend original k-diffusion
- "schedulers_sep_kdiffusers": OptionInfo("K-Diffusion specific config ", "", gr.HTML),
- "always_batch_cond_uncond": OptionInfo(False, "Disable conditional batching enabled on low memory systems"),
- "enable_quantization": OptionInfo(True, "Enable quantization for sharper and cleaner results"),
- 's_churn': OptionInfo(0.0, "Sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
- 's_min_uncond': OptionInfo(0.0, "Sigma negative guidance minimum ", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}),
- 's_tmin': OptionInfo(0.0, "Sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
- 's_noise': OptionInfo(1.0, "Sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
- 's_min': OptionInfo(0.0, "Sigma min", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
- 's_max': OptionInfo(0.0, "Sigma max", gr.Slider, {"minimum": 0.0, "maximum": 100.0, "step": 1.0}),
- "schedulers_sep_compvis": OptionInfo("CompVis specific config ", "", gr.HTML),
- 'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}),
- 'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}),
- "ddim_discretize": OptionInfo('uniform', "DDIM discretize img2img", gr.Radio, {"choices": ['uniform', 'quad']}),
- # TODO pad_cond_uncond implementation missing for original backend
- "pad_cond_uncond": OptionInfo(True, "Pad prompt and negative prompt to be same length", gr.Checkbox, {"visible": False}),
- # TODO batch_cond-uncond implementation missing for original backend
- "batch_cond_uncond": OptionInfo(True, "Do conditional and unconditional denoising in one batch", gr.Checkbox, {"visible": False}),
-}))
-
-options_templates.update(options_section(('postprocessing', "Postprocessing"), {
- 'postprocessing_enable_in_main_ui': OptionInfo([], "Enable additional postprocessing operations", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
- 'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
-
- "postprocessing_sep_img2img": OptionInfo("Img2Img & Inpainting ", "", gr.HTML),
- "img2img_color_correction": OptionInfo(False, "Apply color correction"),
- "img2img_fix_steps": OptionInfo(False, "For image processing do exact number of steps as specified", gr.Checkbox, { "visible": False }),
- "img2img_background_color": OptionInfo("#ffffff", "Image transparent color fill", ui_components.FormColorPicker, {}),
- "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
- "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for image processing", gr.Slider, {"minimum": 0.1, "maximum": 1.5, "step": 0.01}),
- "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 8, "step": 1, "visible": False}),
-
- "postprocessing_sep_face_restoration": OptionInfo("Face Restoration ", "", gr.HTML),
- "face_restoration_model": OptionInfo("CodeFormer", "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
- "code_former_weight": OptionInfo(0.2, "CodeFormer weight parameter", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
- "face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
-
- "postprocessing_sep_upscalers": OptionInfo("Upscaling ", "", gr.HTML),
- "upscaler_unload": OptionInfo(False, "Unload upscaler after processing"),
- "upscaler_for_img2img": OptionInfo("None", "Default upscaler for image resize operations", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers], "visible": False}, refresh=refresh_upscalers),
- "upscaler_tile_size": OptionInfo(192, "Upscaler tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
- "upscaler_tile_overlap": OptionInfo(8, "Upscaler tile overlap", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}),
-}))
-
-options_templates.update(options_section(('training', "Training"), {
- "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible"),
- "pin_memory": OptionInfo(True, "Pin training dataset to memory"),
- "save_optimizer_state": OptionInfo(False, "Save resumable optimizer state when training"),
- "save_training_settings_to_txt": OptionInfo(True, "Save training settings to a text file on training start"),
- "dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
- "dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
- "embeddings_templates_dir": OptionInfo(os.path.join(paths.script_path, 'train', 'templates'), "Embeddings train templates directory", folder=True),
- "training_image_repeats_per_epoch": OptionInfo(1, "Image repeats per epoch", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
- "training_write_csv_every": OptionInfo(0, "Save loss CSV file every n steps"),
- "training_enable_tensorboard": OptionInfo(False, "Enable tensorboard logging"),
- "training_tensorboard_save_images": OptionInfo(False, "Save generated images within tensorboard"),
- "training_tensorboard_flush_every": OptionInfo(120, "Tensorboard flush period"),
-}))
-
-options_templates.update(options_section(('interrogate', "Interrogate"), {
- "interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
- "interrogate_return_ranks": OptionInfo(True, "Interrogate: include ranks of model tags matches in results"),
- "interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
- "interrogate_clip_min_length": OptionInfo(32, "Interrogate: minimum description length", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
- "interrogate_clip_max_length": OptionInfo(192, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
- "interrogate_clip_dict_limit": OptionInfo(2048, "CLIP: maximum number of lines in text file", gr.Slider, { "visible": False }),
- "interrogate_clip_skip_categories": OptionInfo(["artists", "movements", "flavors"], "Interrogate: skip categories", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types()}, refresh=modules.interrogate.category_types),
- "interrogate_deepbooru_score_threshold": OptionInfo(0.65, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
- "deepbooru_sort_alpha": OptionInfo(False, "Interrogate: deepbooru sort alphabetically"),
- "deepbooru_use_spaces": OptionInfo(False, "Use spaces for tags in deepbooru"),
- "deepbooru_escape": OptionInfo(True, "Escape brackets in deepbooru"),
- "deepbooru_filter_tags": OptionInfo("", "Filter out tags from deepbooru output"),
-}))
-
-options_templates.update(options_section(('extra_networks', "Extra Networks"), {
- "extra_networks_sep1": OptionInfo("Extra networks UI ", "", gr.HTML),
- "extra_networks": OptionInfo(["All"], "Extra networks", ui_components.DropdownMulti, lambda: {"choices": ['All'] + [en.title for en in extra_networks]}),
- "extra_networks_view": OptionInfo("gallery", "UI view", gr.Radio, {"choices": ["gallery", "list"]}),
- "extra_networks_card_cover": OptionInfo("sidebar", "UI position", gr.Radio, {"choices": ["cover", "inline", "sidebar"]}),
- "extra_networks_height": OptionInfo(53, "UI height (%)", gr.Slider, {"minimum": 10, "maximum": 100, "step": 1}),
- "extra_networks_sidebar_width": OptionInfo(35, "UI sidebar width (%)", gr.Slider, {"minimum": 10, "maximum": 80, "step": 1}),
- "extra_networks_card_size": OptionInfo(160, "UI card size (px)", gr.Slider, {"minimum": 20, "maximum": 2000, "step": 1}),
- "extra_networks_card_square": OptionInfo(True, "UI disable variable aspect ratio"),
- "extra_networks_card_fit": OptionInfo("cover", "UI image contain method", gr.Radio, {"choices": ["contain", "cover", "fill"], "visible": False}),
- "extra_networks_sep2": OptionInfo("Extra networks general ", "", gr.HTML),
- "extra_network_skip_indexing": OptionInfo(False, "Build info on first access", gr.Checkbox),
- "extra_networks_default_multiplier": OptionInfo(1.0, "Default multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
- "extra_networks_sep3": OptionInfo("Extra networks settings ", "", gr.HTML),
- "extra_networks_styles": OptionInfo(True, "Show built-in styles"),
- "lora_preferred_name": OptionInfo("filename", "LoRA preffered name", gr.Radio, {"choices": ["filename", "alias"]}),
- "lora_add_hashes_to_infotext": OptionInfo(True, "LoRA add hash info"),
- "lora_force_diffusers": OptionInfo(False if not cmd_opts.use_openvino else True, "LoRA use alternative loading method"),
- "lora_fuse_diffusers": OptionInfo(False if not cmd_opts.use_openvino else True, "LoRA use merge when using alternative method"),
- "lora_in_memory_limit": OptionInfo(0, "LoRA memory cache", gr.Slider, {"minimum": 0, "maximum": 24, "step": 1}),
- "lora_functional": OptionInfo(False, "Use Kohya method for handling multiple LoRA", gr.Checkbox, { "visible": False }),
- "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, { "choices": ["None"], "visible": False }),
-}))
-
-options_templates.update(options_section((None, "Hidden options"), {
- "disabled_extensions": OptionInfo([], "Disable these extensions"),
- "disable_all_extensions": OptionInfo("none", "Disable all extensions (preserves the list of disabled extensions)", gr.Radio, {"choices": ["none", "user", "all"]}),
- "sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
-}))
-
-options_templates.update()
-
-
-class Options:
- data = None
- data_labels = options_templates
- filename = None
- typemap = {int: float}
-
- def __init__(self):
- self.data = {k: v.default for k, v in self.data_labels.items()}
-
- def __setattr__(self, key, value): # pylint: disable=inconsistent-return-statements
- if self.data is not None:
- if key in self.data or key in self.data_labels:
- if cmd_opts.freeze:
- log.warning(f'Settings are frozen: {key}')
- return
- if cmd_opts.hide_ui_dir_config and key in restricted_opts:
- log.warning(f'Settings key is restricted: {key}')
- return
- self.data[key] = value
- return
- return super(Options, self).__setattr__(key, value) # pylint: disable=super-with-arguments
-
- def __getattr__(self, item):
- if self.data is not None:
- if item in self.data:
- return self.data[item]
- if item in self.data_labels:
- return self.data_labels[item].default
- return super(Options, self).__getattribute__(item) # pylint: disable=super-with-arguments
-
- def set(self, key, value):
- """sets an option and calls its onchange callback, returning True if the option changed and False otherwise"""
- oldval = self.data.get(key, None)
- if oldval is None:
- oldval = self.data_labels[key].default
- if oldval == value:
- return False
- try:
- setattr(self, key, value)
- except RuntimeError:
- return False
- if self.data_labels[key].onchange is not None:
- try:
- self.data_labels[key].onchange()
- except Exception as e:
- log.error(f'Error in onchange callback: {key} {value} {e}')
- setattr(self, key, oldval)
- return False
- return True
-
- def get_default(self, key):
- """returns the default value for the key"""
- data_label = self.data_labels.get(key)
- return data_label.default if data_label is not None else None
-
- def save(self, filename=None, silent=False):
- if filename is None:
- filename = self.filename
- if cmd_opts.freeze:
- log.warning(f'Settings saving is disabled: {filename}')
- return
- try:
- # output = json.dumps(self.data, indent=2)
- diff = {}
- unused_settings = []
- for k, v in self.data.items():
- if k in self.data_labels:
- if type(v) is list:
- diff[k] = v
- if self.data_labels[k].default != v:
- diff[k] = v
- else:
- if k not in compatibility_opts:
- unused_settings.append(k)
- diff[k] = v
- writefile(diff, filename, silent=silent)
- if len(unused_settings) > 0:
- log.debug(f"Unused settings: {unused_settings}")
- except Exception as e:
- log.error(f'Saving settings failed: {filename} {e}')
-
- def same_type(self, x, y):
- if x is None or y is None:
- return True
- type_x = self.typemap.get(type(x), type(x))
- type_y = self.typemap.get(type(y), type(y))
- return type_x == type_y
-
- def load(self, filename=None):
- if filename is None:
- filename = self.filename
- if not os.path.isfile(filename):
- log.debug(f'Created default config: {filename}')
- self.save(filename)
- return
- self.data = readfile(filename, lock=True)
- if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None:
- self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')]
- unknown_settings = []
- for k, v in self.data.items():
- info = self.data_labels.get(k, None)
- if info is not None and not self.same_type(info.default, v):
- log.error(f"Error: bad setting value: {k}: {v} ({type(v).__name__}; expected {type(info.default).__name__})")
- if info is None and k not in compatibility_opts:
- unknown_settings.append(k)
- if len(unknown_settings) > 0:
- log.debug(f"Unknown settings: {unknown_settings}")
-
- def onchange(self, key, func, call=True):
- item = self.data_labels.get(key)
- item.onchange = func
- if call:
- func()
-
- def dumpjson(self):
- d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()}
- metadata = {
- k: {
- "is_stored": k in self.data and self.data[k] != self.data_labels[k].default, # pylint: disable=unnecessary-dict-index-lookup
- "tab_name": v.section[0]
- } for k, v in self.data_labels.items()
- }
- return json.dumps({"values": d, "metadata": metadata})
-
- def add_option(self, key, info):
- self.data_labels[key] = info
-
- def reorder(self):
- """reorder settings so that all items related to section always go together"""
- section_ids = {}
- settings_items = self.data_labels.items()
- for _k, item in settings_items:
- if item.section not in section_ids:
- section_ids[item.section] = len(section_ids)
- self.data_labels = dict(sorted(settings_items, key=lambda x: section_ids[x[1].section]))
-
- def cast_value(self, key, value):
- """casts an arbitrary to the same type as this setting's value with key
- Example: cast_value("eta_noise_seed_delta", "12") -> returns 12 (an int rather than str)
- """
- if value is None:
- return None
- default_value = self.data_labels[key].default
- if default_value is None:
- default_value = getattr(self, key, None)
- if default_value is None:
- return None
- expected_type = type(default_value)
- if expected_type == bool and value == "False":
- value = False
- elif expected_type == type(value):
- pass
- else:
- value = expected_type(value)
- return value
-
-profiler = None
-opts = Options()
-config_filename = cmd_opts.config
-opts.load(config_filename)
-cmd_opts = cmd_args.compatibility_args(opts, cmd_opts)
-if cmd_opts.use_xformers:
- opts.data['cross_attention_optimization'] = 'xFormers'
-opts.data['uni_pc_lower_order_final'] = opts.schedulers_use_loworder # compatibility
-opts.data['uni_pc_order'] = opts.schedulers_solver_order # compatibility
-log.info(f'Engine: backend={backend} compute={devices.backend} device={devices.get_optimal_device_name()} attention="{opts.cross_attention_optimization}" mode={devices.inference_context.__name__}')
-log.info(f'Device: {print_dict(devices.get_gpu_info())}')
-
-prompt_styles = modules.styles.StyleDatabase(opts)
-cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or (cmd_opts.server_name or False)) and not cmd_opts.insecure
-devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer'])
-device = devices.device
-batch_cond_uncond = opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
-parallel_processing_allowed = not cmd_opts.lowvram
-mem_mon = modules.memmon.MemUsageMonitor("MemMon", devices.device)
-max_workers = 4
-if devices.backend == "directml":
- directml_do_hijack()
-
-
-class TotalTQDM: # compatibility with previous global-tqdm
- # import tqdm
- def __init__(self):
- pass
- def reset(self):
- pass
- def update(self):
- pass
- def updateTotal(self, new_total):
- pass
- def clear(self):
- pass
-total_tqdm = TotalTQDM()
-
-
-def restart_server(restart=True):
- if demo is None:
- return
- log.warning('Server shutdown requested')
- try:
- sys.tracebacklimit = 0
- stdout = io.StringIO()
- stderr = io.StringIO()
- with contextlib.redirect_stdout(stdout), contextlib.redirect_stdout(stderr):
- demo.server.wants_restart = restart
- demo.server.should_exit = True
- demo.server.force_exit = True
- demo.close(verbose=False)
- demo.server.close()
- demo.fns = []
- time.sleep(1)
- sys.tracebacklimit = 100
- # os._exit(0)
- except (Exception, BaseException) as e:
- log.error(f'Server shutdown error: {e}')
- if restart:
- log.info('Server will restart')
-
-
-def restore_defaults(restart=True):
- if os.path.exists(cmd_opts.config):
- log.info('Restoring server defaults')
- os.remove(cmd_opts.config)
- restart_server(restart)
-
-
-def listdir(path):
- if not os.path.exists(path):
- return []
- mtime = os.path.getmtime(path)
- if path in dir_timestamps and mtime == dir_timestamps[path]:
- return dir_cache[path]
- else:
- dir_cache[path] = [os.path.join(path, f) for f in os.listdir(path)]
- dir_timestamps[path] = mtime
- return dir_cache[path]
-
-
-def walk_files(path, allowed_extensions=None):
- if not os.path.exists(path):
- return
- if allowed_extensions is not None:
- allowed_extensions = set(allowed_extensions)
- for root, _dirs, files in os.walk(path, followlinks=True):
- for filename in files:
- if allowed_extensions is not None:
- _, ext = os.path.splitext(filename)
- if ext not in allowed_extensions:
- continue
- yield os.path.join(root, filename)
-
-
-def html_path(filename):
- return os.path.join(paths.script_path, "html", filename)
-
-
-def html(filename):
- path = html_path(filename)
- if os.path.exists(path):
- with open(path, encoding="utf8") as file:
- return file.read()
- return ""
-
-
-def get_version():
- version = None
- if version is None:
- try:
- import subprocess
- res = subprocess.run('git log --pretty=format:"%h %ad" -1 --date=short', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True)
- ver = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else ' '
- githash, updated = ver.split(' ')
- res = subprocess.run('git remote get-url origin', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True)
- origin = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else ''
- res = subprocess.run('git rev-parse --abbrev-ref HEAD', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True)
- branch = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else ''
- version = {
- 'app': 'sd.next',
- 'updated': updated,
- 'hash': githash,
- 'url': origin.replace('\n', '') + '/tree/' + branch.replace('\n', '')
- }
- except Exception:
- version = { 'app': 'sd.next' }
- return version
-
-
-def req(url_addr, headers = None, **kwargs):
- if headers is None:
- headers = { 'Content-type': 'application/json' }
- try:
- res = requests.get(url_addr, timeout=30, headers=headers, verify=False, allow_redirects=True, **kwargs)
- except Exception as e:
- log.error(f'HTTP request error: url={url_addr} {e}')
- res = { 'status_code': 500, 'text': f'HTTP request error: url={url_addr} {e}' }
- res = SimpleNamespace(**res)
- return res
-
-
-sd_model: diffusers.DiffusionPipeline = None # dummy and overwritten by class
-sd_refiner: diffusers.DiffusionPipeline = None # dummy and overwritten by class
-sd_model_type: str = '' # dummy and overwritten by class
-sd_refiner_type: str = '' # dummy and overwritten by class
-compiled_model_state = None
-
-from modules.modeldata import Shared # pylint: disable=ungrouped-imports
-sys.modules[__name__].__class__ = Shared
+import io
+import os
+import sys
+import time
+import json
+import contextlib
+from types import SimpleNamespace
+from urllib.parse import urlparse
+from enum import Enum
+import requests
+import gradio as gr
+import fasteners
+import orjson
+import diffusers
+from rich.console import Console
+from modules import errors, shared_items, shared_state, cmd_args, ui_components, theme
+from modules.paths import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # pylint: disable=W0611
+from modules.dml import memory_providers, default_memory_provider, directml_do_hijack
+import modules.interrogate
+import modules.memmon
+import modules.styles
+import modules.devices as devices # pylint: disable=R0402
+import modules.paths as paths
+from installer import print_dict
+from installer import log as central_logger # pylint: disable=E0611
+
+
+errors.install([gr])
+demo: gr.Blocks = None
+log = central_logger
+progress_print_out = sys.stdout
+parser = cmd_args.parser
+url = 'https://github.com/vladmandic/automatic'
+cmd_opts, _ = parser.parse_known_args()
+hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
+xformers_available = False
+clip_model = None
+interrogator = modules.interrogate.InterrogateModels("interrogate")
+sd_upscalers = []
+face_restorers = []
+tab_names = []
+extra_networks = []
+options_templates = {}
+hypernetworks = {}
+loaded_hypernetworks = []
+settings_components = None
+latent_upscale_default_mode = "None"
+latent_upscale_modes = {
+ "Latent Nearest": {"mode": "nearest", "antialias": False},
+ "Latent Nearest-exact": {"mode": "nearest-exact", "antialias": False},
+ "Latent Area": {"mode": "area", "antialias": False},
+ "Latent Bilinear": {"mode": "bilinear", "antialias": False},
+ "Latent Bicubic": {"mode": "bicubic", "antialias": False},
+ "Latent Bilinear antialias": {"mode": "bilinear", "antialias": True},
+ "Latent Bicubic antialias": {"mode": "bicubic", "antialias": True},
+ # "Latent Linear": {"mode": "linear", "antialias": False}, # not supported for latents with channels=4
+ # "Latent Trilinear": {"mode": "trilinear", "antialias": False}, # not supported for latents with channels=4
+}
+restricted_opts = {
+ "samples_filename_pattern",
+ "directories_filename_pattern",
+ "outdir_samples",
+ "outdir_txt2img_samples",
+ "outdir_img2img_samples",
+ "outdir_extras_samples",
+ "outdir_grids",
+ "outdir_txt2img_grids",
+ "outdir_save",
+ "outdir_init_images"
+}
+resize_modes = ["None", "Fixed", "Crop", "Fill", "Latent"]
+compatibility_opts = ['clip_skip', 'uni_pc_lower_order_final', 'uni_pc_order']
+console = Console(log_time=True, log_time_format='%H:%M:%S-%f')
+dir_timestamps = {}
+dir_cache = {}
+
+
+class Backend(Enum):
+ ORIGINAL = 1
+ DIFFUSERS = 2
+
+
+state = shared_state.State()
+if not hasattr(cmd_opts, "use_openvino"):
+ cmd_opts.use_openvino = False
+
+
+def readfile(filename, silent=False, lock=False):
+ data = {}
+ lock_file = None
+ locked = False
+ try:
+ # if not os.path.exists(filename):
+ # return {}
+ t0 = time.time()
+ if lock:
+ lock_file = fasteners.InterProcessReaderWriterLock(f"{filename}.lock", logger=log)
+ locked = lock_file.acquire_read_lock(blocking=True, timeout=3)
+ with open(filename, "rb") as file:
+ b = file.read()
+ data = orjson.loads(b) # pylint: disable=no-member
+ # if type(data) is str:
+ # data = json.loads(data)
+ t1 = time.time()
+ if not silent:
+ log.debug(f'Read: file="{filename}" json={len(data)} bytes={os.path.getsize(filename)} time={t1-t0:.3f}')
+ except Exception as e:
+ if not silent:
+ log.error(f'Reading failed: {filename} {e}')
+ finally:
+ if lock_file is not None:
+ lock_file.release_read_lock()
+ if locked and os.path.exists(f"{filename}.lock"):
+ os.remove(f"{filename}.lock")
+ return data
+
+
+def writefile(data, filename, mode='w', silent=False, atomic=False):
+ lock = None
+ locked = False
+ import tempfile
+
+
+ def default(obj):
+ log.error(f"Saving: {filename} not a valid object: {obj}")
+ return str(obj)
+
+ try:
+ t0 = time.time()
+ # skipkeys=True, ensure_ascii=True, check_circular=True, allow_nan=True
+ if type(data) == dict:
+ output = json.dumps(data, indent=2, default=default)
+ elif type(data) == list:
+ output = json.dumps(data, indent=2, default=default)
+ elif isinstance(data, object):
+ simple = {}
+ for k in data.__dict__:
+ if data.__dict__[k] is not None:
+ simple[k] = data.__dict__[k]
+ output = json.dumps(simple, indent=2, default=default)
+ else:
+ raise ValueError('not a valid object')
+ lock = fasteners.InterProcessReaderWriterLock(f"{filename}.lock", logger=log)
+ locked = lock.acquire_write_lock(blocking=True, timeout=3)
+ if atomic:
+ with tempfile.NamedTemporaryFile(mode=mode, encoding="utf8", delete=False, dir=os.path.dirname(filename)) as f:
+ f.write(output)
+ f.flush()
+ os.fsync(f.fileno())
+ os.replace(f.name, filename)
+ else:
+ with open(filename, mode=mode, encoding="utf8") as file:
+ file.write(output)
+ t1 = time.time()
+ if not silent:
+ log.debug(f'Save: file="{filename}" json={len(data)} bytes={len(output)} time={t1-t0:.3f}')
+ except Exception as e:
+ log.error(f'Saving failed: {filename} {e}')
+ errors.display(e, 'Saving failed')
+ finally:
+ if lock is not None:
+ lock.release_read_lock()
+ if locked and os.path.exists(f"{filename}.lock"):
+ os.remove(f"{filename}.lock")
+
+
+# early select backend
+default_backend = 'original'
+early_opts = readfile(cmd_opts.config, silent=True)
+early_backend = early_opts.get('sd_backend', default_backend)
+backend = Backend.DIFFUSERS if early_backend.lower() == 'diffusers' else Backend.ORIGINAL
+if cmd_opts.backend is not None: # override with args
+ backend = Backend.DIFFUSERS if cmd_opts.backend.lower() == 'diffusers' else Backend.ORIGINAL
+if cmd_opts.use_openvino: # override for openvino
+ backend = Backend.DIFFUSERS
+
+
+class OptionInfo:
+ def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, folder=None, submit=None, comment_before='', comment_after=''):
+ self.default = default
+ self.label = label
+ self.component = component
+ self.component_args = component_args
+ self.onchange = onchange
+ self.section = section
+ self.refresh = refresh
+ self.folder = folder
+ self.comment_before = comment_before # HTML text that will be added after label in UI
+ self.comment_after = comment_after # HTML text that will be added before label in UI
+ self.submit = submit
+
+ def link(self, label, uri):
+ self.comment_before += f"[{label} ]"
+ return self
+
+ def js(self, label, js_func):
+ self.comment_before += f"[{label} ]"
+ return self
+
+ def info(self, info):
+ self.comment_after += f"({info}) "
+ return self
+
+ def html(self, info):
+ self.comment_after += f"{info} "
+ return self
+
+ def needs_restart(self):
+ self.comment_after += " (requires restart) "
+ return self
+
+
+def options_section(section_identifier, options_dict):
+ for v in options_dict.values():
+ v.section = section_identifier
+ return options_dict
+
+
+def list_checkpoint_tiles():
+ import modules.sd_models # pylint: disable=W0621
+ return modules.sd_models.checkpoint_tiles()
+
+default_checkpoint = list_checkpoint_tiles()[0] if len(list_checkpoint_tiles()) > 0 else "model.ckpt"
+
+
+def is_url(string):
+ parsed_url = urlparse(string)
+ return all([parsed_url.scheme, parsed_url.netloc])
+
+
+def reload_hypernetworks():
+ from modules.hypernetworks import hypernetwork
+ global hypernetworks # pylint: disable=W0603
+ hypernetworks = hypernetwork.list_hypernetworks(opts.hypernetwork_dir)
+
+
+def refresh_checkpoints():
+ import modules.sd_models # pylint: disable=W0621
+ return modules.sd_models.list_models()
+
+
+def refresh_vaes():
+ import modules.sd_vae # pylint: disable=W0621
+ modules.sd_vae.refresh_vae_list()
+
+
+def refresh_upscalers():
+ import modules.modelloader # pylint: disable=W0621
+ modules.modelloader.load_upscalers()
+
+
+def list_samplers():
+ import modules.sd_samplers # pylint: disable=W0621
+ modules.sd_samplers.set_samplers()
+ return modules.sd_samplers.all_samplers
+
+
+def temp_disable_extensions():
+ disable_safe = ['sd-webui-controlnet', 'multidiffusion-upscaler-for-automatic1111', 'a1111-sd-webui-lycoris', 'sd-webui-agent-scheduler', 'clip-interrogator-ext', 'stable-diffusion-webui-rembg', 'sd-extension-chainner', 'stable-diffusion-webui-images-browser']
+ disable_diffusers = ['sd-webui-controlnet', 'multidiffusion-upscaler-for-automatic1111', 'a1111-sd-webui-lycoris', 'sd-webui-animatediff']
+ disable_original = []
+ disabled = []
+ if cmd_opts.safe:
+ for ext in disable_safe:
+ if ext not in opts.disabled_extensions:
+ disabled.append(ext)
+ if backend == Backend.DIFFUSERS:
+ for ext in disable_diffusers:
+ if ext not in opts.disabled_extensions:
+ disabled.append(ext)
+ if backend == Backend.ORIGINAL:
+ for ext in disable_original:
+ if ext not in opts.disabled_extensions:
+ disabled.append(ext)
+ cmd_opts.controlnet_loglevel = 'WARNING'
+ return disabled
+
+
+if devices.backend == "cpu":
+ cross_attention_optimization_default = "Doggettx's"
+elif devices.backend == "mps":
+ cross_attention_optimization_default = "Doggettx's"
+elif devices.backend == "ipex":
+ cross_attention_optimization_default = "Scaled-Dot-Product"
+elif devices.backend == "directml":
+ cross_attention_optimization_default = "Sub-quadratic"
+elif devices.backend == "rocm":
+ cross_attention_optimization_default = "Sub-quadratic"
+else: # cuda
+ cross_attention_optimization_default ="Scaled-Dot-Product"
+
+
+options_templates.update(options_section(('sd', "Execution & Models"), {
+ "sd_backend": OptionInfo(default_backend, "Execution backend", gr.Radio, {"choices": ["original", "diffusers"] }),
+ "sd_model_checkpoint": OptionInfo(default_checkpoint, "Base model", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
+ "sd_model_refiner": OptionInfo('None', "Refiner model", gr.Dropdown, lambda: {"choices": ['None'] + list_checkpoint_tiles()}, refresh=refresh_checkpoints),
+ "sd_vae": OptionInfo("Automatic", "VAE model", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list),
+ "sd_checkpoint_autoload": OptionInfo(True, "Model autoload on server start"),
+ "sd_model_dict": OptionInfo('None', "Use baseline data from a different model", gr.Dropdown, lambda: {"choices": ['None'] + list_checkpoint_tiles()}, refresh=refresh_checkpoints),
+ "stream_load": OptionInfo(False, "Load models using stream loading method", gr.Checkbox, {"visible": backend == Backend.ORIGINAL }),
+ "model_reuse_dict": OptionInfo(False, "When loading models attempt to reuse previous model dictionary", gr.Checkbox, {"visible": False}),
+ "prompt_attention": OptionInfo("Full parser", "Prompt attention parser", gr.Radio, {"choices": ["Full parser", "Compel parser", "A1111 parser", "Fixed attention"] }),
+ "prompt_mean_norm": OptionInfo(True, "Prompt attention mean normalization", gr.Checkbox, {"visible": backend == Backend.ORIGINAL }),
+ "comma_padding_backtrack": OptionInfo(20, "Prompt padding for long prompts", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1, "visible": backend == Backend.ORIGINAL }),
+ "sd_checkpoint_cache": OptionInfo(0, "Number of cached models", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1, "visible": backend == Backend.ORIGINAL }),
+ "sd_vae_checkpoint_cache": OptionInfo(0, "Number of cached VAEs", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1, "visible": False}),
+ "sd_disable_ckpt": OptionInfo(False, "Disallow usage of models in ckpt format"),
+}))
+
+options_templates.update(options_section(('cuda', "Compute Settings"), {
+ "math_sep": OptionInfo("Execution precision ", "", gr.HTML),
+ "precision": OptionInfo("Autocast", "Precision type", gr.Radio, {"choices": ["Autocast", "Full"]}),
+ "cuda_dtype": OptionInfo("FP32" if sys.platform == "darwin" or cmd_opts.use_openvino else "BF16" if devices.backend == "ipex" else "FP16", "Device precision type", gr.Radio, {"choices": ["FP32", "FP16", "BF16"]}),
+ "no_half": OptionInfo(False if not cmd_opts.use_openvino else True, "Use full precision for model (--no-half)", None, None, None),
+ "no_half_vae": OptionInfo(False if not cmd_opts.use_openvino else True, "Use full precision for VAE (--no-half-vae)"),
+ "upcast_sampling": OptionInfo(False if sys.platform != "darwin" else True, "Enable upcast sampling"),
+ "upcast_attn": OptionInfo(False, "Enable upcast cross attention layer"),
+ "cuda_cast_unet": OptionInfo(False, "Use fixed UNet precision"),
+ "disable_nan_check": OptionInfo(True, "Disable NaN check in produced images/latent spaces", gr.Checkbox, {"visible": False}),
+ "rollback_vae": OptionInfo(False, "Attempt VAE roll back when produced NaN values"),
+
+ "cross_attention_sep": OptionInfo("Cross-attention ", "", gr.HTML),
+ "cross_attention_optimization": OptionInfo(cross_attention_optimization_default, "Cross-attention optimization method", gr.Radio, lambda: {"choices": shared_items.list_crossattention() }),
+ "cross_attention_options": OptionInfo([], "Cross-attention advanced options", gr.CheckboxGroup, {"choices": ['xFormers enable flash Attention', 'SDP disable memory attention']}),
+ "sub_quad_sep": OptionInfo("Sub-quadratic options ", "", gr.HTML),
+ "sub_quad_q_chunk_size": OptionInfo(512, "cross-attention query chunk size", gr.Slider, {"minimum": 16, "maximum": 8192, "step": 8}),
+ "sub_quad_kv_chunk_size": OptionInfo(512, "cross-attention kv chunk size", gr.Slider, {"minimum": 0, "maximum": 8192, "step": 8}),
+ "sub_quad_chunk_threshold": OptionInfo(80, "cross-attention chunking threshold", gr.Slider, {"minimum": 0, "maximum": 100, "step": 1}),
+
+ "other_sep": OptionInfo("Execution precision ", "", gr.HTML),
+ "opt_channelslast": OptionInfo(False, "Use channels last as torch memory format "),
+ "cudnn_benchmark": OptionInfo(False, "Enable full-depth cuDNN benchmark feature"),
+ "diffusers_fuse_projections": OptionInfo(False, "Enable fused projections"),
+ "torch_gc_threshold": OptionInfo(80, "Memory usage threshold before running Torch GC", gr.Slider, {"minimum": 0, "maximum": 100, "step": 1}),
+
+ "cuda_compile_sep": OptionInfo("Model Compile ", "", gr.HTML),
+ "cuda_compile": OptionInfo(False if not cmd_opts.use_openvino else True, "Compile UNet"),
+ "cuda_compile_vae": OptionInfo(False if not cmd_opts.use_openvino else True, "Compile VAE"),
+ "cuda_compile_text_encoder": OptionInfo(False, "Compile Text Encoder"),
+ "cuda_compile_upscaler": OptionInfo(False if not cmd_opts.use_openvino else True, "Compile Upscaler"),
+ "cuda_compile_backend": OptionInfo("none" if not cmd_opts.use_openvino else "openvino_fx", "Model compile backend", gr.Radio, {"choices": ['none', 'inductor', 'cudagraphs', 'aot_ts_nvfuser', 'hidet', 'ipex', 'openvino_fx', 'stable-fast']}),
+ "cuda_compile_mode": OptionInfo("default", "Model compile mode", gr.Radio, {"choices": ['default', 'reduce-overhead', 'max-autotune', 'max-autotune-no-cudagraphs']}),
+ "cuda_compile_fullgraph": OptionInfo(False, "Model compile fullgraph"),
+ "cuda_compile_precompile": OptionInfo(False, "Model compile precompile"),
+ "cuda_compile_verbose": OptionInfo(False, "Model compile verbose mode"),
+ "cuda_compile_errors": OptionInfo(True, "Model compile suppress errors"),
+ "diffusers_quantization": OptionInfo(False, "Enable dynamic quantization with torchao"),
+
+ "nncf_sep": OptionInfo("NNCF ", "", gr.HTML),
+ "nncf_compress_weights": OptionInfo(False, "Compress Model weights with NNCF"),
+ "nncf_compress_vae_weights": OptionInfo(False, "Compress VAE weights with NNCF"),
+ "nncf_compress_text_encoder_weights": OptionInfo(False, "Compress Text Encoder weights with NNCF"),
+
+ "directml_sep": OptionInfo("DirectML ", "", gr.HTML),
+ "directml_memory_provider": OptionInfo(default_memory_provider, 'DirectML memory stats provider', gr.Radio, {"choices": memory_providers}),
+ "directml_catch_nan": OptionInfo(False, "DirectML retry specific operation when NaN is produced if possible. (makes generation slower)"),
+
+ "ipex_sep": OptionInfo("IPEX ", "", gr.HTML),
+ "ipex_optimize": OptionInfo(False if not devices.backend == "ipex" else True, "Enable IPEX Optimize for Intel GPUs with UNet"),
+ "ipex_optimize_vae": OptionInfo(False if not devices.backend == "ipex" else True, "Enable IPEX Optimize for Intel GPUs with VAE"),
+ "ipex_optimize_text_encoder": OptionInfo(False if not devices.backend == "ipex" else True, "Enable IPEX Optimize for Intel GPUs with Text Encoder"),
+ "ipex_optimize_upscaler": OptionInfo(False if not devices.backend == "ipex" else True, "Enable IPEX Optimize for Intel GPUs with Upscalers"),
+
+ "openvino_sep": OptionInfo("OpenVINO ", "", gr.HTML),
+ "openvino_disable_model_caching": OptionInfo(False, "OpenVINO disable model caching"),
+ "openvino_hetero_gpu": OptionInfo(False, "OpenVINO use Hetero Device for single inference with multiple devices"),
+ "openvino_remove_cpu_from_hetero": OptionInfo(False, "OpenVINO remove CPU from Hetero Device"),
+ "openvino_remove_igpu_from_hetero": OptionInfo(False, "OpenVINO remove iGPU from Hetero Device"),
+ "nncf_compress_weights_mode": OptionInfo("INT8", "OpenVINO compress mode for NNCF (CPU Only)", gr.Radio, {"choices": ['INT8', 'INT4_SYM', 'INT4_ASYM', 'NF4']}),
+ "nncf_compress_weights_raito": OptionInfo(1.0, "OpenVINO compress ratio for NNCF with 4-bit modes", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
+}))
+
+options_templates.update(options_section(('advanced', "Inference Settings"), {
+ "token_merging_sep": OptionInfo("Token merging ", "", gr.HTML),
+ "token_merging_ratio": OptionInfo(0.0, "Token merging ratio (txt2img)", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}),
+ "token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio (img2img)", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}),
+ "token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio (hires)", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}),
+
+ "freeu_sep": OptionInfo("FreeU ", "", gr.HTML),
+ "freeu_enabled": OptionInfo(False, "FreeU enabled"),
+ "freeu_b1": OptionInfo(1.2, "1st stage backbone factor", gr.Slider, {"minimum": 1.0, "maximum": 2.0, "step": 0.01}),
+ "freeu_b2": OptionInfo(1.4, "2nd stage backbone factor", gr.Slider, {"minimum": 1.0, "maximum": 2.0, "step": 0.01}),
+ "freeu_s1": OptionInfo(0.9, "1st stage skip factor", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ "freeu_s2": OptionInfo(0.2, "2nd stage skip factor", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+
+ "hypertile_sep": OptionInfo("HyperTile ", "", gr.HTML),
+ "hypertile_unet_enabled": OptionInfo(False, "HyperTile for UNet enabled"),
+ "hypertile_unet_tile": OptionInfo(256, "HyperTile for UNet tile size", gr.Slider, {"minimum": 0, "maximum": 1024, "step": 8}),
+ "hypertile_vae_enabled": OptionInfo(False, "HyperTile for VAE enabled", gr.Checkbox),
+ "hypertile_vae_tile": OptionInfo(128, "HyperTile for VAE tile size", gr.Slider, {"minimum": 0, "maximum": 1024, "step": 8}),
+
+ "inference_other_sep": OptionInfo("Other ", "", gr.HTML),
+ "batch_frame_mode": OptionInfo(False, "Process multiple images in batch in parallel"),
+ "inference_mode": OptionInfo("no-grad", "Torch inference mode", gr.Radio, {"choices": ["no-grad", "inference-mode", "none"]}),
+ "sd_vae_sliced_encode": OptionInfo(False, "VAE sliced encode"),
+}))
+
+options_templates.update(options_section(('diffusers', "Diffusers Settings"), {
+ "diffusers_pipeline": OptionInfo('Autodetect', 'Diffusers pipeline', gr.Dropdown, lambda: {"choices": list(shared_items.get_pipelines()) }),
+ "diffusers_move_base": OptionInfo(True, "Move base model to CPU when using refiner"),
+ "diffusers_move_unet": OptionInfo(True, "Move base model to CPU when using VAE"),
+ "diffusers_move_refiner": OptionInfo(True, "Move refiner model to CPU when not in use"),
+ "diffusers_extract_ema": OptionInfo(True, "Use model EMA weights when possible"),
+ "diffusers_generator_device": OptionInfo("GPU", "Generator device", gr.Radio, {"choices": ["GPU", "CPU", "Unset"]}),
+ "diffusers_model_cpu_offload": OptionInfo(False, "Enable model CPU offload (--medvram)"),
+ "diffusers_seq_cpu_offload": OptionInfo(False, "Enable sequential CPU offload (--lowvram)"),
+ "diffusers_vae_upcast": OptionInfo("default", "VAE upcasting", gr.Radio, {"choices": ['default', 'true', 'false']}),
+ "diffusers_vae_slicing": OptionInfo(True, "Enable VAE slicing"),
+ "diffusers_vae_tiling": OptionInfo(True if not cmd_opts.use_openvino else False, "Enable VAE tiling"),
+ "diffusers_attention_slicing": OptionInfo(False, "Enable attention slicing"),
+ "diffusers_model_load_variant": OptionInfo("default", "Diffusers model loading variant", gr.Radio, {"choices": ['default', 'fp32', 'fp16']}),
+ "diffusers_vae_load_variant": OptionInfo("default", "Diffusers VAE loading variant", gr.Radio, {"choices": ['default', 'fp32', 'fp16']}),
+ "custom_diffusers_pipeline": OptionInfo('', 'Load custom Diffusers pipeline'),
+ "diffusers_eval": OptionInfo(True, "Force model eval"),
+ "diffusers_force_zeros": OptionInfo(True, "Force zeros for prompts when empty"),
+ "diffusers_aesthetics_score": OptionInfo(False, "Require aesthetics score"),
+ "diffusers_pooled": OptionInfo("default", "Diffusers SDXL pooled embeds (experimental)", gr.Radio, {"choices": ['default', 'weighted']}),
+}))
+
+options_templates.update(options_section(('system-paths', "System Paths"), {
+ "models_paths_sep_options": OptionInfo("Models paths ", "", gr.HTML),
+ "models_dir": OptionInfo('models', "Base path where all models are stored", folder=True),
+ "ckpt_dir": OptionInfo(os.path.join(paths.models_path, 'Stable-diffusion'), "Folder with stable diffusion models", folder=True),
+ "diffusers_dir": OptionInfo(os.path.join(paths.models_path, 'Diffusers'), "Folder with Hugggingface models", folder=True),
+ "hfcache_dir": OptionInfo(os.path.join(os.path.expanduser('~'), '.cache', 'huggingface', 'hub'), "Folder for Hugggingface cache", folder=True),
+ "vae_dir": OptionInfo(os.path.join(paths.models_path, 'VAE'), "Folder with VAE files", folder=True),
+ "sd_lora": OptionInfo("", "Add LoRA to prompt", gr.Textbox, {"visible": False}),
+ "lora_dir": OptionInfo(os.path.join(paths.models_path, 'Lora'), "Folder with LoRA network(s)", folder=True),
+ "lyco_dir": OptionInfo(os.path.join(paths.models_path, 'LyCORIS'), "Folder with LyCORIS network(s)", gr.Text, {"visible": False}),
+ "styles_dir": OptionInfo(os.path.join(paths.data_path, 'styles.csv'), "File or Folder with user-defined styles", folder=True),
+ "embeddings_dir": OptionInfo(os.path.join(paths.models_path, 'embeddings'), "Folder with textual inversion embeddings", folder=True),
+ "hypernetwork_dir": OptionInfo(os.path.join(paths.models_path, 'hypernetworks'), "Folder with Hypernetwork models", folder=True),
+ "control_dir": OptionInfo(os.path.join(paths.models_path, 'control'), "Folder with Control models", folder=True),
+ "codeformer_models_path": OptionInfo(os.path.join(paths.models_path, 'Codeformer'), "Folder with codeformer models", folder=True),
+ "gfpgan_models_path": OptionInfo(os.path.join(paths.models_path, 'GFPGAN'), "Folder with GFPGAN models", folder=True),
+ "esrgan_models_path": OptionInfo(os.path.join(paths.models_path, 'ESRGAN'), "Folder with ESRGAN models", folder=True),
+ "bsrgan_models_path": OptionInfo(os.path.join(paths.models_path, 'BSRGAN'), "Folder with BSRGAN models", folder=True),
+ "realesrgan_models_path": OptionInfo(os.path.join(paths.models_path, 'RealESRGAN'), "Folder with RealESRGAN models", folder=True),
+ "scunet_models_path": OptionInfo(os.path.join(paths.models_path, 'SCUNet'), "Folder with SCUNet models", folder=True),
+ "swinir_models_path": OptionInfo(os.path.join(paths.models_path, 'SwinIR'), "Folder with SwinIR models", folder=True),
+ "ldsr_models_path": OptionInfo(os.path.join(paths.models_path, 'LDSR'), "Folder with LDSR models", folder=True),
+ "clip_models_path": OptionInfo(os.path.join(paths.models_path, 'CLIP'), "Folder with CLIP models", folder=True),
+
+ "other_paths_sep_options": OptionInfo("Other paths ", "", gr.HTML),
+ "openvino_cache_path": OptionInfo('cache', "Directory for OpenVINO cache", folder=True),
+ "temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default", folder=True),
+ "clean_temp_dir_at_start": OptionInfo(True, "Cleanup non-default temporary directory when starting webui"),
+}))
+
+options_templates.update(options_section(('saving-images', "Image Options"), {
+ "keep_incomplete": OptionInfo(True, "Keep incomplete images"),
+ "samples_save": OptionInfo(True, "Always save all generated images"),
+ "samples_format": OptionInfo('jpg', 'File format for generated images', gr.Dropdown, {"choices": ["jpg", "png", "webp", "tiff", "jp2"]}),
+ "jpeg_quality": OptionInfo(90, "Quality for saved images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
+ "img_max_size_mp": OptionInfo(250, "Maximum image size (MP)", gr.Slider, {"minimum": 100, "maximum": 2000, "step": 1}),
+ "webp_lossless": OptionInfo(False, "Use lossless compression for webp images"),
+ "save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
+ "samples_save_zip": OptionInfo(True, "Create zip archive when downloading multiple images"),
+
+ "image_sep_metadata": OptionInfo("Metadata/Logging ", "", gr.HTML),
+ "image_metadata": OptionInfo(True, "Include metadata in saved images"),
+ "save_txt": OptionInfo(False, "Create info file for each every image"),
+ "save_log_fn": OptionInfo("", "Create JSON log file for each saved image", component_args=hide_dirs),
+ "image_watermark_enabled": OptionInfo(False, "Include watermark in saved images"),
+ "image_watermark": OptionInfo('', "Image watermark string"),
+ "image_sep_grid": OptionInfo("Grid Options ", "", gr.HTML),
+ "grid_save": OptionInfo(True, "Always save all generated image grids"),
+ "grid_format": OptionInfo('jpg', 'File format for grids', gr.Dropdown, {"choices": ["jpg", "png", "webp", "tiff", "jp2"]}),
+ "n_rows": OptionInfo(-1, "Grid row count", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
+ "grid_background": OptionInfo("#000000", "Grid background color", ui_components.FormColorPicker, {}),
+ "font": OptionInfo("", "Font file"),
+ "font_color": OptionInfo("#FFFFFF", "Font color", ui_components.FormColorPicker, {}),
+
+ "save_sep_options": OptionInfo("Intermediate Image Saving ", "", gr.HTML),
+ "save_init_img": OptionInfo(False, "Save copy of img2img init images"),
+ "save_images_before_highres_fix": OptionInfo(False, "Save copy of image before applying hires"),
+ "save_images_before_refiner": OptionInfo(False, "Save copy of image before running refiner"),
+ "save_images_before_face_restoration": OptionInfo(False, "Save copy of image before doing face restoration"),
+ "save_images_before_color_correction": OptionInfo(False, "Save copy of image before applying color correction"),
+ "save_mask": OptionInfo(False, "Save copy of the inpainting greyscale mask"),
+ "save_mask_composite": OptionInfo(False, "Save copy of inpainting masked composite"),
+}))
+
+options_templates.update(options_section(('saving-paths', "Image Naming & Paths"), {
+ "saving_sep_images": OptionInfo("Save options ", "", gr.HTML),
+ "save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs),
+ "use_original_name_batch": OptionInfo(True, "Use original name during batch process"),
+ "use_upscaler_name_as_suffix": OptionInfo(True, "Use upscaler as suffix", gr.Checkbox, {"visible": False}),
+ "samples_filename_pattern": OptionInfo("[seq]-[model_name]-[prompt_words]", "Images filename pattern", component_args=hide_dirs),
+ "directories_max_prompt_words": OptionInfo(8, "Max words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 99, "step": 1, **hide_dirs}),
+ "use_save_to_dirs_for_ui": OptionInfo(False, "Save images to a subdirectory when using Save button", gr.Checkbox, {"visible": False}),
+ "save_to_dirs": OptionInfo(False, "Save images to a subdirectory"),
+ "directories_filename_pattern": OptionInfo("[date]", "Directory name pattern", component_args=hide_dirs),
+
+ "outdir_sep_dirs": OptionInfo("Directories ", "", gr.HTML),
+ "outdir_samples": OptionInfo("", "Output directory for images", component_args=hide_dirs, folder=True),
+ "outdir_txt2img_samples": OptionInfo("outputs/text", 'Directory for text generate', component_args=hide_dirs, folder=True),
+ "outdir_img2img_samples": OptionInfo("outputs/image", 'Directory for image generate', component_args=hide_dirs, folder=True),
+ "outdir_control_samples": OptionInfo("outputs/control", 'Directory for control generate', component_args=hide_dirs, folder=True),
+ "outdir_extras_samples": OptionInfo("outputs/extras", 'Directory for processed images', component_args=hide_dirs, folder=True),
+ "outdir_save": OptionInfo("outputs/save", "Directory for manually saved images", component_args=hide_dirs, folder=True),
+ "outdir_video": OptionInfo("outputs/video", "Directory for videos", component_args=hide_dirs, folder=True),
+ "outdir_init_images": OptionInfo("outputs/init-images", "Directory for init images", component_args=hide_dirs, folder=True),
+
+ "outdir_sep_grids": OptionInfo("Grids ", "", gr.HTML),
+ "grid_extended_filename": OptionInfo(True, "Add extended info (seed, prompt) to filename when saving grid", gr.Checkbox, {"visible": False}),
+ "grid_save_to_dirs": OptionInfo(False, "Save grids to a subdirectory", gr.Checkbox, {"visible": False}),
+ "outdir_grids": OptionInfo("", "Output directory for grids", component_args=hide_dirs, folder=True),
+ "outdir_txt2img_grids": OptionInfo("outputs/grids", 'Output directory for txt2img grids', component_args=hide_dirs, folder=True),
+ "outdir_img2img_grids": OptionInfo("outputs/grids", 'Output directory for img2img grids', component_args=hide_dirs, folder=True),
+ "outdir_control_grids": OptionInfo("outputs/grids", 'Output directory for control grids', component_args=hide_dirs, folder=True),
+}))
+
+options_templates.update(options_section(('ui', "User Interface"), {
+ "motd": OptionInfo(True, "Show MOTD"),
+ "gradio_theme": OptionInfo("black-teal", "UI theme", gr.Dropdown, lambda: {"choices": theme.list_themes()}, refresh=theme.refresh_themes),
+ "theme_style": OptionInfo("Auto", "Theme mode", gr.Radio, {"choices": ["Auto", "Dark", "Light"]}),
+ "font_size": OptionInfo(16, "Font size", gr.Slider, {"minimum": 8, "maximum": 32, "step": 1, "visible": True}),
+ "tooltips": OptionInfo("UI Tooltips", "UI tooltips", gr.Radio, {"choices": ["None", "Browser default", "UI tooltips"], "visible": False}),
+ "gallery_height": OptionInfo("", "Gallery height", gr.Textbox),
+ "compact_view": OptionInfo(False, "Compact view"),
+ "return_grid": OptionInfo(True, "Show grid in results"),
+ "return_mask": OptionInfo(False, "Inpainting include greyscale mask in results"),
+ "return_mask_composite": OptionInfo(False, "Inpainting include masked composite in results"),
+ "disable_weights_auto_swap": OptionInfo(True, "Do not change selected model when reading generation parameters"),
+ "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
+ "send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
+ "keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001, "visible": False}),
+ "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing ", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001, "visible": False}),
+ "keyedit_delimiters": OptionInfo(".,\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters", gr.Textbox, { "visible": False }), # pylint: disable=anomalous-backslash-in-string
+ "quicksettings_list": OptionInfo(["sd_model_checkpoint"] if backend == Backend.ORIGINAL else ["sd_model_checkpoint", "sd_model_refiner"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}),
+ "ui_scripts_reorder": OptionInfo("", "UI scripts order", gr.Textbox, { "visible": False }),
+}))
+
+options_templates.update(options_section(('live-preview', "Live Previews"), {
+ "show_progressbar": OptionInfo(True, "Show progressbar", gr.Checkbox, {"visible": False}),
+ "live_previews_enable": OptionInfo(True, "Show live previews of the created image", gr.Checkbox, {"visible": False}),
+ "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid", gr.Checkbox, {"visible": False}),
+ "notification_audio_enable": OptionInfo(False, "Play a sound when images are finished generating"),
+ "notification_audio_path": OptionInfo("html/notification.mp3","Path to notification sound", component_args=hide_dirs, folder=True),
+ "show_progress_every_n_steps": OptionInfo(1, "Live preview display period", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}),
+ "show_progress_type": OptionInfo("Approximate", "Live preview method", gr.Radio, {"choices": ["Simple", "Approximate", "TAESD", "Full VAE"]}),
+ "live_preview_content": OptionInfo("Combined", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"], "visible": False}),
+ "live_preview_refresh_period": OptionInfo(500, "Progress update period", gr.Slider, {"minimum": 0, "maximum": 5000, "step": 25}),
+ "logmonitor_show": OptionInfo(True, "Show log view"),
+ "logmonitor_refresh_period": OptionInfo(5000, "Log view update period", gr.Slider, {"minimum": 0, "maximum": 30000, "step": 25}),
+}))
+
+options_templates.update(options_section(('sampler-params', "Sampler Settings"), {
+ "show_samplers": OptionInfo([], "Show samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}),
+ 'eta_noise_seed_delta': OptionInfo(0, "Noise seed delta (eta)", gr.Number, {"precision": 0}),
+ "scheduler_eta": OptionInfo(1.0, "Noise multiplier (eta)", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ "schedulers_solver_order": OptionInfo(2, "Solver order (where applicable)", gr.Slider, {"minimum": 1, "maximum": 5, "step": 1}),
+
+ # managed from ui.py for backend original
+ "schedulers_brownian_noise": OptionInfo(True, "Use Brownian noise", gr.Checkbox, {"visible": False}),
+ "schedulers_discard_penultimate": OptionInfo(True, "Discard penultimate sigma", gr.Checkbox, {"visible": False}),
+ "schedulers_sigma": OptionInfo("default", "Sigma algorithm", gr.Radio, {"choices": ['default', 'karras', 'exponential', 'polyexponential'], "visible": False}),
+ "schedulers_use_karras": OptionInfo(True, "Use Karras sigmas", gr.Checkbox, {"visible": False}),
+ "schedulers_use_thresholding": OptionInfo(False, "Use dynamic thresholding", gr.Checkbox, {"visible": False}),
+ "schedulers_use_loworder": OptionInfo(True, "Use simplified solvers in final steps", gr.Checkbox, {"visible": False}),
+ "schedulers_prediction_type": OptionInfo("default", "Override model prediction type", gr.Radio, {"choices": ['default', 'epsilon', 'sample', 'v_prediction']}),
+
+ # managed from ui.py for backend diffusers
+ "schedulers_sep_diffusers": OptionInfo("Diffusers specific config ", "", gr.HTML),
+ "schedulers_dpm_solver": OptionInfo("sde-dpmsolver++", "DPM solver algorithm", gr.Radio, {"choices": ['dpmsolver', 'dpmsolver++', 'sde-dpmsolver', 'sde-dpmsolver++']}),
+ "schedulers_beta_schedule": OptionInfo("default", "Beta schedule", gr.Radio, {"choices": ['default', 'linear', 'scaled_linear', 'squaredcos_cap_v2']}),
+ 'schedulers_beta_start': OptionInfo(0, "Beta start", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.00001}),
+ 'schedulers_beta_end': OptionInfo(0, "Beta end", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.00001}),
+ 'schedulers_timesteps_range': OptionInfo(1000, "Timesteps range", gr.Slider, {"minimum": 250, "maximum": 4000, "step": 1}),
+ "schedulers_rescale_betas": OptionInfo(False, "Rescale betas with zero terminal SNR", gr.Checkbox),
+
+ # managed from ui.py for backend original k-diffusion
+ "schedulers_sep_kdiffusers": OptionInfo("K-Diffusion specific config ", "", gr.HTML),
+ "always_batch_cond_uncond": OptionInfo(False, "Disable conditional batching enabled on low memory systems"),
+ "enable_quantization": OptionInfo(True, "Enable quantization for sharper and cleaner results"),
+ 's_churn': OptionInfo(0.0, "Sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ 's_min_uncond': OptionInfo(0.0, "Sigma negative guidance minimum ", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}),
+ 's_tmin': OptionInfo(0.0, "Sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ 's_noise': OptionInfo(1.0, "Sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ 's_min': OptionInfo(0.0, "Sigma min", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ 's_max': OptionInfo(0.0, "Sigma max", gr.Slider, {"minimum": 0.0, "maximum": 100.0, "step": 1.0}),
+ "schedulers_sep_compvis": OptionInfo("CompVis specific config ", "", gr.HTML),
+ 'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}),
+ 'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}),
+ "ddim_discretize": OptionInfo('uniform', "DDIM discretize img2img", gr.Radio, {"choices": ['uniform', 'quad']}),
+ # TODO pad_cond_uncond implementation missing for original backend
+ "pad_cond_uncond": OptionInfo(True, "Pad prompt and negative prompt to be same length", gr.Checkbox, {"visible": False}),
+ # TODO batch_cond-uncond implementation missing for original backend
+ "batch_cond_uncond": OptionInfo(True, "Do conditional and unconditional denoising in one batch", gr.Checkbox, {"visible": False}),
+}))
+
+options_templates.update(options_section(('postprocessing', "Postprocessing"), {
+ 'postprocessing_enable_in_main_ui': OptionInfo([], "Enable additional postprocessing operations", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
+ 'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
+
+ "postprocessing_sep_img2img": OptionInfo("Img2Img & Inpainting ", "", gr.HTML),
+ "img2img_color_correction": OptionInfo(False, "Apply color correction"),
+ "img2img_fix_steps": OptionInfo(False, "For image processing do exact number of steps as specified", gr.Checkbox, { "visible": False }),
+ "img2img_background_color": OptionInfo("#ffffff", "Image transparent color fill", ui_components.FormColorPicker, {}),
+ "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for image processing", gr.Slider, {"minimum": 0.1, "maximum": 1.5, "step": 0.01}),
+ "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 8, "step": 1, "visible": False}),
+
+ "postprocessing_sep_face_restoration": OptionInfo("Face Restoration ", "", gr.HTML),
+ "face_restoration_model": OptionInfo("CodeFormer", "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
+ "code_former_weight": OptionInfo(0.2, "CodeFormer weight parameter", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
+ "face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
+
+ "postprocessing_sep_upscalers": OptionInfo("Upscaling ", "", gr.HTML),
+ "upscaler_unload": OptionInfo(False, "Unload upscaler after processing"),
+ "upscaler_for_img2img": OptionInfo("None", "Default upscaler for image resize operations", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers], "visible": False}, refresh=refresh_upscalers),
+ "upscaler_tile_size": OptionInfo(192, "Upscaler tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
+ "upscaler_tile_overlap": OptionInfo(8, "Upscaler tile overlap", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}),
+}))
+
+options_templates.update(options_section(('training', "Training"), {
+ "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible"),
+ "pin_memory": OptionInfo(True, "Pin training dataset to memory"),
+ "save_optimizer_state": OptionInfo(False, "Save resumable optimizer state when training"),
+ "save_training_settings_to_txt": OptionInfo(True, "Save training settings to a text file on training start"),
+ "dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
+ "dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
+ "embeddings_templates_dir": OptionInfo(os.path.join(paths.script_path, 'train', 'templates'), "Embeddings train templates directory", folder=True),
+ "training_image_repeats_per_epoch": OptionInfo(1, "Image repeats per epoch", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
+ "training_write_csv_every": OptionInfo(0, "Save loss CSV file every n steps"),
+ "training_enable_tensorboard": OptionInfo(False, "Enable tensorboard logging"),
+ "training_tensorboard_save_images": OptionInfo(False, "Save generated images within tensorboard"),
+ "training_tensorboard_flush_every": OptionInfo(120, "Tensorboard flush period"),
+}))
+
+options_templates.update(options_section(('interrogate', "Interrogate"), {
+ "interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
+ "interrogate_return_ranks": OptionInfo(True, "Interrogate: include ranks of model tags matches in results"),
+ "interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
+ "interrogate_clip_min_length": OptionInfo(32, "Interrogate: minimum description length", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
+ "interrogate_clip_max_length": OptionInfo(192, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
+ "interrogate_clip_dict_limit": OptionInfo(2048, "CLIP: maximum number of lines in text file", gr.Slider, { "visible": False }),
+ "interrogate_clip_skip_categories": OptionInfo(["artists", "movements", "flavors"], "Interrogate: skip categories", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types()}, refresh=modules.interrogate.category_types),
+ "interrogate_deepbooru_score_threshold": OptionInfo(0.65, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
+ "deepbooru_sort_alpha": OptionInfo(False, "Interrogate: deepbooru sort alphabetically"),
+ "deepbooru_use_spaces": OptionInfo(False, "Use spaces for tags in deepbooru"),
+ "deepbooru_escape": OptionInfo(True, "Escape brackets in deepbooru"),
+ "deepbooru_filter_tags": OptionInfo("", "Filter out tags from deepbooru output"),
+}))
+
+options_templates.update(options_section(('extra_networks', "Extra Networks"), {
+ "extra_networks_sep1": OptionInfo("Extra networks UI ", "", gr.HTML),
+ "extra_networks": OptionInfo(["All"], "Extra networks", ui_components.DropdownMulti, lambda: {"choices": ['All'] + [en.title for en in extra_networks]}),
+ "extra_networks_view": OptionInfo("gallery", "UI view", gr.Radio, {"choices": ["gallery", "list"]}),
+ "extra_networks_card_cover": OptionInfo("sidebar", "UI position", gr.Radio, {"choices": ["cover", "inline", "sidebar"]}),
+ "extra_networks_height": OptionInfo(53, "UI height (%)", gr.Slider, {"minimum": 10, "maximum": 100, "step": 1}),
+ "extra_networks_sidebar_width": OptionInfo(35, "UI sidebar width (%)", gr.Slider, {"minimum": 10, "maximum": 80, "step": 1}),
+ "extra_networks_card_size": OptionInfo(160, "UI card size (px)", gr.Slider, {"minimum": 20, "maximum": 2000, "step": 1}),
+ "extra_networks_card_square": OptionInfo(True, "UI disable variable aspect ratio"),
+ "extra_networks_card_fit": OptionInfo("cover", "UI image contain method", gr.Radio, {"choices": ["contain", "cover", "fill"], "visible": False}),
+ "extra_networks_sep2": OptionInfo("Extra networks general ", "", gr.HTML),
+ "extra_network_skip_indexing": OptionInfo(False, "Build info on first access", gr.Checkbox),
+ "extra_networks_default_multiplier": OptionInfo(1.0, "Default multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ "extra_networks_sep3": OptionInfo("Extra networks settings ", "", gr.HTML),
+ "extra_networks_styles": OptionInfo(True, "Show built-in styles"),
+ "lora_preferred_name": OptionInfo("filename", "LoRA preffered name", gr.Radio, {"choices": ["filename", "alias"]}),
+ "lora_add_hashes_to_infotext": OptionInfo(True, "LoRA add hash info"),
+ "lora_force_diffusers": OptionInfo(False if not cmd_opts.use_openvino else True, "LoRA use alternative loading method"),
+ "lora_fuse_diffusers": OptionInfo(False if not cmd_opts.use_openvino else True, "LoRA use merge when using alternative method"),
+ "lora_in_memory_limit": OptionInfo(0, "LoRA memory cache", gr.Slider, {"minimum": 0, "maximum": 24, "step": 1}),
+ "lora_functional": OptionInfo(False, "Use Kohya method for handling multiple LoRA", gr.Checkbox, { "visible": False }),
+ "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, { "choices": ["None"], "visible": False }),
+}))
+
+options_templates.update(options_section((None, "Hidden options"), {
+ "disabled_extensions": OptionInfo([], "Disable these extensions"),
+ "disable_all_extensions": OptionInfo("none", "Disable all extensions (preserves the list of disabled extensions)", gr.Radio, {"choices": ["none", "user", "all"]}),
+ "sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
+}))
+
+options_templates.update()
+
+
+class Options:
+ data = None
+ data_labels = options_templates
+ filename = None
+ typemap = {int: float}
+
+ def __init__(self):
+ self.data = {k: v.default for k, v in self.data_labels.items()}
+
+ def __setattr__(self, key, value): # pylint: disable=inconsistent-return-statements
+ if self.data is not None:
+ if key in self.data or key in self.data_labels:
+ if cmd_opts.freeze:
+ log.warning(f'Settings are frozen: {key}')
+ return
+ if cmd_opts.hide_ui_dir_config and key in restricted_opts:
+ log.warning(f'Settings key is restricted: {key}')
+ return
+ self.data[key] = value
+ return
+ return super(Options, self).__setattr__(key, value) # pylint: disable=super-with-arguments
+
+ def __getattr__(self, item):
+ if self.data is not None:
+ if item in self.data:
+ return self.data[item]
+ if item in self.data_labels:
+ return self.data_labels[item].default
+ return super(Options, self).__getattribute__(item) # pylint: disable=super-with-arguments
+
+ def set(self, key, value):
+ """sets an option and calls its onchange callback, returning True if the option changed and False otherwise"""
+ oldval = self.data.get(key, None)
+ if oldval is None:
+ oldval = self.data_labels[key].default
+ if oldval == value:
+ return False
+ try:
+ setattr(self, key, value)
+ except RuntimeError:
+ return False
+ if self.data_labels[key].onchange is not None:
+ try:
+ self.data_labels[key].onchange()
+ except Exception as e:
+ log.error(f'Error in onchange callback: {key} {value} {e}')
+ setattr(self, key, oldval)
+ return False
+ return True
+
+ def get_default(self, key):
+ """returns the default value for the key"""
+ data_label = self.data_labels.get(key)
+ return data_label.default if data_label is not None else None
+
+ def save(self, filename=None, silent=False):
+ if filename is None:
+ filename = self.filename
+ if cmd_opts.freeze:
+ log.warning(f'Settings saving is disabled: {filename}')
+ return
+ try:
+ # output = json.dumps(self.data, indent=2)
+ diff = {}
+ unused_settings = []
+ for k, v in self.data.items():
+ if k in self.data_labels:
+ if type(v) is list:
+ diff[k] = v
+ if self.data_labels[k].default != v:
+ diff[k] = v
+ else:
+ if k not in compatibility_opts:
+ unused_settings.append(k)
+ diff[k] = v
+ writefile(diff, filename, silent=silent)
+ if len(unused_settings) > 0:
+ log.debug(f"Unused settings: {unused_settings}")
+ except Exception as e:
+ log.error(f'Saving settings failed: {filename} {e}')
+
+ def same_type(self, x, y):
+ if x is None or y is None:
+ return True
+ type_x = self.typemap.get(type(x), type(x))
+ type_y = self.typemap.get(type(y), type(y))
+ return type_x == type_y
+
+ def load(self, filename=None):
+ if filename is None:
+ filename = self.filename
+ if not os.path.isfile(filename):
+ log.debug(f'Created default config: {filename}')
+ self.save(filename)
+ return
+ self.data = readfile(filename, lock=True)
+ if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None:
+ self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')]
+ unknown_settings = []
+ for k, v in self.data.items():
+ info = self.data_labels.get(k, None)
+ if info is not None and not self.same_type(info.default, v):
+ log.error(f"Error: bad setting value: {k}: {v} ({type(v).__name__}; expected {type(info.default).__name__})")
+ if info is None and k not in compatibility_opts:
+ unknown_settings.append(k)
+ if len(unknown_settings) > 0:
+ log.debug(f"Unknown settings: {unknown_settings}")
+
+ def onchange(self, key, func, call=True):
+ item = self.data_labels.get(key)
+ item.onchange = func
+ if call:
+ func()
+
+ def dumpjson(self):
+ d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()}
+ metadata = {
+ k: {
+ "is_stored": k in self.data and self.data[k] != self.data_labels[k].default, # pylint: disable=unnecessary-dict-index-lookup
+ "tab_name": v.section[0]
+ } for k, v in self.data_labels.items()
+ }
+ return json.dumps({"values": d, "metadata": metadata})
+
+ def add_option(self, key, info):
+ self.data_labels[key] = info
+
+ def reorder(self):
+ """reorder settings so that all items related to section always go together"""
+ section_ids = {}
+ settings_items = self.data_labels.items()
+ for _k, item in settings_items:
+ if item.section not in section_ids:
+ section_ids[item.section] = len(section_ids)
+ self.data_labels = dict(sorted(settings_items, key=lambda x: section_ids[x[1].section]))
+
+ def cast_value(self, key, value):
+ """casts an arbitrary to the same type as this setting's value with key
+ Example: cast_value("eta_noise_seed_delta", "12") -> returns 12 (an int rather than str)
+ """
+ if value is None:
+ return None
+ default_value = self.data_labels[key].default
+ if default_value is None:
+ default_value = getattr(self, key, None)
+ if default_value is None:
+ return None
+ expected_type = type(default_value)
+ if expected_type == bool and value == "False":
+ value = False
+ elif expected_type == type(value):
+ pass
+ else:
+ value = expected_type(value)
+ return value
+
+profiler = None
+opts = Options()
+config_filename = cmd_opts.config
+opts.load(config_filename)
+cmd_opts = cmd_args.compatibility_args(opts, cmd_opts)
+if cmd_opts.use_xformers:
+ opts.data['cross_attention_optimization'] = 'xFormers'
+opts.data['uni_pc_lower_order_final'] = opts.schedulers_use_loworder # compatibility
+opts.data['uni_pc_order'] = opts.schedulers_solver_order # compatibility
+log.info(f'Engine: backend={backend} compute={devices.backend} device={devices.get_optimal_device_name()} attention="{opts.cross_attention_optimization}" mode={devices.inference_context.__name__}')
+log.info(f'Device: {print_dict(devices.get_gpu_info())}')
+
+prompt_styles = modules.styles.StyleDatabase(opts)
+cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or (cmd_opts.server_name or False)) and not cmd_opts.insecure
+devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer'])
+device = devices.device
+batch_cond_uncond = opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
+parallel_processing_allowed = not cmd_opts.lowvram
+mem_mon = modules.memmon.MemUsageMonitor("MemMon", devices.device)
+max_workers = 4
+if devices.backend == "directml":
+ directml_do_hijack()
+
+
+class TotalTQDM: # compatibility with previous global-tqdm
+ # import tqdm
+ def __init__(self):
+ pass
+ def reset(self):
+ pass
+ def update(self):
+ pass
+ def updateTotal(self, new_total):
+ pass
+ def clear(self):
+ pass
+total_tqdm = TotalTQDM()
+
+
+def restart_server(restart=True):
+ if demo is None:
+ return
+ log.warning('Server shutdown requested')
+ try:
+ sys.tracebacklimit = 0
+ stdout = io.StringIO()
+ stderr = io.StringIO()
+ with contextlib.redirect_stdout(stdout), contextlib.redirect_stdout(stderr):
+ demo.server.wants_restart = restart
+ demo.server.should_exit = True
+ demo.server.force_exit = True
+ demo.close(verbose=False)
+ demo.server.close()
+ demo.fns = []
+ time.sleep(1)
+ sys.tracebacklimit = 100
+ # os._exit(0)
+ except (Exception, BaseException) as e:
+ log.error(f'Server shutdown error: {e}')
+ if restart:
+ log.info('Server will restart')
+
+
+def restore_defaults(restart=True):
+ if os.path.exists(cmd_opts.config):
+ log.info('Restoring server defaults')
+ os.remove(cmd_opts.config)
+ restart_server(restart)
+
+
+def listdir(path):
+ if not os.path.exists(path):
+ return []
+ mtime = os.path.getmtime(path)
+ if path in dir_timestamps and mtime == dir_timestamps[path]:
+ return dir_cache[path]
+ else:
+ dir_cache[path] = [os.path.join(path, f) for f in os.listdir(path)]
+ dir_timestamps[path] = mtime
+ return dir_cache[path]
+
+
+def walk_files(path, allowed_extensions=None):
+ if not os.path.exists(path):
+ return
+ if allowed_extensions is not None:
+ allowed_extensions = set(allowed_extensions)
+ for root, _dirs, files in os.walk(path, followlinks=True):
+ for filename in files:
+ if allowed_extensions is not None:
+ _, ext = os.path.splitext(filename)
+ if ext not in allowed_extensions:
+ continue
+ yield os.path.join(root, filename)
+
+
+def html_path(filename):
+ return os.path.join(paths.script_path, "html", filename)
+
+
+def html(filename):
+ path = html_path(filename)
+ if os.path.exists(path):
+ with open(path, encoding="utf8") as file:
+ return file.read()
+ return ""
+
+
+def get_version():
+ version = None
+ if version is None:
+ try:
+ import subprocess
+ res = subprocess.run('git log --pretty=format:"%h %ad" -1 --date=short', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True)
+ ver = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else ' '
+ githash, updated = ver.split(' ')
+ res = subprocess.run('git remote get-url origin', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True)
+ origin = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else ''
+ res = subprocess.run('git rev-parse --abbrev-ref HEAD', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True)
+ branch = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else ''
+ version = {
+ 'app': 'sd.next',
+ 'updated': updated,
+ 'hash': githash,
+ 'url': origin.replace('\n', '') + '/tree/' + branch.replace('\n', '')
+ }
+ except Exception:
+ version = { 'app': 'sd.next' }
+ return version
+
+
+def req(url_addr, headers = None, **kwargs):
+ if headers is None:
+ headers = { 'Content-type': 'application/json' }
+ try:
+ res = requests.get(url_addr, timeout=30, headers=headers, verify=False, allow_redirects=True, **kwargs)
+ except Exception as e:
+ log.error(f'HTTP request error: url={url_addr} {e}')
+ res = { 'status_code': 500, 'text': f'HTTP request error: url={url_addr} {e}' }
+ res = SimpleNamespace(**res)
+ return res
+
+
+sd_model: diffusers.DiffusionPipeline = None # dummy and overwritten by class
+sd_refiner: diffusers.DiffusionPipeline = None # dummy and overwritten by class
+sd_model_type: str = '' # dummy and overwritten by class
+sd_refiner_type: str = '' # dummy and overwritten by class
+compiled_model_state = None
+
+from modules.modeldata import Shared # pylint: disable=ungrouped-imports
+sys.modules[__name__].__class__ = Shared
diff --git a/modules/shared_items.py b/modules/shared_items.py
index b0d4f1ec4..0796cac43 100644
--- a/modules/shared_items.py
+++ b/modules/shared_items.py
@@ -1,55 +1,55 @@
-def postprocessing_scripts():
- import modules.scripts
- return modules.scripts.scripts_postproc.scripts
-
-
-def sd_vae_items():
- import modules.sd_vae
- return ["Automatic", "None"] + list(modules.sd_vae.vae_dict)
-
-
-def refresh_vae_list():
- import modules.sd_vae
- modules.sd_vae.refresh_vae_list()
-
-
-def list_crossattention():
- return [
- "Disabled",
- "xFormers",
- "Scaled-Dot-Product",
- "Doggettx's",
- "InvokeAI's",
- "Sub-quadratic",
- "Split attention"
- ]
-
-def get_pipelines():
- import diffusers
- from installer import log
- pipelines = { # note: not all pipelines can be used manually as they require prior pipeline next to decoder pipeline
- 'Autodetect': None,
- 'Stable Diffusion': getattr(diffusers, 'StableDiffusionPipeline', None),
- 'Stable Diffusion Inpaint': getattr(diffusers, 'StableDiffusionInpaintPipeline', None),
- 'Stable Diffusion Img2Img': getattr(diffusers, 'StableDiffusionImg2ImgPipeline', None),
- 'Stable Diffusion Instruct': getattr(diffusers, 'StableDiffusionInstructPix2PixPipeline', None),
- 'Stable Diffusion Upscale': getattr(diffusers, 'StableDiffusionUpscalePipeline', None),
- 'Stable Diffusion XL': getattr(diffusers, 'StableDiffusionXLPipeline', None),
- 'Stable Diffusion XL Img2Img': getattr(diffusers, 'StableDiffusionXLImg2ImgPipeline', None),
- 'Stable Diffusion XL Inpaint': getattr(diffusers, 'StableDiffusionXLInpaintPipeline', None),
- 'Stable Diffusion XL Instruct': getattr(diffusers, 'StableDiffusionXLInstructPix2PixPipeline', None),
- 'Latent Consistency Model': getattr(diffusers, 'LatentConsistencyModelPipeline', None),
- 'PixArt Alpha': getattr(diffusers, 'PixArtAlphaPipeline', None),
- 'UniDiffuser': getattr(diffusers, 'UniDiffuserPipeline', None),
- 'Wuerstchen': getattr(diffusers, 'WuerstchenCombinedPipeline', None),
- 'Kandinsky 2.1': getattr(diffusers, 'KandinskyPipeline', None),
- 'Kandinsky 2.2': getattr(diffusers, 'KandinskyV22Pipeline', None),
- 'Kandinsky 3': getattr(diffusers, 'Kandinsky3Pipeline', None),
- 'DeepFloyd IF': getattr(diffusers, 'IFPipeline', None),
- 'Custom Diffusers Pipeline': getattr(diffusers, 'DiffusionPipeline', None),
- # Segmind SSD-1B, Segmind Tiny
- }
- for k, v in pipelines.items():
- if k != 'Autodetect' and v is None:
- log.error(f'Not available: pipeline={k} diffusers={diffusers.__version__} path={diffusers.__file__}')
- return pipelines
+def postprocessing_scripts():
+ import modules.scripts
+ return modules.scripts.scripts_postproc.scripts
+
+
+def sd_vae_items():
+ import modules.sd_vae
+ return ["Automatic", "None"] + list(modules.sd_vae.vae_dict)
+
+
+def refresh_vae_list():
+ import modules.sd_vae
+ modules.sd_vae.refresh_vae_list()
+
+
+def list_crossattention():
+ return [
+ "Disabled",
+ "xFormers",
+ "Scaled-Dot-Product",
+ "Doggettx's",
+ "InvokeAI's",
+ "Sub-quadratic",
+ "Split attention"
+ ]
+
+def get_pipelines():
+ import diffusers
+ from installer import log
+ pipelines = { # note: not all pipelines can be used manually as they require prior pipeline next to decoder pipeline
+ 'Autodetect': None,
+ 'Stable Diffusion': getattr(diffusers, 'StableDiffusionPipeline', None),
+ 'Stable Diffusion Inpaint': getattr(diffusers, 'StableDiffusionInpaintPipeline', None),
+ 'Stable Diffusion Img2Img': getattr(diffusers, 'StableDiffusionImg2ImgPipeline', None),
+ 'Stable Diffusion Instruct': getattr(diffusers, 'StableDiffusionInstructPix2PixPipeline', None),
+ 'Stable Diffusion Upscale': getattr(diffusers, 'StableDiffusionUpscalePipeline', None),
+ 'Stable Diffusion XL': getattr(diffusers, 'StableDiffusionXLPipeline', None),
+ 'Stable Diffusion XL Img2Img': getattr(diffusers, 'StableDiffusionXLImg2ImgPipeline', None),
+ 'Stable Diffusion XL Inpaint': getattr(diffusers, 'StableDiffusionXLInpaintPipeline', None),
+ 'Stable Diffusion XL Instruct': getattr(diffusers, 'StableDiffusionXLInstructPix2PixPipeline', None),
+ 'Latent Consistency Model': getattr(diffusers, 'LatentConsistencyModelPipeline', None),
+ 'PixArt Alpha': getattr(diffusers, 'PixArtAlphaPipeline', None),
+ 'UniDiffuser': getattr(diffusers, 'UniDiffuserPipeline', None),
+ 'Wuerstchen': getattr(diffusers, 'WuerstchenCombinedPipeline', None),
+ 'Kandinsky 2.1': getattr(diffusers, 'KandinskyPipeline', None),
+ 'Kandinsky 2.2': getattr(diffusers, 'KandinskyV22Pipeline', None),
+ 'Kandinsky 3': getattr(diffusers, 'Kandinsky3Pipeline', None),
+ 'DeepFloyd IF': getattr(diffusers, 'IFPipeline', None),
+ 'Custom Diffusers Pipeline': getattr(diffusers, 'DiffusionPipeline', None),
+ # Segmind SSD-1B, Segmind Tiny
+ }
+ for k, v in pipelines.items():
+ if k != 'Autodetect' and v is None:
+ log.error(f'Not available: pipeline={k} diffusers={diffusers.__version__} path={diffusers.__file__}')
+ return pipelines
diff --git a/modules/styles.py b/modules/styles.py
index 1f9a6fd52..58cf3e2a5 100644
--- a/modules/styles.py
+++ b/modules/styles.py
@@ -1,237 +1,237 @@
-# We need this so Python doesn't complain about the unknown StableDiffusionProcessing-typehint at runtime
-from __future__ import annotations
-import re
-import os
-import csv
-import json
-import time
-from installer import log
-
-
-class Style():
- def __init__(self, name: str, desc: str = "", prompt: str = "", negative_prompt: str = "", extra: str = "", filename: str = "", preview: str = "", mtime: float = 0):
- self.name = name
- self.description = desc
- self.prompt = prompt
- self.negative_prompt = negative_prompt
- self.extra = extra
- self.filename = filename
- self.preview = preview
- self.mtime = mtime
-
-def merge_prompts(style_prompt: str, prompt: str) -> str:
- if "{prompt}" in style_prompt:
- res = style_prompt.replace("{prompt}", prompt)
- else:
- original_prompt = prompt.strip()
- style_prompt = style_prompt.strip()
- parts = filter(None, (original_prompt, style_prompt))
- if original_prompt.endswith(","):
- res = " ".join(parts)
- else:
- res = ", ".join(parts)
- return res
-
-
-def apply_styles_to_prompt(prompt, styles):
- for style in styles:
- prompt = merge_prompts(style, prompt)
- return prompt
-
-
-def apply_styles_to_extra(p, style: Style):
- if style is None:
- return
- name_map = {
- 'sampler': 'sampler_name',
- }
- from modules.generation_parameters_copypaste import parse_generation_parameters
- extra = parse_generation_parameters(style.extra)
- extra.pop('Prompt', None)
- extra.pop('Negative prompt', None)
- fields = []
- for k, v in extra.items():
- k = k.lower()
- k = k.replace(' ', '_')
- if k in name_map: # rename some fields
- k = name_map[k]
- if hasattr(p, k):
- orig = getattr(p, k)
- if type(orig) != type(v) and orig is not None:
- v = type(orig)(v)
- setattr(p, k, v)
- fields.append(f'{k}={v}')
- log.info(f'Applying style: name="{style.name}" extra={fields}')
-
-
-class StyleDatabase:
- def __init__(self, opts):
- from modules import paths
-
- self.no_style = Style("None")
- self.styles = {}
- self.path = opts.styles_dir
- self.built_in = opts.extra_networks_styles
- if os.path.isfile(opts.styles_dir) or opts.styles_dir.endswith(".csv"):
- legacy_file = opts.styles_dir
- self.load_csv(legacy_file)
- opts.styles_dir = os.path.join(paths.models_path, "styles")
- self.path = opts.styles_dir
- os.makedirs(opts.styles_dir, exist_ok=True)
- self.save_styles(opts.styles_dir, verbose=True)
- log.debug(f'Migrated styles: file={legacy_file} folder={opts.styles_dir}')
- self.reload()
- if not os.path.isdir(opts.styles_dir):
- opts.styles_dir = os.path.join(paths.models_path, "styles")
- self.path = opts.styles_dir
- os.makedirs(opts.styles_dir, exist_ok=True)
-
- def load_style(self, fn, prefix=None):
- with open(fn, 'r', encoding='utf-8') as f:
- new_style = None
- try:
- all_styles = json.load(f)
- if type(all_styles) is dict:
- all_styles = [all_styles]
- for style in all_styles:
- if type(style) is not dict or "name" not in style:
- raise ValueError('cannot parse style')
- basename = os.path.splitext(os.path.basename(fn))[0]
- name = re.sub(r'[\t\r\n]', '', style.get("name", basename)).strip()
- if prefix is not None:
- name = os.path.join(prefix, name)
- else:
- name = os.path.join(os.path.dirname(os.path.relpath(fn, self.path)), name)
- new_style = Style(
- name=name,
- desc=style.get('description', name),
- prompt=style.get("prompt", ""),
- negative_prompt=style.get("negative", ""),
- extra=style.get("extra", ""),
- preview=style.get("preview", None),
- filename=fn,
- mtime=os.path.getmtime(fn),
- )
- self.styles[style["name"]] = new_style
- except Exception as e:
- log.error(f'Failed to load style: file={fn} error={e}')
- return new_style
-
-
- def reload(self):
- t0 = time.time()
- self.styles.clear()
-
- def list_folder(folder):
- import concurrent
- future_items = {}
- with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor:
- for filename in os.listdir(folder):
- fn = os.path.abspath(os.path.join(folder, filename))
- if os.path.isfile(fn) and fn.lower().endswith(".json"):
- future_items[executor.submit(self.load_style, fn, None)] = fn
- # self.load_style(fn)
- elif os.path.isdir(fn) and not fn.startswith('.'):
- list_folder(fn)
- self.styles = dict(sorted(self.styles.items(), key=lambda style: style[1].filename))
- if self.built_in:
- fn = os.path.join('html', 'art-styles.json')
- future_items[executor.submit(self.load_style, fn, 'built-in')] = fn
- for future in concurrent.futures.as_completed(future_items):
- future.result()
-
- list_folder(self.path)
- t1 = time.time()
- log.debug(f'Load styles: folder="{self.path}" items={len(self.styles.keys())} time={t1-t0:.2f}')
-
- def find_style(self, name):
- found = [style for style in self.styles.values() if style.name == name]
- return found[0] if len(found) > 0 else self.no_style
-
- def get_style_prompts(self, styles):
- if styles is None or not isinstance(styles, list):
- log.error(f'Invalid styles: {styles}')
- return []
- return [self.find_style(x).prompt for x in styles]
-
- def get_negative_style_prompts(self, styles):
- if styles is None or not isinstance(styles, list):
- log.error(f'Invalid styles: {styles}')
- return []
- return [self.find_style(x).negative_prompt for x in styles]
-
- def apply_styles_to_prompt(self, prompt, styles):
- if styles is None or not isinstance(styles, list):
- log.error(f'Invalid styles: {styles}')
- return prompt
- return apply_styles_to_prompt(prompt, [self.find_style(x).prompt for x in styles])
-
- def apply_negative_styles_to_prompt(self, prompt, styles):
- if styles is None or not isinstance(styles, list):
- log.error(f'Invalid styles: {styles}')
- return prompt
- return apply_styles_to_prompt(prompt, [self.find_style(x).negative_prompt for x in styles])
-
- def apply_styles_to_extra(self, p):
- if p.styles is None or not isinstance(p.styles, list):
- log.error(f'Invalid styles: {p.styles}')
- return
- for style in p.styles:
- s = self.find_style(style)
- apply_styles_to_extra(p, s)
-
- def save_styles(self, path, verbose=False):
- for name in list(self.styles):
- style = {
- "name": name,
- "prompt": self.styles[name].prompt,
- "negative": self.styles[name].negative_prompt,
- "extra": "",
- "preview": "",
- }
- keepcharacters = (' ','.','_')
- fn = "".join(c for c in name if c.isalnum() or c in keepcharacters).rstrip()
- fn = os.path.join(path, fn + ".json")
- try:
- with open(fn, 'w', encoding='utf-8') as f:
- json.dump(style, f, indent=2)
- if verbose:
- log.debug(f'Saved style: name={name} file={fn}')
- except Exception as e:
- log.error(f'Failed to save style: name={name} file={path} error={e}')
- count = len(list(self.styles))
- if count > 0:
- log.debug(f'Saved styles: folder="{path}" items={count}')
-
- def load_csv(self, legacy_file):
- if not os.path.isfile(legacy_file):
- return
- with open(legacy_file, "r", encoding="utf-8-sig", newline='') as file:
- reader = csv.DictReader(file, skipinitialspace=True)
- num = 0
- for row in reader:
- try:
- name = row["name"]
- prompt = row["prompt"] if "prompt" in row else row["text"]
- negative = row.get("negative_prompt", "") if "negative_prompt" in row else row.get("negative", "")
- self.styles[name] = Style(name, desc=name, prompt=prompt, negative_prompt=negative, extra="")
- log.debug(f'Migrated style: {self.styles[name].__dict__}')
- num += 1
- except Exception:
- log.error(f'Styles error: file="{legacy_file}" row={row}')
- log.info(f'Load legacy styles: file="{legacy_file}" loaded={num} created={len(list(self.styles))}')
-
- """
- def save_csv(self, path: str) -> None:
- import tempfile
- basedir = os.path.dirname(path)
- if basedir is not None and len(basedir) > 0:
- os.makedirs(basedir, exist_ok=True)
- fd, temp_path = tempfile.mkstemp(".csv")
- with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file:
- writer = csv.DictWriter(file, fieldnames=Style._fields)
- writer.writeheader()
- writer.writerows(style._asdict() for k, style in self.styles.items())
- log.debug(f'Saved legacy styles: {path} {len(self.styles.keys())}')
- shutil.move(temp_path, path)
- """
+# We need this so Python doesn't complain about the unknown StableDiffusionProcessing-typehint at runtime
+from __future__ import annotations
+import re
+import os
+import csv
+import json
+import time
+from installer import log
+
+
+class Style():
+ def __init__(self, name: str, desc: str = "", prompt: str = "", negative_prompt: str = "", extra: str = "", filename: str = "", preview: str = "", mtime: float = 0):
+ self.name = name
+ self.description = desc
+ self.prompt = prompt
+ self.negative_prompt = negative_prompt
+ self.extra = extra
+ self.filename = filename
+ self.preview = preview
+ self.mtime = mtime
+
+def merge_prompts(style_prompt: str, prompt: str) -> str:
+ if "{prompt}" in style_prompt:
+ res = style_prompt.replace("{prompt}", prompt)
+ else:
+ original_prompt = prompt.strip()
+ style_prompt = style_prompt.strip()
+ parts = filter(None, (original_prompt, style_prompt))
+ if original_prompt.endswith(","):
+ res = " ".join(parts)
+ else:
+ res = ", ".join(parts)
+ return res
+
+
+def apply_styles_to_prompt(prompt, styles):
+ for style in styles:
+ prompt = merge_prompts(style, prompt)
+ return prompt
+
+
+def apply_styles_to_extra(p, style: Style):
+ if style is None:
+ return
+ name_map = {
+ 'sampler': 'sampler_name',
+ }
+ from modules.generation_parameters_copypaste import parse_generation_parameters
+ extra = parse_generation_parameters(style.extra)
+ extra.pop('Prompt', None)
+ extra.pop('Negative prompt', None)
+ fields = []
+ for k, v in extra.items():
+ k = k.lower()
+ k = k.replace(' ', '_')
+ if k in name_map: # rename some fields
+ k = name_map[k]
+ if hasattr(p, k):
+ orig = getattr(p, k)
+ if type(orig) != type(v) and orig is not None:
+ v = type(orig)(v)
+ setattr(p, k, v)
+ fields.append(f'{k}={v}')
+ log.info(f'Applying style: name="{style.name}" extra={fields}')
+
+
+class StyleDatabase:
+ def __init__(self, opts):
+ from modules import paths
+
+ self.no_style = Style("None")
+ self.styles = {}
+ self.path = opts.styles_dir
+ self.built_in = opts.extra_networks_styles
+ if os.path.isfile(opts.styles_dir) or opts.styles_dir.endswith(".csv"):
+ legacy_file = opts.styles_dir
+ self.load_csv(legacy_file)
+ opts.styles_dir = os.path.join(paths.models_path, "styles")
+ self.path = opts.styles_dir
+ os.makedirs(opts.styles_dir, exist_ok=True)
+ self.save_styles(opts.styles_dir, verbose=True)
+ log.debug(f'Migrated styles: file={legacy_file} folder={opts.styles_dir}')
+ self.reload()
+ if not os.path.isdir(opts.styles_dir):
+ opts.styles_dir = os.path.join(paths.models_path, "styles")
+ self.path = opts.styles_dir
+ os.makedirs(opts.styles_dir, exist_ok=True)
+
+ def load_style(self, fn, prefix=None):
+ with open(fn, 'r', encoding='utf-8') as f:
+ new_style = None
+ try:
+ all_styles = json.load(f)
+ if type(all_styles) is dict:
+ all_styles = [all_styles]
+ for style in all_styles:
+ if type(style) is not dict or "name" not in style:
+ raise ValueError('cannot parse style')
+ basename = os.path.splitext(os.path.basename(fn))[0]
+ name = re.sub(r'[\t\r\n]', '', style.get("name", basename)).strip()
+ if prefix is not None:
+ name = os.path.join(prefix, name)
+ else:
+ name = os.path.join(os.path.dirname(os.path.relpath(fn, self.path)), name)
+ new_style = Style(
+ name=name,
+ desc=style.get('description', name),
+ prompt=style.get("prompt", ""),
+ negative_prompt=style.get("negative", ""),
+ extra=style.get("extra", ""),
+ preview=style.get("preview", None),
+ filename=fn,
+ mtime=os.path.getmtime(fn),
+ )
+ self.styles[style["name"]] = new_style
+ except Exception as e:
+ log.error(f'Failed to load style: file={fn} error={e}')
+ return new_style
+
+
+ def reload(self):
+ t0 = time.time()
+ self.styles.clear()
+
+ def list_folder(folder):
+ import concurrent
+ future_items = {}
+ with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor:
+ for filename in os.listdir(folder):
+ fn = os.path.abspath(os.path.join(folder, filename))
+ if os.path.isfile(fn) and fn.lower().endswith(".json"):
+ future_items[executor.submit(self.load_style, fn, None)] = fn
+ # self.load_style(fn)
+ elif os.path.isdir(fn) and not fn.startswith('.'):
+ list_folder(fn)
+ self.styles = dict(sorted(self.styles.items(), key=lambda style: style[1].filename))
+ if self.built_in:
+ fn = os.path.join('html', 'art-styles.json')
+ future_items[executor.submit(self.load_style, fn, 'built-in')] = fn
+ for future in concurrent.futures.as_completed(future_items):
+ future.result()
+
+ list_folder(self.path)
+ t1 = time.time()
+ log.debug(f'Load styles: folder="{self.path}" items={len(self.styles.keys())} time={t1-t0:.2f}')
+
+ def find_style(self, name):
+ found = [style for style in self.styles.values() if style.name == name]
+ return found[0] if len(found) > 0 else self.no_style
+
+ def get_style_prompts(self, styles):
+ if styles is None or not isinstance(styles, list):
+ log.error(f'Invalid styles: {styles}')
+ return []
+ return [self.find_style(x).prompt for x in styles]
+
+ def get_negative_style_prompts(self, styles):
+ if styles is None or not isinstance(styles, list):
+ log.error(f'Invalid styles: {styles}')
+ return []
+ return [self.find_style(x).negative_prompt for x in styles]
+
+ def apply_styles_to_prompt(self, prompt, styles):
+ if styles is None or not isinstance(styles, list):
+ log.error(f'Invalid styles: {styles}')
+ return prompt
+ return apply_styles_to_prompt(prompt, [self.find_style(x).prompt for x in styles])
+
+ def apply_negative_styles_to_prompt(self, prompt, styles):
+ if styles is None or not isinstance(styles, list):
+ log.error(f'Invalid styles: {styles}')
+ return prompt
+ return apply_styles_to_prompt(prompt, [self.find_style(x).negative_prompt for x in styles])
+
+ def apply_styles_to_extra(self, p):
+ if p.styles is None or not isinstance(p.styles, list):
+ log.error(f'Invalid styles: {p.styles}')
+ return
+ for style in p.styles:
+ s = self.find_style(style)
+ apply_styles_to_extra(p, s)
+
+ def save_styles(self, path, verbose=False):
+ for name in list(self.styles):
+ style = {
+ "name": name,
+ "prompt": self.styles[name].prompt,
+ "negative": self.styles[name].negative_prompt,
+ "extra": "",
+ "preview": "",
+ }
+ keepcharacters = (' ','.','_')
+ fn = "".join(c for c in name if c.isalnum() or c in keepcharacters).rstrip()
+ fn = os.path.join(path, fn + ".json")
+ try:
+ with open(fn, 'w', encoding='utf-8') as f:
+ json.dump(style, f, indent=2)
+ if verbose:
+ log.debug(f'Saved style: name={name} file={fn}')
+ except Exception as e:
+ log.error(f'Failed to save style: name={name} file={path} error={e}')
+ count = len(list(self.styles))
+ if count > 0:
+ log.debug(f'Saved styles: folder="{path}" items={count}')
+
+ def load_csv(self, legacy_file):
+ if not os.path.isfile(legacy_file):
+ return
+ with open(legacy_file, "r", encoding="utf-8-sig", newline='') as file:
+ reader = csv.DictReader(file, skipinitialspace=True)
+ num = 0
+ for row in reader:
+ try:
+ name = row["name"]
+ prompt = row["prompt"] if "prompt" in row else row["text"]
+ negative = row.get("negative_prompt", "") if "negative_prompt" in row else row.get("negative", "")
+ self.styles[name] = Style(name, desc=name, prompt=prompt, negative_prompt=negative, extra="")
+ log.debug(f'Migrated style: {self.styles[name].__dict__}')
+ num += 1
+ except Exception:
+ log.error(f'Styles error: file="{legacy_file}" row={row}')
+ log.info(f'Load legacy styles: file="{legacy_file}" loaded={num} created={len(list(self.styles))}')
+
+ """
+ def save_csv(self, path: str) -> None:
+ import tempfile
+ basedir = os.path.dirname(path)
+ if basedir is not None and len(basedir) > 0:
+ os.makedirs(basedir, exist_ok=True)
+ fd, temp_path = tempfile.mkstemp(".csv")
+ with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file:
+ writer = csv.DictWriter(file, fieldnames=Style._fields)
+ writer.writeheader()
+ writer.writerows(style._asdict() for k, style in self.styles.items())
+ log.debug(f'Saved legacy styles: {path} {len(self.styles.keys())}')
+ shutil.move(temp_path, path)
+ """
diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py
index 2ccd4a27d..b4bf9828e 100644
--- a/modules/textual_inversion/autocrop.py
+++ b/modules/textual_inversion/autocrop.py
@@ -1,337 +1,337 @@
-import os
-import cv2
-import requests
-import numpy as np
-from PIL import ImageDraw
-
-GREEN = "#0F0"
-BLUE = "#00F"
-RED = "#F00"
-
-
-def crop_image(im, settings):
- """ Intelligently crop an image to the subject matter """
-
- scale_by = 1
- if is_landscape(im.width, im.height):
- scale_by = settings.crop_height / im.height
- elif is_portrait(im.width, im.height):
- scale_by = settings.crop_width / im.width
- elif is_square(im.width, im.height):
- if is_square(settings.crop_width, settings.crop_height):
- scale_by = settings.crop_width / im.width
- elif is_landscape(settings.crop_width, settings.crop_height):
- scale_by = settings.crop_width / im.width
- elif is_portrait(settings.crop_width, settings.crop_height):
- scale_by = settings.crop_height / im.height
-
- im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))
- im_debug = im.copy()
-
- focus = focal_point(im_debug, settings)
-
- # take the focal point and turn it into crop coordinates that try to center over the focal
- # point but then get adjusted back into the frame
- y_half = int(settings.crop_height / 2)
- x_half = int(settings.crop_width / 2)
-
- x1 = focus.x - x_half
- if x1 < 0:
- x1 = 0
- elif x1 + settings.crop_width > im.width:
- x1 = im.width - settings.crop_width
-
- y1 = focus.y - y_half
- if y1 < 0:
- y1 = 0
- elif y1 + settings.crop_height > im.height:
- y1 = im.height - settings.crop_height
-
- x2 = x1 + settings.crop_width
- y2 = y1 + settings.crop_height
-
- crop = [x1, y1, x2, y2]
-
- results = []
-
- results.append(im.crop(tuple(crop)))
-
- if settings.annotate_image:
- d = ImageDraw.Draw(im_debug)
- rect = list(crop)
- rect[2] -= 1
- rect[3] -= 1
- d.rectangle(rect, outline=GREEN)
- results.append(im_debug)
- if settings.destop_view_image:
- im_debug.show()
-
- return results
-
-def focal_point(im, settings):
- corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else []
- entropy_points = image_entropy_points(im, settings) if settings.entropy_points_weight > 0 else []
- face_points = image_face_points(im, settings) if settings.face_points_weight > 0 else []
-
- pois = []
-
- weight_pref_total = 0
- if len(corner_points) > 0:
- weight_pref_total += settings.corner_points_weight
- if len(entropy_points) > 0:
- weight_pref_total += settings.entropy_points_weight
- if len(face_points) > 0:
- weight_pref_total += settings.face_points_weight
-
- corner_centroid = None
- if len(corner_points) > 0:
- corner_centroid = centroid(corner_points)
- corner_centroid.weight = settings.corner_points_weight / weight_pref_total
- pois.append(corner_centroid)
-
- entropy_centroid = None
- if len(entropy_points) > 0:
- entropy_centroid = centroid(entropy_points)
- entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total
- pois.append(entropy_centroid)
-
- face_centroid = None
- if len(face_points) > 0:
- face_centroid = centroid(face_points)
- face_centroid.weight = settings.face_points_weight / weight_pref_total
- pois.append(face_centroid)
-
- average_point = poi_average(pois, settings)
-
- if settings.annotate_image:
- d = ImageDraw.Draw(im)
- max_size = min(im.width, im.height) * 0.07
- if corner_centroid is not None:
- color = BLUE
- box = corner_centroid.bounding(max_size * corner_centroid.weight)
- d.text((box[0], box[1]-15), f"Edge: {corner_centroid.weight:.02f}", fill=color)
- d.ellipse(box, outline=color)
- if len(corner_points) > 1:
- for f in corner_points:
- d.rectangle(f.bounding(4), outline=color)
- if entropy_centroid is not None:
- color = "#ff0"
- box = entropy_centroid.bounding(max_size * entropy_centroid.weight)
- d.text((box[0], box[1]-15), f"Entropy: {entropy_centroid.weight:.02f}", fill=color)
- d.ellipse(box, outline=color)
- if len(entropy_points) > 1:
- for f in entropy_points:
- d.rectangle(f.bounding(4), outline=color)
- if face_centroid is not None:
- color = RED
- box = face_centroid.bounding(max_size * face_centroid.weight)
- d.text((box[0], box[1]-15), f"Face: {face_centroid.weight:.02f}", fill=color)
- d.ellipse(box, outline=color)
- if len(face_points) > 1:
- for f in face_points:
- d.rectangle(f.bounding(4), outline=color)
-
- d.ellipse(average_point.bounding(max_size), outline=GREEN)
-
- return average_point
-
-
-def image_face_points(im, settings):
- if settings.dnn_model_path is not None:
- detector = cv2.FaceDetectorYN.create(
- settings.dnn_model_path,
- "",
- (im.width, im.height),
- 0.9, # score threshold
- 0.3, # nms threshold
- 5000 # keep top k before nms
- )
- faces = detector.detect(np.array(im))
- results = []
- if faces[1] is not None:
- for face in faces[1]:
- x = face[0]
- y = face[1]
- w = face[2]
- h = face[3]
- results.append(
- PointOfInterest(
- int(x + (w * 0.5)), # face focus left/right is center
- int(y + (h * 0.33)), # face focus up/down is close to the top of the head
- size = w,
- weight = 1/len(faces[1])
- )
- )
- return results
- else:
- np_im = np.array(im)
- gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY)
-
- tries = [
- [ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ],
- [ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ],
- [ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ],
- [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ],
- [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ],
- [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ],
- [ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ],
- [ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ]
- ]
- for t in tries:
- classifier = cv2.CascadeClassifier(t[0])
- minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side
- try:
- faces = classifier.detectMultiScale(gray, scaleFactor=1.1,
- minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE)
- except Exception:
- continue
-
- if len(faces) > 0:
- rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces]
- return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2]), weight=1/len(rects)) for r in rects]
- return []
-
-
-def image_corner_points(im, settings): # pylint: disable=unused-argument
- grayscale = im.convert("L")
-
- # naive attempt at preventing focal points from collecting at watermarks near the bottom
- gd = ImageDraw.Draw(grayscale)
- gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999")
-
- np_im = np.array(grayscale)
-
- points = cv2.goodFeaturesToTrack(
- np_im,
- maxCorners=100,
- qualityLevel=0.04,
- minDistance=min(grayscale.width, grayscale.height)*0.06,
- useHarrisDetector=False,
- )
-
- if points is None:
- return []
-
- focal_points = []
- for point in points:
- x, y = point.ravel()
- focal_points.append(PointOfInterest(x, y, size=4, weight=1/len(points)))
-
- return focal_points
-
-
-def image_entropy_points(im, settings):
- landscape = im.height < im.width
- portrait = im.height > im.width
- if landscape:
- move_idx = [0, 2]
- move_max = im.size[0]
- elif portrait:
- move_idx = [1, 3]
- move_max = im.size[1]
- else:
- return []
-
- e_max = 0
- crop_current = [0, 0, settings.crop_width, settings.crop_height]
- crop_best = crop_current
- while crop_current[move_idx[1]] < move_max:
- crop = im.crop(tuple(crop_current))
- e = image_entropy(crop)
-
- if e > e_max:
- e_max = e
- crop_best = list(crop_current)
-
- crop_current[move_idx[0]] += 4
- crop_current[move_idx[1]] += 4
-
- x_mid = int(crop_best[0] + settings.crop_width/2)
- y_mid = int(crop_best[1] + settings.crop_height/2)
-
- return [PointOfInterest(x_mid, y_mid, size=25, weight=1.0)]
-
-
-def image_entropy(im):
- # greyscale image entropy
- # band = np.asarray(im.convert("L"))
- band = np.asarray(im.convert("1"), dtype=np.uint8)
- hist, _ = np.histogram(band, bins=range(0, 256))
- hist = hist[hist > 0]
- return -np.log2(hist / hist.sum()).sum()
-
-def centroid(pois):
- x = [poi.x for poi in pois]
- y = [poi.y for poi in pois]
- return PointOfInterest(sum(x)/len(pois), sum(y)/len(pois))
-
-
-def poi_average(pois, settings): # pylint: disable=unused-argument
- weight = 0.0
- x = 0.0
- y = 0.0
- for poi in pois:
- weight += poi.weight
- x += poi.x * poi.weight
- y += poi.y * poi.weight
- avg_x = round(weight and x / weight)
- avg_y = round(weight and y / weight)
-
- return PointOfInterest(avg_x, avg_y)
-
-
-def is_landscape(w, h):
- return w > h
-
-
-def is_portrait(w, h):
- return h > w
-
-
-def is_square(w, h):
- return w == h
-
-
-def download_and_cache_models(dirname):
- download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
- model_file_name = 'face_detection_yunet.onnx'
- if not os.path.exists(dirname):
- os.makedirs(dirname, exist_ok=True)
- cache_file = os.path.join(dirname, model_file_name)
- if not os.path.exists(cache_file):
- print(f"downloading face detection model from '{download_url}' to '{cache_file}'")
- response = requests.get(download_url, timeout=60*60*2)
- with open(cache_file, "wb") as f:
- f.write(response.content)
-
- if os.path.exists(cache_file):
- return cache_file
- return None
-
-
-class PointOfInterest:
- def __init__(self, x, y, weight=1.0, size=10):
- self.x = x
- self.y = y
- self.weight = weight
- self.size = size
-
- def bounding(self, size):
- return [
- self.x - size//2,
- self.y - size//2,
- self.x + size//2,
- self.y + size//2
- ]
-
-
-class Settings:
- def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None):
- self.crop_width = crop_width
- self.crop_height = crop_height
- self.corner_points_weight = corner_points_weight
- self.entropy_points_weight = entropy_points_weight
- self.face_points_weight = face_points_weight
- self.annotate_image = annotate_image
- self.destop_view_image = False
- self.dnn_model_path = dnn_model_path
+import os
+import cv2
+import requests
+import numpy as np
+from PIL import ImageDraw
+
+GREEN = "#0F0"
+BLUE = "#00F"
+RED = "#F00"
+
+
+def crop_image(im, settings):
+ """ Intelligently crop an image to the subject matter """
+
+ scale_by = 1
+ if is_landscape(im.width, im.height):
+ scale_by = settings.crop_height / im.height
+ elif is_portrait(im.width, im.height):
+ scale_by = settings.crop_width / im.width
+ elif is_square(im.width, im.height):
+ if is_square(settings.crop_width, settings.crop_height):
+ scale_by = settings.crop_width / im.width
+ elif is_landscape(settings.crop_width, settings.crop_height):
+ scale_by = settings.crop_width / im.width
+ elif is_portrait(settings.crop_width, settings.crop_height):
+ scale_by = settings.crop_height / im.height
+
+ im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))
+ im_debug = im.copy()
+
+ focus = focal_point(im_debug, settings)
+
+ # take the focal point and turn it into crop coordinates that try to center over the focal
+ # point but then get adjusted back into the frame
+ y_half = int(settings.crop_height / 2)
+ x_half = int(settings.crop_width / 2)
+
+ x1 = focus.x - x_half
+ if x1 < 0:
+ x1 = 0
+ elif x1 + settings.crop_width > im.width:
+ x1 = im.width - settings.crop_width
+
+ y1 = focus.y - y_half
+ if y1 < 0:
+ y1 = 0
+ elif y1 + settings.crop_height > im.height:
+ y1 = im.height - settings.crop_height
+
+ x2 = x1 + settings.crop_width
+ y2 = y1 + settings.crop_height
+
+ crop = [x1, y1, x2, y2]
+
+ results = []
+
+ results.append(im.crop(tuple(crop)))
+
+ if settings.annotate_image:
+ d = ImageDraw.Draw(im_debug)
+ rect = list(crop)
+ rect[2] -= 1
+ rect[3] -= 1
+ d.rectangle(rect, outline=GREEN)
+ results.append(im_debug)
+ if settings.destop_view_image:
+ im_debug.show()
+
+ return results
+
+def focal_point(im, settings):
+ corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else []
+ entropy_points = image_entropy_points(im, settings) if settings.entropy_points_weight > 0 else []
+ face_points = image_face_points(im, settings) if settings.face_points_weight > 0 else []
+
+ pois = []
+
+ weight_pref_total = 0
+ if len(corner_points) > 0:
+ weight_pref_total += settings.corner_points_weight
+ if len(entropy_points) > 0:
+ weight_pref_total += settings.entropy_points_weight
+ if len(face_points) > 0:
+ weight_pref_total += settings.face_points_weight
+
+ corner_centroid = None
+ if len(corner_points) > 0:
+ corner_centroid = centroid(corner_points)
+ corner_centroid.weight = settings.corner_points_weight / weight_pref_total
+ pois.append(corner_centroid)
+
+ entropy_centroid = None
+ if len(entropy_points) > 0:
+ entropy_centroid = centroid(entropy_points)
+ entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total
+ pois.append(entropy_centroid)
+
+ face_centroid = None
+ if len(face_points) > 0:
+ face_centroid = centroid(face_points)
+ face_centroid.weight = settings.face_points_weight / weight_pref_total
+ pois.append(face_centroid)
+
+ average_point = poi_average(pois, settings)
+
+ if settings.annotate_image:
+ d = ImageDraw.Draw(im)
+ max_size = min(im.width, im.height) * 0.07
+ if corner_centroid is not None:
+ color = BLUE
+ box = corner_centroid.bounding(max_size * corner_centroid.weight)
+ d.text((box[0], box[1]-15), f"Edge: {corner_centroid.weight:.02f}", fill=color)
+ d.ellipse(box, outline=color)
+ if len(corner_points) > 1:
+ for f in corner_points:
+ d.rectangle(f.bounding(4), outline=color)
+ if entropy_centroid is not None:
+ color = "#ff0"
+ box = entropy_centroid.bounding(max_size * entropy_centroid.weight)
+ d.text((box[0], box[1]-15), f"Entropy: {entropy_centroid.weight:.02f}", fill=color)
+ d.ellipse(box, outline=color)
+ if len(entropy_points) > 1:
+ for f in entropy_points:
+ d.rectangle(f.bounding(4), outline=color)
+ if face_centroid is not None:
+ color = RED
+ box = face_centroid.bounding(max_size * face_centroid.weight)
+ d.text((box[0], box[1]-15), f"Face: {face_centroid.weight:.02f}", fill=color)
+ d.ellipse(box, outline=color)
+ if len(face_points) > 1:
+ for f in face_points:
+ d.rectangle(f.bounding(4), outline=color)
+
+ d.ellipse(average_point.bounding(max_size), outline=GREEN)
+
+ return average_point
+
+
+def image_face_points(im, settings):
+ if settings.dnn_model_path is not None:
+ detector = cv2.FaceDetectorYN.create(
+ settings.dnn_model_path,
+ "",
+ (im.width, im.height),
+ 0.9, # score threshold
+ 0.3, # nms threshold
+ 5000 # keep top k before nms
+ )
+ faces = detector.detect(np.array(im))
+ results = []
+ if faces[1] is not None:
+ for face in faces[1]:
+ x = face[0]
+ y = face[1]
+ w = face[2]
+ h = face[3]
+ results.append(
+ PointOfInterest(
+ int(x + (w * 0.5)), # face focus left/right is center
+ int(y + (h * 0.33)), # face focus up/down is close to the top of the head
+ size = w,
+ weight = 1/len(faces[1])
+ )
+ )
+ return results
+ else:
+ np_im = np.array(im)
+ gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY)
+
+ tries = [
+ [ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ],
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ]
+ ]
+ for t in tries:
+ classifier = cv2.CascadeClassifier(t[0])
+ minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side
+ try:
+ faces = classifier.detectMultiScale(gray, scaleFactor=1.1,
+ minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE)
+ except Exception:
+ continue
+
+ if len(faces) > 0:
+ rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces]
+ return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2]), weight=1/len(rects)) for r in rects]
+ return []
+
+
+def image_corner_points(im, settings): # pylint: disable=unused-argument
+ grayscale = im.convert("L")
+
+ # naive attempt at preventing focal points from collecting at watermarks near the bottom
+ gd = ImageDraw.Draw(grayscale)
+ gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999")
+
+ np_im = np.array(grayscale)
+
+ points = cv2.goodFeaturesToTrack(
+ np_im,
+ maxCorners=100,
+ qualityLevel=0.04,
+ minDistance=min(grayscale.width, grayscale.height)*0.06,
+ useHarrisDetector=False,
+ )
+
+ if points is None:
+ return []
+
+ focal_points = []
+ for point in points:
+ x, y = point.ravel()
+ focal_points.append(PointOfInterest(x, y, size=4, weight=1/len(points)))
+
+ return focal_points
+
+
+def image_entropy_points(im, settings):
+ landscape = im.height < im.width
+ portrait = im.height > im.width
+ if landscape:
+ move_idx = [0, 2]
+ move_max = im.size[0]
+ elif portrait:
+ move_idx = [1, 3]
+ move_max = im.size[1]
+ else:
+ return []
+
+ e_max = 0
+ crop_current = [0, 0, settings.crop_width, settings.crop_height]
+ crop_best = crop_current
+ while crop_current[move_idx[1]] < move_max:
+ crop = im.crop(tuple(crop_current))
+ e = image_entropy(crop)
+
+ if e > e_max:
+ e_max = e
+ crop_best = list(crop_current)
+
+ crop_current[move_idx[0]] += 4
+ crop_current[move_idx[1]] += 4
+
+ x_mid = int(crop_best[0] + settings.crop_width/2)
+ y_mid = int(crop_best[1] + settings.crop_height/2)
+
+ return [PointOfInterest(x_mid, y_mid, size=25, weight=1.0)]
+
+
+def image_entropy(im):
+ # greyscale image entropy
+ # band = np.asarray(im.convert("L"))
+ band = np.asarray(im.convert("1"), dtype=np.uint8)
+ hist, _ = np.histogram(band, bins=range(0, 256))
+ hist = hist[hist > 0]
+ return -np.log2(hist / hist.sum()).sum()
+
+def centroid(pois):
+ x = [poi.x for poi in pois]
+ y = [poi.y for poi in pois]
+ return PointOfInterest(sum(x)/len(pois), sum(y)/len(pois))
+
+
+def poi_average(pois, settings): # pylint: disable=unused-argument
+ weight = 0.0
+ x = 0.0
+ y = 0.0
+ for poi in pois:
+ weight += poi.weight
+ x += poi.x * poi.weight
+ y += poi.y * poi.weight
+ avg_x = round(weight and x / weight)
+ avg_y = round(weight and y / weight)
+
+ return PointOfInterest(avg_x, avg_y)
+
+
+def is_landscape(w, h):
+ return w > h
+
+
+def is_portrait(w, h):
+ return h > w
+
+
+def is_square(w, h):
+ return w == h
+
+
+def download_and_cache_models(dirname):
+ download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
+ model_file_name = 'face_detection_yunet.onnx'
+ if not os.path.exists(dirname):
+ os.makedirs(dirname, exist_ok=True)
+ cache_file = os.path.join(dirname, model_file_name)
+ if not os.path.exists(cache_file):
+ print(f"downloading face detection model from '{download_url}' to '{cache_file}'")
+ response = requests.get(download_url, timeout=60*60*2)
+ with open(cache_file, "wb") as f:
+ f.write(response.content)
+
+ if os.path.exists(cache_file):
+ return cache_file
+ return None
+
+
+class PointOfInterest:
+ def __init__(self, x, y, weight=1.0, size=10):
+ self.x = x
+ self.y = y
+ self.weight = weight
+ self.size = size
+
+ def bounding(self, size):
+ return [
+ self.x - size//2,
+ self.y - size//2,
+ self.x + size//2,
+ self.y + size//2
+ ]
+
+
+class Settings:
+ def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None):
+ self.crop_width = crop_width
+ self.crop_height = crop_height
+ self.corner_points_weight = corner_points_weight
+ self.entropy_points_weight = entropy_points_weight
+ self.face_points_weight = face_points_weight
+ self.annotate_image = annotate_image
+ self.destop_view_image = False
+ self.dnn_model_path = dnn_model_path
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index c530670e1..e80106ebe 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -1,231 +1,231 @@
-import os
-import re
-import random
-from collections import defaultdict
-import numpy as np
-import torch
-from PIL import Image
-from torch.utils.data import Dataset, DataLoader, Sampler
-from torchvision import transforms
-import tqdm
-from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
-from modules import devices, shared
-
-re_numbers_at_start = re.compile(r"^[-\d]+\s*")
-
-
-class DatasetEntry:
- def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None, weight=None):
- self.filename = filename
- self.filename_text = filename_text
- self.weight = weight
- self.latent_dist = latent_dist
- self.latent_sample = latent_sample
- self.cond = cond
- self.cond_text = cond_text
- self.pixel_values = pixel_values
-
-
-class PersonalizedBase(Dataset):
- def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False, use_weight=False):
- re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
-
- self.placeholder_token = placeholder_token
- self.flip = transforms.RandomHorizontalFlip(p=flip_p)
- self.dataset = []
- with open(template_file, "r", encoding="utf8") as file:
- lines = [x.strip() for x in file.readlines()]
- self.lines = lines
-
- assert data_root, 'dataset directory not specified'
- assert os.path.isdir(data_root), "Dataset directory doesn't exist"
- assert os.listdir(data_root), "Dataset directory is empty"
-
- self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
- self.shuffle_tags = shuffle_tags
- self.tag_drop_out = tag_drop_out
- groups = defaultdict(list)
- shared.log.info(f"TI Training: Preparing dataset: {data_root}")
- for path in tqdm.tqdm(self.image_paths):
- alpha_channel = None
- if shared.state.interrupted:
- raise RuntimeError("interrupted")
- try:
- image = Image.open(path)
- if use_weight and 'A' in image.getbands():
- alpha_channel = image.getchannel('A')
- image = image.convert('RGB')
- if not varsize:
- image = image.resize((width, height), Image.Resampling.BICUBIC)
- except Exception:
- continue
-
- text_filename = f"{os.path.splitext(path)[0]}.txt"
- filename = os.path.basename(path)
-
- if os.path.exists(text_filename):
- with open(text_filename, "r", encoding="utf8") as file:
- filename_text = file.read()
- else:
- filename_text = os.path.splitext(filename)[0]
- filename_text = re.sub(re_numbers_at_start, '', filename_text)
- if re_word:
- tokens = re_word.findall(filename_text)
- filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens)
-
- npimage = np.array(image).astype(np.uint8)
- npimage = (npimage / 127.5 - 1.0).astype(np.float32)
-
- torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32)
- latent_sample = None
-
- with devices.autocast():
- latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))
-
- if latent_sampling_method == "deterministic":
- if isinstance(latent_dist, DiagonalGaussianDistribution):
- latent_dist.std = torch.exp(0 * latent_dist.logvar)
- else:
- latent_sampling_method = "once"
- latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
-
- if use_weight and alpha_channel is not None:
- channels, *latent_size = latent_sample.shape
- weight_img = alpha_channel.resize(latent_size)
- npweight = np.array(weight_img).astype(np.float32)
- #Repeat for every channel in the latent sample
- weight = torch.tensor([npweight] * channels).reshape([channels] + latent_size)
- #Normalize the weight to a minimum of 0 and a mean of 1, that way the loss will be comparable to default.
- weight -= weight.min()
- weight /= weight.mean()
- elif use_weight:
- #If an image does not have a alpha channel, add a ones weight map anyway so we can stack it later
- weight = torch.ones(latent_sample.shape)
- else:
- weight = None
-
- if latent_sampling_method == "random":
- entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight)
- else:
- entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, weight=weight)
-
- if not (self.tag_drop_out != 0 or self.shuffle_tags):
- entry.cond_text = self.create_text(filename_text)
-
- if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
- with devices.autocast():
- entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
- groups[image.size].append(len(self.dataset))
- self.dataset.append(entry)
- del torchdata
- del latent_dist
- del latent_sample
- del weight
-
- self.length = len(self.dataset)
- self.groups = list(groups.values())
- assert self.length > 0, "No images have been found in the dataset."
- self.batch_size = min(batch_size, self.length)
- self.gradient_step = min(gradient_step, self.length // self.batch_size)
- self.latent_sampling_method = latent_sampling_method
-
- if len(groups) > 1:
- print("Buckets:")
- for (w, h), ids in sorted(groups.items(), key=lambda x: x[0]):
- print(f" {w}x{h}: {len(ids)}")
- print()
-
- def create_text(self, filename_text):
- text = random.choice(self.lines)
- tags = filename_text.split(',')
- if self.tag_drop_out != 0:
- tags = [t for t in tags if random.random() > self.tag_drop_out]
- if self.shuffle_tags:
- random.shuffle(tags)
- text = text.replace("[filewords]", ','.join(tags))
- text = text.replace("[name]", self.placeholder_token)
- return text
-
- def __len__(self):
- return self.length
-
- def __getitem__(self, i):
- entry = self.dataset[i]
- if self.tag_drop_out != 0 or self.shuffle_tags:
- entry.cond_text = self.create_text(entry.filename_text)
- if self.latent_sampling_method == "random":
- entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
- return entry
-
-
-class GroupedBatchSampler(Sampler):
- def __init__(self, data_source: PersonalizedBase, batch_size: int):
- super().__init__(data_source)
-
- n = len(data_source)
- self.groups = data_source.groups
- self.len = n_batch = n // batch_size
- expected = [len(g) / n * n_batch * batch_size for g in data_source.groups]
- self.base = [int(e) // batch_size for e in expected]
- self.n_rand_batches = nrb = n_batch - sum(self.base)
- self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected]
- self.batch_size = batch_size
-
- def __len__(self):
- return self.len
-
- def __iter__(self):
- b = self.batch_size
-
- for g in self.groups:
- random.shuffle(g)
-
- batches = []
- for g in self.groups:
- batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
- for _ in range(self.n_rand_batches):
- rand_group = random.choices(self.groups, self.probs)[0]
- batches.append(random.choices(rand_group, k=b))
-
- random.shuffle(batches)
-
- yield from batches
-
-
-class PersonalizedDataLoader(DataLoader):
- def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
- super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory)
- if latent_sampling_method == "random":
- self.collate_fn = collate_wrapper_random
- else:
- self.collate_fn = collate_wrapper
-
-
-class BatchLoader:
- def __init__(self, data):
- self.cond_text = [entry.cond_text for entry in data]
- self.cond = [entry.cond for entry in data]
- self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
- if all(entry.weight is not None for entry in data):
- self.weight = torch.stack([entry.weight for entry in data]).squeeze(1)
- else:
- self.weight = None
- #self.emb_index = [entry.emb_index for entry in data]
- #print(self.latent_sample.device)
-
- def pin_memory(self):
- self.latent_sample = self.latent_sample.pin_memory()
- return self
-
-def collate_wrapper(batch):
- return BatchLoader(batch)
-
-class BatchLoaderRandom(BatchLoader):
- def __init__(self, data):
- super().__init__(data)
-
- def pin_memory(self):
- return self
-
-def collate_wrapper_random(batch):
- return BatchLoaderRandom(batch)
+import os
+import re
+import random
+from collections import defaultdict
+import numpy as np
+import torch
+from PIL import Image
+from torch.utils.data import Dataset, DataLoader, Sampler
+from torchvision import transforms
+import tqdm
+from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
+from modules import devices, shared
+
+re_numbers_at_start = re.compile(r"^[-\d]+\s*")
+
+
+class DatasetEntry:
+ def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None, weight=None):
+ self.filename = filename
+ self.filename_text = filename_text
+ self.weight = weight
+ self.latent_dist = latent_dist
+ self.latent_sample = latent_sample
+ self.cond = cond
+ self.cond_text = cond_text
+ self.pixel_values = pixel_values
+
+
+class PersonalizedBase(Dataset):
+ def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False, use_weight=False):
+ re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
+
+ self.placeholder_token = placeholder_token
+ self.flip = transforms.RandomHorizontalFlip(p=flip_p)
+ self.dataset = []
+ with open(template_file, "r", encoding="utf8") as file:
+ lines = [x.strip() for x in file.readlines()]
+ self.lines = lines
+
+ assert data_root, 'dataset directory not specified'
+ assert os.path.isdir(data_root), "Dataset directory doesn't exist"
+ assert os.listdir(data_root), "Dataset directory is empty"
+
+ self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
+ self.shuffle_tags = shuffle_tags
+ self.tag_drop_out = tag_drop_out
+ groups = defaultdict(list)
+ shared.log.info(f"TI Training: Preparing dataset: {data_root}")
+ for path in tqdm.tqdm(self.image_paths):
+ alpha_channel = None
+ if shared.state.interrupted:
+ raise RuntimeError("interrupted")
+ try:
+ image = Image.open(path)
+ if use_weight and 'A' in image.getbands():
+ alpha_channel = image.getchannel('A')
+ image = image.convert('RGB')
+ if not varsize:
+ image = image.resize((width, height), Image.Resampling.BICUBIC)
+ except Exception:
+ continue
+
+ text_filename = f"{os.path.splitext(path)[0]}.txt"
+ filename = os.path.basename(path)
+
+ if os.path.exists(text_filename):
+ with open(text_filename, "r", encoding="utf8") as file:
+ filename_text = file.read()
+ else:
+ filename_text = os.path.splitext(filename)[0]
+ filename_text = re.sub(re_numbers_at_start, '', filename_text)
+ if re_word:
+ tokens = re_word.findall(filename_text)
+ filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens)
+
+ npimage = np.array(image).astype(np.uint8)
+ npimage = (npimage / 127.5 - 1.0).astype(np.float32)
+
+ torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32)
+ latent_sample = None
+
+ with devices.autocast():
+ latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))
+
+ if latent_sampling_method == "deterministic":
+ if isinstance(latent_dist, DiagonalGaussianDistribution):
+ latent_dist.std = torch.exp(0 * latent_dist.logvar)
+ else:
+ latent_sampling_method = "once"
+ latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
+
+ if use_weight and alpha_channel is not None:
+ channels, *latent_size = latent_sample.shape
+ weight_img = alpha_channel.resize(latent_size)
+ npweight = np.array(weight_img).astype(np.float32)
+ #Repeat for every channel in the latent sample
+ weight = torch.tensor([npweight] * channels).reshape([channels] + latent_size)
+ #Normalize the weight to a minimum of 0 and a mean of 1, that way the loss will be comparable to default.
+ weight -= weight.min()
+ weight /= weight.mean()
+ elif use_weight:
+ #If an image does not have a alpha channel, add a ones weight map anyway so we can stack it later
+ weight = torch.ones(latent_sample.shape)
+ else:
+ weight = None
+
+ if latent_sampling_method == "random":
+ entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight)
+ else:
+ entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, weight=weight)
+
+ if not (self.tag_drop_out != 0 or self.shuffle_tags):
+ entry.cond_text = self.create_text(filename_text)
+
+ if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
+ with devices.autocast():
+ entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
+ groups[image.size].append(len(self.dataset))
+ self.dataset.append(entry)
+ del torchdata
+ del latent_dist
+ del latent_sample
+ del weight
+
+ self.length = len(self.dataset)
+ self.groups = list(groups.values())
+ assert self.length > 0, "No images have been found in the dataset."
+ self.batch_size = min(batch_size, self.length)
+ self.gradient_step = min(gradient_step, self.length // self.batch_size)
+ self.latent_sampling_method = latent_sampling_method
+
+ if len(groups) > 1:
+ print("Buckets:")
+ for (w, h), ids in sorted(groups.items(), key=lambda x: x[0]):
+ print(f" {w}x{h}: {len(ids)}")
+ print()
+
+ def create_text(self, filename_text):
+ text = random.choice(self.lines)
+ tags = filename_text.split(',')
+ if self.tag_drop_out != 0:
+ tags = [t for t in tags if random.random() > self.tag_drop_out]
+ if self.shuffle_tags:
+ random.shuffle(tags)
+ text = text.replace("[filewords]", ','.join(tags))
+ text = text.replace("[name]", self.placeholder_token)
+ return text
+
+ def __len__(self):
+ return self.length
+
+ def __getitem__(self, i):
+ entry = self.dataset[i]
+ if self.tag_drop_out != 0 or self.shuffle_tags:
+ entry.cond_text = self.create_text(entry.filename_text)
+ if self.latent_sampling_method == "random":
+ entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
+ return entry
+
+
+class GroupedBatchSampler(Sampler):
+ def __init__(self, data_source: PersonalizedBase, batch_size: int):
+ super().__init__(data_source)
+
+ n = len(data_source)
+ self.groups = data_source.groups
+ self.len = n_batch = n // batch_size
+ expected = [len(g) / n * n_batch * batch_size for g in data_source.groups]
+ self.base = [int(e) // batch_size for e in expected]
+ self.n_rand_batches = nrb = n_batch - sum(self.base)
+ self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected]
+ self.batch_size = batch_size
+
+ def __len__(self):
+ return self.len
+
+ def __iter__(self):
+ b = self.batch_size
+
+ for g in self.groups:
+ random.shuffle(g)
+
+ batches = []
+ for g in self.groups:
+ batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
+ for _ in range(self.n_rand_batches):
+ rand_group = random.choices(self.groups, self.probs)[0]
+ batches.append(random.choices(rand_group, k=b))
+
+ random.shuffle(batches)
+
+ yield from batches
+
+
+class PersonalizedDataLoader(DataLoader):
+ def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
+ super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory)
+ if latent_sampling_method == "random":
+ self.collate_fn = collate_wrapper_random
+ else:
+ self.collate_fn = collate_wrapper
+
+
+class BatchLoader:
+ def __init__(self, data):
+ self.cond_text = [entry.cond_text for entry in data]
+ self.cond = [entry.cond for entry in data]
+ self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
+ if all(entry.weight is not None for entry in data):
+ self.weight = torch.stack([entry.weight for entry in data]).squeeze(1)
+ else:
+ self.weight = None
+ #self.emb_index = [entry.emb_index for entry in data]
+ #print(self.latent_sample.device)
+
+ def pin_memory(self):
+ self.latent_sample = self.latent_sample.pin_memory()
+ return self
+
+def collate_wrapper(batch):
+ return BatchLoader(batch)
+
+class BatchLoaderRandom(BatchLoader):
+ def __init__(self, data):
+ super().__init__(data)
+
+ def pin_memory(self):
+ return self
+
+def collate_wrapper_random(batch):
+ return BatchLoaderRandom(batch)
diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py
index a59798bb4..498c78461 100644
--- a/modules/textual_inversion/image_embedding.py
+++ b/modules/textual_inversion/image_embedding.py
@@ -1,213 +1,213 @@
-import base64
-import json
-import numpy as np
-import zlib
-from PIL import Image, PngImagePlugin, ImageDraw, ImageFont
-import torch
-from modules.shared import opts
-
-
-class EmbeddingEncoder(json.JSONEncoder):
- def default(self, obj):
- if isinstance(obj, torch.Tensor):
- return {'TORCHTENSOR': obj.cpu().detach().numpy().tolist()}
- return json.JSONEncoder.default(self, obj)
-
-
-class EmbeddingDecoder(json.JSONDecoder):
- def __init__(self, *args, **kwargs):
- json.JSONDecoder.__init__(self, *args, object_hook=self.object_hook, **kwargs)
-
- def object_hook(self, d):
- if 'TORCHTENSOR' in d:
- return torch.from_numpy(np.array(d['TORCHTENSOR']))
- return d
-
-
-def embedding_to_b64(data):
- d = json.dumps(data, cls=EmbeddingEncoder)
- return base64.b64encode(d.encode())
-
-
-def embedding_from_b64(data):
- d = base64.b64decode(data)
- return json.loads(d, cls=EmbeddingDecoder)
-
-
-def lcg(m=2**32, a=1664525, c=1013904223, seed=0):
- while True:
- seed = (a * seed + c) % m
- yield seed % 255
-
-
-def xor_block(block):
- g = lcg()
- randblock = np.array([next(g) for _ in range(np.prod(block.shape))]).astype(np.uint8).reshape(block.shape)
- return np.bitwise_xor(block.astype(np.uint8), randblock & 0x0F)
-
-
-def style_block(block, sequence):
- im = Image.new('RGB', (block.shape[1], block.shape[0]))
- draw = ImageDraw.Draw(im)
- i = 0
- for x in range(-6, im.size[0], 8):
- for yi, y in enumerate(range(-6, im.size[1], 8)):
- offset = 0
- if yi % 2 == 0:
- offset = 4
- shade = sequence[i % len(sequence)]
- i += 1
- draw.ellipse((x+offset, y, x+6+offset, y+6), fill=(shade, shade, shade))
-
- fg = np.array(im).astype(np.uint8) & 0xF0
-
- return block ^ fg
-
-
-def insert_image_data_embed(image, data):
- d = 3
- data_compressed = zlib.compress(json.dumps(data, cls=EmbeddingEncoder).encode(), level=9)
- data_np_ = np.frombuffer(data_compressed, np.uint8).copy()
- data_np_high = data_np_ >> 4
- data_np_low = data_np_ & 0x0F
-
- h = image.size[1]
- next_size = data_np_low.shape[0] + (h-(data_np_low.shape[0] % h))
- next_size = next_size + ((h*d)-(next_size % (h*d)))
-
- data_np_low = np.resize(data_np_low, next_size)
- data_np_low = data_np_low.reshape((h, -1, d))
-
- data_np_high = np.resize(data_np_high, next_size)
- data_np_high = data_np_high.reshape((h, -1, d))
-
- edge_style = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024]
- edge_style = (np.abs(edge_style)/np.max(np.abs(edge_style))*255).astype(np.uint8)
-
- data_np_low = style_block(data_np_low, sequence=edge_style)
- data_np_low = xor_block(data_np_low)
- data_np_high = style_block(data_np_high, sequence=edge_style[::-1])
- data_np_high = xor_block(data_np_high)
-
- im_low = Image.fromarray(data_np_low, mode='RGB')
- im_high = Image.fromarray(data_np_high, mode='RGB')
-
- background = Image.new('RGB', (image.size[0]+im_low.size[0]+im_high.size[0]+2, image.size[1]), (0, 0, 0))
- background.paste(im_low, (0, 0))
- background.paste(image, (im_low.size[0]+1, 0))
- background.paste(im_high, (im_low.size[0]+1+image.size[0]+1, 0))
-
- return background
-
-
-def crop_black(img, tol=0):
- mask = (img > tol).all(2)
- mask0, mask1 = mask.any(0), mask.any(1)
- col_start, col_end = mask0.argmax(), mask.shape[1]-mask0[::-1].argmax()
- row_start, row_end = mask1.argmax(), mask.shape[0]-mask1[::-1].argmax()
- return img[row_start:row_end, col_start:col_end]
-
-
-def extract_image_data_embed(image):
- d = 3
- outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1], image.size[0], d).astype(np.uint8)) & 0x0F
- black_cols = np.where(np.sum(outarr, axis=(0, 2)) == 0)
- if black_cols[0].shape[0] < 2:
- return None
-
- data_block_lower = outarr[:, :black_cols[0].min(), :].astype(np.uint8)
- data_block_upper = outarr[:, black_cols[0].max()+1:, :].astype(np.uint8)
-
- data_block_lower = xor_block(data_block_lower)
- data_block_upper = xor_block(data_block_upper)
-
- data_block = (data_block_upper << 4) | (data_block_lower)
- data_block = data_block.flatten().tobytes()
-
- data = zlib.decompress(data_block)
- return json.loads(data, cls=EmbeddingDecoder)
-
-
-def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, textfont=None):
- from math import cos
- image = srcimage.copy()
- fontsize = 32
- if textfont is None:
- textfont = opts.font or 'javascript/roboto.ttf'
-
- factor = 1.5
- gradient = Image.new('RGBA', (1, image.size[1]), color=(0, 0, 0, 0))
- for y in range(image.size[1]):
- mag = 1-cos(y/image.size[1]*factor)
- mag = max(mag, 1-cos((image.size[1]-y)/image.size[1]*factor*1.1))
- gradient.putpixel((0, y), (0, 0, 0, int(mag*255)))
- image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size))
-
- draw = ImageDraw.Draw(image)
-
- font = ImageFont.truetype(textfont, fontsize)
- padding = 10
-
- _, _, w, h = draw.textbbox((0, 0), title, font=font)
- fontsize = min(int(fontsize * (((image.size[0]*0.75)-(padding*4))/w)), 72)
- font = ImageFont.truetype(textfont, fontsize)
- _, _, w, h = draw.textbbox((0, 0), title, font=font)
- draw.text((padding, padding), title, anchor='lt', font=font, fill=(255, 255, 255, 230))
-
- _, _, w, h = draw.textbbox((0, 0), footerLeft, font=font)
- fontsize_left = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
- _, _, w, h = draw.textbbox((0, 0), footerMid, font=font)
- fontsize_mid = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
- _, _, w, h = draw.textbbox((0, 0), footerRight, font=font)
- fontsize_right = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
-
- font = ImageFont.truetype(textfont, min(fontsize_left, fontsize_mid, fontsize_right))
-
- draw.text((padding, image.size[1]-padding), footerLeft, anchor='ls', font=font, fill=(255, 255, 255, 230))
- draw.text((image.size[0]/2, image.size[1]-padding), footerMid, anchor='ms', font=font, fill=(255, 255, 255, 230))
- draw.text((image.size[0]-padding, image.size[1]-padding), footerRight, anchor='rs', font=font, fill=(255, 255, 255, 230))
-
- return image
-
-
-if __name__ == '__main__':
-
- testEmbed = Image.open('test_embedding.png')
- data = extract_image_data_embed(testEmbed)
- assert data is not None
-
- data = embedding_from_b64(testEmbed.text['sd-ti-embedding'])
- assert data is not None
-
- image = Image.new('RGBA', (512, 512), (255, 255, 200, 255))
- cap_image = caption_image_overlay(image, 'title', 'footerLeft', 'footerMid', 'footerRight')
-
- test_embed = {'string_to_param': {'*': torch.from_numpy(np.random.random((2, 4096)))}}
-
- embedded_image = insert_image_data_embed(cap_image, test_embed)
-
- retrived_embed = extract_image_data_embed(embedded_image)
-
- assert str(retrived_embed) == str(test_embed)
-
- embedded_image2 = insert_image_data_embed(cap_image, retrived_embed)
-
- assert embedded_image == embedded_image2
-
- g = lcg()
- shared_random = np.array([next(g) for _ in range(100)]).astype(np.uint8).tolist()
-
- reference_random = [253, 242, 127, 44, 157, 27, 239, 133, 38, 79, 167, 4, 177,
- 95, 130, 79, 78, 14, 52, 215, 220, 194, 126, 28, 240, 179,
- 160, 153, 149, 50, 105, 14, 21, 218, 199, 18, 54, 198, 193,
- 38, 128, 19, 53, 195, 124, 75, 205, 12, 6, 145, 0, 28,
- 30, 148, 8, 45, 218, 171, 55, 249, 97, 166, 12, 35, 0,
- 41, 221, 122, 215, 170, 31, 113, 186, 97, 119, 31, 23, 185,
- 66, 140, 30, 41, 37, 63, 137, 109, 216, 55, 159, 145, 82,
- 204, 86, 73, 222, 44, 198, 118, 240, 97]
-
- assert shared_random == reference_random
-
- hunna_kay_random_sum = sum(np.array([next(g) for _ in range(100000)]).astype(np.uint8).tolist())
-
- assert 12731374 == hunna_kay_random_sum
+import base64
+import json
+import numpy as np
+import zlib
+from PIL import Image, PngImagePlugin, ImageDraw, ImageFont
+import torch
+from modules.shared import opts
+
+
+class EmbeddingEncoder(json.JSONEncoder):
+ def default(self, obj):
+ if isinstance(obj, torch.Tensor):
+ return {'TORCHTENSOR': obj.cpu().detach().numpy().tolist()}
+ return json.JSONEncoder.default(self, obj)
+
+
+class EmbeddingDecoder(json.JSONDecoder):
+ def __init__(self, *args, **kwargs):
+ json.JSONDecoder.__init__(self, *args, object_hook=self.object_hook, **kwargs)
+
+ def object_hook(self, d):
+ if 'TORCHTENSOR' in d:
+ return torch.from_numpy(np.array(d['TORCHTENSOR']))
+ return d
+
+
+def embedding_to_b64(data):
+ d = json.dumps(data, cls=EmbeddingEncoder)
+ return base64.b64encode(d.encode())
+
+
+def embedding_from_b64(data):
+ d = base64.b64decode(data)
+ return json.loads(d, cls=EmbeddingDecoder)
+
+
+def lcg(m=2**32, a=1664525, c=1013904223, seed=0):
+ while True:
+ seed = (a * seed + c) % m
+ yield seed % 255
+
+
+def xor_block(block):
+ g = lcg()
+ randblock = np.array([next(g) for _ in range(np.prod(block.shape))]).astype(np.uint8).reshape(block.shape)
+ return np.bitwise_xor(block.astype(np.uint8), randblock & 0x0F)
+
+
+def style_block(block, sequence):
+ im = Image.new('RGB', (block.shape[1], block.shape[0]))
+ draw = ImageDraw.Draw(im)
+ i = 0
+ for x in range(-6, im.size[0], 8):
+ for yi, y in enumerate(range(-6, im.size[1], 8)):
+ offset = 0
+ if yi % 2 == 0:
+ offset = 4
+ shade = sequence[i % len(sequence)]
+ i += 1
+ draw.ellipse((x+offset, y, x+6+offset, y+6), fill=(shade, shade, shade))
+
+ fg = np.array(im).astype(np.uint8) & 0xF0
+
+ return block ^ fg
+
+
+def insert_image_data_embed(image, data):
+ d = 3
+ data_compressed = zlib.compress(json.dumps(data, cls=EmbeddingEncoder).encode(), level=9)
+ data_np_ = np.frombuffer(data_compressed, np.uint8).copy()
+ data_np_high = data_np_ >> 4
+ data_np_low = data_np_ & 0x0F
+
+ h = image.size[1]
+ next_size = data_np_low.shape[0] + (h-(data_np_low.shape[0] % h))
+ next_size = next_size + ((h*d)-(next_size % (h*d)))
+
+ data_np_low = np.resize(data_np_low, next_size)
+ data_np_low = data_np_low.reshape((h, -1, d))
+
+ data_np_high = np.resize(data_np_high, next_size)
+ data_np_high = data_np_high.reshape((h, -1, d))
+
+ edge_style = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024]
+ edge_style = (np.abs(edge_style)/np.max(np.abs(edge_style))*255).astype(np.uint8)
+
+ data_np_low = style_block(data_np_low, sequence=edge_style)
+ data_np_low = xor_block(data_np_low)
+ data_np_high = style_block(data_np_high, sequence=edge_style[::-1])
+ data_np_high = xor_block(data_np_high)
+
+ im_low = Image.fromarray(data_np_low, mode='RGB')
+ im_high = Image.fromarray(data_np_high, mode='RGB')
+
+ background = Image.new('RGB', (image.size[0]+im_low.size[0]+im_high.size[0]+2, image.size[1]), (0, 0, 0))
+ background.paste(im_low, (0, 0))
+ background.paste(image, (im_low.size[0]+1, 0))
+ background.paste(im_high, (im_low.size[0]+1+image.size[0]+1, 0))
+
+ return background
+
+
+def crop_black(img, tol=0):
+ mask = (img > tol).all(2)
+ mask0, mask1 = mask.any(0), mask.any(1)
+ col_start, col_end = mask0.argmax(), mask.shape[1]-mask0[::-1].argmax()
+ row_start, row_end = mask1.argmax(), mask.shape[0]-mask1[::-1].argmax()
+ return img[row_start:row_end, col_start:col_end]
+
+
+def extract_image_data_embed(image):
+ d = 3
+ outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1], image.size[0], d).astype(np.uint8)) & 0x0F
+ black_cols = np.where(np.sum(outarr, axis=(0, 2)) == 0)
+ if black_cols[0].shape[0] < 2:
+ return None
+
+ data_block_lower = outarr[:, :black_cols[0].min(), :].astype(np.uint8)
+ data_block_upper = outarr[:, black_cols[0].max()+1:, :].astype(np.uint8)
+
+ data_block_lower = xor_block(data_block_lower)
+ data_block_upper = xor_block(data_block_upper)
+
+ data_block = (data_block_upper << 4) | (data_block_lower)
+ data_block = data_block.flatten().tobytes()
+
+ data = zlib.decompress(data_block)
+ return json.loads(data, cls=EmbeddingDecoder)
+
+
+def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, textfont=None):
+ from math import cos
+ image = srcimage.copy()
+ fontsize = 32
+ if textfont is None:
+ textfont = opts.font or 'javascript/roboto.ttf'
+
+ factor = 1.5
+ gradient = Image.new('RGBA', (1, image.size[1]), color=(0, 0, 0, 0))
+ for y in range(image.size[1]):
+ mag = 1-cos(y/image.size[1]*factor)
+ mag = max(mag, 1-cos((image.size[1]-y)/image.size[1]*factor*1.1))
+ gradient.putpixel((0, y), (0, 0, 0, int(mag*255)))
+ image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size))
+
+ draw = ImageDraw.Draw(image)
+
+ font = ImageFont.truetype(textfont, fontsize)
+ padding = 10
+
+ _, _, w, h = draw.textbbox((0, 0), title, font=font)
+ fontsize = min(int(fontsize * (((image.size[0]*0.75)-(padding*4))/w)), 72)
+ font = ImageFont.truetype(textfont, fontsize)
+ _, _, w, h = draw.textbbox((0, 0), title, font=font)
+ draw.text((padding, padding), title, anchor='lt', font=font, fill=(255, 255, 255, 230))
+
+ _, _, w, h = draw.textbbox((0, 0), footerLeft, font=font)
+ fontsize_left = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
+ _, _, w, h = draw.textbbox((0, 0), footerMid, font=font)
+ fontsize_mid = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
+ _, _, w, h = draw.textbbox((0, 0), footerRight, font=font)
+ fontsize_right = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
+
+ font = ImageFont.truetype(textfont, min(fontsize_left, fontsize_mid, fontsize_right))
+
+ draw.text((padding, image.size[1]-padding), footerLeft, anchor='ls', font=font, fill=(255, 255, 255, 230))
+ draw.text((image.size[0]/2, image.size[1]-padding), footerMid, anchor='ms', font=font, fill=(255, 255, 255, 230))
+ draw.text((image.size[0]-padding, image.size[1]-padding), footerRight, anchor='rs', font=font, fill=(255, 255, 255, 230))
+
+ return image
+
+
+if __name__ == '__main__':
+
+ testEmbed = Image.open('test_embedding.png')
+ data = extract_image_data_embed(testEmbed)
+ assert data is not None
+
+ data = embedding_from_b64(testEmbed.text['sd-ti-embedding'])
+ assert data is not None
+
+ image = Image.new('RGBA', (512, 512), (255, 255, 200, 255))
+ cap_image = caption_image_overlay(image, 'title', 'footerLeft', 'footerMid', 'footerRight')
+
+ test_embed = {'string_to_param': {'*': torch.from_numpy(np.random.random((2, 4096)))}}
+
+ embedded_image = insert_image_data_embed(cap_image, test_embed)
+
+ retrived_embed = extract_image_data_embed(embedded_image)
+
+ assert str(retrived_embed) == str(test_embed)
+
+ embedded_image2 = insert_image_data_embed(cap_image, retrived_embed)
+
+ assert embedded_image == embedded_image2
+
+ g = lcg()
+ shared_random = np.array([next(g) for _ in range(100)]).astype(np.uint8).tolist()
+
+ reference_random = [253, 242, 127, 44, 157, 27, 239, 133, 38, 79, 167, 4, 177,
+ 95, 130, 79, 78, 14, 52, 215, 220, 194, 126, 28, 240, 179,
+ 160, 153, 149, 50, 105, 14, 21, 218, 199, 18, 54, 198, 193,
+ 38, 128, 19, 53, 195, 124, 75, 205, 12, 6, 145, 0, 28,
+ 30, 148, 8, 45, 218, 171, 55, 249, 97, 166, 12, 35, 0,
+ 41, 221, 122, 215, 170, 31, 113, 186, 97, 119, 31, 23, 185,
+ 66, 140, 30, 41, 37, 63, 137, 109, 216, 55, 159, 145, 82,
+ 204, 86, 73, 222, 44, 198, 118, 240, 97]
+
+ assert shared_random == reference_random
+
+ hunna_kay_random_sum = sum(np.array([next(g) for _ in range(100000)]).astype(np.uint8).tolist())
+
+ assert 12731374 == hunna_kay_random_sum
diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py
index f12a62824..45a2eeb54 100644
--- a/modules/textual_inversion/learn_schedule.py
+++ b/modules/textual_inversion/learn_schedule.py
@@ -1,77 +1,77 @@
-class LearnScheduleIterator:
- def __init__(self, learn_rate, max_steps, cur_step=0):
- """
- specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000
- """
-
- pairs = learn_rate.split(',')
- self.rates = []
- self.it = 0
- self.maxit = 0
- try:
- for pair in pairs:
- if not pair.strip():
- continue
- tmp = pair.split(':')
- if len(tmp) == 2:
- step = int(tmp[1])
- if step > cur_step:
- self.rates.append((float(tmp[0]), min(step, max_steps)))
- self.maxit += 1
- if step > max_steps:
- return
- elif step == -1:
- self.rates.append((float(tmp[0]), max_steps))
- self.maxit += 1
- return
- else:
- self.rates.append((float(tmp[0]), max_steps))
- self.maxit += 1
- return
- assert self.rates
- except (ValueError, AssertionError) as e:
- raise RuntimeError('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.') from e
-
-
- def __iter__(self):
- return self
-
- def __next__(self):
- if self.it < self.maxit:
- self.it += 1
- return self.rates[self.it - 1]
- else:
- raise StopIteration
-
-
-class LearnRateScheduler:
- def __init__(self, learn_rate, max_steps, cur_step=0, verbose=True):
- self.schedules = LearnScheduleIterator(learn_rate, max_steps, cur_step)
- (self.learn_rate, self.end_step) = next(self.schedules)
- self.verbose = verbose
-
- # if self.verbose:
- # print(f'Training at rate of {self.learn_rate} until step {self.end_step}')
-
- self.finished = False
-
- def step(self, step_number):
- if step_number < self.end_step:
- return False
-
- try:
- (self.learn_rate, self.end_step) = next(self.schedules)
- except StopIteration:
- self.finished = True
- return False
- return True
-
- def apply(self, optimizer, step_number):
- if not self.step(step_number):
- return
-
- # if self.verbose:
- # tqdm.tqdm.write(f'Training at rate of {self.learn_rate} until step {self.end_step}')
-
- for pg in optimizer.param_groups:
- pg['lr'] = self.learn_rate
+class LearnScheduleIterator:
+ def __init__(self, learn_rate, max_steps, cur_step=0):
+ """
+ specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000
+ """
+
+ pairs = learn_rate.split(',')
+ self.rates = []
+ self.it = 0
+ self.maxit = 0
+ try:
+ for pair in pairs:
+ if not pair.strip():
+ continue
+ tmp = pair.split(':')
+ if len(tmp) == 2:
+ step = int(tmp[1])
+ if step > cur_step:
+ self.rates.append((float(tmp[0]), min(step, max_steps)))
+ self.maxit += 1
+ if step > max_steps:
+ return
+ elif step == -1:
+ self.rates.append((float(tmp[0]), max_steps))
+ self.maxit += 1
+ return
+ else:
+ self.rates.append((float(tmp[0]), max_steps))
+ self.maxit += 1
+ return
+ assert self.rates
+ except (ValueError, AssertionError) as e:
+ raise RuntimeError('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.') from e
+
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ if self.it < self.maxit:
+ self.it += 1
+ return self.rates[self.it - 1]
+ else:
+ raise StopIteration
+
+
+class LearnRateScheduler:
+ def __init__(self, learn_rate, max_steps, cur_step=0, verbose=True):
+ self.schedules = LearnScheduleIterator(learn_rate, max_steps, cur_step)
+ (self.learn_rate, self.end_step) = next(self.schedules)
+ self.verbose = verbose
+
+ # if self.verbose:
+ # print(f'Training at rate of {self.learn_rate} until step {self.end_step}')
+
+ self.finished = False
+
+ def step(self, step_number):
+ if step_number < self.end_step:
+ return False
+
+ try:
+ (self.learn_rate, self.end_step) = next(self.schedules)
+ except StopIteration:
+ self.finished = True
+ return False
+ return True
+
+ def apply(self, optimizer, step_number):
+ if not self.step(step_number):
+ return
+
+ # if self.verbose:
+ # tqdm.tqdm.write(f'Training at rate of {self.learn_rate} until step {self.end_step}')
+
+ for pg in optimizer.param_groups:
+ pg['lr'] = self.learn_rate
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index a2ef4c6e6..6813da77d 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -1,217 +1,217 @@
-import os
-import math
-from tqdm import tqdm
-from PIL import Image, ImageOps
-from modules import paths, shared, images, deepbooru
-from modules.textual_inversion import autocrop
-
-
-def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size=False, process_keep_channels=False, process_flip=False, process_split=False, process_caption_only=False, process_caption=False, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None): # pylint: disable=unused-argument
- try:
- if process_caption:
- shared.interrogator.load()
-
- if process_caption_deepbooru:
- deepbooru.model.start()
-
- preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_keep_channels, process_flip, process_split, process_caption, process_caption_deepbooru, process_caption_only, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug, process_multicrop, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold)
-
- finally:
-
- if process_caption:
- shared.interrogator.send_blip_to_ram()
-
- if process_caption_deepbooru:
- deepbooru.model.stop()
-
-
-class PreprocessParams:
- src = None
- dstdir = None
- subindex = 0
- flip = False
- process_caption_only = False
- process_caption = False
- process_caption_deepbooru = False
- preprocess_txt_action = None
-
-
-def save_pic_with_caption(image, index, params: PreprocessParams, existing_caption=None, existing_caption_filename=None):
- caption = ""
- if params.process_caption:
- caption += shared.interrogator.generate_caption(image)
- if params.process_caption_deepbooru:
- if len(caption) > 0:
- caption += ", "
- caption += deepbooru.model.tag_multi(image)
-
- filename_part = params.src
- filename_part = os.path.splitext(filename_part)[0]
- filename_part = os.path.basename(filename_part)
-
- basename = f"{index:05}-{params.subindex}-{filename_part}"
- if not params.process_caption_only:
- image.save(os.path.join(params.dstdir, f"{basename}.png"))
-
- if params.preprocess_txt_action == 'prepend' and existing_caption:
- caption = f"{existing_caption} {caption}"
- elif params.preprocess_txt_action == 'append' and existing_caption:
- caption = f"{caption} {existing_caption}"
- elif params.preprocess_txt_action == 'copy' and existing_caption:
- caption = existing_caption
- caption = caption.strip()
- if len(caption) > 0:
- if params.process_caption_only:
- fn = os.path.join(params.dstdir, f"{filename_part}.txt")
- elif existing_caption_filename is not None:
- fn = existing_caption_filename
- else:
- fn = os.path.join(params.dstdir, f"{basename}.txt")
- with open(fn, "w", encoding="utf8") as file:
- file.write(caption)
-
- params.subindex += 1
-
-
-def save_pic(image, index, params, existing_caption=None, existing_caption_filename=None):
- save_pic_with_caption(image, index, params, existing_caption=existing_caption, existing_caption_filename=existing_caption_filename)
- if params.flip:
- save_pic_with_caption(ImageOps.mirror(image), index, params, existing_caption=existing_caption, existing_caption_filename=existing_caption_filename)
-
-
-def split_pic(image, inverse_xy, width, height, overlap_ratio):
- if inverse_xy:
- from_w, from_h = image.height, image.width
- to_w, to_h = height, width
- else:
- from_w, from_h = image.width, image.height
- to_w, to_h = width, height
- h = from_h * to_w // from_w
- if inverse_xy:
- image = image.resize((h, to_w))
- else:
- image = image.resize((to_w, h))
-
- split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio)))
- y_step = (h - to_h) / (split_count - 1)
- for i in range(split_count):
- y = int(y_step * i)
- if inverse_xy:
- splitted = image.crop((y, 0, y + to_h, to_w))
- else:
- splitted = image.crop((0, y, to_w, y + to_h))
- yield splitted
-
-# not using torchvision.transforms.CenterCrop because it doesn't allow float regions
-def center_crop(image: Image, w: int, h: int):
- iw, ih = image.size
- if ih / h < iw / w:
- sw = w * ih / h
- box = (iw - sw) / 2, 0, iw - (iw - sw) / 2, ih
- else:
- sh = h * iw / w
- box = 0, (ih - sh) / 2, iw, ih - (ih - sh) / 2
- return image.resize((w, h), Image.Resampling.LANCZOS, box)
-
-
-def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, threshold):
- iw, ih = image.size
- err = lambda w, h: 1-(lambda x: x if x < 1 else 1/x)(iw/ih/(w/h)) # pylint: disable=unnecessary-lambda-assignment,unnecessary-direct-lambda-call
- wh = max(((w, h) for w in range(mindim, maxdim+1, 64) for h in range(mindim, maxdim+1, 64)
- if minarea <= w * h <= maxarea and err(w, h) <= threshold),
- key= lambda wh: (wh[0]*wh[1], -err(*wh))[::1 if objective=='Maximize area' else -1],
- default=None
- )
- return wh and center_crop(image, *wh)
-
-
-def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_keep_channels, process_flip, process_split, process_caption, process_caption_deepbooru, process_caption_only, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug, process_multicrop, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold):
-
- width = process_width
- height = process_height
- src = os.path.abspath(process_src)
- dst = os.path.abspath(process_dst)
- split_threshold = max(0.0, min(1.0, split_threshold))
- overlap_ratio = max(0.0, min(0.9, overlap_ratio))
- assert src != dst, 'same directory specified as source and destination'
- os.makedirs(dst, exist_ok=True)
- files = os.listdir(src)
- shared.state.job = "preprocess"
- shared.state.textinfo = "Preprocessing..."
- shared.state.job_count = len(files)
- params = PreprocessParams()
- params.dstdir = dst
- params.flip = process_flip
- params.process_caption_only = process_caption_only
- params.process_caption = process_caption
- params.process_caption_deepbooru = process_caption_deepbooru
- params.preprocess_txt_action = preprocess_txt_action
- pbar = tqdm(files)
- for index, imagefile in enumerate(pbar):
- params.subindex = 0
- filename = os.path.join(src, imagefile)
- try:
- img = Image.open(filename)
- img = ImageOps.exif_transpose(img)
- if not process_keep_channels:
- img = img.convert("RGB")
- except Exception:
- continue
-
- description = f"Preprocessing image {index + 1}/{len(files)}"
- pbar.set_description(description)
- shared.state.textinfo = description
- params.src = filename
- existing_caption = None
- existing_caption_filename = f"{os.path.splitext(filename)[0]}.txt"
- if os.path.exists(existing_caption_filename):
- with open(existing_caption_filename, 'r', encoding="utf8") as file:
- existing_caption = file.read()
- else:
- existing_caption_filename = None
- if shared.state.interrupted:
- break
- if img.height > img.width:
- ratio = (img.width * height) / (img.height * width)
- inverse_xy = False
- else:
- ratio = (img.height * width) / (img.width * height)
- inverse_xy = True
- process_default_resize = True
- if process_split and ratio < 1.0 and ratio <= split_threshold:
- for splitted in split_pic(img, inverse_xy, width, height, overlap_ratio):
- save_pic(splitted, index, params, existing_caption=existing_caption, existing_caption_filename=existing_caption_filename)
- process_default_resize = False
- if process_focal_crop and img.height != img.width:
- dnn_model_path = None
- try:
- dnn_model_path = autocrop.download_and_cache_models(os.path.join(paths.models_path, "opencv"))
- except Exception as e:
- print("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", e)
- autocrop_settings = autocrop.Settings(
- crop_width = width,
- crop_height = height,
- face_points_weight = process_focal_crop_face_weight,
- entropy_points_weight = process_focal_crop_entropy_weight,
- corner_points_weight = process_focal_crop_edges_weight,
- annotate_image = process_focal_crop_debug,
- dnn_model_path = dnn_model_path,
- )
- for focal in autocrop.crop_image(img, autocrop_settings):
- save_pic(focal, index, params, existing_caption=existing_caption)
- process_default_resize = False
-
- if process_multicrop:
- cropped = multicrop_pic(img, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold)
- if cropped is not None:
- save_pic(cropped, index, params, existing_caption=existing_caption)
- else:
- print(f"skipped {img.width}x{img.height} image {filename} (can't find suitable size within error threshold)")
- process_default_resize = False
- if process_keep_original_size:
- save_pic(img, index, params, existing_caption=existing_caption)
- process_default_resize = False
- if process_default_resize:
- img = images.resize_image(1, img, width, height)
- save_pic(img, index, params, existing_caption=existing_caption)
- shared.state.nextjob()
+import os
+import math
+from tqdm import tqdm
+from PIL import Image, ImageOps
+from modules import paths, shared, images, deepbooru
+from modules.textual_inversion import autocrop
+
+
+def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size=False, process_keep_channels=False, process_flip=False, process_split=False, process_caption_only=False, process_caption=False, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None): # pylint: disable=unused-argument
+ try:
+ if process_caption:
+ shared.interrogator.load()
+
+ if process_caption_deepbooru:
+ deepbooru.model.start()
+
+ preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_keep_channels, process_flip, process_split, process_caption, process_caption_deepbooru, process_caption_only, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug, process_multicrop, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold)
+
+ finally:
+
+ if process_caption:
+ shared.interrogator.send_blip_to_ram()
+
+ if process_caption_deepbooru:
+ deepbooru.model.stop()
+
+
+class PreprocessParams:
+ src = None
+ dstdir = None
+ subindex = 0
+ flip = False
+ process_caption_only = False
+ process_caption = False
+ process_caption_deepbooru = False
+ preprocess_txt_action = None
+
+
+def save_pic_with_caption(image, index, params: PreprocessParams, existing_caption=None, existing_caption_filename=None):
+ caption = ""
+ if params.process_caption:
+ caption += shared.interrogator.generate_caption(image)
+ if params.process_caption_deepbooru:
+ if len(caption) > 0:
+ caption += ", "
+ caption += deepbooru.model.tag_multi(image)
+
+ filename_part = params.src
+ filename_part = os.path.splitext(filename_part)[0]
+ filename_part = os.path.basename(filename_part)
+
+ basename = f"{index:05}-{params.subindex}-{filename_part}"
+ if not params.process_caption_only:
+ image.save(os.path.join(params.dstdir, f"{basename}.png"))
+
+ if params.preprocess_txt_action == 'prepend' and existing_caption:
+ caption = f"{existing_caption} {caption}"
+ elif params.preprocess_txt_action == 'append' and existing_caption:
+ caption = f"{caption} {existing_caption}"
+ elif params.preprocess_txt_action == 'copy' and existing_caption:
+ caption = existing_caption
+ caption = caption.strip()
+ if len(caption) > 0:
+ if params.process_caption_only:
+ fn = os.path.join(params.dstdir, f"{filename_part}.txt")
+ elif existing_caption_filename is not None:
+ fn = existing_caption_filename
+ else:
+ fn = os.path.join(params.dstdir, f"{basename}.txt")
+ with open(fn, "w", encoding="utf8") as file:
+ file.write(caption)
+
+ params.subindex += 1
+
+
+def save_pic(image, index, params, existing_caption=None, existing_caption_filename=None):
+ save_pic_with_caption(image, index, params, existing_caption=existing_caption, existing_caption_filename=existing_caption_filename)
+ if params.flip:
+ save_pic_with_caption(ImageOps.mirror(image), index, params, existing_caption=existing_caption, existing_caption_filename=existing_caption_filename)
+
+
+def split_pic(image, inverse_xy, width, height, overlap_ratio):
+ if inverse_xy:
+ from_w, from_h = image.height, image.width
+ to_w, to_h = height, width
+ else:
+ from_w, from_h = image.width, image.height
+ to_w, to_h = width, height
+ h = from_h * to_w // from_w
+ if inverse_xy:
+ image = image.resize((h, to_w))
+ else:
+ image = image.resize((to_w, h))
+
+ split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio)))
+ y_step = (h - to_h) / (split_count - 1)
+ for i in range(split_count):
+ y = int(y_step * i)
+ if inverse_xy:
+ splitted = image.crop((y, 0, y + to_h, to_w))
+ else:
+ splitted = image.crop((0, y, to_w, y + to_h))
+ yield splitted
+
+# not using torchvision.transforms.CenterCrop because it doesn't allow float regions
+def center_crop(image: Image, w: int, h: int):
+ iw, ih = image.size
+ if ih / h < iw / w:
+ sw = w * ih / h
+ box = (iw - sw) / 2, 0, iw - (iw - sw) / 2, ih
+ else:
+ sh = h * iw / w
+ box = 0, (ih - sh) / 2, iw, ih - (ih - sh) / 2
+ return image.resize((w, h), Image.Resampling.LANCZOS, box)
+
+
+def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, threshold):
+ iw, ih = image.size
+ err = lambda w, h: 1-(lambda x: x if x < 1 else 1/x)(iw/ih/(w/h)) # pylint: disable=unnecessary-lambda-assignment,unnecessary-direct-lambda-call
+ wh = max(((w, h) for w in range(mindim, maxdim+1, 64) for h in range(mindim, maxdim+1, 64)
+ if minarea <= w * h <= maxarea and err(w, h) <= threshold),
+ key= lambda wh: (wh[0]*wh[1], -err(*wh))[::1 if objective=='Maximize area' else -1],
+ default=None
+ )
+ return wh and center_crop(image, *wh)
+
+
+def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_keep_channels, process_flip, process_split, process_caption, process_caption_deepbooru, process_caption_only, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug, process_multicrop, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold):
+
+ width = process_width
+ height = process_height
+ src = os.path.abspath(process_src)
+ dst = os.path.abspath(process_dst)
+ split_threshold = max(0.0, min(1.0, split_threshold))
+ overlap_ratio = max(0.0, min(0.9, overlap_ratio))
+ assert src != dst, 'same directory specified as source and destination'
+ os.makedirs(dst, exist_ok=True)
+ files = os.listdir(src)
+ shared.state.job = "preprocess"
+ shared.state.textinfo = "Preprocessing..."
+ shared.state.job_count = len(files)
+ params = PreprocessParams()
+ params.dstdir = dst
+ params.flip = process_flip
+ params.process_caption_only = process_caption_only
+ params.process_caption = process_caption
+ params.process_caption_deepbooru = process_caption_deepbooru
+ params.preprocess_txt_action = preprocess_txt_action
+ pbar = tqdm(files)
+ for index, imagefile in enumerate(pbar):
+ params.subindex = 0
+ filename = os.path.join(src, imagefile)
+ try:
+ img = Image.open(filename)
+ img = ImageOps.exif_transpose(img)
+ if not process_keep_channels:
+ img = img.convert("RGB")
+ except Exception:
+ continue
+
+ description = f"Preprocessing image {index + 1}/{len(files)}"
+ pbar.set_description(description)
+ shared.state.textinfo = description
+ params.src = filename
+ existing_caption = None
+ existing_caption_filename = f"{os.path.splitext(filename)[0]}.txt"
+ if os.path.exists(existing_caption_filename):
+ with open(existing_caption_filename, 'r', encoding="utf8") as file:
+ existing_caption = file.read()
+ else:
+ existing_caption_filename = None
+ if shared.state.interrupted:
+ break
+ if img.height > img.width:
+ ratio = (img.width * height) / (img.height * width)
+ inverse_xy = False
+ else:
+ ratio = (img.height * width) / (img.width * height)
+ inverse_xy = True
+ process_default_resize = True
+ if process_split and ratio < 1.0 and ratio <= split_threshold:
+ for splitted in split_pic(img, inverse_xy, width, height, overlap_ratio):
+ save_pic(splitted, index, params, existing_caption=existing_caption, existing_caption_filename=existing_caption_filename)
+ process_default_resize = False
+ if process_focal_crop and img.height != img.width:
+ dnn_model_path = None
+ try:
+ dnn_model_path = autocrop.download_and_cache_models(os.path.join(paths.models_path, "opencv"))
+ except Exception as e:
+ print("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", e)
+ autocrop_settings = autocrop.Settings(
+ crop_width = width,
+ crop_height = height,
+ face_points_weight = process_focal_crop_face_weight,
+ entropy_points_weight = process_focal_crop_entropy_weight,
+ corner_points_weight = process_focal_crop_edges_weight,
+ annotate_image = process_focal_crop_debug,
+ dnn_model_path = dnn_model_path,
+ )
+ for focal in autocrop.crop_image(img, autocrop_settings):
+ save_pic(focal, index, params, existing_caption=existing_caption)
+ process_default_resize = False
+
+ if process_multicrop:
+ cropped = multicrop_pic(img, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold)
+ if cropped is not None:
+ save_pic(cropped, index, params, existing_caption=existing_caption)
+ else:
+ print(f"skipped {img.width}x{img.height} image {filename} (can't find suitable size within error threshold)")
+ process_default_resize = False
+ if process_keep_original_size:
+ save_pic(img, index, params, existing_caption=existing_caption)
+ process_default_resize = False
+ if process_default_resize:
+ img = images.resize_image(1, img, width, height)
+ save_pic(img, index, params, existing_caption=existing_caption)
+ shared.state.nextjob()
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 98295ef99..fd3fac891 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -1,692 +1,692 @@
-import os
-import html
-import csv
-import time
-from collections import namedtuple
-import torch
-from tqdm import tqdm
-import safetensors.torch
-import numpy as np
-from PIL import Image, PngImagePlugin
-from modules import shared, devices, processing, sd_models, images, errors
-import modules.textual_inversion.dataset
-from modules.textual_inversion.learn_schedule import LearnRateScheduler
-from modules.textual_inversion.image_embedding import embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay
-from modules.textual_inversion.ti_logging import save_settings_to_file
-from modules.modelloader import directory_files, extension_filter, directory_mtime
-
-
-TextualInversionTemplate = namedtuple("TextualInversionTemplate", ["name", "path"])
-textual_inversion_templates = {}
-
-
-def list_textual_inversion_templates():
- textual_inversion_templates.clear()
- for root, _dirs, fns in os.walk(shared.opts.embeddings_templates_dir):
- for fn in fns:
- path = os.path.join(root, fn)
- textual_inversion_templates[fn] = TextualInversionTemplate(fn, path)
- return textual_inversion_templates
-
-
-class Embedding:
- def __init__(self, vec, name, filename=None, step=None):
- self.vec = vec
- self.name = name
- self.tag = name
- self.step = step
- self.filename = filename
- self.basename = os.path.relpath(filename, shared.opts.embeddings_dir) if filename is not None else None
- self.shape = None
- self.vectors = 0
- self.cached_checksum = None
- self.sd_checkpoint = None
- self.sd_checkpoint_name = None
- self.optimizer_state_dict = None
-
- def save(self, filename):
- embedding_data = {
- "string_to_token": {"*": 265},
- "string_to_param": {"*": self.vec},
- "name": self.name,
- "step": self.step,
- "sd_checkpoint": self.sd_checkpoint,
- "sd_checkpoint_name": self.sd_checkpoint_name,
- }
- torch.save(embedding_data, filename)
- if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None:
- optimizer_saved_dict = {
- 'hash': self.checksum(),
- 'optimizer_state_dict': self.optimizer_state_dict,
- }
- torch.save(optimizer_saved_dict, f"{filename}.optim")
-
- def checksum(self):
- if self.cached_checksum is not None:
- return self.cached_checksum
- def const_hash(a):
- r = 0
- for v in a:
- r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
- return r
- self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}'
- return self.cached_checksum
-
-
-class DirWithTextualInversionEmbeddings:
- def __init__(self, path):
- self.path = path
- self.mtime = None
-
- def has_changed(self):
- if not os.path.isdir(self.path):
- return False
- return directory_mtime(self.path) != self.mtime
-
- def update(self):
- if not os.path.isdir(self.path):
- return
- self.mtime = directory_mtime(self.path)
-
-
-class EmbeddingDatabase:
- def __init__(self):
- self.ids_lookup = {}
- self.word_embeddings = {}
- self.skipped_embeddings = {}
- self.expected_shape = -1
- self.embedding_dirs = {}
- self.previously_displayed_embeddings = ()
- self.embeddings_used = []
-
- def add_embedding_dir(self, path):
- self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
-
- def clear_embedding_dirs(self):
- self.embedding_dirs.clear()
-
- def register_embedding(self, embedding, model):
- self.word_embeddings[embedding.name] = embedding
- if hasattr(model, 'cond_stage_model'):
- ids = model.cond_stage_model.tokenize([embedding.name])[0]
- elif hasattr(model, 'tokenizer'):
- ids = model.tokenizer.convert_tokens_to_ids(embedding.name)
- if type(ids) != list:
- ids = [ids]
- first_id = ids[0]
- if first_id not in self.ids_lookup:
- self.ids_lookup[first_id] = []
- self.ids_lookup[first_id] = sorted(self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True)
- return embedding
-
- def get_expected_shape(self):
- if shared.backend == shared.Backend.DIFFUSERS:
- return 0
- if shared.sd_model is None:
- shared.log.error('Model not loaded')
- return 0
- vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
- return vec.shape[1]
-
- def load_diffusers_embedding(self, filename: str, path: str):
- if shared.sd_model is None:
- return
- fn, ext = os.path.splitext(filename)
- if ext.lower() != ".pt" and ext.lower() != ".safetensors":
- return
- pipe = shared.sd_model
- name = os.path.basename(fn)
- embedding = Embedding(vec=None, name=name, filename=path)
- if not hasattr(pipe, "tokenizer") or not hasattr(pipe, 'text_encoder'):
- self.skipped_embeddings[name] = embedding
- return
- try:
- is_xl = hasattr(pipe, 'text_encoder_2')
- try:
- if not is_xl: # only use for sd15/sd21
- pipe.load_textual_inversion(path, token=name, cache_dir=shared.opts.diffusers_dir, local_files_only=True)
- self.register_embedding(embedding, shared.sd_model)
- except Exception:
- pass
- is_loaded = pipe.tokenizer.convert_tokens_to_ids(name)
- if type(is_loaded) != list:
- is_loaded = [is_loaded]
- is_loaded = is_loaded[0] > 49407
- if is_loaded:
- self.register_embedding(embedding, shared.sd_model)
- else:
- embeddings_dict = {}
- if ext.lower() in ['.safetensors']:
- with safetensors.torch.safe_open(path, framework="pt") as f:
- for k in f.keys():
- embeddings_dict[k] = f.get_tensor(k)
- else:
- raise NotImplementedError
- """
- # alternatively could disable load_textual_inversion and load everything here
- elif ext.lower() in ['.pt', '.bin']:
- data = torch.load(path, map_location="cpu")
- embedding.tag = data.get('name', None)
- embedding.step = data.get('step', None)
- embedding.sd_checkpoint = data.get('sd_checkpoint', None)
- embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
- param_dict = data.get('string_to_param', None)
- embeddings_dict['clip_l'] = []
- for tokens in param_dict.values():
- for vec in tokens:
- embeddings_dict['clip_l'].append(vec)
- """
- clip_l = pipe.text_encoder if hasattr(pipe, 'text_encoder') else None
- clip_g = pipe.text_encoder_2 if hasattr(pipe, 'text_encoder_2') else None
- is_sd = clip_l is not None and 'clip_l' in embeddings_dict and clip_g is None and 'clip_g' not in embeddings_dict
- is_xl = clip_l is not None and 'clip_l' in embeddings_dict and clip_g is not None and 'clip_g' in embeddings_dict
- tokens = []
- for i in range(len(embeddings_dict["clip_l"])):
- if (is_sd or is_xl) and (len(clip_l.get_input_embeddings().weight.data[0]) == len(embeddings_dict["clip_l"][i])):
- tokens.append(name if i == 0 else f"{name}_{i}")
- num_added = pipe.tokenizer.add_tokens(tokens)
- if num_added > 0:
- token_ids = pipe.tokenizer.convert_tokens_to_ids(tokens)
- if is_sd: # only used for sd15 if load_textual_inversion failed and format is safetensors
- clip_l.resize_token_embeddings(len(pipe.tokenizer))
- for i in range(len(token_ids)):
- clip_l.get_input_embeddings().weight.data[token_ids[i]] = embeddings_dict["clip_l"][i]
- elif is_xl:
- pipe.tokenizer_2.add_tokens(tokens)
- clip_l.resize_token_embeddings(len(pipe.tokenizer))
- clip_g.resize_token_embeddings(len(pipe.tokenizer))
- for i in range(len(token_ids)):
- clip_l.get_input_embeddings().weight.data[token_ids[i]] = embeddings_dict["clip_l"][i]
- clip_g.get_input_embeddings().weight.data[token_ids[i]] = embeddings_dict["clip_g"][i]
- self.register_embedding(embedding, shared.sd_model)
- else:
- raise NotImplementedError
- except Exception:
- self.skipped_embeddings[name] = embedding
-
- def load_from_file(self, path, filename):
- name, ext = os.path.splitext(filename)
- ext = ext.upper()
- if shared.backend == shared.Backend.DIFFUSERS:
- self.load_diffusers_embedding(filename, path)
- return
-
- if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
- if '.preview' in filename.lower():
- return
- embed_image = Image.open(path)
- if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
- data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
- else:
- data = extract_image_data_embed(embed_image)
- if not data: # if data is None, means this is not an embeding, just a preview image
- return
- elif ext in ['.BIN', '.PT']:
- data = torch.load(path, map_location="cpu")
- elif ext in ['.SAFETENSORS']:
- data = safetensors.torch.load_file(path, device="cpu")
- else:
- return
-
- # textual inversion embeddings
- if 'string_to_param' in data:
- param_dict = data['string_to_param']
- param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
- assert len(param_dict) == 1, 'embedding file has multiple terms in it'
- emb = next(iter(param_dict.items()))[1]
- # diffuser concepts
- elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
- if len(data.keys()) != 1:
- self.skipped_embeddings[name] = Embedding(None, name=name, filename=path)
- return
- emb = next(iter(data.values()))
- if len(emb.shape) == 1:
- emb = emb.unsqueeze(0)
- else:
- raise RuntimeError(f"Couldn't identify {filename} as textual inversion embedding")
-
- vec = emb.detach().to(devices.device, dtype=torch.float32)
- # name = data.get('name', name)
- embedding = Embedding(vec=vec, name=name, filename=path)
- embedding.tag = data.get('name', None)
- embedding.step = data.get('step', None)
- embedding.sd_checkpoint = data.get('sd_checkpoint', None)
- embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
- embedding.vectors = vec.shape[0]
- embedding.shape = vec.shape[-1]
- if self.expected_shape == -1 or self.expected_shape == embedding.shape:
- self.register_embedding(embedding, shared.sd_model)
- else:
- self.skipped_embeddings[name] = embedding
-
- def load_from_dir(self, embdir):
- if sd_models.model_data.sd_model is None:
- shared.log.info('Skipping embeddings load: model not loaded')
- return
- if not os.path.isdir(embdir.path):
- return
- is_ext = extension_filter(['.PNG', '.WEBP', '.JXL', '.AVIF', '.BIN', '.PT', '.SAFETENSORS'])
- is_not_preview = lambda fp: not next(iter(os.path.splitext(fp))).upper().endswith('.PREVIEW') # pylint: disable=unnecessary-lambda-assignment
- for file_path in [*filter(lambda fp: is_ext(fp) and is_not_preview(fp), directory_files(embdir.path))]:
- try:
- if os.stat(file_path).st_size == 0:
- continue
- fn = os.path.basename(file_path)
- self.load_from_file(file_path, fn)
- except Exception as e:
- errors.display(e, f'embedding load {fn}')
- continue
-
- def load_textual_inversion_embeddings(self, force_reload=False):
- if shared.sd_model is None:
- return
- t0 = time.time()
- if not force_reload:
- need_reload = False
- for embdir in self.embedding_dirs.values():
- if embdir.has_changed():
- need_reload = True
- break
- if not need_reload:
- return
- self.ids_lookup.clear()
- self.word_embeddings.clear()
- self.skipped_embeddings.clear()
- self.embeddings_used.clear()
- self.expected_shape = self.get_expected_shape()
- for embdir in self.embedding_dirs.values():
- self.load_from_dir(embdir)
- embdir.update()
-
- # re-sort word_embeddings because load_from_dir may not load in alphabetic order.
- # using a temporary copy so we don't reinitialize self.word_embeddings in case other objects have a reference to it.
- sorted_word_embeddings = {e.name: e for e in sorted(self.word_embeddings.values(), key=lambda e: e.name.lower())}
- self.word_embeddings.clear()
- self.word_embeddings.update(sorted_word_embeddings)
-
- displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys()))
- if self.previously_displayed_embeddings != displayed_embeddings:
- self.previously_displayed_embeddings = displayed_embeddings
- t1 = time.time()
- shared.log.info(f"Load embeddings: loaded={len(self.word_embeddings)} skipped={len(self.skipped_embeddings)} time={t1-t0:.2f}")
-
-
- def find_embedding_at_position(self, tokens, offset):
- token = tokens[offset]
- possible_matches = self.ids_lookup.get(token, None)
- if possible_matches is None:
- return None, None
- for ids, embedding in possible_matches:
- if tokens[offset:offset + len(ids)] == ids:
- return embedding, len(ids)
- return None, None
-
-
-def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
- cond_model = shared.sd_model.cond_stage_model
- with devices.autocast():
- cond_model([""]) # will send cond model to GPU if lowvram/medvram is active
- #cond_model expects at least some text, so we provide '*' as backup.
- embedded = cond_model.encode_embedding_init_text(init_text or '*', num_vectors_per_token)
- vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
- #Only copy if we provided an init_text, otherwise keep vectors as zeros
- if init_text:
- for i in range(num_vectors_per_token):
- vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
- # Remove illegal characters from name.
- name = "".join( x for x in name if (x.isalnum() or x in "._- "))
- fn = os.path.join(shared.opts.embeddings_dir, f"{name}.pt")
- if not overwrite_old and os.path.exists(fn):
- shared.log.warning(f"Embedding already exists: {fn}")
- else:
- embedding = Embedding(vec=vec, name=name, filename=fn)
- embedding.step = 0
- embedding.save(fn)
- shared.log.info(f'Created embedding: {fn} vectors {num_vectors_per_token} init {init_text}')
- return fn
-
-
-def write_loss(log_directory, filename, step, epoch_len, values):
- if shared.opts.training_write_csv_every == 0:
- return
- if step % shared.opts.training_write_csv_every != 0:
- return
- write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True
- with open(os.path.join(log_directory, filename), "a+", newline='', encoding='utf-8') as fout:
- csv_writer = csv.DictWriter(fout, fieldnames=["step", "epoch", "epoch_step", *(values.keys())])
- if write_csv_header:
- csv_writer.writeheader()
- epoch = (step - 1) // epoch_len
- epoch_step = (step - 1) % epoch_len
- csv_writer.writerow({
- "step": step,
- "epoch": epoch,
- "epoch_step": epoch_step,
- **values,
- })
-
-
-def tensorboard_setup(log_directory):
- from torch.utils.tensorboard import SummaryWriter
- os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True)
- return SummaryWriter(
- log_dir=os.path.join(log_directory, "tensorboard"),
- flush_secs=shared.opts.training_tensorboard_flush_every)
-
-
-def tensorboard_add(tensorboard_writer, loss, global_step, step, learn_rate, epoch_num):
- tensorboard_add_scaler(tensorboard_writer, "Loss/train", loss, global_step)
- tensorboard_add_scaler(tensorboard_writer, f"Loss/train/epoch-{epoch_num}", loss, step)
- tensorboard_add_scaler(tensorboard_writer, "Learn rate/train", learn_rate, global_step)
- tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", learn_rate, step)
-
-
-def tensorboard_add_scaler(tensorboard_writer, tag, value, step):
- tensorboard_writer.add_scalar(tag=tag, scalar_value=value, global_step=step)
-
-
-def tensorboard_add_image(tensorboard_writer, tag, pil_image, step):
- # Convert a pil image to a torch tensor
- img_tensor = torch.as_tensor(np.array(pil_image, copy=True))
- img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0], len(pil_image.getbands()))
- img_tensor = img_tensor.permute((2, 0, 1))
- tensorboard_writer.add_image(tag, img_tensor, global_step=step)
-
-
-def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, name="embedding"):
- assert model_name, f"{name} not selected"
- assert learn_rate, "Learning rate is empty or 0"
- assert isinstance(batch_size, int), "Batch size must be integer"
- assert batch_size > 0, "Batch size must be positive"
- assert isinstance(gradient_step, int), "Gradient accumulation step must be integer"
- assert gradient_step > 0, "Gradient accumulation step must be positive"
- assert data_root, "Dataset directory is empty"
- assert os.path.isdir(data_root), "Dataset directory doesn't exist"
- assert os.listdir(data_root), "Dataset directory is empty"
- assert template_filename, "Prompt template file not selected"
- assert template_file, f"Prompt template file {template_filename} not found"
- assert os.path.isfile(template_file.path), f"Prompt template file {template_filename} doesn't exist"
- assert steps, "Max steps is empty or 0"
- assert isinstance(steps, int), "Max steps must be integer"
- assert steps > 0, "Max steps must be positive"
- assert isinstance(save_model_every, int), "Save {name} must be integer"
- assert save_model_every >= 0, "Save {name} must be positive or 0"
- assert isinstance(create_image_every, int), "Create image must be integer"
- assert create_image_every >= 0, "Create image must be positive or 0"
-
-
-def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): # pylint: disable=unused-argument
- from modules import sd_hijack, sd_hijack_checkpoint
-
- shared.log.debug(f'train_embedding: embedding_name={embedding_name}|learn_rate={learn_rate}|batch_size={batch_size}|gradient_step={gradient_step}|data_root={data_root}|log_directory={log_directory}|training_width={training_width}|training_height={training_height}|varsize={varsize}|steps={steps}|clip_grad_mode={clip_grad_mode}|clip_grad_value={clip_grad_value}|shuffle_tags={shuffle_tags}|tag_drop_out={tag_drop_out}|latent_sampling_method={latent_sampling_method}|use_weight={use_weight}|create_image_every={create_image_every}|save_embedding_every={save_embedding_every}|template_filename={template_filename}|save_image_with_stored_embedding={save_image_with_stored_embedding}|preview_from_txt2img={preview_from_txt2img}|preview_prompt={preview_prompt}|preview_negative_prompt={preview_negative_prompt}|preview_steps={preview_steps}|preview_sampler_index={preview_sampler_index}|preview_cfg_scale={preview_cfg_scale}|preview_seed={preview_seed}|preview_width={preview_width}|preview_height={preview_height}')
- save_embedding_every = save_embedding_every or 0
- create_image_every = create_image_every or 0
- template_file = textual_inversion_templates.get(template_filename, None)
- validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_embedding_every, create_image_every, name="embedding")
- if log_directory is None or log_directory == '':
- log_directory = f"{os.path.join(shared.cmd_opts.data_dir, 'train/log/embeddings')}"
- template_file = template_file.path
-
- shared.state.job = "train"
- shared.state.textinfo = "Initializing textual inversion training..."
- shared.state.job_count = steps
-
- filename = os.path.join(shared.opts.embeddings_dir, f'{embedding_name}.pt')
-
- if log_directory == '':
- log_directory = f"{os.path.join(shared.cmd_opts.data_dir, 'train/log/embeddings')}"
- log_directory = os.path.join(log_directory, embedding_name)
- unload = shared.opts.unload_models_when_training
-
- if save_embedding_every > 0:
- embedding_dir = os.path.join(log_directory, "embeddings")
- os.makedirs(embedding_dir, exist_ok=True)
- else:
- embedding_dir = None
-
- if create_image_every > 0:
- images_dir = os.path.join(log_directory, "images")
- os.makedirs(images_dir, exist_ok=True)
- else:
- images_dir = None
-
- if create_image_every > 0 and save_image_with_stored_embedding:
- images_embeds_dir = os.path.join(log_directory, "image_embeddings")
- os.makedirs(images_embeds_dir, exist_ok=True)
- else:
- images_embeds_dir = None
-
- hijack = sd_hijack.model_hijack
- embedding = hijack.embedding_db.word_embeddings[embedding_name]
- checkpoint = sd_models.select_checkpoint()
- initial_step = embedding.step or 0
- if initial_step >= steps:
- shared.state.textinfo = "Model has already been trained beyond specified max steps"
- return embedding, filename
- scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
- clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
- torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
- None
- if clip_grad:
- clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
- # dataset loading may take a while, so input validations and early returns should be done before this
- shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
- old_parallel_processing_allowed = shared.parallel_processing_allowed
-
- if shared.opts.training_enable_tensorboard:
- tensorboard_writer = tensorboard_setup(log_directory)
-
- pin_memory = shared.opts.pin_memory
- # init dataset
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize, use_weight=use_weight)
-
- if shared.opts.save_training_settings_to_txt:
- save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()})
- latent_sampling_method = ds.latent_sampling_method
- # init dataloader
- dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
- if unload:
- shared.parallel_processing_allowed = False
- shared.sd_model.first_stage_model.to(devices.cpu)
-
- embedding.vec.requires_grad = True
- optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
- if shared.opts.save_optimizer_state:
- optimizer_state_dict = None
- if os.path.exists(f"{filename}.optim"):
- optimizer_saved_dict = torch.load(f"{filename}.optim", map_location='cpu')
- if embedding.checksum() == optimizer_saved_dict.get('hash', None):
- optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
- if optimizer_state_dict is not None:
- optimizer.load_state_dict(optimizer_state_dict)
- shared.log.info("Load existing optimizer from checkpoint")
- else:
- shared.log.info("No saved optimizer exists in checkpoint")
-
- scaler = torch.cuda.amp.GradScaler()
-
- batch_size = ds.batch_size
- gradient_step = ds.gradient_step
- # n steps = batch_size * gradient_step * n image processed
- steps_per_epoch = len(ds) // batch_size // gradient_step
- max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
- loss_step = 0
- _loss_step = 0 #internal
- last_saved_file = ""
- last_saved_image = ""
- forced_filename = ""
- embedding_yet_to_be_embedded = False
- is_training_inpainting_model = shared.sd_model.model.conditioning_key in {'hybrid', 'concat'}
- img_c = None
-
- pbar = tqdm(total=steps - initial_step)
- try:
- sd_hijack_checkpoint.add()
- for _i in range((steps-initial_step) * gradient_step):
- if scheduler.finished:
- break
- if shared.state.interrupted:
- break
- for j, batch in enumerate(dl):
- # works as a drop_last=True for gradient accumulation
- if j == max_steps_per_epoch:
- break
- scheduler.apply(optimizer, embedding.step)
- if scheduler.finished:
- break
- if shared.state.interrupted:
- break
- if clip_grad:
- clip_grad_sched.step(embedding.step)
- with devices.autocast():
- x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
- if use_weight:
- w = batch.weight.to(devices.device, non_blocking=pin_memory)
- c = shared.sd_model.cond_stage_model(batch.cond_text)
- if is_training_inpainting_model:
- if img_c is None:
- img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height)
- cond = {"c_concat": [img_c], "c_crossattn": [c]}
- else:
- cond = c
- if use_weight:
- loss = shared.sd_model.weighted_forward(x, cond, w)[0] / gradient_step
- del w
- else:
- loss = shared.sd_model.forward(x, cond)[0] / gradient_step
- del x
- _loss_step += loss.item()
-
- scaler.scale(loss).backward()
- # go back until we reach gradient accumulation steps
- if (j + 1) % gradient_step != 0:
- continue
- if clip_grad:
- clip_grad(embedding.vec, clip_grad_sched.learn_rate)
-
- scaler.step(optimizer)
- scaler.update()
- embedding.step += 1
- pbar.update()
- optimizer.zero_grad(set_to_none=True)
- loss_step = _loss_step
- _loss_step = 0
- steps_done = embedding.step + 1
- epoch_num = embedding.step // steps_per_epoch
-
- description = f"Training textual inversion step {embedding.step} loss: {loss_step:.5f} lr: {scheduler.learn_rate:.5f}"
- pbar.set_description(description)
- if embedding_dir is not None and steps_done % save_embedding_every == 0:
- # Before saving, change name to match current checkpoint.
- embedding_name_every = f'{embedding_name}-{steps_done}'
- last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
- save_embedding(embedding, optimizer, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
- embedding_yet_to_be_embedded = True
-
- write_loss(log_directory, f"{embedding_name}.csv", embedding.step, steps_per_epoch, { "loss": f"{loss_step:.7f}", "learn_rate": scheduler.learn_rate })
-
- if images_dir is not None and steps_done % create_image_every == 0:
- forced_filename = f'{embedding_name}-{steps_done}'
- last_saved_image = os.path.join(images_dir, forced_filename)
- shared.sd_model.first_stage_model.to(devices.device)
-
- p = processing.StableDiffusionProcessingTxt2Img(
- sd_model=shared.sd_model,
- do_not_save_grid=True,
- do_not_save_samples=True,
- do_not_reload_embeddings=True,
- )
-
- if preview_from_txt2img:
- p.prompt = preview_prompt
- p.negative_prompt = preview_negative_prompt
- p.steps = preview_steps
- p.sampler_name = processing.get_sampler_name(preview_sampler_index)
- p.cfg_scale = preview_cfg_scale
- p.seed = preview_seed
- p.width = preview_width
- p.height = preview_height
- else:
- p.prompt = batch.cond_text[0]
- p.steps = 20
- p.width = training_width
- p.height = training_height
-
- preview_text = p.prompt
- processed = processing.process_images(p)
- image = processed.images[0] if len(processed.images) > 0 else None
-
- if unload:
- shared.sd_model.first_stage_model.to(devices.cpu)
-
- if image is not None:
- shared.state.assign_current_image(image)
- last_saved_image, _last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
- last_saved_image += f", prompt: {preview_text}"
- if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
- tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step)
-
- if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
- last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
- info = PngImagePlugin.PngInfo()
- data = torch.load(last_saved_file)
- info.add_text("sd-ti-embedding", embedding_to_b64(data))
- title = f"<{data.get('name', '???')}>"
- try:
- vectorSize = list(data['string_to_param'].values())[0].shape[0]
- except Exception:
- vectorSize = '?'
- checkpoint = sd_models.select_checkpoint()
- footer_left = checkpoint.model_name
- footer_mid = f'[{checkpoint.shorthash}]'
- footer_right = f'{vectorSize}v {steps_done}s'
- captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
- captioned_image = insert_image_data_embed(captioned_image, data)
- captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
- embedding_yet_to_be_embedded = False
-
- last_saved_image, _last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
- last_saved_image += f", prompt: {preview_text}"
-
- shared.state.job_no = embedding.step
- shared.state.textinfo = f"""
-
-Loss: {loss_step:.7f}
-Step: {steps_done}
-Last prompt: {html.escape(batch.cond_text[0])}
-Last saved embedding: {html.escape(last_saved_file)}
-Last saved image: {html.escape(last_saved_image)}
-
-"""
- filename = os.path.join(shared.opts.embeddings_dir, f'{embedding_name}.pt')
- save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
- except Exception as e:
- errors.display(e, 'embedding train')
- finally:
- pbar.leave = False
- pbar.close()
- shared.sd_model.first_stage_model.to(devices.device)
- shared.parallel_processing_allowed = old_parallel_processing_allowed
- sd_hijack_checkpoint.remove()
- return embedding, filename
-
-
-def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):
- old_embedding_name = embedding.name
- old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
- old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
- old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None
- try:
- embedding.sd_checkpoint = checkpoint.shorthash
- embedding.sd_checkpoint_name = checkpoint.model_name
- if remove_cached_checksum:
- embedding.cached_checksum = None
- embedding.name = embedding_name
- embedding.optimizer_state_dict = optimizer.state_dict()
- embedding.save(filename)
- except Exception:
- embedding.sd_checkpoint = old_sd_checkpoint
- embedding.sd_checkpoint_name = old_sd_checkpoint_name
- embedding.name = old_embedding_name
- embedding.cached_checksum = old_cached_checksum
- raise
+import os
+import html
+import csv
+import time
+from collections import namedtuple
+import torch
+from tqdm import tqdm
+import safetensors.torch
+import numpy as np
+from PIL import Image, PngImagePlugin
+from modules import shared, devices, processing, sd_models, images, errors
+import modules.textual_inversion.dataset
+from modules.textual_inversion.learn_schedule import LearnRateScheduler
+from modules.textual_inversion.image_embedding import embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay
+from modules.textual_inversion.ti_logging import save_settings_to_file
+from modules.modelloader import directory_files, extension_filter, directory_mtime
+
+
+TextualInversionTemplate = namedtuple("TextualInversionTemplate", ["name", "path"])
+textual_inversion_templates = {}
+
+
+def list_textual_inversion_templates():
+ textual_inversion_templates.clear()
+ for root, _dirs, fns in os.walk(shared.opts.embeddings_templates_dir):
+ for fn in fns:
+ path = os.path.join(root, fn)
+ textual_inversion_templates[fn] = TextualInversionTemplate(fn, path)
+ return textual_inversion_templates
+
+
+class Embedding:
+ def __init__(self, vec, name, filename=None, step=None):
+ self.vec = vec
+ self.name = name
+ self.tag = name
+ self.step = step
+ self.filename = filename
+ self.basename = os.path.relpath(filename, shared.opts.embeddings_dir) if filename is not None else None
+ self.shape = None
+ self.vectors = 0
+ self.cached_checksum = None
+ self.sd_checkpoint = None
+ self.sd_checkpoint_name = None
+ self.optimizer_state_dict = None
+
+ def save(self, filename):
+ embedding_data = {
+ "string_to_token": {"*": 265},
+ "string_to_param": {"*": self.vec},
+ "name": self.name,
+ "step": self.step,
+ "sd_checkpoint": self.sd_checkpoint,
+ "sd_checkpoint_name": self.sd_checkpoint_name,
+ }
+ torch.save(embedding_data, filename)
+ if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None:
+ optimizer_saved_dict = {
+ 'hash': self.checksum(),
+ 'optimizer_state_dict': self.optimizer_state_dict,
+ }
+ torch.save(optimizer_saved_dict, f"{filename}.optim")
+
+ def checksum(self):
+ if self.cached_checksum is not None:
+ return self.cached_checksum
+ def const_hash(a):
+ r = 0
+ for v in a:
+ r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
+ return r
+ self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}'
+ return self.cached_checksum
+
+
+class DirWithTextualInversionEmbeddings:
+ def __init__(self, path):
+ self.path = path
+ self.mtime = None
+
+ def has_changed(self):
+ if not os.path.isdir(self.path):
+ return False
+ return directory_mtime(self.path) != self.mtime
+
+ def update(self):
+ if not os.path.isdir(self.path):
+ return
+ self.mtime = directory_mtime(self.path)
+
+
+class EmbeddingDatabase:
+ def __init__(self):
+ self.ids_lookup = {}
+ self.word_embeddings = {}
+ self.skipped_embeddings = {}
+ self.expected_shape = -1
+ self.embedding_dirs = {}
+ self.previously_displayed_embeddings = ()
+ self.embeddings_used = []
+
+ def add_embedding_dir(self, path):
+ self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
+
+ def clear_embedding_dirs(self):
+ self.embedding_dirs.clear()
+
+ def register_embedding(self, embedding, model):
+ self.word_embeddings[embedding.name] = embedding
+ if hasattr(model, 'cond_stage_model'):
+ ids = model.cond_stage_model.tokenize([embedding.name])[0]
+ elif hasattr(model, 'tokenizer'):
+ ids = model.tokenizer.convert_tokens_to_ids(embedding.name)
+ if type(ids) != list:
+ ids = [ids]
+ first_id = ids[0]
+ if first_id not in self.ids_lookup:
+ self.ids_lookup[first_id] = []
+ self.ids_lookup[first_id] = sorted(self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True)
+ return embedding
+
+ def get_expected_shape(self):
+ if shared.backend == shared.Backend.DIFFUSERS:
+ return 0
+ if shared.sd_model is None:
+ shared.log.error('Model not loaded')
+ return 0
+ vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
+ return vec.shape[1]
+
+ def load_diffusers_embedding(self, filename: str, path: str):
+ if shared.sd_model is None:
+ return
+ fn, ext = os.path.splitext(filename)
+ if ext.lower() != ".pt" and ext.lower() != ".safetensors":
+ return
+ pipe = shared.sd_model
+ name = os.path.basename(fn)
+ embedding = Embedding(vec=None, name=name, filename=path)
+ if not hasattr(pipe, "tokenizer") or not hasattr(pipe, 'text_encoder'):
+ self.skipped_embeddings[name] = embedding
+ return
+ try:
+ is_xl = hasattr(pipe, 'text_encoder_2')
+ try:
+ if not is_xl: # only use for sd15/sd21
+ pipe.load_textual_inversion(path, token=name, cache_dir=shared.opts.diffusers_dir, local_files_only=True)
+ self.register_embedding(embedding, shared.sd_model)
+ except Exception:
+ pass
+ is_loaded = pipe.tokenizer.convert_tokens_to_ids(name)
+ if type(is_loaded) != list:
+ is_loaded = [is_loaded]
+ is_loaded = is_loaded[0] > 49407
+ if is_loaded:
+ self.register_embedding(embedding, shared.sd_model)
+ else:
+ embeddings_dict = {}
+ if ext.lower() in ['.safetensors']:
+ with safetensors.torch.safe_open(path, framework="pt") as f:
+ for k in f.keys():
+ embeddings_dict[k] = f.get_tensor(k)
+ else:
+ raise NotImplementedError
+ """
+ # alternatively could disable load_textual_inversion and load everything here
+ elif ext.lower() in ['.pt', '.bin']:
+ data = torch.load(path, map_location="cpu")
+ embedding.tag = data.get('name', None)
+ embedding.step = data.get('step', None)
+ embedding.sd_checkpoint = data.get('sd_checkpoint', None)
+ embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
+ param_dict = data.get('string_to_param', None)
+ embeddings_dict['clip_l'] = []
+ for tokens in param_dict.values():
+ for vec in tokens:
+ embeddings_dict['clip_l'].append(vec)
+ """
+ clip_l = pipe.text_encoder if hasattr(pipe, 'text_encoder') else None
+ clip_g = pipe.text_encoder_2 if hasattr(pipe, 'text_encoder_2') else None
+ is_sd = clip_l is not None and 'clip_l' in embeddings_dict and clip_g is None and 'clip_g' not in embeddings_dict
+ is_xl = clip_l is not None and 'clip_l' in embeddings_dict and clip_g is not None and 'clip_g' in embeddings_dict
+ tokens = []
+ for i in range(len(embeddings_dict["clip_l"])):
+ if (is_sd or is_xl) and (len(clip_l.get_input_embeddings().weight.data[0]) == len(embeddings_dict["clip_l"][i])):
+ tokens.append(name if i == 0 else f"{name}_{i}")
+ num_added = pipe.tokenizer.add_tokens(tokens)
+ if num_added > 0:
+ token_ids = pipe.tokenizer.convert_tokens_to_ids(tokens)
+ if is_sd: # only used for sd15 if load_textual_inversion failed and format is safetensors
+ clip_l.resize_token_embeddings(len(pipe.tokenizer))
+ for i in range(len(token_ids)):
+ clip_l.get_input_embeddings().weight.data[token_ids[i]] = embeddings_dict["clip_l"][i]
+ elif is_xl:
+ pipe.tokenizer_2.add_tokens(tokens)
+ clip_l.resize_token_embeddings(len(pipe.tokenizer))
+ clip_g.resize_token_embeddings(len(pipe.tokenizer))
+ for i in range(len(token_ids)):
+ clip_l.get_input_embeddings().weight.data[token_ids[i]] = embeddings_dict["clip_l"][i]
+ clip_g.get_input_embeddings().weight.data[token_ids[i]] = embeddings_dict["clip_g"][i]
+ self.register_embedding(embedding, shared.sd_model)
+ else:
+ raise NotImplementedError
+ except Exception:
+ self.skipped_embeddings[name] = embedding
+
+ def load_from_file(self, path, filename):
+ name, ext = os.path.splitext(filename)
+ ext = ext.upper()
+ if shared.backend == shared.Backend.DIFFUSERS:
+ self.load_diffusers_embedding(filename, path)
+ return
+
+ if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
+ if '.preview' in filename.lower():
+ return
+ embed_image = Image.open(path)
+ if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
+ data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
+ else:
+ data = extract_image_data_embed(embed_image)
+ if not data: # if data is None, means this is not an embeding, just a preview image
+ return
+ elif ext in ['.BIN', '.PT']:
+ data = torch.load(path, map_location="cpu")
+ elif ext in ['.SAFETENSORS']:
+ data = safetensors.torch.load_file(path, device="cpu")
+ else:
+ return
+
+ # textual inversion embeddings
+ if 'string_to_param' in data:
+ param_dict = data['string_to_param']
+ param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
+ assert len(param_dict) == 1, 'embedding file has multiple terms in it'
+ emb = next(iter(param_dict.items()))[1]
+ # diffuser concepts
+ elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
+ if len(data.keys()) != 1:
+ self.skipped_embeddings[name] = Embedding(None, name=name, filename=path)
+ return
+ emb = next(iter(data.values()))
+ if len(emb.shape) == 1:
+ emb = emb.unsqueeze(0)
+ else:
+ raise RuntimeError(f"Couldn't identify {filename} as textual inversion embedding")
+
+ vec = emb.detach().to(devices.device, dtype=torch.float32)
+ # name = data.get('name', name)
+ embedding = Embedding(vec=vec, name=name, filename=path)
+ embedding.tag = data.get('name', None)
+ embedding.step = data.get('step', None)
+ embedding.sd_checkpoint = data.get('sd_checkpoint', None)
+ embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
+ embedding.vectors = vec.shape[0]
+ embedding.shape = vec.shape[-1]
+ if self.expected_shape == -1 or self.expected_shape == embedding.shape:
+ self.register_embedding(embedding, shared.sd_model)
+ else:
+ self.skipped_embeddings[name] = embedding
+
+ def load_from_dir(self, embdir):
+ if sd_models.model_data.sd_model is None:
+ shared.log.info('Skipping embeddings load: model not loaded')
+ return
+ if not os.path.isdir(embdir.path):
+ return
+ is_ext = extension_filter(['.PNG', '.WEBP', '.JXL', '.AVIF', '.BIN', '.PT', '.SAFETENSORS'])
+ is_not_preview = lambda fp: not next(iter(os.path.splitext(fp))).upper().endswith('.PREVIEW') # pylint: disable=unnecessary-lambda-assignment
+ for file_path in [*filter(lambda fp: is_ext(fp) and is_not_preview(fp), directory_files(embdir.path))]:
+ try:
+ if os.stat(file_path).st_size == 0:
+ continue
+ fn = os.path.basename(file_path)
+ self.load_from_file(file_path, fn)
+ except Exception as e:
+ errors.display(e, f'embedding load {fn}')
+ continue
+
+ def load_textual_inversion_embeddings(self, force_reload=False):
+ if shared.sd_model is None:
+ return
+ t0 = time.time()
+ if not force_reload:
+ need_reload = False
+ for embdir in self.embedding_dirs.values():
+ if embdir.has_changed():
+ need_reload = True
+ break
+ if not need_reload:
+ return
+ self.ids_lookup.clear()
+ self.word_embeddings.clear()
+ self.skipped_embeddings.clear()
+ self.embeddings_used.clear()
+ self.expected_shape = self.get_expected_shape()
+ for embdir in self.embedding_dirs.values():
+ self.load_from_dir(embdir)
+ embdir.update()
+
+ # re-sort word_embeddings because load_from_dir may not load in alphabetic order.
+ # using a temporary copy so we don't reinitialize self.word_embeddings in case other objects have a reference to it.
+ sorted_word_embeddings = {e.name: e for e in sorted(self.word_embeddings.values(), key=lambda e: e.name.lower())}
+ self.word_embeddings.clear()
+ self.word_embeddings.update(sorted_word_embeddings)
+
+ displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys()))
+ if self.previously_displayed_embeddings != displayed_embeddings:
+ self.previously_displayed_embeddings = displayed_embeddings
+ t1 = time.time()
+ shared.log.info(f"Load embeddings: loaded={len(self.word_embeddings)} skipped={len(self.skipped_embeddings)} time={t1-t0:.2f}")
+
+
+ def find_embedding_at_position(self, tokens, offset):
+ token = tokens[offset]
+ possible_matches = self.ids_lookup.get(token, None)
+ if possible_matches is None:
+ return None, None
+ for ids, embedding in possible_matches:
+ if tokens[offset:offset + len(ids)] == ids:
+ return embedding, len(ids)
+ return None, None
+
+
+def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
+ cond_model = shared.sd_model.cond_stage_model
+ with devices.autocast():
+ cond_model([""]) # will send cond model to GPU if lowvram/medvram is active
+ #cond_model expects at least some text, so we provide '*' as backup.
+ embedded = cond_model.encode_embedding_init_text(init_text or '*', num_vectors_per_token)
+ vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
+ #Only copy if we provided an init_text, otherwise keep vectors as zeros
+ if init_text:
+ for i in range(num_vectors_per_token):
+ vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
+ # Remove illegal characters from name.
+ name = "".join( x for x in name if (x.isalnum() or x in "._- "))
+ fn = os.path.join(shared.opts.embeddings_dir, f"{name}.pt")
+ if not overwrite_old and os.path.exists(fn):
+ shared.log.warning(f"Embedding already exists: {fn}")
+ else:
+ embedding = Embedding(vec=vec, name=name, filename=fn)
+ embedding.step = 0
+ embedding.save(fn)
+ shared.log.info(f'Created embedding: {fn} vectors {num_vectors_per_token} init {init_text}')
+ return fn
+
+
+def write_loss(log_directory, filename, step, epoch_len, values):
+ if shared.opts.training_write_csv_every == 0:
+ return
+ if step % shared.opts.training_write_csv_every != 0:
+ return
+ write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True
+ with open(os.path.join(log_directory, filename), "a+", newline='', encoding='utf-8') as fout:
+ csv_writer = csv.DictWriter(fout, fieldnames=["step", "epoch", "epoch_step", *(values.keys())])
+ if write_csv_header:
+ csv_writer.writeheader()
+ epoch = (step - 1) // epoch_len
+ epoch_step = (step - 1) % epoch_len
+ csv_writer.writerow({
+ "step": step,
+ "epoch": epoch,
+ "epoch_step": epoch_step,
+ **values,
+ })
+
+
+def tensorboard_setup(log_directory):
+ from torch.utils.tensorboard import SummaryWriter
+ os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True)
+ return SummaryWriter(
+ log_dir=os.path.join(log_directory, "tensorboard"),
+ flush_secs=shared.opts.training_tensorboard_flush_every)
+
+
+def tensorboard_add(tensorboard_writer, loss, global_step, step, learn_rate, epoch_num):
+ tensorboard_add_scaler(tensorboard_writer, "Loss/train", loss, global_step)
+ tensorboard_add_scaler(tensorboard_writer, f"Loss/train/epoch-{epoch_num}", loss, step)
+ tensorboard_add_scaler(tensorboard_writer, "Learn rate/train", learn_rate, global_step)
+ tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", learn_rate, step)
+
+
+def tensorboard_add_scaler(tensorboard_writer, tag, value, step):
+ tensorboard_writer.add_scalar(tag=tag, scalar_value=value, global_step=step)
+
+
+def tensorboard_add_image(tensorboard_writer, tag, pil_image, step):
+ # Convert a pil image to a torch tensor
+ img_tensor = torch.as_tensor(np.array(pil_image, copy=True))
+ img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0], len(pil_image.getbands()))
+ img_tensor = img_tensor.permute((2, 0, 1))
+ tensorboard_writer.add_image(tag, img_tensor, global_step=step)
+
+
+def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, name="embedding"):
+ assert model_name, f"{name} not selected"
+ assert learn_rate, "Learning rate is empty or 0"
+ assert isinstance(batch_size, int), "Batch size must be integer"
+ assert batch_size > 0, "Batch size must be positive"
+ assert isinstance(gradient_step, int), "Gradient accumulation step must be integer"
+ assert gradient_step > 0, "Gradient accumulation step must be positive"
+ assert data_root, "Dataset directory is empty"
+ assert os.path.isdir(data_root), "Dataset directory doesn't exist"
+ assert os.listdir(data_root), "Dataset directory is empty"
+ assert template_filename, "Prompt template file not selected"
+ assert template_file, f"Prompt template file {template_filename} not found"
+ assert os.path.isfile(template_file.path), f"Prompt template file {template_filename} doesn't exist"
+ assert steps, "Max steps is empty or 0"
+ assert isinstance(steps, int), "Max steps must be integer"
+ assert steps > 0, "Max steps must be positive"
+ assert isinstance(save_model_every, int), "Save {name} must be integer"
+ assert save_model_every >= 0, "Save {name} must be positive or 0"
+ assert isinstance(create_image_every, int), "Create image must be integer"
+ assert create_image_every >= 0, "Create image must be positive or 0"
+
+
+def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): # pylint: disable=unused-argument
+ from modules import sd_hijack, sd_hijack_checkpoint
+
+ shared.log.debug(f'train_embedding: embedding_name={embedding_name}|learn_rate={learn_rate}|batch_size={batch_size}|gradient_step={gradient_step}|data_root={data_root}|log_directory={log_directory}|training_width={training_width}|training_height={training_height}|varsize={varsize}|steps={steps}|clip_grad_mode={clip_grad_mode}|clip_grad_value={clip_grad_value}|shuffle_tags={shuffle_tags}|tag_drop_out={tag_drop_out}|latent_sampling_method={latent_sampling_method}|use_weight={use_weight}|create_image_every={create_image_every}|save_embedding_every={save_embedding_every}|template_filename={template_filename}|save_image_with_stored_embedding={save_image_with_stored_embedding}|preview_from_txt2img={preview_from_txt2img}|preview_prompt={preview_prompt}|preview_negative_prompt={preview_negative_prompt}|preview_steps={preview_steps}|preview_sampler_index={preview_sampler_index}|preview_cfg_scale={preview_cfg_scale}|preview_seed={preview_seed}|preview_width={preview_width}|preview_height={preview_height}')
+ save_embedding_every = save_embedding_every or 0
+ create_image_every = create_image_every or 0
+ template_file = textual_inversion_templates.get(template_filename, None)
+ validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_embedding_every, create_image_every, name="embedding")
+ if log_directory is None or log_directory == '':
+ log_directory = f"{os.path.join(shared.cmd_opts.data_dir, 'train/log/embeddings')}"
+ template_file = template_file.path
+
+ shared.state.job = "train"
+ shared.state.textinfo = "Initializing textual inversion training..."
+ shared.state.job_count = steps
+
+ filename = os.path.join(shared.opts.embeddings_dir, f'{embedding_name}.pt')
+
+ if log_directory == '':
+ log_directory = f"{os.path.join(shared.cmd_opts.data_dir, 'train/log/embeddings')}"
+ log_directory = os.path.join(log_directory, embedding_name)
+ unload = shared.opts.unload_models_when_training
+
+ if save_embedding_every > 0:
+ embedding_dir = os.path.join(log_directory, "embeddings")
+ os.makedirs(embedding_dir, exist_ok=True)
+ else:
+ embedding_dir = None
+
+ if create_image_every > 0:
+ images_dir = os.path.join(log_directory, "images")
+ os.makedirs(images_dir, exist_ok=True)
+ else:
+ images_dir = None
+
+ if create_image_every > 0 and save_image_with_stored_embedding:
+ images_embeds_dir = os.path.join(log_directory, "image_embeddings")
+ os.makedirs(images_embeds_dir, exist_ok=True)
+ else:
+ images_embeds_dir = None
+
+ hijack = sd_hijack.model_hijack
+ embedding = hijack.embedding_db.word_embeddings[embedding_name]
+ checkpoint = sd_models.select_checkpoint()
+ initial_step = embedding.step or 0
+ if initial_step >= steps:
+ shared.state.textinfo = "Model has already been trained beyond specified max steps"
+ return embedding, filename
+ scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
+ clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
+ torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
+ None
+ if clip_grad:
+ clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
+ # dataset loading may take a while, so input validations and early returns should be done before this
+ shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
+ old_parallel_processing_allowed = shared.parallel_processing_allowed
+
+ if shared.opts.training_enable_tensorboard:
+ tensorboard_writer = tensorboard_setup(log_directory)
+
+ pin_memory = shared.opts.pin_memory
+ # init dataset
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize, use_weight=use_weight)
+
+ if shared.opts.save_training_settings_to_txt:
+ save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()})
+ latent_sampling_method = ds.latent_sampling_method
+ # init dataloader
+ dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
+ if unload:
+ shared.parallel_processing_allowed = False
+ shared.sd_model.first_stage_model.to(devices.cpu)
+
+ embedding.vec.requires_grad = True
+ optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
+ if shared.opts.save_optimizer_state:
+ optimizer_state_dict = None
+ if os.path.exists(f"{filename}.optim"):
+ optimizer_saved_dict = torch.load(f"{filename}.optim", map_location='cpu')
+ if embedding.checksum() == optimizer_saved_dict.get('hash', None):
+ optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
+ if optimizer_state_dict is not None:
+ optimizer.load_state_dict(optimizer_state_dict)
+ shared.log.info("Load existing optimizer from checkpoint")
+ else:
+ shared.log.info("No saved optimizer exists in checkpoint")
+
+ scaler = torch.cuda.amp.GradScaler()
+
+ batch_size = ds.batch_size
+ gradient_step = ds.gradient_step
+ # n steps = batch_size * gradient_step * n image processed
+ steps_per_epoch = len(ds) // batch_size // gradient_step
+ max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
+ loss_step = 0
+ _loss_step = 0 #internal
+ last_saved_file = ""
+ last_saved_image = ""
+ forced_filename = ""
+ embedding_yet_to_be_embedded = False
+ is_training_inpainting_model = shared.sd_model.model.conditioning_key in {'hybrid', 'concat'}
+ img_c = None
+
+ pbar = tqdm(total=steps - initial_step)
+ try:
+ sd_hijack_checkpoint.add()
+ for _i in range((steps-initial_step) * gradient_step):
+ if scheduler.finished:
+ break
+ if shared.state.interrupted:
+ break
+ for j, batch in enumerate(dl):
+ # works as a drop_last=True for gradient accumulation
+ if j == max_steps_per_epoch:
+ break
+ scheduler.apply(optimizer, embedding.step)
+ if scheduler.finished:
+ break
+ if shared.state.interrupted:
+ break
+ if clip_grad:
+ clip_grad_sched.step(embedding.step)
+ with devices.autocast():
+ x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
+ if use_weight:
+ w = batch.weight.to(devices.device, non_blocking=pin_memory)
+ c = shared.sd_model.cond_stage_model(batch.cond_text)
+ if is_training_inpainting_model:
+ if img_c is None:
+ img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height)
+ cond = {"c_concat": [img_c], "c_crossattn": [c]}
+ else:
+ cond = c
+ if use_weight:
+ loss = shared.sd_model.weighted_forward(x, cond, w)[0] / gradient_step
+ del w
+ else:
+ loss = shared.sd_model.forward(x, cond)[0] / gradient_step
+ del x
+ _loss_step += loss.item()
+
+ scaler.scale(loss).backward()
+ # go back until we reach gradient accumulation steps
+ if (j + 1) % gradient_step != 0:
+ continue
+ if clip_grad:
+ clip_grad(embedding.vec, clip_grad_sched.learn_rate)
+
+ scaler.step(optimizer)
+ scaler.update()
+ embedding.step += 1
+ pbar.update()
+ optimizer.zero_grad(set_to_none=True)
+ loss_step = _loss_step
+ _loss_step = 0
+ steps_done = embedding.step + 1
+ epoch_num = embedding.step // steps_per_epoch
+
+ description = f"Training textual inversion step {embedding.step} loss: {loss_step:.5f} lr: {scheduler.learn_rate:.5f}"
+ pbar.set_description(description)
+ if embedding_dir is not None and steps_done % save_embedding_every == 0:
+ # Before saving, change name to match current checkpoint.
+ embedding_name_every = f'{embedding_name}-{steps_done}'
+ last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
+ save_embedding(embedding, optimizer, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
+ embedding_yet_to_be_embedded = True
+
+ write_loss(log_directory, f"{embedding_name}.csv", embedding.step, steps_per_epoch, { "loss": f"{loss_step:.7f}", "learn_rate": scheduler.learn_rate })
+
+ if images_dir is not None and steps_done % create_image_every == 0:
+ forced_filename = f'{embedding_name}-{steps_done}'
+ last_saved_image = os.path.join(images_dir, forced_filename)
+ shared.sd_model.first_stage_model.to(devices.device)
+
+ p = processing.StableDiffusionProcessingTxt2Img(
+ sd_model=shared.sd_model,
+ do_not_save_grid=True,
+ do_not_save_samples=True,
+ do_not_reload_embeddings=True,
+ )
+
+ if preview_from_txt2img:
+ p.prompt = preview_prompt
+ p.negative_prompt = preview_negative_prompt
+ p.steps = preview_steps
+ p.sampler_name = processing.get_sampler_name(preview_sampler_index)
+ p.cfg_scale = preview_cfg_scale
+ p.seed = preview_seed
+ p.width = preview_width
+ p.height = preview_height
+ else:
+ p.prompt = batch.cond_text[0]
+ p.steps = 20
+ p.width = training_width
+ p.height = training_height
+
+ preview_text = p.prompt
+ processed = processing.process_images(p)
+ image = processed.images[0] if len(processed.images) > 0 else None
+
+ if unload:
+ shared.sd_model.first_stage_model.to(devices.cpu)
+
+ if image is not None:
+ shared.state.assign_current_image(image)
+ last_saved_image, _last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
+ last_saved_image += f", prompt: {preview_text}"
+ if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
+ tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step)
+
+ if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
+ last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
+ info = PngImagePlugin.PngInfo()
+ data = torch.load(last_saved_file)
+ info.add_text("sd-ti-embedding", embedding_to_b64(data))
+ title = f"<{data.get('name', '???')}>"
+ try:
+ vectorSize = list(data['string_to_param'].values())[0].shape[0]
+ except Exception:
+ vectorSize = '?'
+ checkpoint = sd_models.select_checkpoint()
+ footer_left = checkpoint.model_name
+ footer_mid = f'[{checkpoint.shorthash}]'
+ footer_right = f'{vectorSize}v {steps_done}s'
+ captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
+ captioned_image = insert_image_data_embed(captioned_image, data)
+ captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
+ embedding_yet_to_be_embedded = False
+
+ last_saved_image, _last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
+ last_saved_image += f", prompt: {preview_text}"
+
+ shared.state.job_no = embedding.step
+ shared.state.textinfo = f"""
+
+Loss: {loss_step:.7f}
+Step: {steps_done}
+Last prompt: {html.escape(batch.cond_text[0])}
+Last saved embedding: {html.escape(last_saved_file)}
+Last saved image: {html.escape(last_saved_image)}
+
+"""
+ filename = os.path.join(shared.opts.embeddings_dir, f'{embedding_name}.pt')
+ save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
+ except Exception as e:
+ errors.display(e, 'embedding train')
+ finally:
+ pbar.leave = False
+ pbar.close()
+ shared.sd_model.first_stage_model.to(devices.device)
+ shared.parallel_processing_allowed = old_parallel_processing_allowed
+ sd_hijack_checkpoint.remove()
+ return embedding, filename
+
+
+def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):
+ old_embedding_name = embedding.name
+ old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
+ old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
+ old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None
+ try:
+ embedding.sd_checkpoint = checkpoint.shorthash
+ embedding.sd_checkpoint_name = checkpoint.model_name
+ if remove_cached_checksum:
+ embedding.cached_checksum = None
+ embedding.name = embedding_name
+ embedding.optimizer_state_dict = optimizer.state_dict()
+ embedding.save(filename)
+ except Exception:
+ embedding.sd_checkpoint = old_sd_checkpoint
+ embedding.sd_checkpoint_name = old_sd_checkpoint_name
+ embedding.name = old_embedding_name
+ embedding.cached_checksum = old_cached_checksum
+ raise
diff --git a/modules/textual_inversion/ti_logging.py b/modules/textual_inversion/ti_logging.py
index a79696e3f..1116ec91d 100644
--- a/modules/textual_inversion/ti_logging.py
+++ b/modules/textual_inversion/ti_logging.py
@@ -1,22 +1,22 @@
-import datetime
-import json
-import os
-
-saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "clip_grad_mode", "clip_grad_value", "gradient_step", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file", "latent_sampling_method"}
-saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"}
-saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"}
-saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet
-saved_params_previews = {"preview_prompt", "preview_negative_prompt", "preview_steps", "preview_sampler_index", "preview_cfg_scale", "preview_seed", "preview_width", "preview_height"}
-
-
-def save_settings_to_file(log_directory, all_params):
- now = datetime.datetime.now()
- params = {"datetime": now.strftime("%Y-%m-%d %H:%M:%S")}
- keys = saved_params_all
- if all_params.get('preview_from_txt2img'):
- keys = keys | saved_params_previews
- params.update({k: v for k, v in all_params.items() if k in keys})
- filename = f"settings-{now.strftime('%Y-%m-%d_%H-%M-%S')}.json"
- with open(os.path.join(log_directory, filename), "w", encoding='utf-8') as file:
- print(f'Training settings file: {os.path.join(log_directory, filename)}')
- json.dump(params, file, indent=2)
+import datetime
+import json
+import os
+
+saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "clip_grad_mode", "clip_grad_value", "gradient_step", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file", "latent_sampling_method"}
+saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"}
+saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"}
+saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet
+saved_params_previews = {"preview_prompt", "preview_negative_prompt", "preview_steps", "preview_sampler_index", "preview_cfg_scale", "preview_seed", "preview_width", "preview_height"}
+
+
+def save_settings_to_file(log_directory, all_params):
+ now = datetime.datetime.now()
+ params = {"datetime": now.strftime("%Y-%m-%d %H:%M:%S")}
+ keys = saved_params_all
+ if all_params.get('preview_from_txt2img'):
+ keys = keys | saved_params_previews
+ params.update({k: v for k, v in all_params.items() if k in keys})
+ filename = f"settings-{now.strftime('%Y-%m-%d_%H-%M-%S')}.json"
+ with open(os.path.join(log_directory, filename), "w", encoding='utf-8') as file:
+ print(f'Training settings file: {os.path.join(log_directory, filename)}')
+ json.dump(params, file, indent=2)
diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py
index 1f488848c..9c5d7e3c4 100644
--- a/modules/textual_inversion/ui.py
+++ b/modules/textual_inversion/ui.py
@@ -1,35 +1,35 @@
-import html
-import gradio as gr
-import modules.textual_inversion.textual_inversion
-import modules.textual_inversion.preprocess
-from modules import shared
-
-
-def create_embedding(name, initialization_text, nvpt, overwrite_old):
- from modules import sd_hijack
- filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, overwrite_old, init_text=initialization_text)
- sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
- return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", ""
-
-
-def preprocess(*args):
- modules.textual_inversion.preprocess.preprocess(*args)
- return f"Preprocessing {'interrupted' if shared.state.interrupted else 'finished'}.", ""
-
-
-def train_embedding(*args):
- from modules import sd_hijack
- assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
- apply_optimizations = False
- try:
- if not apply_optimizations:
- sd_hijack.undo_optimizations()
- embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args)
- res = f"Training {'interrupted' if shared.state.interrupted else 'finished'} at {embedding.step} steps. Embedding saved to {html.escape(filename)}"
- return res, ""
- except Exception as e:
- shared.log.error(f"Exception in train_embedding: {e}")
- raise RuntimeError from e
- finally:
- if not apply_optimizations:
- sd_hijack.apply_optimizations()
+import html
+import gradio as gr
+import modules.textual_inversion.textual_inversion
+import modules.textual_inversion.preprocess
+from modules import shared
+
+
+def create_embedding(name, initialization_text, nvpt, overwrite_old):
+ from modules import sd_hijack
+ filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, overwrite_old, init_text=initialization_text)
+ sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
+ return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", ""
+
+
+def preprocess(*args):
+ modules.textual_inversion.preprocess.preprocess(*args)
+ return f"Preprocessing {'interrupted' if shared.state.interrupted else 'finished'}.", ""
+
+
+def train_embedding(*args):
+ from modules import sd_hijack
+ assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
+ apply_optimizations = False
+ try:
+ if not apply_optimizations:
+ sd_hijack.undo_optimizations()
+ embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args)
+ res = f"Training {'interrupted' if shared.state.interrupted else 'finished'} at {embedding.step} steps. Embedding saved to {html.escape(filename)}"
+ return res, ""
+ except Exception as e:
+ shared.log.error(f"Exception in train_embedding: {e}")
+ raise RuntimeError from e
+ finally:
+ if not apply_optimizations:
+ sd_hijack.apply_optimizations()
diff --git a/modules/timer.py b/modules/timer.py
index 17ebc559d..4f1e6a744 100644
--- a/modules/timer.py
+++ b/modules/timer.py
@@ -1,35 +1,35 @@
-import time
-
-
-class Timer:
- def __init__(self):
- self.start = time.time()
- self.records = {}
- self.total = 0
-
- def elapsed(self):
- end = time.time()
- res = end - self.start
- self.start = end
- return res
-
- def record(self, category, extra_time=0):
- e = self.elapsed()
- if category not in self.records:
- self.records[category] = 0
-
- self.records[category] += e + extra_time
- self.total += e + extra_time
-
- def summary(self, min_time=0.05):
- res = f"{self.total:.2f}"
- additions = [x for x in self.records.items() if x[1] >= min_time]
- if not additions:
- return res
- res += " { " + " ".join([f"{category}={time_taken:.2f}" for category, time_taken in additions]) + " }"
- return res
-
- def reset(self):
- self.__init__()
-
-startup = Timer()
+import time
+
+
+class Timer:
+ def __init__(self):
+ self.start = time.time()
+ self.records = {}
+ self.total = 0
+
+ def elapsed(self):
+ end = time.time()
+ res = end - self.start
+ self.start = end
+ return res
+
+ def record(self, category, extra_time=0):
+ e = self.elapsed()
+ if category not in self.records:
+ self.records[category] = 0
+
+ self.records[category] += e + extra_time
+ self.total += e + extra_time
+
+ def summary(self, min_time=0.05):
+ res = f"{self.total:.2f}"
+ additions = [x for x in self.records.items() if x[1] >= min_time]
+ if not additions:
+ return res
+ res += " { " + " ".join([f"{category}={time_taken:.2f}" for category, time_taken in additions]) + " }"
+ return res
+
+ def reset(self):
+ self.__init__()
+
+startup = Timer()
diff --git a/modules/txt2img.py b/modules/txt2img.py
index edb8d87e8..8a52003b8 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -1,94 +1,94 @@
-import os
-import modules.scripts
-from modules import shared, processing
-from modules.generation_parameters_copypaste import create_override_settings_dict
-from modules.ui import plaintext_to_html
-
-
-debug = shared.log.trace if os.environ.get('SD_PROCESS_DEBUG', None) is not None else lambda *args, **kwargs: None
-debug('Trace: PROCESS')
-
-
-def txt2img(id_task,
- prompt, negative_prompt, prompt_styles,
- steps, sampler_index, hr_sampler_index,
- full_quality, restore_faces, tiling,
- n_iter, batch_size,
- cfg_scale, image_cfg_scale, diffusers_guidance_rescale, sag_scale,
- clip_skip,
- seed, subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w,
- height, width,
- enable_hr, denoising_strength,
- hr_scale, hr_upscaler, hr_force, hr_second_pass_steps, hr_resize_x, hr_resize_y,
- refiner_steps, refiner_start, refiner_prompt, refiner_negative,
- hdr_clamp, hdr_boundary, hdr_threshold, hdr_center, hdr_channel_shift, hdr_full_shift, hdr_maximize, hdr_max_center, hdr_max_boundry,
- override_settings_texts,
- *args):
-
- debug(f'txt2img: id_task={id_task}|prompt={prompt}|negative={negative_prompt}|styles={prompt_styles}|steps={steps}|sampler_index={sampler_index}|hr_sampler_index={hr_sampler_index}|full_quality={full_quality}|restore_faces={restore_faces}|tiling={tiling}|batch_count={n_iter}|batch_size={batch_size}|cfg_scale={cfg_scale}|clip_skip={clip_skip}|seed={seed}|subseed={subseed}|subseed_strength={subseed_strength}|seed_resize_from_h={seed_resize_from_h}|seed_resize_from_w={seed_resize_from_w}|height={height}|width={width}|enable_hr={enable_hr}|denoising_strength={denoising_strength}|hr_scale={hr_scale}|hr_upscaler={hr_upscaler}|hr_force={hr_force}|hr_second_pass_steps={hr_second_pass_steps}|hr_resize_x={hr_resize_x}|hr_resize_y={hr_resize_y}|image_cfg_scale={image_cfg_scale}|diffusers_guidance_rescale={diffusers_guidance_rescale}|refiner_steps={refiner_steps}|refiner_start={refiner_start}|refiner_prompt={refiner_prompt}|refiner_negative={refiner_negative}|override_settings={override_settings_texts}')
-
- if shared.sd_model is None:
- shared.log.warning('Model not loaded')
- return [], '', '', 'Error: model not loaded'
-
- override_settings = create_override_settings_dict(override_settings_texts)
- if sampler_index is None:
- sampler_index = 0
- if hr_sampler_index is None:
- hr_sampler_index = 0
-
- p = processing.StableDiffusionProcessingTxt2Img(
- sd_model=shared.sd_model,
- outpath_samples=shared.opts.outdir_samples or shared.opts.outdir_txt2img_samples,
- outpath_grids=shared.opts.outdir_grids or shared.opts.outdir_txt2img_grids,
- prompt=prompt,
- styles=prompt_styles,
- negative_prompt=negative_prompt,
- seed=seed,
- subseed=subseed,
- subseed_strength=subseed_strength,
- seed_resize_from_h=seed_resize_from_h,
- seed_resize_from_w=seed_resize_from_w,
- seed_enable_extras=True,
- sampler_name = processing.get_sampler_name(sampler_index),
- hr_sampler_name = processing.get_sampler_name(hr_sampler_index),
- batch_size=batch_size,
- n_iter=n_iter,
- steps=steps,
- cfg_scale=cfg_scale,
- image_cfg_scale=image_cfg_scale,
- diffusers_guidance_rescale=diffusers_guidance_rescale,
- sag_scale=sag_scale,
- clip_skip=clip_skip,
- width=width,
- height=height,
- full_quality=full_quality,
- restore_faces=restore_faces,
- tiling=tiling,
- enable_hr=enable_hr,
- denoising_strength=denoising_strength,
- hr_scale=hr_scale,
- hr_upscaler=hr_upscaler,
- hr_force=hr_force,
- hr_second_pass_steps=hr_second_pass_steps,
- hr_resize_x=hr_resize_x,
- hr_resize_y=hr_resize_y,
- refiner_steps=refiner_steps,
- refiner_start=refiner_start,
- refiner_prompt=refiner_prompt,
- refiner_negative=refiner_negative,
- hdr_clamp=hdr_clamp, hdr_boundary=hdr_boundary, hdr_threshold=hdr_threshold,
- hdr_center=hdr_center, hdr_channel_shift=hdr_channel_shift, hdr_full_shift=hdr_full_shift,
- hdr_maximize=hdr_maximize, hdr_max_center=hdr_max_center, hdr_max_boundry=hdr_max_boundry,
- override_settings=override_settings,
- )
- p.scripts = modules.scripts.scripts_txt2img
- p.script_args = args
- processed = modules.scripts.scripts_txt2img.run(p, *args)
- if processed is None:
- processed = processing.process_images(p)
- p.close()
- if processed is None:
- return [], '', '', 'Error: processing failed'
- generation_info_js = processed.js() if processed is not None else ''
- return processed.images, generation_info_js, processed.info, plaintext_to_html(processed.comments)
+import os
+import modules.scripts
+from modules import shared, processing
+from modules.generation_parameters_copypaste import create_override_settings_dict
+from modules.ui import plaintext_to_html
+
+
+debug = shared.log.trace if os.environ.get('SD_PROCESS_DEBUG', None) is not None else lambda *args, **kwargs: None
+debug('Trace: PROCESS')
+
+
+def txt2img(id_task,
+ prompt, negative_prompt, prompt_styles,
+ steps, sampler_index, hr_sampler_index,
+ full_quality, restore_faces, tiling,
+ n_iter, batch_size,
+ cfg_scale, image_cfg_scale, diffusers_guidance_rescale, sag_scale,
+ clip_skip,
+ seed, subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w,
+ height, width,
+ enable_hr, denoising_strength,
+ hr_scale, hr_upscaler, hr_force, hr_second_pass_steps, hr_resize_x, hr_resize_y,
+ refiner_steps, refiner_start, refiner_prompt, refiner_negative,
+ hdr_clamp, hdr_boundary, hdr_threshold, hdr_center, hdr_channel_shift, hdr_full_shift, hdr_maximize, hdr_max_center, hdr_max_boundry,
+ override_settings_texts,
+ *args):
+
+ debug(f'txt2img: id_task={id_task}|prompt={prompt}|negative={negative_prompt}|styles={prompt_styles}|steps={steps}|sampler_index={sampler_index}|hr_sampler_index={hr_sampler_index}|full_quality={full_quality}|restore_faces={restore_faces}|tiling={tiling}|batch_count={n_iter}|batch_size={batch_size}|cfg_scale={cfg_scale}|clip_skip={clip_skip}|seed={seed}|subseed={subseed}|subseed_strength={subseed_strength}|seed_resize_from_h={seed_resize_from_h}|seed_resize_from_w={seed_resize_from_w}|height={height}|width={width}|enable_hr={enable_hr}|denoising_strength={denoising_strength}|hr_scale={hr_scale}|hr_upscaler={hr_upscaler}|hr_force={hr_force}|hr_second_pass_steps={hr_second_pass_steps}|hr_resize_x={hr_resize_x}|hr_resize_y={hr_resize_y}|image_cfg_scale={image_cfg_scale}|diffusers_guidance_rescale={diffusers_guidance_rescale}|refiner_steps={refiner_steps}|refiner_start={refiner_start}|refiner_prompt={refiner_prompt}|refiner_negative={refiner_negative}|override_settings={override_settings_texts}')
+
+ if shared.sd_model is None:
+ shared.log.warning('Model not loaded')
+ return [], '', '', 'Error: model not loaded'
+
+ override_settings = create_override_settings_dict(override_settings_texts)
+ if sampler_index is None:
+ sampler_index = 0
+ if hr_sampler_index is None:
+ hr_sampler_index = 0
+
+ p = processing.StableDiffusionProcessingTxt2Img(
+ sd_model=shared.sd_model,
+ outpath_samples=shared.opts.outdir_samples or shared.opts.outdir_txt2img_samples,
+ outpath_grids=shared.opts.outdir_grids or shared.opts.outdir_txt2img_grids,
+ prompt=prompt,
+ styles=prompt_styles,
+ negative_prompt=negative_prompt,
+ seed=seed,
+ subseed=subseed,
+ subseed_strength=subseed_strength,
+ seed_resize_from_h=seed_resize_from_h,
+ seed_resize_from_w=seed_resize_from_w,
+ seed_enable_extras=True,
+ sampler_name = processing.get_sampler_name(sampler_index),
+ hr_sampler_name = processing.get_sampler_name(hr_sampler_index),
+ batch_size=batch_size,
+ n_iter=n_iter,
+ steps=steps,
+ cfg_scale=cfg_scale,
+ image_cfg_scale=image_cfg_scale,
+ diffusers_guidance_rescale=diffusers_guidance_rescale,
+ sag_scale=sag_scale,
+ clip_skip=clip_skip,
+ width=width,
+ height=height,
+ full_quality=full_quality,
+ restore_faces=restore_faces,
+ tiling=tiling,
+ enable_hr=enable_hr,
+ denoising_strength=denoising_strength,
+ hr_scale=hr_scale,
+ hr_upscaler=hr_upscaler,
+ hr_force=hr_force,
+ hr_second_pass_steps=hr_second_pass_steps,
+ hr_resize_x=hr_resize_x,
+ hr_resize_y=hr_resize_y,
+ refiner_steps=refiner_steps,
+ refiner_start=refiner_start,
+ refiner_prompt=refiner_prompt,
+ refiner_negative=refiner_negative,
+ hdr_clamp=hdr_clamp, hdr_boundary=hdr_boundary, hdr_threshold=hdr_threshold,
+ hdr_center=hdr_center, hdr_channel_shift=hdr_channel_shift, hdr_full_shift=hdr_full_shift,
+ hdr_maximize=hdr_maximize, hdr_max_center=hdr_max_center, hdr_max_boundry=hdr_max_boundry,
+ override_settings=override_settings,
+ )
+ p.scripts = modules.scripts.scripts_txt2img
+ p.script_args = args
+ processed = modules.scripts.scripts_txt2img.run(p, *args)
+ if processed is None:
+ processed = processing.process_images(p)
+ p.close()
+ if processed is None:
+ return [], '', '', 'Error: processing failed'
+ generation_info_js = processed.js() if processed is not None else ''
+ return processed.images, generation_info_js, processed.info, plaintext_to_html(processed.comments)
diff --git a/modules/ui.py b/modules/ui.py
index 9397a2ef6..1bd7e0c4e 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1,483 +1,483 @@
-import os
-import mimetypes
-import gradio as gr
-import gradio.routes
-import gradio.utils
-from modules.call_queue import wrap_gradio_call
-from modules import timer, gr_hijack, shared, theme, sd_models, script_callbacks, modelloader, ui_common, ui_loadsave, ui_symbols, ui_javascript, generation_parameters_copypaste
-from modules.ui_components import FormRow
-from modules.paths import script_path, data_path # pylint: disable=unused-import
-from modules.dml import directml_override_opts
-import modules.scripts
-import modules.errors
-
-
-modules.errors.install()
-mimetypes.init()
-mimetypes.add_type('application/javascript', '.js')
-log = shared.log
-opts = shared.opts
-cmd_opts = shared.cmd_opts
-ui_system_tabs = None
-switch_values_symbol = ui_symbols.switch
-detect_image_size_symbol = ui_symbols.detect
-paste_symbol = ui_symbols.paste
-clear_prompt_symbol = ui_symbols.clear
-restore_progress_symbol = ui_symbols.apply
-folder_symbol = ui_symbols.folder
-extra_networks_symbol = ui_symbols.networks
-apply_style_symbol = ui_symbols.apply
-save_style_symbol = ui_symbols.save
-paste_function = None
-gr_hijack.init()
-
-
-if not cmd_opts.share and not cmd_opts.listen:
- # fix gradio phoning home
- gradio.utils.version_check = lambda: None
- gradio.utils.get_local_ip_address = lambda: '127.0.0.1'
-
-
-def gr_show(visible=True):
- return {"visible": visible, "__type__": "update"}
-
-
-def create_output_panel(tabname, outdir): # pylint: disable=unused-argument # outdir is used by extensions
- a, b, c, _d, e = ui_common.create_output_panel(tabname)
- return a, b, c, e
-
-
-def plaintext_to_html(text): # may be referenced by extensions
- return ui_common.plaintext_to_html(text)
-
-
-def infotext_to_html(text): # may be referenced by extensions
- return ui_common.infotext_to_html(text)
-
-
-def send_gradio_gallery_to_image(x):
- if len(x) == 0:
- return None
- return generation_parameters_copypaste.image_from_url_text(x[0])
-
-
-def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
- return ui_common.create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id)
-
-
-def connect_clear_prompt(button): # pylint: disable=unused-argument
- pass
-
-
-def setup_progressbar(*args, **kwargs): # pylint: disable=unused-argument
- pass
-
-
-def apply_setting(key, value):
- if value is None:
- return gr.update()
- if shared.cmd_opts.freeze:
- return gr.update()
- # dont allow model to be swapped when model hash exists in prompt
- if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap:
- return gr.update()
- if key == "sd_model_checkpoint":
- ckpt_info = sd_models.get_closet_checkpoint_match(value)
- if ckpt_info is not None:
- value = ckpt_info.title
- else:
- return gr.update()
- comp_args = opts.data_labels[key].component_args
- if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
- return gr.update()
- valtype = type(opts.data_labels[key].default)
- oldval = opts.data.get(key, None)
- opts.data[key] = valtype(value) if valtype != type(None) else value
- if oldval != value and opts.data_labels[key].onchange is not None:
- opts.data_labels[key].onchange()
- opts.save(shared.config_filename)
- return getattr(opts, key)
-
-
-def get_value_for_setting(key):
- value = getattr(opts, key)
- info = opts.data_labels[key]
- args = info.component_args() if callable(info.component_args) else info.component_args or {}
- args = {k: v for k, v in args.items() if k not in {'precision'}}
- return gr.update(value=value, **args)
-
-
-def ordered_ui_categories():
- return ['dimensions', 'sampler', 'seed', 'denoising', 'cfg', 'checkboxes', 'accordions', 'override_settings', 'scripts'] # a1111 compatibility item, not implemented
-
-
-def create_ui(startup_timer = None):
- if startup_timer is None:
- timer.startup = timer.Timer()
- ui_javascript.reload_javascript()
- generation_parameters_copypaste.reset()
-
- with gr.Blocks(analytics_enabled=False) as txt2img_interface:
- from modules import ui_txt2img
- ui_txt2img.create_ui()
- timer.startup.record("ui-txt2img")
-
- with gr.Blocks(analytics_enabled=False) as img2img_interface:
- from modules import ui_img2img
- ui_img2img.create_ui()
- timer.startup.record("ui-img2img")
-
- modules.scripts.scripts_current = None
-
- with gr.Blocks(analytics_enabled=False) as control_interface:
- if shared.backend == shared.Backend.DIFFUSERS:
- from modules import ui_control
- ui_control.create_ui()
- timer.startup.record("ui-control")
-
- with gr.Blocks(analytics_enabled=False) as extras_interface:
- from modules import ui_postprocessing
- ui_postprocessing.create_ui()
- timer.startup.record("ui-extras")
-
- with gr.Blocks(analytics_enabled=False) as train_interface:
- if shared.backend == shared.Backend.ORIGINAL:
- from modules import ui_train
- ui_train.create_ui()
- timer.startup.record("ui-train")
-
- with gr.Blocks(analytics_enabled=False) as models_interface:
- from modules import ui_models
- ui_models.create_ui()
- timer.startup.record("ui-models")
-
- with gr.Blocks(analytics_enabled=False) as interrogate_interface:
- from modules import ui_interrogate
- ui_interrogate.create_ui()
- timer.startup.record("ui-interrogate")
-
-
- def create_setting_component(key, is_quicksettings=False):
- def fun():
- return opts.data[key] if key in opts.data else opts.data_labels[key].default
-
- info = opts.data_labels[key]
- t = type(info.default)
- args = (info.component_args() if callable(info.component_args) else info.component_args) or {}
- if info.component is not None:
- comp = info.component
- elif t == str:
- comp = gr.Textbox
- elif t == int:
- comp = gr.Number
- elif t == bool:
- comp = gr.Checkbox
- else:
- raise ValueError(f'bad options item type: {t} for key {key}')
- elem_id = f"setting_{key}"
-
- if not is_quicksettings:
- dirtyable_setting = gr.Group(elem_classes="dirtyable", visible=args.get("visible", True))
- dirtyable_setting.__enter__()
- dirty_indicator = gr.Button("", elem_classes="modification-indicator", elem_id="modification_indicator_" + key)
-
- if info.refresh is not None:
- if is_quicksettings:
- res = comp(label=info.label, value=fun(), elem_id=elem_id, **args)
- ui_common.create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}")
- else:
- with FormRow():
- res = comp(label=info.label, value=fun(), elem_id=elem_id, **args)
- ui_common.create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}")
- elif info.folder is not None:
- with FormRow():
- res = comp(label=info.label, value=fun(), elem_id=elem_id, elem_classes="folder-selector", **args)
- # ui_common.create_browse_button(res, f"folder_{key}")
- else:
- try:
- res = comp(label=info.label, value=fun(), elem_id=elem_id, **args)
- except Exception as e:
- log.error(f'Error creating setting: {key} {e}')
- res = None
-
- if res is not None and not is_quicksettings:
- res.change(fn=None, inputs=res, _js=f'(val) => markIfModified("{key}", val)')
- dirty_indicator.click(fn=lambda: getattr(opts, key), outputs=res, show_progress=False)
- dirtyable_setting.__exit__()
-
- return res
-
- def create_dirty_indicator(key, keys_to_reset, **kwargs):
- def get_opt_values():
- return [getattr(opts, _key) for _key in keys_to_reset]
-
- elements_to_reset = [component_dict[_key] for _key in keys_to_reset if component_dict[_key] is not None]
- indicator = gr.Button("", elem_classes="modification-indicator", elem_id=f"modification_indicator_{key}", **kwargs)
- indicator.click(fn=get_opt_values, outputs=elements_to_reset, show_progress=False)
- return indicator
-
- loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config)
- components = []
- component_dict = {}
- shared.settings_components = component_dict
-
- script_callbacks.ui_settings_callback()
- opts.reorder()
-
- def run_settings(*args):
- changed = []
- for key, value, comp in zip(opts.data_labels.keys(), args, components):
- if comp == dummy_component or value=='dummy':
- continue
- if not opts.same_type(value, opts.data_labels[key].default):
- log.error(f'Setting bad value: {key}={value} expecting={type(opts.data_labels[key].default).__name__}')
- continue
- if opts.set(key, value):
- changed.append(key)
- if cmd_opts.use_directml:
- directml_override_opts()
- if cmd_opts.use_openvino:
- if not shared.opts.cuda_compile:
- shared.log.warning("OpenVINO: Enabling Torch Compile")
- shared.opts.cuda_compile = True
- if shared.opts.cuda_compile_backend != "openvino_fx":
- shared.log.warning("OpenVINO: Setting Torch Compiler backend to OpenVINO FX")
- shared.opts.cuda_compile_backend = "openvino_fx"
- if shared.opts.sd_backend != "diffusers":
- shared.log.warning("OpenVINO: Setting backend to Diffusers")
- shared.opts.sd_backend = "diffusers"
- try:
- if len(changed) > 0:
- opts.save(shared.config_filename)
- log.info(f'Settings: changed={len(changed)} {changed}')
- except RuntimeError:
- log.error(f'Settings failed: change={len(changed)} {changed}')
- return opts.dumpjson(), f'{len(changed)} Settings changed without save: {", ".join(changed)}'
- return opts.dumpjson(), f'{len(changed)} Settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}'
-
- def run_settings_single(value, key):
- if not opts.same_type(value, opts.data_labels[key].default):
- return gr.update(visible=True), opts.dumpjson()
- if not opts.set(key, value):
- return gr.update(value=getattr(opts, key)), opts.dumpjson()
- if cmd_opts.use_directml:
- directml_override_opts()
- opts.save(shared.config_filename)
- log.debug(f'Setting changed: key={key}, value={value}')
- return get_value_for_setting(key), opts.dumpjson()
-
- with gr.Blocks(analytics_enabled=False) as settings_interface:
- with gr.Row(elem_id="system_row"):
- restart_submit = gr.Button(value="Restart server", variant='primary', elem_id="restart_submit")
- shutdown_submit = gr.Button(value="Shutdown server", variant='primary', elem_id="shutdown_submit")
- unload_sd_model = gr.Button(value='Unload checkpoint', variant='primary', elem_id="sett_unload_sd_model")
- reload_sd_model = gr.Button(value='Reload checkpoint', variant='primary', elem_id="sett_reload_sd_model")
-
- with gr.Tabs(elem_id="system") as system_tabs:
- global ui_system_tabs # pylint: disable=global-statement
- ui_system_tabs = system_tabs
- with gr.TabItem("Settings", id="system_settings", elem_id="tab_settings"):
- with gr.Row(elem_id="settings_row"):
- settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit")
- preview_theme = gr.Button(value="Preview theme", variant='primary', elem_id="settings_preview_theme")
- defaults_submit = gr.Button(value="Restore defaults", variant='primary', elem_id="defaults_submit")
- with gr.Row():
- _settings_search = gr.Text(label="Search", elem_id="settings_search")
-
- result = gr.HTML(elem_id="settings_result")
- quicksettings_names = opts.quicksettings_list
- quicksettings_names = {x: i for i, x in enumerate(quicksettings_names) if x != 'quicksettings'}
- quicksettings_list = []
-
- previous_section = []
- tab_item_keys = []
- current_tab = None
- current_row = None
- dummy_component = gr.Textbox(visible=False, value='dummy')
- with gr.Tabs(elem_id="settings"):
- for i, (k, item) in enumerate(opts.data_labels.items()):
- section_must_be_skipped = item.section[0] is None
- if previous_section != item.section and not section_must_be_skipped:
- elem_id, text = item.section
- if current_tab is not None and len(previous_section) > 0:
- create_dirty_indicator(previous_section[0], tab_item_keys)
- tab_item_keys = []
- current_row.__exit__()
- current_tab.__exit__()
- current_tab = gr.TabItem(elem_id=f"settings_{elem_id}", label=text)
- current_tab.__enter__()
- current_row = gr.Column(variant='compact')
- current_row.__enter__()
- previous_section = item.section
- if k in quicksettings_names and not shared.cmd_opts.freeze:
- quicksettings_list.append((i, k, item))
- components.append(dummy_component)
- elif section_must_be_skipped:
- components.append(dummy_component)
- else:
- component = create_setting_component(k)
- component_dict[k] = component
- tab_item_keys.append(k)
- components.append(component)
- if current_tab is not None and len(previous_section) > 0:
- create_dirty_indicator(previous_section[0], tab_item_keys)
- tab_item_keys = []
- current_row.__exit__()
- current_tab.__exit__()
-
- request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications", visible=False)
- with gr.TabItem("Show all pages", elem_id="settings_show_all_pages"):
- create_dirty_indicator("show_all_pages", [], interactive=False)
-
- with gr.TabItem("User interface", id="system_config", elem_id="tab_config"):
- loadsave.create_ui()
- create_dirty_indicator("tab_defaults", [], interactive=False)
-
- with gr.TabItem("Change log", id="change_log", elem_id="system_tab_changelog"):
- with open('CHANGELOG.md', 'r', encoding='utf-8') as f:
- md = f.read()
- gr.Markdown(md)
-
- with gr.TabItem("Licenses", id="system_licenses", elem_id="system_tab_licenses"):
- gr.HTML(shared.html("licenses.html"), elem_id="licenses", elem_classes="licenses")
- create_dirty_indicator("tab_licenses", [], interactive=False)
-
- def unload_sd_weights():
- modules.sd_models.unload_model_weights(op='model')
- modules.sd_models.unload_model_weights(op='refiner')
-
- def reload_sd_weights():
- modules.sd_models.reload_model_weights()
-
- unload_sd_model.click(fn=unload_sd_weights, inputs=[], outputs=[])
- reload_sd_model.click(fn=reload_sd_weights, inputs=[], outputs=[])
- request_notifications.click(fn=lambda: None, inputs=[], outputs=[], _js='function(){}')
- preview_theme.click(fn=None, _js='previewTheme', inputs=[], outputs=[])
-
- timer.startup.record("ui-settings")
-
- interfaces = []
- interfaces += [(txt2img_interface, "Text", "txt2img")]
- interfaces += [(img2img_interface, "Image", "img2img")]
- interfaces += [(control_interface, "Control", "control")] if control_interface is not None else []
- interfaces += [(extras_interface, "Process", "process")]
- interfaces += [(interrogate_interface, "Interrogate", "interrogate")]
- interfaces += [(train_interface, "Train", "train")]
- interfaces += [(models_interface, "Models", "models")]
- interfaces += script_callbacks.ui_tabs_callback()
- interfaces += [(settings_interface, "System", "system")]
-
- from modules import ui_extensions
- extensions_interface = ui_extensions.create_ui()
- interfaces += [(extensions_interface, "Extensions", "extensions")]
- timer.startup.record("ui-extensions")
-
- shared.tab_names = []
- for _interface, label, _ifid in interfaces:
- shared.tab_names.append(label)
-
- with gr.Blocks(theme=theme.gradio_theme, analytics_enabled=False, title="SD.Next") as ui_app:
- with gr.Row(elem_id="quicksettings", variant="compact"):
- for _i, k, _item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])):
- component = create_setting_component(k, is_quicksettings=True)
- component_dict[k] = component
-
- generation_parameters_copypaste.connect_paste_params_buttons()
-
- with gr.Tabs(elem_id="tabs") as tabs:
- for interface, label, ifid in interfaces:
- if interface is None:
- continue
- # if label in shared.opts.hidden_tabs or label == '':
- # continue
- with gr.TabItem(label, id=ifid, elem_id=f"tab_{ifid}"):
- # log.debug(f'UI render: id={ifid}')
- interface.render()
- for interface, _label, ifid in interfaces:
- if interface is None:
- continue
- if ifid in ["extensions", "system"]:
- continue
- loadsave.add_block(interface, ifid)
- loadsave.add_component(f"webui/Tabs@{tabs.elem_id}", tabs)
- loadsave.setup_ui()
- if opts.notification_audio_enable and os.path.exists(os.path.join(script_path, opts.notification_audio_path)):
- gr.Audio(interactive=False, value=os.path.join(script_path, opts.notification_audio_path), elem_id="audio_notification", visible=False)
-
- text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
- components = [c for c in components if c is not None]
- settings_submit.click(
- fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]),
- inputs=components,
- outputs=[text_settings, result],
- )
- defaults_submit.click(fn=lambda: shared.restore_defaults(restart=True), _js="restartReload")
- restart_submit.click(fn=lambda: shared.restart_server(restart=True), _js="restartReload")
- shutdown_submit.click(fn=lambda: shared.restart_server(restart=False), _js="restartReload")
-
- for _i, k, _item in quicksettings_list:
- component = component_dict[k]
- info = opts.data_labels[k]
- change_handler = component.release if hasattr(component, 'release') else component.change
- change_handler(
- fn=lambda value, k=k: run_settings_single(value, key=k),
- inputs=[component],
- outputs=[component, text_settings],
- show_progress=info.refresh is not None,
- )
-
- dummy_component = gr.Textbox(visible=False, value='dummy')
- button_set_checkpoint = gr.Button('Change model', elem_id='change_checkpoint', visible=False)
- button_set_checkpoint.click(
- fn=lambda value, _: run_settings_single(value, key='sd_model_checkpoint'),
- _js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }",
- inputs=[component_dict['sd_model_checkpoint'], dummy_component],
- outputs=[component_dict['sd_model_checkpoint'], text_settings],
- )
- button_set_refiner = gr.Button('Change refiner', elem_id='change_refiner', visible=False)
- button_set_refiner.click(
- fn=lambda value, _: run_settings_single(value, key='sd_model_checkpoint'),
- _js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }",
- inputs=[component_dict['sd_model_refiner'], dummy_component],
- outputs=[component_dict['sd_model_refiner'], text_settings],
- )
- button_set_vae = gr.Button('Change VAE', elem_id='change_vae', visible=False)
- button_set_vae.click(
- fn=lambda value, _: run_settings_single(value, key='sd_vae'),
- _js="function(v){ var res = desiredVAEName; desiredVAEName = ''; return [res || v, null]; }",
- inputs=[component_dict['sd_vae'], dummy_component],
- outputs=[component_dict['sd_vae'], text_settings],
- )
-
- def reference_submit(model):
- if '@' not in model: # diffusers
- loaded = modelloader.load_reference(model)
- return model if loaded else opts.sd_model_checkpoint
- else: # civitai
- model, url = model.split('@')
- loaded = modelloader.load_civitai(model, url)
- return loaded if loaded is not None else opts.sd_model_checkpoint
-
- button_set_reference = gr.Button('Change reference', elem_id='change_reference', visible=False)
- button_set_reference.click(
- fn=reference_submit,
- _js="function(v){ return desiredCheckpointName; }",
- inputs=[component_dict['sd_model_checkpoint']],
- outputs=[component_dict['sd_model_checkpoint']],
- )
- component_keys = [k for k in opts.data_labels.keys() if k in component_dict]
-
- def get_settings_values():
- return [get_value_for_setting(key) for key in component_keys]
-
- ui_app.load(
- fn=get_settings_values,
- inputs=[],
- outputs=[component_dict[k] for k in component_keys if component_dict[k] is not None],
- queue=False,
- )
-
- timer.startup.record("ui-defaults")
- loadsave.dump_defaults()
- ui_app.ui_loadsave = loadsave
- return ui_app
+import os
+import mimetypes
+import gradio as gr
+import gradio.routes
+import gradio.utils
+from modules.call_queue import wrap_gradio_call
+from modules import timer, gr_hijack, shared, theme, sd_models, script_callbacks, modelloader, ui_common, ui_loadsave, ui_symbols, ui_javascript, generation_parameters_copypaste
+from modules.ui_components import FormRow
+from modules.paths import script_path, data_path # pylint: disable=unused-import
+from modules.dml import directml_override_opts
+import modules.scripts
+import modules.errors
+
+
+modules.errors.install()
+mimetypes.init()
+mimetypes.add_type('application/javascript', '.js')
+log = shared.log
+opts = shared.opts
+cmd_opts = shared.cmd_opts
+ui_system_tabs = None
+switch_values_symbol = ui_symbols.switch
+detect_image_size_symbol = ui_symbols.detect
+paste_symbol = ui_symbols.paste
+clear_prompt_symbol = ui_symbols.clear
+restore_progress_symbol = ui_symbols.apply
+folder_symbol = ui_symbols.folder
+extra_networks_symbol = ui_symbols.networks
+apply_style_symbol = ui_symbols.apply
+save_style_symbol = ui_symbols.save
+paste_function = None
+gr_hijack.init()
+
+
+if not cmd_opts.share and not cmd_opts.listen:
+ # fix gradio phoning home
+ gradio.utils.version_check = lambda: None
+ gradio.utils.get_local_ip_address = lambda: '127.0.0.1'
+
+
+def gr_show(visible=True):
+ return {"visible": visible, "__type__": "update"}
+
+
+def create_output_panel(tabname, outdir): # pylint: disable=unused-argument # outdir is used by extensions
+ a, b, c, _d, e = ui_common.create_output_panel(tabname)
+ return a, b, c, e
+
+
+def plaintext_to_html(text): # may be referenced by extensions
+ return ui_common.plaintext_to_html(text)
+
+
+def infotext_to_html(text): # may be referenced by extensions
+ return ui_common.infotext_to_html(text)
+
+
+def send_gradio_gallery_to_image(x):
+ if len(x) == 0:
+ return None
+ return generation_parameters_copypaste.image_from_url_text(x[0])
+
+
+def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
+ return ui_common.create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id)
+
+
+def connect_clear_prompt(button): # pylint: disable=unused-argument
+ pass
+
+
+def setup_progressbar(*args, **kwargs): # pylint: disable=unused-argument
+ pass
+
+
+def apply_setting(key, value):
+ if value is None:
+ return gr.update()
+ if shared.cmd_opts.freeze:
+ return gr.update()
+ # dont allow model to be swapped when model hash exists in prompt
+ if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap:
+ return gr.update()
+ if key == "sd_model_checkpoint":
+ ckpt_info = sd_models.get_closet_checkpoint_match(value)
+ if ckpt_info is not None:
+ value = ckpt_info.title
+ else:
+ return gr.update()
+ comp_args = opts.data_labels[key].component_args
+ if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
+ return gr.update()
+ valtype = type(opts.data_labels[key].default)
+ oldval = opts.data.get(key, None)
+ opts.data[key] = valtype(value) if valtype != type(None) else value
+ if oldval != value and opts.data_labels[key].onchange is not None:
+ opts.data_labels[key].onchange()
+ opts.save(shared.config_filename)
+ return getattr(opts, key)
+
+
+def get_value_for_setting(key):
+ value = getattr(opts, key)
+ info = opts.data_labels[key]
+ args = info.component_args() if callable(info.component_args) else info.component_args or {}
+ args = {k: v for k, v in args.items() if k not in {'precision'}}
+ return gr.update(value=value, **args)
+
+
+def ordered_ui_categories():
+ return ['dimensions', 'sampler', 'seed', 'denoising', 'cfg', 'checkboxes', 'accordions', 'override_settings', 'scripts'] # a1111 compatibility item, not implemented
+
+
+def create_ui(startup_timer = None):
+ if startup_timer is None:
+ timer.startup = timer.Timer()
+ ui_javascript.reload_javascript()
+ generation_parameters_copypaste.reset()
+
+ with gr.Blocks(analytics_enabled=False) as txt2img_interface:
+ from modules import ui_txt2img
+ ui_txt2img.create_ui()
+ timer.startup.record("ui-txt2img")
+
+ with gr.Blocks(analytics_enabled=False) as img2img_interface:
+ from modules import ui_img2img
+ ui_img2img.create_ui()
+ timer.startup.record("ui-img2img")
+
+ modules.scripts.scripts_current = None
+
+ with gr.Blocks(analytics_enabled=False) as control_interface:
+ if shared.backend == shared.Backend.DIFFUSERS:
+ from modules import ui_control
+ ui_control.create_ui()
+ timer.startup.record("ui-control")
+
+ with gr.Blocks(analytics_enabled=False) as extras_interface:
+ from modules import ui_postprocessing
+ ui_postprocessing.create_ui()
+ timer.startup.record("ui-extras")
+
+ with gr.Blocks(analytics_enabled=False) as train_interface:
+ if shared.backend == shared.Backend.ORIGINAL:
+ from modules import ui_train
+ ui_train.create_ui()
+ timer.startup.record("ui-train")
+
+ with gr.Blocks(analytics_enabled=False) as models_interface:
+ from modules import ui_models
+ ui_models.create_ui()
+ timer.startup.record("ui-models")
+
+ with gr.Blocks(analytics_enabled=False) as interrogate_interface:
+ from modules import ui_interrogate
+ ui_interrogate.create_ui()
+ timer.startup.record("ui-interrogate")
+
+
+ def create_setting_component(key, is_quicksettings=False):
+ def fun():
+ return opts.data[key] if key in opts.data else opts.data_labels[key].default
+
+ info = opts.data_labels[key]
+ t = type(info.default)
+ args = (info.component_args() if callable(info.component_args) else info.component_args) or {}
+ if info.component is not None:
+ comp = info.component
+ elif t == str:
+ comp = gr.Textbox
+ elif t == int:
+ comp = gr.Number
+ elif t == bool:
+ comp = gr.Checkbox
+ else:
+ raise ValueError(f'bad options item type: {t} for key {key}')
+ elem_id = f"setting_{key}"
+
+ if not is_quicksettings:
+ dirtyable_setting = gr.Group(elem_classes="dirtyable", visible=args.get("visible", True))
+ dirtyable_setting.__enter__()
+ dirty_indicator = gr.Button("", elem_classes="modification-indicator", elem_id="modification_indicator_" + key)
+
+ if info.refresh is not None:
+ if is_quicksettings:
+ res = comp(label=info.label, value=fun(), elem_id=elem_id, **args)
+ ui_common.create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}")
+ else:
+ with FormRow():
+ res = comp(label=info.label, value=fun(), elem_id=elem_id, **args)
+ ui_common.create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}")
+ elif info.folder is not None:
+ with FormRow():
+ res = comp(label=info.label, value=fun(), elem_id=elem_id, elem_classes="folder-selector", **args)
+ # ui_common.create_browse_button(res, f"folder_{key}")
+ else:
+ try:
+ res = comp(label=info.label, value=fun(), elem_id=elem_id, **args)
+ except Exception as e:
+ log.error(f'Error creating setting: {key} {e}')
+ res = None
+
+ if res is not None and not is_quicksettings:
+ res.change(fn=None, inputs=res, _js=f'(val) => markIfModified("{key}", val)')
+ dirty_indicator.click(fn=lambda: getattr(opts, key), outputs=res, show_progress=False)
+ dirtyable_setting.__exit__()
+
+ return res
+
+ def create_dirty_indicator(key, keys_to_reset, **kwargs):
+ def get_opt_values():
+ return [getattr(opts, _key) for _key in keys_to_reset]
+
+ elements_to_reset = [component_dict[_key] for _key in keys_to_reset if component_dict[_key] is not None]
+ indicator = gr.Button("", elem_classes="modification-indicator", elem_id=f"modification_indicator_{key}", **kwargs)
+ indicator.click(fn=get_opt_values, outputs=elements_to_reset, show_progress=False)
+ return indicator
+
+ loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config)
+ components = []
+ component_dict = {}
+ shared.settings_components = component_dict
+
+ script_callbacks.ui_settings_callback()
+ opts.reorder()
+
+ def run_settings(*args):
+ changed = []
+ for key, value, comp in zip(opts.data_labels.keys(), args, components):
+ if comp == dummy_component or value=='dummy':
+ continue
+ if not opts.same_type(value, opts.data_labels[key].default):
+ log.error(f'Setting bad value: {key}={value} expecting={type(opts.data_labels[key].default).__name__}')
+ continue
+ if opts.set(key, value):
+ changed.append(key)
+ if cmd_opts.use_directml:
+ directml_override_opts()
+ if cmd_opts.use_openvino:
+ if not shared.opts.cuda_compile:
+ shared.log.warning("OpenVINO: Enabling Torch Compile")
+ shared.opts.cuda_compile = True
+ if shared.opts.cuda_compile_backend != "openvino_fx":
+ shared.log.warning("OpenVINO: Setting Torch Compiler backend to OpenVINO FX")
+ shared.opts.cuda_compile_backend = "openvino_fx"
+ if shared.opts.sd_backend != "diffusers":
+ shared.log.warning("OpenVINO: Setting backend to Diffusers")
+ shared.opts.sd_backend = "diffusers"
+ try:
+ if len(changed) > 0:
+ opts.save(shared.config_filename)
+ log.info(f'Settings: changed={len(changed)} {changed}')
+ except RuntimeError:
+ log.error(f'Settings failed: change={len(changed)} {changed}')
+ return opts.dumpjson(), f'{len(changed)} Settings changed without save: {", ".join(changed)}'
+ return opts.dumpjson(), f'{len(changed)} Settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}'
+
+ def run_settings_single(value, key):
+ if not opts.same_type(value, opts.data_labels[key].default):
+ return gr.update(visible=True), opts.dumpjson()
+ if not opts.set(key, value):
+ return gr.update(value=getattr(opts, key)), opts.dumpjson()
+ if cmd_opts.use_directml:
+ directml_override_opts()
+ opts.save(shared.config_filename)
+ log.debug(f'Setting changed: key={key}, value={value}')
+ return get_value_for_setting(key), opts.dumpjson()
+
+ with gr.Blocks(analytics_enabled=False) as settings_interface:
+ with gr.Row(elem_id="system_row"):
+ restart_submit = gr.Button(value="Restart server", variant='primary', elem_id="restart_submit")
+ shutdown_submit = gr.Button(value="Shutdown server", variant='primary', elem_id="shutdown_submit")
+ unload_sd_model = gr.Button(value='Unload checkpoint', variant='primary', elem_id="sett_unload_sd_model")
+ reload_sd_model = gr.Button(value='Reload checkpoint', variant='primary', elem_id="sett_reload_sd_model")
+
+ with gr.Tabs(elem_id="system") as system_tabs:
+ global ui_system_tabs # pylint: disable=global-statement
+ ui_system_tabs = system_tabs
+ with gr.TabItem("Settings", id="system_settings", elem_id="tab_settings"):
+ with gr.Row(elem_id="settings_row"):
+ settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit")
+ preview_theme = gr.Button(value="Preview theme", variant='primary', elem_id="settings_preview_theme")
+ defaults_submit = gr.Button(value="Restore defaults", variant='primary', elem_id="defaults_submit")
+ with gr.Row():
+ _settings_search = gr.Text(label="Search", elem_id="settings_search")
+
+ result = gr.HTML(elem_id="settings_result")
+ quicksettings_names = opts.quicksettings_list
+ quicksettings_names = {x: i for i, x in enumerate(quicksettings_names) if x != 'quicksettings'}
+ quicksettings_list = []
+
+ previous_section = []
+ tab_item_keys = []
+ current_tab = None
+ current_row = None
+ dummy_component = gr.Textbox(visible=False, value='dummy')
+ with gr.Tabs(elem_id="settings"):
+ for i, (k, item) in enumerate(opts.data_labels.items()):
+ section_must_be_skipped = item.section[0] is None
+ if previous_section != item.section and not section_must_be_skipped:
+ elem_id, text = item.section
+ if current_tab is not None and len(previous_section) > 0:
+ create_dirty_indicator(previous_section[0], tab_item_keys)
+ tab_item_keys = []
+ current_row.__exit__()
+ current_tab.__exit__()
+ current_tab = gr.TabItem(elem_id=f"settings_{elem_id}", label=text)
+ current_tab.__enter__()
+ current_row = gr.Column(variant='compact')
+ current_row.__enter__()
+ previous_section = item.section
+ if k in quicksettings_names and not shared.cmd_opts.freeze:
+ quicksettings_list.append((i, k, item))
+ components.append(dummy_component)
+ elif section_must_be_skipped:
+ components.append(dummy_component)
+ else:
+ component = create_setting_component(k)
+ component_dict[k] = component
+ tab_item_keys.append(k)
+ components.append(component)
+ if current_tab is not None and len(previous_section) > 0:
+ create_dirty_indicator(previous_section[0], tab_item_keys)
+ tab_item_keys = []
+ current_row.__exit__()
+ current_tab.__exit__()
+
+ request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications", visible=False)
+ with gr.TabItem("Show all pages", elem_id="settings_show_all_pages"):
+ create_dirty_indicator("show_all_pages", [], interactive=False)
+
+ with gr.TabItem("User interface", id="system_config", elem_id="tab_config"):
+ loadsave.create_ui()
+ create_dirty_indicator("tab_defaults", [], interactive=False)
+
+ with gr.TabItem("Change log", id="change_log", elem_id="system_tab_changelog"):
+ with open('CHANGELOG.md', 'r', encoding='utf-8') as f:
+ md = f.read()
+ gr.Markdown(md)
+
+ with gr.TabItem("Licenses", id="system_licenses", elem_id="system_tab_licenses"):
+ gr.HTML(shared.html("licenses.html"), elem_id="licenses", elem_classes="licenses")
+ create_dirty_indicator("tab_licenses", [], interactive=False)
+
+ def unload_sd_weights():
+ modules.sd_models.unload_model_weights(op='model')
+ modules.sd_models.unload_model_weights(op='refiner')
+
+ def reload_sd_weights():
+ modules.sd_models.reload_model_weights()
+
+ unload_sd_model.click(fn=unload_sd_weights, inputs=[], outputs=[])
+ reload_sd_model.click(fn=reload_sd_weights, inputs=[], outputs=[])
+ request_notifications.click(fn=lambda: None, inputs=[], outputs=[], _js='function(){}')
+ preview_theme.click(fn=None, _js='previewTheme', inputs=[], outputs=[])
+
+ timer.startup.record("ui-settings")
+
+ interfaces = []
+ interfaces += [(txt2img_interface, "Text", "txt2img")]
+ interfaces += [(img2img_interface, "Image", "img2img")]
+ interfaces += [(control_interface, "Control", "control")] if control_interface is not None else []
+ interfaces += [(extras_interface, "Process", "process")]
+ interfaces += [(interrogate_interface, "Interrogate", "interrogate")]
+ interfaces += [(train_interface, "Train", "train")]
+ interfaces += [(models_interface, "Models", "models")]
+ interfaces += script_callbacks.ui_tabs_callback()
+ interfaces += [(settings_interface, "System", "system")]
+
+ from modules import ui_extensions
+ extensions_interface = ui_extensions.create_ui()
+ interfaces += [(extensions_interface, "Extensions", "extensions")]
+ timer.startup.record("ui-extensions")
+
+ shared.tab_names = []
+ for _interface, label, _ifid in interfaces:
+ shared.tab_names.append(label)
+
+ with gr.Blocks(theme=theme.gradio_theme, analytics_enabled=False, title="SD.Next") as ui_app:
+ with gr.Row(elem_id="quicksettings", variant="compact"):
+ for _i, k, _item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])):
+ component = create_setting_component(k, is_quicksettings=True)
+ component_dict[k] = component
+
+ generation_parameters_copypaste.connect_paste_params_buttons()
+
+ with gr.Tabs(elem_id="tabs") as tabs:
+ for interface, label, ifid in interfaces:
+ if interface is None:
+ continue
+ # if label in shared.opts.hidden_tabs or label == '':
+ # continue
+ with gr.TabItem(label, id=ifid, elem_id=f"tab_{ifid}"):
+ # log.debug(f'UI render: id={ifid}')
+ interface.render()
+ for interface, _label, ifid in interfaces:
+ if interface is None:
+ continue
+ if ifid in ["extensions", "system"]:
+ continue
+ loadsave.add_block(interface, ifid)
+ loadsave.add_component(f"webui/Tabs@{tabs.elem_id}", tabs)
+ loadsave.setup_ui()
+ if opts.notification_audio_enable and os.path.exists(os.path.join(script_path, opts.notification_audio_path)):
+ gr.Audio(interactive=False, value=os.path.join(script_path, opts.notification_audio_path), elem_id="audio_notification", visible=False)
+
+ text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
+ components = [c for c in components if c is not None]
+ settings_submit.click(
+ fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]),
+ inputs=components,
+ outputs=[text_settings, result],
+ )
+ defaults_submit.click(fn=lambda: shared.restore_defaults(restart=True), _js="restartReload")
+ restart_submit.click(fn=lambda: shared.restart_server(restart=True), _js="restartReload")
+ shutdown_submit.click(fn=lambda: shared.restart_server(restart=False), _js="restartReload")
+
+ for _i, k, _item in quicksettings_list:
+ component = component_dict[k]
+ info = opts.data_labels[k]
+ change_handler = component.release if hasattr(component, 'release') else component.change
+ change_handler(
+ fn=lambda value, k=k: run_settings_single(value, key=k),
+ inputs=[component],
+ outputs=[component, text_settings],
+ show_progress=info.refresh is not None,
+ )
+
+ dummy_component = gr.Textbox(visible=False, value='dummy')
+ button_set_checkpoint = gr.Button('Change model', elem_id='change_checkpoint', visible=False)
+ button_set_checkpoint.click(
+ fn=lambda value, _: run_settings_single(value, key='sd_model_checkpoint'),
+ _js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }",
+ inputs=[component_dict['sd_model_checkpoint'], dummy_component],
+ outputs=[component_dict['sd_model_checkpoint'], text_settings],
+ )
+ button_set_refiner = gr.Button('Change refiner', elem_id='change_refiner', visible=False)
+ button_set_refiner.click(
+ fn=lambda value, _: run_settings_single(value, key='sd_model_checkpoint'),
+ _js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }",
+ inputs=[component_dict['sd_model_refiner'], dummy_component],
+ outputs=[component_dict['sd_model_refiner'], text_settings],
+ )
+ button_set_vae = gr.Button('Change VAE', elem_id='change_vae', visible=False)
+ button_set_vae.click(
+ fn=lambda value, _: run_settings_single(value, key='sd_vae'),
+ _js="function(v){ var res = desiredVAEName; desiredVAEName = ''; return [res || v, null]; }",
+ inputs=[component_dict['sd_vae'], dummy_component],
+ outputs=[component_dict['sd_vae'], text_settings],
+ )
+
+ def reference_submit(model):
+ if '@' not in model: # diffusers
+ loaded = modelloader.load_reference(model)
+ return model if loaded else opts.sd_model_checkpoint
+ else: # civitai
+ model, url = model.split('@')
+ loaded = modelloader.load_civitai(model, url)
+ return loaded if loaded is not None else opts.sd_model_checkpoint
+
+ button_set_reference = gr.Button('Change reference', elem_id='change_reference', visible=False)
+ button_set_reference.click(
+ fn=reference_submit,
+ _js="function(v){ return desiredCheckpointName; }",
+ inputs=[component_dict['sd_model_checkpoint']],
+ outputs=[component_dict['sd_model_checkpoint']],
+ )
+ component_keys = [k for k in opts.data_labels.keys() if k in component_dict]
+
+ def get_settings_values():
+ return [get_value_for_setting(key) for key in component_keys]
+
+ ui_app.load(
+ fn=get_settings_values,
+ inputs=[],
+ outputs=[component_dict[k] for k in component_keys if component_dict[k] is not None],
+ queue=False,
+ )
+
+ timer.startup.record("ui-defaults")
+ loadsave.dump_defaults()
+ ui_app.ui_loadsave = loadsave
+ return ui_app
diff --git a/modules/ui_common.py b/modules/ui_common.py
index 4bd40b50a..84e374f48 100644
--- a/modules/ui_common.py
+++ b/modules/ui_common.py
@@ -1,380 +1,380 @@
-import json
-import html
-import os
-import shutil
-import platform
-import subprocess
-from functools import reduce
-import gradio as gr
-from modules import call_queue, shared, prompt_parser
-from modules import generation_parameters_copypaste
-from modules import ui_sections
-from modules.ui_components import FormRow, ToolButton
-import modules.ui_symbols as symbols
-import modules.images
-import modules.script_callbacks
-
-
-folder_symbol = symbols.folder
-debug = shared.log.trace if os.environ.get('SD_PASTE_DEBUG', None) is not None else lambda *args, **kwargs: None
-debug('Trace: PASTE')
-
-
-def gr_show(visible=True):
- return {"visible": visible, "__type__": "update"}
-
-
-def update_generation_info(generation_info, html_info, img_index):
- try:
- generation_info = json.loads(generation_info)
- if img_index < 0 or img_index >= len(generation_info["infotexts"]):
- return html_info, generation_info
- infotext = generation_info["infotexts"][img_index]
- html_info_formatted = infotext_to_html(infotext)
- return html_info, html_info_formatted
- except Exception:
- pass
- return html_info, html_info
-
-
-def plaintext_to_html(text):
- res = '' + " \n".join([f"{html.escape(x)}" for x in text.split('\n')]) + '
'
- return res
-
-
-def infotext_to_html(text):
- res = generation_parameters_copypaste.parse_generation_parameters(text)
- prompt = res.get('Prompt', '')
- negative = res.get('Negative prompt', '')
- res.pop('Prompt', None)
- res.pop('Negative prompt', None)
- params = [f'{k}: {v}' for k, v in res.items() if v is not None]
- params = '| '.join(params) if len(params) > 0 else ''
- code = f'''
- Prompt: {html.escape(prompt)}
- Negative: {html.escape(negative)}
- Parameters: {html.escape(params)}
- '''
- return code
-
-
-def delete_files(js_data, images, _html_info, index):
- try:
- data = json.loads(js_data)
- except Exception:
- data = { 'index_of_first_image': 0 }
- start_index = 0
- if index > -1 and shared.opts.save_selected_only and (index >= data['index_of_first_image']):
- images = [images[index]]
- start_index = index
- filenames = []
- filenames = []
- fullfns = []
- for _image_index, filedata in enumerate(images, start_index):
- if 'name' in filedata and os.path.isfile(filedata['name']):
- fullfn = filedata['name']
- filenames.append(os.path.basename(fullfn))
- try:
- os.remove(fullfn)
- base, _ext = os.path.splitext(fullfn)
- desc = f'{base}.txt'
- if os.path.exists(desc):
- os.remove(desc)
- fullfns.append(fullfn)
- shared.log.info(f"Deleting image: {fullfn}")
- except Exception as e:
- shared.log.error(f'Error deleting file: {fullfn} {e}')
- images = [image for image in images if image['name'] not in fullfns]
- return images, plaintext_to_html(f"Deleted: {filenames[0] if len(filenames) > 0 else 'none'}")
-
-
-def save_files(js_data, images, html_info, index):
- os.makedirs(shared.opts.outdir_save, exist_ok=True)
-
- class PObject: # pylint: disable=too-few-public-methods
- def __init__(self, d=None):
- if d is not None:
- for k, v in d.items():
- setattr(self, k, v)
- self.prompt = getattr(self, 'prompt', None) or getattr(self, 'Prompt', None)
- self.all_prompts = getattr(self, 'all_prompts', [self.prompt])
- self.negative_prompt = getattr(self, 'negative_prompt', None)
- self.all_negative_prompt = getattr(self, 'all_negative_prompts', [self.negative_prompt])
- self.seed = getattr(self, 'seed', None) or getattr(self, 'Seed', None)
- self.all_seeds = getattr(self, 'all_seeds', [self.seed])
- self.subseed = getattr(self, 'subseed', None)
- self.all_subseeds = getattr(self, 'all_subseeds', [self.subseed])
- self.width = getattr(self, 'width', None)
- self.height = getattr(self, 'height', None)
- self.index_of_first_image = getattr(self, 'index_of_first_image', 0)
- self.infotexts = getattr(self, 'infotexts', [html_info])
- self.infotext = self.infotexts[0] if len(self.infotexts) > 0 else html_info
- self.outpath_grids = shared.opts.outdir_grids or shared.opts.outdir_txt2img_grids
- try:
- data = json.loads(js_data)
- except Exception:
- data = {}
- p = PObject(data)
- start_index = 0
- if index > -1 and shared.opts.save_selected_only and (index >= p.index_of_first_image): # ensures we are looking at a specific non-grid picture, and we have save_selected_only # pylint: disable=no-member
- images = [images[index]]
- start_index = index
- filenames = []
- fullfns = []
- for image_index, filedata in enumerate(images, start_index):
- is_grid = image_index < p.index_of_first_image # pylint: disable=no-member
- i = 0 if is_grid else (image_index - p.index_of_first_image) # pylint: disable=no-member
- while len(p.all_seeds) <= i:
- p.all_seeds.append(p.seed)
- while len(p.all_prompts) <= i:
- p.all_prompts.append(p.prompt)
- while len(p.infotexts) <= i:
- p.infotexts.append(p.infotext)
- if 'name' in filedata and ('tmp' not in filedata['name']) and os.path.isfile(filedata['name']):
- fullfn = filedata['name']
- filenames.append(os.path.basename(fullfn))
- fullfns.append(fullfn)
- destination = shared.opts.outdir_save
- namegen = modules.images.FilenameGenerator(p, seed=p.all_seeds[i], prompt=p.all_prompts[i], image=None) # pylint: disable=no-member
- dirname = namegen.apply(shared.opts.directories_filename_pattern or "[prompt_words]").lstrip(' ').rstrip('\\ /')
- destination = os.path.join(destination, dirname)
- destination = namegen.sanitize(destination)
- os.makedirs(destination, exist_ok = True)
- shutil.copy(fullfn, destination)
- shared.log.info(f'Copying image: file="{fullfn}" folder="{destination}"')
- tgt_filename = os.path.join(destination, os.path.basename(fullfn))
- if shared.opts.save_txt:
- try:
- from PIL import Image
- image = Image.open(fullfn)
- info, _ = images.read_info_from_image(image)
- filename_txt = f"{os.path.splitext(tgt_filename)[0]}.txt"
- with open(filename_txt, "w", encoding="utf8") as file:
- file.write(f"{info}\n")
- shared.log.debug(f'Saving: text="{filename_txt}"')
- except Exception as e:
- shared.log.warning(f'Image description save failed: {filename_txt} {e}')
- modules.script_callbacks.image_save_btn_callback(tgt_filename)
- else:
- image = generation_parameters_copypaste.image_from_url_text(filedata)
- info = p.infotexts[i + 1] if len(p.infotexts) > len(p.all_seeds) else p.infotexts[i] # infotexts may be offset by 1 because the first image is the grid
- fullfn, txt_fullfn = modules.images.save_image(image, shared.opts.outdir_save, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], info=info, extension=shared.opts.samples_format, grid=is_grid, p=p)
- if fullfn is None:
- continue
- filename = os.path.relpath(fullfn, shared.opts.outdir_save)
- filenames.append(filename)
- fullfns.append(fullfn)
- if txt_fullfn:
- filenames.append(os.path.basename(txt_fullfn))
- # fullfns.append(txt_fullfn)
- modules.script_callbacks.image_save_btn_callback(filename)
- if shared.opts.samples_save_zip and len(fullfns) > 1:
- zip_filepath = os.path.join(shared.opts.outdir_save, "images.zip")
- from zipfile import ZipFile
- with ZipFile(zip_filepath, "w") as zip_file:
- for i in range(len(fullfns)):
- if os.path.isfile(fullfns[i]):
- with open(fullfns[i], mode="rb") as f:
- zip_file.writestr(filenames[i], f.read())
- fullfns.insert(0, zip_filepath)
- return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0] if len(filenames) > 0 else 'none'}")
-
-
-def open_folder(result_gallery, gallery_index = 0):
- try:
- folder = os.path.dirname(result_gallery[gallery_index]['name'])
- except Exception:
- folder = shared.opts.outdir_samples
- if not os.path.exists(folder):
- shared.log.warning(f'Folder open: folder={folder} does not exist')
- return
- elif not os.path.isdir(folder):
- shared.log.warning(f"Folder open: folder={folder} not a folder")
- return
-
- if not shared.cmd_opts.hide_ui_dir_config:
- path = os.path.normpath(folder)
- if platform.system() == "Windows":
- os.startfile(path) # pylint: disable=no-member
- elif platform.system() == "Darwin":
- subprocess.Popen(["open", path]) # pylint: disable=consider-using-with
- elif "microsoft-standard-WSL2" in platform.uname().release:
- subprocess.Popen(["wsl-open", path]) # pylint: disable=consider-using-with
- else:
- subprocess.Popen(["xdg-open", path]) # pylint: disable=consider-using-with
-
-
-def interrogate_clip(image):
- if image is None:
- shared.log.error("Interrogate: no image selected")
- return gr.update()
- prompt = shared.interrogator.interrogate(image)
- return gr.update() if prompt is None else prompt
-
-
-def interrogate_booru(image):
- if image is None:
- shared.log.error("Interrogate: no image selected")
- return gr.update()
- from modules import deepbooru
- prompt = deepbooru.model.tag(image)
- return gr.update() if prompt is None else prompt
-
-
-def create_output_panel(tabname, preview=True, prompt=None):
- with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
- with gr.Group(elem_id=f"{tabname}_gallery_container"):
- if tabname == "txt2img":
- gr.HTML(value="", elem_id="main_info", visible=False, elem_classes=["main-info"])
- # columns are for <576px, <768px, <992px, <1200px, <1400px, >1400px
- result_gallery = gr.Gallery(value=[], label='Output', show_label=False, show_download_button=True, allow_preview=True, elem_id=f"{tabname}_gallery", container=False, preview=preview, columns=5, object_fit='scale-down', height=shared.opts.gallery_height or None)
- if prompt is not None:
- interrogate_clip_btn, interrogate_booru_btn = ui_sections.create_interrogate_buttons('control')
- interrogate_clip_btn.click(fn=interrogate_clip, inputs=[result_gallery], outputs=[prompt])
- interrogate_booru_btn.click(fn=interrogate_booru, inputs=[result_gallery], outputs=[prompt])
-
-
- with gr.Column(elem_id=f"{tabname}_footer", elem_classes="gallery_footer"):
- dummy_component = gr.Label(visible=False)
- with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"):
- if not shared.cmd_opts.listen:
- open_folder_button = gr.Button('Show', visible=not shared.cmd_opts.hide_ui_dir_config, elem_id=f'open_folder_{tabname}')
- open_folder_button.click(open_folder, _js="(gallery, dummy) => [gallery, selected_gallery_index()]", inputs=[result_gallery, dummy_component], outputs=[])
- else:
- clip_files = gr.Button('Copy', elem_id=f'open_folder_{tabname}')
- clip_files.click(fn=None, _js='clip_gallery_urls', inputs=[result_gallery], outputs=[])
- save = gr.Button('Save', elem_id=f'save_{tabname}')
- delete = gr.Button('Delete', elem_id=f'delete_{tabname}')
- if shared.backend == shared.Backend.ORIGINAL:
- buttons = generation_parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"])
- else:
- buttons = generation_parameters_copypaste.create_buttons(["img2img", "inpaint", "control", "extras"])
-
- download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
- with gr.Group():
- html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext", visible=False) # contains raw infotext as returned by wrapped call
- html_info_formatted = gr.HTML(elem_id=f'html_info_formatted_{tabname}', elem_classes="infotext", visible=True) # contains html formatted infotext
- html_info.change(fn=infotext_to_html, inputs=[html_info], outputs=[html_info_formatted], show_progress=False)
- html_log = gr.HTML(elem_id=f'html_log_{tabname}')
- generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
- generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
-
- generation_info_button.click(fn=update_generation_info, _js="(x, y, z) => [x, y, selected_gallery_index()]", show_progress=False, # triggered on gallery change from js
- inputs=[generation_info, html_info, html_info],
- outputs=[html_info, html_info_formatted],
- )
- save.click(fn=call_queue.wrap_gradio_call(save_files), _js="(x, y, z, i) => [x, y, z, selected_gallery_index()]", show_progress=False,
- inputs=[generation_info, result_gallery, html_info, html_info],
- outputs=[download_files, html_log],
- )
- delete.click(fn=call_queue.wrap_gradio_call(delete_files), _js="(x, y, z, i) => [x, y, z, selected_gallery_index()]",
- inputs=[generation_info, result_gallery, html_info, html_info],
- outputs=[result_gallery, html_log],
- )
-
- if tabname == "txt2img":
- paste_field_names = modules.scripts.scripts_txt2img.paste_field_names
- elif tabname == "img2img":
- paste_field_names = modules.scripts.scripts_img2img.paste_field_names
- elif tabname == "control":
- paste_field_names = modules.scripts.scripts_control.paste_field_names
- else:
- paste_field_names = []
- for paste_tabname, paste_button in buttons.items():
- debug(f'Create output panel: button={paste_button} tabname={paste_tabname}')
- bindings = generation_parameters_copypaste.ParamBinding(paste_button=paste_button, tabname=paste_tabname, source_tabname=("txt2img" if tabname == "txt2img" else None), source_image_component=result_gallery, paste_field_names=paste_field_names)
- generation_parameters_copypaste.register_paste_params_button(bindings)
- return result_gallery, generation_info, html_info, html_info_formatted, html_log
-
-
-def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id, visible: bool = True):
- def refresh():
- refresh_method()
- args = refreshed_args() if callable(refreshed_args) else refreshed_args
- for k, v in args.items():
- setattr(refresh_component, k, v)
- return gr.update(**(args or {}))
-
- refresh_button = ToolButton(value=symbols.refresh, elem_id=elem_id, visible=visible)
- refresh_button.click(fn=refresh, inputs=[], outputs=[refresh_component])
- return refresh_button
-
-
-def create_browse_button(browse_component, elem_id):
- def browse(folder):
- # import subprocess
- if folder is not None:
- return gr.update(value = folder)
- return gr.update()
-
- browse_button = ToolButton(value=symbols.folder, elem_id=elem_id)
- browse_button.click(fn=browse, _js="async () => await browseFolder()", inputs=[browse_component], outputs=[browse_component])
- # browse_button.click(fn=browse, inputs=[browse_component], outputs=[browse_component])
- return browse_button
-
-
-def create_override_inputs(tab): # pylint: disable=unused-argument
- with FormRow(elem_id=f"{tab}_override_settings_row"):
- override_settings = gr.Dropdown([], value=None, label="Override settings", visible=False, elem_id=f"{tab}_override_settings", multiselect=True)
- override_settings.change(fn=lambda x: gr.Dropdown.update(visible=len(x) > 0), inputs=[override_settings], outputs=[override_settings])
- return override_settings
-
-
-def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, is_subseed):
- """ Connects a 'reuse (sub)seed' button's click event so that it copies last used
- (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength
- was 0, i.e. no variation seed was used, it copies the normal seed value instead."""
- def copy_seed(gen_info_string: str, index: int):
- res = -1
- try:
- gen_info = json.loads(gen_info_string)
- shared.log.debug(f'Reuse: info={gen_info}')
- index -= gen_info.get('index_of_first_image', 0)
- index = int(index)
-
- if is_subseed and gen_info.get('subseed_strength', 0) > 0:
- all_subseeds = gen_info.get('all_subseeds', [-1])
- res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0]
- else:
- all_seeds = gen_info.get('all_seeds', [-1])
- res = all_seeds[index if 0 <= index < len(all_seeds) else 0]
- except json.decoder.JSONDecodeError:
- if gen_info_string != '':
- shared.log.error(f"Error parsing JSON generation info: {gen_info_string}")
- return [res, gr_show(False)]
-
- dummy_component = gr.Number(visible=False, value=0)
- reuse_seed.click(fn=copy_seed, _js="(x, y) => [x, selected_gallery_index()]", show_progress=False, inputs=[generation_info, dummy_component], outputs=[seed, dummy_component])
-
-
-def update_token_counter(text, steps):
- from modules import extra_networks, sd_hijack
- try:
- text, _ = extra_networks.parse_prompt(text)
- _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
- prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)
- except Exception:
- prompt_schedules = [[[steps, text]]]
-
- flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
- prompts = [prompt_text for step, prompt_text in flat_prompts]
- if shared.backend == shared.Backend.ORIGINAL:
- token_count, max_length = max([sd_hijack.model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0])
- elif shared.backend == shared.Backend.DIFFUSERS:
- if shared.sd_model is not None and hasattr(shared.sd_model, 'tokenizer'):
- tokenizer = shared.sd_model.tokenizer
- if tokenizer is None:
- token_count = 0
- max_length = 75
- else:
- has_bos_token = tokenizer.bos_token_id is not None
- has_eos_token = tokenizer.eos_token_id is not None
- ids = [shared.sd_model.tokenizer(prompt) for prompt in prompts]
- if len(ids) > 0 and hasattr(ids[0], 'input_ids'):
- ids = [x.input_ids for x in ids]
- token_count = max([len(x) for x in ids]) - int(has_bos_token) - int(has_eos_token)
- max_length = tokenizer.model_max_length - int(has_bos_token) - int(has_eos_token)
- else:
- token_count = 0
- max_length = 75
- return f"{token_count}/{max_length} "
+import json
+import html
+import os
+import shutil
+import platform
+import subprocess
+from functools import reduce
+import gradio as gr
+from modules import call_queue, shared, prompt_parser
+from modules import generation_parameters_copypaste
+from modules import ui_sections
+from modules.ui_components import FormRow, ToolButton
+import modules.ui_symbols as symbols
+import modules.images
+import modules.script_callbacks
+
+
+folder_symbol = symbols.folder
+debug = shared.log.trace if os.environ.get('SD_PASTE_DEBUG', None) is not None else lambda *args, **kwargs: None
+debug('Trace: PASTE')
+
+
+def gr_show(visible=True):
+ return {"visible": visible, "__type__": "update"}
+
+
+def update_generation_info(generation_info, html_info, img_index):
+ try:
+ generation_info = json.loads(generation_info)
+ if img_index < 0 or img_index >= len(generation_info["infotexts"]):
+ return html_info, generation_info
+ infotext = generation_info["infotexts"][img_index]
+ html_info_formatted = infotext_to_html(infotext)
+ return html_info, html_info_formatted
+ except Exception:
+ pass
+ return html_info, html_info
+
+
+def plaintext_to_html(text):
+ res = '' + " \n".join([f"{html.escape(x)}" for x in text.split('\n')]) + '
'
+ return res
+
+
+def infotext_to_html(text):
+ res = generation_parameters_copypaste.parse_generation_parameters(text)
+ prompt = res.get('Prompt', '')
+ negative = res.get('Negative prompt', '')
+ res.pop('Prompt', None)
+ res.pop('Negative prompt', None)
+ params = [f'{k}: {v}' for k, v in res.items() if v is not None]
+ params = '| '.join(params) if len(params) > 0 else ''
+ code = f'''
+ Prompt: {html.escape(prompt)}
+ Negative: {html.escape(negative)}
+ Parameters: {html.escape(params)}
+ '''
+ return code
+
+
+def delete_files(js_data, images, _html_info, index):
+ try:
+ data = json.loads(js_data)
+ except Exception:
+ data = { 'index_of_first_image': 0 }
+ start_index = 0
+ if index > -1 and shared.opts.save_selected_only and (index >= data['index_of_first_image']):
+ images = [images[index]]
+ start_index = index
+ filenames = []
+ filenames = []
+ fullfns = []
+ for _image_index, filedata in enumerate(images, start_index):
+ if 'name' in filedata and os.path.isfile(filedata['name']):
+ fullfn = filedata['name']
+ filenames.append(os.path.basename(fullfn))
+ try:
+ os.remove(fullfn)
+ base, _ext = os.path.splitext(fullfn)
+ desc = f'{base}.txt'
+ if os.path.exists(desc):
+ os.remove(desc)
+ fullfns.append(fullfn)
+ shared.log.info(f"Deleting image: {fullfn}")
+ except Exception as e:
+ shared.log.error(f'Error deleting file: {fullfn} {e}')
+ images = [image for image in images if image['name'] not in fullfns]
+ return images, plaintext_to_html(f"Deleted: {filenames[0] if len(filenames) > 0 else 'none'}")
+
+
+def save_files(js_data, images, html_info, index):
+ os.makedirs(shared.opts.outdir_save, exist_ok=True)
+
+ class PObject: # pylint: disable=too-few-public-methods
+ def __init__(self, d=None):
+ if d is not None:
+ for k, v in d.items():
+ setattr(self, k, v)
+ self.prompt = getattr(self, 'prompt', None) or getattr(self, 'Prompt', None)
+ self.all_prompts = getattr(self, 'all_prompts', [self.prompt])
+ self.negative_prompt = getattr(self, 'negative_prompt', None)
+ self.all_negative_prompt = getattr(self, 'all_negative_prompts', [self.negative_prompt])
+ self.seed = getattr(self, 'seed', None) or getattr(self, 'Seed', None)
+ self.all_seeds = getattr(self, 'all_seeds', [self.seed])
+ self.subseed = getattr(self, 'subseed', None)
+ self.all_subseeds = getattr(self, 'all_subseeds', [self.subseed])
+ self.width = getattr(self, 'width', None)
+ self.height = getattr(self, 'height', None)
+ self.index_of_first_image = getattr(self, 'index_of_first_image', 0)
+ self.infotexts = getattr(self, 'infotexts', [html_info])
+ self.infotext = self.infotexts[0] if len(self.infotexts) > 0 else html_info
+ self.outpath_grids = shared.opts.outdir_grids or shared.opts.outdir_txt2img_grids
+ try:
+ data = json.loads(js_data)
+ except Exception:
+ data = {}
+ p = PObject(data)
+ start_index = 0
+ if index > -1 and shared.opts.save_selected_only and (index >= p.index_of_first_image): # ensures we are looking at a specific non-grid picture, and we have save_selected_only # pylint: disable=no-member
+ images = [images[index]]
+ start_index = index
+ filenames = []
+ fullfns = []
+ for image_index, filedata in enumerate(images, start_index):
+ is_grid = image_index < p.index_of_first_image # pylint: disable=no-member
+ i = 0 if is_grid else (image_index - p.index_of_first_image) # pylint: disable=no-member
+ while len(p.all_seeds) <= i:
+ p.all_seeds.append(p.seed)
+ while len(p.all_prompts) <= i:
+ p.all_prompts.append(p.prompt)
+ while len(p.infotexts) <= i:
+ p.infotexts.append(p.infotext)
+ if 'name' in filedata and ('tmp' not in filedata['name']) and os.path.isfile(filedata['name']):
+ fullfn = filedata['name']
+ filenames.append(os.path.basename(fullfn))
+ fullfns.append(fullfn)
+ destination = shared.opts.outdir_save
+ namegen = modules.images.FilenameGenerator(p, seed=p.all_seeds[i], prompt=p.all_prompts[i], image=None) # pylint: disable=no-member
+ dirname = namegen.apply(shared.opts.directories_filename_pattern or "[prompt_words]").lstrip(' ').rstrip('\\ /')
+ destination = os.path.join(destination, dirname)
+ destination = namegen.sanitize(destination)
+ os.makedirs(destination, exist_ok = True)
+ shutil.copy(fullfn, destination)
+ shared.log.info(f'Copying image: file="{fullfn}" folder="{destination}"')
+ tgt_filename = os.path.join(destination, os.path.basename(fullfn))
+ if shared.opts.save_txt:
+ try:
+ from PIL import Image
+ image = Image.open(fullfn)
+ info, _ = images.read_info_from_image(image)
+ filename_txt = f"{os.path.splitext(tgt_filename)[0]}.txt"
+ with open(filename_txt, "w", encoding="utf8") as file:
+ file.write(f"{info}\n")
+ shared.log.debug(f'Saving: text="{filename_txt}"')
+ except Exception as e:
+ shared.log.warning(f'Image description save failed: {filename_txt} {e}')
+ modules.script_callbacks.image_save_btn_callback(tgt_filename)
+ else:
+ image = generation_parameters_copypaste.image_from_url_text(filedata)
+ info = p.infotexts[i + 1] if len(p.infotexts) > len(p.all_seeds) else p.infotexts[i] # infotexts may be offset by 1 because the first image is the grid
+ fullfn, txt_fullfn = modules.images.save_image(image, shared.opts.outdir_save, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], info=info, extension=shared.opts.samples_format, grid=is_grid, p=p)
+ if fullfn is None:
+ continue
+ filename = os.path.relpath(fullfn, shared.opts.outdir_save)
+ filenames.append(filename)
+ fullfns.append(fullfn)
+ if txt_fullfn:
+ filenames.append(os.path.basename(txt_fullfn))
+ # fullfns.append(txt_fullfn)
+ modules.script_callbacks.image_save_btn_callback(filename)
+ if shared.opts.samples_save_zip and len(fullfns) > 1:
+ zip_filepath = os.path.join(shared.opts.outdir_save, "images.zip")
+ from zipfile import ZipFile
+ with ZipFile(zip_filepath, "w") as zip_file:
+ for i in range(len(fullfns)):
+ if os.path.isfile(fullfns[i]):
+ with open(fullfns[i], mode="rb") as f:
+ zip_file.writestr(filenames[i], f.read())
+ fullfns.insert(0, zip_filepath)
+ return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0] if len(filenames) > 0 else 'none'}")
+
+
+def open_folder(result_gallery, gallery_index = 0):
+ try:
+ folder = os.path.dirname(result_gallery[gallery_index]['name'])
+ except Exception:
+ folder = shared.opts.outdir_samples
+ if not os.path.exists(folder):
+ shared.log.warning(f'Folder open: folder={folder} does not exist')
+ return
+ elif not os.path.isdir(folder):
+ shared.log.warning(f"Folder open: folder={folder} not a folder")
+ return
+
+ if not shared.cmd_opts.hide_ui_dir_config:
+ path = os.path.normpath(folder)
+ if platform.system() == "Windows":
+ os.startfile(path) # pylint: disable=no-member
+ elif platform.system() == "Darwin":
+ subprocess.Popen(["open", path]) # pylint: disable=consider-using-with
+ elif "microsoft-standard-WSL2" in platform.uname().release:
+ subprocess.Popen(["wsl-open", path]) # pylint: disable=consider-using-with
+ else:
+ subprocess.Popen(["xdg-open", path]) # pylint: disable=consider-using-with
+
+
+def interrogate_clip(image):
+ if image is None:
+ shared.log.error("Interrogate: no image selected")
+ return gr.update()
+ prompt = shared.interrogator.interrogate(image)
+ return gr.update() if prompt is None else prompt
+
+
+def interrogate_booru(image):
+ if image is None:
+ shared.log.error("Interrogate: no image selected")
+ return gr.update()
+ from modules import deepbooru
+ prompt = deepbooru.model.tag(image)
+ return gr.update() if prompt is None else prompt
+
+
+def create_output_panel(tabname, preview=True, prompt=None):
+ with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
+ with gr.Group(elem_id=f"{tabname}_gallery_container"):
+ if tabname == "txt2img":
+ gr.HTML(value="", elem_id="main_info", visible=False, elem_classes=["main-info"])
+ # columns are for <576px, <768px, <992px, <1200px, <1400px, >1400px
+ result_gallery = gr.Gallery(value=[], label='Output', show_label=False, show_download_button=True, allow_preview=True, elem_id=f"{tabname}_gallery", container=False, preview=preview, columns=5, object_fit='scale-down', height=shared.opts.gallery_height or None)
+ if prompt is not None:
+ interrogate_clip_btn, interrogate_booru_btn = ui_sections.create_interrogate_buttons('control')
+ interrogate_clip_btn.click(fn=interrogate_clip, inputs=[result_gallery], outputs=[prompt])
+ interrogate_booru_btn.click(fn=interrogate_booru, inputs=[result_gallery], outputs=[prompt])
+
+
+ with gr.Column(elem_id=f"{tabname}_footer", elem_classes="gallery_footer"):
+ dummy_component = gr.Label(visible=False)
+ with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"):
+ if not shared.cmd_opts.listen:
+ open_folder_button = gr.Button('Show', visible=not shared.cmd_opts.hide_ui_dir_config, elem_id=f'open_folder_{tabname}')
+ open_folder_button.click(open_folder, _js="(gallery, dummy) => [gallery, selected_gallery_index()]", inputs=[result_gallery, dummy_component], outputs=[])
+ else:
+ clip_files = gr.Button('Copy', elem_id=f'open_folder_{tabname}')
+ clip_files.click(fn=None, _js='clip_gallery_urls', inputs=[result_gallery], outputs=[])
+ save = gr.Button('Save', elem_id=f'save_{tabname}')
+ delete = gr.Button('Delete', elem_id=f'delete_{tabname}')
+ if shared.backend == shared.Backend.ORIGINAL:
+ buttons = generation_parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"])
+ else:
+ buttons = generation_parameters_copypaste.create_buttons(["img2img", "inpaint", "control", "extras"])
+
+ download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
+ with gr.Group():
+ html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext", visible=False) # contains raw infotext as returned by wrapped call
+ html_info_formatted = gr.HTML(elem_id=f'html_info_formatted_{tabname}', elem_classes="infotext", visible=True) # contains html formatted infotext
+ html_info.change(fn=infotext_to_html, inputs=[html_info], outputs=[html_info_formatted], show_progress=False)
+ html_log = gr.HTML(elem_id=f'html_log_{tabname}')
+ generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
+ generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
+
+ generation_info_button.click(fn=update_generation_info, _js="(x, y, z) => [x, y, selected_gallery_index()]", show_progress=False, # triggered on gallery change from js
+ inputs=[generation_info, html_info, html_info],
+ outputs=[html_info, html_info_formatted],
+ )
+ save.click(fn=call_queue.wrap_gradio_call(save_files), _js="(x, y, z, i) => [x, y, z, selected_gallery_index()]", show_progress=False,
+ inputs=[generation_info, result_gallery, html_info, html_info],
+ outputs=[download_files, html_log],
+ )
+ delete.click(fn=call_queue.wrap_gradio_call(delete_files), _js="(x, y, z, i) => [x, y, z, selected_gallery_index()]",
+ inputs=[generation_info, result_gallery, html_info, html_info],
+ outputs=[result_gallery, html_log],
+ )
+
+ if tabname == "txt2img":
+ paste_field_names = modules.scripts.scripts_txt2img.paste_field_names
+ elif tabname == "img2img":
+ paste_field_names = modules.scripts.scripts_img2img.paste_field_names
+ elif tabname == "control":
+ paste_field_names = modules.scripts.scripts_control.paste_field_names
+ else:
+ paste_field_names = []
+ for paste_tabname, paste_button in buttons.items():
+ debug(f'Create output panel: button={paste_button} tabname={paste_tabname}')
+ bindings = generation_parameters_copypaste.ParamBinding(paste_button=paste_button, tabname=paste_tabname, source_tabname=("txt2img" if tabname == "txt2img" else None), source_image_component=result_gallery, paste_field_names=paste_field_names)
+ generation_parameters_copypaste.register_paste_params_button(bindings)
+ return result_gallery, generation_info, html_info, html_info_formatted, html_log
+
+
+def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id, visible: bool = True):
+ def refresh():
+ refresh_method()
+ args = refreshed_args() if callable(refreshed_args) else refreshed_args
+ for k, v in args.items():
+ setattr(refresh_component, k, v)
+ return gr.update(**(args or {}))
+
+ refresh_button = ToolButton(value=symbols.refresh, elem_id=elem_id, visible=visible)
+ refresh_button.click(fn=refresh, inputs=[], outputs=[refresh_component])
+ return refresh_button
+
+
+def create_browse_button(browse_component, elem_id):
+ def browse(folder):
+ # import subprocess
+ if folder is not None:
+ return gr.update(value = folder)
+ return gr.update()
+
+ browse_button = ToolButton(value=symbols.folder, elem_id=elem_id)
+ browse_button.click(fn=browse, _js="async () => await browseFolder()", inputs=[browse_component], outputs=[browse_component])
+ # browse_button.click(fn=browse, inputs=[browse_component], outputs=[browse_component])
+ return browse_button
+
+
+def create_override_inputs(tab): # pylint: disable=unused-argument
+ with FormRow(elem_id=f"{tab}_override_settings_row"):
+ override_settings = gr.Dropdown([], value=None, label="Override settings", visible=False, elem_id=f"{tab}_override_settings", multiselect=True)
+ override_settings.change(fn=lambda x: gr.Dropdown.update(visible=len(x) > 0), inputs=[override_settings], outputs=[override_settings])
+ return override_settings
+
+
+def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, is_subseed):
+ """ Connects a 'reuse (sub)seed' button's click event so that it copies last used
+ (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength
+ was 0, i.e. no variation seed was used, it copies the normal seed value instead."""
+ def copy_seed(gen_info_string: str, index: int):
+ res = -1
+ try:
+ gen_info = json.loads(gen_info_string)
+ shared.log.debug(f'Reuse: info={gen_info}')
+ index -= gen_info.get('index_of_first_image', 0)
+ index = int(index)
+
+ if is_subseed and gen_info.get('subseed_strength', 0) > 0:
+ all_subseeds = gen_info.get('all_subseeds', [-1])
+ res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0]
+ else:
+ all_seeds = gen_info.get('all_seeds', [-1])
+ res = all_seeds[index if 0 <= index < len(all_seeds) else 0]
+ except json.decoder.JSONDecodeError:
+ if gen_info_string != '':
+ shared.log.error(f"Error parsing JSON generation info: {gen_info_string}")
+ return [res, gr_show(False)]
+
+ dummy_component = gr.Number(visible=False, value=0)
+ reuse_seed.click(fn=copy_seed, _js="(x, y) => [x, selected_gallery_index()]", show_progress=False, inputs=[generation_info, dummy_component], outputs=[seed, dummy_component])
+
+
+def update_token_counter(text, steps):
+ from modules import extra_networks, sd_hijack
+ try:
+ text, _ = extra_networks.parse_prompt(text)
+ _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
+ prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)
+ except Exception:
+ prompt_schedules = [[[steps, text]]]
+
+ flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
+ prompts = [prompt_text for step, prompt_text in flat_prompts]
+ if shared.backend == shared.Backend.ORIGINAL:
+ token_count, max_length = max([sd_hijack.model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0])
+ elif shared.backend == shared.Backend.DIFFUSERS:
+ if shared.sd_model is not None and hasattr(shared.sd_model, 'tokenizer'):
+ tokenizer = shared.sd_model.tokenizer
+ if tokenizer is None:
+ token_count = 0
+ max_length = 75
+ else:
+ has_bos_token = tokenizer.bos_token_id is not None
+ has_eos_token = tokenizer.eos_token_id is not None
+ ids = [shared.sd_model.tokenizer(prompt) for prompt in prompts]
+ if len(ids) > 0 and hasattr(ids[0], 'input_ids'):
+ ids = [x.input_ids for x in ids]
+ token_count = max([len(x) for x in ids]) - int(has_bos_token) - int(has_eos_token)
+ max_length = tokenizer.model_max_length - int(has_bos_token) - int(has_eos_token)
+ else:
+ token_count = 0
+ max_length = 75
+ return f"{token_count}/{max_length} "
diff --git a/modules/ui_components.py b/modules/ui_components.py
index 0e5862765..f9aed5c6c 100644
--- a/modules/ui_components.py
+++ b/modules/ui_components.py
@@ -1,130 +1,130 @@
-import gradio as gr
-
-
-class FormComponent:
- def get_expected_parent(self):
- return gr.components.Form
-
-
-gr.Dropdown.get_expected_parent = FormComponent.get_expected_parent
-
-
-class ToolButton(FormComponent, gr.Button):
- """Small button with single emoji as text, fits inside gradio forms"""
-
- def __init__(self, *args, **kwargs):
- classes = kwargs.pop("elem_classes", [])
- super().__init__(*args, elem_classes=["tool", *classes], **kwargs)
-
- def get_block_name(self):
- return "button"
-
-
-class FormRow(FormComponent, gr.Row):
- """Same as gr.Row but fits inside gradio forms"""
-
- def get_block_name(self):
- return "row"
-
-
-class FormColumn(FormComponent, gr.Column):
- """Same as gr.Column but fits inside gradio forms"""
-
- def get_block_name(self):
- return "column"
-
-
-class FormGroup(FormComponent, gr.Group):
- """Same as gr.Row but fits inside gradio forms"""
-
- def get_block_name(self):
- return "group"
-
-
-class FormHTML(FormComponent, gr.HTML):
- """Same as gr.HTML but fits inside gradio forms"""
-
- def get_block_name(self):
- return "html"
-
-
-class FormColorPicker(FormComponent, gr.ColorPicker):
- """Same as gr.ColorPicker but fits inside gradio forms"""
-
- def get_block_name(self):
- return "colorpicker"
-
-
-class DropdownMulti(FormComponent, gr.Dropdown):
- """Same as gr.Dropdown but always multiselect"""
- def __init__(self, **kwargs):
- super().__init__(multiselect=True, **kwargs)
-
- def get_block_name(self):
- return "dropdown"
-
-
-class DropdownEditable(FormComponent, gr.Dropdown):
- """Same as gr.Dropdown but allows editing value"""
- def __init__(self, **kwargs):
- super().__init__(allow_custom_value=True, **kwargs)
-
- def get_block_name(self):
- return "dropdown"
-
-
-class InputAccordion(gr.Checkbox):
- """A gr.Accordion that can be used as an input - returns True if open, False if closed.
- Actaully just a hidden checkbox, but creates an accordion that follows and is followed by the state of the checkbox.
- """
- global_index = 0
-
- def __init__(self, value, **kwargs):
- self.accordion_id = kwargs.get('elem_id')
- if self.accordion_id is None:
- self.accordion_id = f"input-accordion-{InputAccordion.global_index}"
- InputAccordion.global_index += 1
- kwargs_checkbox = {**kwargs, "elem_id": f"{self.accordion_id}-checkbox", "visible": False}
- super().__init__(value, **kwargs_checkbox)
- self.change(fn=None, _js='function(checked){ inputAccordionChecked("' + self.accordion_id + '", checked); }', inputs=[self])
- kwargs_accordion = {
- **kwargs,
- "elem_id": self.accordion_id,
- "label": kwargs.get('label', 'Accordion'),
- "elem_classes": ['input-accordion'],
- "open": value,
- }
- self.accordion = gr.Accordion(**kwargs_accordion)
-
- def extra(self):
- """Allows you to put something into the label of the accordion.
- Use it like this:
- ```
- with InputAccordion(False, label="Accordion") as acc:
- with acc.extra():
- FormHTML(value="hello", min_width=0)
- ...
- ```
- """
- return gr.Column(elem_id=self.accordion_id + '-extra', elem_classes='input-accordion-extra', min_width=0)
-
- def __enter__(self):
- self.accordion.__enter__()
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- self.accordion.__exit__(exc_type, exc_val, exc_tb)
-
- def get_block_name(self):
- return "checkbox"
-
-
-class ResizeHandleRow(gr.Row):
- """Same as gr.Row but fits inside gradio forms"""
-
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- self.elem_classes.append("resize-handle-row")
-
- def get_block_name(self):
- return "row"
+import gradio as gr
+
+
+class FormComponent:
+ def get_expected_parent(self):
+ return gr.components.Form
+
+
+gr.Dropdown.get_expected_parent = FormComponent.get_expected_parent
+
+
+class ToolButton(FormComponent, gr.Button):
+ """Small button with single emoji as text, fits inside gradio forms"""
+
+ def __init__(self, *args, **kwargs):
+ classes = kwargs.pop("elem_classes", [])
+ super().__init__(*args, elem_classes=["tool", *classes], **kwargs)
+
+ def get_block_name(self):
+ return "button"
+
+
+class FormRow(FormComponent, gr.Row):
+ """Same as gr.Row but fits inside gradio forms"""
+
+ def get_block_name(self):
+ return "row"
+
+
+class FormColumn(FormComponent, gr.Column):
+ """Same as gr.Column but fits inside gradio forms"""
+
+ def get_block_name(self):
+ return "column"
+
+
+class FormGroup(FormComponent, gr.Group):
+ """Same as gr.Row but fits inside gradio forms"""
+
+ def get_block_name(self):
+ return "group"
+
+
+class FormHTML(FormComponent, gr.HTML):
+ """Same as gr.HTML but fits inside gradio forms"""
+
+ def get_block_name(self):
+ return "html"
+
+
+class FormColorPicker(FormComponent, gr.ColorPicker):
+ """Same as gr.ColorPicker but fits inside gradio forms"""
+
+ def get_block_name(self):
+ return "colorpicker"
+
+
+class DropdownMulti(FormComponent, gr.Dropdown):
+ """Same as gr.Dropdown but always multiselect"""
+ def __init__(self, **kwargs):
+ super().__init__(multiselect=True, **kwargs)
+
+ def get_block_name(self):
+ return "dropdown"
+
+
+class DropdownEditable(FormComponent, gr.Dropdown):
+ """Same as gr.Dropdown but allows editing value"""
+ def __init__(self, **kwargs):
+ super().__init__(allow_custom_value=True, **kwargs)
+
+ def get_block_name(self):
+ return "dropdown"
+
+
+class InputAccordion(gr.Checkbox):
+ """A gr.Accordion that can be used as an input - returns True if open, False if closed.
+ Actaully just a hidden checkbox, but creates an accordion that follows and is followed by the state of the checkbox.
+ """
+ global_index = 0
+
+ def __init__(self, value, **kwargs):
+ self.accordion_id = kwargs.get('elem_id')
+ if self.accordion_id is None:
+ self.accordion_id = f"input-accordion-{InputAccordion.global_index}"
+ InputAccordion.global_index += 1
+ kwargs_checkbox = {**kwargs, "elem_id": f"{self.accordion_id}-checkbox", "visible": False}
+ super().__init__(value, **kwargs_checkbox)
+ self.change(fn=None, _js='function(checked){ inputAccordionChecked("' + self.accordion_id + '", checked); }', inputs=[self])
+ kwargs_accordion = {
+ **kwargs,
+ "elem_id": self.accordion_id,
+ "label": kwargs.get('label', 'Accordion'),
+ "elem_classes": ['input-accordion'],
+ "open": value,
+ }
+ self.accordion = gr.Accordion(**kwargs_accordion)
+
+ def extra(self):
+ """Allows you to put something into the label of the accordion.
+ Use it like this:
+ ```
+ with InputAccordion(False, label="Accordion") as acc:
+ with acc.extra():
+ FormHTML(value="hello", min_width=0)
+ ...
+ ```
+ """
+ return gr.Column(elem_id=self.accordion_id + '-extra', elem_classes='input-accordion-extra', min_width=0)
+
+ def __enter__(self):
+ self.accordion.__enter__()
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.accordion.__exit__(exc_type, exc_val, exc_tb)
+
+ def get_block_name(self):
+ return "checkbox"
+
+
+class ResizeHandleRow(gr.Row):
+ """Same as gr.Row but fits inside gradio forms"""
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.elem_classes.append("resize-handle-row")
+
+ def get_block_name(self):
+ return "row"
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index 335223824..fbffa8c58 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -1,827 +1,827 @@
-import io
-import re
-import time
-import json
-import html
-import base64
-import os.path
-import urllib.parse
-import threading
-from datetime import datetime
-from types import SimpleNamespace
-from pathlib import Path
-from html.parser import HTMLParser
-from collections import OrderedDict
-import gradio as gr
-from PIL import Image
-from starlette.responses import FileResponse, JSONResponse
-from modules import paths, shared, scripts, modelloader, errors
-from modules.ui_components import ToolButton
-import modules.ui_symbols as symbols
-
-
-allowed_dirs = []
-refresh_time = 0
-extra_pages = shared.extra_networks
-debug = shared.log.trace if os.environ.get('SD_EN_DEBUG', None) is not None else lambda *args, **kwargs: None
-debug('Trace: EN')
-card_full = '''
-
-
-
-
-
-'''
-card_list = '''
-
-'''
-
-
-def init_api(app):
-
- def fetch_file(filename: str = ""):
- if not os.path.exists(filename):
- return JSONResponse({ "error": f"file {filename}: not found" }, status_code=404)
- if filename.startswith('html/') or filename.startswith('models/'):
- return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
- if not any(Path(folder).absolute() in Path(filename).absolute().parents for folder in allowed_dirs):
- return JSONResponse({ "error": f"file {filename}: must be in one of allowed directories" }, status_code=403)
- if os.path.splitext(filename)[1].lower() not in (".png", ".jpg", ".jpeg", ".webp"):
- return JSONResponse({"error": f"file {filename}: not an image file"}, status_code=403)
- return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
-
- def get_metadata(page: str = "", item: str = ""):
- page = next(iter([x for x in shared.extra_networks if x.name == page]), None)
- if page is None:
- return JSONResponse({ 'metadata': 'none' })
- metadata = page.metadata.get(item, 'none')
- if metadata is None:
- metadata = ''
- # shared.log.debug(f"Extra networks metadata: page='{page}' item={item} len={len(metadata)}")
- return JSONResponse({"metadata": metadata})
-
- def get_info(page: str = "", item: str = ""):
- page = next(iter([x for x in get_pages() if x.name == page]), None)
- if page is None:
- return JSONResponse({ 'info': 'none' })
- item = next(iter([x for x in page.items if x['name'] == item]), None)
- if item is None:
- return JSONResponse({ 'info': 'none' })
- info = page.find_info(item['filename'])
- if info is None:
- info = {}
- # shared.log.debug(f"Extra networks info: page='{page.name}' item={item['name']} len={len(info)}")
- return JSONResponse({"info": info})
-
- def get_desc(page: str = "", item: str = ""):
- page = next(iter([x for x in get_pages() if x.name == page]), None)
- if page is None:
- return JSONResponse({ 'description': 'none' })
- item = next(iter([x for x in page.items if x['name'] == item]), None)
- if item is None:
- return JSONResponse({ 'description': 'none' })
- desc = page.find_description(item['filename'])
- if desc is None:
- desc = ''
- # shared.log.debug(f"Extra networks desc: page='{page.name}' item={item['name']} len={len(desc)}")
- return JSONResponse({"description": desc})
-
- app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"])
- app.add_api_route("/sd_extra_networks/metadata", get_metadata, methods=["GET"])
- app.add_api_route("/sd_extra_networks/info", get_info, methods=["GET"])
- app.add_api_route("/sd_extra_networks/description", get_desc, methods=["GET"])
-
-
-class ExtraNetworksPage:
- def __init__(self, title):
- self.title = title
- self.name = title.lower()
- self.allow_negative_prompt = False
- self.metadata = {}
- self.info = {}
- self.html = ''
- self.items = []
- self.missing_thumbs = []
- self.refresh_time = 0
- self.page_time = 0
- self.list_time = 0
- self.info_time = 0
- self.desc_time = 0
- self.dirs = {}
- self.view = shared.opts.extra_networks_view
- self.card = card_full if shared.opts.extra_networks_view == 'gallery' else card_list
-
- def refresh(self):
- pass
-
- def create_xyz_grid(self):
- xyz_grid = [x for x in scripts.scripts_data if x.script_class.__module__ == "xyz_grid.py"][0].module
-
- def add_prompt(p, opt, x):
- for item in [x for x in self.items if x["name"] == opt]:
- try:
- p.prompt = f'{p.prompt} {eval(item["prompt"])}' # pylint: disable=eval-used
- except Exception as e:
- shared.log.error(f'Cannot evaluate extra network prompt: {item["prompt"]} {e}')
-
- if not any(self.title in x.label for x in xyz_grid.axis_options):
- if self.title == 'Model':
- return
- opt = xyz_grid.AxisOption(f"[Network] {self.title}", str, add_prompt, choices=lambda: [x["name"] for x in self.items])
- xyz_grid.axis_options.append(opt)
-
- def link_preview(self, filename):
- quoted_filename = urllib.parse.quote(filename.replace('\\', '/'))
- mtime = os.path.getmtime(filename)
- preview = f"./sd_extra_networks/thumb?filename={quoted_filename}&mtime={mtime}"
- return preview
-
- def search_terms_from_path(self, filename):
- return filename.replace('\\', '/')
-
- def is_empty(self, folder):
- for f in shared.listdir(folder):
- _fn, ext = os.path.splitext(f)
- if ext.lower() in ['.ckpt', '.safetensors', '.pt', '.json'] or os.path.isdir(os.path.join(folder, f)):
- return False
- return True
-
- def create_thumb(self):
- debug(f'EN create-thumb: {self.name}')
- created = 0
- for f in self.missing_thumbs:
- if not os.path.exists(f):
- continue
- fn, _ext = os.path.splitext(f)
- fn = fn.replace('.preview', '')
- fn = f'{fn}.thumb.jpg'
- if os.path.exists(fn):
- continue
- img = None
- try:
- img = Image.open(f)
- except Exception:
- img = None
- shared.log.warning(f'Extra network removing invalid image: {f}')
- try:
- if img is None:
- img = None
- os.remove(f)
- elif img.width > 1024 or img.height > 1024 or os.path.getsize(f) > 65536:
- img = img.convert('RGB')
- img.thumbnail((512, 512), Image.Resampling.HAMMING)
- img.save(fn, quality=50)
- img.close()
- created += 1
- except Exception as e:
- shared.log.warning(f'Extra network error creating thumbnail: {f} {e}')
- if created > 0:
- shared.log.info(f"Extra network thumbnails: {self.name} created={created}")
- self.missing_thumbs.clear()
-
- def create_items(self, tabname):
- if self.refresh_time is not None and self.refresh_time > refresh_time: # cached results
- return
- t0 = time.time()
- try:
- self.items = list(self.list_items())
- self.refresh_time = time.time()
- except Exception as e:
- self.items = []
- shared.log.error(f'Extra networks error listing items: class={self.__class__.__name__} tab={tabname} {e}')
- for item in self.items:
- if item is None:
- continue
- self.metadata[item["name"]] = item.get("metadata", {})
- t1 = time.time()
- debug(f'EN create-items: page={self.name} items={len(self.items)} time={t1-t0:.2f}')
- self.list_time += t1-t0
-
-
- def create_page(self, tabname, skip = False):
- debug(f'EN create-page: {self.name}')
- if self.page_time > refresh_time and len(self.html) > 0: # cached page
- return self.html
- self_name_id = self.name.replace(" ", "_")
- if skip:
- return f""
- subdirs = {}
- allowed_folders = [os.path.abspath(x) for x in self.allowed_directories_for_previews()]
- for parentdir, dirs in {d: modelloader.directory_list(d) for d in allowed_folders}.items():
- for tgt in dirs.keys():
- if os.path.join(paths.models_path, 'Reference') in tgt:
- subdirs['Reference'] = 1
- if shared.backend == shared.Backend.DIFFUSERS and shared.opts.diffusers_dir in tgt:
- subdirs[os.path.basename(shared.opts.diffusers_dir)] = 1
- if 'models--' in tgt:
- continue
- subdir = tgt[len(parentdir):].replace("\\", "/")
- while subdir.startswith("/"):
- subdir = subdir[1:]
- # if not self.is_empty(tgt):
- if not subdir.startswith("."):
- subdirs[subdir] = 1
- debug(f"Extra networks: page='{self.name}' subfolders={list(subdirs)}")
- subdirs = OrderedDict(sorted(subdirs.items()))
- if self.name == 'model':
- subdirs['Reference'] = 1
- subdirs[os.path.basename(shared.opts.diffusers_dir)] = 1
- subdirs.move_to_end(os.path.basename(shared.opts.diffusers_dir))
- subdirs.move_to_end('Reference')
- if self.name == 'style' and shared.opts.extra_networks_styles:
- subdirs['built-in'] = 1
- subdirs_html = "All "
- subdirs_html += "".join([f"{html.escape(subdir)} " for subdir in subdirs if subdir != ''])
- self.html = ''
- self.create_items(tabname)
- self.create_xyz_grid()
- htmls = []
- if len(self.items) > 0 and self.items[0].get('mtime', None) is not None:
- self.items.sort(key=lambda x: x["mtime"], reverse=True)
- for item in self.items:
- htmls.append(self.create_html(item, tabname))
- self.html += ''.join(htmls)
- self.page_time = time.time()
- if len(subdirs_html) > 0 or len(self.html) > 0:
- self.html = f""
- else:
- return ''
- shared.log.debug(f"Extra networks: page='{self.name}' items={len(self.items)} subfolders={len(subdirs)} tab={tabname} folders={self.allowed_directories_for_previews()} list={self.list_time:.2f} desc={self.desc_time:.2f} info={self.info_time:.2f} workers={shared.max_workers}")
- if len(self.missing_thumbs) > 0:
- threading.Thread(target=self.create_thumb).start()
- return self.html
-
- def list_items(self):
- raise NotImplementedError
-
- def allowed_directories_for_previews(self):
- return []
-
- def create_html(self, item, tabname):
- try:
- args = {
- "tabname": tabname,
- "page": self.name,
- "name": item["name"],
- "title": os.path.basename(item["name"].replace('_', ' ')),
- "filename": item["filename"],
- "tags": '|'.join([item.get("tags")] if isinstance(item.get("tags", {}), str) else list(item.get("tags", {}).keys())),
- "preview": html.escape(item.get("preview", self.link_preview('html/card-no-preview.png'))),
- "width": shared.opts.extra_networks_card_size,
- "height": shared.opts.extra_networks_card_size if shared.opts.extra_networks_card_square else 'auto',
- "fit": shared.opts.extra_networks_card_fit,
- "prompt": item.get("prompt", None),
- "search": item.get("search_term", ""),
- "description": item.get("description") or "",
- "card_click": item.get("onclick", '"' + html.escape(f'return cardClicked({item.get("prompt", None)}, {"true" if self.allow_negative_prompt else "false"})') + '"'),
- "mtime": item.get("mtime", 0),
- "size": item.get("size", 0),
- }
- alias = item.get("alias", None)
- if alias is not None:
- args['title'] += f'\nAlias: {alias}'
- return self.card.format(**args)
- except Exception as e:
- shared.log.error(f'Extra networks item error: page={tabname} item={item["name"]} {e}')
- return ""
-
- def find_preview_file(self, path):
- if path is None:
- return 'html/card-no-preview.png'
- if shared.opts.diffusers_dir in path:
- path = os.path.relpath(path, shared.opts.diffusers_dir)
- ref = os.path.join('models', 'Reference')
- fn = os.path.join(ref, path.replace('models--', '').replace('\\', '/').split('/')[0])
- files = shared.listdir(ref)
- else:
- files = shared.listdir(os.path.dirname(path))
- fn = os.path.splitext(path)[0]
- exts = ["jpg", "jpeg", "png", "webp", "tiff", "jp2"]
- for file in [f'{fn}{mid}{ext}' for ext in exts for mid in ['.thumb.', '.', '.preview.']]:
- if file in files:
- if 'Reference' not in file and '.thumb.' not in file:
- self.missing_thumbs.append(file)
- return file
- return 'html/card-no-preview.png'
-
- def find_preview(self, path):
- preview_file = self.find_preview_file(path)
- return self.link_preview(preview_file)
-
- def find_description(self, path, info=None):
- t0 = time.time()
- class HTMLFilter(HTMLParser):
- text = ""
- def handle_data(self, data):
- self.text += data
- def handle_endtag(self, tag):
- if tag == 'p':
- self.text += '\n'
-
- fn = os.path.splitext(path)[0] + '.txt'
- if fn in shared.listdir(os.path.dirname(path)):
- try:
- with open(fn, "r", encoding="utf-8", errors="replace") as f:
- txt = f.read()
- txt = re.sub('[<>]', '', txt)
- return txt
- except OSError:
- pass
- if info is None:
- info = self.find_info(path)
- desc = info.get('description', '') or ''
- f = HTMLFilter()
- f.feed(desc)
- t1 = time.time()
- self.desc_time += t1-t0
- return f.text
-
- def find_info(self, path):
- fn = os.path.splitext(path)[0] + '.json'
- data = {}
- if fn in shared.listdir(os.path.dirname(path)):
- t0 = time.time()
- data = shared.readfile(fn, silent=True)
- if type(data) is list:
- data = data[0]
- t1 = time.time()
- self.info_time += t1-t0
- return data
-
-
-def initialize():
- shared.extra_networks.clear()
-
-
-def register_page(page: ExtraNetworksPage):
- # registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions
- debug(f'EN register-page: {page}')
- if page in shared.extra_networks:
- debug(f'EN register-page: {page} already registered')
- return
- shared.extra_networks.append(page)
- # allowed_dirs.clear()
- # for pg in shared.extra_networks:
- for folder in page.allowed_directories_for_previews():
- if folder not in allowed_dirs:
- allowed_dirs.append(os.path.abspath(folder))
-
-
-def register_pages():
- from modules.ui_extra_networks_textual_inversion import ExtraNetworksPageTextualInversion
- from modules.ui_extra_networks_hypernets import ExtraNetworksPageHypernetworks
- from modules.ui_extra_networks_checkpoints import ExtraNetworksPageCheckpoints
- from modules.ui_extra_networks_styles import ExtraNetworksPageStyles
- from modules.ui_extra_networks_vae import ExtraNetworksPageVAEs
- debug('EN register-pages')
- register_page(ExtraNetworksPageCheckpoints())
- register_page(ExtraNetworksPageStyles())
- register_page(ExtraNetworksPageTextualInversion())
- register_page(ExtraNetworksPageHypernetworks())
- register_page(ExtraNetworksPageVAEs())
-
-
-def get_pages(title=None):
- pages = []
- if 'All' in shared.opts.extra_networks:
- pages = shared.extra_networks
- else:
- titles = [page.title for page in shared.extra_networks]
- if title is None:
- for page in shared.opts.extra_networks:
- try:
- idx = titles.index(page)
- pages.append(shared.extra_networks[idx])
- except ValueError:
- continue
- else:
- try:
- idx = titles.index(title)
- pages.append(shared.extra_networks[idx])
- except ValueError:
- pass
- return pages
-
-
-class ExtraNetworksUi:
- def __init__(self):
- self.tabname: str = None
- self.pages: list(str) = None
- self.visible: gr.State = None
- self.state: gr.Textbox = None
- self.details: gr.Group = None
- self.tabs: gr.Tabs = None
- self.gallery: gr.Gallery = None
- self.description: gr.Textbox = None
- self.search: gr.Textbox = None
- self.button_details: gr.Button = None
- self.button_refresh: gr.Button = None
- self.button_scan: gr.Button = None
- self.button_view: gr.Button = None
- self.button_quicksave: gr.Button = None
- self.button_save: gr.Button = None
- self.button_sort: gr.Button = None
- self.button_apply: gr.Button = None
- self.button_close: gr.Button = None
- self.button_model: gr.Checkbox = None
- self.details_components: list = []
- self.last_item: dict = None
- self.last_page: ExtraNetworksPage = None
- self.state: gr.State = None
-
-
-def create_ui(container, button_parent, tabname, skip_indexing = False):
- debug(f'EN create-ui: {tabname}')
- ui = ExtraNetworksUi()
- ui.tabname = tabname
- ui.pages = []
- ui.state = gr.Textbox('{}', elem_id=f"{tabname}_extra_state", visible=False)
- ui.visible = gr.State(value=False) # pylint: disable=abstract-class-instantiated
- ui.details = gr.Group(elem_id=f"{tabname}_extra_details", visible=False)
- ui.tabs = gr.Tabs(elem_id=f"{tabname}_extra_tabs")
- ui.button_details = gr.Button('Details', elem_id=f"{tabname}_extra_details_btn", visible=False)
- state = {}
- if shared.cmd_opts.profile:
- import cProfile
- pr = cProfile.Profile()
- pr.enable()
-
- def get_item(state, params = None):
- if params is not None and type(params) == dict:
- page = next(iter([x for x in get_pages() if x.title == 'Style']), None)
- item = page.create_style(params)
- else:
- if state is None or not hasattr(state, 'page') or not hasattr(state, 'item'):
- return None, None
- page = next(iter([x for x in get_pages() if x.title == state.page]), None)
- if page is None:
- return None, None
- item = next(iter([x for x in page.items if x["name"] == state.item]), None)
- if item is None:
- return page, None
- item = SimpleNamespace(**item)
- ui.last_item = item
- ui.last_page = page
- return page, item
-
- # main event that is triggered when js updates state text field with json values, used to communicate js -> python
- def state_change(state_text):
- try:
- nonlocal state
- state = SimpleNamespace(**json.loads(state_text))
- except Exception as e:
- shared.log.error(f'Extra networks state error: {e}')
- return
- _page, _item = get_item(state)
- # shared.log.debug(f'Extra network: op={state.op} page={page.title if page is not None else None} item={item.filename if item is not None else None}')
-
- def toggle_visibility(is_visible):
- is_visible = not is_visible
- return is_visible, gr.update(visible=is_visible), gr.update(variant=("secondary-down" if is_visible else "secondary"))
-
- with ui.details:
- details_close = ToolButton(symbols.close, elem_id=f"{tabname}_extra_details_close", elem_classes=['extra-details-close'])
- details_close.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=[ui.details])
- with gr.Row():
- with gr.Column(scale=1):
- text = gr.HTML('title
')
- ui.details_components.append(text)
- with gr.Column(scale=1):
- img = gr.Image(value=None, show_label=False, interactive=False, container=False, show_download_button=False, show_info=False, elem_id=f"{tabname}_extra_details_img", elem_classes=['extra-details-img'])
- ui.details_components.append(img)
- with gr.Row():
- btn_save_img = gr.Button('Replace', elem_classes=['small-button'])
- btn_delete_img = gr.Button('Delete', elem_classes=['small-button'])
- with gr.Tabs():
- with gr.Tab('Description'):
- desc = gr.Textbox('', show_label=False, lines=8, placeholder="Extra network description...")
- ui.details_components.append(desc)
- with gr.Row():
- btn_save_desc = gr.Button('Save', elem_classes=['small-button'], elem_id=f'{tabname}_extra_details_save_desc')
- btn_delete_desc = gr.Button('Delete', elem_classes=['small-button'], elem_id=f'{tabname}_extra_details_delete_desc')
- btn_close_desc = gr.Button('Close', elem_classes=['small-button'], elem_id=f'{tabname}_extra_details_close_desc')
- btn_close_desc.click(fn=lambda: gr.update(visible=False), _js='refeshDetailsEN', inputs=[], outputs=[ui.details])
- with gr.Tab('Model metadata'):
- info = gr.JSON({}, show_label=False)
- ui.details_components.append(info)
- with gr.Row():
- btn_save_info = gr.Button('Save', elem_classes=['small-button'], elem_id=f'{tabname}_extra_details_save_info')
- btn_delete_info = gr.Button('Delete', elem_classes=['small-button'], elem_id=f'{tabname}_extra_details_delete_info')
- btn_close_info = gr.Button('Close', elem_classes=['small-button'], elem_id=f'{tabname}_extra_details_close_info')
- btn_close_info.click(fn=lambda: gr.update(visible=False), _js='refeshDetailsEN', inputs=[], outputs=[ui.details])
- with gr.Tab('Embedded metadata'):
- meta = gr.JSON({}, show_label=False)
- ui.details_components.append(meta)
-
- with ui.tabs:
- def ui_tab_change(page):
- scan_visible = page in ['Model', 'Lora', 'Hypernetwork', 'Embedding']
- save_visible = page in ['Style']
- model_visible = page in ['Model']
- return [gr.update(visible=scan_visible), gr.update(visible=save_visible), gr.update(visible=model_visible)]
-
- ui.button_refresh = ToolButton(symbols.refresh, elem_id=f"{tabname}_extra_refresh")
- ui.button_scan = ToolButton(symbols.scan, elem_id=f"{tabname}_extra_scan", visible=True)
- ui.button_quicksave = ToolButton(symbols.book, elem_id=f"{tabname}_extra_quicksave", visible=False)
- ui.button_save = ToolButton(symbols.book, elem_id=f"{tabname}_extra_save", visible=False)
- ui.button_sort = ToolButton(symbols.sort, elem_id=f"{tabname}_extra_sort", visible=True)
- ui.button_view = ToolButton(symbols.view, elem_id=f"{tabname}_extra_view", visible=True)
- ui.button_close = ToolButton(symbols.close, elem_id=f"{tabname}_extra_close", visible=True)
- ui.button_model = ToolButton(symbols.refine, elem_id=f"{tabname}_extra_model", visible=True)
- ui.search = gr.Textbox('', show_label=False, elem_id=f"{tabname}_extra_search", placeholder="Search...", elem_classes="textbox", lines=2, container=False)
- ui.description = gr.Textbox('', show_label=False, elem_id=f"{tabname}_description", elem_classes="textbox", lines=2, interactive=False, container=False)
-
- if ui.tabname == 'txt2img': # refresh only once
- global refresh_time # pylint: disable=global-statement
- refresh_time = time.time()
- if not skip_indexing:
- threads = []
- for page in get_pages():
- if os.environ.get('SD_EN_DEBUG', None) is not None:
- threads.append(threading.Thread(target=page.create_items, args=[ui.tabname]))
- threads[-1].start()
- else:
- page.create_items(ui.tabname)
- for thread in threads:
- thread.join()
- for page in get_pages():
- page.create_page(ui.tabname, skip_indexing)
- with gr.Tab(page.title, id=page.title.lower().replace(" ", "_"), elem_classes="extra-networks-tab") as tab:
- page_html = gr.HTML(page.html, elem_id=f'{tabname}{page.name}_extra_page', elem_classes="extra-networks-page")
- ui.pages.append(page_html)
- tab.select(ui_tab_change, _js="getENActivePage", inputs=[ui.button_details], outputs=[ui.button_scan, ui.button_save, ui.button_model])
- if shared.cmd_opts.profile:
- errors.profile(pr, 'ExtraNetworks')
- pr.disable()
- # ui.tabs.change(fn=ui_tab_change, inputs=[], outputs=[ui.button_scan, ui.button_save])
-
- def fn_save_img(image):
- if ui.last_item is None or ui.last_item.local_preview is None:
- return 'html/card-no-preview.png'
- images = list(ui.gallery.temp_files) # gallery cannot be used as input component so looking at most recently registered temp files
- if len(images) < 1:
- shared.log.warning(f'Extra network no image: item={ui.last_item.name}')
- return 'html/card-no-preview.png'
- try:
- images.sort(key=lambda f: os.path.getmtime(f), reverse=True)
- image = Image.open(images[0])
- except Exception as e:
- shared.log.error(f'Extra network error opening image: item={ui.last_item.name} {e}')
- return 'html/card-no-preview.png'
- fn_delete_img(image)
- if image.width > 512 or image.height > 512:
- image = image.convert('RGB')
- image.thumbnail((512, 512), Image.Resampling.HAMMING)
- try:
- image.save(ui.last_item.local_preview, quality=50)
- shared.log.debug(f'Extra network save image: item={ui.last_item.name} filename="{ui.last_item.local_preview}"')
- except Exception as e:
- shared.log.error(f'Extra network save image: item={ui.last_item.name} filename="{ui.last_item.local_preview}" {e}')
- return image
-
- def fn_delete_img(_image):
- preview_extensions = ["jpg", "jpeg", "png", "webp", "tiff", "jp2"]
- fn = os.path.splitext(ui.last_item.filename)[0]
- for file in [f'{fn}{mid}{ext}' for ext in preview_extensions for mid in ['.thumb.', '.preview.', '.']]:
- if os.path.exists(file):
- os.remove(file)
- shared.log.debug(f'Extra network delete image: item={ui.last_item.name} filename="{file}"')
- return 'html/card-no-preview.png'
-
- def fn_save_desc(desc):
- if hasattr(ui.last_item, 'type') and ui.last_item.type == 'Style':
- params = ui.last_page.parse_desc(desc)
- if params is not None:
- fn_save_info(params)
- else:
- fn = os.path.splitext(ui.last_item.filename)[0] + '.txt'
- with open(fn, 'w', encoding='utf-8') as f:
- f.write(desc)
- shared.log.debug(f'Extra network save desc: item={ui.last_item.name} filename="{fn}"')
- return desc
-
- def fn_delete_desc(desc):
- if ui.last_item is None:
- return desc
- if hasattr(ui.last_item, 'type') and ui.last_item.type == 'Style':
- fn = os.path.splitext(ui.last_item.filename)[0] + '.json'
- else:
- fn = os.path.splitext(ui.last_item.filename)[0] + '.txt'
- if os.path.exists(fn):
- shared.log.debug(f'Extra network delete desc: item={ui.last_item.name} filename="{fn}"')
- os.remove(fn)
- return ''
- return desc
-
- def fn_save_info(info):
- fn = os.path.splitext(ui.last_item.filename)[0] + '.json'
- shared.writefile(info, fn, silent=True)
- shared.log.debug(f'Extra network save info: item={ui.last_item.name} filename="{fn}"')
- return info
-
- def fn_delete_info(info):
- if ui.last_item is None:
- return info
- fn = os.path.splitext(ui.last_item.filename)[0] + '.json'
- if os.path.exists(fn):
- shared.log.debug(f'Extra network delete info: item={ui.last_item.name} filename="{fn}"')
- os.remove(fn)
- return ''
- return info
-
- btn_save_img.click(fn=fn_save_img, _js='closeDetailsEN', inputs=[img], outputs=[img])
- btn_delete_img.click(fn=fn_delete_img, _js='closeDetailsEN', inputs=[img], outputs=[img])
- btn_save_desc.click(fn=fn_save_desc, _js='closeDetailsEN', inputs=[desc], outputs=[desc])
- btn_delete_desc.click(fn=fn_delete_desc, _js='closeDetailsEN', inputs=[desc], outputs=[desc])
- btn_save_info.click(fn=fn_save_info, _js='closeDetailsEN', inputs=[info], outputs=[info])
- btn_delete_info.click(fn=fn_delete_info, _js='closeDetailsEN', inputs=[info], outputs=[info])
-
- def show_details(text, img, desc, info, meta, params):
- page, item = get_item(state, params)
- if item is not None and hasattr(item, 'name'):
- stat = os.stat(item.filename) if os.path.exists(item.filename) else None
- desc = item.description
- fullinfo = shared.readfile(os.path.splitext(item.filename)[0] + '.json', silent=True)
- if 'modelVersions' in fullinfo: # sanitize massive objects
- fullinfo['modelVersions'] = []
- info = fullinfo
- meta = page.metadata.get(item.name, {}) or {}
- if type(meta) is str:
- try:
- meta = json.loads(meta)
- except Exception:
- meta = {}
- if ui.last_item.preview.startswith('data:'):
- b64str = ui.last_item.preview.split(',',1)[1]
- img = Image.open(io.BytesIO(base64.b64decode(b64str)))
- elif hasattr(item, 'local_preview') and os.path.exists(item.local_preview):
- img = item.local_preview
- else:
- img = page.find_preview_file(item.filename)
- lora = ''
- model = ''
- style = ''
- note = ''
- if not os.path.exists(item.filename):
- note = f' Target filename: {item.filename}'
- if page.title == 'Model':
- merge = len(list(meta.get('sd_merge_models', {})))
- if merge > 0:
- model += f'Merge models {merge} recipes '
- if meta.get('modelspec.architecture', None) is not None:
- model += f'''
- Architecture {meta.get('modelspec.architecture', 'N/A')}
- Title {meta.get('modelspec.title', 'N/A')}
- Resolution {meta.get('modelspec.resolution', 'N/A')}
- '''
- if page.title == 'Lora':
- try:
- tags = getattr(item, 'tags', {})
- tags = [f'{name}:{tags[name]}' for i, name in enumerate(tags)]
- tags = ' '.join(tags)
- except Exception:
- tags = ''
- try:
- triggers = ' '.join(info.get('tags', []))
- except Exception:
- triggers = ''
- lora = f'''
- Model tags {tags}
- User tags {triggers}
- Base model {meta.get('ss_sd_model_name', 'N/A')}
- Resolution {meta.get('ss_resolution', 'N/A')}
- Training images {meta.get('ss_num_train_images', 'N/A')}
- Comment {meta.get('ss_training_comment', 'N/A')}
- '''
- if page.title == 'Style':
- style = f'''
- Name {item.name}
- Description {item.description}
- Preview Embedded {item.preview.startswith('data:')}
- '''
- desc = f'Name: {os.path.basename(item.name)}\nDescription: {item.description}\nPrompt: {item.prompt}\nNegative: {item.negative}\nExtra: {item.extra}\n'
- text = f'''
- {item.name}
-
- Type {page.title}
- Alias {getattr(item, 'alias', 'N/A')}
- Filename {item.filename}
- Hash {getattr(item, 'hash', 'N/A')}
- Size {round(stat.st_size/1024/1024, 2) if stat is not None else 'N/A'} MB
- Last modified {datetime.fromtimestamp(stat.st_mtime) if stat is not None else 'N/A'}
-
- {lora}
- {model}
- {style}
-
- {note}
- '''
- return [text, img, desc, info, meta, gr.update(visible=item is not None)]
-
- def ui_refresh_click(title):
- pages = []
- for page in get_pages():
- if title is None or title == '' or title == page.title or len(page.html) == 0:
- page.page_time = 0
- page.refresh_time = 0
- page.refresh()
- page.create_page(ui.tabname)
- shared.log.debug(f"Refreshing Extra networks: page='{page.title}' items={len(page.items)} tab={ui.tabname}")
- pages.append(page.html)
- ui.search.update(value = ui.search.value)
- return pages
-
- def ui_view_cards(title):
- pages = []
- for page in get_pages():
- if title is None or title == '' or title == page.title or len(page.html) == 0:
- shared.opts.extra_networks_view = page.view
- page.view = 'gallery' if page.view == 'list' else 'list'
- page.card = card_full if page.view == 'gallery' else card_list
- page.html = ''
- page.create_page(ui.tabname)
- shared.log.debug(f"Refreshing Extra networks: page='{page.title}' items={len(page.items)} tab={ui.tabname} view={page.view}")
- pages.append(page.html)
- ui.search.update(value = ui.search.value)
- return pages
-
- def ui_scan_click(title):
- from modules import ui_models
- if ui_models.search_metadata_civit is not None:
- ui_models.search_metadata_civit(True, title)
- return ui_refresh_click(title)
-
- def ui_save_click():
- from modules import generation_parameters_copypaste
- filename = os.path.join(paths.data_path, "params.txt")
- if os.path.exists(filename):
- with open(filename, "r", encoding="utf8") as file:
- prompt = file.read()
- else:
- prompt = ''
- params = generation_parameters_copypaste.parse_generation_parameters(prompt)
- res = show_details(text=None, img=None, desc=None, info=None, meta=None, params=params)
- return res
-
- def ui_quicksave_click(name):
- from modules import generation_parameters_copypaste
- fn = os.path.join(paths.data_path, "params.txt")
- if os.path.exists(fn):
- with open(fn, "r", encoding="utf8") as file:
- prompt = file.read()
- else:
- prompt = ''
- params = generation_parameters_copypaste.parse_generation_parameters(prompt)
- fn = os.path.join(shared.opts.styles_dir, os.path.splitext(name)[0] + '.json')
- prompt = params.get('Prompt', '')
- item = {
- "name": name,
- "description": '',
- "prompt": prompt,
- "negative": params.get('Negative prompt', ''),
- "extra": '',
- # "type": 'Style',
- # "title": name,
- # "filename": fn,
- # "search_term": None,
- # "preview": None,
- # "local_preview": None,
- }
- shared.writefile(item, fn, silent=True)
- if len(prompt) > 0:
- shared.log.debug(f"Extra network quick save style: item={name} filename='{fn}'")
- else:
- shared.log.warning(f"Extra network quick save model: item={name} filename='{fn}' prompt is empty")
-
- def ui_sort_cards(msg):
- shared.log.debug(f'Extra networks: {msg}')
- return msg
-
- dummy = gr.State(value=False) # pylint: disable=abstract-class-instantiated
- button_parent.click(fn=toggle_visibility, inputs=[ui.visible], outputs=[ui.visible, container, button_parent])
- ui.button_close.click(fn=toggle_visibility, inputs=[ui.visible], outputs=[ui.visible, container])
- ui.button_sort.click(fn=ui_sort_cards, _js='sortExtraNetworks', inputs=[ui.search], outputs=[ui.description])
- ui.button_view.click(fn=ui_view_cards, inputs=[ui.search], outputs=ui.pages)
- ui.button_refresh.click(fn=ui_refresh_click, _js='getENActivePage', inputs=[ui.search], outputs=ui.pages)
- ui.button_scan.click(fn=ui_scan_click, _js='getENActivePage', inputs=[ui.search], outputs=ui.pages)
- ui.button_save.click(fn=ui_save_click, inputs=[], outputs=ui.details_components + [ui.details])
- ui.button_quicksave.click(fn=ui_quicksave_click, _js="() => prompt('Prompt name', '')", inputs=[ui.search], outputs=[])
- ui.button_details.click(show_details, _js="getCardDetails", inputs=ui.details_components + [dummy], outputs=ui.details_components + [ui.details])
- ui.state.change(state_change, inputs=[ui.state], outputs=[])
- return ui
-
-
-def setup_ui(ui, gallery):
- ui.gallery = gallery
+import io
+import re
+import time
+import json
+import html
+import base64
+import os.path
+import urllib.parse
+import threading
+from datetime import datetime
+from types import SimpleNamespace
+from pathlib import Path
+from html.parser import HTMLParser
+from collections import OrderedDict
+import gradio as gr
+from PIL import Image
+from starlette.responses import FileResponse, JSONResponse
+from modules import paths, shared, scripts, modelloader, errors
+from modules.ui_components import ToolButton
+import modules.ui_symbols as symbols
+
+
+allowed_dirs = []
+refresh_time = 0
+extra_pages = shared.extra_networks
+debug = shared.log.trace if os.environ.get('SD_EN_DEBUG', None) is not None else lambda *args, **kwargs: None
+debug('Trace: EN')
+card_full = '''
+
+
+
+
+
+'''
+card_list = '''
+
+'''
+
+
+def init_api(app):
+
+ def fetch_file(filename: str = ""):
+ if not os.path.exists(filename):
+ return JSONResponse({ "error": f"file {filename}: not found" }, status_code=404)
+ if filename.startswith('html/') or filename.startswith('models/'):
+ return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
+ if not any(Path(folder).absolute() in Path(filename).absolute().parents for folder in allowed_dirs):
+ return JSONResponse({ "error": f"file {filename}: must be in one of allowed directories" }, status_code=403)
+ if os.path.splitext(filename)[1].lower() not in (".png", ".jpg", ".jpeg", ".webp"):
+ return JSONResponse({"error": f"file {filename}: not an image file"}, status_code=403)
+ return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
+
+ def get_metadata(page: str = "", item: str = ""):
+ page = next(iter([x for x in shared.extra_networks if x.name == page]), None)
+ if page is None:
+ return JSONResponse({ 'metadata': 'none' })
+ metadata = page.metadata.get(item, 'none')
+ if metadata is None:
+ metadata = ''
+ # shared.log.debug(f"Extra networks metadata: page='{page}' item={item} len={len(metadata)}")
+ return JSONResponse({"metadata": metadata})
+
+ def get_info(page: str = "", item: str = ""):
+ page = next(iter([x for x in get_pages() if x.name == page]), None)
+ if page is None:
+ return JSONResponse({ 'info': 'none' })
+ item = next(iter([x for x in page.items if x['name'] == item]), None)
+ if item is None:
+ return JSONResponse({ 'info': 'none' })
+ info = page.find_info(item['filename'])
+ if info is None:
+ info = {}
+ # shared.log.debug(f"Extra networks info: page='{page.name}' item={item['name']} len={len(info)}")
+ return JSONResponse({"info": info})
+
+ def get_desc(page: str = "", item: str = ""):
+ page = next(iter([x for x in get_pages() if x.name == page]), None)
+ if page is None:
+ return JSONResponse({ 'description': 'none' })
+ item = next(iter([x for x in page.items if x['name'] == item]), None)
+ if item is None:
+ return JSONResponse({ 'description': 'none' })
+ desc = page.find_description(item['filename'])
+ if desc is None:
+ desc = ''
+ # shared.log.debug(f"Extra networks desc: page='{page.name}' item={item['name']} len={len(desc)}")
+ return JSONResponse({"description": desc})
+
+ app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"])
+ app.add_api_route("/sd_extra_networks/metadata", get_metadata, methods=["GET"])
+ app.add_api_route("/sd_extra_networks/info", get_info, methods=["GET"])
+ app.add_api_route("/sd_extra_networks/description", get_desc, methods=["GET"])
+
+
+class ExtraNetworksPage:
+ def __init__(self, title):
+ self.title = title
+ self.name = title.lower()
+ self.allow_negative_prompt = False
+ self.metadata = {}
+ self.info = {}
+ self.html = ''
+ self.items = []
+ self.missing_thumbs = []
+ self.refresh_time = 0
+ self.page_time = 0
+ self.list_time = 0
+ self.info_time = 0
+ self.desc_time = 0
+ self.dirs = {}
+ self.view = shared.opts.extra_networks_view
+ self.card = card_full if shared.opts.extra_networks_view == 'gallery' else card_list
+
+ def refresh(self):
+ pass
+
+ def create_xyz_grid(self):
+ xyz_grid = [x for x in scripts.scripts_data if x.script_class.__module__ == "xyz_grid.py"][0].module
+
+ def add_prompt(p, opt, x):
+ for item in [x for x in self.items if x["name"] == opt]:
+ try:
+ p.prompt = f'{p.prompt} {eval(item["prompt"])}' # pylint: disable=eval-used
+ except Exception as e:
+ shared.log.error(f'Cannot evaluate extra network prompt: {item["prompt"]} {e}')
+
+ if not any(self.title in x.label for x in xyz_grid.axis_options):
+ if self.title == 'Model':
+ return
+ opt = xyz_grid.AxisOption(f"[Network] {self.title}", str, add_prompt, choices=lambda: [x["name"] for x in self.items])
+ xyz_grid.axis_options.append(opt)
+
+ def link_preview(self, filename):
+ quoted_filename = urllib.parse.quote(filename.replace('\\', '/'))
+ mtime = os.path.getmtime(filename)
+ preview = f"./sd_extra_networks/thumb?filename={quoted_filename}&mtime={mtime}"
+ return preview
+
+ def search_terms_from_path(self, filename):
+ return filename.replace('\\', '/')
+
+ def is_empty(self, folder):
+ for f in shared.listdir(folder):
+ _fn, ext = os.path.splitext(f)
+ if ext.lower() in ['.ckpt', '.safetensors', '.pt', '.json'] or os.path.isdir(os.path.join(folder, f)):
+ return False
+ return True
+
+ def create_thumb(self):
+ debug(f'EN create-thumb: {self.name}')
+ created = 0
+ for f in self.missing_thumbs:
+ if not os.path.exists(f):
+ continue
+ fn, _ext = os.path.splitext(f)
+ fn = fn.replace('.preview', '')
+ fn = f'{fn}.thumb.jpg'
+ if os.path.exists(fn):
+ continue
+ img = None
+ try:
+ img = Image.open(f)
+ except Exception:
+ img = None
+ shared.log.warning(f'Extra network removing invalid image: {f}')
+ try:
+ if img is None:
+ img = None
+ os.remove(f)
+ elif img.width > 1024 or img.height > 1024 or os.path.getsize(f) > 65536:
+ img = img.convert('RGB')
+ img.thumbnail((512, 512), Image.Resampling.HAMMING)
+ img.save(fn, quality=50)
+ img.close()
+ created += 1
+ except Exception as e:
+ shared.log.warning(f'Extra network error creating thumbnail: {f} {e}')
+ if created > 0:
+ shared.log.info(f"Extra network thumbnails: {self.name} created={created}")
+ self.missing_thumbs.clear()
+
+ def create_items(self, tabname):
+ if self.refresh_time is not None and self.refresh_time > refresh_time: # cached results
+ return
+ t0 = time.time()
+ try:
+ self.items = list(self.list_items())
+ self.refresh_time = time.time()
+ except Exception as e:
+ self.items = []
+ shared.log.error(f'Extra networks error listing items: class={self.__class__.__name__} tab={tabname} {e}')
+ for item in self.items:
+ if item is None:
+ continue
+ self.metadata[item["name"]] = item.get("metadata", {})
+ t1 = time.time()
+ debug(f'EN create-items: page={self.name} items={len(self.items)} time={t1-t0:.2f}')
+ self.list_time += t1-t0
+
+
+ def create_page(self, tabname, skip = False):
+ debug(f'EN create-page: {self.name}')
+ if self.page_time > refresh_time and len(self.html) > 0: # cached page
+ return self.html
+ self_name_id = self.name.replace(" ", "_")
+ if skip:
+ return f""
+ subdirs = {}
+ allowed_folders = [os.path.abspath(x) for x in self.allowed_directories_for_previews()]
+ for parentdir, dirs in {d: modelloader.directory_list(d) for d in allowed_folders}.items():
+ for tgt in dirs.keys():
+ if os.path.join(paths.models_path, 'Reference') in tgt:
+ subdirs['Reference'] = 1
+ if shared.backend == shared.Backend.DIFFUSERS and shared.opts.diffusers_dir in tgt:
+ subdirs[os.path.basename(shared.opts.diffusers_dir)] = 1
+ if 'models--' in tgt:
+ continue
+ subdir = tgt[len(parentdir):].replace("\\", "/")
+ while subdir.startswith("/"):
+ subdir = subdir[1:]
+ # if not self.is_empty(tgt):
+ if not subdir.startswith("."):
+ subdirs[subdir] = 1
+ debug(f"Extra networks: page='{self.name}' subfolders={list(subdirs)}")
+ subdirs = OrderedDict(sorted(subdirs.items()))
+ if self.name == 'model':
+ subdirs['Reference'] = 1
+ subdirs[os.path.basename(shared.opts.diffusers_dir)] = 1
+ subdirs.move_to_end(os.path.basename(shared.opts.diffusers_dir))
+ subdirs.move_to_end('Reference')
+ if self.name == 'style' and shared.opts.extra_networks_styles:
+ subdirs['built-in'] = 1
+ subdirs_html = "All "
+ subdirs_html += "".join([f"{html.escape(subdir)} " for subdir in subdirs if subdir != ''])
+ self.html = ''
+ self.create_items(tabname)
+ self.create_xyz_grid()
+ htmls = []
+ if len(self.items) > 0 and self.items[0].get('mtime', None) is not None:
+ self.items.sort(key=lambda x: x["mtime"], reverse=True)
+ for item in self.items:
+ htmls.append(self.create_html(item, tabname))
+ self.html += ''.join(htmls)
+ self.page_time = time.time()
+ if len(subdirs_html) > 0 or len(self.html) > 0:
+ self.html = f""
+ else:
+ return ''
+ shared.log.debug(f"Extra networks: page='{self.name}' items={len(self.items)} subfolders={len(subdirs)} tab={tabname} folders={self.allowed_directories_for_previews()} list={self.list_time:.2f} desc={self.desc_time:.2f} info={self.info_time:.2f} workers={shared.max_workers}")
+ if len(self.missing_thumbs) > 0:
+ threading.Thread(target=self.create_thumb).start()
+ return self.html
+
+ def list_items(self):
+ raise NotImplementedError
+
+ def allowed_directories_for_previews(self):
+ return []
+
+ def create_html(self, item, tabname):
+ try:
+ args = {
+ "tabname": tabname,
+ "page": self.name,
+ "name": item["name"],
+ "title": os.path.basename(item["name"].replace('_', ' ')),
+ "filename": item["filename"],
+ "tags": '|'.join([item.get("tags")] if isinstance(item.get("tags", {}), str) else list(item.get("tags", {}).keys())),
+ "preview": html.escape(item.get("preview", self.link_preview('html/card-no-preview.png'))),
+ "width": shared.opts.extra_networks_card_size,
+ "height": shared.opts.extra_networks_card_size if shared.opts.extra_networks_card_square else 'auto',
+ "fit": shared.opts.extra_networks_card_fit,
+ "prompt": item.get("prompt", None),
+ "search": item.get("search_term", ""),
+ "description": item.get("description") or "",
+ "card_click": item.get("onclick", '"' + html.escape(f'return cardClicked({item.get("prompt", None)}, {"true" if self.allow_negative_prompt else "false"})') + '"'),
+ "mtime": item.get("mtime", 0),
+ "size": item.get("size", 0),
+ }
+ alias = item.get("alias", None)
+ if alias is not None:
+ args['title'] += f'\nAlias: {alias}'
+ return self.card.format(**args)
+ except Exception as e:
+ shared.log.error(f'Extra networks item error: page={tabname} item={item["name"]} {e}')
+ return ""
+
+ def find_preview_file(self, path):
+ if path is None:
+ return 'html/card-no-preview.png'
+ if shared.opts.diffusers_dir in path:
+ path = os.path.relpath(path, shared.opts.diffusers_dir)
+ ref = os.path.join('models', 'Reference')
+ fn = os.path.join(ref, path.replace('models--', '').replace('\\', '/').split('/')[0])
+ files = shared.listdir(ref)
+ else:
+ files = shared.listdir(os.path.dirname(path))
+ fn = os.path.splitext(path)[0]
+ exts = ["jpg", "jpeg", "png", "webp", "tiff", "jp2"]
+ for file in [f'{fn}{mid}{ext}' for ext in exts for mid in ['.thumb.', '.', '.preview.']]:
+ if file in files:
+ if 'Reference' not in file and '.thumb.' not in file:
+ self.missing_thumbs.append(file)
+ return file
+ return 'html/card-no-preview.png'
+
+ def find_preview(self, path):
+ preview_file = self.find_preview_file(path)
+ return self.link_preview(preview_file)
+
+ def find_description(self, path, info=None):
+ t0 = time.time()
+ class HTMLFilter(HTMLParser):
+ text = ""
+ def handle_data(self, data):
+ self.text += data
+ def handle_endtag(self, tag):
+ if tag == 'p':
+ self.text += '\n'
+
+ fn = os.path.splitext(path)[0] + '.txt'
+ if fn in shared.listdir(os.path.dirname(path)):
+ try:
+ with open(fn, "r", encoding="utf-8", errors="replace") as f:
+ txt = f.read()
+ txt = re.sub('[<>]', '', txt)
+ return txt
+ except OSError:
+ pass
+ if info is None:
+ info = self.find_info(path)
+ desc = info.get('description', '') or ''
+ f = HTMLFilter()
+ f.feed(desc)
+ t1 = time.time()
+ self.desc_time += t1-t0
+ return f.text
+
+ def find_info(self, path):
+ fn = os.path.splitext(path)[0] + '.json'
+ data = {}
+ if fn in shared.listdir(os.path.dirname(path)):
+ t0 = time.time()
+ data = shared.readfile(fn, silent=True)
+ if type(data) is list:
+ data = data[0]
+ t1 = time.time()
+ self.info_time += t1-t0
+ return data
+
+
+def initialize():
+ shared.extra_networks.clear()
+
+
+def register_page(page: ExtraNetworksPage):
+ # registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions
+ debug(f'EN register-page: {page}')
+ if page in shared.extra_networks:
+ debug(f'EN register-page: {page} already registered')
+ return
+ shared.extra_networks.append(page)
+ # allowed_dirs.clear()
+ # for pg in shared.extra_networks:
+ for folder in page.allowed_directories_for_previews():
+ if folder not in allowed_dirs:
+ allowed_dirs.append(os.path.abspath(folder))
+
+
+def register_pages():
+ from modules.ui_extra_networks_textual_inversion import ExtraNetworksPageTextualInversion
+ from modules.ui_extra_networks_hypernets import ExtraNetworksPageHypernetworks
+ from modules.ui_extra_networks_checkpoints import ExtraNetworksPageCheckpoints
+ from modules.ui_extra_networks_styles import ExtraNetworksPageStyles
+ from modules.ui_extra_networks_vae import ExtraNetworksPageVAEs
+ debug('EN register-pages')
+ register_page(ExtraNetworksPageCheckpoints())
+ register_page(ExtraNetworksPageStyles())
+ register_page(ExtraNetworksPageTextualInversion())
+ register_page(ExtraNetworksPageHypernetworks())
+ register_page(ExtraNetworksPageVAEs())
+
+
+def get_pages(title=None):
+ pages = []
+ if 'All' in shared.opts.extra_networks:
+ pages = shared.extra_networks
+ else:
+ titles = [page.title for page in shared.extra_networks]
+ if title is None:
+ for page in shared.opts.extra_networks:
+ try:
+ idx = titles.index(page)
+ pages.append(shared.extra_networks[idx])
+ except ValueError:
+ continue
+ else:
+ try:
+ idx = titles.index(title)
+ pages.append(shared.extra_networks[idx])
+ except ValueError:
+ pass
+ return pages
+
+
+class ExtraNetworksUi:
+ def __init__(self):
+ self.tabname: str = None
+ self.pages: list(str) = None
+ self.visible: gr.State = None
+ self.state: gr.Textbox = None
+ self.details: gr.Group = None
+ self.tabs: gr.Tabs = None
+ self.gallery: gr.Gallery = None
+ self.description: gr.Textbox = None
+ self.search: gr.Textbox = None
+ self.button_details: gr.Button = None
+ self.button_refresh: gr.Button = None
+ self.button_scan: gr.Button = None
+ self.button_view: gr.Button = None
+ self.button_quicksave: gr.Button = None
+ self.button_save: gr.Button = None
+ self.button_sort: gr.Button = None
+ self.button_apply: gr.Button = None
+ self.button_close: gr.Button = None
+ self.button_model: gr.Checkbox = None
+ self.details_components: list = []
+ self.last_item: dict = None
+ self.last_page: ExtraNetworksPage = None
+ self.state: gr.State = None
+
+
+def create_ui(container, button_parent, tabname, skip_indexing = False):
+ debug(f'EN create-ui: {tabname}')
+ ui = ExtraNetworksUi()
+ ui.tabname = tabname
+ ui.pages = []
+ ui.state = gr.Textbox('{}', elem_id=f"{tabname}_extra_state", visible=False)
+ ui.visible = gr.State(value=False) # pylint: disable=abstract-class-instantiated
+ ui.details = gr.Group(elem_id=f"{tabname}_extra_details", visible=False)
+ ui.tabs = gr.Tabs(elem_id=f"{tabname}_extra_tabs")
+ ui.button_details = gr.Button('Details', elem_id=f"{tabname}_extra_details_btn", visible=False)
+ state = {}
+ if shared.cmd_opts.profile:
+ import cProfile
+ pr = cProfile.Profile()
+ pr.enable()
+
+ def get_item(state, params = None):
+ if params is not None and type(params) == dict:
+ page = next(iter([x for x in get_pages() if x.title == 'Style']), None)
+ item = page.create_style(params)
+ else:
+ if state is None or not hasattr(state, 'page') or not hasattr(state, 'item'):
+ return None, None
+ page = next(iter([x for x in get_pages() if x.title == state.page]), None)
+ if page is None:
+ return None, None
+ item = next(iter([x for x in page.items if x["name"] == state.item]), None)
+ if item is None:
+ return page, None
+ item = SimpleNamespace(**item)
+ ui.last_item = item
+ ui.last_page = page
+ return page, item
+
+ # main event that is triggered when js updates state text field with json values, used to communicate js -> python
+ def state_change(state_text):
+ try:
+ nonlocal state
+ state = SimpleNamespace(**json.loads(state_text))
+ except Exception as e:
+ shared.log.error(f'Extra networks state error: {e}')
+ return
+ _page, _item = get_item(state)
+ # shared.log.debug(f'Extra network: op={state.op} page={page.title if page is not None else None} item={item.filename if item is not None else None}')
+
+ def toggle_visibility(is_visible):
+ is_visible = not is_visible
+ return is_visible, gr.update(visible=is_visible), gr.update(variant=("secondary-down" if is_visible else "secondary"))
+
+ with ui.details:
+ details_close = ToolButton(symbols.close, elem_id=f"{tabname}_extra_details_close", elem_classes=['extra-details-close'])
+ details_close.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=[ui.details])
+ with gr.Row():
+ with gr.Column(scale=1):
+ text = gr.HTML('title
')
+ ui.details_components.append(text)
+ with gr.Column(scale=1):
+ img = gr.Image(value=None, show_label=False, interactive=False, container=False, show_download_button=False, show_info=False, elem_id=f"{tabname}_extra_details_img", elem_classes=['extra-details-img'])
+ ui.details_components.append(img)
+ with gr.Row():
+ btn_save_img = gr.Button('Replace', elem_classes=['small-button'])
+ btn_delete_img = gr.Button('Delete', elem_classes=['small-button'])
+ with gr.Tabs():
+ with gr.Tab('Description'):
+ desc = gr.Textbox('', show_label=False, lines=8, placeholder="Extra network description...")
+ ui.details_components.append(desc)
+ with gr.Row():
+ btn_save_desc = gr.Button('Save', elem_classes=['small-button'], elem_id=f'{tabname}_extra_details_save_desc')
+ btn_delete_desc = gr.Button('Delete', elem_classes=['small-button'], elem_id=f'{tabname}_extra_details_delete_desc')
+ btn_close_desc = gr.Button('Close', elem_classes=['small-button'], elem_id=f'{tabname}_extra_details_close_desc')
+ btn_close_desc.click(fn=lambda: gr.update(visible=False), _js='refeshDetailsEN', inputs=[], outputs=[ui.details])
+ with gr.Tab('Model metadata'):
+ info = gr.JSON({}, show_label=False)
+ ui.details_components.append(info)
+ with gr.Row():
+ btn_save_info = gr.Button('Save', elem_classes=['small-button'], elem_id=f'{tabname}_extra_details_save_info')
+ btn_delete_info = gr.Button('Delete', elem_classes=['small-button'], elem_id=f'{tabname}_extra_details_delete_info')
+ btn_close_info = gr.Button('Close', elem_classes=['small-button'], elem_id=f'{tabname}_extra_details_close_info')
+ btn_close_info.click(fn=lambda: gr.update(visible=False), _js='refeshDetailsEN', inputs=[], outputs=[ui.details])
+ with gr.Tab('Embedded metadata'):
+ meta = gr.JSON({}, show_label=False)
+ ui.details_components.append(meta)
+
+ with ui.tabs:
+ def ui_tab_change(page):
+ scan_visible = page in ['Model', 'Lora', 'Hypernetwork', 'Embedding']
+ save_visible = page in ['Style']
+ model_visible = page in ['Model']
+ return [gr.update(visible=scan_visible), gr.update(visible=save_visible), gr.update(visible=model_visible)]
+
+ ui.button_refresh = ToolButton(symbols.refresh, elem_id=f"{tabname}_extra_refresh")
+ ui.button_scan = ToolButton(symbols.scan, elem_id=f"{tabname}_extra_scan", visible=True)
+ ui.button_quicksave = ToolButton(symbols.book, elem_id=f"{tabname}_extra_quicksave", visible=False)
+ ui.button_save = ToolButton(symbols.book, elem_id=f"{tabname}_extra_save", visible=False)
+ ui.button_sort = ToolButton(symbols.sort, elem_id=f"{tabname}_extra_sort", visible=True)
+ ui.button_view = ToolButton(symbols.view, elem_id=f"{tabname}_extra_view", visible=True)
+ ui.button_close = ToolButton(symbols.close, elem_id=f"{tabname}_extra_close", visible=True)
+ ui.button_model = ToolButton(symbols.refine, elem_id=f"{tabname}_extra_model", visible=True)
+ ui.search = gr.Textbox('', show_label=False, elem_id=f"{tabname}_extra_search", placeholder="Search...", elem_classes="textbox", lines=2, container=False)
+ ui.description = gr.Textbox('', show_label=False, elem_id=f"{tabname}_description", elem_classes="textbox", lines=2, interactive=False, container=False)
+
+ if ui.tabname == 'txt2img': # refresh only once
+ global refresh_time # pylint: disable=global-statement
+ refresh_time = time.time()
+ if not skip_indexing:
+ threads = []
+ for page in get_pages():
+ if os.environ.get('SD_EN_DEBUG', None) is not None:
+ threads.append(threading.Thread(target=page.create_items, args=[ui.tabname]))
+ threads[-1].start()
+ else:
+ page.create_items(ui.tabname)
+ for thread in threads:
+ thread.join()
+ for page in get_pages():
+ page.create_page(ui.tabname, skip_indexing)
+ with gr.Tab(page.title, id=page.title.lower().replace(" ", "_"), elem_classes="extra-networks-tab") as tab:
+ page_html = gr.HTML(page.html, elem_id=f'{tabname}{page.name}_extra_page', elem_classes="extra-networks-page")
+ ui.pages.append(page_html)
+ tab.select(ui_tab_change, _js="getENActivePage", inputs=[ui.button_details], outputs=[ui.button_scan, ui.button_save, ui.button_model])
+ if shared.cmd_opts.profile:
+ errors.profile(pr, 'ExtraNetworks')
+ pr.disable()
+ # ui.tabs.change(fn=ui_tab_change, inputs=[], outputs=[ui.button_scan, ui.button_save])
+
+ def fn_save_img(image):
+ if ui.last_item is None or ui.last_item.local_preview is None:
+ return 'html/card-no-preview.png'
+ images = list(ui.gallery.temp_files) # gallery cannot be used as input component so looking at most recently registered temp files
+ if len(images) < 1:
+ shared.log.warning(f'Extra network no image: item={ui.last_item.name}')
+ return 'html/card-no-preview.png'
+ try:
+ images.sort(key=lambda f: os.path.getmtime(f), reverse=True)
+ image = Image.open(images[0])
+ except Exception as e:
+ shared.log.error(f'Extra network error opening image: item={ui.last_item.name} {e}')
+ return 'html/card-no-preview.png'
+ fn_delete_img(image)
+ if image.width > 512 or image.height > 512:
+ image = image.convert('RGB')
+ image.thumbnail((512, 512), Image.Resampling.HAMMING)
+ try:
+ image.save(ui.last_item.local_preview, quality=50)
+ shared.log.debug(f'Extra network save image: item={ui.last_item.name} filename="{ui.last_item.local_preview}"')
+ except Exception as e:
+ shared.log.error(f'Extra network save image: item={ui.last_item.name} filename="{ui.last_item.local_preview}" {e}')
+ return image
+
+ def fn_delete_img(_image):
+ preview_extensions = ["jpg", "jpeg", "png", "webp", "tiff", "jp2"]
+ fn = os.path.splitext(ui.last_item.filename)[0]
+ for file in [f'{fn}{mid}{ext}' for ext in preview_extensions for mid in ['.thumb.', '.preview.', '.']]:
+ if os.path.exists(file):
+ os.remove(file)
+ shared.log.debug(f'Extra network delete image: item={ui.last_item.name} filename="{file}"')
+ return 'html/card-no-preview.png'
+
+ def fn_save_desc(desc):
+ if hasattr(ui.last_item, 'type') and ui.last_item.type == 'Style':
+ params = ui.last_page.parse_desc(desc)
+ if params is not None:
+ fn_save_info(params)
+ else:
+ fn = os.path.splitext(ui.last_item.filename)[0] + '.txt'
+ with open(fn, 'w', encoding='utf-8') as f:
+ f.write(desc)
+ shared.log.debug(f'Extra network save desc: item={ui.last_item.name} filename="{fn}"')
+ return desc
+
+ def fn_delete_desc(desc):
+ if ui.last_item is None:
+ return desc
+ if hasattr(ui.last_item, 'type') and ui.last_item.type == 'Style':
+ fn = os.path.splitext(ui.last_item.filename)[0] + '.json'
+ else:
+ fn = os.path.splitext(ui.last_item.filename)[0] + '.txt'
+ if os.path.exists(fn):
+ shared.log.debug(f'Extra network delete desc: item={ui.last_item.name} filename="{fn}"')
+ os.remove(fn)
+ return ''
+ return desc
+
+ def fn_save_info(info):
+ fn = os.path.splitext(ui.last_item.filename)[0] + '.json'
+ shared.writefile(info, fn, silent=True)
+ shared.log.debug(f'Extra network save info: item={ui.last_item.name} filename="{fn}"')
+ return info
+
+ def fn_delete_info(info):
+ if ui.last_item is None:
+ return info
+ fn = os.path.splitext(ui.last_item.filename)[0] + '.json'
+ if os.path.exists(fn):
+ shared.log.debug(f'Extra network delete info: item={ui.last_item.name} filename="{fn}"')
+ os.remove(fn)
+ return ''
+ return info
+
+ btn_save_img.click(fn=fn_save_img, _js='closeDetailsEN', inputs=[img], outputs=[img])
+ btn_delete_img.click(fn=fn_delete_img, _js='closeDetailsEN', inputs=[img], outputs=[img])
+ btn_save_desc.click(fn=fn_save_desc, _js='closeDetailsEN', inputs=[desc], outputs=[desc])
+ btn_delete_desc.click(fn=fn_delete_desc, _js='closeDetailsEN', inputs=[desc], outputs=[desc])
+ btn_save_info.click(fn=fn_save_info, _js='closeDetailsEN', inputs=[info], outputs=[info])
+ btn_delete_info.click(fn=fn_delete_info, _js='closeDetailsEN', inputs=[info], outputs=[info])
+
+ def show_details(text, img, desc, info, meta, params):
+ page, item = get_item(state, params)
+ if item is not None and hasattr(item, 'name'):
+ stat = os.stat(item.filename) if os.path.exists(item.filename) else None
+ desc = item.description
+ fullinfo = shared.readfile(os.path.splitext(item.filename)[0] + '.json', silent=True)
+ if 'modelVersions' in fullinfo: # sanitize massive objects
+ fullinfo['modelVersions'] = []
+ info = fullinfo
+ meta = page.metadata.get(item.name, {}) or {}
+ if type(meta) is str:
+ try:
+ meta = json.loads(meta)
+ except Exception:
+ meta = {}
+ if ui.last_item.preview.startswith('data:'):
+ b64str = ui.last_item.preview.split(',',1)[1]
+ img = Image.open(io.BytesIO(base64.b64decode(b64str)))
+ elif hasattr(item, 'local_preview') and os.path.exists(item.local_preview):
+ img = item.local_preview
+ else:
+ img = page.find_preview_file(item.filename)
+ lora = ''
+ model = ''
+ style = ''
+ note = ''
+ if not os.path.exists(item.filename):
+ note = f' Target filename: {item.filename}'
+ if page.title == 'Model':
+ merge = len(list(meta.get('sd_merge_models', {})))
+ if merge > 0:
+ model += f'Merge models {merge} recipes '
+ if meta.get('modelspec.architecture', None) is not None:
+ model += f'''
+ Architecture {meta.get('modelspec.architecture', 'N/A')}
+ Title {meta.get('modelspec.title', 'N/A')}
+ Resolution {meta.get('modelspec.resolution', 'N/A')}
+ '''
+ if page.title == 'Lora':
+ try:
+ tags = getattr(item, 'tags', {})
+ tags = [f'{name}:{tags[name]}' for i, name in enumerate(tags)]
+ tags = ' '.join(tags)
+ except Exception:
+ tags = ''
+ try:
+ triggers = ' '.join(info.get('tags', []))
+ except Exception:
+ triggers = ''
+ lora = f'''
+ Model tags {tags}
+ User tags {triggers}
+ Base model {meta.get('ss_sd_model_name', 'N/A')}
+ Resolution {meta.get('ss_resolution', 'N/A')}
+ Training images {meta.get('ss_num_train_images', 'N/A')}
+ Comment {meta.get('ss_training_comment', 'N/A')}
+ '''
+ if page.title == 'Style':
+ style = f'''
+ Name {item.name}
+ Description {item.description}
+ Preview Embedded {item.preview.startswith('data:')}
+ '''
+ desc = f'Name: {os.path.basename(item.name)}\nDescription: {item.description}\nPrompt: {item.prompt}\nNegative: {item.negative}\nExtra: {item.extra}\n'
+ text = f'''
+ {item.name}
+
+ Type {page.title}
+ Alias {getattr(item, 'alias', 'N/A')}
+ Filename {item.filename}
+ Hash {getattr(item, 'hash', 'N/A')}
+ Size {round(stat.st_size/1024/1024, 2) if stat is not None else 'N/A'} MB
+ Last modified {datetime.fromtimestamp(stat.st_mtime) if stat is not None else 'N/A'}
+
+ {lora}
+ {model}
+ {style}
+
+ {note}
+ '''
+ return [text, img, desc, info, meta, gr.update(visible=item is not None)]
+
+ def ui_refresh_click(title):
+ pages = []
+ for page in get_pages():
+ if title is None or title == '' or title == page.title or len(page.html) == 0:
+ page.page_time = 0
+ page.refresh_time = 0
+ page.refresh()
+ page.create_page(ui.tabname)
+ shared.log.debug(f"Refreshing Extra networks: page='{page.title}' items={len(page.items)} tab={ui.tabname}")
+ pages.append(page.html)
+ ui.search.update(value = ui.search.value)
+ return pages
+
+ def ui_view_cards(title):
+ pages = []
+ for page in get_pages():
+ if title is None or title == '' or title == page.title or len(page.html) == 0:
+ shared.opts.extra_networks_view = page.view
+ page.view = 'gallery' if page.view == 'list' else 'list'
+ page.card = card_full if page.view == 'gallery' else card_list
+ page.html = ''
+ page.create_page(ui.tabname)
+ shared.log.debug(f"Refreshing Extra networks: page='{page.title}' items={len(page.items)} tab={ui.tabname} view={page.view}")
+ pages.append(page.html)
+ ui.search.update(value = ui.search.value)
+ return pages
+
+ def ui_scan_click(title):
+ from modules import ui_models
+ if ui_models.search_metadata_civit is not None:
+ ui_models.search_metadata_civit(True, title)
+ return ui_refresh_click(title)
+
+ def ui_save_click():
+ from modules import generation_parameters_copypaste
+ filename = os.path.join(paths.data_path, "params.txt")
+ if os.path.exists(filename):
+ with open(filename, "r", encoding="utf8") as file:
+ prompt = file.read()
+ else:
+ prompt = ''
+ params = generation_parameters_copypaste.parse_generation_parameters(prompt)
+ res = show_details(text=None, img=None, desc=None, info=None, meta=None, params=params)
+ return res
+
+ def ui_quicksave_click(name):
+ from modules import generation_parameters_copypaste
+ fn = os.path.join(paths.data_path, "params.txt")
+ if os.path.exists(fn):
+ with open(fn, "r", encoding="utf8") as file:
+ prompt = file.read()
+ else:
+ prompt = ''
+ params = generation_parameters_copypaste.parse_generation_parameters(prompt)
+ fn = os.path.join(shared.opts.styles_dir, os.path.splitext(name)[0] + '.json')
+ prompt = params.get('Prompt', '')
+ item = {
+ "name": name,
+ "description": '',
+ "prompt": prompt,
+ "negative": params.get('Negative prompt', ''),
+ "extra": '',
+ # "type": 'Style',
+ # "title": name,
+ # "filename": fn,
+ # "search_term": None,
+ # "preview": None,
+ # "local_preview": None,
+ }
+ shared.writefile(item, fn, silent=True)
+ if len(prompt) > 0:
+ shared.log.debug(f"Extra network quick save style: item={name} filename='{fn}'")
+ else:
+ shared.log.warning(f"Extra network quick save model: item={name} filename='{fn}' prompt is empty")
+
+ def ui_sort_cards(msg):
+ shared.log.debug(f'Extra networks: {msg}')
+ return msg
+
+ dummy = gr.State(value=False) # pylint: disable=abstract-class-instantiated
+ button_parent.click(fn=toggle_visibility, inputs=[ui.visible], outputs=[ui.visible, container, button_parent])
+ ui.button_close.click(fn=toggle_visibility, inputs=[ui.visible], outputs=[ui.visible, container])
+ ui.button_sort.click(fn=ui_sort_cards, _js='sortExtraNetworks', inputs=[ui.search], outputs=[ui.description])
+ ui.button_view.click(fn=ui_view_cards, inputs=[ui.search], outputs=ui.pages)
+ ui.button_refresh.click(fn=ui_refresh_click, _js='getENActivePage', inputs=[ui.search], outputs=ui.pages)
+ ui.button_scan.click(fn=ui_scan_click, _js='getENActivePage', inputs=[ui.search], outputs=ui.pages)
+ ui.button_save.click(fn=ui_save_click, inputs=[], outputs=ui.details_components + [ui.details])
+ ui.button_quicksave.click(fn=ui_quicksave_click, _js="() => prompt('Prompt name', '')", inputs=[ui.search], outputs=[])
+ ui.button_details.click(show_details, _js="getCardDetails", inputs=ui.details_components + [dummy], outputs=ui.details_components + [ui.details])
+ ui.state.change(state_change, inputs=[ui.state], outputs=[])
+ return ui
+
+
+def setup_ui(ui, gallery):
+ ui.gallery = gallery
diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py
index c217348c0..7bf7e197b 100644
--- a/modules/ui_extra_networks_checkpoints.py
+++ b/modules/ui_extra_networks_checkpoints.py
@@ -1,84 +1,84 @@
-import os
-import html
-import json
-import concurrent
-from modules import shared, ui_extra_networks, sd_models
-
-
-reference_dir = os.path.join('models', 'Reference')
-
-class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
- def __init__(self):
- super().__init__('Model')
-
- def refresh(self):
- shared.refresh_checkpoints()
-
- def list_reference(self): # pylint: disable=inconsistent-return-statements
- reference_models = shared.readfile(os.path.join('html', 'reference.json'))
- for k, v in reference_models.items():
- if shared.backend != shared.Backend.DIFFUSERS:
- if not v.get('original', False):
- continue
- url = v.get('alt', None) or v['path']
- else:
- url = v['path']
- name = os.path.join(reference_dir, k)
- preview = v.get('preview', v['path'])
- yield {
- "type": 'Model',
- "name": name,
- "title": name,
- "filename": url,
- "search_term": self.search_terms_from_path(name),
- "preview": self.find_preview(os.path.join(reference_dir, preview)),
- "local_preview": self.find_preview_file(os.path.join(reference_dir, preview)),
- "onclick": '"' + html.escape(f"""return selectReference({json.dumps(url)})""") + '"',
- "hash": None,
- "mtime": 0,
- "size": 0,
- "info": {},
- "metadata": {},
- "description": v.get('desc', ''),
- }
-
- def create_item(self, name):
- record = None
- try:
- checkpoint: sd_models.CheckpointInfo = sd_models.checkpoints_list.get(name)
- exists = os.path.exists(checkpoint.filename)
- record = {
- "type": 'Model',
- "name": checkpoint.name,
- "title": checkpoint.title,
- "filename": checkpoint.filename,
- "hash": checkpoint.shorthash,
- "search_term": self.search_terms_from_path(checkpoint.title),
- "preview": self.find_preview(checkpoint.filename),
- "local_preview": f"{os.path.splitext(checkpoint.filename)[0]}.{shared.opts.samples_format}",
- "metadata": checkpoint.metadata,
- "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"',
- "mtime": os.path.getmtime(checkpoint.filename) if exists else 0,
- "size": os.path.getsize(checkpoint.filename) if exists else 0,
- }
- record["info"] = self.find_info(checkpoint.filename)
- record["description"] = self.find_description(checkpoint.filename, record["info"])
- except Exception as e:
- shared.log.debug(f"Extra networks error: type=model file={name} {e}")
- return record
-
- def list_items(self):
- with concurrent.futures.ThreadPoolExecutor(max_workers=shared.max_workers) as executor:
- future_items = {executor.submit(self.create_item, cp): cp for cp in list(sd_models.checkpoints_list.copy())}
- for future in concurrent.futures.as_completed(future_items):
- item = future.result()
- if item is not None:
- yield item
- for record in self.list_reference():
- yield record
-
- def allowed_directories_for_previews(self):
- if shared.backend == shared.Backend.DIFFUSERS:
- return [v for v in [shared.opts.ckpt_dir, shared.opts.diffusers_dir, reference_dir] if v is not None]
- else:
- return [v for v in [shared.opts.ckpt_dir, reference_dir, sd_models.model_path] if v is not None]
+import os
+import html
+import json
+import concurrent
+from modules import shared, ui_extra_networks, sd_models
+
+
+reference_dir = os.path.join('models', 'Reference')
+
+class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
+ def __init__(self):
+ super().__init__('Model')
+
+ def refresh(self):
+ shared.refresh_checkpoints()
+
+ def list_reference(self): # pylint: disable=inconsistent-return-statements
+ reference_models = shared.readfile(os.path.join('html', 'reference.json'))
+ for k, v in reference_models.items():
+ if shared.backend != shared.Backend.DIFFUSERS:
+ if not v.get('original', False):
+ continue
+ url = v.get('alt', None) or v['path']
+ else:
+ url = v['path']
+ name = os.path.join(reference_dir, k)
+ preview = v.get('preview', v['path'])
+ yield {
+ "type": 'Model',
+ "name": name,
+ "title": name,
+ "filename": url,
+ "search_term": self.search_terms_from_path(name),
+ "preview": self.find_preview(os.path.join(reference_dir, preview)),
+ "local_preview": self.find_preview_file(os.path.join(reference_dir, preview)),
+ "onclick": '"' + html.escape(f"""return selectReference({json.dumps(url)})""") + '"',
+ "hash": None,
+ "mtime": 0,
+ "size": 0,
+ "info": {},
+ "metadata": {},
+ "description": v.get('desc', ''),
+ }
+
+ def create_item(self, name):
+ record = None
+ try:
+ checkpoint: sd_models.CheckpointInfo = sd_models.checkpoints_list.get(name)
+ exists = os.path.exists(checkpoint.filename)
+ record = {
+ "type": 'Model',
+ "name": checkpoint.name,
+ "title": checkpoint.title,
+ "filename": checkpoint.filename,
+ "hash": checkpoint.shorthash,
+ "search_term": self.search_terms_from_path(checkpoint.title),
+ "preview": self.find_preview(checkpoint.filename),
+ "local_preview": f"{os.path.splitext(checkpoint.filename)[0]}.{shared.opts.samples_format}",
+ "metadata": checkpoint.metadata,
+ "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"',
+ "mtime": os.path.getmtime(checkpoint.filename) if exists else 0,
+ "size": os.path.getsize(checkpoint.filename) if exists else 0,
+ }
+ record["info"] = self.find_info(checkpoint.filename)
+ record["description"] = self.find_description(checkpoint.filename, record["info"])
+ except Exception as e:
+ shared.log.debug(f"Extra networks error: type=model file={name} {e}")
+ return record
+
+ def list_items(self):
+ with concurrent.futures.ThreadPoolExecutor(max_workers=shared.max_workers) as executor:
+ future_items = {executor.submit(self.create_item, cp): cp for cp in list(sd_models.checkpoints_list.copy())}
+ for future in concurrent.futures.as_completed(future_items):
+ item = future.result()
+ if item is not None:
+ yield item
+ for record in self.list_reference():
+ yield record
+
+ def allowed_directories_for_previews(self):
+ if shared.backend == shared.Backend.DIFFUSERS:
+ return [v for v in [shared.opts.ckpt_dir, shared.opts.diffusers_dir, reference_dir] if v is not None]
+ else:
+ return [v for v in [shared.opts.ckpt_dir, reference_dir, sd_models.model_path] if v is not None]
diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py
index 28189dad4..b31ef975e 100644
--- a/modules/ui_extra_networks_hypernets.py
+++ b/modules/ui_extra_networks_hypernets.py
@@ -1,34 +1,34 @@
-import json
-import os
-from modules import shared, ui_extra_networks
-
-
-class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
- def __init__(self):
- super().__init__('Hypernetwork')
-
- def refresh(self):
- shared.reload_hypernetworks()
-
- def list_items(self):
- for name, path in shared.hypernetworks.items():
- try:
- name = os.path.relpath(os.path.splitext(path)[0], shared.opts.hypernetwork_dir)
- yield {
- "type": 'Hypernetwork',
- "name": name,
- "filename": path,
- "preview": self.find_preview(path),
- "description": self.find_description(path),
- "info": self.find_info(path),
- "search_term": self.search_terms_from_path(name),
- "prompt": json.dumps(f""),
- "local_preview": f"{os.path.splitext(path)[0]}.{shared.opts.samples_format}",
- "mtime": os.path.getmtime(path),
- "size": os.path.getsize(path),
- }
- except Exception as e:
- shared.log.debug(f"Extra networks error: type=hypernetwork file={path} {e}")
-
- def allowed_directories_for_previews(self):
- return [shared.opts.hypernetwork_dir]
+import json
+import os
+from modules import shared, ui_extra_networks
+
+
+class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
+ def __init__(self):
+ super().__init__('Hypernetwork')
+
+ def refresh(self):
+ shared.reload_hypernetworks()
+
+ def list_items(self):
+ for name, path in shared.hypernetworks.items():
+ try:
+ name = os.path.relpath(os.path.splitext(path)[0], shared.opts.hypernetwork_dir)
+ yield {
+ "type": 'Hypernetwork',
+ "name": name,
+ "filename": path,
+ "preview": self.find_preview(path),
+ "description": self.find_description(path),
+ "info": self.find_info(path),
+ "search_term": self.search_terms_from_path(name),
+ "prompt": json.dumps(f""),
+ "local_preview": f"{os.path.splitext(path)[0]}.{shared.opts.samples_format}",
+ "mtime": os.path.getmtime(path),
+ "size": os.path.getsize(path),
+ }
+ except Exception as e:
+ shared.log.debug(f"Extra networks error: type=hypernetwork file={path} {e}")
+
+ def allowed_directories_for_previews(self):
+ return [shared.opts.hypernetwork_dir]
diff --git a/modules/ui_extra_networks_styles.py b/modules/ui_extra_networks_styles.py
index 9f46850e9..909414f6c 100644
--- a/modules/ui_extra_networks_styles.py
+++ b/modules/ui_extra_networks_styles.py
@@ -1,137 +1,137 @@
-import os
-import html
-import json
-import concurrent
-from modules import shared, extra_networks, ui_extra_networks, styles
-
-
-class ExtraNetworksPageStyles(ui_extra_networks.ExtraNetworksPage):
- def __init__(self):
- super().__init__('Style')
-
- def refresh(self):
- shared.prompt_styles.reload()
-
- def parse_desc(self, desc):
- lines = desc.strip().split("\n")
- params = { 'name': '', 'description': '', 'prompt': '', 'negative': '', 'extra': ''}
- found = ''
- for line in lines:
- line = line.strip()
- if line.lower().startswith('name:'):
- found = 'name'
- params['name'] = line[5:].strip()
- elif line.lower().startswith('description:'):
- found = 'description'
- params['description'] = line[12:].strip()
- elif line.lower().startswith('prompt:'):
- found = 'prompt'
- params['prompt'] = line[7:].strip()
- elif line.lower().startswith('negative:'):
- found = 'negative'
- params['negative'] = line[9:].strip()
- elif line.lower().startswith('extra:'):
- found = 'extra'
- params['extra'] = line[6:].strip()
- elif found != '':
- params[found] += '\n' + line
- if params['name'] == '':
- return None
- if params['description'] == '':
- params['description'] = params['name']
- return params
-
- def create_style(self, params):
- from modules.images import FilenameGenerator
- from hashlib import sha256
- namegen = FilenameGenerator(p=None, seed=None, prompt=params.get('Prompt', ''), image=None, grid=False)
- name = namegen.prompt_words()
- sha = sha256(json.dumps(name).encode()).hexdigest()[0:8]
- fn = os.path.join(shared.opts.styles_dir, sha + '.json')
- item = {
- "type": 'Style',
- "name": name,
- "title": name,
- "filename": fn,
- "search_term": f'{self.search_terms_from_path(fn)} {params.get("Prompt", "")}',
- "preview": self.find_preview(name),
- "description": '',
- "prompt": params.get('Prompt', ''),
- "negative": params.get('Negative prompt', ''),
- "extra": '',
- "local_preview": f"{name}.{shared.opts.samples_format}",
- }
- return item
-
- def create_item(self, k):
- item = None
- try:
- style = shared.prompt_styles.styles.get(k)
- fn = os.path.splitext(getattr(style, 'filename', ''))[0]
- name = getattr(style, 'name', '')
- if name == '':
- return item
- txt = f'Prompt: {getattr(style, "prompt", "")}'
- if len(getattr(style, 'negative_prompt', '')) > 0:
- txt += f'\nNegative: {style.negative_prompt}'
- item = {
- "type": 'Style',
- "name": name,
- "title": k,
- "filename": style.filename,
- "search_term": f'{self.search_terms_from_path(name)} {txt}',
- "preview": style.preview if getattr(style, 'preview', None) is not None and style.preview.startswith('data:') else self.find_preview(fn),
- "description": style.description if getattr(style, 'description', None) is not None and len(style.description) > 0 else txt,
- "prompt": getattr(style, 'prompt', ''),
- "negative": getattr(style, 'negative_prompt', ''),
- "extra": getattr(style, 'extra', ''),
- "local_preview": f"{fn}.{shared.opts.samples_format}",
- "onclick": '"' + html.escape(f"""return selectStyle({json.dumps(name)})""") + '"',
- "mtime": getattr(style, 'mtime', 0),
- "size": os.path.getsize(style.filename),
- }
- except Exception as e:
- shared.log.debug(f"Extra networks error: type=style file={k} {e}")
- return item
-
- def list_items(self):
- with concurrent.futures.ThreadPoolExecutor(max_workers=shared.max_workers) as executor:
- future_items = {executor.submit(self.create_item, style): style for style in list(shared.prompt_styles.styles)}
- for future in concurrent.futures.as_completed(future_items):
- item = future.result()
- if item is not None:
- yield item
-
- def allowed_directories_for_previews(self):
- return [v for v in [shared.opts.styles_dir] if v is not None] + ['html']
-
-
-class ExtraNetworkStyles(extra_networks.ExtraNetwork):
- def __init__(self):
- super().__init__('style')
- self.indexes = {}
-
- def activate(self, p, params_list):
- for param in params_list:
- if len(param.items) > 0:
- style = None
- search = param.items[0]
- # style = shared.prompt_styles.find_style(param.items[0])
- match = [s for s in shared.prompt_styles.styles.values() if s.name == search]
- if len(match) > 0:
- style = match[0]
- else:
- match = [s for s in shared.prompt_styles.styles.values() if s.name.startswith(search)]
- if len(match) > 0:
- i = self.indexes.get(search, 0)
- self.indexes[search] = (i + 1) % len(match)
- style = match[self.indexes[search]]
- if style is not None:
- p.styles.append(style.name)
- p.prompts = [styles.merge_prompts(style.prompt, prompt) for prompt in p.prompts]
- p.negative_prompts = [styles.merge_prompts(style.negative_prompt, prompt) for prompt in p.negative_prompts]
- styles.apply_styles_to_extra(p, style)
-
-
- def deactivate(self, p):
- pass
+import os
+import html
+import json
+import concurrent
+from modules import shared, extra_networks, ui_extra_networks, styles
+
+
+class ExtraNetworksPageStyles(ui_extra_networks.ExtraNetworksPage):
+ def __init__(self):
+ super().__init__('Style')
+
+ def refresh(self):
+ shared.prompt_styles.reload()
+
+ def parse_desc(self, desc):
+ lines = desc.strip().split("\n")
+ params = { 'name': '', 'description': '', 'prompt': '', 'negative': '', 'extra': ''}
+ found = ''
+ for line in lines:
+ line = line.strip()
+ if line.lower().startswith('name:'):
+ found = 'name'
+ params['name'] = line[5:].strip()
+ elif line.lower().startswith('description:'):
+ found = 'description'
+ params['description'] = line[12:].strip()
+ elif line.lower().startswith('prompt:'):
+ found = 'prompt'
+ params['prompt'] = line[7:].strip()
+ elif line.lower().startswith('negative:'):
+ found = 'negative'
+ params['negative'] = line[9:].strip()
+ elif line.lower().startswith('extra:'):
+ found = 'extra'
+ params['extra'] = line[6:].strip()
+ elif found != '':
+ params[found] += '\n' + line
+ if params['name'] == '':
+ return None
+ if params['description'] == '':
+ params['description'] = params['name']
+ return params
+
+ def create_style(self, params):
+ from modules.images import FilenameGenerator
+ from hashlib import sha256
+ namegen = FilenameGenerator(p=None, seed=None, prompt=params.get('Prompt', ''), image=None, grid=False)
+ name = namegen.prompt_words()
+ sha = sha256(json.dumps(name).encode()).hexdigest()[0:8]
+ fn = os.path.join(shared.opts.styles_dir, sha + '.json')
+ item = {
+ "type": 'Style',
+ "name": name,
+ "title": name,
+ "filename": fn,
+ "search_term": f'{self.search_terms_from_path(fn)} {params.get("Prompt", "")}',
+ "preview": self.find_preview(name),
+ "description": '',
+ "prompt": params.get('Prompt', ''),
+ "negative": params.get('Negative prompt', ''),
+ "extra": '',
+ "local_preview": f"{name}.{shared.opts.samples_format}",
+ }
+ return item
+
+ def create_item(self, k):
+ item = None
+ try:
+ style = shared.prompt_styles.styles.get(k)
+ fn = os.path.splitext(getattr(style, 'filename', ''))[0]
+ name = getattr(style, 'name', '')
+ if name == '':
+ return item
+ txt = f'Prompt: {getattr(style, "prompt", "")}'
+ if len(getattr(style, 'negative_prompt', '')) > 0:
+ txt += f'\nNegative: {style.negative_prompt}'
+ item = {
+ "type": 'Style',
+ "name": name,
+ "title": k,
+ "filename": style.filename,
+ "search_term": f'{self.search_terms_from_path(name)} {txt}',
+ "preview": style.preview if getattr(style, 'preview', None) is not None and style.preview.startswith('data:') else self.find_preview(fn),
+ "description": style.description if getattr(style, 'description', None) is not None and len(style.description) > 0 else txt,
+ "prompt": getattr(style, 'prompt', ''),
+ "negative": getattr(style, 'negative_prompt', ''),
+ "extra": getattr(style, 'extra', ''),
+ "local_preview": f"{fn}.{shared.opts.samples_format}",
+ "onclick": '"' + html.escape(f"""return selectStyle({json.dumps(name)})""") + '"',
+ "mtime": getattr(style, 'mtime', 0),
+ "size": os.path.getsize(style.filename),
+ }
+ except Exception as e:
+ shared.log.debug(f"Extra networks error: type=style file={k} {e}")
+ return item
+
+ def list_items(self):
+ with concurrent.futures.ThreadPoolExecutor(max_workers=shared.max_workers) as executor:
+ future_items = {executor.submit(self.create_item, style): style for style in list(shared.prompt_styles.styles)}
+ for future in concurrent.futures.as_completed(future_items):
+ item = future.result()
+ if item is not None:
+ yield item
+
+ def allowed_directories_for_previews(self):
+ return [v for v in [shared.opts.styles_dir] if v is not None] + ['html']
+
+
+class ExtraNetworkStyles(extra_networks.ExtraNetwork):
+ def __init__(self):
+ super().__init__('style')
+ self.indexes = {}
+
+ def activate(self, p, params_list):
+ for param in params_list:
+ if len(param.items) > 0:
+ style = None
+ search = param.items[0]
+ # style = shared.prompt_styles.find_style(param.items[0])
+ match = [s for s in shared.prompt_styles.styles.values() if s.name == search]
+ if len(match) > 0:
+ style = match[0]
+ else:
+ match = [s for s in shared.prompt_styles.styles.values() if s.name.startswith(search)]
+ if len(match) > 0:
+ i = self.indexes.get(search, 0)
+ self.indexes[search] = (i + 1) % len(match)
+ style = match[self.indexes[search]]
+ if style is not None:
+ p.styles.append(style.name)
+ p.prompts = [styles.merge_prompts(style.prompt, prompt) for prompt in p.prompts]
+ p.negative_prompts = [styles.merge_prompts(style.negative_prompt, prompt) for prompt in p.negative_prompts]
+ styles.apply_styles_to_extra(p, style)
+
+
+ def deactivate(self, p):
+ pass
diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py
index d94ce39dc..49c4f2d51 100644
--- a/modules/ui_extra_networks_textual_inversion.py
+++ b/modules/ui_extra_networks_textual_inversion.py
@@ -1,79 +1,79 @@
-import json
-import os
-import concurrent
-from modules import shared, sd_hijack, sd_models, ui_extra_networks
-from modules.textual_inversion.textual_inversion import Embedding
-
-
-class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
- def __init__(self):
- super().__init__('Embedding')
- self.allow_negative_prompt = True
- self.embeddings = []
-
- def refresh(self):
- if sd_models.model_data.sd_model is None:
- return
- if shared.backend == shared.Backend.ORIGINAL:
- sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
- elif hasattr(sd_models.model_data.sd_model, 'embedding_db'):
- sd_models.model_data.sd_model.embedding_db.load_textual_inversion_embeddings(force_reload=True)
-
- def create_item(self, embedding: Embedding):
- record = None
- try:
- path, _ext = os.path.splitext(embedding.filename)
- tags = {}
- if embedding.tag is not None:
- tags[embedding.tag]=1
- name = os.path.splitext(embedding.basename)[0]
- record = {
- "type": 'Embedding',
- "name": name,
- "filename": embedding.filename,
- "preview": self.find_preview(embedding.filename),
- "search_term": self.search_terms_from_path(name),
- "prompt": json.dumps(f" {os.path.splitext(embedding.name)[0]}"),
- "local_preview": f"{path}.{shared.opts.samples_format}",
- "tags": tags,
- "mtime": os.path.getmtime(embedding.filename),
- "size": os.path.getsize(embedding.filename),
- }
- record["info"] = self.find_info(embedding.filename)
- record["description"] = self.find_description(embedding.filename, record["info"])
- except Exception as e:
- shared.log.debug(f"Extra networks error: type=embedding file={embedding.filename} {e}")
- return record
-
- def list_items(self):
-
- def list_folder(folder):
- for filename in os.listdir(folder):
- fn = os.path.join(folder, filename)
- if os.path.isfile(fn) and (fn.lower().endswith(".pt") or fn.lower().endswith(".safetensors")):
- embedding = Embedding(vec=0, name=os.path.basename(fn), filename=fn)
- embedding.filename = fn
- self.embeddings.append(embedding)
- elif os.path.isdir(fn) and not fn.startswith('.'):
- list_folder(fn)
-
- if sd_models.model_data.sd_model is None:
- self.embeddings = []
- list_folder(shared.opts.embeddings_dir)
- elif shared.backend == shared.Backend.ORIGINAL:
- self.embeddings = list(sd_hijack.model_hijack.embedding_db.word_embeddings.values())
- elif hasattr(sd_models.model_data.sd_model, 'embedding_db'):
- self.embeddings = list(sd_models.model_data.sd_model.embedding_db.word_embeddings.values())
- else:
- self.embeddings = []
- self.embeddings = sorted(self.embeddings, key=lambda emb: emb.filename)
-
- with concurrent.futures.ThreadPoolExecutor(max_workers=shared.max_workers) as executor:
- future_items = {executor.submit(self.create_item, net): net for net in self.embeddings}
- for future in concurrent.futures.as_completed(future_items):
- item = future.result()
- if item is not None:
- yield item
-
- def allowed_directories_for_previews(self):
- return list(sd_hijack.model_hijack.embedding_db.embedding_dirs)
+import json
+import os
+import concurrent
+from modules import shared, sd_hijack, sd_models, ui_extra_networks
+from modules.textual_inversion.textual_inversion import Embedding
+
+
+class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
+ def __init__(self):
+ super().__init__('Embedding')
+ self.allow_negative_prompt = True
+ self.embeddings = []
+
+ def refresh(self):
+ if sd_models.model_data.sd_model is None:
+ return
+ if shared.backend == shared.Backend.ORIGINAL:
+ sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
+ elif hasattr(sd_models.model_data.sd_model, 'embedding_db'):
+ sd_models.model_data.sd_model.embedding_db.load_textual_inversion_embeddings(force_reload=True)
+
+ def create_item(self, embedding: Embedding):
+ record = None
+ try:
+ path, _ext = os.path.splitext(embedding.filename)
+ tags = {}
+ if embedding.tag is not None:
+ tags[embedding.tag]=1
+ name = os.path.splitext(embedding.basename)[0]
+ record = {
+ "type": 'Embedding',
+ "name": name,
+ "filename": embedding.filename,
+ "preview": self.find_preview(embedding.filename),
+ "search_term": self.search_terms_from_path(name),
+ "prompt": json.dumps(f" {os.path.splitext(embedding.name)[0]}"),
+ "local_preview": f"{path}.{shared.opts.samples_format}",
+ "tags": tags,
+ "mtime": os.path.getmtime(embedding.filename),
+ "size": os.path.getsize(embedding.filename),
+ }
+ record["info"] = self.find_info(embedding.filename)
+ record["description"] = self.find_description(embedding.filename, record["info"])
+ except Exception as e:
+ shared.log.debug(f"Extra networks error: type=embedding file={embedding.filename} {e}")
+ return record
+
+ def list_items(self):
+
+ def list_folder(folder):
+ for filename in os.listdir(folder):
+ fn = os.path.join(folder, filename)
+ if os.path.isfile(fn) and (fn.lower().endswith(".pt") or fn.lower().endswith(".safetensors")):
+ embedding = Embedding(vec=0, name=os.path.basename(fn), filename=fn)
+ embedding.filename = fn
+ self.embeddings.append(embedding)
+ elif os.path.isdir(fn) and not fn.startswith('.'):
+ list_folder(fn)
+
+ if sd_models.model_data.sd_model is None:
+ self.embeddings = []
+ list_folder(shared.opts.embeddings_dir)
+ elif shared.backend == shared.Backend.ORIGINAL:
+ self.embeddings = list(sd_hijack.model_hijack.embedding_db.word_embeddings.values())
+ elif hasattr(sd_models.model_data.sd_model, 'embedding_db'):
+ self.embeddings = list(sd_models.model_data.sd_model.embedding_db.word_embeddings.values())
+ else:
+ self.embeddings = []
+ self.embeddings = sorted(self.embeddings, key=lambda emb: emb.filename)
+
+ with concurrent.futures.ThreadPoolExecutor(max_workers=shared.max_workers) as executor:
+ future_items = {executor.submit(self.create_item, net): net for net in self.embeddings}
+ for future in concurrent.futures.as_completed(future_items):
+ item = future.result()
+ if item is not None:
+ yield item
+
+ def allowed_directories_for_previews(self):
+ return list(sd_hijack.model_hijack.embedding_db.embedding_dirs)
diff --git a/modules/ui_extra_networks_vae.py b/modules/ui_extra_networks_vae.py
index e05920a6b..f36afa7ef 100644
--- a/modules/ui_extra_networks_vae.py
+++ b/modules/ui_extra_networks_vae.py
@@ -1,38 +1,38 @@
-import html
-import json
-import os
-from modules import shared, ui_extra_networks, sd_vae, hashes
-
-
-class ExtraNetworksPageVAEs(ui_extra_networks.ExtraNetworksPage):
- def __init__(self):
- super().__init__('VAE')
-
- def refresh(self):
- shared.refresh_vaes()
-
- def list_items(self):
- for name, filename in sd_vae.vae_dict.items():
- try:
- record = {
- "type": 'VAE',
- "name": name,
- "title": name,
- "filename": filename,
- "hash": hashes.sha256_from_cache(filename, f"vae/{filename}"),
- "search_term": self.search_terms_from_path(filename),
- "preview": self.find_preview(filename),
- "local_preview": f"{os.path.splitext(filename)[0]}.{shared.opts.samples_format}",
- "metadata": {},
- "onclick": '"' + html.escape(f"""return selectVAE({json.dumps(name)})""") + '"',
- "mtime": os.path.getmtime(filename),
- "size": os.path.getsize(filename),
- }
- record["info"] = self.find_info(filename)
- record["description"] = self.find_description(filename, record["info"])
- yield record
- except Exception as e:
- shared.log.debug(f"Extra networks error: type=vae file={filename} {e}")
-
- def allowed_directories_for_previews(self):
- return [v for v in [shared.opts.vae_dir] if v is not None]
+import html
+import json
+import os
+from modules import shared, ui_extra_networks, sd_vae, hashes
+
+
+class ExtraNetworksPageVAEs(ui_extra_networks.ExtraNetworksPage):
+ def __init__(self):
+ super().__init__('VAE')
+
+ def refresh(self):
+ shared.refresh_vaes()
+
+ def list_items(self):
+ for name, filename in sd_vae.vae_dict.items():
+ try:
+ record = {
+ "type": 'VAE',
+ "name": name,
+ "title": name,
+ "filename": filename,
+ "hash": hashes.sha256_from_cache(filename, f"vae/{filename}"),
+ "search_term": self.search_terms_from_path(filename),
+ "preview": self.find_preview(filename),
+ "local_preview": f"{os.path.splitext(filename)[0]}.{shared.opts.samples_format}",
+ "metadata": {},
+ "onclick": '"' + html.escape(f"""return selectVAE({json.dumps(name)})""") + '"',
+ "mtime": os.path.getmtime(filename),
+ "size": os.path.getsize(filename),
+ }
+ record["info"] = self.find_info(filename)
+ record["description"] = self.find_description(filename, record["info"])
+ yield record
+ except Exception as e:
+ shared.log.debug(f"Extra networks error: type=vae file={filename} {e}")
+
+ def allowed_directories_for_previews(self):
+ return [v for v in [shared.opts.vae_dir] if v is not None]
diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py
index 0fb103859..5effd83d3 100644
--- a/modules/ui_postprocessing.py
+++ b/modules/ui_postprocessing.py
@@ -1,87 +1,87 @@
-import json
-import gradio as gr
-from modules import scripts, shared, ui_common, postprocessing, call_queue
-import modules.generation_parameters_copypaste as parameters_copypaste
-from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call # pylint: disable=unused-import
-from modules.extras import run_pnginfo
-from modules.ui_common import infotext_to_html
-
-
-def wrap_pnginfo(image):
- _, geninfo, info = run_pnginfo(image)
- return infotext_to_html(geninfo), info, geninfo
-
-
-def submit_click(tab_index, extras_image, image_batch, extras_batch_input_dir, extras_batch_output_dir, show_extras_results, save_output, *script_inputs):
- result_images, geninfo, js_info = postprocessing.run_postprocessing(tab_index, extras_image, image_batch, extras_batch_input_dir, extras_batch_output_dir, show_extras_results, *script_inputs, save_output=save_output)
- return result_images, geninfo, json.dumps(js_info), ''
-
-
-def create_ui():
- tab_index = gr.State(value=0) # pylint: disable=abstract-class-instantiated
- with gr.Row(equal_height=False, variant='compact', elem_classes="extras"):
- with gr.Column(variant='compact'):
- with gr.Tabs(elem_id="mode_extras"):
- with gr.TabItem('Single Image', id="single_image", elem_id="extras_single_tab") as tab_single:
- extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image")
- with gr.TabItem('Process Batch', id="batch_process", elem_id="extras_batch_process_tab") as tab_batch:
- image_batch = gr.Files(label="Batch process", interactive=True, elem_id="extras_image_batch")
- with gr.TabItem('Process Folder', id="batch_from_directory", elem_id="extras_batch_directory_tab") as tab_batch_dir:
- extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir")
- extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir")
- show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results")
- with gr.Row():
- buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "control"])
- with gr.Row():
- save_output = gr.Checkbox(label='Save output', value=True, elem_id="extras_save_output")
- script_inputs = scripts.scripts_postproc.setup_ui()
- with gr.Column():
- id_part = 'extras'
- with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"):
- submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
- interrupt = gr.Button('Stop', elem_id=f"{id_part}_interrupt", variant='secondary')
- interrupt.click(fn=lambda: shared.state.interrupt(), inputs=[], outputs=[])
- skip = gr.Button('Skip', elem_id=f"{id_part}_skip", variant='secondary')
- skip.click(fn=lambda: shared.state.skip(), inputs=[], outputs=[])
- result_images, generation_info, html_info, html_info_formatted, html_log = ui_common.create_output_panel("extras")
- gr.HTML('File metadata')
- exif_info = gr.HTML(elem_id="pnginfo_html_info")
- gen_info = gr.Text(elem_id="pnginfo_gen_info", visible=False)
- for tabname, button in buttons.items():
- parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(paste_button=button, tabname=tabname, source_text_component=gen_info, source_image_component=extras_image))
-
- tab_single.select(fn=lambda: 0, inputs=[], outputs=[tab_index])
- tab_batch.select(fn=lambda: 1, inputs=[], outputs=[tab_index])
- tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index])
- extras_image.change(
- fn=wrap_gradio_call(wrap_pnginfo),
- inputs=[extras_image],
- outputs=[html_info_formatted, exif_info, gen_info],
- )
- submit.click(
- _js="submit_postprocessing",
- fn=call_queue.wrap_gradio_gpu_call(submit_click, extra_outputs=[None, '']),
- inputs=[
- tab_index,
- extras_image,
- image_batch,
- extras_batch_input_dir,
- extras_batch_output_dir,
- show_extras_results,
- save_output,
- *script_inputs,
- ],
- outputs=[
- result_images,
- html_info,
- generation_info,
- html_log,
- ]
- )
-
- parameters_copypaste.add_paste_fields("extras", extras_image, None)
-
- extras_image.change(
- fn=scripts.scripts_postproc.image_changed,
- inputs=[], outputs=[]
- )
+import json
+import gradio as gr
+from modules import scripts, shared, ui_common, postprocessing, call_queue
+import modules.generation_parameters_copypaste as parameters_copypaste
+from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call # pylint: disable=unused-import
+from modules.extras import run_pnginfo
+from modules.ui_common import infotext_to_html
+
+
+def wrap_pnginfo(image):
+ _, geninfo, info = run_pnginfo(image)
+ return infotext_to_html(geninfo), info, geninfo
+
+
+def submit_click(tab_index, extras_image, image_batch, extras_batch_input_dir, extras_batch_output_dir, show_extras_results, save_output, *script_inputs):
+ result_images, geninfo, js_info = postprocessing.run_postprocessing(tab_index, extras_image, image_batch, extras_batch_input_dir, extras_batch_output_dir, show_extras_results, *script_inputs, save_output=save_output)
+ return result_images, geninfo, json.dumps(js_info), ''
+
+
+def create_ui():
+ tab_index = gr.State(value=0) # pylint: disable=abstract-class-instantiated
+ with gr.Row(equal_height=False, variant='compact', elem_classes="extras"):
+ with gr.Column(variant='compact'):
+ with gr.Tabs(elem_id="mode_extras"):
+ with gr.TabItem('Single Image', id="single_image", elem_id="extras_single_tab") as tab_single:
+ extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image")
+ with gr.TabItem('Process Batch', id="batch_process", elem_id="extras_batch_process_tab") as tab_batch:
+ image_batch = gr.Files(label="Batch process", interactive=True, elem_id="extras_image_batch")
+ with gr.TabItem('Process Folder', id="batch_from_directory", elem_id="extras_batch_directory_tab") as tab_batch_dir:
+ extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir")
+ extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir")
+ show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results")
+ with gr.Row():
+ buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "control"])
+ with gr.Row():
+ save_output = gr.Checkbox(label='Save output', value=True, elem_id="extras_save_output")
+ script_inputs = scripts.scripts_postproc.setup_ui()
+ with gr.Column():
+ id_part = 'extras'
+ with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"):
+ submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
+ interrupt = gr.Button('Stop', elem_id=f"{id_part}_interrupt", variant='secondary')
+ interrupt.click(fn=lambda: shared.state.interrupt(), inputs=[], outputs=[])
+ skip = gr.Button('Skip', elem_id=f"{id_part}_skip", variant='secondary')
+ skip.click(fn=lambda: shared.state.skip(), inputs=[], outputs=[])
+ result_images, generation_info, html_info, html_info_formatted, html_log = ui_common.create_output_panel("extras")
+ gr.HTML('File metadata')
+ exif_info = gr.HTML(elem_id="pnginfo_html_info")
+ gen_info = gr.Text(elem_id="pnginfo_gen_info", visible=False)
+ for tabname, button in buttons.items():
+ parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(paste_button=button, tabname=tabname, source_text_component=gen_info, source_image_component=extras_image))
+
+ tab_single.select(fn=lambda: 0, inputs=[], outputs=[tab_index])
+ tab_batch.select(fn=lambda: 1, inputs=[], outputs=[tab_index])
+ tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index])
+ extras_image.change(
+ fn=wrap_gradio_call(wrap_pnginfo),
+ inputs=[extras_image],
+ outputs=[html_info_formatted, exif_info, gen_info],
+ )
+ submit.click(
+ _js="submit_postprocessing",
+ fn=call_queue.wrap_gradio_gpu_call(submit_click, extra_outputs=[None, '']),
+ inputs=[
+ tab_index,
+ extras_image,
+ image_batch,
+ extras_batch_input_dir,
+ extras_batch_output_dir,
+ show_extras_results,
+ save_output,
+ *script_inputs,
+ ],
+ outputs=[
+ result_images,
+ html_info,
+ generation_info,
+ html_log,
+ ]
+ )
+
+ parameters_copypaste.add_paste_fields("extras", extras_image, None)
+
+ extras_image.change(
+ fn=scripts.scripts_postproc.image_changed,
+ inputs=[], outputs=[]
+ )
diff --git a/modules/ui_prompt_styles.py b/modules/ui_prompt_styles.py
index c7b7f75d5..7c28ac59e 100644
--- a/modules/ui_prompt_styles.py
+++ b/modules/ui_prompt_styles.py
@@ -1,105 +1,105 @@
-# TODO: a1111 compatibility item, not used
-
-import gradio as gr
-from modules import shared, styles
-
-styles_edit_symbol = '\U0001f58c\uFE0F' # 🖌️
-styles_materialize_symbol = '\U0001f4cb' # 📋
-
-
-def select_style(name):
- style = shared.prompt_styles.styles.get(name)
- existing = style is not None
- empty = not name
- prompt = style.prompt if style else gr.update()
- negative_prompt = style.negative_prompt if style else gr.update()
- return prompt, negative_prompt, gr.update(visible=existing), gr.update(visible=not empty)
-
-
-def save_style(name, prompt, negative_prompt):
- if not name:
- return gr.update(visible=False)
- style = styles.Style(name, prompt, negative_prompt)
- shared.prompt_styles.styles[style.name] = style
- shared.prompt_styles.save_styles('')
- return gr.update(visible=True)
-
-
-def delete_style(name):
- if name == "":
- return '', '', ''
- shared.prompt_styles.styles.pop(name, None)
- shared.prompt_styles.save_styles('')
- return '', '', ''
-
-
-def materialize_styles(prompt, negative_prompt, styles): # pylint: disable=redefined-outer-name
- prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles)
- negative_prompt = shared.prompt_styles.apply_negative_styles_to_prompt(negative_prompt, styles)
- return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=negative_prompt), gr.Dropdown.update(value=[])]
-
-
-def refresh_styles():
- return gr.update(choices=list(shared.prompt_styles.styles)), gr.update(choices=list(shared.prompt_styles.styles))
-
-
-class UiPromptStyles:
- def __init__(self, tabname, main_ui_prompt, main_ui_negative_prompt): # pylint: disable=unused-argument
- self.dropdown = gr.Dropdown(label="Styles", elem_id=f"{tabname}_styles", choices=[style.name for style in shared.prompt_styles.styles.values()], value=[], multiselect=True)
-
- """
- def __init__(self, tabname, main_ui_prompt, main_ui_negative_prompt):
- self.tabname = tabname
-
- with gr.Row(elem_id=f"{tabname}_styles_row"):
- self.dropdown = gr.Dropdown(label="Styles", show_label=False, elem_id=f"{tabname}_styles", choices=list(shared.prompt_styles.styles), value=[], multiselect=True, tooltip="Styles")
- edit_button = ui_components.ToolButton(value=styles_edit_symbol, elem_id=f"{tabname}_styles_edit_button", tooltip="Edit styles")
-
- with gr.Box(elem_id=f"{tabname}_styles_dialog", elem_classes="popup-dialog") as styles_dialog:
- with gr.Row():
- self.selection = gr.Dropdown(label="Styles", elem_id=f"{tabname}_styles_edit_select", choices=list(shared.prompt_styles.styles), value=[], allow_custom_value=True, info="Styles allow you to add custom text to prompt. Use the {prompt} token in style text, and it will be replaced with user's prompt when applying style. Otherwise, style's text will be added to the end of the prompt.")
- ui_common.create_refresh_button([self.dropdown, self.selection], shared.prompt_styles.reload, lambda: {"choices": list(shared.prompt_styles.styles)}, f"refresh_{tabname}_styles")
- self.materialize = ui_components.ToolButton(value=styles_materialize_symbol, elem_id=f"{tabname}_style_apply", tooltip="Apply all selected styles from the style selction dropdown in main UI to the prompt.")
-
- with gr.Row():
- self.prompt = gr.Textbox(label="Prompt", show_label=True, elem_id=f"{tabname}_edit_style_prompt", lines=3)
-
- with gr.Row():
- self.neg_prompt = gr.Textbox(label="Negative prompt", show_label=True, elem_id=f"{tabname}_edit_style_neg_prompt", lines=3)
-
- with gr.Row():
- self.save = gr.Button('Save', variant='primary', elem_id=f'{tabname}_edit_style_save', visible=False)
- self.delete = gr.Button('Delete', variant='primary', elem_id=f'{tabname}_edit_style_delete', visible=False)
- self.close = gr.Button('Close', variant='secondary', elem_id=f'{tabname}_edit_style_close')
-
- self.selection.change(
- fn=select_style,
- inputs=[self.selection],
- outputs=[self.prompt, self.neg_prompt, self.delete, self.save],
- show_progress=False,
- )
-
- self.save.click(
- fn=save_style,
- inputs=[self.selection, self.prompt, self.neg_prompt],
- outputs=[self.delete],
- show_progress=False,
- ).then(refresh_styles, outputs=[self.dropdown, self.selection], show_progress=False)
-
- self.delete.click(
- fn=delete_style,
- _js='function(name){ if(name == "") return ""; return confirm("Delete style " + name + "?") ? name : ""; }',
- inputs=[self.selection],
- outputs=[self.selection, self.prompt, self.neg_prompt],
- show_progress=False,
- ).then(refresh_styles, outputs=[self.dropdown, self.selection], show_progress=False)
-
- self.materialize.click(
- fn=materialize_styles,
- inputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown],
- outputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown],
- show_progress=False,
- ).then(fn=None, _js="function(){update_"+tabname+"_tokens(); closePopup();}", show_progress=False)
-
- ui_common.setup_dialog(button_show=edit_button, dialog=styles_dialog, button_close=self.close)
- """
+# TODO: a1111 compatibility item, not used
+
+import gradio as gr
+from modules import shared, styles
+
+styles_edit_symbol = '\U0001f58c\uFE0F' # 🖌️
+styles_materialize_symbol = '\U0001f4cb' # 📋
+
+
+def select_style(name):
+ style = shared.prompt_styles.styles.get(name)
+ existing = style is not None
+ empty = not name
+ prompt = style.prompt if style else gr.update()
+ negative_prompt = style.negative_prompt if style else gr.update()
+ return prompt, negative_prompt, gr.update(visible=existing), gr.update(visible=not empty)
+
+
+def save_style(name, prompt, negative_prompt):
+ if not name:
+ return gr.update(visible=False)
+ style = styles.Style(name, prompt, negative_prompt)
+ shared.prompt_styles.styles[style.name] = style
+ shared.prompt_styles.save_styles('')
+ return gr.update(visible=True)
+
+
+def delete_style(name):
+ if name == "":
+ return '', '', ''
+ shared.prompt_styles.styles.pop(name, None)
+ shared.prompt_styles.save_styles('')
+ return '', '', ''
+
+
+def materialize_styles(prompt, negative_prompt, styles): # pylint: disable=redefined-outer-name
+ prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles)
+ negative_prompt = shared.prompt_styles.apply_negative_styles_to_prompt(negative_prompt, styles)
+ return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=negative_prompt), gr.Dropdown.update(value=[])]
+
+
+def refresh_styles():
+ return gr.update(choices=list(shared.prompt_styles.styles)), gr.update(choices=list(shared.prompt_styles.styles))
+
+
+class UiPromptStyles:
+ def __init__(self, tabname, main_ui_prompt, main_ui_negative_prompt): # pylint: disable=unused-argument
+ self.dropdown = gr.Dropdown(label="Styles", elem_id=f"{tabname}_styles", choices=[style.name for style in shared.prompt_styles.styles.values()], value=[], multiselect=True)
+
+ """
+ def __init__(self, tabname, main_ui_prompt, main_ui_negative_prompt):
+ self.tabname = tabname
+
+ with gr.Row(elem_id=f"{tabname}_styles_row"):
+ self.dropdown = gr.Dropdown(label="Styles", show_label=False, elem_id=f"{tabname}_styles", choices=list(shared.prompt_styles.styles), value=[], multiselect=True, tooltip="Styles")
+ edit_button = ui_components.ToolButton(value=styles_edit_symbol, elem_id=f"{tabname}_styles_edit_button", tooltip="Edit styles")
+
+ with gr.Box(elem_id=f"{tabname}_styles_dialog", elem_classes="popup-dialog") as styles_dialog:
+ with gr.Row():
+ self.selection = gr.Dropdown(label="Styles", elem_id=f"{tabname}_styles_edit_select", choices=list(shared.prompt_styles.styles), value=[], allow_custom_value=True, info="Styles allow you to add custom text to prompt. Use the {prompt} token in style text, and it will be replaced with user's prompt when applying style. Otherwise, style's text will be added to the end of the prompt.")
+ ui_common.create_refresh_button([self.dropdown, self.selection], shared.prompt_styles.reload, lambda: {"choices": list(shared.prompt_styles.styles)}, f"refresh_{tabname}_styles")
+ self.materialize = ui_components.ToolButton(value=styles_materialize_symbol, elem_id=f"{tabname}_style_apply", tooltip="Apply all selected styles from the style selction dropdown in main UI to the prompt.")
+
+ with gr.Row():
+ self.prompt = gr.Textbox(label="Prompt", show_label=True, elem_id=f"{tabname}_edit_style_prompt", lines=3)
+
+ with gr.Row():
+ self.neg_prompt = gr.Textbox(label="Negative prompt", show_label=True, elem_id=f"{tabname}_edit_style_neg_prompt", lines=3)
+
+ with gr.Row():
+ self.save = gr.Button('Save', variant='primary', elem_id=f'{tabname}_edit_style_save', visible=False)
+ self.delete = gr.Button('Delete', variant='primary', elem_id=f'{tabname}_edit_style_delete', visible=False)
+ self.close = gr.Button('Close', variant='secondary', elem_id=f'{tabname}_edit_style_close')
+
+ self.selection.change(
+ fn=select_style,
+ inputs=[self.selection],
+ outputs=[self.prompt, self.neg_prompt, self.delete, self.save],
+ show_progress=False,
+ )
+
+ self.save.click(
+ fn=save_style,
+ inputs=[self.selection, self.prompt, self.neg_prompt],
+ outputs=[self.delete],
+ show_progress=False,
+ ).then(refresh_styles, outputs=[self.dropdown, self.selection], show_progress=False)
+
+ self.delete.click(
+ fn=delete_style,
+ _js='function(name){ if(name == "") return ""; return confirm("Delete style " + name + "?") ? name : ""; }',
+ inputs=[self.selection],
+ outputs=[self.selection, self.prompt, self.neg_prompt],
+ show_progress=False,
+ ).then(refresh_styles, outputs=[self.dropdown, self.selection], show_progress=False)
+
+ self.materialize.click(
+ fn=materialize_styles,
+ inputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown],
+ outputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown],
+ show_progress=False,
+ ).then(fn=None, _js="function(){update_"+tabname+"_tokens(); closePopup();}", show_progress=False)
+
+ ui_common.setup_dialog(button_show=edit_button, dialog=styles_dialog, button_close=self.close)
+ """
diff --git a/modules/ui_tempdir.py b/modules/ui_tempdir.py
index 10aab1b96..8b4196a0b 100644
--- a/modules/ui_tempdir.py
+++ b/modules/ui_tempdir.py
@@ -1,100 +1,100 @@
-import os
-import tempfile
-from collections import namedtuple
-from pathlib import Path
-import gradio as gr
-from PIL import Image, PngImagePlugin
-from modules import shared, errors, paths
-
-
-Savedfile = namedtuple("Savedfile", ["name"])
-debug = errors.log.trace if os.environ.get('SD_PATH_DEBUG', None) is not None else lambda *args, **kwargs: None
-
-
-def register_tmp_file(gradio, filename):
- if hasattr(gradio, 'temp_file_sets'):
- gradio.temp_file_sets[0] = gradio.temp_file_sets[0] | {os.path.abspath(filename)}
-
-
-def check_tmp_file(gradio, filename):
- ok = False
- if hasattr(gradio, 'temp_file_sets'):
- ok = ok or any(filename in fileset for fileset in gradio.temp_file_sets)
- if shared.opts.outdir_samples != '':
- ok = ok or Path(shared.opts.outdir_samples).resolve() in Path(filename).resolve().parents
- else:
- ok = ok or Path(shared.opts.outdir_txt2img_samples).resolve() in Path(filename).resolve().parents
- ok = ok or Path(shared.opts.outdir_img2img_samples).resolve() in Path(filename).resolve().parents
- ok = ok or Path(shared.opts.outdir_extras_samples).resolve() in Path(filename).resolve().parents
- if shared.opts.outdir_grids != '':
- ok = ok or Path(shared.opts.outdir_grids).resolve() in Path(filename).resolve().parents
- else:
- ok = ok or Path(shared.opts.outdir_txt2img_grids).resolve() in Path(filename).resolve().parents
- ok = ok or Path(shared.opts.outdir_img2img_grids).resolve() in Path(filename).resolve().parents
- ok = ok or Path(shared.opts.outdir_save).resolve() in Path(filename).resolve().parents
- ok = ok or Path(shared.opts.outdir_init_images).resolve() in Path(filename).resolve().parents
- return ok
-
-
-def pil_to_temp_file(self, img: Image, dir: str, format="png") -> str: # pylint: disable=redefined-builtin,unused-argument
- """
- # original gradio implementation
- bytes_data = gr.processing_utils.encode_pil_to_bytes(img, format)
- temp_dir = Path(dir) / self.hash_bytes(bytes_data)
- temp_dir.mkdir(exist_ok=True, parents=True)
- filename = str(temp_dir / f"image.{format}")
- img.save(filename, pnginfo=gr.processing_utils.get_pil_metadata(img))
- """
- already_saved_as = getattr(img, 'already_saved_as', None)
- exists = os.path.isfile(already_saved_as) if already_saved_as is not None else False
- debug(f'Image lookup: {already_saved_as} exists={exists}')
- if already_saved_as and exists:
- register_tmp_file(shared.demo, already_saved_as)
- file_obj = Savedfile(already_saved_as)
- name = file_obj.name
- debug(f'Image registered: {name}')
- return name
- if shared.opts.temp_dir != "":
- dir = shared.opts.temp_dir
- use_metadata = False
- metadata = PngImagePlugin.PngInfo()
- for key, value in img.info.items():
- if isinstance(key, str) and isinstance(value, str):
- metadata.add_text(key, value)
- use_metadata = True
- if not os.path.exists(dir):
- os.makedirs(dir, exist_ok=True)
- shared.log.debug(f'Created temp folder: path="{dir}"')
- with tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir) as tmp:
- name = tmp.name
- img.save(name, pnginfo=(metadata if use_metadata else None))
- img.already_saved_as = name
- size = os.path.getsize(name)
- shared.log.debug(f'Saving temp: image="{name}" resolution={img.width}x{img.height} size={size}')
- params = ', '.join([f'{k}: {v}' for k, v in img.info.items()])
- params = params[12:] if params.startswith('parameters: ') else params
- with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file:
- file.write(params)
- return name
-
-
-# override save to file function so that it also writes PNG info
-gr.components.IOComponent.pil_to_temp_file = pil_to_temp_file # gradio >=3.32.0
-
-def on_tmpdir_changed():
- if shared.opts.temp_dir == "":
- return
- register_tmp_file(shared.demo, os.path.join(shared.opts.temp_dir, "x"))
-
-
-def cleanup_tmpdr():
- temp_dir = shared.opts.temp_dir
- if temp_dir == "" or not os.path.isdir(temp_dir):
- return
- for root, _dirs, files in os.walk(temp_dir, topdown=False):
- for name in files:
- _, extension = os.path.splitext(name)
- if extension != ".png" and extension != ".jpg" and extension != ".webp":
- continue
- filename = os.path.join(root, name)
- os.remove(filename)
+import os
+import tempfile
+from collections import namedtuple
+from pathlib import Path
+import gradio as gr
+from PIL import Image, PngImagePlugin
+from modules import shared, errors, paths
+
+
+Savedfile = namedtuple("Savedfile", ["name"])
+debug = errors.log.trace if os.environ.get('SD_PATH_DEBUG', None) is not None else lambda *args, **kwargs: None
+
+
+def register_tmp_file(gradio, filename):
+ if hasattr(gradio, 'temp_file_sets'):
+ gradio.temp_file_sets[0] = gradio.temp_file_sets[0] | {os.path.abspath(filename)}
+
+
+def check_tmp_file(gradio, filename):
+ ok = False
+ if hasattr(gradio, 'temp_file_sets'):
+ ok = ok or any(filename in fileset for fileset in gradio.temp_file_sets)
+ if shared.opts.outdir_samples != '':
+ ok = ok or Path(shared.opts.outdir_samples).resolve() in Path(filename).resolve().parents
+ else:
+ ok = ok or Path(shared.opts.outdir_txt2img_samples).resolve() in Path(filename).resolve().parents
+ ok = ok or Path(shared.opts.outdir_img2img_samples).resolve() in Path(filename).resolve().parents
+ ok = ok or Path(shared.opts.outdir_extras_samples).resolve() in Path(filename).resolve().parents
+ if shared.opts.outdir_grids != '':
+ ok = ok or Path(shared.opts.outdir_grids).resolve() in Path(filename).resolve().parents
+ else:
+ ok = ok or Path(shared.opts.outdir_txt2img_grids).resolve() in Path(filename).resolve().parents
+ ok = ok or Path(shared.opts.outdir_img2img_grids).resolve() in Path(filename).resolve().parents
+ ok = ok or Path(shared.opts.outdir_save).resolve() in Path(filename).resolve().parents
+ ok = ok or Path(shared.opts.outdir_init_images).resolve() in Path(filename).resolve().parents
+ return ok
+
+
+def pil_to_temp_file(self, img: Image, dir: str, format="png") -> str: # pylint: disable=redefined-builtin,unused-argument
+ """
+ # original gradio implementation
+ bytes_data = gr.processing_utils.encode_pil_to_bytes(img, format)
+ temp_dir = Path(dir) / self.hash_bytes(bytes_data)
+ temp_dir.mkdir(exist_ok=True, parents=True)
+ filename = str(temp_dir / f"image.{format}")
+ img.save(filename, pnginfo=gr.processing_utils.get_pil_metadata(img))
+ """
+ already_saved_as = getattr(img, 'already_saved_as', None)
+ exists = os.path.isfile(already_saved_as) if already_saved_as is not None else False
+ debug(f'Image lookup: {already_saved_as} exists={exists}')
+ if already_saved_as and exists:
+ register_tmp_file(shared.demo, already_saved_as)
+ file_obj = Savedfile(already_saved_as)
+ name = file_obj.name
+ debug(f'Image registered: {name}')
+ return name
+ if shared.opts.temp_dir != "":
+ dir = shared.opts.temp_dir
+ use_metadata = False
+ metadata = PngImagePlugin.PngInfo()
+ for key, value in img.info.items():
+ if isinstance(key, str) and isinstance(value, str):
+ metadata.add_text(key, value)
+ use_metadata = True
+ if not os.path.exists(dir):
+ os.makedirs(dir, exist_ok=True)
+ shared.log.debug(f'Created temp folder: path="{dir}"')
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir) as tmp:
+ name = tmp.name
+ img.save(name, pnginfo=(metadata if use_metadata else None))
+ img.already_saved_as = name
+ size = os.path.getsize(name)
+ shared.log.debug(f'Saving temp: image="{name}" resolution={img.width}x{img.height} size={size}')
+ params = ', '.join([f'{k}: {v}' for k, v in img.info.items()])
+ params = params[12:] if params.startswith('parameters: ') else params
+ with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file:
+ file.write(params)
+ return name
+
+
+# override save to file function so that it also writes PNG info
+gr.components.IOComponent.pil_to_temp_file = pil_to_temp_file # gradio >=3.32.0
+
+def on_tmpdir_changed():
+ if shared.opts.temp_dir == "":
+ return
+ register_tmp_file(shared.demo, os.path.join(shared.opts.temp_dir, "x"))
+
+
+def cleanup_tmpdr():
+ temp_dir = shared.opts.temp_dir
+ if temp_dir == "" or not os.path.isdir(temp_dir):
+ return
+ for root, _dirs, files in os.walk(temp_dir, topdown=False):
+ for name in files:
+ _, extension = os.path.splitext(name)
+ if extension != ".png" and extension != ".jpg" and extension != ".webp":
+ continue
+ filename = os.path.join(root, name)
+ os.remove(filename)
diff --git a/repositories/codeformer/facelib/detection/yolov5face/utils/datasets.py b/repositories/codeformer/facelib/detection/yolov5face/utils/datasets.py
old mode 100755
new mode 100644
diff --git a/repositories/codeformer/facelib/detection/yolov5face/utils/general.py b/repositories/codeformer/facelib/detection/yolov5face/utils/general.py
old mode 100755
new mode 100644
diff --git a/repositories/codeformer/scripts/crop_align_face.py b/repositories/codeformer/scripts/crop_align_face.py
old mode 100755
new mode 100644
diff --git a/requirements.txt b/requirements.txt
index 6f0861c2d..3922292be 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,74 +1,74 @@
-setuptools
-addict
-aenum
-aiohttp
-anyio
-appdirs
-astunparse
-blendmodes
-clean-fid
-easydev
-extcolors
-facexlib
-filetype
-future
-gdown
-gfpgan
-GitPython
-httpcore
-inflection
-jsonmerge
-kornia
-lark
-lmdb
-lpips
-omegaconf
-open-clip-torch
-piexif
-psutil
-pyyaml
-resize-right
-rich
-safetensors
-scipy
-tb_nightly
-tensordict
-toml
-torchdiffeq
-voluptuous
-yapf
-scikit-image
-basicsr
-fasteners
-dctorch
-pymatting
-matplotlib
-peft
-orjson
-httpx==0.24.1
-compel==2.0.2
-torchsde==0.2.6
- clip-interrogator==0.6.0
-antlr4-python3-runtime==4.9.3
-requests==2.31.0
-tqdm==4.66.1
-accelerate==0.25.0
-opencv-contrib-python-headless==4.8.1.78
-diffusers==0.25.0
-einops==0.4.1
-gradio==3.43.2
-huggingface_hub==0.20.1
-numexpr==2.8.4
-numpy==1.26.2
-numba==0.58.1
-pandas==1.5.3
-protobuf==3.20.3
-pytorch_lightning==1.9.4
-tokenizers==0.15.0
-transformers==4.36.2
-tomesd==0.1.3
-urllib3==1.26.18
-Pillow==10.1.0
-timm==0.9.12
-pydantic==1.10.13
-typing-extensions==4.9.0
+setuptools
+addict
+aenum
+aiohttp
+anyio
+appdirs
+astunparse
+blendmodes
+clean-fid
+easydev
+extcolors
+facexlib
+filetype
+future
+gdown
+gfpgan
+GitPython
+httpcore
+inflection
+jsonmerge
+kornia
+lark
+lmdb
+lpips
+omegaconf
+open-clip-torch
+piexif
+psutil
+pyyaml
+resize-right
+rich
+safetensors
+scipy
+tb_nightly
+tensordict
+toml
+torchdiffeq
+voluptuous
+yapf
+scikit-image
+basicsr
+fasteners
+dctorch
+pymatting
+matplotlib
+peft
+orjson
+httpx==0.24.1
+compel==2.0.2
+torchsde==0.2.6
+ clip-interrogator==0.6.0
+antlr4-python3-runtime==4.9.3
+requests==2.31.0
+tqdm==4.66.1
+accelerate==0.25.0
+opencv-contrib-python-headless==4.8.1.78
+diffusers==0.25.0
+einops==0.4.1
+gradio==3.43.2
+huggingface_hub==0.20.1
+numexpr==2.8.4
+numpy==1.26.2
+numba==0.58.1
+pandas==1.5.3
+protobuf==3.20.3
+pytorch_lightning==1.9.4
+tokenizers==0.15.0
+transformers==4.36.2
+tomesd==0.1.3
+urllib3==1.26.18
+Pillow==10.1.0
+timm==0.9.12
+pydantic==1.10.13
+typing-extensions==4.9.0
diff --git a/scripts/postprocessing_video.py b/scripts/postprocessing_video.py
index f34036961..8f5e388f5 100644
--- a/scripts/postprocessing_video.py
+++ b/scripts/postprocessing_video.py
@@ -1,47 +1,47 @@
-import gradio as gr
-import modules.images
-from modules import scripts_postprocessing
-
-
-class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
- name = "Video"
-
- def ui(self):
- def video_type_change(video_type):
- return [
- gr.update(visible=video_type != 'None'),
- gr.update(visible=video_type == 'GIF' or video_type == 'PNG'),
- gr.update(visible=video_type == 'MP4'),
- gr.update(visible=video_type == 'MP4'),
- gr.update(visible=video_type == 'MP4'),
- gr.update(visible=video_type == 'MP4'),
- ]
-
- with gr.Row():
- video_type = gr.Dropdown(label='Video file', choices=['None', 'GIF', 'PNG', 'MP4'], value='None')
- duration = gr.Slider(label='Duration', minimum=0.25, maximum=10, step=0.25, value=2, visible=False)
- with gr.Row():
- loop = gr.Checkbox(label='Loop', value=True, visible=False)
- pad = gr.Slider(label='Pad frames', minimum=0, maximum=24, step=1, value=1, visible=False)
- interpolate = gr.Slider(label='Interpolate frames', minimum=0, maximum=24, step=1, value=0, visible=False)
- scale = gr.Slider(label='Rescale', minimum=0.5, maximum=2, step=0.05, value=1, visible=False)
- change = gr.Slider(label='Frame change sensitivity', minimum=0, maximum=1, step=0.05, value=0.3, visible=False)
- with gr.Row():
- filename = gr.Textbox(label='Filename', placeholder='enter filename', lines=1)
- video_type.change(fn=video_type_change, inputs=[video_type], outputs=[duration, loop, pad, interpolate, scale, change])
- return {
- "filename": filename,
- "video_type": video_type,
- "duration": duration,
- "loop": loop,
- "pad": pad,
- "interpolate": interpolate,
- "scale": scale,
- "change": change,
- }
-
- def postprocess(self, images, filename, video_type, duration, loop, pad, interpolate, scale, change): # pylint: disable=arguments-differ
- filename = filename.strip()
- if video_type == 'None' or len(filename) == 0 or images is None or len(images) < 2:
- return
- modules.images.save_video(p=None, filename=filename, images=images, video_type=video_type, duration=duration, loop=loop, pad=pad, interpolate=interpolate, scale=scale, change=change)
+import gradio as gr
+import modules.images
+from modules import scripts_postprocessing
+
+
+class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
+ name = "Video"
+
+ def ui(self):
+ def video_type_change(video_type):
+ return [
+ gr.update(visible=video_type != 'None'),
+ gr.update(visible=video_type == 'GIF' or video_type == 'PNG'),
+ gr.update(visible=video_type == 'MP4'),
+ gr.update(visible=video_type == 'MP4'),
+ gr.update(visible=video_type == 'MP4'),
+ gr.update(visible=video_type == 'MP4'),
+ ]
+
+ with gr.Row():
+ video_type = gr.Dropdown(label='Video file', choices=['None', 'GIF', 'PNG', 'MP4'], value='None')
+ duration = gr.Slider(label='Duration', minimum=0.25, maximum=10, step=0.25, value=2, visible=False)
+ with gr.Row():
+ loop = gr.Checkbox(label='Loop', value=True, visible=False)
+ pad = gr.Slider(label='Pad frames', minimum=0, maximum=24, step=1, value=1, visible=False)
+ interpolate = gr.Slider(label='Interpolate frames', minimum=0, maximum=24, step=1, value=0, visible=False)
+ scale = gr.Slider(label='Rescale', minimum=0.5, maximum=2, step=0.05, value=1, visible=False)
+ change = gr.Slider(label='Frame change sensitivity', minimum=0, maximum=1, step=0.05, value=0.3, visible=False)
+ with gr.Row():
+ filename = gr.Textbox(label='Filename', placeholder='enter filename', lines=1)
+ video_type.change(fn=video_type_change, inputs=[video_type], outputs=[duration, loop, pad, interpolate, scale, change])
+ return {
+ "filename": filename,
+ "video_type": video_type,
+ "duration": duration,
+ "loop": loop,
+ "pad": pad,
+ "interpolate": interpolate,
+ "scale": scale,
+ "change": change,
+ }
+
+ def postprocess(self, images, filename, video_type, duration, loop, pad, interpolate, scale, change): # pylint: disable=arguments-differ
+ filename = filename.strip()
+ if video_type == 'None' or len(filename) == 0 or images is None or len(images) < 2:
+ return
+ modules.images.save_video(p=None, filename=filename, images=images, video_type=video_type, duration=duration, loop=loop, pad=pad, interpolate=interpolate, scale=scale, change=change)
diff --git a/webui.bat b/webui.bat
index 2f3317118..68fd17c86 100755
--- a/webui.bat
+++ b/webui.bat
@@ -1,86 +1,86 @@
-:: --------------------------------------------------------------------------------------------------------------
-:: Do not make any changes to this file. Instead, create a shortcut to this file and add needed arguments there.
-:: --------------------------------------------------------------------------------------------------------------
-
-@echo off
-
-if not defined PYTHON (set PYTHON=python)
-if not defined VENV_DIR (set "VENV_DIR=%~dp0%venv")
-set ERROR_REPORTING=FALSE
-mkdir tmp 2>NUL
-
-%PYTHON% -c "" >tmp/stdout.txt 2>tmp/stderr.txt
-if %ERRORLEVEL% == 0 goto :check_pip
-echo Cannot launch python
-goto :show_stdout_stderr
-
-:check_pip
-%PYTHON% -mpip --help >tmp/stdout.txt 2>tmp/stderr.txt
-if %ERRORLEVEL% == 0 goto :start_venv
-if "%PIP_INSTALLER_LOCATION%" == "" goto :show_stdout_stderr
-%PYTHON% "%PIP_INSTALLER_LOCATION%" >tmp/stdout.txt 2>tmp/stderr.txt
-if %ERRORLEVEL% == 0 goto :start_venv
-echo Cannot install pip
-goto :show_stdout_stderr
-
-:start_venv
-if ["%VENV_DIR%"] == ["-"] goto :skip_venv
-if ["%SKIP_VENV%"] == ["1"] goto :skip_venv
-
-dir "%VENV_DIR%\Scripts\Python.exe" >tmp/stdout.txt 2>tmp/stderr.txt
-if %ERRORLEVEL% == 0 goto :activate_venv
-
-for /f "delims=" %%i in ('CALL %PYTHON% -c "import sys; print(sys.executable)"') do set PYTHON_FULLNAME="%%i"
-echo Using python: %PYTHON_FULLNAME%
-echo Creating VENV: %VENV_DIR%
-%PYTHON_FULLNAME% -m venv "%VENV_DIR%" >tmp/stdout.txt 2>tmp/stderr.txt
-if %ERRORLEVEL% == 0 goto :activate_venv
-echo Failed creating VENV: "%VENV_DIR%"
-goto :show_stdout_stderr
-
-:activate_venv
-set PYTHON="%VENV_DIR%\Scripts\Python.exe"
-echo Using VENV: %VENV_DIR%
-
-:skip_venv
-if [%ACCELERATE%] == ["True"] goto :accelerate
-goto :launch
-
-:accelerate
-set ACCELERATE="%VENV_DIR%\Scripts\accelerate.exe"
-if EXIST %ACCELERATE% goto :accelerate_launch
-
-:launch
-%PYTHON% launch.py %*
-pause
-exit /b
-
-:accelerate_launch
-echo Using accelerate
-%ACCELERATE% launch --num_cpu_threads_per_process=6 launch.py %*
-pause
-exit /b
-
-:show_stdout_stderr
-
-echo.
-echo exit code: %errorlevel%
-
-for /f %%i in ("tmp\stdout.txt") do set size=%%~zi
-if %size% equ 0 goto :show_stderr
-echo.
-echo stdout:
-type tmp\stdout.txt
-
-:show_stderr
-for /f %%i in ("tmp\stderr.txt") do set size=%%~zi
-if %size% equ 0 goto :show_stderr
-echo.
-echo stderr:
-type tmp\stderr.txt
-
-:endofscript
-
-echo.
-echo Launch Failed
-pause
+:: --------------------------------------------------------------------------------------------------------------
+:: Do not make any changes to this file. Instead, create a shortcut to this file and add needed arguments there.
+:: --------------------------------------------------------------------------------------------------------------
+
+@echo off
+
+if not defined PYTHON (set PYTHON=python)
+if not defined VENV_DIR (set "VENV_DIR=%~dp0%venv")
+set ERROR_REPORTING=FALSE
+mkdir tmp 2>NUL
+
+%PYTHON% -c "" >tmp/stdout.txt 2>tmp/stderr.txt
+if %ERRORLEVEL% == 0 goto :check_pip
+echo Cannot launch python
+goto :show_stdout_stderr
+
+:check_pip
+%PYTHON% -mpip --help >tmp/stdout.txt 2>tmp/stderr.txt
+if %ERRORLEVEL% == 0 goto :start_venv
+if "%PIP_INSTALLER_LOCATION%" == "" goto :show_stdout_stderr
+%PYTHON% "%PIP_INSTALLER_LOCATION%" >tmp/stdout.txt 2>tmp/stderr.txt
+if %ERRORLEVEL% == 0 goto :start_venv
+echo Cannot install pip
+goto :show_stdout_stderr
+
+:start_venv
+if ["%VENV_DIR%"] == ["-"] goto :skip_venv
+if ["%SKIP_VENV%"] == ["1"] goto :skip_venv
+
+dir "%VENV_DIR%\Scripts\Python.exe" >tmp/stdout.txt 2>tmp/stderr.txt
+if %ERRORLEVEL% == 0 goto :activate_venv
+
+for /f "delims=" %%i in ('CALL %PYTHON% -c "import sys; print(sys.executable)"') do set PYTHON_FULLNAME="%%i"
+echo Using python: %PYTHON_FULLNAME%
+echo Creating VENV: %VENV_DIR%
+%PYTHON_FULLNAME% -m venv "%VENV_DIR%" >tmp/stdout.txt 2>tmp/stderr.txt
+if %ERRORLEVEL% == 0 goto :activate_venv
+echo Failed creating VENV: "%VENV_DIR%"
+goto :show_stdout_stderr
+
+:activate_venv
+set PYTHON="%VENV_DIR%\Scripts\Python.exe"
+echo Using VENV: %VENV_DIR%
+
+:skip_venv
+if [%ACCELERATE%] == ["True"] goto :accelerate
+goto :launch
+
+:accelerate
+set ACCELERATE="%VENV_DIR%\Scripts\accelerate.exe"
+if EXIST %ACCELERATE% goto :accelerate_launch
+
+:launch
+%PYTHON% launch.py %*
+pause
+exit /b
+
+:accelerate_launch
+echo Using accelerate
+%ACCELERATE% launch --num_cpu_threads_per_process=6 launch.py %*
+pause
+exit /b
+
+:show_stdout_stderr
+
+echo.
+echo exit code: %errorlevel%
+
+for /f %%i in ("tmp\stdout.txt") do set size=%%~zi
+if %size% equ 0 goto :show_stderr
+echo.
+echo stdout:
+type tmp\stdout.txt
+
+:show_stderr
+for /f %%i in ("tmp\stderr.txt") do set size=%%~zi
+if %size% equ 0 goto :show_stderr
+echo.
+echo stderr:
+type tmp\stderr.txt
+
+:endofscript
+
+echo.
+echo Launch Failed
+pause