From 9f8fcc68c50a906a0e232670178b83a3e56273c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ignacio=20Fern=C3=A1ndez=20Gra=C3=B1a?= <51716758+inafergra@users.noreply.github.com> Date: Mon, 14 Oct 2024 16:31:52 +0200 Subject: [PATCH] [Bug] Use torch.no_grad() in qng-spsa (#25) --- qadence_libs/qinfo_tools/qng.py | 46 ++++++++++++++++----------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/qadence_libs/qinfo_tools/qng.py b/qadence_libs/qinfo_tools/qng.py index 16baab2..f2aa5a7 100644 --- a/qadence_libs/qinfo_tools/qng.py +++ b/qadence_libs/qinfo_tools/qng.py @@ -215,29 +215,29 @@ def qng_spsa( See :class:`~qadence_libs.qinfo_tools.QuantumNaturalGradient` for details. """ + with torch.no_grad(): + # Get estimation of the QFI matrix + vparams_dict = dict(zip(vparams_keys, vparams_values)) + qfi_estimator, qfi_mat_positive_sd = get_quantum_fisher_spsa( + circuit=circuit, + iteration=state["iter"], + vparams_dict=vparams_dict, + previous_qfi_estimator=state["qfi_estimator"], + epsilon=epsilon, + beta=beta, + ) - # Get estimation of the QFI matrix - vparams_dict = dict(zip(vparams_keys, vparams_values)) - qfi_estimator, qfi_mat_positive_sd = get_quantum_fisher_spsa( - circuit=circuit, - iteration=state["iter"], - vparams_dict=vparams_dict, - previous_qfi_estimator=state["qfi_estimator"], - epsilon=epsilon, - beta=beta, - ) - - # Get transformed gradient vector solving the least squares problem - transf_grad = torch.linalg.lstsq( - 0.25 * qfi_mat_positive_sd, - grad_vec, - driver="gelsd", - ).solution + # Get transformed gradient vector solving the least squares problem + transf_grad = torch.linalg.lstsq( + 0.25 * qfi_mat_positive_sd, + grad_vec, + driver="gelsd", + ).solution - for i, p in enumerate(vparams_values): - if p.grad is None: - continue - p.data.add_(transf_grad[i], alpha=-lr) + for i, p in enumerate(vparams_values): + if p.grad is None: + continue + p.data.add_(transf_grad[i], alpha=-lr) - state["iter"] += 1 - state["qfi_estimator"] = qfi_estimator + state["iter"] += 1 + state["qfi_estimator"] = qfi_estimator