diff --git a/pkg/srun/api.go b/pkg/srun/api.go index ef00a8a..733be1e 100644 --- a/pkg/srun/api.go +++ b/pkg/srun/api.go @@ -10,8 +10,10 @@ import ( "math/rand" "net/http" "net/url" + "regexp" "strings" "time" + "unsafe" ) type Api struct { @@ -93,7 +95,26 @@ func (a *Api) GetUserInfo() (map[string]interface{}, error) { return a.request("cgi-bin/rad_user_info", nil) } -func (a *Api) FollowRedirect(addr *url.URL, onNext func(addr *url.URL) error) (*url.URL, error) { +func (a *Api) _JoinRedirectLocation(addr *url.URL, loc string) (*url.URL, error) { + if loc == "" { + return nil, errors.New("目标跳转地址缺失") + } + if strings.HasPrefix(loc, "/") { + addr.Path = strings.TrimPrefix(loc, "/") + return addr, nil + } else { + return url.Parse(loc) + } +} + +type _FollowRedirectConfig struct { + // 覆盖响应处理逻辑,设置后 onNextAddr 无效 + onResponse func(res *http.Response) (next *url.URL, err error) + // 获取到下一个请求地址时触发 + onNextAddr func(addr *url.URL) error +} + +func (a *Api) _FollowRedirect(addr *url.URL, conf _FollowRedirectConfig) (*url.URL, error) { addrCopy := *addr addr = &addrCopy for { @@ -109,26 +130,31 @@ func (a *Api) FollowRedirect(addr *url.URL, onNext func(addr *url.URL) error) (* if err != nil { return nil, err } + if conf.onResponse != nil { + var nextAddr *url.URL + nextAddr, err = conf.onResponse(res) + if err != nil { + return nil, err + } else if nextAddr == nil { + break + } + addr = nextAddr + continue + } _, _ = io.Copy(io.Discard, res.Body) _ = res.Body.Close() - loc := res.Header.Get("location") if res.StatusCode < 300 { break } else if res.StatusCode < 400 { - if loc == "" { - return nil, errors.New("目标跳转地址缺失") + addr, err = a._JoinRedirectLocation(addr, res.Header.Get("location")) + if err != nil { + return nil, err } - if strings.HasPrefix(loc, "/") { - addr.Path = strings.TrimPrefix(loc, "/") - } else { - addr, err = url.Parse(loc) - if err != nil { + if conf.onNextAddr != nil { + if err = conf.onNextAddr(addr); err != nil { return nil, err } } - if err = onNext(addr); err != nil { - return nil, err - } } else { return nil, fmt.Errorf("server return http status %d", res.StatusCode) } @@ -136,7 +162,7 @@ func (a *Api) FollowRedirect(addr *url.URL, onNext func(addr *url.URL) error) (* return addr, nil } -func (a *Api) searchAcid(query url.Values) (string, bool) { +func (a *Api) _SearchAcid(query url.Values) (string, bool) { addr := query.Get(`ac_id`) return addr, addr != "" } @@ -150,13 +176,15 @@ func (a *Api) DetectAcid() (string, error) { var AcidFound = errors.New("acid found") var acid string - _, err = a.FollowRedirect(baseUrl, func(addr *url.URL) error { - var ok bool - acid, ok = a.searchAcid(addr.Query()) - if ok { - return AcidFound - } - return nil + _, err = a._FollowRedirect(baseUrl, _FollowRedirectConfig{ + onNextAddr: func(addr *url.URL) error { + var ok bool + acid, ok = a._SearchAcid(addr.Query()) + if ok { + return AcidFound + } + return nil + }, }) if err != nil { if errors.Is(err, AcidFound) { @@ -175,15 +203,36 @@ func (a *Api) Reality(addr string, getAcid bool) (acid string, online bool, err } var AlreadyOnline = errors.New("already online") var finalUrl *url.URL - finalUrl, err = a.FollowRedirect(startUrl, func(addr *url.URL) error { - // 任一跳转没有跳出初始域名说明已经在线 - if addr.Host == startUrl.Host { - return AlreadyOnline - } - if getAcid { - acid, _ = a.searchAcid(addr.Query()) - } - return nil + finalUrl, err = a._FollowRedirect(startUrl, _FollowRedirectConfig{ + onResponse: func(res *http.Response) (next *url.URL, err error) { + defer res.Body.Close() + if res.StatusCode < 300 { + var body []byte + body, err = io.ReadAll(res.Body) + if err != nil { + return + } + + var reg *regexp.Regexp + reg, err = regexp.Compile(``) + if err != nil { + return + } + + result := reg.FindSubmatch(body) + if len(result) == 2 { + nextBytes := result[1] + nextAddr := unsafe.String(unsafe.SliceData(nextBytes), len(nextBytes)) + next, err = url.Parse(nextAddr) + } + } else if res.StatusCode < 400 { + next, err = a._JoinRedirectLocation(res.Request.URL, res.Header.Get("location")) + } + if getAcid && next != nil { + acid, _ = a._SearchAcid(next.Query()) + } + return + }, }) if err != nil { if errors.Is(err, AlreadyOnline) {