Obtain Magnetic Laplacian PE via PyG Pre-transform

We provide a function that could obtain magnetic Laplacian PE based on torch_geometric.transforms, our codes is built on AddLaplacianEigenvectorPE.

An example to call our function for MagLap PE is as follows:

class DataProcessor(InMemoryDataset):
    def __init__(self, config):
        self.mag_pre_transform = Compose([AddMagLaplacianEigenvectorPE(k=config['model']['mag_pe_dim_input'],
                                        q=config['model']['q'],
                                        multiple_q=config['model']['q_dim'], attr_name='mag_pe')])

    def process:
        if self.mag_pre_transform is not None:
            data = self.mag_pre_transform(data)

The class is located at ./maglap/get_mag_lap.py and is as follows:

@functional_transform('add_mag_laplacian_eigenvector_pe')
class AddMagLaplacianEigenvectorPE(BaseTransform):
    r"""Adds the Magnetic Laplacian eigenvector positional encoding. The eigenvectors are
    complex number, so choosing k of them means there will be 2*k channels (k real parts and k imaginary parts)
    in total.

    Args:
        k (int): The number of non-trivial eigenvectors to consider.
        attr_name (str, optional): The attribute name of the data object to add
            positional encodings to. If set to :obj:`None`, will be
            concatenated to :obj:`data.x`.
            (default: :obj:`"laplacian_eigenvector_pe"`)
        **kwargs (optional): Additional arguments of
            :meth:`scipy.sparse.linalg.eigs` (when :attr:`is_undirected` is
            :obj:`False`) or :meth:`scipy.sparse.linalg.eigsh` (when
            :attr:`is_undirected` is :obj:`True`).
    """
    def __init__(
            self,
            k: int,
            q: float = 0.1,
            dynamic_q: bool = False,
            multiple_q: int = 1,
            attr_name: Optional[str] = 'laplacian_eigenvector_pe',
            **kwargs,
    ):
        self.k = k
        self.q = q
        self.dynamic_q = dynamic_q
        self.multiple_q = multiple_q
        self.attr_name = attr_name
        self.kwargs = kwargs

    def __call__(self, data: Data) -> Data:
        from scipy.sparse.linalg import eigs, eigsh
        eig_fn = eigsh # always use hermitian version

        num_nodes = data.num_nodes
        edge_index, edge_weight_list = get_mag_laplacian(
            data.edge_index,
            data.edge_weight,
            normalization='sym',
            num_nodes=num_nodes,
            q = self.q,
            dynamic_q=self.dynamic_q,
            multiple_q=self.multiple_q
        )

        pe_list = []
        eigvals_list = []
        for edge_weight in edge_weight_list:
            L = to_scipy_sparse_matrix(edge_index, edge_weight, num_nodes)

            #try:
            #    eig_vals, eig_vecs = eig_fn(
            #        L,
            #        k=self.k,
            #        which='SA',
            #        return_eigenvectors=True,
            #        **self.kwargs,
            #    )
            #    sort = eig_vals.argsort()
            #    eig_vals = eig_vals[sort]
            #    eig_vecs = eig_vecs[:, sort]
            #except:
                #from scipy.linalg import eigh
                #eig_vals, eig_vecs = eigh(L.toarray())
                #sort = eig_vals.argsort()[:self.k]
                #eig_vals = eig_vals[sort]
                #eig_vecs = eig_vecs[:, sort]
                #eig_vals = eig_vals[:self.k]
                #eig_vecs = eig_vecs[:, :self.k]

            #if np.isnan(eig_vecs).any() or np.isnan(eig_vals).any():
            eig_vals, eig_vecs = np.linalg.eigh(L.toarray())
            sort = eig_vals.argsort()[:self.k]
            eig_vals = eig_vals[sort]
            eig_vecs = eig_vecs[:, sort]

            # padding zeros if num of nodes less than desired pe dimension
            if len(eig_vals) < self.k:
                eig_vals = np.pad(eig_vals, (0, self.k - len(eig_vals)))
                eig_vecs = np.pad(eig_vecs, ((0, 0),(0, self.k - eig_vecs.shape[-1])))

            #pe = np.concatenate( (np.expand_dims(np.real(eig_vecs[:, eig_vals.argsort()]), -1),
            #                           np.expand_dims(np.imag(eig_vecs[:, eig_vals.argsort()]), -1)), axis=-1)
            #pe = np.concatenate( (np.expand_dims(np.real(eig_vecs), -1),
            #                           np.expand_dims(np.imag(eig_vecs), -1)), axis=-1)
            # pe = torch.from_numpy(pe) # [N, pe_dim, 2]
            #sign = -1 + 2 * torch.randint(0, 2, (self.k, ))
            #sign = torch.unsqueeze(torch.unsqueeze(sign, dim=-1), dim=0)
            #pe = sign * pe

            #pe = pe.flatten(1, 2) # [N, pe_dim * 2]



            pe = torch.from_numpy(np.expand_dims(eig_vecs, 1))
            eig_vals = np.expand_dims(np.expand_dims(eig_vals, 0), 0)
            pe_list.append(pe)
            eigvals_list.append(torch.from_numpy(eig_vals))
        #pe = torch.cat(pe_list, dim=-1)
        #eig_vals = torch.cat(eigvals_list, dim=-1)
        pe = torch.cat(pe_list, dim=1)
        eig_vals = torch.cat(eigvals_list, dim=1)
        data = add_node_attr(data, pe, attr_name=self.attr_name)
        #data = add_node_attr(data, eig_vals.reshape(1, -1), attr_name='Lambda')
        data = add_node_attr(data, eig_vals, attr_name='Lambda')
        return data