summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarc-Andre Lafortune <github@marc-andre.ca>2020-12-04 01:57:40 -0500
committerMarc-André Lafortune <github@marc-andre.ca>2020-12-05 00:56:58 -0500
commita83a51932dbc31b549e11b9da8967f2f52a8b07c (patch)
tree272eea99e5f40150af2a52114d7f8c9d8f325264
parent3b5b309b7b3724849c27dc1c836b5348a8a82e23 (diff)
[ruby/matrix] Optimize **
Avoiding recursive call would imply iterating bits starting from most significant, which is not easy to do efficiently. Any saving would be dwarfed by the multiplications anyways. [Feature #15233]
Notes
Notes: Merged: https://github.com/ruby/ruby/pull/3844
-rw-r--r--lib/matrix.rb53
-rw-r--r--test/matrix/test_matrix.rb6
2 files changed, 44 insertions, 15 deletions
diff --git a/lib/matrix.rb b/lib/matrix.rb
index 336a92877b..c6193ebee1 100644
--- a/lib/matrix.rb
+++ b/lib/matrix.rb
@@ -1233,26 +1233,49 @@ class Matrix
# # => 67 96
# # 48 99
#
- def **(other)
- case other
+ def **(exp)
+ case exp
when Integer
- x = self
- if other <= 0
- x = self.inverse
- return self.class.identity(self.column_count) if other == 0
- other = -other
- end
- z = nil
- loop do
- z = z ? z * x : x if other[0] == 1
- return z if (other >>= 1).zero?
- x *= x
+ case
+ when exp == 0
+ _make_sure_it_is_invertible = inverse
+ self.class.identity(column_count)
+ when exp < 0
+ inverse.power_int(-exp)
+ else
+ power_int(exp)
end
when Numeric
v, d, v_inv = eigensystem
- v * self.class.diagonal(*d.each(:diagonal).map{|e| e ** other}) * v_inv
+ v * self.class.diagonal(*d.each(:diagonal).map{|e| e ** exp}) * v_inv
+ else
+ raise ErrOperationNotDefined, ["**", self.class, exp.class]
+ end
+ end
+
+ protected def power_int(exp)
+ # assumes `exp` is an Integer > 0
+ #
+ # Previous algorithm:
+ # build M**2, M**4 = (M**2)**2, M**8, ... and multiplying those you need
+ # e.g. M**0b1011 = M**11 = M * M**2 * M**8
+ # ^ ^
+ # (highlighted the 2 out of 5 multiplications involving `M * x`)
+ #
+ # Current algorithm has same number of multiplications but with lower exponents:
+ # M**11 = M * (M * M**4)**2
+ # ^ ^ ^
+ # (highlighted the 3 out of 5 multiplications involving `M * x`)
+ #
+ # This should be faster for all (non nil-potent) matrices.
+ case
+ when exp == 1
+ self
+ when exp.odd?
+ self * power_int(exp - 1)
else
- raise ErrOperationNotDefined, ["**", self.class, other.class]
+ sqrt = power_int(exp / 2)
+ sqrt * sqrt
end
end
diff --git a/test/matrix/test_matrix.rb b/test/matrix/test_matrix.rb
index b134bfb3a1..8125fb2bcb 100644
--- a/test/matrix/test_matrix.rb
+++ b/test/matrix/test_matrix.rb
@@ -448,6 +448,12 @@ class TestMatrix < Test::Unit::TestCase
assert_equal(Matrix[[67,96],[48,99]], Matrix[[7,6],[3,9]] ** 2)
assert_equal(Matrix.I(5), Matrix.I(5) ** -1)
assert_raise(Matrix::ErrOperationNotDefined) { Matrix.I(5) ** Object.new }
+
+ m = Matrix[[0,2],[1,0]]
+ exp = 0b11101000
+ assert_equal(Matrix.scalar(2, 1 << (exp/2)), m ** exp)
+ exp = 0b11101001
+ assert_equal(Matrix[[0, 2 << (exp/2)], [1 << (exp/2), 0]], m ** exp)
end
def test_det