"""
Exact Gaussian-neural scheme forces and energies of five different non-periodic
configurations and three different periodic configurations have been calculated
in Mathematica, and are given below.  This script checks the values calculated
by the code with and without fortran modules.

"""

import sys
import numpy as np
from collections import OrderedDict
from ase import Atoms

from amp import Amp
from amp.descriptor.gaussian import Gaussian


def check_perform():
    """Determines whether or not to perform the test.
    This should only perform the test if the python version is 2.x
    and tensorflow is installed. If returns False (meaning don't
    peform test), also supplies the reason."""
    if sys.version_info >= (3,):
        return False, 'amp.model.tflow not supported in python3.'
    try:
        import tensorflow
    except ImportError:
        return False, 'Tensorflow not installed.'
    return True, ''


def non_periodic_test():
    """Gaussian/tflowNeural non-periodic."""
    perform, reason = check_perform()
    if not perform:
        print('Skipping this test because {}'.format(reason))
        return

    from amp.model.tflow import NeuralNetwork
    # Making the list of non-periodic images
    images = [Atoms(symbols='PdOPd2',
                    pbc=np.array([False, False, False], dtype=bool),
                    cell=np.array(
                        [[1.,  0.,  0.],
                         [0.,  1.,  0.],
                            [0.,  0.,  1.]]),
                    positions=np.array(
                        [[0.,  0.,  0.],
                         [0.,  2.,  0.],
                            [0.,  0.,  3.],
                            [1.,  0.,  0.]])),
              Atoms(symbols='PdOPd2',
                    pbc=np.array([False, False, False], dtype=bool),
                    cell=np.array(
                        [[1.,  0.,  0.],
                         [0.,  1.,  0.],
                            [0.,  0.,  1.]]),
                    positions=np.array(
                        [[0.,  1.,  0.],
                         [1.,  2.,  1.],
                            [-1.,  1.,  2.],
                            [1.,  3.,  2.]])),
              Atoms(symbols='PdO',
                    pbc=np.array([False, False, False], dtype=bool),
                    cell=np.array(
                        [[1.,  0.,  0.],
                         [0.,  1.,  0.],
                         [0.,  0.,  1.]]),
                    positions=np.array(
                        [[2.,  1., -1.],
                         [1.,  2.,  1.]])),
              Atoms(symbols='Pd2O',
                    pbc=np.array([False, False, False], dtype=bool),
                    cell=np.array(
                        [[1.,  0.,  0.],
                         [0.,  1.,  0.],
                         [0.,  0.,  1.]]),
                    positions=np.array(
                        [[-2., -1., -1.],
                         [1.,  2.,  1.],
                         [3.,  4.,  4.]])),
              Atoms(symbols='Cu',
                    pbc=np.array([False, False, False], dtype=bool),
                    cell=np.array(
                        [[1.,  0.,  0.],
                         [0.,  1.,  0.],
                         [0.,  0.,  1.]]),
                    positions=np.array(
                        [[0.,  0.,  0.]]))]

    # Correct energies and forces
    correct_energies = [14.231186811226152, 14.327219917287948,
                        5.5742510565528285, 9.41456771216968,
                        -0.5019297954597407]
    correct_forces = \
        [[[-0.05095024246182649, -0.10709193432146558, -0.09734321482638622],
          [-0.044550772904033635, 0.2469763195486647, -0.07617425912869778],
            [-0.02352490951707703, -0.050782839419131864, 0.24409220250631508],
            [0.11902592488293715, -0.08910154580806727, -0.07057472855123109]],
            [[-0.024868720575099375, -0.07417891957113862,
              -0.12121240797223251],
             [0.060156158438252574, 0.017517013378773042,
              -0.020047135079325505],
             [-0.10901144291312388, -0.06671262448352767, 0.06581556263014315],
             [0.07372400504997068, 0.12337453067589325, 0.07544398042141486]],
            [[0.10151747265164626, -0.10151747265164626, -0.20303494530329252],
             [-0.10151747265164626, 0.10151747265164626, 0.20303494530329252]],
            [[-0.00031177673224312745, -0.00031177673224312745,
              -0.0002078511548287517],
             [0.004823209772264884, 0.004823209772264884,
              0.006975000714861393],
             [-0.004511433040021756, -0.004511433040021756,
              -0.006767149560032641]],
            [[0.0, 0.0, 0.0]]]

    # Parameters
    Gs = {'O': [{'type': 'G2', 'element': 'Pd', 'eta': 0.8},
                {'type': 'G4', 'elements': [
                    'Pd', 'Pd'], 'eta':0.2, 'gamma':0.3, 'zeta':1},
                {'type': 'G4', 'elements': ['O', 'Pd'], 'eta':0.3, 'gamma':0.6,
                 'zeta':0.5}],
          'Pd': [{'type': 'G2', 'element': 'Pd', 'eta': 0.2},
                 {'type': 'G4', 'elements': ['Pd', 'Pd'],
                  'eta':0.9, 'gamma':0.75, 'zeta':1.5},
                 {'type': 'G4', 'elements': ['O', 'Pd'], 'eta':0.4,
                  'gamma':0.3, 'zeta':4}],
          'Cu': [{'type': 'G2', 'element': 'Cu', 'eta': 0.8},
                 {'type': 'G4', 'elements': ['Cu', 'O'],
                  'eta':0.2, 'gamma':0.3, 'zeta':1},
                 {'type': 'G4', 'elements': ['Cu', 'Cu'], 'eta':0.3,
                  'gamma':0.6, 'zeta':0.5}]}

    hiddenlayers = {'O': (2, 1), 'Pd': (2, 1), 'Cu': (2, 1)}

    weights = OrderedDict([('O', OrderedDict([(1, np.matrix([[-2.0, 6.0],
                                                             [3.0, -3.0],
                                                             [1.5, -0.9],
                                                             [-2.5, -1.5]])),
                                              (2, np.matrix([[5.5],
                                                             [3.6],
                                                             [1.4]]))])),
                           ('Pd', OrderedDict([(1, np.matrix([[-1.0, 3.0],
                                                              [2.0, 4.2],
                                                              [1.0, -0.7],
                                                              [-3.0, 2.0]])),
                                               (2, np.matrix([[4.0],
                                                              [0.5],
                                                              [3.0]]))])),
                           ('Cu', OrderedDict([(1, np.matrix([[0.0, 1.0],
                                                              [-1.0, -2.0],
                                                              [2.5, -1.9],
                                                              [-3.5, 0.5]])),
                                               (2, np.matrix([[0.5],
                                                              [1.6],
                                                              [-1.4]]))]))])

    scalings = OrderedDict([('O', OrderedDict([('intercept', -2.3),
                                               ('slope', 4.5)])),
                            ('Pd', OrderedDict([('intercept', 1.6),
                                                ('slope', 2.5)])),
                            ('Cu', OrderedDict([('intercept', -0.3),
                                                ('slope', -0.5)]))])

    fingerprints_range = {"Cu": np.array([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]),
                          "O": np.array([[0.2139617720858539,
                                          2.258090276328769],
                                         [0.0, 1.085656080548734],
                                         [0.0, 0.0]]),
                          "Pd": np.array([[0.0, 1.4751761770313006],
                                          [0.0, 0.28464992134267897],
                                          [0.0, 0.20167521020630502]])}

    # Testing pure-python and fortran versions of Gaussian-neural force call
    for fortran in [False, True]:
        for cores in range(1, 6):
            label = 'call-nonperiodic/%s-%i' % (fortran, cores)
            calc = Amp(descriptor=Gaussian(cutoff=6.5,
                                           Gs=Gs,
                                           fortran=fortran),
                       model=NeuralNetwork(hiddenlayers=hiddenlayers,
                                           weights=weights,
                                           scalings=scalings,
                                           activation='sigmoid',
                                           fprange=fingerprints_range),
                       label=label,
                       dblabel=label,
                       cores=cores)

            predicted_energies = [calc.get_potential_energy(image) for image in
                                  images]

            for image_no in range(len(predicted_energies)):
                print(predicted_energies[image_no])
                print(correct_energies[image_no])
                diff = abs(predicted_energies[image_no] -
                           correct_energies[image_no])
                assert (diff < 10.**(-3.)), \
                    'The predicted energy of image %i is wrong!' % (
                        image_no + 1)

            predicted_forces = [calc.get_forces(image) for image in images]

            for image_no in range(len(predicted_forces)):
                print('predicted forces:')
                print(predicted_forces[image_no])
                print('correct forces:')
                print(np.array(correct_forces[image_no]))
                for index in range(np.shape(predicted_forces[image_no])[0]):
                    for direction in range(
                            np.shape(predicted_forces[image_no])[1]):
                        diff = abs(predicted_forces[image_no][index][
                            direction] -
                            correct_forces[image_no][index][direction])
                        assert (diff < 10.**(-3.)), \
                            'The predicted %i force of atom %i of image %i ' \
                            'is wrong!' % (direction, index, image_no + 1)


def periodic_test():
    """Gaussian/tflowNeural periodic."""
    perform, reason = check_perform()
    if not perform:
        print('Skipping this test because {}'.format(reason))
        return

    from amp.model.tflow import NeuralNetwork
    # Making the list of periodic images
    images = [Atoms(symbols='PdOPd',
                    pbc=np.array([True, False, False], dtype=bool),
                    cell=np.array(
                        [[2.,  0.,  0.],
                         [0.,  2.,  0.],
                         [0.,  0.,  2.]]),
                    positions=np.array(
                        [[0.5,  1., 0.5],
                         [1.,  0.5,  1.],
                         [1.5,  1.5,  1.5]])),
              Atoms(symbols='PdO',
                    pbc=np.array([True, True, False], dtype=bool),
                    cell=np.array(
                        [[2.,  0.,  0.],
                         [0.,  2.,  0.],
                            [0.,  0.,  2.]]),
                    positions=np.array(
                        [[0.5,  1., 0.5],
                         [1.,  0.5,  1.]])),
              Atoms(symbols='Cu',
                    pbc=np.array([True, True, False], dtype=bool),
                    cell=np.array(
                        [[1.8,  0.,  0.],
                         [0.,  1.8,  0.],
                            [0.,  0.,  1.8]]),
                    positions=np.array(
                        [[0.,  0., 0.]]))]

    # Correct energies and forces
    correct_energies = [3.8560954326995978, 1.6120748520627273,
                        0.19433107801410093]
    correct_forces = \
        [[[0.14747720528015523, -3.3010645563584973, 3.3008168318984463],
          [0.03333579762326405, 9.050780376599887, -0.42608278400777605],
            [-0.1808130029034193, -5.7497158202413905, -2.8747340478906698]],
            [[6.5035267996045045 * (10.**(-6.)),
              -6.503526799604495 * (10.**(-6.)),
              0.00010834689201069249],
             [-6.5035267996045045 * (10.**(-6.)),
              6.503526799604495 * (10.**(-6.)),
              -0.00010834689201069249]],
            [[0.0, 0.0, 0.0]]]

    # Parameters
    Gs = {'O': [{'type': 'G2', 'element': 'Pd', 'eta': 0.8},
                {'type': 'G4', 'elements': ['O', 'Pd'], 'eta':0.3, 'gamma':0.6,
                 'zeta':0.5}],
          'Pd': [{'type': 'G2', 'element': 'Pd', 'eta': 0.2},
                 {'type': 'G4', 'elements': ['Pd', 'Pd'],
                  'eta':0.9, 'gamma':0.75, 'zeta':1.5}],
          'Cu': [{'type': 'G2', 'element': 'Cu', 'eta': 0.8},
                 {'type': 'G4', 'elements': ['Cu', 'Cu'], 'eta':0.3,
                          'gamma':0.6, 'zeta':0.5}]}

    hiddenlayers = {'O': (2, 1), 'Pd': (2, 1), 'Cu': (2, 1)}

    weights = OrderedDict([('O', OrderedDict([(1, np.matrix([[-2.0, 6.0],
                                                             [3.0, -3.0],
                                                             [1.5, -0.9]])),
                                              (2, np.matrix([[5.5],
                                                             [3.6],
                                                             [1.4]]))])),
                           ('Pd', OrderedDict([(1, np.matrix([[-1.0, 3.0],
                                                              [2.0, 4.2],
                                                              [1.0, -0.7]])),
                                               (2, np.matrix([[4.0],
                                                              [0.5],
                                                              [3.0]]))])),
                           ('Cu', OrderedDict([(1, np.matrix([[0.0, 1.0],
                                                              [-1.0, -2.0],
                                                              [2.5, -1.9]])),
                                               (2, np.matrix([[0.5],
                                                              [1.6],
                                                              [-1.4]]))]))])

    scalings = OrderedDict([('O', OrderedDict([('intercept', -2.3),
                                               ('slope', 4.5)])),
                            ('Pd', OrderedDict([('intercept', 1.6),
                                                ('slope', 2.5)])),
                            ('Cu', OrderedDict([('intercept', -0.3),
                                                ('slope', -0.5)]))])

    fingerprints_range = {"Cu": np.array([[2.8636310860653253,
                                           2.8636310860653253],
                                          [1.5435994865298275,
                                           1.5435994865298275]]),
                          "O": np.array([[2.9409056366723028,
                                          2.972494902604392],
                                         [1.9522542722823606,
                                          4.0720361595017245]]),
                          "Pd": np.array([[2.4629488092411096,
                                           2.6160138774087125],
                                          [0.27127576524253594,
                                           0.5898312261433813]])}

    # Testing pure-python and fortran versions of Gaussian-neural force call
    for fortran in [False, True]:
        for cores in range(1, 4):
            label = 'call-periodic/%s-%i' % (fortran, cores)
            calc = Amp(descriptor=Gaussian(cutoff=4.,
                                           Gs=Gs,
                                           fortran=fortran),
                       model=NeuralNetwork(hiddenlayers=hiddenlayers,
                                           weights=weights,
                                           scalings=scalings,
                                           activation='tanh',
                                           fprange=fingerprints_range,
                                           unit_type="double"),
                       label=label,
                       dblabel=label,
                       cores=cores)

            predicted_energies = [calc.get_potential_energy(image) for image in
                                  images]

            for image_no in range(len(predicted_energies)):
                print(predicted_energies[image_no])
                print(correct_energies[image_no])
                diff = abs(predicted_energies[image_no] -
                           correct_energies[image_no])
                assert (diff < 10.**(-14.)), \
                    'The predicted energy of image %i is wrong!' % (
                        image_no + 1)

            predicted_forces = [calc.get_forces(image) for image in images]

            for image_no in range(len(predicted_forces)):
                print('predicted forces:')
                print(predicted_forces[image_no])
                print('correct forces:')
                print(np.array(correct_forces[image_no]))
                for index in range(np.shape(predicted_forces[image_no])[0]):
                    for direction in range(
                            np.shape(predicted_forces[image_no])[1]):
                        diff = abs(predicted_forces[image_no][index][
                            direction] -
                            correct_forces[image_no][index][direction])
                        assert (diff < 10.**(-11.)), \
                            'The predicted %i force of atom %i of image' \
                            ' %i is wrong!' % (direction,
                                               index,
                                               image_no + 1)

if __name__ == '__main__':
    non_periodic_test()
    periodic_test()
