'''
MELTS: Calculate Olivine Control Line
=====================================

Extrapolates a melt composition to a more primitive composition under the assumption that only olivine has crystallised from the melt. This calculation might be used for estimating primary melt compositions, or the effect of post-entrapment crystallisation on olivine-hosted melt inclusions.

The calculation uses the rhyoliteMELTS v1.2 model.

Open this code in an executable MyBinder instance (MyBinder links may be slow to load-- please be patient!):

.. image:: https://mybinder.org/badge_logo.svg
  :target: https://mybinder.org/v2/gl/swmatthews-research%2FThermoEngineLite/main?urlpath=%2Fdoc%2Ftree%2F.%2Fdoc%2Fsource%2Fauto_examples%2F_4_unorganized%2Fplot_olivine_control_line.ipynb

'''

# %%
# Initialization
# --------------
#
# Import the necessary packages

from thermoengine import model, magmaforge, redox
from thermoengine.core import chem
import pandas as pd
import numpy as np
from scipy.optimize import root_scalar
import matplotlib.pyplot as plt

# %%
# Retrieve the MELTS v1.2 database and extract the olivine and liquid models.

db = model.Database("MELTS_v1_2")
liq = db.get_phase('Liq')
olv = db.get_phase('Ol')

# %%
# Set up the Calculations
# -----------------------

# %%
# This section sets up the algorithm for incrementally finding the equilibrium olivine composition at the liquidus, adding a small amount of it into the melt, and repeating.
#
# You must run the cells below, but you do not need to change anything inside them to run a calculation.

def calculate_olivine_control_line(
        starting_comp : pd.Series,
        P_bar : float,
        target_Mgn : float,
        step_size : float = 0.005
        ) -> tuple[pd.DataFrame,
                   pd.DataFrame]:
    """
    Calculate an olivine control line, starting at the start composition and
    ending when the equilibrium olivine has the target Mgn. The Fe3/tFe ratio
    must be set in the starting_comp.

    Inputs
    ------
    starting_comp : pd.Series
        The starting composition in weight percent (or grams) of oxides
    P_bar : float
        The pressure in bars
    target_Mgn : float
        The olivine Mgn at which to terminate the calculation
    step_size : float
        The moles of olivine to add to the liquid composition at each step
    
    Outputs
    -------
    pd.DataFrame
        The melt compositions at each step on the olivine control line
    pd.DataFrame
        The olivine compositions at each step on the olivine control line
    """

    if 'Fe2O3' not in starting_comp or starting_comp['Fe2O3'] == 0.0:
        print("Warning: Fe2O3 set to 0.0")
    if 'FeO' not in starting_comp or starting_comp['FeO'] == 0.0:
        print("Warning: FeO is set to 0.0")
    
    if target_Mgn > 1 and target_Mgn < 100:
        target_Mgn = target_Mgn / 100
    elif target_Mgn > 100 or target_Mgn < 0:
        assert False, "Bad target Mg number."

    start_comp_arr = np.zeros(len(chem.OXIDE_ORDER))
    for i, ox in zip(range(len(chem.OXIDE_ORDER)), chem.OXIDE_ORDER):
        if ox in starting_comp:
            start_comp_arr[i] = starting_comp[ox]
        else:
            start_comp_arr[i] = 0.0
    start_comp_arr = start_comp_arr / np.sum(start_comp_arr) * 100

    liq_moles = magmaforge.System.convert_wtoxides_to_liquid_endmems(
        starting_comp, liq, oxides=starting_comp.index)
    
    olv_omni_comps = _calc_omni_comps_for_phases(olv, liq)


    T, olv_mols = find_liquidus_olivine(P_bar, liq_moles, olv_omni_comps)
    Mgn = olv_mols[5]/(olv_mols[1]+olv_mols[5])    
    
    MAX_ITER = 1000
    iternum = 0

    check_phase_sat(T, P_bar, starting_comp)

    # Initialise dataframes for results
    liq_comps = np.zeros([MAX_ITER, len(chem.OXIDE_ORDER)])
    olv_comps = np.zeros([MAX_ITER, olv.endmember_num+1])
    liq_comps[0,:] = start_comp_arr
    olv_comps[0,:-1] = olv_mols
    olv_comps[0,-1] = Mgn 
    T_seq = [T]

    while Mgn < target_Mgn and iternum < MAX_ITER:
        liq_moles += step_size * olv_omni_comps.T.dot(olv_mols)

        T, olv_mols = find_liquidus_olivine(P_bar, liq_moles, olv_omni_comps, T0=T)
        Mgn = olv_mols[5]/(olv_mols[1]+olv_mols[5])

        liq_mol_ox = liq.endmember_mol_oxide_comps.T.dot(liq_moles)
        liq_wtpt = chem.mol_to_wt_oxide(liq_mol_ox, chem.OXIDE_ORDER)
        liq_wtpt = liq_wtpt/np.sum(liq_wtpt) * 100
        liq_comps[iternum+1,:] = liq_wtpt


        olv_comps[iternum+1,:-1] = olv_mols
        olv_comps[iternum+1, -1] = Mgn
        T_seq.append(T)
        iternum += 1
    
    liq_comps = pd.DataFrame(liq_comps[:iternum+1,:], columns=chem.OXIDE_ORDER, index=T_seq)
    olv_comps = pd.DataFrame(olv_comps[:iternum+1,:], 
                             columns=list(olv.endmember_names)+['Mg#'], 
                             index=T_seq)
    
    return liq_comps, olv_comps


def find_liquidus_olivine(P, liq_moles, olv_omni_comps, T0=1500.0):
    res = root_scalar(_affncomp_root, x0=T0, args=(P, liq_moles, olv_omni_comps))
    assert res.converged, "Something has gone wrong when finding the temperature, check inputs."
    T = res.root
    affn, comp = get_olv_affncomp_from_liq(T, P, liq_moles, olv_omni_comps)
    return T, comp

    
def _affncomp_root(T, P, liq_moles, omni_comps):
    return get_olv_affncomp_from_liq(T, P, liq_moles, omni_comps)[0]

def get_olv_affncomp_from_liq(T, P, liq_moles, omni_comps):
    liq_mu = liq.chem_potential(T, P, mol=liq_moles)
    olv_mu = omni_comps.dot(liq_mu)
    affn, comp = olv.affinity_and_comp(T, P, olv_mu)
    return (affn, comp)

def _calc_omni_comps_for_phases(phase, omni_phase):
    """
    Calculates the matrix for conversion of phase compositions into
    omni_component endmembers.
    """

    omni_element_comps = omni_phase.endmember_element_comps
    inv_omni_element_comps = np.linalg.pinv(omni_element_comps)

    TOL = 1e-12

    phs_element_comps = phase.endmember_element_comps
    phs_omni_comps = phs_element_comps.dot(inv_omni_element_comps)

    phs_omni_comps[np.abs(phs_omni_comps) <= TOL] = 0
        
    return phs_omni_comps

def check_phase_sat(T, P, oxide_wtpt):
    sys = magmaforge.System(oxide_wtpt, T_K=T, P_bar=P, database='MELTS_v1_2')
    phs_mass = sys._system_calculator._total_mass_of_every_phase

    for phs in phs_mass:
        if phs not in ['Liquid', 'Olivine'] and phs_mass[phs] > 0.0:
            print(f"Warning: {phs} is saturated: {phs_mass[phs]:.2f} wt% present at olvine-saturation temperature.")

# %%
# User Inputs
# -----------
#
# Provide the starting composition of the magma here. You must have already determined the amount of FeO and Fe2O3 in the magma.

oxide_comp = pd.Series({
                        'SiO2':  48.32,
                        'TiO2':   0.93,
                        'Al2O3': 15.99,
                        'Fe2O3':  0.67,
                        'FeO':    8.36,
                        'MnO':    0.16,
                        'MgO':   10.56,
                        'CaO':   14.01,
                        'Na2O':   1.72,
                        'K2O':    0.05,
                        'P2O5':   0.03,
                        'H2O':    0.3,
                        'CO2':    0.05
                        })
# Composition of olivine-saturated glass from Kistufell, Iceland (Breddam, 2002)


# %%
# Provide the pressure of crystallisation and the target Mg# for the equilibrium olivine:

pressure_bar = 1000.0
target_Mgn = 90.0

# %%
# Run the calculation:

liq_comp, olv_comp = calculate_olivine_control_line(
                                            oxide_comp, 
                                            P_bar=1000.0, 
                                            target_Mgn=90.0
                                            )


# %%
# Results
# -------
#
# We can see the results as tables of the liquid and olivine compositions along the extrapolated liquid line of descent:

liq_comp

# %%

olv_comp

# %%
# We can save the tables (which can be downloaded if you are running this in a binder):

liq_comp.to_csv('liquid_compositions.csv')
olv_comp.to_csv('olivine_compositions.csv')

# %%
# But we can also plot the results:

fig, ax = plt.subplots(3,1, figsize=(3.75, 7))

ax[0].plot(liq_comp['MgO'], liq_comp['FeO'], marker='s', markersize=3)
ax[1].plot(liq_comp['MgO'], olv_comp['Mg#']*100, marker='s', markersize=3)
ax[2].plot(liq_comp['MgO'], liq_comp.index - 273.15, marker='s', markersize=3)

for a in ax:
    a.set_xlabel('MgO (wt%)')

ax[0].set_ylabel('FeO (wt%)')
ax[1].set_ylabel('Fo (mol%)')
ax[2].set_ylabel('Temperature (°C)')

fig.tight_layout()

plt.show()