You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Thank you for your sharing code. I try to run the deepspeed_graph.py, but have some errors. I made a little modification to the code and it is as follows:
import lightning.pytorch as pl
import torch
from lightning.pytorch import Trainer
from graph_weather import GraphWeatherForecaster
from lightning.pytorch.strategies import DeepSpeedStrategy
from deepspeed.ops.adam import DeepSpeedCPUAdam
lat_lons = []
for lat in range(-90, 90, 1):
for lon in range(0, 360, 1):
lat_lons.append((lat, lon))
class LitModel(pl.LightningModule):
def __init__(self, lat_lons, feature_dim, aux_dim):
super().__init__()
self.model = GraphWeatherForecaster(
lat_lons=lat_lons, feature_dim=feature_dim, aux_dim=aux_dim
)
def training_step(self, batch):
x, y = batch
x = x.half()
y = y.half()
out = self.model(x)
criterion = torch.nn.MSELoss()
loss = criterion(out, y)
return loss
def configure_optimizers(self):
optimizer = DeepSpeedCPUAdam(self.parameters())
# optimizer = torch.optim.AdamW(self.parameters())
return optimizer
# def forward(self, x):
# return self.model(x)
# Fake data
from torch.utils.data import DataLoader, Dataset
class FakeDataset(Dataset):
def __init__(self):
super(FakeDataset, self).__init__()
def __len__(self):
return 64000
def __getitem__(self, item):
return torch.randn((64800, 78 + 0)), torch.randn((64800, 78))
model = LitModel(lat_lons=lat_lons, feature_dim=78, aux_dim=0)
trainer = Trainer(
accelerator="gpu",
devices=1,
strategy="deepspeed_stage_2_offload",
precision=16,
max_epochs=10,
limit_train_batches=2000,
)
dataset = FakeDataset()
train_dataloader = DataLoader(
dataset, batch_size=1, num_workers=1, pin_memory=True, prefetch_factor=1
)
trainer.fit(model=model, train_dataloaders=train_dataloader)
The error is :
| Name | Type | Params
-------------------------------------------------
0 | model | GraphWeatherForecaster | 7.6 M
-------------------------------------------------
7.6 M Trainable params
0 Non-trainable params
7.6 M Total params
30.342 Total estimated model params size (MB)
/home/lk/anaconda3/envs/rlai/lib/python3.8/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:430: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 40 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
rank_zero_warn(
Epoch 0: 0%| | 0/2000 [00:00<?, ?it/s]Traceback (most recent call last):
File "/media/lk/lksgcc/lk_git/21_RenewablePower/WeatherForecast/graph_weather/train/deepspeed_graph.py", line 73, in <module>
trainer.fit(model=model, train_dataloaders=train_dataloader)
File "/home/lk/anaconda3/envs/rlai/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 520, in fit
call._call_and_handle_interrupt(
File "/home/lk/anaconda3/envs/rlai/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 42, in _call_and_handle_interrupt
return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
File "/home/lk/anaconda3/envs/rlai/lib/python3.8/site-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 92, in launch
return function(*args, **kwargs)
File "/home/lk/anaconda3/envs/rlai/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 559, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/home/lk/anaconda3/envs/rlai/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 935, in _run
results = self._run_stage()
File "/home/lk/anaconda3/envs/rlai/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 978, in _run_stage
self.fit_loop.run()
File "/home/lk/anaconda3/envs/rlai/lib/python3.8/site-packages/lightning/pytorch/loops/fit_loop.py", line 201, in run
self.advance()
File "/home/lk/anaconda3/envs/rlai/lib/python3.8/site-packages/lightning/pytorch/loops/fit_loop.py", line 354, in advance
self.epoch_loop.run(self._data_fetcher)
File "/home/lk/anaconda3/envs/rlai/lib/python3.8/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 133, in run
self.advance(data_fetcher)
File "/home/lk/anaconda3/envs/rlai/lib/python3.8/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 218, in advance
batch_output = self.automatic_optimization.run(trainer.optimizers[0], kwargs)
File "/home/lk/anaconda3/envs/rlai/lib/python3.8/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 185, in run
self._optimizer_step(kwargs.get("batch_idx", 0), closure)
File "/home/lk/anaconda3/envs/rlai/lib/python3.8/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 261, in _optimizer_step
call._call_lightning_module_hook(
File "/home/lk/anaconda3/envs/rlai/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 142, in _call_lightning_module_hook
output = fn(*args, **kwargs)
File "/home/lk/anaconda3/envs/rlai/lib/python3.8/site-packages/lightning/pytorch/core/module.py", line 1266, in optimizer_step
optimizer.step(closure=optimizer_closure)
File "/home/lk/anaconda3/envs/rlai/lib/python3.8/site-packages/lightning/pytorch/core/optimizer.py", line 158, in step
step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
File "/home/lk/anaconda3/envs/rlai/lib/python3.8/site-packages/lightning/pytorch/strategies/ddp.py", line 257, in optimizer_step
optimizer_output = super().optimizer_step(optimizer, closure, model, **kwargs)
File "/home/lk/anaconda3/envs/rlai/lib/python3.8/site-packages/lightning/pytorch/strategies/strategy.py", line 224, in optimizer_step
return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
File "/home/lk/anaconda3/envs/rlai/lib/python3.8/site-packages/lightning/pytorch/plugins/precision/deepspeed.py", line 92, in optimizer_step
closure_result = closure()
File "/home/lk/anaconda3/envs/rlai/lib/python3.8/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 140, in __call__
self._result = self.closure(*args, **kwargs)
File "/home/lk/anaconda3/envs/rlai/lib/python3.8/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 126, in closure
step_output = self._step_fn()
File "/home/lk/anaconda3/envs/rlai/lib/python3.8/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 308, in _training_step
training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
File "/home/lk/anaconda3/envs/rlai/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 288, in _call_strategy_hook
output = fn(*args, **kwargs)
File "/home/lk/anaconda3/envs/rlai/lib/python3.8/site-packages/lightning/pytorch/strategies/ddp.py", line 329, in training_step
return self.model(*args, **kwargs)
File "/home/lk/anaconda3/envs/rlai/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
return forward_call(*input, **kwargs)
File "/home/lk/anaconda3/envs/rlai/lib/python3.8/site-packages/deepspeed/utils/nvtx.py", line 11, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/home/lk/anaconda3/envs/rlai/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 1846, in forward
loss = self.module(*inputs, **kwargs)
File "/home/lk/anaconda3/envs/rlai/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
return forward_call(*input, **kwargs)
File "/home/lk/anaconda3/envs/rlai/lib/python3.8/site-packages/lightning/pytorch/overrides/base.py", line 90, in forward
output = self._forward_module.training_step(*inputs, **kwargs)
File "/media/lk/lksgcc/lk_git/21_RenewablePower/WeatherForecast/graph_weather/train/deepspeed_graph.py", line 28, in training_step
out = self.model(x)
File "/home/lk/anaconda3/envs/rlai/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
return forward_call(*input, **kwargs)
File "/home/lk/anaconda3/envs/rlai/lib/python3.8/site-packages/graph_weather/models/forecast.py", line 109, in forward
x, edge_idx, edge_attr = self.encoder(features)
File "/home/lk/anaconda3/envs/rlai/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
return forward_call(*input, **kwargs)
File "/home/lk/anaconda3/envs/rlai/lib/python3.8/site-packages/graph_weather/models/layers/encoder.py", line 172, in forward
edge_attr = self.edge_encoder(self.graph.edge_attr) # Update attributes based on distance
File "/home/lk/anaconda3/envs/rlai/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
return forward_call(*input, **kwargs)
File "/home/lk/anaconda3/envs/rlai/lib/python3.8/site-packages/graph_weather/models/layers/graph_net_block.py", line 75, in forward
out = self.model(x)
File "/home/lk/anaconda3/envs/rlai/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
return forward_call(*input, **kwargs)
File "/home/lk/anaconda3/envs/rlai/lib/python3.8/site-packages/torch/nn/modules/container.py", line 204, in forward
input = module(input)
File "/home/lk/anaconda3/envs/rlai/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
return forward_call(*input, **kwargs)
File "/home/lk/anaconda3/envs/rlai/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 115, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 must have the same dtype
Following the running process, I found the linear input become torch.float32 dtype, however, the weights and bias are all torch.float16. So how could I fix this problem? Thanks.
The text was updated successfully, but these errors were encountered:
Thank you for your sharing code. I try to run the deepspeed_graph.py, but have some errors. I made a little modification to the code and it is as follows:
The error is :
Following the running process, I found the linear input become torch.float32 dtype, however, the weights and bias are all torch.float16. So how could I fix this problem? Thanks.
The text was updated successfully, but these errors were encountered: