import os
import sys
from performance_curve import PerformanceCurve
from performance_curve import CreateAVGPerformanceCurve
from scipy.optimize import minimize
import numpy as np
import pickle
from promod3 import loop
import time
from ost import seq, io


def F2(values, frag_length, base_dir):

  print "enter F2 ",values[0]

  start_time = time.time()

  # intended to optimize following parameter set:
  # - sequence_profile
  # - structure_profile
  # 
  # The sequence profile weight is considered to be constantly 1.
  # values[0] is therefore the weight for structure profile to be optimized.
  #
  # the frag_length parameter is sufficient to find all training set data
  # in the predefined directory structure


  sequence_profile_weight = -1.0
  structure_profile_weight = values[0]

  
  test_set_dir = os.path.join(base_dir, "test_sets", "fragments_" + str(frag_length)) 

  # get the training fragments
  fh = open(os.path.join(test_set_dir, "training_fragments.txt"))
  training_fragments = [item.strip() for item in fh.readlines()]
  fh.close()
 
  # load other stuff
  infile = open(os.path.join(test_set_dir, "random_baseline.dat"),'rb')
  baselines = pickle.load(infile)
  infile.close() 
  structure_db = loop.StructureDB.Load(os.path.join(base_dir, "structure_db_without_testset.dat"))
  profile_db = seq.ProfileDB.Load(os.path.join(test_set_dir, "profile_db.dat"))

  curves = list()
  
  for tf in training_fragments:

    data_path = os.path.join(test_set_dir, tf + ".txt")
    data = open(data_path, 'r').readlines()
    frag_seq = data[0].strip()
    profile = profile_db.GetProfile(tf)
    target_path = os.path.join(test_set_dir, tf + ".pdb")
    target_ent = io.LoadPDB(target_path)    
    target_bb_list = loop.BackboneList(frag_seq, target_ent.residues)

    fragger = loop.Fragger(frag_seq)

    fragger.AddSequenceProfileParameters(sequence_profile_weight, profile)
    fragger.AddStructureProfileParameters(structure_profile_weight, profile)
    fragger.Fill(structure_db, 0.0, 100)

    values = list()
    for i in range(len(fragger)):
      values.append(fragger[i].CARMSD(target_bb_list,True))
    pcurve = PerformanceCurve(values, 3.0, 1000)
    baseline = baselines[tf]
    pcurve.Subtract(baseline)
    curves.append(pcurve)


  avg_curve = CreateAVGPerformanceCurve(curves)
  return_value = avg_curve.AUC()

  # we're using a minimizer here...
  return_value = - return_value

  eval_time = time.time() - start_time 

  print "result: ", return_value, "eval time: ", eval_time

  sys.stdout.flush()

  return return_value



def F3(values, frag_length, base_dir):

  print "enter F3 ", values[0], values[1]

  start_time = time.time()

  # intended to optimize following parameter set:
  # - ss_agreement
  # - torsion_fancy
  # - seq_sim
  # 
  # 
  # 
  # The ss_agreement weight is considered to be constantly 1.
  # values[0] is the weight for torsion_fancy to be optimized
  # values[1] is the weight for sequence similarity to be optimized
  #
  # the frag_length parameter is sufficient to find all training set data
  # in the predefined directory structure


  ss_agreement_weight = 1.0
  torsion_fancy_weight = values[0]
  seq_sim_weight = values[1]

  test_set_dir = os.path.join(base_dir, "test_sets","fragments_" + str(frag_length)) 

  # get the training fragments
  fh = open(os.path.join(test_set_dir, "training_fragments.txt"))
  training_fragments = [item.strip() for item in fh.readlines()]
  fh.close()
 
  # load other stuff
  infile = open(os.path.join(test_set_dir, "random_baseline.dat"),'rb')
  baselines = pickle.load(infile)
  infile.close() 
  structure_db = loop.StructureDB.Load(os.path.join(base_dir, "structure_db_without_testset.dat"))
  torsion_sampler_coil = loop.LoadTorsionSamplerCoil()
  torsion_sampler_helix = loop.LoadTorsionSamplerHelical()
  torsion_sampler_extended = loop.LoadTorsionSamplerExtended()
  subst_matrix = seq.alg.BLOSUM62
  


  curves = list()
  
  for tf in training_fragments:

    data_path = os.path.join(test_set_dir, tf + ".txt")
    data = open(data_path, 'r').readlines()
    frag_seq = data[0].strip()
    frag_psipred_pred = [item for item in data[1].strip()]
    frag_psipred_cfi = [int(item) for item in data[2].strip()]
    frag_psipred = loop.PsipredPrediction(frag_psipred_pred, frag_psipred_cfi)
    frag_resname_before = data[4].strip()
    frag_resname_after = data[5].strip()
    target_path = os.path.join(test_set_dir, tf + ".pdb")
    target_ent = io.LoadPDB(target_path)    
    target_bb_list = loop.BackboneList(frag_seq, target_ent.residues)

    fancy_samplers = list()
    for i in range(frag_length):
      pred = frag_psipred_pred[i]
      cfi = frag_psipred_cfi[i]
      if pred == 'H' and cfi >= 6:
        fancy_samplers.append(torsion_sampler_helix)
      elif pred == 'E' and cfi >= 6:
        fancy_samplers.append(torsion_sampler_extended)
      else:
        fancy_samplers.append(torsion_sampler_coil)

    fragger = loop.Fragger(frag_seq)
    fragger.AddSSAgreeParameters(ss_agreement_weight, frag_psipred)
    fragger.AddSeqSimParameters(seq_sim_weight, subst_matrix)
    fragger.AddTorsionProbabilityParameters(torsion_fancy_weight,
                                            fancy_samplers,
                                            frag_resname_before,
                                            frag_resname_after)

    fragger.Fill(structure_db, 0.0, 100)
    values = list()
    for i in range(len(fragger)):
      values.append(fragger[i].CARMSD(target_bb_list,True))
    pcurve = PerformanceCurve(values, 3.0, 1000)
    baseline = baselines[tf]
    pcurve.Subtract(baseline)
    curves.append(pcurve)



  avg_curve = CreateAVGPerformanceCurve(curves)
  return_value = avg_curve.AUC()

  # we're using a minimizer here...
  return_value = -return_value

  eval_time = time.time() - start_time 

  print "result: ", return_value, "eval time: ", eval_time

  sys.stdout.flush()

  return return_value


def F4(values, frag_length, base_dir):

  print "enter F4 ", values[0], values[1], values[2]

  start_time = time.time()

  # intended to optimize following parameter set:
  # - ss_agreement
  # - torsion_fancy
  # - sequence_profile
  # - structure_profile
  # 
  # 
  # The ss_agreement weight is considered to be constantly 1.
  # values[0] is the weight for torsion_fancy to be optimized
  # values[1] is the weight for sequence profile to be optimized
  # values[2] is the weight for structure profile to be optimized
  #
  # the frag_length parameter is sufficient to find all training set data
  # in the predefined directory structure

  ss_agreement_weight = 1.0
  torsion_fancy_weight = values[0]
  seq_profile_weight = values[1]
  struct_profile_weight = values[2]

  test_set_dir = os.path.join(base_dir, "test_sets","fragments_" + str(frag_length)) 

  # get the training fragments
  fh = open(os.path.join(test_set_dir, "training_fragments.txt"))
  training_fragments = [item.strip() for item in fh.readlines()]
  fh.close()
 
  # load other stuff
  infile = open(os.path.join(test_set_dir, "random_baseline.dat"),'rb')
  baselines = pickle.load(infile)
  infile.close() 
  structure_db = loop.StructureDB.Load(os.path.join(base_dir,"structure_db_without_testset.dat"))
  profile_db = seq.ProfileDB.Load(os.path.join(test_set_dir, "profile_db.dat"))
  torsion_sampler_coil = loop.LoadTorsionSamplerCoil()
  torsion_sampler_helix = loop.LoadTorsionSamplerHelical()
  torsion_sampler_extended = loop.LoadTorsionSamplerExtended()
  subst_matrix = seq.alg.BLOSUM62
  


  curves = list()
  
  for tf in training_fragments:

    data_path = os.path.join(test_set_dir, tf + ".txt")
    data = open(data_path, 'r').readlines()
    frag_seq = data[0].strip()
    profile = profile_db.GetProfile(tf)
    frag_psipred_pred = [item for item in data[1].strip()]
    frag_psipred_cfi = [int(item) for item in data[2].strip()]
    frag_psipred = loop.PsipredPrediction(frag_psipred_pred, frag_psipred_cfi)
    frag_resname_before = data[4].strip()
    frag_resname_after = data[5].strip()
    target_path = os.path.join(test_set_dir, tf + ".pdb")
    target_ent = io.LoadPDB(target_path)    
    target_bb_list = loop.BackboneList(frag_seq, target_ent.residues)

    fancy_samplers = list()
    for i in range(frag_length):
      pred = frag_psipred_pred[i]
      cfi = frag_psipred_cfi[i]
      if pred == 'H' and cfi >= 6:
        fancy_samplers.append(torsion_sampler_helix)
      elif pred == 'E' and cfi >= 6:
        fancy_samplers.append(torsion_sampler_extended)
      else:
        fancy_samplers.append(torsion_sampler_coil)

    fragger = loop.Fragger(frag_seq)
    fragger.AddSSAgreeParameters(ss_agreement_weight, frag_psipred)
    fragger.AddTorsionProbabilityParameters(torsion_fancy_weight,
                                            fancy_samplers,
                                            frag_resname_before,
                                            frag_resname_after)
    fragger.AddSequenceProfileParameters(seq_profile_weight, profile)
    fragger.AddStructureProfileParameters(struct_profile_weight, profile)

    fragger.Fill(structure_db, 0.0, 100)
    values = list()
    for i in range(len(fragger)):
      values.append(fragger[i].CARMSD(target_bb_list,True))
    pcurve = PerformanceCurve(values, 3.0, 1000)
    baseline = baselines[tf]
    pcurve.Subtract(baseline)
    curves.append(pcurve)



  avg_curve = CreateAVGPerformanceCurve(curves)
  return_value = avg_curve.AUC()

  # we're using a minimizer here...
  return_value = -return_value

  eval_time = time.time() - start_time 

  print "result: ", return_value, "eval time: ", eval_time

  sys.stdout.flush()

  return return_value


if not len(sys.argv) == 4:
  print "usage: ost scriptname.py weight_group frag_length base_dir"

start_weights_F4 = dict()
start_weights_F4[5] = [21.606, -1.094, -1.741]
start_weights_F4[7] = [20.497, -1.175, -1.993]
start_weights_F4[9] = [15.037859922, -0.839021062216, -2.91114701232] 
start_weights_F4[11] = [6.43439853872, -0.811787003143, -2.75538072379] 
start_weights_F4[15] = [6.40941367767, -0.797013002482, -4.45581400632]

start_weights_F3 = dict()
start_weights_F3[5] = [10.43, 3.00]
start_weights_F3[7] = [10.43, 3.00]
start_weights_F3[9] = [10.43, 3.00]
start_weights_F3[11] = [10.43, 3.00]
start_weights_F3[15] = [10.43, 3.00]

start_weights_F2 = dict()
start_weights_F2[5] = [-1.0]
start_weights_F2[7] = [-1.0]
start_weights_F2[9] = [-1.0]
start_weights_F2[11] = [-1.0]
start_weights_F2[15] = [-1.0]


weight_group = int(sys.argv[1])
frag_length = int(sys.argv[2])
base_dir = sys.argv[3]

if weight_group not in [2,3,4]:
  raise ValueError("weight group must be in [2,3,4]!")


if weight_group == 4:

  print "optimize weight group 4 for frag length", frag_length
  if frag_length not in start_weights_F4:
    raise ValueError("No start weights for this frag length!")
  initial_val = np.ndarray(3)
  initial_val[0] = start_weights_F4[frag_length][0]
  initial_val[1] = start_weights_F4[frag_length][1]
  initial_val[2] = start_weights_F4[frag_length][2]
  additional_args = tuple([frag_length, base_dir])
  minimize(F4, initial_val, method = "powell", tol = 0.00001, args = additional_args)

if weight_group == 3:

  print "optimize weight group 3 for frag length", frag_length
  if frag_length not in start_weights_F3:
    raise ValueError("No start weights for this frag length!")
  initial_val = np.ndarray(2)
  initial_val[0] = start_weights_F3[frag_length][0]
  initial_val[1] = start_weights_F3[frag_length][1]
  additional_args = tuple([frag_length, base_dir])
  minimize(F3, initial_val, method = "powell", tol = 0.00001, args = additional_args)

if weight_group == 2:

  print "optimize weight group 2 for frag length", frag_length
  if frag_length not in start_weights_F2:
    raise ValueError("No start weights for this frag length!")
  initial_val = np.ndarray(1)
  initial_val[0] = start_weights_F2[frag_length][0]
  additional_args = tuple([frag_length, base_dir])
  minimize(F2, initial_val, method = "powell", tol = 0.00001, args = additional_args)



