Skip to content

Commit 858ce68

Browse files
ArthurZuckerzucchini-nlphiyouga
authored
make it go brrrr (#38409)
* make it go brrrr * date time * update * fix * up * uppp * up * no number i * udpate * fix * [paligemma] fix processor with suffix (#38365) fix pg processor * [video utils] group and reorder by number of frames (#38374) fix * Fix convert to original state dict for VLMs (#38385) * fix convert to original state dict * fix * lint * Update modeling_utils.py * update * warn * no verbose * fginal * ouft * style --------- Co-authored-by: Raushan Turganbay <raushan@huggingface.co> Co-authored-by: hoshi-hiyouga <hiyouga@buaa.edu.cn>
1 parent ab5067e commit 858ce68

File tree

1 file changed

+111
-51
lines changed

1 file changed

+111
-51
lines changed

utils/patch_helper.py

Lines changed: 111 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -35,61 +35,121 @@
3535
```
3636
"""
3737

38-
import argparse
38+
import json
39+
import subprocess
3940

40-
from git import GitCommandError, Repo
41-
from packaging import version
41+
import transformers
4242

4343

44-
def get_merge_commit(repo, pr_number, since_tag):
44+
LABEL = "for patch" # Replace with your label
45+
REPO = "huggingface/transformers" # Optional if already in correct repo
46+
47+
48+
def get_release_branch_name():
49+
"""Derive branch name from transformers version."""
50+
major, minor, *_ = transformers.__version__.split(".")
51+
major = int(major)
52+
minor = int(minor)
53+
54+
if minor == 0:
55+
# Handle major version rollback, e.g., from 5.0 to 4.latest (if ever needed)
56+
major -= 1
57+
# You'll need logic to determine the last minor of the previous major version
58+
raise ValueError("Minor version is 0; need logic to find previous major version's last minor")
59+
else:
60+
minor -= 1
61+
62+
return f"v{major}.{minor}-release"
63+
64+
65+
def checkout_branch(branch):
66+
"""Checkout the target branch."""
4567
try:
46-
# Use git log to find the merge commit for the PR within the given tag range
47-
merge_commit = next(repo.iter_commits(f"v{since_tag}...origin/main", grep=f"#{pr_number}"))
48-
return merge_commit
49-
except StopIteration:
50-
print(f"No merge commit found for PR #{pr_number} between tags {since_tag} and {main}")
51-
return None
52-
except GitCommandError as e:
53-
print(f"Error finding merge commit for PR #{pr_number}: {str(e)}")
54-
return None
55-
56-
57-
def main(pr_numbers):
58-
repo = Repo(".") # Initialize the Repo object for the current directory
59-
merge_commits = []
60-
61-
tags = {}
62-
for tag in repo.tags:
63-
try:
64-
# Parse and sort tags, skip invalid ones
65-
tag_ver = version.parse(tag.name)
66-
tags[tag_ver] = tag
67-
except Exception:
68-
print(f"Skipping invalid version tag: {tag.name}")
69-
70-
last_tag = sorted(tags)[-1]
71-
major_minor = f"{last_tag.major}.{last_tag.minor}.0"
72-
# Iterate through tag ranges to find the merge commits
73-
for pr in pr_numbers:
74-
pr = pr.split("https://github.com/huggingface/transformers/pull/")[-1]
75-
commit = get_merge_commit(repo, pr, major_minor)
76-
if commit:
77-
merge_commits.append(commit)
78-
79-
# Sort commits by date
80-
merge_commits.sort(key=lambda commit: commit.committed_datetime)
81-
82-
# Output the git cherry-pick commands
83-
print("Git cherry-pick commands to run:")
84-
for commit in merge_commits:
85-
print(f"git cherry-pick {commit.hexsha} #{commit.committed_datetime}")
68+
subprocess.run(["git", "checkout", branch], check=True)
69+
print(f"✅ Checked out branch: {branch}")
70+
except subprocess.CalledProcessError:
71+
print(f"❌ Failed to checkout branch: {branch}. Does it exist?")
72+
exit(1)
8673

8774

88-
if __name__ == "__main__":
89-
parser = argparse.ArgumentParser(description="Find and sort merge commits for specified PRs.")
90-
parser.add_argument("--prs", nargs="+", required=False, type=str, help="PR numbers to find merge commits for")
75+
def get_prs_by_label(label):
76+
"""Call gh CLI to get PRs with a specific label."""
77+
cmd = [
78+
"gh",
79+
"pr",
80+
"list",
81+
"--label",
82+
label,
83+
"--state",
84+
"all",
85+
"--json",
86+
"number,title,mergeCommit,url",
87+
"--limit",
88+
"100",
89+
]
90+
result = subprocess.run(cmd, capture_output=True, text=True)
91+
result.check_returncode()
92+
prs = json.loads(result.stdout)
93+
for pr in prs:
94+
is_merged = pr.get("mergeCommit", {})
95+
if is_merged:
96+
pr["oid"] = is_merged.get("oid")
97+
return prs
98+
99+
100+
def get_commit_timestamp(commit_sha):
101+
"""Get UNIX timestamp of a commit using git."""
102+
result = subprocess.run(["git", "show", "-s", "--format=%ct", commit_sha], capture_output=True, text=True)
103+
result.check_returncode()
104+
return int(result.stdout.strip())
91105

92-
args = parser.parse_args()
93-
if args.prs is None:
94-
args.prs = "https://github.com/huggingface/transformers/pull/33753 https://github.com/huggingface/transformers/pull/33861 https://github.com/huggingface/transformers/pull/33906 https://github.com/huggingface/transformers/pull/33761 https://github.com/huggingface/transformers/pull/33586 https://github.com/huggingface/transformers/pull/33766 https://github.com/huggingface/transformers/pull/33958 https://github.com/huggingface/transformers/pull/33965".split()
95-
main(args.prs)
106+
107+
def cherry_pick_commit(sha):
108+
"""Cherry-pick a given commit SHA."""
109+
try:
110+
subprocess.run(["git", "cherry-pick", sha], check=True)
111+
print(f"✅ Cherry-picked commit {sha}")
112+
except subprocess.CalledProcessError:
113+
print(f"⚠️ Failed to cherry-pick {sha}. Manual intervention required.")
114+
115+
116+
def commit_in_history(commit_sha, base_branch="HEAD"):
117+
"""Return True if commit is already part of base_branch history."""
118+
result = subprocess.run(
119+
["git", "merge-base", "--is-ancestor", commit_sha, base_branch],
120+
stdout=subprocess.DEVNULL,
121+
stderr=subprocess.DEVNULL,
122+
)
123+
return result.returncode == 0
124+
125+
126+
def main(verbose=False):
127+
branch = get_release_branch_name()
128+
checkout_branch(branch)
129+
prs = get_prs_by_label(LABEL)
130+
# Attach commit timestamps
131+
for pr in prs:
132+
sha = pr.get("oid")
133+
if sha:
134+
pr["timestamp"] = get_commit_timestamp(sha)
135+
else:
136+
print("\n" + "=" * 80)
137+
print(f"⚠️ WARNING: PR #{pr['number']} ({sha}) is NOT in main!")
138+
print("⚠️ A core maintainer must review this before cherry-picking.")
139+
print("=" * 80 + "\n")
140+
# Sort by commit timestamp (ascending)
141+
prs = [pr for pr in prs if pr.get("timestamp") is not None]
142+
prs.sort(key=lambda pr: pr["timestamp"])
143+
for pr in prs:
144+
sha = pr.get("oid")
145+
if sha:
146+
if commit_in_history(sha):
147+
if verbose:
148+
print(f"🔁 PR #{pr['number']} ({pr['title']}) already in history. Skipping.")
149+
else:
150+
print(f"🚀 PR #{pr['number']} ({pr['title']}) not in history. Cherry-picking...")
151+
cherry_pick_commit(sha)
152+
153+
154+
if __name__ == "__main__":
155+
main()

0 commit comments

Comments
 (0)