Skip to content

Commit

Permalink
Compare train loss without sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
Ivan-Zhou committed May 27, 2024
1 parent 28c9720 commit e35bf7e
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 28 deletions.
24 changes: 16 additions & 8 deletions analysis/compare_marin_to_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,33 @@

DATA_DIR = "scratch/data"
OUTPUT_DIR = "scratch/output"
OLMO_1B_DATA_PATH = f"{DATA_DIR}/OLMo-1B.csv"
OLMO_1B_TRAIN_DATA_PATH = f"{DATA_DIR}/OLMo-1B_train.csv"
OLMO_1B_EVAL_DATA_PATH = f"{DATA_DIR}/OLMo-1B_eval.csv"
Path(OUTPUT_DIR).mkdir(exist_ok=True, parents=True)


def compare_marin_to_olmo(marin_run_id: str = "eo302w0523"):
marin_data_path = f"{DATA_DIR}/marin_{marin_run_id}.csv"
assert Path(marin_data_path).exists(), f"Marin data not found at {marin_data_path}"
assert Path(OLMO_1B_DATA_PATH).exists(), f"OLMo-1B data not found at {OLMO_1B_DATA_PATH}"
df_marin = pd.read_csv(marin_data_path)
df_olmo = pd.read_csv(OLMO_1B_DATA_PATH)
marin_train_data_path = f"{DATA_DIR}/marin_{marin_run_id}_train.csv"
marin_eval_data_path = f"{DATA_DIR}/marin_{marin_run_id}_eval.csv"
for path in [marin_train_data_path, marin_eval_data_path, OLMO_1B_TRAIN_DATA_PATH, OLMO_1B_EVAL_DATA_PATH]:
assert Path(path).exists(), f"File not found at {path}"

df_marin_train = pd.read_csv(marin_train_data_path)
df_marin_eval = pd.read_csv(marin_eval_data_path)
df_olmo_train = pd.read_csv(OLMO_1B_TRAIN_DATA_PATH)
df_olmo_eval = pd.read_csv(OLMO_1B_EVAL_DATA_PATH)

# Limit OLMo-1B data to the same steps as Marin data
max_step = df_marin["_step"].max()
max_step = df_marin_train["_step"].max()
print(f"Limiting OLMo-1B data to steps <= {max_step}")
df_olmo = df_olmo[df_olmo["_step"] <= max_step]
df_olmo_train = df_olmo_train[df_olmo_train["_step"] <= max_step]
df_olmo_eval = df_olmo_eval[df_olmo_eval["_step"] <= max_step]

# compare metrics
metrics_mapping = get_marin_olmo_metrics_mapping()
for marin_key, olmo_key in metrics_mapping.items():
df_marin = df_marin_train if "train" in marin_key else df_marin_eval
df_olmo = df_olmo_train if "train" in olmo_key else df_olmo_eval
if marin_key not in df_marin:
print(f"Missing key {marin_key} in Marin data")
continue
Expand Down
54 changes: 34 additions & 20 deletions analysis/wandb_parse_runs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path

from typing import List
import fire
import pandas as pd
import wandb
Expand All @@ -24,18 +24,23 @@ def check_missing_steps(df: pd.DataFrame, step_size: int = 1000):
print(f"Missing steps: {missing_steps}")


def get_all_olmo_runs(project_name: str = PROJECT_OLMO_1B):
def get_all_olmo_runs(project_name: str = PROJECT_OLMO_1B, target: str = "eval"):
# we need to extract train/loss and eval/loss separately; otherwise, the train loss will be sampled
if target == "eval":
keys = get_olmo_metrics_keys()
elif target == "train":
keys = ["train/CrossEntropyLoss"]
else:
raise ValueError(f"Unknown target: {target}")

api = wandb.Api(timeout=TIMEOUT)
path = f"{ENTITY_OLMO}/{project_name}"
runs = api.runs(path=path, per_page=1000)
olmo_keys = get_olmo_metrics_keys()
dfs = []
for _, run in tqdm(enumerate(runs)):
df_history = run.history(samples=N_SAMPLES, keys=olmo_keys)
if len(df_history) == 0:
print(f"Skipping run {run.id} with no data")
continue
dfs.append(df_history)
df = run.history(samples=N_SAMPLES, keys=keys)
if len(df) > 0:
dfs.append(df)

# merge all runs
df_all = pd.concat(dfs)
Expand All @@ -44,35 +49,44 @@ def get_all_olmo_runs(project_name: str = PROJECT_OLMO_1B):
min_step = df_all["_step"].min()

# check for steps
print(f"Found {len(dfs)} runs with {df_all.shape[0]} rows, steps: {min_step} - {max_step}")
check_missing_steps(df_all)
print(f"Found {len(dfs)} runs for {target} with {df_all.shape[0]} rows, steps: {min_step} - {max_step}")
step_size = 1000 if target == "eval" else 1
check_missing_steps(df_all, step_size=step_size)

# save to file
out_file = f"{OUT_DIR}/{project_name}.csv"
out_file = f"{OUT_DIR}/{project_name}_{target}.csv"
print(f"Saving {df_all.shape[0]} rows to {out_file}")
df_all.to_csv(out_file, index=False)


def smooth_column(df: pd.DataFrame, col: str, window_size: int = 256):
def smooth_column(df: pd.DataFrame, cols: List[str], window_size: int = 256):
"""Smooth a column with a rolling average"""
df[col] = df[col].rolling(window=window_size, center=True).mean()
for col in cols:
df[col] = df[col].rolling(window=window_size, center=True).mean()
return df


def get_marin_run(run_id: str):
def get_marin_run(run_id: str, target="eval", smooth: bool = False, window_size: int = 256):
if target == "eval":
keys = get_marin_metrics_keys()
elif target == "train":
keys = ["train/loss"]
api = wandb.Api(timeout=TIMEOUT)
path = f"{ENTITY_MARIN}/{PROJECT_MARIN}"
run = api.run(f"{path}/{run_id}")
keys = get_marin_metrics_keys()
df_history = run.history(samples=N_SAMPLES, keys=keys)
# df_history = run.history(samples=N_SAMPLES)
# df_history = smooth_column(df_history, "train/loss", window_size=512)

out_file = f"{OUT_DIR}/marin_{run_id}.csv"
if smooth and target == "train":
df_history = smooth_column(df_history, keys, window_size=window_size)

out_name = f"marin_{run_id}_{target}"
if smooth:
out_name += f"_smooth_{window_size}"
out_file = f"{OUT_DIR}/{out_name}.csv"
print(f"Saving {df_history.shape[0]} rows to {out_file}")
df_history.to_csv(out_file, index=False)


if __name__ == "__main__":
# get_all_olmo_runs()
# get_all_olmo_runs(target="train")
# get_all_olmo_runs(target="eval")
fire.Fire(get_marin_run)

0 comments on commit e35bf7e

Please sign in to comment.