From 8a2f6bc86c1487c6852d0dd00a964ea65ecfdd4a Mon Sep 17 00:00:00 2001 From: bartandrews Date: Tue, 21 Jan 2025 17:09:00 +0100 Subject: [PATCH] Revert "try unintegrating bitstring_to_mps from statevector_to_mps" This reverts commit c084141e3167fbafc26dfad8b40e2c5783bdfebc. --- python/ffsim/tenpy/util.py | 96 +++++++++++++++++++------------------- 1 file changed, 48 insertions(+), 48 deletions(-) diff --git a/python/ffsim/tenpy/util.py b/python/ffsim/tenpy/util.py index 027a4f824..d5041eee6 100644 --- a/python/ffsim/tenpy/util.py +++ b/python/ffsim/tenpy/util.py @@ -77,60 +77,60 @@ def statevector_to_mps( The MPS. """ - # # check if state vector is basis state - # basis_state = True if np.count_nonzero(statevector) == 1 else False - - # # generate the ffsim-ordered list of product states - # if basis_state: - # idx = int(np.flatnonzero(statevector)[0]) - # string = ffsim.addresses_to_strings( - # [idx], - # norb, - # nelec, - # concatenate=False, - # bitstring_type=ffsim.BitstringType.INT, - # ) - # bitstring = (string[0][0], string[1][0]) - # product_states = [_bitstring_to_product_state(bitstring, norb)] - # else: - product_states = _generate_product_states(norb, nelec) + # check if state vector is basis state + basis_state = True if np.count_nonzero(statevector) == 1 else False + + # generate the ffsim-ordered list of product states + if basis_state: + idx = int(np.flatnonzero(statevector)[0]) + string = ffsim.addresses_to_strings( + [idx], + norb, + nelec, + concatenate=False, + bitstring_type=ffsim.BitstringType.INT, + ) + bitstring = (string[0][0], string[1][0]) + product_states = [_bitstring_to_product_state(bitstring, norb)] + else: + product_states = _generate_product_states(norb, nelec) # construct the reference product state MPS shfs = SpinHalfFermionSite(cons_N="N", cons_Sz="Sz") mps_reference = MPS.from_product_state([shfs] * norb, product_states[0]) - # if basis_state: - # # compute swap factor - # swap_factor = _compute_swap_factor(mps_reference) - # - # # apply swap factor - # if swap_factor == -1: - # minus_identity_npc = npc.Array.from_ndarray( - # -shfs.get_op("Id").to_ndarray(), - # [shfs.leg, shfs.leg.conj()], - # labels=["p", "p*"], - # ) - # mps_reference.apply_local_op(0, minus_identity_npc) - # - # mps = mps_reference - # else: - # initialize the TeNPy ExactDiag class instance - charge_sector = mps_reference.get_total_charge(True) - exact_diag = ExactDiag(mpo_model, charge_sector=charge_sector) - statevector_reference = exact_diag.mps_to_full(mps_reference) - leg_charge = statevector_reference.legs[0] - - # determine the mapping from ffsim basis to TeNPy basis - basis_ordering_ffsim, swap_factors_ffsim = _map_tenpy_to_ffsim_basis( - product_states, exact_diag - ) - basis_ordering_tenpy = np.argsort(basis_ordering_ffsim) - swap_factors_tenpy = swap_factors_ffsim[np.argsort(basis_ordering_ffsim)] + if basis_state: + # compute swap factor + swap_factor = _compute_swap_factor(mps_reference) + + # apply swap factor + if swap_factor == -1: + minus_identity_npc = npc.Array.from_ndarray( + -shfs.get_op("Id").to_ndarray(), + [shfs.leg, shfs.leg.conj()], + labels=["p", "p*"], + ) + mps_reference.apply_local_op(0, minus_identity_npc) + + mps = mps_reference + else: + # initialize the TeNPy ExactDiag class instance + charge_sector = mps_reference.get_total_charge(True) + exact_diag = ExactDiag(mpo_model, charge_sector=charge_sector) + statevector_reference = exact_diag.mps_to_full(mps_reference) + leg_charge = statevector_reference.legs[0] + + # determine the mapping from ffsim basis to TeNPy basis + basis_ordering_ffsim, swap_factors_ffsim = _map_tenpy_to_ffsim_basis( + product_states, exact_diag + ) + basis_ordering_tenpy = np.argsort(basis_ordering_ffsim) + swap_factors_tenpy = swap_factors_ffsim[np.argsort(basis_ordering_ffsim)] - # convert ffsim statevector to TeNPy MPS - statevector = np.multiply(swap_factors_tenpy, statevector[basis_ordering_tenpy]) - statevector_npc = npc.Array.from_ndarray(statevector, [leg_charge]) - mps = exact_diag.full_to_mps(statevector_npc) + # convert ffsim statevector to TeNPy MPS + statevector = np.multiply(swap_factors_tenpy, statevector[basis_ordering_tenpy]) + statevector_npc = npc.Array.from_ndarray(statevector, [leg_charge]) + mps = exact_diag.full_to_mps(statevector_npc) return mps