Clone the repository and install the required packages
git clone https://github.com/ivy3h/SCL.git
cd SCL
pip install -r requirements.txtBoth the datasets and models are listed here. Before running the code, please manually download the MMT-Bench dataset. All other datasets and models will be automatically downloaded during code execution.
To execute the intrinsic self-correction process, run the following command:
python inference.py --model [model name] --prompt [self-correction prompt] --dataset [evaluation dataset] --num_test [number of tasks]To construct preference data through the intrinsic self-correction process, run the following command:
python data_construction.py --model [model name] --prompt [self-correction prompt] --dataset [construction dataset]Our DPO code is based on SWIFT. To set the SWIFT environment, run the following commands:
git clone https://github.com/modelscope/swift.git
cd swift
pip install -e '.[llm]'
pip install -e '.[eval]'
pip install -r requirements/framework.txt -U
pip install -r requirements/llm.txt -USet the dataset path to customize the dataset before initiating the optimization process. For more detailed information, please refer to the SWIFT documentation. You can also explore additional alignment training methods.
Note
We recommend using the WebUI for training to enhance convenience and avoid potential bugs.
To execute the DPO, run the following command:
CUDA_VISIBLE_DEVICES=0,1 \
swift rlhf \
--rlhf_type dpo \
--model_type <model> \
--beta 0.1 \
--sft_beta 0.1 \
--sft_type lora \
--dataset <dataset> \
--num_train_epochs 3 \
--lora_target_modules DEFAULT \
--gradient_checkpointing true \
--batch_size 1 \
--learning_rate 5e-5 \
--gradient_accumulation_steps 16 \
--warmup_ratio 0.03 \
--save_total_limit 2Verify the file path of the trained model to locate the checkpoint directory. To execute the evaluation, run the following command:
CUDA_VISIBLE_DEVICES=0,1 \
swift eval \
--model_type Trained model \
--eval_dataset <dataset name> \
--eval_limit <evaluation limit> \
--ckpt_dir <checkpoint path> \
--log_file <output file path> \
--ignore_args_error true