Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 85 additions & 100 deletions metrics_evaluation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"execution_count": 1,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# import string\n",
Expand All @@ -36,8 +38,10 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"execution_count": 2,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"import matplotlib as mpl\n",
Expand All @@ -64,8 +68,10 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"execution_count": 3,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"metricslist = ['Brier', 'LogLoss']\n",
Expand All @@ -90,8 +96,10 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"execution_count": 4,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"mystery = {}\n",
Expand All @@ -108,8 +116,10 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"execution_count": 5,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"snphotcc = {}\n",
Expand All @@ -131,8 +141,10 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"execution_count": 6,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"plasticc = {}\n",
Expand All @@ -141,38 +153,16 @@
]
},
{
"cell_type": "code",
"execution_count": null,
"cell_type": "markdown",
"metadata": {},
"outputs": [],
"source": [
"# old_snphotcc_names = []\n",
"# for prefix in ['templates_', 'wavelets_']:\n",
"# for suffix in ['boost_forest', 'knn', 'nb', 'neural_network', 'svm']:\n",
"# old_snphotcc_names.append(prefix+suffix+'.dat')\n",
"\n",
"# for i in range(len(snphotcc_names)):\n",
"# name = old_snphotcc_names[i]\n",
"# fileloc = dirname+'classifications/'+name\n",
"# snphotcc_info = pd.read_csv(fileloc, sep=' ')\n",
"# full = snphotcc_info.set_index('Object').join(truth_snphotcc.set_index('Object'))\n",
"# name = snphotcc_names[i]\n",
" \n",
"# truth = full['Type'] - 1\n",
"# snphotcc_truth_table = proclam.metrics.util.det_to_prob(truth)\n",
"# fileloc = 'examples/'+name+'/truth_table_'+name+'.csv'\n",
"# with open(fileloc, 'wb') as truth_place:\n",
"# np.savetxt(fileloc, snphotcc_truth_table, delimiter=' ')\n",
" \n",
"# probs = full[['1', '2', '3']]\n",
"# fileloc = 'examples/'+name+'/predicted_prob_'+name+'.csv'\n",
"# probs.to_csv(fileloc, sep=' ', index=False, header=True)"
]
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"execution_count": 7,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# more_names = snphotcc_names\n",
Expand All @@ -183,8 +173,10 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"execution_count": 8,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"def make_class_pairs(data_info_dict):\n",
Expand All @@ -201,11 +193,21 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"execution_count": 9,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'label': 'ProClaM', 'names': ['Idealized', 'Guess', 'Tunnel', 'Broadbrush', 'Cruise', 'SubsumedTo', 'SubsumedFrom'], 'dirname': 'examples/ProClaM/', 'classifications': ['Idealized/predicted_prob_Idealized.csv', 'Guess/predicted_prob_Guess.csv', 'Tunnel/predicted_prob_Tunnel.csv', 'Broadbrush/predicted_prob_Broadbrush.csv', 'Cruise/predicted_prob_Cruise.csv', 'SubsumedTo/predicted_prob_SubsumedTo.csv', 'SubsumedFrom/predicted_prob_SubsumedFrom.csv'], 'truth_tables': ['Idealized/truth_table_Idealized.csv', 'Guess/truth_table_Guess.csv', 'Tunnel/truth_table_Tunnel.csv', 'Broadbrush/truth_table_Broadbrush.csv', 'Cruise/truth_table_Cruise.csv', 'SubsumedTo/truth_table_SubsumedTo.csv', 'SubsumedFrom/truth_table_SubsumedFrom.csv']}\n"
]
}
],
"source": [
"for dataset in [mystery, snphotcc, plasticc]:\n",
"for dataset in [ plasticc]: #mystery, snphotcc,\n",
" dataset = make_file_locs(dataset)\n",
" dataset['class_pairs'] = make_class_pairs(dataset)"
]
Expand All @@ -221,8 +223,10 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"execution_count": 10,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"def plot_cm(probs, truth, name, loc=''):\n",
Expand All @@ -234,14 +238,17 @@
" plt.ylabel('true class')\n",
" plt.colorbar()\n",
" plt.title(name)\n",
" plt.savefig(loc+name+'_cm.png')\n",
" plt.close()"
" #plt.savefig(loc+name+'_cm.png')\n",
" plt.show()\n",
" #plt.close()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"execution_count": 11,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": []
},
Expand All @@ -263,6 +270,7 @@
" nobj_truth = np.shape(truth_values)[0]\n",
" nclass_truth = np.shape(truth_values)[1]\n",
" tvec = np.where(truth_values==1)[1]\n",
" print(tvec)\n",
"# if nclass_truth!= nclass:\n",
"# print('Truth table of size %i x %i and prob matrix of size %i x %i do not match up in size'%(nobj,nclass,nobj_truth,nclass_truth))\n",
"# else:\n",
Expand All @@ -274,8 +282,10 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"execution_count": 12,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"def make_patch_spines_invisible(ax):\n",
Expand Down Expand Up @@ -324,6 +334,7 @@
" plt.legend(handles, metric_names)\n",
" plt.suptitle(title)\n",
" plt.savefig(fileloc)\n",
" plt.show()\n",
" return"
]
},
Expand All @@ -336,13 +347,28 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"execution_count": 13,
"metadata": {
"scrolled": false
},
"outputs": [
{
"ename": "KeyError",
"evalue": "'class_pairs'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-13-15f56f172c88>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mdataset\u001b[0m \u001b[0;32min\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mmystery\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msnphotcc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mplasticc\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mempty\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmetricslist\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'names'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mcc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpair\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'class_pairs'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpair\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mprobm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtruthv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mread_class_pairs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpair\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdataset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcc\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;31m#loc=dataset['dirname'], title=dataset['label']+' '+dataset['names'][cc])\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mKeyError\u001b[0m: 'class_pairs'"
]
}
],
"source": [
"for dataset in [mystery, snphotcc, plasticc]:\n",
" data = np.empty((len(metricslist), len(dataset['names'])))\n",
" for cc, pair in enumerate(dataset['class_pairs']):\n",
" print(pair)\n",
" probm, truthv = read_class_pairs(pair, dataset, cc)#loc=dataset['dirname'], title=dataset['label']+' '+dataset['names'][cc])\n",
"# plot_cm(probm, truthv, str(cc), loc='./sandbox/')\n",
" det = proclam.metrics.util.prob_to_det(probm)\n",
Expand All @@ -361,55 +387,14 @@
"# metric_plot(dataset, metricslist, markerlist, colors)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# more_data = np.empty((len(metricslist), len(more_names)))\n",
"# for cc, pair in enumerate(more_class_pairs):\n",
"# probm, truthv = read_class_pairs(pair, dirname)\n",
"# for count, metric in enumerate(metricslist):\n",
"# D = getattr(proclam.metrics, metric)()\n",
"# hm = D.evaluate(probm, truthv)\n",
"# more_data[count][cc] = hm"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# metric_plot(more_names, metricslist, more_data, markerlist, colors, title='SNPhotCC', fileloc=dirname+'snphotccdata.png')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# data = np.empty((len(metricslist), len(names)))\n",
"# for cc, pair in enumerate(class_pairs):\n",
"# probm, truthv = read_class_pairs(pair, dirname)\n",
"# for count, metric in enumerate(metricslist):\n",
"# D = getattr(proclam.metrics, metric)()\n",
"# hm = D.evaluate(probm, truthv)\n",
"# data[count][cc] = hm"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# metric_plot(names, metricslist, data, markerlist, colors, title='Mystery Dataset', fileloc=dirname+'mysterydata.png')"
]
"source": []
},
{
"cell_type": "code",
Expand Down Expand Up @@ -443,7 +428,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
"version": "3.6.8"
}
},
"nbformat": 4,
Expand Down
Loading