###############################################################################
# TrajMap_webserver v0.0 | 2024.12.2
# Matej Kožić | mkozic@chem.pmf.hr 
#
# Requirements:
#   MDTraj, Pandas, Numpy, Matplotlib
#
###############################################################################
#-----------------------------------------------------------------------------#
# IMPORTS IMPORTS IMPORTS IMPORTS IMPORTS IMPORTS IMPORTS IMPORTS IMPORTS IMP #
#-----------------------------------------------------------------------------#
###############################################################################

import mdtraj as mdt
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import (MultipleLocator)
import matplotlib.colors as mcolors
import matplotlib

###############################################################################
#-----------------------------------------------------------------------------#
# COLORMAPS COLORMAPS COLORMAPS COLORMAPS COLORMAPS COLORMAPS COLORMAPS COLOR #
#-----------------------------------------------------------------------------#
###############################################################################

### Viridis capped ###
viridis = matplotlib.colormaps["viridis"]
colors = viridis(np.linspace(0, 1, 256))
colors[250:] = mcolors.to_rgba('red')
colors[:5] = mcolors.to_rgba('blue')
viridis_capped = mcolors.LinearSegmentedColormap.from_list('viridis_capped', colors)

### Seismic capped ###
seismic = matplotlib.colormaps["seismic"]
colors = seismic(np.linspace(0, 1, 256))
colors[250:] = mcolors.to_rgba('fuchsia')
colors[:5] = mcolors.to_rgba('cyan')
seismic_capped = mcolors.LinearSegmentedColormap.from_list('seismic_capped', colors)

### Viridis segmented ###
#viridis_segmented = cm.get_cmap viridis')(range(8))
   
###############################################################################
#=============================================================================#
# FUNCTIONS FUNCTIONS FUNCTIONS FUNCTIONS FUNCTIONS FUNCTIONS FUNCTIONS FUNCT #
#=============================================================================#
###############################################################################

#-------------------------------#
#                               #
#-- P R E P R O C E S S I N G --#
#                               #
#-------------------------------#

#-----------------------------------------------------------------------------#

### Trajectories to PDB trajectory for PROTEIN ###

def traj2pdb_PROT(topology, trajectories, stride, savename):
    
    """ Loads the trajectory of the protein and converts it to .pdb trajectory
    that is used for creating a shift graph. Stride defines how much frames
    will the processor skip, e.g. original 1000 frames with stride 10 outputs
    100 frames"""

    if ".pdb" not in savename:
        savename = str (savename + ".pdb" )
    
    print("Loading the trajectory file...")
    load_traj = mdt.load(trajectories,top = topology, stride = stride)
    print(load_traj)
    print("Selecting the backbone...")
    select = load_traj.topology.select("backbone")
    load_traj = load_traj.atom_slice(select)
    print("Aligning...")
    print(load_traj)
    load_traj = load_traj.superpose(load_traj,0)
    print("Saving to ", savename)
    load_traj.save_pdb(savename)
    
#-----------------------------------------------------------------------------#

### Trajectories to PDB trajectory for DNA ###


def traj2pdb_DNA(topology, trajectories, stride, savename):

    """ Loads the trajectory of DNA and converts it to .pdb trajectory
    that is used for creating a shift graph. Stride defines how much frames
    will the processor skip, e.g. original 1000 frames with stride 10 outputs
    100 frames"""

    if ".pdb" not in savename:
        savename = str (savename + ".pdb" )
    
    print("Loading the trajectory file...")
    load_traj = mdt.load(trajectories,top = topology, stride = stride)
    print(load_traj)
    print("Selecting non-water atoms...")
    select = load_traj.topology.select("not water")
    load_traj = load_traj.atom_slice(select)
    print("Aligning...")
    print(load_traj)
    load_traj = load_traj.superpose(load_traj,0)
    print("Saving to ", savename)
    load_traj.save_pdb(savename)

#-----------------------------------------------------------------------------#

### PDB to CSV for PROTEIN ###

def pdb2csv_PROT(file, savename, residues):
    
    """ Takes centered aligned trajectories in .pdb format and creates a 
    shift matrix that is saved as .csv"""

    print ("Importing" , file) 
    data = pd.DataFrame(open(file), dtype = "string")
    data = data.replace(" A1"," A 1",regex=True)
    data = data.replace(" ",",",regex=True)
    print(".")
    data = data.replace(",,,,,",",",regex=True)
    print("..")
    data = data.replace(",,,,",",",regex=True)
    print("...")
    data = data.replace(",,,",",",regex=True)
    print("....")
    data = data.replace(",,",",",regex=True)
    print(".....")
    print("Finishing import...")
    
    data = data.squeeze()
    data = data.str.split(",", expand = True )
    
    data = data.replace( pd.NA ,"0",regex=True)
    data = data.loc[data[2] != "OXT" ]
    data = data.loc[data[2] != "OT2" ]
    data = data.loc[data[0] != "TER" ]
    data = data.loc[data[0] != "MODEL" ]
    data = data.loc[data[0] != "REMARK" ]
    data = data.loc[data[0] != "CONNECT" ]
    data = data.loc[data[0] != "CONECT" ]

    graph_data = pd.DataFrame()
    graph_data_temp = pd.DataFrame()
    concat_counter = 0
    limit = len(data)
    atom = "ATOM"
    end = "END"
    time = 0 
    N = 14
    CA = C = 12
    O =16
    a = 6
    b = 7 
    c = 8 
    i = int(residues) * 4 + 2
    h = 1
     
    try:
        while i < limit :
            if atom in data.iloc[i,0]:
                
                residue = data.iloc[i,5]
                print("Calculating: Row: ", i ,"; Timestep: " , time,
                      "; Residue: ", residue, "; File:", file)
                    
                x_i1 = float(data.iloc[i,a]) #N
                y_i1 = float(data.iloc[i,b])
                z_i1 = float(data.iloc[i,c])
                
                x_i2 = float(data.iloc[i+1,a]) #CA
                y_i2 = float(data.iloc[i+1,b])
                z_i2 = float(data.iloc[i+1,c])
                
                x_i3 = float(data.iloc[i+2,a]) #C
                y_i3 = float(data.iloc[i+2,b])
                z_i3 = float(data.iloc[i+2,c])
                
                x_i4 = float(data.iloc[i+3,a]) #O
                y_i4 = float(data.iloc[i+3,b])
                z_i4 = float(data.iloc[i+3,c])
            
                x_h1 = float(data.iloc[h,a])
                y_h1 = float(data.iloc[h,b])
                z_h1 = float(data.iloc[h,c])
                
                x_h2 = float(data.iloc[h+1,a])
                y_h2 = float(data.iloc[h+1,b])
                z_h2 = float(data.iloc[h+1,c])
                
                x_h3 = float(data.iloc[h+2,a])
                y_h3 = float(data.iloc[h+2,b])
                z_h3 = float(data.iloc[h+2,c])
                
                x_h4 = float(data.iloc[h+3,a])
                y_h4 = float(data.iloc[h+3,b])
                z_h4 = float(data.iloc[h+3,c])
                
                x_cm_i = (x_i1*N + x_i2*CA + x_i3*C + x_i4*O) / (O+N+CA+C)
                y_cm_i = (y_i1*N + y_i2*CA + y_i3*C + y_i4*O) / (O+N+CA+C)
                z_cm_i = (z_i1*N + z_i2*CA + z_i3*C + z_i4*O) / (O+N+CA+C)
                
                x_cm_h = (x_h1*N + x_h2*CA + x_h3*C + x_h4*O) / (O+N+CA+C) 
                y_cm_h = (y_h1*N + y_h2*CA + y_h3*C + y_h4*O) / (O+N+CA+C)
                z_cm_h = (z_h1*N + z_h2*CA + z_h3*C + z_h4*O) / (O+N+CA+C)
            
                value =  ( ((x_cm_i - x_cm_h )**2 +
                            (y_cm_i - y_cm_h)**2 + 
                            (z_cm_i - z_cm_h)**2 ) )**(1/2)  
                h = h + 4
                i = i + 4
    
                coords= pd.DataFrame([residue, time, value])
                coords = coords.transpose()
                
                graph_data_temp = pd.concat([graph_data_temp,coords],axis=0,join="outer")

                if concat_counter == 5:
                    graph_data = pd.concat([graph_data,graph_data_temp],axis=0,join="outer")
                    graph_data_temp = pd.DataFrame()
                    concat_counter = 0
            
            elif end in str(data.iloc[i,0]) :
                print("---------Finished step ", time, "----------")
                time = time + 1
                i = i + 1
                h = 1
                concat_counter = concat_counter + 1      
    except IndexError :
        pass      
    
    graph_data = pd.concat([graph_data,graph_data_temp],axis=0,join="outer")  
    
    max_resid = int(max(np.array(graph_data[0], dtype = "int")))
    
    time = graph_data.iloc[len(graph_data)-10,1]
    
    graph_data = graph_data.set_index(np.arange(0, 
                                     len(graph_data), step = 1))
    matrix = pd.DataFrame(data=None, 
                          index = np.arange(0, int(time)+1, 1), 
                          columns = np.arange(0,max_resid + 1,1))
    k = 0                     
    try:
        while k <= len(graph_data) :
            x = int(graph_data.iloc[k,0])
            y = int(graph_data.iloc[k,1])
            z = float(graph_data.iloc[k,2])
            matrix.iloc[y,x] = z
            print("Transcribing row ", k, "of", file)
            k = k + 1
    except IndexError: pass
    
    matrix = matrix.transpose()
    matrix = matrix.fillna(0)
    matrix.to_csv(savename, index = False)
    
    print("Saved coordinates from pdb to: ",savename)

#-----------------------------------------------------------------------------#

### PDB to CSV for DNA ###

def pdb2csv_DNA(file, savename, residues):
    
    """ Takes centered aligned trajectories in .pdb format and creates a 
    shift matrix that is saved as .csv"""
    
    print ("Importing" , file)
    data = pd.DataFrame(open(file), dtype = "string")
    data = data.replace(" A1"," A 1",regex=True)
    data = data.replace(" ",",",regex=True)
    print(".")
    data = data.replace(",,,,,",",",regex=True)
    print("..")
    data = data.replace(",,,,",",",regex=True)
    print("...")
    data = data.replace(",,,",",",regex=True)
    print("....")
    data = data.replace(",,",",",regex=True)
    print(".....")
    print("Finishing import...")

    data = data.squeeze()
    data = data.str.split(",", expand = True )

    data = data.replace( pd.NA ,"0",regex=True)
    data = data.loc[data[2] != "OXT" ]
    data = data.loc[data[2] != "OT2" ]
    data = data.loc[data[0] != "TER" ]
    data = data.loc[data[0] != "MODEL" ]
    data = data.loc[data[0] != "REMARK" ]
    data = data.loc[data[0] != "CONNECT" ]
    data = data.loc[data[0] != "CONECT" ]
    data = data.loc[data[0] != "CRYTS1" ]

    char_to_remove = "'"
    mask = data.applymap(lambda x: char_to_remove in str(x)).any(axis=1)
    data = data[~mask]

    char_to_remove = "H"
    mask = data.applymap(lambda x: char_to_remove in str(x)).any(axis=1)
    data = data[~mask]

    char_to_remove = "P"
    mask = data.applymap(lambda x: char_to_remove in str(x)).any(axis=1)
    data = data[~mask]

    char_to_remove = "NA"
    mask = data.applymap(lambda x: char_to_remove in str(x)).any(axis=1)
    data = data[~mask]

    stepfinder = 0
    currentline = ""
    while "END" not in currentline : 
        currentline = data.iloc[stepfinder,0]
        stepfinder += 1
        

    graph_data = pd.DataFrame()
    graph_data_temp = pd.DataFrame()
    concat_counter = 0
    limit = len(data)
    atom = "ATOM"
    end = "END"
    time = 0 
    N = 14
    O =16
    a = 6
    b = 7 
    c = 8 
    i = stepfinder
    h = 0
    coords_last = 0
    newchain = 0

    # CHECKPOINT

    framenum = (limit - stepfinder) / (stepfinder )
    print ("|---------------------------------------|")
    print ("   CALCULATED NUMBER OF FRAMES: ", framenum)
    print ("|---------------------------------------|")

     
    try:
        while i < limit :
            if atom in data.iloc[i,0]:
                
                residue = data.iloc[i,5]
                print("Calculating: Row: ", i ,"; Timestep: " , time,
                      "; Residue: ", residue, "; File:", file)
                    
                        
                restype = data.iloc[i,3]
                
                if "DA" in restype:
                    
                    ### i ###
                    
                    x_i_0 = float(data.iloc[i+0,a]) # N9
                    y_i_0 = float(data.iloc[i+0,b])
                    z_i_0 = float(data.iloc[i+0,c])
                    
                    x_i_1 = float(data.iloc[i+1,a]) # C8
                    y_i_1 = float(data.iloc[i+1,b])
                    z_i_1 = float(data.iloc[i+1,c])
                    
                    x_i_2 = float(data.iloc[i+2,a]) # N7
                    y_i_2 = float(data.iloc[i+2,b])
                    z_i_2 = float(data.iloc[i+2,c])
                    
                    x_i_3 = float(data.iloc[i+3,a]) # C5
                    y_i_3 = float(data.iloc[i+3,b])
                    z_i_3 = float(data.iloc[i+3,c])
                    
                    x_i_4 = float(data.iloc[i+4,a]) # C6
                    y_i_4 = float(data.iloc[i+4,b])
                    z_i_4 = float(data.iloc[i+4,c])
                    
                    x_i_5 = float(data.iloc[i+5,a]) # N6
                    y_i_5 = float(data.iloc[i+5,b])
                    z_i_5 = float(data.iloc[i+5,c])
                    
                    x_i_6 = float(data.iloc[i+6,a]) # N1
                    y_i_6 = float(data.iloc[i+6,b])
                    z_i_6 = float(data.iloc[i+6,c])
                    
                    x_i_7 = float(data.iloc[i+7,a]) # C2 
                    y_i_7 = float(data.iloc[i+7,b])
                    z_i_7 = float(data.iloc[i+7,c])
                    
                    x_i_8 = float(data.iloc[i+8,a]) # N3
                    y_i_8 = float(data.iloc[i+8,b])
                    z_i_8 = float(data.iloc[i+8,c])
                    
                    x_i_9 = float(data.iloc[i+9,a]) # C4
                    y_i_9 = float(data.iloc[i+9,b])
                    z_i_9 = float(data.iloc[i+9,c])
                    
                    
                    ### h ###
                    
                    x_h_0 = float(data.iloc[h+0,a])
                    y_h_0 = float(data.iloc[h+0,b])
                    z_h_0 = float(data.iloc[h+0,c])
                    
                    x_h_1 = float(data.iloc[h+1,a])
                    y_h_1 = float(data.iloc[h+1,b])
                    z_h_1 = float(data.iloc[h+1,c])
                    
                    x_h_2 = float(data.iloc[h+2,a])
                    y_h_2 = float(data.iloc[h+2,b])
                    z_h_2 = float(data.iloc[h+2,c])
                    
                    x_h_3 = float(data.iloc[h+3,a])
                    y_h_3 = float(data.iloc[h+3,b])
                    z_h_3 = float(data.iloc[h+3,c])
                    
                    x_h_4 = float(data.iloc[h+4,a])
                    y_h_4 = float(data.iloc[h+4,b])
                    z_h_4 = float(data.iloc[h+4,c])
                    
                    x_h_5 = float(data.iloc[h+5,a])
                    y_h_5 = float(data.iloc[h+5,b])
                    z_h_5 = float(data.iloc[h+5,c])
                    
                    x_h_6 = float(data.iloc[h+6,a])
                    y_h_6 = float(data.iloc[h+6,b])
                    z_h_6 = float(data.iloc[h+6,c])
                    
                    x_h_7 = float(data.iloc[h+7,a])
                    y_h_7 = float(data.iloc[h+7,b])
                    z_h_7 = float(data.iloc[h+7,c])
                    
                    x_h_8 = float(data.iloc[h+8,a])
                    y_h_8 = float(data.iloc[h+8,b])
                    z_h_8 = float(data.iloc[h+8,c])
                    
                    x_h_9 = float(data.iloc[h+9,a])
                    y_h_9 = float(data.iloc[h+9,b])
                    z_h_9 = float(data.iloc[h+9,c])
                    
                    x_cm_i = ( x_i_5*N +
                               x_i_6*N  ) / ( 2*N)
                    
                    y_cm_i = ( y_i_5*N +
                               y_i_6*N  ) / ( 2*N)
                    
                    z_cm_i = ( z_i_5*N +
                               z_i_6*N  ) / ( 2*N)
                    
                    
                    x_cm_h = ( x_h_5*N +
                               x_h_6*N  ) / ( 2*N)
                    
                    y_cm_h = ( y_h_5*N +
                               y_h_6*N  ) / ( 2*N)
                    
                    z_cm_h = ( z_h_5*N +
                               z_h_6*N  ) / ( 2*N)
                    
                    
                    value =  ( ( ( x_cm_i - x_cm_h )**2 +
                                 (y_cm_i - y_cm_h)**2 + 
                                 (z_cm_i - z_cm_h)**2 ) )**(1/2) 
                 
                    h = h + 10
                    i = i + 10
                    
                elif "DT" in restype:
                    
                    ### i ###
                    
                    x_i_0 = float(data.iloc[i+0,a]) # N1
                    y_i_0 = float(data.iloc[i+0,b])
                    z_i_0 = float(data.iloc[i+0,c])
                    
                    x_i_1 = float(data.iloc[i+1,a]) # C6
                    y_i_1 = float(data.iloc[i+1,b])
                    z_i_1 = float(data.iloc[i+1,c])
                    
                    x_i_2 = float(data.iloc[i+2,a]) # C5
                    y_i_2 = float(data.iloc[i+2,b])
                    z_i_2 = float(data.iloc[i+2,c])
                    
                    x_i_3 = float(data.iloc[i+3,a]) # C7
                    y_i_3 = float(data.iloc[i+3,b])
                    z_i_3 = float(data.iloc[i+3,c])
                    
                    x_i_4 = float(data.iloc[i+4,a]) # C4
                    y_i_4 = float(data.iloc[i+4,b])
                    z_i_4 = float(data.iloc[i+4,c])
                    
                    x_i_5 = float(data.iloc[i+5,a]) # O4
                    y_i_5 = float(data.iloc[i+5,b])
                    z_i_5 = float(data.iloc[i+5,c])
                    
                    x_i_6 = float(data.iloc[i+6,a]) # N3
                    y_i_6 = float(data.iloc[i+6,b])
                    z_i_6 = float(data.iloc[i+6,c])
                    
                    x_i_7 = float(data.iloc[i+7,a]) # C2 
                    y_i_7 = float(data.iloc[i+7,b])
                    z_i_7 = float(data.iloc[i+7,c])
                    
                    x_i_8 = float(data.iloc[i+8,a]) # O2
                    y_i_8 = float(data.iloc[i+8,b])
                    z_i_8 = float(data.iloc[i+8,c])
                    
                    
                    
                    ### h ###
                    
                    x_h_0 = float(data.iloc[h+0,a])
                    y_h_0 = float(data.iloc[h+0,b])
                    z_h_0 = float(data.iloc[h+0,c])
                    
                    x_h_1 = float(data.iloc[h+1,a])
                    y_h_1 = float(data.iloc[h+1,b])
                    z_h_1 = float(data.iloc[h+1,c])
                    
                    x_h_2 = float(data.iloc[h+2,a])
                    y_h_2 = float(data.iloc[h+2,b])
                    z_h_2 = float(data.iloc[h+2,c])
                    
                    x_h_3 = float(data.iloc[h+3,a])
                    y_h_3 = float(data.iloc[h+3,b])
                    z_h_3 = float(data.iloc[h+3,c])
                    
                    x_h_4 = float(data.iloc[h+4,a])
                    y_h_4 = float(data.iloc[h+4,b])
                    z_h_4 = float(data.iloc[h+4,c])
                    
                    x_h_5 = float(data.iloc[h+5,a])
                    y_h_5 = float(data.iloc[h+5,b])
                    z_h_5 = float(data.iloc[h+5,c])
                    
                    x_h_6 = float(data.iloc[h+6,a])
                    y_h_6 = float(data.iloc[h+6,b])
                    z_h_6 = float(data.iloc[h+6,c])
                    
                    x_h_7 = float(data.iloc[h+7,a])
                    y_h_7 = float(data.iloc[h+7,b])
                    z_h_7 = float(data.iloc[h+7,c])
                    
                    x_h_8 = float(data.iloc[h+8,a])
                    y_h_8 = float(data.iloc[h+8,b])
                    z_h_8 = float(data.iloc[h+8,c])
                    
                    
                    x_cm_i = (x_i_5*O + x_i_6*N ) / ( 1*N + 1*O)

                    y_cm_i = (y_i_5*O + y_i_6*N ) / ( 1*N + 1*O)

                    z_cm_i = (z_i_5*O + z_i_6*N ) / ( 1*N + 1*O)
                                                   
                   
                    x_cm_h = (x_h_5*O +
                              x_h_6*N ) / ( 1*N + 1*O)

                    y_cm_h = (y_h_5*O +
                              y_h_6*N ) / ( 1*N + 1*O)

                    z_cm_h = (z_h_5*O +
                              z_h_6*N ) / ( 1*N + 1*O)
                      
                    
                    value =  ( ( ( x_cm_i - x_cm_h )**2 +
                                 (y_cm_i - y_cm_h)**2 + 
                                 (z_cm_i - z_cm_h)**2 ) )**(1/2) 
                 
                    h = h + 9
                    i = i + 9             
                    
                elif "DC" in restype:
                    
                    ### i ###
                    
                    x_i_0 = float(data.iloc[i+0,a]) # N1
                    y_i_0 = float(data.iloc[i+0,b])
                    z_i_0 = float(data.iloc[i+0,c])
                    
                    x_i_1 = float(data.iloc[i+1,a]) # C6
                    y_i_1 = float(data.iloc[i+1,b])
                    z_i_1 = float(data.iloc[i+1,c])
                    
                    x_i_2 = float(data.iloc[i+2,a]) # C5
                    y_i_2 = float(data.iloc[i+2,b])
                    z_i_2 = float(data.iloc[i+2,c])
                    
                    x_i_3 = float(data.iloc[i+3,a]) # C4
                    y_i_3 = float(data.iloc[i+3,b])
                    z_i_3 = float(data.iloc[i+3,c])
                    
                    x_i_4 = float(data.iloc[i+4,a]) # N4
                    y_i_4 = float(data.iloc[i+4,b])
                    z_i_4 = float(data.iloc[i+4,c])
                    
                    x_i_5 = float(data.iloc[i+5,a]) # N3
                    y_i_5 = float(data.iloc[i+5,b])
                    z_i_5 = float(data.iloc[i+5,c])
                    
                    x_i_6 = float(data.iloc[i+6,a]) # C2
                    y_i_6 = float(data.iloc[i+6,b])
                    z_i_6 = float(data.iloc[i+6,c])
                    
                    x_i_7 = float(data.iloc[i+7,a]) # O2 
                    y_i_7 = float(data.iloc[i+7,b])
                    z_i_7 = float(data.iloc[i+7,c])
                    
                    
                    ### h ###
                    
                    x_h_0 = float(data.iloc[h+0,a])
                    y_h_0 = float(data.iloc[h+0,b])
                    z_h_0 = float(data.iloc[h+0,c])
                    
                    x_h_1 = float(data.iloc[h+1,a])
                    y_h_1 = float(data.iloc[h+1,b])
                    z_h_1 = float(data.iloc[h+1,c])
                    
                    x_h_2 = float(data.iloc[h+2,a])
                    y_h_2 = float(data.iloc[h+2,b])
                    z_h_2 = float(data.iloc[h+2,c])
                    
                    x_h_3 = float(data.iloc[h+3,a])
                    y_h_3 = float(data.iloc[h+3,b])
                    z_h_3 = float(data.iloc[h+3,c])
                    
                    x_h_4 = float(data.iloc[h+4,a])
                    y_h_4 = float(data.iloc[h+4,b])
                    z_h_4 = float(data.iloc[h+4,c])
                    
                    x_h_5 = float(data.iloc[h+5,a])
                    y_h_5 = float(data.iloc[h+5,b])
                    z_h_5 = float(data.iloc[h+5,c])
                    
                    x_h_6 = float(data.iloc[h+6,a])
                    y_h_6 = float(data.iloc[h+6,b])
                    z_h_6 = float(data.iloc[h+6,c])
                    
                    x_h_7 = float(data.iloc[h+7,a])
                    y_h_7 = float(data.iloc[h+7,b])
                    z_h_7 = float(data.iloc[h+7,c])
                
                
                    
                    x_cm_i = (x_i_4*N + x_i_5*N
                                + x_i_7*O  ) / ( 2*N + 1*O)
                    
                    y_cm_i = (y_i_4*N + y_i_5*N
                                + y_i_7*O  ) / ( 2*N + 1*O)
                    
                    
                    z_cm_i = (z_i_4*N + z_i_5*N
                                + z_i_7*O  ) / ( 2*N + 1*O)
                    
                    
                    x_cm_h = (x_h_4*N + x_h_5*N
                                + x_h_7*O  ) / ( 2*N + 1*O)
                    
                    y_cm_h = (y_h_4*N + y_h_5*N
                                + y_h_7*O  ) / ( 2*N + 1*O)
                    
                    
                    z_cm_h = (z_h_4*N + z_h_5*N
                                + z_h_7*O  ) / ( 2*N + 1*O)
                    
                    
                    
                    
                
                    value =  ( ( ( x_cm_i - x_cm_h )**2 +
                                 (y_cm_i - y_cm_h)**2 + 
                                 (z_cm_i - z_cm_h)**2 ) )**(1/2) 
                 
                    h = h + 8
                    i = i + 8
                    
                elif "DG" in restype:          
                    
                    ### i ###
                    
                    x_i_0 = float(data.iloc[i+0,a]) # N9
                    y_i_0 = float(data.iloc[i+0,b])
                    z_i_0 = float(data.iloc[i+0,c])
                    
                    x_i_1 = float(data.iloc[i+1,a]) # C8
                    y_i_1 = float(data.iloc[i+1,b])
                    z_i_1 = float(data.iloc[i+1,c])
                    
                    x_i_2 = float(data.iloc[i+2,a]) # N7
                    y_i_2 = float(data.iloc[i+2,b])
                    z_i_2 = float(data.iloc[i+2,c])
                    
                    x_i_3 = float(data.iloc[i+3,a]) # C5
                    y_i_3 = float(data.iloc[i+3,b])
                    z_i_3 = float(data.iloc[i+3,c])
                    
                    x_i_4 = float(data.iloc[i+4,a]) # C6
                    y_i_4 = float(data.iloc[i+4,b])
                    z_i_4 = float(data.iloc[i+4,c])
                    
                    x_i_5 = float(data.iloc[i+5,a]) # O6
                    y_i_5 = float(data.iloc[i+5,b])
                    z_i_5 = float(data.iloc[i+5,c])
                    
                    x_i_6 = float(data.iloc[i+6,a]) # N1
                    y_i_6 = float(data.iloc[i+6,b])
                    z_i_6 = float(data.iloc[i+6,c])
                    
                    x_i_7 = float(data.iloc[i+7,a]) # C2
                    y_i_7 = float(data.iloc[i+7,b])
                    z_i_7 = float(data.iloc[i+7,c])
                    
                    x_i_8 = float(data.iloc[i+8,a]) # N2
                    y_i_8 = float(data.iloc[i+8,b])
                    z_i_8 = float(data.iloc[i+8,c])
                    
                    x_i_9 = float(data.iloc[i+9,a]) # N3
                    y_i_9 = float(data.iloc[i+9,b])
                    z_i_9 = float(data.iloc[i+9,c])
                    
                    x_i_10 = float(data.iloc[i+10,a]) # C4
                    y_i_10 = float(data.iloc[i+10,b])
                    z_i_10 = float(data.iloc[i+10,c])
                    
                    ### h ###
                    
                    x_h_0 = float(data.iloc[h+0,a])
                    y_h_0 = float(data.iloc[h+0,b])
                    z_h_0 = float(data.iloc[h+0,c])
                    
                    x_h_1 = float(data.iloc[h+1,a])
                    y_h_1 = float(data.iloc[h+1,b])
                    z_h_1 = float(data.iloc[h+1,c])
                    
                    x_h_2 = float(data.iloc[h+2,a])
                    y_h_2 = float(data.iloc[h+2,b])
                    z_h_2 = float(data.iloc[h+2,c])
                    
                    x_h_3 = float(data.iloc[h+3,a])
                    y_h_3 = float(data.iloc[h+3,b])
                    z_h_3 = float(data.iloc[h+3,c])
                    
                    x_h_4 = float(data.iloc[h+4,a])
                    y_h_4 = float(data.iloc[h+4,b])
                    z_h_4 = float(data.iloc[h+4,c])
                    
                    x_h_5 = float(data.iloc[h+5,a])
                    y_h_5 = float(data.iloc[h+5,b])
                    z_h_5 = float(data.iloc[h+5,c])
                    
                    x_h_6 = float(data.iloc[h+6,a])
                    y_h_6 = float(data.iloc[h+6,b])
                    z_h_6 = float(data.iloc[h+6,c])
                    
                    x_h_7 = float(data.iloc[h+7,a])
                    y_h_7 = float(data.iloc[h+7,b])
                    z_h_7 = float(data.iloc[h+7,c])
                    
                    x_h_8 = float(data.iloc[h+8,a])
                    y_h_8 = float(data.iloc[h+8,b])
                    z_h_8 = float(data.iloc[h+8,c])
                    
                    x_h_9 = float(data.iloc[h+9,a])
                    y_h_9 = float(data.iloc[h+9,b])
                    z_h_9 = float(data.iloc[h+9,c])
                    
                    x_h_10 = float(data.iloc[h+10,a])
                    y_h_10 = float(data.iloc[h+10,b])
                    z_h_10 = float(data.iloc[h+10,c])  
                    
                    
                    x_cm_i = ( x_i_5*O + x_i_6*N + x_i_8*N ) / ( 2*N + 1*O)
                    
                    y_cm_i = ( y_i_5*O + y_i_6*N + y_i_8*N ) / ( 2*N + 1*O)
                    
                    z_cm_i = ( z_i_5*O + z_i_6*N + z_i_8*N ) / ( 2*N + 1*O)
                    
                
                    x_cm_h = ( x_h_5*O +
                               x_h_6*N + x_h_8*N ) / ( 2*N + 1*O)
                    
                    y_cm_h = ( y_h_5*O +
                               y_h_6*N + y_h_8*N ) / ( 2*N + 1*O)
                    
                    z_cm_h = ( z_h_5*O +
                               z_h_6*N + z_h_8*N ) / ( 2*N + 1*O)
                    
                           
                    value =  ( ( ( x_cm_i - x_cm_h )**2 +
                                 (y_cm_i - y_cm_h)**2 + 
                                 (z_cm_i - z_cm_h)**2 ) )**(1/2) 
                 
                    h = h + 11
                    i = i + 11
                    
                coords= pd.DataFrame([residue, time, value])
                coords = coords.transpose()
                
                coords_new = coords.copy()
                
                #CHECK FOR ZERO BASED INDEXING
                if int(residue) ==  1 and newchain == 0:
                    newchain = 1
                    resnum = 0
                else:
                    if int(coords.iloc[0,0]) < int(coords_last.iloc[0,0]):
                        coords_new.iloc[0,0] = str(2*int(coords_last.iloc[0,0]) - resnum)
                        coords.iloc[0,0] = str(int(coords_last.iloc[0,0]) + 1)
                        #print ( int(coords.iloc[0,0]), int(coords_last.iloc[0,0]), int(coords_new.iloc[0,0]), resnum)

                        newchain = 1
                        resnum += 3
                
                coords_last = coords.copy()
            
                graph_data_temp = pd.concat([graph_data_temp,coords],axis=0,join="outer")
                
                #graph_data_temp = pd.concat([graph_data_temp,coords_new],axis=0,join="outer")


                if concat_counter == 5:
                    graph_data = pd.concat([graph_data,graph_data_temp],axis=0,join="outer")
                    graph_data_temp = pd.DataFrame()
                    concat_counter = 0
            
            elif end in str(data.iloc[i,0]) :
                print("---------Finished step ", time, "----------")
                time = time + 1
                i = i + 1
    ### Dodati switch za ovo
                #h = 0
                h = i - 20*stepfinder 
                if h < 0:
                    h = 0
                concat_counter = concat_counter + 1  
                
                newchain = 0
                            
    except IndexError :
        pass      

    graph_data = pd.concat([graph_data,graph_data_temp],axis=0,join="outer")  

    max_resid = int(max(np.array(graph_data[0], dtype = "int")))

    time = graph_data.iloc[len(graph_data)-10,1]

    graph_data = graph_data.set_index(np.arange(0, 
                                     len(graph_data), step = 1))
    matrix = pd.DataFrame(data=None, 
                          index = np.arange(0, int(time)+1, 1), 
                          columns = np.arange(0,max_resid + 1,1))
    k = 0                     
    try:
        while k <= len(graph_data) :
            x = int(graph_data.iloc[k,0])
            y = int(graph_data.iloc[k,1])
            z = float(graph_data.iloc[k,2])
            matrix.iloc[y,x] = z
            #print("Transcribing row ", k, "of", file)
            k = k + 1
    except IndexError: pass

    matrix = matrix.transpose()
    matrix = matrix.fillna(0)
    matrix.to_csv(savename, index = False)

    print("Saved coordinates from pdb to: ",savename)

#-----------------------------------------------------------------------------#

### Loads the CSV matrix ###

def csv2matrix (file):
    
    """ Loads in a CSV matrix into the variable it is assigned to as
    matrix = csv2matrix(csvfile)"""
    
    matrix = pd.read_csv(file, index_col = None)
    return(matrix)

#-----------------------------------------------------------------------------#

### Creates a difference matrix of two CSV matrices ###

def matrix_substract(matrix_A, matrix_B, save_name):

    """ Saves CSV file of the difference of matrix A and matrix B, as 
    diff = matrix_A - matrix_B"""    

    m_A = csv2matrix(matrix_A)
    m_B = csv2matrix(matrix_B)
    diff_matrix = m_A - m_B
    
    diff_matrix.to_csv(save_name, index = False)
    
    print("Saved the difference of ", matrix_A, " - ", matrix_B, " to ", save_name)
    
#-----------------------------------------------------------------------------#

### Matrix to shift graph data ###

def matrix2shift (matrix, res1, res2, time1, time2):    

    """ Creates a dataframe of shift graph data of residues in a range from
    res1 to res2, in a time range from time1 to time2, used as
    data = matrix2shift(...), after which the variable data is visualized.
    Parameters:
            Matrix - Loaded from csv2matrix()
            res1 - first residue 
            res2 - last residue
            time1 - starting time
            time2 - ending time
    For visualizing single residue res1 should equal res2.
    """
    
    matrix_array = np.array(matrix)
    output = pd.DataFrame(data = np.arange(time1,time2,1))    
    
    if res1 == res2 :
        output = matrix_array[res1]
    else:
        y = matrix_array[res1:res2,time1:time2]
        
        i = 0
        while i < len(output) : 
            output.iloc[i] = np.average(y[:,i])
            i = i + 1
    return output

#-----------------------------------------------------------------------------#

#-------------------------------#
#                               #
#-- V I S U A L I Z A T I O N --#
#                               #
#-------------------------------#

#-----------------------------------------------------------------------------#

### Visualizes CSV matrix as a heatmap - Trajectory Map ###

def matrix2map (matrix, savefig, title, 
                x_major, x_minor, y_major, y_minor, vmin, vmax, 
                residues, cmap, aspect):
    
    """ Visualized CSV file as a heatmap -> Trajectory Map
    Parameters are as follows:
        Matrix - matrix loaded with csv2matrix from CSV file
        Savefig - name of the resulting figure
        Title - Title that will apear above the trajectory map
        x_major - Spacing of major tickmarks on x axis. e.g. every 20
        x_minor - Spacing of minor tickmarks on x axis, e.g. every 5
        y_major - Spacing of major tickmarks on y axis. e.g. every 20
        y_minor - Spacing of minor tickmarks on y axis, e.g. every 5
        vmin - Minimal value of the z axis (colorbar). 0 for regular 
            trajectory map, and -vmax for DiffMap
        vmax - Maximal value of the z axis (colorbar).
        residues - number of residues
        cmap - Numerical code for the colormap:
            0 - "Magma", linear cmap for regualar Trajectory Maps
            1 - "Seismic" divergent cmap for DiffMaps
            2 - Capped viridis to detect clipping values in regualr TM'
            3 - Capped seismic to detect clipping values in DIffMaps
            4 NOT - Segmented viridis for quantized values
            4 - "Greys" monochrome luminance only black and white linear map
            5 - "Turbo" outdated colormap for when you want to confuse people
        aspect - Aspect ratio of pixels with numerical codes as follows:
            0 - "auot", aspect determined automatically to fit the map to 
                approximate 3:4 ratio
            1 - "equal" pixels are squares and the map's aspect ratio will be
                number of frames * number of residues
    """
    
    print("Making the TrajMap...")
    matrix = np.array(matrix)

    if cmap == 0 : cmap = "magma"
    elif cmap == 1 :  cmap = "seismic"
    elif cmap == 2 : cmap = viridis_capped
    elif cmap == 3 : cmap = seismic_capped
    elif cmap == 4 : cmap = "Greys"
    elif cmap == 5 : cmap = "turbo"
     
    if aspect == 0 : aspect = "auto"
    elif aspect == 1 : aspect = "equal"

    ylabel = "Residue"
    xlabel = "Frame"
    size = 800
    frames = len(matrix[0])
    
    ax = plt.subplot(111)
    plt.imshow(matrix, cmap = cmap, vmin = vmin, vmax = vmax, aspect = aspect)
    ax.set_title(title)

    ax.invert_yaxis()
    plt.xticks(np.arange(0, frames + 1, step= x_major), fontsize = 7)                 
    plt.yticks(np.arange(0, residues, step=y_major), fontsize = 7) 
    ax.yaxis.set_minor_locator(MultipleLocator(y_minor)) 
    ax.xaxis.set_minor_locator(MultipleLocator(x_minor))
    plt.ylabel(ylabel)
    plt.xlabel (xlabel)
    plt.colorbar( label = "Shift / Å", fraction=0.029, pad=0.028)
    plt.savefig(savefig, dpi = size, bbox_inches='tight')
    print("TrajMap saved to:", savefig)

#-----------------------------------------------------------------------------#

### Visualites shift graph data as a graph - Shift Graph ###

def shift2graph (shift_data, savefig, title, label, 
                 y_min, y_max, y_major, y_minor, 
                 x_min, x_max, x_major, x_minor,
                 roll_avg, color):
    
    """ Creates a graph from shift data.
    Parameters are as follows:
        shift_data - Data loaded with matrix2shift()
        savefig - Name of the figure that will be saved
        title - Title that will appear above the graph
        label - Label of the data that will be on the graph
        y_min - Minimum tickmark value on y axis
        y_max - Maximum tickmark value on the y axis
        y_major - Step between tickmarks on the y axis 
        y_minor - Minor tick marks on y axis
        x_min - Minimum tickmark value on x axis
        x_max - Maximum tickmark value on the x axis
        x_major - Step between tickmarks on the y axis 
        x_minor - Minor tickmarks on x axis
        roll_avg - Number of frames for rolling average calculation
        color - Color of the curve, e.g. "black"
    """
    
    size = 800
    shift_data = pd.DataFrame(shift_data, dtype = "float64")
    roll_avg = shift_data.rolling(roll_avg).mean()
    ax = plt.subplot(111)
    plt.plot(shift_data, linewidth=0.3, color=color, label = label)
    ax.legend(loc='upper left',
              fancybox=True, shadow=True, prop={'size': 7.5})
    plt.plot(roll_avg, linewidth=1, color=color) 

    ax.set_xlim(x_min, x_max) 
    ax.set_ylim(y_min, y_max) 
    plt.yticks(np.arange(y_min, y_max, step=y_major)) 
    plt.xticks(np.arange(x_min, x_max + x_major, step=x_major)) 
    plt.title(title)
    plt.xlabel("Frame")
    plt.ylabel ("Shift / Å")
    ax.yaxis.set_minor_locator(MultipleLocator(y_minor)) 
    ax.xaxis.set_minor_locator(MultipleLocator(x_minor))

    plt.savefig(savefig, dpi= size)

###############################################################################
#\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\#
# END END END END END END END END END END END END END END END END END END END #
#/////////////////////////////////////////////////////////////////////////////#
# END #########################################################################

