improve: 优化 acid 嗅探逻辑

This commit is contained in:
Mmx233
2024-01-01 15:17:04 +08:00
parent 5e3e26ff93
commit 55ebcba831

View File

@@ -2,6 +2,7 @@ package srun
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"github.com/Mmx233/tool" "github.com/Mmx233/tool"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -92,46 +93,68 @@ func (a *Api) GetUserInfo() (map[string]interface{}, error) {
return a.request("cgi-bin/rad_user_info", nil) return a.request("cgi-bin/rad_user_info", nil)
} }
func (a *Api) DetectAcid() (string, error) { func (a *Api) FollowRedirect(addr *url.URL, onNext func(addr *url.URL) error) (*url.URL, error) {
addr := a.BaseUrl addrCopy := *addr
addr = &addrCopy
for { for {
log.Debugln("HTTP GET ", addr) log.Debugln("HTTP GET ", addr)
req, err := http.NewRequest("GET", addr, nil) req, err := http.NewRequest("GET", addr.String(), nil)
if err != nil { if err != nil {
return "", err return nil, err
} }
for k, v := range a.CustomHeader { for k, v := range a.CustomHeader {
req.Header.Set(k, fmt.Sprint(v)) req.Header.Set(k, fmt.Sprint(v))
} }
res, err := a.NoDirect.Do(req) res, err := a.NoDirect.Do(req)
if err != nil { if err != nil {
return "", err return nil, err
} }
_, _ = io.Copy(io.Discard, res.Body) _, _ = io.Copy(io.Discard, res.Body)
_ = res.Body.Close() _ = res.Body.Close()
loc := res.Header.Get("location") loc := res.Header.Get("location")
if res.StatusCode == 302 && loc != "" { if res.StatusCode < 300 {
if strings.HasPrefix(loc, "/") { break
addr = a.BaseUrl + strings.TrimPrefix(loc, "/") } else if res.StatusCode < 400 {
} else { if loc == "" {
addr = loc return nil, errors.New("目标跳转地址缺失")
} }
if strings.HasPrefix(loc, "/") {
addr.Path = strings.TrimPrefix(loc, "/")
} else {
addr, err = url.Parse(loc)
if err != nil {
return nil, err
}
}
if err = onNext(addr); err != nil {
return nil, err
}
} else {
return nil, fmt.Errorf("server return http status %d", res.StatusCode)
}
}
return addr, nil
}
var u *url.URL func (a *Api) DetectAcid() (string, error) {
u, err = url.Parse(addr) baseUrl, err := url.Parse(a.BaseUrl)
if err != nil { if err != nil {
return "", err return "", err
} }
acid := u.Query().Get(`ac_id`)
if acid != "" {
return acid, nil
}
continue var AcidFound = errors.New("acid found")
var acid string
_, err = a.FollowRedirect(baseUrl, func(addr *url.URL) error {
acid = addr.Query().Get(`ac_id`)
if acid != "" {
return AcidFound
} }
break return nil
})
if err != nil && !errors.Is(err, AcidFound) {
return "", err
} }
return "", ErrAcidCannotFound return acid, nil
} }
type LoginRequest struct { type LoginRequest struct {