pkg/eval: Deduplicate range implementations with generics.

This commit is contained in:
Qi Xiao 2022-08-08 12:28:47 +01:00
parent 551e246d96
commit 1b8ccdbdbc

View File

@ -791,21 +791,23 @@ func rangeFn(fm *Frame, opts rangeOpts, args ...vals.Num) error {
switch nums := nums.(type) {
case []int:
return rangeInt(nums, out)
return rangeBuiltinNum(nums, out)
case []*big.Int:
return rangeBigInt(nums, out)
return rangeBigNum(nums, out, bigIntDesc)
case []*big.Rat:
return rangeBitRat(nums, out)
return rangeBigNum(nums, out, bigRatDesc)
case []float64:
return rangeFloat64(nums, out)
return rangeBuiltinNum(nums, out)
default:
panic("unreachable")
}
}
func rangeInt(nums []int, out ValueOutput) error {
type builtinNum interface{ int | float64 }
func rangeBuiltinNum[T builtinNum](nums []T, out ValueOutput) error {
start, end := nums[0], nums[1]
var step int
var step T
if start <= end {
if len(nums) == 3 {
step = nums[2]
@ -848,61 +850,33 @@ func rangeInt(nums []int, out ValueOutput) error {
return nil
}
// TODO: Use type parameters to deduplicate this with rangeInt when Elvish
// requires Go 1.18.
func rangeFloat64(nums []float64, out ValueOutput) error {
start, end := nums[0], nums[1]
var step float64
if start <= end {
if len(nums) == 3 {
step = nums[2]
if step <= 0 {
return errs.BadValue{
What: "step", Valid: "positive", Actual: vals.ToString(step)}
}
} else {
step = 1
}
for cur := start; cur < end; cur += step {
err := out.Put(vals.FromGo(cur))
if err != nil {
return err
}
if cur+step <= cur {
break
}
}
} else {
if len(nums) == 3 {
step = nums[2]
if step >= 0 {
return errs.BadValue{
What: "step", Valid: "negative", Actual: vals.ToString(step)}
}
} else {
step = -1
}
for cur := start; cur > end; cur += step {
err := out.Put(vals.FromGo(cur))
if err != nil {
return err
}
if cur+step >= cur {
break
}
}
}
return nil
type bigNum[T any] interface {
Cmp(T) int
Sign() int
Add(T, T) T
}
var (
bigInt1 = big.NewInt(1)
bigIntNeg1 = big.NewInt(-1)
)
type bigNumDesc[T any] struct {
one T
negOne T
newZero func() T
}
func rangeBigInt(nums []*big.Int, out ValueOutput) error {
var bigIntDesc = bigNumDesc[*big.Int]{
one: big.NewInt(1),
negOne: big.NewInt(-1),
newZero: func() *big.Int { return &big.Int{} },
}
var bigRatDesc = bigNumDesc[*big.Rat]{
one: big.NewRat(1, 1),
negOne: big.NewRat(-1, 1),
newZero: func() *big.Rat { return &big.Rat{} },
}
func rangeBigNum[T bigNum[T]](nums []T, out ValueOutput, d bigNumDesc[T]) error {
start, end := nums[0], nums[1]
var step *big.Int
var step T
if start.Cmp(end) <= 0 {
if len(nums) == 3 {
step = nums[2]
@ -911,15 +885,15 @@ func rangeBigInt(nums []*big.Int, out ValueOutput) error {
What: "step", Valid: "positive", Actual: vals.ToString(step)}
}
} else {
step = bigInt1
step = d.one
}
var cur, next *big.Int
var cur, next T
for cur = start; cur.Cmp(end) < 0; cur = next {
err := out.Put(vals.FromGo(cur))
if err != nil {
return err
}
next = &big.Int{}
next = d.newZero()
next.Add(cur, step)
cur = next
}
@ -931,69 +905,15 @@ func rangeBigInt(nums []*big.Int, out ValueOutput) error {
What: "step", Valid: "negative", Actual: vals.ToString(step)}
}
} else {
step = bigIntNeg1
step = d.negOne
}
var cur, next *big.Int
var cur, next T
for cur = start; cur.Cmp(end) > 0; cur = next {
err := out.Put(vals.FromGo(cur))
if err != nil {
return err
}
next = &big.Int{}
next.Add(cur, step)
cur = next
}
}
return nil
}
var (
bigRat1 = big.NewRat(1, 1)
bigRatNeg1 = big.NewRat(-1, 1)
)
// TODO: Use type parameters to deduplicate this with rangeBitInt when Elvish
// requires Go 1.18.
func rangeBitRat(nums []*big.Rat, out ValueOutput) error {
start, end := nums[0], nums[1]
var step *big.Rat
if start.Cmp(end) <= 0 {
if len(nums) == 3 {
step = nums[2]
if step.Sign() <= 0 {
return errs.BadValue{
What: "step", Valid: "positive", Actual: vals.ToString(step)}
}
} else {
step = bigRat1
}
var cur, next *big.Rat
for cur = start; cur.Cmp(end) < 0; cur = next {
err := out.Put(vals.FromGo(cur))
if err != nil {
return err
}
next = &big.Rat{}
next.Add(cur, step)
cur = next
}
} else {
if len(nums) == 3 {
step = nums[2]
if step.Sign() >= 0 {
return errs.BadValue{
What: "step", Valid: "negative", Actual: vals.ToString(step)}
}
} else {
step = bigRatNeg1
}
var cur, next *big.Rat
for cur = start; cur.Cmp(end) > 0; cur = next {
err := out.Put(vals.FromGo(cur))
if err != nil {
return err
}
next = &big.Rat{}
next = d.newZero()
next.Add(cur, step)
cur = next
}