Skip to content

Inference API

The FastAPI service wraps the graph neural network to expose prediction endpoints.

Inference Service

project_name.api.InferenceService

Handles model loading and predictions.

Source code in src/project_name/api.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
class InferenceService:
    """Handles model loading and predictions."""

    def __init__(self, model_path: str | Path) -> None:
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = GraphNeuralNetwork(
            num_node_features=11,
            hidden_dim=128,
            num_layers=3,
            output_dim=1,
        )
        self.model.load_state_dict(torch.load(model_path, map_location=self.device))
        self.model.eval()

    def predict(
        self,
        node_features: list[list[float]],
        edge_index: list[list[int]],
    ) -> list[float]:
        """Generate prediction."""
        with torch.no_grad():
            x = torch.tensor(node_features, dtype=torch.float32).to(self.device)
            edge_idx = torch.tensor(edge_index, dtype=torch.long).to(self.device)
            data = Data(x=x, edge_index=edge_idx)
            output = self.model(data)
            return [float(output.squeeze().cpu())]

predict

predict(node_features: list[list[float]], edge_index: list[list[int]]) -> list[float]

Generate prediction.

Source code in src/project_name/api.py
67
68
69
70
71
72
73
74
75
76
77
78
def predict(
    self,
    node_features: list[list[float]],
    edge_index: list[list[int]],
) -> list[float]:
    """Generate prediction."""
    with torch.no_grad():
        x = torch.tensor(node_features, dtype=torch.float32).to(self.device)
        edge_idx = torch.tensor(edge_index, dtype=torch.long).to(self.device)
        data = Data(x=x, edge_index=edge_idx)
        output = self.model(data)
        return [float(output.squeeze().cpu())]

Request/Response Schemas

project_name.api.PredictionRequest

Bases: BaseModel

Input for prediction.

Source code in src/project_name/api.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class PredictionRequest(BaseModel):
    """Input for prediction."""

    node_features: list[list[float]]
    edge_index: list[list[int]]
    save_prediction: bool = True

    @field_validator("node_features")
    def validate_node_features(cls, v: list[list[float]]) -> list[list[float]]:
        """Validate node features have correct number of features.

        Args:
            v: Node features matrix.

        Returns:
            Validated node features.

        Raises:
            ValueError: If node features don't have correct number of dimensions.
        """
        if not v:
            raise ValueError("node_features cannot be empty")
        num_features = len(v[0])
        if num_features != 11:
            raise ValueError(f"Each node must have exactly 11 features, got {num_features}")
        if not all(len(features) == num_features for features in v):
            raise ValueError("All nodes must have the same number of features")
        return v

validate_node_features

validate_node_features(v: list[list[float]]) -> list[list[float]]

Validate node features have correct number of features.

Parameters:

Name Type Description Default
v list[list[float]]

Node features matrix.

required

Returns:

Type Description
list[list[float]]

Validated node features.

Raises:

Type Description
ValueError

If node features don't have correct number of dimensions.

Source code in src/project_name/api.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
@field_validator("node_features")
def validate_node_features(cls, v: list[list[float]]) -> list[list[float]]:
    """Validate node features have correct number of features.

    Args:
        v: Node features matrix.

    Returns:
        Validated node features.

    Raises:
        ValueError: If node features don't have correct number of dimensions.
    """
    if not v:
        raise ValueError("node_features cannot be empty")
    num_features = len(v[0])
    if num_features != 11:
        raise ValueError(f"Each node must have exactly 11 features, got {num_features}")
    if not all(len(features) == num_features for features in v):
        raise ValueError("All nodes must have the same number of features")
    return v

project_name.api.PredictionResponse

Bases: BaseModel

Prediction output.

Source code in src/project_name/api.py
46
47
48
49
class PredictionResponse(BaseModel):
    """Prediction output."""

    prediction: list[float]

Routes

project_name.api.health_check

health_check()

Health check.

Source code in src/project_name/api.py
121
122
123
124
@app.get("/")
def health_check():
    """Health check."""
    return {"status": "healthy"}

project_name.api.predict

predict(request: PredictionRequest, background_tasks: BackgroundTasks)

Generate prediction.

Source code in src/project_name/api.py
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
@app.post("/predict", response_model=PredictionResponse)
def predict(request: PredictionRequest, background_tasks: BackgroundTasks):
    """Generate prediction."""
    if service is None:
        raise HTTPException(status_code=503, detail="Model not ready")

    try:
        prediction = service.predict(
            request.node_features,
            request.edge_index,
        )
        if request.save_prediction:
            background_tasks.add_task(save_prediction_to_gcp, request.node_features, request.edge_index, prediction)
        return PredictionResponse(prediction=prediction)
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

Modules

api

project_name.api

InferenceService

Handles model loading and predictions.

Source code in src/project_name/api.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
class InferenceService:
    """Handles model loading and predictions."""

    def __init__(self, model_path: str | Path) -> None:
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = GraphNeuralNetwork(
            num_node_features=11,
            hidden_dim=128,
            num_layers=3,
            output_dim=1,
        )
        self.model.load_state_dict(torch.load(model_path, map_location=self.device))
        self.model.eval()

    def predict(
        self,
        node_features: list[list[float]],
        edge_index: list[list[int]],
    ) -> list[float]:
        """Generate prediction."""
        with torch.no_grad():
            x = torch.tensor(node_features, dtype=torch.float32).to(self.device)
            edge_idx = torch.tensor(edge_index, dtype=torch.long).to(self.device)
            data = Data(x=x, edge_index=edge_idx)
            output = self.model(data)
            return [float(output.squeeze().cpu())]

predict

predict(node_features: list[list[float]], edge_index: list[list[int]]) -> list[float]

Generate prediction.

Source code in src/project_name/api.py
67
68
69
70
71
72
73
74
75
76
77
78
def predict(
    self,
    node_features: list[list[float]],
    edge_index: list[list[int]],
) -> list[float]:
    """Generate prediction."""
    with torch.no_grad():
        x = torch.tensor(node_features, dtype=torch.float32).to(self.device)
        edge_idx = torch.tensor(edge_index, dtype=torch.long).to(self.device)
        data = Data(x=x, edge_index=edge_idx)
        output = self.model(data)
        return [float(output.squeeze().cpu())]

PredictionRequest

Bases: BaseModel

Input for prediction.

Source code in src/project_name/api.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class PredictionRequest(BaseModel):
    """Input for prediction."""

    node_features: list[list[float]]
    edge_index: list[list[int]]
    save_prediction: bool = True

    @field_validator("node_features")
    def validate_node_features(cls, v: list[list[float]]) -> list[list[float]]:
        """Validate node features have correct number of features.

        Args:
            v: Node features matrix.

        Returns:
            Validated node features.

        Raises:
            ValueError: If node features don't have correct number of dimensions.
        """
        if not v:
            raise ValueError("node_features cannot be empty")
        num_features = len(v[0])
        if num_features != 11:
            raise ValueError(f"Each node must have exactly 11 features, got {num_features}")
        if not all(len(features) == num_features for features in v):
            raise ValueError("All nodes must have the same number of features")
        return v

validate_node_features

validate_node_features(v: list[list[float]]) -> list[list[float]]

Validate node features have correct number of features.

Parameters:

Name Type Description Default
v list[list[float]]

Node features matrix.

required

Returns:

Type Description
list[list[float]]

Validated node features.

Raises:

Type Description
ValueError

If node features don't have correct number of dimensions.

Source code in src/project_name/api.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
@field_validator("node_features")
def validate_node_features(cls, v: list[list[float]]) -> list[list[float]]:
    """Validate node features have correct number of features.

    Args:
        v: Node features matrix.

    Returns:
        Validated node features.

    Raises:
        ValueError: If node features don't have correct number of dimensions.
    """
    if not v:
        raise ValueError("node_features cannot be empty")
    num_features = len(v[0])
    if num_features != 11:
        raise ValueError(f"Each node must have exactly 11 features, got {num_features}")
    if not all(len(features) == num_features for features in v):
        raise ValueError("All nodes must have the same number of features")
    return v

PredictionResponse

Bases: BaseModel

Prediction output.

Source code in src/project_name/api.py
46
47
48
49
class PredictionResponse(BaseModel):
    """Prediction output."""

    prediction: list[float]

health_check

health_check()

Health check.

Source code in src/project_name/api.py
121
122
123
124
@app.get("/")
def health_check():
    """Health check."""
    return {"status": "healthy"}

lifespan async

lifespan(app: FastAPI)

Load model on startup.

Source code in src/project_name/api.py
87
88
89
90
91
92
93
94
95
96
@asynccontextmanager
async def lifespan(app: FastAPI):
    """Load model on startup."""
    global service
    model_path = "best_model.pt"
    service = InferenceService(MODEL_FOLDER + model_path)

    yield

    del service

predict

predict(request: PredictionRequest, background_tasks: BackgroundTasks)

Generate prediction.

Source code in src/project_name/api.py
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
@app.post("/predict", response_model=PredictionResponse)
def predict(request: PredictionRequest, background_tasks: BackgroundTasks):
    """Generate prediction."""
    if service is None:
        raise HTTPException(status_code=503, detail="Model not ready")

    try:
        prediction = service.predict(
            request.node_features,
            request.edge_index,
        )
        if request.save_prediction:
            background_tasks.add_task(save_prediction_to_gcp, request.node_features, request.edge_index, prediction)
        return PredictionResponse(prediction=prediction)
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

save_prediction_to_gcp

save_prediction_to_gcp(node_features: list[list[float]], edge_index: list[list[int]], outputs: list[float])

Save the prediction results to GCP bucket.

Source code in src/project_name/api.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
def save_prediction_to_gcp(node_features: list[list[float]], edge_index: list[list[int]], outputs: list[float]):
    """Save the prediction results to GCP bucket."""
    client = storage.Client()
    bucket = client.bucket(PRED_FOLDER)
    time = datetime.now(tz=timezone.utc).isoformat()
    # Prepare prediction data
    data = {
        "node_features": node_features,
        "edge_index": edge_index,
        "prediction": outputs,
        "timestamp": time,
    }
    blob = bucket.blob(f"predictions/prediction_{time}.json")
    blob.upload_from_string(json.dumps(data))
    print("Prediction saved to GCP bucket.")

compare_promote

project_name.compare_promote

data

project_name.data

QM9Dataset

Bases: Dataset

QM9 dataset wrapper from torch_geometric.

Source code in src/project_name/data.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class QM9Dataset(Dataset):
    """QM9 dataset wrapper from torch_geometric."""

    def __init__(self, data_path: Path) -> None:
        """Initialize the QM9 dataset.

        Args:
            data_path: Path to the data directory where QM9 will be stored.
        """
        self.data_path = Path(data_path)
        self.dataset = self._load_dataset()

    def _load_dataset(self) -> QM9:
        """Load QM9 dataset, checking if it already exists locally.

        Downloads the dataset on first instantiation if it doesn't exist.

        Returns:
            QM9 dataset from torch_geometric.
        """
        raw_path = self.data_path / "raw"
        raw_path.mkdir(parents=True, exist_ok=True)

        print("Loading QM9 dataset (downloading if not already present)...")
        dataset = QM9(root=str(self.data_path))
        print(f"Dataset ready at {self.data_path}")
        return dataset

    def __len__(self) -> int:
        """Return the length of the dataset."""
        return len(self.dataset)

    def __getitem__(self, index: int):
        """Return a given sample from the dataset."""
        return self.dataset[index]

    def preprocess(self, output_folder: Path) -> None:
        """Preprocess the raw data and save it to the output folder."""

__getitem__

__getitem__(index: int)

Return a given sample from the dataset.

Source code in src/project_name/data.py
39
40
41
def __getitem__(self, index: int):
    """Return a given sample from the dataset."""
    return self.dataset[index]

__init__

__init__(data_path: Path) -> None

Initialize the QM9 dataset.

Parameters:

Name Type Description Default
data_path Path

Path to the data directory where QM9 will be stored.

required
Source code in src/project_name/data.py
10
11
12
13
14
15
16
17
def __init__(self, data_path: Path) -> None:
    """Initialize the QM9 dataset.

    Args:
        data_path: Path to the data directory where QM9 will be stored.
    """
    self.data_path = Path(data_path)
    self.dataset = self._load_dataset()

__len__

__len__() -> int

Return the length of the dataset.

Source code in src/project_name/data.py
35
36
37
def __len__(self) -> int:
    """Return the length of the dataset."""
    return len(self.dataset)

preprocess

preprocess(output_folder: Path) -> None

Preprocess the raw data and save it to the output folder.

Source code in src/project_name/data.py
43
44
def preprocess(self, output_folder: Path) -> None:
    """Preprocess the raw data and save it to the output folder."""

download_model

project_name.download_model

evaluate

project_name.evaluate

evaluate

evaluate(model: GraphNeuralNetwork, loader: DataLoader, device: torch.device, target_indices: Sequence[int]) -> float

Evaluate model on a dataloader.

Computes mean MSE loss per graph over the entire loader, matching train_epoch.

Parameters:

Name Type Description Default
model GraphNeuralNetwork

Trained GNN model.

required
loader DataLoader

DataLoader for validation/test set.

required
device device

Torch device.

required
target_indices Sequence[int]

Indices of target properties in batch.y.

required

Returns:

Type Description
float

Mean MSE loss per graph.

Source code in src/project_name/evaluate.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
@torch.no_grad()
def evaluate(
    model: GraphNeuralNetwork,
    loader: DataLoader,
    device: torch.device,
    target_indices: Sequence[int],
) -> float:
    """Evaluate model on a dataloader.

    Computes mean MSE loss per graph over the entire loader, matching train_epoch.

    Args:
        model: Trained GNN model.
        loader: DataLoader for validation/test set.
        device: Torch device.
        target_indices: Indices of target properties in batch.y.

    Returns:
        Mean MSE loss per graph.
    """
    model.eval()

    total_loss: float = 0.0
    num_samples: int = 0

    target_idx = list(target_indices)

    for batch in loader:
        batch = batch.to(device)

        pred: torch.Tensor = model(batch)
        target: torch.Tensor = batch.y[:, target_idx]

        loss: torch.Tensor = F.mse_loss(pred, target)
        total_loss += loss.item() * batch.num_graphs
        num_samples += batch.num_graphs

    if num_samples == 0:
        return 0.0

    return total_loss / num_samples

evaluate_with_metrics

evaluate_with_metrics(model: GraphNeuralNetwork, loader: DataLoader, device: torch.device, target_indices: Sequence[int]) -> dict[str, float]

Evaluate model on a dataloader with multiple metrics.

Parameters:

Name Type Description Default
model GraphNeuralNetwork

Trained GNN model.

required
loader DataLoader

DataLoader for validation/test set.

required
device device

Torch device.

required
target_indices Sequence[int]

Indices of target properties in batch.y.

required

Returns:

Type Description
dict[str, float]

Dictionary with metrics: mse, rmse, mae, r2.

Source code in src/project_name/evaluate.py
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
@torch.no_grad()
def evaluate_with_metrics(
    model: GraphNeuralNetwork,
    loader: DataLoader,
    device: torch.device,
    target_indices: Sequence[int],
) -> dict[str, float]:
    """Evaluate model on a dataloader with multiple metrics.

    Args:
        model: Trained GNN model.
        loader: DataLoader for validation/test set.
        device: Torch device.
        target_indices: Indices of target properties in batch.y.

    Returns:
        Dictionary with metrics: mse, rmse, mae, r2.
    """
    model.eval()

    all_preds: list[torch.Tensor] = []
    all_targets: list[torch.Tensor] = []

    target_idx = list(target_indices)

    for batch in loader:
        batch = batch.to(device)

        pred: torch.Tensor = model(batch)
        target: torch.Tensor = batch.y[:, target_idx]

        all_preds.append(pred)
        all_targets.append(target)

    if len(all_preds) == 0:
        return {"mse": 0.0, "rmse": 0.0, "mae": 0.0, "r2": 0.0}

    predictions = torch.cat(all_preds, dim=0)
    targets = torch.cat(all_targets, dim=0)

    # MSE
    mse = F.mse_loss(predictions, targets).item()

    # RMSE
    rmse = torch.sqrt(F.mse_loss(predictions, targets)).item()

    # MAE
    mae = F.l1_loss(predictions, targets).item()

    # R² score
    ss_res = torch.sum((targets - predictions) ** 2).item()
    ss_tot = torch.sum((targets - torch.mean(targets)) ** 2).item()
    r2 = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0.0

    return {
        "mse": mse,
        "rmse": rmse,
        "mae": mae,
        "r2": r2,
    }

get_device

get_device() -> torch.device

Get the best available device for computation.

Source code in src/project_name/evaluate.py
122
123
124
125
126
127
128
def get_device() -> torch.device:
    """Get the best available device for computation."""
    if torch.cuda.is_available():
        return torch.device("cuda")
    if torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")

main

main(cfg: DictConfig) -> None

Load best model and evaluate on test set with comprehensive metrics.

Source code in src/project_name/evaluate.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
@hydra.main(version_base=None, config_path="../../configs", config_name="config")
def main(cfg: DictConfig) -> None:
    """Load best model and evaluate on test set with comprehensive metrics."""
    device = get_device()

    # Load dataset
    dataset = QM9Dataset(cfg.training.data_path)
    dataset.transform = NormalizeScale()

    n = len(dataset)
    train_size = int(cfg.training.train_ratio * n)
    val_size = int(cfg.training.val_ratio * n)
    test_size = n - train_size - val_size

    _, _, test_dataset = torch.utils.data.random_split(
        dataset,
        [train_size, val_size, test_size],
        generator=torch.Generator().manual_seed(cfg.seed),
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=cfg.training.batch_size,
        shuffle=False,
    )

    target_indices = list(cfg.training.target_indices)
    num_targets = len(target_indices)

    # Build model
    model = GraphNeuralNetwork(
        num_node_features=cfg.model.num_node_features,
        hidden_dim=cfg.model.hidden_dim,
        num_layers=cfg.model.num_layers,
        output_dim=num_targets,
    ).to(device)

    # Load best model
    best_model_path = Path(cfg.training.model_dir) / "best_model.pt"
    print(f"Loading model from: {best_model_path}")

    try:
        state = torch.load(best_model_path, weights_only=True)
    except TypeError:
        state = torch.load(best_model_path)

    model.load_state_dict(state)

    # Evaluate with multiple metrics
    metrics = evaluate_with_metrics(model, test_loader, device, target_indices)

    print("\n" + "=" * 50)
    print("Test Set Evaluation (Best Model)")
    print("=" * 50)
    print(f"MSE:  {metrics['mse']:.6f}")
    print(f"RMSE: {metrics['rmse']:.6f}")
    print(f"MAE:  {metrics['mae']:.6f}")
    print(f"R²:   {metrics['r2']:.6f}")
    print("=" * 50 + "\n")

model

project_name.model

GraphNeuralNetwork

Bases: Module

Graph Neural Network for molecular property regression.

Source code in src/project_name/model.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
class GraphNeuralNetwork(nn.Module):
    """Graph Neural Network for molecular property regression."""

    def __init__(
        self,
        num_node_features: int = 11,
        num_edge_features: int = 4,
        hidden_dim: int = 128,
        num_layers: int = 3,
        output_dim: int = 1,
        dropout: float = 0.1,
    ) -> None:
        """Initialize the GNN model.

        Args:
            num_node_features: Number of node (atom) features.
            num_edge_features: Number of edge (bond) features.
            hidden_dim: Number of hidden channels.
            num_layers: Number of GraphConv layers.


            output_dim: Output dimension (1 for single property regression).
            dropout: Dropout rate for regularization.
        """
        super().__init__()
        self.dropout_rate = dropout

        self.initial_embedding = nn.Linear(num_node_features, hidden_dim)

        self.conv_layers = nn.ModuleList([GraphConv(hidden_dim, hidden_dim) for _ in range(num_layers)])

        self.pool = global_mean_pool

        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, output_dim),
        )

    def forward(self, data) -> torch.Tensor:
        """Forward pass through the model.

        Args:
            data: PyTorch Geometric Data object with x, edge_index, and batch attributes.

        Returns:
            Predicted property values.
        """
        x, edge_index, batch = data.x, data.edge_index, data.batch

        x = self.initial_embedding(x)
        x = F.relu(x)

        for conv in self.conv_layers:
            x = conv(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout_rate, training=self.training)

        x = self.pool(x, batch)

        x = self.mlp(x)

        return x

__init__

__init__(num_node_features: int = 11, num_edge_features: int = 4, hidden_dim: int = 128, num_layers: int = 3, output_dim: int = 1, dropout: float = 0.1) -> None

Initialize the GNN model.

Parameters:

Name Type Description Default
num_node_features int

Number of node (atom) features.

11
num_edge_features int

Number of edge (bond) features.

4
hidden_dim int

Number of hidden channels.

128
num_layers int

Number of GraphConv layers.

3
output_dim int

Output dimension (1 for single property regression).

1
dropout float

Dropout rate for regularization.

0.1
Source code in src/project_name/model.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def __init__(
    self,
    num_node_features: int = 11,
    num_edge_features: int = 4,
    hidden_dim: int = 128,
    num_layers: int = 3,
    output_dim: int = 1,
    dropout: float = 0.1,
) -> None:
    """Initialize the GNN model.

    Args:
        num_node_features: Number of node (atom) features.
        num_edge_features: Number of edge (bond) features.
        hidden_dim: Number of hidden channels.
        num_layers: Number of GraphConv layers.


        output_dim: Output dimension (1 for single property regression).
        dropout: Dropout rate for regularization.
    """
    super().__init__()
    self.dropout_rate = dropout

    self.initial_embedding = nn.Linear(num_node_features, hidden_dim)

    self.conv_layers = nn.ModuleList([GraphConv(hidden_dim, hidden_dim) for _ in range(num_layers)])

    self.pool = global_mean_pool

    self.mlp = nn.Sequential(
        nn.Linear(hidden_dim, hidden_dim),
        nn.ReLU(),
        nn.Dropout(dropout),
        nn.Linear(hidden_dim, hidden_dim // 2),
        nn.ReLU(),
        nn.Dropout(dropout),
        nn.Linear(hidden_dim // 2, output_dim),
    )

forward

forward(data) -> torch.Tensor

Forward pass through the model.

Parameters:

Name Type Description Default
data

PyTorch Geometric Data object with x, edge_index, and batch attributes.

required

Returns:

Type Description
Tensor

Predicted property values.

Source code in src/project_name/model.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def forward(self, data) -> torch.Tensor:
    """Forward pass through the model.

    Args:
        data: PyTorch Geometric Data object with x, edge_index, and batch attributes.

    Returns:
        Predicted property values.
    """
    x, edge_index, batch = data.x, data.edge_index, data.batch

    x = self.initial_embedding(x)
    x = F.relu(x)

    for conv in self.conv_layers:
        x = conv(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout_rate, training=self.training)

    x = self.pool(x, batch)

    x = self.mlp(x)

    return x

profiling

project_name.profiling

Profiling utilities for training and evaluation.

TrainingProfiler

Manages profiling across entire training session.

Source code in src/project_name/profiling.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
class TrainingProfiler:
    """Manages profiling across entire training session."""

    def __init__(
        self,
        enabled: bool = False,
        output_dir: Optional[Path] = None,
        warmup_steps: int = 1,
        active_steps: int = 10,
        repeat_steps: int = 1,
    ) -> None:
        """Initialize the training profiler.

        Args:
            enabled: Whether to enable profiling.
            output_dir: Directory to save profiling results.
        """
        self.enabled = enabled
        self.output_dir = output_dir or Path("profiling_results/run")
        self.prof: Optional[profile] = None

        if self.enabled:
            self.output_dir.mkdir(parents=True, exist_ok=True)
            self.prof = profile(
                activities=[ProfilerActivity.CPU]
                if not torch.cuda.is_available()
                else [ProfilerActivity.CPU, ProfilerActivity.CUDA],
                record_shapes=True,
                profile_memory=True,
                schedule=torch.profiler.schedule(
                    wait=0,
                    warmup=warmup_steps,
                    active=active_steps,
                    repeat=repeat_steps,
                ),
                on_trace_ready=tensorboard_trace_handler(output_dir),
            )
            self.prof.__enter__()

    def step(self) -> None:
        """Record a step (epoch) in the profiler."""
        if self.prof:
            self.prof.step()

    def finalize(self) -> None:
        """Finalize profiling and export trace."""
        if self.prof:
            self.prof.__exit__(None, None, None)
            print(f"✅ Profiling trace saved to {self.output_dir}")

__init__

__init__(enabled: bool = False, output_dir: Optional[Path] = None, warmup_steps: int = 1, active_steps: int = 10, repeat_steps: int = 1) -> None

Initialize the training profiler.

Parameters:

Name Type Description Default
enabled bool

Whether to enable profiling.

False
output_dir Optional[Path]

Directory to save profiling results.

None
Source code in src/project_name/profiling.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def __init__(
    self,
    enabled: bool = False,
    output_dir: Optional[Path] = None,
    warmup_steps: int = 1,
    active_steps: int = 10,
    repeat_steps: int = 1,
) -> None:
    """Initialize the training profiler.

    Args:
        enabled: Whether to enable profiling.
        output_dir: Directory to save profiling results.
    """
    self.enabled = enabled
    self.output_dir = output_dir or Path("profiling_results/run")
    self.prof: Optional[profile] = None

    if self.enabled:
        self.output_dir.mkdir(parents=True, exist_ok=True)
        self.prof = profile(
            activities=[ProfilerActivity.CPU]
            if not torch.cuda.is_available()
            else [ProfilerActivity.CPU, ProfilerActivity.CUDA],
            record_shapes=True,
            profile_memory=True,
            schedule=torch.profiler.schedule(
                wait=0,
                warmup=warmup_steps,
                active=active_steps,
                repeat=repeat_steps,
            ),
            on_trace_ready=tensorboard_trace_handler(output_dir),
        )
        self.prof.__enter__()

finalize

finalize() -> None

Finalize profiling and export trace.

Source code in src/project_name/profiling.py
57
58
59
60
61
def finalize(self) -> None:
    """Finalize profiling and export trace."""
    if self.prof:
        self.prof.__exit__(None, None, None)
        print(f"✅ Profiling trace saved to {self.output_dir}")

step

step() -> None

Record a step (epoch) in the profiler.

Source code in src/project_name/profiling.py
52
53
54
55
def step(self) -> None:
    """Record a step (epoch) in the profiler."""
    if self.prof:
        self.prof.step()

timing_checkpoint

timing_checkpoint(name: str, enabled: bool = True) -> Generator

Context manager for simple timing measurements.

Parameters:

Name Type Description Default
name str

Name for this checkpoint.

required
enabled bool

Whether to enable timing.

True

Yields:

Type Description
Generator

Dictionary with timing results.

Source code in src/project_name/profiling.py
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
@contextmanager
def timing_checkpoint(name: str, enabled: bool = True) -> Generator:
    """Context manager for simple timing measurements.

    Args:
        name: Name for this checkpoint.
        enabled: Whether to enable timing.

    Yields:
        Dictionary with timing results.
    """
    result = {"name": name, "duration": 0.0}

    if not enabled:
        yield result
        return

    start = time.perf_counter()
    try:
        yield result
    finally:
        result["duration"] = time.perf_counter() - start
        print(f"⏱️  {name}: {result['duration']:.4f}s")

prune

project_name.prune

apply_unstructured_pruning

apply_unstructured_pruning(model: torch.nn.Module, amount: float) -> dict[str, Any]

Apply unstructured L1 pruning to FC layers only, then make it permanent.

Source code in src/project_name/prune.py
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
def apply_unstructured_pruning(
    model: torch.nn.Module,
    amount: float,
) -> dict[str, Any]:
    """Apply unstructured L1 pruning to FC layers only, then make it permanent."""
    if not (0.0 <= amount < 1.0):
        raise ValueError(f"Prune amount must be in [0, 1). Got: {amount}")

    pruned_modules = list(_iter_prunable_weight_params(model))
    if not pruned_modules:
        logger.warning("No prunable fully-connected (nn.Linear) layers found.")
        return {"modules_pruned": 0, "global_sparsity": 0.0}

    for m, pname in pruned_modules:
        prune.l1_unstructured(m, name=pname, amount=amount)

    for m, pname in pruned_modules:
        prune.remove(m, pname)

    total_elems = 0
    zero_elems = 0
    for m, pname in pruned_modules:
        w = getattr(m, pname)
        total_elems += w.numel()
        zero_elems += int((w == 0).sum().item())

    global_sparsity = (zero_elems / total_elems) if total_elems > 0 else 0.0
    return {
        "modules_pruned": len(pruned_modules),
        "global_sparsity": global_sparsity,
        "zero_elems": zero_elems,
        "total_elems": total_elems,
    }

evaluate_mse

evaluate_mse(model: torch.nn.Module, loader: DataLoader, device: torch.device, target_indices: Sequence[int]) -> float

Mean MSE per graph (matches your train/eval convention).

Source code in src/project_name/prune.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
@torch.no_grad()
def evaluate_mse(
    model: torch.nn.Module,
    loader: DataLoader,
    device: torch.device,
    target_indices: Sequence[int],
) -> float:
    """Mean MSE per graph (matches your train/eval convention)."""
    model.eval()

    total_loss: float = 0.0
    num_samples: int = 0
    target_idx = list(target_indices)

    for batch in loader:
        batch = batch.to(device)
        pred = model(batch)
        target = batch.y[:, target_idx]
        loss = torch.nn.functional.mse_loss(pred, target)
        total_loss += float(loss.item()) * batch.num_graphs
        num_samples += batch.num_graphs

    return total_loss / max(1, num_samples)

measure_inference_latency

measure_inference_latency(model: torch.nn.Module, loader: DataLoader, device: torch.device, *, warmup_batches: int = 10, timed_batches: int = 50) -> dict[str, float]

Measures average latency per batch (ms) over a fixed number of batches.

Notes: - Uses torch.inference_mode() via @torch.no_grad() + model.eval() - Syncs CUDA for accurate timing

Source code in src/project_name/prune.py
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
@torch.no_grad()
def measure_inference_latency(
    model: torch.nn.Module,
    loader: DataLoader,
    device: torch.device,
    *,
    warmup_batches: int = 10,
    timed_batches: int = 50,
) -> dict[str, float]:
    """
    Measures average latency per batch (ms) over a fixed number of batches.

    Notes:
    - Uses torch.inference_mode() via @torch.no_grad() + model.eval()
    - Syncs CUDA for accurate timing
    """
    model.eval()

    def _sync() -> None:
        if device.type == "cuda":
            torch.cuda.synchronize()

    it = iter(loader)

    # Warmup
    for _ in range(warmup_batches):
        try:
            batch = next(it)
        except StopIteration:
            it = iter(loader)
            batch = next(it)
        batch = batch.to(device)
        _ = model(batch)
    _sync()

    # Timed
    times: list[float] = []
    for _ in range(timed_batches):
        try:
            batch = next(it)
        except StopIteration:
            it = iter(loader)
            batch = next(it)
        batch = batch.to(device)

        _sync()
        t0 = time.perf_counter()
        _ = model(batch)
        _sync()
        t1 = time.perf_counter()
        times.append(t1 - t0)

    if not times:
        return {"ms_per_batch": 0.0, "batches": 0}

    avg_s = sum(times) / len(times)
    return {"ms_per_batch": avg_s * 1000.0, "batches": float(len(times))}

quantize

project_name.quantize

measure_inference_latency

measure_inference_latency(model: torch.nn.Module, loader: DataLoader, device: torch.device, *, warmup_batches: int = 10, timed_batches: int = 50) -> dict[str, float]

Average latency per batch in ms. Note: for quantized CPU models this is the typical use case.

Source code in src/project_name/quantize.py
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
@torch.no_grad()
def measure_inference_latency(
    model: torch.nn.Module,
    loader: DataLoader,
    device: torch.device,
    *,
    warmup_batches: int = 10,
    timed_batches: int = 50,
) -> dict[str, float]:
    """
    Average latency per batch in ms.
    Note: for quantized CPU models this is the typical use case.
    """
    model.eval()
    it = iter(loader)

    # Warmup
    for _ in range(warmup_batches):
        try:
            batch = next(it)
        except StopIteration:
            it = iter(loader)
            batch = next(it)
        batch = batch.to(device)
        _ = model(batch)

    times: list[float] = []
    for _ in range(timed_batches):
        try:
            batch = next(it)
        except StopIteration:
            it = iter(loader)
            batch = next(it)

        batch = batch.to(device)
        t0 = time.perf_counter()
        _ = model(batch)
        t1 = time.perf_counter()
        times.append(t1 - t0)

    if not times:
        return {"ms_per_batch": 0.0, "batches": 0.0}

    avg_s = sum(times) / len(times)
    return {"ms_per_batch": avg_s * 1000.0, "batches": float(len(times))}

quantize_full_model

quantize_full_model(model: torch.nn.Module, scheme: str) -> torch.nn.Module

Apply weight-only INT8 quantization to all linear layers in the model, including those nested inside GraphConv blocks. Uses torchao when available and falls back to the torch.ao dynamic quantization API otherwise.

scheme
  • "torchao_int8_weight_only" (default)
  • "torch_ao_dynamic" (fallback-style dynamic quantization)
Source code in src/project_name/quantize.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def quantize_full_model(model: torch.nn.Module, scheme: str) -> torch.nn.Module:
    """
    Apply weight-only INT8 quantization to *all* linear layers in the model, including those
    nested inside GraphConv blocks. Uses torchao when available and falls back to the
    torch.ao dynamic quantization API otherwise.

    scheme:
      - "torchao_int8_weight_only" (default)
      - "torch_ao_dynamic" (fallback-style dynamic quantization)
    """
    if scheme == "torch_ao_dynamic":
        from torch.ao.quantization import quantize_dynamic  # older weights-only API

        return quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)

    # default: torchao
    try:
        from torchao.quantization import quantize_
        from torchao.quantization import Int8WeightOnlyConfig
    except Exception as e:
        logger.warning("torchao not available (%s). Falling back to torch.ao.quantization.quantize_dynamic.", e)
        from torch.ao.quantization import quantize_dynamic

        return quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)

    # torchao quantize_ is inplace; returns None or model depending on version
    quantize_(
        model, Int8WeightOnlyConfig()
    )  #  [oai_citation:3‡PyTorch Documentation](https://docs.pytorch.org/ao/stable/generated/torchao.quantization.quantize_.html)
    return model

train

project_name.train

train

train(cfg: DictConfig) -> None

Train the GNN model on QM9 dataset.

Parameters:

Name Type Description Default
cfg DictConfig

Hydra configuration object containing all parameters.

required
Source code in src/project_name/train.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
@hydra.main(version_base=None, config_path=_CONFIG_PATH, config_name="config")
def train(cfg: DictConfig) -> None:
    """Train the GNN model on QM9 dataset.

    Args:
        cfg: Hydra configuration object containing all parameters.
    """
    device: torch.device = _get_device()
    logger.info("Using device: %s", device)

    model_dir: Path = get_data_path(
        cfg.training.model_dir,
        gcs_bucket=OmegaConf.select(cfg, "training.gcs_bucket"),
    )
    model_dir.mkdir(parents=True, exist_ok=True)
    print(cfg)

    profile: bool = cfg.training.profile
    profiler_run_dir: str = cfg.training.profiler_run_dir
    run = _init_wandb(cfg)
    if run is not None:
        logger.info("wandb logging enabled (run: %s)", run.id)
    else:
        logger.info("wandb logging disabled")
    with timing_checkpoint("Load dataset", enabled=profile):
        logger.info("Loading QM9 dataset...")
        data_path = get_data_path(
            cfg.training.data_path,
            gcs_bucket=OmegaConf.select(cfg, "training.gcs_bucket"),
        )
        dataset: Dataset = QM9Dataset(data_path)

    # Apply normalization transform
    dataset.transform = NormalizeScale()

    # Split dataset
    n: int = len(dataset)
    train_size: int = int(cfg.training.train_ratio * n)
    val_size: int = int(cfg.training.val_ratio * n)
    test_size: int = n - train_size - val_size

    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
        dataset,
        [train_size, val_size, test_size],
        generator=torch.Generator().manual_seed(cfg.seed),
    )

    # Create data loaders
    # Create data loaders (parallel loading)
    workers = _num_workers(cfg)
    logger.info("DataLoader num_workers=%d", workers)

    train_loader: DataLoader = DataLoader(
        train_dataset,
        batch_size=cfg.training.batch_size,
        shuffle=True,
        num_workers=workers,
        pin_memory=torch.cuda.is_available(),
        persistent_workers=(workers > 0),
    )

    val_loader: DataLoader = DataLoader(
        val_dataset,
        batch_size=cfg.training.batch_size,
        shuffle=False,
        num_workers=workers,
        pin_memory=torch.cuda.is_available(),
        persistent_workers=(workers > 0),
    )

    test_loader: DataLoader = DataLoader(
        test_dataset,
        batch_size=cfg.training.batch_size,
        shuffle=False,
        num_workers=workers,
        pin_memory=torch.cuda.is_available(),
        persistent_workers=(workers > 0),
    )

    logger.info("Dataset split - Train: %d, Val: %d, Test: %d", len(train_dataset), len(val_dataset), len(test_dataset))

    # Get target indices and infer output dimension
    target_indices: list[int] = list(cfg.training.target_indices)
    num_targets: int = len(target_indices)
    logger.info("Predicting %d target(s): %s", num_targets, target_indices)

    # Initialize model
    model: GraphNeuralNetwork | nn.DataParallel = GraphNeuralNetwork(
        num_node_features=cfg.model.num_node_features,
        hidden_dim=cfg.model.hidden_dim,
        num_layers=cfg.model.num_layers,
        output_dim=num_targets,
    ).to(device)

    if torch.cuda.is_available() and torch.cuda.device_count() > 1:
        model = nn.DataParallel(model, device_ids=list(range(torch.cuda.device_count())))

    optimizer: Optimizer = torch.optim.Adam(model.parameters(), lr=cfg.training.learning_rate)

    # Early stopping variables
    best_val_loss: float = float("inf")
    patience: int = cfg.training.patience
    patience_counter: int = 0

    logger.info(
        "Starting training for %d epochs (batch_size=%d, lr=%g, patience=%d)",
        cfg.training.epochs,
        cfg.training.batch_size,
        cfg.training.learning_rate,
        patience,
    )
    profiler = TrainingProfiler(enabled=profile, output_dir=Path(f"profiling_results/{profiler_run_dir}"))

    for epoch in range(1, cfg.training.epochs + 1):
        train_loss: float = train_epoch(model, train_loader, optimizer, device, target_indices)

        # compute validation metrics (mse/rmse/mae/r2)
        val_metrics = evaluate_with_metrics(model, val_loader, device, target_indices)
        val_loss: float = float(val_metrics["mse"])  # keep early-stopping tied to MSE

        if epoch % LOG_INTERVAL == 0 or epoch == 1:
            logger.info(
                "Epoch %3d | Train Loss: %.6f | Val MSE: %.6f | Val RMSE: %.6f | Val MAE: %.6f | Val R2: %.6f",
                epoch,
                train_loss,
                val_metrics["mse"],
                val_metrics["rmse"],
                val_metrics["mae"],
                val_metrics["r2"],
            )

        improved = val_loss < best_val_loss
        if improved:
            best_val_loss = val_loss
            patience_counter = 0
            best_model_path: Path = model_dir / "best_model.pt"
            torch.save(
                model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict(),
                best_model_path,
            )
            logger.debug("Saved best model to %s", best_model_path)
        else:
            patience_counter += 1

        if run is not None:
            wandb.log(
                {
                    "epoch": epoch,
                    "loss/train": train_loss,
                    # keep your existing val loss key if you want
                    "loss/val": val_loss,
                    # add full validation metrics
                    "val/mse": val_metrics["mse"],
                    "val/rmse": val_metrics["rmse"],
                    "val/mae": val_metrics["mae"],
                    "val/r2": val_metrics["r2"],
                    "early_stopping/patience_counter": patience_counter,
                    "early_stopping/best_val_loss": best_val_loss,
                }
            )

        if patience_counter >= patience:
            logger.info("Early stopping triggered at epoch %d (best_val_loss=%.6f)", epoch, best_val_loss)
            break

        profiler.step()
    profiler.finalize()

    # Load best model and evaluate on test set
    best_model_path = model_dir / "best_model.pt"

    try:
        state = torch.load(best_model_path, weights_only=True)
    except TypeError:
        state = torch.load(best_model_path)

    if isinstance(model, nn.DataParallel):
        model.module.load_state_dict(state)
    else:
        model.load_state_dict(state)

    test_metrics: dict[str, float] = evaluate_with_metrics(model, test_loader, device, target_indices)
    test_loss: float = float(test_metrics["mse"])
    logger.info("Final test loss: %.6f", test_loss)

    # Save final model
    final_model_path: Path = model_dir / "final_model.pt"
    torch.save(
        model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict(),
        final_model_path,
    )
    logger.info("Training complete. Models saved to %s", model_dir)

    # wandb: final logs
    if run is not None:
        # Log full test metrics (assumes you computed `test_metrics` as shown before)
        wandb.log(
            {
                "loss/test": test_loss,  # keep compatibility (MSE)
                "test/mse": test_metrics["mse"],
                "test/rmse": test_metrics["rmse"],
                "test/mae": test_metrics["mae"],
                "test/r2": test_metrics["r2"],
            }
        )

        # Optionally log model artifact
        artifact = None
        if bool(OmegaConf.select(cfg, "wandb.log_artifacts", default=True)):
            artifact = wandb.Artifact(
                name="qm9-gnn",
                type="model",
                description="Trained model",
                metadata={
                    "target_indices": target_indices,
                    "best_val_loss": best_val_loss,
                    "test_mse": test_metrics["mse"],
                    "test_rmse": test_metrics["rmse"],
                    "test_mae": test_metrics["mae"],
                    "test_r2": test_metrics["r2"],
                },
            )
            artifact.add_file(str(best_model_path))
            run.log_artifact(artifact)

            # Link only if we actually created an artifact
            run.link_artifact(
                artifact=artifact,
                target_path="model-registry/mlops-molecules",
                aliases=["latest"],
            )

        wandb.finish()

train_epoch

train_epoch(model: GraphNeuralNetwork, loader: DataLoader, optimizer: Optimizer, device: torch.device, target_indices: list[int]) -> float

Train for one epoch.

Source code in src/project_name/train.py
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def train_epoch(
    model: GraphNeuralNetwork,
    loader: DataLoader,
    optimizer: Optimizer,
    device: torch.device,
    target_indices: list[int],
) -> float:
    """Train for one epoch."""
    model.train()
    total_loss: float = 0.0
    num_samples: int = 0

    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad(set_to_none=True)

        pred: torch.Tensor = model(batch)
        target: torch.Tensor = batch.y[:, target_indices]

        loss: torch.Tensor = F.mse_loss(pred, target)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * batch.num_graphs
        num_samples += batch.num_graphs

    return total_loss / num_samples

utils

project_name.utils

Utility functions for data loading and environment detection.

get_data_path

get_data_path(config_path: str | Path, gcs_bucket: str | None = None) -> Path

Get the appropriate data path based on the environment.

In cloud environments (GCP/Vertex AI), data is mounted to /gcs/. In local environments, data is loaded from the configured path.

Parameters:

Name Type Description Default
config_path str | Path

The data path from config (e.g., 'data' or 'data/processed').

required
gcs_bucket str | None

Optional GCS bucket name. If provided and running in cloud, will use /gcs/ as the base path.

None

Returns:

Type Description
Path

Path object pointing to the correct data location.

Source code in src/project_name/utils.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def get_data_path(config_path: str | Path, gcs_bucket: str | None = None) -> Path:
    """Get the appropriate data path based on the environment.

    In cloud environments (GCP/Vertex AI), data is mounted to /gcs/<bucket-name>.
    In local environments, data is loaded from the configured path.

    Args:
        config_path: The data path from config (e.g., 'data' or 'data/processed').
        gcs_bucket: Optional GCS bucket name. If provided and running in cloud,
                    will use /gcs/<bucket-name> as the base path.

    Returns:
        Path object pointing to the correct data location.
    """
    gcs_mount = Path("/gcs")

    if gcs_mount.exists() and gcs_bucket:
        data_path = gcs_mount / gcs_bucket / config_path
    else:
        data_path = Path(config_path)

    return data_path

visualize

project_name.visualize