Skip to content

Commit

Permalink
Refactor with a Context struct
Browse files Browse the repository at this point in the history
Unclear why the rpo tests are now failing.
  • Loading branch information
maddyblue committed Dec 20, 2016
1 parent 6dc58fe commit 32925ed
Show file tree
Hide file tree
Showing 10 changed files with 618 additions and 548 deletions.
477 changes: 477 additions & 0 deletions context.go

Large diffs are not rendered by default.

462 changes: 9 additions & 453 deletions decimal.go

Large diffs are not rendered by default.

10 changes: 7 additions & 3 deletions decimal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@ import (
"testing"
)

var (
testCtx = &Context{}
)

func (d *Decimal) GoString() string {
return fmt.Sprintf(`{Coeff: %s, Exponent: %d, MaxExponent: %d, MinExponent: %d, Precision: %d}`, d.Coeff.String(), d.Exponent, d.MaxExponent, d.MinExponent, d.Precision)
return fmt.Sprintf(`{Coeff: %s, Exponent: %d}`, d.Coeff.String(), d.Exponent)
}

func TestNewFromString(t *testing.T) {
Expand Down Expand Up @@ -135,7 +139,7 @@ func TestAdd(t *testing.T) {
x := newDecimal(t, tc.x)
y := newDecimal(t, tc.y)
d := new(Decimal)
err := d.Add(x, y)
err := testCtx.Add(d, x, y)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -224,7 +228,7 @@ func TestModf(t *testing.T) {
t.Fatalf("frac: expected: %s, got: %s", tc.f, frac)
}
a := new(Decimal)
if err := a.Add(integ, frac); err != nil {
if err := testCtx.Add(a, integ, frac); err != nil {
t.Fatal(err)
}
if c, err := a.Cmp(x); err != nil {
Expand Down
64 changes: 36 additions & 28 deletions error.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,38 @@

package apd

// NewErrDecimal creates a ErrDecimal with given context.
func NewErrDecimal(c *Context) *ErrDecimal {
return &ErrDecimal{
Ctx: c,
}
}

// ErrDecimal performs operations on decimals and collects errors during
// operations. If an error is already set, the operation is skipped. Designed to
// be used for many operations in a row, with a single error check at the end.
type ErrDecimal struct {
Err error
Ctx *Context
}

// Abs performs d.Abs(x).
// Abs performs e.Ctx.Abs(d, x).
func (e *ErrDecimal) Abs(d, x *Decimal) {
if e.Err != nil {
return
}
e.Err = d.Abs(x)
e.Err = e.Ctx.Abs(d, x)
}

// Add performs d.Add(x, y).
// Add performs e.Ctx.Add(d, x, y).
func (e *ErrDecimal) Add(d, x, y *Decimal) {
if e.Err != nil {
return
}
e.Err = d.Add(x, y)
e.Err = e.Ctx.Add(d, x, y)
}

// Cmp returns 0 if Err is set. Otherwise returns a.Cmp(b).
// Cmp returns 0 if Err is set. Otherwise returns e.Ctx.Cmp(a, b).
func (e *ErrDecimal) Cmp(a, b *Decimal) int {
if e.Err != nil {
return 0
Expand All @@ -47,12 +55,12 @@ func (e *ErrDecimal) Cmp(a, b *Decimal) int {
return c
}

// Exp performs d.Exp(x).
// Exp performs e.Ctx.Exp(d, x).
func (e *ErrDecimal) Exp(d, x *Decimal) {
if e.Err != nil {
return
}
e.Err = d.Exp(x)
e.Err = e.Ctx.Exp(d, x)
}

// Int64 returns 0 if Err is set. Otherwise returns d.Int64().
Expand All @@ -65,90 +73,90 @@ func (e *ErrDecimal) Int64(d *Decimal) int64 {
return r
}

// Ln performs d.Ln(x).
// Ln performs e.Ctx.Ln(d, x).
func (e *ErrDecimal) Ln(d, x *Decimal) {
if e.Err != nil {
return
}
e.Err = d.Ln(x)
e.Err = e.Ctx.Ln(d, x)
}

// Log10 performs d.Log10(x).
func (e *ErrDecimal) Log10(d, x *Decimal) {
if e.Err != nil {
return
}
e.Err = d.Log10(x)
e.Err = e.Ctx.Log10(d, x)
}

// Mul performs d.Mul(x, y).
// Mul performs e.Ctx.Mul(d, x, y).
func (e *ErrDecimal) Mul(d, x, y *Decimal) {
if e.Err != nil {
return
}
e.Err = d.Mul(x, y)
e.Err = e.Ctx.Mul(d, x, y)
}

// Neg performs d.Neg(x).
// Neg performs e.Ctx.Neg(d, x).
func (e *ErrDecimal) Neg(d, x *Decimal) {
if e.Err != nil {
return
}
e.Err = d.Neg(x)
e.Err = e.Ctx.Neg(d, x)
}

// Pow performs d.Pow(x, y).
// Pow performs e.Ctx.Pow(d, x, y).
func (e *ErrDecimal) Pow(d, x, y *Decimal) {
if e.Err != nil {
return
}
e.Err = d.Pow(x, y)
e.Err = e.Ctx.Pow(d, x, y)
}

// Quo performs d.Quo(x, y).
// Quo performs e.Ctx.Quo(d, x, y).
func (e *ErrDecimal) Quo(d, x, y *Decimal) {
if e.Err != nil {
return
}
e.Err = d.Quo(x, y)
e.Err = e.Ctx.Quo(d, x, y)
}

// QuoInteger performs d.QuoInteger(x, y).
// QuoInteger performs e.Ctx.QuoInteger(d, x, y).
func (e *ErrDecimal) QuoInteger(d, x, y *Decimal) {
if e.Err != nil {
return
}
e.Err = d.QuoInteger(x, y)
e.Err = e.Ctx.QuoInteger(d, x, y)
}

// Rem performs d.Rem(x, y).
// Rem performs e.Ctx.Rem(d, x, y).
func (e *ErrDecimal) Rem(d, x, y *Decimal) {
if e.Err != nil {
return
}
e.Err = d.Rem(x, y)
e.Err = e.Ctx.Rem(d, x, y)
}

// Round performs d.Round(x).
// Round performs e.Ctx.Round(d, x).
func (e *ErrDecimal) Round(d, x *Decimal) {
if e.Err != nil {
return
}
e.Err = d.Round(x)
e.Err = e.Ctx.Round(d, x)
}

// Sqrt performs d.Sqrt(x).
// Sqrt performs e.Ctx.Sqrt(d, x).
func (e *ErrDecimal) Sqrt(d, x *Decimal) {
if e.Err != nil {
return
}
e.Err = d.Sqrt(x)
e.Err = e.Ctx.Sqrt(d, x)
}

// Sub performs d.Sub(x, y).
// Sub performs e.Ctx.Sub(d, x, y).
func (e *ErrDecimal) Sub(d, x, y *Decimal) {
if e.Err != nil {
return
}
e.Err = d.Sub(x, y)
e.Err = e.Ctx.Sub(d, x, y)
}
20 changes: 19 additions & 1 deletion error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import "testing"
// Appease the unused test.
// TODO(mjibson): actually test all the ErrDecimal methods.
func TestErrDecimal(t *testing.T) {
var ed ErrDecimal
ed := NewErrDecimal(&Context{})
a := New(1, 0)
ed.Abs(a, a)
ed.Exp(a, a)
Expand All @@ -30,3 +30,21 @@ func TestErrDecimal(t *testing.T) {
ed.QuoInteger(a, a, a)
ed.Rem(a, a, a)
}

func TestNewErrDecimal(t *testing.T) {
c := &Context{
Precision: 5,
MaxExponent: 2,
}
nc := c.WithPrecision(c.Precision * 2)
ed := NewErrDecimal(&nc)
if ed.Ctx.Precision != 10 {
t.Fatalf("expected %d, got %d", 10, ed.Ctx.Precision)
}
if c.Precision != 5 {
t.Fatalf("expected %d, got %d", 5, c.Precision)
}
if c.MaxExponent != 2 {
t.Fatalf("expected %d, got %d", 2, c.MaxExponent)
}
}
46 changes: 26 additions & 20 deletions gda_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,10 +283,12 @@ func gdaTest(t *testing.T, name string) (int, int, int, int, int) {
operands[i] = newDecimal(t, o)
}
d := new(Decimal)
d.Precision = uint32(tc.Precision)
d.MaxExponent = int32(tc.MaxExponent)
d.MinExponent = int32(tc.MinExponent)
d.Rounding = mode
c := Context{
Precision: uint32(tc.Precision),
MaxExponent: int32(tc.MaxExponent),
MinExponent: int32(tc.MinExponent),
Rounding: mode,
}
// helpful acme address link
t.Logf("%s %s = %s (prec: %d, round: %s)", tc.Operation, strings.Join(tc.Operands, " "), tc.Result, tc.Precision, tc.Rounding)
start := time.Now()
Expand All @@ -299,37 +301,37 @@ func gdaTest(t *testing.T, name string) (int, int, int, int, int) {
go func() {
switch tc.Operation {
case "abs":
err = d.Abs(operands[0])
err = c.Abs(d, operands[0])
case "add":
err = d.Add(operands[0], operands[1])
err = c.Add(d, operands[0], operands[1])
case "compare":
var c int
c, err = operands[0].Cmp(operands[1])
d.SetInt64(int64(c))
case "divide":
err = d.Quo(operands[0], operands[1])
err = c.Quo(d, operands[0], operands[1])
case "divideint":
err = d.QuoInteger(operands[0], operands[1])
err = c.QuoInteger(d, operands[0], operands[1])
case "exp":
err = d.Exp(operands[0])
err = c.Exp(d, operands[0])
case "ln":
err = d.Ln(operands[0])
err = c.Ln(d, operands[0])
case "log10":
err = d.Log10(operands[0])
err = c.Log10(d, operands[0])
case "minus":
err = d.Neg(operands[0])
err = c.Neg(d, operands[0])
case "multiply":
err = d.Mul(operands[0], operands[1])
err = c.Mul(d, operands[0], operands[1])
case "plus":
err = d.Add(operands[0], decimalZero)
err = c.Add(d, operands[0], decimalZero)
case "power":
err = d.Pow(operands[0], operands[1])
err = c.Pow(d, operands[0], operands[1])
case "remainder":
err = d.Rem(operands[0], operands[1])
err = c.Rem(d, operands[0], operands[1])
case "squareroot":
err = d.Sqrt(operands[0])
err = c.Sqrt(d, operands[0])
case "subtract":
err = d.Sub(operands[0], operands[1])
err = c.Sub(d, operands[0], operands[1])
default:
done <- fmt.Errorf("unknown operation: %s", tc.Operation)
}
Expand Down Expand Up @@ -368,11 +370,11 @@ func gdaTest(t *testing.T, name string) (int, int, int, int, int) {
t.Fatalf("%+v", err)
}
r := newDecimal(t, tc.Result)
c, err := d.Cmp(r)
p, err := d.Cmp(r)
if err != nil {
t.Fatal(err)
}
if c != 0 {
if p != 0 {
if *flagPython {
if tc.CheckPython(t, d) {
return
Expand Down Expand Up @@ -671,6 +673,10 @@ var GDAignore = map[string]bool{
"pow2056": true,

// incorrect rounding
"rpo107": true,
"rpo213": true,
"rpo412": true,
"rpo507": true,
"rpo607": true,
"rpo707": true,
}
17 changes: 8 additions & 9 deletions loop.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
// on an *Decimal has converged. It was adapted from robpike.io/ivy/value's loop
// implementation.
type loop struct {
c *Context
name string // The name of the function we are evaluating.
i uint64 // Loop count.
maxIterations uint64 // When to give up.
Expand All @@ -40,39 +41,37 @@ const digitsToBitsRatio = math.Ln10 / math.Ln2
// of the function being evaluated, the argument to the function,
// and the desired scale of the result, and the iterations
// per bit.
func newLoop(name string, x *Decimal, itersPerBit int) *loop {
func (c *Context) newLoop(name string, x *Decimal, itersPerBit int) *loop {
bits := x.Coeff.BitLen()
incrPrec := float64(x.Precision) + float64(x.Exponent)
incrPrec := float64(c.Precision) + float64(x.Exponent)
if incrPrec > 0 {
bits += int(incrPrec * digitsToBitsRatio)
}
if scaleBits := int(float64(x.Precision) * digitsToBitsRatio); scaleBits > bits {
if scaleBits := int(float64(c.Precision) * digitsToBitsRatio); scaleBits > bits {
bits = scaleBits
}
l := &loop{
c: c,
name: name,
maxIterations: 10 + uint64(itersPerBit*bits),
}
l.arg = new(Decimal)
l.arg.Set(x)
l.stallThresh = New(1, -int32(x.Precision+1))
l.stallThresh = New(1, -int32(c.Precision+1))
l.prevZ = new(Decimal)
l.delta = new(Decimal)
p := x.Precision + 2
l.prevZ.Precision = p
l.delta.Precision = p
return l
}

// done reports whether the loop is done. If it does not converge
// after the maximum number of iterations, it returns an error.
func (l *loop) done(z *Decimal) (bool, error) {
l.delta.Sub(l.prevZ, z)
l.c.Sub(l.delta, l.prevZ, z)
switch l.delta.Sign() {
case 0:
return true, nil
case -1:
l.delta.Neg(l.delta)
l.c.Neg(l.delta, l.delta)
}
if c, err := l.delta.Cmp(l.stallThresh); err != nil {
return false, err
Expand Down
Loading

0 comments on commit 32925ed

Please sign in to comment.