自定义 SageMaker 算法如何确定是否启用了检查点?

0

【以下的问题经过翻译处理】 根据SageMaker环境变量文档,算法应该将模型产物保存到由SM_MODEL_DIR 变量指定的文件夹中。

SageMaker容器文档描述了额外的环境变量,包括SM_OUTPUT_DATA_DIR用于写入非模型训练产物。

...但是算法如何确定是否已请求存储检查点?

在Amazon SageMaker中使用检查点文档只指定了一个默认本地保存路径,我找不到任何环境变量可以指示是否要保存检查点。我看到有一段代码检查默认本地路径是否存在,但我并不确定它是否有效(请求检查点时存在,不请求时不存在)。

将检查点参数化是件好事,可以避免在不需要的时候浪费EBS空间(和宝贵的IOPS);根据其他I/O(如模型和数据文件夹)的惯例,我认为SageMaker有特定的机制来传递这个指令,而不仅仅是定义一个算法超参数?

profile picture
专家
已提问 1 年前54 查看次数
1 回答
0

【以下的回答经过翻译处理】 您好,

对于自定义 Sagemaker 容器或深度学习框架,我倾向于这样做。以下是我尝试过的pytorch的例子

  • Entry point 文件:
# 1. Define a custom argument, say checkpointdir
 parser.add_argument("--checkpointdir", help="The checkpoint dir", type=str,
                     default=None)
# 2. You can additional params for checkpoint frequency etc

# 3. Code for checkpointing
if checkpointdir is not None:
   #TODO: save mode
  • Jupyter 笔记本示例 Sagemaker estimator
# 1. Define local and remote variables for checkpoints
checkpoint_s3 = "s3://{}/{}t/".format(bucket, "checkpoints")
localcheckpoint_dir="/opt/ml/checkpoints/"

hyperparameters = {

    "batchsize": "8",
    "epochs" : "1000",
    "learning_rate":.0001,
    "weight_decay":5e-5,
    "momentum":.9,
    "patience": 20,
    "log-level" : "INFO",
    "commit_id":commit_id,
    "model" :"FasterRcnnFactory",
    "accumulation_steps": 8,
# 2.  define hp for checkpoint dir
    "checkpointdir": localcheckpoint_dir
}

# In the Sagemaker estimator fit, specify the local and remote path
from sagemaker.pytorch import PyTorch

estimator = PyTorch(
     entry_point='experiment_train.py',
                    source_dir = 'src',
                    dependencies =['src/datasets', 'src/evaluators', 'src/models'],
                    role=role,
                    framework_version ="1.0.0",
                    py_version='py3',
                    git_config= git_config,
                    image_name= docker_repo,
                    train_instance_count=1,
                    train_instance_type=instance_type,
# 3. The entrypoint file will pick up the checkpoint location from here
                    hyperparameters =hyperparameters,
                    output_path=s3_output_path,
                    metric_definitions=metric_definitions,
                    train_use_spot_instances = use_spot,
                    train_max_run =  train_max_run_secs,
                    train_max_wait = max_wait_time_secs,   
                    base_job_name ="object-detection",
# 4. Sagemaker knows that the checkpoints will need to be periodically copied from the localcheckpoint_dir to s3 pointed to by checkpoint_s3
                    checkpoint_s3_uri=checkpoint_s3,
                    checkpoint_local_path=localcheckpoint_dir)
profile picture
专家
已回答 1 年前

您未登录。 登录 发布回答。

一个好的回答可以清楚地解答问题和提供建设性反馈,并能促进提问者的职业发展。

回答问题的准则