Skip to content

Commit

Permalink
Apply review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
thvasilo committed Jan 9, 2025
1 parent 6045c20 commit d980349
Show file tree
Hide file tree
Showing 4 changed files with 201 additions and 47 deletions.
2 changes: 0 additions & 2 deletions python/graphstorm/gconstruct/file_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,6 @@ def expand_wildcard(data_files: List[str]) -> List[str]:
"""
expanded_files = []
if len(data_files) == 1 and os.path.isdir(data_files[0]):
data_files = [os.path.join(data_files[0], "*")]
for item in data_files:
if '*' in item:
matched_files = sorted(glob.glob(item))
Expand Down
61 changes: 58 additions & 3 deletions sagemaker/pipeline/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ This project simplifies the process of running GraphStorm workflows on Amazon Sa
for detailed permissions needed to create and run SageMaker Pipelines.
- Familiarity with SageMaker AI and
[SageMaker Pipelines](https://docs.aws.amazon.com/sagemaker/latest/dg/pipelines.html).
- Basic understanding of graph neural networks and GraphStorm
- Basic understanding of graph neural networks and [GraphStorm](https://graphstorm.readthedocs.io/en/latest/index.html).

## Project Structure

Expand Down Expand Up @@ -124,7 +124,7 @@ For a full list of execution options:
python execute_sm_pipeline.py --help
```

For more fine-grained execution options, like selective execution, please refer to
For more fine-grained execution options, like selective execution, please refer to
[SageMaker AI documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/pipelines-selective-ex.html).

## Pipeline Components
Expand All @@ -149,7 +149,62 @@ The pipeline's behavior is controlled by various configuration parameters, inclu
- Task configuration (graph name, input/output locations)
- Training and inference configurations

Refer to the `PipelineArgs` class in `pipeline_parameters.py` for a complete list of configurable options.
### AWS Configuration
- `--execution-role`: SageMaker execution IAM role ARN. (Required)
- `--region`: AWS region. (Required)
- `--graphstorm-pytorch-cpu-image-url`: GraphStorm GConstruct/dist_part/train/inference CPU ECR image URL. (Required)
- `--graphstorm-pytorch-gpu-image-url`: GraphStorm GConstruct/dist_part/train/inference GPU ECR image URL.
- `--gsprocessing-pyspark-image-url`: GSProcessing SageMaker PySpark ECR image URL.

### Instance Configuration
- `--instance-count` / `--num-parts`: Number of worker instances/partitions for partition, training, inference. (Required)
- `--cpu-instance-type`: CPU instance type. (Default: ml.m5.4xlarge)
- `--gpu-instance-type`: GPU instance type. (Default: ml.g5.4xlarge)
- `--train-on-cpu`: Run training and inference on CPU instances instead of GPU. (Flag)
- `--graph-construction-instance-type`: Instance type for graph construction.
- `--gsprocessing-instance-count`: Number of GSProcessing instances.
- `--volume-size-gb`: Additional volume size for SageMaker instances in GB. (Default: 100)

### Task Configuration
- `--graph-name`: Name of the graph. (Required)
- `--input-data-s3`: S3 path to the input graph data. (Required)
- `--output-prefix-s3`: S3 prefix for the output data. (Required)
- `--pipeline-name`: Name for the pipeline.
- `--base-job-name`: Base job name for SageMaker jobs. (Default: 'gs')
- `--jobs-to-run`: Space-separated string of jobs to run in the pipeline. (Required)
- `--log-level`: Logging level for the jobs. (Default: INFO)
- `--step-cache-expiration`: Expiration time for the step cache. (Default: 30d)
- `--update-pipeline`: Update an existing pipeline instead of creating a new one. (Flag)

### Graph Construction Configuration
- `--graph-construction-config-filename`: Filename for the graph construction config.
- `--graph-construction-args`: Parameters to be passed directly to the GConstruct job.

### Partition Configuration
- `--partition-algorithm`: Partitioning algorithm. (Default: random)
- `--partition-output-json`: Name for the output JSON file that describes the partitioned data. (Default: metadata.json)
- `--partition-input-json`: Name for the JSON file that describes the input data for partitioning. (Default: updated_row_counts_metadata.json)

### Training Configuration
- `--model-output-path`: S3 path for model output.
- `--num-trainers`: Number of trainers to use during training/inference. (Default: 4)
- `--train-inference-task`: Task type for training and inference. (Required)
- `--train-yaml-s3`: S3 path to train YAML configuration file.
- `--use-graphbolt`: Whether to use GraphBolt. (Default: false)

### Inference Configuration
- `--inference-yaml-s3`: S3 path to inference YAML configuration file.
- `--inference-model-snapshot`: Which model snapshot to choose to run inference with.
- `--save-predictions`: Whether to save predictions to S3 during inference. (Flag)
- `--save-embeddings`: Whether to save embeddings to S3 during inference. (Flag)

### Script Paths
- `--dist-part-script`: Path to DistPartition SageMaker entry point script.
- `--gb-convert-script`: Path to GraphBolt partition script.
- `--train-script`: Path to training SageMaker entry point script.
- `--inference-script`: Path to inference SageMaker entry point script.
- `--gconstruct-script`: Path to GConstruct SageMaker entry point script.
- `--gsprocessing-script`: Path to GSProcessing SageMaker entry point script.

## Advanced Usage

Expand Down
14 changes: 9 additions & 5 deletions sagemaker/pipeline/execute_sm_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,9 @@ def parse_args():
"--pipeline-args-json-file",
type=str,
help=(
"When executing locally, optionally provide a JSON representation of the pipeline arguments. "
"By default we look for '<pipeline-name>-pipeline-args.json' in the working dir."
"When executing locally, optionally provide a JSON representation of the pipeline "
"arguments. By default we look for '<pipeline-name>-pipeline-args.json' "
"in the working dir."
),
)

Expand Down Expand Up @@ -265,14 +266,17 @@ def main():

logging.info("Pipeline execution started: %s", execution.describe())
logging.info("Execution ARN: %s", execution.arn)
logging.info(f"Output will be created under: {execution_params['ExecutionSubpath']}")
logging.info(
"Output will be created under: %s", execution_params["ExecutionSubpath"]
)

if not args.async_execution:
logging.info("Waiting for pipeline execution to complete...")
execution.wait()
logging.info("Pipeline execution completed.")
logging.info("Final status: %s",
execution.describe()['PipelineExecutionStatus'])
logging.info(
"Final status: %s", execution.describe()["PipelineExecutionStatus"]
)


if __name__ == "__main__":
Expand Down
171 changes: 134 additions & 37 deletions sagemaker/pipeline/pipeline_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,21 @@

@dataclass
class AWSConfig:
"""AWS-related configuration"""
"""AWS-related configuration.
Parameters
----------
execution_role : str
SageMaker execution IAM role ARN.
region : str
AWS region.
graphstorm_pytorch_cpu_image_url : str
GraphStorm GConstruct/dist_part/train/inference CPU ECR image URL.
graphstorm_pytorch_gpu_image_url : str
GraphStorm GConstruct/dist_part/train/inference GPU ECR image URL.
gsprocessing_pyspark_image_url : str
GSProcessing SageMaker PySpark ECR image URL.
"""

execution_role: str
region: str
Expand All @@ -50,7 +64,25 @@ class AWSConfig:

@dataclass
class InstanceConfig:
"""Configuration for SageMaker instances"""
"""Configuration for SageMaker instances.
Parameters
----------
train_infer_instance_count : int
Number of worker instances/partitions for partition, training, inference.
cpu_instance_type : str
CPU instance type.
gpu_instance_type : str
GPU instance type.
graph_construction_instance_type : str
Instance type for graph construction.
gsprocessing_instance_count : int
Number of GSProcessing instances.
train_on_cpu : bool
Whether to run training and inference on CPU instances.
volume_size_gb : int
Additional volume size for SageMaker instances in GB.
"""

train_infer_instance_count: int
cpu_instance_type: str
Expand Down Expand Up @@ -87,7 +119,25 @@ def __post_init__(self):

@dataclass
class TaskConfig:
"""Pipeline/task-level configuration"""
"""Pipeline/task-level configuration.
Parameters
----------
base_job_name : str
Base job name for SageMaker jobs.
graph_name : str
Name of the graph.
input_data_s3 : str
S3 path to the input graph data.
jobs_to_run : List[str]
List of jobs to run in the pipeline.
log_level : str
Logging level for the jobs.
output_prefix : str
S3 prefix for the output data.
pipeline_name : str
Name for the pipeline.
"""

base_job_name: str
graph_name: str
Expand Down Expand Up @@ -129,15 +179,33 @@ def __post_init__(self):

@dataclass
class GraphConstructionConfig:
"""Configuration for the graph construction step"""
"""Configuration for the graph construction step.
Parameters
----------
config_filename : str
Filename for the graph construction config.
graph_construction_args : str
Parameters to be passed directly to the GConstruct job.
"""

config_filename: str
graph_construction_args: str


@dataclass
class PartitionConfig:
"""Configuration for the partition step"""
"""Configuration for the partition step.
Parameters
----------
partition_algorithm : str
Partitioning algorithm.
input_json_filename : str
Name for the JSON file that describes the input data for partitioning.
output_json_filename : str
Name for the output JSON file that describes the partitioned data.
"""

partition_algorithm: str
input_json_filename: str
Expand All @@ -146,7 +214,21 @@ class PartitionConfig:

@dataclass
class TrainingConfig:
"""Configuration for the training step"""
"""Configuration for the training step.
Parameters
----------
model_output_path : str
S3 path for model output.
train_inference_task : str
Task type for training and inference.
train_yaml_file : str
S3 path to train YAML configuration file.
num_trainers : int
Number of trainers to use during training/inference.
use_graphbolt_str : str
Whether to use GraphBolt ('true' or 'false').
"""

model_output_path: str
train_inference_task: str
Expand All @@ -160,7 +242,19 @@ def __post_init__(self):

@dataclass
class InferenceConfig:
"""Configuration for the inference step"""
"""Configuration for the inference step.
Parameters
----------
save_embeddings : bool
Whether to save embeddings to S3 during inference.
save_predictions : bool
Whether to save predictions to S3 during inference.
inference_model_snapshot : str
Which model snapshot to choose to run inference with.
inference_yaml_file : str
S3 path to inference YAML configuration file.
"""

save_embeddings: bool
save_predictions: bool
Expand All @@ -170,7 +264,23 @@ class InferenceConfig:

@dataclass
class ScriptPaths:
"""Entry point script locations"""
"""Entry point script locations.
Parameters
----------
dist_part_script : str
Path to DistPartition SageMaker entry point script.
gb_convert_script : str
Path to GraphBolt partition script.
train_script : str
Path to training SageMaker entry point script.
inference_script : str
Path to inference SageMaker entry point script.
gconstruct_script : str
Path to GConstruct SageMaker entry point script.
gsprocessing_script : str
Path to GSProcessing SageMaker entry point script.
"""

dist_part_script: str
gb_convert_script: str
Expand All @@ -187,25 +297,25 @@ class PipelineArgs:
Parameters
----------
aws_config : AWSConfig
AWS configuration settings
AWS configuration settings.
graph_construction_config : GraphConstructionConfig
Graph construction configuration
Graph construction configuration.
instance_config : InstanceConfig
Instance configuration settings
Instance configuration settings.
task_config : TaskConfig
Task-level configuration settings
Task-level configuration settings.
partition_config : PartitionConfig
Partition configuration settings
Partition configuration settings.
training_config : TrainingConfig
Training configuration settings
Training configuration settings.
inference_config : InferenceConfig
Inference configuration settings
Inference configuration settings.
script_paths : ScriptPaths
Paths to SageMaker entry point scripts
Paths to SageMaker entry point scripts.
step_cache_expiration : str
Cache expiration for pipeline steps
Cache expiration for pipeline steps.
update : bool
Whether to update existing pipeline or create a new one
Whether to update existing pipeline or create a new one.
"""

aws_config: AWSConfig
Expand Down Expand Up @@ -280,23 +390,11 @@ def __post_init__(self):
"when running graph construction."
)

# TODO: When using GConstruct+DistPart/GBConvert (possible?) provide
# the correct partition input JSON filename?
# if "gconstruct" in self.task_config.jobs_to_run:
# if "dist_part" in self.task_config.jobs_to_run or "gb_convert" in self.task_config.jobs_to_run:
# if self.partition_config.input_json_filename != f"{self.task_config.graph_name}.json":
# logging.warning(
# "When running GConstruct, the partition input JSON "
# "filename should be '<graph_name>.json'. "
# "Got %s, setting to %s instead",
# self.partition_config.input_json_filename,
# f"{self.task_config.graph_name}.json",
# )
# self.partition_config.input_json_filename = f"{self.task_config.graph_name}.json"


# When using DistPart and GBConvert ensure the metadata.json filename is used
if "dist_part" in self.task_config.jobs_to_run and "gb_convert" in self.task_config.jobs_to_run:
if (
"dist_part" in self.task_config.jobs_to_run
and "gb_convert" in self.task_config.jobs_to_run
):
if self.partition_config.output_json_filename != "metadata.json":
logging.warning(
"When running DistPart or GBConvert, the partition output JSON "
Expand All @@ -306,7 +404,6 @@ def __post_init__(self):
)
self.partition_config.output_json_filename = "metadata.json"


# Ensure we have a GSProcessing image to run GSProcessing
if "gsprocessing" in self.task_config.jobs_to_run:
assert (
Expand Down Expand Up @@ -349,8 +446,8 @@ def __post_init__(self):

# GConstruct uses 'metis', so just translate that if needed
if (
"gconstruct" in self.task_config.jobs_to_run and
self.partition_config.partition_algorithm.lower() == "parmetis"
"gconstruct" in self.task_config.jobs_to_run
and self.partition_config.partition_algorithm.lower() == "parmetis"
):
self.partition_config.partition_algorithm = "metis"

Expand Down

0 comments on commit d980349

Please sign in to comment.