-
Notifications
You must be signed in to change notification settings - Fork 54
/
Copy pathconfig.py
73 lines (59 loc) · 2.56 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from dataclasses import dataclass, field
from logging import getLogger
from typing import Any, Dict
from ..config import ScenarioConfig
LOGGER = getLogger("training")
TRAINING_ARGUMENT = {
"per_device_train_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": "./trainer_output",
"evaluation_strategy": "no",
"eval_strategy": "no",
"save_strategy": "no",
"do_train": True,
"use_cpu": False,
"max_steps": -1,
# disable evaluation
"do_eval": False,
"do_predict": False,
# disable custom logging
"report_to": "none",
# disbale transformers memory metrics
"skip_memory_metrics": True,
# from pytorch warning: "this flag results in an extra traversal of the
# autograd graph every iteration which can adversely affect performance."
"ddp_find_unused_parameters": False,
}
DATASET_SHAPES = {"dataset_size": 500, "sequence_length": 16, "num_choices": 1}
@dataclass
class TrainingConfig(ScenarioConfig):
name: str = "training"
_target_: str = "optimum_benchmark.scenarios.training.scenario.TrainingScenario"
# training options
max_steps: int = 140
warmup_steps: int = 40
# dataset options
dataset_shapes: Dict[str, Any] = field(default_factory=dict)
# training options
training_arguments: Dict[str, Any] = field(default_factory=dict)
# tracking options
latency: bool = field(default=True, metadata={"help": "Measure latencies and throughputs"})
memory: bool = field(default=False, metadata={"help": "Measure max memory usage"})
energy: bool = field(default=False, metadata={"help": "Measure energy usage"})
def __post_init__(self):
super().__post_init__()
self.dataset_shapes = {**DATASET_SHAPES, **self.dataset_shapes}
self.training_arguments = {**TRAINING_ARGUMENT, **self.training_arguments}
if self.training_arguments["max_steps"] == -1:
self.training_arguments["max_steps"] = self.max_steps
if self.max_steps != self.training_arguments["max_steps"]:
LOGGER.warning(
f"`scenario.max_steps` ({self.max_steps}) and `scenario.training_arguments.max_steps` "
f"({self.training_arguments['max_steps']}) are different. "
"Using `scenario.training_arguments.max_steps`."
)
self.max_steps = self.training_arguments["max_steps"]
if self.warmup_steps > self.max_steps:
raise ValueError(
f"`scenario.warmup_steps` ({self.warmup_steps}) must be smaller than `scenario.max_steps` ({self.max_steps})"
)