Skip to content

Commit

Permalink
Changes added
Browse files Browse the repository at this point in the history
  • Loading branch information
admiralhr99 committed Aug 4, 2024
1 parent c57da98 commit 5197ccf
Showing 1 changed file with 160 additions and 45 deletions.
205 changes: 160 additions & 45 deletions main.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"bytes"
"context"
"encoding/json"
"flag"
Expand All @@ -27,6 +28,11 @@ type PullRequest struct {
} `json:"head"`
}

type PullRequestData struct {
LastRun time.Time `json:"last_run"`
PRs []PullRequest `json:"prs"`
}

func main() {
owner := "projectdiscovery"
repo := "nuclei-templates"
Expand All @@ -39,22 +45,49 @@ func main() {
client := github.NewClient(nil)
ctx := context.Background()

oneMonthAgo := time.Now().AddDate(0, -1, 0)
data, err := loadPreviousPRs(*outputFile)
if err != nil {
fmt.Fprintf(os.Stderr, "Error loading previous pull requests: %v\n", err)
data = &PullRequestData{LastRun: time.Now().AddDate(0, -1, 0), PRs: []PullRequest{}}
}

if err := fetchPullRequests(ctx, client, owner, repo, oneMonthAgo, *outputFile, *silent); err != nil {
fmt.Fprintf(os.Stderr, "Error fetching pull requests: %v\n", err)
newPRs, err := fetchNewPullRequests(ctx, client, owner, repo, data.LastRun, *silent)
if err != nil {
fmt.Fprintf(os.Stderr, "Error fetching new pull requests: %v\n", err)
os.Exit(1)
}

if *download {
if err := downloadYAMLFiles(*outputFile); err != nil {
fmt.Fprintf(os.Stderr, "Error downloading YAML files: %v\n", err)
os.Exit(1)
if len(newPRs) > 0 {
fmt.Println("New pull requests:")
for _, pr := range newPRs {
fmt.Printf("- %s\n", pr.Title)
}

if *download {
if err := downloadYAMLFiles(newPRs); err != nil {
fmt.Fprintf(os.Stderr, "Error downloading YAML files: %v\n", err)
}
}

data.PRs = append(newPRs, data.PRs...)
data.LastRun = time.Now()
if err := writeToFile(data, *outputFile); err != nil {
fmt.Fprintf(os.Stderr, "Error saving pull requests: %v\n", err)
}
} else {
fmt.Println("No new pull requests found.")
}
}

func fetchPullRequests(ctx context.Context, client *github.Client, owner, repo string, since time.Time, outputFile string, silent bool) error {
func getPRSlice(prMap map[string]PullRequest) []PullRequest {
prSlice := make([]PullRequest, 0, len(prMap))
for _, pr := range prMap {
prSlice = append(prSlice, pr)
}
return prSlice
}

func fetchNewPullRequests(ctx context.Context, client *github.Client, owner, repo string, since time.Time, silent bool) ([]PullRequest, error) {
opts := &github.PullRequestListOptions{
State: "open",
Sort: "created",
Expand All @@ -64,21 +97,17 @@ func fetchPullRequests(ctx context.Context, client *github.Client, owner, repo s
},
}

var allPRs []PullRequest
var newPRs []PullRequest

for {
prs, resp, err := client.PullRequests.List(ctx, owner, repo, opts)
if err != nil {
return err
return nil, err
}

for _, pr := range prs {
if pr.CreatedAt.Before(since) {
break
}

if strings.Contains(strings.ToLower(*pr.Title), "cve") {
allPRs = append(allPRs, PullRequest{
if pr.CreatedAt.After(since) && strings.Contains(strings.ToLower(*pr.Title), "cve") {
newPR := PullRequest{
Title: *pr.Title,
CreatedAt: *pr.CreatedAt,
HTMLURL: *pr.HTMLURL,
Expand All @@ -92,11 +121,15 @@ func fetchPullRequests(ctx context.Context, client *github.Client, owner, repo s
}{
SHA: *pr.Head.SHA,
},
})

if !silent {
fmt.Printf("New PR: %s\n", *pr.Title)
}

newPRs = append(newPRs, newPR)
//if !silent {
// fmt.Printf("New PR: %s\n", newPR.Title)
//}
} else if pr.CreatedAt.Before(since) || pr.CreatedAt.Equal(since) {
// We've reached PRs that are older than or equal to our last run, so we can stop
return newPRs, nil
}
}

Expand All @@ -106,33 +139,96 @@ func fetchPullRequests(ctx context.Context, client *github.Client, owner, repo s
opts.Page = resp.NextPage
}

return writeToFile(allPRs, outputFile)
return newPRs, nil
}

func writeToFile(prs []PullRequest, filename string) error {
file, err := os.Create(filename)
func loadPreviousPRs(filename string) (*PullRequestData, error) {
file, err := os.Open(filename)
if err != nil {
return err
if os.IsNotExist(err) {
return &PullRequestData{LastRun: time.Now().AddDate(0, -1, 0), PRs: []PullRequest{}}, nil
}
return nil, err
}
defer file.Close()

encoder := json.NewEncoder(file)
encoder.SetIndent("", " ")
return encoder.Encode(prs)
var data PullRequestData
decoder := json.NewDecoder(file)
if err := decoder.Decode(&data); err != nil {
return nil, err
}

return &data, nil
}

func downloadYAMLFiles(inputFile string) error {
file, err := os.Open(inputFile)
//func fetchPullRequests(ctx context.Context, client *github.Client, owner, repo string, since time.Time, outputFile string, silent bool) error {
// opts := &github.PullRequestListOptions{
// State: "open",
// Sort: "created",
// Direction: "desc",
// ListOptions: github.ListOptions{
// PerPage: 100,
// },
// }
//
// var allPRs []PullRequest
//
// for {
// prs, resp, err := client.PullRequests.List(ctx, owner, repo, opts)
// if err != nil {
// return err
// }
//
// for _, pr := range prs {
// if pr.CreatedAt.Before(since) {
// break
// }
//
// if strings.Contains(strings.ToLower(*pr.Title), "cve") {
// allPRs = append(allPRs, PullRequest{
// Title: *pr.Title,
// CreatedAt: *pr.CreatedAt,
// HTMLURL: *pr.HTMLURL,
// User: struct {
// Login string `json:"login"`
// }{
// Login: *pr.User.Login,
// },
// Head: struct {
// SHA string `json:"sha"`
// }{
// SHA: *pr.Head.SHA,
// },
// })
//
// if !silent {
// fmt.Printf("New PR: %s\n", *pr.Title)
// }
// }
// }
//
// if resp.NextPage == 0 {
// break
// }
// opts.Page = resp.NextPage
// }
//
// return writeToFile(allPRs, outputFile)
//}

func writeToFile(data *PullRequestData, filename string) error {
file, err := os.Create(filename)
if err != nil {
return err
}
defer file.Close()

var prs []PullRequest
if err := json.NewDecoder(file).Decode(&prs); err != nil {
return err
}
encoder := json.NewEncoder(file)
encoder.SetIndent("", " ")
return encoder.Encode(data)
}

func downloadYAMLFiles(prs []PullRequest) error {
paths := []string{
"http/cves",
"network/cves",
Expand Down Expand Up @@ -161,15 +257,18 @@ func downloadYAMLFiles(inputFile string) error {
lowercaseFilename := strings.ToLower(cveUpper) + ".yaml"

var downloadedURL string
var remoteContent []byte
for _, path := range paths {
url := fmt.Sprintf("https://raw.githubusercontent.com/projectdiscovery/nuclei-templates/%s/%s/%s/%s",
pr.Head.SHA,
path,
year,
filename)

if err := downloadFile(url, filename); err == nil {
content, err := fetchFileContent(url)
if err == nil {
downloadedURL = url
remoteContent = content
break
}
}
Expand All @@ -180,8 +279,10 @@ func downloadYAMLFiles(inputFile string) error {
pr.Head.SHA,
filename)

if err := downloadFile(url, filename); err == nil {
content, err := fetchFileContent(url)
if err == nil {
downloadedURL = url
remoteContent = content
}
}

Expand All @@ -191,13 +292,23 @@ func downloadYAMLFiles(inputFile string) error {
pr.Head.SHA,
lowercaseFilename)

if err := downloadFile(url, filename); err == nil {
content, err := fetchFileContent(url)
if err == nil {
downloadedURL = url
remoteContent = content
}
}

if downloadedURL != "" {
fmt.Printf("Downloaded: %s\n", downloadedURL)
if shouldUpdateFile(filename, remoteContent) {
if err := os.WriteFile(filename, remoteContent, 0644); err != nil {
fmt.Printf("Failed to write file %s: %v\n", filename, err)
} else {
fmt.Printf("Updated: %s\n", downloadedURL)
}
} else {
fmt.Printf("Skipped: %s (No changes)\n", filename)
}
} else {
fmt.Printf("Failed to download %s from any path\n", filename)
}
Expand Down Expand Up @@ -235,23 +346,27 @@ func extractYear(cve string) string {
return ""
}

func downloadFile(url, filename string) error {
func fetchFileContent(url string) ([]byte, error) {
resp, err := http.Get(url)
if err != nil {
return err
return nil, err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return fmt.Errorf("bad status: %s", resp.Status)
return nil, fmt.Errorf("bad status: %s", resp.Status)
}

out, err := os.Create(filename)
return io.ReadAll(resp.Body)
}

func shouldUpdateFile(filename string, remoteContent []byte) bool {
localContent, err := os.ReadFile(filename)
if err != nil {
return err
// File doesn't exist locally, should download
return true
}
defer out.Close()

_, err = io.Copy(out, resp.Body)
return err
// Compare content
return !bytes.Equal(localContent, remoteContent)
}

0 comments on commit 5197ccf

Please sign in to comment.