diff --git a/adapter/provider/provider.go b/adapter/provider/provider.go index 4381c24d9..79a752a65 100644 --- a/adapter/provider/provider.go +++ b/adapter/provider/provider.go @@ -71,19 +71,15 @@ func (pp *proxySetProvider) HealthCheck() { } func (pp *proxySetProvider) Update() error { - elm, same, err := pp.Fetcher.Update() - if err == nil && !same { - pp.OnUpdate(elm) - } + _, _, err := pp.Fetcher.Update() return err } func (pp *proxySetProvider) Initial() error { - elm, err := pp.Fetcher.Initial() + _, err := pp.Fetcher.Initial() if err != nil { return err } - pp.OnUpdate(elm) pp.getSubscriptionInfo() pp.closeAllConnections() return nil diff --git a/component/profile/cachefile/cache.go b/component/profile/cachefile/cache.go index e3da03699..0591c92b7 100644 --- a/component/profile/cachefile/cache.go +++ b/component/profile/cachefile/cache.go @@ -1,6 +1,7 @@ package cachefile import ( + "math" "os" "sync" "time" @@ -19,6 +20,7 @@ var ( bucketSelected = []byte("selected") bucketFakeip = []byte("fakeip") + bucketETag = []byte("etag") ) // CacheFile store and update the cache file @@ -143,6 +145,59 @@ func (c *CacheFile) FlushFakeIP() error { return err } +func (c *CacheFile) SetETagWithHash(url string, hash []byte, etag string) { + if c.DB == nil { + return + } + + lenHash := len(hash) + if lenHash > math.MaxUint8 { + return // maybe panic is better + } + + data := make([]byte, 1, 1+lenHash+len(etag)) + data[0] = uint8(lenHash) + data = append(data, hash...) + data = append(data, etag...) + + err := c.DB.Batch(func(t *bbolt.Tx) error { + bucket, err := t.CreateBucketIfNotExists(bucketETag) + if err != nil { + return err + } + + return bucket.Put([]byte(url), data) + }) + if err != nil { + log.Warnln("[CacheFile] write cache to %s failed: %s", c.DB.Path(), err.Error()) + return + } +} +func (c *CacheFile) GetETagWithHash(key string) (hash []byte, etag string) { + if c.DB == nil { + return + } + var value []byte + c.DB.View(func(t *bbolt.Tx) error { + if bucket := t.Bucket(bucketETag); bucket != nil { + if v := bucket.Get([]byte(key)); v != nil { + value = v + } + } + return nil + }) + if len(value) == 0 { + return + } + lenHash := int(value[0]) + if len(value) < 1+lenHash { + return + } + hash = value[1 : 1+lenHash] + etag = string(value[1+lenHash:]) + return +} + func (c *CacheFile) Close() error { return c.DB.Close() } diff --git a/component/resource/fetcher.go b/component/resource/fetcher.go index fec9fe771..0b15e6c32 100644 --- a/component/resource/fetcher.go +++ b/component/resource/fetcher.go @@ -1,9 +1,7 @@ package resource import ( - "bytes" "context" - "crypto/md5" "os" "path/filepath" "time" @@ -29,10 +27,10 @@ type Fetcher[V any] struct { name string vehicle types.Vehicle updatedAt time.Time - hash [16]byte + hash types.HashType parser Parser[V] interval time.Duration - OnUpdate func(V) + onUpdate func(V) watcher *fswatch.Watcher } @@ -54,92 +52,63 @@ func (f *Fetcher[V]) UpdatedAt() time.Time { func (f *Fetcher[V]) Initial() (V, error) { var ( - buf []byte - err error - isLocal bool - forceUpdate bool + buf []byte + contents V + err error ) if stat, fErr := os.Stat(f.vehicle.Path()); fErr == nil { + // local file exists, use it first buf, err = os.ReadFile(f.vehicle.Path()) modTime := stat.ModTime() - f.updatedAt = modTime - isLocal = true - if time.Since(modTime) > f.interval { - forceUpdate = true + contents, _, err = f.loadBuf(buf, types.MakeHash(buf), false) + f.updatedAt = modTime // reset updatedAt to file's modTime + + if err == nil { + err = f.startPullLoop(time.Since(modTime) > f.interval) + if err != nil { + return lo.Empty[V](), err + } + return contents, nil } - } else { - buf, err = f.vehicle.Read(f.ctx) - f.updatedAt = time.Now() } + // parse local file error, fallback to remote + contents, _, err = f.Update() + if err != nil { return lo.Empty[V](), err } - - contents, err := f.parser(buf) + err = f.startPullLoop(false) if err != nil { - if !isLocal { - return lo.Empty[V](), err - } - - // parse local file error, fallback to remote - buf, err = f.vehicle.Read(f.ctx) - if err != nil { - return lo.Empty[V](), err - } - - contents, err = f.parser(buf) - if err != nil { - return lo.Empty[V](), err - } - - isLocal = false - } - - if f.vehicle.Type() != types.File && !isLocal { - if err := safeWrite(f.vehicle.Path(), buf); err != nil { - return lo.Empty[V](), err - } - } - - f.hash = md5.Sum(buf) - - // pull contents automatically - if f.vehicle.Type() == types.File { - f.watcher, err = fswatch.NewWatcher(fswatch.Options{ - Path: []string{f.vehicle.Path()}, - Direct: true, - Callback: f.update, - }) - if err != nil { - return lo.Empty[V](), err - } - err = f.watcher.Start() - if err != nil { - return lo.Empty[V](), err - } - } else if f.interval > 0 { - go f.pullLoop(forceUpdate) + return lo.Empty[V](), err } - return contents, nil } func (f *Fetcher[V]) Update() (V, bool, error) { - buf, err := f.vehicle.Read(f.ctx) + buf, hash, err := f.vehicle.Read(f.ctx, f.hash) if err != nil { return lo.Empty[V](), false, err } - return f.SideUpdate(buf) + return f.loadBuf(buf, hash, f.vehicle.Type() != types.File) } func (f *Fetcher[V]) SideUpdate(buf []byte) (V, bool, error) { + return f.loadBuf(buf, types.MakeHash(buf), true) +} + +func (f *Fetcher[V]) loadBuf(buf []byte, hash types.HashType, updateFile bool) (V, bool, error) { now := time.Now() - hash := md5.Sum(buf) - if bytes.Equal(f.hash[:], hash[:]) { + if f.hash.Equal(hash) { + if updateFile { + _ = os.Chtimes(f.vehicle.Path(), now, now) + } f.updatedAt = now - _ = os.Chtimes(f.vehicle.Path(), now, now) + return lo.Empty[V](), true, nil + } + + if buf == nil { // f.hash has been changed between f.vehicle.Read but should not happen (cause by concurrent) return lo.Empty[V](), true, nil } @@ -148,15 +117,18 @@ func (f *Fetcher[V]) SideUpdate(buf []byte) (V, bool, error) { return lo.Empty[V](), false, err } - if f.vehicle.Type() != types.File { + if updateFile { if err = safeWrite(f.vehicle.Path(), buf); err != nil { return lo.Empty[V](), false, err } } - f.updatedAt = now f.hash = hash + if f.onUpdate != nil { + f.onUpdate(contents) + } + return contents, false, nil } @@ -176,7 +148,7 @@ func (f *Fetcher[V]) pullLoop(forceUpdate bool) { if forceUpdate { log.Warnln("[Provider] %s not updated for a long time, force refresh", f.Name()) - f.update(f.vehicle.Path()) + f.updateWithLog() } timer := time.NewTimer(initialInterval) @@ -185,15 +157,40 @@ func (f *Fetcher[V]) pullLoop(forceUpdate bool) { select { case <-timer.C: timer.Reset(f.interval) - f.update(f.vehicle.Path()) + f.updateWithLog() case <-f.ctx.Done(): return } } } -func (f *Fetcher[V]) update(path string) { - elm, same, err := f.Update() +func (f *Fetcher[V]) startPullLoop(forceUpdate bool) (err error) { + // pull contents automatically + if f.vehicle.Type() == types.File { + f.watcher, err = fswatch.NewWatcher(fswatch.Options{ + Path: []string{f.vehicle.Path()}, + Direct: true, + Callback: f.updateCallback, + }) + if err != nil { + return err + } + err = f.watcher.Start() + if err != nil { + return err + } + } else if f.interval > 0 { + go f.pullLoop(forceUpdate) + } + return +} + +func (f *Fetcher[V]) updateCallback(path string) { + f.updateWithLog() +} + +func (f *Fetcher[V]) updateWithLog() { + _, same, err := f.Update() if err != nil { log.Errorln("[Provider] %s pull error: %s", f.Name(), err.Error()) return @@ -205,9 +202,7 @@ func (f *Fetcher[V]) update(path string) { } log.Infoln("[Provider] %s's content update", f.Name()) - if f.OnUpdate != nil { - f.OnUpdate(elm) - } + return } func safeWrite(path string, buf []byte) error { @@ -230,7 +225,7 @@ func NewFetcher[V any](name string, interval time.Duration, vehicle types.Vehicl name: name, vehicle: vehicle, parser: parser, - OnUpdate: onUpdate, + onUpdate: onUpdate, interval: interval, } } diff --git a/component/resource/vehicle.go b/component/resource/vehicle.go index 74324e6d3..ccc59aece 100644 --- a/component/resource/vehicle.go +++ b/component/resource/vehicle.go @@ -9,6 +9,7 @@ import ( "time" mihomoHttp "github.com/metacubex/mihomo/component/http" + "github.com/metacubex/mihomo/component/profile/cachefile" types "github.com/metacubex/mihomo/constant/provider" ) @@ -28,8 +29,13 @@ func (f *FileVehicle) Url() string { return "file://" + f.path } -func (f *FileVehicle) Read(ctx context.Context) ([]byte, error) { - return os.ReadFile(f.path) +func (f *FileVehicle) Read(ctx context.Context, oldHash types.HashType) (buf []byte, hash types.HashType, err error) { + buf, err = os.ReadFile(f.path) + if err != nil { + return + } + hash = types.MakeHash(buf) + return } func (f *FileVehicle) Proxy() string { @@ -63,24 +69,49 @@ func (h *HTTPVehicle) Proxy() string { return h.proxy } -func (h *HTTPVehicle) Read(ctx context.Context) ([]byte, error) { +func (h *HTTPVehicle) Read(ctx context.Context, oldHash types.HashType) (buf []byte, hash types.HashType, err error) { ctx, cancel := context.WithTimeout(ctx, time.Second*20) defer cancel() - resp, err := mihomoHttp.HttpRequestWithProxy(ctx, h.url, http.MethodGet, h.header, nil, h.proxy) + header := h.header + setIfNoneMatch := false + if oldHash.IsValid() { + hashBytes, etag := cachefile.Cache().GetETagWithHash(h.url) + if oldHash.EqualBytes(hashBytes) && etag != "" { + if header == nil { + header = http.Header{} + } else { + header = header.Clone() + } + header.Set("If-None-Match", etag) + setIfNoneMatch = true + } + } + resp, err := mihomoHttp.HttpRequestWithProxy(ctx, h.url, http.MethodGet, header, nil, h.proxy) if err != nil { - return nil, err + return } defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode > 299 { - return nil, errors.New(resp.Status) + if setIfNoneMatch && resp.StatusCode == http.StatusNotModified { + return nil, oldHash, nil + } + err = errors.New(resp.Status) + return } - buf, err := io.ReadAll(resp.Body) + buf, err = io.ReadAll(resp.Body) if err != nil { - return nil, err + return } - return buf, nil + hash = types.MakeHash(buf) + cachefile.Cache().SetETagWithHash(h.url, hash.Bytes(), resp.Header.Get("ETag")) + return } func NewHTTPVehicle(url string, path string, proxy string, header http.Header) *HTTPVehicle { - return &HTTPVehicle{url, path, proxy, header} + return &HTTPVehicle{ + url: url, + path: path, + proxy: proxy, + header: header, + } } diff --git a/constant/provider/hash.go b/constant/provider/hash.go new file mode 100644 index 000000000..b95ffe234 --- /dev/null +++ b/constant/provider/hash.go @@ -0,0 +1,29 @@ +package provider + +import ( + "bytes" + "crypto/md5" +) + +type HashType [md5.Size]byte // MD5 + +func MakeHash(data []byte) HashType { + return md5.Sum(data) +} + +func (h HashType) Equal(hash HashType) bool { + return h == hash +} + +func (h HashType) EqualBytes(hashBytes []byte) bool { + return bytes.Equal(hashBytes, h[:]) +} + +func (h HashType) Bytes() []byte { + return h[:] +} + +func (h HashType) IsValid() bool { + var zero HashType + return h != zero +} diff --git a/constant/provider/interface.go b/constant/provider/interface.go index 9d24f6917..2c83a1b88 100644 --- a/constant/provider/interface.go +++ b/constant/provider/interface.go @@ -32,7 +32,7 @@ func (v VehicleType) String() string { } type Vehicle interface { - Read(ctx context.Context) ([]byte, error) + Read(ctx context.Context, oldHash HashType) (buf []byte, hash HashType, err error) Path() string Url() string Proxy() string diff --git a/rules/provider/provider.go b/rules/provider/provider.go index ad720d477..0cbf83bac 100644 --- a/rules/provider/provider.go +++ b/rules/provider/provider.go @@ -66,22 +66,12 @@ func (rp *ruleSetProvider) Type() P.ProviderType { } func (rp *ruleSetProvider) Initial() error { - elm, err := rp.Fetcher.Initial() - if err != nil { - return err - } - - rp.OnUpdate(elm) - return nil + _, err := rp.Fetcher.Initial() + return err } func (rp *ruleSetProvider) Update() error { - elm, same, err := rp.Fetcher.Update() - if err == nil && !same { - rp.OnUpdate(elm) - return nil - } - + _, _, err := rp.Fetcher.Update() return err }