This is an official repository for our paper, Teaching Models to Understand (but not Generate) High-risk Data. The repository is organized by the figures and tables in the paper. Please refer to each accordingly.
- General Preparation
- Replicating Figure 2
- Replicating Table 2
- Replicating Figure 3
- Replicating Table 4
We recommend using a conda environment for this repository. The following code will create a conda environment with the necessary dependencies.
conda create -n decouple python=3.9 && conda activate decouple
pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu118
conda env update --name decouple --file environment.yml
cd OLMo && pip install -e .[all] && cd ..Toxic data is acquired from Pushshift Reddit snapshots Reddit Comments (RC) between March and December 2023 and Reddit Submissions (RS) between March and May 2023. These snapshots are not publically available, but can be torrented.
Pushshift snapshots should be saved as .zst files in a directory called data/documents. The following script extracts, tags, and filters documents from the December 2023 RC snapshot as an example (data/documents/RC_2023-03.zst).
bash preprocess_reddit.shThe script will output filtered toxic documents into data/toxic_reddit and non-toxic documents into data/non_toxic_reddit.
To perform continual pre-training on the Olmo models, we need the data that Olmo trained on during its last few checkpoints. The following code will download the data that Olmo 1B was exposed to from ckpt 737000 to ckpt 738020.
bash download_olmo_data.shWe download the Olmo ckpt 737000 model using the following bash script.
bash download_olmo_ckpt.shFor later evaluation, you should also download Olmo ckpt 738020 by changing the "checkpoint_num" variable in the script.
To convert an Olmo checkpoint into hf format, use the following script.
bash convert_to_hf.shWe first need to continually pre-train the Olmo model on the toxic data.
The following code will merge toxic reddit data into Dolma. Change the partition variable to create data variants to allow for confidence intervals. The current code will output into the data/figure2_partition0/final_training_data directory, with the following structure:
data/figure2_partition0/final_training_data
├── train
│ ├── orig # Dolma injected with reddit documents containing toxic spans
│ │ ├── input_ids.npy
│ │ ├── label_mask.npy # Mask indicating which tokens are toxic. (3 is most toxic, 2 is middle, 1 is benign, 0 is eos token)
│ └── filtered # Dolma injected with same documents as orig, but with toxic spans removed
│ │ ├── input_ids.npy
│ │ ├── label_mask.npy # Mask indicating which tokens are toxic. (3 is most toxic, 2 is middle, 1 is benign, 0 is eos token). Because this is filtered, all values are benign (1) or eos (0).
├── test
│ ├── unseen_data.jsonl # Unseen dolma data for later evaluation
bash figure2/prepare_figure2_trainingdata.shWe then train the following models on the training data. Please make sure to specify the correct "partition" and "mode".
bash figure2/train_figure2_olmo_continual.shTo replicate figure 2 (b), we proceed to fine-tune each partition-mode model variant on the Tulu dataset. First, convert the checkpoints to hf format. Then, follow the instructions in the file open-instruct/README.md to set up the Open-Instruct environment. Finally, execute the training.
bash convert_to_hf.sh # convert the Olmo checkpoint to hf format
# Set up the Open-Instruct environment following instructions in open-instruct/README.md.
export CUDA_HOME={path_to_your_conda_environment} # i.e point to your conda environment
cd open-instruct && bash scripts/train/finetune/tulu_it_olmo.sh && cd .. # start trainingWe then evaluate the model on CivilComments and RealToxicityPrompts. Note: For RealToxicityPrompts, you will need to obtain a Perspective API key and save it in the API_KEYS.py file.
bash figure2/eval_figure2.shTo plot the results, copy the results from the evaluation into the plotting/figure2.py directory to recreate the same plot.
We use the batches that were substituted out from Dolma to evaluate model performance on unseen dolma. This data is already stored in data/figure2_partition0/final_training_data/test/unseen_data.jsonl during our data processing from figure 2.
We now need to collect non-toxic Reddit documents that are unseen. These documents have also already been collected and stored in data/non_toxic_reddit/ when we ran bash preprocess_reddit.sh. Due to limited compute, we only used non-toxic data from RC_2023-12_extracted.jsonl in our experiments. To tokenized and chunk this data for evaluation, run the following script:
bash table2/prepare_table2_evaldata.shTo evaluate each of the models trained in Figure 2 on unseen dolma and unseen non-toxic reddit, run the following script. Make sure to specify the correct partition and mode (for different model variants, you have to specify the correct non-toxic Reddit document partition because each swapped out a different set of reddit documents).
bash table2/eval_table2.shTo best isolate the effect of toxic data quantity, we first perform a strict filtering of the existing Dolma dataset. In particular, we conduct the following:
# download dolma data seen from steps 735000 to 738020 to ensure that the data we have after filtering exceeds 1B tokens
bash figure3/download_figure3_olmo_data.sh
# perform strict toxicity filtering of Dolma data
bash figure3/run_filter_figure3_dolma_data.sh
# merge the filtered Dolma data with toxic Reddit data. Adjust "insert_data_percentage" to change the amount of toxic data injected into Dolma.
bash figure3/prepare_figure3_trainingdata.shThen, we download checkpoint 735000 to initialize continual pre-training from.
bash figure3/download_figure3_olmo_ckpt.sh
Finally, we launch the continual pre-training run.
bash figure3/train_figure3_olmo_continual.shFirst, convert the checkpoints to hf format.
We then evaluate the model on CivilComments and RealToxicityPrompts. Note: For RealToxicityPrompts, you will need to obtain a Perspective API key and save it in the API_KEYS.py file.
bash figure3/eval_figure3.shTo plot the results, copy the results from the evaluation into the plotting/figure3.py directory to recreate the same plot.
To run training, change the mode and seed variable in config/train_table4_hf.yaml and run the following script:
bash table4/train_table4_hf.shTo evaluate on the trained models, run the following script:
bash table4/eval_table4.sh