From 55ebcba83109d134444ff6997cc718a5b7d31982 Mon Sep 17 00:00:00 2001 From: Mmx233 Date: Mon, 1 Jan 2024 15:17:04 +0800 Subject: [PATCH] =?UTF-8?q?improve:=20=E4=BC=98=E5=8C=96=20acid=20?= =?UTF-8?q?=E5=97=85=E6=8E=A2=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/srun/api.go | 65 +++++++++++++++++++++++++++++++++---------------- 1 file changed, 44 insertions(+), 21 deletions(-) diff --git a/pkg/srun/api.go b/pkg/srun/api.go index 628bd3c..f79de8a 100644 --- a/pkg/srun/api.go +++ b/pkg/srun/api.go @@ -2,6 +2,7 @@ package srun import ( "encoding/json" + "errors" "fmt" "github.com/Mmx233/tool" log "github.com/sirupsen/logrus" @@ -92,46 +93,68 @@ func (a *Api) GetUserInfo() (map[string]interface{}, error) { return a.request("cgi-bin/rad_user_info", nil) } -func (a *Api) DetectAcid() (string, error) { - addr := a.BaseUrl +func (a *Api) FollowRedirect(addr *url.URL, onNext func(addr *url.URL) error) (*url.URL, error) { + addrCopy := *addr + addr = &addrCopy for { log.Debugln("HTTP GET ", addr) - req, err := http.NewRequest("GET", addr, nil) + req, err := http.NewRequest("GET", addr.String(), nil) if err != nil { - return "", err + return nil, err } for k, v := range a.CustomHeader { req.Header.Set(k, fmt.Sprint(v)) } res, err := a.NoDirect.Do(req) if err != nil { - return "", err + return nil, err } _, _ = io.Copy(io.Discard, res.Body) _ = res.Body.Close() loc := res.Header.Get("location") - if res.StatusCode == 302 && loc != "" { + if res.StatusCode < 300 { + break + } else if res.StatusCode < 400 { + if loc == "" { + return nil, errors.New("目标跳转地址缺失") + } if strings.HasPrefix(loc, "/") { - addr = a.BaseUrl + strings.TrimPrefix(loc, "/") + addr.Path = strings.TrimPrefix(loc, "/") } else { - addr = loc + addr, err = url.Parse(loc) + if err != nil { + return nil, err + } } - - var u *url.URL - u, err = url.Parse(addr) - if err != nil { - return "", err + if err = onNext(addr); err != nil { + return nil, err } - acid := u.Query().Get(`ac_id`) - if acid != "" { - return acid, nil - } - - continue + } else { + return nil, fmt.Errorf("server return http status %d", res.StatusCode) } - break } - return "", ErrAcidCannotFound + return addr, nil +} + +func (a *Api) DetectAcid() (string, error) { + baseUrl, err := url.Parse(a.BaseUrl) + if err != nil { + return "", err + } + + var AcidFound = errors.New("acid found") + var acid string + _, err = a.FollowRedirect(baseUrl, func(addr *url.URL) error { + acid = addr.Query().Get(`ac_id`) + if acid != "" { + return AcidFound + } + return nil + }) + if err != nil && !errors.Is(err, AcidFound) { + return "", err + } + return acid, nil } type LoginRequest struct {