From 5197ccf93cb165eda01ad2e6d0a3c33b72ffe8b7 Mon Sep 17 00:00:00 2001 From: admiralhr99 Date: Sun, 4 Aug 2024 16:54:21 +0330 Subject: [PATCH] Changes added --- main.go | 205 +++++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 160 insertions(+), 45 deletions(-) diff --git a/main.go b/main.go index f53dde0..db09a81 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main import ( + "bytes" "context" "encoding/json" "flag" @@ -27,6 +28,11 @@ type PullRequest struct { } `json:"head"` } +type PullRequestData struct { + LastRun time.Time `json:"last_run"` + PRs []PullRequest `json:"prs"` +} + func main() { owner := "projectdiscovery" repo := "nuclei-templates" @@ -39,22 +45,49 @@ func main() { client := github.NewClient(nil) ctx := context.Background() - oneMonthAgo := time.Now().AddDate(0, -1, 0) + data, err := loadPreviousPRs(*outputFile) + if err != nil { + fmt.Fprintf(os.Stderr, "Error loading previous pull requests: %v\n", err) + data = &PullRequestData{LastRun: time.Now().AddDate(0, -1, 0), PRs: []PullRequest{}} + } - if err := fetchPullRequests(ctx, client, owner, repo, oneMonthAgo, *outputFile, *silent); err != nil { - fmt.Fprintf(os.Stderr, "Error fetching pull requests: %v\n", err) + newPRs, err := fetchNewPullRequests(ctx, client, owner, repo, data.LastRun, *silent) + if err != nil { + fmt.Fprintf(os.Stderr, "Error fetching new pull requests: %v\n", err) os.Exit(1) } - if *download { - if err := downloadYAMLFiles(*outputFile); err != nil { - fmt.Fprintf(os.Stderr, "Error downloading YAML files: %v\n", err) - os.Exit(1) + if len(newPRs) > 0 { + fmt.Println("New pull requests:") + for _, pr := range newPRs { + fmt.Printf("- %s\n", pr.Title) + } + + if *download { + if err := downloadYAMLFiles(newPRs); err != nil { + fmt.Fprintf(os.Stderr, "Error downloading YAML files: %v\n", err) + } + } + + data.PRs = append(newPRs, data.PRs...) + data.LastRun = time.Now() + if err := writeToFile(data, *outputFile); err != nil { + fmt.Fprintf(os.Stderr, "Error saving pull requests: %v\n", err) } + } else { + fmt.Println("No new pull requests found.") } } -func fetchPullRequests(ctx context.Context, client *github.Client, owner, repo string, since time.Time, outputFile string, silent bool) error { +func getPRSlice(prMap map[string]PullRequest) []PullRequest { + prSlice := make([]PullRequest, 0, len(prMap)) + for _, pr := range prMap { + prSlice = append(prSlice, pr) + } + return prSlice +} + +func fetchNewPullRequests(ctx context.Context, client *github.Client, owner, repo string, since time.Time, silent bool) ([]PullRequest, error) { opts := &github.PullRequestListOptions{ State: "open", Sort: "created", @@ -64,21 +97,17 @@ func fetchPullRequests(ctx context.Context, client *github.Client, owner, repo s }, } - var allPRs []PullRequest + var newPRs []PullRequest for { prs, resp, err := client.PullRequests.List(ctx, owner, repo, opts) if err != nil { - return err + return nil, err } for _, pr := range prs { - if pr.CreatedAt.Before(since) { - break - } - - if strings.Contains(strings.ToLower(*pr.Title), "cve") { - allPRs = append(allPRs, PullRequest{ + if pr.CreatedAt.After(since) && strings.Contains(strings.ToLower(*pr.Title), "cve") { + newPR := PullRequest{ Title: *pr.Title, CreatedAt: *pr.CreatedAt, HTMLURL: *pr.HTMLURL, @@ -92,11 +121,15 @@ func fetchPullRequests(ctx context.Context, client *github.Client, owner, repo s }{ SHA: *pr.Head.SHA, }, - }) - - if !silent { - fmt.Printf("New PR: %s\n", *pr.Title) } + + newPRs = append(newPRs, newPR) + //if !silent { + // fmt.Printf("New PR: %s\n", newPR.Title) + //} + } else if pr.CreatedAt.Before(since) || pr.CreatedAt.Equal(since) { + // We've reached PRs that are older than or equal to our last run, so we can stop + return newPRs, nil } } @@ -106,33 +139,96 @@ func fetchPullRequests(ctx context.Context, client *github.Client, owner, repo s opts.Page = resp.NextPage } - return writeToFile(allPRs, outputFile) + return newPRs, nil } -func writeToFile(prs []PullRequest, filename string) error { - file, err := os.Create(filename) +func loadPreviousPRs(filename string) (*PullRequestData, error) { + file, err := os.Open(filename) if err != nil { - return err + if os.IsNotExist(err) { + return &PullRequestData{LastRun: time.Now().AddDate(0, -1, 0), PRs: []PullRequest{}}, nil + } + return nil, err } defer file.Close() - encoder := json.NewEncoder(file) - encoder.SetIndent("", " ") - return encoder.Encode(prs) + var data PullRequestData + decoder := json.NewDecoder(file) + if err := decoder.Decode(&data); err != nil { + return nil, err + } + + return &data, nil } -func downloadYAMLFiles(inputFile string) error { - file, err := os.Open(inputFile) +//func fetchPullRequests(ctx context.Context, client *github.Client, owner, repo string, since time.Time, outputFile string, silent bool) error { +// opts := &github.PullRequestListOptions{ +// State: "open", +// Sort: "created", +// Direction: "desc", +// ListOptions: github.ListOptions{ +// PerPage: 100, +// }, +// } +// +// var allPRs []PullRequest +// +// for { +// prs, resp, err := client.PullRequests.List(ctx, owner, repo, opts) +// if err != nil { +// return err +// } +// +// for _, pr := range prs { +// if pr.CreatedAt.Before(since) { +// break +// } +// +// if strings.Contains(strings.ToLower(*pr.Title), "cve") { +// allPRs = append(allPRs, PullRequest{ +// Title: *pr.Title, +// CreatedAt: *pr.CreatedAt, +// HTMLURL: *pr.HTMLURL, +// User: struct { +// Login string `json:"login"` +// }{ +// Login: *pr.User.Login, +// }, +// Head: struct { +// SHA string `json:"sha"` +// }{ +// SHA: *pr.Head.SHA, +// }, +// }) +// +// if !silent { +// fmt.Printf("New PR: %s\n", *pr.Title) +// } +// } +// } +// +// if resp.NextPage == 0 { +// break +// } +// opts.Page = resp.NextPage +// } +// +// return writeToFile(allPRs, outputFile) +//} + +func writeToFile(data *PullRequestData, filename string) error { + file, err := os.Create(filename) if err != nil { return err } defer file.Close() - var prs []PullRequest - if err := json.NewDecoder(file).Decode(&prs); err != nil { - return err - } + encoder := json.NewEncoder(file) + encoder.SetIndent("", " ") + return encoder.Encode(data) +} +func downloadYAMLFiles(prs []PullRequest) error { paths := []string{ "http/cves", "network/cves", @@ -161,6 +257,7 @@ func downloadYAMLFiles(inputFile string) error { lowercaseFilename := strings.ToLower(cveUpper) + ".yaml" var downloadedURL string + var remoteContent []byte for _, path := range paths { url := fmt.Sprintf("https://raw.githubusercontent.com/projectdiscovery/nuclei-templates/%s/%s/%s/%s", pr.Head.SHA, @@ -168,8 +265,10 @@ func downloadYAMLFiles(inputFile string) error { year, filename) - if err := downloadFile(url, filename); err == nil { + content, err := fetchFileContent(url) + if err == nil { downloadedURL = url + remoteContent = content break } } @@ -180,8 +279,10 @@ func downloadYAMLFiles(inputFile string) error { pr.Head.SHA, filename) - if err := downloadFile(url, filename); err == nil { + content, err := fetchFileContent(url) + if err == nil { downloadedURL = url + remoteContent = content } } @@ -191,13 +292,23 @@ func downloadYAMLFiles(inputFile string) error { pr.Head.SHA, lowercaseFilename) - if err := downloadFile(url, filename); err == nil { + content, err := fetchFileContent(url) + if err == nil { downloadedURL = url + remoteContent = content } } if downloadedURL != "" { - fmt.Printf("Downloaded: %s\n", downloadedURL) + if shouldUpdateFile(filename, remoteContent) { + if err := os.WriteFile(filename, remoteContent, 0644); err != nil { + fmt.Printf("Failed to write file %s: %v\n", filename, err) + } else { + fmt.Printf("Updated: %s\n", downloadedURL) + } + } else { + fmt.Printf("Skipped: %s (No changes)\n", filename) + } } else { fmt.Printf("Failed to download %s from any path\n", filename) } @@ -235,23 +346,27 @@ func extractYear(cve string) string { return "" } -func downloadFile(url, filename string) error { +func fetchFileContent(url string) ([]byte, error) { resp, err := http.Get(url) if err != nil { - return err + return nil, err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return fmt.Errorf("bad status: %s", resp.Status) + return nil, fmt.Errorf("bad status: %s", resp.Status) } - out, err := os.Create(filename) + return io.ReadAll(resp.Body) +} + +func shouldUpdateFile(filename string, remoteContent []byte) bool { + localContent, err := os.ReadFile(filename) if err != nil { - return err + // File doesn't exist locally, should download + return true } - defer out.Close() - _, err = io.Copy(out, resp.Body) - return err + // Compare content + return !bytes.Equal(localContent, remoteContent) }