File size: 4,462 Bytes
e551dda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import torch
import pandas as pd
import numpy as np

from matplotlib import cm
import matplotlib.pyplot as plt
import scipy
import torch.nn.functional as F
import torchvision

from sklearn.metrics import explained_variance_score,mean_squared_error,mean_absolute_error,r2_score,precision_score,recall_score,f1_score,roc_auc_score,roc_curve, auc,confusion_matrix
from sklearn.feature_selection import r_regression

from torch_sparse import SparseTensor
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import minimum_spanning_tree
from math import pi as PI

def scipy_spanning_tree(edge_index, edge_weight,num_nodes ):
    row, col = edge_index.cpu()
    edge_weight=edge_weight.cpu()
    cgraph = csr_matrix((edge_weight, (row, col)), shape=(num_nodes, num_nodes))
    Tcsr = minimum_spanning_tree(cgraph)
    tree_row, tree_col = Tcsr.nonzero()
    spanning_edges = np.stack([tree_row,tree_col],0)    
    return spanning_edges
    
def build_spanning_tree_edge(edge_index,edge_weight, num_nodes):
    spanning_edges = scipy_spanning_tree(edge_index, edge_weight,num_nodes,)
        
    spanning_edges = torch.tensor(spanning_edges, dtype=torch.long, device=edge_index.device)
    spanning_edges_undirected = torch.cat([spanning_edges,torch.stack([spanning_edges[1],spanning_edges[0]])],1)
    return spanning_edges_undirected




def record(values,epoch,writer,phase="Train"):
    """ tfboard write """
    for key,value in values.items():
        writer.add_scalar(key+"/"+phase,value,epoch)           
def calculate(y_hat,y_true,y_hat_logit):
    """ calculate five metrics using y_hat, y_true, y_hat_logit """
    train_acc=(np.array(y_hat) == np.array(y_true)).sum()/len(y_true) 
    # recall=recall_score(y_true, y_hat,zero_division=0,average='micro')
    # precision=precision_score(y_true, y_hat,zero_division=0,average='micro')
    # Fscore=f1_score(y_true, y_hat,zero_division=0,average='micro')
    # roc=roc_auc_score(y_true, scipy.special.softmax(np.array(y_hat_logit),axis=1)[:,1],average='micro',multi_class='ovr')
    # one_hot_encoded_labels = np.zeros((len(y_true), 100))
    # one_hot_encoded_labels[np.arange(len(y_true)), y_true] = 1
    # roc=roc_auc_score(one_hot_encoded_labels, scipy.special.softmax(np.array(y_hat_logit),axis=1),average='micro',multi_class='ovr')
    return train_acc


def print_1(epoch,phase,values,color=None):
    """ print epoch info"""
    if color is not None:
        print(color( f"epoch[{epoch:d}] {phase}"+ " ".join([f"{key}={value:.3f}" for key, value in values.items()]) ))
    else:
        print(( f"epoch[{epoch:d}] {phase}"+ " ".join([f"{key}={value:.3f}" for key, value in values.items()]) ))

def get_angle(v1, v2):
    if v1.shape[1]==2:
        v1=F.pad(v1, (0, 1),value=0)
    if v2.shape[1]==2:
        v2= F.pad(v2, (0, 1),value=0)
    return torch.atan2( torch.cross(v1, v2, dim=1).norm(p=2, dim=1), (v1 * v2).sum(dim=1))
def get_theta(v1, v2):
    # v1 is starting line, right-hand rule to v2, if thumb is up, +, else -
    angle=get_angle(v1, v2)
    if v1.shape[1]==2:
        v1=F.pad(v1, (0, 1),value=0)
    if v2.shape[1]==2:
        v2= F.pad(v2, (0, 1),value=0)
    v = torch.cross(v1, v2, dim=1)[...,2]
    flag = torch.sign((v))
    flag[flag==0]=-1 
    return angle*flag   

def triplets(edge_index, num_nodes):
    row, col = edge_index

    value = torch.arange(row.size(0), device=row.device)
    adj_t = SparseTensor(row=row, col=col, value=value,
                         sparse_sizes=(num_nodes, num_nodes))
    adj_t_col = adj_t[:,row]
    num_triplets = adj_t_col.set_value(None).sum(dim=0).to(torch.long)

    idx_j = row.repeat_interleave(num_triplets) 
    idx_i = col.repeat_interleave(num_triplets) 
    edx_2nd = value.repeat_interleave(num_triplets) 
    idx_k = adj_t_col.t().storage.col() 
    edx_1st = adj_t_col.t().storage.value()
    mask1 = (idx_i == idx_k) & (idx_j != idx_i)  # Remove go back triplets. 
    mask2 = (idx_i == idx_j) & (idx_j != idx_k)  # Remove repeat self loop triplets
    mask3 = (idx_j == idx_k) & (idx_i != idx_k)  # Remove self-loop neighbors 
    mask = ~(mask1 | mask2 | mask3) 
    idx_i, idx_j, idx_k, edx_1st, edx_2nd = idx_i[mask], idx_j[mask], idx_k[mask], edx_1st[mask], edx_2nd[mask]
    
    num_triplets_real = torch.cumsum(num_triplets, dim=0) - torch.cumsum(~mask, dim=0)[torch.cumsum(num_triplets, dim=0)-1]

    return torch.stack([idx_i, idx_j, idx_k]), num_triplets_real.to(torch.long), edx_1st, edx_2nd