{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# NEP vs DFT results\n\nThis example show how to calculate energy, forces and stress of structures and compare them with DFT results\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from pynep.calculate import NEP\nfrom ase.io import read\nimport numpy as np\nimport matplotlib.pyplot as plt\nimport matplotlib.ticker as ticker\n\n\ndef plot_e(ed, er):\n    fig = plt.figure()\n    # plt.xticks(fontname=\"Arial\", weight='bold')\n    plt.title(\"NEP energy vs DFT energy\", fontsize=16)\n    ed = ed - np.mean(ed)\n    er = er - np.mean(er)\n    ax = plt.gca()\n    ax.set_aspect(1)\n    xmajorLocator = ticker.MaxNLocator(5)\n    ymajorLocator = ticker.MaxNLocator(5)\n    ax.xaxis.set_major_locator(xmajorLocator)\n    ax.yaxis.set_major_locator(ymajorLocator)\n    \n    ymajorFormatter = ticker.FormatStrFormatter('%.1f') \n    xmajorFormatter = ticker.FormatStrFormatter('%.1f') \n    ax.xaxis.set_major_formatter(xmajorFormatter)\n    ax.yaxis.set_major_formatter(ymajorFormatter)\n    \n    ax.set_xlabel('DFT energy (eV/atom)', fontsize=14)\n    ax.set_ylabel('NEP energy (eV/atom)', fontsize=14)\n    \n    ax.spines['bottom'].set_linewidth(3)\n    ax.spines['left'].set_linewidth(3)\n    ax.spines['right'].set_linewidth(3)\n    ax.spines['top'].set_linewidth(3)    \n    ax.tick_params(labelsize=16)\n\n    \n    plt.plot([np.min(ed), np.max(ed)], [np.min(er), np.max(er)],\n            color='black',linewidth=3,linestyle='--',)\n    plt.scatter(ed, er, zorder=200)\n    \n    m1 = min(np.min(ed), np.min(er))\n    m2 = max(np.max(ed), np.max(er))\n    ax.set_xlim(m1, m2)\n    ax.set_ylim(m1, m2)\n\n    rmse = np.sqrt(np.mean((ed-er)**2))\n    plt.text(np.min(ed) * 0.85 + np.max(ed) * 0.15, \n             np.min(er) * 0.15 + np.max(ed) * 0.85,\n             \"RMSE: {:.3f} eV/atom\".format(rmse), fontsize=14)\n    plt.savefig('e.png')\n    return fig\n\n\ndef plot_f(fd, fr):\n    fig = plt.figure()\n    ax = plt.gca()\n    plt.title(\"NEP forces vs DFT forces\", fontsize=16)\n    ax.set_aspect(1)\n    xmajorLocator = ticker.MaxNLocator(5)\n    ymajorLocator = ticker.MaxNLocator(5)\n    ax.xaxis.set_major_locator(xmajorLocator)\n    ax.yaxis.set_major_locator(ymajorLocator)\n    \n    ymajorFormatter = ticker.FormatStrFormatter('%.1f') \n    xmajorFormatter = ticker.FormatStrFormatter('%.1f') \n    ax.xaxis.set_major_formatter(xmajorFormatter)\n    ax.yaxis.set_major_formatter(ymajorFormatter)\n    \n    ax.set_xlabel('DFT forces (eV/A)', fontsize=14)\n    ax.set_ylabel('NEP forces (eV/A)', fontsize=14)\n    \n    ax.spines['bottom'].set_linewidth(2)\n    ax.spines['left'].set_linewidth(2)\n    ax.spines['right'].set_linewidth(2)\n    ax.spines['top'].set_linewidth(2)\n\n    ax.tick_params(labelsize=14)\n\n    ax.set_xlim(np.min(fd), np.max(fd))\n    ax.set_ylim(np.min(fr), np.max(fr))\n\n    plt.plot([np.min(fd), np.max(fd)], [np.min(fr), np.max(fr)],\n            color='black',linewidth=2,linestyle='--')\n    plt.scatter(fd.reshape(-1), fr.reshape(-1), s=2)\n\n    m1 = min(np.min(fd), np.min(fr))\n    m2 = max(np.max(fd), np.max(fr))\n    ax.set_xlim(m1, m2)\n    ax.set_ylim(m1, m2)\n\n    rmse = np.sqrt(np.mean((fd-fr)**2))\n    plt.text(np.min(fd) * 0.85 + np.max(fd) * 0.15, \n             np.min(fr) * 0.15 + np.max(fr) * 0.85,\n             \"RMSE: {:.3f} eV/A\".format(rmse), fontsize=14)\n    plt.savefig('f.png')\n    return fig\n\na = read('data.traj', ':')\ncalc = NEP(\"C_2022_NEP3.txt\")\nprint(calc)\ne1, e2, f1, f2 = [], [], [], []\nfor i in a:\n    i.set_calculator(calc)\n    e1.append(i.get_potential_energy() / len(i))\n    e2.append(i.info['energy'] / len(i))\n    f1.append(i.get_forces().reshape(-1))\n    f2.append(i.info['forces'].reshape(-1))\ne1 = np.array(e1)\ne2 = np.array(e2)\nf1 = np.concatenate(f1)\nf2 = np.concatenate(f2)\nplot_e(e2, e1)\nplot_f(f2, f1)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}