From 00c194d1a47ae89ea0d343317d0079c4a8305d5e Mon Sep 17 00:00:00 2001 From: Peng Cheng Date: Sun, 27 Mar 2022 17:59:00 -0400 Subject: [PATCH] VectorOps.dot and MatrixOps.matMul are slightly faster --- .../scala/shapesafe/core/shape/Names.scala | 4 ++++ .../shapesafe/core/shape/ops/MatrixOps.scala | 20 ++++++++++++++----- .../shapesafe/core/shape/ops/VectorOps.scala | 11 +++++++--- 3 files changed, 27 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/shapesafe/core/shape/Names.scala b/core/src/main/scala/shapesafe/core/shape/Names.scala index 8686b938..090eb8f5 100644 --- a/core/src/main/scala/shapesafe/core/shape/Names.scala +++ b/core/src/main/scala/shapesafe/core/shape/Names.scala @@ -74,7 +74,11 @@ object Names extends Tuples with ApplyLiterals.ToNames { object Syntax extends Syntax val i = Names("i") + val ii = Names("i", "i") + val ij = Names("i", "j") val jk = Names("j", "k") + val ijjk = Names("i", "j", "j", "k") + val ik = Names("i", "k") } diff --git a/core/src/main/scala/shapesafe/core/shape/ops/MatrixOps.scala b/core/src/main/scala/shapesafe/core/shape/ops/MatrixOps.scala index ec07629f..5722ac13 100644 --- a/core/src/main/scala/shapesafe/core/shape/ops/MatrixOps.scala +++ b/core/src/main/scala/shapesafe/core/shape/ops/MatrixOps.scala @@ -1,15 +1,21 @@ package shapesafe.core.shape.ops import shapeless.Nat -import shapesafe.core.shape.{Index, Indices, Names, Shape} +import shapesafe.core.shape.Index.LtoR +import shapesafe.core.shape.{Indices, Names, Shape} trait MatrixOps extends HasShape { def matMul[THAT <: Shape](that: THAT) = { - val s1 = shape :<<= Names.ij - val s2 = that :<<= Names.jk +// val s1 = shape :<<= Names.ij +// val s2 = that :<<= Names.jk +// +// s1.einSum(s2) --> Names.ik - s1.einSum(s2) --> Names.ik + // TODO: above is slight slower, need to optimise + + val outer = (shape >< that) :<<= Names.ijjk + outer.einSum --> Names.ik } def mat_*[THAT <: Shape](that: THAT) = { @@ -17,7 +23,11 @@ trait MatrixOps extends HasShape { } def transpose = { - shape.requireNumDim(Nat._2).rearrangeBy(Indices & Index.LtoR(1) & Index.LtoR(0)) + shape + .requireNumDim(Nat._2) + .rearrangeBy( + Indices & LtoR(1) & LtoR(0) + ) } def `^T` = { diff --git a/core/src/main/scala/shapesafe/core/shape/ops/VectorOps.scala b/core/src/main/scala/shapesafe/core/shape/ops/VectorOps.scala index 25ecb567..12166e81 100644 --- a/core/src/main/scala/shapesafe/core/shape/ops/VectorOps.scala +++ b/core/src/main/scala/shapesafe/core/shape/ops/VectorOps.scala @@ -5,10 +5,15 @@ import shapesafe.core.shape.{Names, Shape} trait VectorOps extends HasShape { def dot[THAT <: Shape](that: THAT) = { - val s1 = shape :<<= Names.i - val s2 = that :<<= Names.i +// val s1 = shape :<<= Names.i +// val s2 = that :<<= Names.i +// +// (s1.einSum(s2) --> Names.Eye) >< Shape(1) - (s1.einSum(s2) --> Names.Eye) >< Shape(1) + // TODO: above is slight slower, need to optimise + + val outer = (shape >< that) :<<= Names.ii + (outer.einSum --> Names.Eye) >< Shape._1 } def cross[THAT <: Shape](that: THAT) = {