File size: 2,302 Bytes
9a67fbe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import pandas as pd
from sklearn.model_selection import train_test_split
import importlib.resources as pkg_resources
import polyatomic_complexes
import numpy as np
from typing import Tuple
from pathlib import Path
def load_dataset(name) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
if name.lower() == "esol":
data_path = (
pkg_resources.files("polyatomic_complexes.dataset.esol") / "ESOL.csv"
)
df = pd.read_csv(str(data_path))
target_col = "measured log solubility in mols per litre"
elif name.lower() == "freesolv":
data_path = (
pkg_resources.files("polyatomic_complexes.dataset.free_solv")
/ "FreeSolv.csv"
)
df = pd.read_csv(str(data_path))
target_col = "expt"
elif name.lower() == "lipophil":
data_path = (
pkg_resources.files("polyatomic_complexes.dataset.lipophilicity")
/ "Lipophilicity.csv"
)
df = pd.read_csv(str(data_path))
target_col = "exp"
elif name.lower() == "boilingpoint":
data_path = (
Path(__file__).parent.parent / "benchmark_csv/boiling_point.csv".__str__()
)
df = pd.read_csv(data_path)
target_col = "boiling_point_K"
elif name.lower() == "qm9":
data_path = (
Path(__file__).parent.parent / "benchmark_csv/qm9_subset.csv".__str__()
)
df = pd.read_csv(data_path)
target_col = "cv"
elif name.lower() == "ic50":
data_path = (
Path(__file__).parent.parent / "benchmark_csv/ic_50_subset.csv".__str__()
)
df = pd.read_csv(data_path)
target_col = "pIC50"
elif name.lower() == "bindingdb":
data_path = (
Path(__file__).parent.parent / "benchmark_csv/bindingdb.csv".__str__()
)
df = pd.read_csv(data_path)
target_col = "pIC50"
else:
raise ValueError(f"Unknown dataset: {name}")
df.dropna(subset=["smiles", target_col], inplace=True)
smiles = df["smiles"]
targets = df[target_col]
X_train, X_test, y_train, y_test = train_test_split(
smiles, targets, test_size=0.2, random_state=42
)
return X_train.to_numpy(), X_test.to_numpy(), y_train.to_numpy(), y_test.to_numpy()
|