Simulate force profiles for a blue-detuned MOT for CaOH

using Revise
using
    QuantumStates,
    OpticalBlochEquations,
    DifferentialEquations,
    UnitsToValue
┌ Info: Precompiling OpticalBlochEquations [691d0331-80d3-41b1-b293-7891a6f4a14f]
└ @ Base loading.jl:1662
using Distributions
uniform_dist = Uniform(0, 2π)
function sample_direction(r=1.0)
    θ = 2π * rand()
    z = rand() * 2 - 1
    return (r * sqrt(1 - z^2) * cos(θ), r * sqrt(1 - z^2) * sin(θ), r * z)
end
;

Reproduce Fig. 1 from “Blue-Detuned Magneto-Optical Trap” (PRL 120, 083201 (2018))

  • Values are from Steck notes of the 87Rb D2 line
  • Energies are in MHz
λ = 780.241209686e-9
Γ = 2π * 6.065e6 # in units of MHz
m = @with_unit 87 "u"
k = 2π / λ
;

\(5^2S_{1/2}\) state

# F = 1
QN_bounds = (E = 0.0, F = 1)
F1_lower = enumerate_states(AngularMomentumState, QN_bounds)

# F = 2
QN_bounds = (E = 6.83468261090429e9, F = 2) # GHz (in units of MHz)
F2_lower = enumerate_states(AngularMomentumState, QN_bounds)
;

\(5^2P_{3/2}\)

E_upper = 384.2304844685e12 # THz (in units of MHz)

# F = 1
QN_bounds = (E = E_upper - 72.9113e6 - 156.947e6, F = 1)
F1_upper = enumerate_states(AngularMomentumState, QN_bounds)

# F = 2
QN_bounds = (E = E_upper - 72.9113e6, F = 2)
F2_upper = enumerate_states(AngularMomentumState, QN_bounds)
;
ground_states = [F1_lower; F2_lower]
excited_states = [F1_upper; F2_upper]
states = [ground_states; excited_states]
;
d = zeros(length(states), length(states), 3)
for (i, state) in enumerate(ground_states)
    for (j, state′) in enumerate(excited_states)
        j += length(ground_states)
        for p in -1:1
            tdm = TDM(state, state′, p)
            d[i,j,p+2] = tdm
            # d[j,i,p+2] = conj(tdm)
        end
    end
end

# Magnetic moments in units of MHz/G
# For now, only assuming a field along the z axis
d_m = zeros(length(states), length(states), 3)
for (i, state) in enumerate(states)
    for (j, state′) in enumerate(states)
        for p  -1:1
            if j > i
                d_m[i,j,p+2] += TDM_magnetic(state, state′, p) * (state.E == state′.E)
                d_m[i,j,p+2] += TDM_magnetic(state′, state, p) * (state.E == state′.E)
            elseif i == j
                d_m[i,j,p+2] += TDM_magnetic(state, state′, p) * (state.E == state′.E)
            end
        end
    end
end
d_m[1:3,1:3,:]     .*= (-1/2)
d_m[4:8,4:8,:]     .*= ( 1/2)
d_m[9:11,9:11,:]   .*= (2/3)
d_m[12:16,12:16,:] .*= (2/3)
# d_m ./= (Γ / (μB * 1e6))
;
# Laser parameters
δf11 = +11.5e6
δf22 = +26.0e6
θ₁ = 0.0
θ₂ = 0.0

Isat = 2.503
s = (113 / Isat) / 12

# pol = σ⁻
pol = σ⁺

# Create MOT beams for F=1 -> F'=1 transition
ω_F1_to_F1 = 2π * (F1_upper[1].E - F1_lower[1].E + δf11)

= +
ϵ = exp(im * θ₁) * rotate_pol(pol, k̂)
laser1 = Laser(k̂, ϵ, ω_F1_to_F1, s)

= -
ϵ = exp(im * θ₁) * rotate_pol(pol, k̂)
laser2 = Laser(k̂, ϵ, ω_F1_to_F1, s)

= +
ϵ = exp(im * θ₂) * rotate_pol(pol, k̂)
laser3 = Laser(k̂, ϵ, ω_F1_to_F1, s)

= -
ϵ = exp(im * θ₂) * rotate_pol(pol, k̂)
laser4 = Laser(k̂, ϵ, ω_F1_to_F1, s)

= +
ϵ = rotate_pol(pol, k̂)
laser5 = Laser(k̂, ϵ, ω_F1_to_F1, s)

= -
ϵ = rotate_pol(pol, k̂)
laser6 = Laser(k̂, ϵ, ω_F1_to_F1, s)

# Create MOT beams for F=2 -> F'=2 transition
ω_F2_to_F2 = 2π * (F2_upper[1].E - F2_lower[1].E + δf22)

= +
ϵ = exp(im * θ₁) * rotate_pol(pol, k̂)
laser7 = Laser(k̂, ϵ, ω_F2_to_F2, s)

= -
ϵ = exp(im * θ₁) * rotate_pol(pol, k̂)
laser8 = Laser(k̂, ϵ, ω_F2_to_F2, s)

= +
ϵ = exp(im * θ₂) * rotate_pol(pol, k̂)
laser9 = Laser(k̂, ϵ, ω_F2_to_F2, s)

= -
ϵ = exp(im * θ₂) * rotate_pol(pol, k̂)
laser10 = Laser(k̂, ϵ, ω_F2_to_F2, s)

= +
ϵ = rotate_pol(pol, k̂)
laser11 = Laser(k̂, ϵ, ω_F2_to_F2, s)

= -
ϵ = rotate_pol(pol, k̂)
laser12 = Laser(k̂, ϵ, ω_F2_to_F2, s)

lasers = [laser1, laser2, laser3, laser4, laser5, laser6, laser7, laser8, laser9, laser10, laser11, laser12]
# lasers = [laser5, laser6, laser11, laser12] # only z beams
;
freq_res = 1e-1
ω_min = freq_res
period = 2π / ω_min

ρ0 = zeros(ComplexF64, length(states), length(states))
ρ0[3,3] = 1.0

particle = Particle()
particle.r0 = [0.0, 0.0, 0.0] 
particle.v = [0.0, 0.0, 0.1] #5 .* sample_direction() #[0.0, 0.0, 0.5] # velocity is in m/s (divided by Γ/k for the simulation)

p = obe0, particle, states, lasers, d, d_m, true, true, λ, Γ, freq_res)
p.B = (0., 0., 5.) .// (_μB * 1e6))
;
LoadError: UndefVarError: _μB not defined
# using BenchmarkTools
# @btime ρ!(dρ, ρ0, p, 0.0)
t_end = 20p.period
tspan = (0., t_end)
step_size = 1e-3
n_periods = 2
times = range(t_end - n_periods * p.period, t_end, step=p.period * step_size)

prob = ODEProblem(ρ!, p.ρ0_vec, tspan, p, reltol=1e-5) # callback=AutoAbstol(false, init_curmax=0.0)) # what does this do?
;
@time sol = DifferentialEquations.solve(prob, alg=DP5(), abstol=1e-5);
 21.067959 seconds (35.42 M allocations: 4.191 GiB, 2.08% gc time, 96.56% compilation time)
using Plots
plot_us = sol.u
plot_ts = sol.t

n_states = size(p.ρ_soa, 1)
plot(size=(800, 400), ylim=(-0.1, 1.1), legend=nothing)
for i in 1:n_states
    state_idx = n_states*(i-1) + i
    plot!(plot_ts, [real(u[state_idx]) for u in plot_us])
end
plot!()
offset = 0
vline!([sol.t[end] - p.period - offset, sol.t[end] - offset], color="red", linestyle=:dash)
# vline!([280, 380], color="red", linestyle=:dash)

sol.u[end][end-2:end]
3-element Vector{ComplexF64}:
   1.9700131428947971 + 0.0im
 -0.06391252360445562 + 0.0im
   0.5584231316344115 + 0.0im
using Statistics, LinearAlgebra
offset = 0
period_idx = find_idx_for_time(p.period, sol.t, true)
force_idxs = Int(length(sol.t) - 1/step_size - offset):(length(sol.t) - offset) # (period_idx - offset):(length(times) - offset)

@time F = calculate_force_from_period(p, sol; force_idxs)
println("Excited population: ", real(sum(diag(mean(sol.u[force_idxs]))[9:end])))
println("Force: ", 1e3 .* F) #sign.(force) .* abs(force ⋅ p.particle.v) / norm(p.particle.v))
println("Acceleration (10^3 m/s^2): ", 1e-3 * QuantumStates.ħ * p.k * Γ * F[3] / m)
  0.247017 seconds (709.97 k allocations: 35.210 MiB, 7.05% gc time, 98.87% compilation time)
LoadError: ArgumentError: use diagm instead of diag to construct a diagonal matrix

Force versus velocity

function prob_func!(p, scan_values, i)
    # Update velocity and position
    p.v .= sample_direction(scan_values.v[i])
    p.v .= round_vel(p.v, p.freq_res)    
    p.r0 .= rand(uniform_dist, 3) .* 2π
    return nothing
end
function param_func(p, scan_values, i)
    return scan_values.v[i]
end
function output_func(p, sol)
    f = real.(sol[end][end-2:end])
    return (f  p.v) / norm(p.v)
end
;
prob = ODEProblem(ρ!, p.ρ0_vec, tspan, p, save_on=false) 
p.freq_res = 1e-2
prob.p.B .= (0.0, 0.0, 0.0)
vs = repeat([[0,1,2,3,4]; collect(5:5:40)], 50) .// k)
scan_values = (v = vs,)
;
@time scan_params, forces = force_scan(prob, scan_values, prob_func!, param_func, output_func; n_threads=20);
vs, averaged_forces = average_forces(scan_params, forces)
averaged_forces[1] = 0.0
plot(vs .*/ k), (1e-3 * ħ * k * Γ / m) .* averaged_forces / t_end, legend=nothing)

Force versus magnetic field

function prob_func!(p, scan_values, i)
    # Update velocity and position
    p.v .= sample_direction(scan_values.vs[i])
    p.r0 .= rand(uniform_dist, 3) .* 2π
    # Round `v` to ensure that the OBEs are periodic
    p.v .= round_vel(p.v, p.freq_res)
    p.B .= (0.0, 0.0, scan_values.Bz[i])
    return nothing
end
function param_func(p, scan_values, i)
    return p.B[3]
end
function output_func(p, sol)
    f = real.(sol[end][end-2:end])
    return f[3] #(f ⋅ p.v) / norm(p.v)
end
;
p = obe0, particle, states, lasers, d, d_m, true, true, λ, Γ, freq_res)
prob = ODEProblem(ρ!, p.ρ0_vec, tspan, p, reltol=1e-5, save_on=false)

n_samples = 50

Bzs = (0:5:100.0) .// (_μB * 1e6))
vs  = 0.1 .* ones(length(Bzs))

scan_values = (Bz = repeat(Bzs, n_samples), vs = repeat(vs, n_samples))
@time scan_params, forces = force_scan(prob, scan_values, prob_func!, param_func, output_func);
Bzs, averaged_forces = average_forces(scan_params, forces)
plot(Bzs, 1e3 * averaged_forces, legend=nothing)

Reproduce Fig. 6 from (New J. Phys. 18, 123017 (2016))

λ = 1
Γ = 2π
m = 1
k_norm = 2π / λ
;
# F = 1
QN_bounds = (E = 0.0, F = 1)
F1_lower = enumerate_states(AngularMomentumState, QN_bounds)
QN_bounds = (E = 1.0, F = 1)
F1_upper = enumerate_states(AngularMomentumState, QN_bounds)
;
ground_states = F1_lower
excited_states = F1_upper
states = [ground_states; excited_states]
;
d = zeros(length(states), length(states), 3)
for (i, state) in enumerate(ground_states)
    for (j, state′) in enumerate(excited_states)
        j += length(ground_states)
        for p in -1:1
            tdm = TDM(state, state′, p)
            d[i,j,p+2] = tdm
            d[j,i,p+2] = conj(tdm)
        end
    end
end

# Magnetic moments in units of MHz/G
# For now, only assuming a field along the z axis
d_m = zeros(length(states), length(states), 3)
for (i, state) in enumerate(states)
    for (j, state′) in enumerate(states)
        for p  -1:1
            if j > i
                d_m[i,j,p+2] += TDM_magnetic(state, state′, p) * (state.E == state′.E)
                d_m[i,j,p+2] += TDM_magnetic(state′, state, p) * (state.E == state′.E)
            elseif i == j
                d_m[i,j,p+2] += TDM_magnetic(state, state′, p) * (state.E == state′.E)
            end
        end
    end
end
# d_m .+= permutedims(d_m, (2,1,3))
# d_m ./= (Γ / μB) # There's a 2π factor here too but it's included in the simulation code
;
# Laser parameters
Δ = -2.5Γ
s = 1.0
θ₁ = 0.0
θ₂ = 0.0

# Create MOT beams for F=1 -> F'=1 transition
ω_F1_to_F1 = 2π * (F1_upper[1].E - F1_lower[1].E) + Δ

k = +
ϵ = exp(im * θ₁) * rotate_pol(σ⁺, k)
laser1 = Laser(k, ϵ, ω_F1_to_F1, s)

k = -
ϵ = exp(im * θ₁) * rotate_pol(σ⁺, k)
laser2 = Laser(k, ϵ, ω_F1_to_F1, s)

k = +
ϵ = exp(im * θ₂) * rotate_pol(σ⁺, k)
laser3 = Laser(k, ϵ, ω_F1_to_F1, s)

k = -
ϵ = exp(im * θ₂) * rotate_pol(σ⁺, k)
laser4 = Laser(k, ϵ, ω_F1_to_F1, s)

k = +
ϵ = rotate_pol(σ⁺, k)
laser5 = Laser(k, ϵ, ω_F1_to_F1, s)

k = -
ϵ = rotate_pol(σ⁺, k)
laser6 = Laser(k, ϵ, ω_F1_to_F1, s)

lasers = [laser1, laser2, laser3, laser4, laser5, laser6]
# lasers = [laser1, laser2]
# lasers = [laser3, laser4]
# lasers = [laser5, laser6]
;
d
6×6×3 Array{Float64, 3}:
[:, :, 1] =
  0.0   0.0        0.0       -0.0       -0.0       -0.0
  0.0   0.0        0.0       -0.707107   0.0        0.0
  0.0   0.0        0.0       -0.0       -0.707107  -0.0
 -0.0  -0.707107  -0.0        0.0        0.0        0.0
 -0.0   0.0       -0.707107   0.0        0.0        0.0
 -0.0   0.0       -0.0        0.0        0.0        0.0

[:, :, 2] =
 0.0        0.0   0.0        0.707107   0.0   0.0
 0.0        0.0   0.0       -0.0       -0.0  -0.0
 0.0        0.0   0.0        0.0        0.0  -0.707107
 0.707107  -0.0   0.0        0.0        0.0   0.0
 0.0       -0.0   0.0        0.0        0.0   0.0
 0.0       -0.0  -0.707107   0.0        0.0   0.0

[:, :, 3] =
  0.0       0.0        0.0  -0.0   0.707107  -0.0
  0.0       0.0        0.0   0.0   0.0        0.707107
  0.0       0.0        0.0  -0.0  -0.0       -0.0
 -0.0       0.0       -0.0   0.0   0.0        0.0
  0.707107  0.0       -0.0   0.0   0.0        0.0
 -0.0       0.707107  -0.0   0.0   0.0        0.0
freq_res = 1e-1
ω_min = freq_res
period = 2π / ω_min

ρ0 = zeros(ComplexF64, length(states), length(states))
ρ0[2,2] = 1.0

particle = Particle()
particle.r0 = [0.0, 0.0, 0.0] # ./ k_norm
particle.v = [0.0, 0.0, 0.1] #.* (Γ / k_norm)

p = obe0, particle, states, lasers, d, d_m, true, true, λ, Γ, freq_res)
p.B .= (0.0, 0.0, 0.0)
;
# using BenchmarkTools
# dρ = deepcopy(ρ0)
# @btime ρ!(dρ, ρ0, p, 0.0)
t_end = 4period
tspan = (0., t_end)
times = range(0, t_end, step=period/1000)
# times = range(t_end - 3p.period, t_end, step=period/1000)

prob = ODEProblem(ρ!, p.ρ0_vec, tspan, p, saveat=times)#, callback=AutoAbstol(false, init_curmax=0.0)) # what does this do?
;
@time sol = DifferentialEquations.solve(prob, alg=DP5(), abstol=1e-4);
  0.004079 seconds (4.05 k allocations: 2.967 MiB)
using Plots
plot_us = sol.u
plot_ts = sol.t

n_states = size(p.ρ_soa, 1)
plot(size=(800, 400), ylim=(-0.1, 1.1), legend=nothing)
for i in 1:n_states
    state_idx = n_states*(i-1) + i
    plot!(plot_ts, [real(u[state_idx]) for u in plot_us])
end
plot!()
offset = 0
vline!([sol.t[end] - p.period - offset, sol.t[end] - offset], color="red", linestyle=:dash)
# vline!([280, 380], color="red", linestyle=:dash)

using Statistics, LinearAlgebra
offset = 0
period_idx = find_idx_for_time(p.period, sol.t, true)
force_idxs = (period_idx - offset + 1):(length(times) - offset)

@time F = calculate_force_from_period(p, sol; force_idxs)
println("Excited population: ", real(sum(diag(mean(sol.u[period_idx:end]))[4:end])))
println("Force: ", 1e3 .* F) #sign.(force) .* abs(force ⋅ p.particle.v) / norm(p.particle.v))
println("Acceleration (10³ m/s^2): ", 1e-3 * QuantumStates.ħ * k_norm * Γ * F / m)

Force versus velocity

function prob_func!(p, scan_values, i)
    # Update velocity and position
    p.v .= (scan_values.v[i], 0.0, 0.1) #sample_direction(inner_config.v[i])
    p.r0 .= rand(uniform_dist, 3) .* 2π
    return nothing
end
function param_func(p, scan_values, i)
    return round(norm(p.v), digits=2)
end
function output_func(p, sol)
    f = sol[end][end]
    return f[1]
end
;
vs = repeat(0:0.2:7, 100)
scan_values = (v = vs,)
@time scan_params, forces = force_scan(prob, scan_values, prob_func!, param_func, output_func);
vs, averaged_forces = average_forces(scan_params, forces)
plot(vs, 1e3 .* averaged_forces, legend=nothing)

Force versus magnetic field

function prob_func!(p, scan_values, i)
    # Update velocity and position
    p.v .= sample_direction(scan_values.vs[i])
    p.r0 .= rand(uniform_dist, 3) .* 2π
    # Round `v` to ensure that the OBEs are periodic
    p.v .= round_vel(p.v, p.freq_res)
    p.B .= (0.0, 0.0, scan_values.Bz[i])
    return nothing
end
function param_func(p, scan_values, i)
    return p.B[3]
end
function output_func(p, sol)
    f = real.(sol[end][end-2:end])
    # return f ⋅ p.v / norm(p.v)
    return f[3]
end
;
p = obe0, particle, states, lasers, d, d_m, true, true, λ=λ, Γ=Γ, freq_res=freq_res)
ODEProblem(ρ!, p.ρ0_vec, tspan, p, save_on=false)

n_samples = 300

Bzs = (0:5:50) .// (_μB * 1e6))
vs  = 0.1 .* ones(length(Bzs))
scan_values = (Bz = repeat(Bzs, n_samples), vs = repeat(vs, n_samples))

@time scan_params, forces = force_scan(prob, scan_values, prob_func!, param_func, output_func);
Bzs, averaged_forces = average_forces(scan_params, forces)
plot(Bzs, 1e3 * averaged_forces / t_end, legend=nothing)
Attempt dense output for the force
# Dense output, how does it compare?
@time sol = DifferentialEquations.solve(prob, alg=DP5(), abstol=1e-6, reltol=1e-7, dense=true);
for i in 1:6
    plot!(i, [real(sol(j)[i,i]) for j in 1:p.period])
end
plot!()
offset = 0
vline!([sol.t[end] - period - offset, sol.t[end] - offset], color="red", linestyle=:dash)
ts = 500:0.001:1000
sum(sol(t) for t  ts) ./ length(ts) |> diag

CaOH

λ = 626e-9
Γ = 2π * 6.4e6 # in units of MHz
m = @with_unit 57 "u"
k = 2π / λ
;
Load \(\tilde{X}(000)\) and \(\tilde{A}(000)\) Hamiltonians
HX = load_from_file("CaOH_000_N0to3_Hamiltonian", "X://My Drive//github//QuantumStates//Hamiltonians//")
HA = load_from_file("CaOH_A000_J12to52_Hamiltonian", "X://My Drive//github//QuantumStates//Hamiltonians//")

HX_N1 = subspace(HX, (N=1,))
HA_J12 = subspace(HA, (Ω=1/2, J=1/2,))
evaluate!(HX_N1); QuantumStates.solve!(HX_N1)
evaluate!(HA_J12); QuantumStates.solve!(HA_J12)

# Add Zeeman term to the X state Hamiltonian
_μB = (μ_B / h) * (1e-6 * 1e-4)
Zeeman_z(state, state′) = Zeeman(state, state′, 0)
HX_N1 = add_to_H(HX_N1, :B_z, gS * _μB * Zeeman_z)
HX_N1.parameters.B_z = 0.0

# Convert A states from Hund's case (a) to case (b)
HX_0110 = load_from_file("CaOH_BendingMode_Hamiltonian", "X://My Drive//github//QuantumStates//Hamiltonians//")
states_A_J12_caseB = convert_basis(HA_J12.states, HX_0110.basis)
basis_idxs, reduced_A_J12_caseB_basis = states_to_basis(states_A_J12_caseB)
full_basis = [HX_N1.basis; reduced_A_J12_caseB_basis]

for i  eachindex(states_A_J12_caseB)
    states_A_J12_caseB[i].coeffs = states_A_J12_caseB[i].coeffs[basis_idxs]
    states_A_J12_caseB[i].basis = reduced_A_J12_caseB_basis
end

states = [HX_N1.states; states_A_J12_caseB]
for state  states
    state.E *= 1e6
end
;
d = zeros(ComplexF64, 16, 16, 3)
d_ge = zeros(ComplexF64, 12, 4, 3)
basis_tdms = get_tdms_two_bases(HX_N1.basis, reduced_A_J12_caseB_basis, TDM)
tdms_between_states!(d_ge, basis_tdms, HX_N1.states, states_A_J12_caseB)
d[1:12, 13:16, :] .= d_ge

d_m = zeros(ComplexF64, 16, 16, 3)
d_m_gg = zeros(ComplexF64, 12, 12, 3)
basis_tdms_m = get_basis_tdms(HX_N1.basis, TDM_magnetic)
tdms_between_states!(d_m_gg, basis_tdms_m, HX_N1.states, HX_N1.states)
d_m[1:12, 1:12, :] .= d_m_gg
;
# Laser parameters
J12_energy = energy(states[1])
J32_energy = energy(states[5])
A_energy = energy(states[13])

δJ12 = +13.0e6
δJ32 = +13.0e6
θ₁ = 0.0
θ₂ = 0.0
s_J12 = 5.0
s_J32 = 5.0
pol = σ⁻

ω_J12 = 2π * (A_energy - J12_energy + δJ12)
ω_J32 = 2π * (A_energy - J32_energy + δJ32)

ϵ = exp(im * θ₁) * rotate_pol(pol, +x̂); laser1 = Laser(+x̂, ϵ, ω_J12, s_J12)
ϵ = exp(im * θ₁) * rotate_pol(pol, -x̂); laser2 = Laser(-x̂, ϵ, ω_J12, s_J12)
ϵ = exp(im * θ₂) * rotate_pol(pol, +ŷ); laser3 = Laser(+ŷ, ϵ, ω_J12, s_J12)
ϵ = exp(im * θ₂) * rotate_pol(pol, -ŷ); laser4 = Laser(-ŷ, ϵ, ω_J12, s_J12)
ϵ = rotate_pol(pol, +ẑ); laser5 = Laser(+ẑ, ϵ, ω_J12, s_J12)
ϵ = rotate_pol(pol, -ẑ); laser6 = Laser(-ẑ, ϵ, ω_J12, s_J12)

ϵ = exp(im * θ₁) * rotate_pol(pol, +x̂); laser7 = Laser(+x̂, ϵ, ω_J32, s_J32)
ϵ = exp(im * θ₁) * rotate_pol(pol, -x̂); laser8 = Laser(-x̂, ϵ, ω_J32, s_J32)
ϵ = exp(im * θ₂) * rotate_pol(pol, +ŷ); laser9 = Laser(+ŷ, ϵ, ω_J32, s_J32)
ϵ = exp(im * θ₂) * rotate_pol(pol, -ŷ); laser10 = Laser(-ŷ, ϵ, ω_J32, s_J32)
ϵ = rotate_pol(pol, +ẑ); laser11 = Laser(+ẑ, ϵ, ω_J32, s_J32)
ϵ = rotate_pol(pol, -ẑ); laser12 = Laser(-ẑ, ϵ, ω_J32, s_J32)

lasers = [laser1, laser2, laser3, laser4, laser5, laser6, laser7, laser8, laser9, laser10, laser11, laser12]
;
# Set initial conditions
particle = Particle()
particle.r0 = [0.0, 0.0, 0.0]
particle.v = [0.0, 0.0, 0.0]

ρ0 = zeros(ComplexF64, length(states), length(states)) # Use a static array for this??? Might be better...
ρ0[1,1] = 1.0

freq_res = 1e-1
p = obe0, particle, states, lasers, d, d_m, true, true, λ, Γ, freq_res; basis_tdms, basis_tdms_m, HX_N1, d_ge, d_m_gg)

p.B .= (0.0, 0.0, 0.0) .// (_μB * 1e6))
p.v .= (0.0, 0.0, 4.0)

t_end = 200p.period+1
tspan = (0., t_end)
# prob = ODEProblem(ρ!, ρ0, tspan, p, callback=force_cb, save_on=false) #, callback=AutoAbstol(false, init_curmax=0.0)) # what does this do?

# ρ0_tmp = deepcopy(ρ0)
# ρ0_tmp_all = [ρ0_tmp for _ in 1:length(times)]
# @time sol = DifferentialEquations.solve(prob, DP5(), ρ0_tmp_all, abstol=1e-5);
;
# Implement a periodic callback to reset the force each period
function reset_force!(integrator)
    force_current_period = integrator.u[end-2:end] / integrator.p.period
    force_diff = abs.(force_current_period - integrator.p.force_last_period)
    integrator.p.force_last_period = force_current_period
    force_tol = 1e-6
    if force_diff[1] < force_tol && force_diff[2] < force_tol && force_diff[3] < force_tol
        terminate!(integrator)
    else
        integrator.u[end-2:end] .= 0.0
    end
    return nothing
end
cb = PeriodicCallback(reset_force!, p.period)
;
using StaticArrays
# using BenchmarkTools
# dρ = deepcopy(ρ0)
# @time ρ!(dρ, ρ0, p, 0.0)
# @btime ρ!($dρ, $ρ0, $p, 0.0)
prob = ODEProblem(ρ!, p.ρ0_vec, tspan, p) 
@time sol = DifferentialEquations.solve(prob, DP5(), callback=cb, abstol=1e-5)
;
tspan
sol.t[end]
using Plots
plot_us = sol.u
plot_ts = sol.t

n_states = size(p.ρ_soa, 1)
plot(size=(800, 400), ylim=(-0.1, 1.1), legend=nothing)
for i in 1:n_states
    state_idx = n_states*(i-1) + i
    plot!(plot_ts, [real(u[state_idx]) for u in plot_us])
end
plot!()
offset = 0
vline!([sol.t[end] - p.period - offset, sol.t[end] - offset], color="red", linestyle=:dash)
# vline!([280, 380], color="red", linestyle=:dash)
using Statistics, LinearAlgebra
offset = 0
period_idx = find_idx_for_time(p.period, sol.t, true)
force_idxs = (period_idx - offset):(length(times) - offset)
@time force_value = calculate_force_from_period(p, sol)

# @time force = calculate_force_from_period(p, sol)
# println("Excited population: ", real(sum(diag(mean(sol.u[force_idxs]))[9:end])))
println("Force: ", force_value)
println("Acceleration (10^3 m/s^2): ", 1e-3 * ħ * k * Γ * force_value[3] / m)

Force versus velocity

# Why not use `EnsembleProblem`? This makes a new copy for every single problem, which we won't need
# Instead, we make `nthreads` problems
# prob = ODEProblem(ρ!, ρ0, tspan, p)
# ensemble_prob = EnsembleProblem(prob, prob_func=prob_func_inner)
# @time sim = solve(ensemble_prob, DP5(), EnsembleThreads(), trajectories=100, batch_size=10, abstol=1e-5, save_everystep=false)
# ;
function prob_func!(p, scan_values, i)
    # Update velocity and position
    p.v .= sample_direction(scan_values.v[i])
    p.v .= round_vel(p.v, p.freq_res)    
    p.r0 .= rand(uniform_dist, 3) .* 2π
    return nothing
end
function param_func(p, scan_values, i)
    return scan_values.v[i]
end
function output_func(p, sol)
    f = p.force_last_period
    return (f  p.v) / norm(p.v)
end
;
freq_res = 1e-1
p = obe0, particle, states, lasers, d, d_m, true, true, λ, Γ, freq_res; basis_tdms, basis_tdms_m, HX_N1, d_ge, d_m_gg)

tspan = (0, 100p.period+1)
prob = ODEProblem(ρ!, p.ρ0_vec, tspan, p, save_on=false) 

prob.p.B .= (0.0, 0.0, 0.0) .// (_μB * 1e6))

n_samples = 100
vs = repeat(0:1:40, n_samples) .// k)
scan_values = (v = vs,)
;
@time scan_params, forces = force_scan(prob, scan_values, prob_func!, param_func, output_func);
vs, averaged_forces = average_forces(scan_params, forces)
averaged_forces[1] = 0.0
plot(vs .*/ k), (1e-3 * ħ * k * Γ / m) .* averaged_forces / t_end, legend=nothing)

Force versus magnetic field

function prob_func!(p, scan_values, i)
    # Update velocity and position
    p.v .= sample_direction(scan_values.v[i])
    p.r0 .= rand(uniform_dist, 3) .* 2π
    # Round `v` to ensure that the OBEs are periodic
    p.v .= round_vel(p.v, p.freq_res)
    
    # Solve Hamiltonian for new `Bz` value (this is expensive, so only do it if the value has changed)
    Bz = scan_values.Bz[i]
    if p.HX_N1.parameters.B_z != Bz
        p.B .= (0.0, 0.0, Bz)
        p.HX_N1.parameters.B_z = scan_values.Bz[i]
        evaluate!(HX_N1)
        QuantumStates.solve!(HX_N1)
        for i  eachindex(HX_N1.states)
            HX_N1.states[i].E *= 1e6
        end
        p.states[1:12] .= HX_N1.states

        # Update TDMs
        tdms_between_states!(p.d_ge, p.basis_tdms, HX_N1.states, states_A_J12_caseB)
        p.d[1:12, 13:16, :] .= d_ge
        tdms_between_states!(d_m_gg, p.basis_tdms_m, HX_N1.states, HX_N1.states)
        p.d_m[1:12, 1:12, :] .= d_m_gg
    end
    
    return nothing
end
function param_func(p, scan_values, i)
    return scan_values.Bz[i]
end
function output_func(p, sol)
    f = calculate_force_from_period(p, sol)
    return f[3] #(f ⋅ p.v) / norm(p.v)
end
;
n_samples = 10

Bzs = (0:2:50) .// (_μB * 1e6))
vs  = (1.0 // k)) .* ones(length(Bzs))

scan_values = (Bz = repeat(Bzs, n_samples), v = repeat(vs, n_samples))
@time scan_params, forces = force_scan(prob, scan_values, prob_func!, param_func, output_func);
Bzs, averaged_forces = average_forces(scan_params, forces)
plot(Bzs .*/ (_μB * 1e6)), (1e-3 * ħ * k * Γ / m) .* averaged_forces, legend=nothing)